datafusion_physical_plan/aggregates/group_values/single_group_by/
primitive.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::aggregates::group_values::GroupValues;
19use ahash::RandomState;
20use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano};
21use arrow::array::{
22    cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder,
23    PrimitiveArray,
24};
25use arrow::datatypes::{i256, DataType};
26use arrow::record_batch::RecordBatch;
27use datafusion_common::Result;
28use datafusion_execution::memory_pool::proxy::VecAllocExt;
29use datafusion_expr::EmitTo;
30use half::f16;
31use hashbrown::hash_table::HashTable;
32use std::mem::size_of;
33use std::sync::Arc;
34
35/// A trait to allow hashing of floating point numbers
36pub(crate) trait HashValue {
37    fn hash(&self, state: &RandomState) -> u64;
38}
39
40macro_rules! hash_integer {
41    ($($t:ty),+) => {
42        $(impl HashValue for $t {
43            #[cfg(not(feature = "force_hash_collisions"))]
44            fn hash(&self, state: &RandomState) -> u64 {
45                state.hash_one(self)
46            }
47
48            #[cfg(feature = "force_hash_collisions")]
49            fn hash(&self, _state: &RandomState) -> u64 {
50                0
51            }
52        })+
53    };
54}
55hash_integer!(i8, i16, i32, i64, i128, i256);
56hash_integer!(u8, u16, u32, u64);
57hash_integer!(IntervalDayTime, IntervalMonthDayNano);
58
59macro_rules! hash_float {
60    ($($t:ty),+) => {
61        $(impl HashValue for $t {
62            #[cfg(not(feature = "force_hash_collisions"))]
63            fn hash(&self, state: &RandomState) -> u64 {
64                state.hash_one(self.to_bits())
65            }
66
67            #[cfg(feature = "force_hash_collisions")]
68            fn hash(&self, _state: &RandomState) -> u64 {
69                0
70            }
71        })+
72    };
73}
74
75hash_float!(f16, f32, f64);
76
77/// A [`GroupValues`] storing a single column of primitive values
78///
79/// This specialization is significantly faster than using the more general
80/// purpose `Row`s format
81pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
82    /// The data type of the output array
83    data_type: DataType,
84    /// Stores the `(group_index, hash)` based on the hash of its value
85    ///
86    /// We also store `hash` is for reducing cost of rehashing. Such cost
87    /// is obvious in high cardinality group by situation.
88    /// More details can see:
89    /// <https://github.com/apache/datafusion/issues/15961>
90    map: HashTable<(usize, u64)>,
91    /// The group index of the null value if any
92    null_group: Option<usize>,
93    /// The values for each group index
94    values: Vec<T::Native>,
95    /// The random state used to generate hashes
96    random_state: RandomState,
97}
98
99impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
100    pub fn new(data_type: DataType) -> Self {
101        assert!(PrimitiveArray::<T>::is_compatible(&data_type));
102        Self {
103            data_type,
104            map: HashTable::with_capacity(128),
105            values: Vec::with_capacity(128),
106            null_group: None,
107            random_state: crate::aggregates::AGGREGATION_HASH_SEED,
108        }
109    }
110}
111
112impl<T: ArrowPrimitiveType> GroupValues for GroupValuesPrimitive<T>
113where
114    T::Native: HashValue,
115{
116    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
117        assert_eq!(cols.len(), 1);
118        groups.clear();
119
120        for v in cols[0].as_primitive::<T>() {
121            let group_id = match v {
122                None => *self.null_group.get_or_insert_with(|| {
123                    let group_id = self.values.len();
124                    self.values.push(Default::default());
125                    group_id
126                }),
127                Some(key) => {
128                    let state = &self.random_state;
129                    let hash = key.hash(state);
130                    let insert = self.map.entry(
131                        hash,
132                        |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) },
133                        |&(_, h)| h,
134                    );
135
136                    match insert {
137                        hashbrown::hash_table::Entry::Occupied(o) => o.get().0,
138                        hashbrown::hash_table::Entry::Vacant(v) => {
139                            let g = self.values.len();
140                            v.insert((g, hash));
141                            self.values.push(key);
142                            g
143                        }
144                    }
145                }
146            };
147            groups.push(group_id)
148        }
149        Ok(())
150    }
151
152    fn size(&self) -> usize {
153        self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size()
154    }
155
156    fn is_empty(&self) -> bool {
157        self.values.is_empty()
158    }
159
160    fn len(&self) -> usize {
161        self.values.len()
162    }
163
164    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
165        fn build_primitive<T: ArrowPrimitiveType>(
166            values: Vec<T::Native>,
167            null_idx: Option<usize>,
168        ) -> PrimitiveArray<T> {
169            let nulls = null_idx.map(|null_idx| {
170                let mut buffer = NullBufferBuilder::new(values.len());
171                buffer.append_n_non_nulls(null_idx);
172                buffer.append_null();
173                buffer.append_n_non_nulls(values.len() - null_idx - 1);
174                // NOTE: The inner builder must be constructed as there is at least one null
175                buffer.finish().unwrap()
176            });
177            PrimitiveArray::<T>::new(values.into(), nulls)
178        }
179
180        let array: PrimitiveArray<T> = match emit_to {
181            EmitTo::All => {
182                self.map.clear();
183                build_primitive(std::mem::take(&mut self.values), self.null_group.take())
184            }
185            EmitTo::First(n) => {
186                self.map.retain(|entry| {
187                    // Decrement group index by n
188                    let group_idx = entry.0;
189                    match group_idx.checked_sub(n) {
190                        // Group index was >= n, shift value down
191                        Some(sub) => {
192                            entry.0 = sub;
193                            true
194                        }
195                        // Group index was < n, so remove from table
196                        None => false,
197                    }
198                });
199                let null_group = match &mut self.null_group {
200                    Some(v) if *v >= n => {
201                        *v -= n;
202                        None
203                    }
204                    Some(_) => self.null_group.take(),
205                    None => None,
206                };
207                let mut split = self.values.split_off(n);
208                std::mem::swap(&mut self.values, &mut split);
209                build_primitive(split, null_group)
210            }
211        };
212
213        Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
214    }
215
216    fn clear_shrink(&mut self, batch: &RecordBatch) {
217        let count = batch.num_rows();
218        self.values.clear();
219        self.values.shrink_to(count);
220        self.map.clear();
221        self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
222    }
223}