datafusion_functions_aggregate/min_max/
min_max_struct.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{cmp::Ordering, sync::Arc};
19
20use arrow::{
21    array::{
22        Array, ArrayData, ArrayRef, AsArray, BooleanArray, MutableArrayData, StructArray,
23    },
24    datatypes::DataType,
25};
26use datafusion_common::{
27    internal_err,
28    scalar::{copy_array_data, partial_cmp_struct},
29    Result,
30};
31use datafusion_expr::{EmitTo, GroupsAccumulator};
32use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
33
34/// Accumulator for MIN/MAX operations on Struct data types.
35///
36/// This accumulator tracks the minimum or maximum struct value encountered
37/// during aggregation, depending on the `is_min` flag.
38///
39/// The comparison is done based on the struct fields in order.
40pub(crate) struct MinMaxStructAccumulator {
41    /// Inner data storage.
42    inner: MinMaxStructState,
43    /// if true, is `MIN` otherwise is `MAX`
44    is_min: bool,
45}
46
47impl MinMaxStructAccumulator {
48    pub fn new_min(data_type: DataType) -> Self {
49        Self {
50            inner: MinMaxStructState::new(data_type),
51            is_min: true,
52        }
53    }
54
55    pub fn new_max(data_type: DataType) -> Self {
56        Self {
57            inner: MinMaxStructState::new(data_type),
58            is_min: false,
59        }
60    }
61}
62
63impl GroupsAccumulator for MinMaxStructAccumulator {
64    fn update_batch(
65        &mut self,
66        values: &[ArrayRef],
67        group_indices: &[usize],
68        opt_filter: Option<&BooleanArray>,
69        total_num_groups: usize,
70    ) -> Result<()> {
71        let array = &values[0];
72        assert_eq!(array.len(), group_indices.len());
73        assert_eq!(array.data_type(), &self.inner.data_type);
74        // apply filter if needed
75        let array = apply_filter_as_nulls(array, opt_filter)?;
76
77        fn struct_min(a: &StructArray, b: &StructArray) -> bool {
78            matches!(partial_cmp_struct(a, b), Some(Ordering::Less))
79        }
80
81        fn struct_max(a: &StructArray, b: &StructArray) -> bool {
82            matches!(partial_cmp_struct(a, b), Some(Ordering::Greater))
83        }
84
85        if self.is_min {
86            self.inner.update_batch(
87                array.as_struct(),
88                group_indices,
89                total_num_groups,
90                struct_min,
91            )
92        } else {
93            self.inner.update_batch(
94                array.as_struct(),
95                group_indices,
96                total_num_groups,
97                struct_max,
98            )
99        }
100    }
101
102    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
103        let (_, min_maxes) = self.inner.emit_to(emit_to);
104        let fields = match &self.inner.data_type {
105            DataType::Struct(fields) => fields,
106            _ => return internal_err!("Data type is not a struct"),
107        };
108        let null_array = StructArray::new_null(fields.clone(), 1);
109        let min_maxes_data: Vec<ArrayData> = min_maxes
110            .iter()
111            .map(|v| match v {
112                Some(v) => v.to_data(),
113                None => null_array.to_data(),
114            })
115            .collect();
116        let min_maxes_refs: Vec<&ArrayData> = min_maxes_data.iter().collect();
117        let mut copy = MutableArrayData::new(min_maxes_refs, true, min_maxes_data.len());
118
119        for (i, item) in min_maxes_data.iter().enumerate() {
120            copy.extend(i, 0, item.len());
121        }
122        let result = copy.freeze();
123        assert_eq!(&self.inner.data_type, result.data_type());
124        Ok(Arc::new(StructArray::from(result)))
125    }
126
127    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
128        // min/max are their own states (no transition needed)
129        self.evaluate(emit_to).map(|arr| vec![arr])
130    }
131
132    fn merge_batch(
133        &mut self,
134        values: &[ArrayRef],
135        group_indices: &[usize],
136        opt_filter: Option<&BooleanArray>,
137        total_num_groups: usize,
138    ) -> Result<()> {
139        // min/max are their own states (no transition needed)
140        self.update_batch(values, group_indices, opt_filter, total_num_groups)
141    }
142
143    fn convert_to_state(
144        &self,
145        values: &[ArrayRef],
146        opt_filter: Option<&BooleanArray>,
147    ) -> Result<Vec<ArrayRef>> {
148        // Min/max do not change the values as they are their own states
149        // apply the filter by combining with the null mask, if any
150        let output = apply_filter_as_nulls(&values[0], opt_filter)?;
151        Ok(vec![output])
152    }
153
154    fn supports_convert_to_state(&self) -> bool {
155        true
156    }
157
158    fn size(&self) -> usize {
159        self.inner.size()
160    }
161}
162
163#[derive(Debug)]
164struct MinMaxStructState {
165    /// The minimum/maximum value for each group
166    min_max: Vec<Option<StructArray>>,
167    /// The data type of the array
168    data_type: DataType,
169    /// The total bytes of the string data (for pre-allocating the final array,
170    /// and tracking memory usage)
171    total_data_bytes: usize,
172}
173
174#[derive(Debug, Clone)]
175enum MinMaxLocation {
176    /// the min/max value is stored in the existing `min_max` array
177    ExistingMinMax,
178    /// the min/max value is stored in the input array at the given index
179    Input(StructArray),
180}
181
182/// Implement the MinMaxStructState with a comparison function
183/// for comparing structs
184impl MinMaxStructState {
185    /// Create a new MinMaxStructState
186    ///
187    /// # Arguments:
188    /// * `data_type`: The data type of the arrays that will be passed to this accumulator
189    fn new(data_type: DataType) -> Self {
190        Self {
191            min_max: vec![],
192            data_type,
193            total_data_bytes: 0,
194        }
195    }
196
197    /// Set the specified group to the given value, updating memory usage appropriately
198    fn set_value(&mut self, group_index: usize, new_val: &StructArray) {
199        let new_val = StructArray::from(copy_array_data(&new_val.to_data()));
200        match self.min_max[group_index].as_mut() {
201            None => {
202                self.total_data_bytes += new_val.get_array_memory_size();
203                self.min_max[group_index] = Some(new_val);
204            }
205            Some(existing_val) => {
206                // Copy data over to avoid re-allocating
207                self.total_data_bytes -= existing_val.get_array_memory_size();
208                self.total_data_bytes += new_val.get_array_memory_size();
209                *existing_val = new_val;
210            }
211        }
212    }
213
214    /// Updates the min/max values for the given string values
215    ///
216    /// `cmp` is the  comparison function to use, called like `cmp(new_val, existing_val)`
217    /// returns true if the `new_val` should replace `existing_val`
218    fn update_batch<F>(
219        &mut self,
220        array: &StructArray,
221        group_indices: &[usize],
222        total_num_groups: usize,
223        mut cmp: F,
224    ) -> Result<()>
225    where
226        F: FnMut(&StructArray, &StructArray) -> bool + Send + Sync,
227    {
228        self.min_max.resize(total_num_groups, None);
229        // Minimize value copies by calculating the new min/maxes for each group
230        // in this batch (either the existing min/max or the new input value)
231        // and updating the owned values in `self.min_maxes` at most once
232        let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups];
233
234        // Figure out the new min value for each group
235        for (index, group_index) in (0..array.len()).zip(group_indices.iter()) {
236            let group_index = *group_index;
237            if array.is_null(index) {
238                continue;
239            }
240            let new_val = array.slice(index, 1);
241
242            let existing_val = match &locations[group_index] {
243                // previous input value was the min/max, so compare it
244                MinMaxLocation::Input(existing_val) => existing_val,
245                MinMaxLocation::ExistingMinMax => {
246                    let Some(existing_val) = self.min_max[group_index].as_ref() else {
247                        // no existing min/max, so this is the new min/max
248                        locations[group_index] = MinMaxLocation::Input(new_val);
249                        continue;
250                    };
251                    existing_val
252                }
253            };
254
255            // Compare the new value to the existing value, replacing if necessary
256            if cmp(&new_val, existing_val) {
257                locations[group_index] = MinMaxLocation::Input(new_val);
258            }
259        }
260
261        // Update self.min_max with any new min/max values we found in the input
262        for (group_index, location) in locations.iter().enumerate() {
263            match location {
264                MinMaxLocation::ExistingMinMax => {}
265                MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val),
266            }
267        }
268        Ok(())
269    }
270
271    /// Emits the specified min_max values
272    ///
273    /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes
274    ///
275    /// - `data_capacity`: the total length of all strings and their contents,
276    /// - `min_maxes`: the actual min/max values for each group
277    fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec<Option<StructArray>>) {
278        match emit_to {
279            EmitTo::All => {
280                (
281                    std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max
282                    std::mem::take(&mut self.min_max),
283                )
284            }
285            EmitTo::First(n) => {
286                let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect();
287                let first_data_capacity: usize = first_min_maxes
288                    .iter()
289                    .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0))
290                    .sum();
291                self.total_data_bytes -= first_data_capacity;
292                (first_data_capacity, first_min_maxes)
293            }
294        }
295    }
296
297    fn size(&self) -> usize {
298        self.total_data_bytes + self.min_max.len() * size_of::<Option<StructArray>>()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use arrow::array::{Int32Array, StringArray, StructArray};
306    use arrow::datatypes::{DataType, Field, Fields, Int32Type};
307    use std::sync::Arc;
308
309    fn create_test_struct_array(
310        int_values: Vec<Option<i32>>,
311        str_values: Vec<Option<&str>>,
312    ) -> StructArray {
313        let int_array = Int32Array::from(int_values);
314        let str_array = StringArray::from(str_values);
315
316        let fields = vec![
317            Field::new("int_field", DataType::Int32, true),
318            Field::new("str_field", DataType::Utf8, true),
319        ];
320
321        StructArray::new(
322            Fields::from(fields),
323            vec![
324                Arc::new(int_array) as ArrayRef,
325                Arc::new(str_array) as ArrayRef,
326            ],
327            None,
328        )
329    }
330
331    fn create_nested_struct_array(
332        int_values: Vec<Option<i32>>,
333        str_values: Vec<Option<&str>>,
334    ) -> StructArray {
335        let inner_struct = create_test_struct_array(int_values, str_values);
336
337        let fields = vec![Field::new("inner", inner_struct.data_type().clone(), true)];
338
339        StructArray::new(
340            Fields::from(fields),
341            vec![Arc::new(inner_struct) as ArrayRef],
342            None,
343        )
344    }
345
346    #[test]
347    fn test_min_max_simple_struct() {
348        let array = create_test_struct_array(
349            vec![Some(1), Some(2), Some(3)],
350            vec![Some("a"), Some("b"), Some("c")],
351        );
352
353        let mut min_accumulator =
354            MinMaxStructAccumulator::new_min(array.data_type().clone());
355        let mut max_accumulator =
356            MinMaxStructAccumulator::new_max(array.data_type().clone());
357        let values = vec![Arc::new(array) as ArrayRef];
358        let group_indices = vec![0, 0, 0];
359
360        min_accumulator
361            .update_batch(&values, &group_indices, None, 1)
362            .unwrap();
363        max_accumulator
364            .update_batch(&values, &group_indices, None, 1)
365            .unwrap();
366        let min_result = min_accumulator.evaluate(EmitTo::All).unwrap();
367        let max_result = max_accumulator.evaluate(EmitTo::All).unwrap();
368        let min_result = min_result.as_struct();
369        let max_result = max_result.as_struct();
370
371        assert_eq!(min_result.len(), 1);
372        let int_array = min_result.column(0).as_primitive::<Int32Type>();
373        let str_array = min_result.column(1).as_string::<i32>();
374        assert_eq!(int_array.value(0), 1);
375        assert_eq!(str_array.value(0), "a");
376
377        assert_eq!(max_result.len(), 1);
378        let int_array = max_result.column(0).as_primitive::<Int32Type>();
379        let str_array = max_result.column(1).as_string::<i32>();
380        assert_eq!(int_array.value(0), 3);
381        assert_eq!(str_array.value(0), "c");
382    }
383
384    #[test]
385    fn test_min_max_nested_struct() {
386        let array = create_nested_struct_array(
387            vec![Some(1), Some(2), Some(3)],
388            vec![Some("a"), Some("b"), Some("c")],
389        );
390
391        let mut min_accumulator =
392            MinMaxStructAccumulator::new_min(array.data_type().clone());
393        let mut max_accumulator =
394            MinMaxStructAccumulator::new_max(array.data_type().clone());
395        let values = vec![Arc::new(array) as ArrayRef];
396        let group_indices = vec![0, 0, 0];
397
398        min_accumulator
399            .update_batch(&values, &group_indices, None, 1)
400            .unwrap();
401        max_accumulator
402            .update_batch(&values, &group_indices, None, 1)
403            .unwrap();
404        let min_result = min_accumulator.evaluate(EmitTo::All).unwrap();
405        let max_result = max_accumulator.evaluate(EmitTo::All).unwrap();
406        let min_result = min_result.as_struct();
407        let max_result = max_result.as_struct();
408
409        assert_eq!(min_result.len(), 1);
410        let inner = min_result.column(0).as_struct();
411        let int_array = inner.column(0).as_primitive::<Int32Type>();
412        let str_array = inner.column(1).as_string::<i32>();
413        assert_eq!(int_array.value(0), 1);
414        assert_eq!(str_array.value(0), "a");
415
416        assert_eq!(max_result.len(), 1);
417        let inner = max_result.column(0).as_struct();
418        let int_array = inner.column(0).as_primitive::<Int32Type>();
419        let str_array = inner.column(1).as_string::<i32>();
420        assert_eq!(int_array.value(0), 3);
421        assert_eq!(str_array.value(0), "c");
422    }
423
424    #[test]
425    fn test_min_max_with_nulls() {
426        let array = create_test_struct_array(
427            vec![Some(1), None, Some(3)],
428            vec![Some("a"), None, Some("c")],
429        );
430
431        let mut min_accumulator =
432            MinMaxStructAccumulator::new_min(array.data_type().clone());
433        let mut max_accumulator =
434            MinMaxStructAccumulator::new_max(array.data_type().clone());
435        let values = vec![Arc::new(array) as ArrayRef];
436        let group_indices = vec![0, 0, 0];
437
438        min_accumulator
439            .update_batch(&values, &group_indices, None, 1)
440            .unwrap();
441        max_accumulator
442            .update_batch(&values, &group_indices, None, 1)
443            .unwrap();
444        let min_result = min_accumulator.evaluate(EmitTo::All).unwrap();
445        let max_result = max_accumulator.evaluate(EmitTo::All).unwrap();
446        let min_result = min_result.as_struct();
447        let max_result = max_result.as_struct();
448
449        assert_eq!(min_result.len(), 1);
450        let int_array = min_result.column(0).as_primitive::<Int32Type>();
451        let str_array = min_result.column(1).as_string::<i32>();
452        assert_eq!(int_array.value(0), 1);
453        assert_eq!(str_array.value(0), "a");
454
455        assert_eq!(max_result.len(), 1);
456        let int_array = max_result.column(0).as_primitive::<Int32Type>();
457        let str_array = max_result.column(1).as_string::<i32>();
458        assert_eq!(int_array.value(0), 3);
459        assert_eq!(str_array.value(0), "c");
460    }
461
462    #[test]
463    fn test_min_max_multiple_groups() {
464        let array = create_test_struct_array(
465            vec![Some(1), Some(2), Some(3), Some(4)],
466            vec![Some("a"), Some("b"), Some("c"), Some("d")],
467        );
468
469        let mut min_accumulator =
470            MinMaxStructAccumulator::new_min(array.data_type().clone());
471        let mut max_accumulator =
472            MinMaxStructAccumulator::new_max(array.data_type().clone());
473        let values = vec![Arc::new(array) as ArrayRef];
474        let group_indices = vec![0, 1, 0, 1];
475
476        min_accumulator
477            .update_batch(&values, &group_indices, None, 2)
478            .unwrap();
479        max_accumulator
480            .update_batch(&values, &group_indices, None, 2)
481            .unwrap();
482        let min_result = min_accumulator.evaluate(EmitTo::All).unwrap();
483        let max_result = max_accumulator.evaluate(EmitTo::All).unwrap();
484        let min_result = min_result.as_struct();
485        let max_result = max_result.as_struct();
486
487        assert_eq!(min_result.len(), 2);
488        let int_array = min_result.column(0).as_primitive::<Int32Type>();
489        let str_array = min_result.column(1).as_string::<i32>();
490        assert_eq!(int_array.value(0), 1);
491        assert_eq!(str_array.value(0), "a");
492        assert_eq!(int_array.value(1), 2);
493        assert_eq!(str_array.value(1), "b");
494
495        assert_eq!(max_result.len(), 2);
496        let int_array = max_result.column(0).as_primitive::<Int32Type>();
497        let str_array = max_result.column(1).as_string::<i32>();
498        assert_eq!(int_array.value(0), 3);
499        assert_eq!(str_array.value(0), "c");
500        assert_eq!(int_array.value(1), 4);
501        assert_eq!(str_array.value(1), "d");
502    }
503
504    #[test]
505    fn test_min_max_with_filter() {
506        let array = create_test_struct_array(
507            vec![Some(1), Some(2), Some(3), Some(4)],
508            vec![Some("a"), Some("b"), Some("c"), Some("d")],
509        );
510
511        // Create a filter that only keeps even numbers
512        let filter = BooleanArray::from(vec![false, true, false, true]);
513
514        let mut min_accumulator =
515            MinMaxStructAccumulator::new_min(array.data_type().clone());
516        let mut max_accumulator =
517            MinMaxStructAccumulator::new_max(array.data_type().clone());
518        let values = vec![Arc::new(array) as ArrayRef];
519        let group_indices = vec![0, 0, 0, 0];
520
521        min_accumulator
522            .update_batch(&values, &group_indices, Some(&filter), 1)
523            .unwrap();
524        max_accumulator
525            .update_batch(&values, &group_indices, Some(&filter), 1)
526            .unwrap();
527        let min_result = min_accumulator.evaluate(EmitTo::All).unwrap();
528        let max_result = max_accumulator.evaluate(EmitTo::All).unwrap();
529        let min_result = min_result.as_struct();
530        let max_result = max_result.as_struct();
531
532        assert_eq!(min_result.len(), 1);
533        let int_array = min_result.column(0).as_primitive::<Int32Type>();
534        let str_array = min_result.column(1).as_string::<i32>();
535        assert_eq!(int_array.value(0), 2);
536        assert_eq!(str_array.value(0), "b");
537
538        assert_eq!(max_result.len(), 1);
539        let int_array = max_result.column(0).as_primitive::<Int32Type>();
540        let str_array = max_result.column(1).as_string::<i32>();
541        assert_eq!(int_array.value(0), 4);
542        assert_eq!(str_array.value(0), "d");
543    }
544}