datafusion_optimizer/
scalar_subquery_to_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s
19
20use std::collections::{BTreeSet, HashMap};
21use std::sync::Arc;
22
23use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR};
24use crate::optimizer::ApplyOrder;
25use crate::utils::{evaluates_to_null, replace_qualified_name};
26use crate::{OptimizerConfig, OptimizerRule};
27
28use crate::analyzer::type_coercion::TypeCoercionRewriter;
29use datafusion_common::alias::AliasGenerator;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
32};
33use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue};
34use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
35use datafusion_expr::logical_plan::{JoinType, Subquery};
36use datafusion_expr::utils::conjunction;
37use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder};
38
39/// Optimizer rule for rewriting subquery filters to joins
40/// and places additional projection on top of the filter, to preserve
41/// original schema.
42#[derive(Default, Debug)]
43pub struct ScalarSubqueryToJoin {}
44
45impl ScalarSubqueryToJoin {
46    #[allow(missing_docs)]
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Finds expressions that have a scalar subquery in them (and recurses when found)
52    ///
53    /// # Arguments
54    /// * `predicate` - A conjunction to split and search
55    ///
56    /// Returns a tuple (subqueries, alias)
57    fn extract_subquery_exprs(
58        &self,
59        predicate: &Expr,
60        alias_gen: &Arc<AliasGenerator>,
61    ) -> Result<(Vec<(Subquery, String)>, Expr)> {
62        let mut extract = ExtractScalarSubQuery {
63            sub_query_info: vec![],
64            alias_gen,
65        };
66        predicate
67            .clone()
68            .rewrite(&mut extract)
69            .data()
70            .map(|new_expr| (extract.sub_query_info, new_expr))
71    }
72}
73
74impl OptimizerRule for ScalarSubqueryToJoin {
75    fn supports_rewrite(&self) -> bool {
76        true
77    }
78
79    fn rewrite(
80        &self,
81        plan: LogicalPlan,
82        config: &dyn OptimizerConfig,
83    ) -> Result<Transformed<LogicalPlan>> {
84        match plan {
85            LogicalPlan::Filter(filter) => {
86                // Optimization: skip the rest of the rule and its copies if
87                // there are no scalar subqueries
88                if !contains_scalar_subquery(&filter.predicate) {
89                    return Ok(Transformed::no(LogicalPlan::Filter(filter)));
90                }
91
92                let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs(
93                    &filter.predicate,
94                    config.alias_generator(),
95                )?;
96
97                if subqueries.is_empty() {
98                    return internal_err!("Expected subqueries not found in filter");
99                }
100
101                // iterate through all subqueries in predicate, turning each into a left join
102                let mut cur_input = filter.input.as_ref().clone();
103                for (subquery, alias) in subqueries {
104                    if let Some((optimized_subquery, expr_check_map)) =
105                        build_join(&subquery, &cur_input, &alias)?
106                    {
107                        if !expr_check_map.is_empty() {
108                            rewrite_expr = rewrite_expr
109                                .transform_up_with_lambdas_params(
110                                    |expr, lambdas_params| {
111                                        // replace column references with entry in map, if it exists
112                                        if let Some(map_expr) = expr
113                                            .try_as_col()
114                                            .filter(|c| {
115                                                !c.is_lambda_parameter(lambdas_params)
116                                            })
117                                            .and_then(|col| expr_check_map.get(&col.name))
118                                        {
119                                            Ok(Transformed::yes(map_expr.clone()))
120                                        } else {
121                                            Ok(Transformed::no(expr))
122                                        }
123                                    },
124                                )
125                                .data()?;
126                        }
127                        cur_input = optimized_subquery;
128                    } else {
129                        // if we can't handle all of the subqueries then bail for now
130                        return Ok(Transformed::no(LogicalPlan::Filter(filter)));
131                    }
132                }
133
134                // Preserve original schema as new Join might have more fields than what Filter & parents expect.
135                let projection =
136                    filter.input.schema().columns().into_iter().map(Expr::from);
137                let new_plan = LogicalPlanBuilder::from(cur_input)
138                    .filter(rewrite_expr)?
139                    .project(projection)?
140                    .build()?;
141                Ok(Transformed::yes(new_plan))
142            }
143            LogicalPlan::Projection(projection) => {
144                // Optimization: skip the rest of the rule and its copies if
145                // there are no scalar subqueries
146                if !projection.expr.iter().any(contains_scalar_subquery) {
147                    return Ok(Transformed::no(LogicalPlan::Projection(projection)));
148                }
149
150                let mut all_subqueries = vec![];
151                let mut expr_to_rewrite_expr_map = HashMap::new();
152                let mut subquery_to_expr_map = HashMap::new();
153                for expr in projection.expr.iter() {
154                    let (subqueries, rewrite_exprs) =
155                        self.extract_subquery_exprs(expr, config.alias_generator())?;
156                    for (subquery, _) in &subqueries {
157                        subquery_to_expr_map.insert(subquery.clone(), expr.clone());
158                    }
159                    all_subqueries.extend(subqueries);
160                    expr_to_rewrite_expr_map.insert(expr, rewrite_exprs);
161                }
162                if all_subqueries.is_empty() {
163                    return internal_err!("Expected subqueries not found in projection");
164                }
165                // iterate through all subqueries in predicate, turning each into a left join
166                let mut cur_input = projection.input.as_ref().clone();
167                for (subquery, alias) in all_subqueries {
168                    if let Some((optimized_subquery, expr_check_map)) =
169                        build_join(&subquery, &cur_input, &alias)?
170                    {
171                        cur_input = optimized_subquery;
172                        if !expr_check_map.is_empty() {
173                            if let Some(expr) = subquery_to_expr_map.get(&subquery) {
174                                if let Some(rewrite_expr) =
175                                    expr_to_rewrite_expr_map.get(expr)
176                                {
177                                    let new_expr = rewrite_expr
178                                        .clone()
179                                        .transform_up_with_lambdas_params(
180                                            |expr, lambdas_params| {
181                                                // replace column references with entry in map, if it exists
182                                                if let Some(map_expr) = expr
183                                                    .try_as_col()
184                                                    .filter(|c| {
185                                                        !c.is_lambda_parameter(
186                                                            lambdas_params,
187                                                        )
188                                                    })
189                                                    .and_then(|col| {
190                                                        expr_check_map.get(&col.name)
191                                                    })
192                                                {
193                                                    Ok(Transformed::yes(map_expr.clone()))
194                                                } else {
195                                                    Ok(Transformed::no(expr))
196                                                }
197                                            },
198                                        )
199                                        .data()?;
200                                    expr_to_rewrite_expr_map.insert(expr, new_expr);
201                                }
202                            }
203                        }
204                    } else {
205                        // if we can't handle all of the subqueries then bail for now
206                        return Ok(Transformed::no(LogicalPlan::Projection(projection)));
207                    }
208                }
209
210                let mut proj_exprs = vec![];
211                for expr in projection.expr.iter() {
212                    let old_expr_name = expr.schema_name().to_string();
213                    let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap();
214                    let new_expr_name = new_expr.schema_name().to_string();
215                    if new_expr_name != old_expr_name {
216                        proj_exprs.push(new_expr.clone().alias(old_expr_name))
217                    } else {
218                        proj_exprs.push(new_expr.clone());
219                    }
220                }
221                let new_plan = LogicalPlanBuilder::from(cur_input)
222                    .project(proj_exprs)?
223                    .build()?;
224                Ok(Transformed::yes(new_plan))
225            }
226
227            plan => Ok(Transformed::no(plan)),
228        }
229    }
230
231    fn name(&self) -> &str {
232        "scalar_subquery_to_join"
233    }
234
235    fn apply_order(&self) -> Option<ApplyOrder> {
236        Some(ApplyOrder::TopDown)
237    }
238}
239
240/// Returns true if the expression has a scalar subquery somewhere in it
241/// false otherwise
242fn contains_scalar_subquery(expr: &Expr) -> bool {
243    expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_))))
244        .expect("Inner is always Ok")
245}
246
247struct ExtractScalarSubQuery<'a> {
248    sub_query_info: Vec<(Subquery, String)>,
249    alias_gen: &'a Arc<AliasGenerator>,
250}
251
252impl TreeNodeRewriter for ExtractScalarSubQuery<'_> {
253    type Node = Expr;
254
255    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
256        match expr {
257            Expr::ScalarSubquery(subquery) => {
258                let subqry_alias = self.alias_gen.next("__scalar_sq");
259                self.sub_query_info
260                    .push((subquery.clone(), subqry_alias.clone()));
261                let scalar_expr = subquery
262                    .subquery
263                    .head_output_expr()?
264                    .map_or(plan_err!("single expression required."), Ok)?;
265                Ok(Transformed::new(
266                    Expr::Column(create_col_from_scalar_expr(
267                        &scalar_expr,
268                        subqry_alias,
269                    )?),
270                    true,
271                    TreeNodeRecursion::Jump,
272                ))
273            }
274            _ => Ok(Transformed::no(expr)),
275        }
276    }
277}
278
279/// Takes a query like:
280///
281/// ```text
282/// select id from customers where balance >
283///     (select avg(total) from orders where orders.c_id = customers.id)
284/// ```
285///
286/// and optimizes it into:
287///
288/// ```text
289/// select c.id from customers c
290/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id
291/// where c.balance > o.val
292/// ```
293///
294/// Or a query like:
295///
296/// ```text
297/// select id from customers where balance >
298///     (select avg(total) from orders)
299/// ```
300///
301/// and optimizes it into:
302///
303/// ```text
304/// select c.id from customers c
305/// left join (select avg(total) as val from orders) a
306/// where c.balance > a.val
307/// ```
308///
309/// # Arguments
310///
311/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders)
312/// * `filter_input` - The non-subquery portion (from customers)
313/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y)
314/// * `subquery_alias` - Subquery aliases
315fn build_join(
316    subquery: &Subquery,
317    filter_input: &LogicalPlan,
318    subquery_alias: &str,
319) -> Result<Option<(LogicalPlan, HashMap<String, Expr>)>> {
320    let subquery_plan = subquery.subquery.as_ref();
321    let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true);
322    let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?;
323    if !pull_up.can_pull_up {
324        return Ok(None);
325    }
326
327    let collected_count_expr_map =
328        pull_up.collected_count_expr_map.get(&new_plan).cloned();
329    let sub_query_alias = LogicalPlanBuilder::from(new_plan)
330        .alias(subquery_alias.to_string())?
331        .build()?;
332
333    let mut all_correlated_cols = BTreeSet::new();
334    pull_up
335        .correlated_subquery_cols_map
336        .values()
337        .for_each(|cols| all_correlated_cols.extend(cols.clone()));
338
339    // alias the join filter
340    let join_filter_opt =
341        conjunction(pull_up.join_filters).map_or(Ok(None), |filter| {
342            replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some)
343        })?;
344
345    // join our sub query into the main plan
346    let new_plan = if join_filter_opt.is_none() {
347        match filter_input {
348            LogicalPlan::EmptyRelation(EmptyRelation {
349                produce_one_row: true,
350                schema: _,
351            }) => sub_query_alias,
352            _ => {
353                // if not correlated, group down to 1 row and left join on that (preserving row count)
354                LogicalPlanBuilder::from(filter_input.clone())
355                    .join_on(
356                        sub_query_alias,
357                        JoinType::Left,
358                        vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)],
359                    )?
360                    .build()?
361            }
362        }
363    } else {
364        // left join if correlated, grouping by the join keys so we don't change row count
365        LogicalPlanBuilder::from(filter_input.clone())
366            .join_on(sub_query_alias, JoinType::Left, join_filter_opt)?
367            .build()?
368    };
369    let mut computation_project_expr = HashMap::new();
370    if let Some(expr_map) = collected_count_expr_map {
371        for (name, result) in expr_map {
372            if evaluates_to_null(result.clone(), result.column_refs())? {
373                // If expr always returns null when column is null, skip processing
374                continue;
375            }
376            let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr {
377                Expr::Case(expr::Case {
378                    expr: None,
379                    when_then_expr: vec![
380                        (
381                            Box::new(Expr::IsNull(Box::new(Expr::Column(
382                                Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
383                            )))),
384                            Box::new(result),
385                        ),
386                        (
387                            Box::new(Expr::Not(Box::new(filter.clone()))),
388                            Box::new(Expr::Literal(ScalarValue::Null, None)),
389                        ),
390                    ],
391                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
392                        name.clone(),
393                    )))),
394                })
395            } else {
396                Expr::Case(expr::Case {
397                    expr: None,
398                    when_then_expr: vec![(
399                        Box::new(Expr::IsNull(Box::new(Expr::Column(
400                            Column::new_unqualified(UN_MATCHED_ROW_INDICATOR),
401                        )))),
402                        Box::new(result),
403                    )],
404                    else_expr: Some(Box::new(Expr::Column(Column::new_unqualified(
405                        name.clone(),
406                    )))),
407                })
408            };
409            let mut expr_rewrite = TypeCoercionRewriter {
410                schema: new_plan.schema(),
411            };
412            computation_project_expr.insert(
413                name,
414                computer_expr
415                    .rewrite_with_schema(new_plan.schema(), &mut expr_rewrite)
416                    .data()?,
417            );
418        }
419    }
420
421    Ok(Some((new_plan, computation_project_expr)))
422}
423
424#[cfg(test)]
425mod tests {
426    use std::ops::Add;
427
428    use super::*;
429    use crate::test::*;
430
431    use arrow::datatypes::DataType;
432    use datafusion_expr::test::function_stub::sum;
433
434    use crate::assert_optimized_plan_eq_display_indent_snapshot;
435    use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between};
436    use datafusion_functions_aggregate::min_max::{max, min};
437
438    macro_rules! assert_optimized_plan_equal {
439        (
440            $plan:expr,
441            @ $expected:literal $(,)?
442        ) => {{
443            let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ScalarSubqueryToJoin::new());
444            assert_optimized_plan_eq_display_indent_snapshot!(
445                rule,
446                $plan,
447                @ $expected,
448            )
449        }};
450    }
451
452    /// Test multiple correlated subqueries
453    #[test]
454    fn multiple_subqueries() -> Result<()> {
455        let orders = Arc::new(
456            LogicalPlanBuilder::from(scan_tpch_table("orders"))
457                .filter(
458                    col("orders.o_custkey")
459                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
460                )?
461                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
462                .project(vec![max(col("orders.o_custkey"))])?
463                .build()?,
464        );
465
466        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
467            .filter(
468                lit(1)
469                    .lt(scalar_subquery(Arc::clone(&orders)))
470                    .and(lit(1).lt(scalar_subquery(orders))),
471            )?
472            .project(vec![col("customer.c_custkey")])?
473            .build()?;
474
475        assert_optimized_plan_equal!(
476            plan,
477            @r"
478        Projection: customer.c_custkey [c_custkey:Int64]
479          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
480            Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
481              Left Join:  Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
482                Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
483                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
484                  SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
485                    Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
486                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
487                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
488                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
489                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
490                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
491                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
492        "
493        )
494    }
495
496    /// Test recursive correlated subqueries
497    #[test]
498    fn recursive_subqueries() -> Result<()> {
499        let lineitem = Arc::new(
500            LogicalPlanBuilder::from(scan_tpch_table("lineitem"))
501                .filter(
502                    col("lineitem.l_orderkey")
503                        .eq(out_ref_col(DataType::Int64, "orders.o_orderkey")),
504                )?
505                .aggregate(
506                    Vec::<Expr>::new(),
507                    vec![sum(col("lineitem.l_extendedprice"))],
508                )?
509                .project(vec![sum(col("lineitem.l_extendedprice"))])?
510                .build()?,
511        );
512
513        let orders = Arc::new(
514            LogicalPlanBuilder::from(scan_tpch_table("orders"))
515                .filter(
516                    col("orders.o_custkey")
517                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey"))
518                        .and(col("orders.o_totalprice").lt(scalar_subquery(lineitem))),
519                )?
520                .aggregate(Vec::<Expr>::new(), vec![sum(col("orders.o_totalprice"))])?
521                .project(vec![sum(col("orders.o_totalprice"))])?
522                .build()?,
523        );
524
525        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
526            .filter(col("customer.c_acctbal").lt(scalar_subquery(orders)))?
527            .project(vec![col("customer.c_custkey")])?
528            .build()?;
529
530        assert_optimized_plan_equal!(
531            plan,
532            @r"
533        Projection: customer.c_custkey [c_custkey:Int64]
534          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
535            Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
536              Left Join:  Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]
537                TableScan: customer [c_custkey:Int64, c_name:Utf8]
538                SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
539                  Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]
540                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]
541                      Projection: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
542                        Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
543                          Left Join:  Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]
544                            TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
545                            SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
546                              Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]
547                                Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]
548                                  TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]
549        "
550        )
551    }
552
553    /// Test for correlated scalar subquery filter with additional subquery filters
554    #[test]
555    fn scalar_subquery_with_subquery_filters() -> Result<()> {
556        let sq = Arc::new(
557            LogicalPlanBuilder::from(scan_tpch_table("orders"))
558                .filter(
559                    out_ref_col(DataType::Int64, "customer.c_custkey")
560                        .eq(col("orders.o_custkey"))
561                        .and(col("o_orderkey").eq(lit(1))),
562                )?
563                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
564                .project(vec![max(col("orders.o_custkey"))])?
565                .build()?,
566        );
567
568        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
569            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
570            .project(vec![col("customer.c_custkey")])?
571            .build()?;
572
573        assert_optimized_plan_equal!(
574            plan,
575            @r"
576        Projection: customer.c_custkey [c_custkey:Int64]
577          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
578            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
579              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
580                TableScan: customer [c_custkey:Int64, c_name:Utf8]
581                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
582                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
583                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
584                      Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
585                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
586        "
587        )
588    }
589
590    /// Test for correlated scalar subquery with no columns in schema
591    #[test]
592    fn scalar_subquery_no_cols() -> Result<()> {
593        let sq = Arc::new(
594            LogicalPlanBuilder::from(scan_tpch_table("orders"))
595                .filter(
596                    out_ref_col(DataType::Int64, "customer.c_custkey")
597                        .eq(out_ref_col(DataType::Int64, "customer.c_custkey")),
598                )?
599                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
600                .project(vec![max(col("orders.o_custkey"))])?
601                .build()?,
602        );
603
604        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
605            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
606            .project(vec![col("customer.c_custkey")])?
607            .build()?;
608
609        // it will optimize, but fail for the same reason the unoptimized query would
610        assert_optimized_plan_equal!(
611            plan,
612            @r"
613        Projection: customer.c_custkey [c_custkey:Int64]
614          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
615            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
616              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
617                TableScan: customer [c_custkey:Int64, c_name:Utf8]
618                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
619                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
620                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
621                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
622        "
623        )
624    }
625
626    /// Test for scalar subquery with both columns in schema
627    #[test]
628    fn scalar_subquery_with_no_correlated_cols() -> Result<()> {
629        let sq = Arc::new(
630            LogicalPlanBuilder::from(scan_tpch_table("orders"))
631                .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))?
632                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
633                .project(vec![max(col("orders.o_custkey"))])?
634                .build()?,
635        );
636
637        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
638            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
639            .project(vec![col("customer.c_custkey")])?
640            .build()?;
641
642        assert_optimized_plan_equal!(
643            plan,
644            @r"
645        Projection: customer.c_custkey [c_custkey:Int64]
646          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
647            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
648              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
649                TableScan: customer [c_custkey:Int64, c_name:Utf8]
650                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
651                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
652                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
653                      Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
654                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
655        "
656        )
657    }
658
659    /// Test for correlated scalar subquery not equal
660    #[test]
661    fn scalar_subquery_where_not_eq() -> Result<()> {
662        let sq = Arc::new(
663            LogicalPlanBuilder::from(scan_tpch_table("orders"))
664                .filter(
665                    out_ref_col(DataType::Int64, "customer.c_custkey")
666                        .not_eq(col("orders.o_custkey")),
667                )?
668                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
669                .project(vec![max(col("orders.o_custkey"))])?
670                .build()?,
671        );
672
673        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
674            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
675            .project(vec![col("customer.c_custkey")])?
676            .build()?;
677
678        // Unsupported predicate, subquery should not be decorrelated
679        assert_optimized_plan_equal!(
680            plan,
681            @r"
682        Projection: customer.c_custkey [c_custkey:Int64]
683          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
684            Subquery: [max(orders.o_custkey):Int64;N]
685              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
686                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
687                  Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
688                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
689            TableScan: customer [c_custkey:Int64, c_name:Utf8]
690        "
691        )
692    }
693
694    /// Test for correlated scalar subquery less than
695    #[test]
696    fn scalar_subquery_where_less_than() -> Result<()> {
697        let sq = Arc::new(
698            LogicalPlanBuilder::from(scan_tpch_table("orders"))
699                .filter(
700                    out_ref_col(DataType::Int64, "customer.c_custkey")
701                        .lt(col("orders.o_custkey")),
702                )?
703                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
704                .project(vec![max(col("orders.o_custkey"))])?
705                .build()?,
706        );
707
708        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
709            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
710            .project(vec![col("customer.c_custkey")])?
711            .build()?;
712
713        // Unsupported predicate, subquery should not be decorrelated
714        assert_optimized_plan_equal!(
715            plan,
716            @r"
717        Projection: customer.c_custkey [c_custkey:Int64]
718          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
719            Subquery: [max(orders.o_custkey):Int64;N]
720              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
721                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
722                  Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
723                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
724            TableScan: customer [c_custkey:Int64, c_name:Utf8]
725        "
726        )
727    }
728
729    /// Test for correlated scalar subquery filter with subquery disjunction
730    #[test]
731    fn scalar_subquery_with_subquery_disjunction() -> Result<()> {
732        let sq = Arc::new(
733            LogicalPlanBuilder::from(scan_tpch_table("orders"))
734                .filter(
735                    out_ref_col(DataType::Int64, "customer.c_custkey")
736                        .eq(col("orders.o_custkey"))
737                        .or(col("o_orderkey").eq(lit(1))),
738                )?
739                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
740                .project(vec![max(col("orders.o_custkey"))])?
741                .build()?,
742        );
743
744        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
745            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
746            .project(vec![col("customer.c_custkey")])?
747            .build()?;
748
749        // Unsupported predicate, subquery should not be decorrelated
750        assert_optimized_plan_equal!(
751            plan,
752            @r"
753        Projection: customer.c_custkey [c_custkey:Int64]
754          Filter: customer.c_custkey = (<subquery>) [c_custkey:Int64, c_name:Utf8]
755            Subquery: [max(orders.o_custkey):Int64;N]
756              Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
757                Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
758                  Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
759                    TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
760            TableScan: customer [c_custkey:Int64, c_name:Utf8]
761        "
762        )
763    }
764
765    /// Test for correlated scalar without projection
766    #[test]
767    fn scalar_subquery_no_projection() -> Result<()> {
768        let sq = Arc::new(
769            LogicalPlanBuilder::from(scan_tpch_table("orders"))
770                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
771                .build()?,
772        );
773
774        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
775            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
776            .project(vec![col("customer.c_custkey")])?
777            .build()?;
778
779        let expected = "Error during planning: Scalar subquery should only return one column, but found 4: orders.o_orderkey, orders.o_custkey, orders.o_orderstatus, orders.o_totalprice";
780        assert_analyzer_check_err(vec![], plan, expected);
781        Ok(())
782    }
783
784    /// Test for correlated scalar expressions
785    #[test]
786    fn scalar_subquery_project_expr() -> Result<()> {
787        let sq = Arc::new(
788            LogicalPlanBuilder::from(scan_tpch_table("orders"))
789                .filter(
790                    out_ref_col(DataType::Int64, "customer.c_custkey")
791                        .eq(col("orders.o_custkey")),
792                )?
793                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
794                .project(vec![col("max(orders.o_custkey)").add(lit(1))])?
795                .build()?,
796        );
797
798        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
799            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
800            .project(vec![col("customer.c_custkey")])?
801            .build()?;
802
803        assert_optimized_plan_equal!(
804            plan,
805            @r"
806        Projection: customer.c_custkey [c_custkey:Int64]
807          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
808            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
809              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
810                TableScan: customer [c_custkey:Int64, c_name:Utf8]
811                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
812                  Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]
813                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
814                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
815        "
816        )
817    }
818
819    /// Test for correlated scalar subquery with non-strong project
820    #[test]
821    fn scalar_subquery_with_non_strong_project() -> Result<()> {
822        let case = Expr::Case(expr::Case {
823            expr: None,
824            when_then_expr: vec![(
825                Box::new(col("max(orders.o_totalprice)")),
826                Box::new(lit("a")),
827            )],
828            else_expr: Some(Box::new(lit("b"))),
829        });
830
831        let sq = Arc::new(
832            LogicalPlanBuilder::from(scan_tpch_table("orders"))
833                .filter(
834                    out_ref_col(DataType::Int64, "customer.c_custkey")
835                        .eq(col("orders.o_custkey")),
836                )?
837                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_totalprice"))])?
838                .project(vec![case])?
839                .build()?,
840        );
841
842        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
843            .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])?
844            .build()?;
845
846        assert_optimized_plan_equal!(
847            plan,
848            @r#"
849        Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N]
850          Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]
851            TableScan: customer [c_custkey:Int64, c_name:Utf8]
852            SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
853              Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean]
854                Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]
855                  TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
856        "#
857        )
858    }
859
860    /// Test for correlated scalar subquery multiple projected columns
861    #[test]
862    fn scalar_subquery_multi_col() -> Result<()> {
863        let sq = Arc::new(
864            LogicalPlanBuilder::from(scan_tpch_table("orders"))
865                .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))?
866                .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])?
867                .build()?,
868        );
869
870        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
871            .filter(
872                col("customer.c_custkey")
873                    .eq(scalar_subquery(sq))
874                    .and(col("c_custkey").eq(lit(1))),
875            )?
876            .project(vec![col("customer.c_custkey")])?
877            .build()?;
878
879        let expected = "Error during planning: Scalar subquery should only return one column, but found 2: orders.o_custkey, orders.o_orderkey";
880        assert_analyzer_check_err(vec![], plan, expected);
881        Ok(())
882    }
883
884    /// Test for correlated scalar subquery filter with additional filters
885    #[test]
886    fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> {
887        let sq = Arc::new(
888            LogicalPlanBuilder::from(scan_tpch_table("orders"))
889                .filter(
890                    out_ref_col(DataType::Int64, "customer.c_custkey")
891                        .eq(col("orders.o_custkey")),
892                )?
893                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
894                .project(vec![max(col("orders.o_custkey"))])?
895                .build()?,
896        );
897
898        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
899            .filter(
900                col("customer.c_custkey")
901                    .gt_eq(scalar_subquery(sq))
902                    .and(col("c_custkey").eq(lit(1))),
903            )?
904            .project(vec![col("customer.c_custkey")])?
905            .build()?;
906
907        assert_optimized_plan_equal!(
908            plan,
909            @r"
910        Projection: customer.c_custkey [c_custkey:Int64]
911          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
912            Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
913              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
914                TableScan: customer [c_custkey:Int64, c_name:Utf8]
915                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
916                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
917                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
918                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
919        "
920        )
921    }
922
923    #[test]
924    fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> {
925        let sq = Arc::new(
926            LogicalPlanBuilder::from(scan_tpch_table("orders"))
927                .filter(
928                    out_ref_col(DataType::Int64, "customer.c_custkey")
929                        .eq(col("orders.o_custkey")),
930                )?
931                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
932                .project(vec![max(col("orders.o_custkey"))])?
933                .build()?,
934        );
935
936        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
937            .filter(
938                col("customer.c_custkey")
939                    .eq(scalar_subquery(sq))
940                    .and(col("c_custkey").eq(lit(1))),
941            )?
942            .project(vec![col("customer.c_custkey")])?
943            .build()?;
944
945        assert_optimized_plan_equal!(
946            plan,
947            @r"
948        Projection: customer.c_custkey [c_custkey:Int64]
949          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
950            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
951              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
952                TableScan: customer [c_custkey:Int64, c_name:Utf8]
953                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
954                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
955                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
956                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
957        "
958        )
959    }
960
961    /// Test for correlated scalar subquery filter with disjunctions
962    #[test]
963    fn scalar_subquery_disjunction() -> Result<()> {
964        let sq = Arc::new(
965            LogicalPlanBuilder::from(scan_tpch_table("orders"))
966                .filter(
967                    out_ref_col(DataType::Int64, "customer.c_custkey")
968                        .eq(col("orders.o_custkey")),
969                )?
970                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
971                .project(vec![max(col("orders.o_custkey"))])?
972                .build()?,
973        );
974
975        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
976            .filter(
977                col("customer.c_custkey")
978                    .eq(scalar_subquery(sq))
979                    .or(col("customer.c_custkey").eq(lit(1))),
980            )?
981            .project(vec![col("customer.c_custkey")])?
982            .build()?;
983
984        assert_optimized_plan_equal!(
985            plan,
986            @r"
987        Projection: customer.c_custkey [c_custkey:Int64]
988          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
989            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
990              Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
991                TableScan: customer [c_custkey:Int64, c_name:Utf8]
992                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
993                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
994                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
995                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
996        "
997        )
998    }
999
1000    /// Test for correlated scalar subquery filter
1001    #[test]
1002    fn exists_subquery_correlated() -> Result<()> {
1003        let sq = Arc::new(
1004            LogicalPlanBuilder::from(test_table_scan_with_name("sq")?)
1005                .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))?
1006                .aggregate(Vec::<Expr>::new(), vec![min(col("c"))])?
1007                .project(vec![min(col("c"))])?
1008                .build()?,
1009        );
1010
1011        let plan = LogicalPlanBuilder::from(test_table_scan_with_name("test")?)
1012            .filter(col("test.c").lt(scalar_subquery(sq)))?
1013            .project(vec![col("test.c")])?
1014            .build()?;
1015
1016        assert_optimized_plan_equal!(
1017            plan,
1018            @r"
1019        Projection: test.c [c:UInt32]
1020          Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]
1021            Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1022              Left Join:  Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]
1023                TableScan: test [a:UInt32, b:UInt32, c:UInt32]
1024                SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1025                  Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]
1026                    Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]
1027                      TableScan: sq [a:UInt32, b:UInt32, c:UInt32]
1028        "
1029        )
1030    }
1031
1032    /// Test for non-correlated scalar subquery with no filters
1033    #[test]
1034    fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> {
1035        let sq = Arc::new(
1036            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1037                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1038                .project(vec![max(col("orders.o_custkey"))])?
1039                .build()?,
1040        );
1041
1042        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1043            .filter(col("customer.c_custkey").lt(scalar_subquery(sq)))?
1044            .project(vec![col("customer.c_custkey")])?
1045            .build()?;
1046
1047        assert_optimized_plan_equal!(
1048            plan,
1049            @r"
1050        Projection: customer.c_custkey [c_custkey:Int64]
1051          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1052            Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1053              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1054                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1055                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1056                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1057                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1058                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1059        "
1060        )
1061    }
1062
1063    #[test]
1064    fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> {
1065        let sq = Arc::new(
1066            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1067                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1068                .project(vec![max(col("orders.o_custkey"))])?
1069                .build()?,
1070        );
1071
1072        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1073            .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))?
1074            .project(vec![col("customer.c_custkey")])?
1075            .build()?;
1076
1077        assert_optimized_plan_equal!(
1078            plan,
1079            @r"
1080        Projection: customer.c_custkey [c_custkey:Int64]
1081          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1082            Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1083              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]
1084                TableScan: customer [c_custkey:Int64, c_name:Utf8]
1085                SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]
1086                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1087                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1088                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1089        "
1090        )
1091    }
1092
1093    #[test]
1094    fn correlated_scalar_subquery_in_between_clause() -> Result<()> {
1095        let sq1 = Arc::new(
1096            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1097                .filter(
1098                    out_ref_col(DataType::Int64, "customer.c_custkey")
1099                        .eq(col("orders.o_custkey")),
1100                )?
1101                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1102                .project(vec![min(col("orders.o_custkey"))])?
1103                .build()?,
1104        );
1105        let sq2 = Arc::new(
1106            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1107                .filter(
1108                    out_ref_col(DataType::Int64, "customer.c_custkey")
1109                        .eq(col("orders.o_custkey")),
1110                )?
1111                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1112                .project(vec![max(col("orders.o_custkey"))])?
1113                .build()?,
1114        );
1115
1116        let between_expr = Expr::Between(Between {
1117            expr: Box::new(col("customer.c_custkey")),
1118            negated: false,
1119            low: Box::new(scalar_subquery(sq1)),
1120            high: Box::new(scalar_subquery(sq2)),
1121        });
1122
1123        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1124            .filter(between_expr)?
1125            .project(vec![col("customer.c_custkey")])?
1126            .build()?;
1127
1128        assert_optimized_plan_equal!(
1129            plan,
1130            @r"
1131        Projection: customer.c_custkey [c_custkey:Int64]
1132          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1133            Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1134              Left Join:  Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1135                Left Join:  Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]
1136                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
1137                  SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1138                    Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1139                      Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]
1140                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1141                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1142                  Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]
1143                    Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]
1144                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1145        "
1146        )
1147    }
1148
1149    #[test]
1150    fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> {
1151        let sq1 = Arc::new(
1152            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1153                .aggregate(Vec::<Expr>::new(), vec![min(col("orders.o_custkey"))])?
1154                .project(vec![min(col("orders.o_custkey"))])?
1155                .build()?,
1156        );
1157        let sq2 = Arc::new(
1158            LogicalPlanBuilder::from(scan_tpch_table("orders"))
1159                .aggregate(Vec::<Expr>::new(), vec![max(col("orders.o_custkey"))])?
1160                .project(vec![max(col("orders.o_custkey"))])?
1161                .build()?,
1162        );
1163
1164        let between_expr = Expr::Between(Between {
1165            expr: Box::new(col("customer.c_custkey")),
1166            negated: false,
1167            low: Box::new(scalar_subquery(sq1)),
1168            high: Box::new(scalar_subquery(sq2)),
1169        });
1170
1171        let plan = LogicalPlanBuilder::from(scan_tpch_table("customer"))
1172            .filter(between_expr)?
1173            .project(vec![col("customer.c_custkey")])?
1174            .build()?;
1175
1176        assert_optimized_plan_equal!(
1177            plan,
1178            @r"
1179        Projection: customer.c_custkey [c_custkey:Int64]
1180          Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8]
1181            Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1182              Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]
1183                Left Join:  Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]
1184                  TableScan: customer [c_custkey:Int64, c_name:Utf8]
1185                  SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]
1186                    Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]
1187                      Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]
1188                        TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1189                SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]
1190                  Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]
1191                    Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]
1192                      TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]
1193        "
1194        )
1195    }
1196}