datafusion_physical_plan/aggregates/
topk_stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! A memory-conscious aggregation implementation that limits group buckets to a fixed number
19
20use crate::aggregates::group_values::GroupByMetrics;
21use crate::aggregates::topk::priority_map::PriorityMap;
22use crate::aggregates::{
23    aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
24    PhysicalGroupBy,
25};
26use crate::metrics::BaselineMetrics;
27use crate::{RecordBatchStream, SendableRecordBatchStream};
28use arrow::array::{Array, ArrayRef, RecordBatch};
29use arrow::datatypes::SchemaRef;
30use arrow::util::pretty::print_batches;
31use datafusion_common::internal_datafusion_err;
32use datafusion_common::Result;
33use datafusion_execution::TaskContext;
34use datafusion_physical_expr::PhysicalExpr;
35use futures::stream::{Stream, StreamExt};
36use log::{trace, Level};
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41pub struct GroupedTopKAggregateStream {
42    partition: usize,
43    row_count: usize,
44    started: bool,
45    schema: SchemaRef,
46    input: SendableRecordBatchStream,
47    baseline_metrics: BaselineMetrics,
48    group_by_metrics: GroupByMetrics,
49    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
50    group_by: PhysicalGroupBy,
51    priority_map: PriorityMap,
52}
53
54impl GroupedTopKAggregateStream {
55    pub fn new(
56        aggr: &AggregateExec,
57        context: Arc<TaskContext>,
58        partition: usize,
59        limit: usize,
60    ) -> Result<Self> {
61        let agg_schema = Arc::clone(&aggr.schema);
62        let group_by = aggr.group_by.clone();
63        let input = aggr.input.execute(partition, Arc::clone(&context))?;
64        let baseline_metrics = BaselineMetrics::new(&aggr.metrics, partition);
65        let group_by_metrics = GroupByMetrics::new(&aggr.metrics, partition);
66        let aggregate_arguments =
67            aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?;
68        let (val_field, desc) = aggr
69            .get_minmax_desc()
70            .ok_or_else(|| internal_datafusion_err!("Min/max required"))?;
71
72        let (expr, _) = &aggr.group_expr().expr()[0];
73        let kt = expr.data_type(&aggr.input().schema())?;
74        let vt = val_field.data_type().clone();
75
76        let priority_map = PriorityMap::new(kt, vt, limit, desc)?;
77
78        Ok(GroupedTopKAggregateStream {
79            partition,
80            started: false,
81            row_count: 0,
82            schema: agg_schema,
83            input,
84            baseline_metrics,
85            group_by_metrics,
86            aggregate_arguments,
87            group_by,
88            priority_map,
89        })
90    }
91}
92
93impl RecordBatchStream for GroupedTopKAggregateStream {
94    fn schema(&self) -> SchemaRef {
95        Arc::clone(&self.schema)
96    }
97}
98
99impl GroupedTopKAggregateStream {
100    fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
101        let _timer = self.group_by_metrics.time_calculating_group_ids.timer();
102
103        let len = ids.len();
104        self.priority_map.set_batch(ids, Arc::clone(&vals));
105
106        let has_nulls = vals.null_count() > 0;
107        for row_idx in 0..len {
108            if has_nulls && vals.is_null(row_idx) {
109                continue;
110            }
111            self.priority_map.insert(row_idx)?;
112        }
113        Ok(())
114    }
115}
116
117impl Stream for GroupedTopKAggregateStream {
118    type Item = Result<RecordBatch>;
119
120    fn poll_next(
121        mut self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123    ) -> Poll<Option<Self::Item>> {
124        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
125        let emitting_time = self.group_by_metrics.emitting_time.clone();
126        while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
127            let _timer = elapsed_compute.timer();
128            match res {
129                // got a batch, convert to rows and append to our TreeMap
130                Some(Ok(batch)) => {
131                    self.started = true;
132                    trace!(
133                        "partition {} has {} rows and got batch with {} rows",
134                        self.partition,
135                        self.row_count,
136                        batch.num_rows()
137                    );
138                    if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 {
139                        print_batches(std::slice::from_ref(&batch))?;
140                    }
141                    self.row_count += batch.num_rows();
142                    let batches = &[batch];
143                    let group_by_values =
144                        evaluate_group_by(&self.group_by, batches.first().unwrap())?;
145                    assert_eq!(
146                        group_by_values.len(),
147                        1,
148                        "Exactly 1 group value required"
149                    );
150                    assert_eq!(
151                        group_by_values[0].len(),
152                        1,
153                        "Exactly 1 group value required"
154                    );
155                    let group_by_values = Arc::clone(&group_by_values[0][0]);
156                    let input_values = {
157                        let _timer = (!self.aggregate_arguments.is_empty()).then(|| {
158                            self.group_by_metrics.aggregate_arguments_time.timer()
159                        });
160                        evaluate_many(
161                            &self.aggregate_arguments,
162                            batches.first().unwrap(),
163                        )?
164                    };
165                    assert_eq!(input_values.len(), 1, "Exactly 1 input required");
166                    assert_eq!(input_values[0].len(), 1, "Exactly 1 input required");
167                    let input_values = Arc::clone(&input_values[0][0]);
168
169                    // iterate over each column of group_by values
170                    (*self).intern(group_by_values, input_values)?;
171                }
172                // inner is done, emit all rows and switch to producing output
173                None => {
174                    if self.priority_map.is_empty() {
175                        trace!("partition {} emit None", self.partition);
176                        return Poll::Ready(None);
177                    }
178                    let batch = {
179                        let _timer = emitting_time.timer();
180                        let cols = self.priority_map.emit()?;
181                        RecordBatch::try_new(Arc::clone(&self.schema), cols)?
182                    };
183                    trace!(
184                        "partition {} emit batch with {} rows",
185                        self.partition,
186                        batch.num_rows()
187                    );
188                    if log::log_enabled!(Level::Trace) {
189                        print_batches(std::slice::from_ref(&batch))?;
190                    }
191                    return Poll::Ready(Some(Ok(batch)));
192                }
193                // inner had error, return to caller
194                Some(Err(e)) => {
195                    return Poll::Ready(Some(Err(e)));
196                }
197            }
198        }
199        Poll::Pending
200    }
201}