datafusion_functions/core/
nvl2.rs1use arrow::datatypes::{DataType, Field, FieldRef};
19use datafusion_common::{internal_err, utils::take_function_args, Result};
20use datafusion_expr::{
21 conditional_expressions::CaseBuilder,
22 simplify::{ExprSimplifyResult, SimplifyInfo},
23 type_coercion::binary::comparison_coercion,
24 ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs,
25 ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_macros::user_doc;
28
29#[user_doc(
30 doc_section(label = "Conditional Functions"),
31 description = "Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.",
32 syntax_example = "nvl2(expression1, expression2, expression3)",
33 sql_example = r#"```sql
34> select nvl2(null, 'a', 'b');
35+--------------------------------+
36| nvl2(NULL,Utf8("a"),Utf8("b")) |
37+--------------------------------+
38| b |
39+--------------------------------+
40> select nvl2('data', 'a', 'b');
41+----------------------------------------+
42| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) |
43+----------------------------------------+
44| a |
45+----------------------------------------+
46```
47"#,
48 argument(
49 name = "expression1",
50 description = "Expression to test for null. Can be a constant, column, or function, and any combination of operators."
51 ),
52 argument(
53 name = "expression2",
54 description = "Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators."
55 ),
56 argument(
57 name = "expression3",
58 description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators."
59 )
60)]
61#[derive(Debug, PartialEq, Eq, Hash)]
62pub struct NVL2Func {
63 signature: Signature,
64}
65
66impl Default for NVL2Func {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl NVL2Func {
73 pub fn new() -> Self {
74 Self {
75 signature: Signature::user_defined(Volatility::Immutable),
76 }
77 }
78}
79
80impl ScalarUDFImpl for NVL2Func {
81 fn as_any(&self) -> &dyn std::any::Any {
82 self
83 }
84
85 fn name(&self) -> &str {
86 "nvl2"
87 }
88
89 fn signature(&self) -> &Signature {
90 &self.signature
91 }
92
93 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94 Ok(arg_types[1].clone())
95 }
96
97 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
98 let nullable =
99 args.arg_fields[1].is_nullable() || args.arg_fields[2].is_nullable();
100 let return_type = args.arg_fields[1].data_type().clone();
101 Ok(Field::new(self.name(), return_type, nullable).into())
102 }
103
104 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105 internal_err!("nvl2 should have been simplified to case")
106 }
107
108 fn simplify(
109 &self,
110 args: Vec<Expr>,
111 _info: &dyn SimplifyInfo,
112 ) -> Result<ExprSimplifyResult> {
113 let [test, if_non_null, if_null] = take_function_args(self.name(), args)?;
114
115 let expr = CaseBuilder::new(
116 None,
117 vec![test.is_not_null()],
118 vec![if_non_null],
119 Some(Box::new(if_null)),
120 )
121 .end()?;
122
123 Ok(ExprSimplifyResult::Simplified(expr))
124 }
125
126 fn short_circuits(&self) -> bool {
127 true
128 }
129
130 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
131 let [tested, if_non_null, if_null] = take_function_args(self.name(), arg_types)?;
132 let new_type =
133 [if_non_null, if_null]
134 .iter()
135 .try_fold(tested.clone(), |acc, x| {
136 let coerced_type = comparison_coercion(&acc, x);
141 if let Some(coerced_type) = coerced_type {
142 Ok(coerced_type)
143 } else {
144 internal_err!("Coercion from {acc} to {x} failed.")
145 }
146 })?;
147 Ok(vec![new_type; arg_types.len()])
148 }
149
150 fn documentation(&self) -> Option<&Documentation> {
151 self.doc()
152 }
153}