datafusion_physical_plan/aggregates/
topk_stream.rs1use crate::aggregates::group_values::GroupByMetrics;
21use crate::aggregates::topk::priority_map::PriorityMap;
22use crate::aggregates::{
23 aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec,
24 PhysicalGroupBy,
25};
26use crate::metrics::BaselineMetrics;
27use crate::{RecordBatchStream, SendableRecordBatchStream};
28use arrow::array::{Array, ArrayRef, RecordBatch};
29use arrow::datatypes::SchemaRef;
30use arrow::util::pretty::print_batches;
31use datafusion_common::internal_datafusion_err;
32use datafusion_common::Result;
33use datafusion_execution::TaskContext;
34use datafusion_physical_expr::PhysicalExpr;
35use futures::stream::{Stream, StreamExt};
36use log::{trace, Level};
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40
41pub struct GroupedTopKAggregateStream {
42 partition: usize,
43 row_count: usize,
44 started: bool,
45 schema: SchemaRef,
46 input: SendableRecordBatchStream,
47 baseline_metrics: BaselineMetrics,
48 group_by_metrics: GroupByMetrics,
49 aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
50 group_by: PhysicalGroupBy,
51 priority_map: PriorityMap,
52}
53
54impl GroupedTopKAggregateStream {
55 pub fn new(
56 aggr: &AggregateExec,
57 context: Arc<TaskContext>,
58 partition: usize,
59 limit: usize,
60 ) -> Result<Self> {
61 let agg_schema = Arc::clone(&aggr.schema);
62 let group_by = aggr.group_by.clone();
63 let input = aggr.input.execute(partition, Arc::clone(&context))?;
64 let baseline_metrics = BaselineMetrics::new(&aggr.metrics, partition);
65 let group_by_metrics = GroupByMetrics::new(&aggr.metrics, partition);
66 let aggregate_arguments =
67 aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?;
68 let (val_field, desc) = aggr
69 .get_minmax_desc()
70 .ok_or_else(|| internal_datafusion_err!("Min/max required"))?;
71
72 let (expr, _) = &aggr.group_expr().expr()[0];
73 let kt = expr.data_type(&aggr.input().schema())?;
74 let vt = val_field.data_type().clone();
75
76 let priority_map = PriorityMap::new(kt, vt, limit, desc)?;
77
78 Ok(GroupedTopKAggregateStream {
79 partition,
80 started: false,
81 row_count: 0,
82 schema: agg_schema,
83 input,
84 baseline_metrics,
85 group_by_metrics,
86 aggregate_arguments,
87 group_by,
88 priority_map,
89 })
90 }
91}
92
93impl RecordBatchStream for GroupedTopKAggregateStream {
94 fn schema(&self) -> SchemaRef {
95 Arc::clone(&self.schema)
96 }
97}
98
99impl GroupedTopKAggregateStream {
100 fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> {
101 let _timer = self.group_by_metrics.time_calculating_group_ids.timer();
102
103 let len = ids.len();
104 self.priority_map.set_batch(ids, Arc::clone(&vals));
105
106 let has_nulls = vals.null_count() > 0;
107 for row_idx in 0..len {
108 if has_nulls && vals.is_null(row_idx) {
109 continue;
110 }
111 self.priority_map.insert(row_idx)?;
112 }
113 Ok(())
114 }
115}
116
117impl Stream for GroupedTopKAggregateStream {
118 type Item = Result<RecordBatch>;
119
120 fn poll_next(
121 mut self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 ) -> Poll<Option<Self::Item>> {
124 let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
125 let emitting_time = self.group_by_metrics.emitting_time.clone();
126 while let Poll::Ready(res) = self.input.poll_next_unpin(cx) {
127 let _timer = elapsed_compute.timer();
128 match res {
129 Some(Ok(batch)) => {
131 self.started = true;
132 trace!(
133 "partition {} has {} rows and got batch with {} rows",
134 self.partition,
135 self.row_count,
136 batch.num_rows()
137 );
138 if log::log_enabled!(Level::Trace) && batch.num_rows() < 20 {
139 print_batches(std::slice::from_ref(&batch))?;
140 }
141 self.row_count += batch.num_rows();
142 let batches = &[batch];
143 let group_by_values =
144 evaluate_group_by(&self.group_by, batches.first().unwrap())?;
145 assert_eq!(
146 group_by_values.len(),
147 1,
148 "Exactly 1 group value required"
149 );
150 assert_eq!(
151 group_by_values[0].len(),
152 1,
153 "Exactly 1 group value required"
154 );
155 let group_by_values = Arc::clone(&group_by_values[0][0]);
156 let input_values = {
157 let _timer = (!self.aggregate_arguments.is_empty()).then(|| {
158 self.group_by_metrics.aggregate_arguments_time.timer()
159 });
160 evaluate_many(
161 &self.aggregate_arguments,
162 batches.first().unwrap(),
163 )?
164 };
165 assert_eq!(input_values.len(), 1, "Exactly 1 input required");
166 assert_eq!(input_values[0].len(), 1, "Exactly 1 input required");
167 let input_values = Arc::clone(&input_values[0][0]);
168
169 (*self).intern(group_by_values, input_values)?;
171 }
172 None => {
174 if self.priority_map.is_empty() {
175 trace!("partition {} emit None", self.partition);
176 return Poll::Ready(None);
177 }
178 let batch = {
179 let _timer = emitting_time.timer();
180 let cols = self.priority_map.emit()?;
181 RecordBatch::try_new(Arc::clone(&self.schema), cols)?
182 };
183 trace!(
184 "partition {} emit batch with {} rows",
185 self.partition,
186 batch.num_rows()
187 );
188 if log::log_enabled!(Level::Trace) {
189 print_batches(std::slice::from_ref(&batch))?;
190 }
191 return Poll::Ready(Some(Ok(batch)));
192 }
193 Some(Err(e)) => {
195 return Poll::Ready(Some(Err(e)));
196 }
197 }
198 }
199 Poll::Pending
200 }
201}