datafusion_physical_plan/aggregates/
no_grouping.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Aggregate without grouping columns
19
20use crate::aggregates::{
21    aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
22    AggregateMode,
23};
24use crate::metrics::{BaselineMetrics, RecordOutput};
25use crate::{RecordBatchStream, SendableRecordBatchStream};
26use arrow::datatypes::SchemaRef;
27use arrow::record_batch::RecordBatch;
28use datafusion_common::Result;
29use datafusion_execution::TaskContext;
30use datafusion_physical_expr::PhysicalExpr;
31use futures::stream::BoxStream;
32use std::borrow::Cow;
33use std::sync::Arc;
34use std::task::{Context, Poll};
35
36use crate::filter::batch_filter;
37use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
38use futures::stream::{Stream, StreamExt};
39
40use super::AggregateExec;
41
42/// stream struct for aggregation without grouping columns
43pub(crate) struct AggregateStream {
44    stream: BoxStream<'static, Result<RecordBatch>>,
45    schema: SchemaRef,
46}
47
48/// Actual implementation of [`AggregateStream`].
49///
50/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
51/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
52/// [`futures::stream::unfold`].
53///
54/// The latter requires a state object, which is [`AggregateStreamInner`].
55struct AggregateStreamInner {
56    schema: SchemaRef,
57    mode: AggregateMode,
58    input: SendableRecordBatchStream,
59    baseline_metrics: BaselineMetrics,
60    aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
61    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
62    accumulators: Vec<AccumulatorItem>,
63    reservation: MemoryReservation,
64    finished: bool,
65}
66
67impl AggregateStream {
68    /// Create a new AggregateStream
69    pub fn new(
70        agg: &AggregateExec,
71        context: Arc<TaskContext>,
72        partition: usize,
73    ) -> Result<Self> {
74        let agg_schema = Arc::clone(&agg.schema);
75        let agg_filter_expr = agg.filter_expr.clone();
76
77        let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
78        let input = agg.input.execute(partition, Arc::clone(&context))?;
79
80        let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?;
81        let filter_expressions = match agg.mode {
82            AggregateMode::Partial
83            | AggregateMode::Single
84            | AggregateMode::SinglePartitioned => agg_filter_expr,
85            AggregateMode::Final | AggregateMode::FinalPartitioned => {
86                vec![None; agg.aggr_expr.len()]
87            }
88        };
89        let accumulators = create_accumulators(&agg.aggr_expr)?;
90
91        let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]"))
92            .register(context.memory_pool());
93
94        let inner = AggregateStreamInner {
95            schema: Arc::clone(&agg.schema),
96            mode: agg.mode,
97            input,
98            baseline_metrics,
99            aggregate_expressions,
100            filter_expressions,
101            accumulators,
102            reservation,
103            finished: false,
104        };
105        let stream = futures::stream::unfold(inner, |mut this| async move {
106            if this.finished {
107                return None;
108            }
109
110            let elapsed_compute = this.baseline_metrics.elapsed_compute();
111
112            loop {
113                let result = match this.input.next().await {
114                    Some(Ok(batch)) => {
115                        let timer = elapsed_compute.timer();
116                        let result = aggregate_batch(
117                            &this.mode,
118                            batch,
119                            &mut this.accumulators,
120                            &this.aggregate_expressions,
121                            &this.filter_expressions,
122                        );
123
124                        timer.done();
125
126                        // allocate memory
127                        // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
128                        // overshooting a bit. Also this means we either store the whole record batch or not.
129                        match result
130                            .and_then(|allocated| this.reservation.try_grow(allocated))
131                        {
132                            Ok(_) => continue,
133                            Err(e) => Err(e),
134                        }
135                    }
136                    Some(Err(e)) => Err(e),
137                    None => {
138                        this.finished = true;
139                        let timer = this.baseline_metrics.elapsed_compute().timer();
140                        let result =
141                            finalize_aggregation(&mut this.accumulators, &this.mode)
142                                .and_then(|columns| {
143                                    RecordBatch::try_new(
144                                        Arc::clone(&this.schema),
145                                        columns,
146                                    )
147                                    .map_err(Into::into)
148                                })
149                                .record_output(&this.baseline_metrics);
150
151                        timer.done();
152
153                        result
154                    }
155                };
156
157                this.finished = true;
158                return Some((result, this));
159            }
160        });
161
162        // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
163        let stream = stream.fuse();
164        let stream = Box::pin(stream);
165
166        Ok(Self {
167            schema: agg_schema,
168            stream,
169        })
170    }
171}
172
173impl Stream for AggregateStream {
174    type Item = Result<RecordBatch>;
175
176    fn poll_next(
177        mut self: std::pin::Pin<&mut Self>,
178        cx: &mut Context<'_>,
179    ) -> Poll<Option<Self::Item>> {
180        let this = &mut *self;
181        this.stream.poll_next_unpin(cx)
182    }
183}
184
185impl RecordBatchStream for AggregateStream {
186    fn schema(&self) -> SchemaRef {
187        Arc::clone(&self.schema)
188    }
189}
190
191/// Perform group-by aggregation for the given [`RecordBatch`].
192///
193/// If successful, this returns the additional number of bytes that were allocated during this process.
194///
195/// TODO: Make this a member function
196fn aggregate_batch(
197    mode: &AggregateMode,
198    batch: RecordBatch,
199    accumulators: &mut [AccumulatorItem],
200    expressions: &[Vec<Arc<dyn PhysicalExpr>>],
201    filters: &[Option<Arc<dyn PhysicalExpr>>],
202) -> Result<usize> {
203    let mut allocated = 0usize;
204
205    // 1.1 iterate accumulators and respective expressions together
206    // 1.2 filter the batch if necessary
207    // 1.3 evaluate expressions
208    // 1.4 update / merge accumulators with the expressions' values
209
210    // 1.1
211    accumulators
212        .iter_mut()
213        .zip(expressions)
214        .zip(filters)
215        .try_for_each(|((accum, expr), filter)| {
216            // 1.2
217            let batch = match filter {
218                Some(filter) => Cow::Owned(batch_filter(&batch, filter)?),
219                None => Cow::Borrowed(&batch),
220            };
221
222            let n_rows = batch.num_rows();
223
224            // 1.3
225            let values = expr
226                .iter()
227                .map(|e| e.evaluate(&batch).and_then(|v| v.into_array(n_rows)))
228                .collect::<Result<Vec<_>>>()?;
229
230            // 1.4
231            let size_pre = accum.size();
232            let res = match mode {
233                AggregateMode::Partial
234                | AggregateMode::Single
235                | AggregateMode::SinglePartitioned => accum.update_batch(&values),
236                AggregateMode::Final | AggregateMode::FinalPartitioned => {
237                    accum.merge_batch(&values)
238                }
239            };
240            let size_post = accum.size();
241            allocated += size_post.saturating_sub(size_pre);
242            res
243        })?;
244
245    Ok(allocated)
246}