datafusion_optimizer/simplify_expressions/
unwrap_cast.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Unwrap casts in binary comparisons
19//!
20//! The functions in this module attempt to remove casts from
21//! comparisons to literals ([`ScalarValue`]s) by applying the casts
22//! to the literals if possible. It is inspired by the optimizer rule
23//! `UnwrapCastInBinaryComparison` of Spark.
24//!
25//! Removing casts often improves performance because:
26//! 1. The cast is done once (to the literal) rather than to every value
27//! 2. Can enable other optimizations such as predicate pushdown that
28//!    don't support casting
29//!
30//! The rule is applied to expressions of the following forms:
31//!
32//! 1. `cast(left_expr as data_type) comparison_op literal_expr`
33//! 2. `literal_expr comparison_op cast(left_expr as data_type)`
34//! 3. `cast(literal_expr) IN (expr1, expr2, ...)`
35//! 4. `literal_expr IN (cast(expr1) , cast(expr2), ...)`
36//!
37//! If the expression matches one of the forms above, the rule will
38//! ensure the value of `literal` is in range(min, max) of the
39//! expr's data_type, and if the scalar is within range, the literal
40//! will be casted to the data type of expr on the other side, and the
41//! cast will be removed from the other side.
42//!
43//! # Example
44//!
45//! If the DataType of c1 is INT32. Given the filter
46//!
47//! ```text
48//! cast(c1 as INT64) > INT64(10)`
49//! ```
50//!
51//! This rule will remove the cast and rewrite the expression to:
52//!
53//! ```text
54//! c1 > INT32(10)
55//! ```
56
57use arrow::datatypes::DataType;
58use datafusion_common::{internal_err, tree_node::Transformed};
59use datafusion_common::{Result, ScalarValue};
60use datafusion_expr::{lit, BinaryExpr};
61use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast};
62use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type};
63
64pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
65    info: &S,
66    cast_expr: Expr,
67    literal: Expr,
68    op: Operator,
69) -> Result<Transformed<Expr>> {
70    match (cast_expr, literal) {
71        (
72            Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }),
73            Expr::Literal(lit_value, _),
74        ) => {
75            let Ok(expr_type) = info.get_data_type(&expr) else {
76                return internal_err!("Can't get the data type of the expr {:?}", &expr);
77            };
78
79            if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op)
80            {
81                return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
82                    left: expr,
83                    op,
84                    right: Box::new(lit(value)),
85                })));
86            };
87
88            // if the lit_value can be casted to the type of internal_left_expr
89            // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
90            let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else {
91                return internal_err!(
92                    "Can't cast the literal expr {:?} to type {}",
93                    &lit_value,
94                    &expr_type
95                );
96            };
97            Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
98                left: expr,
99                op,
100                right: Box::new(lit(value)),
101            })))
102        }
103        _ => internal_err!("Expect cast expr and literal"),
104    }
105}
106
107pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
108    S: SimplifyInfo,
109>(
110    info: &S,
111    expr: &Expr,
112    op: Operator,
113    literal: &Expr,
114) -> bool {
115    match (expr, literal) {
116        (
117            Expr::TryCast(TryCast {
118                expr: left_expr, ..
119            })
120            | Expr::Cast(Cast {
121                expr: left_expr, ..
122            }),
123            Expr::Literal(lit_val, _),
124        ) => {
125            let Ok(expr_type) = info.get_data_type(left_expr) else {
126                return false;
127            };
128
129            let Ok(lit_type) = info.get_data_type(literal) else {
130                return false;
131            };
132
133            if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() {
134                return true;
135            }
136
137            try_cast_literal_to_type(lit_val, &expr_type).is_some()
138                && is_supported_type(&expr_type)
139                && is_supported_type(&lit_type)
140        }
141        _ => false,
142    }
143}
144
145pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist<
146    S: SimplifyInfo,
147>(
148    info: &S,
149    expr: &Expr,
150    list: &[Expr],
151) -> bool {
152    let (Expr::TryCast(TryCast {
153        expr: left_expr, ..
154    })
155    | Expr::Cast(Cast {
156        expr: left_expr, ..
157    })) = expr
158    else {
159        return false;
160    };
161
162    let Ok(expr_type) = info.get_data_type(left_expr) else {
163        return false;
164    };
165
166    if !is_supported_type(&expr_type) {
167        return false;
168    }
169
170    for right in list {
171        let Ok(right_type) = info.get_data_type(right) else {
172            return false;
173        };
174
175        if !is_supported_type(&right_type) {
176            return false;
177        }
178
179        match right {
180            Expr::Literal(lit_val, _)
181                if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {}
182            _ => return false,
183        }
184    }
185
186    true
187}
188
189///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./
190///
191/// Specifically, rewrites
192/// ```sql
193/// cast(col) <op> <literal>
194/// ```
195///
196/// To
197///
198/// ```sql
199/// col <op> cast(<literal>)
200/// col <op> <casted_literal>
201/// ```
202fn cast_literal_to_type_with_op(
203    lit_value: &ScalarValue,
204    target_type: &DataType,
205    op: Operator,
206) -> Option<ScalarValue> {
207    match (op, lit_value) {
208        (
209            Operator::Eq | Operator::NotEq,
210            ScalarValue::Utf8(Some(_))
211            | ScalarValue::Utf8View(Some(_))
212            | ScalarValue::LargeUtf8(Some(_)),
213        ) => {
214            // Only try for integer types (TODO can we do this for other types
215            // like timestamps)?
216            use DataType::*;
217            if matches!(
218                target_type,
219                Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
220            ) {
221                let casted = lit_value.cast_to(target_type).ok()?;
222                let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?;
223                if lit_value != &round_tripped {
224                    return None;
225                }
226                Some(casted)
227            } else {
228                None
229            }
230        }
231        _ => None,
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use std::collections::HashMap;
239    use std::sync::Arc;
240
241    use crate::simplify_expressions::ExprSimplifier;
242    use arrow::datatypes::{Field, TimeUnit};
243    use datafusion_common::{DFSchema, DFSchemaRef};
244    use datafusion_expr::execution_props::ExecutionProps;
245    use datafusion_expr::simplify::SimplifyContext;
246    use datafusion_expr::{cast, col, in_list, try_cast};
247
248    #[test]
249    fn test_not_unwrap_cast_comparison() {
250        let schema = expr_test_schema();
251        // cast(INT32(c1), INT64) > INT64(c2)
252        let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2"));
253        assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
254
255        // INT32(c1) < INT32(16), the type is same
256        let expr_lt = col("c1").lt(lit(16i32));
257        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
258
259        // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
260        let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
261        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
262
263        // cast(c1, UTF8) < '123', only eq/not_eq should be optimized
264        let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123"));
265        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
266
267        // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not
268        // be casted
269        let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123"));
270        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
271
272        // cast(c1, UTF8) = 'not a number', should not be able to cast to column type
273        let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number"));
274        assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
275
276        // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will
277        // not be optimized to integer comparison
278        let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999"));
279        assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
280    }
281
282    #[test]
283    fn test_unwrap_cast_comparison() {
284        let schema = expr_test_schema();
285        // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16))
286        // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
287        let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64));
288        let expected = col("c1").lt(lit(16i32));
289        assert_eq!(optimize_test(expr_lt, &schema), expected);
290        let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64));
291        let expected = col("c1").lt(lit(16i32));
292        assert_eq!(optimize_test(expr_lt, &schema), expected);
293
294        // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16)
295        let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32));
296        let expected = col("c2").eq(lit(16i64));
297        assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
298
299        // cast(c1, INT64) < INT64(NULL) => NULL
300        let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64());
301        let expected = null_bool();
302        assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
303
304        // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL)
305        let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
306        let expected = null_bool();
307        assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
308
309        // cast(c1, UTF8) = '123' => c1 = 123
310        let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123"));
311        let expected = col("c1").eq(lit(123i32));
312        assert_eq!(optimize_test(expr_input, &schema), expected);
313
314        // cast(c1, UTF8) != '123' => c1 != 123
315        let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123"));
316        let expected = col("c1").not_eq(lit(123i32));
317        assert_eq!(optimize_test(expr_input, &schema), expected);
318
319        // cast(c1, UTF8) = NULL => NULL
320        let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
321        let expected = null_bool();
322        assert_eq!(optimize_test(expr_input, &schema), expected);
323    }
324
325    #[test]
326    fn test_unwrap_cast_comparison_unsigned() {
327        // "cast(c6, UINT64) = 0u64 => c6 = 0u32
328        let schema = expr_test_schema();
329        let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
330        let expected = col("c6").eq(lit(0u32));
331        assert_eq!(optimize_test(expr_input, &schema), expected);
332
333        // cast(c6, UTF8) = "123" => c6 = 123
334        let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123"));
335        let expected = col("c6").eq(lit(123u32));
336        assert_eq!(optimize_test(expr_input, &schema), expected);
337
338        // cast(c6, UTF8) != "123" => c6 != 123
339        let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123"));
340        let expected = col("c6").not_eq(lit(123u32));
341        assert_eq!(optimize_test(expr_input, &schema), expected);
342    }
343
344    #[test]
345    fn test_unwrap_cast_comparison_string() {
346        let schema = expr_test_schema();
347        let dict = ScalarValue::Dictionary(
348            Box::new(DataType::Int32),
349            Box::new(ScalarValue::from("value")),
350        );
351
352        // cast(str1 as Dictionary<Int32, Utf8>) = arrow_cast('value', 'Dictionary<Int32, Utf8>') => str1 = Utf8('value1')
353        let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone()));
354        let expected = col("str1").eq(lit("value"));
355        assert_eq!(optimize_test(expr_input, &schema), expected);
356
357        // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary<Int32, Utf8>')
358        let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value"));
359        let expected = col("tag").eq(lit(dict.clone()));
360        assert_eq!(optimize_test(expr_input, &schema), expected);
361
362        // Verify reversed argument order
363        // arrow_cast('value', 'Dictionary<Int32, Utf8>') = cast(str1 as Dictionary<Int32, Utf8>) => Utf8('value1') = str1
364        let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type()));
365        let expected = col("str1").eq(lit("value"));
366        assert_eq!(optimize_test(expr_input, &schema), expected);
367    }
368
369    #[test]
370    fn test_unwrap_cast_comparison_large_string() {
371        let schema = expr_test_schema();
372        // cast(largestr as Dictionary<Int32, LargeUtf8>) = arrow_cast('value', 'Dictionary<Int32, LargeUtf8>') => str1 = LargeUtf8('value1')
373        let dict = ScalarValue::Dictionary(
374            Box::new(DataType::Int32),
375            Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))),
376        );
377        let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict));
378        let expected =
379            col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned()))));
380        assert_eq!(optimize_test(expr_input, &schema), expected);
381    }
382
383    #[test]
384    fn test_not_unwrap_cast_with_decimal_comparison() {
385        let schema = expr_test_schema();
386        // integer to decimal: value is out of the bounds of the decimal
387        // cast(c3, INT64) = INT64(100000000000000000)
388        let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64));
389        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
390
391        // cast(c4, INT64) = INT64(1000) will overflow the i128
392        let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64));
393        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
394
395        // decimal to decimal: value will lose the scale when convert to the target data type
396        // c3 = DECIMAL(12340,20,4)
397        let expr_eq =
398            cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4));
399        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
400
401        // decimal to integer
402        // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type
403        let expr_eq =
404            cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1));
405        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
406
407        // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type
408        let expr_eq =
409            cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2));
410        assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
411    }
412
413    #[test]
414    fn test_unwrap_cast_with_decimal_lit_comparison() {
415        let schema = expr_test_schema();
416        // integer to decimal
417        // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2));
418        let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64));
419        let expected = col("c3").lt(lit_decimal(1600, 18, 2));
420        assert_eq!(optimize_test(expr_lt, &schema), expected);
421
422        // c3 < INT64(NULL)
423        let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64());
424        let expected = null_bool();
425        assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
426
427        // decimal to decimal
428        // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2)
429        let expr_lt =
430            cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0));
431        let expected = col("c3").lt(lit_decimal(12300, 18, 2));
432        assert_eq!(optimize_test(expr_lt, &schema), expected);
433
434        // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2)
435        let expr_lt =
436            cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3));
437        let expected = col("c3").lt(lit_decimal(123, 18, 2));
438        assert_eq!(optimize_test(expr_lt, &schema), expected);
439
440        // decimal to integer
441        // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123)
442        let expr_lt =
443            cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2));
444        let expected = col("c1").lt(lit(123i32));
445        assert_eq!(optimize_test(expr_lt, &schema), expected);
446    }
447
448    #[test]
449    fn test_not_unwrap_list_cast_lit_comparison() {
450        let schema = expr_test_schema();
451        // internal left type is not supported
452        // FLOAT32(C5) in ...
453        let expr_lt =
454            cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false);
455        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
456
457        // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12))
458        let expr_lt = cast(col("c1"), DataType::Float32)
459            .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false);
460        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
461
462        // INT32(C1) in (INT64(99999999999), INT64(12))
463        let expr_lt = cast(col("c1"), DataType::Int64)
464            .in_list(vec![lit(12i32), lit(99999999999i64)], false);
465        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
466
467        // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3))
468        let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list(
469            vec![
470                lit_decimal(12, 12, 3),
471                lit_decimal(12, 12, 3),
472                lit_decimal(128, 12, 3),
473            ],
474            false,
475        );
476        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
477    }
478
479    #[test]
480    fn test_unwrap_list_cast_comparison() {
481        let schema = expr_test_schema();
482        // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) ->
483        // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78))
484        let expr_lt = cast(col("c1"), DataType::Int64).in_list(
485            vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)],
486            false,
487        );
488        let expected = col("c1").in_list(
489            vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)],
490            false,
491        );
492        assert_eq!(optimize_test(expr_lt, &schema), expected);
493        // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) ->
494        // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78))
495        let expr_lt = cast(col("c2"), DataType::Int32).in_list(
496            vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)],
497            false,
498        );
499        let expected = col("c2").in_list(
500            vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)],
501            false,
502        );
503
504        assert_eq!(optimize_test(expr_lt, &schema), expected);
505
506        // decimal test case
507        // c3 is decimal(18,2)
508        let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list(
509            vec![
510                lit_decimal(12000, 19, 3),
511                lit_decimal(24000, 19, 3),
512                lit_decimal(1280, 19, 3),
513                lit_decimal(1240, 19, 3),
514            ],
515            false,
516        );
517        let expected = col("c3").in_list(
518            vec![
519                lit_decimal(1200, 18, 2),
520                lit_decimal(2400, 18, 2),
521                lit_decimal(128, 18, 2),
522                lit_decimal(124, 18, 2),
523            ],
524            false,
525        );
526        assert_eq!(optimize_test(expr_lt, &schema), expected);
527
528        // cast(INT32(12), INT64) IN (.....) =>
529        // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16))
530        // => true
531        let expr_lt = cast(lit(12i32), DataType::Int64).in_list(
532            vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)],
533            false,
534        );
535        let expected = lit(true);
536        assert_eq!(optimize_test(expr_lt, &schema), expected);
537    }
538
539    #[test]
540    fn aliased() {
541        let schema = expr_test_schema();
542        // c1 < INT64(16) -> c1 < cast(INT32(16))
543        // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
544        let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x");
545        let expected = col("c1").lt(lit(16i32)).alias("x");
546        assert_eq!(optimize_test(expr_lt, &schema), expected);
547    }
548
549    #[test]
550    fn nested() {
551        let schema = expr_test_schema();
552        // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
553        // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32
554        let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast(
555            col("c1"),
556            DataType::Int64,
557        )
558        .gt(lit(32i64)));
559        let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32)));
560        assert_eq!(optimize_test(expr_lt, &schema), expected);
561    }
562
563    #[test]
564    fn test_not_support_data_type() {
565        // "c6 > 0" will be cast to `cast(c6 as float) > 0
566        // but the type of c6 is uint32
567        // the rewriter will not throw error and just return the original expr
568        let schema = expr_test_schema();
569        let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64));
570        assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
571
572        // inlist for unsupported data type
573        let expr_input = in_list(
574            cast(col("c6"), DataType::Float64),
575            // need more literals to avoid rewriting to binary expr
576            vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)],
577            false,
578        );
579        assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
580    }
581
582    #[test]
583    /// Basic integration test for unwrapping casts with different timezones
584    fn test_unwrap_cast_with_timestamp_nanos() {
585        let schema = expr_test_schema();
586        // cast(ts_nano as Timestamp(Nanosecond, UTC)) < 1666612093000000000::Timestamp(Nanosecond, Utc))
587        let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type())
588            .lt(lit_timestamp_nano_utc(1666612093000000000));
589        let expected =
590            col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000));
591        assert_eq!(optimize_test(expr_lt, &schema), expected);
592    }
593
594    fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
595        let props = ExecutionProps::new();
596        let simplifier = ExprSimplifier::new(
597            SimplifyContext::new(&props).with_schema(Arc::clone(schema)),
598        );
599
600        simplifier.simplify(expr).unwrap()
601    }
602
603    fn expr_test_schema() -> DFSchemaRef {
604        Arc::new(
605            DFSchema::from_unqualified_fields(
606                vec![
607                    Field::new("c1", DataType::Int32, false),
608                    Field::new("c2", DataType::Int64, false),
609                    Field::new("c3", DataType::Decimal128(18, 2), false),
610                    Field::new("c4", DataType::Decimal128(38, 37), false),
611                    Field::new("c5", DataType::Float32, false),
612                    Field::new("c6", DataType::UInt32, false),
613                    Field::new("ts_nano_none", timestamp_nano_none_type(), false),
614                    Field::new("ts_nano_utf", timestamp_nano_utc_type(), false),
615                    Field::new("str1", DataType::Utf8, false),
616                    Field::new("largestr", DataType::LargeUtf8, false),
617                    Field::new("tag", dictionary_tag_type(), false),
618                ]
619                .into(),
620                HashMap::new(),
621            )
622            .unwrap(),
623        )
624    }
625
626    fn null_bool() -> Expr {
627        lit(ScalarValue::Boolean(None))
628    }
629
630    fn null_i8() -> Expr {
631        lit(ScalarValue::Int8(None))
632    }
633
634    fn null_i32() -> Expr {
635        lit(ScalarValue::Int32(None))
636    }
637
638    fn null_i64() -> Expr {
639        lit(ScalarValue::Int64(None))
640    }
641
642    fn lit_decimal(value: i128, precision: u8, scale: i8) -> Expr {
643        lit(ScalarValue::Decimal128(Some(value), precision, scale))
644    }
645
646    fn lit_timestamp_nano_none(ts: i64) -> Expr {
647        lit(ScalarValue::TimestampNanosecond(Some(ts), None))
648    }
649
650    fn lit_timestamp_nano_utc(ts: i64) -> Expr {
651        let utc = Some("+0:00".into());
652        lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
653    }
654
655    fn timestamp_nano_none_type() -> DataType {
656        DataType::Timestamp(TimeUnit::Nanosecond, None)
657    }
658
659    // this is the type that now() returns
660    fn timestamp_nano_utc_type() -> DataType {
661        let utc = Some("+0:00".into());
662        DataType::Timestamp(TimeUnit::Nanosecond, utc)
663    }
664
665    // a dictionary type for storing string tags
666    fn dictionary_tag_type() -> DataType {
667        DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))
668    }
669}