1use crate::aggregates::group_values::multi_group_by::{
19 nulls_equal_to, GroupColumn, Nulls,
20};
21use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder;
22use arrow::array::ArrowNativeTypeOp;
23use arrow::array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray};
24use arrow::buffer::ScalarBuffer;
25use arrow::datatypes::DataType;
26use datafusion_common::Result;
27use datafusion_execution::memory_pool::proxy::VecAllocExt;
28use itertools::izip;
29use std::iter;
30use std::sync::Arc;
31
32#[derive(Debug)]
41pub struct PrimitiveGroupValueBuilder<T: ArrowPrimitiveType, const NULLABLE: bool> {
42 data_type: DataType,
43 group_values: Vec<T::Native>,
44 nulls: MaybeNullBufferBuilder,
45}
46
47impl<T, const NULLABLE: bool> PrimitiveGroupValueBuilder<T, NULLABLE>
48where
49 T: ArrowPrimitiveType,
50{
51 pub fn new(data_type: DataType) -> Self {
53 Self {
54 data_type,
55 group_values: vec![],
56 nulls: MaybeNullBufferBuilder::new(),
57 }
58 }
59}
60
61impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
62 for PrimitiveGroupValueBuilder<T, NULLABLE>
63{
64 fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
65 if NULLABLE {
67 let exist_null = self.nulls.is_null(lhs_row);
68 let input_null = array.is_null(rhs_row);
69 if let Some(result) = nulls_equal_to(exist_null, input_null) {
70 return result;
71 }
72 }
74
75 self.group_values[lhs_row].is_eq(array.as_primitive::<T>().value(rhs_row))
76 }
77
78 fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> {
79 if NULLABLE {
81 if array.is_null(row) {
82 self.nulls.append(true);
83 self.group_values.push(T::default_value());
84 } else {
85 self.nulls.append(false);
86 self.group_values.push(array.as_primitive::<T>().value(row));
87 }
88 } else {
89 self.group_values.push(array.as_primitive::<T>().value(row));
90 }
91
92 Ok(())
93 }
94
95 fn vectorized_equal_to(
96 &self,
97 lhs_rows: &[usize],
98 array: &ArrayRef,
99 rhs_rows: &[usize],
100 equal_to_results: &mut [bool],
101 ) {
102 let array = array.as_primitive::<T>();
103
104 let iter = izip!(
105 lhs_rows.iter(),
106 rhs_rows.iter(),
107 equal_to_results.iter_mut(),
108 );
109
110 for (&lhs_row, &rhs_row, equal_to_result) in iter {
111 if !*equal_to_result {
113 continue;
114 }
115
116 if NULLABLE {
118 let exist_null = self.nulls.is_null(lhs_row);
119 let input_null = array.is_null(rhs_row);
120 if let Some(result) = nulls_equal_to(exist_null, input_null) {
121 *equal_to_result = result;
122 continue;
123 }
124 }
126
127 *equal_to_result = self.group_values[lhs_row].is_eq(array.value(rhs_row));
128 }
129 }
130
131 fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> {
132 let arr = array.as_primitive::<T>();
133
134 let null_count = array.null_count();
135 let num_rows = array.len();
136 let all_null_or_non_null = if null_count == 0 {
137 Nulls::None
138 } else if null_count == num_rows {
139 Nulls::All
140 } else {
141 Nulls::Some
142 };
143
144 match (NULLABLE, all_null_or_non_null) {
145 (true, Nulls::Some) => {
146 for &row in rows {
147 if array.is_null(row) {
148 self.nulls.append(true);
149 self.group_values.push(T::default_value());
150 } else {
151 self.nulls.append(false);
152 self.group_values.push(arr.value(row));
153 }
154 }
155 }
156
157 (true, Nulls::None) => {
158 self.nulls.append_n(rows.len(), false);
159 for &row in rows {
160 self.group_values.push(arr.value(row));
161 }
162 }
163
164 (true, Nulls::All) => {
165 self.nulls.append_n(rows.len(), true);
166 self.group_values
167 .extend(iter::repeat_n(T::default_value(), rows.len()));
168 }
169
170 (false, _) => {
171 for &row in rows {
172 self.group_values.push(arr.value(row));
173 }
174 }
175 }
176
177 Ok(())
178 }
179
180 fn len(&self) -> usize {
181 self.group_values.len()
182 }
183
184 fn size(&self) -> usize {
185 self.group_values.allocated_size() + self.nulls.allocated_size()
186 }
187
188 fn build(self: Box<Self>) -> ArrayRef {
189 let Self {
190 data_type,
191 group_values,
192 nulls,
193 } = *self;
194
195 let nulls = nulls.build();
196 if !NULLABLE {
197 assert!(nulls.is_none(), "unexpected nulls in non nullable input");
198 }
199
200 let arr = PrimitiveArray::<T>::new(ScalarBuffer::from(group_values), nulls);
201 Arc::new(arr.with_data_type(data_type))
203 }
204
205 fn take_n(&mut self, n: usize) -> ArrayRef {
206 let first_n = self.group_values.drain(0..n).collect::<Vec<_>>();
207
208 let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None };
209
210 Arc::new(
211 PrimitiveArray::<T>::new(ScalarBuffer::from(first_n), first_n_nulls)
212 .with_data_type(self.data_type.clone()),
213 )
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::sync::Arc;
220
221 use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder;
222 use arrow::array::{ArrayRef, Float32Array, Int64Array, NullBufferBuilder};
223 use arrow::datatypes::{DataType, Float32Type, Int64Type};
224
225 use super::GroupColumn;
226
227 #[test]
228 fn test_nullable_primitive_equal_to() {
229 let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
230 builder_array: &ArrayRef,
231 append_rows: &[usize]| {
232 for &index in append_rows {
233 builder.append_val(builder_array, index).unwrap();
234 }
235 };
236
237 let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
238 lhs_rows: &[usize],
239 input_array: &ArrayRef,
240 rhs_rows: &[usize],
241 equal_to_results: &mut Vec<bool>| {
242 let iter = lhs_rows.iter().zip(rhs_rows.iter());
243 for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
244 equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
245 }
246 };
247
248 test_nullable_primitive_equal_to_internal(append, equal_to);
249 }
250
251 #[test]
252 fn test_nullable_primitive_vectorized_equal_to() {
253 let append = |builder: &mut PrimitiveGroupValueBuilder<Float32Type, true>,
254 builder_array: &ArrayRef,
255 append_rows: &[usize]| {
256 builder
257 .vectorized_append(builder_array, append_rows)
258 .unwrap();
259 };
260
261 let equal_to = |builder: &PrimitiveGroupValueBuilder<Float32Type, true>,
262 lhs_rows: &[usize],
263 input_array: &ArrayRef,
264 rhs_rows: &[usize],
265 equal_to_results: &mut Vec<bool>| {
266 builder.vectorized_equal_to(
267 lhs_rows,
268 input_array,
269 rhs_rows,
270 equal_to_results,
271 );
272 };
273
274 test_nullable_primitive_equal_to_internal(append, equal_to);
275 }
276
277 fn test_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
278 where
279 A: FnMut(&mut PrimitiveGroupValueBuilder<Float32Type, true>, &ArrayRef, &[usize]),
280 E: FnMut(
281 &PrimitiveGroupValueBuilder<Float32Type, true>,
282 &[usize],
283 &ArrayRef,
284 &[usize],
285 &mut Vec<bool>,
286 ),
287 {
288 let mut builder =
298 PrimitiveGroupValueBuilder::<Float32Type, true>::new(DataType::Float32);
299 let builder_array = Arc::new(Float32Array::from(vec![
300 None,
301 None,
302 None,
303 Some(1.0),
304 Some(2.0),
305 Some(f32::NAN),
306 Some(3.0),
307 ])) as ArrayRef;
308 append(&mut builder, &builder_array, &[0, 1, 2, 3, 4, 5, 6]);
309
310 let (_, values, _nulls) = Float32Array::from(vec![
312 Some(1.0),
313 Some(2.0),
314 None,
315 Some(1.0),
316 None,
317 Some(f32::NAN),
318 None,
319 ])
320 .into_parts();
321
322 let mut nulls = NullBufferBuilder::new(6);
324 nulls.append_non_null();
325 nulls.append_null(); nulls.append_null();
327 nulls.append_non_null();
328 nulls.append_null();
329 nulls.append_non_null();
330 nulls.append_null();
331 let input_array = Arc::new(Float32Array::new(values, nulls.finish())) as ArrayRef;
332
333 let mut equal_to_results = vec![true; builder.len()];
335 equal_to(
336 &builder,
337 &[0, 1, 2, 3, 4, 5, 6],
338 &input_array,
339 &[0, 1, 2, 3, 4, 5, 6],
340 &mut equal_to_results,
341 );
342
343 assert!(!equal_to_results[0]);
344 assert!(equal_to_results[1]);
345 assert!(equal_to_results[2]);
346 assert!(equal_to_results[3]);
347 assert!(!equal_to_results[4]);
348 assert!(equal_to_results[5]);
349 assert!(!equal_to_results[6]);
350 }
351
352 #[test]
353 fn test_not_nullable_primitive_equal_to() {
354 let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, false>,
355 builder_array: &ArrayRef,
356 append_rows: &[usize]| {
357 for &index in append_rows {
358 builder.append_val(builder_array, index).unwrap();
359 }
360 };
361
362 let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
363 lhs_rows: &[usize],
364 input_array: &ArrayRef,
365 rhs_rows: &[usize],
366 equal_to_results: &mut Vec<bool>| {
367 let iter = lhs_rows.iter().zip(rhs_rows.iter());
368 for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() {
369 equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row);
370 }
371 };
372
373 test_not_nullable_primitive_equal_to_internal(append, equal_to);
374 }
375
376 #[test]
377 fn test_not_nullable_primitive_vectorized_equal_to() {
378 let append = |builder: &mut PrimitiveGroupValueBuilder<Int64Type, false>,
379 builder_array: &ArrayRef,
380 append_rows: &[usize]| {
381 builder
382 .vectorized_append(builder_array, append_rows)
383 .unwrap();
384 };
385
386 let equal_to = |builder: &PrimitiveGroupValueBuilder<Int64Type, false>,
387 lhs_rows: &[usize],
388 input_array: &ArrayRef,
389 rhs_rows: &[usize],
390 equal_to_results: &mut Vec<bool>| {
391 builder.vectorized_equal_to(
392 lhs_rows,
393 input_array,
394 rhs_rows,
395 equal_to_results,
396 );
397 };
398
399 test_not_nullable_primitive_equal_to_internal(append, equal_to);
400 }
401
402 fn test_not_nullable_primitive_equal_to_internal<A, E>(mut append: A, mut equal_to: E)
403 where
404 A: FnMut(&mut PrimitiveGroupValueBuilder<Int64Type, false>, &ArrayRef, &[usize]),
405 E: FnMut(
406 &PrimitiveGroupValueBuilder<Int64Type, false>,
407 &[usize],
408 &ArrayRef,
409 &[usize],
410 &mut Vec<bool>,
411 ),
412 {
413 let mut builder =
419 PrimitiveGroupValueBuilder::<Int64Type, false>::new(DataType::Int64);
420 let builder_array =
421 Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef;
422 append(&mut builder, &builder_array, &[0, 1]);
423
424 let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef;
426
427 let mut equal_to_results = vec![true; builder.len()];
429 equal_to(
430 &builder,
431 &[0, 1],
432 &input_array,
433 &[0, 1],
434 &mut equal_to_results,
435 );
436
437 assert!(equal_to_results[0]);
438 assert!(!equal_to_results[1]);
439 }
440
441 #[test]
442 fn test_nullable_primitive_vectorized_operation_special_case() {
443 let mut builder =
447 PrimitiveGroupValueBuilder::<Int64Type, true>::new(DataType::Int64);
448
449 let all_nulls_input_array = Arc::new(Int64Array::from(vec![
451 Option::<i64>::None,
452 None,
453 None,
454 None,
455 None,
456 ])) as _;
457 builder
458 .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4])
459 .unwrap();
460
461 let mut equal_to_results = vec![true; all_nulls_input_array.len()];
462 builder.vectorized_equal_to(
463 &[0, 1, 2, 3, 4],
464 &all_nulls_input_array,
465 &[0, 1, 2, 3, 4],
466 &mut equal_to_results,
467 );
468
469 assert!(equal_to_results[0]);
470 assert!(equal_to_results[1]);
471 assert!(equal_to_results[2]);
472 assert!(equal_to_results[3]);
473 assert!(equal_to_results[4]);
474
475 let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![
477 Some(1),
478 Some(2),
479 Some(3),
480 Some(4),
481 Some(5),
482 ])) as _;
483 builder
484 .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4])
485 .unwrap();
486
487 let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
488 builder.vectorized_equal_to(
489 &[5, 6, 7, 8, 9],
490 &all_not_nulls_input_array,
491 &[0, 1, 2, 3, 4],
492 &mut equal_to_results,
493 );
494
495 assert!(equal_to_results[0]);
496 assert!(equal_to_results[1]);
497 assert!(equal_to_results[2]);
498 assert!(equal_to_results[3]);
499 assert!(equal_to_results[4]);
500 }
501}