datafusion_physical_optimizer/
combine_partial_final_agg.rs1use std::sync::Arc;
22
23use datafusion_common::error::Result;
24use datafusion_physical_plan::aggregates::{
25 AggregateExec, AggregateMode, PhysicalGroupBy,
26};
27use datafusion_physical_plan::ExecutionPlan;
28
29use crate::PhysicalOptimizerRule;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
33use datafusion_physical_expr::{physical_exprs_equal, PhysicalExpr};
34
35#[derive(Default, Debug)]
40pub struct CombinePartialFinalAggregate {}
41
42impl CombinePartialFinalAggregate {
43 #[allow(missing_docs)]
44 pub fn new() -> Self {
45 Self {}
46 }
47}
48
49impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
50 fn optimize(
51 &self,
52 plan: Arc<dyn ExecutionPlan>,
53 _config: &ConfigOptions,
54 ) -> Result<Arc<dyn ExecutionPlan>> {
55 plan.transform_down(|plan| {
56 let Some(agg_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
58 return Ok(Transformed::no(plan));
59 };
60
61 if !matches!(
62 agg_exec.mode(),
63 AggregateMode::Final | AggregateMode::FinalPartitioned
64 ) {
65 return Ok(Transformed::no(plan));
66 }
67
68 let Some(input_agg_exec) =
70 agg_exec.input().as_any().downcast_ref::<AggregateExec>()
71 else {
72 return Ok(Transformed::no(plan));
73 };
74
75 let transformed = if matches!(input_agg_exec.mode(), AggregateMode::Partial)
76 && can_combine(
77 (
78 agg_exec.group_expr(),
79 agg_exec.aggr_expr(),
80 agg_exec.filter_expr(),
81 ),
82 (
83 input_agg_exec.group_expr(),
84 input_agg_exec.aggr_expr(),
85 input_agg_exec.filter_expr(),
86 ),
87 ) {
88 let mode = if agg_exec.mode() == &AggregateMode::Final {
89 AggregateMode::Single
90 } else {
91 AggregateMode::SinglePartitioned
92 };
93 AggregateExec::try_new(
94 mode,
95 input_agg_exec.group_expr().clone(),
96 input_agg_exec.aggr_expr().to_vec(),
97 input_agg_exec.filter_expr().to_vec(),
98 Arc::clone(input_agg_exec.input()),
99 input_agg_exec.input_schema(),
100 )
101 .map(|combined_agg| combined_agg.with_limit(agg_exec.limit()))
102 .ok()
103 .map(Arc::new)
104 } else {
105 None
106 };
107 Ok(if let Some(transformed) = transformed {
108 Transformed::yes(transformed)
109 } else {
110 Transformed::no(plan)
111 })
112 })
113 .data()
114 }
115
116 fn name(&self) -> &str {
117 "CombinePartialFinalAggregate"
118 }
119
120 fn schema_check(&self) -> bool {
121 true
122 }
123}
124
125type GroupExprsRef<'a> = (
126 &'a PhysicalGroupBy,
127 &'a [Arc<AggregateFunctionExpr>],
128 &'a [Option<Arc<dyn PhysicalExpr>>],
129);
130
131fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
132 let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
133 let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
134
135 physical_exprs_equal(
137 &input_group_by.output_exprs(),
138 &final_group_by.input_exprs(),
139 ) && input_group_by.groups() == final_group_by.groups()
140 && input_group_by.null_expr().len() == final_group_by.null_expr().len()
141 && input_group_by
142 .null_expr()
143 .iter()
144 .zip(final_group_by.null_expr().iter())
145 .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
146 lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
147 })
148 && final_aggr_expr.len() == input_aggr_expr.len()
149 && final_aggr_expr
150 .iter()
151 .zip(input_aggr_expr.iter())
152 .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
153 && final_filter_expr.len() == input_filter_expr.len()
154 && final_filter_expr.iter().zip(input_filter_expr.iter()).all(
155 |(final_expr, partial_expr)| match (final_expr, partial_expr) {
156 (Some(l), Some(r)) => l.eq(r),
157 (None, None) => true,
158 _ => false,
159 },
160 )
161}
162
163