datafusion_expr/
udf.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDF`]: Scalar User Defined Functions
19
20use crate::async_udf::AsyncScalarUDF;
21use crate::expr::{schema_name_from_exprs_comma_separated_without_space, Lambda};
22use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
23use crate::sort_properties::{ExprProperties, SortProperties};
24use crate::udf_eq::UdfEq;
25use crate::{ColumnarValue, Documentation, Expr, ExprSchemable, Signature};
26use arrow::array::{ArrayRef, RecordBatch};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema};
28use datafusion_common::alias::AliasGenerator;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::TreeNodeRecursion;
31use datafusion_common::{
32    exec_err, not_impl_err, DFSchema, ExprSchema, Result, ScalarValue,
33};
34use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
35use datafusion_expr_common::interval_arithmetic::Interval;
36use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
37use indexmap::IndexMap;
38use std::any::Any;
39use std::borrow::Cow;
40use std::cmp::Ordering;
41use std::collections::HashMap;
42use std::fmt::Debug;
43use std::hash::{Hash, Hasher};
44use std::sync::{Arc, LazyLock};
45
46/// Logical representation of a Scalar User Defined Function.
47///
48/// A scalar function produces a single row output for each row of input. This
49/// struct contains the information DataFusion needs to plan and invoke
50/// functions you supply such as name, type signature, return type, and actual
51/// implementation.
52///
53/// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]).
54///
55/// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API
56///    access (examples in  [`advanced_udf.rs`]).
57///
58/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments.
59///
60/// # API Note
61///
62/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards
63/// compatibility with the older API.
64///
65/// [`create_udf`]: crate::expr_fn::create_udf
66/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
67/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
68#[derive(Debug, Clone)]
69pub struct ScalarUDF {
70    inner: Arc<dyn ScalarUDFImpl>,
71}
72
73impl PartialEq for ScalarUDF {
74    fn eq(&self, other: &Self) -> bool {
75        self.inner.dyn_eq(other.inner.as_any())
76    }
77}
78
79impl PartialOrd for ScalarUDF {
80    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81        let mut cmp = self.name().cmp(other.name());
82        if cmp == Ordering::Equal {
83            cmp = self.signature().partial_cmp(other.signature())?;
84        }
85        if cmp == Ordering::Equal {
86            cmp = self.aliases().partial_cmp(other.aliases())?;
87        }
88        // Contract for PartialOrd and PartialEq consistency requires that
89        // a == b if and only if partial_cmp(a, b) == Some(Equal).
90        if cmp == Ordering::Equal && self != other {
91            // Functions may have other properties besides name and signature
92            // that differentiate two instances (e.g. type, or arbitrary parameters).
93            // We cannot return Some(Equal) in such case.
94            return None;
95        }
96        debug_assert!(
97            cmp == Ordering::Equal || self != other,
98            "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
99            The functions compare as equal, but they are not equal based on general properties that \
100            the PartialOrd implementation observes,",
101            self.name(), other.name()
102        );
103        Some(cmp)
104    }
105}
106
107impl Eq for ScalarUDF {}
108
109impl Hash for ScalarUDF {
110    fn hash<H: Hasher>(&self, state: &mut H) {
111        self.inner.dyn_hash(state)
112    }
113}
114
115impl ScalarUDF {
116    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
117    ///
118    /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
119    pub fn new_from_impl<F>(fun: F) -> ScalarUDF
120    where
121        F: ScalarUDFImpl + 'static,
122    {
123        Self::new_from_shared_impl(Arc::new(fun))
124    }
125
126    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
127    pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
128        Self { inner: fun }
129    }
130
131    /// Return the underlying [`ScalarUDFImpl`] trait object for this function
132    pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
133        &self.inner
134    }
135
136    /// Adds additional names that can be used to invoke this function, in
137    /// addition to `name`
138    ///
139    /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
140    pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
141        Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
142    }
143
144    /// Returns a [`Expr`] logical expression to call this UDF with specified
145    /// arguments.
146    ///
147    /// This utility allows easily calling UDFs
148    ///
149    /// # Example
150    /// ```no_run
151    /// use datafusion_expr::{col, lit, ScalarUDF};
152    /// # fn my_udf() -> ScalarUDF { unimplemented!() }
153    /// let my_func: ScalarUDF = my_udf();
154    /// // Create an expr for `my_func(a, 12.3)`
155    /// let expr = my_func.call(vec![col("a"), lit(12.3)]);
156    /// ```
157    pub fn call(&self, args: Vec<Expr>) -> Expr {
158        Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
159            Arc::new(self.clone()),
160            args,
161        ))
162    }
163
164    /// Returns this function's name.
165    ///
166    /// See [`ScalarUDFImpl::name`] for more details.
167    pub fn name(&self) -> &str {
168        self.inner.name()
169    }
170
171    /// Returns this function's display_name.
172    ///
173    /// See [`ScalarUDFImpl::display_name`] for more details
174    #[deprecated(
175        since = "50.0.0",
176        note = "This method is unused and will be removed in a future release"
177    )]
178    pub fn display_name(&self, args: &[Expr]) -> Result<String> {
179        #[expect(deprecated)]
180        self.inner.display_name(args)
181    }
182
183    /// Returns this function's schema_name.
184    ///
185    /// See [`ScalarUDFImpl::schema_name`] for more details
186    pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
187        self.inner.schema_name(args)
188    }
189
190    /// Returns the aliases for this function.
191    ///
192    /// See [`ScalarUDF::with_aliases`] for more details
193    pub fn aliases(&self) -> &[String] {
194        self.inner.aliases()
195    }
196
197    /// Returns this function's [`Signature`] (what input types are accepted).
198    ///
199    /// See [`ScalarUDFImpl::signature`] for more details.
200    pub fn signature(&self) -> &Signature {
201        self.inner.signature()
202    }
203
204    /// The datatype this function returns given the input argument types.
205    /// This function is used when the input arguments are [`DataType`]s.
206    ///
207    ///  # Notes
208    ///
209    /// If a function implement [`ScalarUDFImpl::return_field_from_args`],
210    /// its [`ScalarUDFImpl::return_type`] should raise an error.
211    ///
212    /// See [`ScalarUDFImpl::return_type`] for more details.
213    pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
214        self.inner.return_type(arg_types)
215    }
216
217    /// Return the datatype this function returns given the input argument types.
218    ///
219    /// See [`ScalarUDFImpl::return_field_from_args`] for more details.
220    pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
221        self.inner.return_field_from_args(args)
222    }
223
224    /// Do the function rewrite
225    ///
226    /// See [`ScalarUDFImpl::simplify`] for more details.
227    pub fn simplify(
228        &self,
229        args: Vec<Expr>,
230        info: &dyn SimplifyInfo,
231    ) -> Result<ExprSimplifyResult> {
232        self.inner.simplify(args, info)
233    }
234
235    #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
236    pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
237        #[allow(deprecated)]
238        self.inner.is_nullable(args, schema)
239    }
240
241    /// Invoke the function on `args`, returning the appropriate result.
242    ///
243    /// See [`ScalarUDFImpl::invoke_with_args`] for details.
244    pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
245        #[cfg(debug_assertions)]
246        let return_field = Arc::clone(&args.return_field);
247        let result = self.inner.invoke_with_args(args)?;
248        // Maybe this could be enabled always?
249        // This doesn't use debug_assert!, but it's meant to run anywhere except on production. It's same in spirit, thus conditioning on debug_assertions.
250        #[cfg(debug_assertions)]
251        {
252            if &result.data_type() != return_field.data_type() {
253                return datafusion_common::internal_err!("Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'",
254                        self.name(),
255                        result.data_type(),
256                        return_field.data_type()
257                    );
258            }
259            // TODO verify return data is non-null when it was promised to be?
260        }
261        Ok(result)
262    }
263
264    /// Determines which of the arguments passed to this function are evaluated eagerly
265    /// and which may be evaluated lazily.
266    ///
267    /// See [ScalarUDFImpl::conditional_arguments] for more information.
268    pub fn conditional_arguments<'a>(
269        &self,
270        args: &'a [Expr],
271    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
272        self.inner.conditional_arguments(args)
273    }
274
275    /// Returns true if some of this `exprs` subexpressions may not be evaluated
276    /// and thus any side effects (like divide by zero) may not be encountered.
277    ///
278    /// See [ScalarUDFImpl::short_circuits] for more information.
279    pub fn short_circuits(&self) -> bool {
280        self.inner.short_circuits()
281    }
282
283    /// Computes the output interval for a [`ScalarUDF`], given the input
284    /// intervals.
285    ///
286    /// # Parameters
287    ///
288    /// * `inputs` are the intervals for the inputs (children) of this function.
289    ///
290    /// # Example
291    ///
292    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
293    /// then the output interval would be `[0, 3]`.
294    pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
295        self.inner.evaluate_bounds(inputs)
296    }
297
298    /// Updates bounds for child expressions, given a known interval for this
299    /// function. This is used to propagate constraints down through an expression
300    /// tree.
301    ///
302    /// # Parameters
303    ///
304    /// * `interval` is the currently known interval for this function.
305    /// * `inputs` are the current intervals for the inputs (children) of this function.
306    ///
307    /// # Returns
308    ///
309    /// A `Vec` of new intervals for the children, in order.
310    ///
311    /// If constraint propagation reveals an infeasibility for any child, returns
312    /// [`None`]. If none of the children intervals change as a result of
313    /// propagation, may return an empty vector instead of cloning `children`.
314    /// This is the default (and conservative) return value.
315    ///
316    /// # Example
317    ///
318    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
319    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
320    pub fn propagate_constraints(
321        &self,
322        interval: &Interval,
323        inputs: &[&Interval],
324    ) -> Result<Option<Vec<Interval>>> {
325        self.inner.propagate_constraints(interval, inputs)
326    }
327
328    /// Calculates the [`SortProperties`] of this function based on its
329    /// children's properties.
330    pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
331        self.inner.output_ordering(inputs)
332    }
333
334    pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
335        self.inner.preserves_lex_ordering(inputs)
336    }
337
338    /// See [`ScalarUDFImpl::coerce_types`] for more details.
339    pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
340        self.inner.coerce_types(arg_types)
341    }
342
343    /// Returns the documentation for this Scalar UDF.
344    ///
345    /// Documentation can be accessed programmatically as well as
346    /// generating publicly facing documentation.
347    pub fn documentation(&self) -> Option<&Documentation> {
348        self.inner.documentation()
349    }
350
351    /// Return true if this function is an async function
352    pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
353        self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
354    }
355
356    /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead
357    pub(crate) fn arguments_expr_schema<'a>(
358        &self,
359        args: &[Expr],
360        schema: &'a dyn ExprSchema,
361    ) -> Result<Vec<impl ExprSchema + 'a>> {
362        self.arguments_scope_with(
363            &lambda_parameters(args, schema)?,
364            ExtendableExprSchema::new(schema),
365        )
366    }
367
368    /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead,
369    pub fn arguments_arrow_schema<'a>(
370        &self,
371        args: &[ValueOrLambdaParameter],
372        schema: &'a Schema,
373    ) -> Result<Vec<Cow<'a, Schema>>> {
374        self.arguments_scope_with(args, Cow::Borrowed(schema))
375    }
376
377    pub fn arguments_schema_from_logical_args<'a>(
378        &self,
379        args: &[Expr],
380        schema: &'a DFSchema,
381    ) -> Result<Vec<Cow<'a, DFSchema>>> {
382        self.arguments_scope_with(
383            &lambda_parameters(args, schema)?,
384            Cow::Borrowed(schema),
385        )
386    }
387
388    /// Scalar function supports lambdas as arguments, which will be evaluated with
389    /// a different schema that of the function itself. This functions returns a vec
390    /// with the correspoding schema that each argument will run
391    ///
392    /// Return a vec with a value for each argument in args that, if it's a value, it's a clone of base_scope,
393    /// if it's a lambda, it's the return of merge called with the index and the fields from lambdas_parameters
394    /// updated with names from metadata
395    fn arguments_scope_with<T: ExtendSchema + Clone>(
396        &self,
397        args: &[ValueOrLambdaParameter],
398        schema: T,
399    ) -> Result<Vec<T>> {
400        let parameters = self.inner().lambdas_parameters(args)?;
401
402        if parameters.len() != args.len() {
403            return exec_err!(
404                "lambdas_schemas: {} lambdas_parameters returned {} values instead of {}",
405                self.name(),
406                args.len(),
407                parameters.len()
408            );
409        }
410
411        std::iter::zip(args, parameters)
412            .enumerate()
413            .map(|(i, (arg, parameters))| match (arg, parameters) {
414                (ValueOrLambdaParameter::Value(_), None) => Ok(schema.clone()),
415                (ValueOrLambdaParameter::Value(_), Some(_)) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a value but lambdas_parameters result treat it as a lambda", self.name(), i),
416                (ValueOrLambdaParameter::Lambda(_, _), None) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a lambda but lambdas_parameters result treat it as a value", self.name(), i),
417                (ValueOrLambdaParameter::Lambda(names, captures), Some(args)) => {
418                    if names.len() > args.len() {
419                        return exec_err!("lambdas_schemas: {} argument {} (0-indexed), a lambda, supports up to {} arguments, but got {}", self.name(), i, args.len(), names.len())
420                    }
421
422                    let fields = std::iter::zip(*names, args)
423                        .map(|(name, arg)| arg.with_name(name))
424                        .collect::<Fields>();
425
426                    if *captures {
427                        schema.extend(fields)
428                    } else {
429                        T::from_fields(fields)
430                    }
431                }
432            })
433            .collect()
434    }
435}
436
437pub trait ExtendSchema: Sized {
438    fn from_fields(params: Fields) -> Result<Self>;
439    fn extend(&self, params: Fields) -> Result<Self>;
440}
441
442impl ExtendSchema for DFSchema {
443    fn from_fields(params: Fields) -> Result<Self> {
444        DFSchema::from_unqualified_fields(params, Default::default())
445    }
446
447    fn extend(&self, params: Fields) -> Result<Self> {
448        let qualified_fields = self
449            .iter()
450            .map(|(qualifier, field)| {
451                if params.find(field.name().as_str()).is_none() {
452                    return (qualifier.cloned(), Arc::clone(field));
453                }
454
455                let alias_gen = AliasGenerator::new();
456
457                loop {
458                    let alias = alias_gen.next(field.name().as_str());
459
460                    if params.find(&alias).is_none()
461                        && !self.has_column_with_unqualified_name(&alias)
462                    {
463                        return (
464                            qualifier.cloned(),
465                            Arc::new(Field::new(
466                                alias,
467                                field.data_type().clone(),
468                                field.is_nullable(),
469                            )),
470                        );
471                    }
472                }
473            })
474            .collect();
475
476        let mut schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?;
477        let fields_schema = DFSchema::from_unqualified_fields(params, HashMap::new())?;
478
479        schema.merge(&fields_schema);
480
481        assert_eq!(
482            schema.fields().len(),
483            self.fields().len() + fields_schema.fields().len()
484        );
485
486        Ok(schema)
487    }
488}
489
490impl ExtendSchema for Schema {
491    fn from_fields(params: Fields) -> Result<Self> {
492        Ok(Schema::new(params))
493    }
494
495    fn extend(&self, params: Fields) -> Result<Self> {
496        let mut params2 = params.iter()
497            .map(|f| (f.name().as_str(), Some(Arc::clone(f))))
498            .collect::<IndexMap<_, _>>();
499
500        let mut fields = self.fields()
501            .iter()
502            .map(|field| {
503                match params2.get_mut(field.name().as_str()).and_then(|p| p.take()) {
504                    Some(param) => param,
505                    None => Arc::clone(field),
506                }
507            })
508            .collect::<Vec<_>>();
509
510        fields.extend(params2.into_values().flatten());
511
512        let fields = self
513            .fields()
514            .iter()
515            .map(|field| {
516                if params.find(field.name().as_str()).is_none() {
517                    return Arc::clone(field);
518                }
519
520                let alias_gen = AliasGenerator::new();
521
522                loop {
523                    let alias = alias_gen.next(field.name().as_str());
524
525                    if params.find(&alias).is_none()
526                        && self.column_with_name(&alias).is_none()
527                    {
528                        return Arc::new(Field::new(
529                            alias,
530                            field.data_type().clone(),
531                            field.is_nullable(),
532                        ));
533                    }
534                }
535            })
536            .chain(params.iter().cloned())
537            .collect::<Fields>();
538
539        assert_eq!(fields.len(), self.fields().len() + params.len());
540
541        Ok(Schema::new_with_metadata(fields, self.metadata.clone()))
542    }
543}
544
545impl<T: ExtendSchema + Clone> ExtendSchema for Cow<'_, T> {
546    fn from_fields(params: Fields) -> Result<Self> {
547        Ok(Cow::Owned(T::from_fields(params)?))
548    }
549
550    fn extend(&self, params: Fields) -> Result<Self> {
551        Ok(Cow::Owned(self.as_ref().extend(params)?))
552    }
553}
554
555impl<T: ExtendSchema> ExtendSchema for Arc<T> {
556    fn from_fields(params: Fields) -> Result<Self> {
557        Ok(Arc::new(T::from_fields(params)?))
558    }
559
560    fn extend(&self, params: Fields) -> Result<Self> {
561        Ok(Arc::new(self.as_ref().extend(params)?))
562    }
563}
564
565impl ExtendSchema for ExtendableExprSchema<'_> {
566    fn from_fields(params: Fields) -> Result<Self> {
567        static EMPTY_DFSCHEMA: LazyLock<DFSchema> = LazyLock::new(DFSchema::empty);
568
569        Ok(ExtendableExprSchema {
570            fields_chain: vec![params],
571            outer_schema: &*EMPTY_DFSCHEMA,
572        })
573    }
574
575    fn extend(&self, params: Fields) -> Result<Self> {
576        Ok(ExtendableExprSchema {
577            fields_chain: std::iter::once(params)
578                .chain(self.fields_chain.iter().cloned())
579                .collect(),
580            outer_schema: self.outer_schema,
581        })
582    }
583}
584
585/// A `&dyn ExprSchema` wrapper that supports adding the parameters of a lambda
586#[derive(Clone, Debug)]
587struct ExtendableExprSchema<'a> {
588    fields_chain: Vec<Fields>,
589    outer_schema: &'a dyn ExprSchema,
590}
591
592impl<'a> ExtendableExprSchema<'a> {
593    fn new(schema: &'a dyn ExprSchema) -> Self {
594        Self {
595            fields_chain: vec![],
596            outer_schema: schema,
597        }
598    }
599}
600
601impl ExprSchema for ExtendableExprSchema<'_> {
602    fn field_from_column(&self, col: &datafusion_common::Column) -> Result<&Field> {
603        if col.relation.is_none() {
604            for fields in &self.fields_chain {
605                if let Some((_index, lambda_param)) = fields.find(&col.name) {
606                    return Ok(lambda_param);
607                }
608            }
609        }
610
611        self.outer_schema.field_from_column(col)
612    }
613}
614
615#[derive(Clone, Debug)]
616pub enum ValueOrLambdaParameter<'a> {
617    /// A columnar value with the given field
618    Value(FieldRef),
619    /// A lambda with the given parameters names and a flag indicating wheter it captures any columns
620    Lambda(&'a [String], bool),
621}
622
623impl<F> From<F> for ScalarUDF
624where
625    F: ScalarUDFImpl + 'static,
626{
627    fn from(fun: F) -> Self {
628        Self::new_from_impl(fun)
629    }
630}
631
632/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
633/// scalar function.
634#[derive(Debug, Clone)]
635pub struct ScalarFunctionArgs {
636    /// The evaluated arguments to the function
637    /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)`
638    pub args: Vec<ColumnarValue>,
639    /// Field associated with each arg, if it exists
640    pub arg_fields: Vec<FieldRef>,
641    /// The number of rows in record batch being evaluated
642    pub number_rows: usize,
643    /// The return field of the scalar function returned (from `return_type`
644    /// or `return_field_from_args`) when creating the physical expression
645    /// from the logical expression
646    pub return_field: FieldRef,
647    /// The config options at execution time
648    pub config_options: Arc<ConfigOptions>,
649    /// The lambdas passed to the function
650    /// If it's not a lambda it will be `None`
651    pub lambdas: Option<Vec<Option<ScalarFunctionLambdaArg>>>,
652}
653
654/// A lambda argument to a ScalarFunction
655#[derive(Clone, Debug)]
656pub struct ScalarFunctionLambdaArg {
657    /// The parameters defined in this lambda
658    ///
659    /// For example, for `array_transform([2], v -> -v)`,
660    /// this will be `vec![Field::new("v", DataType::Int32, true)]`
661    pub params: Vec<FieldRef>,
662    /// The body of the lambda
663    ///
664    /// For example, for `array_transform([2], v -> -v)`,
665    /// this will be the physical expression of `-v`
666    pub body: Arc<dyn PhysicalExpr>,
667    /// A RecordBatch containing at least the captured columns inside this lambda body, if any
668    /// Note that it may contain additional, non-specified columns, but that's implementation detail
669    ///
670    /// For example, for `array_transform([2], v -> v + a + b)`,
671    /// this will be a `RecordBatch` with two columns, `a` and `b`
672    pub captures: Option<RecordBatch>,
673}
674
675impl ScalarFunctionArgs {
676    /// The return type of the function. See [`Self::return_field`] for more
677    /// details.
678    pub fn return_type(&self) -> &DataType {
679        self.return_field.data_type()
680    }
681
682    pub fn to_lambda_args(&self) -> Vec<ValueOrLambda<'_>> {
683        match &self.lambdas {
684            Some(lambdas) => std::iter::zip(&self.args, lambdas)
685                .map(|(arg, lambda)| match lambda {
686                    Some(lambda) => ValueOrLambda::Lambda(lambda),
687                    None => ValueOrLambda::Value(arg),
688                })
689                .collect(),
690            None => self.args.iter().map(ValueOrLambda::Value).collect(),
691        }
692    }
693}
694
695// An argument to a ScalarUDF that supports lambdas
696#[derive(Debug)]
697pub enum ValueOrLambda<'a> {
698    Value(&'a ColumnarValue),
699    Lambda(&'a ScalarFunctionLambdaArg),
700}
701
702/// Information about arguments passed to the function
703///
704/// This structure contains metadata about how the function was called
705/// such as the type of the arguments, any scalar arguments and if the
706/// arguments can (ever) be null
707///
708/// See [`ScalarUDFImpl::return_field_from_args`] for more information
709#[derive(Debug)]
710pub struct ReturnFieldArgs<'a> {
711    /// The data types of the arguments to the function
712    ///
713    /// If argument `i` to the function is a lambda, it will be the field returned by the
714    /// lambda when executed with the arguments returned from `ScalarUDFImpl::lambdas_parameters`
715    ///
716    /// For example, with `array_transform([1], v -> v == 5)`
717    /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]`
718    pub arg_fields: &'a [FieldRef],
719    /// Is argument `i` to the function a scalar (constant)?
720    ///
721    /// If the argument `i` is not a scalar, it will be None
722    ///
723    /// For example, if a function is called like `my_function(column_a, 5)`
724    /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
725    pub scalar_arguments: &'a [Option<&'a ScalarValue>],
726    /// Is argument `i` to the function a lambda?
727    ///
728    /// For example, with `array_transform([1], v -> v == 5)`
729    /// this field will be `[false, true]`
730    pub lambdas: &'a [bool],
731}
732
733/// A tagged Field indicating whether it correspond to a value or a lambda argument
734#[derive(Debug)]
735pub enum ValueOrLambdaField<'a> {
736    /// The Field of a ColumnarValue argument
737    Value(&'a FieldRef),
738    /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters
739    Lambda(&'a FieldRef),
740}
741
742impl<'a> ReturnFieldArgs<'a> {
743    /// Based on self.lambdas, encodes self.arg_fields to tagged enums
744    /// indicating whether it correspond to a value or a lambda argument
745    pub fn to_lambda_args(&self) -> Vec<ValueOrLambdaField<'a>> {
746        std::iter::zip(self.arg_fields, self.lambdas)
747            .map(|(field, is_lambda)| {
748                if *is_lambda {
749                    ValueOrLambdaField::Lambda(field)
750                } else {
751                    ValueOrLambdaField::Value(field)
752                }
753            })
754            .collect()
755    }
756}
757
758/// Trait for implementing user defined scalar functions.
759///
760/// This trait exposes the full API for implementing user defined functions and
761/// can be used to implement any function.
762///
763/// See [`advanced_udf.rs`] for a full example with complete implementation and
764/// [`ScalarUDF`] for other available options.
765///
766/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
767///
768/// # Basic Example
769/// ```
770/// # use std::any::Any;
771/// # use std::sync::LazyLock;
772/// # use arrow::datatypes::DataType;
773/// # use datafusion_common::{DataFusionError, plan_err, Result};
774/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
775/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
776/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
777/// /// This struct for a simple UDF that adds one to an int32
778/// #[derive(Debug, PartialEq, Eq, Hash)]
779/// struct AddOne {
780///   signature: Signature,
781/// }
782///
783/// impl AddOne {
784///   fn new() -> Self {
785///     Self {
786///       signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable),
787///      }
788///   }
789/// }
790///
791/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
792///         Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)")
793///             .with_argument("arg1", "The int32 number to add one to")
794///             .build()
795///     });
796///
797/// fn get_doc() -> &'static Documentation {
798///     &DOCUMENTATION
799/// }
800///
801/// /// Implement the ScalarUDFImpl trait for AddOne
802/// impl ScalarUDFImpl for AddOne {
803///    fn as_any(&self) -> &dyn Any { self }
804///    fn name(&self) -> &str { "add_one" }
805///    fn signature(&self) -> &Signature { &self.signature }
806///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
807///      if !matches!(args.get(0), Some(&DataType::Int32)) {
808///        return plan_err!("add_one only accepts Int32 arguments");
809///      }
810///      Ok(DataType::Int32)
811///    }
812///    // The actual implementation would add one to the argument
813///    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
814///         unimplemented!()
815///    }
816///    fn documentation(&self) -> Option<&Documentation> {
817///         Some(get_doc())
818///     }
819/// }
820///
821/// // Create a new ScalarUDF from the implementation
822/// let add_one = ScalarUDF::from(AddOne::new());
823///
824/// // Call the function `add_one(col)`
825/// let expr = add_one.call(vec![col("a")]);
826/// ```
827pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
828    /// Returns this object as an [`Any`] trait object
829    fn as_any(&self) -> &dyn Any;
830
831    /// Returns this function's name
832    fn name(&self) -> &str;
833
834    /// Returns any aliases (alternate names) for this function.
835    ///
836    /// Aliases can be used to invoke the same function using different names.
837    /// For example in some databases `now()` and `current_timestamp()` are
838    /// aliases for the same function. This behavior can be obtained by
839    /// returning `current_timestamp` as an alias for the `now` function.
840    ///
841    /// Note: `aliases` should only include names other than [`Self::name`].
842    /// Defaults to `[]` (no aliases)
843    fn aliases(&self) -> &[String] {
844        &[]
845    }
846
847    /// Returns the user-defined display name of function, given the arguments
848    ///
849    /// This can be used to customize the output column name generated by this
850    /// function.
851    ///
852    /// Defaults to `name(args[0], args[1], ...)`
853    #[deprecated(
854        since = "50.0.0",
855        note = "This method is unused and will be removed in a future release"
856    )]
857    fn display_name(&self, args: &[Expr]) -> Result<String> {
858        let names: Vec<String> = args.iter().map(ToString::to_string).collect();
859        // TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
860        Ok(format!("{}({})", self.name(), names.join(",")))
861    }
862
863    /// Returns the name of the column this expression would create
864    ///
865    /// See [`Expr::schema_name`] for details
866    fn schema_name(&self, args: &[Expr]) -> Result<String> {
867        Ok(format!(
868            "{}({})",
869            self.name(),
870            schema_name_from_exprs_comma_separated_without_space(args)?
871        ))
872    }
873
874    /// Returns a [`Signature`] describing the argument types for which this
875    /// function has an implementation, and the function's [`Volatility`].
876    ///
877    /// See [`Signature`] for more details on argument type handling
878    /// and [`Self::return_type`] for computing the return type.
879    ///
880    /// [`Volatility`]: datafusion_expr_common::signature::Volatility
881    fn signature(&self) -> &Signature;
882
883    /// [`DataType`] returned by this function, given the types of the
884    /// arguments.
885    ///
886    /// # Arguments
887    ///
888    /// `arg_types` Data types of the arguments. The implementation of
889    /// `return_type` can assume that some other part of the code has coerced
890    /// the actual argument types to match [`Self::signature`].
891    ///
892    /// # Notes
893    ///
894    /// If you provide an implementation for [`Self::return_field_from_args`],
895    /// DataFusion will not call `return_type` (this function). While it is
896    /// valid to to put [`unimplemented!()`] or [`unreachable!()`], it is
897    /// recommended to return [`DataFusionError::Internal`] instead, which
898    /// reduces the severity of symptoms if bugs occur (an error rather than a
899    /// panic).
900    ///
901    /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
902    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
903
904    /// Create a new instance of this function with updated configuration.
905    ///
906    /// This method is called when configuration options change at runtime
907    /// (e.g., via `SET` statements) to allow functions that depend on
908    /// configuration to update themselves accordingly.
909    ///
910    /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so
911    /// this API is not needed for functions where the values may
912    /// depend on the current options.
913    ///
914    /// This API is useful for functions where the return
915    /// **type** depends on the configuration options, such as the `now()` function
916    /// which depends on the current timezone.
917    ///
918    /// # Arguments
919    ///
920    /// * `config` - The updated configuration options
921    ///
922    /// # Returns
923    ///
924    /// * `Some(ScalarUDF)` - A new instance of this function configured with the new settings
925    /// * `None` - If this function does not change with new configuration settings (the default)
926    fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
927        None
928    }
929
930    /// What type will be returned by this function, given the arguments?
931    ///
932    /// By default, this function calls [`Self::return_type`] with the
933    /// types of each argument.
934    ///
935    /// # Notes
936    ///
937    /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient,
938    /// as the result type is typically a deterministic function of the input types
939    /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly
940    /// is generally unnecessary unless the return type depends on runtime values.
941    ///
942    /// This function can be used for more advanced cases such as:
943    ///
944    /// 1. specifying nullability
945    /// 2. return types based on the **values** of the arguments (rather than
946    ///    their **types**.
947    ///
948    /// # Example creating `Field`
949    ///
950    /// Note the name of the [`Field`] is ignored, except for structured types such as
951    /// `DataType::Struct`.
952    ///
953    /// ```rust
954    /// # use std::sync::Arc;
955    /// # use arrow::datatypes::{DataType, Field, FieldRef};
956    /// # use datafusion_common::Result;
957    /// # use datafusion_expr::ReturnFieldArgs;
958    /// # struct Example{}
959    /// # impl Example {
960    /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
961    ///     // report output is only nullable if any one of the arguments are nullable
962    ///     let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
963    ///     let field = Arc::new(Field::new("ignored_name", DataType::Int32, true));
964    ///     Ok(field)
965    /// }
966    /// # }
967    /// ```
968    ///
969    /// # Output Type based on Values
970    ///
971    /// For example, the following two function calls get the same argument
972    /// types (something and a `Utf8` string) but return different types based
973    /// on the value of the second argument:
974    ///
975    /// * `arrow_cast(x, 'Int16')` --> `Int16`
976    /// * `arrow_cast(x, 'Float32')` --> `Float32`
977    ///
978    /// # Requirements
979    ///
980    /// This function **must** consistently return the same type for the same
981    /// logical input even if the input is simplified (e.g. it must return the same
982    /// value for `('foo' | 'bar')` as it does for ('foobar').
983    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
984        let data_types = args
985            .arg_fields
986            .iter()
987            .map(|f| f.data_type())
988            .cloned()
989            .collect::<Vec<_>>();
990        let return_type = self.return_type(&data_types)?;
991        Ok(Arc::new(Field::new(self.name(), return_type, true)))
992    }
993
994    #[deprecated(
995        since = "45.0.0",
996        note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
997    )]
998    fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
999        true
1000    }
1001
1002    /// Invoke the function returning the appropriate result.
1003    ///
1004    /// # Performance
1005    ///
1006    /// For the best performance, the implementations should handle the common case
1007    /// when one or more of their arguments are constant values (aka
1008    /// [`ColumnarValue::Scalar`]).
1009    ///
1010    /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
1011    /// to arrays, which will likely be simpler code, but be slower.
1012    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
1013
1014    /// Optionally apply per-UDF simplification / rewrite rules.
1015    ///
1016    /// This can be used to apply function specific simplification rules during
1017    /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
1018    /// implementation does nothing.
1019    ///
1020    /// Note that DataFusion handles simplifying arguments and  "constant
1021    /// folding" (replacing a function call with constant arguments such as
1022    /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
1023    /// optimizations manually for specific UDFs.
1024    ///
1025    /// # Arguments
1026    /// * `args`: The arguments of the function
1027    /// * `info`: The necessary information for simplification
1028    ///
1029    /// # Returns
1030    /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
1031    /// if the function cannot be simplified, the arguments *MUST* be returned
1032    /// unmodified
1033    ///
1034    /// # Notes
1035    ///
1036    /// The returned expression must have the same schema as the original
1037    /// expression, including both the data type and nullability. For example,
1038    /// if the original expression is nullable, the returned expression must
1039    /// also be nullable, otherwise it may lead to schema verification errors
1040    /// later in query planning.
1041    fn simplify(
1042        &self,
1043        args: Vec<Expr>,
1044        _info: &dyn SimplifyInfo,
1045    ) -> Result<ExprSimplifyResult> {
1046        Ok(ExprSimplifyResult::Original(args))
1047    }
1048
1049    /// Returns true if some of this `exprs` subexpressions may not be evaluated
1050    /// and thus any side effects (like divide by zero) may not be encountered.
1051    ///
1052    /// Setting this to true prevents certain optimizations such as common
1053    /// subexpression elimination
1054    ///
1055    /// When overriding this function to return `true`, [ScalarUDFImpl::conditional_arguments] can also be
1056    /// overridden to report more accurately which arguments are eagerly evaluated and which ones
1057    /// lazily.
1058    fn short_circuits(&self) -> bool {
1059        false
1060    }
1061
1062    /// Determines which of the arguments passed to this function are evaluated eagerly
1063    /// and which may be evaluated lazily.
1064    ///
1065    /// If this function returns `None`, all arguments are eagerly evaluated.
1066    /// Returning `None` is a micro optimization that saves a needless `Vec`
1067    /// allocation.
1068    ///
1069    /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
1070    /// are the arguments that are always evaluated, and `lazy` are the
1071    /// arguments that may be evaluated lazily (i.e. may not be evaluated at all
1072    /// in some cases).
1073    ///
1074    /// Implementations must ensure that the two returned `Vec`s are disjunct,
1075    /// and that each argument from `args` is present in one the two `Vec`s.
1076    ///
1077    /// When overriding this function, [ScalarUDFImpl::short_circuits] must
1078    /// be overridden to return `true`.
1079    fn conditional_arguments<'a>(
1080        &self,
1081        args: &'a [Expr],
1082    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1083        if self.short_circuits() {
1084            Some((vec![], args.iter().collect()))
1085        } else {
1086            None
1087        }
1088    }
1089
1090    /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input
1091    /// intervals.
1092    ///
1093    /// # Parameters
1094    ///
1095    /// * `children` are the intervals for the children (inputs) of this function.
1096    ///
1097    /// # Example
1098    ///
1099    /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
1100    /// then the output interval would be `[0, 3]`.
1101    fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
1102        // We cannot assume the input datatype is the same of output type.
1103        Interval::make_unbounded(&DataType::Null)
1104    }
1105
1106    /// Updates bounds for child expressions, given a known [`Interval`]s for this
1107    /// function.
1108    ///
1109    /// This function is used to propagate constraints down through an
1110    /// expression tree.
1111    ///
1112    /// # Parameters
1113    ///
1114    /// * `interval` is the currently known interval for this function.
1115    /// * `inputs` are the current intervals for the inputs (children) of this function.
1116    ///
1117    /// # Returns
1118    ///
1119    /// A `Vec` of new intervals for the children, in order.
1120    ///
1121    /// If constraint propagation reveals an infeasibility for any child, returns
1122    /// [`None`]. If none of the children intervals change as a result of
1123    /// propagation, may return an empty vector instead of cloning `children`.
1124    /// This is the default (and conservative) return value.
1125    ///
1126    /// # Example
1127    ///
1128    /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
1129    /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
1130    fn propagate_constraints(
1131        &self,
1132        _interval: &Interval,
1133        _inputs: &[&Interval],
1134    ) -> Result<Option<Vec<Interval>>> {
1135        Ok(Some(vec![]))
1136    }
1137
1138    /// Calculates the [`SortProperties`] of this function based on its children's properties.
1139    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
1140        if !self.preserves_lex_ordering(inputs)? {
1141            return Ok(SortProperties::Unordered);
1142        }
1143
1144        let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
1145            return Ok(SortProperties::Singleton);
1146        };
1147
1148        if inputs
1149            .iter()
1150            .skip(1)
1151            .all(|input| &input.sort_properties == first_order)
1152        {
1153            Ok(*first_order)
1154        } else {
1155            Ok(SortProperties::Unordered)
1156        }
1157    }
1158
1159    /// Returns true if the function preserves lexicographical ordering based on
1160    /// the input ordering.
1161    ///
1162    /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not.
1163    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
1164        Ok(false)
1165    }
1166
1167    /// Coerce arguments of a function call to types that the function can evaluate.
1168    ///
1169    /// This function is only called if [`ScalarUDFImpl::signature`] returns
1170    /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of
1171    /// the other variants of [`TypeSignature`] which handle common cases.
1172    ///
1173    /// See the [type coercion module](crate::type_coercion)
1174    /// documentation for more details on type coercion
1175    ///
1176    /// [`TypeSignature`]: crate::TypeSignature
1177    ///
1178    /// For example, if your function requires a floating point arguments, but the user calls
1179    /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
1180    /// to ensure the argument is converted to `1::double`
1181    ///
1182    /// # Parameters
1183    /// * `arg_types`: The argument types of the arguments  this function with
1184    ///
1185    /// # Return value
1186    /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
1187    /// arguments to these specific types.
1188    fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1189        not_impl_err!("Function {} does not implement coerce_types", self.name())
1190    }
1191
1192    /// Returns the documentation for this Scalar UDF.
1193    ///
1194    /// Documentation can be accessed programmatically as well as generating
1195    /// publicly facing documentation.
1196    fn documentation(&self) -> Option<&Documentation> {
1197        None
1198    }
1199
1200    /// Returns the parameters that any lambda supports
1201    fn lambdas_parameters(
1202        &self,
1203        args: &[ValueOrLambdaParameter],
1204    ) -> Result<Vec<Option<Vec<Field>>>> {
1205        Ok(vec![None; args.len()])
1206    }
1207}
1208
1209/// ScalarUDF that adds an alias to the underlying function. It is better to
1210/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
1211#[derive(Debug, PartialEq, Eq, Hash)]
1212struct AliasedScalarUDFImpl {
1213    inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
1214    aliases: Vec<String>,
1215}
1216
1217impl AliasedScalarUDFImpl {
1218    pub fn new(
1219        inner: Arc<dyn ScalarUDFImpl>,
1220        new_aliases: impl IntoIterator<Item = &'static str>,
1221    ) -> Self {
1222        let mut aliases = inner.aliases().to_vec();
1223        aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1224        Self {
1225            inner: inner.into(),
1226            aliases,
1227        }
1228    }
1229}
1230
1231#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1232impl ScalarUDFImpl for AliasedScalarUDFImpl {
1233    fn as_any(&self) -> &dyn Any {
1234        self
1235    }
1236
1237    fn name(&self) -> &str {
1238        self.inner.name()
1239    }
1240
1241    fn display_name(&self, args: &[Expr]) -> Result<String> {
1242        #[expect(deprecated)]
1243        self.inner.display_name(args)
1244    }
1245
1246    fn schema_name(&self, args: &[Expr]) -> Result<String> {
1247        self.inner.schema_name(args)
1248    }
1249
1250    fn signature(&self) -> &Signature {
1251        self.inner.signature()
1252    }
1253
1254    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1255        self.inner.return_type(arg_types)
1256    }
1257
1258    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
1259        self.inner.return_field_from_args(args)
1260    }
1261
1262    fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
1263        #[allow(deprecated)]
1264        self.inner.is_nullable(args, schema)
1265    }
1266
1267    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1268        self.inner.invoke_with_args(args)
1269    }
1270
1271    fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
1272        None
1273    }
1274
1275    fn aliases(&self) -> &[String] {
1276        &self.aliases
1277    }
1278
1279    fn simplify(
1280        &self,
1281        args: Vec<Expr>,
1282        info: &dyn SimplifyInfo,
1283    ) -> Result<ExprSimplifyResult> {
1284        self.inner.simplify(args, info)
1285    }
1286
1287    fn conditional_arguments<'a>(
1288        &self,
1289        args: &'a [Expr],
1290    ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1291        self.inner.conditional_arguments(args)
1292    }
1293
1294    fn short_circuits(&self) -> bool {
1295        self.inner.short_circuits()
1296    }
1297
1298    fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
1299        self.inner.evaluate_bounds(input)
1300    }
1301
1302    fn propagate_constraints(
1303        &self,
1304        interval: &Interval,
1305        inputs: &[&Interval],
1306    ) -> Result<Option<Vec<Interval>>> {
1307        self.inner.propagate_constraints(interval, inputs)
1308    }
1309
1310    fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
1311        self.inner.output_ordering(inputs)
1312    }
1313
1314    fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
1315        self.inner.preserves_lex_ordering(inputs)
1316    }
1317
1318    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1319        self.inner.coerce_types(arg_types)
1320    }
1321
1322    fn documentation(&self) -> Option<&Documentation> {
1323        self.inner.documentation()
1324    }
1325
1326    fn lambdas_parameters(
1327        &self,
1328        args: &[ValueOrLambdaParameter],
1329    ) -> Result<Vec<Option<Vec<Field>>>> {
1330        self.inner.lambdas_parameters(args)
1331    }
1332}
1333
1334fn lambda_parameters<'a>(
1335    args: &'a [Expr],
1336    schema: &dyn ExprSchema,
1337) -> Result<Vec<ValueOrLambdaParameter<'a>>> {
1338    args.iter()
1339        .map(|e| match e {
1340            Expr::Lambda(Lambda { params, body: _ }) => {
1341                let mut captures = false;
1342
1343                e.apply_with_lambdas_params(|expr, lambdas_params| match expr {
1344                    Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
1345                        captures = true;
1346
1347                        Ok(TreeNodeRecursion::Stop)
1348                    }
1349                    _ => Ok(TreeNodeRecursion::Continue),
1350                })
1351                .unwrap();
1352
1353                Ok(ValueOrLambdaParameter::Lambda(params.as_slice(), captures))
1354            }
1355            _ => Ok(ValueOrLambdaParameter::Value(e.to_field(schema)?.1)),
1356        })
1357        .collect()
1358}
1359
1360/// Merge the lambda body captured columns with it's arguments
1361/// Datafusion relies on an unspecified field ordering implemented in this function
1362/// As such, this is the only correct way to merge the captured values with the arguments
1363/// The number of args should not be lower than the number of params
1364///
1365/// See also merge_captures_with_lazy_args and merge_captures_with_boxed_lazy_args that lazily
1366/// computes only the necessary arguments to match the number of params
1367pub fn merge_captures_with_args(
1368    captures: Option<&RecordBatch>,
1369    params: &[FieldRef],
1370    args: &[ArrayRef],
1371) -> Result<RecordBatch> {
1372    if args.len() < params.len() {
1373        return exec_err!(
1374            "merge_captures_with_args called with {} params but with {} args",
1375            params.len(),
1376            args.len()
1377        );
1378    }
1379
1380    // the order of the merged batch must be kept in sync with ScalarFunction::lambdas_schemas variants
1381    let (fields, columns) = match captures {
1382        Some(captures) => {
1383            let fields = captures
1384                .schema()
1385                .fields()
1386                .iter()
1387                .chain(params)
1388                .cloned()
1389                .collect::<Vec<_>>();
1390
1391            let columns = [captures.columns(), args].concat();
1392
1393            (fields, columns)
1394        }
1395        None => (params.to_vec(), args.to_vec()),
1396    };
1397
1398    Ok(RecordBatch::try_new(
1399        Arc::new(Schema::new(fields)),
1400        columns,
1401    )?)
1402}
1403
1404/// Lazy version of merge_captures_with_args that receives closures to compute the arguments,
1405/// and calls only the necessary to match the number of params
1406pub fn merge_captures_with_lazy_args(
1407    captures: Option<&RecordBatch>,
1408    params: &[FieldRef],
1409    args: &[&dyn Fn() -> Result<ArrayRef>],
1410) -> Result<RecordBatch> {
1411    merge_captures_with_args(
1412        captures,
1413        params,
1414        &args
1415            .iter()
1416            .take(params.len())
1417            .map(|arg| arg())
1418            .collect::<Result<Vec<_>>>()?,
1419    )
1420}
1421
1422/// Variation of merge_captures_with_lazy_args that take boxed closures
1423pub fn merge_captures_with_boxed_lazy_args(
1424    captures: Option<&RecordBatch>,
1425    params: &[FieldRef],
1426    args: &[Box<dyn Fn() -> Result<ArrayRef>>],
1427) -> Result<RecordBatch> {
1428    merge_captures_with_args(
1429        captures,
1430        params,
1431        &args
1432            .iter()
1433            .take(params.len())
1434            .map(|arg| arg())
1435            .collect::<Result<Vec<_>>>()?,
1436    )
1437}
1438
1439#[cfg(test)]
1440mod tests {
1441    use super::*;
1442    use datafusion_expr_common::signature::Volatility;
1443    use std::hash::DefaultHasher;
1444
1445    #[derive(Debug, PartialEq, Eq, Hash)]
1446    struct TestScalarUDFImpl {
1447        name: &'static str,
1448        field: &'static str,
1449        signature: Signature,
1450    }
1451    impl ScalarUDFImpl for TestScalarUDFImpl {
1452        fn as_any(&self) -> &dyn Any {
1453            self
1454        }
1455
1456        fn name(&self) -> &str {
1457            self.name
1458        }
1459
1460        fn signature(&self) -> &Signature {
1461            &self.signature
1462        }
1463
1464        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1465            unimplemented!()
1466        }
1467
1468        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1469            unimplemented!()
1470        }
1471    }
1472
1473    // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
1474    // must be consistent, so they are tested together.
1475    #[test]
1476    fn test_partial_eq_hash_and_partial_ord() {
1477        // A parameterized function
1478        let f = test_func("foo", "a");
1479
1480        // Same like `f`, different instance
1481        let f2 = test_func("foo", "a");
1482        assert_eq!(f, f2);
1483        assert_eq!(hash(&f), hash(&f2));
1484        assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1485
1486        // Different parameter
1487        let b = test_func("foo", "b");
1488        assert_ne!(f, b);
1489        assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1490        assert_eq!(f.partial_cmp(&b), None);
1491
1492        // Different name
1493        let o = test_func("other", "a");
1494        assert_ne!(f, o);
1495        assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1496        assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1497
1498        // Different name and parameter
1499        assert_ne!(b, o);
1500        assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1501        assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1502    }
1503
1504    fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
1505        ScalarUDF::from(TestScalarUDFImpl {
1506            name,
1507            field: parameter,
1508            signature: Signature::any(1, Volatility::Immutable),
1509        })
1510    }
1511
1512    fn hash<T: Hash>(value: &T) -> u64 {
1513        let hasher = &mut DefaultHasher::new();
1514        value.hash(hasher);
1515        hasher.finish()
1516    }
1517
1518    use std::borrow::Cow;
1519
1520    use arrow::datatypes::Fields;
1521
1522    use crate::{
1523        tree_node::tests::{args, list_int, list_list_int, array_transform_udf},
1524        udf::{lambda_parameters, ExtendableExprSchema},
1525    };
1526
1527    #[test]
1528    fn test_arguments_expr_schema() {
1529        let args = args();
1530        let schema = list_list_int();
1531
1532        let schemas = array_transform_udf()
1533            .arguments_expr_schema(&args, &schema)
1534            .unwrap()
1535            .into_iter()
1536            .map(|s| format!("{s:?}"))
1537            .collect::<Vec<_>>();
1538
1539        let mut lambdas_parameters = array_transform_udf()
1540            .inner()
1541            .lambdas_parameters(&lambda_parameters(&args, &schema).unwrap())
1542            .unwrap();
1543
1544        assert_eq!(
1545            schemas,
1546            &[
1547                format!("{}", &list_list_int()),
1548                format!(
1549                    "{:?}",
1550                    ExtendableExprSchema {
1551                        fields_chain: vec![Fields::from(
1552                            lambdas_parameters[0].take().unwrap()
1553                        )],
1554                        outer_schema: &list_list_int()
1555                    }
1556                ),
1557            ]
1558        )
1559    }
1560
1561    #[test]
1562    fn test_arguments_arrow_schema() {
1563        let list_int = list_int();
1564        let list_list_int = list_list_int();
1565
1566        let schemas = array_transform_udf()
1567            .arguments_arrow_schema(
1568                &lambda_parameters(&args(), &list_list_int).unwrap(),
1569                //&[HashSet::new(), HashSet::from([0])],
1570                list_list_int.as_arrow(),
1571            )
1572            .unwrap();
1573
1574        assert_eq!(
1575            schemas,
1576            &[
1577                Cow::Borrowed(list_list_int.as_arrow()),
1578                Cow::Owned(list_int.as_arrow().clone())
1579            ]
1580        )
1581    }
1582
1583    #[test]
1584    fn test_arguments_schema_from_logical_args() {
1585        let list_list_int = list_list_int();
1586
1587        let schemas = array_transform_udf()
1588            .arguments_schema_from_logical_args(&args(), &list_list_int)
1589            .unwrap();
1590
1591        assert_eq!(
1592            schemas,
1593            &[Cow::Borrowed(&list_list_int), Cow::Owned(list_int())]
1594        )
1595    }
1596}