datafusion_expr/udf.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDF`]: Scalar User Defined Functions
19
20use crate::async_udf::AsyncScalarUDF;
21use crate::expr::{schema_name_from_exprs_comma_separated_without_space, Lambda};
22use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
23use crate::sort_properties::{ExprProperties, SortProperties};
24use crate::udf_eq::UdfEq;
25use crate::{ColumnarValue, Documentation, Expr, ExprSchemable, Signature};
26use arrow::array::{ArrayRef, RecordBatch};
27use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema};
28use datafusion_common::alias::AliasGenerator;
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::TreeNodeRecursion;
31use datafusion_common::{
32 exec_err, not_impl_err, DFSchema, ExprSchema, Result, ScalarValue,
33};
34use datafusion_expr_common::dyn_eq::{DynEq, DynHash};
35use datafusion_expr_common::interval_arithmetic::Interval;
36use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
37use indexmap::IndexMap;
38use std::any::Any;
39use std::borrow::Cow;
40use std::cmp::Ordering;
41use std::collections::HashMap;
42use std::fmt::Debug;
43use std::hash::{Hash, Hasher};
44use std::sync::{Arc, LazyLock};
45
46/// Logical representation of a Scalar User Defined Function.
47///
48/// A scalar function produces a single row output for each row of input. This
49/// struct contains the information DataFusion needs to plan and invoke
50/// functions you supply such as name, type signature, return type, and actual
51/// implementation.
52///
53/// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]).
54///
55/// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API
56/// access (examples in [`advanced_udf.rs`]).
57///
58/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments.
59///
60/// # API Note
61///
62/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards
63/// compatibility with the older API.
64///
65/// [`create_udf`]: crate::expr_fn::create_udf
66/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
67/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
68#[derive(Debug, Clone)]
69pub struct ScalarUDF {
70 inner: Arc<dyn ScalarUDFImpl>,
71}
72
73impl PartialEq for ScalarUDF {
74 fn eq(&self, other: &Self) -> bool {
75 self.inner.dyn_eq(other.inner.as_any())
76 }
77}
78
79impl PartialOrd for ScalarUDF {
80 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81 let mut cmp = self.name().cmp(other.name());
82 if cmp == Ordering::Equal {
83 cmp = self.signature().partial_cmp(other.signature())?;
84 }
85 if cmp == Ordering::Equal {
86 cmp = self.aliases().partial_cmp(other.aliases())?;
87 }
88 // Contract for PartialOrd and PartialEq consistency requires that
89 // a == b if and only if partial_cmp(a, b) == Some(Equal).
90 if cmp == Ordering::Equal && self != other {
91 // Functions may have other properties besides name and signature
92 // that differentiate two instances (e.g. type, or arbitrary parameters).
93 // We cannot return Some(Equal) in such case.
94 return None;
95 }
96 debug_assert!(
97 cmp == Ordering::Equal || self != other,
98 "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \
99 The functions compare as equal, but they are not equal based on general properties that \
100 the PartialOrd implementation observes,",
101 self.name(), other.name()
102 );
103 Some(cmp)
104 }
105}
106
107impl Eq for ScalarUDF {}
108
109impl Hash for ScalarUDF {
110 fn hash<H: Hasher>(&self, state: &mut H) {
111 self.inner.dyn_hash(state)
112 }
113}
114
115impl ScalarUDF {
116 /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
117 ///
118 /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
119 pub fn new_from_impl<F>(fun: F) -> ScalarUDF
120 where
121 F: ScalarUDFImpl + 'static,
122 {
123 Self::new_from_shared_impl(Arc::new(fun))
124 }
125
126 /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
127 pub fn new_from_shared_impl(fun: Arc<dyn ScalarUDFImpl>) -> ScalarUDF {
128 Self { inner: fun }
129 }
130
131 /// Return the underlying [`ScalarUDFImpl`] trait object for this function
132 pub fn inner(&self) -> &Arc<dyn ScalarUDFImpl> {
133 &self.inner
134 }
135
136 /// Adds additional names that can be used to invoke this function, in
137 /// addition to `name`
138 ///
139 /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
140 pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
141 Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases))
142 }
143
144 /// Returns a [`Expr`] logical expression to call this UDF with specified
145 /// arguments.
146 ///
147 /// This utility allows easily calling UDFs
148 ///
149 /// # Example
150 /// ```no_run
151 /// use datafusion_expr::{col, lit, ScalarUDF};
152 /// # fn my_udf() -> ScalarUDF { unimplemented!() }
153 /// let my_func: ScalarUDF = my_udf();
154 /// // Create an expr for `my_func(a, 12.3)`
155 /// let expr = my_func.call(vec![col("a"), lit(12.3)]);
156 /// ```
157 pub fn call(&self, args: Vec<Expr>) -> Expr {
158 Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
159 Arc::new(self.clone()),
160 args,
161 ))
162 }
163
164 /// Returns this function's name.
165 ///
166 /// See [`ScalarUDFImpl::name`] for more details.
167 pub fn name(&self) -> &str {
168 self.inner.name()
169 }
170
171 /// Returns this function's display_name.
172 ///
173 /// See [`ScalarUDFImpl::display_name`] for more details
174 #[deprecated(
175 since = "50.0.0",
176 note = "This method is unused and will be removed in a future release"
177 )]
178 pub fn display_name(&self, args: &[Expr]) -> Result<String> {
179 #[expect(deprecated)]
180 self.inner.display_name(args)
181 }
182
183 /// Returns this function's schema_name.
184 ///
185 /// See [`ScalarUDFImpl::schema_name`] for more details
186 pub fn schema_name(&self, args: &[Expr]) -> Result<String> {
187 self.inner.schema_name(args)
188 }
189
190 /// Returns the aliases for this function.
191 ///
192 /// See [`ScalarUDF::with_aliases`] for more details
193 pub fn aliases(&self) -> &[String] {
194 self.inner.aliases()
195 }
196
197 /// Returns this function's [`Signature`] (what input types are accepted).
198 ///
199 /// See [`ScalarUDFImpl::signature`] for more details.
200 pub fn signature(&self) -> &Signature {
201 self.inner.signature()
202 }
203
204 /// The datatype this function returns given the input argument types.
205 /// This function is used when the input arguments are [`DataType`]s.
206 ///
207 /// # Notes
208 ///
209 /// If a function implement [`ScalarUDFImpl::return_field_from_args`],
210 /// its [`ScalarUDFImpl::return_type`] should raise an error.
211 ///
212 /// See [`ScalarUDFImpl::return_type`] for more details.
213 pub fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
214 self.inner.return_type(arg_types)
215 }
216
217 /// Return the datatype this function returns given the input argument types.
218 ///
219 /// See [`ScalarUDFImpl::return_field_from_args`] for more details.
220 pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
221 self.inner.return_field_from_args(args)
222 }
223
224 /// Do the function rewrite
225 ///
226 /// See [`ScalarUDFImpl::simplify`] for more details.
227 pub fn simplify(
228 &self,
229 args: Vec<Expr>,
230 info: &dyn SimplifyInfo,
231 ) -> Result<ExprSimplifyResult> {
232 self.inner.simplify(args, info)
233 }
234
235 #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")]
236 pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
237 #[allow(deprecated)]
238 self.inner.is_nullable(args, schema)
239 }
240
241 /// Invoke the function on `args`, returning the appropriate result.
242 ///
243 /// See [`ScalarUDFImpl::invoke_with_args`] for details.
244 pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
245 #[cfg(debug_assertions)]
246 let return_field = Arc::clone(&args.return_field);
247 let result = self.inner.invoke_with_args(args)?;
248 // Maybe this could be enabled always?
249 // This doesn't use debug_assert!, but it's meant to run anywhere except on production. It's same in spirit, thus conditioning on debug_assertions.
250 #[cfg(debug_assertions)]
251 {
252 if &result.data_type() != return_field.data_type() {
253 return datafusion_common::internal_err!("Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'",
254 self.name(),
255 result.data_type(),
256 return_field.data_type()
257 );
258 }
259 // TODO verify return data is non-null when it was promised to be?
260 }
261 Ok(result)
262 }
263
264 /// Determines which of the arguments passed to this function are evaluated eagerly
265 /// and which may be evaluated lazily.
266 ///
267 /// See [ScalarUDFImpl::conditional_arguments] for more information.
268 pub fn conditional_arguments<'a>(
269 &self,
270 args: &'a [Expr],
271 ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
272 self.inner.conditional_arguments(args)
273 }
274
275 /// Returns true if some of this `exprs` subexpressions may not be evaluated
276 /// and thus any side effects (like divide by zero) may not be encountered.
277 ///
278 /// See [ScalarUDFImpl::short_circuits] for more information.
279 pub fn short_circuits(&self) -> bool {
280 self.inner.short_circuits()
281 }
282
283 /// Computes the output interval for a [`ScalarUDF`], given the input
284 /// intervals.
285 ///
286 /// # Parameters
287 ///
288 /// * `inputs` are the intervals for the inputs (children) of this function.
289 ///
290 /// # Example
291 ///
292 /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
293 /// then the output interval would be `[0, 3]`.
294 pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
295 self.inner.evaluate_bounds(inputs)
296 }
297
298 /// Updates bounds for child expressions, given a known interval for this
299 /// function. This is used to propagate constraints down through an expression
300 /// tree.
301 ///
302 /// # Parameters
303 ///
304 /// * `interval` is the currently known interval for this function.
305 /// * `inputs` are the current intervals for the inputs (children) of this function.
306 ///
307 /// # Returns
308 ///
309 /// A `Vec` of new intervals for the children, in order.
310 ///
311 /// If constraint propagation reveals an infeasibility for any child, returns
312 /// [`None`]. If none of the children intervals change as a result of
313 /// propagation, may return an empty vector instead of cloning `children`.
314 /// This is the default (and conservative) return value.
315 ///
316 /// # Example
317 ///
318 /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
319 /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
320 pub fn propagate_constraints(
321 &self,
322 interval: &Interval,
323 inputs: &[&Interval],
324 ) -> Result<Option<Vec<Interval>>> {
325 self.inner.propagate_constraints(interval, inputs)
326 }
327
328 /// Calculates the [`SortProperties`] of this function based on its
329 /// children's properties.
330 pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
331 self.inner.output_ordering(inputs)
332 }
333
334 pub fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
335 self.inner.preserves_lex_ordering(inputs)
336 }
337
338 /// See [`ScalarUDFImpl::coerce_types`] for more details.
339 pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
340 self.inner.coerce_types(arg_types)
341 }
342
343 /// Returns the documentation for this Scalar UDF.
344 ///
345 /// Documentation can be accessed programmatically as well as
346 /// generating publicly facing documentation.
347 pub fn documentation(&self) -> Option<&Documentation> {
348 self.inner.documentation()
349 }
350
351 /// Return true if this function is an async function
352 pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
353 self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
354 }
355
356 /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead
357 pub(crate) fn arguments_expr_schema<'a>(
358 &self,
359 args: &[Expr],
360 schema: &'a dyn ExprSchema,
361 ) -> Result<Vec<impl ExprSchema + 'a>> {
362 self.arguments_scope_with(
363 &lambda_parameters(args, schema)?,
364 ExtendableExprSchema::new(schema),
365 )
366 }
367
368 /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead,
369 pub fn arguments_arrow_schema<'a>(
370 &self,
371 args: &[ValueOrLambdaParameter],
372 schema: &'a Schema,
373 ) -> Result<Vec<Cow<'a, Schema>>> {
374 self.arguments_scope_with(args, Cow::Borrowed(schema))
375 }
376
377 pub fn arguments_schema_from_logical_args<'a>(
378 &self,
379 args: &[Expr],
380 schema: &'a DFSchema,
381 ) -> Result<Vec<Cow<'a, DFSchema>>> {
382 self.arguments_scope_with(
383 &lambda_parameters(args, schema)?,
384 Cow::Borrowed(schema),
385 )
386 }
387
388 /// Scalar function supports lambdas as arguments, which will be evaluated with
389 /// a different schema that of the function itself. This functions returns a vec
390 /// with the correspoding schema that each argument will run
391 ///
392 /// Return a vec with a value for each argument in args that, if it's a value, it's a clone of base_scope,
393 /// if it's a lambda, it's the return of merge called with the index and the fields from lambdas_parameters
394 /// updated with names from metadata
395 fn arguments_scope_with<T: ExtendSchema + Clone>(
396 &self,
397 args: &[ValueOrLambdaParameter],
398 schema: T,
399 ) -> Result<Vec<T>> {
400 let parameters = self.inner().lambdas_parameters(args)?;
401
402 if parameters.len() != args.len() {
403 return exec_err!(
404 "lambdas_schemas: {} lambdas_parameters returned {} values instead of {}",
405 self.name(),
406 args.len(),
407 parameters.len()
408 );
409 }
410
411 std::iter::zip(args, parameters)
412 .enumerate()
413 .map(|(i, (arg, parameters))| match (arg, parameters) {
414 (ValueOrLambdaParameter::Value(_), None) => Ok(schema.clone()),
415 (ValueOrLambdaParameter::Value(_), Some(_)) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a value but lambdas_parameters result treat it as a lambda", self.name(), i),
416 (ValueOrLambdaParameter::Lambda(_, _), None) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a lambda but lambdas_parameters result treat it as a value", self.name(), i),
417 (ValueOrLambdaParameter::Lambda(names, captures), Some(args)) => {
418 if names.len() > args.len() {
419 return exec_err!("lambdas_schemas: {} argument {} (0-indexed), a lambda, supports up to {} arguments, but got {}", self.name(), i, args.len(), names.len())
420 }
421
422 let fields = std::iter::zip(*names, args)
423 .map(|(name, arg)| arg.with_name(name))
424 .collect::<Fields>();
425
426 if *captures {
427 schema.extend(fields)
428 } else {
429 T::from_fields(fields)
430 }
431 }
432 })
433 .collect()
434 }
435}
436
437pub trait ExtendSchema: Sized {
438 fn from_fields(params: Fields) -> Result<Self>;
439 fn extend(&self, params: Fields) -> Result<Self>;
440}
441
442impl ExtendSchema for DFSchema {
443 fn from_fields(params: Fields) -> Result<Self> {
444 DFSchema::from_unqualified_fields(params, Default::default())
445 }
446
447 fn extend(&self, params: Fields) -> Result<Self> {
448 let qualified_fields = self
449 .iter()
450 .map(|(qualifier, field)| {
451 if params.find(field.name().as_str()).is_none() {
452 return (qualifier.cloned(), Arc::clone(field));
453 }
454
455 let alias_gen = AliasGenerator::new();
456
457 loop {
458 let alias = alias_gen.next(field.name().as_str());
459
460 if params.find(&alias).is_none()
461 && !self.has_column_with_unqualified_name(&alias)
462 {
463 return (
464 qualifier.cloned(),
465 Arc::new(Field::new(
466 alias,
467 field.data_type().clone(),
468 field.is_nullable(),
469 )),
470 );
471 }
472 }
473 })
474 .collect();
475
476 let mut schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?;
477 let fields_schema = DFSchema::from_unqualified_fields(params, HashMap::new())?;
478
479 schema.merge(&fields_schema);
480
481 assert_eq!(
482 schema.fields().len(),
483 self.fields().len() + fields_schema.fields().len()
484 );
485
486 Ok(schema)
487 }
488}
489
490impl ExtendSchema for Schema {
491 fn from_fields(params: Fields) -> Result<Self> {
492 Ok(Schema::new(params))
493 }
494
495 fn extend(&self, params: Fields) -> Result<Self> {
496 let mut params2 = params.iter()
497 .map(|f| (f.name().as_str(), Some(Arc::clone(f))))
498 .collect::<IndexMap<_, _>>();
499
500 let mut fields = self.fields()
501 .iter()
502 .map(|field| {
503 match params2.get_mut(field.name().as_str()).and_then(|p| p.take()) {
504 Some(param) => param,
505 None => Arc::clone(field),
506 }
507 })
508 .collect::<Vec<_>>();
509
510 fields.extend(params2.into_values().flatten());
511
512 let fields = self
513 .fields()
514 .iter()
515 .map(|field| {
516 if params.find(field.name().as_str()).is_none() {
517 return Arc::clone(field);
518 }
519
520 let alias_gen = AliasGenerator::new();
521
522 loop {
523 let alias = alias_gen.next(field.name().as_str());
524
525 if params.find(&alias).is_none()
526 && self.column_with_name(&alias).is_none()
527 {
528 return Arc::new(Field::new(
529 alias,
530 field.data_type().clone(),
531 field.is_nullable(),
532 ));
533 }
534 }
535 })
536 .chain(params.iter().cloned())
537 .collect::<Fields>();
538
539 assert_eq!(fields.len(), self.fields().len() + params.len());
540
541 Ok(Schema::new_with_metadata(fields, self.metadata.clone()))
542 }
543}
544
545impl<T: ExtendSchema + Clone> ExtendSchema for Cow<'_, T> {
546 fn from_fields(params: Fields) -> Result<Self> {
547 Ok(Cow::Owned(T::from_fields(params)?))
548 }
549
550 fn extend(&self, params: Fields) -> Result<Self> {
551 Ok(Cow::Owned(self.as_ref().extend(params)?))
552 }
553}
554
555impl<T: ExtendSchema> ExtendSchema for Arc<T> {
556 fn from_fields(params: Fields) -> Result<Self> {
557 Ok(Arc::new(T::from_fields(params)?))
558 }
559
560 fn extend(&self, params: Fields) -> Result<Self> {
561 Ok(Arc::new(self.as_ref().extend(params)?))
562 }
563}
564
565impl ExtendSchema for ExtendableExprSchema<'_> {
566 fn from_fields(params: Fields) -> Result<Self> {
567 static EMPTY_DFSCHEMA: LazyLock<DFSchema> = LazyLock::new(DFSchema::empty);
568
569 Ok(ExtendableExprSchema {
570 fields_chain: vec![params],
571 outer_schema: &*EMPTY_DFSCHEMA,
572 })
573 }
574
575 fn extend(&self, params: Fields) -> Result<Self> {
576 Ok(ExtendableExprSchema {
577 fields_chain: std::iter::once(params)
578 .chain(self.fields_chain.iter().cloned())
579 .collect(),
580 outer_schema: self.outer_schema,
581 })
582 }
583}
584
585/// A `&dyn ExprSchema` wrapper that supports adding the parameters of a lambda
586#[derive(Clone, Debug)]
587struct ExtendableExprSchema<'a> {
588 fields_chain: Vec<Fields>,
589 outer_schema: &'a dyn ExprSchema,
590}
591
592impl<'a> ExtendableExprSchema<'a> {
593 fn new(schema: &'a dyn ExprSchema) -> Self {
594 Self {
595 fields_chain: vec![],
596 outer_schema: schema,
597 }
598 }
599}
600
601impl ExprSchema for ExtendableExprSchema<'_> {
602 fn field_from_column(&self, col: &datafusion_common::Column) -> Result<&Field> {
603 if col.relation.is_none() {
604 for fields in &self.fields_chain {
605 if let Some((_index, lambda_param)) = fields.find(&col.name) {
606 return Ok(lambda_param);
607 }
608 }
609 }
610
611 self.outer_schema.field_from_column(col)
612 }
613}
614
615#[derive(Clone, Debug)]
616pub enum ValueOrLambdaParameter<'a> {
617 /// A columnar value with the given field
618 Value(FieldRef),
619 /// A lambda with the given parameters names and a flag indicating wheter it captures any columns
620 Lambda(&'a [String], bool),
621}
622
623impl<F> From<F> for ScalarUDF
624where
625 F: ScalarUDFImpl + 'static,
626{
627 fn from(fun: F) -> Self {
628 Self::new_from_impl(fun)
629 }
630}
631
632/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
633/// scalar function.
634#[derive(Debug, Clone)]
635pub struct ScalarFunctionArgs {
636 /// The evaluated arguments to the function
637 /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)`
638 pub args: Vec<ColumnarValue>,
639 /// Field associated with each arg, if it exists
640 pub arg_fields: Vec<FieldRef>,
641 /// The number of rows in record batch being evaluated
642 pub number_rows: usize,
643 /// The return field of the scalar function returned (from `return_type`
644 /// or `return_field_from_args`) when creating the physical expression
645 /// from the logical expression
646 pub return_field: FieldRef,
647 /// The config options at execution time
648 pub config_options: Arc<ConfigOptions>,
649 /// The lambdas passed to the function
650 /// If it's not a lambda it will be `None`
651 pub lambdas: Option<Vec<Option<ScalarFunctionLambdaArg>>>,
652}
653
654/// A lambda argument to a ScalarFunction
655#[derive(Clone, Debug)]
656pub struct ScalarFunctionLambdaArg {
657 /// The parameters defined in this lambda
658 ///
659 /// For example, for `array_transform([2], v -> -v)`,
660 /// this will be `vec![Field::new("v", DataType::Int32, true)]`
661 pub params: Vec<FieldRef>,
662 /// The body of the lambda
663 ///
664 /// For example, for `array_transform([2], v -> -v)`,
665 /// this will be the physical expression of `-v`
666 pub body: Arc<dyn PhysicalExpr>,
667 /// A RecordBatch containing at least the captured columns inside this lambda body, if any
668 /// Note that it may contain additional, non-specified columns, but that's implementation detail
669 ///
670 /// For example, for `array_transform([2], v -> v + a + b)`,
671 /// this will be a `RecordBatch` with two columns, `a` and `b`
672 pub captures: Option<RecordBatch>,
673}
674
675impl ScalarFunctionArgs {
676 /// The return type of the function. See [`Self::return_field`] for more
677 /// details.
678 pub fn return_type(&self) -> &DataType {
679 self.return_field.data_type()
680 }
681
682 pub fn to_lambda_args(&self) -> Vec<ValueOrLambda<'_>> {
683 match &self.lambdas {
684 Some(lambdas) => std::iter::zip(&self.args, lambdas)
685 .map(|(arg, lambda)| match lambda {
686 Some(lambda) => ValueOrLambda::Lambda(lambda),
687 None => ValueOrLambda::Value(arg),
688 })
689 .collect(),
690 None => self.args.iter().map(ValueOrLambda::Value).collect(),
691 }
692 }
693}
694
695// An argument to a ScalarUDF that supports lambdas
696#[derive(Debug)]
697pub enum ValueOrLambda<'a> {
698 Value(&'a ColumnarValue),
699 Lambda(&'a ScalarFunctionLambdaArg),
700}
701
702/// Information about arguments passed to the function
703///
704/// This structure contains metadata about how the function was called
705/// such as the type of the arguments, any scalar arguments and if the
706/// arguments can (ever) be null
707///
708/// See [`ScalarUDFImpl::return_field_from_args`] for more information
709#[derive(Debug)]
710pub struct ReturnFieldArgs<'a> {
711 /// The data types of the arguments to the function
712 ///
713 /// If argument `i` to the function is a lambda, it will be the field returned by the
714 /// lambda when executed with the arguments returned from `ScalarUDFImpl::lambdas_parameters`
715 ///
716 /// For example, with `array_transform([1], v -> v == 5)`
717 /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]`
718 pub arg_fields: &'a [FieldRef],
719 /// Is argument `i` to the function a scalar (constant)?
720 ///
721 /// If the argument `i` is not a scalar, it will be None
722 ///
723 /// For example, if a function is called like `my_function(column_a, 5)`
724 /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]`
725 pub scalar_arguments: &'a [Option<&'a ScalarValue>],
726 /// Is argument `i` to the function a lambda?
727 ///
728 /// For example, with `array_transform([1], v -> v == 5)`
729 /// this field will be `[false, true]`
730 pub lambdas: &'a [bool],
731}
732
733/// A tagged Field indicating whether it correspond to a value or a lambda argument
734#[derive(Debug)]
735pub enum ValueOrLambdaField<'a> {
736 /// The Field of a ColumnarValue argument
737 Value(&'a FieldRef),
738 /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters
739 Lambda(&'a FieldRef),
740}
741
742impl<'a> ReturnFieldArgs<'a> {
743 /// Based on self.lambdas, encodes self.arg_fields to tagged enums
744 /// indicating whether it correspond to a value or a lambda argument
745 pub fn to_lambda_args(&self) -> Vec<ValueOrLambdaField<'a>> {
746 std::iter::zip(self.arg_fields, self.lambdas)
747 .map(|(field, is_lambda)| {
748 if *is_lambda {
749 ValueOrLambdaField::Lambda(field)
750 } else {
751 ValueOrLambdaField::Value(field)
752 }
753 })
754 .collect()
755 }
756}
757
758/// Trait for implementing user defined scalar functions.
759///
760/// This trait exposes the full API for implementing user defined functions and
761/// can be used to implement any function.
762///
763/// See [`advanced_udf.rs`] for a full example with complete implementation and
764/// [`ScalarUDF`] for other available options.
765///
766/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
767///
768/// # Basic Example
769/// ```
770/// # use std::any::Any;
771/// # use std::sync::LazyLock;
772/// # use arrow::datatypes::DataType;
773/// # use datafusion_common::{DataFusionError, plan_err, Result};
774/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
775/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
776/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
777/// /// This struct for a simple UDF that adds one to an int32
778/// #[derive(Debug, PartialEq, Eq, Hash)]
779/// struct AddOne {
780/// signature: Signature,
781/// }
782///
783/// impl AddOne {
784/// fn new() -> Self {
785/// Self {
786/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable),
787/// }
788/// }
789/// }
790///
791/// static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
792/// Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)")
793/// .with_argument("arg1", "The int32 number to add one to")
794/// .build()
795/// });
796///
797/// fn get_doc() -> &'static Documentation {
798/// &DOCUMENTATION
799/// }
800///
801/// /// Implement the ScalarUDFImpl trait for AddOne
802/// impl ScalarUDFImpl for AddOne {
803/// fn as_any(&self) -> &dyn Any { self }
804/// fn name(&self) -> &str { "add_one" }
805/// fn signature(&self) -> &Signature { &self.signature }
806/// fn return_type(&self, args: &[DataType]) -> Result<DataType> {
807/// if !matches!(args.get(0), Some(&DataType::Int32)) {
808/// return plan_err!("add_one only accepts Int32 arguments");
809/// }
810/// Ok(DataType::Int32)
811/// }
812/// // The actual implementation would add one to the argument
813/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
814/// unimplemented!()
815/// }
816/// fn documentation(&self) -> Option<&Documentation> {
817/// Some(get_doc())
818/// }
819/// }
820///
821/// // Create a new ScalarUDF from the implementation
822/// let add_one = ScalarUDF::from(AddOne::new());
823///
824/// // Call the function `add_one(col)`
825/// let expr = add_one.call(vec![col("a")]);
826/// ```
827pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
828 /// Returns this object as an [`Any`] trait object
829 fn as_any(&self) -> &dyn Any;
830
831 /// Returns this function's name
832 fn name(&self) -> &str;
833
834 /// Returns any aliases (alternate names) for this function.
835 ///
836 /// Aliases can be used to invoke the same function using different names.
837 /// For example in some databases `now()` and `current_timestamp()` are
838 /// aliases for the same function. This behavior can be obtained by
839 /// returning `current_timestamp` as an alias for the `now` function.
840 ///
841 /// Note: `aliases` should only include names other than [`Self::name`].
842 /// Defaults to `[]` (no aliases)
843 fn aliases(&self) -> &[String] {
844 &[]
845 }
846
847 /// Returns the user-defined display name of function, given the arguments
848 ///
849 /// This can be used to customize the output column name generated by this
850 /// function.
851 ///
852 /// Defaults to `name(args[0], args[1], ...)`
853 #[deprecated(
854 since = "50.0.0",
855 note = "This method is unused and will be removed in a future release"
856 )]
857 fn display_name(&self, args: &[Expr]) -> Result<String> {
858 let names: Vec<String> = args.iter().map(ToString::to_string).collect();
859 // TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
860 Ok(format!("{}({})", self.name(), names.join(",")))
861 }
862
863 /// Returns the name of the column this expression would create
864 ///
865 /// See [`Expr::schema_name`] for details
866 fn schema_name(&self, args: &[Expr]) -> Result<String> {
867 Ok(format!(
868 "{}({})",
869 self.name(),
870 schema_name_from_exprs_comma_separated_without_space(args)?
871 ))
872 }
873
874 /// Returns a [`Signature`] describing the argument types for which this
875 /// function has an implementation, and the function's [`Volatility`].
876 ///
877 /// See [`Signature`] for more details on argument type handling
878 /// and [`Self::return_type`] for computing the return type.
879 ///
880 /// [`Volatility`]: datafusion_expr_common::signature::Volatility
881 fn signature(&self) -> &Signature;
882
883 /// [`DataType`] returned by this function, given the types of the
884 /// arguments.
885 ///
886 /// # Arguments
887 ///
888 /// `arg_types` Data types of the arguments. The implementation of
889 /// `return_type` can assume that some other part of the code has coerced
890 /// the actual argument types to match [`Self::signature`].
891 ///
892 /// # Notes
893 ///
894 /// If you provide an implementation for [`Self::return_field_from_args`],
895 /// DataFusion will not call `return_type` (this function). While it is
896 /// valid to to put [`unimplemented!()`] or [`unreachable!()`], it is
897 /// recommended to return [`DataFusionError::Internal`] instead, which
898 /// reduces the severity of symptoms if bugs occur (an error rather than a
899 /// panic).
900 ///
901 /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
902 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
903
904 /// Create a new instance of this function with updated configuration.
905 ///
906 /// This method is called when configuration options change at runtime
907 /// (e.g., via `SET` statements) to allow functions that depend on
908 /// configuration to update themselves accordingly.
909 ///
910 /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so
911 /// this API is not needed for functions where the values may
912 /// depend on the current options.
913 ///
914 /// This API is useful for functions where the return
915 /// **type** depends on the configuration options, such as the `now()` function
916 /// which depends on the current timezone.
917 ///
918 /// # Arguments
919 ///
920 /// * `config` - The updated configuration options
921 ///
922 /// # Returns
923 ///
924 /// * `Some(ScalarUDF)` - A new instance of this function configured with the new settings
925 /// * `None` - If this function does not change with new configuration settings (the default)
926 fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
927 None
928 }
929
930 /// What type will be returned by this function, given the arguments?
931 ///
932 /// By default, this function calls [`Self::return_type`] with the
933 /// types of each argument.
934 ///
935 /// # Notes
936 ///
937 /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient,
938 /// as the result type is typically a deterministic function of the input types
939 /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly
940 /// is generally unnecessary unless the return type depends on runtime values.
941 ///
942 /// This function can be used for more advanced cases such as:
943 ///
944 /// 1. specifying nullability
945 /// 2. return types based on the **values** of the arguments (rather than
946 /// their **types**.
947 ///
948 /// # Example creating `Field`
949 ///
950 /// Note the name of the [`Field`] is ignored, except for structured types such as
951 /// `DataType::Struct`.
952 ///
953 /// ```rust
954 /// # use std::sync::Arc;
955 /// # use arrow::datatypes::{DataType, Field, FieldRef};
956 /// # use datafusion_common::Result;
957 /// # use datafusion_expr::ReturnFieldArgs;
958 /// # struct Example{}
959 /// # impl Example {
960 /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
961 /// // report output is only nullable if any one of the arguments are nullable
962 /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
963 /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true));
964 /// Ok(field)
965 /// }
966 /// # }
967 /// ```
968 ///
969 /// # Output Type based on Values
970 ///
971 /// For example, the following two function calls get the same argument
972 /// types (something and a `Utf8` string) but return different types based
973 /// on the value of the second argument:
974 ///
975 /// * `arrow_cast(x, 'Int16')` --> `Int16`
976 /// * `arrow_cast(x, 'Float32')` --> `Float32`
977 ///
978 /// # Requirements
979 ///
980 /// This function **must** consistently return the same type for the same
981 /// logical input even if the input is simplified (e.g. it must return the same
982 /// value for `('foo' | 'bar')` as it does for ('foobar').
983 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
984 let data_types = args
985 .arg_fields
986 .iter()
987 .map(|f| f.data_type())
988 .cloned()
989 .collect::<Vec<_>>();
990 let return_type = self.return_type(&data_types)?;
991 Ok(Arc::new(Field::new(self.name(), return_type, true)))
992 }
993
994 #[deprecated(
995 since = "45.0.0",
996 note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error"
997 )]
998 fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool {
999 true
1000 }
1001
1002 /// Invoke the function returning the appropriate result.
1003 ///
1004 /// # Performance
1005 ///
1006 /// For the best performance, the implementations should handle the common case
1007 /// when one or more of their arguments are constant values (aka
1008 /// [`ColumnarValue::Scalar`]).
1009 ///
1010 /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
1011 /// to arrays, which will likely be simpler code, but be slower.
1012 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue>;
1013
1014 /// Optionally apply per-UDF simplification / rewrite rules.
1015 ///
1016 /// This can be used to apply function specific simplification rules during
1017 /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
1018 /// implementation does nothing.
1019 ///
1020 /// Note that DataFusion handles simplifying arguments and "constant
1021 /// folding" (replacing a function call with constant arguments such as
1022 /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
1023 /// optimizations manually for specific UDFs.
1024 ///
1025 /// # Arguments
1026 /// * `args`: The arguments of the function
1027 /// * `info`: The necessary information for simplification
1028 ///
1029 /// # Returns
1030 /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE
1031 /// if the function cannot be simplified, the arguments *MUST* be returned
1032 /// unmodified
1033 ///
1034 /// # Notes
1035 ///
1036 /// The returned expression must have the same schema as the original
1037 /// expression, including both the data type and nullability. For example,
1038 /// if the original expression is nullable, the returned expression must
1039 /// also be nullable, otherwise it may lead to schema verification errors
1040 /// later in query planning.
1041 fn simplify(
1042 &self,
1043 args: Vec<Expr>,
1044 _info: &dyn SimplifyInfo,
1045 ) -> Result<ExprSimplifyResult> {
1046 Ok(ExprSimplifyResult::Original(args))
1047 }
1048
1049 /// Returns true if some of this `exprs` subexpressions may not be evaluated
1050 /// and thus any side effects (like divide by zero) may not be encountered.
1051 ///
1052 /// Setting this to true prevents certain optimizations such as common
1053 /// subexpression elimination
1054 ///
1055 /// When overriding this function to return `true`, [ScalarUDFImpl::conditional_arguments] can also be
1056 /// overridden to report more accurately which arguments are eagerly evaluated and which ones
1057 /// lazily.
1058 fn short_circuits(&self) -> bool {
1059 false
1060 }
1061
1062 /// Determines which of the arguments passed to this function are evaluated eagerly
1063 /// and which may be evaluated lazily.
1064 ///
1065 /// If this function returns `None`, all arguments are eagerly evaluated.
1066 /// Returning `None` is a micro optimization that saves a needless `Vec`
1067 /// allocation.
1068 ///
1069 /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager`
1070 /// are the arguments that are always evaluated, and `lazy` are the
1071 /// arguments that may be evaluated lazily (i.e. may not be evaluated at all
1072 /// in some cases).
1073 ///
1074 /// Implementations must ensure that the two returned `Vec`s are disjunct,
1075 /// and that each argument from `args` is present in one the two `Vec`s.
1076 ///
1077 /// When overriding this function, [ScalarUDFImpl::short_circuits] must
1078 /// be overridden to return `true`.
1079 fn conditional_arguments<'a>(
1080 &self,
1081 args: &'a [Expr],
1082 ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1083 if self.short_circuits() {
1084 Some((vec![], args.iter().collect()))
1085 } else {
1086 None
1087 }
1088 }
1089
1090 /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input
1091 /// intervals.
1092 ///
1093 /// # Parameters
1094 ///
1095 /// * `children` are the intervals for the children (inputs) of this function.
1096 ///
1097 /// # Example
1098 ///
1099 /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`,
1100 /// then the output interval would be `[0, 3]`.
1101 fn evaluate_bounds(&self, _input: &[&Interval]) -> Result<Interval> {
1102 // We cannot assume the input datatype is the same of output type.
1103 Interval::make_unbounded(&DataType::Null)
1104 }
1105
1106 /// Updates bounds for child expressions, given a known [`Interval`]s for this
1107 /// function.
1108 ///
1109 /// This function is used to propagate constraints down through an
1110 /// expression tree.
1111 ///
1112 /// # Parameters
1113 ///
1114 /// * `interval` is the currently known interval for this function.
1115 /// * `inputs` are the current intervals for the inputs (children) of this function.
1116 ///
1117 /// # Returns
1118 ///
1119 /// A `Vec` of new intervals for the children, in order.
1120 ///
1121 /// If constraint propagation reveals an infeasibility for any child, returns
1122 /// [`None`]. If none of the children intervals change as a result of
1123 /// propagation, may return an empty vector instead of cloning `children`.
1124 /// This is the default (and conservative) return value.
1125 ///
1126 /// # Example
1127 ///
1128 /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the
1129 /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`.
1130 fn propagate_constraints(
1131 &self,
1132 _interval: &Interval,
1133 _inputs: &[&Interval],
1134 ) -> Result<Option<Vec<Interval>>> {
1135 Ok(Some(vec![]))
1136 }
1137
1138 /// Calculates the [`SortProperties`] of this function based on its children's properties.
1139 fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
1140 if !self.preserves_lex_ordering(inputs)? {
1141 return Ok(SortProperties::Unordered);
1142 }
1143
1144 let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else {
1145 return Ok(SortProperties::Singleton);
1146 };
1147
1148 if inputs
1149 .iter()
1150 .skip(1)
1151 .all(|input| &input.sort_properties == first_order)
1152 {
1153 Ok(*first_order)
1154 } else {
1155 Ok(SortProperties::Unordered)
1156 }
1157 }
1158
1159 /// Returns true if the function preserves lexicographical ordering based on
1160 /// the input ordering.
1161 ///
1162 /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not.
1163 fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
1164 Ok(false)
1165 }
1166
1167 /// Coerce arguments of a function call to types that the function can evaluate.
1168 ///
1169 /// This function is only called if [`ScalarUDFImpl::signature`] returns
1170 /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of
1171 /// the other variants of [`TypeSignature`] which handle common cases.
1172 ///
1173 /// See the [type coercion module](crate::type_coercion)
1174 /// documentation for more details on type coercion
1175 ///
1176 /// [`TypeSignature`]: crate::TypeSignature
1177 ///
1178 /// For example, if your function requires a floating point arguments, but the user calls
1179 /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]`
1180 /// to ensure the argument is converted to `1::double`
1181 ///
1182 /// # Parameters
1183 /// * `arg_types`: The argument types of the arguments this function with
1184 ///
1185 /// # Return value
1186 /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
1187 /// arguments to these specific types.
1188 fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
1189 not_impl_err!("Function {} does not implement coerce_types", self.name())
1190 }
1191
1192 /// Returns the documentation for this Scalar UDF.
1193 ///
1194 /// Documentation can be accessed programmatically as well as generating
1195 /// publicly facing documentation.
1196 fn documentation(&self) -> Option<&Documentation> {
1197 None
1198 }
1199
1200 /// Returns the parameters that any lambda supports
1201 fn lambdas_parameters(
1202 &self,
1203 args: &[ValueOrLambdaParameter],
1204 ) -> Result<Vec<Option<Vec<Field>>>> {
1205 Ok(vec![None; args.len()])
1206 }
1207}
1208
1209/// ScalarUDF that adds an alias to the underlying function. It is better to
1210/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
1211#[derive(Debug, PartialEq, Eq, Hash)]
1212struct AliasedScalarUDFImpl {
1213 inner: UdfEq<Arc<dyn ScalarUDFImpl>>,
1214 aliases: Vec<String>,
1215}
1216
1217impl AliasedScalarUDFImpl {
1218 pub fn new(
1219 inner: Arc<dyn ScalarUDFImpl>,
1220 new_aliases: impl IntoIterator<Item = &'static str>,
1221 ) -> Self {
1222 let mut aliases = inner.aliases().to_vec();
1223 aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));
1224 Self {
1225 inner: inner.into(),
1226 aliases,
1227 }
1228 }
1229}
1230
1231#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method
1232impl ScalarUDFImpl for AliasedScalarUDFImpl {
1233 fn as_any(&self) -> &dyn Any {
1234 self
1235 }
1236
1237 fn name(&self) -> &str {
1238 self.inner.name()
1239 }
1240
1241 fn display_name(&self, args: &[Expr]) -> Result<String> {
1242 #[expect(deprecated)]
1243 self.inner.display_name(args)
1244 }
1245
1246 fn schema_name(&self, args: &[Expr]) -> Result<String> {
1247 self.inner.schema_name(args)
1248 }
1249
1250 fn signature(&self) -> &Signature {
1251 self.inner.signature()
1252 }
1253
1254 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
1255 self.inner.return_type(arg_types)
1256 }
1257
1258 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
1259 self.inner.return_field_from_args(args)
1260 }
1261
1262 fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
1263 #[allow(deprecated)]
1264 self.inner.is_nullable(args, schema)
1265 }
1266
1267 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1268 self.inner.invoke_with_args(args)
1269 }
1270
1271 fn with_updated_config(&self, _config: &ConfigOptions) -> Option<ScalarUDF> {
1272 None
1273 }
1274
1275 fn aliases(&self) -> &[String] {
1276 &self.aliases
1277 }
1278
1279 fn simplify(
1280 &self,
1281 args: Vec<Expr>,
1282 info: &dyn SimplifyInfo,
1283 ) -> Result<ExprSimplifyResult> {
1284 self.inner.simplify(args, info)
1285 }
1286
1287 fn conditional_arguments<'a>(
1288 &self,
1289 args: &'a [Expr],
1290 ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> {
1291 self.inner.conditional_arguments(args)
1292 }
1293
1294 fn short_circuits(&self) -> bool {
1295 self.inner.short_circuits()
1296 }
1297
1298 fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
1299 self.inner.evaluate_bounds(input)
1300 }
1301
1302 fn propagate_constraints(
1303 &self,
1304 interval: &Interval,
1305 inputs: &[&Interval],
1306 ) -> Result<Option<Vec<Interval>>> {
1307 self.inner.propagate_constraints(interval, inputs)
1308 }
1309
1310 fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
1311 self.inner.output_ordering(inputs)
1312 }
1313
1314 fn preserves_lex_ordering(&self, inputs: &[ExprProperties]) -> Result<bool> {
1315 self.inner.preserves_lex_ordering(inputs)
1316 }
1317
1318 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
1319 self.inner.coerce_types(arg_types)
1320 }
1321
1322 fn documentation(&self) -> Option<&Documentation> {
1323 self.inner.documentation()
1324 }
1325
1326 fn lambdas_parameters(
1327 &self,
1328 args: &[ValueOrLambdaParameter],
1329 ) -> Result<Vec<Option<Vec<Field>>>> {
1330 self.inner.lambdas_parameters(args)
1331 }
1332}
1333
1334fn lambda_parameters<'a>(
1335 args: &'a [Expr],
1336 schema: &dyn ExprSchema,
1337) -> Result<Vec<ValueOrLambdaParameter<'a>>> {
1338 args.iter()
1339 .map(|e| match e {
1340 Expr::Lambda(Lambda { params, body: _ }) => {
1341 let mut captures = false;
1342
1343 e.apply_with_lambdas_params(|expr, lambdas_params| match expr {
1344 Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
1345 captures = true;
1346
1347 Ok(TreeNodeRecursion::Stop)
1348 }
1349 _ => Ok(TreeNodeRecursion::Continue),
1350 })
1351 .unwrap();
1352
1353 Ok(ValueOrLambdaParameter::Lambda(params.as_slice(), captures))
1354 }
1355 _ => Ok(ValueOrLambdaParameter::Value(e.to_field(schema)?.1)),
1356 })
1357 .collect()
1358}
1359
1360/// Merge the lambda body captured columns with it's arguments
1361/// Datafusion relies on an unspecified field ordering implemented in this function
1362/// As such, this is the only correct way to merge the captured values with the arguments
1363/// The number of args should not be lower than the number of params
1364///
1365/// See also merge_captures_with_lazy_args and merge_captures_with_boxed_lazy_args that lazily
1366/// computes only the necessary arguments to match the number of params
1367pub fn merge_captures_with_args(
1368 captures: Option<&RecordBatch>,
1369 params: &[FieldRef],
1370 args: &[ArrayRef],
1371) -> Result<RecordBatch> {
1372 if args.len() < params.len() {
1373 return exec_err!(
1374 "merge_captures_with_args called with {} params but with {} args",
1375 params.len(),
1376 args.len()
1377 );
1378 }
1379
1380 // the order of the merged batch must be kept in sync with ScalarFunction::lambdas_schemas variants
1381 let (fields, columns) = match captures {
1382 Some(captures) => {
1383 let fields = captures
1384 .schema()
1385 .fields()
1386 .iter()
1387 .chain(params)
1388 .cloned()
1389 .collect::<Vec<_>>();
1390
1391 let columns = [captures.columns(), args].concat();
1392
1393 (fields, columns)
1394 }
1395 None => (params.to_vec(), args.to_vec()),
1396 };
1397
1398 Ok(RecordBatch::try_new(
1399 Arc::new(Schema::new(fields)),
1400 columns,
1401 )?)
1402}
1403
1404/// Lazy version of merge_captures_with_args that receives closures to compute the arguments,
1405/// and calls only the necessary to match the number of params
1406pub fn merge_captures_with_lazy_args(
1407 captures: Option<&RecordBatch>,
1408 params: &[FieldRef],
1409 args: &[&dyn Fn() -> Result<ArrayRef>],
1410) -> Result<RecordBatch> {
1411 merge_captures_with_args(
1412 captures,
1413 params,
1414 &args
1415 .iter()
1416 .take(params.len())
1417 .map(|arg| arg())
1418 .collect::<Result<Vec<_>>>()?,
1419 )
1420}
1421
1422/// Variation of merge_captures_with_lazy_args that take boxed closures
1423pub fn merge_captures_with_boxed_lazy_args(
1424 captures: Option<&RecordBatch>,
1425 params: &[FieldRef],
1426 args: &[Box<dyn Fn() -> Result<ArrayRef>>],
1427) -> Result<RecordBatch> {
1428 merge_captures_with_args(
1429 captures,
1430 params,
1431 &args
1432 .iter()
1433 .take(params.len())
1434 .map(|arg| arg())
1435 .collect::<Result<Vec<_>>>()?,
1436 )
1437}
1438
1439#[cfg(test)]
1440mod tests {
1441 use super::*;
1442 use datafusion_expr_common::signature::Volatility;
1443 use std::hash::DefaultHasher;
1444
1445 #[derive(Debug, PartialEq, Eq, Hash)]
1446 struct TestScalarUDFImpl {
1447 name: &'static str,
1448 field: &'static str,
1449 signature: Signature,
1450 }
1451 impl ScalarUDFImpl for TestScalarUDFImpl {
1452 fn as_any(&self) -> &dyn Any {
1453 self
1454 }
1455
1456 fn name(&self) -> &str {
1457 self.name
1458 }
1459
1460 fn signature(&self) -> &Signature {
1461 &self.signature
1462 }
1463
1464 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1465 unimplemented!()
1466 }
1467
1468 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1469 unimplemented!()
1470 }
1471 }
1472
1473 // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd
1474 // must be consistent, so they are tested together.
1475 #[test]
1476 fn test_partial_eq_hash_and_partial_ord() {
1477 // A parameterized function
1478 let f = test_func("foo", "a");
1479
1480 // Same like `f`, different instance
1481 let f2 = test_func("foo", "a");
1482 assert_eq!(f, f2);
1483 assert_eq!(hash(&f), hash(&f2));
1484 assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal));
1485
1486 // Different parameter
1487 let b = test_func("foo", "b");
1488 assert_ne!(f, b);
1489 assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test
1490 assert_eq!(f.partial_cmp(&b), None);
1491
1492 // Different name
1493 let o = test_func("other", "a");
1494 assert_ne!(f, o);
1495 assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test
1496 assert_eq!(f.partial_cmp(&o), Some(Ordering::Less));
1497
1498 // Different name and parameter
1499 assert_ne!(b, o);
1500 assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test
1501 assert_eq!(b.partial_cmp(&o), Some(Ordering::Less));
1502 }
1503
1504 fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF {
1505 ScalarUDF::from(TestScalarUDFImpl {
1506 name,
1507 field: parameter,
1508 signature: Signature::any(1, Volatility::Immutable),
1509 })
1510 }
1511
1512 fn hash<T: Hash>(value: &T) -> u64 {
1513 let hasher = &mut DefaultHasher::new();
1514 value.hash(hasher);
1515 hasher.finish()
1516 }
1517
1518 use std::borrow::Cow;
1519
1520 use arrow::datatypes::Fields;
1521
1522 use crate::{
1523 tree_node::tests::{args, list_int, list_list_int, array_transform_udf},
1524 udf::{lambda_parameters, ExtendableExprSchema},
1525 };
1526
1527 #[test]
1528 fn test_arguments_expr_schema() {
1529 let args = args();
1530 let schema = list_list_int();
1531
1532 let schemas = array_transform_udf()
1533 .arguments_expr_schema(&args, &schema)
1534 .unwrap()
1535 .into_iter()
1536 .map(|s| format!("{s:?}"))
1537 .collect::<Vec<_>>();
1538
1539 let mut lambdas_parameters = array_transform_udf()
1540 .inner()
1541 .lambdas_parameters(&lambda_parameters(&args, &schema).unwrap())
1542 .unwrap();
1543
1544 assert_eq!(
1545 schemas,
1546 &[
1547 format!("{}", &list_list_int()),
1548 format!(
1549 "{:?}",
1550 ExtendableExprSchema {
1551 fields_chain: vec![Fields::from(
1552 lambdas_parameters[0].take().unwrap()
1553 )],
1554 outer_schema: &list_list_int()
1555 }
1556 ),
1557 ]
1558 )
1559 }
1560
1561 #[test]
1562 fn test_arguments_arrow_schema() {
1563 let list_int = list_int();
1564 let list_list_int = list_list_int();
1565
1566 let schemas = array_transform_udf()
1567 .arguments_arrow_schema(
1568 &lambda_parameters(&args(), &list_list_int).unwrap(),
1569 //&[HashSet::new(), HashSet::from([0])],
1570 list_list_int.as_arrow(),
1571 )
1572 .unwrap();
1573
1574 assert_eq!(
1575 schemas,
1576 &[
1577 Cow::Borrowed(list_list_int.as_arrow()),
1578 Cow::Owned(list_int.as_arrow().clone())
1579 ]
1580 )
1581 }
1582
1583 #[test]
1584 fn test_arguments_schema_from_logical_args() {
1585 let list_list_int = list_list_int();
1586
1587 let schemas = array_transform_udf()
1588 .arguments_schema_from_logical_args(&args(), &list_list_int)
1589 .unwrap();
1590
1591 assert_eq!(
1592 schemas,
1593 &[Cow::Borrowed(&list_list_int), Cow::Owned(list_int())]
1594 )
1595 }
1596}