datafusion_expr/
tree_node.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Tree node implementation for Logical Expressions
19
20use crate::{
21    expr::{
22        AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case,
23        Cast, GroupingSet, InList, InSubquery, Lambda, Like, Placeholder, ScalarFunction,
24        TryCast, Unnest, WindowFunction, WindowFunctionParams,
25    },
26    Expr,
27};
28use datafusion_common::{
29    tree_node::{
30        Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
31    },
32    DFSchema, HashSet, Result,
33};
34
35/// Implementation of the [`TreeNode`] trait
36///
37/// This allows logical expressions (`Expr`) to be traversed and transformed
38/// Facilitates tasks such as optimization and rewriting during query
39/// planning.
40impl TreeNode for Expr {
41    /// Applies a function `f` to each child expression of `self`.
42    ///
43    /// The function `f` determines whether to continue traversing the tree or to stop.
44    /// This method collects all child expressions and applies `f` to each.
45    fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
46        &'n self,
47        f: F,
48    ) -> Result<TreeNodeRecursion> {
49        match self {
50            Expr::Alias(Alias { expr, .. })
51            | Expr::Unnest(Unnest { expr })
52            | Expr::Not(expr)
53            | Expr::IsNotNull(expr)
54            | Expr::IsTrue(expr)
55            | Expr::IsFalse(expr)
56            | Expr::IsUnknown(expr)
57            | Expr::IsNotTrue(expr)
58            | Expr::IsNotFalse(expr)
59            | Expr::IsNotUnknown(expr)
60            | Expr::IsNull(expr)
61            | Expr::Negative(expr)
62            | Expr::Cast(Cast { expr, .. })
63            | Expr::TryCast(TryCast { expr, .. })
64            | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
65            Expr::GroupingSet(GroupingSet::Rollup(exprs))
66            | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
67            Expr::ScalarFunction(ScalarFunction { args, .. }) => {
68                args.apply_elements(f)
69            }
70            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
71                lists_of_exprs.apply_elements(f)
72            }
73            // TODO: remove the next line after `Expr::Wildcard` is removed
74            #[expect(deprecated)]
75            Expr::Column(_)
76            // Treat OuterReferenceColumn as a leaf expression
77            | Expr::OuterReferenceColumn(_, _)
78            | Expr::ScalarVariable(_, _)
79            | Expr::Literal(_, _)
80            | Expr::Exists { .. }
81            | Expr::ScalarSubquery(_)
82            | Expr::Wildcard { .. }
83            | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
84            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
85                (left, right).apply_ref_elements(f)
86            }
87            Expr::Like(Like { expr, pattern, .. })
88            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
89                (expr, pattern).apply_ref_elements(f)
90            }
91            Expr::Between(Between {
92                              expr, low, high, ..
93                          }) => (expr, low, high).apply_ref_elements(f),
94            Expr::Case(Case { expr, when_then_expr, else_expr }) =>
95                (expr, when_then_expr, else_expr).apply_ref_elements(f),
96            Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
97                (args, filter, order_by).apply_ref_elements(f),
98            Expr::WindowFunction(window_fun) => {
99                let WindowFunctionParams {
100                    args,
101                    partition_by,
102                    order_by,
103                    filter,
104                    ..
105                } = &window_fun.as_ref().params;
106                (args, partition_by, order_by, filter).apply_ref_elements(f)
107            }
108
109            Expr::InList(InList { expr, list, .. }) => {
110                (expr, list).apply_ref_elements(f)
111            }
112            Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f)
113        }
114    }
115
116    /// Maps each child of `self` using the provided closure `f`.
117    ///
118    /// The closure `f` takes ownership of an expression and returns a `Transformed` result,
119    /// indicating whether the expression was transformed or left unchanged.
120    fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
121        self,
122        mut f: F,
123    ) -> Result<Transformed<Self>> {
124        Ok(match self {
125            // TODO: remove the next line after `Expr::Wildcard` is removed
126            #[expect(deprecated)]
127            Expr::Column(_)
128            | Expr::Wildcard { .. }
129            | Expr::Placeholder(Placeholder { .. })
130            | Expr::OuterReferenceColumn(_, _)
131            | Expr::Exists { .. }
132            | Expr::ScalarSubquery(_)
133            | Expr::ScalarVariable(_, _)
134            | Expr::Literal(_, _) => Transformed::no(self),
135            Expr::Unnest(Unnest { expr, .. }) => expr
136                .map_elements(f)?
137                .update_data(|expr| Expr::Unnest(Unnest { expr })),
138            Expr::Alias(Alias {
139                expr,
140                relation,
141                name,
142                metadata,
143            }) => f(*expr)?.update_data(|e| {
144                e.alias_qualified_with_metadata(relation, name, metadata)
145            }),
146            Expr::InSubquery(InSubquery {
147                expr,
148                subquery,
149                negated,
150            }) => expr.map_elements(f)?.update_data(|be| {
151                Expr::InSubquery(InSubquery::new(be, subquery, negated))
152            }),
153            Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
154                .map_elements(f)?
155                .update_data(|(new_left, new_right)| {
156                    Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
157                }),
158            Expr::Like(Like {
159                negated,
160                expr,
161                pattern,
162                escape_char,
163                case_insensitive,
164            }) => {
165                (expr, pattern)
166                    .map_elements(f)?
167                    .update_data(|(new_expr, new_pattern)| {
168                        Expr::Like(Like::new(
169                            negated,
170                            new_expr,
171                            new_pattern,
172                            escape_char,
173                            case_insensitive,
174                        ))
175                    })
176            }
177            Expr::SimilarTo(Like {
178                negated,
179                expr,
180                pattern,
181                escape_char,
182                case_insensitive,
183            }) => {
184                (expr, pattern)
185                    .map_elements(f)?
186                    .update_data(|(new_expr, new_pattern)| {
187                        Expr::SimilarTo(Like::new(
188                            negated,
189                            new_expr,
190                            new_pattern,
191                            escape_char,
192                            case_insensitive,
193                        ))
194                    })
195            }
196            Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
197            Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
198            Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
199            Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
200            Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
201            Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
202            Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
203            Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
204            Expr::IsNotUnknown(expr) => {
205                expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
206            }
207            Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
208            Expr::Between(Between {
209                expr,
210                negated,
211                low,
212                high,
213            }) => (expr, low, high).map_elements(f)?.update_data(
214                |(new_expr, new_low, new_high)| {
215                    Expr::Between(Between::new(new_expr, negated, new_low, new_high))
216                },
217            ),
218            Expr::Case(Case {
219                expr,
220                when_then_expr,
221                else_expr,
222            }) => (expr, when_then_expr, else_expr)
223                .map_elements(f)?
224                .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
225                    Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
226                }),
227            Expr::Cast(Cast { expr, data_type }) => expr
228                .map_elements(f)?
229                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
230            Expr::TryCast(TryCast { expr, data_type }) => expr
231                .map_elements(f)?
232                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
233            Expr::ScalarFunction(ScalarFunction { func, args }) => {
234                args.map_elements(f)?.map_data(|new_args| {
235                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
236                        func, new_args,
237                    )))
238                })?
239            }
240            Expr::WindowFunction(window_fun) => {
241                let WindowFunction {
242                    fun,
243                    params:
244                        WindowFunctionParams {
245                            args,
246                            partition_by,
247                            order_by,
248                            window_frame,
249                            filter,
250                            null_treatment,
251                            distinct,
252                        },
253                } = *window_fun;
254
255                (args, partition_by, order_by, filter)
256                    .map_elements(f)?
257                    .map_data(
258                        |(new_args, new_partition_by, new_order_by, new_filter)| {
259                            Ok(Expr::from(WindowFunction {
260                                fun,
261                                params: WindowFunctionParams {
262                                    args: new_args,
263                                    partition_by: new_partition_by,
264                                    order_by: new_order_by,
265                                    window_frame,
266                                    filter: new_filter,
267                                    null_treatment,
268                                    distinct,
269                                },
270                            }))
271                        },
272                    )?
273            }
274            Expr::AggregateFunction(AggregateFunction {
275                func,
276                params:
277                    AggregateFunctionParams {
278                        args,
279                        distinct,
280                        filter,
281                        order_by,
282                        null_treatment,
283                    },
284            }) => (args, filter, order_by).map_elements(f)?.map_data(
285                |(new_args, new_filter, new_order_by)| {
286                    Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
287                        func,
288                        new_args,
289                        distinct,
290                        new_filter,
291                        new_order_by,
292                        null_treatment,
293                    )))
294                },
295            )?,
296            Expr::GroupingSet(grouping_set) => match grouping_set {
297                GroupingSet::Rollup(exprs) => exprs
298                    .map_elements(f)?
299                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
300                GroupingSet::Cube(exprs) => exprs
301                    .map_elements(f)?
302                    .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
303                GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
304                    .map_elements(f)?
305                    .update_data(|new_lists_of_exprs| {
306                        Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
307                    }),
308            },
309            Expr::InList(InList {
310                expr,
311                list,
312                negated,
313            }) => (expr, list)
314                .map_elements(f)?
315                .update_data(|(new_expr, new_list)| {
316                    Expr::InList(InList::new(new_expr, new_list, negated))
317                }),
318            Expr::Lambda(Lambda { params, body }) => body
319                .map_elements(f)?
320                .update_data(|body| Expr::Lambda(Lambda { params, body })),
321        })
322    }
323}
324
325impl Expr {
326    /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`,
327    /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`.
328    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
329    pub fn rewrite_with_schema<
330        R: for<'a> TreeNodeRewriterWithPayload<Node = Expr, Payload<'a> = &'a DFSchema>,
331    >(
332        self,
333        schema: &DFSchema,
334        rewriter: &mut R,
335    ) -> Result<Transformed<Self>> {
336        rewriter
337            .f_down(self, schema)?
338            .transform_children(|n| match &n {
339                Expr::ScalarFunction(ScalarFunction { func, args })
340                    if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) =>
341                {
342                    let mut lambdas_schemas = func
343                        .arguments_schema_from_logical_args(args, schema)?
344                        .into_iter();
345
346                    n.map_children(|n| {
347                        n.rewrite_with_schema(&lambdas_schemas.next().unwrap(), rewriter)
348                    })
349                }
350                _ => n.map_children(|n| n.rewrite_with_schema(schema, rewriter)),
351            })?
352            .transform_parent(|n| rewriter.f_up(n, schema))
353    }
354
355    /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`,
356    /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`.
357    pub fn rewrite_with_lambdas_params<
358        R: for<'a> TreeNodeRewriterWithPayload<
359            Node = Expr,
360            Payload<'a> = &'a HashSet<String>,
361        >,
362    >(
363        self,
364        rewriter: &mut R,
365    ) -> Result<Transformed<Self>> {
366        self.rewrite_with_lambdas_params_impl(&HashSet::new(), rewriter)
367    }
368
369    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
370    fn rewrite_with_lambdas_params_impl<
371        R: for<'a> TreeNodeRewriterWithPayload<
372            Node = Expr,
373            Payload<'a> = &'a HashSet<String>,
374        >,
375    >(
376        self,
377        args: &HashSet<String>,
378        rewriter: &mut R,
379    ) -> Result<Transformed<Self>> {
380        rewriter
381            .f_down(self, args)?
382            .transform_children(|n| match n {
383                Expr::Lambda(Lambda {
384                    ref params,
385                    body: _,
386                }) => {
387                    let mut args = args.clone();
388
389                    args.extend(params.iter().cloned());
390
391                    n.map_children(|n| {
392                        n.rewrite_with_lambdas_params_impl(&args, rewriter)
393                    })
394                }
395                _ => {
396                    n.map_children(|n| n.rewrite_with_lambdas_params_impl(args, rewriter))
397                }
398            })?
399            .transform_parent(|n| rewriter.f_up(n, args))
400    }
401
402    /// Similarly to [`Self::map_children`], rewrites all lambdas that may
403    /// appear in expressions such as `array_transform([1, 2], v -> v*2)`.
404    ///
405    /// Returns the current node.
406    pub fn map_children_with_lambdas_params<
407        F: FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
408    >(
409        self,
410        args: &HashSet<String>,
411        mut f: F,
412    ) -> Result<Transformed<Self>> {
413        match &self {
414            Expr::Lambda(Lambda { params, body: _ }) => {
415                let mut args = args.clone();
416
417                args.extend(params.iter().cloned());
418
419                self.map_children(|expr| f(expr, &args))
420            }
421            _ => self.map_children(|expr| f(expr, args)),
422        }
423    }
424
425    /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`,
426    /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`.
427    pub fn transform_up_with_lambdas_params<
428        F: FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
429    >(
430        self,
431        mut f: F,
432    ) -> Result<Transformed<Self>> {
433        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
434        fn transform_up_with_lambdas_params_impl<
435            F: FnMut(Expr, &HashSet<String>) -> Result<Transformed<Expr>>,
436        >(
437            node: Expr,
438            args: &HashSet<String>,
439            f: &mut F,
440        ) -> Result<Transformed<Expr>> {
441            node.map_children_with_lambdas_params(args, |node, args| {
442                transform_up_with_lambdas_params_impl(node, args, f)
443            })?
444            .transform_parent(|node| f(node, args))
445            /*match &node {
446                Expr::Lambda(Lambda { params, body: _ }) => {
447                    let mut args = args.clone();
448
449                    args.extend(params.iter().cloned());
450
451                    node.map_children(|n| {
452                        transform_up_with_lambdas_params_impl(n, &args, f)
453                    })?
454                    .transform_parent(|n| f(n, &args))
455                }
456                _ => node
457                    .map_children(|n| transform_up_with_lambdas_params_impl(n, args, f))?
458                    .transform_parent(|n| f(n, args)),
459            }*/
460        }
461
462        transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
463    }
464
465    /// Similarly to [`Self::transform_down`], rewrites this expr and its inputs using `f`,
466    /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`.
467    pub fn transform_down_with_lambdas_params<
468        F: FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
469    >(
470        self,
471        mut f: F,
472    ) -> Result<Transformed<Self>> {
473        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
474        fn transform_down_with_lambdas_params_impl<
475            F: FnMut(Expr, &HashSet<String>) -> Result<Transformed<Expr>>,
476        >(
477            node: Expr,
478            args: &HashSet<String>,
479            f: &mut F,
480        ) -> Result<Transformed<Expr>> {
481            f(node, args)?.transform_children(|node| {
482                node.map_children_with_lambdas_params(args, |node, args| {
483                    transform_down_with_lambdas_params_impl(node, args, f)
484                })
485            })
486        }
487
488        transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
489    }
490
491    pub fn apply_with_lambdas_params<
492        'n,
493        F: FnMut(&'n Self, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
494    >(
495        &'n self,
496        mut f: F,
497    ) -> Result<TreeNodeRecursion> {
498        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
499        fn apply_with_lambdas_params_impl<
500            'n,
501            F: FnMut(&'n Expr, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
502        >(
503            node: &'n Expr,
504            args: &HashSet<&'n str>,
505            f: &mut F,
506        ) -> Result<TreeNodeRecursion> {
507            match node {
508                Expr::Lambda(Lambda { params, body: _ }) => {
509                    let mut args = args.clone();
510
511                    args.extend(params.iter().map(|v| v.as_str()));
512
513                    f(node, &args)?.visit_children(|| {
514                        node.apply_children(|c| {
515                            apply_with_lambdas_params_impl(c, &args, f)
516                        })
517                    })
518                }
519                _ => f(node, args)?.visit_children(|| {
520                    node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f))
521                }),
522            }
523        }
524
525        apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
526    }
527
528    /// Similarly to [`Self::transform`], rewrites this expr and its inputs using `f`,
529    /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`.
530    pub fn transform_with_schema<
531        F: FnMut(Self, &DFSchema) -> Result<Transformed<Self>>,
532    >(
533        self,
534        schema: &DFSchema,
535        f: F,
536    ) -> Result<Transformed<Self>> {
537        self.transform_up_with_schema(schema, f)
538    }
539
540    /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`,
541    /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`.
542    pub fn transform_up_with_schema<
543        F: FnMut(Self, &DFSchema) -> Result<Transformed<Self>>,
544    >(
545        self,
546        schema: &DFSchema,
547        mut f: F,
548    ) -> Result<Transformed<Self>> {
549        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
550        fn transform_up_with_schema_impl<
551            F: FnMut(Expr, &DFSchema) -> Result<Transformed<Expr>>,
552        >(
553            node: Expr,
554            schema: &DFSchema,
555            f: &mut F,
556        ) -> Result<Transformed<Expr>> {
557            node.map_children_with_schema(schema, |n, schema| {
558                transform_up_with_schema_impl(n, schema, f)
559            })?
560            .transform_parent(|n| f(n, schema))
561        }
562
563        transform_up_with_schema_impl(self, schema, &mut f)
564    }
565
566    pub fn map_children_with_schema<
567        F: FnMut(Self, &DFSchema) -> Result<Transformed<Self>>,
568    >(
569        self,
570        schema: &DFSchema,
571        mut f: F,
572    ) -> Result<Transformed<Self>> {
573        match self {
574            Expr::ScalarFunction(ref fun)
575                if fun.args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) =>
576            {
577                let mut args_schemas = fun
578                    .func
579                    .arguments_schema_from_logical_args(&fun.args, schema)?
580                    .into_iter();
581
582                self.map_children(|expr| f(expr, &args_schemas.next().unwrap()))
583            }
584            _ => self.map_children(|expr| f(expr, schema)),
585        }
586    }
587
588    pub fn exists_with_lambdas_params<F: FnMut(&Self, &HashSet<&str>) -> Result<bool>>(
589        &self,
590        mut f: F,
591    ) -> Result<bool> {
592        let mut found = false;
593
594        self.apply_with_lambdas_params(|n, lambdas_params| {
595            if f(n, lambdas_params)? {
596                found = true;
597                Ok(TreeNodeRecursion::Stop)
598            } else {
599                Ok(TreeNodeRecursion::Continue)
600            }
601        })?;
602
603        Ok(found)
604    }
605}
606
607pub trait ExprWithLambdasRewriter2: Sized {
608    /// Invoked while traversing down the tree before any children are rewritten.
609    /// Default implementation returns the node as is and continues recursion.
610    fn f_down(&mut self, node: Expr, _schema: &DFSchema) -> Result<Transformed<Expr>> {
611        Ok(Transformed::no(node))
612    }
613
614    /// Invoked while traversing up the tree after all children have been rewritten.
615    /// Default implementation returns the node as is and continues recursion.
616    fn f_up(&mut self, node: Expr, _schema: &DFSchema) -> Result<Transformed<Expr>> {
617        Ok(Transformed::no(node))
618    }
619}
620pub trait TreeNodeRewriterWithPayload: Sized {
621    type Node;
622    type Payload<'a>;
623
624    /// Invoked while traversing down the tree before any children are rewritten.
625    /// Default implementation returns the node as is and continues recursion.
626    fn f_down<'a>(
627        &mut self,
628        node: Self::Node,
629        _payload: Self::Payload<'a>,
630    ) -> Result<Transformed<Self::Node>> {
631        Ok(Transformed::no(node))
632    }
633
634    /// Invoked while traversing up the tree after all children have been rewritten.
635    /// Default implementation returns the node as is and continues recursion.
636    fn f_up<'a>(
637        &mut self,
638        node: Self::Node,
639        _payload: Self::Payload<'a>,
640    ) -> Result<Transformed<Self::Node>> {
641        Ok(Transformed::no(node))
642    }
643}
644
645/*
646struct LambdaColumnNormalizer<'a> {
647    existing_qualifiers: HashSet<&'a str>,
648    alias_generator: AliasGenerator,
649    lambdas_columns: HashMap<String, Vec<TableReference>>,
650}
651
652impl<'a> LambdaColumnNormalizer<'a> {
653    fn new(dfschema: &'a DFSchema, expr: &'a Expr) -> Self {
654        let mut existing_qualifiers: HashSet<&'a str> = dfschema
655            .field_qualifiers()
656            .iter()
657            .flatten()
658            .map(|tbl| tbl.table())
659            .filter(|table| table.starts_with("lambda_"))
660            .collect();
661
662        expr.apply(|node| {
663            if let Expr::Lambda(lambda) = node {
664                if let Some(qualifier) = &lambda.qualifier {
665                    existing_qualifiers.insert(qualifier);
666                }
667            }
668
669            Ok(TreeNodeRecursion::Continue)
670        })
671        .unwrap();
672
673        Self {
674            existing_qualifiers,
675            alias_generator: AliasGenerator::new(),
676            lambdas_columns: HashMap::new(),
677        }
678    }
679}
680
681impl TreeNodeRewriter for LambdaColumnNormalizer<'_> {
682    type Node = Expr;
683
684    fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
685        match node {
686            Expr::Lambda(mut lambda) => {
687                let tbl = lambda.qualifier.as_ref().map_or_else(
688                    || loop {
689                        let table = self.alias_generator.next("lambda");
690
691                        if !self.existing_qualifiers.contains(table.as_str()) {
692                            break TableReference::bare(table);
693                        }
694                    },
695                    |qualifier| TableReference::bare(qualifier.as_str()),
696                );
697
698                for param in &lambda.params {
699                    self.lambdas_columns
700                        .entry_ref(param)
701                        .or_default()
702                        .push(tbl.clone());
703                }
704
705                if lambda.qualifier.is_none() {
706                    lambda.qualifier = Some(tbl.table().to_owned());
707
708                    Ok(Transformed::yes(Expr::Lambda(lambda)))
709                } else {
710                    Ok(Transformed::no(Expr::Lambda(lambda)))
711                }
712            }
713            Expr::Column(c) if c.relation.is_none() => {
714                if let Some(lambda_qualifier) = self.lambdas_columns.get(c.name()) {
715                    Ok(Transformed::yes(Expr::Column(
716                        c.with_relation(lambda_qualifier.last().unwrap().clone()),
717                    )))
718                } else {
719                    Ok(Transformed::no(Expr::Column(c)))
720                }
721            }
722            _ => Ok(Transformed::no(node))
723        }
724    }
725
726    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
727        if let Expr::Lambda(lambda) = &node {
728            for param in &lambda.params {
729                match self.lambdas_columns.entry_ref(param) {
730                    EntryRef::Occupied(mut entry) => {
731                        let chain = entry.get_mut();
732
733                        chain.pop();
734
735                        if chain.is_empty() {
736                            entry.remove();
737                        }
738                    }
739                    EntryRef::Vacant(_) => unreachable!(),
740                }
741            }
742        }
743
744        Ok(Transformed::no(node))
745    }
746}
747*/
748
749// helpers used in udf.rs
750#[cfg(test)]
751pub(crate) mod tests {
752    use super::TreeNodeRewriterWithPayload;
753    use crate::{
754        col, expr::Lambda, Expr, ScalarUDF, ScalarUDFImpl, ValueOrLambdaParameter,
755    };
756    use arrow::datatypes::{DataType, Field, Schema};
757    use datafusion_common::{
758        tree_node::{Transformed, TreeNodeRecursion},
759        DFSchema, HashSet, Result,
760    };
761    use datafusion_expr_common::signature::{Signature, Volatility};
762
763    pub(crate) fn list_list_int() -> DFSchema {
764        DFSchema::try_from(Schema::new(vec![Field::new(
765            "v",
766            DataType::new_list(DataType::new_list(DataType::Int32, false), false),
767            false,
768        )]))
769        .unwrap()
770    }
771
772    pub(crate) fn list_int() -> DFSchema {
773        DFSchema::try_from(Schema::new(vec![Field::new(
774            "v",
775            DataType::new_list(DataType::Int32, false),
776            false,
777        )]))
778        .unwrap()
779    }
780
781    fn int() -> DFSchema {
782        DFSchema::try_from(Schema::new(vec![Field::new("v", DataType::Int32, false)]))
783            .unwrap()
784    }
785
786    pub(crate) fn array_transform_udf() -> ScalarUDF {
787        ScalarUDF::new_from_impl(ArrayTransformFunc::new())
788    }
789
790    pub(crate) fn args() -> Vec<Expr> {
791        vec![
792            col("v"),
793            Expr::Lambda(Lambda::new(
794                vec!["v".into()],
795                array_transform_udf().call(vec![
796                    col("v"),
797                    Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))),
798                ]),
799            )),
800        ]
801    }
802
803    // array_transform(v, |v| -> array_transform(v, |v| -> -v))
804    fn array_transform() -> Expr {
805        array_transform_udf().call(args())
806    }
807
808    #[derive(Debug, PartialEq, Eq, Hash)]
809    pub(crate) struct ArrayTransformFunc {
810        signature: Signature,
811    }
812
813    impl ArrayTransformFunc {
814        pub fn new() -> Self {
815            Self {
816                signature: Signature::any(2, Volatility::Immutable),
817            }
818        }
819    }
820
821    impl ScalarUDFImpl for ArrayTransformFunc {
822        fn as_any(&self) -> &dyn std::any::Any {
823            self
824        }
825
826        fn name(&self) -> &str {
827            "array_transform"
828        }
829
830        fn signature(&self) -> &Signature {
831            &self.signature
832        }
833
834        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
835            Ok(arg_types[0].clone())
836        }
837
838        fn lambdas_parameters(
839            &self,
840            args: &[ValueOrLambdaParameter],
841        ) -> Result<Vec<Option<Vec<Field>>>> {
842            let ValueOrLambdaParameter::Value(value_field) = &args[0] else {
843                unreachable!()
844            };
845
846            let DataType::List(field) = value_field.data_type() else {
847                unreachable!()
848            };
849
850            Ok(vec![
851                None,
852                Some(vec![Field::new(
853                    "",
854                    field.data_type().clone(),
855                    field.is_nullable(),
856                )]),
857            ])
858        }
859
860        fn invoke_with_args(
861            &self,
862            _args: crate::ScalarFunctionArgs,
863        ) -> Result<datafusion_expr_common::columnar_value::ColumnarValue> {
864            unimplemented!()
865        }
866    }
867
868    #[test]
869    fn test_rewrite_with_schema() {
870        let schema = list_list_int();
871        let array_transform = array_transform();
872
873        let mut rewriter = OkRewriter::default();
874
875        array_transform
876            .rewrite_with_schema(&schema, &mut rewriter)
877            .unwrap();
878
879        let expected = [
880            (
881                "f_down array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
882                list_list_int(),
883            ),
884            ("f_down v", list_list_int()),
885            ("f_up v", list_list_int()),
886            ("f_down (v) -> array_transform(v, (v) -> (- v))", list_int()),
887            ("f_down array_transform(v, (v) -> (- v))", list_int()),
888            ("f_down v", list_int()),
889            ("f_up v", list_int()),
890            ("f_down (v) -> (- v)", int()),
891            ("f_down (- v)", int()),
892            ("f_down v", int()),
893            ("f_up v", int()),
894            ("f_up (- v)", int()),
895            ("f_up (v) -> (- v)", int()),
896            ("f_up array_transform(v, (v) -> (- v))", list_int()),
897            ("f_up (v) -> array_transform(v, (v) -> (- v))", list_int()),
898            (
899                "f_up array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
900                list_list_int(),
901            ),
902        ]
903        .map(|(a, b)| (String::from(a), b));
904
905        assert_eq!(rewriter.steps, expected)
906    }
907
908    #[derive(Default)]
909    struct OkRewriter {
910        steps: Vec<(String, DFSchema)>,
911    }
912
913    impl TreeNodeRewriterWithPayload for OkRewriter {
914        type Node = Expr;
915        type Payload<'a> = &'a DFSchema;
916
917        fn f_down(
918            &mut self,
919            node: Expr,
920            schema: &DFSchema,
921        ) -> Result<Transformed<Expr>> {
922            self.steps.push((format!("f_down {node}"), schema.clone()));
923
924            Ok(Transformed::no(node))
925        }
926
927        fn f_up(
928            &mut self,
929            node: Expr,
930            schema: &DFSchema,
931        ) -> Result<Transformed<Expr>> {
932            self.steps.push((format!("f_up {node}"), schema.clone()));
933
934            Ok(Transformed::no(node))
935        }
936    }
937
938    #[test]
939    fn test_transform_up_with_lambdas_params() {
940        let mut steps = vec![];
941
942        array_transform()
943            .transform_up_with_lambdas_params(|node, params| {
944                steps.push((node.to_string(), params.clone()));
945
946                Ok(Transformed::no(node))
947            })
948            .unwrap();
949
950        let lambdas_params = &HashSet::from([String::from("v")]);
951
952        let expected = [
953            ("v", lambdas_params),
954            ("v", lambdas_params),
955            ("v", lambdas_params),
956            ("(- v)", lambdas_params),
957            ("(v) -> (- v)", lambdas_params),
958            ("array_transform(v, (v) -> (- v))", lambdas_params),
959            ("(v) -> array_transform(v, (v) -> (- v))", lambdas_params),
960            (
961                "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
962                lambdas_params,
963            ),
964        ]
965        .map(|(a, b)| (String::from(a), b.clone()));
966
967        assert_eq!(steps, expected);
968    }
969
970    #[test]
971    fn test_apply_with_lambdas_params() {
972        let array_transform = array_transform();
973        let mut steps = vec![];
974
975        array_transform
976            .apply_with_lambdas_params(|node, params| {
977                steps.push((node.to_string(), params.clone()));
978
979                Ok(TreeNodeRecursion::Continue)
980            })
981            .unwrap();
982
983        let expected = [
984            ("v", HashSet::from(["v"])),
985            ("v", HashSet::from(["v"])),
986            ("v", HashSet::from(["v"])),
987            ("(- v)", HashSet::from(["v"])),
988            ("(v) -> (- v)", HashSet::from(["v"])),
989            ("array_transform(v, (v) -> (- v))", HashSet::from(["v"])),
990            ("(v) -> array_transform(v, (v) -> (- v))", HashSet::from(["v"])),
991            (
992                "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
993                HashSet::from(["v"]),
994            ),
995        ]
996        .map(|(a, b)| (String::from(a), b));
997
998        assert_eq!(steps, expected);
999    }
1000}