datafusion_physical_expr/simplifier/
mod.rs1use arrow::datatypes::Schema;
21use datafusion_common::{
22 tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
23 Result,
24};
25use std::sync::Arc;
26
27use crate::{PhysicalExpr, PhysicalExprExt};
28
29pub mod unwrap_cast;
30
31pub struct PhysicalExprSimplifier<'a> {
37 schema: &'a Schema,
38}
39
40impl<'a> PhysicalExprSimplifier<'a> {
41 pub fn new(schema: &'a Schema) -> Self {
43 Self { schema }
44 }
45
46 pub fn simplify(
48 &mut self,
49 expr: Arc<dyn PhysicalExpr>,
50 ) -> Result<Arc<dyn PhysicalExpr>> {
51 return expr
52 .transform_up_with_schema(self.schema, |node, schema| {
53 #[cfg(test)]
55 let original_type = node.data_type(schema).unwrap();
56 let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, schema)?;
57 #[cfg(test)]
58 assert_eq!(
59 unwrapped.data.data_type(schema).unwrap(),
60 original_type,
61 "Simplified expression should have the same data type as the original"
62 );
63 Ok(unwrapped)
64 })
65 .data();
66
67 Ok(expr.rewrite(self)?.data)
68 }
69}
70
71impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
72 type Node = Arc<dyn PhysicalExpr>;
73
74 fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
75 #[cfg(test)]
77 let original_type = node.data_type(self.schema).unwrap();
78 let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?;
79 #[cfg(test)]
80 assert_eq!(
81 unwrapped.data.data_type(self.schema).unwrap(),
82 original_type,
83 "Simplified expression should have the same data type as the original"
84 );
85 Ok(unwrapped)
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr};
93 use arrow::datatypes::{DataType, Field, Schema};
94 use datafusion_common::ScalarValue;
95 use datafusion_expr::Operator;
96
97 fn test_schema() -> Schema {
98 Schema::new(vec![
99 Field::new("c1", DataType::Int32, false),
100 Field::new("c2", DataType::Int64, false),
101 Field::new("c3", DataType::Utf8, false),
102 ])
103 }
104
105 #[test]
106 fn test_simplify() {
107 let schema = test_schema();
108 let mut simplifier = PhysicalExprSimplifier::new(&schema);
109
110 let column_expr = col("c2", &schema).unwrap();
112 let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
113 let literal_expr = lit(ScalarValue::Int32(Some(99)));
114 let binary_expr =
115 Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
116
117 let optimized = simplifier.simplify(binary_expr).unwrap();
119
120 let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
121
122 let left_expr = optimized_binary.left();
124 assert!(
125 left_expr.as_any().downcast_ref::<CastExpr>().is_none()
126 && left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
127 );
128 let right_literal = optimized_binary
129 .right()
130 .as_any()
131 .downcast_ref::<Literal>()
132 .unwrap();
133 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
134 }
135
136 #[test]
137 fn test_nested_expression_simplification() {
138 let schema = test_schema();
139 let mut simplifier = PhysicalExprSimplifier::new(&schema);
140
141 let c1_expr = col("c1", &schema).unwrap();
143 let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
144 let c1_literal = lit(ScalarValue::Int64(Some(5)));
145 let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
146
147 let c2_expr = col("c2", &schema).unwrap();
148 let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
149 let c2_literal = lit(ScalarValue::Int32(Some(10)));
150 let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
151
152 let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
153
154 let optimized = simplifier.simplify(or_expr).unwrap();
156
157 let or_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
158
159 let left_binary = or_binary
161 .left()
162 .as_any()
163 .downcast_ref::<BinaryExpr>()
164 .unwrap();
165 let left_left_expr = left_binary.left();
166 assert!(
167 left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
168 && left_left_expr
169 .as_any()
170 .downcast_ref::<TryCastExpr>()
171 .is_none()
172 );
173 let left_literal = left_binary
174 .right()
175 .as_any()
176 .downcast_ref::<Literal>()
177 .unwrap();
178 assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
179
180 let right_binary = or_binary
182 .right()
183 .as_any()
184 .downcast_ref::<BinaryExpr>()
185 .unwrap();
186 let right_left_expr = right_binary.left();
187 assert!(
188 right_left_expr
189 .as_any()
190 .downcast_ref::<CastExpr>()
191 .is_none()
192 && right_left_expr
193 .as_any()
194 .downcast_ref::<TryCastExpr>()
195 .is_none()
196 );
197 let right_literal = right_binary
198 .right()
199 .as_any()
200 .downcast_ref::<Literal>()
201 .unwrap();
202 assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
203 }
204}