datafusion_physical_expr/
scalar_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Declaration of built-in (scalar) functions.
19//! This module contains built-in functions' enumeration and metadata.
20//!
21//! Generally, a function has:
22//! * a signature
23//! * a return type, that is a function of the incoming argument's types
24//! * the computation, that must accept each valid signature
25//!
26//! * Signature: see `Signature`
27//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
28//!
29//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
30//! to a function that supports f64, it is coerced to f64.
31
32use std::any::Any;
33use std::borrow::Cow;
34use std::fmt::{self, Debug, Formatter};
35use std::hash::{Hash, Hasher};
36use std::sync::Arc;
37
38use crate::expressions::{Column, LambdaExpr, Literal};
39use crate::PhysicalExpr;
40
41use arrow::array::{Array, NullArray, RecordBatch};
42use arrow::datatypes::{DataType, Field, FieldRef, Schema};
43use datafusion_common::config::{ConfigEntry, ConfigOptions};
44use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
45use datafusion_common::{internal_err, HashSet, Result, ScalarValue};
46use datafusion_expr::interval_arithmetic::Interval;
47use datafusion_expr::sort_properties::ExprProperties;
48use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
49use datafusion_expr::{
50    expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs,
51    ScalarFunctionLambdaArg, ScalarUDF, ValueOrLambdaParameter, Volatility,
52};
53
54/// Physical expression of a scalar function
55pub struct ScalarFunctionExpr {
56    fun: Arc<ScalarUDF>,
57    name: String,
58    args: Vec<Arc<dyn PhysicalExpr>>,
59    return_field: FieldRef,
60    config_options: Arc<ConfigOptions>,
61}
62
63impl Debug for ScalarFunctionExpr {
64    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
65        f.debug_struct("ScalarFunctionExpr")
66            .field("fun", &"<FUNC>")
67            .field("name", &self.name)
68            .field("args", &self.args)
69            .field("return_field", &self.return_field)
70            .finish()
71    }
72}
73
74impl ScalarFunctionExpr {
75    /// Create a new Scalar function
76    pub fn new(
77        name: &str,
78        fun: Arc<ScalarUDF>,
79        args: Vec<Arc<dyn PhysicalExpr>>,
80        return_field: FieldRef,
81        config_options: Arc<ConfigOptions>,
82    ) -> Self {
83        Self {
84            fun,
85            name: name.to_owned(),
86            args,
87            return_field,
88            config_options,
89        }
90    }
91
92    /// Create a new Scalar function
93    pub fn try_new(
94        fun: Arc<ScalarUDF>,
95        args: Vec<Arc<dyn PhysicalExpr>>,
96        schema: &Schema,
97        config_options: Arc<ConfigOptions>,
98    ) -> Result<Self> {
99        let lambdas_schemas = lambdas_schemas_from_args(&fun, &args, schema)?;
100
101        let arg_fields = std::iter::zip(&args, lambdas_schemas)
102            .map(|(e, schema)| {
103                if let Some(lambda) = e.as_any().downcast_ref::<LambdaExpr>() {
104                    lambda.body().return_field(&schema)
105                } else {
106                    e.return_field(&schema)
107                }
108            })
109            .collect::<Result<Vec<_>>>()?;
110
111        // verify that input data types is consistent with function's `TypeSignature`
112        let arg_types = arg_fields
113            .iter()
114            .map(|f| f.data_type().clone())
115            .collect::<Vec<_>>();
116
117        data_types_with_scalar_udf(&arg_types, &fun)?;
118
119        let arguments = args
120            .iter()
121            .map(|e| {
122                e.as_any()
123                    .downcast_ref::<Literal>()
124                    .map(|literal| literal.value())
125            })
126            .collect::<Vec<_>>();
127
128        let lambdas = args
129            .iter()
130            .map(|e| e.as_any().is::<LambdaExpr>())
131            .collect::<Vec<_>>();
132
133        let ret_args = ReturnFieldArgs {
134            arg_fields: &arg_fields,
135            scalar_arguments: &arguments,
136            lambdas: &lambdas,
137        };
138
139        let return_field = fun.return_field_from_args(ret_args)?;
140        let name = fun.name().to_string();
141
142        Ok(Self {
143            fun,
144            name,
145            args,
146            return_field,
147            config_options,
148        })
149    }
150
151    /// Get the scalar function implementation
152    pub fn fun(&self) -> &ScalarUDF {
153        &self.fun
154    }
155
156    /// The name for this expression
157    pub fn name(&self) -> &str {
158        &self.name
159    }
160
161    /// Input arguments
162    pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
163        &self.args
164    }
165
166    /// Data type produced by this expression
167    pub fn return_type(&self) -> &DataType {
168        self.return_field.data_type()
169    }
170
171    pub fn with_nullable(mut self, nullable: bool) -> Self {
172        self.return_field = self
173            .return_field
174            .as_ref()
175            .clone()
176            .with_nullable(nullable)
177            .into();
178        self
179    }
180
181    pub fn nullable(&self) -> bool {
182        self.return_field.is_nullable()
183    }
184
185    pub fn config_options(&self) -> &ConfigOptions {
186        &self.config_options
187    }
188
189    /// Given an arbitrary PhysicalExpr attempt to downcast it to a ScalarFunctionExpr
190    /// and verify that its inner function is of type T.
191    /// If the downcast fails, or the function is not of type T, returns `None`.
192    /// Otherwise returns `Some(ScalarFunctionExpr)`.
193    pub fn try_downcast_func<T>(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr>
194    where
195        T: 'static,
196    {
197        match expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
198            Some(scalar_expr)
199                if scalar_expr
200                    .fun()
201                    .inner()
202                    .as_any()
203                    .downcast_ref::<T>()
204                    .is_some() =>
205            {
206                Some(scalar_expr)
207            }
208            _ => None,
209        }
210    }
211}
212
213impl fmt::Display for ScalarFunctionExpr {
214    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
215        write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
216    }
217}
218
219impl PartialEq for ScalarFunctionExpr {
220    fn eq(&self, o: &Self) -> bool {
221        if std::ptr::eq(self, o) {
222            // The equality implementation is somewhat expensive, so let's short-circuit when possible.
223            return true;
224        }
225        let Self {
226            fun,
227            name,
228            args,
229            return_field,
230            config_options,
231        } = self;
232        fun.eq(&o.fun)
233            && name.eq(&o.name)
234            && args.eq(&o.args)
235            && return_field.eq(&o.return_field)
236            && (Arc::ptr_eq(config_options, &o.config_options)
237                || sorted_config_entries(config_options)
238                    == sorted_config_entries(&o.config_options))
239    }
240}
241impl Eq for ScalarFunctionExpr {}
242impl Hash for ScalarFunctionExpr {
243    fn hash<H: Hasher>(&self, state: &mut H) {
244        let Self {
245            fun,
246            name,
247            args,
248            return_field,
249            config_options: _, // expensive to hash, and often equal
250        } = self;
251        fun.hash(state);
252        name.hash(state);
253        args.hash(state);
254        return_field.hash(state);
255    }
256}
257
258fn sorted_config_entries(config_options: &ConfigOptions) -> Vec<ConfigEntry> {
259    let mut entries = config_options.entries();
260    entries.sort_by(|l, r| l.key.cmp(&r.key));
261    entries
262}
263
264impl PhysicalExpr for ScalarFunctionExpr {
265    /// Return a reference to Any that can be used for downcasting
266    fn as_any(&self) -> &dyn Any {
267        self
268    }
269
270    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
271        Ok(self.return_field.data_type().clone())
272    }
273
274    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
275        Ok(self.return_field.is_nullable())
276    }
277
278    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
279        let args = self
280            .args
281            .iter()
282            .map(|e| match e.as_any().downcast_ref::<LambdaExpr>() {
283                Some(_) => Ok(ColumnarValue::Scalar(ScalarValue::Null)),
284                None => Ok(e.evaluate(batch)?),
285            })
286            .collect::<Result<Vec<_>>>()?;
287
288        let arg_fields = self
289            .args
290            .iter()
291            .map(|e| e.return_field(batch.schema_ref()))
292            .collect::<Result<Vec<_>>>()?;
293
294        let input_empty = args.is_empty();
295        let input_all_scalar = args
296            .iter()
297            .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
298
299        let lambdas = if self.args.iter().any(|arg| arg.as_any().is::<LambdaExpr>()) {
300            let args_metadata = std::iter::zip(&self.args, &arg_fields)
301                .map(
302                    |(expr, field)| match expr.as_any().downcast_ref::<LambdaExpr>() {
303                        Some(lambda) => {
304                            let mut captures = false;
305
306                            expr.apply_with_lambdas_params(|expr, lambdas_params| {
307                                match expr.as_any().downcast_ref::<Column>() {
308                                    Some(col) if !lambdas_params.contains(col.name()) => {
309                                        captures = true;
310
311                                        Ok(TreeNodeRecursion::Stop)
312                                    }
313                                    _ => Ok(TreeNodeRecursion::Continue),
314                                }
315                            })
316                            .unwrap();
317
318                            ValueOrLambdaParameter::Lambda(lambda.params(), captures)
319                        }
320                        None => ValueOrLambdaParameter::Value(Arc::clone(field)),
321                    },
322                )
323                .collect::<Vec<_>>();
324
325            let params = self.fun().inner().lambdas_parameters(&args_metadata)?;
326
327            let lambdas = std::iter::zip(&self.args, params)
328                .map(|(arg, lambda_params)| {
329                    arg.as_any()
330                        .downcast_ref::<LambdaExpr>()
331                        .map(|lambda| {
332                            let mut indices = HashSet::new();
333
334                            arg.apply_with_lambdas_params(|expr, lambdas_params| {
335                                if let Some(column) =
336                                    expr.as_any().downcast_ref::<Column>()
337                                {
338                                    if !lambdas_params.contains(column.name()) {
339                                        indices.insert(
340                                            column.index(), //batch
341                                                            //    .schema_ref()
342                                                            //    .index_of(column.name())?,
343                                        );
344                                    }
345                                }
346
347                                Ok(TreeNodeRecursion::Continue)
348                            })?;
349
350                            //let mut indices = indices.into_iter().collect::<Vec<_>>();
351
352                            //indices.sort_unstable();
353
354                            let params =
355                                std::iter::zip(lambda.params(), lambda_params.unwrap())
356                                    .map(|(name, param)| Arc::new(param.with_name(name)))
357                                    .collect();
358
359                            let captures = if !indices.is_empty() {
360                                let (fields, columns): (Vec<_>, _) = std::iter::zip(
361                                    batch.schema_ref().fields(),
362                                    batch.columns(),
363                                )
364                                .enumerate()
365                                .map(|(column_index, (field, column))| {
366                                    if indices.contains(&column_index) {
367                                        (Arc::clone(field), Arc::clone(column))
368                                    } else {
369                                        (
370                                            Arc::new(Field::new(
371                                                field.name(),
372                                                DataType::Null,
373                                                false,
374                                            )),
375                                            Arc::new(NullArray::new(column.len())) as _,
376                                        )
377                                    }
378                                })
379                                .unzip();
380
381                                let schema = Arc::new(Schema::new(fields));
382
383                                Some(RecordBatch::try_new(schema, columns)?)
384                                //Some(batch.project(&indices)?)
385                            } else {
386                                None
387                            };
388
389                            Ok(ScalarFunctionLambdaArg {
390                                params,
391                                body: Arc::clone(lambda.body()),
392                                captures,
393                            })
394                        })
395                        .transpose()
396                })
397                .collect::<Result<Vec<_>>>()?;
398
399            Some(lambdas)
400        } else {
401            None
402        };
403
404        // evaluate the function
405        let output = self.fun.invoke_with_args(ScalarFunctionArgs {
406            args,
407            arg_fields,
408            number_rows: batch.num_rows(),
409            return_field: Arc::clone(&self.return_field),
410            config_options: Arc::clone(&self.config_options),
411            lambdas,
412        })?;
413
414        if let ColumnarValue::Array(array) = &output {
415            if array.len() != batch.num_rows() {
416                // If the arguments are a non-empty slice of scalar values, we can assume that
417                // returning a one-element array is equivalent to returning a scalar.
418                let preserve_scalar =
419                    array.len() == 1 && !input_empty && input_all_scalar;
420                return if preserve_scalar {
421                    ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
422                } else {
423                    internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
424                            self.name, batch.num_rows(), array.len())
425                };
426            }
427        }
428        Ok(output)
429    }
430
431    fn return_field(&self, _input_schema: &Schema) -> Result<FieldRef> {
432        Ok(Arc::clone(&self.return_field))
433    }
434
435    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
436        self.args.iter().collect()
437    }
438
439    fn with_new_children(
440        self: Arc<Self>,
441        children: Vec<Arc<dyn PhysicalExpr>>,
442    ) -> Result<Arc<dyn PhysicalExpr>> {
443        Ok(Arc::new(ScalarFunctionExpr::new(
444            &self.name,
445            Arc::clone(&self.fun),
446            children,
447            Arc::clone(&self.return_field),
448            Arc::clone(&self.config_options),
449        )))
450    }
451
452    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
453        self.fun.evaluate_bounds(children)
454    }
455
456    fn propagate_constraints(
457        &self,
458        interval: &Interval,
459        children: &[&Interval],
460    ) -> Result<Option<Vec<Interval>>> {
461        self.fun.propagate_constraints(interval, children)
462    }
463
464    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
465        let sort_properties = self.fun.output_ordering(children)?;
466        let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
467        let children_range = children
468            .iter()
469            .map(|props| &props.range)
470            .collect::<Vec<_>>();
471        let range = self.fun().evaluate_bounds(&children_range)?;
472
473        Ok(ExprProperties {
474            sort_properties,
475            range,
476            preserves_lex_ordering,
477        })
478    }
479
480    fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result {
481        write!(f, "{}(", self.name)?;
482        for (i, expr) in self.args.iter().enumerate() {
483            if i > 0 {
484                write!(f, ", ")?;
485            }
486            expr.fmt_sql(f)?;
487        }
488        write!(f, ")")
489    }
490
491    fn is_volatile_node(&self) -> bool {
492        self.fun.signature().volatility == Volatility::Volatile
493    }
494}
495
496pub fn lambdas_schemas_from_args<'a>(
497    fun: &ScalarUDF,
498    args: &[Arc<dyn PhysicalExpr>],
499    schema: &'a Schema,
500) -> Result<Vec<Cow<'a, Schema>>> {
501    let args_metadata = args
502        .iter()
503        .map(|e| match e.as_any().downcast_ref::<LambdaExpr>() {
504            Some(lambda) => {
505                let mut captures = false;
506
507                e.apply_with_lambdas_params(|expr, lambdas_params| {
508                    match expr.as_any().downcast_ref::<Column>() {
509                        Some(col) if !lambdas_params.contains(col.name()) => {
510                            captures = true;
511
512                            Ok(TreeNodeRecursion::Stop)
513                        }
514                        _ => Ok(TreeNodeRecursion::Continue),
515                    }
516                })
517                .unwrap();
518
519                Ok(ValueOrLambdaParameter::Lambda(lambda.params(), captures))
520            }
521            None => Ok(ValueOrLambdaParameter::Value(e.return_field(schema)?)),
522        })
523        .collect::<Result<Vec<_>>>()?;
524
525    /*let captures = args
526    .iter()
527    .map(|arg| {
528        if arg.as_any().is::<LambdaExpr>() {
529            let mut columns = HashSet::new();
530
531            arg.apply_with_lambdas_params(|n, lambdas_params| {
532                if let Some(column) = n.as_any().downcast_ref::<Column>() {
533                    if !lambdas_params.contains(column.name()) {
534                        columns.insert(schema.index_of(column.name())?);
535                    }
536                    // columns.insert(column.index());
537                }
538
539                Ok(TreeNodeRecursion::Continue)
540            })?;
541
542            Ok(columns)
543        } else {
544            Ok(HashSet::new())
545        }
546    })
547    .collect::<Result<Vec<_>>>()?; */
548
549    fun.arguments_arrow_schema(&args_metadata, schema)
550}
551
552pub trait PhysicalExprExt: Sized {
553    fn apply_with_lambdas_params<
554        'n,
555        F: FnMut(&'n Self, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
556    >(
557        &'n self,
558        f: F,
559    ) -> Result<TreeNodeRecursion>;
560
561    fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>>(
562        &'n self,
563        schema: &Schema,
564        f: F,
565    ) -> Result<TreeNodeRecursion>;
566
567    fn apply_children_with_schema<
568        'n,
569        F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>,
570    >(
571        &'n self,
572        schema: &Schema,
573        f: F,
574    ) -> Result<TreeNodeRecursion>;
575
576    fn transform_down_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
577        self,
578        schema: &Schema,
579        f: F,
580    ) -> Result<Transformed<Self>>;
581
582    fn transform_up_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
583        self,
584        schema: &Schema,
585        f: F,
586    ) -> Result<Transformed<Self>>;
587
588    fn transform_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
589        self,
590        schema: &Schema,
591        f: F,
592    ) -> Result<Transformed<Self>> {
593        self.transform_up_with_schema(schema, f)
594    }
595
596    fn transform_down_with_lambdas_params(
597        self,
598        f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
599    ) -> Result<Transformed<Self>>;
600
601    fn transform_up_with_lambdas_params(
602        self,
603        f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
604    ) -> Result<Transformed<Self>>;
605
606    fn transform_with_lambdas_params(
607        self,
608        f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
609    ) -> Result<Transformed<Self>> {
610        self.transform_up_with_lambdas_params(f)
611    }
612}
613
614impl PhysicalExprExt for Arc<dyn PhysicalExpr> {
615    fn apply_with_lambdas_params<
616        'n,
617        F: FnMut(&'n Self, &HashSet<&'n str>) -> Result<TreeNodeRecursion>,
618    >(
619        &'n self,
620        mut f: F,
621    ) -> Result<TreeNodeRecursion> {
622        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
623        fn apply_with_lambdas_params_impl<
624            'n,
625            F: FnMut(
626                &'n Arc<dyn PhysicalExpr>,
627                &HashSet<&'n str>,
628            ) -> Result<TreeNodeRecursion>,
629        >(
630            node: &'n Arc<dyn PhysicalExpr>,
631            args: &HashSet<&'n str>,
632            f: &mut F,
633        ) -> Result<TreeNodeRecursion> {
634            match node.as_any().downcast_ref::<LambdaExpr>() {
635                Some(lambda) => {
636                    let mut args = args.clone();
637
638                    args.extend(lambda.params().iter().map(|v| v.as_str()));
639
640                    f(node, &args)?.visit_children(|| {
641                        node.apply_children(|c| {
642                            apply_with_lambdas_params_impl(c, &args, f)
643                        })
644                    })
645                }
646                _ => f(node, args)?.visit_children(|| {
647                    node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f))
648                }),
649            }
650        }
651
652        apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
653    }
654
655    fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>>(
656        &'n self,
657        schema: &Schema,
658        mut f: F,
659    ) -> Result<TreeNodeRecursion> {
660        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
661        fn apply_with_lambdas_impl<
662            'n,
663            F: FnMut(&'n Arc<dyn PhysicalExpr>, &Schema) -> Result<TreeNodeRecursion>,
664        >(
665            node: &'n Arc<dyn PhysicalExpr>,
666            schema: &Schema,
667            f: &mut F,
668        ) -> Result<TreeNodeRecursion> {
669            f(node, schema)?.visit_children(|| {
670                node.apply_children_with_schema(schema, |c, schema| {
671                    apply_with_lambdas_impl(c, schema, f)
672                })
673            })
674        }
675
676        apply_with_lambdas_impl(self, schema, &mut f)
677    }
678
679    fn apply_children_with_schema<
680        'n,
681        F: FnMut(&'n Self, &Schema) -> Result<TreeNodeRecursion>,
682    >(
683        &'n self,
684        schema: &Schema,
685        mut f: F,
686    ) -> Result<TreeNodeRecursion> {
687        match self.as_any().downcast_ref::<ScalarFunctionExpr>() {
688            Some(scalar_function)
689                if scalar_function
690                    .args()
691                    .iter()
692                    .any(|arg| arg.as_any().is::<LambdaExpr>()) =>
693            {
694                let mut lambdas_schemas = lambdas_schemas_from_args(
695                    scalar_function.fun(),
696                    scalar_function.args(),
697                    schema,
698                )?
699                .into_iter();
700
701                self.apply_children(|expr| f(expr, &lambdas_schemas.next().unwrap()))
702            }
703            _ => self.apply_children(|e| f(e, schema)),
704        }
705    }
706
707    fn transform_down_with_schema<
708        F: FnMut(Self, &Schema) -> Result<Transformed<Self>>,
709    >(
710        self,
711        schema: &Schema,
712        mut f: F,
713    ) -> Result<Transformed<Self>> {
714        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
715        fn transform_down_with_schema_impl<
716            F: FnMut(
717                Arc<dyn PhysicalExpr>,
718                &Schema,
719            ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
720        >(
721            node: Arc<dyn PhysicalExpr>,
722            schema: &Schema,
723            f: &mut F,
724        ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
725            f(node, schema)?.transform_children(|node| {
726                map_children_with_schema(node, schema, |n, schema| {
727                    transform_down_with_schema_impl(n, schema, f)
728                })
729            })
730        }
731
732        transform_down_with_schema_impl(self, schema, &mut f)
733    }
734
735    fn transform_up_with_schema<F: FnMut(Self, &Schema) -> Result<Transformed<Self>>>(
736        self,
737        schema: &Schema,
738        mut f: F,
739    ) -> Result<Transformed<Self>> {
740        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
741        fn transform_up_with_schema_impl<
742            F: FnMut(
743                Arc<dyn PhysicalExpr>,
744                &Schema,
745            ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
746        >(
747            node: Arc<dyn PhysicalExpr>,
748            schema: &Schema,
749            f: &mut F,
750        ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
751            map_children_with_schema(node, schema, |n, schema| {
752                transform_up_with_schema_impl(n, schema, f)
753            })?
754            .transform_parent(|n| f(n, schema))
755        }
756
757        transform_up_with_schema_impl(self, schema, &mut f)
758    }
759
760    fn transform_up_with_lambdas_params(
761        self,
762        mut f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
763    ) -> Result<Transformed<Self>> {
764        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
765        fn transform_up_with_lambdas_params_impl<
766            F: FnMut(
767                Arc<dyn PhysicalExpr>,
768                &HashSet<String>,
769            ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
770        >(
771            node: Arc<dyn PhysicalExpr>,
772            params: &HashSet<String>,
773            f: &mut F,
774        ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
775            map_children_with_lambdas_params(node, params, |n, params| {
776                transform_up_with_lambdas_params_impl(n, params, f)
777            })?
778            .transform_parent(|n| f(n, params))
779        }
780
781        transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
782    }
783
784    fn transform_down_with_lambdas_params(
785        self,
786        mut f: impl FnMut(Self, &HashSet<String>) -> Result<Transformed<Self>>,
787    ) -> Result<Transformed<Self>> {
788        #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
789        fn transform_down_with_lambdas_params_impl<
790            F: FnMut(
791                Arc<dyn PhysicalExpr>,
792                &HashSet<String>,
793            ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
794        >(
795            node: Arc<dyn PhysicalExpr>,
796            params: &HashSet<String>,
797            f: &mut F,
798        ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
799            f(node, params)?.transform_children(|node| {
800                map_children_with_lambdas_params(node, params, |node, args| {
801                    transform_down_with_lambdas_params_impl(node, args, f)
802                })
803            })
804        }
805
806        transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f)
807    }
808}
809
810fn map_children_with_schema(
811    node: Arc<dyn PhysicalExpr>,
812    schema: &Schema,
813    mut f: impl FnMut(
814        Arc<dyn PhysicalExpr>,
815        &Schema,
816    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
817) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
818    match node.as_any().downcast_ref::<ScalarFunctionExpr>() {
819        Some(fun) if fun.args().iter().any(|arg| arg.as_any().is::<LambdaExpr>()) => {
820            let mut args_schemas =
821                lambdas_schemas_from_args(fun.fun(), fun.args(), schema)?.into_iter();
822
823            node.map_children(|node| f(node, &args_schemas.next().unwrap()))
824        }
825        _ => node.map_children(|node| f(node, schema)),
826    }
827}
828
829fn map_children_with_lambdas_params(
830    node: Arc<dyn PhysicalExpr>,
831    params: &HashSet<String>,
832    mut f: impl FnMut(
833        Arc<dyn PhysicalExpr>,
834        &HashSet<String>,
835    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>>,
836) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
837    match node.as_any().downcast_ref::<LambdaExpr>() {
838        Some(lambda) => {
839            let mut params = params.clone();
840
841            params.extend(lambda.params().iter().cloned());
842
843            node.map_children(|node| f(node, &params))
844        }
845        None => node.map_children(|node| f(node, params)),
846    }
847}
848
849#[cfg(test)]
850mod tests {
851    use std::any::Any;
852    use std::{borrow::Cow, sync::Arc};
853
854    use super::*;
855    use super::{lambdas_schemas_from_args, PhysicalExprExt};
856    use crate::expressions::Column;
857    use crate::{create_physical_expr, ScalarFunctionExpr};
858    use arrow::datatypes::{DataType, Field, Schema};
859    use datafusion_common::{tree_node::TreeNodeRecursion, DFSchema, HashSet, Result};
860    use datafusion_expr::{
861        col, expr::Lambda, Expr, ScalarFunctionArgs, ValueOrLambdaParameter, Volatility,
862    };
863    use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature};
864    use datafusion_expr_common::columnar_value::ColumnarValue;
865    use datafusion_physical_expr_common::physical_expr::is_volatile;
866    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
867
868    /// Test helper to create a mock UDF with a specific volatility
869    #[derive(Debug, PartialEq, Eq, Hash)]
870    struct MockScalarUDF {
871        signature: Signature,
872    }
873
874    impl ScalarUDFImpl for MockScalarUDF {
875        fn as_any(&self) -> &dyn Any {
876            self
877        }
878
879        fn name(&self) -> &str {
880            "mock_function"
881        }
882
883        fn signature(&self) -> &Signature {
884            &self.signature
885        }
886
887        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
888            Ok(DataType::Int32)
889        }
890
891        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
892            Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
893        }
894    }
895
896    #[test]
897    fn test_scalar_function_volatile_node() {
898        // Create a volatile UDF
899        let volatile_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
900            signature: Signature::uniform(
901                1,
902                vec![DataType::Float32],
903                Volatility::Volatile,
904            ),
905        }));
906
907        // Create a non-volatile UDF
908        let stable_udf = Arc::new(ScalarUDF::from(MockScalarUDF {
909            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
910        }));
911
912        let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
913        let args = vec![Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>];
914        let config_options = Arc::new(ConfigOptions::new());
915
916        // Test volatile function
917        let volatile_expr = ScalarFunctionExpr::try_new(
918            volatile_udf,
919            args.clone(),
920            &schema,
921            Arc::clone(&config_options),
922        )
923        .unwrap();
924
925        assert!(volatile_expr.is_volatile_node());
926        let volatile_arc: Arc<dyn PhysicalExpr> = Arc::new(volatile_expr);
927        assert!(is_volatile(&volatile_arc));
928
929        // Test non-volatile function
930        let stable_expr =
931            ScalarFunctionExpr::try_new(stable_udf, args, &schema, config_options)
932                .unwrap();
933
934        assert!(!stable_expr.is_volatile_node());
935        let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
936        assert!(!is_volatile(&stable_arc));
937    }
938
939    fn list_list_int() -> Schema {
940        Schema::new(vec![Field::new(
941            "v",
942            DataType::new_list(DataType::new_list(DataType::Int32, false), false),
943            false,
944        )])
945    }
946
947    fn list_int() -> Schema {
948        Schema::new(vec![Field::new(
949            "v",
950            DataType::new_list(DataType::Int32, false),
951            false,
952        )])
953    }
954
955    fn int() -> Schema {
956        Schema::new(vec![Field::new("v", DataType::Int32, false)])
957    }
958
959    fn array_transform_udf() -> ScalarUDF {
960        ScalarUDF::new_from_impl(ArrayTransformFunc::new())
961    }
962
963    fn args() -> Vec<Expr> {
964        vec![
965            col("v"),
966            Expr::Lambda(Lambda::new(
967                vec!["v".into()],
968                array_transform_udf().call(vec![
969                    col("v"),
970                    Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))),
971                ]),
972            )),
973        ]
974    }
975
976    // array_transform(v, |v| -> array_transform(v, |v| -> -v))
977    fn array_transform() -> Arc<dyn PhysicalExpr> {
978        let e = array_transform_udf().call(args());
979
980        create_physical_expr(
981            &e,
982            &DFSchema::try_from(list_list_int()).unwrap(),
983            &Default::default(),
984        )
985        .unwrap()
986    }
987
988    #[derive(Debug, PartialEq, Eq, Hash)]
989    struct ArrayTransformFunc {
990        signature: Signature,
991    }
992
993    impl ArrayTransformFunc {
994        pub fn new() -> Self {
995            Self {
996                signature: Signature::any(2, Volatility::Immutable),
997            }
998        }
999    }
1000
1001    impl ScalarUDFImpl for ArrayTransformFunc {
1002        fn as_any(&self) -> &dyn std::any::Any {
1003            self
1004        }
1005
1006        fn name(&self) -> &str {
1007            "array_transform"
1008        }
1009
1010        fn signature(&self) -> &Signature {
1011            &self.signature
1012        }
1013
1014        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1015            Ok(arg_types[0].clone())
1016        }
1017
1018        fn lambdas_parameters(
1019            &self,
1020            args: &[ValueOrLambdaParameter],
1021        ) -> Result<Vec<Option<Vec<Field>>>> {
1022            let ValueOrLambdaParameter::Value(value_field) = &args[0] else {
1023                unimplemented!()
1024            };
1025            let DataType::List(field) = value_field.data_type() else {
1026                unimplemented!()
1027            };
1028
1029            Ok(vec![
1030                None,
1031                Some(vec![Field::new(
1032                    "",
1033                    field.data_type().clone(),
1034                    field.is_nullable(),
1035                )]),
1036            ])
1037        }
1038
1039        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1040            unimplemented!()
1041        }
1042    }
1043
1044    #[test]
1045    fn test_lambdas_schemas_from_args() {
1046        let schema = list_list_int();
1047        let expr = array_transform();
1048
1049        let args = expr
1050            .as_any()
1051            .downcast_ref::<ScalarFunctionExpr>()
1052            .unwrap()
1053            .args();
1054
1055        let schemas =
1056            lambdas_schemas_from_args(&array_transform_udf(), args, &schema).unwrap();
1057
1058        assert_eq!(schemas, &[Cow::Borrowed(&schema), Cow::Owned(list_int())]);
1059    }
1060
1061    #[test]
1062    fn test_apply_with_schema() {
1063        let mut steps = vec![];
1064
1065        array_transform()
1066            .apply_with_schema(&list_list_int(), |node, schema| {
1067                steps.push((node.to_string(), schema.clone()));
1068
1069                Ok(TreeNodeRecursion::Continue)
1070            })
1071            .unwrap();
1072
1073        let expected = [
1074            (
1075                "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))",
1076                list_list_int(),
1077            ),
1078            ("(v) -> array_transform(v@0, (v) -> (- v@0))", list_int()),
1079            ("array_transform(v@0, (v) -> (- v@0))", list_int()),
1080            ("(v) -> (- v@0)", int()),
1081            ("(- v@0)", int()),
1082            ("v@0", int()),
1083            ("v@0", int()),
1084            ("v@0", int()),
1085        ]
1086        .map(|(a, b)| (String::from(a), b));
1087
1088        assert_eq!(steps, expected);
1089    }
1090
1091    #[test]
1092    fn test_apply_with_lambdas_params() {
1093        let array_transform = array_transform();
1094        let mut steps = vec![];
1095
1096        array_transform
1097            .apply_with_lambdas_params(|node, params| {
1098                steps.push((node.to_string(), params.clone()));
1099
1100                Ok(TreeNodeRecursion::Continue)
1101            })
1102            .unwrap();
1103
1104        let expected = [
1105            (
1106                "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))",
1107                HashSet::from(["v"]),
1108            ),
1109            (
1110                "(v) -> array_transform(v@0, (v) -> (- v@0))",
1111                HashSet::from(["v"]),
1112            ),
1113            ("array_transform(v@0, (v) -> (- v@0))", HashSet::from(["v"])),
1114            ("(v) -> (- v@0)", HashSet::from(["v"])),
1115            ("(- v@0)", HashSet::from(["v"])),
1116            ("v@0", HashSet::from(["v"])),
1117            ("v@0", HashSet::from(["v"])),
1118            ("v@0", HashSet::from(["v"])),
1119        ]
1120        .map(|(a, b)| (String::from(a), b));
1121
1122        assert_eq!(steps, expected);
1123    }
1124}