datafusion_expr/expr_rewriter/
mod.rs1use std::collections::HashMap;
21use std::collections::HashSet;
22use std::fmt::Debug;
23use std::sync::Arc;
24
25use crate::expr::{Alias, Sort, Unnest};
26use crate::logical_plan::Projection;
27use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::TableReference;
32use datafusion_common::{Column, DFSchema, Result};
33
34mod order_by;
35pub use order_by::rewrite_sort_cols_by_aggs;
36
37pub trait FunctionRewrite: Debug {
47 fn name(&self) -> &str;
49
50 fn rewrite(
55 &self,
56 expr: Expr,
57 schema: &DFSchema,
58 config: &ConfigOptions,
59 ) -> Result<Transformed<Expr>>;
60}
61
62pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
65 expr.transform_up_with_lambdas_params(|expr, lambdas_params| {
66 Ok({
67 if let Expr::Column(c) = expr {
68 if c.relation.is_some() || !lambdas_params.contains(c.name()) {
69 let col = LogicalPlanBuilder::normalize(plan, c)?;
70 Transformed::yes(Expr::Column(col))
71 } else {
72 Transformed::no(Expr::Column(c))
73 }
74 } else {
75 Transformed::no(expr)
76 }
77 })
78 })
79 .data()
80}
81
82pub fn normalize_col_with_schemas_and_ambiguity_check(
84 expr: Expr,
85 schemas: &[&[&DFSchema]],
86 using_columns: &[HashSet<Column>],
87) -> Result<Expr> {
88 if let Expr::Unnest(Unnest { expr }) = expr {
90 let e = normalize_col_with_schemas_and_ambiguity_check(
91 expr.as_ref().clone(),
92 schemas,
93 using_columns,
94 )?;
95 return Ok(Expr::Unnest(Unnest { expr: Box::new(e) }));
96 }
97
98 expr.transform_up_with_lambdas_params(|expr, lambdas_params| {
99 Ok({
100 match expr {
101 Expr::Column(c) => {
102 if c.relation.is_none() && lambdas_params.contains(c.name()) {
103 Transformed::no(Expr::Column(c))
104 } else {
105 let col = c.normalize_with_schemas_and_ambiguity_check(
106 schemas,
107 using_columns,
108 )?;
109 Transformed::yes(Expr::Column(col))
110 }
111 }
112 _ => Transformed::no(expr),
113 }
114 })
115 })
116 .data()
117}
118
119pub fn normalize_cols(
121 exprs: impl IntoIterator<Item = impl Into<Expr>>,
122 plan: &LogicalPlan,
123) -> Result<Vec<Expr>> {
124 exprs
125 .into_iter()
126 .map(|e| normalize_col(e.into(), plan))
127 .collect()
128}
129
130pub fn normalize_sorts(
131 sorts: impl IntoIterator<Item = impl Into<Sort>>,
132 plan: &LogicalPlan,
133) -> Result<Vec<Sort>> {
134 sorts
135 .into_iter()
136 .map(|e| {
137 let sort = e.into();
138 normalize_col(sort.expr, plan)
139 .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first))
140 })
141 .collect()
142}
143
144pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
147 expr.transform_up_with_lambdas_params(|expr, lambdas_params| {
148 Ok({
149 match &expr {
150 Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
151 match replace_map.get(c) {
152 Some(new_c) => {
153 Transformed::yes(Expr::Column((*new_c).to_owned()))
154 }
155 None => Transformed::no(expr),
156 }
157 }
158 _ => Transformed::no(expr),
159 }
160 })
161 })
162 .data()
163}
164
165pub fn unnormalize_col(expr: Expr) -> Expr {
171 expr.transform(|expr| {
172 Ok({
173 if let Expr::Column(c) = expr {
174 let col = Column::new_unqualified(c.name);
175 Transformed::yes(Expr::Column(col))
176 } else {
177 Transformed::no(expr)
178 }
179 })
180 })
181 .data()
182 .expect("Unnormalize is infallible")
183}
184
185pub fn create_col_from_scalar_expr(
187 scalar_expr: &Expr,
188 subqry_alias: String,
189) -> Result<Column> {
190 match scalar_expr {
191 Expr::Alias(Alias { name, .. }) => Ok(Column::new(
192 Some::<TableReference>(subqry_alias.into()),
193 name,
194 )),
195 Expr::Column(col) => Ok(col.with_relation(subqry_alias.into())),
196 _ => {
197 let scalar_column = scalar_expr.schema_name().to_string();
198 Ok(Column::new(
199 Some::<TableReference>(subqry_alias.into()),
200 scalar_column,
201 ))
202 }
203 }
204}
205
206#[inline]
208pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
209 exprs.into_iter().map(unnormalize_col).collect()
210}
211
212pub fn strip_outer_reference(expr: Expr) -> Expr {
215 expr.transform(|expr| {
216 Ok({
217 if let Expr::OuterReferenceColumn(_, col) = expr {
218 Transformed::yes(Expr::Column(col))
220 } else {
221 Transformed::no(expr)
222 }
223 })
224 })
225 .data()
226 .expect("strip_outer_reference is infallible")
227}
228
229pub fn coerce_plan_expr_for_schema(
232 plan: LogicalPlan,
233 schema: &DFSchema,
234) -> Result<LogicalPlan> {
235 match plan {
236 LogicalPlan::Projection(Projection { expr, input, .. }) => {
238 let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?;
239 let projection = Projection::try_new(new_exprs, input)?;
240 Ok(LogicalPlan::Projection(projection))
241 }
242 _ => {
243 let exprs: Vec<Expr> = plan.schema().iter().map(Expr::from).collect();
244 let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?;
245 let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none());
246 if add_project {
247 let projection = Projection::try_new(new_exprs, Arc::new(plan))?;
248 Ok(LogicalPlan::Projection(projection))
249 } else {
250 Ok(plan)
251 }
252 }
253 }
254}
255
256fn coerce_exprs_for_schema(
257 exprs: Vec<Expr>,
258 src_schema: &DFSchema,
259 dst_schema: &DFSchema,
260) -> Result<Vec<Expr>> {
261 exprs
262 .into_iter()
263 .enumerate()
264 .map(|(idx, expr)| {
265 let new_type = dst_schema.field(idx).data_type();
266 if new_type != &expr.get_type(src_schema)? {
267 match expr {
268 Expr::Alias(Alias { expr, name, .. }) => {
269 Ok(expr.cast_to(new_type, src_schema)?.alias(name))
270 }
271 #[expect(deprecated)]
272 Expr::Wildcard { .. } => Ok(expr),
273 _ => expr.cast_to(new_type, src_schema),
274 }
275 } else {
276 Ok(expr)
277 }
278 })
279 .collect::<Result<_>>()
280}
281
282#[inline]
284pub fn unalias(expr: Expr) -> Expr {
285 match expr {
286 Expr::Alias(Alias { expr, .. }) => unalias(*expr),
287 _ => expr,
288 }
289}
290
291pub struct NamePreserver {
300 use_alias: bool,
301}
302
303#[derive(Debug)]
306pub enum SavedName {
307 Saved {
309 relation: Option<TableReference>,
310 name: String,
311 },
312 None,
314}
315
316impl NamePreserver {
317 pub fn new(plan: &LogicalPlan) -> Self {
319 Self {
320 use_alias: !matches!(
323 plan,
324 LogicalPlan::Filter(_)
325 | LogicalPlan::Join(_)
326 | LogicalPlan::TableScan(_)
327 | LogicalPlan::Limit(_)
328 | LogicalPlan::Statement(_)
329 ),
330 }
331 }
332
333 pub fn new_for_projection() -> Self {
337 Self { use_alias: true }
338 }
339
340 pub fn save(&self, expr: &Expr) -> SavedName {
341 if self.use_alias {
342 let (relation, name) = expr.qualified_name();
343 SavedName::Saved { relation, name }
344 } else {
345 SavedName::None
346 }
347 }
348}
349
350impl SavedName {
351 pub fn restore(self, expr: Expr) -> Expr {
353 match self {
354 SavedName::Saved { relation, name } => {
355 let (new_relation, new_name) = expr.qualified_name();
356 if new_relation != relation || new_name != name {
357 expr.alias_qualified(relation, name)
358 } else {
359 expr
360 }
361 }
362 SavedName::None => expr,
363 }
364 }
365}
366
367#[cfg(test)]
368mod test {
369 use std::ops::Add;
370
371 use super::*;
372 use crate::literal::lit_with_metadata;
373 use crate::{col, lit, Cast};
374 use arrow::datatypes::{DataType, Field, Schema};
375 use datafusion_common::tree_node::TreeNodeRewriter;
376 use datafusion_common::ScalarValue;
377
378 #[derive(Default)]
379 struct RecordingRewriter {
380 v: Vec<String>,
381 }
382
383 impl TreeNodeRewriter for RecordingRewriter {
384 type Node = Expr;
385
386 fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
387 self.v.push(format!("Previsited {expr}"));
388 Ok(Transformed::no(expr))
389 }
390
391 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
392 self.v.push(format!("Mutated {expr}"));
393 Ok(Transformed::no(expr))
394 }
395 }
396
397 #[test]
398 fn rewriter_rewrite() {
399 let transformer = |expr: Expr| -> Result<Transformed<Expr>> {
401 match expr {
402 Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => {
403 let utf8_val = if utf8_val == "foo" {
404 "bar".to_string()
405 } else {
406 utf8_val
407 };
408 Ok(Transformed::yes(lit_with_metadata(utf8_val, metadata)))
409 }
410 _ => Ok(Transformed::no(expr)),
412 }
413 };
414
415 let rewritten = col("state")
417 .eq(lit("foo"))
418 .transform(transformer)
419 .data()
420 .unwrap();
421 assert_eq!(rewritten, col("state").eq(lit("bar")));
422
423 let rewritten = col("state")
425 .eq(lit("baz"))
426 .transform(transformer)
427 .data()
428 .unwrap();
429 assert_eq!(rewritten, col("state").eq(lit("baz")));
430 }
431
432 #[test]
433 fn normalize_cols() {
434 let expr = col("a") + col("b") + col("c");
435
436 let schema_a = make_schema_with_empty_metadata(
438 vec![Some("tableA".into()), Some("tableA".into())],
439 vec!["a", "aa"],
440 );
441 let schema_c = make_schema_with_empty_metadata(
442 vec![Some("tableC".into()), Some("tableC".into())],
443 vec!["cc", "c"],
444 );
445 let schema_b =
446 make_schema_with_empty_metadata(vec![Some("tableB".into())], vec!["b"]);
447 let schema_f = make_schema_with_empty_metadata(
449 vec![Some("tableC".into()), Some("tableC".into())],
450 vec!["f", "ff"],
451 );
452 let schemas = [schema_c, schema_f, schema_b, schema_a];
453 let schemas = schemas.iter().collect::<Vec<_>>();
454
455 let normalized_expr =
456 normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
457 .unwrap();
458 assert_eq!(
459 normalized_expr,
460 col("tableA.a") + col("tableB.b") + col("tableC.c")
461 );
462 }
463
464 #[test]
465 fn normalize_cols_non_exist() {
466 let expr = col("a") + col("b");
468 let schema_a =
469 make_schema_with_empty_metadata(vec![Some("\"tableA\"".into())], vec!["a"]);
470 let schemas = [schema_a];
471 let schemas = schemas.iter().collect::<Vec<_>>();
472
473 let error =
474 normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[])
475 .unwrap_err()
476 .strip_backtrace();
477 let expected = "Schema error: No field named b. \
478 Valid fields are \"tableA\".a.";
479 assert_eq!(error, expected);
480 }
481
482 #[test]
483 fn unnormalize_cols() {
484 let expr = col("tableA.a") + col("tableB.b");
485 let unnormalized_expr = unnormalize_col(expr);
486 assert_eq!(unnormalized_expr, col("a") + col("b"));
487 }
488
489 fn make_schema_with_empty_metadata(
490 qualifiers: Vec<Option<TableReference>>,
491 fields: Vec<&str>,
492 ) -> DFSchema {
493 let fields = fields
494 .iter()
495 .map(|f| Arc::new(Field::new((*f).to_string(), DataType::Int8, false)))
496 .collect::<Vec<_>>();
497 let schema = Arc::new(Schema::new(fields));
498 DFSchema::from_field_specific_qualified_schema(qualifiers, &schema).unwrap()
499 }
500
501 #[test]
502 fn rewriter_visit() {
503 let mut rewriter = RecordingRewriter::default();
504 col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
505
506 assert_eq!(
507 rewriter.v,
508 vec![
509 "Previsited state = Utf8(\"CO\")",
510 "Previsited state",
511 "Mutated state",
512 "Previsited Utf8(\"CO\")",
513 "Mutated Utf8(\"CO\")",
514 "Mutated state = Utf8(\"CO\")"
515 ]
516 )
517 }
518
519 #[test]
520 fn test_rewrite_preserving_name() {
521 test_rewrite(col("a"), col("a"));
522
523 test_rewrite(col("a"), col("b"));
524
525 test_rewrite(
527 col("a"),
528 Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)),
529 );
530
531 test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));
533
534 test_rewrite(
536 Expr::Column(Column::new(Some("test"), "a")),
537 Expr::Column(Column::new_unqualified("test.a")),
538 );
539 test_rewrite(
540 Expr::Column(Column::new_unqualified("test.a")),
541 Expr::Column(Column::new(Some("test"), "a")),
542 );
543 }
544
545 fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
548 struct TestRewriter {
549 rewrite_to: Expr,
550 }
551
552 impl TreeNodeRewriter for TestRewriter {
553 type Node = Expr;
554
555 fn f_up(&mut self, _: Expr) -> Result<Transformed<Expr>> {
556 Ok(Transformed::yes(self.rewrite_to.clone()))
557 }
558 }
559
560 let mut rewriter = TestRewriter {
561 rewrite_to: rewrite_to.clone(),
562 };
563 let saved_name = NamePreserver { use_alias: true }.save(&expr_from);
564 let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
565 let new_expr = saved_name.restore(new_expr);
566
567 let original_name = expr_from.qualified_name();
568 let new_name = new_expr.qualified_name();
569 assert_eq!(
570 original_name, new_name,
571 "mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
572 )
573 }
574}