datafusion_functions_nested/
array_transform.rs1use arrow::{
21 array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray},
22 compute::take_record_batch,
23 datatypes::{DataType, Field},
24};
25use datafusion_common::{
26 HashSet, Result, exec_err, internal_err, tree_node::{Transformed, TreeNode}, utils::{elements_indices, list_indices, list_values, take_function_args}
27};
28use datafusion_expr::{
29 ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, expr::Lambda, merge_captures_with_lazy_args
30};
31use datafusion_macros::user_doc;
32use std::{any::Any, sync::Arc};
33
34make_udf_expr_and_func!(
35 ArrayTransform,
36 array_transform,
37 array lambda,
38 "transforms the values of a array",
39 array_transform_udf
40);
41
42#[user_doc(
43 doc_section(label = "Array Functions"),
44 description = "transforms the values of a array",
45 syntax_example = "array_transform(array, x -> x*2)",
46 sql_example = r#"```sql
47> select array_transform([1, 2, 3, 4, 5], x -> x*2);
48+-------------------------------------------+
49| array_transform([1, 2, 3, 4, 5], x -> x*2) |
50+-------------------------------------------+
51| [2, 4, 6, 8, 10] |
52+-------------------------------------------+
53```"#,
54 argument(
55 name = "array",
56 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
57 ),
58 argument(name = "lambda", description = "Lambda")
59)]
60#[derive(Debug, PartialEq, Eq, Hash)]
61pub struct ArrayTransform {
62 signature: Signature,
63 aliases: Vec<String>,
64}
65
66impl Default for ArrayTransform {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl ArrayTransform {
73 pub fn new() -> Self {
74 Self {
75 signature: Signature::any(2, Volatility::Immutable),
76 aliases: vec![String::from("list_transform")],
77 }
78 }
79}
80
81impl ScalarUDFImpl for ArrayTransform {
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn name(&self) -> &str {
87 "array_transform"
88 }
89
90 fn aliases(&self) -> &[String] {
91 &self.aliases
92 }
93
94 fn signature(&self) -> &Signature {
95 &self.signature
96 }
97
98 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
99 internal_err!("return_type called instead of return_field_from_args")
100 }
101
102 fn return_field_from_args(
103 &self,
104 args: datafusion_expr::ReturnFieldArgs,
105 ) -> Result<Arc<Field>> {
106 let args = args.to_lambda_args();
107
108 let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] =
109 take_function_args(self.name(), &args)?
110 else {
111 return exec_err!(
112 "{} expects a value follewed by a lambda, got {:?}",
113 self.name(),
114 args
115 );
116 };
117
118 let field = Arc::new(Field::new(
123 Field::LIST_FIELD_DEFAULT_NAME,
124 lambda.data_type().clone(),
125 lambda.is_nullable(),
126 ));
127
128 let return_type = match list.data_type() {
129 DataType::List(_) => DataType::List(field),
130 DataType::LargeList(_) => DataType::LargeList(field),
131 DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size),
132 _ => unreachable!(),
133 };
134
135 Ok(Arc::new(Field::new("", return_type, list.is_nullable())))
136 }
137
138 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
139 let lambda_args = args.to_lambda_args();
141 let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?;
142
143 let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) =
144 (list_value, lambda)
145 else {
146 return exec_err!(
147 "{} expects a value followed by a lambda, got {:?}",
148 self.name(),
149 &lambda_args
150 );
151 };
152
153 let list_array = list_value.to_array(args.number_rows)?;
154
155 let adjusted_captures = lambda
159 .captures
160 .as_ref()
161 .map(|captures| take_record_batch(captures, &list_indices(&list_array)?))
162 .transpose()?;
163
164 let values_param = || Ok(Arc::clone(list_values(&list_array)?));
167 let indices_param = || elements_indices(&list_array);
168
169 let lambda_batch = merge_captures_with_lazy_args(
173 adjusted_captures.as_ref(),
174 &lambda.params, &[&values_param, &indices_param],
176 )?;
177
178 let transformed_values = lambda
180 .body
181 .evaluate(&lambda_batch)?
182 .into_array(lambda_batch.num_rows())?;
183
184 let field = match args.return_field.data_type() {
185 DataType::List(field)
186 | DataType::LargeList(field)
187 | DataType::FixedSizeList(field, _) => Arc::clone(field),
188 _ => {
189 return exec_err!(
190 "{} expected ScalarFunctionArgs.return_field to be a list, got {}",
191 self.name(),
192 args.return_field
193 )
194 }
195 };
196
197 let transformed_list = match list_array.data_type() {
198 DataType::List(_) => {
199 let list = list_array.as_list();
200
201 Arc::new(ListArray::new(
202 field,
203 list.offsets().clone(),
204 transformed_values,
205 list.nulls().cloned(),
206 )) as ArrayRef
207 }
208 DataType::LargeList(_) => {
209 let large_list = list_array.as_list();
210
211 Arc::new(LargeListArray::new(
212 field,
213 large_list.offsets().clone(),
214 transformed_values,
215 large_list.nulls().cloned(),
216 ))
217 }
218 DataType::FixedSizeList(_, value_length) => {
219 Arc::new(FixedSizeListArray::new(
220 field,
221 *value_length,
222 transformed_values,
223 list_array.as_fixed_size_list().nulls().cloned(),
224 ))
225 }
226 other => exec_err!("expected list, got {other}")?,
227 };
228
229 Ok(ColumnarValue::Array(transformed_list))
230 }
231
232 fn lambdas_parameters(
233 &self,
234 args: &[ValueOrLambdaParameter],
235 ) -> Result<Vec<Option<Vec<Field>>>> {
236 let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] =
237 args
238 else {
239 return exec_err!(
240 "{} expects a value follewed by a lambda, got {:?}",
241 self.name(),
242 args
243 );
244 };
245
246 let (field, index_type) = match list.data_type() {
247 DataType::List(field) => (field, DataType::Int32),
248 DataType::LargeList(field) => (field, DataType::Int64),
249 DataType::FixedSizeList(field, _) => (field, DataType::Int32),
250 _ => return exec_err!("expected list, got {list}"),
251 };
252
253 let value = Field::new("value", field.data_type().clone(), field.is_nullable())
257 .with_metadata(field.metadata().clone());
258 let index = Field::new("index", index_type, false);
259
260 Ok(vec![None, Some(vec![value, index])])
261 }
262
263 fn documentation(&self) -> Option<&Documentation> {
264 self.doc()
265 }
266}