datafusion_physical_plan/aggregates/group_values/multi_group_by/
primitive.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::aggregates::group_values::multi_group_by::{
19    nulls_equal_to, GroupColumn, Nulls,
20};
21use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
22use arrow::array::ArrowNativeTypeOp;
23use arrow::array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
24use arrow::buffer::ScalarBuffer;
25use arrow::datatypes::DataType;
26use datafusion_common::Result;
27use datafusion_execution::memory_pool::proxy::VecAllocExt;
28use itertools::izip;
29use std::iter;
30use std::sync::Arc;
31
32/// An implementation of [`GroupColumn`] for primitive values
33///
34/// Optimized to skip null buffer construction if the input is known to be non nullable
35///
36/// # Template parameters
37///
38/// `T`: the native Rust type that stores the data
39/// `NULLABLE`: if the data can contain any nulls
40#[derive(Debug)]
41pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType, const NULLABLE: bool> {
42    data_type: DataType,
43    group_values: Vec<T::Native>,
44    nulls: MaybeNullBufferBuilder,
45}
46
47impl<T, const NULLABLE: bool> PrimitiveGroupValueBuilder<T, NULLABLE>
48where
49    T: ArrowPrimitiveType,
50{
51    /// Create a new `PrimitiveGroupValueBuilder`
52    pub fn new(data_type: DataType) -> Self {
53        Self {
54            data_type,
55            group_values: vec![],
56            nulls: MaybeNullBufferBuilder::new(),
57        }
58    }
59}
60
61impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
62    for PrimitiveGroupValueBuilder<T, NULLABLE>
63{
64    fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
65        // Perf: skip null check (by short circuit) if input is not nullable
66        if NULLABLE {
67            let exist_null = self.nulls.is_null(lhs_row);
68            let input_null = array.is_null(rhs_row);
69            if let Some(result) = nulls_equal_to(exist_null, input_null) {
70                return result;
71            }
72            // Otherwise, we need to check their values
73        }
74
75        self.group_values[lhs_row].is_eq(array.as_primitive::<T>().value(rhs_row))
76    }
77
78    fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> {
79        // Perf: skip null check if input can't have nulls
80        if NULLABLE {
81            if array.is_null(row) {
82                self.nulls.append(true);
83                self.group_values.push(T::default_value());
84            } else {
85                self.nulls.append(false);
86                self.group_values.push(array.as_primitive::<T>().value(row));
87            }
88        } else {
89            self.group_values.push(array.as_primitive::<T>().value(row));
90        }
91
92        Ok(())
93    }
94
95    fn vectorized_equal_to(
96        &self,
97        lhs_rows: &[usize],
98        array: &ArrayRef,
99        rhs_rows: &[usize],
100        equal_to_results: &mut [bool],
101    ) {
102        let array = array.as_primitive::<T>();
103
104        let iter = izip!(
105            lhs_rows.iter(),
106            rhs_rows.iter(),
107            equal_to_results.iter_mut(),
108        );
109
110        for (&lhs_row, &rhs_row, equal_to_result) in iter {
111            // Has found not equal to in previous column, don't need to check
112            if !*equal_to_result {
113                continue;
114            }
115
116            // Perf: skip null check (by short circuit) if input is not nullable
117            if NULLABLE {
118                let exist_null = self.nulls.is_null(lhs_row);
119                let input_null = array.is_null(rhs_row);
120                if let Some(result) = nulls_equal_to(exist_null, input_null) {
121                    *equal_to_result = result;
122                    continue;
123                }
124                // Otherwise, we need to check their values
125            }
126
127            *equal_to_result = self.group_values[lhs_row].is_eq(array.value(rhs_row));
128        }
129    }
130
131    fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> {
132        let arr = array.as_primitive::<T>();
133
134        let null_count = array.null_count();
135        let num_rows = array.len();
136        let all_null_or_non_null = if null_count == 0 {
137            Nulls::None
138        } else if null_count == num_rows {
139            Nulls::All
140        } else {
141            Nulls::Some
142        };
143
144        match (NULLABLE, all_null_or_non_null) {
145            (true, Nulls::Some) => {
146                for &row in rows {
147                    if array.is_null(row) {
148                        self.nulls.append(true);
149                        self.group_values.push(T::default_value());
150                    } else {
151                        self.nulls.append(false);
152                        self.group_values.push(arr.value(row));
153                    }
154                }
155            }
156
157            (true, Nulls::None) => {
158                self.nulls.append_n(rows.len(), false);
159                for &row in rows {
160                    self.group_values.push(arr.value(row));
161                }
162            }
163
164            (true, Nulls::All) => {
165                self.nulls.append_n(rows.len(), true);
166                self.group_values
167                    .extend(iter::repeat_n(T::default_value(), rows.len()));
168            }
169
170            (false, _) => {
171                for &row in rows {
172                    self.group_values.push(arr.value(row));
173                }
174            }
175        }
176
177        Ok(())
178    }
179
180    fn len(&self) -> usize {
181        self.group_values.len()
182    }
183
184    fn size(&self) -> usize {
185        self.group_values.allocated_size() + self.nulls.allocated_size()
186    }
187
188    fn build(self: Box<Self>) -> ArrayRef {
189        let Self {
190            data_type,
191            group_values,
192            nulls,
193        } = *self;
194
195        let nulls = nulls.build();
196        if !NULLABLE {
197            assert!(nulls.is_none(), "unexpected nulls in non nullable input");
198        }
199
200        let arr = PrimitiveArray::<T>::new(ScalarBuffer::from(group_values), nulls);
201        // Set timezone information for timestamp
202        Arc::new(arr.with_data_type(data_type))
203    }
204
205    fn take_n(&mut self, n: usize) -> ArrayRef {
206        let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
207
208        let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None };
209
210        Arc::new(
211            PrimitiveArray::<T>::new(ScalarBuffer::from(first_n), first_n_nulls)
212                .with_data_type(self.data_type.clone()),
213        )
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::sync::Arc;
220
221    use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder;
222    use arrow::array::{ArrayRef, Float32Array, Int64Array, NullBufferBuilder};
223    use arrow::datatypes::{DataType, Float32Type, Int64Type};
224
225    use super::GroupColumn;
226
227    #[test]
228    fn test_nullable_primitive_equal_to() {
229        let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
230                      builder_array: &ArrayRef,
231                      append_rows: &[usize]| {
232            for &index in append_rows {
233                builder.append_val(builder_array, index).unwrap();
234            }
235        };
236
237        let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
238                        lhs_rows: &[usize],
239                        input_array: &ArrayRef,
240                        rhs_rows: &[usize],
241                        equal_to_results: &mut Vec<bool>| {
242            let iter = lhs_rows.iter().zip(rhs_rows.iter());
243            for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
244                equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
245            }
246        };
247
248        test_nullable_primitive_equal_to_internal(append, equal_to);
249    }
250
251    #[test]
252    fn test_nullable_primitive_vectorized_equal_to() {
253        let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
254                      builder_array: &ArrayRef,
255                      append_rows: &[usize]| {
256            builder
257                .vectorized_append(builder_array, append_rows)
258                .unwrap();
259        };
260
261        let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
262                        lhs_rows: &[usize],
263                        input_array: &ArrayRef,
264                        rhs_rows: &[usize],
265                        equal_to_results: &mut Vec<bool>| {
266            builder.vectorized_equal_to(
267                lhs_rows,
268                input_array,
269                rhs_rows,
270                equal_to_results,
271            );
272        };
273
274        test_nullable_primitive_equal_to_internal(append, equal_to);
275    }
276
277    fn test_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
278    where
279        A: FnMut(&mut PrimitiveGroupValueBuilder<Float32Type, true>, &ArrayRef, &[usize]),
280        E: FnMut(
281            &PrimitiveGroupValueBuilder<Float32Type, true>,
282            &[usize],
283            &ArrayRef,
284            &[usize],
285            &mut Vec<bool>,
286        ),
287    {
288        // Will cover such cases:
289        //   - exist null, input not null
290        //   - exist null, input null; values not equal
291        //   - exist null, input null; values equal
292        //   - exist not null, input null
293        //   - exist not null, input not null; values not equal
294        //   - exist not null, input not null; values equal
295
296        // Define PrimitiveGroupValueBuilder
297        let mut builder =
298            PrimitiveGroupValueBuilder::<Float32Type, true>::new(DataType::Float32);
299        let builder_array = Arc::new(Float32Array::from(vec![
300            None,
301            None,
302            None,
303            Some(1.0),
304            Some(2.0),
305            Some(f32::NAN),
306            Some(3.0),
307        ])) as ArrayRef;
308        append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6]);
309
310        // Define input array
311        let (_, values, _nulls) = Float32Array::from(vec![
312            Some(1.0),
313            Some(2.0),
314            None,
315            Some(1.0),
316            None,
317            Some(f32::NAN),
318            None,
319        ])
320        .into_parts();
321
322        // explicitly build a null buffer where one of the null values also happens to match
323        let mut nulls = NullBufferBuilder::new(6);
324        nulls.append_non_null();
325        nulls.append_null(); // this sets Some(2) to null above
326        nulls.append_null();
327        nulls.append_non_null();
328        nulls.append_null();
329        nulls.append_non_null();
330        nulls.append_null();
331        let input_array = Arc::new(Float32Array::new(values, nulls.finish())) as ArrayRef;
332
333        // Check
334        let mut equal_to_results = vec![true; builder.len()];
335        equal_to(
336            &builder,
337            &[0, 1, 2, 3, 4, 5, 6],
338            &input_array,
339            &[0, 1, 2, 3, 4, 5, 6],
340            &mut equal_to_results,
341        );
342
343        assert!(!equal_to_results[0]);
344        assert!(equal_to_results[1]);
345        assert!(equal_to_results[2]);
346        assert!(equal_to_results[3]);
347        assert!(!equal_to_results[4]);
348        assert!(equal_to_results[5]);
349        assert!(!equal_to_results[6]);
350    }
351
352    #[test]
353    fn test_not_nullable_primitive_equal_to() {
354        let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, false>,
355                      builder_array: &ArrayRef,
356                      append_rows: &[usize]| {
357            for &index in append_rows {
358                builder.append_val(builder_array, index).unwrap();
359            }
360        };
361
362        let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
363                        lhs_rows: &[usize],
364                        input_array: &ArrayRef,
365                        rhs_rows: &[usize],
366                        equal_to_results: &mut Vec<bool>| {
367            let iter = lhs_rows.iter().zip(rhs_rows.iter());
368            for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
369                equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
370            }
371        };
372
373        test_not_nullable_primitive_equal_to_internal(append, equal_to);
374    }
375
376    #[test]
377    fn test_not_nullable_primitive_vectorized_equal_to() {
378        let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, false>,
379                      builder_array: &ArrayRef,
380                      append_rows: &[usize]| {
381            builder
382                .vectorized_append(builder_array, append_rows)
383                .unwrap();
384        };
385
386        let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
387                        lhs_rows: &[usize],
388                        input_array: &ArrayRef,
389                        rhs_rows: &[usize],
390                        equal_to_results: &mut Vec<bool>| {
391            builder.vectorized_equal_to(
392                lhs_rows,
393                input_array,
394                rhs_rows,
395                equal_to_results,
396            );
397        };
398
399        test_not_nullable_primitive_equal_to_internal(append, equal_to);
400    }
401
402    fn test_not_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
403    where
404        A: FnMut(&mut PrimitiveGroupValueBuilder<Int64Type, false>, &ArrayRef, &[usize]),
405        E: FnMut(
406            &PrimitiveGroupValueBuilder<Int64Type, false>,
407            &[usize],
408            &ArrayRef,
409            &[usize],
410            &mut Vec<bool>,
411        ),
412    {
413        // Will cover such cases:
414        //   - values equal
415        //   - values not equal
416
417        // Define PrimitiveGroupValueBuilder
418        let mut builder =
419            PrimitiveGroupValueBuilder::<Int64Type, false>::new(DataType::Int64);
420        let builder_array =
421            Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef;
422        append(&mut builder, &builder_array, &[0, 1]);
423
424        // Define input array
425        let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef;
426
427        // Check
428        let mut equal_to_results = vec![true; builder.len()];
429        equal_to(
430            &builder,
431            &[0, 1],
432            &input_array,
433            &[0, 1],
434            &mut equal_to_results,
435        );
436
437        assert!(equal_to_results[0]);
438        assert!(!equal_to_results[1]);
439    }
440
441    #[test]
442    fn test_nullable_primitive_vectorized_operation_special_case() {
443        // Test the special `all nulls` or `not nulls` input array case
444        // for vectorized append and equal to
445
446        let mut builder =
447            PrimitiveGroupValueBuilder::<Int64Type, true>::new(DataType::Int64);
448
449        // All nulls input array
450        let all_nulls_input_array = Arc::new(Int64Array::from(vec![
451            Option::<i64>::None,
452            None,
453            None,
454            None,
455            None,
456        ])) as _;
457        builder
458            .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4])
459            .unwrap();
460
461        let mut equal_to_results = vec![true; all_nulls_input_array.len()];
462        builder.vectorized_equal_to(
463            &[0, 1, 2, 3, 4],
464            &all_nulls_input_array,
465            &[0, 1, 2, 3, 4],
466            &mut equal_to_results,
467        );
468
469        assert!(equal_to_results[0]);
470        assert!(equal_to_results[1]);
471        assert!(equal_to_results[2]);
472        assert!(equal_to_results[3]);
473        assert!(equal_to_results[4]);
474
475        // All not nulls input array
476        let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![
477            Some(1),
478            Some(2),
479            Some(3),
480            Some(4),
481            Some(5),
482        ])) as _;
483        builder
484            .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4])
485            .unwrap();
486
487        let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
488        builder.vectorized_equal_to(
489            &[5, 6, 7, 8, 9],
490            &all_not_nulls_input_array,
491            &[0, 1, 2, 3, 4],
492            &mut equal_to_results,
493        );
494
495        assert!(equal_to_results[0]);
496        assert!(equal_to_results[1]);
497        assert!(equal_to_results[2]);
498        assert!(equal_to_results[3]);
499        assert!(equal_to_results[4]);
500    }
501}