1use crate::{
21 expr::{
22 AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case,
23 Cast, GroupingSet, InList, InSubquery, Lambda, Like, Placeholder, ScalarFunction,
24 TryCast, Unnest, WindowFunction, WindowFunctionParams,
25 },
26 Expr,
27};
28use datafusion_common::{
29 tree_node::{
30 Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer,
31 },
32 DFSchema, HashSet, Result,
33};
34
35impl TreeNode for Expr {
41 fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
46 &'n self,
47 f: F,
48 ) -> Result<TreeNodeRecursion> {
49 match self {
50 Expr::Alias(Alias { expr, .. })
51 | Expr::Unnest(Unnest { expr })
52 | Expr::Not(expr)
53 | Expr::IsNotNull(expr)
54 | Expr::IsTrue(expr)
55 | Expr::IsFalse(expr)
56 | Expr::IsUnknown(expr)
57 | Expr::IsNotTrue(expr)
58 | Expr::IsNotFalse(expr)
59 | Expr::IsNotUnknown(expr)
60 | Expr::IsNull(expr)
61 | Expr::Negative(expr)
62 | Expr::Cast(Cast { expr, .. })
63 | Expr::TryCast(TryCast { expr, .. })
64 | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f),
65 Expr::GroupingSet(GroupingSet::Rollup(exprs))
66 | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f),
67 Expr::ScalarFunction(ScalarFunction { args, .. }) => {
68 args.apply_elements(f)
69 }
70 Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
71 lists_of_exprs.apply_elements(f)
72 }
73 #[expect(deprecated)]
75 Expr::Column(_)
76 | Expr::OuterReferenceColumn(_, _)
78 | Expr::ScalarVariable(_, _)
79 | Expr::Literal(_, _)
80 | Expr::Exists { .. }
81 | Expr::ScalarSubquery(_)
82 | Expr::Wildcard { .. }
83 | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue),
84 Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
85 (left, right).apply_ref_elements(f)
86 }
87 Expr::Like(Like { expr, pattern, .. })
88 | Expr::SimilarTo(Like { expr, pattern, .. }) => {
89 (expr, pattern).apply_ref_elements(f)
90 }
91 Expr::Between(Between {
92 expr, low, high, ..
93 }) => (expr, low, high).apply_ref_elements(f),
94 Expr::Case(Case { expr, when_then_expr, else_expr }) =>
95 (expr, when_then_expr, else_expr).apply_ref_elements(f),
96 Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
97 (args, filter, order_by).apply_ref_elements(f),
98 Expr::WindowFunction(window_fun) => {
99 let WindowFunctionParams {
100 args,
101 partition_by,
102 order_by,
103 filter,
104 ..
105 } = &window_fun.as_ref().params;
106 (args, partition_by, order_by, filter).apply_ref_elements(f)
107 }
108
109 Expr::InList(InList { expr, list, .. }) => {
110 (expr, list).apply_ref_elements(f)
111 }
112 Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f)
113 }
114 }
115
116 fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
121 self,
122 mut f: F,
123 ) -> Result<Transformed<Self>> {
124 Ok(match self {
125 #[expect(deprecated)]
127 Expr::Column(_)
128 | Expr::Wildcard { .. }
129 | Expr::Placeholder(Placeholder { .. })
130 | Expr::OuterReferenceColumn(_, _)
131 | Expr::Exists { .. }
132 | Expr::ScalarSubquery(_)
133 | Expr::ScalarVariable(_, _)
134 | Expr::Literal(_, _) => Transformed::no(self),
135 Expr::Unnest(Unnest { expr, .. }) => expr
136 .map_elements(f)?
137 .update_data(|expr| Expr::Unnest(Unnest { expr })),
138 Expr::Alias(Alias {
139 expr,
140 relation,
141 name,
142 metadata,
143 }) => f(*expr)?.update_data(|e| {
144 e.alias_qualified_with_metadata(relation, name, metadata)
145 }),
146 Expr::InSubquery(InSubquery {
147 expr,
148 subquery,
149 negated,
150 }) => expr.map_elements(f)?.update_data(|be| {
151 Expr::InSubquery(InSubquery::new(be, subquery, negated))
152 }),
153 Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right)
154 .map_elements(f)?
155 .update_data(|(new_left, new_right)| {
156 Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
157 }),
158 Expr::Like(Like {
159 negated,
160 expr,
161 pattern,
162 escape_char,
163 case_insensitive,
164 }) => {
165 (expr, pattern)
166 .map_elements(f)?
167 .update_data(|(new_expr, new_pattern)| {
168 Expr::Like(Like::new(
169 negated,
170 new_expr,
171 new_pattern,
172 escape_char,
173 case_insensitive,
174 ))
175 })
176 }
177 Expr::SimilarTo(Like {
178 negated,
179 expr,
180 pattern,
181 escape_char,
182 case_insensitive,
183 }) => {
184 (expr, pattern)
185 .map_elements(f)?
186 .update_data(|(new_expr, new_pattern)| {
187 Expr::SimilarTo(Like::new(
188 negated,
189 new_expr,
190 new_pattern,
191 escape_char,
192 case_insensitive,
193 ))
194 })
195 }
196 Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not),
197 Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull),
198 Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull),
199 Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue),
200 Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse),
201 Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown),
202 Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue),
203 Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse),
204 Expr::IsNotUnknown(expr) => {
205 expr.map_elements(f)?.update_data(Expr::IsNotUnknown)
206 }
207 Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative),
208 Expr::Between(Between {
209 expr,
210 negated,
211 low,
212 high,
213 }) => (expr, low, high).map_elements(f)?.update_data(
214 |(new_expr, new_low, new_high)| {
215 Expr::Between(Between::new(new_expr, negated, new_low, new_high))
216 },
217 ),
218 Expr::Case(Case {
219 expr,
220 when_then_expr,
221 else_expr,
222 }) => (expr, when_then_expr, else_expr)
223 .map_elements(f)?
224 .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
225 Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
226 }),
227 Expr::Cast(Cast { expr, data_type }) => expr
228 .map_elements(f)?
229 .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
230 Expr::TryCast(TryCast { expr, data_type }) => expr
231 .map_elements(f)?
232 .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
233 Expr::ScalarFunction(ScalarFunction { func, args }) => {
234 args.map_elements(f)?.map_data(|new_args| {
235 Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
236 func, new_args,
237 )))
238 })?
239 }
240 Expr::WindowFunction(window_fun) => {
241 let WindowFunction {
242 fun,
243 params:
244 WindowFunctionParams {
245 args,
246 partition_by,
247 order_by,
248 window_frame,
249 filter,
250 null_treatment,
251 distinct,
252 },
253 } = *window_fun;
254
255 (args, partition_by, order_by, filter)
256 .map_elements(f)?
257 .map_data(
258 |(new_args, new_partition_by, new_order_by, new_filter)| {
259 Ok(Expr::from(WindowFunction {
260 fun,
261 params: WindowFunctionParams {
262 args: new_args,
263 partition_by: new_partition_by,
264 order_by: new_order_by,
265 window_frame,
266 filter: new_filter,
267 null_treatment,
268 distinct,
269 },
270 }))
271 },
272 )?
273 }
274 Expr::AggregateFunction(AggregateFunction {
275 func,
276 params:
277 AggregateFunctionParams {
278 args,
279 distinct,
280 filter,
281 order_by,
282 null_treatment,
283 },
284 }) => (args, filter, order_by).map_elements(f)?.map_data(
285 |(new_args, new_filter, new_order_by)| {
286 Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
287 func,
288 new_args,
289 distinct,
290 new_filter,
291 new_order_by,
292 null_treatment,
293 )))
294 },
295 )?,
296 Expr::GroupingSet(grouping_set) => match grouping_set {
297 GroupingSet::Rollup(exprs) => exprs
298 .map_elements(f)?
299 .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
300 GroupingSet::Cube(exprs) => exprs
301 .map_elements(f)?
302 .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))),
303 GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs
304 .map_elements(f)?
305 .update_data(|new_lists_of_exprs| {
306 Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs))
307 }),
308 },
309 Expr::InList(InList {
310 expr,
311 list,
312 negated,
313 }) => (expr, list)
314 .map_elements(f)?
315 .update_data(|(new_expr, new_list)| {
316 Expr::InList(InList::new(new_expr, new_list, negated))
317 }),
318 Expr::Lambda(Lambda { params, body }) => body
319 .map_elements(f)?
320 .update_data(|body| Expr::Lambda(Lambda { params, body })),
321 })
322 }
323}
324
325impl Expr {
326 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
329 pub fn rewrite_with_schema<
330 R: for<'a> TreeNodeRewriterWithPayload<Node = Expr, Payload<'a> = &'a DFSchema>,
331 >(
332 self,
333 schema: &DFSchema,
334 rewriter: &mut R,
335 ) -> Result<Transformed<Self>> {
336 rewriter
337 .f_down(self, schema)?
338 .transform_children(|n| match &n {
339 Expr::ScalarFunction(ScalarFunction { func, args })
340 if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) =>
341 {
342 let mut lambdas_schemas = func
343 .arguments_schema_from_logical_args(args, schema)?
344 .into_iter();
345
346 n.map_children(|n| {
347 n.rewrite_with_schema(&lambdas_schemas.next().unwrap(), rewriter)
348 })
349 }
350 _ => n.map_children(|n| n.rewrite_with_schema(schema, rewriter)),
351 })?
352 .transform_parent(|n| rewriter.f_up(n, schema))
353 }
354
355 pub fn rewrite_with_lambdas_params<
358 R: for<'a> TreeNodeRewriterWithPayload<
359 Node = Expr,
360 Payload<'a> = &'a HashSet<String>,
361 >,
362 >(
363 self,
364 rewriter: &mut R,
365 ) -> Result<Transformed<Self>> {
366 self.rewrite_with_lambdas_params_impl(&HashSet::new(), rewriter)
367 }
368
369 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
370 fn rewrite_with_lambdas_params_impl<
371 R: for<'a> TreeNodeRewriterWithPayload<
372 Node = Expr,
373 Payload<'a> = &'a HashSet<String>,
374 >,
375 >(
376 self,
377 args: &HashSet<String>,
378 rewriter: &mut R,
379 ) -> Result<Transformed<Self>> {
380 rewriter
381 .f_down(self, args)?
382 .transform_children(|n| match n {
383 Expr::Lambda(Lambda {
384 ref params,
385 body: _,
386 }) => {
387 let mut args = args.clone();
388
389 args.extend(params.iter().cloned());
390
391 n.map_children(|n| {
392 n.rewrite_with_lambdas_params_impl(&args, rewriter)
393 })
394 }
395 _ => {
396 n.map_children(|n| n.rewrite_with_lambdas_params_impl(args, rewriter))
397 }
398 })?
399 .transform_parent(|n| rewriter.f_up(n, args))
400 }
401
402 pub fn map_children_with_lambdas_params<
407 F: FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
408 >(
409 self,
410 args: &HashSet<String>,
411 mut f: F,
412 ) -> Result<Transformed<Self>> {
413 match &self {
414 Expr::Lambda(Lambda { params, body: _ }) => {
415 let mut args = args.clone();
416
417 args.extend(params.iter().cloned());
418
419 self.map_children(|expr| f(expr, &args))
420 }
421 _ => self.map_children(|expr| f(expr, args)),
422 }
423 }
424
425 pub fn transform_up_with_lambdas_params<
428 F: FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
429 >(
430 self,
431 mut f: F,
432 ) -> Result<Transformed<Self>> {
433 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
434 fn transform_up_with_lambdas_params_impl<
435 F: FnMut(Expr, &HashSet<String>) -> Result<Transformed<Expr>>,
436 >(
437 node: Expr,
438 args: &HashSet<String>,
439 f: &mut F,
440 ) -> Result<Transformed<Expr>> {
441 node.map_children_with_lambdas_params(args, |node, args| {
442 transform_up_with_lambdas_params_impl(node, args, f)
443 })?
444 .transform_parent(|node| f(node, args))
445 }
461
462 transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
463 }
464
465 pub fn transform_down_with_lambdas_params<
468 F: FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
469 >(
470 self,
471 mut f: F,
472 ) -> Result<Transformed<Self>> {
473 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
474 fn transform_down_with_lambdas_params_impl<
475 F: FnMut(Expr, &HashSet<String>) -> Result<Transformed<Expr>>,
476 >(
477 node: Expr,
478 args: &HashSet<String>,
479 f: &mut F,
480 ) -> Result<Transformed<Expr>> {
481 f(node, args)?.transform_children(|node| {
482 node.map_children_with_lambdas_params(args, |node, args| {
483 transform_down_with_lambdas_params_impl(node, args, f)
484 })
485 })
486 }
487
488 transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
489 }
490
491 pub fn apply_with_lambdas_params<
492 'n,
493 F: FnMut(&'n Self, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
494 >(
495 &'n self,
496 mut f: F,
497 ) -> Result<TreeNodeRecursion> {
498 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
499 fn apply_with_lambdas_params_impl<
500 'n,
501 F: FnMut(&'n Expr, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
502 >(
503 node: &'n Expr,
504 args: &HashSet<&'n str>,
505 f: &mut F,
506 ) -> Result<TreeNodeRecursion> {
507 match node {
508 Expr::Lambda(Lambda { params, body: _ }) => {
509 let mut args = args.clone();
510
511 args.extend(params.iter().map(|v| v.as_str()));
512
513 f(node, &args)?.visit_children(|| {
514 node.apply_children(|c| {
515 apply_with_lambdas_params_impl(c, &args, f)
516 })
517 })
518 }
519 _ => f(node, args)?.visit_children(|| {
520 node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f))
521 }),
522 }
523 }
524
525 apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
526 }
527
528 pub fn transform_with_schema<
531 F: FnMut(Self, &DFSchema) -> Result<Transformed<Self>>,
532 >(
533 self,
534 schema: &DFSchema,
535 f: F,
536 ) -> Result<Transformed<Self>> {
537 self.transform_up_with_schema(schema, f)
538 }
539
540 pub fn transform_up_with_schema<
543 F: FnMut(Self, &DFSchema) -> Result<Transformed<Self>>,
544 >(
545 self,
546 schema: &DFSchema,
547 mut f: F,
548 ) -> Result<Transformed<Self>> {
549 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
550 fn transform_up_with_schema_impl<
551 F: FnMut(Expr, &DFSchema) -> Result<Transformed<Expr>>,
552 >(
553 node: Expr,
554 schema: &DFSchema,
555 f: &mut F,
556 ) -> Result<Transformed<Expr>> {
557 node.map_children_with_schema(schema, |n, schema| {
558 transform_up_with_schema_impl(n, schema, f)
559 })?
560 .transform_parent(|n| f(n, schema))
561 }
562
563 transform_up_with_schema_impl(self, schema, &mut f)
564 }
565
566 pub fn map_children_with_schema<
567 F: FnMut(Self, &DFSchema) -> Result<Transformed<Self>>,
568 >(
569 self,
570 schema: &DFSchema,
571 mut f: F,
572 ) -> Result<Transformed<Self>> {
573 match self {
574 Expr::ScalarFunction(ref fun)
575 if fun.args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) =>
576 {
577 let mut args_schemas = fun
578 .func
579 .arguments_schema_from_logical_args(&fun.args, schema)?
580 .into_iter();
581
582 self.map_children(|expr| f(expr, &args_schemas.next().unwrap()))
583 }
584 _ => self.map_children(|expr| f(expr, schema)),
585 }
586 }
587
588 pub fn exists_with_lambdas_params<F: FnMut(&Self, &HashSet<&str>) -> Result<bool>>(
589 &self,
590 mut f: F,
591 ) -> Result<bool> {
592 let mut found = false;
593
594 self.apply_with_lambdas_params(|n, lambdas_params| {
595 if f(n, lambdas_params)? {
596 found = true;
597 Ok(TreeNodeRecursion::Stop)
598 } else {
599 Ok(TreeNodeRecursion::Continue)
600 }
601 })?;
602
603 Ok(found)
604 }
605}
606
607pub trait ExprWithLambdasRewriter2: Sized {
608 fn f_down(&mut self, node: Expr, _schema: &DFSchema) -> Result<Transformed<Expr>> {
611 Ok(Transformed::no(node))
612 }
613
614 fn f_up(&mut self, node: Expr, _schema: &DFSchema) -> Result<Transformed<Expr>> {
617 Ok(Transformed::no(node))
618 }
619}
620pub trait TreeNodeRewriterWithPayload: Sized {
621 type Node;
622 type Payload<'a>;
623
624 fn f_down<'a>(
627 &mut self,
628 node: Self::Node,
629 _payload: Self::Payload<'a>,
630 ) -> Result<Transformed<Self::Node>> {
631 Ok(Transformed::no(node))
632 }
633
634 fn f_up<'a>(
637 &mut self,
638 node: Self::Node,
639 _payload: Self::Payload<'a>,
640 ) -> Result<Transformed<Self::Node>> {
641 Ok(Transformed::no(node))
642 }
643}
644
645#[cfg(test)]
751pub(crate) mod tests {
752 use super::TreeNodeRewriterWithPayload;
753 use crate::{
754 col, expr::Lambda, Expr, ScalarUDF, ScalarUDFImpl, ValueOrLambdaParameter,
755 };
756 use arrow::datatypes::{DataType, Field, Schema};
757 use datafusion_common::{
758 tree_node::{Transformed, TreeNodeRecursion},
759 DFSchema, HashSet, Result,
760 };
761 use datafusion_expr_common::signature::{Signature, Volatility};
762
763 pub(crate) fn list_list_int() -> DFSchema {
764 DFSchema::try_from(Schema::new(vec![Field::new(
765 "v",
766 DataType::new_list(DataType::new_list(DataType::Int32, false), false),
767 false,
768 )]))
769 .unwrap()
770 }
771
772 pub(crate) fn list_int() -> DFSchema {
773 DFSchema::try_from(Schema::new(vec![Field::new(
774 "v",
775 DataType::new_list(DataType::Int32, false),
776 false,
777 )]))
778 .unwrap()
779 }
780
781 fn int() -> DFSchema {
782 DFSchema::try_from(Schema::new(vec![Field::new("v", DataType::Int32, false)]))
783 .unwrap()
784 }
785
786 pub(crate) fn array_transform_udf() -> ScalarUDF {
787 ScalarUDF::new_from_impl(ArrayTransformFunc::new())
788 }
789
790 pub(crate) fn args() -> Vec<Expr> {
791 vec![
792 col("v"),
793 Expr::Lambda(Lambda::new(
794 vec!["v".into()],
795 array_transform_udf().call(vec![
796 col("v"),
797 Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))),
798 ]),
799 )),
800 ]
801 }
802
803 fn array_transform() -> Expr {
805 array_transform_udf().call(args())
806 }
807
808 #[derive(Debug, PartialEq, Eq, Hash)]
809 pub(crate) struct ArrayTransformFunc {
810 signature: Signature,
811 }
812
813 impl ArrayTransformFunc {
814 pub fn new() -> Self {
815 Self {
816 signature: Signature::any(2, Volatility::Immutable),
817 }
818 }
819 }
820
821 impl ScalarUDFImpl for ArrayTransformFunc {
822 fn as_any(&self) -> &dyn std::any::Any {
823 self
824 }
825
826 fn name(&self) -> &str {
827 "array_transform"
828 }
829
830 fn signature(&self) -> &Signature {
831 &self.signature
832 }
833
834 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
835 Ok(arg_types[0].clone())
836 }
837
838 fn lambdas_parameters(
839 &self,
840 args: &[ValueOrLambdaParameter],
841 ) -> Result<Vec<Option<Vec<Field>>>> {
842 let ValueOrLambdaParameter::Value(value_field) = &args[0] else {
843 unreachable!()
844 };
845
846 let DataType::List(field) = value_field.data_type() else {
847 unreachable!()
848 };
849
850 Ok(vec![
851 None,
852 Some(vec![Field::new(
853 "",
854 field.data_type().clone(),
855 field.is_nullable(),
856 )]),
857 ])
858 }
859
860 fn invoke_with_args(
861 &self,
862 _args: crate::ScalarFunctionArgs,
863 ) -> Result<datafusion_expr_common::columnar_value::ColumnarValue> {
864 unimplemented!()
865 }
866 }
867
868 #[test]
869 fn test_rewrite_with_schema() {
870 let schema = list_list_int();
871 let array_transform = array_transform();
872
873 let mut rewriter = OkRewriter::default();
874
875 array_transform
876 .rewrite_with_schema(&schema, &mut rewriter)
877 .unwrap();
878
879 let expected = [
880 (
881 "f_down array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
882 list_list_int(),
883 ),
884 ("f_down v", list_list_int()),
885 ("f_up v", list_list_int()),
886 ("f_down (v) -> array_transform(v, (v) -> (- v))", list_int()),
887 ("f_down array_transform(v, (v) -> (- v))", list_int()),
888 ("f_down v", list_int()),
889 ("f_up v", list_int()),
890 ("f_down (v) -> (- v)", int()),
891 ("f_down (- v)", int()),
892 ("f_down v", int()),
893 ("f_up v", int()),
894 ("f_up (- v)", int()),
895 ("f_up (v) -> (- v)", int()),
896 ("f_up array_transform(v, (v) -> (- v))", list_int()),
897 ("f_up (v) -> array_transform(v, (v) -> (- v))", list_int()),
898 (
899 "f_up array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
900 list_list_int(),
901 ),
902 ]
903 .map(|(a, b)| (String::from(a), b));
904
905 assert_eq!(rewriter.steps, expected)
906 }
907
908 #[derive(Default)]
909 struct OkRewriter {
910 steps: Vec<(String, DFSchema)>,
911 }
912
913 impl TreeNodeRewriterWithPayload for OkRewriter {
914 type Node = Expr;
915 type Payload<'a> = &'a DFSchema;
916
917 fn f_down(
918 &mut self,
919 node: Expr,
920 schema: &DFSchema,
921 ) -> Result<Transformed<Expr>> {
922 self.steps.push((format!("f_down {node}"), schema.clone()));
923
924 Ok(Transformed::no(node))
925 }
926
927 fn f_up(
928 &mut self,
929 node: Expr,
930 schema: &DFSchema,
931 ) -> Result<Transformed<Expr>> {
932 self.steps.push((format!("f_up {node}"), schema.clone()));
933
934 Ok(Transformed::no(node))
935 }
936 }
937
938 #[test]
939 fn test_transform_up_with_lambdas_params() {
940 let mut steps = vec![];
941
942 array_transform()
943 .transform_up_with_lambdas_params(|node, params| {
944 steps.push((node.to_string(), params.clone()));
945
946 Ok(Transformed::no(node))
947 })
948 .unwrap();
949
950 let lambdas_params = &HashSet::from([String::from("v")]);
951
952 let expected = [
953 ("v", lambdas_params),
954 ("v", lambdas_params),
955 ("v", lambdas_params),
956 ("(- v)", lambdas_params),
957 ("(v) -> (- v)", lambdas_params),
958 ("array_transform(v, (v) -> (- v))", lambdas_params),
959 ("(v) -> array_transform(v, (v) -> (- v))", lambdas_params),
960 (
961 "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
962 lambdas_params,
963 ),
964 ]
965 .map(|(a, b)| (String::from(a), b.clone()));
966
967 assert_eq!(steps, expected);
968 }
969
970 #[test]
971 fn test_apply_with_lambdas_params() {
972 let array_transform = array_transform();
973 let mut steps = vec![];
974
975 array_transform
976 .apply_with_lambdas_params(|node, params| {
977 steps.push((node.to_string(), params.clone()));
978
979 Ok(TreeNodeRecursion::Continue)
980 })
981 .unwrap();
982
983 let expected = [
984 ("v", HashSet::from(["v"])),
985 ("v", HashSet::from(["v"])),
986 ("v", HashSet::from(["v"])),
987 ("(- v)", HashSet::from(["v"])),
988 ("(v) -> (- v)", HashSet::from(["v"])),
989 ("array_transform(v, (v) -> (- v))", HashSet::from(["v"])),
990 ("(v) -> array_transform(v, (v) -> (- v))", HashSet::from(["v"])),
991 (
992 "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))",
993 HashSet::from(["v"]),
994 ),
995 ]
996 .map(|(a, b)| (String::from(a), b));
997
998 assert_eq!(steps, expected);
999 }
1000}