datafusion_optimizer/simplify_expressions/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions for expression simplification
19
20use arrow::datatypes::i256;
21use datafusion_common::{internal_err, Result, ScalarValue};
22use datafusion_expr::{
23    expr::{Between, BinaryExpr, InList},
24    expr_fn::{and, bitwise_and, bitwise_or, or},
25    Case, Expr, Like, Operator,
26};
27
28pub static POWS_OF_TEN: [i128; 38] = [
29    1,
30    10,
31    100,
32    1000,
33    10000,
34    100000,
35    1000000,
36    10000000,
37    100000000,
38    1000000000,
39    10000000000,
40    100000000000,
41    1000000000000,
42    10000000000000,
43    100000000000000,
44    1000000000000000,
45    10000000000000000,
46    100000000000000000,
47    1000000000000000000,
48    10000000000000000000,
49    100000000000000000000,
50    1000000000000000000000,
51    10000000000000000000000,
52    100000000000000000000000,
53    1000000000000000000000000,
54    10000000000000000000000000,
55    100000000000000000000000000,
56    1000000000000000000000000000,
57    10000000000000000000000000000,
58    100000000000000000000000000000,
59    1000000000000000000000000000000,
60    10000000000000000000000000000000,
61    100000000000000000000000000000000,
62    1000000000000000000000000000000000,
63    10000000000000000000000000000000000,
64    100000000000000000000000000000000000,
65    1000000000000000000000000000000000000,
66    10000000000000000000000000000000000000,
67];
68
69/// returns true if `needle` is found in a chain of search_op
70/// expressions. Such as: (A AND B) AND C
71fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
72    match expr {
73        Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
74            expr_contains_inner(left, needle, search_op)
75                || expr_contains_inner(right, needle, search_op)
76        }
77        _ => expr == needle,
78    }
79}
80
81/// check volatile calls and return if expr contains needle
82pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
83    expr_contains_inner(expr, needle, search_op) && !needle.is_volatile()
84}
85
86/// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor
87/// expressions. Such as: A ^ (A ^ (B ^ A))
88pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr {
89    /// Deletes recursively 'needles' in a chain of xor expressions
90    fn recursive_delete_xor_in_expr(
91        expr: &Expr,
92        needle: &Expr,
93        xor_counter: &mut i32,
94    ) -> Expr {
95        match expr {
96            Expr::BinaryExpr(BinaryExpr { left, op, right })
97                if *op == Operator::BitwiseXor =>
98            {
99                let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter);
100                let right_expr = recursive_delete_xor_in_expr(right, needle, xor_counter);
101                if left_expr == *needle {
102                    *xor_counter += 1;
103                    return right_expr;
104                } else if right_expr == *needle {
105                    *xor_counter += 1;
106                    return left_expr;
107                }
108
109                Expr::BinaryExpr(BinaryExpr::new(
110                    Box::new(left_expr),
111                    *op,
112                    Box::new(right_expr),
113                ))
114            }
115            _ => expr.clone(),
116        }
117    }
118
119    let mut xor_counter: i32 = 0;
120    let result_expr = recursive_delete_xor_in_expr(expr, needle, &mut xor_counter);
121    if result_expr == *needle {
122        return needle.clone();
123    } else if xor_counter % 2 == 0 {
124        if is_left {
125            return Expr::BinaryExpr(BinaryExpr::new(
126                Box::new(needle.clone()),
127                Operator::BitwiseXor,
128                Box::new(result_expr),
129            ));
130        } else {
131            return Expr::BinaryExpr(BinaryExpr::new(
132                Box::new(result_expr),
133                Operator::BitwiseXor,
134                Box::new(needle.clone()),
135            ));
136        }
137    }
138    result_expr
139}
140
141pub fn is_zero(s: &Expr) -> bool {
142    match s {
143        Expr::Literal(ScalarValue::Int8(Some(0)), _)
144        | Expr::Literal(ScalarValue::Int16(Some(0)), _)
145        | Expr::Literal(ScalarValue::Int32(Some(0)), _)
146        | Expr::Literal(ScalarValue::Int64(Some(0)), _)
147        | Expr::Literal(ScalarValue::UInt8(Some(0)), _)
148        | Expr::Literal(ScalarValue::UInt16(Some(0)), _)
149        | Expr::Literal(ScalarValue::UInt32(Some(0)), _)
150        | Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true,
151        Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true,
152        Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true,
153        Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true,
154        Expr::Literal(ScalarValue::Decimal256(Some(v), _p, _s), _)
155            if *v == i256::ZERO =>
156        {
157            true
158        }
159        _ => false,
160    }
161}
162
163pub fn is_one(s: &Expr) -> bool {
164    match s {
165        Expr::Literal(ScalarValue::Int8(Some(1)), _)
166        | Expr::Literal(ScalarValue::Int16(Some(1)), _)
167        | Expr::Literal(ScalarValue::Int32(Some(1)), _)
168        | Expr::Literal(ScalarValue::Int64(Some(1)), _)
169        | Expr::Literal(ScalarValue::UInt8(Some(1)), _)
170        | Expr::Literal(ScalarValue::UInt16(Some(1)), _)
171        | Expr::Literal(ScalarValue::UInt32(Some(1)), _)
172        | Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true,
173        Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true,
174        Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true,
175        Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => {
176            *s >= 0
177                && POWS_OF_TEN
178                    .get(*s as usize)
179                    .map(|x| x == v)
180                    .unwrap_or_default()
181        }
182        Expr::Literal(ScalarValue::Decimal256(Some(v), _p, s), _) => {
183            *s >= 0
184                && match i256::from(10).checked_pow(*s as u32) {
185                    Some(res) => res == *v,
186                    None => false,
187                }
188        }
189        _ => false,
190    }
191}
192
193pub fn is_true(expr: &Expr) -> bool {
194    match expr {
195        Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v,
196        _ => false,
197    }
198}
199
200/// returns true if expr is a
201/// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise
202pub fn is_bool_lit(expr: &Expr) -> bool {
203    matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _))
204}
205
206/// Return a literal NULL value of Boolean data type
207pub fn lit_bool_null() -> Expr {
208    Expr::Literal(ScalarValue::Boolean(None), None)
209}
210
211pub fn is_null(expr: &Expr) -> bool {
212    match expr {
213        Expr::Literal(v, _) => v.is_null(),
214        _ => false,
215    }
216}
217
218pub fn is_false(expr: &Expr) -> bool {
219    match expr {
220        Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v),
221        _ => false,
222    }
223}
224
225/// returns true if `haystack` looks like (needle OP X) or (X OP needle)
226pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
227    matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile())
228}
229
230pub fn can_reduce_to_equal_statement(haystack: &Expr, needle: &Expr) -> bool {
231    match (haystack, needle) {
232        // a >= constant and constant <= a => a = constant
233        (
234            Expr::BinaryExpr(BinaryExpr {
235                left,
236                op: Operator::GtEq,
237                right,
238            }),
239            Expr::BinaryExpr(BinaryExpr {
240                left: n_left,
241                op: Operator::LtEq,
242                right: n_right,
243            }),
244        ) if left == n_left && right == n_right => true,
245        _ => false,
246    }
247}
248
249/// returns true if `not_expr` is !`expr` (not)
250pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool {
251    matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref())
252}
253
254/// returns true if `not_expr` is !`expr` (bitwise not)
255pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool {
256    matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref())
257}
258
259/// returns the contained boolean value in `expr` as
260/// `Expr::Literal(ScalarValue::Boolean(v))`.
261pub fn as_bool_lit(expr: &Expr) -> Result<Option<bool>> {
262    match expr {
263        Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v),
264        _ => internal_err!("Expected boolean literal, got {expr:?}"),
265    }
266}
267
268pub fn is_case_with_literal_outputs(expr: &Expr) -> bool {
269    match expr {
270        Expr::Case(Case {
271            expr: None,
272            when_then_expr,
273            else_expr,
274        }) => {
275            when_then_expr.iter().all(|(_, then)| is_lit(then))
276                && else_expr.as_deref().is_none_or(is_lit)
277        }
278        _ => false,
279    }
280}
281
282pub fn into_case(expr: Expr) -> Result<Case> {
283    match expr {
284        Expr::Case(case) => Ok(case),
285        _ => internal_err!("Expected case, got {expr:?}"),
286    }
287}
288
289pub fn is_lit(expr: &Expr) -> bool {
290    matches!(expr, Expr::Literal(_, _))
291}
292
293/// negate a Not clause
294/// input is the clause to be negated.(args of Not clause)
295/// For BinaryExpr, use the negation of op instead.
296///    not ( A > B) ===> (A <= B)
297/// For BoolExpr, not (A and B) ===> (not A) or (not B)
298///     not (A or B) ===> (not A) and (not B)
299///     not (not A) ===> A
300/// For NullExpr, not (A is not null) ===> A is null
301///     not (A is null) ===> A is not null
302/// For InList, not (A not in (..)) ===> A in (..)
303///     not (A in (..)) ===> A not in (..)
304/// For Between, not (A between B and C) ===> (A not between B and C)
305///     not (A not between B and C) ===> (A between B and C)
306/// For others, use Not clause
307pub fn negate_clause(expr: Expr) -> Expr {
308    match expr {
309        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
310            if let Some(negated_op) = op.negate() {
311                return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right));
312            }
313            match op {
314                // not (A and B) ===> (not A) or (not B)
315                Operator::And => {
316                    let left = negate_clause(*left);
317                    let right = negate_clause(*right);
318
319                    or(left, right)
320                }
321                // not (A or B) ===> (not A) and (not B)
322                Operator::Or => {
323                    let left = negate_clause(*left);
324                    let right = negate_clause(*right);
325
326                    and(left, right)
327                }
328                // use not clause
329                _ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new(
330                    left, op, right,
331                )))),
332            }
333        }
334        // not (not A) ===> A
335        Expr::Not(expr) => *expr,
336        // not (A is not null) ===> A is null
337        Expr::IsNotNull(expr) => expr.is_null(),
338        // not (A is null) ===> A is not null
339        Expr::IsNull(expr) => expr.is_not_null(),
340        // not (A not in (..)) ===> A in (..)
341        // not (A in (..)) ===> A not in (..)
342        Expr::InList(InList {
343            expr,
344            list,
345            negated,
346        }) => expr.in_list(list, !negated),
347        // not (A between B and C) ===> (A not between B and C)
348        // not (A not between B and C) ===> (A between B and C)
349        Expr::Between(between) => Expr::Between(Between::new(
350            between.expr,
351            !between.negated,
352            between.low,
353            between.high,
354        )),
355        // not (A like B) ===> A not like B
356        Expr::Like(like) => Expr::Like(Like::new(
357            !like.negated,
358            like.expr,
359            like.pattern,
360            like.escape_char,
361            like.case_insensitive,
362        )),
363        // use not clause
364        _ => Expr::Not(Box::new(expr)),
365    }
366}
367
368/// bitwise negate a Negative clause
369/// input is the clause to be bitwise negated.(args for Negative clause)
370/// For BinaryExpr:
371///    ~(A & B) ===> ~A | ~B
372///    ~(A | B) ===> ~A & ~B
373/// For Negative:
374///    ~(~A) ===> A
375/// For others, use Negative clause
376pub fn distribute_negation(expr: Expr) -> Expr {
377    match expr {
378        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
379            match op {
380                // ~(A & B) ===> ~A | ~B
381                Operator::BitwiseAnd => {
382                    let left = distribute_negation(*left);
383                    let right = distribute_negation(*right);
384
385                    bitwise_or(left, right)
386                }
387                // ~(A | B) ===> ~A & ~B
388                Operator::BitwiseOr => {
389                    let left = distribute_negation(*left);
390                    let right = distribute_negation(*right);
391
392                    bitwise_and(left, right)
393                }
394                // use negative clause
395                _ => Expr::Negative(Box::new(Expr::BinaryExpr(BinaryExpr::new(
396                    left, op, right,
397                )))),
398            }
399        }
400        // ~(~A) ===> A
401        Expr::Negative(expr) => *expr,
402        // use negative clause
403        _ => Expr::Negative(Box::new(expr)),
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::{is_one, is_zero};
410    use arrow::datatypes::i256;
411    use datafusion_common::ScalarValue;
412    use datafusion_expr::lit;
413
414    #[test]
415    fn test_is_zero() {
416        assert!(is_zero(&lit(ScalarValue::Int8(Some(0)))));
417        assert!(is_zero(&lit(ScalarValue::Float32(Some(0.0)))));
418        assert!(is_zero(&lit(ScalarValue::Decimal128(
419            Some(i128::from(0)),
420            9,
421            0
422        ))));
423        assert!(is_zero(&lit(ScalarValue::Decimal128(
424            Some(i128::from(0)),
425            9,
426            5
427        ))));
428        assert!(is_zero(&lit(ScalarValue::Decimal256(
429            Some(i256::ZERO),
430            9,
431            0
432        ))));
433        assert!(is_zero(&lit(ScalarValue::Decimal256(
434            Some(i256::ZERO),
435            9,
436            5
437        ))));
438    }
439
440    #[test]
441    fn test_is_one() {
442        assert!(is_one(&lit(ScalarValue::Int8(Some(1)))));
443        assert!(is_one(&lit(ScalarValue::Float32(Some(1.0)))));
444        assert!(is_one(&lit(ScalarValue::Decimal128(
445            Some(i128::from(1)),
446            9,
447            0
448        ))));
449        assert!(is_one(&lit(ScalarValue::Decimal128(
450            Some(i128::from(10)),
451            9,
452            1
453        ))));
454        assert!(is_one(&lit(ScalarValue::Decimal128(
455            Some(i128::from(100)),
456            9,
457            2
458        ))));
459        assert!(is_one(&lit(ScalarValue::Decimal256(
460            Some(i256::from(1)),
461            9,
462            0
463        ))));
464        assert!(is_one(&lit(ScalarValue::Decimal256(
465            Some(i256::from(10)),
466            9,
467            1
468        ))));
469        assert!(is_one(&lit(ScalarValue::Decimal256(
470            Some(i256::from(100)),
471            9,
472            2
473        ))));
474        assert!(!is_one(&lit(ScalarValue::Decimal256(
475            Some(i256::from(100)),
476            9,
477            -1
478        ))));
479    }
480}