datafusion_functions_nested/
array_transform.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_transform function.
19
20use arrow::{
21    array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray},
22    compute::take_record_batch,
23    datatypes::{DataType, Field},
24};
25use datafusion_common::{
26    HashSet, Result, exec_err, internal_err, tree_node::{Transformed, TreeNode}, utils::{elements_indices, list_indices, list_values, take_function_args}
27};
28use datafusion_expr::{
29    ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, expr::Lambda, merge_captures_with_lazy_args
30};
31use datafusion_macros::user_doc;
32use std::{any::Any, sync::Arc};
33
34make_udf_expr_and_func!(
35    ArrayTransform,
36    array_transform,
37    array lambda,
38    "transforms the values of a array",
39    array_transform_udf
40);
41
42#[user_doc(
43    doc_section(label = "Array Functions"),
44    description = "transforms the values of a array",
45    syntax_example = "array_transform(array, x -> x*2)",
46    sql_example = r#"```sql
47> select array_transform([1, 2, 3, 4, 5], x -> x*2);
48+-------------------------------------------+
49| array_transform([1, 2, 3, 4, 5], x -> x*2)       |
50+-------------------------------------------+
51| [2, 4, 6, 8, 10]                          |
52+-------------------------------------------+
53```"#,
54    argument(
55        name = "array",
56        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
57    ),
58    argument(name = "lambda", description = "Lambda")
59)]
60#[derive(Debug, PartialEq, Eq, Hash)]
61pub struct ArrayTransform {
62    signature: Signature,
63    aliases: Vec<String>,
64}
65
66impl Default for ArrayTransform {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl ArrayTransform {
73    pub fn new() -> Self {
74        Self {
75            signature: Signature::any(2, Volatility::Immutable),
76            aliases: vec![String::from("list_transform")],
77        }
78    }
79}
80
81impl ScalarUDFImpl for ArrayTransform {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn name(&self) -> &str {
87        "array_transform"
88    }
89
90    fn aliases(&self) -> &[String] {
91        &self.aliases
92    }
93
94    fn signature(&self) -> &Signature {
95        &self.signature
96    }
97
98    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99        internal_err!("return_type called instead of return_field_from_args")
100    }
101
102    fn return_field_from_args(
103        &self,
104        args: datafusion_expr::ReturnFieldArgs,
105    ) -> Result<Arc<Field>> {
106        let args = args.to_lambda_args();
107
108        let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] =
109            take_function_args(self.name(), &args)?
110        else {
111            return exec_err!(
112                "{} expects a value follewed by a lambda, got {:?}",
113                self.name(),
114                args
115            );
116        };
117
118        //TODO: should metadata be passed? If so, with the same keys or prefixed/suffixed?
119
120        // lambda is the resulting field of executing the lambda body
121        // with the parameters returned in lambdas_parameters
122        let field = Arc::new(Field::new(
123            Field::LIST_FIELD_DEFAULT_NAME,
124            lambda.data_type().clone(),
125            lambda.is_nullable(),
126        ));
127
128        let return_type = match list.data_type() {
129            DataType::List(_) => DataType::List(field),
130            DataType::LargeList(_) => DataType::LargeList(field),
131            DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size),
132            _ => unreachable!(),
133        };
134
135        Ok(Arc::new(Field::new("", return_type, list.is_nullable())))
136    }
137
138    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
139        // args.lambda_args allows the convenient match below, instead of inspecting both args.args and args.lambdas
140        let lambda_args = args.to_lambda_args();
141        let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?;
142
143        let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) =
144            (list_value, lambda)
145        else {
146            return exec_err!(
147                "{} expects a value followed by a lambda, got {:?}",
148                self.name(),
149                &lambda_args
150            );
151        };
152
153        let list_array = list_value.to_array(args.number_rows)?;
154
155        // if any column got captured, we need to adjust it to the values arrays,
156        // duplicating values of list with mulitple values and removing values of empty lists
157        // list_indices is not cheap so is important to avoid it when no column is captured
158        let adjusted_captures = lambda
159            .captures
160            .as_ref()
161            .map(|captures| take_record_batch(captures, &list_indices(&list_array)?))
162            .transpose()?;
163
164        // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments
165        // avoiding unnecessary computations
166        let values_param = || Ok(Arc::clone(list_values(&list_array)?));
167        let indices_param = || elements_indices(&list_array);
168
169        // the order of the merged schema is an unspecified implementation detail that may change in the future,
170        // using this function is the correct way to merge as it return the correct ordering and will change in sync
171        // the implementation without the need for fixes. It also computes only the parameters requested
172        let lambda_batch = merge_captures_with_lazy_args(
173            adjusted_captures.as_ref(),
174            &lambda.params, // ScalarUDF already merged the fields returned in lambdas_parameters with the parameters names definied in the lambda, so we don't need to
175            &[&values_param, &indices_param],
176        )?;
177
178        // call the transforming expression with the record batch composed of the list values merged with captured columns
179        let transformed_values = lambda
180            .body
181            .evaluate(&lambda_batch)?
182            .into_array(lambda_batch.num_rows())?;
183
184        let field = match args.return_field.data_type() {
185            DataType::List(field)
186            | DataType::LargeList(field)
187            | DataType::FixedSizeList(field, _) => Arc::clone(field),
188            _ => {
189                return exec_err!(
190                    "{} expected ScalarFunctionArgs.return_field to be a list, got {}",
191                    self.name(),
192                    args.return_field
193                )
194            }
195        };
196
197        let transformed_list = match list_array.data_type() {
198            DataType::List(_) => {
199                let list = list_array.as_list();
200
201                Arc::new(ListArray::new(
202                    field,
203                    list.offsets().clone(),
204                    transformed_values,
205                    list.nulls().cloned(),
206                )) as ArrayRef
207            }
208            DataType::LargeList(_) => {
209                let large_list = list_array.as_list();
210
211                Arc::new(LargeListArray::new(
212                    field,
213                    large_list.offsets().clone(),
214                    transformed_values,
215                    large_list.nulls().cloned(),
216                ))
217            }
218            DataType::FixedSizeList(_, value_length) => {
219                Arc::new(FixedSizeListArray::new(
220                    field,
221                    *value_length,
222                    transformed_values,
223                    list_array.as_fixed_size_list().nulls().cloned(),
224                ))
225            }
226            other => exec_err!("expected list, got {other}")?,
227        };
228
229        Ok(ColumnarValue::Array(transformed_list))
230    }
231
232    fn lambdas_parameters(
233        &self,
234        args: &[ValueOrLambdaParameter],
235    ) -> Result<Vec<Option<Vec<Field>>>> {
236        let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] =
237            args
238        else {
239            return exec_err!(
240                "{} expects a value follewed by a lambda, got {:?}",
241                self.name(),
242                args
243            );
244        };
245
246        let (field, index_type) = match list.data_type() {
247            DataType::List(field) => (field, DataType::Int32),
248            DataType::LargeList(field) => (field, DataType::Int64),
249            DataType::FixedSizeList(field, _) => (field, DataType::Int32),
250            _ => return exec_err!("expected list, got {list}"),
251        };
252
253        // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2),
254        // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j),
255        // as datafusion will do that for us
256        let value = Field::new("value", field.data_type().clone(), field.is_nullable())
257            .with_metadata(field.metadata().clone());
258        let index = Field::new("index", index_type, false);
259
260        Ok(vec![None, Some(vec![value, index])])
261    }
262
263    fn documentation(&self) -> Option<&Documentation> {
264        self.doc()
265    }
266}