datafusion_physical_plan/aggregates/group_values/
row.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::aggregates::group_values::GroupValues;
19use ahash::RandomState;
20use arrow::array::{Array, ArrayRef, ListArray, RecordBatch, StructArray};
21use arrow::compute::cast;
22use arrow::datatypes::{DataType, SchemaRef};
23use arrow::row::{RowConverter, Rows, SortField};
24use datafusion_common::hash_utils::create_hashes;
25use datafusion_common::Result;
26use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt};
27use datafusion_expr::EmitTo;
28use hashbrown::hash_table::HashTable;
29use log::debug;
30use std::mem::size_of;
31use std::sync::Arc;
32
33/// A [`GroupValues`] making use of [`Rows`]
34///
35/// This is a general implementation of [`GroupValues`] that works for any
36/// combination of data types and number of columns, including nested types such as
37/// structs and lists.
38///
39/// It uses the arrow-rs [`Rows`] to store the group values, which is a row-wise
40/// representation.
41pub struct GroupValuesRows {
42    /// The output schema
43    schema: SchemaRef,
44
45    /// Converter for the group values
46    row_converter: RowConverter,
47
48    /// Logically maps group values to a group_index in
49    /// [`Self::group_values`] and in each accumulator
50    ///
51    /// Uses the raw API of hashbrown to avoid actually storing the
52    /// keys (group values) in the table
53    ///
54    /// keys: u64 hashes of the GroupValue
55    /// values: (hash, group_index)
56    map: HashTable<(u64, usize)>,
57
58    /// The size of `map` in bytes
59    map_size: usize,
60
61    /// The actual group by values, stored in arrow [`Row`] format.
62    /// `group_values[i]` holds the group value for group_index `i`.
63    ///
64    /// The row format is used to compare group keys quickly and store
65    /// them efficiently in memory. Quick comparison is especially
66    /// important for multi-column group keys.
67    ///
68    /// [`Row`]: arrow::row::Row
69    group_values: Option<Rows>,
70
71    /// reused buffer to store hashes
72    hashes_buffer: Vec<u64>,
73
74    /// reused buffer to store rows
75    rows_buffer: Rows,
76
77    /// Random state for creating hashes
78    random_state: RandomState,
79}
80
81impl GroupValuesRows {
82    pub fn try_new(schema: SchemaRef) -> Result<Self> {
83        // Print a debugging message, so it is clear when the (slower) fallback
84        // GroupValuesRows is used.
85        debug!("Creating GroupValuesRows for schema: {schema}");
86        let row_converter = RowConverter::new(
87            schema
88                .fields()
89                .iter()
90                .map(|f| SortField::new(f.data_type().clone()))
91                .collect(),
92        )?;
93
94        let map = HashTable::with_capacity(0);
95
96        let starting_rows_capacity = 1000;
97
98        let starting_data_capacity = 64 * starting_rows_capacity;
99        let rows_buffer =
100            row_converter.empty_rows(starting_rows_capacity, starting_data_capacity);
101        Ok(Self {
102            schema,
103            row_converter,
104            map,
105            map_size: 0,
106            group_values: None,
107            hashes_buffer: Default::default(),
108            rows_buffer,
109            random_state: crate::aggregates::AGGREGATION_HASH_SEED,
110        })
111    }
112}
113
114impl GroupValues for GroupValuesRows {
115    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()> {
116        // Convert the group keys into the row format
117        let group_rows = &mut self.rows_buffer;
118        group_rows.clear();
119        self.row_converter.append(group_rows, cols)?;
120        let n_rows = group_rows.num_rows();
121
122        let mut group_values = match self.group_values.take() {
123            Some(group_values) => group_values,
124            None => self.row_converter.empty_rows(0, 0),
125        };
126
127        // tracks to which group each of the input rows belongs
128        groups.clear();
129
130        // 1.1 Calculate the group keys for the group values
131        let batch_hashes = &mut self.hashes_buffer;
132        batch_hashes.clear();
133        batch_hashes.resize(n_rows, 0);
134        create_hashes(cols, &self.random_state, batch_hashes)?;
135
136        for (row, &target_hash) in batch_hashes.iter().enumerate() {
137            let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| {
138                // Somewhat surprisingly, this closure can be called even if the
139                // hash doesn't match, so check the hash first with an integer
140                // comparison first avoid the more expensive comparison with
141                // group value. https://github.com/apache/datafusion/pull/11718
142                target_hash == *exist_hash
143                    // verify that the group that we are inserting with hash is
144                    // actually the same key value as the group in
145                    // existing_idx  (aka group_values @ row)
146                    && group_rows.row(row) == group_values.row(*group_idx)
147            });
148
149            let group_idx = match entry {
150                // Existing group_index for this group value
151                Some((_hash, group_idx)) => *group_idx,
152                //  1.2 Need to create new entry for the group
153                None => {
154                    // Add new entry to aggr_state and save newly created index
155                    let group_idx = group_values.num_rows();
156                    group_values.push(group_rows.row(row));
157
158                    // for hasher function, use precomputed hash value
159                    self.map.insert_accounted(
160                        (target_hash, group_idx),
161                        |(hash, _group_index)| *hash,
162                        &mut self.map_size,
163                    );
164                    group_idx
165                }
166            };
167            groups.push(group_idx);
168        }
169
170        self.group_values = Some(group_values);
171
172        Ok(())
173    }
174
175    fn size(&self) -> usize {
176        let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0);
177        self.row_converter.size()
178            + group_values_size
179            + self.map_size
180            + self.rows_buffer.size()
181            + self.hashes_buffer.allocated_size()
182    }
183
184    fn is_empty(&self) -> bool {
185        self.len() == 0
186    }
187
188    fn len(&self) -> usize {
189        self.group_values
190            .as_ref()
191            .map(|group_values| group_values.num_rows())
192            .unwrap_or(0)
193    }
194
195    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
196        let mut group_values = self
197            .group_values
198            .take()
199            .expect("Can not emit from empty rows");
200
201        let mut output = match emit_to {
202            EmitTo::All => {
203                let output = self.row_converter.convert_rows(&group_values)?;
204                group_values.clear();
205                self.map.clear();
206                output
207            }
208            EmitTo::First(n) => {
209                let groups_rows = group_values.iter().take(n);
210                let output = self.row_converter.convert_rows(groups_rows)?;
211                // Clear out first n group keys by copying them to a new Rows.
212                // TODO file some ticket in arrow-rs to make this more efficient?
213                let mut new_group_values = self.row_converter.empty_rows(0, 0);
214                for row in group_values.iter().skip(n) {
215                    new_group_values.push(row);
216                }
217                std::mem::swap(&mut new_group_values, &mut group_values);
218
219                self.map.retain(|(_exists_hash, group_idx)| {
220                    // Decrement group index by n
221                    match group_idx.checked_sub(n) {
222                        // Group index was >= n, shift value down
223                        Some(sub) => {
224                            *group_idx = sub;
225                            true
226                        }
227                        // Group index was < n, so remove from table
228                        None => false,
229                    }
230                });
231                output
232            }
233        };
234
235        // TODO: Materialize dictionaries in group keys
236        // https://github.com/apache/datafusion/issues/7647
237        for (field, array) in self.schema.fields.iter().zip(&mut output) {
238            let expected = field.data_type();
239            *array =
240                dictionary_encode_if_necessary(Arc::<dyn Array>::clone(array), expected)?;
241        }
242
243        self.group_values = Some(group_values);
244        Ok(output)
245    }
246
247    fn clear_shrink(&mut self, batch: &RecordBatch) {
248        let count = batch.num_rows();
249        self.group_values = self.group_values.take().map(|mut rows| {
250            rows.clear();
251            rows
252        });
253        self.map.clear();
254        self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
255        self.map_size = self.map.capacity() * size_of::<(u64, usize)>();
256        self.hashes_buffer.clear();
257        self.hashes_buffer.shrink_to(count);
258    }
259}
260
261fn dictionary_encode_if_necessary(
262    array: ArrayRef,
263    expected: &DataType,
264) -> Result<ArrayRef> {
265    match (expected, array.data_type()) {
266        (DataType::Struct(expected_fields), _) => {
267            let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
268            let arrays = expected_fields
269                .iter()
270                .zip(struct_array.columns())
271                .map(|(expected_field, column)| {
272                    dictionary_encode_if_necessary(
273                        Arc::<dyn Array>::clone(column),
274                        expected_field.data_type(),
275                    )
276                })
277                .collect::<Result<Vec<_>>>()?;
278
279            Ok(Arc::new(StructArray::try_new(
280                expected_fields.clone(),
281                arrays,
282                struct_array.nulls().cloned(),
283            )?))
284        }
285        (DataType::List(expected_field), &DataType::List(_)) => {
286            let list = array.as_any().downcast_ref::<ListArray>().unwrap();
287
288            Ok(Arc::new(ListArray::try_new(
289                Arc::<arrow::datatypes::Field>::clone(expected_field),
290                list.offsets().clone(),
291                dictionary_encode_if_necessary(
292                    Arc::<dyn Array>::clone(list.values()),
293                    expected_field.data_type(),
294                )?,
295                list.nulls().cloned(),
296            )?))
297        }
298        (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
299        (_, _) => Ok(Arc::<dyn Array>::clone(&array)),
300    }
301}