1use crate::aggregates::group_values::multi_group_by::{
19 nulls_equal_to, GroupColumn, Nulls,
20};
21use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
22use arrow::array::{
23 types::GenericStringType, Array, ArrayRef, AsArray, BufferBuilder,
24 GenericBinaryArray, GenericByteArray, GenericStringArray, OffsetSizeTrait,
25};
26use arrow::buffer::{OffsetBuffer, ScalarBuffer};
27use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType};
28use datafusion_common::utils::proxy::VecAllocExt;
29use datafusion_common::{exec_datafusion_err, Result};
30use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY};
31use itertools::izip;
32use std::mem::size_of;
33use std::sync::Arc;
34use std::vec;
35
36pub struct ByteGroupValueBuilder<O>
44where
45 O: OffsetSizeTrait,
46{
47 output_type: OutputType,
48 buffer: BufferBuilder<u8>,
49 offsets: Vec<O>,
54 nulls: MaybeNullBufferBuilder,
56 max_buffer_size: usize,
58}
59
60impl<O> ByteGroupValueBuilder<O>
61where
62 O: OffsetSizeTrait,
63{
64 pub fn new(output_type: OutputType) -> Self {
65 Self {
66 output_type,
67 buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY),
68 offsets: vec![O::default()],
69 nulls: MaybeNullBufferBuilder::new(),
70 max_buffer_size: if O::IS_LARGE {
71 i64::MAX as usize
72 } else {
73 i32::MAX as usize
74 },
75 }
76 }
77
78 fn equal_to_inner<B>(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool
79 where
80 B: ByteArrayType,
81 {
82 let array = array.as_bytes::<B>();
83 self.do_equal_to_inner(lhs_row, array, rhs_row)
84 }
85
86 fn append_val_inner<B>(&mut self, array: &ArrayRef, row: usize) -> Result<()>
87 where
88 B: ByteArrayType,
89 {
90 let arr = array.as_bytes::<B>();
91 if arr.is_null(row) {
92 self.nulls.append(true);
93 let offset = self.buffer.len();
95 self.offsets.push(O::usize_as(offset));
96 } else {
97 self.nulls.append(false);
98 self.do_append_val_inner(arr, row)?;
99 }
100
101 Ok(())
102 }
103
104 fn vectorized_equal_to_inner<B>(
105 &self,
106 lhs_rows: &[usize],
107 array: &ArrayRef,
108 rhs_rows: &[usize],
109 equal_to_results: &mut [bool],
110 ) where
111 B: ByteArrayType,
112 {
113 let array = array.as_bytes::<B>();
114
115 let iter = izip!(
116 lhs_rows.iter(),
117 rhs_rows.iter(),
118 equal_to_results.iter_mut(),
119 );
120
121 for (&lhs_row, &rhs_row, equal_to_result) in iter {
122 if !*equal_to_result {
124 continue;
125 }
126
127 *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row);
128 }
129 }
130
131 fn vectorized_append_inner<B>(
132 &mut self,
133 array: &ArrayRef,
134 rows: &[usize],
135 ) -> Result<()>
136 where
137 B: ByteArrayType,
138 {
139 let arr = array.as_bytes::<B>();
140 let null_count = array.null_count();
141 let num_rows = array.len();
142 let all_null_or_non_null = if null_count == 0 {
143 Nulls::None
144 } else if null_count == num_rows {
145 Nulls::All
146 } else {
147 Nulls::Some
148 };
149
150 match all_null_or_non_null {
151 Nulls::Some => {
152 for &row in rows {
153 self.append_val_inner::<B>(array, row)?
154 }
155 }
156
157 Nulls::None => {
158 self.nulls.append_n(rows.len(), false);
159 for &row in rows {
160 self.do_append_val_inner(arr, row)?;
161 }
162 }
163
164 Nulls::All => {
165 self.nulls.append_n(rows.len(), true);
166
167 let new_len = self.offsets.len() + rows.len();
168 let offset = self.buffer.len();
169 self.offsets.resize(new_len, O::usize_as(offset));
170 }
171 }
172
173 Ok(())
174 }
175
176 fn do_equal_to_inner<B>(
177 &self,
178 lhs_row: usize,
179 array: &GenericByteArray<B>,
180 rhs_row: usize,
181 ) -> bool
182 where
183 B: ByteArrayType,
184 {
185 let exist_null = self.nulls.is_null(lhs_row);
186 let input_null = array.is_null(rhs_row);
187 if let Some(result) = nulls_equal_to(exist_null, input_null) {
188 return result;
189 }
190 self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8])
192 }
193
194 fn do_append_val_inner<B>(
195 &mut self,
196 array: &GenericByteArray<B>,
197 row: usize,
198 ) -> Result<()>
199 where
200 B: ByteArrayType,
201 {
202 let value: &[u8] = array.value(row).as_ref();
203 self.buffer.append_slice(value);
204
205 if self.buffer.len() > self.max_buffer_size {
206 return Err(exec_datafusion_err!(
207 "offset overflow, buffer size > {}",
208 self.max_buffer_size
209 ));
210 }
211
212 self.offsets.push(O::usize_as(self.buffer.len()));
213 Ok(())
214 }
215
216 pub fn value(&self, row: usize) -> &[u8] {
218 let l = self.offsets[row].as_usize();
219 let r = self.offsets[row + 1].as_usize();
220 unsafe { self.buffer.as_slice().get_unchecked(l..r) }
222 }
223}
224
225impl<O> GroupColumn for ByteGroupValueBuilder<O>
226where
227 O: OffsetSizeTrait,
228{
229 fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool {
230 match self.output_type {
232 OutputType::Binary => {
233 debug_assert!(matches!(
234 column.data_type(),
235 DataType::Binary | DataType::LargeBinary
236 ));
237 self.equal_to_inner::<GenericBinaryType<O>>(lhs_row, column, rhs_row)
238 }
239 OutputType::Utf8 => {
240 debug_assert!(matches!(
241 column.data_type(),
242 DataType::Utf8 | DataType::LargeUtf8
243 ));
244 self.equal_to_inner::<GenericStringType<O>>(lhs_row, column, rhs_row)
245 }
246 _ => unreachable!("View types should use `ArrowBytesViewMap`"),
247 }
248 }
249
250 fn append_val(&mut self, column: &ArrayRef, row: usize) -> Result<()> {
251 match self.output_type {
253 OutputType::Binary => {
254 debug_assert!(matches!(
255 column.data_type(),
256 DataType::Binary | DataType::LargeBinary
257 ));
258 self.append_val_inner::<GenericBinaryType<O>>(column, row)?
259 }
260 OutputType::Utf8 => {
261 debug_assert!(matches!(
262 column.data_type(),
263 DataType::Utf8 | DataType::LargeUtf8
264 ));
265 self.append_val_inner::<GenericStringType<O>>(column, row)?
266 }
267 _ => unreachable!("View types should use `ArrowBytesViewMap`"),
268 };
269
270 Ok(())
271 }
272
273 fn vectorized_equal_to(
274 &self,
275 lhs_rows: &[usize],
276 array: &ArrayRef,
277 rhs_rows: &[usize],
278 equal_to_results: &mut [bool],
279 ) {
280 match self.output_type {
282 OutputType::Binary => {
283 debug_assert!(matches!(
284 array.data_type(),
285 DataType::Binary | DataType::LargeBinary
286 ));
287 self.vectorized_equal_to_inner::<GenericBinaryType<O>>(
288 lhs_rows,
289 array,
290 rhs_rows,
291 equal_to_results,
292 );
293 }
294 OutputType::Utf8 => {
295 debug_assert!(matches!(
296 array.data_type(),
297 DataType::Utf8 | DataType::LargeUtf8
298 ));
299 self.vectorized_equal_to_inner::<GenericStringType<O>>(
300 lhs_rows,
301 array,
302 rhs_rows,
303 equal_to_results,
304 );
305 }
306 _ => unreachable!("View types should use `ArrowBytesViewMap`"),
307 }
308 }
309
310 fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) -> Result<()> {
311 match self.output_type {
312 OutputType::Binary => {
313 debug_assert!(matches!(
314 column.data_type(),
315 DataType::Binary | DataType::LargeBinary
316 ));
317 self.vectorized_append_inner::<GenericBinaryType<O>>(column, rows)?
318 }
319 OutputType::Utf8 => {
320 debug_assert!(matches!(
321 column.data_type(),
322 DataType::Utf8 | DataType::LargeUtf8
323 ));
324 self.vectorized_append_inner::<GenericStringType<O>>(column, rows)?
325 }
326 _ => unreachable!("View types should use `ArrowBytesViewMap`"),
327 };
328
329 Ok(())
330 }
331
332 fn len(&self) -> usize {
333 self.offsets.len() - 1
334 }
335
336 fn size(&self) -> usize {
337 self.buffer.capacity() * size_of::<u8>()
338 + self.offsets.allocated_size()
339 + self.nulls.allocated_size()
340 }
341
342 fn build(self: Box<Self>) -> ArrayRef {
343 let Self {
344 output_type,
345 mut buffer,
346 offsets,
347 nulls,
348 ..
349 } = *self;
350
351 let null_buffer = nulls.build();
352
353 let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) };
356 let values = buffer.finish();
357 match output_type {
358 OutputType::Binary => {
359 Arc::new(unsafe {
361 GenericBinaryArray::new_unchecked(offsets, values, null_buffer)
362 })
363 }
364 OutputType::Utf8 => {
365 Arc::new(unsafe {
372 GenericStringArray::new_unchecked(offsets, values, null_buffer)
373 })
374 }
375 _ => unreachable!("View types should use `ArrowBytesViewMap`"),
376 }
377 }
378
379 fn take_n(&mut self, n: usize) -> ArrayRef {
380 debug_assert!(self.len() >= n);
381 let null_buffer = self.nulls.take_n(n);
382 let first_remaining_offset = O::as_usize(self.offsets[n]);
383
384 let mut first_n_offsets = self.offsets.drain(0..n).collect::<Vec<_>>();
388 let offset_n = *self.offsets.first().unwrap();
389 self.offsets
390 .iter_mut()
391 .for_each(|offset| *offset = offset.sub(offset_n));
392 first_n_offsets.push(offset_n);
393
394 let offsets =
397 unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(first_n_offsets)) };
398
399 let mut remaining_buffer =
400 BufferBuilder::new(self.buffer.len() - first_remaining_offset);
401 remaining_buffer.append_slice(&self.buffer.as_slice()[first_remaining_offset..]);
404 self.buffer.truncate(first_remaining_offset);
405 let values = self.buffer.finish();
406 self.buffer = remaining_buffer;
407
408 match self.output_type {
409 OutputType::Binary => {
410 Arc::new(unsafe {
412 GenericBinaryArray::new_unchecked(offsets, values, null_buffer)
413 })
414 }
415 OutputType::Utf8 => {
416 Arc::new(unsafe {
423 GenericStringArray::new_unchecked(offsets, values, null_buffer)
424 })
425 }
426 _ => unreachable!("View types should use `ArrowBytesViewMap`"),
427 }
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use std::sync::Arc;
434
435 use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder;
436 use arrow::array::{ArrayRef, NullBufferBuilder, StringArray};
437 use datafusion_common::DataFusionError;
438 use datafusion_physical_expr::binary_map::OutputType;
439
440 use super::GroupColumn;
441
442 #[test]
443 fn test_byte_group_value_builder_overflow() {
444 let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
445
446 let large_string = "a".repeat(1024 * 1024);
447
448 let array =
449 Arc::new(StringArray::from(vec![Some(large_string.as_str())])) as ArrayRef;
450
451 for _ in 0..2047 {
453 builder.append_val(&array, 0).unwrap();
454 }
455
456 assert!(matches!(
457 builder.append_val(&array, 0),
458 Err(DataFusionError::Execution(e)) if e.contains("offset overflow")
459 ));
460
461 assert_eq!(builder.value(2046), large_string.as_bytes());
462 }
463
464 #[test]
465 fn test_byte_take_n() {
466 let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
467 let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef;
468 builder.append_val(&array, 0).unwrap();
470 builder.append_val(&array, 1).unwrap();
471 builder.append_val(&array, 1).unwrap();
472
473 let output = builder.take_n(2);
475 assert_eq!(&output, &array);
476
477 builder.append_val(&array, 0).unwrap();
479 builder.append_val(&array, 1).unwrap();
480 builder.append_val(&array, 0).unwrap();
481
482 let output = builder.take_n(2);
484 let array = Arc::new(StringArray::from(vec![None, Some("a")])) as ArrayRef;
485 assert_eq!(&output, &array);
486
487 let array = Arc::new(StringArray::from(vec![
488 Some("a"),
489 None,
490 Some("longstringfortest"),
491 ])) as ArrayRef;
492
493 builder.append_val(&array, 2).unwrap();
495 builder.append_val(&array, 1).unwrap();
496 builder.append_val(&array, 1).unwrap();
497
498 let output = builder.take_n(4);
500 let array = Arc::new(StringArray::from(vec![
501 None,
502 Some("a"),
503 Some("longstringfortest"),
504 None,
505 ])) as ArrayRef;
506 assert_eq!(&output, &array);
507 }
508
509 #[test]
510 fn test_byte_equal_to() {
511 let append = |builder: &mut ByteGroupValueBuilder<i32>,
512 builder_array: &ArrayRef,
513 append_rows: &[usize]| {
514 for &index in append_rows {
515 builder.append_val(builder_array, index).unwrap();
516 }
517 };
518
519 let equal_to = |builder: &ByteGroupValueBuilder<i32>,
520 lhs_rows: &[usize],
521 input_array: &ArrayRef,
522 rhs_rows: &[usize],
523 equal_to_results: &mut Vec<bool>| {
524 let iter = lhs_rows.iter().zip(rhs_rows.iter());
525 for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
526 equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
527 }
528 };
529
530 test_byte_equal_to_internal(append, equal_to);
531 }
532
533 #[test]
534 fn test_byte_vectorized_equal_to() {
535 let append = |builder: &mut ByteGroupValueBuilder<i32>,
536 builder_array: &ArrayRef,
537 append_rows: &[usize]| {
538 builder
539 .vectorized_append(builder_array, append_rows)
540 .unwrap();
541 };
542
543 let equal_to = |builder: &ByteGroupValueBuilder<i32>,
544 lhs_rows: &[usize],
545 input_array: &ArrayRef,
546 rhs_rows: &[usize],
547 equal_to_results: &mut Vec<bool>| {
548 builder.vectorized_equal_to(
549 lhs_rows,
550 input_array,
551 rhs_rows,
552 equal_to_results,
553 );
554 };
555
556 test_byte_equal_to_internal(append, equal_to);
557 }
558
559 #[test]
560 fn test_byte_vectorized_operation_special_case() {
561 let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
565
566 let all_nulls_input_array = Arc::new(StringArray::from(vec![
568 Option::<&str>::None,
569 None,
570 None,
571 None,
572 None,
573 ])) as _;
574 builder
575 .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4])
576 .unwrap();
577
578 let mut equal_to_results = vec![true; all_nulls_input_array.len()];
579 builder.vectorized_equal_to(
580 &[0, 1, 2, 3, 4],
581 &all_nulls_input_array,
582 &[0, 1, 2, 3, 4],
583 &mut equal_to_results,
584 );
585
586 assert!(equal_to_results[0]);
587 assert!(equal_to_results[1]);
588 assert!(equal_to_results[2]);
589 assert!(equal_to_results[3]);
590 assert!(equal_to_results[4]);
591
592 let all_not_nulls_input_array = Arc::new(StringArray::from(vec![
594 Some("string1"),
595 Some("string2"),
596 Some("string3"),
597 Some("string4"),
598 Some("string5"),
599 ])) as _;
600 builder
601 .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4])
602 .unwrap();
603
604 let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
605 builder.vectorized_equal_to(
606 &[5, 6, 7, 8, 9],
607 &all_not_nulls_input_array,
608 &[0, 1, 2, 3, 4],
609 &mut equal_to_results,
610 );
611
612 assert!(equal_to_results[0]);
613 assert!(equal_to_results[1]);
614 assert!(equal_to_results[2]);
615 assert!(equal_to_results[3]);
616 assert!(equal_to_results[4]);
617 }
618
619 fn test_byte_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
620 where
621 A: FnMut(&mut ByteGroupValueBuilder<i32>, &ArrayRef, &[usize]),
622 E: FnMut(
623 &ByteGroupValueBuilder<i32>,
624 &[usize],
625 &ArrayRef,
626 &[usize],
627 &mut Vec<bool>,
628 ),
629 {
630 let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
640 let builder_array = Arc::new(StringArray::from(vec![
641 None,
642 None,
643 None,
644 Some("foo"),
645 Some("bar"),
646 Some("baz"),
647 ])) as ArrayRef;
648 append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5]);
649
650 let (offsets, buffer, _nulls) = StringArray::from(vec![
652 Some("foo"),
653 Some("bar"),
654 None,
655 None,
656 Some("foo"),
657 Some("baz"),
658 ])
659 .into_parts();
660
661 let mut nulls = NullBufferBuilder::new(6);
663 nulls.append_non_null();
664 nulls.append_null(); nulls.append_null();
666 nulls.append_null();
667 nulls.append_non_null();
668 nulls.append_non_null();
669 let input_array =
670 Arc::new(StringArray::new(offsets, buffer, nulls.finish())) as ArrayRef;
671
672 let mut equal_to_results = vec![true; builder.len()];
674 equal_to(
675 &builder,
676 &[0, 1, 2, 3, 4, 5],
677 &input_array,
678 &[0, 1, 2, 3, 4, 5],
679 &mut equal_to_results,
680 );
681
682 assert!(!equal_to_results[0]);
683 assert!(equal_to_results[1]);
684 assert!(equal_to_results[2]);
685 assert!(!equal_to_results[3]);
686 assert!(!equal_to_results[4]);
687 assert!(equal_to_results[5]);
688 }
689}