datafusion_physical_plan/aggregates/order/partial.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::Ordering;
19use std::mem::size_of;
20use std::sync::Arc;
21
22use arrow::array::ArrayRef;
23use arrow::compute::SortOptions;
24use arrow_ord::partition::partition;
25use datafusion_common::utils::{compare_rows, get_row_at_idx};
26use datafusion_common::{Result, ScalarValue};
27use datafusion_execution::memory_pool::proxy::VecAllocExt;
28use datafusion_expr::EmitTo;
29
30/// Tracks grouping state when the data is ordered by some subset of
31/// the group keys.
32///
33/// Once the next *sort key* value is seen, never see groups with that
34/// sort key again, so we can emit all groups with the previous sort
35/// key and earlier.
36///
37/// For example, given `SUM(amt) GROUP BY id, state` if the input is
38/// sorted by `state`, when a new value of `state` is seen, all groups
39/// with prior values of `state` can be emitted.
40///
41/// The state is tracked like this:
42///
43/// ```text
44/// ┏━━━━━━━━━━━━━━━━━┓ ┏━━━━━━━┓
45/// ┌─────┐ ┌───────────────────┐ ┌─────┃ 9 ┃ ┃ "MD" ┃
46/// │┌───┐│ │ ┌──────────────┐ │ │ ┗━━━━━━━━━━━━━━━━━┛ ┗━━━━━━━┛
47/// ││ 0 ││ │ │ 123, "MA" │ │ │ current_sort sort_key
48/// │└───┘│ │ └──────────────┘ │ │
49/// │ ... │ │ ... │ │ current_sort tracks the
50/// │┌───┐│ │ ┌──────────────┐ │ │ smallest group index that had
51/// ││ 8 ││ │ │ 765, "MA" │ │ │ the same sort_key as current
52/// │├───┤│ │ ├──────────────┤ │ │
53/// ││ 9 ││ │ │ 923, "MD" │◀─┼─┘
54/// │├───┤│ │ ├──────────────┤ │ ┏━━━━━━━━━━━━━━┓
55/// ││10 ││ │ │ 345, "MD" │ │ ┌─────┃ 11 ┃
56/// │├───┤│ │ ├──────────────┤ │ │ ┗━━━━━━━━━━━━━━┛
57/// ││11 ││ │ │ 124, "MD" │◀─┼──┘ current
58/// │└───┘│ │ └──────────────┘ │
59/// └─────┘ └───────────────────┘
60///
61/// group indices
62/// (in group value group_values current tracks the most
63/// order) recent group index
64/// ```
65#[derive(Debug)]
66pub struct GroupOrderingPartial {
67 /// State machine
68 state: State,
69
70 /// The indexes of the group by columns that form the sort key.
71 /// For example if grouping by `id, state` and ordered by `state`
72 /// this would be `[1]`.
73 order_indices: Vec<usize>,
74}
75
76#[derive(Debug, Default, PartialEq)]
77enum State {
78 /// The ordering was temporarily taken. `Self::Taken` is left
79 /// when state must be temporarily taken to satisfy the borrow
80 /// checker. If an error happens before the state can be restored,
81 /// the ordering information is lost and execution can not
82 /// proceed, but there is no undefined behavior.
83 #[default]
84 Taken,
85
86 /// Seen no input yet
87 Start,
88
89 /// Data is in progress.
90 InProgress {
91 /// Smallest group index with the sort_key
92 current_sort: usize,
93 /// The sort key of group_index `current_sort`
94 sort_key: Vec<ScalarValue>,
95 /// index of the current group for which values are being
96 /// generated
97 current: usize,
98 },
99
100 /// Seen end of input, all groups can be emitted
101 Complete,
102}
103
104impl State {
105 fn size(&self) -> usize {
106 match self {
107 State::Taken => 0,
108 State::Start => 0,
109 State::InProgress { sort_key, .. } => sort_key
110 .iter()
111 .map(|scalar_value| scalar_value.size())
112 .sum(),
113 State::Complete => 0,
114 }
115 }
116}
117
118impl GroupOrderingPartial {
119 /// TODO: Remove unnecessary `input_schema` parameter.
120 pub fn try_new(order_indices: Vec<usize>) -> Result<Self> {
121 debug_assert!(!order_indices.is_empty());
122 Ok(Self {
123 state: State::Start,
124 order_indices,
125 })
126 }
127
128 /// Select sort keys from the group values
129 ///
130 /// For example, if group_values had `A, B, C` but the input was
131 /// only sorted on `B` and `C` this should return rows for (`B`,
132 /// `C`)
133 fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Vec<ArrayRef> {
134 // Take only the columns that are in the sort key
135 self.order_indices
136 .iter()
137 .map(|&idx| Arc::clone(&group_values[idx]))
138 .collect()
139 }
140
141 /// How many groups be emitted, or None if no data can be emitted
142 pub fn emit_to(&self) -> Option<EmitTo> {
143 match &self.state {
144 State::Taken => unreachable!("State previously taken"),
145 State::Start => None,
146 State::InProgress { current_sort, .. } => {
147 // Can not emit if we are still on the first row sort
148 // row otherwise we can emit all groups that had earlier sort keys
149 //
150 if *current_sort == 0 {
151 None
152 } else {
153 Some(EmitTo::First(*current_sort))
154 }
155 }
156 State::Complete => Some(EmitTo::All),
157 }
158 }
159
160 /// remove the first n groups from the internal state, shifting
161 /// all existing indexes down by `n`
162 pub fn remove_groups(&mut self, n: usize) {
163 match &mut self.state {
164 State::Taken => unreachable!("State previously taken"),
165 State::Start => panic!("invalid state: start"),
166 State::InProgress {
167 current_sort,
168 current,
169 sort_key: _,
170 } => {
171 // shift indexes down by n
172 assert!(*current >= n);
173 *current -= n;
174 assert!(*current_sort >= n);
175 *current_sort -= n;
176 }
177 State::Complete => panic!("invalid state: complete"),
178 }
179 }
180
181 /// Note that the input is complete so any outstanding groups are done as well
182 pub fn input_done(&mut self) {
183 self.state = match self.state {
184 State::Taken => unreachable!("State previously taken"),
185 _ => State::Complete,
186 };
187 }
188
189 fn updated_sort_key(
190 current_sort: usize,
191 sort_key: Option<Vec<ScalarValue>>,
192 range_current_sort: usize,
193 range_sort_key: Vec<ScalarValue>,
194 ) -> Result<(usize, Vec<ScalarValue>)> {
195 if let Some(sort_key) = sort_key {
196 let sort_options = vec![SortOptions::new(false, false); sort_key.len()];
197 let ordering = compare_rows(&sort_key, &range_sort_key, &sort_options)?;
198 if ordering == Ordering::Equal {
199 return Ok((current_sort, sort_key));
200 }
201 }
202
203 Ok((range_current_sort, range_sort_key))
204 }
205
206 /// Called when new groups are added in a batch. See documentation
207 /// on [`super::GroupOrdering::new_groups`]
208 pub fn new_groups(
209 &mut self,
210 batch_group_values: &[ArrayRef],
211 group_indices: &[usize],
212 total_num_groups: usize,
213 ) -> Result<()> {
214 assert!(total_num_groups > 0);
215 assert!(!batch_group_values.is_empty());
216
217 let max_group_index = total_num_groups - 1;
218
219 let (current_sort, sort_key) = match std::mem::take(&mut self.state) {
220 State::Taken => unreachable!("State previously taken"),
221 State::Start => (0, None),
222 State::InProgress {
223 current_sort,
224 sort_key,
225 ..
226 } => (current_sort, Some(sort_key)),
227 State::Complete => {
228 panic!("Saw new group after the end of input");
229 }
230 };
231
232 // Select the sort key columns
233 let sort_keys = self.compute_sort_keys(batch_group_values);
234
235 // Check if the sort keys indicate a boundary inside the batch
236 let ranges = partition(&sort_keys)?.ranges();
237 let last_range = ranges.last().unwrap();
238
239 let range_current_sort = group_indices[last_range.start];
240 let range_sort_key = get_row_at_idx(&sort_keys, last_range.start)?;
241
242 let (current_sort, sort_key) = if last_range.start == 0 {
243 // There was no boundary in the batch. Compare with the previous sort_key (if present)
244 // to check if there was a boundary between the current batch and the previous one.
245 Self::updated_sort_key(
246 current_sort,
247 sort_key,
248 range_current_sort,
249 range_sort_key,
250 )?
251 } else {
252 (range_current_sort, range_sort_key)
253 };
254
255 self.state = State::InProgress {
256 current_sort,
257 current: max_group_index,
258 sort_key,
259 };
260
261 Ok(())
262 }
263
264 /// Return the size of memory allocated by this structure
265 pub(crate) fn size(&self) -> usize {
266 size_of::<Self>() + self.order_indices.allocated_size() + self.state.size()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 use arrow::array::Int32Array;
275
276 #[test]
277 fn test_group_ordering_partial() -> Result<()> {
278 // Ordered on column a
279 let order_indices = vec![0];
280 let mut group_ordering = GroupOrderingPartial::try_new(order_indices)?;
281
282 let batch_group_values: Vec<ArrayRef> = vec![
283 Arc::new(Int32Array::from(vec![1, 2, 3])),
284 Arc::new(Int32Array::from(vec![2, 1, 3])),
285 ];
286
287 let group_indices = vec![0, 1, 2];
288 let total_num_groups = 3;
289
290 group_ordering.new_groups(
291 &batch_group_values,
292 &group_indices,
293 total_num_groups,
294 )?;
295
296 assert_eq!(
297 group_ordering.state,
298 State::InProgress {
299 current_sort: 2,
300 sort_key: vec![ScalarValue::Int32(Some(3))],
301 current: 2
302 }
303 );
304
305 // push without a boundary
306 let batch_group_values: Vec<ArrayRef> = vec![
307 Arc::new(Int32Array::from(vec![3, 3, 3])),
308 Arc::new(Int32Array::from(vec![2, 1, 7])),
309 ];
310 let group_indices = vec![3, 4, 5];
311 let total_num_groups = 6;
312
313 group_ordering.new_groups(
314 &batch_group_values,
315 &group_indices,
316 total_num_groups,
317 )?;
318
319 assert_eq!(
320 group_ordering.state,
321 State::InProgress {
322 current_sort: 2,
323 sort_key: vec![ScalarValue::Int32(Some(3))],
324 current: 5
325 }
326 );
327
328 // push with only a boundary to previous batch
329 let batch_group_values: Vec<ArrayRef> = vec![
330 Arc::new(Int32Array::from(vec![4, 4, 4])),
331 Arc::new(Int32Array::from(vec![1, 1, 1])),
332 ];
333 let group_indices = vec![6, 7, 8];
334 let total_num_groups = 9;
335
336 group_ordering.new_groups(
337 &batch_group_values,
338 &group_indices,
339 total_num_groups,
340 )?;
341 assert_eq!(
342 group_ordering.state,
343 State::InProgress {
344 current_sort: 6,
345 sort_key: vec![ScalarValue::Int32(Some(4))],
346 current: 8
347 }
348 );
349
350 Ok(())
351 }
352}