1use std::sync::Arc;
35
36use arrow::datatypes::{DataType, Schema};
37use datafusion_common::{tree_node::Transformed, Result, ScalarValue};
38use datafusion_expr::Operator;
39use datafusion_expr_common::casts::try_cast_literal_to_type;
40
41use crate::PhysicalExpr;
42use crate::{
43 expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr},
44 PhysicalExprExt,
45};
46
47pub(crate) fn unwrap_cast_in_comparison(
49 expr: Arc<dyn PhysicalExpr>,
50 schema: &Schema,
51) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
52 expr.transform_down_with_schema(schema, |e, schema| {
53 if let Some(binary) = e.as_any().downcast_ref::<BinaryExpr>() {
54 if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? {
55 return Ok(Transformed::yes(unwrapped));
56 }
57 }
58 Ok(Transformed::no(e))
59 })
60}
61
62fn try_unwrap_cast_binary(
64 binary: &BinaryExpr,
65 schema: &Schema,
66) -> Result<Option<Arc<dyn PhysicalExpr>>> {
67 if let (Some((inner_expr, _cast_type)), Some(literal)) = (
69 extract_cast_info(binary.left()),
70 binary.right().as_any().downcast_ref::<Literal>(),
71 ) {
72 if binary.op().supports_propagation() {
73 if let Some(unwrapped) = try_unwrap_cast_comparison(
74 Arc::clone(inner_expr),
75 literal.value(),
76 *binary.op(),
77 schema,
78 )? {
79 return Ok(Some(unwrapped));
80 }
81 }
82 }
83
84 if let (Some(literal), Some((inner_expr, _cast_type))) = (
86 binary.left().as_any().downcast_ref::<Literal>(),
87 extract_cast_info(binary.right()),
88 ) {
89 if let Some(swapped_op) = binary.op().swap() {
91 if binary.op().supports_propagation() {
92 if let Some(unwrapped) = try_unwrap_cast_comparison(
93 Arc::clone(inner_expr),
94 literal.value(),
95 swapped_op,
96 schema,
97 )? {
98 return Ok(Some(unwrapped));
99 }
100 }
101 }
102 }
105
106 Ok(None)
107}
108
109fn extract_cast_info(
114 expr: &Arc<dyn PhysicalExpr>,
115) -> Option<(&Arc<dyn PhysicalExpr>, &DataType)> {
116 if let Some(cast) = expr.as_any().downcast_ref::<CastExpr>() {
117 Some((cast.expr(), cast.cast_type()))
118 } else if let Some(try_cast) = expr.as_any().downcast_ref::<TryCastExpr>() {
119 Some((try_cast.expr(), try_cast.cast_type()))
120 } else {
121 None
122 }
123}
124
125fn try_unwrap_cast_comparison(
127 inner_expr: Arc<dyn PhysicalExpr>,
128 literal_value: &ScalarValue,
129 op: Operator,
130 schema: &Schema,
131) -> Result<Option<Arc<dyn PhysicalExpr>>> {
132 let inner_type = inner_expr.data_type(schema)?;
134
135 if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) {
137 let literal_expr = lit(casted_literal);
138 let binary_expr = BinaryExpr::new(inner_expr, op, literal_expr);
139 return Ok(Some(Arc::new(binary_expr)));
140 }
141
142 Ok(None)
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::expressions::{col, lit};
149 use arrow::datatypes::{DataType, Field, Schema};
150 use datafusion_common::ScalarValue;
151 use datafusion_expr::Operator;
152
153 fn is_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
155 expr.as_any().downcast_ref::<CastExpr>().is_some()
156 || expr.as_any().downcast_ref::<TryCastExpr>().is_some()
157 }
158
159 fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool {
161 let left_cast_right_literal = is_cast_expr(binary.left())
163 && binary.right().as_any().downcast_ref::<Literal>().is_some();
164
165 let left_literal_right_cast =
167 binary.left().as_any().downcast_ref::<Literal>().is_some()
168 && is_cast_expr(binary.right());
169
170 left_cast_right_literal || left_literal_right_cast
171 }
172
173 fn test_schema() -> Schema {
174 Schema::new(vec![
175 Field::new("c1", DataType::Int32, false),
176 Field::new("c2", DataType::Int64, false),
177 Field::new("c3", DataType::Utf8, false),
178 ])
179 }
180
181 #[test]
182 fn test_unwrap_cast_in_binary_comparison() {
183 let schema = test_schema();
184
185 let column_expr = col("c1", &schema).unwrap();
187 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
188 let literal_expr = lit(10i64);
189 let binary_expr =
190 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
191
192 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
194
195 assert!(result.transformed);
197
198 let optimized = result.data;
200 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
201
202 assert!(!is_cast_expr(optimized_binary.left()));
204
205 let right_literal = optimized_binary
207 .right()
208 .as_any()
209 .downcast_ref::<Literal>()
210 .unwrap();
211 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10)));
212 }
213
214 #[test]
215 fn test_unwrap_cast_with_literal_on_left() {
216 let schema = test_schema();
217
218 let column_expr = col("c1", &schema).unwrap();
220 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
221 let literal_expr = lit(10i64);
222 let binary_expr =
223 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
224
225 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
227
228 assert!(result.transformed);
230
231 let optimized = result.data;
233 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
234
235 assert_eq!(*optimized_binary.op(), Operator::Gt);
237 }
238
239 #[test]
240 fn test_no_unwrap_when_types_unsupported() {
241 let schema = Schema::new(vec![Field::new("f1", DataType::Float32, false)]);
242
243 let column_expr = col("f1", &schema).unwrap();
245 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Float64, None));
246 let literal_expr = lit(10.5f64);
247 let binary_expr =
248 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
249
250 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
252
253 assert!(!result.transformed);
255 }
256
257 #[test]
258 fn test_is_binary_expr_with_cast_and_literal() {
259 let schema = test_schema();
260
261 let column_expr = col("c1", &schema).unwrap();
262 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
263 let literal_expr = lit(10i64);
264 let binary_expr =
265 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
266 let binary_ref = binary_expr.as_any().downcast_ref::<BinaryExpr>().unwrap();
267
268 assert!(is_binary_expr_with_cast_and_literal(binary_ref));
269 }
270
271 #[test]
272 fn test_unwrap_cast_literal_on_left_side() {
273 let schema = Schema::new(vec![Field::new(
276 "decimal_col",
277 DataType::Decimal128(9, 2),
278 true,
279 )]);
280
281 let column_expr = col("decimal_col", &schema).unwrap();
283 let cast_expr = Arc::new(CastExpr::new(
284 column_expr,
285 DataType::Decimal128(22, 2),
286 None,
287 ));
288 let literal_expr = lit(ScalarValue::Decimal128(Some(400), 22, 2));
289 let binary_expr =
290 Arc::new(BinaryExpr::new(literal_expr, Operator::LtEq, cast_expr));
291
292 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
294
295 assert!(result.transformed);
297
298 let optimized = result.data;
300 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
301
302 assert_eq!(*optimized_binary.op(), Operator::GtEq);
304
305 assert!(!is_cast_expr(optimized_binary.left()));
307
308 let right_literal = optimized_binary
310 .right()
311 .as_any()
312 .downcast_ref::<Literal>()
313 .unwrap();
314 assert_eq!(
315 right_literal.value().data_type(),
316 DataType::Decimal128(9, 2)
317 );
318 }
319
320 #[test]
321 fn test_unwrap_cast_with_different_comparison_operators() {
322 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
323
324 let operators = vec![
326 (Operator::Lt, Operator::Gt),
327 (Operator::LtEq, Operator::GtEq),
328 (Operator::Gt, Operator::Lt),
329 (Operator::GtEq, Operator::LtEq),
330 (Operator::Eq, Operator::Eq),
331 (Operator::NotEq, Operator::NotEq),
332 ];
333
334 for (original_op, expected_op) in operators {
335 let column_expr = col("int_col", &schema).unwrap();
337 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
338 let literal_expr = lit(100i64);
339 let binary_expr =
340 Arc::new(BinaryExpr::new(literal_expr, original_op, cast_expr));
341
342 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
344
345 assert!(result.transformed);
347
348 let optimized = result.data;
349 let optimized_binary =
350 optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
351
352 assert_eq!(
354 *optimized_binary.op(),
355 expected_op,
356 "Failed for operator {original_op:?} -> {expected_op:?}"
357 );
358
359 assert!(!is_cast_expr(optimized_binary.left()));
361
362 let right_literal = optimized_binary
364 .right()
365 .as_any()
366 .downcast_ref::<Literal>()
367 .unwrap();
368 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
369 }
370 }
371
372 #[test]
373 fn test_unwrap_cast_with_decimal_types() {
374 let test_cases = vec![
376 (9, 2, 22, 2, 400),
378 (10, 3, 20, 3, 1000),
379 (5, 1, 10, 1, 99),
380 ];
381
382 for (col_p, col_s, cast_p, cast_s, value) in test_cases {
383 let schema = Schema::new(vec![Field::new(
384 "decimal_col",
385 DataType::Decimal128(col_p, col_s),
386 true,
387 )]);
388
389 let column_expr = col("decimal_col", &schema).unwrap();
393 let cast_expr = Arc::new(CastExpr::new(
394 Arc::clone(&column_expr),
395 DataType::Decimal128(cast_p, cast_s),
396 None,
397 ));
398 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
399 let binary_expr =
400 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
401
402 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
403 assert!(result.transformed);
404
405 let cast_expr = Arc::new(CastExpr::new(
407 column_expr,
408 DataType::Decimal128(cast_p, cast_s),
409 None,
410 ));
411 let literal_expr = lit(ScalarValue::Decimal128(Some(value), cast_p, cast_s));
412 let binary_expr =
413 Arc::new(BinaryExpr::new(literal_expr, Operator::Lt, cast_expr));
414
415 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
416 assert!(result.transformed);
417 }
418 }
419
420 #[test]
421 fn test_unwrap_cast_with_null_literals() {
422 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, true)]);
424
425 let column_expr = col("int_col", &schema).unwrap();
427 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
428 let null_literal = lit(ScalarValue::Int64(None));
429 let binary_expr =
430 Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, null_literal));
431
432 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
434
435 assert!(result.transformed);
437
438 let optimized = result.data;
440 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
441 let right_literal = optimized_binary
442 .right()
443 .as_any()
444 .downcast_ref::<Literal>()
445 .unwrap();
446 assert_eq!(right_literal.value(), &ScalarValue::Int32(None));
447 }
448
449 #[test]
450 fn test_unwrap_cast_with_try_cast() {
451 let schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]);
453
454 let column_expr = col("str_col", &schema).unwrap();
456 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
457 let literal_expr = lit(100i64);
458 let binary_expr =
459 Arc::new(BinaryExpr::new(try_cast_expr, Operator::Gt, literal_expr));
460
461 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
463
464 assert!(!result.transformed);
466 }
467
468 #[test]
469 fn test_unwrap_cast_preserves_non_comparison_operators() {
470 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
472
473 let column_expr = col("int_col", &schema).unwrap();
475
476 let cast1 = Arc::new(CastExpr::new(
477 Arc::clone(&column_expr),
478 DataType::Int64,
479 None,
480 ));
481 let lit1 = lit(10i64);
482 let compare1 = Arc::new(BinaryExpr::new(cast1, Operator::Gt, lit1));
483
484 let cast2 = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
485 let lit2 = lit(20i64);
486 let compare2 = Arc::new(BinaryExpr::new(cast2, Operator::Lt, lit2));
487
488 let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2));
489
490 let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap();
492
493 assert!(result.transformed);
495
496 let optimized = result.data;
498 let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
499 assert_eq!(*and_binary.op(), Operator::And);
500
501 let left_binary = and_binary
503 .left()
504 .as_any()
505 .downcast_ref::<BinaryExpr>()
506 .unwrap();
507 let right_binary = and_binary
508 .right()
509 .as_any()
510 .downcast_ref::<BinaryExpr>()
511 .unwrap();
512
513 assert!(!is_cast_expr(left_binary.left()));
514 assert!(!is_cast_expr(right_binary.left()));
515 }
516
517 #[test]
518 fn test_try_cast_unwrapping() {
519 let schema = test_schema();
520
521 let column_expr = col("c1", &schema).unwrap();
523 let try_cast_expr = Arc::new(TryCastExpr::new(column_expr, DataType::Int64));
524 let literal_expr = lit(100i64);
525 let binary_expr =
526 Arc::new(BinaryExpr::new(try_cast_expr, Operator::LtEq, literal_expr));
527
528 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
530
531 assert!(result.transformed);
533
534 let optimized = result.data;
535 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
536
537 assert!(!is_cast_expr(optimized_binary.left()));
539
540 let right_literal = optimized_binary
542 .right()
543 .as_any()
544 .downcast_ref::<Literal>()
545 .unwrap();
546 assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100)));
547 }
548
549 #[test]
550 fn test_non_swappable_operator() {
551 let schema = Schema::new(vec![Field::new("int_col", DataType::Int32, false)]);
553
554 let column_expr = col("int_col", &schema).unwrap();
557 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
558 let literal_expr = lit(10i64);
559 let binary_expr =
560 Arc::new(BinaryExpr::new(literal_expr, Operator::Plus, cast_expr));
561
562 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
564
565 assert!(!result.transformed);
567 }
568
569 #[test]
570 fn test_cast_that_cannot_be_unwrapped_overflow() {
571 let schema = Schema::new(vec![Field::new("small_int", DataType::Int8, false)]);
573
574 let column_expr = col("small_int", &schema).unwrap();
577 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
578 let literal_expr = lit(1000i64); let binary_expr =
580 Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr));
581
582 let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
584
585 assert!(!result.transformed);
587 }
588
589 #[test]
590 fn test_complex_nested_expression() {
591 let schema = test_schema();
592
593 let c1_expr = col("c1", &schema).unwrap();
596 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
597 let c1_literal = lit(10i64);
598 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
599
600 let c2_expr = col("c2", &schema).unwrap();
601 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
602 let c2_literal = lit(20i32);
603 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::Eq, c2_literal));
604
605 let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary));
607
608 let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap();
610
611 assert!(result.transformed);
613
614 let optimized = result.data;
616 let and_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
617
618 let left_binary = and_binary
620 .left()
621 .as_any()
622 .downcast_ref::<BinaryExpr>()
623 .unwrap();
624 assert!(!is_cast_expr(left_binary.left()));
625 let left_literal = left_binary
626 .right()
627 .as_any()
628 .downcast_ref::<Literal>()
629 .unwrap();
630 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10)));
631
632 let right_binary = and_binary
634 .right()
635 .as_any()
636 .downcast_ref::<BinaryExpr>()
637 .unwrap();
638 assert!(!is_cast_expr(right_binary.left()));
639 let right_literal = right_binary
640 .right()
641 .as_any()
642 .downcast_ref::<Literal>()
643 .unwrap();
644 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20)));
645 }
646}