datafusion_expr/expr_rewriter/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression rewriter
19
20use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::TableReference;
32use datafusion_common::{Column, DFSchema, Result};
33
34mod order_by;
35pub use order_by::rewrite_sort_cols_by_aggs;
36
37/// Trait for rewriting [`Expr`]s into function calls.
38///
39/// This trait is used with `FunctionRegistry::register_function_rewrite` to
40/// to evaluating `Expr`s using functions that may not be built in to DataFusion
41///
42/// For example, concatenating arrays `a || b` is represented as
43/// `Operator::ArrowAt`, but can be implemented by calling a function
44/// `array_concat` from the `functions-nested` crate.
45// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it.
46pub trait FunctionRewrite: Debug {
47    /// Return a human readable name for this rewrite
48    fn name(&self) -> &str;
49
50    /// Potentially rewrite `expr` to some other expression
51    ///
52    /// Note that recursion is handled by the caller -- this method should only
53    /// handle `expr`, not recurse to its children.
54    fn rewrite(
55        &self,
56        expr: Expr,
57        schema: &DFSchema,
58        config: &ConfigOptions,
59    ) -> Result<Transformed<Expr>>;
60}
61
62/// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions
63/// in the `expr` expression tree.
64pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
65    expr.transform_up_with_lambdas_params(|expr, lambdas_params| {
66        Ok({
67            if let Expr::Column(c) = expr {
68                if c.relation.is_some() || !lambdas_params.contains(c.name()) {
69                    let col = LogicalPlanBuilder::normalize(plan, c)?;
70                    Transformed::yes(Expr::Column(col))
71                } else {
72                    Transformed::no(Expr::Column(c))
73                }
74            } else {
75                Transformed::no(expr)
76            }
77        })
78    })
79    .data()
80}
81
82/// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage
83pub fn normalize_col_with_schemas_and_ambiguity_check(
84    expr: Expr,
85    schemas: &[&[&DFSchema]],
86    using_columns: &[HashSet<Column>],
87) -> Result<Expr> {
88    // Normalize column inside Unnest
89    if let Expr::Unnest(Unnest { expr }) = expr {
90        let e = normalize_col_with_schemas_and_ambiguity_check(
91            expr.as_ref().clone(),
92            schemas,
93            using_columns,
94        )?;
95        return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
96    }
97
98    expr.transform_up_with_lambdas_params(|expr, lambdas_params| {
99        Ok({
100            match expr {
101                Expr::Column(c) => {
102                    if c.relation.is_none() && lambdas_params.contains(c.name()) {
103                        Transformed::no(Expr::Column(c))
104                    } else {
105                        let col = c.normalize_with_schemas_and_ambiguity_check(
106                            schemas,
107                            using_columns,
108                        )?;
109                        Transformed::yes(Expr::Column(col))
110                    }
111                }
112                _ => Transformed::no(expr),
113            }
114        })
115    })
116    .data()
117}
118
119/// Recursively normalize all [`Column`] expressions in a list of expression trees
120pub fn normalize_cols(
121    exprs: impl IntoIterator<Item = impl Into<Expr>>,
122    plan: &LogicalPlan,
123) -> Result<Vec<Expr>> {
124    exprs
125        .into_iter()
126        .map(|e| normalize_col(e.into(), plan))
127        .collect()
128}
129
130pub fn normalize_sorts(
131    sorts: impl IntoIterator<Item = impl Into<Sort>>,
132    plan: &LogicalPlan,
133) -> Result<Vec<Sort>> {
134    sorts
135        .into_iter()
136        .map(|e| {
137            let sort = e.into();
138            normalize_col(sort.expr, plan)
139                .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
140        })
141        .collect()
142}
143
144/// Recursively replace all [`Column`] expressions in a given expression tree with
145/// `Column` expressions provided by the hash map argument.
146pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
147    expr.transform_up_with_lambdas_params(|expr, lambdas_params| {
148        Ok({
149            match &expr {
150                Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
151                    match replace_map.get(c) {
152                        Some(new_c) => {
153                            Transformed::yes(Expr::Column((*new_c).to_owned()))
154                        }
155                        None => Transformed::no(expr),
156                    }
157                }
158                _ => Transformed::no(expr),
159            }
160        })
161    })
162    .data()
163}
164
165/// Recursively 'unnormalize' (remove all qualifiers) from an
166/// expression tree.
167///
168/// For example, if there were expressions like `foo.bar` this would
169/// rewrite it to just `bar`.
170pub fn unnormalize_col(expr: Expr) -> Expr {
171    expr.transform(|expr| {
172        Ok({
173            if let Expr::Column(c) = expr {
174                let col = Column::new_unqualified(c.name);
175                Transformed::yes(Expr::Column(col))
176            } else {
177                Transformed::no(expr)
178            }
179        })
180    })
181    .data()
182    .expect("Unnormalize is infallible")
183}
184
185/// Create a Column from the Scalar Expr
186pub fn create_col_from_scalar_expr(
187    scalar_expr: &Expr,
188    subqry_alias: String,
189) -> Result<Column> {
190    match scalar_expr {
191        Expr::Alias(Alias { name, .. }) => Ok(Column::new(
192            Some::<TableReference>(subqry_alias.into()),
193            name,
194        )),
195        Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
196        _ => {
197            let scalar_column = scalar_expr.schema_name().to_string();
198            Ok(Column::new(
199                Some::<TableReference>(subqry_alias.into()),
200                scalar_column,
201            ))
202        }
203    }
204}
205
206/// Recursively un-normalize all [`Column`] expressions in a list of expression trees
207#[inline]
208pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
209    exprs.into_iter().map(unnormalize_col).collect()
210}
211
212/// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column
213/// in the expression tree.
214pub fn strip_outer_reference(expr: Expr) -> Expr {
215    expr.transform(|expr| {
216        Ok({
217            if let Expr::OuterReferenceColumn(_, col) = expr {
218                //todo: what if this col collides with a lambda parameter?
219                Transformed::yes(Expr::Column(col))
220            } else {
221                Transformed::no(expr)
222            }
223        })
224    })
225    .data()
226    .expect("strip_outer_reference is infallible")
227}
228
229/// Returns plan with expressions coerced to types compatible with
230/// schema types
231pub fn coerce_plan_expr_for_schema(
232    plan: LogicalPlan,
233    schema: &DFSchema,
234) -> Result<LogicalPlan> {
235    match plan {
236        // special case Projection to avoid adding multiple projections
237        LogicalPlan::Projection(Projection { expr, input, .. }) => {
238            let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
239            let projection = Projection::try_new(new_exprs, input)?;
240            Ok(LogicalPlan::Projection(projection))
241        }
242        _ => {
243            let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
244            let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
245            let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
246            if add_project {
247                let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
248                Ok(LogicalPlan::Projection(projection))
249            } else {
250                Ok(plan)
251            }
252        }
253    }
254}
255
256fn coerce_exprs_for_schema(
257    exprs: Vec<Expr>,
258    src_schema: &DFSchema,
259    dst_schema: &DFSchema,
260) -> Result<Vec<Expr>> {
261    exprs
262        .into_iter()
263        .enumerate()
264        .map(|(idx, expr)| {
265            let new_type = dst_schema.field(idx).data_type();
266            if new_type != &expr.get_type(src_schema)? {
267                match expr {
268                    Expr::Alias(Alias { expr, name, .. }) => {
269                        Ok(expr.cast_to(new_type, src_schema)?.alias(name))
270                    }
271                    #[expect(deprecated)]
272                    Expr::Wildcard { .. } => Ok(expr),
273                    _ => expr.cast_to(new_type, src_schema),
274                }
275            } else {
276                Ok(expr)
277            }
278        })
279        .collect::<Result<_>>()
280}
281
282/// Recursively un-alias an expressions
283#[inline]
284pub fn unalias(expr: Expr) -> Expr {
285    match expr {
286        Expr::Alias(Alias { expr, .. }) => unalias(*expr),
287        _ => expr,
288    }
289}
290
291/// Handles ensuring the name of rewritten expressions is not changed.
292///
293/// This is important when optimizing plans to ensure the output
294/// schema of plan nodes don't change after optimization.
295/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
296/// expression should be preserved: `3 as "1 + 2"`
297///
298/// See <https://github.com/apache/datafusion/issues/3555> for details
299pub struct NamePreserver {
300    use_alias: bool,
301}
302
303/// If the qualified name of an expression is remembered, it will be preserved
304/// when rewriting the expression
305#[derive(Debug)]
306pub enum SavedName {
307    /// Saved qualified name to be preserved
308    Saved {
309        relation: Option<TableReference>,
310        name: String,
311    },
312    /// Name is not preserved
313    None,
314}
315
316impl NamePreserver {
317    /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
318    pub fn new(plan: &LogicalPlan) -> Self {
319        Self {
320            // The expressions of these plans do not contribute to their output schema,
321            // so there is no need to preserve expression names to prevent a schema change.
322            use_alias: !matches!(
323                plan,
324                LogicalPlan::Filter(_)
325                    | LogicalPlan::Join(_)
326                    | LogicalPlan::TableScan(_)
327                    | LogicalPlan::Limit(_)
328                    | LogicalPlan::Statement(_)
329            ),
330        }
331    }
332
333    /// Create a new NamePreserver for rewriting the `expr`s in `Projection`
334    ///
335    /// This will use aliases
336    pub fn new_for_projection() -> Self {
337        Self { use_alias: true }
338    }
339
340    pub fn save(&self, expr: &Expr) -> SavedName {
341        if self.use_alias {
342            let (relation, name) = expr.qualified_name();
343            SavedName::Saved { relation, name }
344        } else {
345            SavedName::None
346        }
347    }
348}
349
350impl SavedName {
351    /// Ensures the qualified name of the rewritten expression is preserved
352    pub fn restore(self, expr: Expr) -> Expr {
353        match self {
354            SavedName::Saved { relation, name } => {
355                let (new_relation, new_name) = expr.qualified_name();
356                if new_relation != relation || new_name != name {
357                    expr.alias_qualified(relation, name)
358                } else {
359                    expr
360                }
361            }
362            SavedName::None => expr,
363        }
364    }
365}
366
367#[cfg(test)]
368mod test {
369    use std::ops::Add;
370
371    use super::*;
372    use crate::literal::lit_with_metadata;
373    use crate::{col, lit, Cast};
374    use arrow::datatypes::{DataType, Field, Schema};
375    use datafusion_common::tree_node::TreeNodeRewriter;
376    use datafusion_common::ScalarValue;
377
378    #[derive(Default)]
379    struct RecordingRewriter {
380        v: Vec<String>,
381    }
382
383    impl TreeNodeRewriter for RecordingRewriter {
384        type Node = Expr;
385
386        fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
387            self.v.push(format!("Previsited {expr}"));
388            Ok(Transformed::no(expr))
389        }
390
391        fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
392            self.v.push(format!("Mutated {expr}"));
393            Ok(Transformed::no(expr))
394        }
395    }
396
397    #[test]
398    fn rewriter_rewrite() {
399        // rewrites all "foo" string literals to "bar"
400        let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
401            match expr {
402                Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => {
403                    let utf8_val = if utf8_val == "foo" {
404                        "bar".to_string()
405                    } else {
406                        utf8_val
407                    };
408                    Ok(Transformed::yes(lit_with_metadata(utf8_val, metadata)))
409                }
410                // otherwise, return None
411                _ => Ok(Transformed::no(expr)),
412            }
413        };
414
415        // rewrites "foo" --> "bar"
416        let rewritten = col("state")
417            .eq(lit("foo"))
418            .transform(transformer)
419            .data()
420            .unwrap();
421        assert_eq!(rewritten, col("state").eq(lit("bar")));
422
423        // doesn't rewrite
424        let rewritten = col("state")
425            .eq(lit("baz"))
426            .transform(transformer)
427            .data()
428            .unwrap();
429        assert_eq!(rewritten, col("state").eq(lit("baz")));
430    }
431
432    #[test]
433    fn normalize_cols() {
434        let expr = col("a") + col("b") + col("c");
435
436        // Schemas with some matching and some non matching cols
437        let schema_a = make_schema_with_empty_metadata(
438            vec![Some("tableA".into()), Some("tableA".into())],
439            vec!["a", "aa"],
440        );
441        let schema_c = make_schema_with_empty_metadata(
442            vec![Some("tableC".into()), Some("tableC".into())],
443            vec!["cc", "c"],
444        );
445        let schema_b =
446            make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
447        // non matching
448        let schema_f = make_schema_with_empty_metadata(
449            vec![Some("tableC".into()), Some("tableC".into())],
450            vec!["f", "ff"],
451        );
452        let schemas = [schema_c, schema_f, schema_b, schema_a];
453        let schemas = schemas.iter().collect::<Vec<_>>();
454
455        let normalized_expr =
456            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
457                .unwrap();
458        assert_eq!(
459            normalized_expr,
460            col("tableA.a") + col("tableB.b") + col("tableC.c")
461        );
462    }
463
464    #[test]
465    fn normalize_cols_non_exist() {
466        // test normalizing columns when the name doesn't exist
467        let expr = col("a") + col("b");
468        let schema_a =
469            make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
470        let schemas = [schema_a];
471        let schemas = schemas.iter().collect::<Vec<_>>();
472
473        let error =
474            normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
475                .unwrap_err()
476                .strip_backtrace();
477        let expected = "Schema error: No field named b. \
478            Valid fields are \"tableA\".a.";
479        assert_eq!(error, expected);
480    }
481
482    #[test]
483    fn unnormalize_cols() {
484        let expr = col("tableA.a") + col("tableB.b");
485        let unnormalized_expr = unnormalize_col(expr);
486        assert_eq!(unnormalized_expr, col("a") + col("b"));
487    }
488
489    fn make_schema_with_empty_metadata(
490        qualifiers: Vec<Option<TableReference>>,
491        fields: Vec<&str>,
492    ) -> DFSchema {
493        let fields = fields
494            .iter()
495            .map(|f| Arc::new(Field::new((*f).to_string(), DataType::Int8, false)))
496            .collect::<Vec<_>>();
497        let schema = Arc::new(Schema::new(fields));
498        DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
499    }
500
501    #[test]
502    fn rewriter_visit() {
503        let mut rewriter = RecordingRewriter::default();
504        col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
505
506        assert_eq!(
507            rewriter.v,
508            vec![
509                "Previsited state = Utf8(\"CO\")",
510                "Previsited state",
511                "Mutated state",
512                "Previsited Utf8(\"CO\")",
513                "Mutated Utf8(\"CO\")",
514                "Mutated state = Utf8(\"CO\")"
515            ]
516        )
517    }
518
519    #[test]
520    fn test_rewrite_preserving_name() {
521        test_rewrite(col("a"), col("a"));
522
523        test_rewrite(col("a"), col("b"));
524
525        // cast data types
526        test_rewrite(
527            col("a"),
528            Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
529        );
530
531        // change literal type from i32 to i64
532        test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
533
534        // test preserve qualifier
535        test_rewrite(
536            Expr::Column(Column::new(Some("test"), "a")),
537            Expr::Column(Column::new_unqualified("test.a")),
538        );
539        test_rewrite(
540            Expr::Column(Column::new_unqualified("test.a")),
541            Expr::Column(Column::new(Some("test"), "a")),
542        );
543    }
544
545    /// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name
546    /// by using the `NamePreserver`
547    fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
548        struct TestRewriter {
549            rewrite_to: Expr,
550        }
551
552        impl TreeNodeRewriter for TestRewriter {
553            type Node = Expr;
554
555            fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
556                Ok(Transformed::yes(self.rewrite_to.clone()))
557            }
558        }
559
560        let mut rewriter = TestRewriter {
561            rewrite_to: rewrite_to.clone(),
562        };
563        let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
564        let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
565        let new_expr = saved_name.restore(new_expr);
566
567        let original_name = expr_from.qualified_name();
568        let new_name = new_expr.qualified_name();
569        assert_eq!(
570            original_name, new_name,
571            "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
572        )
573    }
574}