datafusion_physical_optimizer/
combine_partial_final_agg.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs
19//! and try to combine them if necessary
20
21use std::sync::Arc;
22
23use datafusion_common::error::Result;
24use datafusion_physical_plan::aggregates::{
25    AggregateExec, AggregateMode, PhysicalGroupBy,
26};
27use datafusion_physical_plan::ExecutionPlan;
28
29use crate::PhysicalOptimizerRule;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
32use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
33use datafusion_physical_expr::{physical_exprs_equal, PhysicalExpr};
34
35/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs
36/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal.
37///
38/// This rule should be applied after the EnforceDistribution and EnforceSorting rules
39#[derive(Default, Debug)]
40pub struct CombinePartialFinalAggregate {}
41
42impl CombinePartialFinalAggregate {
43    #[allow(missing_docs)]
44    pub fn new() -> Self {
45        Self {}
46    }
47}
48
49impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
50    fn optimize(
51        &self,
52        plan: Arc<dyn ExecutionPlan>,
53        _config: &ConfigOptions,
54    ) -> Result<Arc<dyn ExecutionPlan>> {
55        plan.transform_down(|plan| {
56            // Check if the plan is AggregateExec
57            let Some(agg_exec) = plan.as_any().downcast_ref::<AggregateExec>() else {
58                return Ok(Transformed::no(plan));
59            };
60
61            if !matches!(
62                agg_exec.mode(),
63                AggregateMode::Final | AggregateMode::FinalPartitioned
64            ) {
65                return Ok(Transformed::no(plan));
66            }
67
68            // Check if the input is AggregateExec
69            let Some(input_agg_exec) =
70                agg_exec.input().as_any().downcast_ref::<AggregateExec>()
71            else {
72                return Ok(Transformed::no(plan));
73            };
74
75            let transformed = if matches!(input_agg_exec.mode(), AggregateMode::Partial)
76                && can_combine(
77                    (
78                        agg_exec.group_expr(),
79                        agg_exec.aggr_expr(),
80                        agg_exec.filter_expr(),
81                    ),
82                    (
83                        input_agg_exec.group_expr(),
84                        input_agg_exec.aggr_expr(),
85                        input_agg_exec.filter_expr(),
86                    ),
87                ) {
88                let mode = if agg_exec.mode() == &AggregateMode::Final {
89                    AggregateMode::Single
90                } else {
91                    AggregateMode::SinglePartitioned
92                };
93                AggregateExec::try_new(
94                    mode,
95                    input_agg_exec.group_expr().clone(),
96                    input_agg_exec.aggr_expr().to_vec(),
97                    input_agg_exec.filter_expr().to_vec(),
98                    Arc::clone(input_agg_exec.input()),
99                    input_agg_exec.input_schema(),
100                )
101                .map(|combined_agg| combined_agg.with_limit(agg_exec.limit()))
102                .ok()
103                .map(Arc::new)
104            } else {
105                None
106            };
107            Ok(if let Some(transformed) = transformed {
108                Transformed::yes(transformed)
109            } else {
110                Transformed::no(plan)
111            })
112        })
113        .data()
114    }
115
116    fn name(&self) -> &str {
117        "CombinePartialFinalAggregate"
118    }
119
120    fn schema_check(&self) -> bool {
121        true
122    }
123}
124
125type GroupExprsRef<'a> = (
126    &'a PhysicalGroupBy,
127    &'a [Arc<AggregateFunctionExpr>],
128    &'a [Option<Arc<dyn PhysicalExpr>>],
129);
130
131fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
132    let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
133    let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
134
135    // Compare output expressions of the partial, and input expressions of the final operator.
136    physical_exprs_equal(
137        &input_group_by.output_exprs(),
138        &final_group_by.input_exprs(),
139    ) && input_group_by.groups() == final_group_by.groups()
140        && input_group_by.null_expr().len() == final_group_by.null_expr().len()
141        && input_group_by
142            .null_expr()
143            .iter()
144            .zip(final_group_by.null_expr().iter())
145            .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
146                lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
147            })
148        && final_aggr_expr.len() == input_aggr_expr.len()
149        && final_aggr_expr
150            .iter()
151            .zip(input_aggr_expr.iter())
152            .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr))
153        && final_filter_expr.len() == input_filter_expr.len()
154        && final_filter_expr.iter().zip(input_filter_expr.iter()).all(
155            |(final_expr, partial_expr)| match (final_expr, partial_expr) {
156                (Some(l), Some(r)) => l.eq(r),
157                (None, None) => true,
158                _ => false,
159            },
160        )
161}
162
163// See tests in datafusion/core/tests/physical_optimizer