datafusion_physical_plan/aggregates/
row_hash.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Hash aggregation
19
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::vec;
23
24use super::order::GroupOrdering;
25use super::AggregateExec;
26use crate::aggregates::group_values::{new_group_values, GroupByMetrics, GroupValues};
27use crate::aggregates::order::GroupOrderingFull;
28use crate::aggregates::{
29    create_schema, evaluate_group_by, evaluate_many, evaluate_optional, AggregateMode,
30    PhysicalGroupBy,
31};
32use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
33use crate::sorts::sort::sort_batch;
34use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
35use crate::spill::spill_manager::SpillManager;
36use crate::stream::RecordBatchStreamAdapter;
37use crate::{aggregates, metrics, PhysicalExpr};
38use crate::{RecordBatchStream, SendableRecordBatchStream};
39
40use arrow::array::*;
41use arrow::datatypes::SchemaRef;
42use datafusion_common::{internal_err, DataFusionError, Result};
43use datafusion_execution::memory_pool::proxy::VecAllocExt;
44use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
45use datafusion_execution::TaskContext;
46use datafusion_expr::{EmitTo, GroupsAccumulator};
47use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
48use datafusion_physical_expr::expressions::Column;
49use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr};
50use datafusion_physical_expr_common::sort_expr::LexOrdering;
51
52use datafusion_common::instant::Instant;
53use futures::ready;
54use futures::stream::{Stream, StreamExt};
55use log::debug;
56
57#[derive(Debug, Clone)]
58/// This object tracks the aggregation phase (input/output)
59pub(crate) enum ExecutionState {
60    ReadingInput,
61    /// When producing output, the remaining rows to output are stored
62    /// here and are sliced off as needed in batch_size chunks
63    ProducingOutput(RecordBatch),
64    /// Produce intermediate aggregate state for each input row without
65    /// aggregation.
66    ///
67    /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
68    SkippingAggregation,
69    /// All input has been consumed and all groups have been emitted
70    Done,
71}
72
73/// This encapsulates the spilling state
74struct SpillState {
75    // ========================================================================
76    // PROPERTIES:
77    // These fields are initialized at the start and remain constant throughout
78    // the execution.
79    // ========================================================================
80    /// Sorting expression for spilling batches
81    spill_expr: LexOrdering,
82
83    /// Schema for spilling batches
84    spill_schema: SchemaRef,
85
86    /// aggregate_arguments for merging spilled data
87    merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
88
89    /// GROUP BY expressions for merging spilled data
90    merging_group_by: PhysicalGroupBy,
91
92    /// Manages the process of spilling and reading back intermediate data
93    spill_manager: SpillManager,
94
95    // ========================================================================
96    // STATES:
97    // Fields changes during execution. Can be buffer, or state flags that
98    // influence the execution in parent `GroupedHashAggregateStream`
99    // ========================================================================
100    /// If data has previously been spilled, the locations of the
101    /// spill files (in Arrow IPC format)
102    spills: Vec<SortedSpillFile>,
103
104    /// true when streaming merge is in progress
105    is_stream_merging: bool,
106
107    // ========================================================================
108    // METRICS:
109    // ========================================================================
110    /// Peak memory used for buffered data.
111    /// Calculated as sum of peak memory values across partitions
112    peak_mem_used: metrics::Gauge,
113    // Metrics related to spilling are managed inside `spill_manager`
114}
115
116/// Tracks if the aggregate should skip partial aggregations
117///
118/// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
119struct SkipAggregationProbe {
120    // ========================================================================
121    // PROPERTIES:
122    // These fields are initialized at the start and remain constant throughout
123    // the execution.
124    // ========================================================================
125    /// Aggregation ratio check performed when the number of input rows exceeds
126    /// this threshold (from `SessionConfig`)
127    probe_rows_threshold: usize,
128    /// Maximum ratio of `num_groups` to `input_rows` for continuing aggregation
129    /// (from `SessionConfig`). If the ratio exceeds this value, aggregation
130    /// is skipped and input rows are directly converted to output
131    probe_ratio_threshold: f64,
132
133    // ========================================================================
134    // STATES:
135    // Fields changes during execution. Can be buffer, or state flags that
136    // influence the execution in parent `GroupedHashAggregateStream`
137    // ========================================================================
138    /// Number of processed input rows (updated during probing)
139    input_rows: usize,
140    /// Number of total group values for `input_rows` (updated during probing)
141    num_groups: usize,
142
143    /// Flag indicating further data aggregation may be skipped (decision made
144    /// when probing complete)
145    should_skip: bool,
146    /// Flag indicating further updates of `SkipAggregationProbe` state won't
147    /// make any effect (set either while probing or on probing completion)
148    is_locked: bool,
149
150    // ========================================================================
151    // METRICS:
152    // ========================================================================
153    /// Number of rows where state was output without aggregation.
154    ///
155    /// * If 0, all input rows were aggregated (should_skip was always false)
156    ///
157    /// * if greater than zero, the number of rows which were output directly
158    ///   without aggregation
159    skipped_aggregation_rows: metrics::Count,
160}
161
162impl SkipAggregationProbe {
163    fn new(
164        probe_rows_threshold: usize,
165        probe_ratio_threshold: f64,
166        skipped_aggregation_rows: metrics::Count,
167    ) -> Self {
168        Self {
169            input_rows: 0,
170            num_groups: 0,
171            probe_rows_threshold,
172            probe_ratio_threshold,
173            should_skip: false,
174            is_locked: false,
175            skipped_aggregation_rows,
176        }
177    }
178
179    /// Updates `SkipAggregationProbe` state:
180    /// - increments the number of input rows
181    /// - replaces the number of groups with the new value
182    /// - on `probe_rows_threshold` exceeded calculates
183    ///   aggregation ratio and sets `should_skip` flag
184    /// - if `should_skip` is set, locks further state updates
185    fn update_state(&mut self, input_rows: usize, num_groups: usize) {
186        if self.is_locked {
187            return;
188        }
189        self.input_rows += input_rows;
190        self.num_groups = num_groups;
191        if self.input_rows >= self.probe_rows_threshold {
192            self.should_skip = self.num_groups as f64 / self.input_rows as f64
193                >= self.probe_ratio_threshold;
194            self.is_locked = true;
195        }
196    }
197
198    fn should_skip(&self) -> bool {
199        self.should_skip
200    }
201
202    /// Record the number of rows that were output directly without aggregation
203    fn record_skipped(&mut self, batch: &RecordBatch) {
204        self.skipped_aggregation_rows.add(batch.num_rows());
205    }
206}
207
208/// HashTable based Grouping Aggregator
209///
210/// # Design Goals
211///
212/// This structure is designed so that updating the aggregates can be
213/// vectorized (done in a tight loop) without allocations. The
214/// accumulator state is *not* managed by this operator (e.g in the
215/// hash table) and instead is delegated to the individual
216/// accumulators which have type specialized inner loops that perform
217/// the aggregation.
218///
219/// # Architecture
220///
221/// ```text
222///
223///     Assigns a consecutive group           internally stores aggregate values
224///     index for each unique set                     for all groups
225///         of group values
226///
227///         ┌────────────┐              ┌──────────────┐       ┌──────────────┐
228///         │ ┌────────┐ │              │┌────────────┐│       │┌────────────┐│
229///         │ │  "A"   │ │              ││accumulator ││       ││accumulator ││
230///         │ ├────────┤ │              ││     0      ││       ││     N      ││
231///         │ │  "Z"   │ │              ││ ┌────────┐ ││       ││ ┌────────┐ ││
232///         │ └────────┘ │              ││ │ state  │ ││       ││ │ state  │ ││
233///         │            │              ││ │┌─────┐ │ ││  ...  ││ │┌─────┐ │ ││
234///         │    ...     │              ││ │├─────┤ │ ││       ││ │├─────┤ │ ││
235///         │            │              ││ │└─────┘ │ ││       ││ │└─────┘ │ ││
236///         │            │              ││ │        │ ││       ││ │        │ ││
237///         │ ┌────────┐ │              ││ │  ...   │ ││       ││ │  ...   │ ││
238///         │ │  "Q"   │ │              ││ │        │ ││       ││ │        │ ││
239///         │ └────────┘ │              ││ │┌─────┐ │ ││       ││ │┌─────┐ │ ││
240///         │            │              ││ │└─────┘ │ ││       ││ │└─────┘ │ ││
241///         └────────────┘              ││ └────────┘ ││       ││ └────────┘ ││
242///                                     │└────────────┘│       │└────────────┘│
243///                                     └──────────────┘       └──────────────┘
244///
245///         group_values                             accumulators
246///
247///  ```
248///
249/// For example, given a query like `COUNT(x), SUM(y) ... GROUP BY z`,
250/// [`group_values`] will store the distinct values of `z`. There will
251/// be one accumulator for `COUNT(x)`, specialized for the data type
252/// of `x` and one accumulator for `SUM(y)`, specialized for the data
253/// type of `y`.
254///
255/// # Discussion
256///
257/// [`group_values`] does not store any aggregate state inline. It only
258/// assigns "group indices", one for each (distinct) group value. The
259/// accumulators manage the in-progress aggregate state for each
260/// group, with the group values themselves are stored in
261/// [`group_values`] at the corresponding group index.
262///
263/// The accumulator state (e.g partial sums) is managed by and stored
264/// by a [`GroupsAccumulator`] accumulator. There is one accumulator
265/// per aggregate expression (COUNT, AVG, etc) in the
266/// stream. Internally, each `GroupsAccumulator` manages the state for
267/// multiple groups, and is passed `group_indexes` during update. Note
268/// The accumulator state is not managed by this operator (e.g in the
269/// hash table).
270///
271/// [`group_values`]: Self::group_values
272///
273/// # Partial Aggregate and multi-phase grouping
274///
275/// As described on [`Accumulator::state`], this operator is used in the context
276/// "multi-phase" grouping when the mode is [`AggregateMode::Partial`].
277///
278/// An important optimization for multi-phase partial aggregation is to skip
279/// partial aggregation when it is not effective enough to warrant the memory or
280/// CPU cost, as is often the case for queries many distinct groups (high
281/// cardinality group by). Memory is particularly important because each Partial
282/// aggregator must store the intermediate state for each group.
283///
284/// If the ratio of the number of groups to the number of input rows exceeds a
285/// threshold, and [`GroupsAccumulator::supports_convert_to_state`] is
286/// supported, this operator will stop applying Partial aggregation and directly
287/// pass the input rows to the next aggregation phase.
288///
289/// [`Accumulator::state`]: datafusion_expr::Accumulator::state
290///
291/// # Spilling (to disk)
292///
293/// The sizes of group values and accumulators can become large. Before that causes out of memory,
294/// this hash aggregator outputs partial states early for partial aggregation or spills to local
295/// disk using Arrow IPC format for final aggregation. For every input [`RecordBatch`], the memory
296/// manager checks whether the new input size meets the memory configuration. If not, outputting or
297/// spilling happens. For outputting, the final aggregation takes care of re-grouping. For spilling,
298/// later stream-merge sort on reading back the spilled data does re-grouping. Note the rows cannot
299/// be grouped once spilled onto disk, the read back data needs to be re-grouped again. In addition,
300/// re-grouping may cause out of memory again. Thus, re-grouping has to be a sort based aggregation.
301/// ```text
302/// Partial Aggregation [batch_size = 2] (max memory = 3 rows)
303///
304///  INPUTS        PARTIALLY AGGREGATED (UPDATE BATCH)   OUTPUTS
305/// ┌─────────┐    ┌─────────────────┐                  ┌─────────────────┐
306/// │ a │ b   │    │ a │    AVG(b)   │                  │ a │    AVG(b)   │
307/// │---│-----│    │   │[count]│[sum]│                  │   │[count]│[sum]│
308/// │ 3 │ 3.0 │ ─▶ │---│-------│-----│                  │---│-------│-----│
309/// │ 2 │ 2.0 │    │ 2 │ 1     │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1     │ 2.0 │
310/// └─────────┘    │ 3 │ 2     │ 7.0 │               │  │ 3 │ 2     │ 7.0 │
311/// ┌─────────┐ ─▶ │ 4 │ 1     │ 8.0 │               │  └─────────────────┘
312/// │ 3 │ 4.0 │    └─────────────────┘               └▶ ┌─────────────────┐
313/// │ 4 │ 8.0 │    ┌─────────────────┐                  │ 4 │ 1     │ 8.0 │
314/// └─────────┘    │ a │    AVG(b)   │               ┌▶ │ 1 │ 1     │ 1.0 │
315/// ┌─────────┐    │---│-------│-----│               │  └─────────────────┘
316/// │ 1 │ 1.0 │ ─▶ │ 1 │ 1     │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐
317/// │ 3 │ 2.0 │    │ 3 │ 1     │ 2.0 │                  │ 3 │ 1     │ 2.0 │
318/// └─────────┘    └─────────────────┘                  └─────────────────┘
319///
320///
321/// Final Aggregation [batch_size = 2] (max memory = 3 rows)
322///
323/// PARTIALLY INPUTS       FINAL AGGREGATION (MERGE BATCH)       RE-GROUPED (SORTED)
324/// ┌─────────────────┐    [keep using the partial schema]       [Real final aggregation
325/// │ a │    AVG(b)   │    ┌─────────────────┐                    output]
326/// │   │[count]│[sum]│    │ a │    AVG(b)   │                   ┌────────────┐
327/// │---│-------│-----│ ─▶ │   │[count]│[sum]│                   │ a │ AVG(b) │
328/// │ 3 │ 3     │ 3.0 │    │---│-------│-----│ ─▶ spill ─┐       │---│--------│
329/// │ 2 │ 2     │ 1.0 │    │ 2 │ 2     │ 1.0 │           │       │ 1 │    4.0 │
330/// └─────────────────┘    │ 3 │ 4     │ 8.0 │           ▼       │ 2 │    1.0 │
331/// ┌─────────────────┐ ─▶ │ 4 │ 1     │ 7.0 │     Streaming  ─▶ └────────────┘
332/// │ 3 │ 1     │ 5.0 │    └─────────────────┘     merge sort ─▶ ┌────────────┐
333/// │ 4 │ 1     │ 7.0 │    ┌─────────────────┐            ▲      │ a │ AVG(b) │
334/// └─────────────────┘    │ a │    AVG(b)   │            │      │---│--------│
335/// ┌─────────────────┐    │---│-------│-----│ ─▶ memory ─┘      │ 3 │    2.0 │
336/// │ 1 │ 2     │ 8.0 │ ─▶ │ 1 │ 2     │ 8.0 │                   │ 4 │    7.0 │
337/// │ 2 │ 2     │ 3.0 │    │ 2 │ 2     │ 3.0 │                   └────────────┘
338/// └─────────────────┘    └─────────────────┘
339/// ```
340pub(crate) struct GroupedHashAggregateStream {
341    // ========================================================================
342    // PROPERTIES:
343    // These fields are initialized at the start and remain constant throughout
344    // the execution.
345    // ========================================================================
346    schema: SchemaRef,
347    input: SendableRecordBatchStream,
348    mode: AggregateMode,
349
350    /// Arguments to pass to each accumulator.
351    ///
352    /// The arguments in `accumulator[i]` is passed `aggregate_arguments[i]`
353    ///
354    /// The argument to each accumulator is itself a `Vec` because
355    /// some aggregates such as `CORR` can accept more than one
356    /// argument.
357    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
358
359    /// Optional filter expression to evaluate, one for each for
360    /// accumulator. If present, only those rows for which the filter
361    /// evaluate to true should be included in the aggregate results.
362    ///
363    /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`,
364    /// the filter expression is  `x > 100`.
365    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
366
367    /// GROUP BY expressions
368    group_by: PhysicalGroupBy,
369
370    /// max rows in output RecordBatches
371    batch_size: usize,
372
373    /// Optional soft limit on the number of `group_values` in a batch
374    /// If the number of `group_values` in a single batch exceeds this value,
375    /// the `GroupedHashAggregateStream` operation immediately switches to
376    /// output mode and emits all groups.
377    group_values_soft_limit: Option<usize>,
378
379    // ========================================================================
380    // STATE FLAGS:
381    // These fields will be updated during the execution. And control the flow of
382    // the execution.
383    // ========================================================================
384    /// Tracks if this stream is generating input or output
385    exec_state: ExecutionState,
386
387    /// Have we seen the end of the input
388    input_done: bool,
389
390    // ========================================================================
391    // STATE BUFFERS:
392    // These fields will accumulate intermediate results during the execution.
393    // ========================================================================
394    /// An interning store of group keys
395    group_values: Box<dyn GroupValues>,
396
397    /// scratch space for the current input [`RecordBatch`] being
398    /// processed. Reused across batches here to avoid reallocations
399    current_group_indices: Vec<usize>,
400
401    /// Accumulators, one for each `AggregateFunctionExpr` in the query
402    ///
403    /// For example, if the query has aggregates, `SUM(x)`,
404    /// `COUNT(y)`, there will be two accumulators, each one
405    /// specialized for that particular aggregate and its input types
406    accumulators: Vec<Box<dyn GroupsAccumulator>>,
407
408    // ========================================================================
409    // TASK-SPECIFIC STATES:
410    // Inner states groups together properties, states for a specific task.
411    // ========================================================================
412    /// Optional ordering information, that might allow groups to be
413    /// emitted from the hash table prior to seeing the end of the
414    /// input
415    group_ordering: GroupOrdering,
416
417    /// The spill state object
418    spill_state: SpillState,
419
420    /// Optional probe for skipping data aggregation, if supported by
421    /// current stream.
422    skip_aggregation_probe: Option<SkipAggregationProbe>,
423
424    // ========================================================================
425    // EXECUTION RESOURCES:
426    // Fields related to managing execution resources and monitoring performance.
427    // ========================================================================
428    /// The memory reservation for this grouping
429    reservation: MemoryReservation,
430
431    /// Execution metrics
432    baseline_metrics: BaselineMetrics,
433
434    /// Aggregation-specific metrics
435    group_by_metrics: GroupByMetrics,
436
437    /// Reduction factor metric, calculated as `output_rows/input_rows` (only for partial aggregation)
438    reduction_factor: Option<metrics::RatioMetrics>,
439}
440
441impl GroupedHashAggregateStream {
442    /// Create a new GroupedHashAggregateStream
443    pub fn new(
444        agg: &AggregateExec,
445        context: Arc<TaskContext>,
446        partition: usize,
447    ) -> Result<Self> {
448        debug!("Creating GroupedHashAggregateStream");
449        let agg_schema = Arc::clone(&agg.schema);
450        let agg_group_by = agg.group_by.clone();
451        let agg_filter_expr = agg.filter_expr.clone();
452
453        let batch_size = context.session_config().batch_size();
454        let input = agg.input.execute(partition, Arc::clone(&context))?;
455        let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
456        let group_by_metrics = GroupByMetrics::new(&agg.metrics, partition);
457
458        let timer = baseline_metrics.elapsed_compute().timer();
459
460        let aggregate_exprs = agg.aggr_expr.clone();
461
462        // arguments for each aggregate, one vec of expressions per
463        // aggregate
464        let aggregate_arguments = aggregates::aggregate_expressions(
465            &agg.aggr_expr,
466            &agg.mode,
467            agg_group_by.num_group_exprs(),
468        )?;
469        // arguments for aggregating spilled data is the same as the one for final aggregation
470        let merging_aggregate_arguments = aggregates::aggregate_expressions(
471            &agg.aggr_expr,
472            &AggregateMode::Final,
473            agg_group_by.num_group_exprs(),
474        )?;
475
476        let filter_expressions = match agg.mode {
477            AggregateMode::Partial
478            | AggregateMode::Single
479            | AggregateMode::SinglePartitioned => agg_filter_expr,
480            AggregateMode::Final | AggregateMode::FinalPartitioned => {
481                vec![None; agg.aggr_expr.len()]
482            }
483        };
484
485        // Instantiate the accumulators
486        let accumulators: Vec<_> = aggregate_exprs
487            .iter()
488            .map(create_group_accumulator)
489            .collect::<Result<_>>()?;
490
491        let group_schema = agg_group_by.group_schema(&agg.input().schema())?;
492
493        // fix https://github.com/apache/datafusion/issues/13949
494        // Builds a **partial aggregation** schema by combining the group columns and
495        // the accumulator state columns produced by each aggregate expression.
496        //
497        // # Why Partial Aggregation Schema Is Needed
498        //
499        // In a multi-stage (partial/final) aggregation strategy, each partial-aggregate
500        // operator produces *intermediate* states (e.g., partial sums, counts) rather
501        // than final scalar values. These extra columns do **not** exist in the original
502        // input schema (which may be something like `[colA, colB, ...]`). Instead,
503        // each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`).
504        //
505        // Therefore, when we spill these intermediate states or pass them to another
506        // aggregation operator, we must use a schema that includes both the group
507        // columns **and** the partial-state columns.
508        let partial_agg_schema = create_schema(
509            &agg.input().schema(),
510            &agg_group_by,
511            &aggregate_exprs,
512            AggregateMode::Partial,
513        )?;
514
515        // Need to update the GROUP BY expressions to point to the correct column after schema change
516        let merging_group_by_expr = agg_group_by
517            .expr
518            .iter()
519            .enumerate()
520            .map(|(idx, (_, name))| {
521                (Arc::new(Column::new(name.as_str(), idx)) as _, name.clone())
522            })
523            .collect();
524
525        let partial_agg_schema = Arc::new(partial_agg_schema);
526
527        let spill_expr =
528            group_schema
529                .fields
530                .into_iter()
531                .enumerate()
532                .map(|(idx, field)| {
533                    PhysicalSortExpr::new_default(Arc::new(Column::new(
534                        field.name().as_str(),
535                        idx,
536                    )) as _)
537                });
538        let Some(spill_expr) = LexOrdering::new(spill_expr) else {
539            return internal_err!("Spill expression is empty");
540        };
541
542        let agg_fn_names = aggregate_exprs
543            .iter()
544            .map(|expr| expr.human_display())
545            .collect::<Vec<_>>()
546            .join(", ");
547        let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})");
548        let reservation = MemoryConsumer::new(name)
549            .with_can_spill(true)
550            .register(context.memory_pool());
551        let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?;
552        let group_values = new_group_values(group_schema, &group_ordering)?;
553        timer.done();
554
555        let exec_state = ExecutionState::ReadingInput;
556
557        let spill_manager = SpillManager::new(
558            context.runtime_env(),
559            metrics::SpillMetrics::new(&agg.metrics, partition),
560            Arc::clone(&partial_agg_schema),
561        )
562        .with_compression_type(context.session_config().spill_compression());
563
564        let spill_state = SpillState {
565            spills: vec![],
566            spill_expr,
567            spill_schema: partial_agg_schema,
568            is_stream_merging: false,
569            merging_aggregate_arguments,
570            merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr),
571            peak_mem_used: MetricBuilder::new(&agg.metrics)
572                .gauge("peak_mem_used", partition),
573            spill_manager,
574        };
575
576        // Skip aggregation is supported if:
577        // - aggregation mode is Partial
578        // - input is not ordered by GROUP BY expressions,
579        //   since Final mode expects unique group values as its input
580        // - all accumulators support input batch to intermediate
581        //   aggregate state conversion
582        // - there is only one GROUP BY expressions set
583        let skip_aggregation_probe = if agg.mode == AggregateMode::Partial
584            && matches!(group_ordering, GroupOrdering::None)
585            && accumulators
586                .iter()
587                .all(|acc| acc.supports_convert_to_state())
588            && agg_group_by.is_single()
589        {
590            let options = &context.session_config().options().execution;
591            let probe_rows_threshold =
592                options.skip_partial_aggregation_probe_rows_threshold;
593            let probe_ratio_threshold =
594                options.skip_partial_aggregation_probe_ratio_threshold;
595            let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics)
596                .counter("skipped_aggregation_rows", partition);
597            Some(SkipAggregationProbe::new(
598                probe_rows_threshold,
599                probe_ratio_threshold,
600                skipped_aggregation_rows,
601            ))
602        } else {
603            None
604        };
605
606        let reduction_factor = if agg.mode == AggregateMode::Partial {
607            Some(
608                MetricBuilder::new(&agg.metrics)
609                    .with_type(metrics::MetricType::SUMMARY)
610                    .ratio_metrics("reduction_factor", partition),
611            )
612        } else {
613            None
614        };
615
616        Ok(GroupedHashAggregateStream {
617            schema: agg_schema,
618            input,
619            mode: agg.mode,
620            accumulators,
621            aggregate_arguments,
622            filter_expressions,
623            group_by: agg_group_by,
624            reservation,
625            group_values,
626            current_group_indices: Default::default(),
627            exec_state,
628            baseline_metrics,
629            group_by_metrics,
630            batch_size,
631            group_ordering,
632            input_done: false,
633            spill_state,
634            group_values_soft_limit: agg.limit,
635            skip_aggregation_probe,
636            reduction_factor,
637        })
638    }
639}
640
641/// Create an accumulator for `agg_expr` -- a [`GroupsAccumulator`] if
642/// that is supported by the aggregate, or a
643/// [`GroupsAccumulatorAdapter`] if not.
644pub(crate) fn create_group_accumulator(
645    agg_expr: &Arc<AggregateFunctionExpr>,
646) -> Result<Box<dyn GroupsAccumulator>> {
647    if agg_expr.groups_accumulator_supported() {
648        agg_expr.create_groups_accumulator()
649    } else {
650        // Note in the log when the slow path is used
651        debug!(
652            "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}",
653            agg_expr.name()
654        );
655        let agg_expr_captured = Arc::clone(agg_expr);
656        let factory = move || agg_expr_captured.create_accumulator();
657        Ok(Box::new(GroupsAccumulatorAdapter::new(factory)))
658    }
659}
660
661impl Stream for GroupedHashAggregateStream {
662    type Item = Result<RecordBatch>;
663
664    fn poll_next(
665        mut self: std::pin::Pin<&mut Self>,
666        cx: &mut Context<'_>,
667    ) -> Poll<Option<Self::Item>> {
668        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
669
670        loop {
671            match &self.exec_state {
672                ExecutionState::ReadingInput => 'reading_input: {
673                    match ready!(self.input.poll_next_unpin(cx)) {
674                        // New batch to aggregate in partial aggregation operator
675                        Some(Ok(batch)) if self.mode == AggregateMode::Partial => {
676                            let timer = elapsed_compute.timer();
677                            let input_rows = batch.num_rows();
678
679                            if let Some(reduction_factor) = self.reduction_factor.as_ref()
680                            {
681                                reduction_factor.add_total(input_rows);
682                            }
683
684                            // Do the grouping
685                            self.group_aggregate_batch(batch)?;
686
687                            self.update_skip_aggregation_probe(input_rows);
688
689                            // If we can begin emitting rows, do so,
690                            // otherwise keep consuming input
691                            assert!(!self.input_done);
692
693                            // If the number of group values equals or exceeds the soft limit,
694                            // emit all groups and switch to producing output
695                            if self.hit_soft_group_limit() {
696                                timer.done();
697                                self.set_input_done_and_produce_output()?;
698                                // make sure the exec_state just set is not overwritten below
699                                break 'reading_input;
700                            }
701
702                            if let Some(to_emit) = self.group_ordering.emit_to() {
703                                timer.done();
704                                if let Some(batch) = self.emit(to_emit, false)? {
705                                    self.exec_state =
706                                        ExecutionState::ProducingOutput(batch);
707                                };
708                                // make sure the exec_state just set is not overwritten below
709                                break 'reading_input;
710                            }
711
712                            self.emit_early_if_necessary()?;
713
714                            self.switch_to_skip_aggregation()?;
715
716                            timer.done();
717                        }
718
719                        // New batch to aggregate in terminal aggregation operator
720                        // (Final/FinalPartitioned/Single/SinglePartitioned)
721                        Some(Ok(batch)) => {
722                            let timer = elapsed_compute.timer();
723
724                            // Make sure we have enough capacity for `batch`, otherwise spill
725                            self.spill_previous_if_necessary(&batch)?;
726
727                            // Do the grouping
728                            self.group_aggregate_batch(batch)?;
729
730                            // If we can begin emitting rows, do so,
731                            // otherwise keep consuming input
732                            assert!(!self.input_done);
733
734                            // If the number of group values equals or exceeds the soft limit,
735                            // emit all groups and switch to producing output
736                            if self.hit_soft_group_limit() {
737                                timer.done();
738                                self.set_input_done_and_produce_output()?;
739                                // make sure the exec_state just set is not overwritten below
740                                break 'reading_input;
741                            }
742
743                            if let Some(to_emit) = self.group_ordering.emit_to() {
744                                timer.done();
745                                if let Some(batch) = self.emit(to_emit, false)? {
746                                    self.exec_state =
747                                        ExecutionState::ProducingOutput(batch);
748                                };
749                                // make sure the exec_state just set is not overwritten below
750                                break 'reading_input;
751                            }
752
753                            timer.done();
754                        }
755
756                        // Found error from input stream
757                        Some(Err(e)) => {
758                            // inner had error, return to caller
759                            return Poll::Ready(Some(Err(e)));
760                        }
761
762                        // Found end from input stream
763                        None => {
764                            // inner is done, emit all rows and switch to producing output
765                            self.set_input_done_and_produce_output()?;
766                        }
767                    }
768                }
769
770                ExecutionState::SkippingAggregation => {
771                    match ready!(self.input.poll_next_unpin(cx)) {
772                        Some(Ok(batch)) => {
773                            let _timer = elapsed_compute.timer();
774                            if let Some(probe) = self.skip_aggregation_probe.as_mut() {
775                                probe.record_skipped(&batch);
776                            }
777                            let states = self.transform_to_states(batch)?;
778                            return Poll::Ready(Some(Ok(
779                                states.record_output(&self.baseline_metrics)
780                            )));
781                        }
782                        Some(Err(e)) => {
783                            // inner had error, return to caller
784                            return Poll::Ready(Some(Err(e)));
785                        }
786                        None => {
787                            // inner is done, switching to `Done` state
788                            self.exec_state = ExecutionState::Done;
789                        }
790                    }
791                }
792
793                ExecutionState::ProducingOutput(batch) => {
794                    // slice off a part of the batch, if needed
795                    let output_batch;
796                    let size = self.batch_size;
797                    (self.exec_state, output_batch) = if batch.num_rows() <= size {
798                        (
799                            if self.input_done {
800                                ExecutionState::Done
801                            }
802                            // In Partial aggregation, we also need to check
803                            // if we should trigger partial skipping
804                            else if self.mode == AggregateMode::Partial
805                                && self.should_skip_aggregation()
806                            {
807                                ExecutionState::SkippingAggregation
808                            } else {
809                                ExecutionState::ReadingInput
810                            },
811                            batch.clone(),
812                        )
813                    } else {
814                        // output first batch_size rows
815                        let size = self.batch_size;
816                        let num_remaining = batch.num_rows() - size;
817                        let remaining = batch.slice(size, num_remaining);
818                        let output = batch.slice(0, size);
819                        (ExecutionState::ProducingOutput(remaining), output)
820                    };
821
822                    if let Some(reduction_factor) = self.reduction_factor.as_ref() {
823                        reduction_factor.add_part(output_batch.num_rows());
824                    }
825
826                    // Empty record batches should not be emitted.
827                    // They need to be treated as  [`Option<RecordBatch>`]es and handled separately
828                    debug_assert!(output_batch.num_rows() > 0);
829                    return Poll::Ready(Some(Ok(
830                        output_batch.record_output(&self.baseline_metrics)
831                    )));
832                }
833
834                ExecutionState::Done => {
835                    // release the memory reservation since sending back output batch itself needs
836                    // some memory reservation, so make some room for it.
837                    self.clear_all();
838                    let _ = self.update_memory_reservation();
839                    return Poll::Ready(None);
840                }
841            }
842        }
843    }
844}
845
846impl RecordBatchStream for GroupedHashAggregateStream {
847    fn schema(&self) -> SchemaRef {
848        Arc::clone(&self.schema)
849    }
850}
851
852impl GroupedHashAggregateStream {
853    /// Perform group-by aggregation for the given [`RecordBatch`].
854    fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
855        // Evaluate the grouping expressions
856        let group_by_values = if self.spill_state.is_stream_merging {
857            evaluate_group_by(&self.spill_state.merging_group_by, &batch)?
858        } else {
859            evaluate_group_by(&self.group_by, &batch)?
860        };
861
862        // Only create the timer if there are actual aggregate arguments to evaluate
863        let timer = match (
864            self.spill_state.is_stream_merging,
865            self.spill_state.merging_aggregate_arguments.is_empty(),
866            self.aggregate_arguments.is_empty(),
867        ) {
868            (true, false, _) | (false, _, false) => {
869                Some(self.group_by_metrics.aggregate_arguments_time.timer())
870            }
871            _ => None,
872        };
873
874        // Evaluate the aggregation expressions.
875        let input_values = if self.spill_state.is_stream_merging {
876            evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)?
877        } else {
878            evaluate_many(&self.aggregate_arguments, &batch)?
879        };
880        drop(timer);
881
882        // Evaluate the filter expressions, if any, against the inputs
883        let filter_values = if self.spill_state.is_stream_merging {
884            let filter_expressions = vec![None; self.accumulators.len()];
885            evaluate_optional(&filter_expressions, &batch)?
886        } else {
887            evaluate_optional(&self.filter_expressions, &batch)?
888        };
889
890        for group_values in &group_by_values {
891            let groups_start_time = Instant::now();
892
893            // calculate the group indices for each input row
894            let starting_num_groups = self.group_values.len();
895            self.group_values
896                .intern(group_values, &mut self.current_group_indices)?;
897            let group_indices = &self.current_group_indices;
898
899            // Update ordering information if necessary
900            let total_num_groups = self.group_values.len();
901            if total_num_groups > starting_num_groups {
902                self.group_ordering.new_groups(
903                    group_values,
904                    group_indices,
905                    total_num_groups,
906                )?;
907            }
908
909            // Use this instant for both measurements to save a syscall
910            let agg_start_time = Instant::now();
911            self.group_by_metrics
912                .time_calculating_group_ids
913                .add_duration(agg_start_time - groups_start_time);
914
915            // Gather the inputs to call the actual accumulator
916            let t = self
917                .accumulators
918                .iter_mut()
919                .zip(input_values.iter())
920                .zip(filter_values.iter());
921
922            for ((acc, values), opt_filter) in t {
923                let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
924
925                // Call the appropriate method on each aggregator with
926                // the entire input row and the relevant group indexes
927                match self.mode {
928                    AggregateMode::Partial
929                    | AggregateMode::Single
930                    | AggregateMode::SinglePartitioned
931                        if !self.spill_state.is_stream_merging =>
932                    {
933                        acc.update_batch(
934                            values,
935                            group_indices,
936                            opt_filter,
937                            total_num_groups,
938                        )?;
939                    }
940                    _ => {
941                        if opt_filter.is_some() {
942                            return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage");
943                        }
944
945                        // if aggregation is over intermediate states,
946                        // use merge
947                        acc.merge_batch(values, group_indices, None, total_num_groups)?;
948                    }
949                }
950                self.group_by_metrics
951                    .aggregation_time
952                    .add_elapsed(agg_start_time);
953            }
954        }
955
956        match self.update_memory_reservation() {
957            // Here we can ignore `insufficient_capacity_err` because we will spill later,
958            // but at least one batch should fit in the memory
959            Err(DataFusionError::ResourcesExhausted(_))
960                if self.group_values.len() >= self.batch_size =>
961            {
962                Ok(())
963            }
964            other => other,
965        }
966    }
967
968    fn update_memory_reservation(&mut self) -> Result<()> {
969        let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
970        let reservation_result = self.reservation.try_resize(
971            acc + self.group_values.size()
972                + self.group_ordering.size()
973                + self.current_group_indices.allocated_size(),
974        );
975
976        if reservation_result.is_ok() {
977            self.spill_state
978                .peak_mem_used
979                .set_max(self.reservation.size());
980        }
981
982        reservation_result
983    }
984
985    /// Create an output RecordBatch with the group keys and
986    /// accumulator states/values specified in emit_to
987    fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Option<RecordBatch>> {
988        let schema = if spilling {
989            Arc::clone(&self.spill_state.spill_schema)
990        } else {
991            self.schema()
992        };
993        if self.group_values.is_empty() {
994            return Ok(None);
995        }
996
997        let timer = self.group_by_metrics.emitting_time.timer();
998        let mut output = self.group_values.emit(emit_to)?;
999        if let EmitTo::First(n) = emit_to {
1000            self.group_ordering.remove_groups(n);
1001        }
1002
1003        // Next output each aggregate value
1004        for acc in self.accumulators.iter_mut() {
1005            match self.mode {
1006                AggregateMode::Partial => output.extend(acc.state(emit_to)?),
1007                _ if spilling => {
1008                    // If spilling, output partial state because the spilled data will be
1009                    // merged and re-evaluated later.
1010                    output.extend(acc.state(emit_to)?)
1011                }
1012                AggregateMode::Final
1013                | AggregateMode::FinalPartitioned
1014                | AggregateMode::Single
1015                | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?),
1016            }
1017        }
1018        drop(timer);
1019
1020        // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is
1021        // over the target memory size after emission, we can emit again rather than returning Err.
1022        let _ = self.update_memory_reservation();
1023        let batch = RecordBatch::try_new(schema, output)?;
1024        debug_assert!(batch.num_rows() > 0);
1025
1026        Ok(Some(batch))
1027    }
1028
1029    /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
1030    /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the
1031    /// memory. Currently only [`GroupOrdering::None`] is supported for spilling.
1032    fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> {
1033        // TODO: support group_ordering for spilling
1034        if !self.group_values.is_empty()
1035            && batch.num_rows() > 0
1036            && matches!(self.group_ordering, GroupOrdering::None)
1037            && !self.spill_state.is_stream_merging
1038            && self.update_memory_reservation().is_err()
1039        {
1040            assert_ne!(self.mode, AggregateMode::Partial);
1041            self.spill()?;
1042            self.clear_shrink(batch);
1043        }
1044        Ok(())
1045    }
1046
1047    /// Emit all intermediate aggregation states, sort them, and store them on disk.
1048    /// This process helps in reducing memory pressure by allowing the data to be
1049    /// read back with streaming merge.
1050    fn spill(&mut self) -> Result<()> {
1051        // Emit and sort intermediate aggregation state
1052        let Some(emit) = self.emit(EmitTo::All, true)? else {
1053            return Ok(());
1054        };
1055        let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
1056
1057        // Spill sorted state to disk
1058        let spillfile = self
1059            .spill_state
1060            .spill_manager
1061            .spill_record_batch_by_size_and_return_max_batch_memory(
1062                &sorted,
1063                "HashAggSpill",
1064                self.batch_size,
1065            )?;
1066        match spillfile {
1067            Some((spillfile, max_record_batch_memory)) => {
1068                self.spill_state.spills.push(SortedSpillFile {
1069                    file: spillfile,
1070                    max_record_batch_memory,
1071                })
1072            }
1073            None => {
1074                return internal_err!(
1075                    "Calling spill with no intermediate batch to spill"
1076                );
1077            }
1078        }
1079
1080        Ok(())
1081    }
1082
1083    /// Clear memory and shirk capacities to the size of the batch.
1084    fn clear_shrink(&mut self, batch: &RecordBatch) {
1085        self.group_values.clear_shrink(batch);
1086        self.current_group_indices.clear();
1087        self.current_group_indices.shrink_to(batch.num_rows());
1088    }
1089
1090    /// Clear memory and shirk capacities to zero.
1091    fn clear_all(&mut self) {
1092        let s = self.schema();
1093        self.clear_shrink(&RecordBatch::new_empty(s));
1094    }
1095
1096    /// Emit if the used memory exceeds the target for partial aggregation.
1097    /// Currently only [`GroupOrdering::None`] is supported for early emitting.
1098    /// TODO: support group_ordering for early emitting
1099    fn emit_early_if_necessary(&mut self) -> Result<()> {
1100        if self.group_values.len() >= self.batch_size
1101            && matches!(self.group_ordering, GroupOrdering::None)
1102            && self.update_memory_reservation().is_err()
1103        {
1104            assert_eq!(self.mode, AggregateMode::Partial);
1105            let n = self.group_values.len() / self.batch_size * self.batch_size;
1106            if let Some(batch) = self.emit(EmitTo::First(n), false)? {
1107                self.exec_state = ExecutionState::ProducingOutput(batch);
1108            };
1109        }
1110        Ok(())
1111    }
1112
1113    /// At this point, all the inputs are read and there are some spills.
1114    /// Emit the remaining rows and create a batch.
1115    /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
1116    /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
1117    fn update_merged_stream(&mut self) -> Result<()> {
1118        let Some(batch) = self.emit(EmitTo::All, true)? else {
1119            return Ok(());
1120        };
1121        // clear up memory for streaming_merge
1122        self.clear_all();
1123        self.update_memory_reservation()?;
1124        let mut streams: Vec<SendableRecordBatchStream> = vec![];
1125        let expr = self.spill_state.spill_expr.clone();
1126        let schema = batch.schema();
1127        streams.push(Box::pin(RecordBatchStreamAdapter::new(
1128            Arc::clone(&schema),
1129            futures::stream::once(futures::future::lazy(move |_| {
1130                sort_batch(&batch, &expr, None)
1131            })),
1132        )));
1133
1134        self.spill_state.is_stream_merging = true;
1135        self.input = StreamingMergeBuilder::new()
1136            .with_streams(streams)
1137            .with_schema(schema)
1138            .with_spill_manager(self.spill_state.spill_manager.clone())
1139            .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills))
1140            .with_expressions(&self.spill_state.spill_expr)
1141            .with_metrics(self.baseline_metrics.clone())
1142            .with_batch_size(self.batch_size)
1143            .with_reservation(self.reservation.new_empty())
1144            .build()?;
1145        self.input_done = false;
1146        self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
1147        Ok(())
1148    }
1149
1150    /// returns true if there is a soft groups limit and the number of distinct
1151    /// groups we have seen is over that limit
1152    fn hit_soft_group_limit(&self) -> bool {
1153        let Some(group_values_soft_limit) = self.group_values_soft_limit else {
1154            return false;
1155        };
1156        group_values_soft_limit <= self.group_values.len()
1157    }
1158
1159    /// common function for signalling end of processing of the input stream
1160    fn set_input_done_and_produce_output(&mut self) -> Result<()> {
1161        self.input_done = true;
1162        self.group_ordering.input_done();
1163        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
1164        let timer = elapsed_compute.timer();
1165        self.exec_state = if self.spill_state.spills.is_empty() {
1166            let batch = self.emit(EmitTo::All, false)?;
1167            batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput)
1168        } else {
1169            // If spill files exist, stream-merge them.
1170            self.update_merged_stream()?;
1171            ExecutionState::ReadingInput
1172        };
1173        timer.done();
1174        Ok(())
1175    }
1176
1177    /// Updates skip aggregation probe state.
1178    ///
1179    /// Notice: It should only be called in Partial aggregation
1180    fn update_skip_aggregation_probe(&mut self, input_rows: usize) {
1181        if let Some(probe) = self.skip_aggregation_probe.as_mut() {
1182            // Skip aggregation probe is not supported if stream has any spills,
1183            // currently spilling is not supported for Partial aggregation
1184            assert!(self.spill_state.spills.is_empty());
1185            probe.update_state(input_rows, self.group_values.len());
1186        };
1187    }
1188
1189    /// In case the probe indicates that aggregation may be
1190    /// skipped, forces stream to produce currently accumulated output.
1191    ///
1192    /// Notice: It should only be called in Partial aggregation
1193    fn switch_to_skip_aggregation(&mut self) -> Result<()> {
1194        if let Some(probe) = self.skip_aggregation_probe.as_mut() {
1195            if probe.should_skip() {
1196                if let Some(batch) = self.emit(EmitTo::All, false)? {
1197                    self.exec_state = ExecutionState::ProducingOutput(batch);
1198                };
1199            }
1200        }
1201
1202        Ok(())
1203    }
1204
1205    /// Returns true if the aggregation probe indicates that aggregation
1206    /// should be skipped.
1207    ///
1208    /// Notice: It should only be called in Partial aggregation
1209    fn should_skip_aggregation(&self) -> bool {
1210        self.skip_aggregation_probe
1211            .as_ref()
1212            .is_some_and(|probe| probe.should_skip())
1213    }
1214
1215    /// Transforms input batch to intermediate aggregate state, without grouping it
1216    fn transform_to_states(&self, batch: RecordBatch) -> Result<RecordBatch> {
1217        let mut group_values = evaluate_group_by(&self.group_by, &batch)?;
1218        let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
1219        let filter_values = evaluate_optional(&self.filter_expressions, &batch)?;
1220
1221        if group_values.len() != 1 {
1222            return internal_err!("group_values expected to have single element");
1223        }
1224        let mut output = group_values.swap_remove(0);
1225
1226        let iter = self
1227            .accumulators
1228            .iter()
1229            .zip(input_values.iter())
1230            .zip(filter_values.iter());
1231
1232        for ((acc, values), opt_filter) in iter {
1233            let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
1234            output.extend(acc.convert_to_state(values, opt_filter)?);
1235        }
1236
1237        let states_batch = RecordBatch::try_new(self.schema(), output)?;
1238
1239        Ok(states_batch)
1240    }
1241}