datafusion_physical_optimizer/
projection_pushdown.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the `ProjectionPushdown` physical optimization rule.
19//! The function [`remove_unnecessary_projections`] tries to push down all
20//! projections one by one if the operator below is amenable to this. If a
21//! projection reaches a source, it can even disappear from the plan entirely.
22
23use crate::PhysicalOptimizerRule;
24use arrow::datatypes::{Fields, Schema, SchemaRef};
25use datafusion_common::alias::AliasGenerator;
26use datafusion_physical_expr::PhysicalExprExt;
27use std::collections::HashSet;
28use std::sync::Arc;
29
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::{
32    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
33};
34use datafusion_common::{JoinSide, JoinType, Result};
35use datafusion_physical_expr::expressions::Column;
36use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
37use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter};
38use datafusion_physical_plan::joins::NestedLoopJoinExec;
39use datafusion_physical_plan::projection::{
40    remove_unnecessary_projections, ProjectionExec,
41};
42use datafusion_physical_plan::ExecutionPlan;
43
44/// This rule inspects `ProjectionExec`'s in the given physical plan and tries to
45/// remove or swap with its child.
46///
47/// Furthermore, tries to push down projections from nested loop join filters that only depend on
48/// one side of the join. By pushing these projections down, functions that only depend on one side
49/// of the join must be evaluated for the cartesian product of the two sides.
50#[derive(Default, Debug)]
51pub struct ProjectionPushdown {}
52
53impl ProjectionPushdown {
54    #[allow(missing_docs)]
55    pub fn new() -> Self {
56        Self {}
57    }
58}
59
60impl PhysicalOptimizerRule for ProjectionPushdown {
61    fn optimize(
62        &self,
63        plan: Arc<dyn ExecutionPlan>,
64        _config: &ConfigOptions,
65    ) -> Result<Arc<dyn ExecutionPlan>> {
66        let alias_generator = AliasGenerator::new();
67        let plan = plan
68            .transform_up(|plan| {
69                match plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
70                    None => Ok(Transformed::no(plan)),
71                    Some(hash_join) => try_push_down_join_filter(
72                        Arc::clone(&plan),
73                        hash_join,
74                        &alias_generator,
75                    ),
76                }
77            })
78            .map(|t| t.data)?;
79
80        plan.transform_down(remove_unnecessary_projections).data()
81    }
82
83    fn name(&self) -> &str {
84        "ProjectionPushdown"
85    }
86
87    fn schema_check(&self) -> bool {
88        true
89    }
90}
91
92/// Tries to push down parts of the filter.
93///
94/// See [JoinFilterRewriter] for details.
95fn try_push_down_join_filter(
96    original_plan: Arc<dyn ExecutionPlan>,
97    join: &NestedLoopJoinExec,
98    alias_generator: &AliasGenerator,
99) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
100    // Mark joins are currently not supported.
101    if matches!(join.join_type(), JoinType::LeftMark | JoinType::RightMark) {
102        return Ok(Transformed::no(original_plan));
103    }
104
105    let projections = join.projection();
106    let Some(filter) = join.filter() else {
107        return Ok(Transformed::no(original_plan));
108    };
109
110    let original_lhs_length = join.left().schema().fields().len();
111    let original_rhs_length = join.right().schema().fields().len();
112
113    let lhs_rewrite = try_push_down_projection(
114        Arc::clone(&join.right().schema()),
115        Arc::clone(join.left()),
116        JoinSide::Left,
117        filter.clone(),
118        alias_generator,
119    )?;
120    let rhs_rewrite = try_push_down_projection(
121        Arc::clone(&lhs_rewrite.data.0.schema()),
122        Arc::clone(join.right()),
123        JoinSide::Right,
124        lhs_rewrite.data.1,
125        alias_generator,
126    )?;
127    if !lhs_rewrite.transformed && !rhs_rewrite.transformed {
128        return Ok(Transformed::no(original_plan));
129    }
130
131    let join_filter = minimize_join_filter(
132        Arc::clone(rhs_rewrite.data.1.expression()),
133        rhs_rewrite.data.1.column_indices().to_vec(),
134        lhs_rewrite.data.0.schema().as_ref(),
135        rhs_rewrite.data.0.schema().as_ref(),
136    );
137
138    let new_lhs_length = lhs_rewrite.data.0.schema().fields.len();
139    let projections = match projections {
140        None => match join.join_type() {
141            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
142                // Build projections that ignore the newly projected columns.
143                let mut projections = Vec::new();
144                projections.extend(0..original_lhs_length);
145                projections.extend(new_lhs_length..new_lhs_length + original_rhs_length);
146                projections
147            }
148            JoinType::LeftSemi | JoinType::LeftAnti => {
149                // Only return original left columns
150                let mut projections = Vec::new();
151                projections.extend(0..original_lhs_length);
152                projections
153            }
154            JoinType::RightSemi | JoinType::RightAnti => {
155                // Only return original right columns
156                let mut projections = Vec::new();
157                projections.extend(0..original_rhs_length);
158                projections
159            }
160            _ => unreachable!("Unsupported join type"),
161        },
162        Some(projections) => {
163            let rhs_offset = new_lhs_length - original_lhs_length;
164            projections
165                .iter()
166                .map(|idx| {
167                    if *idx >= original_lhs_length {
168                        idx + rhs_offset
169                    } else {
170                        *idx
171                    }
172                })
173                .collect()
174        }
175    };
176
177    Ok(Transformed::yes(Arc::new(NestedLoopJoinExec::try_new(
178        lhs_rewrite.data.0,
179        rhs_rewrite.data.0,
180        Some(join_filter),
181        join.join_type(),
182        Some(projections),
183    )?)))
184}
185
186/// Tries to push down parts of `expr` into the `join_side`.
187fn try_push_down_projection(
188    other_schema: SchemaRef,
189    plan: Arc<dyn ExecutionPlan>,
190    join_side: JoinSide,
191    join_filter: JoinFilter,
192    alias_generator: &AliasGenerator,
193) -> Result<Transformed<(Arc<dyn ExecutionPlan>, JoinFilter)>> {
194    let expr = Arc::clone(join_filter.expression());
195    let original_plan_schema = plan.schema();
196    let mut rewriter = JoinFilterRewriter::new(
197        join_side,
198        original_plan_schema.as_ref(),
199        join_filter.column_indices().to_vec(),
200        alias_generator,
201    );
202    let new_expr = rewriter.rewrite(expr)?;
203
204    if new_expr.transformed {
205        let new_join_side =
206            ProjectionExec::try_new(rewriter.join_side_projections, plan)?;
207        let new_schema = Arc::clone(&new_join_side.schema());
208
209        let (lhs_schema, rhs_schema) = match join_side {
210            JoinSide::Left => (new_schema, other_schema),
211            JoinSide::Right => (other_schema, new_schema),
212            JoinSide::None => unreachable!("Mark join not supported"),
213        };
214        let intermediate_schema = rewriter
215            .intermediate_column_indices
216            .iter()
217            .map(|ci| match ci.side {
218                JoinSide::Left => Arc::clone(&lhs_schema.fields[ci.index]),
219                JoinSide::Right => Arc::clone(&rhs_schema.fields[ci.index]),
220                JoinSide::None => unreachable!("Mark join not supported"),
221            })
222            .collect::<Fields>();
223
224        let join_filter = JoinFilter::new(
225            new_expr.data,
226            rewriter.intermediate_column_indices,
227            Arc::new(Schema::new(intermediate_schema)),
228        );
229        Ok(Transformed::yes((Arc::new(new_join_side), join_filter)))
230    } else {
231        Ok(Transformed::no((plan, join_filter)))
232    }
233}
234
235/// Creates a new [JoinFilter] and tries to minimize the internal schema.
236///
237/// This could eliminate some columns that were only part of a computation that has been pushed
238/// down. As this computation is now materialized on one side of the join, the original input
239/// columns are not needed anymore.
240fn minimize_join_filter(
241    expr: Arc<dyn PhysicalExpr>,
242    old_column_indices: Vec<ColumnIndex>,
243    lhs_schema: &Schema,
244    rhs_schema: &Schema,
245) -> JoinFilter {
246    let mut used_columns = HashSet::new();
247    expr.apply_with_lambdas_params(|expr, lambdas_params| {
248        if let Some(col) = expr.as_any().downcast_ref::<Column>() {
249            if !lambdas_params.contains(col.name()) {
250                used_columns.insert(col.index());
251            }
252        }
253        Ok(TreeNodeRecursion::Continue)
254    })
255    .expect("Closure cannot fail");
256
257    let new_column_indices = old_column_indices
258        .iter()
259        .enumerate()
260        .filter(|(idx, _)| used_columns.contains(idx))
261        .map(|(_, ci)| ci.clone())
262        .collect::<Vec<_>>();
263    let fields = new_column_indices
264        .iter()
265        .map(|ci| match ci.side {
266            JoinSide::Left => lhs_schema.field(ci.index).clone(),
267            JoinSide::Right => rhs_schema.field(ci.index).clone(),
268            JoinSide::None => unreachable!("Mark join not supported"),
269        })
270        .collect::<Fields>();
271
272    let final_expr = expr
273        .transform_up_with_lambdas_params(|expr, lambdas_params| {
274            match expr.as_any().downcast_ref::<Column>() {
275                Some(column) if !lambdas_params.contains(column.name()) => {
276                    let new_idx = used_columns
277                        .iter()
278                        .filter(|idx| **idx < column.index())
279                        .count();
280                    let new_column = Column::new(column.name(), new_idx);
281                    Ok(Transformed::yes(
282                        Arc::new(new_column) as Arc<dyn PhysicalExpr>
283                    ))
284                }
285                _ => Ok(Transformed::no(expr)),
286            }
287        })
288        .expect("Closure cannot fail");
289
290    JoinFilter::new(
291        final_expr.data,
292        new_column_indices,
293        Arc::new(Schema::new(fields)),
294    )
295}
296
297/// Implements the push-down machinery.
298///
299/// The rewriter starts at the top of the filter expression and traverses the expression tree. For
300/// each (sub-)expression, the rewriter checks whether it only refers to one side of the join. If
301/// this is never the case, no subexpressions of the filter can be pushed down. If there is a
302/// subexpression that can be computed using only one side of the join, the entire subexpression is
303/// pushed down to the join side.
304struct JoinFilterRewriter<'a> {
305    join_side: JoinSide,
306    join_side_schema: &'a Schema,
307    join_side_projections: Vec<(Arc<dyn PhysicalExpr>, String)>,
308    intermediate_column_indices: Vec<ColumnIndex>,
309    alias_generator: &'a AliasGenerator,
310}
311
312impl<'a> JoinFilterRewriter<'a> {
313    /// Creates a new [JoinFilterRewriter].
314    fn new(
315        join_side: JoinSide,
316        join_side_schema: &'a Schema,
317        column_indices: Vec<ColumnIndex>,
318        alias_generator: &'a AliasGenerator,
319    ) -> Self {
320        let projections = join_side_schema
321            .fields()
322            .iter()
323            .enumerate()
324            .map(|(idx, field)| {
325                (
326                    Arc::new(Column::new(field.name(), idx)) as Arc<dyn PhysicalExpr>,
327                    field.name().to_string(),
328                )
329            })
330            .collect();
331
332        Self {
333            join_side,
334            join_side_schema,
335            join_side_projections: projections,
336            intermediate_column_indices: column_indices,
337            alias_generator,
338        }
339    }
340
341    /// Executes the push-down machinery on `expr`.
342    ///
343    /// See the [JoinFilterRewriter] for further information.
344    fn rewrite(
345        &mut self,
346        expr: Arc<dyn PhysicalExpr>,
347    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
348        let depends_on_this_side = self.depends_on_join_side(&expr, self.join_side)?;
349        // We don't push down things that do not depend on this side (other side or no side).
350        if !depends_on_this_side {
351            return Ok(Transformed::no(expr));
352        }
353
354        // Recurse if there is a dependency to both sides or if the entire expression is volatile.
355        let depends_on_other_side =
356            self.depends_on_join_side(&expr, self.join_side.negate())?;
357        let is_volatile = is_volatile_expression_tree(expr.as_ref());
358        if depends_on_other_side || is_volatile {
359            return expr.map_children(|expr| self.rewrite(expr));
360        }
361
362        // There is only a dependency on this side.
363
364        // If this expression has no children, we do not push down, as it should already be a column
365        // reference.
366        if expr.children().is_empty() {
367            return Ok(Transformed::no(expr));
368        }
369
370        // Otherwise, we push down a projection.
371        let alias = self.alias_generator.next("join_proj_push_down");
372        let idx = self.create_new_column(alias.clone(), expr)?;
373
374        Ok(Transformed::yes(
375            Arc::new(Column::new(&alias, idx)) as Arc<dyn PhysicalExpr>
376        ))
377    }
378
379    /// Creates a new column in the current join side.
380    fn create_new_column(
381        &mut self,
382        name: String,
383        expr: Arc<dyn PhysicalExpr>,
384    ) -> Result<usize> {
385        // First, add a new projection. The expression must be rewritten, as it is no longer
386        // executed against the filter schema.
387        let new_idx = self.join_side_projections.len();
388        let rewritten_expr = expr.transform_up_with_lambdas_params(|expr, lambdas_params| {
389            Ok(match expr.as_any().downcast_ref::<Column>() {
390                Some(column) if !lambdas_params.contains(column.name()) => {
391                    let intermediate_column =
392                        &self.intermediate_column_indices[column.index()];
393                    assert_eq!(intermediate_column.side, self.join_side);
394
395                    let join_side_index = intermediate_column.index;
396                    let field = self.join_side_schema.field(join_side_index);
397                    let new_column = Column::new(field.name(), join_side_index);
398                    Transformed::yes(Arc::new(new_column) as Arc<dyn PhysicalExpr>)
399                }
400                _ => Transformed::no(expr),
401            })
402        })?;
403        self.join_side_projections.push((rewritten_expr.data, name));
404
405        // Then, update the column indices
406        let new_intermediate_idx = self.intermediate_column_indices.len();
407        let idx = ColumnIndex {
408            index: new_idx,
409            side: self.join_side,
410        };
411        self.intermediate_column_indices.push(idx);
412
413        Ok(new_intermediate_idx)
414    }
415
416    /// Checks whether the entire expression depends on the given `join_side`.
417    fn depends_on_join_side(
418        &mut self,
419        expr: &Arc<dyn PhysicalExpr>,
420        join_side: JoinSide,
421    ) -> Result<bool> {
422        let mut result = false;
423        expr.apply_with_lambdas_params(|expr, lambdas_params| {
424            match expr.as_any().downcast_ref::<Column>() {
425                Some(c) if !lambdas_params.contains(c.name()) => {
426                    let column_index = &self.intermediate_column_indices[c.index()];
427                    if column_index.side == join_side {
428                        result = true;
429                        return Ok(TreeNodeRecursion::Stop);
430                    }
431                    Ok(TreeNodeRecursion::Continue)
432                }
433                _ => Ok(TreeNodeRecursion::Continue),
434            }
435        })?;
436
437        Ok(result)
438    }
439}
440
441fn is_volatile_expression_tree(expr: &dyn PhysicalExpr) -> bool {
442    if expr.is_volatile_node() {
443        return true;
444    }
445
446    expr.children()
447        .iter()
448        .map(|expr| is_volatile_expression_tree(expr.as_ref()))
449        .reduce(|lhs, rhs| lhs || rhs)
450        .unwrap_or(false)
451}
452
453#[cfg(test)]
454mod test {
455    use super::*;
456    use arrow::datatypes::{DataType, Field, FieldRef, Schema};
457    use datafusion_expr_common::operator::Operator;
458    use datafusion_functions::math::random;
459    use datafusion_physical_expr::expressions::{binary, lit};
460    use datafusion_physical_expr::ScalarFunctionExpr;
461    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
462    use datafusion_physical_plan::displayable;
463    use datafusion_physical_plan::empty::EmptyExec;
464    use insta::assert_snapshot;
465    use std::sync::Arc;
466
467    #[tokio::test]
468    async fn no_computation_does_not_project() -> Result<()> {
469        let (left_schema, right_schema) = create_simple_schemas();
470        let optimized_plan = run_test(
471            left_schema,
472            right_schema,
473            a_x(),
474            None,
475            a_greater_than_x,
476            JoinType::Inner,
477        )?;
478
479        assert_snapshot!(optimized_plan, @r"
480        NestedLoopJoinExec: join_type=Inner, filter=a@0 > x@1
481          EmptyExec
482          EmptyExec
483        ");
484        Ok(())
485    }
486
487    #[tokio::test]
488    async fn simple_push_down() -> Result<()> {
489        let (left_schema, right_schema) = create_simple_schemas();
490        let optimized_plan = run_test(
491            left_schema,
492            right_schema,
493            a_x(),
494            None,
495            a_plus_one_greater_than_x_plus_one,
496            JoinType::Inner,
497        )?;
498
499        assert_snapshot!(optimized_plan, @r"
500        NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, x@2]
501          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
502            EmptyExec
503          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
504            EmptyExec
505        ");
506        Ok(())
507    }
508
509    #[tokio::test]
510    async fn does_not_push_down_short_circuiting_expressions() -> Result<()> {
511        let (left_schema, right_schema) = create_simple_schemas();
512        let optimized_plan = run_test(
513            left_schema,
514            right_schema,
515            a_x(),
516            None,
517            |schema| {
518                binary(
519                    lit(false),
520                    Operator::And,
521                    a_plus_one_greater_than_x_plus_one(schema)?,
522                    schema,
523                )
524            },
525            JoinType::Inner,
526        )?;
527
528        assert_snapshot!(optimized_plan, @r"
529        NestedLoopJoinExec: join_type=Inner, filter=false AND join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, x@2]
530          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
531            EmptyExec
532          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
533            EmptyExec
534        ");
535        Ok(())
536    }
537
538    #[tokio::test]
539    async fn does_not_push_down_volatile_functions() -> Result<()> {
540        let (left_schema, right_schema) = create_simple_schemas();
541        let optimized_plan = run_test(
542            left_schema,
543            right_schema,
544            a_x(),
545            None,
546            a_plus_rand_greater_than_x,
547            JoinType::Inner,
548        )?;
549
550        assert_snapshot!(optimized_plan, @r"
551        NestedLoopJoinExec: join_type=Inner, filter=a@0 + rand() > x@1
552          EmptyExec
553          EmptyExec
554        ");
555        Ok(())
556    }
557
558    #[tokio::test]
559    async fn complex_schema_push_down() -> Result<()> {
560        let (left_schema, right_schema) = create_complex_schemas();
561
562        let optimized_plan = run_test(
563            left_schema,
564            right_schema,
565            a_b_x_z(),
566            None,
567            a_plus_b_greater_than_x_plus_z,
568            JoinType::Inner,
569        )?;
570
571        assert_snapshot!(optimized_plan, @r"
572        NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, b@1, c@2, x@4, y@5, z@6]
573          ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, a@0 + b@1 as join_proj_push_down_1]
574            EmptyExec
575          ProjectionExec: expr=[x@0 as x, y@1 as y, z@2 as z, x@0 + z@2 as join_proj_push_down_2]
576            EmptyExec
577        ");
578        Ok(())
579    }
580
581    #[tokio::test]
582    async fn push_down_with_existing_projections() -> Result<()> {
583        let (left_schema, right_schema) = create_complex_schemas();
584
585        let optimized_plan = run_test(
586            left_schema,
587            right_schema,
588            a_b_x_z(),
589            Some(vec![1, 3, 5]), // ("b", "x", "z")
590            a_plus_b_greater_than_x_plus_z,
591            JoinType::Inner,
592        )?;
593
594        assert_snapshot!(optimized_plan, @r"
595        NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[b@1, x@4, z@6]
596          ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, a@0 + b@1 as join_proj_push_down_1]
597            EmptyExec
598          ProjectionExec: expr=[x@0 as x, y@1 as y, z@2 as z, x@0 + z@2 as join_proj_push_down_2]
599            EmptyExec
600        ");
601        Ok(())
602    }
603
604    #[tokio::test]
605    async fn left_semi_join_projection() -> Result<()> {
606        let (left_schema, right_schema) = create_simple_schemas();
607
608        let left_semi_join_plan = run_test(
609            left_schema.clone(),
610            right_schema.clone(),
611            a_x(),
612            None,
613            a_plus_one_greater_than_x_plus_one,
614            JoinType::LeftSemi,
615        )?;
616
617        assert_snapshot!(left_semi_join_plan, @r"
618        NestedLoopJoinExec: join_type=LeftSemi, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0]
619          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
620            EmptyExec
621          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
622            EmptyExec
623        ");
624        Ok(())
625    }
626
627    #[tokio::test]
628    async fn right_semi_join_projection() -> Result<()> {
629        let (left_schema, right_schema) = create_simple_schemas();
630        let right_semi_join_plan = run_test(
631            left_schema,
632            right_schema,
633            a_x(),
634            None,
635            a_plus_one_greater_than_x_plus_one,
636            JoinType::RightSemi,
637        )?;
638        assert_snapshot!(right_semi_join_plan, @r"
639        NestedLoopJoinExec: join_type=RightSemi, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[x@0]
640          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
641            EmptyExec
642          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
643            EmptyExec
644        ");
645        Ok(())
646    }
647
648    fn run_test(
649        left_schema: Schema,
650        right_schema: Schema,
651        column_indices: Vec<ColumnIndex>,
652        existing_projections: Option<Vec<usize>>,
653        filter_expr_builder: impl FnOnce(&Schema) -> Result<Arc<dyn PhysicalExpr>>,
654        join_type: JoinType,
655    ) -> Result<String> {
656        let left = Arc::new(EmptyExec::new(Arc::new(left_schema.clone())));
657        let right = Arc::new(EmptyExec::new(Arc::new(right_schema.clone())));
658
659        let join_fields: Vec<_> = column_indices
660            .iter()
661            .map(|ci| match ci.side {
662                JoinSide::Left => left_schema.field(ci.index).clone(),
663                JoinSide::Right => right_schema.field(ci.index).clone(),
664                JoinSide::None => unreachable!(),
665            })
666            .collect();
667        let join_schema = Arc::new(Schema::new(join_fields));
668
669        let filter_expr = filter_expr_builder(join_schema.as_ref())?;
670
671        let join_filter = JoinFilter::new(filter_expr, column_indices, join_schema);
672
673        let join = NestedLoopJoinExec::try_new(
674            left,
675            right,
676            Some(join_filter),
677            &join_type,
678            existing_projections,
679        )?;
680
681        let optimizer = ProjectionPushdown::new();
682        let optimized_plan = optimizer.optimize(Arc::new(join), &Default::default())?;
683
684        let displayable_plan = displayable(optimized_plan.as_ref()).indent(false);
685        Ok(displayable_plan.to_string())
686    }
687
688    fn create_simple_schemas() -> (Schema, Schema) {
689        let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
690        let right_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]);
691
692        (left_schema, right_schema)
693    }
694
695    fn create_complex_schemas() -> (Schema, Schema) {
696        let left_schema = Schema::new(vec![
697            Field::new("a", DataType::Int32, false),
698            Field::new("b", DataType::Int32, false),
699            Field::new("c", DataType::Int32, false),
700        ]);
701
702        let right_schema = Schema::new(vec![
703            Field::new("x", DataType::Int32, false),
704            Field::new("y", DataType::Int32, false),
705            Field::new("z", DataType::Int32, false),
706        ]);
707
708        (left_schema, right_schema)
709    }
710
711    fn a_x() -> Vec<ColumnIndex> {
712        vec![
713            ColumnIndex {
714                index: 0,
715                side: JoinSide::Left,
716            },
717            ColumnIndex {
718                index: 0,
719                side: JoinSide::Right,
720            },
721        ]
722    }
723
724    fn a_b_x_z() -> Vec<ColumnIndex> {
725        vec![
726            ColumnIndex {
727                index: 0,
728                side: JoinSide::Left,
729            },
730            ColumnIndex {
731                index: 1,
732                side: JoinSide::Left,
733            },
734            ColumnIndex {
735                index: 0,
736                side: JoinSide::Right,
737            },
738            ColumnIndex {
739                index: 2,
740                side: JoinSide::Right,
741            },
742        ]
743    }
744
745    fn a_plus_one_greater_than_x_plus_one(
746        join_schema: &Schema,
747    ) -> Result<Arc<dyn PhysicalExpr>> {
748        let left_expr = binary(
749            Arc::new(Column::new("a", 0)),
750            Operator::Plus,
751            lit(1),
752            join_schema,
753        )?;
754        let right_expr = binary(
755            Arc::new(Column::new("x", 1)),
756            Operator::Plus,
757            lit(1),
758            join_schema,
759        )?;
760        binary(left_expr, Operator::Gt, right_expr, join_schema)
761    }
762
763    fn a_plus_rand_greater_than_x(join_schema: &Schema) -> Result<Arc<dyn PhysicalExpr>> {
764        let left_expr = binary(
765            Arc::new(Column::new("a", 0)),
766            Operator::Plus,
767            Arc::new(ScalarFunctionExpr::new(
768                "rand",
769                random(),
770                vec![],
771                FieldRef::new(Field::new("out", DataType::Float64, false)),
772                Arc::new(ConfigOptions::default()),
773            )),
774            join_schema,
775        )?;
776        let right_expr = Arc::new(Column::new("x", 1));
777        binary(left_expr, Operator::Gt, right_expr, join_schema)
778    }
779
780    fn a_greater_than_x(join_schema: &Schema) -> Result<Arc<dyn PhysicalExpr>> {
781        binary(
782            Arc::new(Column::new("a", 0)),
783            Operator::Gt,
784            Arc::new(Column::new("x", 1)),
785            join_schema,
786        )
787    }
788
789    fn a_plus_b_greater_than_x_plus_z(
790        join_schema: &Schema,
791    ) -> Result<Arc<dyn PhysicalExpr>> {
792        let lhs = binary(
793            Arc::new(Column::new("a", 0)),
794            Operator::Plus,
795            Arc::new(Column::new("b", 1)),
796            join_schema,
797        )?;
798        let rhs = binary(
799            Arc::new(Column::new("x", 2)),
800            Operator::Plus,
801            Arc::new(Column::new("z", 3)),
802            join_schema,
803        )?;
804        binary(lhs, Operator::Gt, rhs, join_schema)
805    }
806}