datafusion_physical_expr/simplifier/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Simplifier for Physical Expressions
19
20use arrow::datatypes::Schema;
21use datafusion_common::{
22    tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
23    Result,
24};
25use std::sync::Arc;
26
27use crate::{PhysicalExpr, PhysicalExprExt};
28
29pub mod unwrap_cast;
30
31/// Simplifies physical expressions by applying various optimizations
32///
33/// This can be useful after adapting expressions from a table schema
34/// to a file schema. For example, casts added to match the types may
35/// potentially be unwrapped.
36pub struct PhysicalExprSimplifier<'a> {
37    schema: &'a Schema,
38}
39
40impl<'a> PhysicalExprSimplifier<'a> {
41    /// Create a new physical expression simplifier
42    pub fn new(schema: &'a Schema) -> Self {
43        Self { schema }
44    }
45
46    /// Simplify a physical expression
47    pub fn simplify(
48        &mut self,
49        expr: Arc<dyn PhysicalExpr>,
50    ) -> Result<Arc<dyn PhysicalExpr>> {
51        return expr
52            .transform_up_with_schema(self.schema, |node, schema| {
53                // Apply unwrap cast optimization
54                #[cfg(test)]
55                let original_type = node.data_type(schema).unwrap();
56                let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, schema)?;
57                #[cfg(test)]
58                assert_eq!(
59                    unwrapped.data.data_type(schema).unwrap(),
60                    original_type,
61                    "Simplified expression should have the same data type as the original"
62            );
63                Ok(unwrapped)
64            })
65            .data();
66
67        Ok(expr.rewrite(self)?.data)
68    }
69}
70
71impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
72    type Node = Arc<dyn PhysicalExpr>;
73
74    fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
75        // Apply unwrap cast optimization
76        #[cfg(test)]
77        let original_type = node.data_type(self.schema).unwrap();
78        let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?;
79        #[cfg(test)]
80        assert_eq!(
81            unwrapped.data.data_type(self.schema).unwrap(),
82            original_type,
83            "Simplified expression should have the same data type as the original"
84        );
85        Ok(unwrapped)
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr};
93    use arrow::datatypes::{DataType, Field, Schema};
94    use datafusion_common::ScalarValue;
95    use datafusion_expr::Operator;
96
97    fn test_schema() -> Schema {
98        Schema::new(vec![
99            Field::new("c1", DataType::Int32, false),
100            Field::new("c2", DataType::Int64, false),
101            Field::new("c3", DataType::Utf8, false),
102        ])
103    }
104
105    #[test]
106    fn test_simplify() {
107        let schema = test_schema();
108        let mut simplifier = PhysicalExprSimplifier::new(&schema);
109
110        // Create: cast(c2 as INT32) != INT32(99)
111        let column_expr = col("c2", &schema).unwrap();
112        let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int32, None));
113        let literal_expr = lit(ScalarValue::Int32(Some(99)));
114        let binary_expr =
115            Arc::new(BinaryExpr::new(cast_expr, Operator::NotEq, literal_expr));
116
117        // Apply full simplification (uses TreeNodeRewriter)
118        let optimized = simplifier.simplify(binary_expr).unwrap();
119
120        let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
121
122        // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match)
123        let left_expr = optimized_binary.left();
124        assert!(
125            left_expr.as_any().downcast_ref::<CastExpr>().is_none()
126                && left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
127        );
128        let right_literal = optimized_binary
129            .right()
130            .as_any()
131            .downcast_ref::<Literal>()
132            .unwrap();
133        assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99)));
134    }
135
136    #[test]
137    fn test_nested_expression_simplification() {
138        let schema = test_schema();
139        let mut simplifier = PhysicalExprSimplifier::new(&schema);
140
141        // Create nested expression: (cast(c1 as INT64) > INT64(5)) OR (cast(c2 as INT32) <= INT32(10))
142        let c1_expr = col("c1", &schema).unwrap();
143        let c1_cast = Arc::new(CastExpr::new(c1_expr, DataType::Int64, None));
144        let c1_literal = lit(ScalarValue::Int64(Some(5)));
145        let c1_binary = Arc::new(BinaryExpr::new(c1_cast, Operator::Gt, c1_literal));
146
147        let c2_expr = col("c2", &schema).unwrap();
148        let c2_cast = Arc::new(CastExpr::new(c2_expr, DataType::Int32, None));
149        let c2_literal = lit(ScalarValue::Int32(Some(10)));
150        let c2_binary = Arc::new(BinaryExpr::new(c2_cast, Operator::LtEq, c2_literal));
151
152        let or_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::Or, c2_binary));
153
154        // Apply simplification
155        let optimized = simplifier.simplify(or_expr).unwrap();
156
157        let or_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
158
159        // Verify left side: c1 > INT32(5)
160        let left_binary = or_binary
161            .left()
162            .as_any()
163            .downcast_ref::<BinaryExpr>()
164            .unwrap();
165        let left_left_expr = left_binary.left();
166        assert!(
167            left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
168                && left_left_expr
169                    .as_any()
170                    .downcast_ref::<TryCastExpr>()
171                    .is_none()
172        );
173        let left_literal = left_binary
174            .right()
175            .as_any()
176            .downcast_ref::<Literal>()
177            .unwrap();
178        assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5)));
179
180        // Verify right side: c2 <= INT64(10)
181        let right_binary = or_binary
182            .right()
183            .as_any()
184            .downcast_ref::<BinaryExpr>()
185            .unwrap();
186        let right_left_expr = right_binary.left();
187        assert!(
188            right_left_expr
189                .as_any()
190                .downcast_ref::<CastExpr>()
191                .is_none()
192                && right_left_expr
193                    .as_any()
194                    .downcast_ref::<TryCastExpr>()
195                    .is_none()
196        );
197        let right_literal = right_binary
198            .right()
199            .as_any()
200            .downcast_ref::<Literal>()
201            .unwrap();
202        assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10)));
203    }
204}