datafusion_physical_optimizer/
aggregate_statistics.rs1use datafusion_common::config::ConfigOptions;
20use datafusion_common::scalar::ScalarValue;
21use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
22use datafusion_common::Result;
23use datafusion_physical_plan::aggregates::AggregateExec;
24use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
25use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr};
26use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
27use datafusion_physical_plan::{expressions, ExecutionPlan};
28use std::sync::Arc;
29
30use crate::PhysicalOptimizerRule;
31
32#[derive(Default, Debug)]
34pub struct AggregateStatistics {}
35
36impl AggregateStatistics {
37 #[allow(missing_docs)]
38 pub fn new() -> Self {
39 Self {}
40 }
41}
42
43impl PhysicalOptimizerRule for AggregateStatistics {
44 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
45 #[allow(clippy::only_used_in_recursion)] fn optimize(
47 &self,
48 plan: Arc<dyn ExecutionPlan>,
49 config: &ConfigOptions,
50 ) -> Result<Arc<dyn ExecutionPlan>> {
51 if let Some(partial_agg_exec) = take_optimizable(&*plan) {
52 let partial_agg_exec = partial_agg_exec
53 .as_any()
54 .downcast_ref::<AggregateExec>()
55 .expect("take_optimizable() ensures that this is a AggregateExec");
56 let stats = partial_agg_exec.input().partition_statistics(None)?;
57 let mut projections = vec![];
58 for expr in partial_agg_exec.aggr_expr() {
59 let field = expr.field();
60 let args = expr.expressions();
61 let statistics_args = StatisticsArgs {
62 statistics: &stats,
63 return_type: field.data_type(),
64 is_distinct: expr.is_distinct(),
65 exprs: args.as_slice(),
66 };
67 if let Some((optimizable_statistic, name)) =
68 take_optimizable_value_from_statistics(&statistics_args, expr)
69 {
70 projections.push(ProjectionExpr {
71 expr: expressions::lit(optimizable_statistic),
72 alias: name.to_owned(),
73 });
74 } else {
75 break;
77 }
78 }
79
80 if projections.len() == partial_agg_exec.aggr_expr().len() {
82 Ok(Arc::new(ProjectionExec::try_new(
84 projections,
85 Arc::new(PlaceholderRowExec::new(plan.schema())),
86 )?))
87 } else {
88 plan.map_children(|child| {
89 self.optimize(child, config).map(Transformed::yes)
90 })
91 .data()
92 }
93 } else {
94 plan.map_children(|child| self.optimize(child, config).map(Transformed::yes))
95 .data()
96 }
97 }
98
99 fn name(&self) -> &str {
100 "aggregate_statistics"
101 }
102
103 fn schema_check(&self) -> bool {
105 false
106 }
107}
108
109fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> {
117 if let Some(final_agg_exec) = node.as_any().downcast_ref::<AggregateExec>() {
118 if !final_agg_exec.mode().is_first_stage()
119 && final_agg_exec.group_expr().is_empty()
120 {
121 let mut child = Arc::clone(final_agg_exec.input());
122 loop {
123 if let Some(partial_agg_exec) =
124 child.as_any().downcast_ref::<AggregateExec>()
125 {
126 if partial_agg_exec.mode().is_first_stage()
127 && partial_agg_exec.group_expr().is_empty()
128 && partial_agg_exec.filter_expr().iter().all(|e| e.is_none())
129 {
130 return Some(child);
131 }
132 }
133 if let [childrens_child] = child.children().as_slice() {
134 child = Arc::clone(childrens_child);
135 } else {
136 break;
137 }
138 }
139 }
140 }
141 None
142}
143
144fn take_optimizable_value_from_statistics(
146 statistics_args: &StatisticsArgs,
147 agg_expr: &AggregateFunctionExpr,
148) -> Option<(ScalarValue, String)> {
149 let value = agg_expr.fun().value_from_stats(statistics_args);
150 value.map(|val| (val, agg_expr.name().to_string()))
151}
152
153