1use std::any::Any;
33use std::borrow::Cow;
34use std::fmt::{self, Debug, Formatter};
35use std::hash::{Hash, Hasher};
36use std::sync::Arc;
37
38use crate::expressions::{Column, LambdaExpr, Literal};
39use crate::PhysicalExpr;
40
41use arrow::array::{Array, NullArray, RecordBatch};
42use arrow::datatypes::{DataType, Field, FieldRef, Schema};
43use datafusion_common::config::{ConfigEntry, ConfigOptions};
44use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
45use datafusion_common::{internal_err, HashSet, Result, ScalarValue};
46use datafusion_expr::interval_arithmetic::Interval;
47use datafusion_expr::sort_properties::ExprProperties;
48use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
49use datafusion_expr::{
50 expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs,
51 ScalarFunctionLambdaArg, ScalarUDF, ValueOrLambdaParameter, Volatility,
52};
53
54pub struct ScalarFunctionExpr {
56 fun: Arc<ScalarUDF>,
57 name: String,
58 args: Vec<Arc<dyn PhysicalExpr>>,
59 return_field: FieldRef,
60 config_options: Arc<ConfigOptions>,
61}
62
63impl Debug for ScalarFunctionExpr {
64 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
65 f.debug_struct("ScalarFunctionExpr")
66 .field("fun", &"<FUNC>")
67 .field("name", &self.name)
68 .field("args", &self.args)
69 .field("return_field", &self.return_field)
70 .finish()
71 }
72}
73
74impl ScalarFunctionExpr {
75 pub fn new(
77 name: &str,
78 fun: Arc<ScalarUDF>,
79 args: Vec<Arc<dyn PhysicalExpr>>,
80 return_field: FieldRef,
81 config_options: Arc<ConfigOptions>,
82 ) -> Self {
83 Self {
84 fun,
85 name: name.to_owned(),
86 args,
87 return_field,
88 config_options,
89 }
90 }
91
92 pub fn try_new(
94 fun: Arc<ScalarUDF>,
95 args: Vec<Arc<dyn PhysicalExpr>>,
96 schema: &Schema,
97 config_options: Arc<ConfigOptions>,
98 ) -> Result<Self> {
99 let lambdas_schemas = lambdas_schemas_from_args(&fun, &args, schema)?;
100
101 let arg_fields = std::iter::zip(&args, lambdas_schemas)
102 .map(|(e, schema)| {
103 if let Some(lambda) = e.as_any().downcast_ref::<LambdaExpr>() {
104 lambda.body().return_field(&schema)
105 } else {
106 e.return_field(&schema)
107 }
108 })
109 .collect::<Result<Vec<_>>>()?;
110
111 let arg_types = arg_fields
113 .iter()
114 .map(|f| f.data_type().clone())
115 .collect::<Vec<_>>();
116
117 data_types_with_scalar_udf(&arg_types, &fun)?;
118
119 let arguments = args
120 .iter()
121 .map(|e| {
122 e.as_any()
123 .downcast_ref::<Literal>()
124 .map(|literal| literal.value())
125 })
126 .collect::<Vec<_>>();
127
128 let lambdas = args
129 .iter()
130 .map(|e| e.as_any().is::<LambdaExpr>())
131 .collect::<Vec<_>>();
132
133 let ret_args = ReturnFieldArgs {
134 arg_fields: &arg_fields,
135 scalar_arguments: &arguments,
136 lambdas: &lambdas,
137 };
138
139 let return_field = fun.return_field_from_args(ret_args)?;
140 let name = fun.name().to_string();
141
142 Ok(Self {
143 fun,
144 name,
145 args,
146 return_field,
147 config_options,
148 })
149 }
150
151 pub fn fun(&self) -> &ScalarUDF {
153 &self.fun
154 }
155
156 pub fn name(&self) -> &str {
158 &self.name
159 }
160
161 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
163 &self.args
164 }
165
166 pub fn return_type(&self) -> &DataType {
168 self.return_field.data_type()
169 }
170
171 pub fn with_nullable(mut self, nullable: bool) -> Self {
172 self.return_field = self
173 .return_field
174 .as_ref()
175 .clone()
176 .with_nullable(nullable)
177 .into();
178 self
179 }
180
181 pub fn nullable(&self) -> bool {
182 self.return_field.is_nullable()
183 }
184
185 pub fn config_options(&self) -> &ConfigOptions {
186 &self.config_options
187 }
188
189 pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
194 where
195 T: 'static,
196 {
197 match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
198 Some(scalar_expr)
199 if scalar_expr
200 .fun()
201 .inner()
202 .as_any()
203 .downcast_ref::<T>()
204 .is_some() =>
205 {
206 Some(scalar_expr)
207 }
208 _ => None,
209 }
210 }
211}
212
213impl fmt::Display for ScalarFunctionExpr {
214 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
215 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
216 }
217}
218
219impl PartialEq for ScalarFunctionExpr {
220 fn eq(&self, o: &Self) -> bool {
221 if std::ptr::eq(self, o) {
222 return true;
224 }
225 let Self {
226 fun,
227 name,
228 args,
229 return_field,
230 config_options,
231 } = self;
232 fun.eq(&o.fun)
233 && name.eq(&o.name)
234 && args.eq(&o.args)
235 && return_field.eq(&o.return_field)
236 && (Arc::ptr_eq(config_options, &o.config_options)
237 || sorted_config_entries(config_options)
238 == sorted_config_entries(&o.config_options))
239 }
240}
241impl Eq for ScalarFunctionExpr {}
242impl Hash for ScalarFunctionExpr {
243 fn hash<H: Hasher>(&self, state: &mut H) {
244 let Self {
245 fun,
246 name,
247 args,
248 return_field,
249 config_options: _, } = self;
251 fun.hash(state);
252 name.hash(state);
253 args.hash(state);
254 return_field.hash(state);
255 }
256}
257
258fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
259 let mut entries = config_options.entries();
260 entries.sort_by(|l, r| l.key.cmp(&r.key));
261 entries
262}
263
264impl PhysicalExpr for ScalarFunctionExpr {
265 fn as_any(&self) -> &dyn Any {
267 self
268 }
269
270 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
271 Ok(self.return_field.data_type().clone())
272 }
273
274 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
275 Ok(self.return_field.is_nullable())
276 }
277
278 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
279 let args = self
280 .args
281 .iter()
282 .map(|e| match e.as_any().downcast_ref::<LambdaExpr>() {
283 Some(_) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
284 None => Ok(e.evaluate(batch)?),
285 })
286 .collect::<Result<Vec<_>>>()?;
287
288 let arg_fields = self
289 .args
290 .iter()
291 .map(|e| e.return_field(batch.schema_ref()))
292 .collect::<Result<Vec<_>>>()?;
293
294 let input_empty = args.is_empty();
295 let input_all_scalar = args
296 .iter()
297 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
298
299 let lambdas = if self.args.iter().any(|arg| arg.as_any().is::<LambdaExpr>()) {
300 let args_metadata = std::iter::zip(&self.args, &arg_fields)
301 .map(
302 |(expr, field)| match expr.as_any().downcast_ref::<LambdaExpr>() {
303 Some(lambda) => {
304 let mut captures = false;
305
306 expr.apply_with_lambdas_params(|expr, lambdas_params| {
307 match expr.as_any().downcast_ref::<Column>() {
308 Some(col) if !lambdas_params.contains(col.name()) => {
309 captures = true;
310
311 Ok(TreeNodeRecursion::Stop)
312 }
313 _ => Ok(TreeNodeRecursion::Continue),
314 }
315 })
316 .unwrap();
317
318 ValueOrLambdaParameter::Lambda(lambda.params(), captures)
319 }
320 None => ValueOrLambdaParameter::Value(Arc::clone(field)),
321 },
322 )
323 .collect::<Vec<_>>();
324
325 let params = self.fun().inner().lambdas_parameters(&args_metadata)?;
326
327 let lambdas = std::iter::zip(&self.args, params)
328 .map(|(arg, lambda_params)| {
329 arg.as_any()
330 .downcast_ref::<LambdaExpr>()
331 .map(|lambda| {
332 let mut indices = HashSet::new();
333
334 arg.apply_with_lambdas_params(|expr, lambdas_params| {
335 if let Some(column) =
336 expr.as_any().downcast_ref::<Column>()
337 {
338 if !lambdas_params.contains(column.name()) {
339 indices.insert(
340 column.index(), );
344 }
345 }
346
347 Ok(TreeNodeRecursion::Continue)
348 })?;
349
350 let params =
355 std::iter::zip(lambda.params(), lambda_params.unwrap())
356 .map(|(name, param)| Arc::new(param.with_name(name)))
357 .collect();
358
359 let captures = if !indices.is_empty() {
360 let (fields, columns): (Vec<_>, _) = std::iter::zip(
361 batch.schema_ref().fields(),
362 batch.columns(),
363 )
364 .enumerate()
365 .map(|(column_index, (field, column))| {
366 if indices.contains(&column_index) {
367 (Arc::clone(field), Arc::clone(column))
368 } else {
369 (
370 Arc::new(Field::new(
371 field.name(),
372 DataType::Null,
373 false,
374 )),
375 Arc::new(NullArray::new(column.len())) as _,
376 )
377 }
378 })
379 .unzip();
380
381 let schema = Arc::new(Schema::new(fields));
382
383 Some(RecordBatch::try_new(schema, columns)?)
384 } else {
386 None
387 };
388
389 Ok(ScalarFunctionLambdaArg {
390 params,
391 body: Arc::clone(lambda.body()),
392 captures,
393 })
394 })
395 .transpose()
396 })
397 .collect::<Result<Vec<_>>>()?;
398
399 Some(lambdas)
400 } else {
401 None
402 };
403
404 let output = self.fun.invoke_with_args(ScalarFunctionArgs {
406 args,
407 arg_fields,
408 number_rows: batch.num_rows(),
409 return_field: Arc::clone(&self.return_field),
410 config_options: Arc::clone(&self.config_options),
411 lambdas,
412 })?;
413
414 if let ColumnarValue::Array(array) = &output {
415 if array.len() != batch.num_rows() {
416 let preserve_scalar =
419 array.len() == 1 && !input_empty && input_all_scalar;
420 return if preserve_scalar {
421 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
422 } else {
423 internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
424 self.name, batch.num_rows(), array.len())
425 };
426 }
427 }
428 Ok(output)
429 }
430
431 fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
432 Ok(Arc::clone(&self.return_field))
433 }
434
435 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
436 self.args.iter().collect()
437 }
438
439 fn with_new_children(
440 self: Arc<Self>,
441 children: Vec<Arc<dyn PhysicalExpr>>,
442 ) -> Result<Arc<dyn PhysicalExpr>> {
443 Ok(Arc::new(ScalarFunctionExpr::new(
444 &self.name,
445 Arc::clone(&self.fun),
446 children,
447 Arc::clone(&self.return_field),
448 Arc::clone(&self.config_options),
449 )))
450 }
451
452 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
453 self.fun.evaluate_bounds(children)
454 }
455
456 fn propagate_constraints(
457 &self,
458 interval: &Interval,
459 children: &[&Interval],
460 ) -> Result<Option<Vec<Interval>>> {
461 self.fun.propagate_constraints(interval, children)
462 }
463
464 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
465 let sort_properties = self.fun.output_ordering(children)?;
466 let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
467 let children_range = children
468 .iter()
469 .map(|props| &props.range)
470 .collect::<Vec<_>>();
471 let range = self.fun().evaluate_bounds(&children_range)?;
472
473 Ok(ExprProperties {
474 sort_properties,
475 range,
476 preserves_lex_ordering,
477 })
478 }
479
480 fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
481 write!(f, "{}(", self.name)?;
482 for (i, expr) in self.args.iter().enumerate() {
483 if i > 0 {
484 write!(f, ", ")?;
485 }
486 expr.fmt_sql(f)?;
487 }
488 write!(f, ")")
489 }
490
491 fn is_volatile_node(&self) -> bool {
492 self.fun.signature().volatility == Volatility::Volatile
493 }
494}
495
496pub fn lambdas_schemas_from_args<'a>(
497 fun: &ScalarUDF,
498 args: &[Arc<dyn PhysicalExpr>],
499 schema: &'a Schema,
500) -> Result<Vec<Cow<'a, Schema>>> {
501 let args_metadata = args
502 .iter()
503 .map(|e| match e.as_any().downcast_ref::<LambdaExpr>() {
504 Some(lambda) => {
505 let mut captures = false;
506
507 e.apply_with_lambdas_params(|expr, lambdas_params| {
508 match expr.as_any().downcast_ref::<Column>() {
509 Some(col) if !lambdas_params.contains(col.name()) => {
510 captures = true;
511
512 Ok(TreeNodeRecursion::Stop)
513 }
514 _ => Ok(TreeNodeRecursion::Continue),
515 }
516 })
517 .unwrap();
518
519 Ok(ValueOrLambdaParameter::Lambda(lambda.params(), captures))
520 }
521 None => Ok(ValueOrLambdaParameter::Value(e.return_field(schema)?)),
522 })
523 .collect::<Result<Vec<_>>>()?;
524
525 fun.arguments_arrow_schema(&args_metadata, schema)
550}
551
552pub trait PhysicalExprExt: Sized {
553 fn apply_with_lambdas_params<
554 'n,
555 F: FnMut(&'n Self, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
556 >(
557 &'n self,
558 f: F,
559 ) -> Result<TreeNodeRecursion>;
560
561 fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>>(
562 &'n self,
563 schema: &Schema,
564 f: F,
565 ) -> Result<TreeNodeRecursion>;
566
567 fn apply_children_with_schema<
568 'n,
569 F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>,
570 >(
571 &'n self,
572 schema: &Schema,
573 f: F,
574 ) -> Result<TreeNodeRecursion>;
575
576 fn transform_down_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
577 self,
578 schema: &Schema,
579 f: F,
580 ) -> Result<Transformed<Self>>;
581
582 fn transform_up_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
583 self,
584 schema: &Schema,
585 f: F,
586 ) -> Result<Transformed<Self>>;
587
588 fn transform_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
589 self,
590 schema: &Schema,
591 f: F,
592 ) -> Result<Transformed<Self>> {
593 self.transform_up_with_schema(schema, f)
594 }
595
596 fn transform_down_with_lambdas_params(
597 self,
598 f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
599 ) -> Result<Transformed<Self>>;
600
601 fn transform_up_with_lambdas_params(
602 self,
603 f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
604 ) -> Result<Transformed<Self>>;
605
606 fn transform_with_lambdas_params(
607 self,
608 f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
609 ) -> Result<Transformed<Self>> {
610 self.transform_up_with_lambdas_params(f)
611 }
612}
613
614impl PhysicalExprExt for Arc<dyn PhysicalExpr> {
615 fn apply_with_lambdas_params<
616 'n,
617 F: FnMut(&'n Self, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
618 >(
619 &'n self,
620 mut f: F,
621 ) -> Result<TreeNodeRecursion> {
622 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
623 fn apply_with_lambdas_params_impl<
624 'n,
625 F: FnMut(
626 &'n Arc<dyn PhysicalExpr>,
627 &HashSet<&'n str>,
628 ) -> Result<TreeNodeRecursion>,
629 >(
630 node: &'n Arc<dyn PhysicalExpr>,
631 args: &HashSet<&'n str>,
632 f: &mut F,
633 ) -> Result<TreeNodeRecursion> {
634 match node.as_any().downcast_ref::<LambdaExpr>() {
635 Some(lambda) => {
636 let mut args = args.clone();
637
638 args.extend(lambda.params().iter().map(|v| v.as_str()));
639
640 f(node, &args)?.visit_children(|| {
641 node.apply_children(|c| {
642 apply_with_lambdas_params_impl(c, &args, f)
643 })
644 })
645 }
646 _ => f(node, args)?.visit_children(|| {
647 node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f))
648 }),
649 }
650 }
651
652 apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
653 }
654
655 fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>>(
656 &'n self,
657 schema: &Schema,
658 mut f: F,
659 ) -> Result<TreeNodeRecursion> {
660 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
661 fn apply_with_lambdas_impl<
662 'n,
663 F: FnMut(&'n Arc<dyn PhysicalExpr>, &Schema) -> Result<TreeNodeRecursion>,
664 >(
665 node: &'n Arc<dyn PhysicalExpr>,
666 schema: &Schema,
667 f: &mut F,
668 ) -> Result<TreeNodeRecursion> {
669 f(node, schema)?.visit_children(|| {
670 node.apply_children_with_schema(schema, |c, schema| {
671 apply_with_lambdas_impl(c, schema, f)
672 })
673 })
674 }
675
676 apply_with_lambdas_impl(self, schema, &mut f)
677 }
678
679 fn apply_children_with_schema<
680 'n,
681 F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>,
682 >(
683 &'n self,
684 schema: &Schema,
685 mut f: F,
686 ) -> Result<TreeNodeRecursion> {
687 match self.as_any().downcast_ref::<ScalarFunctionExpr>() {
688 Some(scalar_function)
689 if scalar_function
690 .args()
691 .iter()
692 .any(|arg| arg.as_any().is::<LambdaExpr>()) =>
693 {
694 let mut lambdas_schemas = lambdas_schemas_from_args(
695 scalar_function.fun(),
696 scalar_function.args(),
697 schema,
698 )?
699 .into_iter();
700
701 self.apply_children(|expr| f(expr, &lambdas_schemas.next().unwrap()))
702 }
703 _ => self.apply_children(|e| f(e, schema)),
704 }
705 }
706
707 fn transform_down_with_schema<
708 F: FnMut(Self, &Schema) -> Result<Transformed<Self>>,
709 >(
710 self,
711 schema: &Schema,
712 mut f: F,
713 ) -> Result<Transformed<Self>> {
714 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
715 fn transform_down_with_schema_impl<
716 F: FnMut(
717 Arc<dyn PhysicalExpr>,
718 &Schema,
719 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
720 >(
721 node: Arc<dyn PhysicalExpr>,
722 schema: &Schema,
723 f: &mut F,
724 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
725 f(node, schema)?.transform_children(|node| {
726 map_children_with_schema(node, schema, |n, schema| {
727 transform_down_with_schema_impl(n, schema, f)
728 })
729 })
730 }
731
732 transform_down_with_schema_impl(self, schema, &mut f)
733 }
734
735 fn transform_up_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
736 self,
737 schema: &Schema,
738 mut f: F,
739 ) -> Result<Transformed<Self>> {
740 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
741 fn transform_up_with_schema_impl<
742 F: FnMut(
743 Arc<dyn PhysicalExpr>,
744 &Schema,
745 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
746 >(
747 node: Arc<dyn PhysicalExpr>,
748 schema: &Schema,
749 f: &mut F,
750 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
751 map_children_with_schema(node, schema, |n, schema| {
752 transform_up_with_schema_impl(n, schema, f)
753 })?
754 .transform_parent(|n| f(n, schema))
755 }
756
757 transform_up_with_schema_impl(self, schema, &mut f)
758 }
759
760 fn transform_up_with_lambdas_params(
761 self,
762 mut f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
763 ) -> Result<Transformed<Self>> {
764 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
765 fn transform_up_with_lambdas_params_impl<
766 F: FnMut(
767 Arc<dyn PhysicalExpr>,
768 &HashSet<String>,
769 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
770 >(
771 node: Arc<dyn PhysicalExpr>,
772 params: &HashSet<String>,
773 f: &mut F,
774 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
775 map_children_with_lambdas_params(node, params, |n, params| {
776 transform_up_with_lambdas_params_impl(n, params, f)
777 })?
778 .transform_parent(|n| f(n, params))
779 }
780
781 transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
782 }
783
784 fn transform_down_with_lambdas_params(
785 self,
786 mut f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
787 ) -> Result<Transformed<Self>> {
788 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
789 fn transform_down_with_lambdas_params_impl<
790 F: FnMut(
791 Arc<dyn PhysicalExpr>,
792 &HashSet<String>,
793 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
794 >(
795 node: Arc<dyn PhysicalExpr>,
796 params: &HashSet<String>,
797 f: &mut F,
798 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
799 f(node, params)?.transform_children(|node| {
800 map_children_with_lambdas_params(node, params, |node, args| {
801 transform_down_with_lambdas_params_impl(node, args, f)
802 })
803 })
804 }
805
806 transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
807 }
808}
809
810fn map_children_with_schema(
811 node: Arc<dyn PhysicalExpr>,
812 schema: &Schema,
813 mut f: impl FnMut(
814 Arc<dyn PhysicalExpr>,
815 &Schema,
816 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
817) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
818 match node.as_any().downcast_ref::<ScalarFunctionExpr>() {
819 Some(fun) if fun.args().iter().any(|arg| arg.as_any().is::<LambdaExpr>()) => {
820 let mut args_schemas =
821 lambdas_schemas_from_args(fun.fun(), fun.args(), schema)?.into_iter();
822
823 node.map_children(|node| f(node, &args_schemas.next().unwrap()))
824 }
825 _ => node.map_children(|node| f(node, schema)),
826 }
827}
828
829fn map_children_with_lambdas_params(
830 node: Arc<dyn PhysicalExpr>,
831 params: &HashSet<String>,
832 mut f: impl FnMut(
833 Arc<dyn PhysicalExpr>,
834 &HashSet<String>,
835 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
836) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
837 match node.as_any().downcast_ref::<LambdaExpr>() {
838 Some(lambda) => {
839 let mut params = params.clone();
840
841 params.extend(lambda.params().iter().cloned());
842
843 node.map_children(|node| f(node, ¶ms))
844 }
845 None => node.map_children(|node| f(node, params)),
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use std::any::Any;
852 use std::{borrow::Cow, sync::Arc};
853
854 use super::*;
855 use super::{lambdas_schemas_from_args, PhysicalExprExt};
856 use crate::expressions::Column;
857 use crate::{create_physical_expr, ScalarFunctionExpr};
858 use arrow::datatypes::{DataType, Field, Schema};
859 use datafusion_common::{tree_node::TreeNodeRecursion, DFSchema, HashSet, Result};
860 use datafusion_expr::{
861 col, expr::Lambda, Expr, ScalarFunctionArgs, ValueOrLambdaParameter, Volatility,
862 };
863 use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
864 use datafusion_expr_common::columnar_value::ColumnarValue;
865 use datafusion_physical_expr_common::physical_expr::is_volatile;
866 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
867
868 #[derive(Debug, PartialEq, Eq, Hash)]
870 struct MockScalarUDF {
871 signature: Signature,
872 }
873
874 impl ScalarUDFImpl for MockScalarUDF {
875 fn as_any(&self) -> &dyn Any {
876 self
877 }
878
879 fn name(&self) -> &str {
880 "mock_function"
881 }
882
883 fn signature(&self) -> &Signature {
884 &self.signature
885 }
886
887 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
888 Ok(DataType::Int32)
889 }
890
891 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
892 Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
893 }
894 }
895
896 #[test]
897 fn test_scalar_function_volatile_node() {
898 let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
900 signature: Signature::uniform(
901 1,
902 vec![DataType::Float32],
903 Volatility::Volatile,
904 ),
905 }));
906
907 let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
909 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
910 }));
911
912 let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
913 let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
914 let config_options = Arc::new(ConfigOptions::new());
915
916 let volatile_expr = ScalarFunctionExpr::try_new(
918 volatile_udf,
919 args.clone(),
920 &schema,
921 Arc::clone(&config_options),
922 )
923 .unwrap();
924
925 assert!(volatile_expr.is_volatile_node());
926 let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
927 assert!(is_volatile(&volatile_arc));
928
929 let stable_expr =
931 ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
932 .unwrap();
933
934 assert!(!stable_expr.is_volatile_node());
935 let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
936 assert!(!is_volatile(&stable_arc));
937 }
938
939 fn list_list_int() -> Schema {
940 Schema::new(vec![Field::new(
941 "v",
942 DataType::new_list(DataType::new_list(DataType::Int32, false), false),
943 false,
944 )])
945 }
946
947 fn list_int() -> Schema {
948 Schema::new(vec![Field::new(
949 "v",
950 DataType::new_list(DataType::Int32, false),
951 false,
952 )])
953 }
954
955 fn int() -> Schema {
956 Schema::new(vec![Field::new("v", DataType::Int32, false)])
957 }
958
959 fn array_transform_udf() -> ScalarUDF {
960 ScalarUDF::new_from_impl(ArrayTransformFunc::new())
961 }
962
963 fn args() -> Vec<Expr> {
964 vec![
965 col("v"),
966 Expr::Lambda(Lambda::new(
967 vec!["v".into()],
968 array_transform_udf().call(vec![
969 col("v"),
970 Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))),
971 ]),
972 )),
973 ]
974 }
975
976 fn array_transform() -> Arc<dyn PhysicalExpr> {
978 let e = array_transform_udf().call(args());
979
980 create_physical_expr(
981 &e,
982 &DFSchema::try_from(list_list_int()).unwrap(),
983 &Default::default(),
984 )
985 .unwrap()
986 }
987
988 #[derive(Debug, PartialEq, Eq, Hash)]
989 struct ArrayTransformFunc {
990 signature: Signature,
991 }
992
993 impl ArrayTransformFunc {
994 pub fn new() -> Self {
995 Self {
996 signature: Signature::any(2, Volatility::Immutable),
997 }
998 }
999 }
1000
1001 impl ScalarUDFImpl for ArrayTransformFunc {
1002 fn as_any(&self) -> &dyn std::any::Any {
1003 self
1004 }
1005
1006 fn name(&self) -> &str {
1007 "array_transform"
1008 }
1009
1010 fn signature(&self) -> &Signature {
1011 &self.signature
1012 }
1013
1014 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1015 Ok(arg_types[0].clone())
1016 }
1017
1018 fn lambdas_parameters(
1019 &self,
1020 args: &[ValueOrLambdaParameter],
1021 ) -> Result<Vec<Option<Vec<Field>>>> {
1022 let ValueOrLambdaParameter::Value(value_field) = &args[0] else {
1023 unimplemented!()
1024 };
1025 let DataType::List(field) = value_field.data_type() else {
1026 unimplemented!()
1027 };
1028
1029 Ok(vec![
1030 None,
1031 Some(vec![Field::new(
1032 "",
1033 field.data_type().clone(),
1034 field.is_nullable(),
1035 )]),
1036 ])
1037 }
1038
1039 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1040 unimplemented!()
1041 }
1042 }
1043
1044 #[test]
1045 fn test_lambdas_schemas_from_args() {
1046 let schema = list_list_int();
1047 let expr = array_transform();
1048
1049 let args = expr
1050 .as_any()
1051 .downcast_ref::<ScalarFunctionExpr>()
1052 .unwrap()
1053 .args();
1054
1055 let schemas =
1056 lambdas_schemas_from_args(&array_transform_udf(), args, &schema).unwrap();
1057
1058 assert_eq!(schemas, &[Cow::Borrowed(&schema), Cow::Owned(list_int())]);
1059 }
1060
1061 #[test]
1062 fn test_apply_with_schema() {
1063 let mut steps = vec![];
1064
1065 array_transform()
1066 .apply_with_schema(&list_list_int(), |node, schema| {
1067 steps.push((node.to_string(), schema.clone()));
1068
1069 Ok(TreeNodeRecursion::Continue)
1070 })
1071 .unwrap();
1072
1073 let expected = [
1074 (
1075 "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))",
1076 list_list_int(),
1077 ),
1078 ("(v) -> array_transform(v@0, (v) -> (- v@0))", list_int()),
1079 ("array_transform(v@0, (v) -> (- v@0))", list_int()),
1080 ("(v) -> (- v@0)", int()),
1081 ("(- v@0)", int()),
1082 ("v@0", int()),
1083 ("v@0", int()),
1084 ("v@0", int()),
1085 ]
1086 .map(|(a, b)| (String::from(a), b));
1087
1088 assert_eq!(steps, expected);
1089 }
1090
1091 #[test]
1092 fn test_apply_with_lambdas_params() {
1093 let array_transform = array_transform();
1094 let mut steps = vec![];
1095
1096 array_transform
1097 .apply_with_lambdas_params(|node, params| {
1098 steps.push((node.to_string(), params.clone()));
1099
1100 Ok(TreeNodeRecursion::Continue)
1101 })
1102 .unwrap();
1103
1104 let expected = [
1105 (
1106 "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))",
1107 HashSet::from(["v"]),
1108 ),
1109 (
1110 "(v) -> array_transform(v@0, (v) -> (- v@0))",
1111 HashSet::from(["v"]),
1112 ),
1113 ("array_transform(v@0, (v) -> (- v@0))", HashSet::from(["v"])),
1114 ("(v) -> (- v@0)", HashSet::from(["v"])),
1115 ("(- v@0)", HashSet::from(["v"])),
1116 ("v@0", HashSet::from(["v"])),
1117 ("v@0", HashSet::from(["v"])),
1118 ("v@0", HashSet::from(["v"])),
1119 ]
1120 .map(|(a, b)| (String::from(a), b));
1121
1122 assert_eq!(steps, expected);
1123 }
1124}