datafusion_optimizer/
push_down_filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`PushDownFilter`] applies filters as early as possible
19
20use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31    internal_err, plan_err, qualified_name, Column, DFSchema, Result,
32};
33use datafusion_expr::expr::WindowFunction;
34use datafusion_expr::expr_rewriter::replace_col;
35use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
36use datafusion_expr::utils::{
37    conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
38};
39use datafusion_expr::{
40    and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
41};
42
43use crate::optimizer::ApplyOrder;
44use crate::simplify_expressions::simplify_predicates;
45use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
46use crate::{OptimizerConfig, OptimizerRule};
47
48/// Optimizer rule for pushing (moving) filter expressions down in a plan so
49/// they are applied as early as possible.
50///
51/// # Introduction
52///
53/// The goal of this rule is to improve query performance by eliminating
54/// redundant work.
55///
56/// For example, given a plan that sorts all values where `a > 10`:
57///
58/// ```text
59///  Filter (a > 10)
60///    Sort (a, b)
61/// ```
62///
63/// A better plan is to  filter the data *before* the Sort, which sorts fewer
64/// rows and therefore does less work overall:
65///
66/// ```text
67///  Sort (a, b)
68///    Filter (a > 10)  <-- Filter is moved before the sort
69/// ```
70///
71/// However it is not always possible to push filters down. For example, given a
72/// plan that finds the top 3 values and then keeps only those that are greater
73/// than 10, if the filter is pushed below the limit it would produce a
74/// different result.
75///
76/// ```text
77///  Filter (a > 10)   <-- can not move this Filter before the limit
78///    Limit (fetch=3)
79///      Sort (a, b)
80/// ```
81///
82///
83/// More formally, a filter-commutative operation is an operation `op` that
84/// satisfies `filter(op(data)) = op(filter(data))`.
85///
86/// The filter-commutative property is plan and column-specific. A filter on `a`
87/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a
88/// filter on  `sum(b)` can not be pushed through the same aggregate.
89///
90/// # Handling Conjunctions
91///
92/// It is possible to only push down **part** of a filter expression if is
93/// connected with `AND`s (more formally if it is a "conjunction").
94///
95/// For example, given the following plan:
96///
97/// ```text
98/// Filter(a > 10 AND sum(b) < 5)
99///   Aggregate(group_by = [a], agg = [sum(b))
100/// ```
101///
102/// The `a > 10` is commutative with the `Aggregate` but  `sum(b) < 5` is not.
103/// Therefore it is possible to only push part of the expression, resulting in:
104///
105/// ```text
106/// Filter(sum(b) < 5)
107///   Aggregate(group_by = [a], agg = [sum(b))
108///     Filter(a > 10)
109/// ```
110///
111/// # Handling Column Aliases
112///
113/// This optimizer must sometimes handle re-writing filter expressions when they
114/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
115///
116/// ```text
117/// Filter (b > 10)
118///     Projection: [a+1 AS "b"]  <-- changes the name of `a+1` to `b`
119/// ```
120///
121/// To apply the filter prior to the `Projection`, all references to `b` must be
122/// rewritten to `a+1`:
123///
124/// ```text
125/// Projection: a AS "b"
126///     Filter: (a + 1 > 10)  <--- changed from b to a + 1
127/// ```
128/// # Implementation Notes
129///
130/// This implementation performs a single pass through the plan, "pushing" down
131/// filters. When it passes through a filter, it stores that filter, and when it
132/// reaches a plan node that does not commute with that filter, it adds the
133/// filter to that place. When it passes through a projection, it re-writes the
134/// filter's expression taking into account that projection.
135#[derive(Default, Debug)]
136pub struct PushDownFilter {}
137
138/// For a given JOIN type, determine whether each input of the join is preserved
139/// for post-join (`WHERE` clause) filters.
140///
141/// It is only correct to push filters below a join for preserved inputs.
142///
143/// # Return Value
144/// A tuple of booleans - (left_preserved, right_preserved).
145///
146/// # "Preserved" input definition
147///
148/// We say a join side is preserved if the join returns all or a subset of the rows from
149/// the relevant side, such that each row of the output table directly maps to a row of
150/// the preserved input table. If a table is not preserved, it can provide extra null rows.
151/// That is, there may be rows in the output table that don't directly map to a row in the
152/// input table.
153///
154/// For example:
155///   - In an inner join, both sides are preserved, because each row of the output
156///     maps directly to a row from each side.
157///
158///   - In a left join, the left side is preserved (we can push predicates) but
159///     the right is not, because there may be rows in the output that don't
160///     directly map to a row in the right input (due to nulls filling where there
161///     is no match on the right).
162pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
163    match join_type {
164        JoinType::Inner => (true, true),
165        JoinType::Left => (true, false),
166        JoinType::Right => (false, true),
167        JoinType::Full => (false, false),
168        // No columns from the right side of the join can be referenced in output
169        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
170        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
171        // No columns from the left side of the join can be referenced in output
172        // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
173        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
174    }
175}
176
177/// For a given JOIN type, determine whether each input of the join is preserved
178/// for the join condition (`ON` clause filters).
179///
180/// It is only correct to push filters below a join for preserved inputs.
181///
182/// # Return Value
183/// A tuple of booleans - (left_preserved, right_preserved).
184///
185/// See [`lr_is_preserved`] for a definition of "preserved".
186pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
187    match join_type {
188        JoinType::Inner => (true, true),
189        JoinType::Left => (false, true),
190        JoinType::Right => (true, false),
191        JoinType::Full => (false, false),
192        JoinType::LeftSemi | JoinType::RightSemi => (true, true),
193        JoinType::LeftAnti => (false, true),
194        JoinType::RightAnti => (true, false),
195        JoinType::LeftMark => (false, true),
196        JoinType::RightMark => (true, false),
197    }
198}
199
200/// Evaluates the columns referenced in the given expression to see if they refer
201/// only to the left or right columns
202#[derive(Debug)]
203struct ColumnChecker<'a> {
204    /// schema of left join input
205    left_schema: &'a DFSchema,
206    /// columns in left_schema, computed on demand
207    left_columns: Option<HashSet<Column>>,
208    /// schema of right join input
209    right_schema: &'a DFSchema,
210    /// columns in left_schema, computed on demand
211    right_columns: Option<HashSet<Column>>,
212}
213
214impl<'a> ColumnChecker<'a> {
215    fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
216        Self {
217            left_schema,
218            left_columns: None,
219            right_schema,
220            right_columns: None,
221        }
222    }
223
224    /// Return true if the expression references only columns from the left side of the join
225    fn is_left_only(&mut self, predicate: &Expr) -> bool {
226        if self.left_columns.is_none() {
227            self.left_columns = Some(schema_columns(self.left_schema));
228        }
229        has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
230    }
231
232    /// Return true if the expression references only columns from the right side of the join
233    fn is_right_only(&mut self, predicate: &Expr) -> bool {
234        if self.right_columns.is_none() {
235            self.right_columns = Some(schema_columns(self.right_schema));
236        }
237        has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
238    }
239}
240
241/// Returns all columns in the schema
242fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
243    schema
244        .iter()
245        .flat_map(|(qualifier, field)| {
246            [
247                Column::new(qualifier.cloned(), field.name()),
248                // we need to push down filter using unqualified column as well
249                Column::new_unqualified(field.name()),
250            ]
251        })
252        .collect::<HashSet<_>>()
253}
254
255/// Determine whether the predicate can evaluate as the join conditions
256fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
257    let mut is_evaluate = true;
258    predicate.apply(|expr| match expr {
259        Expr::Column(_)
260        | Expr::Literal(_, _)
261        | Expr::Placeholder(_)
262        | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
263        Expr::Exists { .. }
264        | Expr::InSubquery(_)
265        | Expr::ScalarSubquery(_)
266        | Expr::OuterReferenceColumn(_, _)
267        | Expr::Unnest(_) => {
268            is_evaluate = false;
269            Ok(TreeNodeRecursion::Stop)
270        }
271        Expr::Alias(_)
272        | Expr::BinaryExpr(_)
273        | Expr::Like(_)
274        | Expr::SimilarTo(_)
275        | Expr::Not(_)
276        | Expr::IsNotNull(_)
277        | Expr::IsNull(_)
278        | Expr::IsTrue(_)
279        | Expr::IsFalse(_)
280        | Expr::IsUnknown(_)
281        | Expr::IsNotTrue(_)
282        | Expr::IsNotFalse(_)
283        | Expr::IsNotUnknown(_)
284        | Expr::Negative(_)
285        | Expr::Between(_)
286        | Expr::Case(_)
287        | Expr::Cast(_)
288        | Expr::TryCast(_)
289        | Expr::InList { .. }
290        | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
291        // TODO: remove the next line after `Expr::Wildcard` is removed
292        #[expect(deprecated)]
293        Expr::AggregateFunction(_)
294        | Expr::WindowFunction(_)
295        | Expr::Wildcard { .. }
296        | Expr::GroupingSet(_)
297        | Expr::Lambda { .. } => internal_err!("Unsupported predicate type"),
298    })?;
299    Ok(is_evaluate)
300}
301
302/// examine OR clause to see if any useful clauses can be extracted and push down.
303/// extract at least one qual from each sub clauses of OR clause, then form the quals
304/// to new OR clause as predicate.
305///
306/// # Example
307/// ```text
308/// Filter: (a = c and a < 20) or (b = d and b > 10)
309///     join/crossjoin:
310///          TableScan: projection=[a, b]
311///          TableScan: projection=[c, d]
312/// ```
313///
314/// is optimized to
315///
316/// ```text
317/// Filter: (a = c and a < 20) or (b = d and b > 10)
318///     join/crossjoin:
319///          Filter: (a < 20) or (b > 10)
320///              TableScan: projection=[a, b]
321///          TableScan: projection=[c, d]
322/// ```
323///
324/// In general, predicates of this form:
325///
326/// ```sql
327/// (A AND B) OR (C AND D)
328/// ```
329///
330/// will be transformed to one of:
331///
332/// * `((A AND B) OR (C AND D)) AND (A OR C)`
333/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)`
334/// * do nothing.
335fn extract_or_clauses_for_join<'a>(
336    filters: &'a [Expr],
337    schema: &'a DFSchema,
338) -> impl Iterator<Item = Expr> + 'a {
339    let schema_columns = schema_columns(schema);
340
341    // new formed OR clauses and their column references
342    filters.iter().filter_map(move |expr| {
343        if let Expr::BinaryExpr(BinaryExpr {
344            left,
345            op: Operator::Or,
346            right,
347        }) = expr
348        {
349            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
350            let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
351
352            // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
353            if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
354                return Some(or(left_expr, right_expr));
355            }
356        }
357        None
358    })
359}
360
361/// extract qual from OR sub-clause.
362///
363/// A qual is extracted if it only contains set of column references in schema_columns.
364///
365/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
366/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
367///
368/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
369/// Otherwise, return None.
370///
371/// For other clause, apply the rule above to extract clause.
372fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
373    let mut predicate = None;
374
375    match expr {
376        Expr::BinaryExpr(BinaryExpr {
377            left: l_expr,
378            op: Operator::Or,
379            right: r_expr,
380        }) => {
381            let l_expr = extract_or_clause(l_expr, schema_columns);
382            let r_expr = extract_or_clause(r_expr, schema_columns);
383
384            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
385                predicate = Some(or(l_expr, r_expr));
386            }
387        }
388        Expr::BinaryExpr(BinaryExpr {
389            left: l_expr,
390            op: Operator::And,
391            right: r_expr,
392        }) => {
393            let l_expr = extract_or_clause(l_expr, schema_columns);
394            let r_expr = extract_or_clause(r_expr, schema_columns);
395
396            match (l_expr, r_expr) {
397                (Some(l_expr), Some(r_expr)) => {
398                    predicate = Some(and(l_expr, r_expr));
399                }
400                (Some(l_expr), None) => {
401                    predicate = Some(l_expr);
402                }
403                (None, Some(r_expr)) => {
404                    predicate = Some(r_expr);
405                }
406                (None, None) => {
407                    predicate = None;
408                }
409            }
410        }
411        _ => {
412            if has_all_column_refs(expr, schema_columns) {
413                predicate = Some(expr.clone());
414            }
415        }
416    }
417
418    predicate
419}
420
421/// push down join/cross-join
422fn push_down_all_join(
423    predicates: Vec<Expr>,
424    inferred_join_predicates: Vec<Expr>,
425    mut join: Join,
426    on_filter: Vec<Expr>,
427) -> Result<Transformed<LogicalPlan>> {
428    let is_inner_join = join.join_type == JoinType::Inner;
429    // Get pushable predicates from current optimizer state
430    let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
431
432    // The predicates can be divided to three categories:
433    // 1) can push through join to its children(left or right)
434    // 2) can be converted to join conditions if the join type is Inner
435    // 3) should be kept as filter conditions
436    let left_schema = join.left.schema();
437    let right_schema = join.right.schema();
438    let mut left_push = vec![];
439    let mut right_push = vec![];
440    let mut keep_predicates = vec![];
441    let mut join_conditions = vec![];
442    let mut checker = ColumnChecker::new(left_schema, right_schema);
443    for predicate in predicates {
444        if left_preserved && checker.is_left_only(&predicate) {
445            left_push.push(predicate);
446        } else if right_preserved && checker.is_right_only(&predicate) {
447            right_push.push(predicate);
448        } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
449            // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
450            // and convert to the join on condition
451            join_conditions.push(predicate);
452        } else {
453            keep_predicates.push(predicate);
454        }
455    }
456
457    // For infer predicates, if they can not push through join, just drop them
458    for predicate in inferred_join_predicates {
459        if left_preserved && checker.is_left_only(&predicate) {
460            left_push.push(predicate);
461        } else if right_preserved && checker.is_right_only(&predicate) {
462            right_push.push(predicate);
463        }
464    }
465
466    let mut on_filter_join_conditions = vec![];
467    let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
468
469    if !on_filter.is_empty() {
470        for on in on_filter {
471            if on_left_preserved && checker.is_left_only(&on) {
472                left_push.push(on)
473            } else if on_right_preserved && checker.is_right_only(&on) {
474                right_push.push(on)
475            } else {
476                on_filter_join_conditions.push(on)
477            }
478        }
479    }
480
481    // Extract from OR clause, generate new predicates for both side of join if possible.
482    // We only track the unpushable predicates above.
483    if left_preserved {
484        left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
485        left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
486    }
487    if right_preserved {
488        right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
489        right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
490    }
491
492    // For predicates from join filter, we should check with if a join side is preserved
493    // in term of join filtering.
494    if on_left_preserved {
495        left_push.extend(extract_or_clauses_for_join(
496            &on_filter_join_conditions,
497            left_schema,
498        ));
499    }
500    if on_right_preserved {
501        right_push.extend(extract_or_clauses_for_join(
502            &on_filter_join_conditions,
503            right_schema,
504        ));
505    }
506
507    if let Some(predicate) = conjunction(left_push) {
508        join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
509    }
510    if let Some(predicate) = conjunction(right_push) {
511        join.right =
512            Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
513    }
514
515    // Add any new join conditions as the non join predicates
516    join_conditions.extend(on_filter_join_conditions);
517    join.filter = conjunction(join_conditions);
518
519    // wrap the join on the filter whose predicates must be kept, if any
520    let plan = LogicalPlan::Join(join);
521    let plan = if let Some(predicate) = conjunction(keep_predicates) {
522        LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
523    } else {
524        plan
525    };
526    Ok(Transformed::yes(plan))
527}
528
529fn push_down_join(
530    join: Join,
531    parent_predicate: Option<&Expr>,
532) -> Result<Transformed<LogicalPlan>> {
533    // Split the parent predicate into individual conjunctive parts.
534    let predicates = parent_predicate
535        .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
536
537    // Extract conjunctions from the JOIN's ON filter, if present.
538    let on_filters = join
539        .filter
540        .as_ref()
541        .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
542
543    // Are there any new join predicates that can be inferred from the filter expressions?
544    let inferred_join_predicates =
545        infer_join_predicates(&join, &predicates, &on_filters)?;
546
547    if on_filters.is_empty()
548        && predicates.is_empty()
549        && inferred_join_predicates.is_empty()
550    {
551        return Ok(Transformed::no(LogicalPlan::Join(join)));
552    }
553
554    push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
555}
556
557/// Extracts any equi-join join predicates from the given filter expressions.
558///
559/// Parameters
560/// * `join` the join in question
561///
562/// * `predicates` the pushed down filter expression
563///
564/// * `on_filters` filters from the join ON clause that have not already been
565///   identified as join predicates
566fn infer_join_predicates(
567    join: &Join,
568    predicates: &[Expr],
569    on_filters: &[Expr],
570) -> Result<Vec<Expr>> {
571    // Only allow both side key is column.
572    let join_col_keys = join
573        .on
574        .iter()
575        .filter_map(|(l, r)| {
576            let left_col = l.try_as_col()?;
577            let right_col = r.try_as_col()?;
578            Some((left_col, right_col))
579        })
580        .collect::<Vec<_>>();
581
582    let join_type = join.join_type;
583
584    let mut inferred_predicates = InferredPredicates::new(join_type);
585
586    infer_join_predicates_from_predicates(
587        &join_col_keys,
588        predicates,
589        &mut inferred_predicates,
590    )?;
591
592    infer_join_predicates_from_on_filters(
593        &join_col_keys,
594        join_type,
595        on_filters,
596        &mut inferred_predicates,
597    )?;
598
599    Ok(inferred_predicates.predicates)
600}
601
602/// Inferred predicates collector.
603/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly
604/// filter out NULL, otherwise ignore it. e.g.
605/// ```text
606/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL;
607/// ```
608/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to
609/// the left side, resulting in the wrong result.
610struct InferredPredicates {
611    predicates: Vec<Expr>,
612    is_inner_join: bool,
613}
614
615impl InferredPredicates {
616    fn new(join_type: JoinType) -> Self {
617        Self {
618            predicates: vec![],
619            is_inner_join: matches!(join_type, JoinType::Inner),
620        }
621    }
622
623    fn try_build_predicate(
624        &mut self,
625        predicate: Expr,
626        replace_map: &HashMap<&Column, &Column>,
627    ) -> Result<()> {
628        if self.is_inner_join
629            || matches!(
630                is_restrict_null_predicate(
631                    predicate.clone(),
632                    replace_map.keys().cloned()
633                ),
634                Ok(true)
635            )
636        {
637            self.predicates.push(replace_col(predicate, replace_map)?);
638        }
639
640        Ok(())
641    }
642}
643
644/// Infer predicates from the pushed down predicates.
645///
646/// Parameters
647/// * `join_col_keys` column pairs from the join ON clause
648///
649/// * `predicates` the pushed down predicates
650///
651/// * `inferred_predicates` the inferred results
652fn infer_join_predicates_from_predicates(
653    join_col_keys: &[(&Column, &Column)],
654    predicates: &[Expr],
655    inferred_predicates: &mut InferredPredicates,
656) -> Result<()> {
657    infer_join_predicates_impl::<true, true>(
658        join_col_keys,
659        predicates,
660        inferred_predicates,
661    )
662}
663
664/// Infer predicates from the join filter.
665///
666/// Parameters
667/// * `join_col_keys` column pairs from the join ON clause
668///
669/// * `join_type` the JoinType of Join
670///
671/// * `on_filters` filters from the join ON clause that have not already been
672///   identified as join predicates
673///
674/// * `inferred_predicates` the inferred results
675fn infer_join_predicates_from_on_filters(
676    join_col_keys: &[(&Column, &Column)],
677    join_type: JoinType,
678    on_filters: &[Expr],
679    inferred_predicates: &mut InferredPredicates,
680) -> Result<()> {
681    match join_type {
682        JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
683        JoinType::Inner => infer_join_predicates_impl::<true, true>(
684            join_col_keys,
685            on_filters,
686            inferred_predicates,
687        ),
688        JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
689            infer_join_predicates_impl::<true, false>(
690                join_col_keys,
691                on_filters,
692                inferred_predicates,
693            )
694        }
695        JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
696            infer_join_predicates_impl::<false, true>(
697                join_col_keys,
698                on_filters,
699                inferred_predicates,
700            )
701        }
702    }
703}
704
705/// Infer predicates from the given predicates.
706///
707/// Parameters
708/// * `join_col_keys` column pairs from the join ON clause
709///
710/// * `input_predicates` the given predicates. It can be the pushed down predicates,
711///   or it can be the filters of the Join
712///
713/// * `inferred_predicates` the inferred results
714///
715/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can
716///   be inferred from the left table related predicate
717///
718/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can
719///   be inferred from the right table related predicate
720fn infer_join_predicates_impl<
721    const ENABLE_LEFT_TO_RIGHT: bool,
722    const ENABLE_RIGHT_TO_LEFT: bool,
723>(
724    join_col_keys: &[(&Column, &Column)],
725    input_predicates: &[Expr],
726    inferred_predicates: &mut InferredPredicates,
727) -> Result<()> {
728    for predicate in input_predicates {
729        let mut join_cols_to_replace = HashMap::new();
730
731        for &col in &predicate.column_refs() {
732            for (l, r) in join_col_keys.iter() {
733                if ENABLE_LEFT_TO_RIGHT && col == *l {
734                    join_cols_to_replace.insert(col, *r);
735                    break;
736                }
737                if ENABLE_RIGHT_TO_LEFT && col == *r {
738                    join_cols_to_replace.insert(col, *l);
739                    break;
740                }
741            }
742        }
743        if join_cols_to_replace.is_empty() {
744            continue;
745        }
746
747        inferred_predicates
748            .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
749    }
750    Ok(())
751}
752
753impl OptimizerRule for PushDownFilter {
754    fn name(&self) -> &str {
755        "push_down_filter"
756    }
757
758    fn apply_order(&self) -> Option<ApplyOrder> {
759        Some(ApplyOrder::TopDown)
760    }
761
762    fn supports_rewrite(&self) -> bool {
763        true
764    }
765
766    fn rewrite(
767        &self,
768        plan: LogicalPlan,
769        _config: &dyn OptimizerConfig,
770    ) -> Result<Transformed<LogicalPlan>> {
771        if let LogicalPlan::Join(join) = plan {
772            return push_down_join(join, None);
773        };
774
775        let plan_schema = Arc::clone(plan.schema());
776
777        let LogicalPlan::Filter(mut filter) = plan else {
778            return Ok(Transformed::no(plan));
779        };
780
781        let predicate = split_conjunction_owned(filter.predicate.clone());
782        let old_predicate_len = predicate.len();
783        let new_predicates = simplify_predicates(predicate)?;
784        if old_predicate_len != new_predicates.len() {
785            let Some(new_predicate) = conjunction(new_predicates) else {
786                // new_predicates is empty - remove the filter entirely
787                // Return the child plan without the filter
788                return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
789            };
790            filter.predicate = new_predicate;
791        }
792
793        match Arc::unwrap_or_clone(filter.input) {
794            LogicalPlan::Filter(child_filter) => {
795                let parents_predicates = split_conjunction_owned(filter.predicate);
796
797                // remove duplicated filters
798                let child_predicates = split_conjunction_owned(child_filter.predicate);
799                let new_predicates = parents_predicates
800                    .into_iter()
801                    .chain(child_predicates)
802                    // use IndexSet to remove dupes while preserving predicate order
803                    .collect::<IndexSet<_>>()
804                    .into_iter()
805                    .collect::<Vec<_>>();
806
807                let Some(new_predicate) = conjunction(new_predicates) else {
808                    return plan_err!("at least one expression exists");
809                };
810                let new_filter = LogicalPlan::Filter(Filter::try_new(
811                    new_predicate,
812                    child_filter.input,
813                )?);
814                #[allow(clippy::used_underscore_binding)]
815                self.rewrite(new_filter, _config)
816            }
817            LogicalPlan::Repartition(repartition) => {
818                let new_filter =
819                    Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
820                        .map(LogicalPlan::Filter)?;
821                insert_below(LogicalPlan::Repartition(repartition), new_filter)
822            }
823            LogicalPlan::Distinct(distinct) => {
824                let new_filter =
825                    Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
826                        .map(LogicalPlan::Filter)?;
827                insert_below(LogicalPlan::Distinct(distinct), new_filter)
828            }
829            LogicalPlan::Sort(sort) => {
830                let new_filter =
831                    Filter::try_new(filter.predicate, Arc::clone(&sort.input))
832                        .map(LogicalPlan::Filter)?;
833                insert_below(LogicalPlan::Sort(sort), new_filter)
834            }
835            LogicalPlan::SubqueryAlias(subquery_alias) => {
836                let mut replace_map = HashMap::new();
837                for (i, (qualifier, field)) in
838                    subquery_alias.input.schema().iter().enumerate()
839                {
840                    let (sub_qualifier, sub_field) =
841                        subquery_alias.schema.qualified_field(i);
842                    replace_map.insert(
843                        qualified_name(sub_qualifier, sub_field.name()),
844                        Expr::Column(Column::new(qualifier.cloned(), field.name())),
845                    );
846                }
847                let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
848
849                let new_filter = LogicalPlan::Filter(Filter::try_new(
850                    new_predicate,
851                    Arc::clone(&subquery_alias.input),
852                )?);
853                insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
854            }
855            LogicalPlan::Projection(projection) => {
856                let predicates = split_conjunction_owned(filter.predicate.clone());
857                let (new_projection, keep_predicate) =
858                    rewrite_projection(predicates, projection)?;
859                if new_projection.transformed {
860                    match keep_predicate {
861                        None => Ok(new_projection),
862                        Some(keep_predicate) => new_projection.map_data(|child_plan| {
863                            Filter::try_new(keep_predicate, Arc::new(child_plan))
864                                .map(LogicalPlan::Filter)
865                        }),
866                    }
867                } else {
868                    filter.input = Arc::new(new_projection.data);
869                    Ok(Transformed::no(LogicalPlan::Filter(filter)))
870                }
871            }
872            LogicalPlan::Unnest(mut unnest) => {
873                let predicates = split_conjunction_owned(filter.predicate.clone());
874                let mut non_unnest_predicates = vec![];
875                let mut unnest_predicates = vec![];
876                let mut unnest_struct_columns = vec![];
877
878                for idx in &unnest.struct_type_columns {
879                    let (sub_qualifier, field) =
880                        unnest.input.schema().qualified_field(*idx);
881                    let field_name = field.name().clone();
882
883                    if let DataType::Struct(children) = field.data_type() {
884                        for child in children {
885                            let child_name = child.name().clone();
886                            unnest_struct_columns.push(Column::new(
887                                sub_qualifier.cloned(),
888                                format!("{field_name}.{child_name}"),
889                            ));
890                        }
891                    }
892                }
893
894                for predicate in predicates {
895                    // collect all the Expr::Column in predicate recursively
896                    let mut accum: HashSet<Column> = HashSet::new();
897                    expr_to_columns(&predicate, &mut accum)?;
898
899                    let contains_list_columns =
900                        unnest.list_type_columns.iter().any(|(_, unnest_list)| {
901                            accum.contains(&unnest_list.output_column)
902                        });
903                    let contains_struct_columns =
904                        unnest_struct_columns.iter().any(|c| accum.contains(c));
905
906                    if contains_list_columns || contains_struct_columns {
907                        unnest_predicates.push(predicate);
908                    } else {
909                        non_unnest_predicates.push(predicate);
910                    }
911                }
912
913                // Unnest predicates should not be pushed down.
914                // If no non-unnest predicates exist, early return
915                if non_unnest_predicates.is_empty() {
916                    filter.input = Arc::new(LogicalPlan::Unnest(unnest));
917                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
918                }
919
920                // Push down non-unnest filter predicate
921                // Unnest
922                //   Unnest Input (Projection)
923                // -> rewritten to
924                // Unnest
925                //   Filter
926                //     Unnest Input (Projection)
927
928                let unnest_input = std::mem::take(&mut unnest.input);
929
930                let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
931                    conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
932                    unnest_input,
933                )?);
934
935                // Directly assign new filter plan as the new unnest's input.
936                // The new filter plan will go through another rewrite pass since the rule itself
937                // is applied recursively to all the child from top to down
938                let unnest_plan =
939                    insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
940
941                match conjunction(unnest_predicates) {
942                    None => Ok(unnest_plan),
943                    Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
944                        Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
945                    ))),
946                }
947            }
948            LogicalPlan::Union(ref union) => {
949                let mut inputs = Vec::with_capacity(union.inputs.len());
950                for input in &union.inputs {
951                    let mut replace_map = HashMap::new();
952                    for (i, (qualifier, field)) in input.schema().iter().enumerate() {
953                        let (union_qualifier, union_field) =
954                            union.schema.qualified_field(i);
955                        replace_map.insert(
956                            qualified_name(union_qualifier, union_field.name()),
957                            Expr::Column(Column::new(qualifier.cloned(), field.name())),
958                        );
959                    }
960
961                    let push_predicate =
962                        replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
963                    inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
964                        push_predicate,
965                        Arc::clone(input),
966                    )?)))
967                }
968                Ok(Transformed::yes(LogicalPlan::Union(Union {
969                    inputs,
970                    schema: Arc::clone(&plan_schema),
971                })))
972            }
973            LogicalPlan::Aggregate(agg) => {
974                // We can push down Predicate which in groupby_expr.
975                let group_expr_columns = agg
976                    .group_expr
977                    .iter()
978                    .map(|e| {
979                        let (relation, name) = e.qualified_name();
980                        Column::new(relation, name)
981                    })
982                    .collect::<HashSet<_>>();
983
984                let predicates = split_conjunction_owned(filter.predicate);
985
986                let mut keep_predicates = vec![];
987                let mut push_predicates = vec![];
988                for expr in predicates {
989                    let cols = expr.column_refs();
990                    if cols.iter().all(|c| group_expr_columns.contains(c)) {
991                        push_predicates.push(expr);
992                    } else {
993                        keep_predicates.push(expr);
994                    }
995                }
996
997                // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
998                // After push, we need to replace `a+b` with Column(a)+Column(b)
999                // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
1000                let mut replace_map = HashMap::new();
1001                for expr in &agg.group_expr {
1002                    replace_map.insert(expr.schema_name().to_string(), expr.clone());
1003                }
1004                let replaced_push_predicates = push_predicates
1005                    .into_iter()
1006                    .map(|expr| replace_cols_by_name(expr, &replace_map))
1007                    .collect::<Result<Vec<_>>>()?;
1008
1009                let agg_input = Arc::clone(&agg.input);
1010                Transformed::yes(LogicalPlan::Aggregate(agg))
1011                    .transform_data(|new_plan| {
1012                        // If we have a filter to push, we push it down to the input of the aggregate
1013                        if let Some(predicate) = conjunction(replaced_push_predicates) {
1014                            let new_filter = make_filter(predicate, agg_input)?;
1015                            insert_below(new_plan, new_filter)
1016                        } else {
1017                            Ok(Transformed::no(new_plan))
1018                        }
1019                    })?
1020                    .map_data(|child_plan| {
1021                        // if there are any remaining predicates we can't push, add them
1022                        // back as a filter
1023                        if let Some(predicate) = conjunction(keep_predicates) {
1024                            make_filter(predicate, Arc::new(child_plan))
1025                        } else {
1026                            Ok(child_plan)
1027                        }
1028                    })
1029            }
1030            // Tries to push filters based on the partition key(s) of the window function(s) used.
1031            // Example:
1032            //   Before:
1033            //     Filter: (a > 1) and (b > 1) and (c > 1)
1034            //      Window: func() PARTITION BY [a] ...
1035            //   ---
1036            //   After:
1037            //     Filter: (b > 1) and (c > 1)
1038            //      Window: func() PARTITION BY [a] ...
1039            //        Filter: (a > 1)
1040            LogicalPlan::Window(window) => {
1041                // Retrieve the set of potential partition keys where we can push filters by.
1042                // Unlike aggregations, where there is only one statement per SELECT, there can be
1043                // multiple window functions, each with potentially different partition keys.
1044                // Therefore, we need to ensure that any potential partition key returned is used in
1045                // ALL window functions. Otherwise, filters cannot be pushed by through that column.
1046                let extract_partition_keys = |func: &WindowFunction| {
1047                    func.params
1048                        .partition_by
1049                        .iter()
1050                        .map(|c| {
1051                            let (relation, name) = c.qualified_name();
1052                            Column::new(relation, name)
1053                        })
1054                        .collect::<HashSet<_>>()
1055                };
1056                let potential_partition_keys = window
1057                    .window_expr
1058                    .iter()
1059                    .map(|e| {
1060                        match e {
1061                            Expr::WindowFunction(window_func) => {
1062                                extract_partition_keys(window_func)
1063                            }
1064                            Expr::Alias(alias) => {
1065                                if let Expr::WindowFunction(window_func) =
1066                                    alias.expr.as_ref()
1067                                {
1068                                    extract_partition_keys(window_func)
1069                                } else {
1070                                    // window functions expressions are only Expr::WindowFunction
1071                                    unreachable!()
1072                                }
1073                            }
1074                            _ => {
1075                                // window functions expressions are only Expr::WindowFunction
1076                                unreachable!()
1077                            }
1078                        }
1079                    })
1080                    // performs the set intersection of the partition keys of all window functions,
1081                    // returning only the common ones
1082                    .reduce(|a, b| &a & &b)
1083                    .unwrap_or_default();
1084
1085                let predicates = split_conjunction_owned(filter.predicate);
1086                let mut keep_predicates = vec![];
1087                let mut push_predicates = vec![];
1088                for expr in predicates {
1089                    let cols = expr.column_refs();
1090                    if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1091                        push_predicates.push(expr);
1092                    } else {
1093                        keep_predicates.push(expr);
1094                    }
1095                }
1096
1097                // Unlike with aggregations, there are no cases where we have to replace, e.g.,
1098                // `a+b` with Column(a)+Column(b). This is because partition expressions are not
1099                // available as standalone columns to the user. For example, while an aggregation on
1100                // `a+b` becomes Column(a + b), in a window partition it becomes
1101                // `func() PARTITION BY [a + b] ...`. Thus, filters on expressions always remain in
1102                // place, so we can use `push_predicates` directly. This is consistent with other
1103                // optimizers, such as the one used by Postgres.
1104
1105                let window_input = Arc::clone(&window.input);
1106                Transformed::yes(LogicalPlan::Window(window))
1107                    .transform_data(|new_plan| {
1108                        // If we have a filter to push, we push it down to the input of the window
1109                        if let Some(predicate) = conjunction(push_predicates) {
1110                            let new_filter = make_filter(predicate, window_input)?;
1111                            insert_below(new_plan, new_filter)
1112                        } else {
1113                            Ok(Transformed::no(new_plan))
1114                        }
1115                    })?
1116                    .map_data(|child_plan| {
1117                        // if there are any remaining predicates we can't push, add them
1118                        // back as a filter
1119                        if let Some(predicate) = conjunction(keep_predicates) {
1120                            make_filter(predicate, Arc::new(child_plan))
1121                        } else {
1122                            Ok(child_plan)
1123                        }
1124                    })
1125            }
1126            LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1127            LogicalPlan::TableScan(scan) => {
1128                let filter_predicates = split_conjunction(&filter.predicate);
1129
1130                let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1131                    filter_predicates
1132                        .into_iter()
1133                        .partition(|pred| pred.is_volatile());
1134
1135                // Check which non-volatile filters are supported by source
1136                let supported_filters = scan
1137                    .source
1138                    .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1139                if non_volatile_filters.len() != supported_filters.len() {
1140                    return internal_err!(
1141                        "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1142                        supported_filters.len(),
1143                        non_volatile_filters.len());
1144                }
1145
1146                // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type
1147                let zip = non_volatile_filters.into_iter().zip(supported_filters);
1148
1149                let new_scan_filters = zip
1150                    .clone()
1151                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1152                    .map(|(pred, _)| pred);
1153
1154                // Add new scan filters
1155                let new_scan_filters: Vec<Expr> = scan
1156                    .filters
1157                    .iter()
1158                    .chain(new_scan_filters)
1159                    .unique()
1160                    .cloned()
1161                    .collect();
1162
1163                // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters
1164                let new_predicate: Vec<Expr> = zip
1165                    .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1166                    .map(|(pred, _)| pred)
1167                    .chain(volatile_filters)
1168                    .cloned()
1169                    .collect();
1170
1171                let new_scan = LogicalPlan::TableScan(TableScan {
1172                    filters: new_scan_filters,
1173                    ..scan
1174                });
1175
1176                Transformed::yes(new_scan).transform_data(|new_scan| {
1177                    if let Some(predicate) = conjunction(new_predicate) {
1178                        make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1179                    } else {
1180                        Ok(Transformed::no(new_scan))
1181                    }
1182                })
1183            }
1184            LogicalPlan::Extension(extension_plan) => {
1185                // This check prevents the Filter from being removed when the extension node has no children,
1186                // so we return the original Filter unchanged.
1187                if extension_plan.node.inputs().is_empty() {
1188                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1189                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1190                }
1191                let prevent_cols =
1192                    extension_plan.node.prevent_predicate_push_down_columns();
1193
1194                // determine if we can push any predicates down past the extension node
1195
1196                // each element is true for push, false to keep
1197                let predicate_push_or_keep = split_conjunction(&filter.predicate)
1198                    .iter()
1199                    .map(|expr| {
1200                        let cols = expr.column_refs();
1201                        if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1202                            Ok(false) // No push (keep)
1203                        } else {
1204                            Ok(true) // push
1205                        }
1206                    })
1207                    .collect::<Result<Vec<_>>>()?;
1208
1209                // all predicates are kept, no changes needed
1210                if predicate_push_or_keep.iter().all(|&x| !x) {
1211                    filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1212                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1213                }
1214
1215                // going to push some predicates down, so split the predicates
1216                let mut keep_predicates = vec![];
1217                let mut push_predicates = vec![];
1218                for (push, expr) in predicate_push_or_keep
1219                    .into_iter()
1220                    .zip(split_conjunction_owned(filter.predicate).into_iter())
1221                {
1222                    if !push {
1223                        keep_predicates.push(expr);
1224                    } else {
1225                        push_predicates.push(expr);
1226                    }
1227                }
1228
1229                let new_children = match conjunction(push_predicates) {
1230                    Some(predicate) => extension_plan
1231                        .node
1232                        .inputs()
1233                        .into_iter()
1234                        .map(|child| {
1235                            Ok(LogicalPlan::Filter(Filter::try_new(
1236                                predicate.clone(),
1237                                Arc::new(child.clone()),
1238                            )?))
1239                        })
1240                        .collect::<Result<Vec<_>>>()?,
1241                    None => extension_plan.node.inputs().into_iter().cloned().collect(),
1242                };
1243                // extension with new inputs.
1244                let child_plan = LogicalPlan::Extension(extension_plan);
1245                let new_extension =
1246                    child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1247
1248                let new_plan = match conjunction(keep_predicates) {
1249                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1250                        predicate,
1251                        Arc::new(new_extension),
1252                    )?),
1253                    None => new_extension,
1254                };
1255                Ok(Transformed::yes(new_plan))
1256            }
1257            child => {
1258                filter.input = Arc::new(child);
1259                Ok(Transformed::no(LogicalPlan::Filter(filter)))
1260            }
1261        }
1262    }
1263}
1264
1265/// Attempts to push `predicate` into a `FilterExec` below `projection
1266///
1267/// # Returns
1268/// (plan, remaining_predicate)
1269///
1270/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it.
1271/// `remaining_predicate` is any part of the predicate that could not be pushed down
1272///
1273/// # Args
1274/// - predicates: Split predicates like `[foo=5, bar=6]`
1275/// - projection: The target projection plan to push down the predicates
1276///
1277/// # Example
1278///
1279/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this:
1280///
1281/// ```text
1282/// Projection(foo, c+d as bar)
1283/// ```
1284///
1285/// Might result in returning `remaining_predicate` of `bar=6` and a plan like
1286///
1287/// ```text
1288/// Projection(foo, c+d as bar)
1289///  Filter(foo=5)
1290///   ...
1291/// ```
1292fn rewrite_projection(
1293    predicates: Vec<Expr>,
1294    mut projection: Projection,
1295) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1296    // A projection is filter-commutable if it do not contain volatile predicates or contain volatile
1297    // predicates that are not used in the filter. However, we should re-writes all predicate expressions.
1298    // collect projection.
1299    let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1300        .schema
1301        .iter()
1302        .zip(projection.expr.iter())
1303        .map(|((qualifier, field), expr)| {
1304            // strip alias, as they should not be part of filters
1305            let expr = expr.clone().unalias();
1306
1307            (qualified_name(qualifier, field.name()), expr)
1308        })
1309        .partition(|(_, value)| value.is_volatile());
1310
1311    let mut push_predicates = vec![];
1312    let mut keep_predicates = vec![];
1313    for expr in predicates {
1314        if contain(&expr, &volatile_map) {
1315            keep_predicates.push(expr);
1316        } else {
1317            push_predicates.push(expr);
1318        }
1319    }
1320
1321    match conjunction(push_predicates) {
1322        Some(expr) => {
1323            // re-write all filters based on this projection
1324            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
1325            let new_filter = LogicalPlan::Filter(Filter::try_new(
1326                replace_cols_by_name(expr, &non_volatile_map)?,
1327                std::mem::take(&mut projection.input),
1328            )?);
1329
1330            projection.input = Arc::new(new_filter);
1331
1332            Ok((
1333                Transformed::yes(LogicalPlan::Projection(projection)),
1334                conjunction(keep_predicates),
1335            ))
1336        }
1337        None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1338    }
1339}
1340
1341/// Creates a new LogicalPlan::Filter node.
1342pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1343    Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1344}
1345
1346/// Replace the existing child of the single input node with `new_child`.
1347///
1348/// Starting:
1349/// ```text
1350/// plan
1351///   child
1352/// ```
1353///
1354/// Ending:
1355/// ```text
1356/// plan
1357///   new_child
1358/// ```
1359fn insert_below(
1360    plan: LogicalPlan,
1361    new_child: LogicalPlan,
1362) -> Result<Transformed<LogicalPlan>> {
1363    let mut new_child = Some(new_child);
1364    let transformed_plan = plan.map_children(|_child| {
1365        if let Some(new_child) = new_child.take() {
1366            Ok(Transformed::yes(new_child))
1367        } else {
1368            // already took the new child
1369            internal_err!("node had more than one input")
1370        }
1371    })?;
1372
1373    // make sure we did the actual replacement
1374    if new_child.is_some() {
1375        return internal_err!("node had no  inputs");
1376    }
1377
1378    Ok(transformed_plan)
1379}
1380
1381impl PushDownFilter {
1382    #[allow(missing_docs)]
1383    pub fn new() -> Self {
1384        Self {}
1385    }
1386}
1387
1388/// replaces columns by its name on the projection.
1389pub fn replace_cols_by_name(
1390    e: Expr,
1391    replace_map: &HashMap<String, Expr>,
1392) -> Result<Expr> {
1393    e.transform_up_with_lambdas_params(|expr, lambdas_params| {
1394        Ok(match &expr {
1395            Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
1396                match replace_map.get(&c.flat_name()) {
1397                    Some(new_c) => Transformed::yes(new_c.clone()),
1398                    None => Transformed::no(expr),
1399                }
1400            }
1401            _ => Transformed::no(expr),
1402        })
1403    })
1404    .data()
1405}
1406
1407/// check whether the expression uses the columns in `check_map`.
1408fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1409    let mut is_contain = false;
1410    e.apply_with_lambdas_params(|expr, lambdas_params| {
1411        Ok(match &expr {
1412            Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
1413                match check_map.get(&c.flat_name()) {
1414                    Some(_) => {
1415                        is_contain = true;
1416                        TreeNodeRecursion::Stop
1417                    }
1418                    None => TreeNodeRecursion::Continue,
1419                }
1420            }
1421            _ => TreeNodeRecursion::Continue,
1422        })
1423    })
1424    .unwrap();
1425    is_contain
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430    use std::any::Any;
1431    use std::cmp::Ordering;
1432    use std::fmt::{Debug, Formatter};
1433
1434    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1435    use async_trait::async_trait;
1436
1437    use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1438    use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1439    use datafusion_expr::logical_plan::table_scan;
1440    use datafusion_expr::{
1441        col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1442        LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1443        TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1444        WindowFunctionDefinition,
1445    };
1446
1447    use crate::assert_optimized_plan_eq_snapshot;
1448    use crate::optimizer::Optimizer;
1449    use crate::simplify_expressions::SimplifyExpressions;
1450    use crate::test::*;
1451    use crate::OptimizerContext;
1452    use datafusion_expr::test::function_stub::sum;
1453    use insta::assert_snapshot;
1454
1455    use super::*;
1456
1457    fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1458
1459    macro_rules! assert_optimized_plan_equal {
1460        (
1461            $plan:expr,
1462            @ $expected:literal $(,)?
1463        ) => {{
1464            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1465            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1466            assert_optimized_plan_eq_snapshot!(
1467                optimizer_ctx,
1468                rules,
1469                $plan,
1470                @ $expected,
1471            )
1472        }};
1473    }
1474
1475    macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1476        (
1477            $plan:expr,
1478            @ $expected:literal $(,)?
1479        ) => {{
1480            let optimizer = Optimizer::with_rules(vec![
1481                Arc::new(SimplifyExpressions::new()),
1482                Arc::new(PushDownFilter::new()),
1483            ]);
1484            let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1485            assert_snapshot!(optimized_plan, @ $expected);
1486            Ok::<(), DataFusionError>(())
1487        }};
1488    }
1489
1490    #[test]
1491    fn filter_before_projection() -> Result<()> {
1492        let table_scan = test_table_scan()?;
1493        let plan = LogicalPlanBuilder::from(table_scan)
1494            .project(vec![col("a"), col("b")])?
1495            .filter(col("a").eq(lit(1i64)))?
1496            .build()?;
1497        // filter is before projection
1498        assert_optimized_plan_equal!(
1499            plan,
1500            @r"
1501        Projection: test.a, test.b
1502          TableScan: test, full_filters=[test.a = Int64(1)]
1503        "
1504        )
1505    }
1506
1507    #[test]
1508    fn filter_after_limit() -> Result<()> {
1509        let table_scan = test_table_scan()?;
1510        let plan = LogicalPlanBuilder::from(table_scan)
1511            .project(vec![col("a"), col("b")])?
1512            .limit(0, Some(10))?
1513            .filter(col("a").eq(lit(1i64)))?
1514            .build()?;
1515        // filter is before single projection
1516        assert_optimized_plan_equal!(
1517            plan,
1518            @r"
1519        Filter: test.a = Int64(1)
1520          Limit: skip=0, fetch=10
1521            Projection: test.a, test.b
1522              TableScan: test
1523        "
1524        )
1525    }
1526
1527    #[test]
1528    fn filter_no_columns() -> Result<()> {
1529        let table_scan = test_table_scan()?;
1530        let plan = LogicalPlanBuilder::from(table_scan)
1531            .filter(lit(0i64).eq(lit(1i64)))?
1532            .build()?;
1533        assert_optimized_plan_equal!(
1534            plan,
1535            @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1536        )
1537    }
1538
1539    #[test]
1540    fn filter_jump_2_plans() -> Result<()> {
1541        let table_scan = test_table_scan()?;
1542        let plan = LogicalPlanBuilder::from(table_scan)
1543            .project(vec![col("a"), col("b"), col("c")])?
1544            .project(vec![col("c"), col("b")])?
1545            .filter(col("a").eq(lit(1i64)))?
1546            .build()?;
1547        // filter is before double projection
1548        assert_optimized_plan_equal!(
1549            plan,
1550            @r"
1551        Projection: test.c, test.b
1552          Projection: test.a, test.b, test.c
1553            TableScan: test, full_filters=[test.a = Int64(1)]
1554        "
1555        )
1556    }
1557
1558    #[test]
1559    fn filter_move_agg() -> Result<()> {
1560        let table_scan = test_table_scan()?;
1561        let plan = LogicalPlanBuilder::from(table_scan)
1562            .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1563            .filter(col("a").gt(lit(10i64)))?
1564            .build()?;
1565        // filter of key aggregation is commutative
1566        assert_optimized_plan_equal!(
1567            plan,
1568            @r"
1569        Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1570          TableScan: test, full_filters=[test.a > Int64(10)]
1571        "
1572        )
1573    }
1574
1575    /// verifies that filters with unusual column names are pushed down through aggregate operators
1576    #[test]
1577    fn filter_move_agg_special() -> Result<()> {
1578        let schema = Schema::new(vec![
1579            Field::new("$a", DataType::UInt32, false),
1580            Field::new("$b", DataType::UInt32, false),
1581            Field::new("$c", DataType::UInt32, false),
1582        ]);
1583        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1584
1585        let plan = LogicalPlanBuilder::from(table_scan)
1586            .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1587            .filter(col("$a").gt(lit(10i64)))?
1588            .build()?;
1589        // filter of key aggregation is commutative
1590        assert_optimized_plan_equal!(
1591            plan,
1592            @r"
1593        Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1594          TableScan: test, full_filters=[test.$a > Int64(10)]
1595        "
1596        )
1597    }
1598
1599    #[test]
1600    fn filter_complex_group_by() -> Result<()> {
1601        let table_scan = test_table_scan()?;
1602        let plan = LogicalPlanBuilder::from(table_scan)
1603            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1604            .filter(col("b").gt(lit(10i64)))?
1605            .build()?;
1606        assert_optimized_plan_equal!(
1607            plan,
1608            @r"
1609        Filter: test.b > Int64(10)
1610          Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1611            TableScan: test
1612        "
1613        )
1614    }
1615
1616    #[test]
1617    fn push_agg_need_replace_expr() -> Result<()> {
1618        let plan = LogicalPlanBuilder::from(test_table_scan()?)
1619            .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1620            .filter(col("test.b + test.a").gt(lit(10i64)))?
1621            .build()?;
1622        assert_optimized_plan_equal!(
1623            plan,
1624            @r"
1625        Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1626          TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1627        "
1628        )
1629    }
1630
1631    #[test]
1632    fn filter_keep_agg() -> Result<()> {
1633        let table_scan = test_table_scan()?;
1634        let plan = LogicalPlanBuilder::from(table_scan)
1635            .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1636            .filter(col("b").gt(lit(10i64)))?
1637            .build()?;
1638        // filter of aggregate is after aggregation since they are non-commutative
1639        assert_optimized_plan_equal!(
1640            plan,
1641            @r"
1642        Filter: b > Int64(10)
1643          Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1644            TableScan: test
1645        "
1646        )
1647    }
1648
1649    /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed
1650    #[test]
1651    fn filter_move_window() -> Result<()> {
1652        let table_scan = test_table_scan()?;
1653
1654        let window = Expr::from(WindowFunction::new(
1655            WindowFunctionDefinition::WindowUDF(
1656                datafusion_functions_window::rank::rank_udwf(),
1657            ),
1658            vec![],
1659        ))
1660        .partition_by(vec![col("a"), col("b")])
1661        .order_by(vec![col("c").sort(true, true)])
1662        .build()
1663        .unwrap();
1664
1665        let plan = LogicalPlanBuilder::from(table_scan)
1666            .window(vec![window])?
1667            .filter(col("b").gt(lit(10i64)))?
1668            .build()?;
1669
1670        assert_optimized_plan_equal!(
1671            plan,
1672            @r"
1673        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1674          TableScan: test, full_filters=[test.b > Int64(10)]
1675        "
1676        )
1677    }
1678
1679    /// verifies that filters with unusual identifier names are pushed down through window functions
1680    #[test]
1681    fn filter_window_special_identifier() -> Result<()> {
1682        let schema = Schema::new(vec![
1683            Field::new("$a", DataType::UInt32, false),
1684            Field::new("$b", DataType::UInt32, false),
1685            Field::new("$c", DataType::UInt32, false),
1686        ]);
1687        let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1688
1689        let window = Expr::from(WindowFunction::new(
1690            WindowFunctionDefinition::WindowUDF(
1691                datafusion_functions_window::rank::rank_udwf(),
1692            ),
1693            vec![],
1694        ))
1695        .partition_by(vec![col("$a"), col("$b")])
1696        .order_by(vec![col("$c").sort(true, true)])
1697        .build()
1698        .unwrap();
1699
1700        let plan = LogicalPlanBuilder::from(table_scan)
1701            .window(vec![window])?
1702            .filter(col("$b").gt(lit(10i64)))?
1703            .build()?;
1704
1705        assert_optimized_plan_equal!(
1706            plan,
1707            @r"
1708        WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1709          TableScan: test, full_filters=[test.$b > Int64(10)]
1710        "
1711        )
1712    }
1713
1714    /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
1715    /// 'b' are pushed
1716    #[test]
1717    fn filter_move_complex_window() -> Result<()> {
1718        let table_scan = test_table_scan()?;
1719
1720        let window = Expr::from(WindowFunction::new(
1721            WindowFunctionDefinition::WindowUDF(
1722                datafusion_functions_window::rank::rank_udwf(),
1723            ),
1724            vec![],
1725        ))
1726        .partition_by(vec![col("a"), col("b")])
1727        .order_by(vec![col("c").sort(true, true)])
1728        .build()
1729        .unwrap();
1730
1731        let plan = LogicalPlanBuilder::from(table_scan)
1732            .window(vec![window])?
1733            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1734            .build()?;
1735
1736        assert_optimized_plan_equal!(
1737            plan,
1738            @r"
1739        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1740          TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1741        "
1742        )
1743    }
1744
1745    /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed
1746    #[test]
1747    fn filter_move_partial_window() -> Result<()> {
1748        let table_scan = test_table_scan()?;
1749
1750        let window = Expr::from(WindowFunction::new(
1751            WindowFunctionDefinition::WindowUDF(
1752                datafusion_functions_window::rank::rank_udwf(),
1753            ),
1754            vec![],
1755        ))
1756        .partition_by(vec![col("a")])
1757        .order_by(vec![col("c").sort(true, true)])
1758        .build()
1759        .unwrap();
1760
1761        let plan = LogicalPlanBuilder::from(table_scan)
1762            .window(vec![window])?
1763            .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1764            .build()?;
1765
1766        assert_optimized_plan_equal!(
1767            plan,
1768            @r"
1769        Filter: test.b = Int64(1)
1770          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1771            TableScan: test, full_filters=[test.a > Int64(10)]
1772        "
1773        )
1774    }
1775
1776    /// verifies that filters on partition expressions are not pushed, as the single expression
1777    /// column is not available to the user, unlike with aggregations
1778    #[test]
1779    fn filter_expression_keep_window() -> Result<()> {
1780        let table_scan = test_table_scan()?;
1781
1782        let window = Expr::from(WindowFunction::new(
1783            WindowFunctionDefinition::WindowUDF(
1784                datafusion_functions_window::rank::rank_udwf(),
1785            ),
1786            vec![],
1787        ))
1788        .partition_by(vec![add(col("a"), col("b"))]) // PARTITION BY a + b
1789        .order_by(vec![col("c").sort(true, true)])
1790        .build()
1791        .unwrap();
1792
1793        let plan = LogicalPlanBuilder::from(table_scan)
1794            .window(vec![window])?
1795            // unlike with aggregations, single partition column "test.a + test.b" is not available
1796            // to the plan, so we use multiple columns when filtering
1797            .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1798            .build()?;
1799
1800        assert_optimized_plan_equal!(
1801            plan,
1802            @r"
1803        Filter: test.a + test.b > Int64(10)
1804          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1805            TableScan: test
1806        "
1807        )
1808    }
1809
1810    /// verifies that filters are not pushed on order by columns (that are not used in partitioning)
1811    #[test]
1812    fn filter_order_keep_window() -> Result<()> {
1813        let table_scan = test_table_scan()?;
1814
1815        let window = Expr::from(WindowFunction::new(
1816            WindowFunctionDefinition::WindowUDF(
1817                datafusion_functions_window::rank::rank_udwf(),
1818            ),
1819            vec![],
1820        ))
1821        .partition_by(vec![col("a")])
1822        .order_by(vec![col("c").sort(true, true)])
1823        .build()
1824        .unwrap();
1825
1826        let plan = LogicalPlanBuilder::from(table_scan)
1827            .window(vec![window])?
1828            .filter(col("c").gt(lit(10i64)))?
1829            .build()?;
1830
1831        assert_optimized_plan_equal!(
1832            plan,
1833            @r"
1834        Filter: test.c > Int64(10)
1835          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1836            TableScan: test
1837        "
1838        )
1839    }
1840
1841    /// verifies that when we use multiple window functions with a common partition key, the filter
1842    /// on that key is pushed
1843    #[test]
1844    fn filter_multiple_windows_common_partitions() -> Result<()> {
1845        let table_scan = test_table_scan()?;
1846
1847        let window1 = Expr::from(WindowFunction::new(
1848            WindowFunctionDefinition::WindowUDF(
1849                datafusion_functions_window::rank::rank_udwf(),
1850            ),
1851            vec![],
1852        ))
1853        .partition_by(vec![col("a")])
1854        .order_by(vec![col("c").sort(true, true)])
1855        .build()
1856        .unwrap();
1857
1858        let window2 = Expr::from(WindowFunction::new(
1859            WindowFunctionDefinition::WindowUDF(
1860                datafusion_functions_window::rank::rank_udwf(),
1861            ),
1862            vec![],
1863        ))
1864        .partition_by(vec![col("b"), col("a")])
1865        .order_by(vec![col("c").sort(true, true)])
1866        .build()
1867        .unwrap();
1868
1869        let plan = LogicalPlanBuilder::from(table_scan)
1870            .window(vec![window1, window2])?
1871            .filter(col("a").gt(lit(10i64)))? // a appears in both window functions
1872            .build()?;
1873
1874        assert_optimized_plan_equal!(
1875            plan,
1876            @r"
1877        WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1878          TableScan: test, full_filters=[test.a > Int64(10)]
1879        "
1880        )
1881    }
1882
1883    /// verifies that when we use multiple window functions with different partitions keys, the
1884    /// filter cannot be pushed
1885    #[test]
1886    fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1887        let table_scan = test_table_scan()?;
1888
1889        let window1 = Expr::from(WindowFunction::new(
1890            WindowFunctionDefinition::WindowUDF(
1891                datafusion_functions_window::rank::rank_udwf(),
1892            ),
1893            vec![],
1894        ))
1895        .partition_by(vec![col("a")])
1896        .order_by(vec![col("c").sort(true, true)])
1897        .build()
1898        .unwrap();
1899
1900        let window2 = Expr::from(WindowFunction::new(
1901            WindowFunctionDefinition::WindowUDF(
1902                datafusion_functions_window::rank::rank_udwf(),
1903            ),
1904            vec![],
1905        ))
1906        .partition_by(vec![col("b"), col("a")])
1907        .order_by(vec![col("c").sort(true, true)])
1908        .build()
1909        .unwrap();
1910
1911        let plan = LogicalPlanBuilder::from(table_scan)
1912            .window(vec![window1, window2])?
1913            .filter(col("b").gt(lit(10i64)))? // b only appears in one window function
1914            .build()?;
1915
1916        assert_optimized_plan_equal!(
1917            plan,
1918            @r"
1919        Filter: test.b > Int64(10)
1920          WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1921            TableScan: test
1922        "
1923        )
1924    }
1925
1926    /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written
1927    #[test]
1928    fn alias() -> Result<()> {
1929        let table_scan = test_table_scan()?;
1930        let plan = LogicalPlanBuilder::from(table_scan)
1931            .project(vec![col("a").alias("b"), col("c")])?
1932            .filter(col("b").eq(lit(1i64)))?
1933            .build()?;
1934        // filter is before projection
1935        assert_optimized_plan_equal!(
1936            plan,
1937            @r"
1938        Projection: test.a AS b, test.c
1939          TableScan: test, full_filters=[test.a = Int64(1)]
1940        "
1941        )
1942    }
1943
1944    fn add(left: Expr, right: Expr) -> Expr {
1945        Expr::BinaryExpr(BinaryExpr::new(
1946            Box::new(left),
1947            Operator::Plus,
1948            Box::new(right),
1949        ))
1950    }
1951
1952    fn multiply(left: Expr, right: Expr) -> Expr {
1953        Expr::BinaryExpr(BinaryExpr::new(
1954            Box::new(left),
1955            Operator::Multiply,
1956            Box::new(right),
1957        ))
1958    }
1959
1960    /// verifies that a filter is pushed to before a projection with a complex expression, the filter expression is correctly re-written
1961    #[test]
1962    fn complex_expression() -> Result<()> {
1963        let table_scan = test_table_scan()?;
1964        let plan = LogicalPlanBuilder::from(table_scan)
1965            .project(vec![
1966                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1967                col("c"),
1968            ])?
1969            .filter(col("b").eq(lit(1i64)))?
1970            .build()?;
1971
1972        // not part of the test, just good to know:
1973        assert_snapshot!(plan,
1974        @r"
1975        Filter: b = Int64(1)
1976          Projection: test.a * Int32(2) + test.c AS b, test.c
1977            TableScan: test
1978        ",
1979        );
1980        // filter is before projection
1981        assert_optimized_plan_equal!(
1982            plan,
1983            @r"
1984        Projection: test.a * Int32(2) + test.c AS b, test.c
1985          TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1986        "
1987        )
1988    }
1989
1990    /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written
1991    #[test]
1992    fn complex_plan() -> Result<()> {
1993        let table_scan = test_table_scan()?;
1994        let plan = LogicalPlanBuilder::from(table_scan)
1995            .project(vec![
1996                add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1997                col("c"),
1998            ])?
1999            // second projection where we rename columns, just to make it difficult
2000            .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2001            .filter(col("a").eq(lit(1i64)))?
2002            .build()?;
2003
2004        // not part of the test, just good to know:
2005        assert_snapshot!(plan,
2006        @r"
2007        Filter: a = Int64(1)
2008          Projection: b * Int32(3) AS a, test.c
2009            Projection: test.a * Int32(2) + test.c AS b, test.c
2010              TableScan: test
2011        ",
2012        );
2013        // filter is before the projections
2014        assert_optimized_plan_equal!(
2015            plan,
2016            @r"
2017        Projection: b * Int32(3) AS a, test.c
2018          Projection: test.a * Int32(2) + test.c AS b, test.c
2019            TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2020        "
2021        )
2022    }
2023
2024    #[derive(Debug, PartialEq, Eq, Hash)]
2025    struct NoopPlan {
2026        input: Vec<LogicalPlan>,
2027        schema: DFSchemaRef,
2028    }
2029
2030    // Manual implementation needed because of `schema` field. Comparison excludes this field.
2031    impl PartialOrd for NoopPlan {
2032        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2033            self.input
2034                .partial_cmp(&other.input)
2035                // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields
2036                .filter(|cmp| *cmp != Ordering::Equal || self == other)
2037        }
2038    }
2039
2040    impl UserDefinedLogicalNodeCore for NoopPlan {
2041        fn name(&self) -> &str {
2042            "NoopPlan"
2043        }
2044
2045        fn inputs(&self) -> Vec<&LogicalPlan> {
2046            self.input.iter().collect()
2047        }
2048
2049        fn schema(&self) -> &DFSchemaRef {
2050            &self.schema
2051        }
2052
2053        fn expressions(&self) -> Vec<Expr> {
2054            self.input
2055                .iter()
2056                .flat_map(|child| child.expressions())
2057                .collect()
2058        }
2059
2060        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2061            HashSet::from_iter(vec!["c".to_string()])
2062        }
2063
2064        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2065            write!(f, "NoopPlan")
2066        }
2067
2068        fn with_exprs_and_inputs(
2069            &self,
2070            _exprs: Vec<Expr>,
2071            inputs: Vec<LogicalPlan>,
2072        ) -> Result<Self> {
2073            Ok(Self {
2074                input: inputs,
2075                schema: Arc::clone(&self.schema),
2076            })
2077        }
2078
2079        fn supports_limit_pushdown(&self) -> bool {
2080            false // Disallow limit push-down by default
2081        }
2082    }
2083
2084    #[test]
2085    fn user_defined_plan() -> Result<()> {
2086        let table_scan = test_table_scan()?;
2087
2088        let custom_plan = LogicalPlan::Extension(Extension {
2089            node: Arc::new(NoopPlan {
2090                input: vec![table_scan.clone()],
2091                schema: Arc::clone(table_scan.schema()),
2092            }),
2093        });
2094        let plan = LogicalPlanBuilder::from(custom_plan)
2095            .filter(col("a").eq(lit(1i64)))?
2096            .build()?;
2097
2098        // Push filter below NoopPlan
2099        assert_optimized_plan_equal!(
2100            plan,
2101            @r"
2102        NoopPlan
2103          TableScan: test, full_filters=[test.a = Int64(1)]
2104        "
2105        )?;
2106
2107        let custom_plan = LogicalPlan::Extension(Extension {
2108            node: Arc::new(NoopPlan {
2109                input: vec![table_scan.clone()],
2110                schema: Arc::clone(table_scan.schema()),
2111            }),
2112        });
2113        let plan = LogicalPlanBuilder::from(custom_plan)
2114            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2115            .build()?;
2116
2117        // Push only predicate on `a` below NoopPlan
2118        assert_optimized_plan_equal!(
2119            plan,
2120            @r"
2121        Filter: test.c = Int64(2)
2122          NoopPlan
2123            TableScan: test, full_filters=[test.a = Int64(1)]
2124        "
2125        )?;
2126
2127        let custom_plan = LogicalPlan::Extension(Extension {
2128            node: Arc::new(NoopPlan {
2129                input: vec![table_scan.clone(), table_scan.clone()],
2130                schema: Arc::clone(table_scan.schema()),
2131            }),
2132        });
2133        let plan = LogicalPlanBuilder::from(custom_plan)
2134            .filter(col("a").eq(lit(1i64)))?
2135            .build()?;
2136
2137        // Push filter below NoopPlan for each child branch
2138        assert_optimized_plan_equal!(
2139            plan,
2140            @r"
2141        NoopPlan
2142          TableScan: test, full_filters=[test.a = Int64(1)]
2143          TableScan: test, full_filters=[test.a = Int64(1)]
2144        "
2145        )?;
2146
2147        let custom_plan = LogicalPlan::Extension(Extension {
2148            node: Arc::new(NoopPlan {
2149                input: vec![table_scan.clone(), table_scan.clone()],
2150                schema: Arc::clone(table_scan.schema()),
2151            }),
2152        });
2153        let plan = LogicalPlanBuilder::from(custom_plan)
2154            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2155            .build()?;
2156
2157        // Push only predicate on `a` below NoopPlan
2158        assert_optimized_plan_equal!(
2159            plan,
2160            @r"
2161        Filter: test.c = Int64(2)
2162          NoopPlan
2163            TableScan: test, full_filters=[test.a = Int64(1)]
2164            TableScan: test, full_filters=[test.a = Int64(1)]
2165        "
2166        )
2167    }
2168
2169    /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed
2170    /// and the other not.
2171    #[test]
2172    fn multi_filter() -> Result<()> {
2173        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2174        let table_scan = test_table_scan()?;
2175        let plan = LogicalPlanBuilder::from(table_scan)
2176            .project(vec![col("a").alias("b"), col("c")])?
2177            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2178            .filter(col("b").gt(lit(10i64)))?
2179            .filter(col("sum(test.c)").gt(lit(10i64)))?
2180            .build()?;
2181
2182        // not part of the test, just good to know:
2183        assert_snapshot!(plan,
2184        @r"
2185        Filter: sum(test.c) > Int64(10)
2186          Filter: b > Int64(10)
2187            Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2188              Projection: test.a AS b, test.c
2189                TableScan: test
2190        ",
2191        );
2192        // filter is before the projections
2193        assert_optimized_plan_equal!(
2194            plan,
2195            @r"
2196        Filter: sum(test.c) > Int64(10)
2197          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2198            Projection: test.a AS b, test.c
2199              TableScan: test, full_filters=[test.a > Int64(10)]
2200        "
2201        )
2202    }
2203
2204    /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed
2205    /// and the other not.
2206    #[test]
2207    fn split_filter() -> Result<()> {
2208        // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c))
2209        let table_scan = test_table_scan()?;
2210        let plan = LogicalPlanBuilder::from(table_scan)
2211            .project(vec![col("a").alias("b"), col("c")])?
2212            .aggregate(vec![col("b")], vec![sum(col("c"))])?
2213            .filter(and(
2214                col("sum(test.c)").gt(lit(10i64)),
2215                and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2216            ))?
2217            .build()?;
2218
2219        // not part of the test, just good to know:
2220        assert_snapshot!(plan,
2221        @r"
2222        Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2223          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2224            Projection: test.a AS b, test.c
2225              TableScan: test
2226        ",
2227        );
2228        // filter is before the projections
2229        assert_optimized_plan_equal!(
2230            plan,
2231            @r"
2232        Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2233          Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2234            Projection: test.a AS b, test.c
2235              TableScan: test, full_filters=[test.a > Int64(10)]
2236        "
2237        )
2238    }
2239
2240    /// verifies that when two limits are in place, we jump neither
2241    #[test]
2242    fn double_limit() -> Result<()> {
2243        let table_scan = test_table_scan()?;
2244        let plan = LogicalPlanBuilder::from(table_scan)
2245            .project(vec![col("a"), col("b")])?
2246            .limit(0, Some(20))?
2247            .limit(0, Some(10))?
2248            .project(vec![col("a"), col("b")])?
2249            .filter(col("a").eq(lit(1i64)))?
2250            .build()?;
2251        // filter does not just any of the limits
2252        assert_optimized_plan_equal!(
2253            plan,
2254            @r"
2255        Projection: test.a, test.b
2256          Filter: test.a = Int64(1)
2257            Limit: skip=0, fetch=10
2258              Limit: skip=0, fetch=20
2259                Projection: test.a, test.b
2260                  TableScan: test
2261        "
2262        )
2263    }
2264
2265    #[test]
2266    fn union_all() -> Result<()> {
2267        let table_scan = test_table_scan()?;
2268        let table_scan2 = test_table_scan_with_name("test2")?;
2269        let plan = LogicalPlanBuilder::from(table_scan)
2270            .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2271            .filter(col("a").eq(lit(1i64)))?
2272            .build()?;
2273        // filter appears below Union
2274        assert_optimized_plan_equal!(
2275            plan,
2276            @r"
2277        Union
2278          TableScan: test, full_filters=[test.a = Int64(1)]
2279          TableScan: test2, full_filters=[test2.a = Int64(1)]
2280        "
2281        )
2282    }
2283
2284    #[test]
2285    fn union_all_on_projection() -> Result<()> {
2286        let table_scan = test_table_scan()?;
2287        let table = LogicalPlanBuilder::from(table_scan)
2288            .project(vec![col("a").alias("b")])?
2289            .alias("test2")?;
2290
2291        let plan = table
2292            .clone()
2293            .union(table.build()?)?
2294            .filter(col("b").eq(lit(1i64)))?
2295            .build()?;
2296
2297        // filter appears below Union
2298        assert_optimized_plan_equal!(
2299            plan,
2300            @r"
2301        Union
2302          SubqueryAlias: test2
2303            Projection: test.a AS b
2304              TableScan: test, full_filters=[test.a = Int64(1)]
2305          SubqueryAlias: test2
2306            Projection: test.a AS b
2307              TableScan: test, full_filters=[test.a = Int64(1)]
2308        "
2309        )
2310    }
2311
2312    #[test]
2313    fn test_union_different_schema() -> Result<()> {
2314        let left = LogicalPlanBuilder::from(test_table_scan()?)
2315            .project(vec![col("a"), col("b"), col("c")])?
2316            .build()?;
2317
2318        let schema = Schema::new(vec![
2319            Field::new("d", DataType::UInt32, false),
2320            Field::new("e", DataType::UInt32, false),
2321            Field::new("f", DataType::UInt32, false),
2322        ]);
2323        let right = table_scan(Some("test1"), &schema, None)?
2324            .project(vec![col("d"), col("e"), col("f")])?
2325            .build()?;
2326        let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2327        let plan = LogicalPlanBuilder::from(left)
2328            .cross_join(right)?
2329            .project(vec![col("test.a"), col("test1.d")])?
2330            .filter(filter)?
2331            .build()?;
2332
2333        assert_optimized_plan_equal!(
2334            plan,
2335            @r"
2336        Projection: test.a, test1.d
2337          Cross Join: 
2338            Projection: test.a, test.b, test.c
2339              TableScan: test, full_filters=[test.a = Int32(1)]
2340            Projection: test1.d, test1.e, test1.f
2341              TableScan: test1, full_filters=[test1.d > Int32(2)]
2342        "
2343        )
2344    }
2345
2346    #[test]
2347    fn test_project_same_name_different_qualifier() -> Result<()> {
2348        let table_scan = test_table_scan()?;
2349        let left = LogicalPlanBuilder::from(table_scan)
2350            .project(vec![col("a"), col("b"), col("c")])?
2351            .build()?;
2352        let right_table_scan = test_table_scan_with_name("test1")?;
2353        let right = LogicalPlanBuilder::from(right_table_scan)
2354            .project(vec![col("a"), col("b"), col("c")])?
2355            .build()?;
2356        let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2357        let plan = LogicalPlanBuilder::from(left)
2358            .cross_join(right)?
2359            .project(vec![col("test.a"), col("test1.a")])?
2360            .filter(filter)?
2361            .build()?;
2362
2363        assert_optimized_plan_equal!(
2364            plan,
2365            @r"
2366        Projection: test.a, test1.a
2367          Cross Join: 
2368            Projection: test.a, test.b, test.c
2369              TableScan: test, full_filters=[test.a = Int32(1)]
2370            Projection: test1.a, test1.b, test1.c
2371              TableScan: test1, full_filters=[test1.a > Int32(2)]
2372        "
2373        )
2374    }
2375
2376    /// verifies that filters with the same columns are correctly placed
2377    #[test]
2378    fn filter_2_breaks_limits() -> Result<()> {
2379        let table_scan = test_table_scan()?;
2380        let plan = LogicalPlanBuilder::from(table_scan)
2381            .project(vec![col("a")])?
2382            .filter(col("a").lt_eq(lit(1i64)))?
2383            .limit(0, Some(1))?
2384            .project(vec![col("a")])?
2385            .filter(col("a").gt_eq(lit(1i64)))?
2386            .build()?;
2387        // Should be able to move both filters below the projections
2388
2389        // not part of the test
2390        assert_snapshot!(plan,
2391        @r"
2392        Filter: test.a >= Int64(1)
2393          Projection: test.a
2394            Limit: skip=0, fetch=1
2395              Filter: test.a <= Int64(1)
2396                Projection: test.a
2397                  TableScan: test
2398        ",
2399        );
2400        assert_optimized_plan_equal!(
2401            plan,
2402            @r"
2403        Projection: test.a
2404          Filter: test.a >= Int64(1)
2405            Limit: skip=0, fetch=1
2406              Projection: test.a
2407                TableScan: test, full_filters=[test.a <= Int64(1)]
2408        "
2409        )
2410    }
2411
2412    /// verifies that filters to be placed on the same depth are ANDed
2413    #[test]
2414    fn two_filters_on_same_depth() -> Result<()> {
2415        let table_scan = test_table_scan()?;
2416        let plan = LogicalPlanBuilder::from(table_scan)
2417            .limit(0, Some(1))?
2418            .filter(col("a").lt_eq(lit(1i64)))?
2419            .filter(col("a").gt_eq(lit(1i64)))?
2420            .project(vec![col("a")])?
2421            .build()?;
2422
2423        // not part of the test
2424        assert_snapshot!(plan,
2425        @r"
2426        Projection: test.a
2427          Filter: test.a >= Int64(1)
2428            Filter: test.a <= Int64(1)
2429              Limit: skip=0, fetch=1
2430                TableScan: test
2431        ",
2432        );
2433        assert_optimized_plan_equal!(
2434            plan,
2435            @r"
2436        Projection: test.a
2437          Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2438            Limit: skip=0, fetch=1
2439              TableScan: test
2440        "
2441        )
2442    }
2443
2444    /// verifies that filters on a plan with user nodes are not lost
2445    /// (ARROW-10547)
2446    #[test]
2447    fn filters_user_defined_node() -> Result<()> {
2448        let table_scan = test_table_scan()?;
2449        let plan = LogicalPlanBuilder::from(table_scan)
2450            .filter(col("a").lt_eq(lit(1i64)))?
2451            .build()?;
2452
2453        let plan = user_defined::new(plan);
2454
2455        // not part of the test
2456        assert_snapshot!(plan,
2457        @r"
2458        TestUserDefined
2459          Filter: test.a <= Int64(1)
2460            TableScan: test
2461        ",
2462        );
2463        assert_optimized_plan_equal!(
2464            plan,
2465            @r"
2466        TestUserDefined
2467          TableScan: test, full_filters=[test.a <= Int64(1)]
2468        "
2469        )
2470    }
2471
2472    /// post-on-join predicates on a column common to both sides is pushed to both sides
2473    #[test]
2474    fn filter_on_join_on_common_independent() -> Result<()> {
2475        let table_scan = test_table_scan()?;
2476        let left = LogicalPlanBuilder::from(table_scan).build()?;
2477        let right_table_scan = test_table_scan_with_name("test2")?;
2478        let right = LogicalPlanBuilder::from(right_table_scan)
2479            .project(vec![col("a")])?
2480            .build()?;
2481        let plan = LogicalPlanBuilder::from(left)
2482            .join(
2483                right,
2484                JoinType::Inner,
2485                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2486                None,
2487            )?
2488            .filter(col("test.a").lt_eq(lit(1i64)))?
2489            .build()?;
2490
2491        // not part of the test, just good to know:
2492        assert_snapshot!(plan,
2493        @r"
2494        Filter: test.a <= Int64(1)
2495          Inner Join: test.a = test2.a
2496            TableScan: test
2497            Projection: test2.a
2498              TableScan: test2
2499        ",
2500        );
2501        // filter sent to side before the join
2502        assert_optimized_plan_equal!(
2503            plan,
2504            @r"
2505        Inner Join: test.a = test2.a
2506          TableScan: test, full_filters=[test.a <= Int64(1)]
2507          Projection: test2.a
2508            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2509        "
2510        )
2511    }
2512
2513    /// post-using-join predicates on a column common to both sides is pushed to both sides
2514    #[test]
2515    fn filter_using_join_on_common_independent() -> Result<()> {
2516        let table_scan = test_table_scan()?;
2517        let left = LogicalPlanBuilder::from(table_scan).build()?;
2518        let right_table_scan = test_table_scan_with_name("test2")?;
2519        let right = LogicalPlanBuilder::from(right_table_scan)
2520            .project(vec![col("a")])?
2521            .build()?;
2522        let plan = LogicalPlanBuilder::from(left)
2523            .join_using(
2524                right,
2525                JoinType::Inner,
2526                vec![Column::from_name("a".to_string())],
2527            )?
2528            .filter(col("a").lt_eq(lit(1i64)))?
2529            .build()?;
2530
2531        // not part of the test, just good to know:
2532        assert_snapshot!(plan,
2533        @r"
2534        Filter: test.a <= Int64(1)
2535          Inner Join: Using test.a = test2.a
2536            TableScan: test
2537            Projection: test2.a
2538              TableScan: test2
2539        ",
2540        );
2541        // filter sent to side before the join
2542        assert_optimized_plan_equal!(
2543            plan,
2544            @r"
2545        Inner Join: Using test.a = test2.a
2546          TableScan: test, full_filters=[test.a <= Int64(1)]
2547          Projection: test2.a
2548            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2549        "
2550        )
2551    }
2552
2553    /// post-join predicates with columns from both sides are converted to join filters
2554    #[test]
2555    fn filter_join_on_common_dependent() -> Result<()> {
2556        let table_scan = test_table_scan()?;
2557        let left = LogicalPlanBuilder::from(table_scan)
2558            .project(vec![col("a"), col("c")])?
2559            .build()?;
2560        let right_table_scan = test_table_scan_with_name("test2")?;
2561        let right = LogicalPlanBuilder::from(right_table_scan)
2562            .project(vec![col("a"), col("b")])?
2563            .build()?;
2564        let plan = LogicalPlanBuilder::from(left)
2565            .join(
2566                right,
2567                JoinType::Inner,
2568                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2569                None,
2570            )?
2571            .filter(col("c").lt_eq(col("b")))?
2572            .build()?;
2573
2574        // not part of the test, just good to know:
2575        assert_snapshot!(plan,
2576        @r"
2577        Filter: test.c <= test2.b
2578          Inner Join: test.a = test2.a
2579            Projection: test.a, test.c
2580              TableScan: test
2581            Projection: test2.a, test2.b
2582              TableScan: test2
2583        ",
2584        );
2585        // Filter is converted to Join Filter
2586        assert_optimized_plan_equal!(
2587            plan,
2588            @r"
2589        Inner Join: test.a = test2.a Filter: test.c <= test2.b
2590          Projection: test.a, test.c
2591            TableScan: test
2592          Projection: test2.a, test2.b
2593            TableScan: test2
2594        "
2595        )
2596    }
2597
2598    /// post-join predicates with columns from one side of a join are pushed only to that side
2599    #[test]
2600    fn filter_join_on_one_side() -> Result<()> {
2601        let table_scan = test_table_scan()?;
2602        let left = LogicalPlanBuilder::from(table_scan)
2603            .project(vec![col("a"), col("b")])?
2604            .build()?;
2605        let table_scan_right = test_table_scan_with_name("test2")?;
2606        let right = LogicalPlanBuilder::from(table_scan_right)
2607            .project(vec![col("a"), col("c")])?
2608            .build()?;
2609
2610        let plan = LogicalPlanBuilder::from(left)
2611            .join(
2612                right,
2613                JoinType::Inner,
2614                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2615                None,
2616            )?
2617            .filter(col("b").lt_eq(lit(1i64)))?
2618            .build()?;
2619
2620        // not part of the test, just good to know:
2621        assert_snapshot!(plan,
2622        @r"
2623        Filter: test.b <= Int64(1)
2624          Inner Join: test.a = test2.a
2625            Projection: test.a, test.b
2626              TableScan: test
2627            Projection: test2.a, test2.c
2628              TableScan: test2
2629        ",
2630        );
2631        assert_optimized_plan_equal!(
2632            plan,
2633            @r"
2634        Inner Join: test.a = test2.a
2635          Projection: test.a, test.b
2636            TableScan: test, full_filters=[test.b <= Int64(1)]
2637          Projection: test2.a, test2.c
2638            TableScan: test2
2639        "
2640        )
2641    }
2642
2643    /// post-join predicates on the right side of a left join are not duplicated
2644    /// TODO: In this case we can sometimes convert the join to an INNER join
2645    #[test]
2646    fn filter_using_left_join() -> Result<()> {
2647        let table_scan = test_table_scan()?;
2648        let left = LogicalPlanBuilder::from(table_scan).build()?;
2649        let right_table_scan = test_table_scan_with_name("test2")?;
2650        let right = LogicalPlanBuilder::from(right_table_scan)
2651            .project(vec![col("a")])?
2652            .build()?;
2653        let plan = LogicalPlanBuilder::from(left)
2654            .join_using(
2655                right,
2656                JoinType::Left,
2657                vec![Column::from_name("a".to_string())],
2658            )?
2659            .filter(col("test2.a").lt_eq(lit(1i64)))?
2660            .build()?;
2661
2662        // not part of the test, just good to know:
2663        assert_snapshot!(plan,
2664        @r"
2665        Filter: test2.a <= Int64(1)
2666          Left Join: Using test.a = test2.a
2667            TableScan: test
2668            Projection: test2.a
2669              TableScan: test2
2670        ",
2671        );
2672        // filter not duplicated nor pushed down - i.e. noop
2673        assert_optimized_plan_equal!(
2674            plan,
2675            @r"
2676        Filter: test2.a <= Int64(1)
2677          Left Join: Using test.a = test2.a
2678            TableScan: test, full_filters=[test.a <= Int64(1)]
2679            Projection: test2.a
2680              TableScan: test2
2681        "
2682        )
2683    }
2684
2685    /// post-join predicates on the left side of a right join are not duplicated
2686    #[test]
2687    fn filter_using_right_join() -> Result<()> {
2688        let table_scan = test_table_scan()?;
2689        let left = LogicalPlanBuilder::from(table_scan).build()?;
2690        let right_table_scan = test_table_scan_with_name("test2")?;
2691        let right = LogicalPlanBuilder::from(right_table_scan)
2692            .project(vec![col("a")])?
2693            .build()?;
2694        let plan = LogicalPlanBuilder::from(left)
2695            .join_using(
2696                right,
2697                JoinType::Right,
2698                vec![Column::from_name("a".to_string())],
2699            )?
2700            .filter(col("test.a").lt_eq(lit(1i64)))?
2701            .build()?;
2702
2703        // not part of the test, just good to know:
2704        assert_snapshot!(plan,
2705        @r"
2706        Filter: test.a <= Int64(1)
2707          Right Join: Using test.a = test2.a
2708            TableScan: test
2709            Projection: test2.a
2710              TableScan: test2
2711        ",
2712        );
2713        // filter not duplicated nor pushed down - i.e. noop
2714        assert_optimized_plan_equal!(
2715            plan,
2716            @r"
2717        Filter: test.a <= Int64(1)
2718          Right Join: Using test.a = test2.a
2719            TableScan: test
2720            Projection: test2.a
2721              TableScan: test2, full_filters=[test2.a <= Int64(1)]
2722        "
2723        )
2724    }
2725
2726    /// post-left-join predicate on a column common to both sides is only pushed to the left side
2727    /// i.e. - not duplicated to the right side
2728    #[test]
2729    fn filter_using_left_join_on_common() -> Result<()> {
2730        let table_scan = test_table_scan()?;
2731        let left = LogicalPlanBuilder::from(table_scan).build()?;
2732        let right_table_scan = test_table_scan_with_name("test2")?;
2733        let right = LogicalPlanBuilder::from(right_table_scan)
2734            .project(vec![col("a")])?
2735            .build()?;
2736        let plan = LogicalPlanBuilder::from(left)
2737            .join_using(
2738                right,
2739                JoinType::Left,
2740                vec![Column::from_name("a".to_string())],
2741            )?
2742            .filter(col("a").lt_eq(lit(1i64)))?
2743            .build()?;
2744
2745        // not part of the test, just good to know:
2746        assert_snapshot!(plan,
2747        @r"
2748        Filter: test.a <= Int64(1)
2749          Left Join: Using test.a = test2.a
2750            TableScan: test
2751            Projection: test2.a
2752              TableScan: test2
2753        ",
2754        );
2755        // filter sent to left side of the join, not the right
2756        assert_optimized_plan_equal!(
2757            plan,
2758            @r"
2759        Left Join: Using test.a = test2.a
2760          TableScan: test, full_filters=[test.a <= Int64(1)]
2761          Projection: test2.a
2762            TableScan: test2
2763        "
2764        )
2765    }
2766
2767    /// post-right-join predicate on a column common to both sides is only pushed to the right side
2768    /// i.e. - not duplicated to the left side.
2769    #[test]
2770    fn filter_using_right_join_on_common() -> Result<()> {
2771        let table_scan = test_table_scan()?;
2772        let left = LogicalPlanBuilder::from(table_scan).build()?;
2773        let right_table_scan = test_table_scan_with_name("test2")?;
2774        let right = LogicalPlanBuilder::from(right_table_scan)
2775            .project(vec![col("a")])?
2776            .build()?;
2777        let plan = LogicalPlanBuilder::from(left)
2778            .join_using(
2779                right,
2780                JoinType::Right,
2781                vec![Column::from_name("a".to_string())],
2782            )?
2783            .filter(col("test2.a").lt_eq(lit(1i64)))?
2784            .build()?;
2785
2786        // not part of the test, just good to know:
2787        assert_snapshot!(plan,
2788        @r"
2789        Filter: test2.a <= Int64(1)
2790          Right Join: Using test.a = test2.a
2791            TableScan: test
2792            Projection: test2.a
2793              TableScan: test2
2794        ",
2795        );
2796        // filter sent to right side of join, not duplicated to the left
2797        assert_optimized_plan_equal!(
2798            plan,
2799            @r"
2800        Right Join: Using test.a = test2.a
2801          TableScan: test
2802          Projection: test2.a
2803            TableScan: test2, full_filters=[test2.a <= Int64(1)]
2804        "
2805        )
2806    }
2807
2808    /// single table predicate parts of ON condition should be pushed to both inputs
2809    #[test]
2810    fn join_on_with_filter() -> Result<()> {
2811        let table_scan = test_table_scan()?;
2812        let left = LogicalPlanBuilder::from(table_scan)
2813            .project(vec![col("a"), col("b"), col("c")])?
2814            .build()?;
2815        let right_table_scan = test_table_scan_with_name("test2")?;
2816        let right = LogicalPlanBuilder::from(right_table_scan)
2817            .project(vec![col("a"), col("b"), col("c")])?
2818            .build()?;
2819        let filter = col("test.c")
2820            .gt(lit(1u32))
2821            .and(col("test.b").lt(col("test2.b")))
2822            .and(col("test2.c").gt(lit(4u32)));
2823        let plan = LogicalPlanBuilder::from(left)
2824            .join(
2825                right,
2826                JoinType::Inner,
2827                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2828                Some(filter),
2829            )?
2830            .build()?;
2831
2832        // not part of the test, just good to know:
2833        assert_snapshot!(plan,
2834        @r"
2835        Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2836          Projection: test.a, test.b, test.c
2837            TableScan: test
2838          Projection: test2.a, test2.b, test2.c
2839            TableScan: test2
2840        ",
2841        );
2842        assert_optimized_plan_equal!(
2843            plan,
2844            @r"
2845        Inner Join: test.a = test2.a Filter: test.b < test2.b
2846          Projection: test.a, test.b, test.c
2847            TableScan: test, full_filters=[test.c > UInt32(1)]
2848          Projection: test2.a, test2.b, test2.c
2849            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2850        "
2851        )
2852    }
2853
2854    /// join filter should be completely removed after pushdown
2855    #[test]
2856    fn join_filter_removed() -> Result<()> {
2857        let table_scan = test_table_scan()?;
2858        let left = LogicalPlanBuilder::from(table_scan)
2859            .project(vec![col("a"), col("b"), col("c")])?
2860            .build()?;
2861        let right_table_scan = test_table_scan_with_name("test2")?;
2862        let right = LogicalPlanBuilder::from(right_table_scan)
2863            .project(vec![col("a"), col("b"), col("c")])?
2864            .build()?;
2865        let filter = col("test.b")
2866            .gt(lit(1u32))
2867            .and(col("test2.c").gt(lit(4u32)));
2868        let plan = LogicalPlanBuilder::from(left)
2869            .join(
2870                right,
2871                JoinType::Inner,
2872                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2873                Some(filter),
2874            )?
2875            .build()?;
2876
2877        // not part of the test, just good to know:
2878        assert_snapshot!(plan,
2879        @r"
2880        Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2881          Projection: test.a, test.b, test.c
2882            TableScan: test
2883          Projection: test2.a, test2.b, test2.c
2884            TableScan: test2
2885        ",
2886        );
2887        assert_optimized_plan_equal!(
2888            plan,
2889            @r"
2890        Inner Join: test.a = test2.a
2891          Projection: test.a, test.b, test.c
2892            TableScan: test, full_filters=[test.b > UInt32(1)]
2893          Projection: test2.a, test2.b, test2.c
2894            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2895        "
2896        )
2897    }
2898
2899    /// predicate on join key in filter expression should be pushed down to both inputs
2900    #[test]
2901    fn join_filter_on_common() -> Result<()> {
2902        let table_scan = test_table_scan()?;
2903        let left = LogicalPlanBuilder::from(table_scan)
2904            .project(vec![col("a")])?
2905            .build()?;
2906        let right_table_scan = test_table_scan_with_name("test2")?;
2907        let right = LogicalPlanBuilder::from(right_table_scan)
2908            .project(vec![col("b")])?
2909            .build()?;
2910        let filter = col("test.a").gt(lit(1u32));
2911        let plan = LogicalPlanBuilder::from(left)
2912            .join(
2913                right,
2914                JoinType::Inner,
2915                (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2916                Some(filter),
2917            )?
2918            .build()?;
2919
2920        // not part of the test, just good to know:
2921        assert_snapshot!(plan,
2922        @r"
2923        Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2924          Projection: test.a
2925            TableScan: test
2926          Projection: test2.b
2927            TableScan: test2
2928        ",
2929        );
2930        assert_optimized_plan_equal!(
2931            plan,
2932            @r"
2933        Inner Join: test.a = test2.b
2934          Projection: test.a
2935            TableScan: test, full_filters=[test.a > UInt32(1)]
2936          Projection: test2.b
2937            TableScan: test2, full_filters=[test2.b > UInt32(1)]
2938        "
2939        )
2940    }
2941
2942    /// single table predicate parts of ON condition should be pushed to right input
2943    #[test]
2944    fn left_join_on_with_filter() -> Result<()> {
2945        let table_scan = test_table_scan()?;
2946        let left = LogicalPlanBuilder::from(table_scan)
2947            .project(vec![col("a"), col("b"), col("c")])?
2948            .build()?;
2949        let right_table_scan = test_table_scan_with_name("test2")?;
2950        let right = LogicalPlanBuilder::from(right_table_scan)
2951            .project(vec![col("a"), col("b"), col("c")])?
2952            .build()?;
2953        let filter = col("test.a")
2954            .gt(lit(1u32))
2955            .and(col("test.b").lt(col("test2.b")))
2956            .and(col("test2.c").gt(lit(4u32)));
2957        let plan = LogicalPlanBuilder::from(left)
2958            .join(
2959                right,
2960                JoinType::Left,
2961                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2962                Some(filter),
2963            )?
2964            .build()?;
2965
2966        // not part of the test, just good to know:
2967        assert_snapshot!(plan,
2968        @r"
2969        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2970          Projection: test.a, test.b, test.c
2971            TableScan: test
2972          Projection: test2.a, test2.b, test2.c
2973            TableScan: test2
2974        ",
2975        );
2976        assert_optimized_plan_equal!(
2977            plan,
2978            @r"
2979        Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2980          Projection: test.a, test.b, test.c
2981            TableScan: test
2982          Projection: test2.a, test2.b, test2.c
2983            TableScan: test2, full_filters=[test2.c > UInt32(4)]
2984        "
2985        )
2986    }
2987
2988    /// single table predicate parts of ON condition should be pushed to left input
2989    #[test]
2990    fn right_join_on_with_filter() -> Result<()> {
2991        let table_scan = test_table_scan()?;
2992        let left = LogicalPlanBuilder::from(table_scan)
2993            .project(vec![col("a"), col("b"), col("c")])?
2994            .build()?;
2995        let right_table_scan = test_table_scan_with_name("test2")?;
2996        let right = LogicalPlanBuilder::from(right_table_scan)
2997            .project(vec![col("a"), col("b"), col("c")])?
2998            .build()?;
2999        let filter = col("test.a")
3000            .gt(lit(1u32))
3001            .and(col("test.b").lt(col("test2.b")))
3002            .and(col("test2.c").gt(lit(4u32)));
3003        let plan = LogicalPlanBuilder::from(left)
3004            .join(
3005                right,
3006                JoinType::Right,
3007                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3008                Some(filter),
3009            )?
3010            .build()?;
3011
3012        // not part of the test, just good to know:
3013        assert_snapshot!(plan,
3014        @r"
3015        Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3016          Projection: test.a, test.b, test.c
3017            TableScan: test
3018          Projection: test2.a, test2.b, test2.c
3019            TableScan: test2
3020        ",
3021        );
3022        assert_optimized_plan_equal!(
3023            plan,
3024            @r"
3025        Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3026          Projection: test.a, test.b, test.c
3027            TableScan: test, full_filters=[test.a > UInt32(1)]
3028          Projection: test2.a, test2.b, test2.c
3029            TableScan: test2
3030        "
3031        )
3032    }
3033
3034    /// single table predicate parts of ON condition should not be pushed
3035    #[test]
3036    fn full_join_on_with_filter() -> Result<()> {
3037        let table_scan = test_table_scan()?;
3038        let left = LogicalPlanBuilder::from(table_scan)
3039            .project(vec![col("a"), col("b"), col("c")])?
3040            .build()?;
3041        let right_table_scan = test_table_scan_with_name("test2")?;
3042        let right = LogicalPlanBuilder::from(right_table_scan)
3043            .project(vec![col("a"), col("b"), col("c")])?
3044            .build()?;
3045        let filter = col("test.a")
3046            .gt(lit(1u32))
3047            .and(col("test.b").lt(col("test2.b")))
3048            .and(col("test2.c").gt(lit(4u32)));
3049        let plan = LogicalPlanBuilder::from(left)
3050            .join(
3051                right,
3052                JoinType::Full,
3053                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3054                Some(filter),
3055            )?
3056            .build()?;
3057
3058        // not part of the test, just good to know:
3059        assert_snapshot!(plan,
3060        @r"
3061        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3062          Projection: test.a, test.b, test.c
3063            TableScan: test
3064          Projection: test2.a, test2.b, test2.c
3065            TableScan: test2
3066        ",
3067        );
3068        assert_optimized_plan_equal!(
3069            plan,
3070            @r"
3071        Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3072          Projection: test.a, test.b, test.c
3073            TableScan: test
3074          Projection: test2.a, test2.b, test2.c
3075            TableScan: test2
3076        "
3077        )
3078    }
3079
3080    struct PushDownProvider {
3081        pub filter_support: TableProviderFilterPushDown,
3082    }
3083
3084    #[async_trait]
3085    impl TableSource for PushDownProvider {
3086        fn schema(&self) -> SchemaRef {
3087            Arc::new(Schema::new(vec![
3088                Field::new("a", DataType::Int32, true),
3089                Field::new("b", DataType::Int32, true),
3090            ]))
3091        }
3092
3093        fn table_type(&self) -> TableType {
3094            TableType::Base
3095        }
3096
3097        fn supports_filters_pushdown(
3098            &self,
3099            filters: &[&Expr],
3100        ) -> Result<Vec<TableProviderFilterPushDown>> {
3101            Ok((0..filters.len())
3102                .map(|_| self.filter_support.clone())
3103                .collect())
3104        }
3105
3106        fn as_any(&self) -> &dyn Any {
3107            self
3108        }
3109    }
3110
3111    fn table_scan_with_pushdown_provider_builder(
3112        filter_support: TableProviderFilterPushDown,
3113        filters: Vec<Expr>,
3114        projection: Option<Vec<usize>>,
3115    ) -> Result<LogicalPlanBuilder> {
3116        let test_provider = PushDownProvider { filter_support };
3117
3118        let table_scan = LogicalPlan::TableScan(TableScan {
3119            table_name: "test".into(),
3120            filters,
3121            projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3122            projection,
3123            source: Arc::new(test_provider),
3124            fetch: None,
3125        });
3126
3127        Ok(LogicalPlanBuilder::from(table_scan))
3128    }
3129
3130    fn table_scan_with_pushdown_provider(
3131        filter_support: TableProviderFilterPushDown,
3132    ) -> Result<LogicalPlan> {
3133        table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3134            .filter(col("a").eq(lit(1i64)))?
3135            .build()
3136    }
3137
3138    #[test]
3139    fn filter_with_table_provider_exact() -> Result<()> {
3140        let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3141
3142        assert_optimized_plan_equal!(
3143            plan,
3144            @"TableScan: test, full_filters=[a = Int64(1)]"
3145        )
3146    }
3147
3148    #[test]
3149    fn filter_with_table_provider_inexact() -> Result<()> {
3150        let plan =
3151            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3152
3153        assert_optimized_plan_equal!(
3154            plan,
3155            @r"
3156        Filter: a = Int64(1)
3157          TableScan: test, partial_filters=[a = Int64(1)]
3158        "
3159        )
3160    }
3161
3162    #[test]
3163    fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3164        let plan =
3165            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3166
3167        let optimized_plan = PushDownFilter::new()
3168            .rewrite(plan, &OptimizerContext::new())
3169            .expect("failed to optimize plan")
3170            .data;
3171
3172        // Optimizing the same plan multiple times should produce the same plan
3173        // each time.
3174        assert_optimized_plan_equal!(
3175            optimized_plan,
3176            @r"
3177        Filter: a = Int64(1)
3178          TableScan: test, partial_filters=[a = Int64(1)]
3179        "
3180        )
3181    }
3182
3183    #[test]
3184    fn filter_with_table_provider_unsupported() -> Result<()> {
3185        let plan =
3186            table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3187
3188        assert_optimized_plan_equal!(
3189            plan,
3190            @r"
3191        Filter: a = Int64(1)
3192          TableScan: test
3193        "
3194        )
3195    }
3196
3197    #[test]
3198    fn multi_combined_filter() -> Result<()> {
3199        let plan = table_scan_with_pushdown_provider_builder(
3200            TableProviderFilterPushDown::Inexact,
3201            vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3202            Some(vec![0]),
3203        )?
3204        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3205        .project(vec![col("a"), col("b")])?
3206        .build()?;
3207
3208        assert_optimized_plan_equal!(
3209            plan,
3210            @r"
3211        Projection: a, b
3212          Filter: a = Int64(10) AND b > Int64(11)
3213            TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3214        "
3215        )
3216    }
3217
3218    #[test]
3219    fn multi_combined_filter_exact() -> Result<()> {
3220        let plan = table_scan_with_pushdown_provider_builder(
3221            TableProviderFilterPushDown::Exact,
3222            vec![],
3223            Some(vec![0]),
3224        )?
3225        .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3226        .project(vec![col("a"), col("b")])?
3227        .build()?;
3228
3229        assert_optimized_plan_equal!(
3230            plan,
3231            @r"
3232        Projection: a, b
3233          TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3234        "
3235        )
3236    }
3237
3238    #[test]
3239    fn test_filter_with_alias() -> Result<()> {
3240        // in table scan the true col name is 'test.a',
3241        // but we rename it as 'b', and use col 'b' in filter
3242        // we need rewrite filter col before push down.
3243        let table_scan = test_table_scan()?;
3244        let plan = LogicalPlanBuilder::from(table_scan)
3245            .project(vec![col("a").alias("b"), col("c")])?
3246            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3247            .build()?;
3248
3249        // filter on col b
3250        assert_snapshot!(plan,
3251        @r"
3252        Filter: b > Int64(10) AND test.c > Int64(10)
3253          Projection: test.a AS b, test.c
3254            TableScan: test
3255        ",
3256        );
3257        // rewrite filter col b to test.a
3258        assert_optimized_plan_equal!(
3259            plan,
3260            @r"
3261        Projection: test.a AS b, test.c
3262          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3263        "
3264        )
3265    }
3266
3267    #[test]
3268    fn test_filter_with_alias_2() -> Result<()> {
3269        // in table scan the true col name is 'test.a',
3270        // but we rename it as 'b', and use col 'b' in filter
3271        // we need rewrite filter col before push down.
3272        let table_scan = test_table_scan()?;
3273        let plan = LogicalPlanBuilder::from(table_scan)
3274            .project(vec![col("a").alias("b"), col("c")])?
3275            .project(vec![col("b"), col("c")])?
3276            .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3277            .build()?;
3278
3279        // filter on col b
3280        assert_snapshot!(plan,
3281        @r"
3282        Filter: b > Int64(10) AND test.c > Int64(10)
3283          Projection: b, test.c
3284            Projection: test.a AS b, test.c
3285              TableScan: test
3286        ",
3287        );
3288        // rewrite filter col b to test.a
3289        assert_optimized_plan_equal!(
3290            plan,
3291            @r"
3292        Projection: b, test.c
3293          Projection: test.a AS b, test.c
3294            TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3295        "
3296        )
3297    }
3298
3299    #[test]
3300    fn test_filter_with_multi_alias() -> Result<()> {
3301        let table_scan = test_table_scan()?;
3302        let plan = LogicalPlanBuilder::from(table_scan)
3303            .project(vec![col("a").alias("b"), col("c").alias("d")])?
3304            .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3305            .build()?;
3306
3307        // filter on col b and d
3308        assert_snapshot!(plan,
3309        @r"
3310        Filter: b > Int64(10) AND d > Int64(10)
3311          Projection: test.a AS b, test.c AS d
3312            TableScan: test
3313        ",
3314        );
3315        // rewrite filter col b to test.a, col d to test.c
3316        assert_optimized_plan_equal!(
3317            plan,
3318            @r"
3319        Projection: test.a AS b, test.c AS d
3320          TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3321        "
3322        )
3323    }
3324
3325    /// predicate on join key in filter expression should be pushed down to both inputs
3326    #[test]
3327    fn join_filter_with_alias() -> Result<()> {
3328        let table_scan = test_table_scan()?;
3329        let left = LogicalPlanBuilder::from(table_scan)
3330            .project(vec![col("a").alias("c")])?
3331            .build()?;
3332        let right_table_scan = test_table_scan_with_name("test2")?;
3333        let right = LogicalPlanBuilder::from(right_table_scan)
3334            .project(vec![col("b").alias("d")])?
3335            .build()?;
3336        let filter = col("c").gt(lit(1u32));
3337        let plan = LogicalPlanBuilder::from(left)
3338            .join(
3339                right,
3340                JoinType::Inner,
3341                (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3342                Some(filter),
3343            )?
3344            .build()?;
3345
3346        assert_snapshot!(plan,
3347        @r"
3348        Inner Join: c = d Filter: c > UInt32(1)
3349          Projection: test.a AS c
3350            TableScan: test
3351          Projection: test2.b AS d
3352            TableScan: test2
3353        ",
3354        );
3355        // Change filter on col `c`, 'd' to `test.a`, 'test.b'
3356        assert_optimized_plan_equal!(
3357            plan,
3358            @r"
3359        Inner Join: c = d
3360          Projection: test.a AS c
3361            TableScan: test, full_filters=[test.a > UInt32(1)]
3362          Projection: test2.b AS d
3363            TableScan: test2, full_filters=[test2.b > UInt32(1)]
3364        "
3365        )
3366    }
3367
3368    #[test]
3369    fn test_in_filter_with_alias() -> Result<()> {
3370        // in table scan the true col name is 'test.a',
3371        // but we rename it as 'b', and use col 'b' in filter
3372        // we need rewrite filter col before push down.
3373        let table_scan = test_table_scan()?;
3374        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3375        let plan = LogicalPlanBuilder::from(table_scan)
3376            .project(vec![col("a").alias("b"), col("c")])?
3377            .filter(in_list(col("b"), filter_value, false))?
3378            .build()?;
3379
3380        // filter on col b
3381        assert_snapshot!(plan,
3382        @r"
3383        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3384          Projection: test.a AS b, test.c
3385            TableScan: test
3386        ",
3387        );
3388        // rewrite filter col b to test.a
3389        assert_optimized_plan_equal!(
3390            plan,
3391            @r"
3392        Projection: test.a AS b, test.c
3393          TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3394        "
3395        )
3396    }
3397
3398    #[test]
3399    fn test_in_filter_with_alias_2() -> Result<()> {
3400        // in table scan the true col name is 'test.a',
3401        // but we rename it as 'b', and use col 'b' in filter
3402        // we need rewrite filter col before push down.
3403        let table_scan = test_table_scan()?;
3404        let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3405        let plan = LogicalPlanBuilder::from(table_scan)
3406            .project(vec![col("a").alias("b"), col("c")])?
3407            .project(vec![col("b"), col("c")])?
3408            .filter(in_list(col("b"), filter_value, false))?
3409            .build()?;
3410
3411        // filter on col b
3412        assert_snapshot!(plan,
3413        @r"
3414        Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3415          Projection: b, test.c
3416            Projection: test.a AS b, test.c
3417              TableScan: test
3418        ",
3419        );
3420        // rewrite filter col b to test.a
3421        assert_optimized_plan_equal!(
3422            plan,
3423            @r"
3424        Projection: b, test.c
3425          Projection: test.a AS b, test.c
3426            TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3427        "
3428        )
3429    }
3430
3431    #[test]
3432    fn test_in_subquery_with_alias() -> Result<()> {
3433        // in table scan the true col name is 'test.a',
3434        // but we rename it as 'b', and use col 'b' in subquery filter
3435        let table_scan = test_table_scan()?;
3436        let table_scan_sq = test_table_scan_with_name("sq")?;
3437        let subplan = Arc::new(
3438            LogicalPlanBuilder::from(table_scan_sq)
3439                .project(vec![col("c")])?
3440                .build()?,
3441        );
3442        let plan = LogicalPlanBuilder::from(table_scan)
3443            .project(vec![col("a").alias("b"), col("c")])?
3444            .filter(in_subquery(col("b"), subplan))?
3445            .build()?;
3446
3447        // filter on col b in subquery
3448        assert_snapshot!(plan,
3449        @r"
3450        Filter: b IN (<subquery>)
3451          Subquery:
3452            Projection: sq.c
3453              TableScan: sq
3454          Projection: test.a AS b, test.c
3455            TableScan: test
3456        ",
3457        );
3458        // rewrite filter col b to test.a
3459        assert_optimized_plan_equal!(
3460            plan,
3461            @r"
3462        Projection: test.a AS b, test.c
3463          TableScan: test, full_filters=[test.a IN (<subquery>)]
3464            Subquery:
3465              Projection: sq.c
3466                TableScan: sq
3467        "
3468        )
3469    }
3470
3471    #[test]
3472    fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3473        // SELECT a FROM (SELECT 1 AS a) b WHERE b.a = 1
3474        let plan = LogicalPlanBuilder::empty(true)
3475            .project(vec![lit(0i64).alias("a")])?
3476            .alias("b")?
3477            .project(vec![col("b.a")])?
3478            .alias("b")?
3479            .filter(col("b.a").eq(lit(1i64)))?
3480            .project(vec![col("b.a")])?
3481            .build()?;
3482
3483        assert_snapshot!(plan,
3484        @r"
3485        Projection: b.a
3486          Filter: b.a = Int64(1)
3487            SubqueryAlias: b
3488              Projection: b.a
3489                SubqueryAlias: b
3490                  Projection: Int64(0) AS a
3491                    EmptyRelation: rows=1
3492        ",
3493        );
3494        // Ensure that the predicate without any columns (0 = 1) is
3495        // still there.
3496        assert_optimized_plan_equal!(
3497            plan,
3498            @r"
3499        Projection: b.a
3500          SubqueryAlias: b
3501            Projection: b.a
3502              SubqueryAlias: b
3503                Projection: Int64(0) AS a
3504                  Filter: Int64(0) = Int64(1)
3505                    EmptyRelation: rows=1
3506        "
3507        )
3508    }
3509
3510    #[test]
3511    fn test_crossjoin_with_or_clause() -> Result<()> {
3512        // select * from test,test1 where (test.a = test1.a and test.b > 1) or (test.b = test1.b and test.c < 10);
3513        let table_scan = test_table_scan()?;
3514        let left = LogicalPlanBuilder::from(table_scan)
3515            .project(vec![col("a"), col("b"), col("c")])?
3516            .build()?;
3517        let right_table_scan = test_table_scan_with_name("test1")?;
3518        let right = LogicalPlanBuilder::from(right_table_scan)
3519            .project(vec![col("a").alias("d"), col("a").alias("e")])?
3520            .build()?;
3521        let filter = or(
3522            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3523            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3524        );
3525        let plan = LogicalPlanBuilder::from(left)
3526            .cross_join(right)?
3527            .filter(filter)?
3528            .build()?;
3529
3530        assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3531        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3532          Projection: test.a, test.b, test.c
3533            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3534          Projection: test1.a AS d, test1.a AS e
3535            TableScan: test1
3536        ")?;
3537
3538        // Originally global state which can help to avoid duplicate Filters been generated and pushed down.
3539        // Now the global state is removed. Need to double confirm that avoid duplicate Filters.
3540        let optimized_plan = PushDownFilter::new()
3541            .rewrite(plan, &OptimizerContext::new())
3542            .expect("failed to optimize plan")
3543            .data;
3544        assert_optimized_plan_equal!(
3545            optimized_plan,
3546            @r"
3547        Inner Join:  Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3548          Projection: test.a, test.b, test.c
3549            TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3550          Projection: test1.a AS d, test1.a AS e
3551            TableScan: test1
3552        "
3553        )
3554    }
3555
3556    #[test]
3557    fn left_semi_join() -> Result<()> {
3558        let left = test_table_scan_with_name("test1")?;
3559        let right_table_scan = test_table_scan_with_name("test2")?;
3560        let right = LogicalPlanBuilder::from(right_table_scan)
3561            .project(vec![col("a"), col("b")])?
3562            .build()?;
3563        let plan = LogicalPlanBuilder::from(left)
3564            .join(
3565                right,
3566                JoinType::LeftSemi,
3567                (
3568                    vec![Column::from_qualified_name("test1.a")],
3569                    vec![Column::from_qualified_name("test2.a")],
3570                ),
3571                None,
3572            )?
3573            .filter(col("test2.a").lt_eq(lit(1i64)))?
3574            .build()?;
3575
3576        // not part of the test, just good to know:
3577        assert_snapshot!(plan,
3578        @r"
3579        Filter: test2.a <= Int64(1)
3580          LeftSemi Join: test1.a = test2.a
3581            TableScan: test1
3582            Projection: test2.a, test2.b
3583              TableScan: test2
3584        ",
3585        );
3586        // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side.
3587        assert_optimized_plan_equal!(
3588            plan,
3589            @r"
3590        Filter: test2.a <= Int64(1)
3591          LeftSemi Join: test1.a = test2.a
3592            TableScan: test1, full_filters=[test1.a <= Int64(1)]
3593            Projection: test2.a, test2.b
3594              TableScan: test2
3595        "
3596        )
3597    }
3598
3599    #[test]
3600    fn left_semi_join_with_filters() -> Result<()> {
3601        let left = test_table_scan_with_name("test1")?;
3602        let right_table_scan = test_table_scan_with_name("test2")?;
3603        let right = LogicalPlanBuilder::from(right_table_scan)
3604            .project(vec![col("a"), col("b")])?
3605            .build()?;
3606        let plan = LogicalPlanBuilder::from(left)
3607            .join(
3608                right,
3609                JoinType::LeftSemi,
3610                (
3611                    vec![Column::from_qualified_name("test1.a")],
3612                    vec![Column::from_qualified_name("test2.a")],
3613                ),
3614                Some(
3615                    col("test1.b")
3616                        .gt(lit(1u32))
3617                        .and(col("test2.b").gt(lit(2u32))),
3618                ),
3619            )?
3620            .build()?;
3621
3622        // not part of the test, just good to know:
3623        assert_snapshot!(plan,
3624        @r"
3625        LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3626          TableScan: test1
3627          Projection: test2.a, test2.b
3628            TableScan: test2
3629        ",
3630        );
3631        // Both side will be pushed down.
3632        assert_optimized_plan_equal!(
3633            plan,
3634            @r"
3635        LeftSemi Join: test1.a = test2.a
3636          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3637          Projection: test2.a, test2.b
3638            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3639        "
3640        )
3641    }
3642
3643    #[test]
3644    fn right_semi_join() -> Result<()> {
3645        let left = test_table_scan_with_name("test1")?;
3646        let right_table_scan = test_table_scan_with_name("test2")?;
3647        let right = LogicalPlanBuilder::from(right_table_scan)
3648            .project(vec![col("a"), col("b")])?
3649            .build()?;
3650        let plan = LogicalPlanBuilder::from(left)
3651            .join(
3652                right,
3653                JoinType::RightSemi,
3654                (
3655                    vec![Column::from_qualified_name("test1.a")],
3656                    vec![Column::from_qualified_name("test2.a")],
3657                ),
3658                None,
3659            )?
3660            .filter(col("test1.a").lt_eq(lit(1i64)))?
3661            .build()?;
3662
3663        // not part of the test, just good to know:
3664        assert_snapshot!(plan,
3665        @r"
3666        Filter: test1.a <= Int64(1)
3667          RightSemi Join: test1.a = test2.a
3668            TableScan: test1
3669            Projection: test2.a, test2.b
3670              TableScan: test2
3671        ",
3672        );
3673        // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side.
3674        assert_optimized_plan_equal!(
3675            plan,
3676            @r"
3677        Filter: test1.a <= Int64(1)
3678          RightSemi Join: test1.a = test2.a
3679            TableScan: test1
3680            Projection: test2.a, test2.b
3681              TableScan: test2, full_filters=[test2.a <= Int64(1)]
3682        "
3683        )
3684    }
3685
3686    #[test]
3687    fn right_semi_join_with_filters() -> Result<()> {
3688        let left = test_table_scan_with_name("test1")?;
3689        let right_table_scan = test_table_scan_with_name("test2")?;
3690        let right = LogicalPlanBuilder::from(right_table_scan)
3691            .project(vec![col("a"), col("b")])?
3692            .build()?;
3693        let plan = LogicalPlanBuilder::from(left)
3694            .join(
3695                right,
3696                JoinType::RightSemi,
3697                (
3698                    vec![Column::from_qualified_name("test1.a")],
3699                    vec![Column::from_qualified_name("test2.a")],
3700                ),
3701                Some(
3702                    col("test1.b")
3703                        .gt(lit(1u32))
3704                        .and(col("test2.b").gt(lit(2u32))),
3705                ),
3706            )?
3707            .build()?;
3708
3709        // not part of the test, just good to know:
3710        assert_snapshot!(plan,
3711        @r"
3712        RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3713          TableScan: test1
3714          Projection: test2.a, test2.b
3715            TableScan: test2
3716        ",
3717        );
3718        // Both side will be pushed down.
3719        assert_optimized_plan_equal!(
3720            plan,
3721            @r"
3722        RightSemi Join: test1.a = test2.a
3723          TableScan: test1, full_filters=[test1.b > UInt32(1)]
3724          Projection: test2.a, test2.b
3725            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3726        "
3727        )
3728    }
3729
3730    #[test]
3731    fn left_anti_join() -> Result<()> {
3732        let table_scan = test_table_scan_with_name("test1")?;
3733        let left = LogicalPlanBuilder::from(table_scan)
3734            .project(vec![col("a"), col("b")])?
3735            .build()?;
3736        let right_table_scan = test_table_scan_with_name("test2")?;
3737        let right = LogicalPlanBuilder::from(right_table_scan)
3738            .project(vec![col("a"), col("b")])?
3739            .build()?;
3740        let plan = LogicalPlanBuilder::from(left)
3741            .join(
3742                right,
3743                JoinType::LeftAnti,
3744                (
3745                    vec![Column::from_qualified_name("test1.a")],
3746                    vec![Column::from_qualified_name("test2.a")],
3747                ),
3748                None,
3749            )?
3750            .filter(col("test2.a").gt(lit(2u32)))?
3751            .build()?;
3752
3753        // not part of the test, just good to know:
3754        assert_snapshot!(plan,
3755        @r"
3756        Filter: test2.a > UInt32(2)
3757          LeftAnti Join: test1.a = test2.a
3758            Projection: test1.a, test1.b
3759              TableScan: test1
3760            Projection: test2.a, test2.b
3761              TableScan: test2
3762        ",
3763        );
3764        // For left anti, filter of the right side filter can be pushed down.
3765        assert_optimized_plan_equal!(
3766            plan,
3767            @r"
3768        Filter: test2.a > UInt32(2)
3769          LeftAnti Join: test1.a = test2.a
3770            Projection: test1.a, test1.b
3771              TableScan: test1, full_filters=[test1.a > UInt32(2)]
3772            Projection: test2.a, test2.b
3773              TableScan: test2
3774        "
3775        )
3776    }
3777
3778    #[test]
3779    fn left_anti_join_with_filters() -> Result<()> {
3780        let table_scan = test_table_scan_with_name("test1")?;
3781        let left = LogicalPlanBuilder::from(table_scan)
3782            .project(vec![col("a"), col("b")])?
3783            .build()?;
3784        let right_table_scan = test_table_scan_with_name("test2")?;
3785        let right = LogicalPlanBuilder::from(right_table_scan)
3786            .project(vec![col("a"), col("b")])?
3787            .build()?;
3788        let plan = LogicalPlanBuilder::from(left)
3789            .join(
3790                right,
3791                JoinType::LeftAnti,
3792                (
3793                    vec![Column::from_qualified_name("test1.a")],
3794                    vec![Column::from_qualified_name("test2.a")],
3795                ),
3796                Some(
3797                    col("test1.b")
3798                        .gt(lit(1u32))
3799                        .and(col("test2.b").gt(lit(2u32))),
3800                ),
3801            )?
3802            .build()?;
3803
3804        // not part of the test, just good to know:
3805        assert_snapshot!(plan,
3806        @r"
3807        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3808          Projection: test1.a, test1.b
3809            TableScan: test1
3810          Projection: test2.a, test2.b
3811            TableScan: test2
3812        ",
3813        );
3814        // For left anti, filter of the right side filter can be pushed down.
3815        assert_optimized_plan_equal!(
3816            plan,
3817            @r"
3818        LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3819          Projection: test1.a, test1.b
3820            TableScan: test1
3821          Projection: test2.a, test2.b
3822            TableScan: test2, full_filters=[test2.b > UInt32(2)]
3823        "
3824        )
3825    }
3826
3827    #[test]
3828    fn right_anti_join() -> Result<()> {
3829        let table_scan = test_table_scan_with_name("test1")?;
3830        let left = LogicalPlanBuilder::from(table_scan)
3831            .project(vec![col("a"), col("b")])?
3832            .build()?;
3833        let right_table_scan = test_table_scan_with_name("test2")?;
3834        let right = LogicalPlanBuilder::from(right_table_scan)
3835            .project(vec![col("a"), col("b")])?
3836            .build()?;
3837        let plan = LogicalPlanBuilder::from(left)
3838            .join(
3839                right,
3840                JoinType::RightAnti,
3841                (
3842                    vec![Column::from_qualified_name("test1.a")],
3843                    vec![Column::from_qualified_name("test2.a")],
3844                ),
3845                None,
3846            )?
3847            .filter(col("test1.a").gt(lit(2u32)))?
3848            .build()?;
3849
3850        // not part of the test, just good to know:
3851        assert_snapshot!(plan,
3852        @r"
3853        Filter: test1.a > UInt32(2)
3854          RightAnti Join: test1.a = test2.a
3855            Projection: test1.a, test1.b
3856              TableScan: test1
3857            Projection: test2.a, test2.b
3858              TableScan: test2
3859        ",
3860        );
3861        // For right anti, filter of the left side can be pushed down.
3862        assert_optimized_plan_equal!(
3863            plan,
3864            @r"
3865        Filter: test1.a > UInt32(2)
3866          RightAnti Join: test1.a = test2.a
3867            Projection: test1.a, test1.b
3868              TableScan: test1
3869            Projection: test2.a, test2.b
3870              TableScan: test2, full_filters=[test2.a > UInt32(2)]
3871        "
3872        )
3873    }
3874
3875    #[test]
3876    fn right_anti_join_with_filters() -> Result<()> {
3877        let table_scan = test_table_scan_with_name("test1")?;
3878        let left = LogicalPlanBuilder::from(table_scan)
3879            .project(vec![col("a"), col("b")])?
3880            .build()?;
3881        let right_table_scan = test_table_scan_with_name("test2")?;
3882        let right = LogicalPlanBuilder::from(right_table_scan)
3883            .project(vec![col("a"), col("b")])?
3884            .build()?;
3885        let plan = LogicalPlanBuilder::from(left)
3886            .join(
3887                right,
3888                JoinType::RightAnti,
3889                (
3890                    vec![Column::from_qualified_name("test1.a")],
3891                    vec![Column::from_qualified_name("test2.a")],
3892                ),
3893                Some(
3894                    col("test1.b")
3895                        .gt(lit(1u32))
3896                        .and(col("test2.b").gt(lit(2u32))),
3897                ),
3898            )?
3899            .build()?;
3900
3901        // not part of the test, just good to know:
3902        assert_snapshot!(plan,
3903        @r"
3904        RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3905          Projection: test1.a, test1.b
3906            TableScan: test1
3907          Projection: test2.a, test2.b
3908            TableScan: test2
3909        ",
3910        );
3911        // For right anti, filter of the left side can be pushed down.
3912        assert_optimized_plan_equal!(
3913            plan,
3914            @r"
3915        RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3916          Projection: test1.a, test1.b
3917            TableScan: test1, full_filters=[test1.b > UInt32(1)]
3918          Projection: test2.a, test2.b
3919            TableScan: test2
3920        "
3921        )
3922    }
3923
3924    #[derive(Debug, PartialEq, Eq, Hash)]
3925    struct TestScalarUDF {
3926        signature: Signature,
3927    }
3928
3929    impl ScalarUDFImpl for TestScalarUDF {
3930        fn as_any(&self) -> &dyn Any {
3931            self
3932        }
3933        fn name(&self) -> &str {
3934            "TestScalarUDF"
3935        }
3936
3937        fn signature(&self) -> &Signature {
3938            &self.signature
3939        }
3940
3941        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3942            Ok(DataType::Int32)
3943        }
3944
3945        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3946            Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3947        }
3948    }
3949
3950    #[test]
3951    fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3952        // SELECT t.a, t.r FROM (SELECT a, sum(b),  TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
3953        let table_scan = test_table_scan_with_name("test1")?;
3954        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3955            signature: Signature::exact(vec![], Volatility::Volatile),
3956        });
3957        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3958
3959        let plan = LogicalPlanBuilder::from(table_scan)
3960            .aggregate(vec![col("a")], vec![sum(col("b"))])?
3961            .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3962            .alias("t")?
3963            .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3964            .project(vec![col("t.a"), col("t.r")])?
3965            .build()?;
3966
3967        assert_snapshot!(plan,
3968        @r"
3969        Projection: t.a, t.r
3970          Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3971            SubqueryAlias: t
3972              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3973                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3974                  TableScan: test1
3975        ",
3976        );
3977        assert_optimized_plan_equal!(
3978            plan,
3979            @r"
3980        Projection: t.a, t.r
3981          SubqueryAlias: t
3982            Filter: r > Float64(0.5)
3983              Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3984                Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3985                  TableScan: test1, full_filters=[test1.a > Int32(5)]
3986        "
3987        )
3988    }
3989
3990    #[test]
3991    fn test_push_down_volatile_function_in_join() -> Result<()> {
3992        // SELECT t.a, t.r FROM (SELECT test1.a AS a, TestScalarUDF() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
3993        let table_scan = test_table_scan_with_name("test1")?;
3994        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3995            signature: Signature::exact(vec![], Volatility::Volatile),
3996        });
3997        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3998        let left = LogicalPlanBuilder::from(table_scan).build()?;
3999        let right_table_scan = test_table_scan_with_name("test2")?;
4000        let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4001        let plan = LogicalPlanBuilder::from(left)
4002            .join(
4003                right,
4004                JoinType::Inner,
4005                (
4006                    vec![Column::from_qualified_name("test1.a")],
4007                    vec![Column::from_qualified_name("test2.a")],
4008                ),
4009                None,
4010            )?
4011            .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4012            .alias("t")?
4013            .filter(col("t.r").gt(lit(0.8)))?
4014            .project(vec![col("t.a"), col("t.r")])?
4015            .build()?;
4016
4017        assert_snapshot!(plan,
4018        @r"
4019        Projection: t.a, t.r
4020          Filter: t.r > Float64(0.8)
4021            SubqueryAlias: t
4022              Projection: test1.a AS a, TestScalarUDF() AS r
4023                Inner Join: test1.a = test2.a
4024                  TableScan: test1
4025                  TableScan: test2
4026        ",
4027        );
4028        assert_optimized_plan_equal!(
4029            plan,
4030            @r"
4031        Projection: t.a, t.r
4032          SubqueryAlias: t
4033            Filter: r > Float64(0.8)
4034              Projection: test1.a AS a, TestScalarUDF() AS r
4035                Inner Join: test1.a = test2.a
4036                  TableScan: test1
4037                  TableScan: test2
4038        "
4039        )
4040    }
4041
4042    #[test]
4043    fn test_push_down_volatile_table_scan() -> Result<()> {
4044        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1;
4045        let table_scan = test_table_scan()?;
4046        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4047            signature: Signature::exact(vec![], Volatility::Volatile),
4048        });
4049        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4050        let plan = LogicalPlanBuilder::from(table_scan)
4051            .project(vec![col("a"), col("b")])?
4052            .filter(expr.gt(lit(0.1)))?
4053            .build()?;
4054
4055        assert_snapshot!(plan,
4056        @r"
4057        Filter: TestScalarUDF() > Float64(0.1)
4058          Projection: test.a, test.b
4059            TableScan: test
4060        ",
4061        );
4062        assert_optimized_plan_equal!(
4063            plan,
4064            @r"
4065        Projection: test.a, test.b
4066          Filter: TestScalarUDF() > Float64(0.1)
4067            TableScan: test
4068        "
4069        )
4070    }
4071
4072    #[test]
4073    fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4074        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4075        let table_scan = test_table_scan()?;
4076        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4077            signature: Signature::exact(vec![], Volatility::Volatile),
4078        });
4079        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4080        let plan = LogicalPlanBuilder::from(table_scan)
4081            .project(vec![col("a"), col("b")])?
4082            .filter(
4083                expr.gt(lit(0.1))
4084                    .and(col("t.a").gt(lit(5)))
4085                    .and(col("t.b").gt(lit(10))),
4086            )?
4087            .build()?;
4088
4089        assert_snapshot!(plan,
4090        @r"
4091        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4092          Projection: test.a, test.b
4093            TableScan: test
4094        ",
4095        );
4096        assert_optimized_plan_equal!(
4097            plan,
4098            @r"
4099        Projection: test.a, test.b
4100          Filter: TestScalarUDF() > Float64(0.1)
4101            TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4102        "
4103        )
4104    }
4105
4106    #[test]
4107    fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4108        // SELECT test.a, test.b FROM test as t WHERE TestScalarUDF() > 0.1 and test.a > 5 and test.b > 10;
4109        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4110            signature: Signature::exact(vec![], Volatility::Volatile),
4111        });
4112        let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4113        let plan = table_scan_with_pushdown_provider_builder(
4114            TableProviderFilterPushDown::Unsupported,
4115            vec![],
4116            None,
4117        )?
4118        .project(vec![col("a"), col("b")])?
4119        .filter(
4120            expr.gt(lit(0.1))
4121                .and(col("t.a").gt(lit(5)))
4122                .and(col("t.b").gt(lit(10))),
4123        )?
4124        .build()?;
4125
4126        assert_snapshot!(plan,
4127        @r"
4128        Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4129          Projection: a, b
4130            TableScan: test
4131        ",
4132        );
4133        assert_optimized_plan_equal!(
4134            plan,
4135            @r"
4136        Projection: a, b
4137          Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4138            TableScan: test
4139        "
4140        )
4141    }
4142
4143    #[test]
4144    fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4145        // Define a custom user-defined logical node
4146        #[derive(Debug, Hash, Eq, PartialEq)]
4147        struct TestUserNode {
4148            schema: DFSchemaRef,
4149        }
4150
4151        impl PartialOrd for TestUserNode {
4152            fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4153                None
4154            }
4155        }
4156
4157        impl TestUserNode {
4158            fn new() -> Self {
4159                let schema = Arc::new(
4160                    DFSchema::new_with_metadata(
4161                        vec![(None, Field::new("a", DataType::Int64, false).into())],
4162                        Default::default(),
4163                    )
4164                    .unwrap(),
4165                );
4166
4167                Self { schema }
4168            }
4169        }
4170
4171        impl UserDefinedLogicalNodeCore for TestUserNode {
4172            fn name(&self) -> &str {
4173                "test_node"
4174            }
4175
4176            fn inputs(&self) -> Vec<&LogicalPlan> {
4177                vec![]
4178            }
4179
4180            fn schema(&self) -> &DFSchemaRef {
4181                &self.schema
4182            }
4183
4184            fn expressions(&self) -> Vec<Expr> {
4185                vec![]
4186            }
4187
4188            fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4189                write!(f, "TestUserNode")
4190            }
4191
4192            fn with_exprs_and_inputs(
4193                &self,
4194                exprs: Vec<Expr>,
4195                inputs: Vec<LogicalPlan>,
4196            ) -> Result<Self> {
4197                assert!(exprs.is_empty());
4198                assert!(inputs.is_empty());
4199                Ok(Self {
4200                    schema: Arc::clone(&self.schema),
4201                })
4202            }
4203        }
4204
4205        // Create a node and build a plan with a filter
4206        let node = LogicalPlan::Extension(Extension {
4207            node: Arc::new(TestUserNode::new()),
4208        });
4209
4210        let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4211
4212        // Check the original plan format (not part of the test assertions)
4213        assert_snapshot!(plan,
4214        @r"
4215        Filter: Boolean(false)
4216          TestUserNode
4217        ",
4218        );
4219        // Check that the filter is pushed down to the user-defined node
4220        assert_optimized_plan_equal!(
4221            plan,
4222            @r"
4223        Filter: Boolean(false)
4224          TestUserNode
4225        "
4226        )
4227    }
4228}