1use arrow::datatypes::i256;
21use datafusion_common::{internal_err, Result, ScalarValue};
22use datafusion_expr::{
23 expr::{Between, BinaryExpr, InList},
24 expr_fn::{and, bitwise_and, bitwise_or, or},
25 Case, Expr, Like, Operator,
26};
27
28pub static POWS_OF_TEN: [i128; 38] = [
29 1,
30 10,
31 100,
32 1000,
33 10000,
34 100000,
35 1000000,
36 10000000,
37 100000000,
38 1000000000,
39 10000000000,
40 100000000000,
41 1000000000000,
42 10000000000000,
43 100000000000000,
44 1000000000000000,
45 10000000000000000,
46 100000000000000000,
47 1000000000000000000,
48 10000000000000000000,
49 100000000000000000000,
50 1000000000000000000000,
51 10000000000000000000000,
52 100000000000000000000000,
53 1000000000000000000000000,
54 10000000000000000000000000,
55 100000000000000000000000000,
56 1000000000000000000000000000,
57 10000000000000000000000000000,
58 100000000000000000000000000000,
59 1000000000000000000000000000000,
60 10000000000000000000000000000000,
61 100000000000000000000000000000000,
62 1000000000000000000000000000000000,
63 10000000000000000000000000000000000,
64 100000000000000000000000000000000000,
65 1000000000000000000000000000000000000,
66 10000000000000000000000000000000000000,
67];
68
69fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
72 match expr {
73 Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
74 expr_contains_inner(left, needle, search_op)
75 || expr_contains_inner(right, needle, search_op)
76 }
77 _ => expr == needle,
78 }
79}
80
81pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
83 expr_contains_inner(expr, needle, search_op) && !needle.is_volatile()
84}
85
86pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr {
89 fn recursive_delete_xor_in_expr(
91 expr: &Expr,
92 needle: &Expr,
93 xor_counter: &mut i32,
94 ) -> Expr {
95 match expr {
96 Expr::BinaryExpr(BinaryExpr { left, op, right })
97 if *op == Operator::BitwiseXor =>
98 {
99 let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter);
100 let right_expr = recursive_delete_xor_in_expr(right, needle, xor_counter);
101 if left_expr == *needle {
102 *xor_counter += 1;
103 return right_expr;
104 } else if right_expr == *needle {
105 *xor_counter += 1;
106 return left_expr;
107 }
108
109 Expr::BinaryExpr(BinaryExpr::new(
110 Box::new(left_expr),
111 *op,
112 Box::new(right_expr),
113 ))
114 }
115 _ => expr.clone(),
116 }
117 }
118
119 let mut xor_counter: i32 = 0;
120 let result_expr = recursive_delete_xor_in_expr(expr, needle, &mut xor_counter);
121 if result_expr == *needle {
122 return needle.clone();
123 } else if xor_counter % 2 == 0 {
124 if is_left {
125 return Expr::BinaryExpr(BinaryExpr::new(
126 Box::new(needle.clone()),
127 Operator::BitwiseXor,
128 Box::new(result_expr),
129 ));
130 } else {
131 return Expr::BinaryExpr(BinaryExpr::new(
132 Box::new(result_expr),
133 Operator::BitwiseXor,
134 Box::new(needle.clone()),
135 ));
136 }
137 }
138 result_expr
139}
140
141pub fn is_zero(s: &Expr) -> bool {
142 match s {
143 Expr::Literal(ScalarValue::Int8(Some(0)), _)
144 | Expr::Literal(ScalarValue::Int16(Some(0)), _)
145 | Expr::Literal(ScalarValue::Int32(Some(0)), _)
146 | Expr::Literal(ScalarValue::Int64(Some(0)), _)
147 | Expr::Literal(ScalarValue::UInt8(Some(0)), _)
148 | Expr::Literal(ScalarValue::UInt16(Some(0)), _)
149 | Expr::Literal(ScalarValue::UInt32(Some(0)), _)
150 | Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true,
151 Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true,
152 Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true,
153 Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true,
154 Expr::Literal(ScalarValue::Decimal256(Some(v), _p, _s), _)
155 if *v == i256::ZERO =>
156 {
157 true
158 }
159 _ => false,
160 }
161}
162
163pub fn is_one(s: &Expr) -> bool {
164 match s {
165 Expr::Literal(ScalarValue::Int8(Some(1)), _)
166 | Expr::Literal(ScalarValue::Int16(Some(1)), _)
167 | Expr::Literal(ScalarValue::Int32(Some(1)), _)
168 | Expr::Literal(ScalarValue::Int64(Some(1)), _)
169 | Expr::Literal(ScalarValue::UInt8(Some(1)), _)
170 | Expr::Literal(ScalarValue::UInt16(Some(1)), _)
171 | Expr::Literal(ScalarValue::UInt32(Some(1)), _)
172 | Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true,
173 Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true,
174 Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true,
175 Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => {
176 *s >= 0
177 && POWS_OF_TEN
178 .get(*s as usize)
179 .map(|x| x == v)
180 .unwrap_or_default()
181 }
182 Expr::Literal(ScalarValue::Decimal256(Some(v), _p, s), _) => {
183 *s >= 0
184 && match i256::from(10).checked_pow(*s as u32) {
185 Some(res) => res == *v,
186 None => false,
187 }
188 }
189 _ => false,
190 }
191}
192
193pub fn is_true(expr: &Expr) -> bool {
194 match expr {
195 Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v,
196 _ => false,
197 }
198}
199
200pub fn is_bool_lit(expr: &Expr) -> bool {
203 matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _))
204}
205
206pub fn lit_bool_null() -> Expr {
208 Expr::Literal(ScalarValue::Boolean(None), None)
209}
210
211pub fn is_null(expr: &Expr) -> bool {
212 match expr {
213 Expr::Literal(v, _) => v.is_null(),
214 _ => false,
215 }
216}
217
218pub fn is_false(expr: &Expr) -> bool {
219 match expr {
220 Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v),
221 _ => false,
222 }
223}
224
225pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
227 matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile())
228}
229
230pub fn can_reduce_to_equal_statement(haystack: &Expr, needle: &Expr) -> bool {
231 match (haystack, needle) {
232 (
234 Expr::BinaryExpr(BinaryExpr {
235 left,
236 op: Operator::GtEq,
237 right,
238 }),
239 Expr::BinaryExpr(BinaryExpr {
240 left: n_left,
241 op: Operator::LtEq,
242 right: n_right,
243 }),
244 ) if left == n_left && right == n_right => true,
245 _ => false,
246 }
247}
248
249pub fn is_not_of(not_expr: &Expr, expr: &Expr) -> bool {
251 matches!(not_expr, Expr::Not(inner) if expr == inner.as_ref())
252}
253
254pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool {
256 matches!(not_expr, Expr::Negative(inner) if expr == inner.as_ref())
257}
258
259pub fn as_bool_lit(expr: &Expr) -> Result<Option<bool>> {
262 match expr {
263 Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v),
264 _ => internal_err!("Expected boolean literal, got {expr:?}"),
265 }
266}
267
268pub fn is_case_with_literal_outputs(expr: &Expr) -> bool {
269 match expr {
270 Expr::Case(Case {
271 expr: None,
272 when_then_expr,
273 else_expr,
274 }) => {
275 when_then_expr.iter().all(|(_, then)| is_lit(then))
276 && else_expr.as_deref().is_none_or(is_lit)
277 }
278 _ => false,
279 }
280}
281
282pub fn into_case(expr: Expr) -> Result<Case> {
283 match expr {
284 Expr::Case(case) => Ok(case),
285 _ => internal_err!("Expected case, got {expr:?}"),
286 }
287}
288
289pub fn is_lit(expr: &Expr) -> bool {
290 matches!(expr, Expr::Literal(_, _))
291}
292
293pub fn negate_clause(expr: Expr) -> Expr {
308 match expr {
309 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
310 if let Some(negated_op) = op.negate() {
311 return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right));
312 }
313 match op {
314 Operator::And => {
316 let left = negate_clause(*left);
317 let right = negate_clause(*right);
318
319 or(left, right)
320 }
321 Operator::Or => {
323 let left = negate_clause(*left);
324 let right = negate_clause(*right);
325
326 and(left, right)
327 }
328 _ => Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr::new(
330 left, op, right,
331 )))),
332 }
333 }
334 Expr::Not(expr) => *expr,
336 Expr::IsNotNull(expr) => expr.is_null(),
338 Expr::IsNull(expr) => expr.is_not_null(),
340 Expr::InList(InList {
343 expr,
344 list,
345 negated,
346 }) => expr.in_list(list, !negated),
347 Expr::Between(between) => Expr::Between(Between::new(
350 between.expr,
351 !between.negated,
352 between.low,
353 between.high,
354 )),
355 Expr::Like(like) => Expr::Like(Like::new(
357 !like.negated,
358 like.expr,
359 like.pattern,
360 like.escape_char,
361 like.case_insensitive,
362 )),
363 _ => Expr::Not(Box::new(expr)),
365 }
366}
367
368pub fn distribute_negation(expr: Expr) -> Expr {
377 match expr {
378 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
379 match op {
380 Operator::BitwiseAnd => {
382 let left = distribute_negation(*left);
383 let right = distribute_negation(*right);
384
385 bitwise_or(left, right)
386 }
387 Operator::BitwiseOr => {
389 let left = distribute_negation(*left);
390 let right = distribute_negation(*right);
391
392 bitwise_and(left, right)
393 }
394 _ => Expr::Negative(Box::new(Expr::BinaryExpr(BinaryExpr::new(
396 left, op, right,
397 )))),
398 }
399 }
400 Expr::Negative(expr) => *expr,
402 _ => Expr::Negative(Box::new(expr)),
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::{is_one, is_zero};
410 use arrow::datatypes::i256;
411 use datafusion_common::ScalarValue;
412 use datafusion_expr::lit;
413
414 #[test]
415 fn test_is_zero() {
416 assert!(is_zero(&lit(ScalarValue::Int8(Some(0)))));
417 assert!(is_zero(&lit(ScalarValue::Float32(Some(0.0)))));
418 assert!(is_zero(&lit(ScalarValue::Decimal128(
419 Some(i128::from(0)),
420 9,
421 0
422 ))));
423 assert!(is_zero(&lit(ScalarValue::Decimal128(
424 Some(i128::from(0)),
425 9,
426 5
427 ))));
428 assert!(is_zero(&lit(ScalarValue::Decimal256(
429 Some(i256::ZERO),
430 9,
431 0
432 ))));
433 assert!(is_zero(&lit(ScalarValue::Decimal256(
434 Some(i256::ZERO),
435 9,
436 5
437 ))));
438 }
439
440 #[test]
441 fn test_is_one() {
442 assert!(is_one(&lit(ScalarValue::Int8(Some(1)))));
443 assert!(is_one(&lit(ScalarValue::Float32(Some(1.0)))));
444 assert!(is_one(&lit(ScalarValue::Decimal128(
445 Some(i128::from(1)),
446 9,
447 0
448 ))));
449 assert!(is_one(&lit(ScalarValue::Decimal128(
450 Some(i128::from(10)),
451 9,
452 1
453 ))));
454 assert!(is_one(&lit(ScalarValue::Decimal128(
455 Some(i128::from(100)),
456 9,
457 2
458 ))));
459 assert!(is_one(&lit(ScalarValue::Decimal256(
460 Some(i256::from(1)),
461 9,
462 0
463 ))));
464 assert!(is_one(&lit(ScalarValue::Decimal256(
465 Some(i256::from(10)),
466 9,
467 1
468 ))));
469 assert!(is_one(&lit(ScalarValue::Decimal256(
470 Some(i256::from(100)),
471 9,
472 2
473 ))));
474 assert!(!is_one(&lit(ScalarValue::Decimal256(
475 Some(i256::from(100)),
476 9,
477 -1
478 ))));
479 }
480}