1use arrow::datatypes::DataType;
58use datafusion_common::{internal_err, tree_node::Transformed};
59use datafusion_common::{Result, ScalarValue};
60use datafusion_expr::{lit, BinaryExpr};
61use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast};
62use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type};
63
64pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
65 info: &S,
66 cast_expr: Expr,
67 literal: Expr,
68 op: Operator,
69) -> Result<Transformed<Expr>> {
70 match (cast_expr, literal) {
71 (
72 Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }),
73 Expr::Literal(lit_value, _),
74 ) => {
75 let Ok(expr_type) = info.get_data_type(&expr) else {
76 return internal_err!("Can't get the data type of the expr {:?}", &expr);
77 };
78
79 if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op)
80 {
81 return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
82 left: expr,
83 op,
84 right: Box::new(lit(value)),
85 })));
86 };
87
88 let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else {
91 return internal_err!(
92 "Can't cast the literal expr {:?} to type {}",
93 &lit_value,
94 &expr_type
95 );
96 };
97 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
98 left: expr,
99 op,
100 right: Box::new(lit(value)),
101 })))
102 }
103 _ => internal_err!("Expect cast expr and literal"),
104 }
105}
106
107pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
108 S: SimplifyInfo,
109>(
110 info: &S,
111 expr: &Expr,
112 op: Operator,
113 literal: &Expr,
114) -> bool {
115 match (expr, literal) {
116 (
117 Expr::TryCast(TryCast {
118 expr: left_expr, ..
119 })
120 | Expr::Cast(Cast {
121 expr: left_expr, ..
122 }),
123 Expr::Literal(lit_val, _),
124 ) => {
125 let Ok(expr_type) = info.get_data_type(left_expr) else {
126 return false;
127 };
128
129 let Ok(lit_type) = info.get_data_type(literal) else {
130 return false;
131 };
132
133 if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() {
134 return true;
135 }
136
137 try_cast_literal_to_type(lit_val, &expr_type).is_some()
138 && is_supported_type(&expr_type)
139 && is_supported_type(&lit_type)
140 }
141 _ => false,
142 }
143}
144
145pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist<
146 S: SimplifyInfo,
147>(
148 info: &S,
149 expr: &Expr,
150 list: &[Expr],
151) -> bool {
152 let (Expr::TryCast(TryCast {
153 expr: left_expr, ..
154 })
155 | Expr::Cast(Cast {
156 expr: left_expr, ..
157 })) = expr
158 else {
159 return false;
160 };
161
162 let Ok(expr_type) = info.get_data_type(left_expr) else {
163 return false;
164 };
165
166 if !is_supported_type(&expr_type) {
167 return false;
168 }
169
170 for right in list {
171 let Ok(right_type) = info.get_data_type(right) else {
172 return false;
173 };
174
175 if !is_supported_type(&right_type) {
176 return false;
177 }
178
179 match right {
180 Expr::Literal(lit_val, _)
181 if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {}
182 _ => return false,
183 }
184 }
185
186 true
187}
188
189fn cast_literal_to_type_with_op(
203 lit_value: &ScalarValue,
204 target_type: &DataType,
205 op: Operator,
206) -> Option<ScalarValue> {
207 match (op, lit_value) {
208 (
209 Operator::Eq | Operator::NotEq,
210 ScalarValue::Utf8(Some(_))
211 | ScalarValue::Utf8View(Some(_))
212 | ScalarValue::LargeUtf8(Some(_)),
213 ) => {
214 use DataType::*;
217 if matches!(
218 target_type,
219 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
220 ) {
221 let casted = lit_value.cast_to(target_type).ok()?;
222 let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?;
223 if lit_value != &round_tripped {
224 return None;
225 }
226 Some(casted)
227 } else {
228 None
229 }
230 }
231 _ => None,
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use std::collections::HashMap;
239 use std::sync::Arc;
240
241 use crate::simplify_expressions::ExprSimplifier;
242 use arrow::datatypes::{Field, TimeUnit};
243 use datafusion_common::{DFSchema, DFSchemaRef};
244 use datafusion_expr::execution_props::ExecutionProps;
245 use datafusion_expr::simplify::SimplifyContext;
246 use datafusion_expr::{cast, col, in_list, try_cast};
247
248 #[test]
249 fn test_not_unwrap_cast_comparison() {
250 let schema = expr_test_schema();
251 let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2"));
253 assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2);
254
255 let expr_lt = col("c1").lt(lit(16i32));
257 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
258
259 let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
261 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
262
263 let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123"));
265 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
266
267 let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123"));
270 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
271
272 let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number"));
274 assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
275
276 let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999"));
279 assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
280 }
281
282 #[test]
283 fn test_unwrap_cast_comparison() {
284 let schema = expr_test_schema();
285 let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64));
288 let expected = col("c1").lt(lit(16i32));
289 assert_eq!(optimize_test(expr_lt, &schema), expected);
290 let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64));
291 let expected = col("c1").lt(lit(16i32));
292 assert_eq!(optimize_test(expr_lt, &schema), expected);
293
294 let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32));
296 let expected = col("c2").eq(lit(16i64));
297 assert_eq!(optimize_test(c2_eq_lit, &schema), expected);
298
299 let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64());
301 let expected = null_bool();
302 assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
303
304 let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
306 let expected = null_bool();
307 assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
308
309 let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123"));
311 let expected = col("c1").eq(lit(123i32));
312 assert_eq!(optimize_test(expr_input, &schema), expected);
313
314 let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123"));
316 let expected = col("c1").not_eq(lit(123i32));
317 assert_eq!(optimize_test(expr_input, &schema), expected);
318
319 let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None)));
321 let expected = null_bool();
322 assert_eq!(optimize_test(expr_input, &schema), expected);
323 }
324
325 #[test]
326 fn test_unwrap_cast_comparison_unsigned() {
327 let schema = expr_test_schema();
329 let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
330 let expected = col("c6").eq(lit(0u32));
331 assert_eq!(optimize_test(expr_input, &schema), expected);
332
333 let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123"));
335 let expected = col("c6").eq(lit(123u32));
336 assert_eq!(optimize_test(expr_input, &schema), expected);
337
338 let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123"));
340 let expected = col("c6").not_eq(lit(123u32));
341 assert_eq!(optimize_test(expr_input, &schema), expected);
342 }
343
344 #[test]
345 fn test_unwrap_cast_comparison_string() {
346 let schema = expr_test_schema();
347 let dict = ScalarValue::Dictionary(
348 Box::new(DataType::Int32),
349 Box::new(ScalarValue::from("value")),
350 );
351
352 let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone()));
354 let expected = col("str1").eq(lit("value"));
355 assert_eq!(optimize_test(expr_input, &schema), expected);
356
357 let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value"));
359 let expected = col("tag").eq(lit(dict.clone()));
360 assert_eq!(optimize_test(expr_input, &schema), expected);
361
362 let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type()));
365 let expected = col("str1").eq(lit("value"));
366 assert_eq!(optimize_test(expr_input, &schema), expected);
367 }
368
369 #[test]
370 fn test_unwrap_cast_comparison_large_string() {
371 let schema = expr_test_schema();
372 let dict = ScalarValue::Dictionary(
374 Box::new(DataType::Int32),
375 Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))),
376 );
377 let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict));
378 let expected =
379 col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned()))));
380 assert_eq!(optimize_test(expr_input, &schema), expected);
381 }
382
383 #[test]
384 fn test_not_unwrap_cast_with_decimal_comparison() {
385 let schema = expr_test_schema();
386 let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64));
389 assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
390
391 let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64));
393 assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
394
395 let expr_eq =
398 cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4));
399 assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
400
401 let expr_eq =
404 cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1));
405 assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
406
407 let expr_eq =
409 cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2));
410 assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq);
411 }
412
413 #[test]
414 fn test_unwrap_cast_with_decimal_lit_comparison() {
415 let schema = expr_test_schema();
416 let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64));
419 let expected = col("c3").lt(lit_decimal(1600, 18, 2));
420 assert_eq!(optimize_test(expr_lt, &schema), expected);
421
422 let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64());
424 let expected = null_bool();
425 assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
426
427 let expr_lt =
430 cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0));
431 let expected = col("c3").lt(lit_decimal(12300, 18, 2));
432 assert_eq!(optimize_test(expr_lt, &schema), expected);
433
434 let expr_lt =
436 cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3));
437 let expected = col("c3").lt(lit_decimal(123, 18, 2));
438 assert_eq!(optimize_test(expr_lt, &schema), expected);
439
440 let expr_lt =
443 cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2));
444 let expected = col("c1").lt(lit(123i32));
445 assert_eq!(optimize_test(expr_lt, &schema), expected);
446 }
447
448 #[test]
449 fn test_not_unwrap_list_cast_lit_comparison() {
450 let schema = expr_test_schema();
451 let expr_lt =
454 cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false);
455 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
456
457 let expr_lt = cast(col("c1"), DataType::Float32)
459 .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false);
460 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
461
462 let expr_lt = cast(col("c1"), DataType::Int64)
464 .in_list(vec![lit(12i32), lit(99999999999i64)], false);
465 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
466
467 let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list(
469 vec![
470 lit_decimal(12, 12, 3),
471 lit_decimal(12, 12, 3),
472 lit_decimal(128, 12, 3),
473 ],
474 false,
475 );
476 assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
477 }
478
479 #[test]
480 fn test_unwrap_list_cast_comparison() {
481 let schema = expr_test_schema();
482 let expr_lt = cast(col("c1"), DataType::Int64).in_list(
485 vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)],
486 false,
487 );
488 let expected = col("c1").in_list(
489 vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)],
490 false,
491 );
492 assert_eq!(optimize_test(expr_lt, &schema), expected);
493 let expr_lt = cast(col("c2"), DataType::Int32).in_list(
496 vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)],
497 false,
498 );
499 let expected = col("c2").in_list(
500 vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)],
501 false,
502 );
503
504 assert_eq!(optimize_test(expr_lt, &schema), expected);
505
506 let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list(
509 vec![
510 lit_decimal(12000, 19, 3),
511 lit_decimal(24000, 19, 3),
512 lit_decimal(1280, 19, 3),
513 lit_decimal(1240, 19, 3),
514 ],
515 false,
516 );
517 let expected = col("c3").in_list(
518 vec![
519 lit_decimal(1200, 18, 2),
520 lit_decimal(2400, 18, 2),
521 lit_decimal(128, 18, 2),
522 lit_decimal(124, 18, 2),
523 ],
524 false,
525 );
526 assert_eq!(optimize_test(expr_lt, &schema), expected);
527
528 let expr_lt = cast(lit(12i32), DataType::Int64).in_list(
532 vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)],
533 false,
534 );
535 let expected = lit(true);
536 assert_eq!(optimize_test(expr_lt, &schema), expected);
537 }
538
539 #[test]
540 fn aliased() {
541 let schema = expr_test_schema();
542 let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x");
545 let expected = col("c1").lt(lit(16i32)).alias("x");
546 assert_eq!(optimize_test(expr_lt, &schema), expected);
547 }
548
549 #[test]
550 fn nested() {
551 let schema = expr_test_schema();
552 let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast(
555 col("c1"),
556 DataType::Int64,
557 )
558 .gt(lit(32i64)));
559 let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32)));
560 assert_eq!(optimize_test(expr_lt, &schema), expected);
561 }
562
563 #[test]
564 fn test_not_support_data_type() {
565 let schema = expr_test_schema();
569 let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64));
570 assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
571
572 let expr_input = in_list(
574 cast(col("c6"), DataType::Float64),
575 vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)],
577 false,
578 );
579 assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
580 }
581
582 #[test]
583 fn test_unwrap_cast_with_timestamp_nanos() {
585 let schema = expr_test_schema();
586 let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type())
588 .lt(lit_timestamp_nano_utc(1666612093000000000));
589 let expected =
590 col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000));
591 assert_eq!(optimize_test(expr_lt, &schema), expected);
592 }
593
594 fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
595 let props = ExecutionProps::new();
596 let simplifier = ExprSimplifier::new(
597 SimplifyContext::new(&props).with_schema(Arc::clone(schema)),
598 );
599
600 simplifier.simplify(expr).unwrap()
601 }
602
603 fn expr_test_schema() -> DFSchemaRef {
604 Arc::new(
605 DFSchema::from_unqualified_fields(
606 vec![
607 Field::new("c1", DataType::Int32, false),
608 Field::new("c2", DataType::Int64, false),
609 Field::new("c3", DataType::Decimal128(18, 2), false),
610 Field::new("c4", DataType::Decimal128(38, 37), false),
611 Field::new("c5", DataType::Float32, false),
612 Field::new("c6", DataType::UInt32, false),
613 Field::new("ts_nano_none", timestamp_nano_none_type(), false),
614 Field::new("ts_nano_utf", timestamp_nano_utc_type(), false),
615 Field::new("str1", DataType::Utf8, false),
616 Field::new("largestr", DataType::LargeUtf8, false),
617 Field::new("tag", dictionary_tag_type(), false),
618 ]
619 .into(),
620 HashMap::new(),
621 )
622 .unwrap(),
623 )
624 }
625
626 fn null_bool() -> Expr {
627 lit(ScalarValue::Boolean(None))
628 }
629
630 fn null_i8() -> Expr {
631 lit(ScalarValue::Int8(None))
632 }
633
634 fn null_i32() -> Expr {
635 lit(ScalarValue::Int32(None))
636 }
637
638 fn null_i64() -> Expr {
639 lit(ScalarValue::Int64(None))
640 }
641
642 fn lit_decimal(value: i128, precision: u8, scale: i8) -> Expr {
643 lit(ScalarValue::Decimal128(Some(value), precision, scale))
644 }
645
646 fn lit_timestamp_nano_none(ts: i64) -> Expr {
647 lit(ScalarValue::TimestampNanosecond(Some(ts), None))
648 }
649
650 fn lit_timestamp_nano_utc(ts: i64) -> Expr {
651 let utc = Some("+0:00".into());
652 lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
653 }
654
655 fn timestamp_nano_none_type() -> DataType {
656 DataType::Timestamp(TimeUnit::Nanosecond, None)
657 }
658
659 fn timestamp_nano_utc_type() -> DataType {
661 let utc = Some("+0:00".into());
662 DataType::Timestamp(TimeUnit::Nanosecond, utc)
663 }
664
665 fn dictionary_tag_type() -> DataType {
667 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8))
668 }
669}