datafusion_optimizer/
common_subexpr_eliminate.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions
19
20use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, HashSet, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35    Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41/// Performs Common Sub-expression Elimination optimization.
42///
43/// This optimization improves query performance by computing expressions that
44/// appear more than once and reusing those results rather than re-computing the
45/// same value
46///
47/// Currently only common sub-expressions within a single `LogicalPlan` are
48/// eliminated.
49///
50/// # Example
51///
52/// Given a projection that computes the same expensive expression
53/// multiple times such as parsing as string as a date with `to_date` twice:
54///
55/// ```text
56/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))])
57/// ```
58///
59/// This optimization will rewrite the plan to compute the common expression once
60/// using a new `ProjectionExec` and then rewrite the original expressions to
61/// refer to that new column.
62///
63/// ```text
64/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here
65///   ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once
66/// ```
67#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71    pub fn new() -> Self {
72        Self {}
73    }
74
75    fn try_optimize_proj(
76        &self,
77        projection: Projection,
78        config: &dyn OptimizerConfig,
79    ) -> Result<Transformed<LogicalPlan>> {
80        let Projection {
81            expr,
82            input,
83            schema,
84            ..
85        } = projection;
86        let input = Arc::unwrap_or_clone(input);
87        self.try_unary_plan(expr, input, config)?
88            .map_data(|(new_expr, new_input)| {
89                Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90                    .map(LogicalPlan::Projection)
91            })
92    }
93
94    fn try_optimize_sort(
95        &self,
96        sort: Sort,
97        config: &dyn OptimizerConfig,
98    ) -> Result<Transformed<LogicalPlan>> {
99        let Sort { expr, input, fetch } = sort;
100        let input = Arc::unwrap_or_clone(input);
101        let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102            .into_iter()
103            .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104            .unzip();
105        let new_sort = self
106            .try_unary_plan(sort_expressions, input, config)?
107            .update_data(|(new_expr, new_input)| {
108                LogicalPlan::Sort(Sort {
109                    expr: new_expr
110                        .into_iter()
111                        .zip(sort_params)
112                        .map(|(expr, (asc, nulls_first))| SortExpr {
113                            expr,
114                            asc,
115                            nulls_first,
116                        })
117                        .collect(),
118                    input: Arc::new(new_input),
119                    fetch,
120                })
121            });
122        Ok(new_sort)
123    }
124
125    fn try_optimize_filter(
126        &self,
127        filter: Filter,
128        config: &dyn OptimizerConfig,
129    ) -> Result<Transformed<LogicalPlan>> {
130        let Filter {
131            predicate, input, ..
132        } = filter;
133        let input = Arc::unwrap_or_clone(input);
134        let expr = vec![predicate];
135        self.try_unary_plan(expr, input, config)?
136            .map_data(|(mut new_expr, new_input)| {
137                assert_eq!(new_expr.len(), 1); // passed in vec![predicate]
138                let new_predicate = new_expr.pop().unwrap();
139                Filter::try_new(new_predicate, Arc::new(new_input))
140                    .map(LogicalPlan::Filter)
141            })
142    }
143
144    fn try_optimize_window(
145        &self,
146        window: Window,
147        config: &dyn OptimizerConfig,
148    ) -> Result<Transformed<LogicalPlan>> {
149        // Collects window expressions from consecutive `LogicalPlan::Window` nodes into
150        // a list.
151        let (window_expr_list, window_schemas, input) =
152            get_consecutive_window_exprs(window);
153
154        // Extract common sub-expressions from the list.
155
156        match CSE::new(ExprCSEController::new(
157            config.alias_generator().as_ref(),
158            ExprMask::Normal,
159        ))
160        .extract_common_nodes(window_expr_list)?
161        {
162            // If there are common sub-expressions, then the insert a projection node
163            // with the common expressions between the new window nodes and the
164            // original input.
165            FoundCommonNodes::Yes {
166                common_nodes: common_exprs,
167                new_nodes_list: new_exprs_list,
168                original_nodes_list: original_exprs_list,
169            } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170                Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171            }),
172            FoundCommonNodes::No {
173                original_nodes_list: original_exprs_list,
174            } => Ok(Transformed::no((original_exprs_list, input, None))),
175        }?
176        // Recurse into the new input.
177        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
178        .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179            self.rewrite(new_input, config)?.map_data(|new_input| {
180                Ok((new_window_expr_list, new_input, window_expr_list))
181            })
182        })?
183        // Rebuild the consecutive window nodes.
184        .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185            // If there were common expressions extracted, then we need to make sure
186            // we restore the original column names.
187            // TODO: Although `find_common_exprs()` inserts aliases around extracted
188            //  common expressions this doesn't mean that the original column names
189            //  (schema) are preserved due to the inserted aliases are not always at
190            //  the top of the expression.
191            //  Let's consider improving `find_common_exprs()` to always keep column
192            //  names and get rid of additional name preserving logic here.
193            if let Some(window_expr_list) = window_expr_list {
194                let name_preserver = NamePreserver::new_for_projection();
195                let saved_names = window_expr_list
196                    .iter()
197                    .map(|exprs| {
198                        exprs
199                            .iter()
200                            .map(|expr| name_preserver.save(expr))
201                            .collect::<Vec<_>>()
202                    })
203                    .collect::<Vec<_>>();
204                new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205                    new_input,
206                    |plan, (new_window_expr, saved_names)| {
207                        let new_window_expr = new_window_expr
208                            .into_iter()
209                            .zip(saved_names)
210                            .map(|(new_window_expr, saved_name)| {
211                                saved_name.restore(new_window_expr)
212                            })
213                            .collect::<Vec<_>>();
214                        Window::try_new(new_window_expr, Arc::new(plan))
215                            .map(LogicalPlan::Window)
216                    },
217                )
218            } else {
219                new_window_expr_list
220                    .into_iter()
221                    .zip(window_schemas)
222                    .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223                        Window::try_new_with_schema(
224                            new_window_expr,
225                            Arc::new(plan),
226                            schema,
227                        )
228                        .map(LogicalPlan::Window)
229                    })
230            }
231        })
232    }
233
234    fn try_optimize_aggregate(
235        &self,
236        aggregate: Aggregate,
237        config: &dyn OptimizerConfig,
238    ) -> Result<Transformed<LogicalPlan>> {
239        let Aggregate {
240            group_expr,
241            aggr_expr,
242            input,
243            schema,
244            ..
245        } = aggregate;
246        let input = Arc::unwrap_or_clone(input);
247        // Extract common sub-expressions from the aggregate and grouping expressions.
248        match CSE::new(ExprCSEController::new(
249            config.alias_generator().as_ref(),
250            ExprMask::Normal,
251        ))
252        .extract_common_nodes(vec![group_expr, aggr_expr])?
253        {
254            // If there are common sub-expressions, then insert a projection node
255            // with the common expressions between the new aggregate node and the
256            // original input.
257            FoundCommonNodes::Yes {
258                common_nodes: common_exprs,
259                new_nodes_list: mut new_exprs_list,
260                original_nodes_list: mut original_exprs_list,
261            } => {
262                let new_aggr_expr = new_exprs_list.pop().unwrap();
263                let new_group_expr = new_exprs_list.pop().unwrap();
264
265                build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266                    let aggr_expr = original_exprs_list.pop().unwrap();
267                    Transformed::yes((
268                        new_aggr_expr,
269                        new_group_expr,
270                        new_input,
271                        Some(aggr_expr),
272                    ))
273                })
274            }
275
276            FoundCommonNodes::No {
277                original_nodes_list: mut original_exprs_list,
278            } => {
279                let new_aggr_expr = original_exprs_list.pop().unwrap();
280                let new_group_expr = original_exprs_list.pop().unwrap();
281
282                Ok(Transformed::no((
283                    new_aggr_expr,
284                    new_group_expr,
285                    input,
286                    None,
287                )))
288            }
289        }?
290        // Recurse into the new input.
291        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
292        .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293            self.rewrite(new_input, config)?.map_data(|new_input| {
294                Ok((
295                    new_aggr_expr,
296                    new_group_expr,
297                    aggr_expr,
298                    Arc::new(new_input),
299                ))
300            })
301        })?
302        // Try extracting common aggregate expressions and rebuild the aggregate node.
303        .transform_data(
304            |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305                // Extract common aggregate sub-expressions from the aggregate expressions.
306                match CSE::new(ExprCSEController::new(
307                    config.alias_generator().as_ref(),
308                    ExprMask::NormalAndAggregates,
309                ))
310                .extract_common_nodes(vec![new_aggr_expr])?
311                {
312                    FoundCommonNodes::Yes {
313                        common_nodes: common_exprs,
314                        new_nodes_list: mut new_exprs_list,
315                        original_nodes_list: mut original_exprs_list,
316                    } => {
317                        let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318                        let new_aggr_expr = original_exprs_list.pop().unwrap();
319                        let saved_names = if let Some(aggr_expr) = aggr_expr {
320                            let name_preserver = NamePreserver::new_for_projection();
321                            aggr_expr
322                                .iter()
323                                .map(|expr| Some(name_preserver.save(expr)))
324                                .collect::<Vec<_>>()
325                        } else {
326                            new_aggr_expr
327                                .clone()
328                                .into_iter()
329                                .map(|_| None)
330                                .collect::<Vec<_>>()
331                        };
332
333                        let mut agg_exprs = common_exprs
334                            .into_iter()
335                            .map(|(expr, expr_alias)| expr.alias(expr_alias))
336                            .collect::<Vec<_>>();
337
338                        let mut proj_exprs = vec![];
339                        for expr in &new_group_expr {
340                            extract_expressions(expr, &mut proj_exprs)
341                        }
342                        for ((expr_rewritten, expr_orig), saved_name) in
343                            rewritten_aggr_expr
344                                .into_iter()
345                                .zip(new_aggr_expr)
346                                .zip(saved_names)
347                        {
348                            if expr_rewritten == expr_orig {
349                                let expr_rewritten = if let Some(saved_name) = saved_name
350                                {
351                                    saved_name.restore(expr_rewritten)
352                                } else {
353                                    expr_rewritten
354                                };
355                                if let Expr::Alias(Alias { expr, name, .. }) =
356                                    expr_rewritten
357                                {
358                                    agg_exprs.push(expr.alias(&name));
359                                    proj_exprs
360                                        .push(Expr::Column(Column::from_name(name)));
361                                } else {
362                                    let expr_alias =
363                                        config.alias_generator().next(CSE_PREFIX);
364                                    let (qualifier, field_name) =
365                                        expr_rewritten.qualified_name();
366                                    let out_name =
367                                        qualified_name(qualifier.as_ref(), &field_name);
368
369                                    agg_exprs.push(expr_rewritten.alias(&expr_alias));
370                                    proj_exprs.push(
371                                        Expr::Column(Column::from_name(expr_alias))
372                                            .alias(out_name),
373                                    );
374                                }
375                            } else {
376                                proj_exprs.push(expr_rewritten);
377                            }
378                        }
379
380                        let agg = LogicalPlan::Aggregate(Aggregate::try_new(
381                            new_input,
382                            new_group_expr,
383                            agg_exprs,
384                        )?);
385                        Projection::try_new(proj_exprs, Arc::new(agg))
386                            .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
387                    }
388
389                    // If there aren't any common aggregate sub-expressions, then just
390                    // rebuild the aggregate node.
391                    FoundCommonNodes::No {
392                        original_nodes_list: mut original_exprs_list,
393                    } => {
394                        let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
395
396                        // If there were common expressions extracted, then we need to
397                        // make sure we restore the original column names.
398                        // TODO: Although `find_common_exprs()` inserts aliases around
399                        //  extracted common expressions this doesn't mean that the
400                        //  original column names (schema) are preserved due to the
401                        //  inserted aliases are not always at the top of the
402                        //  expression.
403                        //  Let's consider improving `find_common_exprs()` to always
404                        //  keep column names and get rid of additional name
405                        //  preserving logic here.
406                        if let Some(aggr_expr) = aggr_expr {
407                            let name_preserver = NamePreserver::new_for_projection();
408                            let saved_names = aggr_expr
409                                .iter()
410                                .map(|expr| name_preserver.save(expr))
411                                .collect::<Vec<_>>();
412                            let new_aggr_expr = rewritten_aggr_expr
413                                .into_iter()
414                                .zip(saved_names)
415                                .map(|(new_expr, saved_name)| {
416                                    saved_name.restore(new_expr)
417                                })
418                                .collect::<Vec<Expr>>();
419
420                            // Since `group_expr` may have changed, schema may also.
421                            // Use `try_new()` method.
422                            Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
423                                .map(LogicalPlan::Aggregate)
424                                .map(Transformed::no)
425                        } else {
426                            Aggregate::try_new_with_schema(
427                                new_input,
428                                new_group_expr,
429                                rewritten_aggr_expr,
430                                schema,
431                            )
432                            .map(LogicalPlan::Aggregate)
433                            .map(Transformed::no)
434                        }
435                    }
436                }
437            },
438        )
439    }
440
441    /// Rewrites the expr list and input to remove common subexpressions
442    ///
443    /// # Parameters
444    ///
445    /// * `exprs`: List of expressions in the node
446    /// * `input`: input plan (that produces the columns referred to in `exprs`)
447    ///
448    /// # Return value
449    ///
450    ///  Returns `(rewritten_exprs, new_input)`. `new_input` is either:
451    ///
452    /// 1. The original `input` of no common subexpressions were extracted
453    /// 2. A newly added projection on top of the original input
454    ///    that computes the common subexpressions
455    fn try_unary_plan(
456        &self,
457        exprs: Vec<Expr>,
458        input: LogicalPlan,
459        config: &dyn OptimizerConfig,
460    ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
461        // Extract common sub-expressions from the expressions.
462        match CSE::new(ExprCSEController::new(
463            config.alias_generator().as_ref(),
464            ExprMask::Normal,
465        ))
466        .extract_common_nodes(vec![exprs])?
467        {
468            FoundCommonNodes::Yes {
469                common_nodes: common_exprs,
470                new_nodes_list: mut new_exprs_list,
471                original_nodes_list: _,
472            } => {
473                let new_exprs = new_exprs_list.pop().unwrap();
474                build_common_expr_project_plan(input, common_exprs)
475                    .map(|new_input| Transformed::yes((new_exprs, new_input)))
476            }
477            FoundCommonNodes::No {
478                original_nodes_list: mut original_exprs_list,
479            } => {
480                let new_exprs = original_exprs_list.pop().unwrap();
481                Ok(Transformed::no((new_exprs, input)))
482            }
483        }?
484        // Recurse into the new input.
485        // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
486        .transform_data(|(new_exprs, new_input)| {
487            self.rewrite(new_input, config)?
488                .map_data(|new_input| Ok((new_exprs, new_input)))
489        })
490    }
491}
492
493/// Get all window expressions inside the consecutive window operators.
494///
495/// Returns the window expressions, and the input to the deepest child
496/// LogicalPlan.
497///
498/// For example, if the input window looks like
499///
500/// ```text
501///   LogicalPlan::Window(exprs=[a, b, c])
502///     LogicalPlan::Window(exprs=[d])
503///       InputPlan
504/// ```
505///
506/// Returns:
507/// *  `window_exprs`: `[[a, b, c], [d]]`
508/// * InputPlan
509///
510/// Consecutive window expressions may refer to same complex expression.
511///
512/// If same complex expression is referred more than once by subsequent
513/// `WindowAggr`s, we can cache complex expression by evaluating it with a
514/// projection before the first WindowAggr.
515///
516/// This enables us to cache complex expression "c3+c4" for following plan:
517///
518/// ```text
519/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
520/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
521/// ```
522///
523/// where, it is referred once by each `WindowAggr` (total of 2) in the plan.
524fn get_consecutive_window_exprs(
525    window: Window,
526) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
527    let mut window_expr_list = vec![];
528    let mut window_schemas = vec![];
529    let mut plan = LogicalPlan::Window(window);
530    while let LogicalPlan::Window(Window {
531        input,
532        window_expr,
533        schema,
534    }) = plan
535    {
536        window_expr_list.push(window_expr);
537        window_schemas.push(schema);
538
539        plan = Arc::unwrap_or_clone(input);
540    }
541    (window_expr_list, window_schemas, plan)
542}
543
544impl OptimizerRule for CommonSubexprEliminate {
545    fn supports_rewrite(&self) -> bool {
546        true
547    }
548
549    fn apply_order(&self) -> Option<ApplyOrder> {
550        // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner.
551        // This is because in some cases adjacent nodes are collected (e.g. `Window`) and
552        // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule.
553        None
554    }
555
556    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
557    fn rewrite(
558        &self,
559        plan: LogicalPlan,
560        config: &dyn OptimizerConfig,
561    ) -> Result<Transformed<LogicalPlan>> {
562        let original_schema = Arc::clone(plan.schema());
563
564        let optimized_plan = match plan {
565            LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
566            LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
567            LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
568            LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
569            LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
570            LogicalPlan::Join(_)
571            | LogicalPlan::Repartition(_)
572            | LogicalPlan::Union(_)
573            | LogicalPlan::TableScan(_)
574            | LogicalPlan::Values(_)
575            | LogicalPlan::EmptyRelation(_)
576            | LogicalPlan::Subquery(_)
577            | LogicalPlan::SubqueryAlias(_)
578            | LogicalPlan::Limit(_)
579            | LogicalPlan::Ddl(_)
580            | LogicalPlan::Explain(_)
581            | LogicalPlan::Analyze(_)
582            | LogicalPlan::Statement(_)
583            | LogicalPlan::DescribeTable(_)
584            | LogicalPlan::Distinct(_)
585            | LogicalPlan::Extension(_)
586            | LogicalPlan::Dml(_)
587            | LogicalPlan::Copy(_)
588            | LogicalPlan::Unnest(_)
589            | LogicalPlan::RecursiveQuery(_) => {
590                // This rule handles recursion itself in a `ApplyOrder::TopDown` like
591                // manner.
592                plan.map_children(|c| self.rewrite(c, config))?
593            }
594        };
595
596        // If we rewrote the plan, ensure the schema stays the same
597        if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
598        {
599            optimized_plan.map_data(|optimized_plan| {
600                build_recover_project_plan(&original_schema, optimized_plan)
601            })
602        } else {
603            Ok(optimized_plan)
604        }
605    }
606
607    fn name(&self) -> &str {
608        "common_sub_expression_eliminate"
609    }
610}
611
612/// Which type of [expressions](Expr) should be considered for rewriting?
613#[derive(Debug, Clone, Copy)]
614enum ExprMask {
615    /// Ignores:
616    ///
617    /// - [`Literal`](Expr::Literal)
618    /// - [`Columns`](Expr::Column)
619    /// - [`ScalarVariable`](Expr::ScalarVariable)
620    /// - [`Alias`](Expr::Alias)
621    /// - [`Wildcard`](Expr::Wildcard)
622    /// - [`AggregateFunction`](Expr::AggregateFunction)
623    Normal,
624
625    /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction).
626    NormalAndAggregates,
627}
628
629struct ExprCSEController<'a> {
630    alias_generator: &'a AliasGenerator,
631    mask: ExprMask,
632
633    // how many aliases have we seen so far
634    alias_counter: usize,
635    lambdas_params: HashSet<String>,
636}
637
638impl<'a> ExprCSEController<'a> {
639    fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
640        Self {
641            alias_generator,
642            mask,
643            alias_counter: 0,
644            lambdas_params: HashSet::new(),
645        }
646    }
647}
648
649impl CSEController for ExprCSEController<'_> {
650    type Node = Expr;
651
652    fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
653        match node {
654            // In case of `ScalarFunction`s we don't know which children are surely
655            // executed so start visiting all children conditionally and stop the
656            // recursion with `TreeNodeRecursion::Jump`.
657            Expr::ScalarFunction(ScalarFunction { func, args }) => {
658                func.conditional_arguments(args)
659            }
660
661            // In case of `And` and `Or` the first child is surely executed, but we
662            // account subexpressions as conditional in the second.
663            Expr::BinaryExpr(BinaryExpr {
664                left,
665                op: Operator::And | Operator::Or,
666                right,
667            }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
668
669            // In case of `Case` the optional base expression and the first when
670            // expressions are surely executed, but we account subexpressions as
671            // conditional in the others.
672            Expr::Case(Case {
673                expr,
674                when_then_expr,
675                else_expr,
676            }) => Some((
677                expr.iter()
678                    .map(|e| e.as_ref())
679                    .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
680                    .collect(),
681                when_then_expr
682                    .iter()
683                    .take(1)
684                    .map(|(_, then)| then.as_ref())
685                    .chain(
686                        when_then_expr
687                            .iter()
688                            .skip(1)
689                            .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
690                    )
691                    .chain(else_expr.iter().map(|e| e.as_ref()))
692                    .collect(),
693            )),
694            _ => None,
695        }
696    }
697
698    fn visit_f_down(&mut self, node: &Expr) {
699        if let Expr::Lambda(lambda) = node {
700            self.lambdas_params
701                .extend(lambda.params.iter().cloned());
702        }
703    }
704
705    fn visit_f_up(&mut self, node: &Expr) {
706        if let Expr::Lambda(lambda) = node {
707            for param in &lambda.params {
708                self.lambdas_params.remove(param);
709            }
710        }
711    }
712
713    fn is_valid(node: &Expr) -> bool {
714        !node.is_volatile_node()
715    }
716
717    fn is_ignored(&self, node: &Expr) -> bool {
718        if matches!(node, Expr::Column(c) if c.is_lambda_parameter(&self.lambdas_params)) {
719            return true
720        }
721
722        // TODO: remove the next line after `Expr::Wildcard` is removed
723        #[expect(deprecated)]
724        let is_normal_minus_aggregates = matches!(
725            node,
726            Expr::Literal(..)
727                | Expr::Column(..)
728                | Expr::ScalarVariable(..)
729                | Expr::Alias(..)
730                | Expr::Wildcard { .. }
731        );
732
733        let is_aggr = matches!(node, Expr::AggregateFunction(..));
734
735        match self.mask {
736            ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
737            ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
738        }
739    }
740
741    fn generate_alias(&self) -> String {
742        self.alias_generator.next(CSE_PREFIX)
743    }
744
745    fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
746        // alias the expressions without an `Alias` ancestor node
747        if self.alias_counter > 0 {
748            col(alias)
749        } else {
750            self.alias_counter += 1;
751            col(alias).alias(node.schema_name().to_string())
752        }
753    }
754
755    fn rewrite_f_down(&mut self, node: &Expr) {
756        if matches!(node, Expr::Alias(_)) {
757            self.alias_counter += 1;
758        }
759    }
760    fn rewrite_f_up(&mut self, node: &Expr) {
761        if matches!(node, Expr::Alias(_)) {
762            self.alias_counter -= 1
763        }
764    }
765}
766
767impl Default for CommonSubexprEliminate {
768    fn default() -> Self {
769        Self::new()
770    }
771}
772
773/// Build the "intermediate" projection plan that evaluates the extracted common
774/// expressions.
775///
776/// # Arguments
777/// input: the input plan
778///
779/// common_exprs: which common subexpressions were used (and thus are added to
780/// intermediate projection)
781///
782/// expr_stats: the set of common subexpressions
783fn build_common_expr_project_plan(
784    input: LogicalPlan,
785    common_exprs: Vec<(Expr, String)>,
786) -> Result<LogicalPlan> {
787    let mut fields_set = BTreeSet::new();
788    let mut project_exprs = common_exprs
789        .into_iter()
790        .map(|(expr, expr_alias)| {
791            fields_set.insert(expr_alias.clone());
792            Ok(expr.alias(expr_alias))
793        })
794        .collect::<Result<Vec<_>>>()?;
795
796    for (qualifier, field) in input.schema().iter() {
797        if fields_set.insert(qualified_name(qualifier, field.name())) {
798            project_exprs.push(Expr::from((qualifier, field)));
799        }
800    }
801
802    Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
803}
804
805/// Build the projection plan to eliminate unnecessary columns produced by
806/// the "intermediate" projection plan built in [build_common_expr_project_plan].
807///
808/// This is required to keep the schema the same for plans that pass the input
809/// on to the output, such as `Filter` or `Sort`.
810fn build_recover_project_plan(
811    schema: &DFSchema,
812    input: LogicalPlan,
813) -> Result<LogicalPlan> {
814    let col_exprs = schema.iter().map(Expr::from).collect();
815    Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
816}
817
818fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
819    if let Expr::GroupingSet(groupings) = expr {
820        for e in groupings.distinct_expr() {
821            let (qualifier, field_name) = e.qualified_name();
822            let col = Column::new(qualifier, field_name);
823            result.push(Expr::Column(col))
824        }
825    } else {
826        let (qualifier, field_name) = expr.qualified_name();
827        let col = Column::new(qualifier, field_name);
828        result.push(Expr::Column(col));
829    }
830}
831
832#[cfg(test)]
833mod test {
834    use std::any::Any;
835    use std::iter;
836
837    use arrow::datatypes::{DataType, Field, Schema};
838    use datafusion_expr::logical_plan::{table_scan, JoinType};
839    use datafusion_expr::{
840        grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
841        ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
842        SimpleAggregateUDF, Volatility,
843    };
844    use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
845
846    use super::*;
847    use crate::assert_optimized_plan_eq_snapshot;
848    use crate::optimizer::OptimizerContext;
849    use crate::test::*;
850    use datafusion_expr::test::function_stub::{avg, sum};
851
852    macro_rules! assert_optimized_plan_equal {
853        (
854            $config:expr,
855            $plan:expr,
856            @ $expected:literal $(,)?
857        ) => {{
858            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
859            assert_optimized_plan_eq_snapshot!(
860                $config,
861                rules,
862                $plan,
863                @ $expected,
864            )
865        }};
866
867        (
868            $plan:expr,
869            @ $expected:literal $(,)?
870        ) => {{
871            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
872            let optimizer_ctx = OptimizerContext::new();
873            assert_optimized_plan_eq_snapshot!(
874                optimizer_ctx,
875                rules,
876                $plan,
877                @ $expected,
878            )
879        }};
880    }
881
882    #[test]
883    fn tpch_q1_simplified() -> Result<()> {
884        // SQL:
885        //  select
886        //      sum(a * (1 - b)),
887        //      sum(a * (1 - b) * (1 + c))
888        //  from T;
889        //
890        // The manual assembled logical plan don't contains the outermost `Projection`.
891
892        let table_scan = test_table_scan()?;
893
894        let plan = LogicalPlanBuilder::from(table_scan)
895            .aggregate(
896                iter::empty::<Expr>(),
897                vec![
898                    sum(col("a") * (lit(1) - col("b"))),
899                    sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
900                ],
901            )?
902            .build()?;
903
904        assert_optimized_plan_equal!(
905            plan,
906            @ r"
907        Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
908          Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
909            TableScan: test
910        "
911        )
912    }
913
914    #[test]
915    fn nested_aliases() -> Result<()> {
916        let table_scan = test_table_scan()?;
917
918        let plan = LogicalPlanBuilder::from(table_scan)
919            .project(vec![
920                (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
921                col("a") + col("b"),
922            ])?
923            .build()?;
924
925        assert_optimized_plan_equal!(
926            plan,
927            @ r"
928        Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
929          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
930            TableScan: test
931        "
932        )
933    }
934
935    #[test]
936    fn aggregate() -> Result<()> {
937        let table_scan = test_table_scan()?;
938
939        let return_type = DataType::UInt32;
940        let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
941        let udf_agg = |inner: Expr| {
942            Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
943                Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
944                    "my_agg",
945                    Signature::exact(vec![DataType::UInt32], Volatility::Stable),
946                    return_type.clone(),
947                    Arc::clone(&accumulator),
948                    vec![Field::new("value", DataType::UInt32, true).into()],
949                ))),
950                vec![inner],
951                false,
952                None,
953                vec![],
954                None,
955            ))
956        };
957
958        // test: common aggregates
959        let plan = LogicalPlanBuilder::from(table_scan.clone())
960            .aggregate(
961                iter::empty::<Expr>(),
962                vec![
963                    // common: avg(col("a"))
964                    avg(col("a")).alias("col1"),
965                    avg(col("a")).alias("col2"),
966                    // no common
967                    avg(col("b")).alias("col3"),
968                    avg(col("c")),
969                    // common: udf_agg(col("a"))
970                    udf_agg(col("a")).alias("col4"),
971                    udf_agg(col("a")).alias("col5"),
972                    // no common
973                    udf_agg(col("b")).alias("col6"),
974                    udf_agg(col("c")),
975                ],
976            )?
977            .build()?;
978
979        assert_optimized_plan_equal!(
980            plan,
981            @ r"
982        Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
983          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
984            TableScan: test
985        "
986        )?;
987
988        // test: trafo after aggregate
989        let plan = LogicalPlanBuilder::from(table_scan.clone())
990            .aggregate(
991                iter::empty::<Expr>(),
992                vec![
993                    lit(1) + avg(col("a")),
994                    lit(1) - avg(col("a")),
995                    lit(1) + udf_agg(col("a")),
996                    lit(1) - udf_agg(col("a")),
997                ],
998            )?
999            .build()?;
1000
1001        assert_optimized_plan_equal!(
1002            plan,
1003            @ r"
1004        Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
1005          Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
1006            TableScan: test
1007        "
1008        )?;
1009
1010        // test: transformation before aggregate
1011        let plan = LogicalPlanBuilder::from(table_scan.clone())
1012            .aggregate(
1013                iter::empty::<Expr>(),
1014                vec![
1015                    avg(lit(1u32) + col("a")).alias("col1"),
1016                    udf_agg(lit(1u32) + col("a")).alias("col2"),
1017                ],
1018            )?
1019            .build()?;
1020
1021        assert_optimized_plan_equal!(
1022            plan,
1023            @ r"
1024        Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1025          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1026            TableScan: test
1027        "
1028        )?;
1029
1030        // test: common between agg and group
1031        let plan = LogicalPlanBuilder::from(table_scan.clone())
1032            .aggregate(
1033                vec![lit(1u32) + col("a")],
1034                vec![
1035                    avg(lit(1u32) + col("a")).alias("col1"),
1036                    udf_agg(lit(1u32) + col("a")).alias("col2"),
1037                ],
1038            )?
1039            .build()?;
1040
1041        assert_optimized_plan_equal!(
1042            plan,
1043            @ r"
1044        Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1045          Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1046            TableScan: test
1047        "
1048        )?;
1049
1050        // test: all mixed
1051        let plan = LogicalPlanBuilder::from(table_scan)
1052            .aggregate(
1053                vec![lit(1u32) + col("a")],
1054                vec![
1055                    (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1056                    (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1057                    avg(lit(1u32) + col("a")),
1058                    (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1059                    (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1060                    udf_agg(lit(1u32) + col("a")),
1061                ],
1062            )?
1063            .build()?;
1064
1065        assert_optimized_plan_equal!(
1066            plan,
1067            @ r"
1068        Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1069          Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1070            Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1071              TableScan: test
1072        "
1073        )
1074    }
1075
1076    #[test]
1077    fn aggregate_with_relations_and_dots() -> Result<()> {
1078        let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1079        let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1080
1081        let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1082
1083        let plan = LogicalPlanBuilder::from(table_scan)
1084            .aggregate(
1085                vec![col_a.clone()],
1086                vec![
1087                    (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1088                    avg(lit(1u32) + col_a),
1089                ],
1090            )?
1091            .build()?;
1092
1093        assert_optimized_plan_equal!(
1094            plan,
1095            @ r"
1096        Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1097          Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1098            Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1099              TableScan: table.test
1100        "
1101        )
1102    }
1103
1104    #[test]
1105    fn subexpr_in_same_order() -> Result<()> {
1106        let table_scan = test_table_scan()?;
1107
1108        let plan = LogicalPlanBuilder::from(table_scan)
1109            .project(vec![
1110                (lit(1) + col("a")).alias("first"),
1111                (lit(1) + col("a")).alias("second"),
1112            ])?
1113            .build()?;
1114
1115        assert_optimized_plan_equal!(
1116            plan,
1117            @ r"
1118        Projection: __common_expr_1 AS first, __common_expr_1 AS second
1119          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1120            TableScan: test
1121        "
1122        )
1123    }
1124
1125    #[test]
1126    fn subexpr_in_different_order() -> Result<()> {
1127        let table_scan = test_table_scan()?;
1128
1129        let plan = LogicalPlanBuilder::from(table_scan)
1130            .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1131            .build()?;
1132
1133        assert_optimized_plan_equal!(
1134            plan,
1135            @ r"
1136        Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1137          Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1138            TableScan: test
1139        "
1140        )
1141    }
1142
1143    #[test]
1144    fn cross_plans_subexpr() -> Result<()> {
1145        let table_scan = test_table_scan()?;
1146
1147        let plan = LogicalPlanBuilder::from(table_scan)
1148            .project(vec![lit(1) + col("a"), col("a")])?
1149            .project(vec![lit(1) + col("a")])?
1150            .build()?;
1151
1152        assert_optimized_plan_equal!(
1153            plan,
1154            @ r"
1155        Projection: Int32(1) + test.a
1156          Projection: Int32(1) + test.a, test.a
1157            TableScan: test
1158        "
1159        )
1160    }
1161
1162    #[test]
1163    fn redundant_project_fields() {
1164        let table_scan = test_table_scan().unwrap();
1165        let c_plus_a = col("c") + col("a");
1166        let b_plus_a = col("b") + col("a");
1167        let common_exprs_1 = vec![
1168            (c_plus_a, format!("{CSE_PREFIX}_1")),
1169            (b_plus_a, format!("{CSE_PREFIX}_2")),
1170        ];
1171        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1172        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1173        let common_exprs_2 = vec![
1174            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1175            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1176        ];
1177        let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1178        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1179
1180        let mut field_set = BTreeSet::new();
1181        for name in project_2.schema().field_names() {
1182            assert!(field_set.insert(name));
1183        }
1184    }
1185
1186    #[test]
1187    fn redundant_project_fields_join_input() {
1188        let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1189        let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1190        let join = LogicalPlanBuilder::from(table_scan_1)
1191            .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1192            .unwrap()
1193            .build()
1194            .unwrap();
1195        let c_plus_a = col("test1.c") + col("test1.a");
1196        let b_plus_a = col("test1.b") + col("test1.a");
1197        let common_exprs_1 = vec![
1198            (c_plus_a, format!("{CSE_PREFIX}_1")),
1199            (b_plus_a, format!("{CSE_PREFIX}_2")),
1200        ];
1201        let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1202        let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1203        let common_exprs_2 = vec![
1204            (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1205            (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1206        ];
1207        let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1208        let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1209
1210        let mut field_set = BTreeSet::new();
1211        for name in project_2.schema().field_names() {
1212            assert!(field_set.insert(name));
1213        }
1214    }
1215
1216    #[test]
1217    fn eliminated_subexpr_datatype() {
1218        use datafusion_expr::cast;
1219
1220        let schema = Schema::new(vec![
1221            Field::new("a", DataType::UInt64, false),
1222            Field::new("b", DataType::UInt64, false),
1223            Field::new("c", DataType::UInt64, false),
1224        ]);
1225
1226        let plan = table_scan(Some("table"), &schema, None)
1227            .unwrap()
1228            .filter(
1229                cast(col("a"), DataType::Int64)
1230                    .lt(lit(1_i64))
1231                    .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1232            )
1233            .unwrap()
1234            .build()
1235            .unwrap();
1236        let rule = CommonSubexprEliminate::new();
1237        let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1238        assert!(optimized_plan.transformed);
1239        let optimized_plan = optimized_plan.data;
1240
1241        let schema = optimized_plan.schema();
1242        let fields_with_datatypes: Vec<_> = schema
1243            .fields()
1244            .iter()
1245            .map(|field| (field.name(), field.data_type()))
1246            .collect();
1247        let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1248        let expected = r#"[
1249    (
1250        "a",
1251        UInt64,
1252    ),
1253    (
1254        "b",
1255        UInt64,
1256    ),
1257    (
1258        "c",
1259        UInt64,
1260    ),
1261]"#;
1262        assert_eq!(expected, formatted_fields_with_datatype);
1263    }
1264
1265    #[test]
1266    fn filter_schema_changed() -> Result<()> {
1267        let table_scan = test_table_scan()?;
1268
1269        let plan = LogicalPlanBuilder::from(table_scan)
1270            .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1271            .build()?;
1272
1273        assert_optimized_plan_equal!(
1274            plan,
1275            @ r"
1276        Projection: test.a, test.b, test.c
1277          Filter: __common_expr_1 - Int32(10) > __common_expr_1
1278            Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1279              TableScan: test
1280        "
1281        )
1282    }
1283
1284    #[test]
1285    fn test_extract_expressions_from_grouping_set() -> Result<()> {
1286        let mut result = Vec::with_capacity(3);
1287        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1288        extract_expressions(&grouping, &mut result);
1289
1290        assert!(result.len() == 3);
1291        Ok(())
1292    }
1293
1294    #[test]
1295    fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1296        let mut result = Vec::with_capacity(2);
1297        let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1298        extract_expressions(&grouping, &mut result);
1299        assert!(result.len() == 2);
1300        Ok(())
1301    }
1302
1303    #[test]
1304    fn test_alias_collision() -> Result<()> {
1305        let table_scan = test_table_scan()?;
1306
1307        let config = OptimizerContext::new();
1308        let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1309        let plan = LogicalPlanBuilder::from(table_scan.clone())
1310            .project(vec![
1311                (col("a") + col("b")).alias(common_expr_1.clone()),
1312                col("c"),
1313            ])?
1314            .project(vec![
1315                col(common_expr_1.clone()).alias("c1"),
1316                col(common_expr_1).alias("c2"),
1317                (col("c") + lit(2)).alias("c3"),
1318                (col("c") + lit(2)).alias("c4"),
1319            ])?
1320            .build()?;
1321
1322        assert_optimized_plan_equal!(
1323            config,
1324            plan,
1325            @ r"
1326        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1327          Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1328            Projection: test.a + test.b AS __common_expr_1, test.c
1329              TableScan: test
1330        "
1331        )?;
1332
1333        let config = OptimizerContext::new();
1334        let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1335        let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1336        let plan = LogicalPlanBuilder::from(table_scan)
1337            .project(vec![
1338                (col("a") + col("b")).alias(common_expr_2.clone()),
1339                col("c"),
1340            ])?
1341            .project(vec![
1342                col(common_expr_2.clone()).alias("c1"),
1343                col(common_expr_2).alias("c2"),
1344                (col("c") + lit(2)).alias("c3"),
1345                (col("c") + lit(2)).alias("c4"),
1346            ])?
1347            .build()?;
1348
1349        assert_optimized_plan_equal!(
1350            config,
1351            plan,
1352            @ r"
1353        Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1354          Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1355            Projection: test.a + test.b AS __common_expr_2, test.c
1356              TableScan: test
1357        "
1358        )?;
1359
1360        Ok(())
1361    }
1362
1363    #[test]
1364    fn test_extract_expressions_from_col() -> Result<()> {
1365        let mut result = Vec::with_capacity(1);
1366        extract_expressions(&col("a"), &mut result);
1367        assert!(result.len() == 1);
1368        Ok(())
1369    }
1370
1371    #[test]
1372    fn test_short_circuits() -> Result<()> {
1373        let table_scan = test_table_scan()?;
1374
1375        let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1376        let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1377        let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1378        let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1379        let plan = LogicalPlanBuilder::from(table_scan)
1380            .project(vec![
1381                extracted_short_circuit.clone().alias("c1"),
1382                extracted_short_circuit.alias("c2"),
1383                extracted_short_circuit_leg_1
1384                    .clone()
1385                    .or(not_extracted_short_circuit_leg_2.clone())
1386                    .alias("c3"),
1387                extracted_short_circuit_leg_1
1388                    .and(not_extracted_short_circuit_leg_2)
1389                    .alias("c4"),
1390                extracted_short_circuit_leg_3
1391                    .clone()
1392                    .or(extracted_short_circuit_leg_3)
1393                    .alias("c5"),
1394            ])?
1395            .build()?;
1396
1397        assert_optimized_plan_equal!(
1398            plan,
1399            @ r"
1400        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1401          Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1402            TableScan: test
1403        "
1404        )
1405    }
1406
1407    #[test]
1408    fn test_volatile() -> Result<()> {
1409        let table_scan = test_table_scan()?;
1410
1411        let extracted_child = col("a") + col("b");
1412        let rand = rand_func().call(vec![]);
1413        let not_extracted_volatile = extracted_child + rand;
1414        let plan = LogicalPlanBuilder::from(table_scan)
1415            .project(vec![
1416                not_extracted_volatile.clone().alias("c1"),
1417                not_extracted_volatile.alias("c2"),
1418            ])?
1419            .build()?;
1420
1421        assert_optimized_plan_equal!(
1422            plan,
1423            @ r"
1424        Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1425          Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1426            TableScan: test
1427        "
1428        )
1429    }
1430
1431    #[test]
1432    fn test_volatile_short_circuits() -> Result<()> {
1433        let table_scan = test_table_scan()?;
1434
1435        let rand = rand_func().call(vec![]);
1436        let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1437        let not_extracted_volatile_short_circuit_1 =
1438            extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1439        let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1440        let not_extracted_volatile_short_circuit_2 =
1441            rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1442        let plan = LogicalPlanBuilder::from(table_scan)
1443            .project(vec![
1444                not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1445                not_extracted_volatile_short_circuit_1.alias("c2"),
1446                not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1447                not_extracted_volatile_short_circuit_2.alias("c4"),
1448            ])?
1449            .build()?;
1450
1451        assert_optimized_plan_equal!(
1452            plan,
1453            @ r"
1454        Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1455          Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1456            TableScan: test
1457        "
1458        )
1459    }
1460
1461    #[test]
1462    fn test_non_top_level_common_expression() -> Result<()> {
1463        let table_scan = test_table_scan()?;
1464
1465        let common_expr = col("a") + col("b");
1466        let plan = LogicalPlanBuilder::from(table_scan)
1467            .project(vec![
1468                common_expr.clone().alias("c1"),
1469                common_expr.alias("c2"),
1470            ])?
1471            .project(vec![col("c1"), col("c2")])?
1472            .build()?;
1473
1474        assert_optimized_plan_equal!(
1475            plan,
1476            @ r"
1477        Projection: c1, c2
1478          Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1479            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1480              TableScan: test
1481        "
1482        )
1483    }
1484
1485    #[test]
1486    fn test_nested_common_expression() -> Result<()> {
1487        let table_scan = test_table_scan()?;
1488
1489        let nested_common_expr = col("a") + col("b");
1490        let common_expr = nested_common_expr.clone() * nested_common_expr;
1491        let plan = LogicalPlanBuilder::from(table_scan)
1492            .project(vec![
1493                common_expr.clone().alias("c1"),
1494                common_expr.alias("c2"),
1495            ])?
1496            .build()?;
1497
1498        assert_optimized_plan_equal!(
1499            plan,
1500            @ r"
1501        Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1502          Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1503            Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1504              TableScan: test
1505        "
1506        )
1507    }
1508
1509    #[test]
1510    fn test_normalize_add_expression() -> Result<()> {
1511        // a + b <=> b + a
1512        let table_scan = test_table_scan()?;
1513        let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1514        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1515
1516        assert_optimized_plan_equal!(
1517            plan,
1518            @ r"
1519        Projection: test.a, test.b, test.c
1520          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1521            Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1522              TableScan: test
1523        "
1524        )
1525    }
1526
1527    #[test]
1528    fn test_normalize_multi_expression() -> Result<()> {
1529        // a * b <=> b * a
1530        let table_scan = test_table_scan()?;
1531        let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1532        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1533
1534        assert_optimized_plan_equal!(
1535            plan,
1536            @ r"
1537        Projection: test.a, test.b, test.c
1538          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1539            Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1540              TableScan: test
1541        "
1542        )
1543    }
1544
1545    #[test]
1546    fn test_normalize_bitset_and_expression() -> Result<()> {
1547        // a & b <=> b & a
1548        let table_scan = test_table_scan()?;
1549        let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1550        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1551
1552        assert_optimized_plan_equal!(
1553            plan,
1554            @ r"
1555        Projection: test.a, test.b, test.c
1556          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1557            Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1558              TableScan: test
1559        "
1560        )
1561    }
1562
1563    #[test]
1564    fn test_normalize_bitset_or_expression() -> Result<()> {
1565        // a | b <=> b | a
1566        let table_scan = test_table_scan()?;
1567        let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1568        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1569
1570        assert_optimized_plan_equal!(
1571            plan,
1572            @ r"
1573        Projection: test.a, test.b, test.c
1574          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1575            Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1576              TableScan: test
1577        "
1578        )
1579    }
1580
1581    #[test]
1582    fn test_normalize_bitset_xor_expression() -> Result<()> {
1583        // a # b <=> b # a
1584        let table_scan = test_table_scan()?;
1585        let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1586        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1587
1588        assert_optimized_plan_equal!(
1589            plan,
1590            @ r"
1591        Projection: test.a, test.b, test.c
1592          Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1593            Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1594              TableScan: test
1595        "
1596        )
1597    }
1598
1599    #[test]
1600    fn test_normalize_eq_expression() -> Result<()> {
1601        // a = b <=> b = a
1602        let table_scan = test_table_scan()?;
1603        let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1604        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1605
1606        assert_optimized_plan_equal!(
1607            plan,
1608            @ r"
1609        Projection: test.a, test.b, test.c
1610          Filter: __common_expr_1 AND __common_expr_1
1611            Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1612              TableScan: test
1613        "
1614        )
1615    }
1616
1617    #[test]
1618    fn test_normalize_ne_expression() -> Result<()> {
1619        // a != b <=> b != a
1620        let table_scan = test_table_scan()?;
1621        let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1622        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1623
1624        assert_optimized_plan_equal!(
1625            plan,
1626            @ r"
1627        Projection: test.a, test.b, test.c
1628          Filter: __common_expr_1 AND __common_expr_1
1629            Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1630              TableScan: test
1631        "
1632        )
1633    }
1634
1635    #[test]
1636    fn test_normalize_complex_expression() -> Result<()> {
1637        // case1: a + b * c <=> b * c + a
1638        let table_scan = test_table_scan()?;
1639        let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1640            .eq(lit(30));
1641        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1642
1643        assert_optimized_plan_equal!(
1644            plan,
1645            @ r"
1646        Projection: test.a, test.b, test.c
1647          Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1648            Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1649              TableScan: test
1650        "
1651        )?;
1652
1653        // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1))
1654        let table_scan = test_table_scan()?;
1655        let expr = (((col("a") + col("b") / col("c")) * col("c"))
1656            / (col("c") * (col("b") / col("c") + col("a")))
1657            + col("a"))
1658        .eq(lit(30));
1659        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1660
1661        assert_optimized_plan_equal!(
1662            plan,
1663            @ r"
1664        Projection: test.a, test.b, test.c
1665          Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1666            Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1667              TableScan: test
1668        "
1669        )?;
1670
1671        // c2 / (c1 + c3) <=> c2 / (c3 + c1)
1672        let table_scan = test_table_scan()?;
1673        let expr = ((col("b") / (col("a") + col("c")))
1674            * (col("b") / (col("c") + col("a"))))
1675        .eq(lit(30));
1676        let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1677        assert_optimized_plan_equal!(
1678            plan,
1679            @ r"
1680        Projection: test.a, test.b, test.c
1681          Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1682            Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1683              TableScan: test
1684        "
1685        )?;
1686
1687        Ok(())
1688    }
1689
1690    #[derive(Debug, PartialEq, Eq, Hash)]
1691    pub struct TestUdf {
1692        signature: Signature,
1693    }
1694
1695    impl TestUdf {
1696        pub fn new() -> Self {
1697            Self {
1698                signature: Signature::numeric(1, Volatility::Immutable),
1699            }
1700        }
1701    }
1702
1703    impl ScalarUDFImpl for TestUdf {
1704        fn as_any(&self) -> &dyn Any {
1705            self
1706        }
1707        fn name(&self) -> &str {
1708            "my_udf"
1709        }
1710
1711        fn signature(&self) -> &Signature {
1712            &self.signature
1713        }
1714
1715        fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1716            Ok(DataType::Int32)
1717        }
1718
1719        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1720            panic!("not implemented")
1721        }
1722    }
1723
1724    #[test]
1725    fn test_normalize_inner_binary_expression() -> Result<()> {
1726        // Not(a == b) <=> Not(b == a)
1727        let table_scan = test_table_scan()?;
1728        let expr1 = not(col("a").eq(col("b")));
1729        let expr2 = not(col("b").eq(col("a")));
1730        let plan = LogicalPlanBuilder::from(table_scan)
1731            .project(vec![expr1, expr2])?
1732            .build()?;
1733        assert_optimized_plan_equal!(
1734            plan,
1735            @ r"
1736        Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1737          Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1738            TableScan: test
1739        "
1740        )?;
1741
1742        // is_null(a == b) <=> is_null(b == a)
1743        let table_scan = test_table_scan()?;
1744        let expr1 = is_null(col("a").eq(col("b")));
1745        let expr2 = is_null(col("b").eq(col("a")));
1746        let plan = LogicalPlanBuilder::from(table_scan)
1747            .project(vec![expr1, expr2])?
1748            .build()?;
1749        assert_optimized_plan_equal!(
1750            plan,
1751            @ r"
1752        Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1753          Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1754            TableScan: test
1755        "
1756        )?;
1757
1758        // a + b between 0 and 10 <=> b + a between 0 and 10
1759        let table_scan = test_table_scan()?;
1760        let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1761        let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1762        let plan = LogicalPlanBuilder::from(table_scan)
1763            .project(vec![expr1, expr2])?
1764            .build()?;
1765        assert_optimized_plan_equal!(
1766            plan,
1767            @ r"
1768        Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1769          Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1770            TableScan: test
1771        "
1772        )?;
1773
1774        // c between a + b and 10 <=> c between b + a and 10
1775        let table_scan = test_table_scan()?;
1776        let expr1 = col("c").between(col("a") + col("b"), lit(10));
1777        let expr2 = col("c").between(col("b") + col("a"), lit(10));
1778        let plan = LogicalPlanBuilder::from(table_scan)
1779            .project(vec![expr1, expr2])?
1780            .build()?;
1781        assert_optimized_plan_equal!(
1782            plan,
1783            @ r"
1784        Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1785          Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1786            TableScan: test
1787        "
1788        )?;
1789
1790        // function call with argument <=> function call with argument
1791        let udf = ScalarUDF::from(TestUdf::new());
1792        let table_scan = test_table_scan()?;
1793        let expr1 = udf.call(vec![col("a") + col("b")]);
1794        let expr2 = udf.call(vec![col("b") + col("a")]);
1795        let plan = LogicalPlanBuilder::from(table_scan)
1796            .project(vec![expr1, expr2])?
1797            .build()?;
1798        assert_optimized_plan_equal!(
1799            plan,
1800            @ r"
1801        Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1802          Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1803            TableScan: test
1804        "
1805        )
1806    }
1807
1808    /// returns a "random" function that is marked volatile (aka each invocation
1809    /// returns a different value)
1810    ///
1811    /// Does not use datafusion_functions::rand to avoid introducing a
1812    /// dependency on that crate.
1813    fn rand_func() -> ScalarUDF {
1814        ScalarUDF::new_from_impl(RandomStub::new())
1815    }
1816
1817    #[derive(Debug, PartialEq, Eq, Hash)]
1818    struct RandomStub {
1819        signature: Signature,
1820    }
1821
1822    impl RandomStub {
1823        fn new() -> Self {
1824            Self {
1825                signature: Signature::exact(vec![], Volatility::Volatile),
1826            }
1827        }
1828    }
1829    impl ScalarUDFImpl for RandomStub {
1830        fn as_any(&self) -> &dyn Any {
1831            self
1832        }
1833
1834        fn name(&self) -> &str {
1835            "random"
1836        }
1837
1838        fn signature(&self) -> &Signature {
1839            &self.signature
1840        }
1841
1842        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1843            Ok(DataType::Float64)
1844        }
1845
1846        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1847            panic!("dummy - not implemented")
1848        }
1849    }
1850}