datafusion_functions_nested/
map_values.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for map_values function.
19
20use crate::utils::{get_map_entry_field, make_scalar_function};
21use arrow::array::{Array, ArrayRef, ListArray};
22use arrow::datatypes::{DataType, Field, FieldRef};
23use datafusion_common::utils::take_function_args;
24use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result};
25use datafusion_expr::{
26    ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
27    TypeSignature, Volatility,
28};
29use datafusion_macros::user_doc;
30use std::any::Any;
31use std::ops::Deref;
32use std::sync::Arc;
33
34make_udf_expr_and_func!(
35    MapValuesFunc,
36    map_values,
37    map,
38    "Return a list of all values in the map.",
39    map_values_udf
40);
41
42#[user_doc(
43    doc_section(label = "Map Functions"),
44    description = "Returns a list of all values in the map.",
45    syntax_example = "map_values(map)",
46    sql_example = r#"```sql
47SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3});
48----
49[1, , 3]
50
51SELECT map_values(map([100, 5], [42, 43]));
52----
53[42, 43]
54```"#,
55    argument(
56        name = "map",
57        description = "Map expression. Can be a constant, column, or function, and any combination of map operators."
58    )
59)]
60#[derive(Debug, PartialEq, Eq, Hash)]
61pub(crate) struct MapValuesFunc {
62    signature: Signature,
63}
64
65impl Default for MapValuesFunc {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl MapValuesFunc {
72    pub fn new() -> Self {
73        Self {
74            signature: Signature::new(
75                TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray),
76                Volatility::Immutable,
77            ),
78        }
79    }
80}
81
82impl ScalarUDFImpl for MapValuesFunc {
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn name(&self) -> &str {
88        "map_values"
89    }
90
91    fn signature(&self) -> &Signature {
92        &self.signature
93    }
94
95    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
96        internal_err!("return_field_from_args should be used instead")
97    }
98
99    fn return_field_from_args(
100        &self,
101        args: datafusion_expr::ReturnFieldArgs,
102    ) -> Result<FieldRef> {
103        let [map_type] = take_function_args(self.name(), args.arg_fields)?;
104
105        Ok(Field::new(
106            self.name(),
107            DataType::List(get_map_values_field_as_list_field(map_type.data_type())?),
108            // Nullable if the map is nullable
109            args.arg_fields.iter().any(|x| x.is_nullable()),
110        )
111        .into())
112    }
113
114    fn invoke_with_args(
115        &self,
116        args: datafusion_expr::ScalarFunctionArgs,
117    ) -> Result<ColumnarValue> {
118        make_scalar_function(map_values_inner)(&args.args)
119    }
120
121    fn documentation(&self) -> Option<&Documentation> {
122        self.doc()
123    }
124}
125
126fn map_values_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
127    let [map_arg] = take_function_args("map_values", args)?;
128
129    let map_array = match map_arg.data_type() {
130        DataType::Map(_, _) => as_map_array(&map_arg)?,
131        _ => return exec_err!("Argument for map_values should be a map"),
132    };
133
134    Ok(Arc::new(ListArray::new(
135        get_map_values_field_as_list_field(map_arg.data_type())?,
136        map_array.offsets().clone(),
137        Arc::clone(map_array.values()),
138        map_array.nulls().cloned(),
139    )))
140}
141
142fn get_map_values_field_as_list_field(map_type: &DataType) -> Result<FieldRef> {
143    let map_fields = get_map_entry_field(map_type)?;
144
145    let values_field = map_fields
146        .last()
147        .unwrap()
148        .deref()
149        .clone()
150        .with_name(Field::LIST_FIELD_DEFAULT_NAME);
151
152    Ok(Arc::new(values_field))
153}
154
155#[cfg(test)]
156mod tests {
157    use crate::map_values::MapValuesFunc;
158    use arrow::datatypes::{DataType, Field, FieldRef};
159    use datafusion_common::ScalarValue;
160    use datafusion_expr::ScalarUDFImpl;
161    use std::sync::Arc;
162
163    #[test]
164    fn return_type_field() {
165        fn get_map_field(
166            is_map_nullable: bool,
167            is_keys_nullable: bool,
168            is_values_nullable: bool,
169        ) -> FieldRef {
170            Field::new_map(
171                "something",
172                "entries",
173                Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)),
174                Arc::new(Field::new(
175                    "values",
176                    DataType::LargeUtf8,
177                    is_values_nullable,
178                )),
179                false,
180                is_map_nullable,
181            )
182            .into()
183        }
184
185        fn get_list_field(
186            name: &str,
187            is_list_nullable: bool,
188            list_item_type: DataType,
189            is_list_items_nullable: bool,
190        ) -> FieldRef {
191            Field::new_list(
192                name,
193                Arc::new(Field::new_list_field(
194                    list_item_type,
195                    is_list_items_nullable,
196                )),
197                is_list_nullable,
198            )
199            .into()
200        }
201
202        fn get_return_field(field: FieldRef) -> FieldRef {
203            let func = MapValuesFunc::new();
204            let args = datafusion_expr::ReturnFieldArgs {
205                arg_fields: &[field],
206                scalar_arguments: &[None::<&ScalarValue>],
207                lambdas: &[false],
208            };
209
210            func.return_field_from_args(args).unwrap()
211        }
212
213        // Test cases:
214        //
215        // |                      Input Map                         ||                   Expected Output                     |
216        // | ------------------------------------------------------ || ----------------------------------------------------- |
217        // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable |
218        // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- |
219        // | false        | false             | false               || false                  | false                        |
220        // | false        | false             | true                || false                  | true                         |
221        // | false        | true              | false               || false                  | false                        |
222        // | false        | true              | true                || false                  | true                         |
223        // | true         | false             | false               || true                   | false                        |
224        // | true         | false             | true                || true                   | true                         |
225        // | true         | true              | false               || true                   | false                        |
226        // | true         | true              | true                || true                   | true                         |
227        //
228        // ---------------
229        // We added the key nullability to show that it does not affect the nullability of the list or the list items.
230
231        assert_eq!(
232            get_return_field(get_map_field(false, false, false)),
233            get_list_field("map_values", false, DataType::LargeUtf8, false)
234        );
235
236        assert_eq!(
237            get_return_field(get_map_field(false, false, true)),
238            get_list_field("map_values", false, DataType::LargeUtf8, true)
239        );
240
241        assert_eq!(
242            get_return_field(get_map_field(false, true, false)),
243            get_list_field("map_values", false, DataType::LargeUtf8, false)
244        );
245
246        assert_eq!(
247            get_return_field(get_map_field(false, true, true)),
248            get_list_field("map_values", false, DataType::LargeUtf8, true)
249        );
250
251        assert_eq!(
252            get_return_field(get_map_field(true, false, false)),
253            get_list_field("map_values", true, DataType::LargeUtf8, false)
254        );
255
256        assert_eq!(
257            get_return_field(get_map_field(true, false, true)),
258            get_list_field("map_values", true, DataType::LargeUtf8, true)
259        );
260
261        assert_eq!(
262            get_return_field(get_map_field(true, true, false)),
263            get_list_field("map_values", true, DataType::LargeUtf8, false)
264        );
265
266        assert_eq!(
267            get_return_field(get_map_field(true, true, true)),
268            get_list_field("map_values", true, DataType::LargeUtf8, true)
269        );
270    }
271}