datafusion_physical_plan/aggregates/row_hash.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Hash aggregation
19
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::vec;
23
24use super::order::GroupOrdering;
25use super::AggregateExec;
26use crate::aggregates::group_values::{new_group_values, GroupByMetrics, GroupValues};
27use crate::aggregates::order::GroupOrderingFull;
28use crate::aggregates::{
29 create_schema, evaluate_group_by, evaluate_many, evaluate_optional, AggregateMode,
30 PhysicalGroupBy,
31};
32use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
33use crate::sorts::sort::sort_batch;
34use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
35use crate::spill::spill_manager::SpillManager;
36use crate::stream::RecordBatchStreamAdapter;
37use crate::{aggregates, metrics, PhysicalExpr};
38use crate::{RecordBatchStream, SendableRecordBatchStream};
39
40use arrow::array::*;
41use arrow::datatypes::SchemaRef;
42use datafusion_common::{internal_err, DataFusionError, Result};
43use datafusion_execution::memory_pool::proxy::VecAllocExt;
44use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
45use datafusion_execution::TaskContext;
46use datafusion_expr::{EmitTo, GroupsAccumulator};
47use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
48use datafusion_physical_expr::expressions::Column;
49use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr};
50use datafusion_physical_expr_common::sort_expr::LexOrdering;
51
52use datafusion_common::instant::Instant;
53use futures::ready;
54use futures::stream::{Stream, StreamExt};
55use log::debug;
56
57#[derive(Debug, Clone)]
58/// This object tracks the aggregation phase (input/output)
59pub(crate) enum ExecutionState {
60 ReadingInput,
61 /// When producing output, the remaining rows to output are stored
62 /// here and are sliced off as needed in batch_size chunks
63 ProducingOutput(RecordBatch),
64 /// Produce intermediate aggregate state for each input row without
65 /// aggregation.
66 ///
67 /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
68 SkippingAggregation,
69 /// All input has been consumed and all groups have been emitted
70 Done,
71}
72
73/// This encapsulates the spilling state
74struct SpillState {
75 // ========================================================================
76 // PROPERTIES:
77 // These fields are initialized at the start and remain constant throughout
78 // the execution.
79 // ========================================================================
80 /// Sorting expression for spilling batches
81 spill_expr: LexOrdering,
82
83 /// Schema for spilling batches
84 spill_schema: SchemaRef,
85
86 /// aggregate_arguments for merging spilled data
87 merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
88
89 /// GROUP BY expressions for merging spilled data
90 merging_group_by: PhysicalGroupBy,
91
92 /// Manages the process of spilling and reading back intermediate data
93 spill_manager: SpillManager,
94
95 // ========================================================================
96 // STATES:
97 // Fields changes during execution. Can be buffer, or state flags that
98 // influence the execution in parent `GroupedHashAggregateStream`
99 // ========================================================================
100 /// If data has previously been spilled, the locations of the
101 /// spill files (in Arrow IPC format)
102 spills: Vec<SortedSpillFile>,
103
104 /// true when streaming merge is in progress
105 is_stream_merging: bool,
106
107 // ========================================================================
108 // METRICS:
109 // ========================================================================
110 /// Peak memory used for buffered data.
111 /// Calculated as sum of peak memory values across partitions
112 peak_mem_used: metrics::Gauge,
113 // Metrics related to spilling are managed inside `spill_manager`
114}
115
116/// Tracks if the aggregate should skip partial aggregations
117///
118/// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
119struct SkipAggregationProbe {
120 // ========================================================================
121 // PROPERTIES:
122 // These fields are initialized at the start and remain constant throughout
123 // the execution.
124 // ========================================================================
125 /// Aggregation ratio check performed when the number of input rows exceeds
126 /// this threshold (from `SessionConfig`)
127 probe_rows_threshold: usize,
128 /// Maximum ratio of `num_groups` to `input_rows` for continuing aggregation
129 /// (from `SessionConfig`). If the ratio exceeds this value, aggregation
130 /// is skipped and input rows are directly converted to output
131 probe_ratio_threshold: f64,
132
133 // ========================================================================
134 // STATES:
135 // Fields changes during execution. Can be buffer, or state flags that
136 // influence the execution in parent `GroupedHashAggregateStream`
137 // ========================================================================
138 /// Number of processed input rows (updated during probing)
139 input_rows: usize,
140 /// Number of total group values for `input_rows` (updated during probing)
141 num_groups: usize,
142
143 /// Flag indicating further data aggregation may be skipped (decision made
144 /// when probing complete)
145 should_skip: bool,
146 /// Flag indicating further updates of `SkipAggregationProbe` state won't
147 /// make any effect (set either while probing or on probing completion)
148 is_locked: bool,
149
150 // ========================================================================
151 // METRICS:
152 // ========================================================================
153 /// Number of rows where state was output without aggregation.
154 ///
155 /// * If 0, all input rows were aggregated (should_skip was always false)
156 ///
157 /// * if greater than zero, the number of rows which were output directly
158 /// without aggregation
159 skipped_aggregation_rows: metrics::Count,
160}
161
162impl SkipAggregationProbe {
163 fn new(
164 probe_rows_threshold: usize,
165 probe_ratio_threshold: f64,
166 skipped_aggregation_rows: metrics::Count,
167 ) -> Self {
168 Self {
169 input_rows: 0,
170 num_groups: 0,
171 probe_rows_threshold,
172 probe_ratio_threshold,
173 should_skip: false,
174 is_locked: false,
175 skipped_aggregation_rows,
176 }
177 }
178
179 /// Updates `SkipAggregationProbe` state:
180 /// - increments the number of input rows
181 /// - replaces the number of groups with the new value
182 /// - on `probe_rows_threshold` exceeded calculates
183 /// aggregation ratio and sets `should_skip` flag
184 /// - if `should_skip` is set, locks further state updates
185 fn update_state(&mut self, input_rows: usize, num_groups: usize) {
186 if self.is_locked {
187 return;
188 }
189 self.input_rows += input_rows;
190 self.num_groups = num_groups;
191 if self.input_rows >= self.probe_rows_threshold {
192 self.should_skip = self.num_groups as f64 / self.input_rows as f64
193 >= self.probe_ratio_threshold;
194 self.is_locked = true;
195 }
196 }
197
198 fn should_skip(&self) -> bool {
199 self.should_skip
200 }
201
202 /// Record the number of rows that were output directly without aggregation
203 fn record_skipped(&mut self, batch: &RecordBatch) {
204 self.skipped_aggregation_rows.add(batch.num_rows());
205 }
206}
207
208/// HashTable based Grouping Aggregator
209///
210/// # Design Goals
211///
212/// This structure is designed so that updating the aggregates can be
213/// vectorized (done in a tight loop) without allocations. The
214/// accumulator state is *not* managed by this operator (e.g in the
215/// hash table) and instead is delegated to the individual
216/// accumulators which have type specialized inner loops that perform
217/// the aggregation.
218///
219/// # Architecture
220///
221/// ```text
222///
223/// Assigns a consecutive group internally stores aggregate values
224/// index for each unique set for all groups
225/// of group values
226///
227/// ┌────────────┐ ┌──────────────┐ ┌──────────────┐
228/// │ ┌────────┐ │ │┌────────────┐│ │┌────────────┐│
229/// │ │ "A" │ │ ││accumulator ││ ││accumulator ││
230/// │ ├────────┤ │ ││ 0 ││ ││ N ││
231/// │ │ "Z" │ │ ││ ┌────────┐ ││ ││ ┌────────┐ ││
232/// │ └────────┘ │ ││ │ state │ ││ ││ │ state │ ││
233/// │ │ ││ │┌─────┐ │ ││ ... ││ │┌─────┐ │ ││
234/// │ ... │ ││ │├─────┤ │ ││ ││ │├─────┤ │ ││
235/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││
236/// │ │ ││ │ │ ││ ││ │ │ ││
237/// │ ┌────────┐ │ ││ │ ... │ ││ ││ │ ... │ ││
238/// │ │ "Q" │ │ ││ │ │ ││ ││ │ │ ││
239/// │ └────────┘ │ ││ │┌─────┐ │ ││ ││ │┌─────┐ │ ││
240/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││
241/// └────────────┘ ││ └────────┘ ││ ││ └────────┘ ││
242/// │└────────────┘│ │└────────────┘│
243/// └──────────────┘ └──────────────┘
244///
245/// group_values accumulators
246///
247/// ```
248///
249/// For example, given a query like `COUNT(x), SUM(y) ... GROUP BY z`,
250/// [`group_values`] will store the distinct values of `z`. There will
251/// be one accumulator for `COUNT(x)`, specialized for the data type
252/// of `x` and one accumulator for `SUM(y)`, specialized for the data
253/// type of `y`.
254///
255/// # Discussion
256///
257/// [`group_values`] does not store any aggregate state inline. It only
258/// assigns "group indices", one for each (distinct) group value. The
259/// accumulators manage the in-progress aggregate state for each
260/// group, with the group values themselves are stored in
261/// [`group_values`] at the corresponding group index.
262///
263/// The accumulator state (e.g partial sums) is managed by and stored
264/// by a [`GroupsAccumulator`] accumulator. There is one accumulator
265/// per aggregate expression (COUNT, AVG, etc) in the
266/// stream. Internally, each `GroupsAccumulator` manages the state for
267/// multiple groups, and is passed `group_indexes` during update. Note
268/// The accumulator state is not managed by this operator (e.g in the
269/// hash table).
270///
271/// [`group_values`]: Self::group_values
272///
273/// # Partial Aggregate and multi-phase grouping
274///
275/// As described on [`Accumulator::state`], this operator is used in the context
276/// "multi-phase" grouping when the mode is [`AggregateMode::Partial`].
277///
278/// An important optimization for multi-phase partial aggregation is to skip
279/// partial aggregation when it is not effective enough to warrant the memory or
280/// CPU cost, as is often the case for queries many distinct groups (high
281/// cardinality group by). Memory is particularly important because each Partial
282/// aggregator must store the intermediate state for each group.
283///
284/// If the ratio of the number of groups to the number of input rows exceeds a
285/// threshold, and [`GroupsAccumulator::supports_convert_to_state`] is
286/// supported, this operator will stop applying Partial aggregation and directly
287/// pass the input rows to the next aggregation phase.
288///
289/// [`Accumulator::state`]: datafusion_expr::Accumulator::state
290///
291/// # Spilling (to disk)
292///
293/// The sizes of group values and accumulators can become large. Before that causes out of memory,
294/// this hash aggregator outputs partial states early for partial aggregation or spills to local
295/// disk using Arrow IPC format for final aggregation. For every input [`RecordBatch`], the memory
296/// manager checks whether the new input size meets the memory configuration. If not, outputting or
297/// spilling happens. For outputting, the final aggregation takes care of re-grouping. For spilling,
298/// later stream-merge sort on reading back the spilled data does re-grouping. Note the rows cannot
299/// be grouped once spilled onto disk, the read back data needs to be re-grouped again. In addition,
300/// re-grouping may cause out of memory again. Thus, re-grouping has to be a sort based aggregation.
301/// ```text
302/// Partial Aggregation [batch_size = 2] (max memory = 3 rows)
303///
304/// INPUTS PARTIALLY AGGREGATED (UPDATE BATCH) OUTPUTS
305/// ┌─────────┐ ┌─────────────────┐ ┌─────────────────┐
306/// │ a │ b │ │ a │ AVG(b) │ │ a │ AVG(b) │
307/// │---│-----│ │ │[count]│[sum]│ │ │[count]│[sum]│
308/// │ 3 │ 3.0 │ ─▶ │---│-------│-----│ │---│-------│-----│
309/// │ 2 │ 2.0 │ │ 2 │ 1 │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1 │ 2.0 │
310/// └─────────┘ │ 3 │ 2 │ 7.0 │ │ │ 3 │ 2 │ 7.0 │
311/// ┌─────────┐ ─▶ │ 4 │ 1 │ 8.0 │ │ └─────────────────┘
312/// │ 3 │ 4.0 │ └─────────────────┘ └▶ ┌─────────────────┐
313/// │ 4 │ 8.0 │ ┌─────────────────┐ │ 4 │ 1 │ 8.0 │
314/// └─────────┘ │ a │ AVG(b) │ ┌▶ │ 1 │ 1 │ 1.0 │
315/// ┌─────────┐ │---│-------│-----│ │ └─────────────────┘
316/// │ 1 │ 1.0 │ ─▶ │ 1 │ 1 │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐
317/// │ 3 │ 2.0 │ │ 3 │ 1 │ 2.0 │ │ 3 │ 1 │ 2.0 │
318/// └─────────┘ └─────────────────┘ └─────────────────┘
319///
320///
321/// Final Aggregation [batch_size = 2] (max memory = 3 rows)
322///
323/// PARTIALLY INPUTS FINAL AGGREGATION (MERGE BATCH) RE-GROUPED (SORTED)
324/// ┌─────────────────┐ [keep using the partial schema] [Real final aggregation
325/// │ a │ AVG(b) │ ┌─────────────────┐ output]
326/// │ │[count]│[sum]│ │ a │ AVG(b) │ ┌────────────┐
327/// │---│-------│-----│ ─▶ │ │[count]│[sum]│ │ a │ AVG(b) │
328/// │ 3 │ 3 │ 3.0 │ │---│-------│-----│ ─▶ spill ─┐ │---│--------│
329/// │ 2 │ 2 │ 1.0 │ │ 2 │ 2 │ 1.0 │ │ │ 1 │ 4.0 │
330/// └─────────────────┘ │ 3 │ 4 │ 8.0 │ ▼ │ 2 │ 1.0 │
331/// ┌─────────────────┐ ─▶ │ 4 │ 1 │ 7.0 │ Streaming ─▶ └────────────┘
332/// │ 3 │ 1 │ 5.0 │ └─────────────────┘ merge sort ─▶ ┌────────────┐
333/// │ 4 │ 1 │ 7.0 │ ┌─────────────────┐ ▲ │ a │ AVG(b) │
334/// └─────────────────┘ │ a │ AVG(b) │ │ │---│--------│
335/// ┌─────────────────┐ │---│-------│-----│ ─▶ memory ─┘ │ 3 │ 2.0 │
336/// │ 1 │ 2 │ 8.0 │ ─▶ │ 1 │ 2 │ 8.0 │ │ 4 │ 7.0 │
337/// │ 2 │ 2 │ 3.0 │ │ 2 │ 2 │ 3.0 │ └────────────┘
338/// └─────────────────┘ └─────────────────┘
339/// ```
340pub(crate) struct GroupedHashAggregateStream {
341 // ========================================================================
342 // PROPERTIES:
343 // These fields are initialized at the start and remain constant throughout
344 // the execution.
345 // ========================================================================
346 schema: SchemaRef,
347 input: SendableRecordBatchStream,
348 mode: AggregateMode,
349
350 /// Arguments to pass to each accumulator.
351 ///
352 /// The arguments in `accumulator[i]` is passed `aggregate_arguments[i]`
353 ///
354 /// The argument to each accumulator is itself a `Vec` because
355 /// some aggregates such as `CORR` can accept more than one
356 /// argument.
357 aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
358
359 /// Optional filter expression to evaluate, one for each for
360 /// accumulator. If present, only those rows for which the filter
361 /// evaluate to true should be included in the aggregate results.
362 ///
363 /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`,
364 /// the filter expression is `x > 100`.
365 filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
366
367 /// GROUP BY expressions
368 group_by: PhysicalGroupBy,
369
370 /// max rows in output RecordBatches
371 batch_size: usize,
372
373 /// Optional soft limit on the number of `group_values` in a batch
374 /// If the number of `group_values` in a single batch exceeds this value,
375 /// the `GroupedHashAggregateStream` operation immediately switches to
376 /// output mode and emits all groups.
377 group_values_soft_limit: Option<usize>,
378
379 // ========================================================================
380 // STATE FLAGS:
381 // These fields will be updated during the execution. And control the flow of
382 // the execution.
383 // ========================================================================
384 /// Tracks if this stream is generating input or output
385 exec_state: ExecutionState,
386
387 /// Have we seen the end of the input
388 input_done: bool,
389
390 // ========================================================================
391 // STATE BUFFERS:
392 // These fields will accumulate intermediate results during the execution.
393 // ========================================================================
394 /// An interning store of group keys
395 group_values: Box<dyn GroupValues>,
396
397 /// scratch space for the current input [`RecordBatch`] being
398 /// processed. Reused across batches here to avoid reallocations
399 current_group_indices: Vec<usize>,
400
401 /// Accumulators, one for each `AggregateFunctionExpr` in the query
402 ///
403 /// For example, if the query has aggregates, `SUM(x)`,
404 /// `COUNT(y)`, there will be two accumulators, each one
405 /// specialized for that particular aggregate and its input types
406 accumulators: Vec<Box<dyn GroupsAccumulator>>,
407
408 // ========================================================================
409 // TASK-SPECIFIC STATES:
410 // Inner states groups together properties, states for a specific task.
411 // ========================================================================
412 /// Optional ordering information, that might allow groups to be
413 /// emitted from the hash table prior to seeing the end of the
414 /// input
415 group_ordering: GroupOrdering,
416
417 /// The spill state object
418 spill_state: SpillState,
419
420 /// Optional probe for skipping data aggregation, if supported by
421 /// current stream.
422 skip_aggregation_probe: Option<SkipAggregationProbe>,
423
424 // ========================================================================
425 // EXECUTION RESOURCES:
426 // Fields related to managing execution resources and monitoring performance.
427 // ========================================================================
428 /// The memory reservation for this grouping
429 reservation: MemoryReservation,
430
431 /// Execution metrics
432 baseline_metrics: BaselineMetrics,
433
434 /// Aggregation-specific metrics
435 group_by_metrics: GroupByMetrics,
436
437 /// Reduction factor metric, calculated as `output_rows/input_rows` (only for partial aggregation)
438 reduction_factor: Option<metrics::RatioMetrics>,
439}
440
441impl GroupedHashAggregateStream {
442 /// Create a new GroupedHashAggregateStream
443 pub fn new(
444 agg: &AggregateExec,
445 context: Arc<TaskContext>,
446 partition: usize,
447 ) -> Result<Self> {
448 debug!("Creating GroupedHashAggregateStream");
449 let agg_schema = Arc::clone(&agg.schema);
450 let agg_group_by = agg.group_by.clone();
451 let agg_filter_expr = agg.filter_expr.clone();
452
453 let batch_size = context.session_config().batch_size();
454 let input = agg.input.execute(partition, Arc::clone(&context))?;
455 let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
456 let group_by_metrics = GroupByMetrics::new(&agg.metrics, partition);
457
458 let timer = baseline_metrics.elapsed_compute().timer();
459
460 let aggregate_exprs = agg.aggr_expr.clone();
461
462 // arguments for each aggregate, one vec of expressions per
463 // aggregate
464 let aggregate_arguments = aggregates::aggregate_expressions(
465 &agg.aggr_expr,
466 &agg.mode,
467 agg_group_by.num_group_exprs(),
468 )?;
469 // arguments for aggregating spilled data is the same as the one for final aggregation
470 let merging_aggregate_arguments = aggregates::aggregate_expressions(
471 &agg.aggr_expr,
472 &AggregateMode::Final,
473 agg_group_by.num_group_exprs(),
474 )?;
475
476 let filter_expressions = match agg.mode {
477 AggregateMode::Partial
478 | AggregateMode::Single
479 | AggregateMode::SinglePartitioned => agg_filter_expr,
480 AggregateMode::Final | AggregateMode::FinalPartitioned => {
481 vec![None; agg.aggr_expr.len()]
482 }
483 };
484
485 // Instantiate the accumulators
486 let accumulators: Vec<_> = aggregate_exprs
487 .iter()
488 .map(create_group_accumulator)
489 .collect::<Result<_>>()?;
490
491 let group_schema = agg_group_by.group_schema(&agg.input().schema())?;
492
493 // fix https://github.com/apache/datafusion/issues/13949
494 // Builds a **partial aggregation** schema by combining the group columns and
495 // the accumulator state columns produced by each aggregate expression.
496 //
497 // # Why Partial Aggregation Schema Is Needed
498 //
499 // In a multi-stage (partial/final) aggregation strategy, each partial-aggregate
500 // operator produces *intermediate* states (e.g., partial sums, counts) rather
501 // than final scalar values. These extra columns do **not** exist in the original
502 // input schema (which may be something like `[colA, colB, ...]`). Instead,
503 // each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`).
504 //
505 // Therefore, when we spill these intermediate states or pass them to another
506 // aggregation operator, we must use a schema that includes both the group
507 // columns **and** the partial-state columns.
508 let partial_agg_schema = create_schema(
509 &agg.input().schema(),
510 &agg_group_by,
511 &aggregate_exprs,
512 AggregateMode::Partial,
513 )?;
514
515 // Need to update the GROUP BY expressions to point to the correct column after schema change
516 let merging_group_by_expr = agg_group_by
517 .expr
518 .iter()
519 .enumerate()
520 .map(|(idx, (_, name))| {
521 (Arc::new(Column::new(name.as_str(), idx)) as _, name.clone())
522 })
523 .collect();
524
525 let partial_agg_schema = Arc::new(partial_agg_schema);
526
527 let spill_expr =
528 group_schema
529 .fields
530 .into_iter()
531 .enumerate()
532 .map(|(idx, field)| {
533 PhysicalSortExpr::new_default(Arc::new(Column::new(
534 field.name().as_str(),
535 idx,
536 )) as _)
537 });
538 let Some(spill_expr) = LexOrdering::new(spill_expr) else {
539 return internal_err!("Spill expression is empty");
540 };
541
542 let agg_fn_names = aggregate_exprs
543 .iter()
544 .map(|expr| expr.human_display())
545 .collect::<Vec<_>>()
546 .join(", ");
547 let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})");
548 let reservation = MemoryConsumer::new(name)
549 .with_can_spill(true)
550 .register(context.memory_pool());
551 let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?;
552 let group_values = new_group_values(group_schema, &group_ordering)?;
553 timer.done();
554
555 let exec_state = ExecutionState::ReadingInput;
556
557 let spill_manager = SpillManager::new(
558 context.runtime_env(),
559 metrics::SpillMetrics::new(&agg.metrics, partition),
560 Arc::clone(&partial_agg_schema),
561 )
562 .with_compression_type(context.session_config().spill_compression());
563
564 let spill_state = SpillState {
565 spills: vec![],
566 spill_expr,
567 spill_schema: partial_agg_schema,
568 is_stream_merging: false,
569 merging_aggregate_arguments,
570 merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr),
571 peak_mem_used: MetricBuilder::new(&agg.metrics)
572 .gauge("peak_mem_used", partition),
573 spill_manager,
574 };
575
576 // Skip aggregation is supported if:
577 // - aggregation mode is Partial
578 // - input is not ordered by GROUP BY expressions,
579 // since Final mode expects unique group values as its input
580 // - all accumulators support input batch to intermediate
581 // aggregate state conversion
582 // - there is only one GROUP BY expressions set
583 let skip_aggregation_probe = if agg.mode == AggregateMode::Partial
584 && matches!(group_ordering, GroupOrdering::None)
585 && accumulators
586 .iter()
587 .all(|acc| acc.supports_convert_to_state())
588 && agg_group_by.is_single()
589 {
590 let options = &context.session_config().options().execution;
591 let probe_rows_threshold =
592 options.skip_partial_aggregation_probe_rows_threshold;
593 let probe_ratio_threshold =
594 options.skip_partial_aggregation_probe_ratio_threshold;
595 let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics)
596 .counter("skipped_aggregation_rows", partition);
597 Some(SkipAggregationProbe::new(
598 probe_rows_threshold,
599 probe_ratio_threshold,
600 skipped_aggregation_rows,
601 ))
602 } else {
603 None
604 };
605
606 let reduction_factor = if agg.mode == AggregateMode::Partial {
607 Some(
608 MetricBuilder::new(&agg.metrics)
609 .with_type(metrics::MetricType::SUMMARY)
610 .ratio_metrics("reduction_factor", partition),
611 )
612 } else {
613 None
614 };
615
616 Ok(GroupedHashAggregateStream {
617 schema: agg_schema,
618 input,
619 mode: agg.mode,
620 accumulators,
621 aggregate_arguments,
622 filter_expressions,
623 group_by: agg_group_by,
624 reservation,
625 group_values,
626 current_group_indices: Default::default(),
627 exec_state,
628 baseline_metrics,
629 group_by_metrics,
630 batch_size,
631 group_ordering,
632 input_done: false,
633 spill_state,
634 group_values_soft_limit: agg.limit,
635 skip_aggregation_probe,
636 reduction_factor,
637 })
638 }
639}
640
641/// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if
642/// that is supported by the aggregate, or a
643/// [`GroupsAccumulatorAdapter`] if not.
644pub(crate) fn create_group_accumulator(
645 agg_expr: &Arc<AggregateFunctionExpr>,
646) -> Result<Box<dyn GroupsAccumulator>> {
647 if agg_expr.groups_accumulator_supported() {
648 agg_expr.create_groups_accumulator()
649 } else {
650 // Note in the log when the slow path is used
651 debug!(
652 "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}",
653 agg_expr.name()
654 );
655 let agg_expr_captured = Arc::clone(agg_expr);
656 let factory = move || agg_expr_captured.create_accumulator();
657 Ok(Box::new(GroupsAccumulatorAdapter::new(factory)))
658 }
659}
660
661impl Stream for GroupedHashAggregateStream {
662 type Item = Result<RecordBatch>;
663
664 fn poll_next(
665 mut self: std::pin::Pin<&mut Self>,
666 cx: &mut Context<'_>,
667 ) -> Poll<Option<Self::Item>> {
668 let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
669
670 loop {
671 match &self.exec_state {
672 ExecutionState::ReadingInput => 'reading_input: {
673 match ready!(self.input.poll_next_unpin(cx)) {
674 // New batch to aggregate in partial aggregation operator
675 Some(Ok(batch)) if self.mode == AggregateMode::Partial => {
676 let timer = elapsed_compute.timer();
677 let input_rows = batch.num_rows();
678
679 if let Some(reduction_factor) = self.reduction_factor.as_ref()
680 {
681 reduction_factor.add_total(input_rows);
682 }
683
684 // Do the grouping
685 self.group_aggregate_batch(batch)?;
686
687 self.update_skip_aggregation_probe(input_rows);
688
689 // If we can begin emitting rows, do so,
690 // otherwise keep consuming input
691 assert!(!self.input_done);
692
693 // If the number of group values equals or exceeds the soft limit,
694 // emit all groups and switch to producing output
695 if self.hit_soft_group_limit() {
696 timer.done();
697 self.set_input_done_and_produce_output()?;
698 // make sure the exec_state just set is not overwritten below
699 break 'reading_input;
700 }
701
702 if let Some(to_emit) = self.group_ordering.emit_to() {
703 timer.done();
704 if let Some(batch) = self.emit(to_emit, false)? {
705 self.exec_state =
706 ExecutionState::ProducingOutput(batch);
707 };
708 // make sure the exec_state just set is not overwritten below
709 break 'reading_input;
710 }
711
712 self.emit_early_if_necessary()?;
713
714 self.switch_to_skip_aggregation()?;
715
716 timer.done();
717 }
718
719 // New batch to aggregate in terminal aggregation operator
720 // (Final/FinalPartitioned/Single/SinglePartitioned)
721 Some(Ok(batch)) => {
722 let timer = elapsed_compute.timer();
723
724 // Make sure we have enough capacity for `batch`, otherwise spill
725 self.spill_previous_if_necessary(&batch)?;
726
727 // Do the grouping
728 self.group_aggregate_batch(batch)?;
729
730 // If we can begin emitting rows, do so,
731 // otherwise keep consuming input
732 assert!(!self.input_done);
733
734 // If the number of group values equals or exceeds the soft limit,
735 // emit all groups and switch to producing output
736 if self.hit_soft_group_limit() {
737 timer.done();
738 self.set_input_done_and_produce_output()?;
739 // make sure the exec_state just set is not overwritten below
740 break 'reading_input;
741 }
742
743 if let Some(to_emit) = self.group_ordering.emit_to() {
744 timer.done();
745 if let Some(batch) = self.emit(to_emit, false)? {
746 self.exec_state =
747 ExecutionState::ProducingOutput(batch);
748 };
749 // make sure the exec_state just set is not overwritten below
750 break 'reading_input;
751 }
752
753 timer.done();
754 }
755
756 // Found error from input stream
757 Some(Err(e)) => {
758 // inner had error, return to caller
759 return Poll::Ready(Some(Err(e)));
760 }
761
762 // Found end from input stream
763 None => {
764 // inner is done, emit all rows and switch to producing output
765 self.set_input_done_and_produce_output()?;
766 }
767 }
768 }
769
770 ExecutionState::SkippingAggregation => {
771 match ready!(self.input.poll_next_unpin(cx)) {
772 Some(Ok(batch)) => {
773 let _timer = elapsed_compute.timer();
774 if let Some(probe) = self.skip_aggregation_probe.as_mut() {
775 probe.record_skipped(&batch);
776 }
777 let states = self.transform_to_states(batch)?;
778 return Poll::Ready(Some(Ok(
779 states.record_output(&self.baseline_metrics)
780 )));
781 }
782 Some(Err(e)) => {
783 // inner had error, return to caller
784 return Poll::Ready(Some(Err(e)));
785 }
786 None => {
787 // inner is done, switching to `Done` state
788 self.exec_state = ExecutionState::Done;
789 }
790 }
791 }
792
793 ExecutionState::ProducingOutput(batch) => {
794 // slice off a part of the batch, if needed
795 let output_batch;
796 let size = self.batch_size;
797 (self.exec_state, output_batch) = if batch.num_rows() <= size {
798 (
799 if self.input_done {
800 ExecutionState::Done
801 }
802 // In Partial aggregation, we also need to check
803 // if we should trigger partial skipping
804 else if self.mode == AggregateMode::Partial
805 && self.should_skip_aggregation()
806 {
807 ExecutionState::SkippingAggregation
808 } else {
809 ExecutionState::ReadingInput
810 },
811 batch.clone(),
812 )
813 } else {
814 // output first batch_size rows
815 let size = self.batch_size;
816 let num_remaining = batch.num_rows() - size;
817 let remaining = batch.slice(size, num_remaining);
818 let output = batch.slice(0, size);
819 (ExecutionState::ProducingOutput(remaining), output)
820 };
821
822 if let Some(reduction_factor) = self.reduction_factor.as_ref() {
823 reduction_factor.add_part(output_batch.num_rows());
824 }
825
826 // Empty record batches should not be emitted.
827 // They need to be treated as [`Option<RecordBatch>`]es and handled separately
828 debug_assert!(output_batch.num_rows() > 0);
829 return Poll::Ready(Some(Ok(
830 output_batch.record_output(&self.baseline_metrics)
831 )));
832 }
833
834 ExecutionState::Done => {
835 // release the memory reservation since sending back output batch itself needs
836 // some memory reservation, so make some room for it.
837 self.clear_all();
838 let _ = self.update_memory_reservation();
839 return Poll::Ready(None);
840 }
841 }
842 }
843 }
844}
845
846impl RecordBatchStream for GroupedHashAggregateStream {
847 fn schema(&self) -> SchemaRef {
848 Arc::clone(&self.schema)
849 }
850}
851
852impl GroupedHashAggregateStream {
853 /// Perform group-by aggregation for the given [`RecordBatch`].
854 fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
855 // Evaluate the grouping expressions
856 let group_by_values = if self.spill_state.is_stream_merging {
857 evaluate_group_by(&self.spill_state.merging_group_by, &batch)?
858 } else {
859 evaluate_group_by(&self.group_by, &batch)?
860 };
861
862 // Only create the timer if there are actual aggregate arguments to evaluate
863 let timer = match (
864 self.spill_state.is_stream_merging,
865 self.spill_state.merging_aggregate_arguments.is_empty(),
866 self.aggregate_arguments.is_empty(),
867 ) {
868 (true, false, _) | (false, _, false) => {
869 Some(self.group_by_metrics.aggregate_arguments_time.timer())
870 }
871 _ => None,
872 };
873
874 // Evaluate the aggregation expressions.
875 let input_values = if self.spill_state.is_stream_merging {
876 evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)?
877 } else {
878 evaluate_many(&self.aggregate_arguments, &batch)?
879 };
880 drop(timer);
881
882 // Evaluate the filter expressions, if any, against the inputs
883 let filter_values = if self.spill_state.is_stream_merging {
884 let filter_expressions = vec![None; self.accumulators.len()];
885 evaluate_optional(&filter_expressions, &batch)?
886 } else {
887 evaluate_optional(&self.filter_expressions, &batch)?
888 };
889
890 for group_values in &group_by_values {
891 let groups_start_time = Instant::now();
892
893 // calculate the group indices for each input row
894 let starting_num_groups = self.group_values.len();
895 self.group_values
896 .intern(group_values, &mut self.current_group_indices)?;
897 let group_indices = &self.current_group_indices;
898
899 // Update ordering information if necessary
900 let total_num_groups = self.group_values.len();
901 if total_num_groups > starting_num_groups {
902 self.group_ordering.new_groups(
903 group_values,
904 group_indices,
905 total_num_groups,
906 )?;
907 }
908
909 // Use this instant for both measurements to save a syscall
910 let agg_start_time = Instant::now();
911 self.group_by_metrics
912 .time_calculating_group_ids
913 .add_duration(agg_start_time - groups_start_time);
914
915 // Gather the inputs to call the actual accumulator
916 let t = self
917 .accumulators
918 .iter_mut()
919 .zip(input_values.iter())
920 .zip(filter_values.iter());
921
922 for ((acc, values), opt_filter) in t {
923 let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
924
925 // Call the appropriate method on each aggregator with
926 // the entire input row and the relevant group indexes
927 match self.mode {
928 AggregateMode::Partial
929 | AggregateMode::Single
930 | AggregateMode::SinglePartitioned
931 if !self.spill_state.is_stream_merging =>
932 {
933 acc.update_batch(
934 values,
935 group_indices,
936 opt_filter,
937 total_num_groups,
938 )?;
939 }
940 _ => {
941 if opt_filter.is_some() {
942 return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage");
943 }
944
945 // if aggregation is over intermediate states,
946 // use merge
947 acc.merge_batch(values, group_indices, None, total_num_groups)?;
948 }
949 }
950 self.group_by_metrics
951 .aggregation_time
952 .add_elapsed(agg_start_time);
953 }
954 }
955
956 match self.update_memory_reservation() {
957 // Here we can ignore `insufficient_capacity_err` because we will spill later,
958 // but at least one batch should fit in the memory
959 Err(DataFusionError::ResourcesExhausted(_))
960 if self.group_values.len() >= self.batch_size =>
961 {
962 Ok(())
963 }
964 other => other,
965 }
966 }
967
968 fn update_memory_reservation(&mut self) -> Result<()> {
969 let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
970 let reservation_result = self.reservation.try_resize(
971 acc + self.group_values.size()
972 + self.group_ordering.size()
973 + self.current_group_indices.allocated_size(),
974 );
975
976 if reservation_result.is_ok() {
977 self.spill_state
978 .peak_mem_used
979 .set_max(self.reservation.size());
980 }
981
982 reservation_result
983 }
984
985 /// Create an output RecordBatch with the group keys and
986 /// accumulator states/values specified in emit_to
987 fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Option<RecordBatch>> {
988 let schema = if spilling {
989 Arc::clone(&self.spill_state.spill_schema)
990 } else {
991 self.schema()
992 };
993 if self.group_values.is_empty() {
994 return Ok(None);
995 }
996
997 let timer = self.group_by_metrics.emitting_time.timer();
998 let mut output = self.group_values.emit(emit_to)?;
999 if let EmitTo::First(n) = emit_to {
1000 self.group_ordering.remove_groups(n);
1001 }
1002
1003 // Next output each aggregate value
1004 for acc in self.accumulators.iter_mut() {
1005 match self.mode {
1006 AggregateMode::Partial => output.extend(acc.state(emit_to)?),
1007 _ if spilling => {
1008 // If spilling, output partial state because the spilled data will be
1009 // merged and re-evaluated later.
1010 output.extend(acc.state(emit_to)?)
1011 }
1012 AggregateMode::Final
1013 | AggregateMode::FinalPartitioned
1014 | AggregateMode::Single
1015 | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?),
1016 }
1017 }
1018 drop(timer);
1019
1020 // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is
1021 // over the target memory size after emission, we can emit again rather than returning Err.
1022 let _ = self.update_memory_reservation();
1023 let batch = RecordBatch::try_new(schema, output)?;
1024 debug_assert!(batch.num_rows() > 0);
1025
1026 Ok(Some(batch))
1027 }
1028
1029 /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
1030 /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the
1031 /// memory. Currently only [`GroupOrdering::None`] is supported for spilling.
1032 fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> {
1033 // TODO: support group_ordering for spilling
1034 if !self.group_values.is_empty()
1035 && batch.num_rows() > 0
1036 && matches!(self.group_ordering, GroupOrdering::None)
1037 && !self.spill_state.is_stream_merging
1038 && self.update_memory_reservation().is_err()
1039 {
1040 assert_ne!(self.mode, AggregateMode::Partial);
1041 self.spill()?;
1042 self.clear_shrink(batch);
1043 }
1044 Ok(())
1045 }
1046
1047 /// Emit all intermediate aggregation states, sort them, and store them on disk.
1048 /// This process helps in reducing memory pressure by allowing the data to be
1049 /// read back with streaming merge.
1050 fn spill(&mut self) -> Result<()> {
1051 // Emit and sort intermediate aggregation state
1052 let Some(emit) = self.emit(EmitTo::All, true)? else {
1053 return Ok(());
1054 };
1055 let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
1056
1057 // Spill sorted state to disk
1058 let spillfile = self
1059 .spill_state
1060 .spill_manager
1061 .spill_record_batch_by_size_and_return_max_batch_memory(
1062 &sorted,
1063 "HashAggSpill",
1064 self.batch_size,
1065 )?;
1066 match spillfile {
1067 Some((spillfile, max_record_batch_memory)) => {
1068 self.spill_state.spills.push(SortedSpillFile {
1069 file: spillfile,
1070 max_record_batch_memory,
1071 })
1072 }
1073 None => {
1074 return internal_err!(
1075 "Calling spill with no intermediate batch to spill"
1076 );
1077 }
1078 }
1079
1080 Ok(())
1081 }
1082
1083 /// Clear memory and shirk capacities to the size of the batch.
1084 fn clear_shrink(&mut self, batch: &RecordBatch) {
1085 self.group_values.clear_shrink(batch);
1086 self.current_group_indices.clear();
1087 self.current_group_indices.shrink_to(batch.num_rows());
1088 }
1089
1090 /// Clear memory and shirk capacities to zero.
1091 fn clear_all(&mut self) {
1092 let s = self.schema();
1093 self.clear_shrink(&RecordBatch::new_empty(s));
1094 }
1095
1096 /// Emit if the used memory exceeds the target for partial aggregation.
1097 /// Currently only [`GroupOrdering::None`] is supported for early emitting.
1098 /// TODO: support group_ordering for early emitting
1099 fn emit_early_if_necessary(&mut self) -> Result<()> {
1100 if self.group_values.len() >= self.batch_size
1101 && matches!(self.group_ordering, GroupOrdering::None)
1102 && self.update_memory_reservation().is_err()
1103 {
1104 assert_eq!(self.mode, AggregateMode::Partial);
1105 let n = self.group_values.len() / self.batch_size * self.batch_size;
1106 if let Some(batch) = self.emit(EmitTo::First(n), false)? {
1107 self.exec_state = ExecutionState::ProducingOutput(batch);
1108 };
1109 }
1110 Ok(())
1111 }
1112
1113 /// At this point, all the inputs are read and there are some spills.
1114 /// Emit the remaining rows and create a batch.
1115 /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
1116 /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
1117 fn update_merged_stream(&mut self) -> Result<()> {
1118 let Some(batch) = self.emit(EmitTo::All, true)? else {
1119 return Ok(());
1120 };
1121 // clear up memory for streaming_merge
1122 self.clear_all();
1123 self.update_memory_reservation()?;
1124 let mut streams: Vec<SendableRecordBatchStream> = vec![];
1125 let expr = self.spill_state.spill_expr.clone();
1126 let schema = batch.schema();
1127 streams.push(Box::pin(RecordBatchStreamAdapter::new(
1128 Arc::clone(&schema),
1129 futures::stream::once(futures::future::lazy(move |_| {
1130 sort_batch(&batch, &expr, None)
1131 })),
1132 )));
1133
1134 self.spill_state.is_stream_merging = true;
1135 self.input = StreamingMergeBuilder::new()
1136 .with_streams(streams)
1137 .with_schema(schema)
1138 .with_spill_manager(self.spill_state.spill_manager.clone())
1139 .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills))
1140 .with_expressions(&self.spill_state.spill_expr)
1141 .with_metrics(self.baseline_metrics.clone())
1142 .with_batch_size(self.batch_size)
1143 .with_reservation(self.reservation.new_empty())
1144 .build()?;
1145 self.input_done = false;
1146 self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
1147 Ok(())
1148 }
1149
1150 /// returns true if there is a soft groups limit and the number of distinct
1151 /// groups we have seen is over that limit
1152 fn hit_soft_group_limit(&self) -> bool {
1153 let Some(group_values_soft_limit) = self.group_values_soft_limit else {
1154 return false;
1155 };
1156 group_values_soft_limit <= self.group_values.len()
1157 }
1158
1159 /// common function for signalling end of processing of the input stream
1160 fn set_input_done_and_produce_output(&mut self) -> Result<()> {
1161 self.input_done = true;
1162 self.group_ordering.input_done();
1163 let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1164 let timer = elapsed_compute.timer();
1165 self.exec_state = if self.spill_state.spills.is_empty() {
1166 let batch = self.emit(EmitTo::All, false)?;
1167 batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
1168 } else {
1169 // If spill files exist, stream-merge them.
1170 self.update_merged_stream()?;
1171 ExecutionState::ReadingInput
1172 };
1173 timer.done();
1174 Ok(())
1175 }
1176
1177 /// Updates skip aggregation probe state.
1178 ///
1179 /// Notice: It should only be called in Partial aggregation
1180 fn update_skip_aggregation_probe(&mut self, input_rows: usize) {
1181 if let Some(probe) = self.skip_aggregation_probe.as_mut() {
1182 // Skip aggregation probe is not supported if stream has any spills,
1183 // currently spilling is not supported for Partial aggregation
1184 assert!(self.spill_state.spills.is_empty());
1185 probe.update_state(input_rows, self.group_values.len());
1186 };
1187 }
1188
1189 /// In case the probe indicates that aggregation may be
1190 /// skipped, forces stream to produce currently accumulated output.
1191 ///
1192 /// Notice: It should only be called in Partial aggregation
1193 fn switch_to_skip_aggregation(&mut self) -> Result<()> {
1194 if let Some(probe) = self.skip_aggregation_probe.as_mut() {
1195 if probe.should_skip() {
1196 if let Some(batch) = self.emit(EmitTo::All, false)? {
1197 self.exec_state = ExecutionState::ProducingOutput(batch);
1198 };
1199 }
1200 }
1201
1202 Ok(())
1203 }
1204
1205 /// Returns true if the aggregation probe indicates that aggregation
1206 /// should be skipped.
1207 ///
1208 /// Notice: It should only be called in Partial aggregation
1209 fn should_skip_aggregation(&self) -> bool {
1210 self.skip_aggregation_probe
1211 .as_ref()
1212 .is_some_and(|probe| probe.should_skip())
1213 }
1214
1215 /// Transforms input batch to intermediate aggregate state, without grouping it
1216 fn transform_to_states(&self, batch: RecordBatch) -> Result<RecordBatch> {
1217 let mut group_values = evaluate_group_by(&self.group_by, &batch)?;
1218 let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
1219 let filter_values = evaluate_optional(&self.filter_expressions, &batch)?;
1220
1221 if group_values.len() != 1 {
1222 return internal_err!("group_values expected to have single element");
1223 }
1224 let mut output = group_values.swap_remove(0);
1225
1226 let iter = self
1227 .accumulators
1228 .iter()
1229 .zip(input_values.iter())
1230 .zip(filter_values.iter());
1231
1232 for ((acc, values), opt_filter) in iter {
1233 let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
1234 output.extend(acc.convert_to_state(values, opt_filter)?);
1235 }
1236
1237 let states_batch = RecordBatch::try_new(self.schema(), output)?;
1238
1239 Ok(states_batch)
1240 }
1241}