datafusion_physical_plan/aggregates/order/
partial.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::cmp::Ordering;
19use std::mem::size_of;
20use std::sync::Arc;
21
22use arrow::array::ArrayRef;
23use arrow::compute::SortOptions;
24use arrow_ord::partition::partition;
25use datafusion_common::utils::{compare_rows, get_row_at_idx};
26use datafusion_common::{Result, ScalarValue};
27use datafusion_execution::memory_pool::proxy::VecAllocExt;
28use datafusion_expr::EmitTo;
29
30/// Tracks grouping state when the data is ordered by some subset of
31/// the group keys.
32///
33/// Once the next *sort key* value is seen, never see groups with that
34/// sort key again, so we can emit all groups with the previous sort
35/// key and earlier.
36///
37/// For example, given `SUM(amt) GROUP BY id, state` if the input is
38/// sorted by `state`, when a new value of `state` is seen, all groups
39/// with prior values of `state` can be emitted.
40///
41/// The state is tracked like this:
42///
43/// ```text
44///                                            ┏━━━━━━━━━━━━━━━━━┓ ┏━━━━━━━┓
45///     ┌─────┐    ┌───────────────────┐ ┌─────┃        9        ┃ ┃ "MD"  ┃
46///     │┌───┐│    │ ┌──────────────┐  │ │     ┗━━━━━━━━━━━━━━━━━┛ ┗━━━━━━━┛
47///     ││ 0 ││    │ │  123, "MA"   │  │ │        current_sort      sort_key
48///     │└───┘│    │ └──────────────┘  │ │
49///     │ ... │    │    ...            │ │      current_sort tracks the
50///     │┌───┐│    │ ┌──────────────┐  │ │      smallest group index that had
51///     ││ 8 ││    │ │  765, "MA"   │  │ │      the same sort_key as current
52///     │├───┤│    │ ├──────────────┤  │ │
53///     ││ 9 ││    │ │  923, "MD"   │◀─┼─┘
54///     │├───┤│    │ ├──────────────┤  │        ┏━━━━━━━━━━━━━━┓
55///     ││10 ││    │ │  345, "MD"   │  │  ┌─────┃      11      ┃
56///     │├───┤│    │ ├──────────────┤  │  │     ┗━━━━━━━━━━━━━━┛
57///     ││11 ││    │ │  124, "MD"   │◀─┼──┘         current
58///     │└───┘│    │ └──────────────┘  │
59///     └─────┘    └───────────────────┘
60///
61///  group indices
62/// (in group value  group_values               current tracks the most
63///      order)                                    recent group index
64/// ```
65#[derive(Debug)]
66pub struct GroupOrderingPartial {
67    /// State machine
68    state: State,
69
70    /// The indexes of the group by columns that form the sort key.
71    /// For example if grouping by `id, state` and ordered by `state`
72    /// this would be `[1]`.
73    order_indices: Vec<usize>,
74}
75
76#[derive(Debug, Default, PartialEq)]
77enum State {
78    /// The ordering was temporarily taken.  `Self::Taken` is left
79    /// when state must be temporarily taken to satisfy the borrow
80    /// checker. If an error happens before the state can be restored,
81    /// the ordering information is lost and execution can not
82    /// proceed, but there is no undefined behavior.
83    #[default]
84    Taken,
85
86    /// Seen no input yet
87    Start,
88
89    /// Data is in progress.
90    InProgress {
91        /// Smallest group index with the sort_key
92        current_sort: usize,
93        /// The sort key of group_index `current_sort`
94        sort_key: Vec<ScalarValue>,
95        /// index of the current group for which values are being
96        /// generated
97        current: usize,
98    },
99
100    /// Seen end of input, all groups can be emitted
101    Complete,
102}
103
104impl State {
105    fn size(&self) -> usize {
106        match self {
107            State::Taken => 0,
108            State::Start => 0,
109            State::InProgress { sort_key, .. } => sort_key
110                .iter()
111                .map(|scalar_value| scalar_value.size())
112                .sum(),
113            State::Complete => 0,
114        }
115    }
116}
117
118impl GroupOrderingPartial {
119    /// TODO: Remove unnecessary `input_schema` parameter.
120    pub fn try_new(order_indices: Vec<usize>) -> Result<Self> {
121        debug_assert!(!order_indices.is_empty());
122        Ok(Self {
123            state: State::Start,
124            order_indices,
125        })
126    }
127
128    /// Select sort keys from the group values
129    ///
130    /// For example, if group_values had `A, B, C` but the input was
131    /// only sorted on `B` and `C` this should return rows for (`B`,
132    /// `C`)
133    fn compute_sort_keys(&mut self, group_values: &[ArrayRef]) -> Vec<ArrayRef> {
134        // Take only the columns that are in the sort key
135        self.order_indices
136            .iter()
137            .map(|&idx| Arc::clone(&group_values[idx]))
138            .collect()
139    }
140
141    /// How many groups be emitted, or None if no data can be emitted
142    pub fn emit_to(&self) -> Option<EmitTo> {
143        match &self.state {
144            State::Taken => unreachable!("State previously taken"),
145            State::Start => None,
146            State::InProgress { current_sort, .. } => {
147                // Can not emit if we are still on the first row sort
148                // row otherwise we can emit all groups that had earlier sort keys
149                //
150                if *current_sort == 0 {
151                    None
152                } else {
153                    Some(EmitTo::First(*current_sort))
154                }
155            }
156            State::Complete => Some(EmitTo::All),
157        }
158    }
159
160    /// remove the first n groups from the internal state, shifting
161    /// all existing indexes down by `n`
162    pub fn remove_groups(&mut self, n: usize) {
163        match &mut self.state {
164            State::Taken => unreachable!("State previously taken"),
165            State::Start => panic!("invalid state: start"),
166            State::InProgress {
167                current_sort,
168                current,
169                sort_key: _,
170            } => {
171                // shift indexes down by n
172                assert!(*current >= n);
173                *current -= n;
174                assert!(*current_sort >= n);
175                *current_sort -= n;
176            }
177            State::Complete => panic!("invalid state: complete"),
178        }
179    }
180
181    /// Note that the input is complete so any outstanding groups are done as well
182    pub fn input_done(&mut self) {
183        self.state = match self.state {
184            State::Taken => unreachable!("State previously taken"),
185            _ => State::Complete,
186        };
187    }
188
189    fn updated_sort_key(
190        current_sort: usize,
191        sort_key: Option<Vec<ScalarValue>>,
192        range_current_sort: usize,
193        range_sort_key: Vec<ScalarValue>,
194    ) -> Result<(usize, Vec<ScalarValue>)> {
195        if let Some(sort_key) = sort_key {
196            let sort_options = vec![SortOptions::new(false, false); sort_key.len()];
197            let ordering = compare_rows(&sort_key, &range_sort_key, &sort_options)?;
198            if ordering == Ordering::Equal {
199                return Ok((current_sort, sort_key));
200            }
201        }
202
203        Ok((range_current_sort, range_sort_key))
204    }
205
206    /// Called when new groups are added in a batch. See documentation
207    /// on [`super::GroupOrdering::new_groups`]
208    pub fn new_groups(
209        &mut self,
210        batch_group_values: &[ArrayRef],
211        group_indices: &[usize],
212        total_num_groups: usize,
213    ) -> Result<()> {
214        assert!(total_num_groups > 0);
215        assert!(!batch_group_values.is_empty());
216
217        let max_group_index = total_num_groups - 1;
218
219        let (current_sort, sort_key) = match std::mem::take(&mut self.state) {
220            State::Taken => unreachable!("State previously taken"),
221            State::Start => (0, None),
222            State::InProgress {
223                current_sort,
224                sort_key,
225                ..
226            } => (current_sort, Some(sort_key)),
227            State::Complete => {
228                panic!("Saw new group after the end of input");
229            }
230        };
231
232        // Select the sort key columns
233        let sort_keys = self.compute_sort_keys(batch_group_values);
234
235        // Check if the sort keys indicate a boundary inside the batch
236        let ranges = partition(&sort_keys)?.ranges();
237        let last_range = ranges.last().unwrap();
238
239        let range_current_sort = group_indices[last_range.start];
240        let range_sort_key = get_row_at_idx(&sort_keys, last_range.start)?;
241
242        let (current_sort, sort_key) = if last_range.start == 0 {
243            // There was no boundary in the batch. Compare with the previous sort_key (if present)
244            // to check if there was a boundary between the current batch and the previous one.
245            Self::updated_sort_key(
246                current_sort,
247                sort_key,
248                range_current_sort,
249                range_sort_key,
250            )?
251        } else {
252            (range_current_sort, range_sort_key)
253        };
254
255        self.state = State::InProgress {
256            current_sort,
257            current: max_group_index,
258            sort_key,
259        };
260
261        Ok(())
262    }
263
264    /// Return the size of memory allocated by this structure
265    pub(crate) fn size(&self) -> usize {
266        size_of::<Self>() + self.order_indices.allocated_size() + self.state.size()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    use arrow::array::Int32Array;
275
276    #[test]
277    fn test_group_ordering_partial() -> Result<()> {
278        // Ordered on column a
279        let order_indices = vec![0];
280        let mut group_ordering = GroupOrderingPartial::try_new(order_indices)?;
281
282        let batch_group_values: Vec<ArrayRef> = vec![
283            Arc::new(Int32Array::from(vec![1, 2, 3])),
284            Arc::new(Int32Array::from(vec![2, 1, 3])),
285        ];
286
287        let group_indices = vec![0, 1, 2];
288        let total_num_groups = 3;
289
290        group_ordering.new_groups(
291            &batch_group_values,
292            &group_indices,
293            total_num_groups,
294        )?;
295
296        assert_eq!(
297            group_ordering.state,
298            State::InProgress {
299                current_sort: 2,
300                sort_key: vec![ScalarValue::Int32(Some(3))],
301                current: 2
302            }
303        );
304
305        // push without a boundary
306        let batch_group_values: Vec<ArrayRef> = vec![
307            Arc::new(Int32Array::from(vec![3, 3, 3])),
308            Arc::new(Int32Array::from(vec![2, 1, 7])),
309        ];
310        let group_indices = vec![3, 4, 5];
311        let total_num_groups = 6;
312
313        group_ordering.new_groups(
314            &batch_group_values,
315            &group_indices,
316            total_num_groups,
317        )?;
318
319        assert_eq!(
320            group_ordering.state,
321            State::InProgress {
322                current_sort: 2,
323                sort_key: vec![ScalarValue::Int32(Some(3))],
324                current: 5
325            }
326        );
327
328        // push with only a boundary to previous batch
329        let batch_group_values: Vec<ArrayRef> = vec![
330            Arc::new(Int32Array::from(vec![4, 4, 4])),
331            Arc::new(Int32Array::from(vec![1, 1, 1])),
332        ];
333        let group_indices = vec![6, 7, 8];
334        let total_num_groups = 9;
335
336        group_ordering.new_groups(
337            &batch_group_values,
338            &group_indices,
339            total_num_groups,
340        )?;
341        assert_eq!(
342            group_ordering.state,
343            State::InProgress {
344                current_sort: 6,
345                sort_key: vec![ScalarValue::Int32(Some(4))],
346                current: 8
347            }
348        );
349
350        Ok(())
351    }
352}