1use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::DataType;
21use arrow::error::ArrowError;
22use datafusion_common::{DataFusionError, Result, ScalarValue};
23use datafusion_expr::function::Hint;
24use datafusion_expr::ColumnarValue;
25use std::sync::Arc;
26
27macro_rules! get_optimal_return_type {
37 ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
38 pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
39 Ok(match arg_type {
40 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
42 DataType::Utf8 | DataType::Binary => $utf8Type,
44 DataType::Utf8View | DataType::BinaryView => $utf8Type,
46 DataType::Null => DataType::Null,
47 DataType::Dictionary(_, value_type) => match **value_type {
48 DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
49 DataType::Utf8 | DataType::Binary => $utf8Type,
50 DataType::Null => DataType::Null,
51 _ => {
52 return datafusion_common::exec_err!(
53 "The {} function can only accept strings, but got {:?}.",
54 name.to_uppercase(),
55 **value_type
56 );
57 }
58 },
59 data_type => {
60 return datafusion_common::exec_err!(
61 "The {} function can only accept strings, but got {:?}.",
62 name.to_uppercase(),
63 data_type
64 );
65 }
66 })
67 }
68 };
69}
70
71get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
73
74get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
76
77pub fn make_scalar_function<F>(
81 inner: F,
82 hints: Vec<Hint>,
83) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
84where
85 F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
86{
87 move |args: &[ColumnarValue]| {
88 let len = args
91 .iter()
92 .fold(Option::<usize>::None, |acc, arg| match arg {
93 ColumnarValue::Scalar(_) => acc,
94 ColumnarValue::Array(a) => Some(a.len()),
95 });
96
97 let is_scalar = len.is_none();
98
99 let inferred_length = len.unwrap_or(1);
100 let args = args
101 .iter()
102 .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
103 .map(|(arg, hint)| {
104 let expansion_len = match hint {
107 Hint::AcceptsSingular => 1,
108 Hint::Pad => inferred_length,
109 };
110 arg.to_array(expansion_len)
111 })
112 .collect::<Result<Vec<_>>>()?;
113
114 let result = (inner)(&args);
115 if is_scalar {
116 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
118 result.map(ColumnarValue::Scalar)
119 } else {
120 result.map(ColumnarValue::Array)
121 }
122 }
123}
124
125pub fn calculate_binary_math<L, R, O, F>(
132 left: &dyn Array,
133 right: &ColumnarValue,
134 fun: F,
135) -> Result<Arc<PrimitiveArray<O>>>
136where
137 R: ArrowPrimitiveType,
138 L: ArrowPrimitiveType,
139 O: ArrowPrimitiveType,
140 F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
141 R::Native: TryFrom<ScalarValue>,
142{
143 Ok(match right {
144 ColumnarValue::Scalar(scalar) => {
145 let right_value: R::Native =
146 R::Native::try_from(scalar.clone()).map_err(|_| {
147 DataFusionError::NotImplemented(format!(
148 "Cannot convert scalar value {} to {}",
149 &scalar,
150 R::DATA_TYPE
151 ))
152 })?;
153 let left_array = left.as_primitive::<L>();
154 let result =
156 left_array.try_unary::<_, O, _>(|lvalue| fun(lvalue, right_value))?;
157 Arc::new(result) as _
158 }
159 ColumnarValue::Array(right) => {
160 let right_casted = arrow::compute::cast(&right, &R::DATA_TYPE)?;
161 let right_array = right_casted.as_primitive::<R>();
162
163 let result = if PrimitiveArray::<L>::is_compatible(&L::DATA_TYPE) {
165 let left_array = left.as_primitive::<L>();
166 try_binary::<_, _, _, O>(left_array, right_array, &fun)?
167 } else {
168 let left_casted = arrow::compute::cast(left, &L::DATA_TYPE)?;
169 let left_array = left_casted.as_primitive::<L>();
170 try_binary::<_, _, _, O>(left_array, right_array, &fun)?
171 };
172 Arc::new(result) as _
173 }
174 })
175}
176
177pub fn decimal128_to_i128(value: i128, scale: i8) -> Result<i128, ArrowError> {
179 if scale < 0 {
180 Err(ArrowError::ComputeError(
181 "Negative scale is not supported".into(),
182 ))
183 } else if scale == 0 {
184 Ok(value)
185 } else {
186 match i128::from(10).checked_pow(scale as u32) {
187 Some(divisor) => Ok(value / divisor),
188 None => Err(ArrowError::ComputeError(format!(
189 "Cannot get a power of {scale}"
190 ))),
191 }
192 }
193}
194
195#[cfg(test)]
196pub mod test {
197 macro_rules! test_function {
205 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident, $CONFIG_OPTIONS:expr) => {
206 let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
207 let func = $FUNC;
208
209 let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
210 let cardinality = $ARGS
211 .iter()
212 .fold(Option::<usize>::None, |acc, arg| match arg {
213 ColumnarValue::Scalar(_) => acc,
214 ColumnarValue::Array(a) => Some(a.len()),
215 })
216 .unwrap_or(1);
217
218 let scalar_arguments = $ARGS.iter().map(|arg| match arg {
219 ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
220 ColumnarValue::Array(_) => None,
221 }).collect::<Vec<_>>();
222 let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
223
224 let nullables = $ARGS.iter().map(|arg| match arg {
225 ColumnarValue::Scalar(scalar) => scalar.is_null(),
226 ColumnarValue::Array(a) => a.null_count() > 0,
227 }).collect::<Vec<_>>();
228
229 let field_array = data_array.into_iter().zip(nullables).enumerate()
230 .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable))
231 .map(std::sync::Arc::new)
232 .collect::<Vec<_>>();
233
234 let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs {
235 arg_fields: &field_array,
236 scalar_arguments: &scalar_arguments_refs,
237 lambdas: &vec![false; scalar_arguments_refs.len()],
238 });
239 let arg_fields = $ARGS.iter()
240 .enumerate()
241 .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into())
242 .collect::<Vec<_>>();
243
244 match expected {
245 Ok(expected) => {
246 assert_eq!(return_field.is_ok(), true);
247 let return_field = return_field.unwrap();
248 let return_type = return_field.data_type();
249 assert_eq!(return_type, &$EXPECTED_DATA_TYPE);
250
251 let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{
252 args: $ARGS,
253 arg_fields,
254 number_rows: cardinality,
255 return_field,
256 lambdas: None,
257 config_options: $CONFIG_OPTIONS
258 });
259 assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
260
261 let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
262 let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
263 assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
264
265 match expected {
267 Some(v) => assert_eq!(result.value(0), v),
268 None => assert!(result.is_null(0)),
269 };
270 }
271 Err(expected_error) => {
272 if let Ok(return_field) = return_field {
273 match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs {
275 args: $ARGS,
276 arg_fields,
277 number_rows: cardinality,
278 return_field,
279 lambdas: None,
280 config_options: $CONFIG_OPTIONS,
281 }) {
282 Ok(_) => assert!(false, "expected error"),
283 Err(error) => {
284 assert!(expected_error
285 .strip_backtrace()
286 .starts_with(&error.strip_backtrace()));
287 }
288 }
289 } else if let Err(error) = return_field {
290 datafusion_common::assert_contains!(
291 expected_error.strip_backtrace(),
292 error.strip_backtrace()
293 );
294 }
295 }
296 };
297 };
298
299 ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
300 test_function!(
301 $FUNC,
302 $ARGS,
303 $EXPECTED,
304 $EXPECTED_TYPE,
305 $EXPECTED_DATA_TYPE,
306 $ARRAY_TYPE,
307 std::sync::Arc::new(datafusion_common::config::ConfigOptions::default())
308 )
309 };
310 }
311
312 use arrow::datatypes::DataType;
313 #[allow(unused_imports)]
314 pub(crate) use test_function;
315
316 use super::*;
317
318 #[test]
319 fn string_to_int_type() {
320 let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
321 assert_eq!(v, DataType::Int32);
322
323 let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
324 assert_eq!(v, DataType::Int32);
325
326 let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
327 assert_eq!(v, DataType::Int64);
328 }
329
330 #[test]
331 fn test_decimal128_to_i128() {
332 let cases = [
333 (123, 0, Some(123)),
334 (1230, 1, Some(123)),
335 (123000, 3, Some(123)),
336 (1, 0, Some(1)),
337 (123, -3, None),
338 (123, i8::MAX, None),
339 (i128::MAX, 0, Some(i128::MAX)),
340 (i128::MAX, 3, Some(i128::MAX / 1000)),
341 ];
342
343 for (value, scale, expected) in cases {
344 match decimal128_to_i128(value, scale) {
345 Ok(actual) => {
346 assert_eq!(
347 actual,
348 expected.expect("Got value but expected none"),
349 "{value} and {scale} vs {expected:?}"
350 );
351 }
352 Err(_) => assert!(expected.is_none()),
353 }
354 }
355 }
356}