datafusion_functions/core/
nvl2.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::datatypes::{DataType, Field, FieldRef};
19use datafusion_common::{internal_err, utils::take_function_args, Result};
20use datafusion_expr::{
21    conditional_expressions::CaseBuilder,
22    simplify::{ExprSimplifyResult, SimplifyInfo},
23    type_coercion::binary::comparison_coercion,
24    ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
25    ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_macros::user_doc;
28
29#[user_doc(
30    doc_section(label = "Conditional Functions"),
31    description = "Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.",
32    syntax_example = "nvl2(expression1, expression2, expression3)",
33    sql_example = r#"```sql
34> select nvl2(null, 'a', 'b');
35+--------------------------------+
36| nvl2(NULL,Utf8("a"),Utf8("b")) |
37+--------------------------------+
38| b                              |
39+--------------------------------+
40> select nvl2('data', 'a', 'b');
41+----------------------------------------+
42| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) |
43+----------------------------------------+
44| a                                      |
45+----------------------------------------+
46```
47"#,
48    argument(
49        name = "expression1",
50        description = "Expression to test for null. Can be a constant, column, or function, and any combination of operators."
51    ),
52    argument(
53        name = "expression2",
54        description = "Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators."
55    ),
56    argument(
57        name = "expression3",
58        description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators."
59    )
60)]
61#[derive(Debug, PartialEq, Eq, Hash)]
62pub struct NVL2Func {
63    signature: Signature,
64}
65
66impl Default for NVL2Func {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl NVL2Func {
73    pub fn new() -> Self {
74        Self {
75            signature: Signature::user_defined(Volatility::Immutable),
76        }
77    }
78}
79
80impl ScalarUDFImpl for NVL2Func {
81    fn as_any(&self) -> &dyn std::any::Any {
82        self
83    }
84
85    fn name(&self) -> &str {
86        "nvl2"
87    }
88
89    fn signature(&self) -> &Signature {
90        &self.signature
91    }
92
93    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94        Ok(arg_types[1].clone())
95    }
96
97    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
98        let nullable =
99            args.arg_fields[1].is_nullable() || args.arg_fields[2].is_nullable();
100        let return_type = args.arg_fields[1].data_type().clone();
101        Ok(Field::new(self.name(), return_type, nullable).into())
102    }
103
104    fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105        internal_err!("nvl2 should have been simplified to case")
106    }
107
108    fn simplify(
109        &self,
110        args: Vec<Expr>,
111        _info: &dyn SimplifyInfo,
112    ) -> Result<ExprSimplifyResult> {
113        let [test, if_non_null, if_null] = take_function_args(self.name(), args)?;
114
115        let expr = CaseBuilder::new(
116            None,
117            vec![test.is_not_null()],
118            vec![if_non_null],
119            Some(Box::new(if_null)),
120        )
121        .end()?;
122
123        Ok(ExprSimplifyResult::Simplified(expr))
124    }
125
126    fn short_circuits(&self) -> bool {
127        true
128    }
129
130    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
131        let [tested, if_non_null, if_null] = take_function_args(self.name(), arg_types)?;
132        let new_type =
133            [if_non_null, if_null]
134                .iter()
135                .try_fold(tested.clone(), |acc, x| {
136                    // The coerced types found by `comparison_coercion` are not guaranteed to be
137                    // coercible for the arguments. `comparison_coercion` returns more loose
138                    // types that can be coerced to both `acc` and `x` for comparison purpose.
139                    // See `maybe_data_types` for the actual coercion.
140                    let coerced_type = comparison_coercion(&acc, x);
141                    if let Some(coerced_type) = coerced_type {
142                        Ok(coerced_type)
143                    } else {
144                        internal_err!("Coercion from {acc} to {x} failed.")
145                    }
146                })?;
147        Ok(vec![new_type; arg_types.len()])
148    }
149
150    fn documentation(&self) -> Option<&Documentation> {
151        self.doc()
152    }
153}