datafusion_physical_plan/aggregates/group_values/multi_group_by/
bytes.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::aggregates::group_values::multi_group_by::{
19    nulls_equal_to, GroupColumn, Nulls,
20};
21use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
22use arrow::array::{
23    types::GenericStringType, Array, ArrayRef, AsArray, BufferBuilder,
24    GenericBinaryArray, GenericByteArray, GenericStringArray, OffsetSizeTrait,
25};
26use arrow::buffer::{OffsetBuffer, ScalarBuffer};
27use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType};
28use datafusion_common::utils::proxy::VecAllocExt;
29use datafusion_common::{exec_datafusion_err, Result};
30use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY};
31use itertools::izip;
32use std::mem::size_of;
33use std::sync::Arc;
34use std::vec;
35
36/// An implementation of [`GroupColumn`] for binary and utf8 types.
37///
38/// Stores a collection of binary or utf8 group values in a single buffer
39/// in a way that allows:
40///
41/// 1. Efficient comparison of incoming rows to existing rows
42/// 2. Efficient construction of the final output array
43pub struct ByteGroupValueBuilder<O>
44where
45    O: OffsetSizeTrait,
46{
47    output_type: OutputType,
48    buffer: BufferBuilder<u8>,
49    /// Offsets into `buffer` for each distinct value. These offsets as used
50    /// directly to create the final `GenericBinaryArray`. The `i`th string is
51    /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values
52    /// are stored as a zero length string.
53    offsets: Vec<O>,
54    /// Nulls
55    nulls: MaybeNullBufferBuilder,
56    /// The maximum size of the buffer for `0`
57    max_buffer_size: usize,
58}
59
60impl<O> ByteGroupValueBuilder<O>
61where
62    O: OffsetSizeTrait,
63{
64    pub fn new(output_type: OutputType) -> Self {
65        Self {
66            output_type,
67            buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY),
68            offsets: vec![O::default()],
69            nulls: MaybeNullBufferBuilder::new(),
70            max_buffer_size: if O::IS_LARGE {
71                i64::MAX as usize
72            } else {
73                i32::MAX as usize
74            },
75        }
76    }
77
78    fn equal_to_inner<B>(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool
79    where
80        B: ByteArrayType,
81    {
82        let array = array.as_bytes::<B>();
83        self.do_equal_to_inner(lhs_row, array, rhs_row)
84    }
85
86    fn append_val_inner<B>(&mut self, array: &ArrayRef, row: usize) -> Result<()>
87    where
88        B: ByteArrayType,
89    {
90        let arr = array.as_bytes::<B>();
91        if arr.is_null(row) {
92            self.nulls.append(true);
93            // nulls need a zero length in the offset buffer
94            let offset = self.buffer.len();
95            self.offsets.push(O::usize_as(offset));
96        } else {
97            self.nulls.append(false);
98            self.do_append_val_inner(arr, row)?;
99        }
100
101        Ok(())
102    }
103
104    fn vectorized_equal_to_inner<B>(
105        &self,
106        lhs_rows: &[usize],
107        array: &ArrayRef,
108        rhs_rows: &[usize],
109        equal_to_results: &mut [bool],
110    ) where
111        B: ByteArrayType,
112    {
113        let array = array.as_bytes::<B>();
114
115        let iter = izip!(
116            lhs_rows.iter(),
117            rhs_rows.iter(),
118            equal_to_results.iter_mut(),
119        );
120
121        for (&lhs_row, &rhs_row, equal_to_result) in iter {
122            // Has found not equal to, don't need to check
123            if !*equal_to_result {
124                continue;
125            }
126
127            *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row);
128        }
129    }
130
131    fn vectorized_append_inner<B>(
132        &mut self,
133        array: &ArrayRef,
134        rows: &[usize],
135    ) -> Result<()>
136    where
137        B: ByteArrayType,
138    {
139        let arr = array.as_bytes::<B>();
140        let null_count = array.null_count();
141        let num_rows = array.len();
142        let all_null_or_non_null = if null_count == 0 {
143            Nulls::None
144        } else if null_count == num_rows {
145            Nulls::All
146        } else {
147            Nulls::Some
148        };
149
150        match all_null_or_non_null {
151            Nulls::Some => {
152                for &row in rows {
153                    self.append_val_inner::<B>(array, row)?
154                }
155            }
156
157            Nulls::None => {
158                self.nulls.append_n(rows.len(), false);
159                for &row in rows {
160                    self.do_append_val_inner(arr, row)?;
161                }
162            }
163
164            Nulls::All => {
165                self.nulls.append_n(rows.len(), true);
166
167                let new_len = self.offsets.len() + rows.len();
168                let offset = self.buffer.len();
169                self.offsets.resize(new_len, O::usize_as(offset));
170            }
171        }
172
173        Ok(())
174    }
175
176    fn do_equal_to_inner<B>(
177        &self,
178        lhs_row: usize,
179        array: &GenericByteArray<B>,
180        rhs_row: usize,
181    ) -> bool
182    where
183        B: ByteArrayType,
184    {
185        let exist_null = self.nulls.is_null(lhs_row);
186        let input_null = array.is_null(rhs_row);
187        if let Some(result) = nulls_equal_to(exist_null, input_null) {
188            return result;
189        }
190        // Otherwise, we need to check their values
191        self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8])
192    }
193
194    fn do_append_val_inner<B>(
195        &mut self,
196        array: &GenericByteArray<B>,
197        row: usize,
198    ) -> Result<()>
199    where
200        B: ByteArrayType,
201    {
202        let value: &[u8] = array.value(row).as_ref();
203        self.buffer.append_slice(value);
204
205        if self.buffer.len() > self.max_buffer_size {
206            return Err(exec_datafusion_err!(
207                "offset overflow, buffer size > {}",
208                self.max_buffer_size
209            ));
210        }
211
212        self.offsets.push(O::usize_as(self.buffer.len()));
213        Ok(())
214    }
215
216    /// return the current value of the specified row irrespective of null
217    pub fn value(&self, row: usize) -> &[u8] {
218        let l = self.offsets[row].as_usize();
219        let r = self.offsets[row + 1].as_usize();
220        // Safety: the offsets are constructed correctly and never decrease
221        unsafe { self.buffer.as_slice().get_unchecked(l..r) }
222    }
223}
224
225impl<O> GroupColumn for ByteGroupValueBuilder<O>
226where
227    O: OffsetSizeTrait,
228{
229    fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool {
230        // Sanity array type
231        match self.output_type {
232            OutputType::Binary => {
233                debug_assert!(matches!(
234                    column.data_type(),
235                    DataType::Binary | DataType::LargeBinary
236                ));
237                self.equal_to_inner::<GenericBinaryType<O>>(lhs_row, column, rhs_row)
238            }
239            OutputType::Utf8 => {
240                debug_assert!(matches!(
241                    column.data_type(),
242                    DataType::Utf8 | DataType::LargeUtf8
243                ));
244                self.equal_to_inner::<GenericStringType<O>>(lhs_row, column, rhs_row)
245            }
246            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
247        }
248    }
249
250    fn append_val(&mut self, column: &ArrayRef, row: usize) -> Result<()> {
251        // Sanity array type
252        match self.output_type {
253            OutputType::Binary => {
254                debug_assert!(matches!(
255                    column.data_type(),
256                    DataType::Binary | DataType::LargeBinary
257                ));
258                self.append_val_inner::<GenericBinaryType<O>>(column, row)?
259            }
260            OutputType::Utf8 => {
261                debug_assert!(matches!(
262                    column.data_type(),
263                    DataType::Utf8 | DataType::LargeUtf8
264                ));
265                self.append_val_inner::<GenericStringType<O>>(column, row)?
266            }
267            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
268        };
269
270        Ok(())
271    }
272
273    fn vectorized_equal_to(
274        &self,
275        lhs_rows: &[usize],
276        array: &ArrayRef,
277        rhs_rows: &[usize],
278        equal_to_results: &mut [bool],
279    ) {
280        // Sanity array type
281        match self.output_type {
282            OutputType::Binary => {
283                debug_assert!(matches!(
284                    array.data_type(),
285                    DataType::Binary | DataType::LargeBinary
286                ));
287                self.vectorized_equal_to_inner::<GenericBinaryType<O>>(
288                    lhs_rows,
289                    array,
290                    rhs_rows,
291                    equal_to_results,
292                );
293            }
294            OutputType::Utf8 => {
295                debug_assert!(matches!(
296                    array.data_type(),
297                    DataType::Utf8 | DataType::LargeUtf8
298                ));
299                self.vectorized_equal_to_inner::<GenericStringType<O>>(
300                    lhs_rows,
301                    array,
302                    rhs_rows,
303                    equal_to_results,
304                );
305            }
306            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
307        }
308    }
309
310    fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) -> Result<()> {
311        match self.output_type {
312            OutputType::Binary => {
313                debug_assert!(matches!(
314                    column.data_type(),
315                    DataType::Binary | DataType::LargeBinary
316                ));
317                self.vectorized_append_inner::<GenericBinaryType<O>>(column, rows)?
318            }
319            OutputType::Utf8 => {
320                debug_assert!(matches!(
321                    column.data_type(),
322                    DataType::Utf8 | DataType::LargeUtf8
323                ));
324                self.vectorized_append_inner::<GenericStringType<O>>(column, rows)?
325            }
326            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
327        };
328
329        Ok(())
330    }
331
332    fn len(&self) -> usize {
333        self.offsets.len() - 1
334    }
335
336    fn size(&self) -> usize {
337        self.buffer.capacity() * size_of::<u8>()
338            + self.offsets.allocated_size()
339            + self.nulls.allocated_size()
340    }
341
342    fn build(self: Box<Self>) -> ArrayRef {
343        let Self {
344            output_type,
345            mut buffer,
346            offsets,
347            nulls,
348            ..
349        } = *self;
350
351        let null_buffer = nulls.build();
352
353        // SAFETY: the offsets were constructed correctly in `insert_if_new` --
354        // monotonically increasing, overflows were checked.
355        let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
356        let values = buffer.finish();
357        match output_type {
358            OutputType::Binary => {
359                // SAFETY: the offsets were constructed correctly
360                Arc::new(unsafe {
361                    GenericBinaryArray::new_unchecked(offsets, values, null_buffer)
362                })
363            }
364            OutputType::Utf8 => {
365                // SAFETY:
366                // 1. the offsets were constructed safely
367                //
368                // 2. the input arrays were all the correct type and thus since
369                // all the values that went in were valid (e.g. utf8) so are all
370                // the values that come out
371                Arc::new(unsafe {
372                    GenericStringArray::new_unchecked(offsets, values, null_buffer)
373                })
374            }
375            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
376        }
377    }
378
379    fn take_n(&mut self, n: usize) -> ArrayRef {
380        debug_assert!(self.len() >= n);
381        let null_buffer = self.nulls.take_n(n);
382        let first_remaining_offset = O::as_usize(self.offsets[n]);
383
384        // Given offsets like [0, 2, 4, 5] and n = 1, we expect to get
385        // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5].
386        // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3].
387        let mut first_n_offsets = self.offsets.drain(0..n).collect::<Vec<_>>();
388        let offset_n = *self.offsets.first().unwrap();
389        self.offsets
390            .iter_mut()
391            .for_each(|offset| *offset = offset.sub(offset_n));
392        first_n_offsets.push(offset_n);
393
394        // SAFETY: the offsets were constructed correctly in `insert_if_new` --
395        // monotonically increasing, overflows were checked.
396        let offsets =
397            unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(first_n_offsets)) };
398
399        let mut remaining_buffer =
400            BufferBuilder::new(self.buffer.len() - first_remaining_offset);
401        // TODO: Current approach copy the remaining and truncate the original one
402        // Find out a way to avoid copying buffer but split the original one into two.
403        remaining_buffer.append_slice(&self.buffer.as_slice()[first_remaining_offset..]);
404        self.buffer.truncate(first_remaining_offset);
405        let values = self.buffer.finish();
406        self.buffer = remaining_buffer;
407
408        match self.output_type {
409            OutputType::Binary => {
410                // SAFETY: the offsets were constructed correctly
411                Arc::new(unsafe {
412                    GenericBinaryArray::new_unchecked(offsets, values, null_buffer)
413                })
414            }
415            OutputType::Utf8 => {
416                // SAFETY:
417                // 1. the offsets were constructed safely
418                //
419                // 2. we asserted the input arrays were all the correct type and
420                // thus since all the values that went in were valid (e.g. utf8)
421                // so are all the values that come out
422                Arc::new(unsafe {
423                    GenericStringArray::new_unchecked(offsets, values, null_buffer)
424                })
425            }
426            _ => unreachable!("View types should use `ArrowBytesViewMap`"),
427        }
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use std::sync::Arc;
434
435    use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder;
436    use arrow::array::{ArrayRef, NullBufferBuilder, StringArray};
437    use datafusion_common::DataFusionError;
438    use datafusion_physical_expr::binary_map::OutputType;
439
440    use super::GroupColumn;
441
442    #[test]
443    fn test_byte_group_value_builder_overflow() {
444        let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
445
446        let large_string = "a".repeat(1024 * 1024);
447
448        let array =
449            Arc::new(StringArray::from(vec![Some(large_string.as_str())])) as ArrayRef;
450
451        // Append items until our buffer length is i32::MAX as usize
452        for _ in 0..2047 {
453            builder.append_val(&array, 0).unwrap();
454        }
455
456        assert!(matches!(
457            builder.append_val(&array, 0),
458            Err(DataFusionError::Execution(e)) if e.contains("offset overflow")
459        ));
460
461        assert_eq!(builder.value(2046), large_string.as_bytes());
462    }
463
464    #[test]
465    fn test_byte_take_n() {
466        let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
467        let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef;
468        // a, null, null
469        builder.append_val(&array, 0).unwrap();
470        builder.append_val(&array, 1).unwrap();
471        builder.append_val(&array, 1).unwrap();
472
473        // (a, null) remaining: null
474        let output = builder.take_n(2);
475        assert_eq!(&output, &array);
476
477        // null, a, null, a
478        builder.append_val(&array, 0).unwrap();
479        builder.append_val(&array, 1).unwrap();
480        builder.append_val(&array, 0).unwrap();
481
482        // (null, a) remaining: (null, a)
483        let output = builder.take_n(2);
484        let array = Arc::new(StringArray::from(vec![None, Some("a")])) as ArrayRef;
485        assert_eq!(&output, &array);
486
487        let array = Arc::new(StringArray::from(vec![
488            Some("a"),
489            None,
490            Some("longstringfortest"),
491        ])) as ArrayRef;
492
493        // null, a, longstringfortest, null, null
494        builder.append_val(&array, 2).unwrap();
495        builder.append_val(&array, 1).unwrap();
496        builder.append_val(&array, 1).unwrap();
497
498        // (null, a, longstringfortest, null) remaining: (null)
499        let output = builder.take_n(4);
500        let array = Arc::new(StringArray::from(vec![
501            None,
502            Some("a"),
503            Some("longstringfortest"),
504            None,
505        ])) as ArrayRef;
506        assert_eq!(&output, &array);
507    }
508
509    #[test]
510    fn test_byte_equal_to() {
511        let append = |builder: &mut ByteGroupValueBuilder<i32>,
512                      builder_array: &ArrayRef,
513                      append_rows: &[usize]| {
514            for &index in append_rows {
515                builder.append_val(builder_array, index).unwrap();
516            }
517        };
518
519        let equal_to = |builder: &ByteGroupValueBuilder<i32>,
520                        lhs_rows: &[usize],
521                        input_array: &ArrayRef,
522                        rhs_rows: &[usize],
523                        equal_to_results: &mut Vec<bool>| {
524            let iter = lhs_rows.iter().zip(rhs_rows.iter());
525            for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
526                equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
527            }
528        };
529
530        test_byte_equal_to_internal(append, equal_to);
531    }
532
533    #[test]
534    fn test_byte_vectorized_equal_to() {
535        let append = |builder: &mut ByteGroupValueBuilder<i32>,
536                      builder_array: &ArrayRef,
537                      append_rows: &[usize]| {
538            builder
539                .vectorized_append(builder_array, append_rows)
540                .unwrap();
541        };
542
543        let equal_to = |builder: &ByteGroupValueBuilder<i32>,
544                        lhs_rows: &[usize],
545                        input_array: &ArrayRef,
546                        rhs_rows: &[usize],
547                        equal_to_results: &mut Vec<bool>| {
548            builder.vectorized_equal_to(
549                lhs_rows,
550                input_array,
551                rhs_rows,
552                equal_to_results,
553            );
554        };
555
556        test_byte_equal_to_internal(append, equal_to);
557    }
558
559    #[test]
560    fn test_byte_vectorized_operation_special_case() {
561        // Test the special `all nulls` or `not nulls` input array case
562        // for vectorized append and equal to
563
564        let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
565
566        // All nulls input array
567        let all_nulls_input_array = Arc::new(StringArray::from(vec![
568            Option::<&str>::None,
569            None,
570            None,
571            None,
572            None,
573        ])) as _;
574        builder
575            .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4])
576            .unwrap();
577
578        let mut equal_to_results = vec![true; all_nulls_input_array.len()];
579        builder.vectorized_equal_to(
580            &[0, 1, 2, 3, 4],
581            &all_nulls_input_array,
582            &[0, 1, 2, 3, 4],
583            &mut equal_to_results,
584        );
585
586        assert!(equal_to_results[0]);
587        assert!(equal_to_results[1]);
588        assert!(equal_to_results[2]);
589        assert!(equal_to_results[3]);
590        assert!(equal_to_results[4]);
591
592        // All not nulls input array
593        let all_not_nulls_input_array = Arc::new(StringArray::from(vec![
594            Some("string1"),
595            Some("string2"),
596            Some("string3"),
597            Some("string4"),
598            Some("string5"),
599        ])) as _;
600        builder
601            .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4])
602            .unwrap();
603
604        let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
605        builder.vectorized_equal_to(
606            &[5, 6, 7, 8, 9],
607            &all_not_nulls_input_array,
608            &[0, 1, 2, 3, 4],
609            &mut equal_to_results,
610        );
611
612        assert!(equal_to_results[0]);
613        assert!(equal_to_results[1]);
614        assert!(equal_to_results[2]);
615        assert!(equal_to_results[3]);
616        assert!(equal_to_results[4]);
617    }
618
619    fn test_byte_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
620    where
621        A: FnMut(&mut ByteGroupValueBuilder<i32>, &ArrayRef, &[usize]),
622        E: FnMut(
623            &ByteGroupValueBuilder<i32>,
624            &[usize],
625            &ArrayRef,
626            &[usize],
627            &mut Vec<bool>,
628        ),
629    {
630        // Will cover such cases:
631        //   - exist null, input not null
632        //   - exist null, input null; values not equal
633        //   - exist null, input null; values equal
634        //   - exist not null, input null
635        //   - exist not null, input not null; values not equal
636        //   - exist not null, input not null; values equal
637
638        // Define ByteGroupValueBuilder
639        let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
640        let builder_array = Arc::new(StringArray::from(vec![
641            None,
642            None,
643            None,
644            Some("foo"),
645            Some("bar"),
646            Some("baz"),
647        ])) as ArrayRef;
648        append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]);
649
650        // Define input array
651        let (offsets, buffer, _nulls) = StringArray::from(vec![
652            Some("foo"),
653            Some("bar"),
654            None,
655            None,
656            Some("foo"),
657            Some("baz"),
658        ])
659        .into_parts();
660
661        // explicitly build a boolean buffer where one of the null values also happens to match
662        let mut nulls = NullBufferBuilder::new(6);
663        nulls.append_non_null();
664        nulls.append_null(); // this sets Some("bar") to null above
665        nulls.append_null();
666        nulls.append_null();
667        nulls.append_non_null();
668        nulls.append_non_null();
669        let input_array =
670            Arc::new(StringArray::new(offsets, buffer, nulls.finish())) as ArrayRef;
671
672        // Check
673        let mut equal_to_results = vec![true; builder.len()];
674        equal_to(
675            &builder,
676            &[0, 1, 2, 3, 4, 5],
677            &input_array,
678            &[0, 1, 2, 3, 4, 5],
679            &mut equal_to_results,
680        );
681
682        assert!(!equal_to_results[0]);
683        assert!(equal_to_results[1]);
684        assert!(equal_to_results[2]);
685        assert!(!equal_to_results[3]);
686        assert!(!equal_to_results[4]);
687        assert!(equal_to_results[5]);
688    }
689}