datafusion_sql/unparser/rewrite.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{collections::HashSet, sync::Arc};
19
20use arrow::datatypes::Schema;
21use datafusion_common::tree_node::TreeNodeContainer;
22use datafusion_common::{
23 tree_node::{Transformed, TransformedResult, TreeNode},
24 Column, HashMap, Result, TableReference,
25};
26use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX};
27use datafusion_expr::tree_node::TreeNodeRewriterWithPayload;
28use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
29use sqlparser::ast::Ident;
30
31/// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions.
32///
33/// DataFusion will return an error if two columns in the schema have the same name with no table qualifiers.
34/// There are certain types of UNION queries that can result in having two columns with the same name, and the
35/// solution was to add table qualifiers to the schema fields.
36/// See <https://github.com/apache/datafusion/issues/5410> for more context on this decision.
37///
38/// However, this causes a problem when unparsing these queries back to SQL - as the table qualifier has
39/// logically been erased and is no longer a valid reference.
40///
41/// The following input SQL:
42/// ```sql
43/// SELECT table1.foo FROM table1
44/// UNION ALL
45/// SELECT table2.foo FROM table2
46/// ORDER BY foo
47/// ```
48///
49/// Would be unparsed into the following invalid SQL without this transformation:
50/// ```sql
51/// SELECT table1.foo FROM table1
52/// UNION ALL
53/// SELECT table2.foo FROM table2
54/// ORDER BY table1.foo
55/// ```
56///
57/// Which would result in a SQL error, as `table1.foo` is not a valid reference in the context of the UNION.
58pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result<LogicalPlan> {
59 let plan = plan.clone();
60
61 let transformed_plan = plan.transform_up(|plan| match plan {
62 LogicalPlan::Union(mut union) => {
63 let schema = Arc::unwrap_or_clone(union.schema);
64 let schema = schema.strip_qualifiers();
65
66 union.schema = Arc::new(schema);
67 Ok(Transformed::yes(LogicalPlan::Union(union)))
68 }
69 LogicalPlan::Sort(sort) => {
70 // Only rewrite Sort expressions that have a UNION as their input
71 if !matches!(&*sort.input, LogicalPlan::Union(_)) {
72 return Ok(Transformed::no(LogicalPlan::Sort(sort)));
73 }
74
75 Ok(Transformed::yes(LogicalPlan::Sort(Sort {
76 expr: rewrite_sort_expr_for_union(sort.expr)?,
77 input: sort.input,
78 fetch: sort.fetch,
79 })))
80 }
81 _ => Ok(Transformed::no(plan)),
82 });
83 transformed_plan.data()
84}
85
86/// Rewrite sort expressions that have a UNION plan as their input to remove the table reference.
87fn rewrite_sort_expr_for_union(exprs: Vec<SortExpr>) -> Result<Vec<SortExpr>> {
88 let sort_exprs = exprs
89 .map_elements(&mut |expr: Expr| {
90 expr.transform_up(|expr| {
91 if let Expr::Column(mut col) = expr {
92 col.relation = None;
93 Ok(Transformed::yes(Expr::Column(col)))
94 } else {
95 Ok(Transformed::no(expr))
96 }
97 })
98 })
99 .data()?;
100
101 Ok(sort_exprs)
102}
103
104/// Rewrite Filter plans that have a Window as their input by inserting a SubqueryAlias.
105///
106/// When a Filter directly operates on a Window plan, it can cause issues during SQL unparsing
107/// because window functions in a WHERE clause are not valid SQL. The solution is to wrap
108/// the Window plan in a SubqueryAlias, effectively creating a derived table.
109///
110/// Example transformation:
111///
112/// Filter: condition
113/// Window: window_function
114/// TableScan: table
115///
116/// becomes:
117///
118/// Filter: condition
119/// SubqueryAlias: __qualify_subquery
120/// Projection: table.column1, table.column2
121/// Window: window_function
122/// TableScan: table
123///
124pub(super) fn rewrite_qualify(plan: LogicalPlan) -> Result<LogicalPlan> {
125 let transformed_plan = plan.transform_up(|plan| match plan {
126 // Check if the filter's input is a Window plan
127 LogicalPlan::Filter(mut filter) => {
128 if matches!(&*filter.input, LogicalPlan::Window(_)) {
129 // Create a SubqueryAlias around the Window plan
130 let qualifier = filter
131 .input
132 .schema()
133 .iter()
134 .find_map(|(q, _)| q)
135 .map(|q| q.to_string())
136 .unwrap_or_else(|| "__qualify_subquery".to_string());
137
138 // for Postgres, name of column for 'rank() over (...)' is 'rank'
139 // but in Datafusion, it is 'rank() over (...)'
140 // without projection, it's still an invalid sql in Postgres
141
142 let project_exprs = filter
143 .input
144 .schema()
145 .iter()
146 .map(|(_, f)| datafusion_expr::col(f.name()).alias(f.name()))
147 .collect::<Vec<_>>();
148
149 let input =
150 datafusion_expr::LogicalPlanBuilder::from(Arc::clone(&filter.input))
151 .project(project_exprs)?
152 .build()?;
153
154 let subquery_alias =
155 datafusion_expr::SubqueryAlias::try_new(Arc::new(input), qualifier)?;
156
157 filter.input = Arc::new(LogicalPlan::SubqueryAlias(subquery_alias));
158 Ok(Transformed::yes(LogicalPlan::Filter(filter)))
159 } else {
160 Ok(Transformed::no(LogicalPlan::Filter(filter)))
161 }
162 }
163
164 _ => Ok(Transformed::no(plan)),
165 });
166
167 transformed_plan.data()
168}
169
170/// Rewrite logic plan for query that order by columns are not in projections
171/// Plan before rewrite:
172///
173/// Projection: j1.j1_string, j2.j2_string
174/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST
175/// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id
176/// Inner Join: Filter: j1.j1_id = j2.j2_id
177/// TableScan: j1
178/// TableScan: j2
179///
180/// Plan after rewrite
181///
182/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST
183/// Projection: j1.j1_string, j2.j2_string
184/// Inner Join: Filter: j1.j1_id = j2.j2_id
185/// TableScan: j1
186/// TableScan: j2
187///
188/// This prevents the original plan generate query with derived table but missing alias.
189pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
190 p: &Projection,
191) -> Option<LogicalPlan> {
192 let LogicalPlan::Sort(sort) = p.input.as_ref() else {
193 return None;
194 };
195
196 let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else {
197 return None;
198 };
199
200 let mut map = HashMap::new();
201 let inner_exprs = inner_p
202 .expr
203 .iter()
204 .enumerate()
205 .map(|(i, f)| match f {
206 Expr::Alias(alias) => {
207 let a = Expr::Column(alias.name.clone().into());
208 map.insert(a.clone(), f.clone());
209 a
210 }
211 Expr::Column(_) => {
212 map.insert(
213 Expr::Column(inner_p.schema.field(i).name().into()),
214 f.clone(),
215 );
216 f.clone()
217 }
218 _ => {
219 let a = Expr::Column(inner_p.schema.field(i).name().into());
220 map.insert(a.clone(), f.clone());
221 a
222 }
223 })
224 .collect::<Vec<_>>();
225
226 let mut collects = p.expr.clone();
227 for sort in &sort.expr {
228 collects.push(sort.expr.clone());
229 }
230
231 // Compare outer collects Expr::to_string with inner collected transformed values
232 // alias -> alias column
233 // column -> remain
234 // others, extract schema field name
235 let outer_collects = collects.iter().map(Expr::to_string).collect::<HashSet<_>>();
236 let inner_collects = inner_exprs
237 .iter()
238 .map(Expr::to_string)
239 .collect::<HashSet<_>>();
240
241 if outer_collects == inner_collects {
242 let mut sort = sort.clone();
243 let mut inner_p = inner_p.clone();
244
245 let new_exprs = p
246 .expr
247 .iter()
248 .map(|e| map.get(e).unwrap_or(e).clone())
249 .collect::<Vec<_>>();
250
251 inner_p.expr.clone_from(&new_exprs);
252 sort.input = Arc::new(LogicalPlan::Projection(inner_p));
253
254 Some(LogicalPlan::Sort(sort))
255 } else {
256 None
257 }
258}
259
260/// This logic is to work out the columns and inner query for SubqueryAlias plan for some types of
261/// subquery or unnest
262/// - `(SELECT column_a as a from table) AS A`
263/// - `(SELECT column_a from table) AS A (a)`
264/// - `SELECT * FROM t1 CROSS JOIN UNNEST(t1.c1) AS u(c1)` (see [find_unnest_column_alias])
265///
266/// A roundtrip example for table alias with columns
267///
268/// query: SELECT id FROM (SELECT j1_id from j1) AS c (id)
269///
270/// LogicPlan:
271/// Projection: c.id
272/// SubqueryAlias: c
273/// Projection: j1.j1_id AS id
274/// Projection: j1.j1_id
275/// TableScan: j1
276///
277/// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS
278/// id FROM (SELECT j1.j1_id FROM j1)) AS c`.
279/// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table
280/// `(SELECT j1.j1_id FROM j1)`
281///
282/// With this logic, the unparsed query will be:
283/// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)`
284///
285/// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)`
286/// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and
287/// Column in the Projections. Once the parser side is fixed, this logic should work
288pub(super) fn subquery_alias_inner_query_and_columns(
289 subquery_alias: &datafusion_expr::SubqueryAlias,
290) -> (&LogicalPlan, Vec<Ident>) {
291 let plan: &LogicalPlan = subquery_alias.input.as_ref();
292
293 if let LogicalPlan::Subquery(subquery) = plan {
294 let (inner_projection, Some(column)) =
295 find_unnest_column_alias(subquery.subquery.as_ref())
296 else {
297 return (plan, vec![]);
298 };
299 return (inner_projection, vec![Ident::new(column)]);
300 }
301
302 let LogicalPlan::Projection(outer_projections) = plan else {
303 return (plan, vec![]);
304 };
305
306 // Check if it's projection inside projection
307 let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else {
308 return (plan, vec![]);
309 };
310
311 let mut columns: Vec<Ident> = vec![];
312 // Check if the inner projection and outer projection have a matching pattern like
313 // Projection: j1.j1_id AS id
314 // Projection: j1.j1_id
315 for (i, inner_expr) in inner_projection.expr.iter().enumerate() {
316 let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else {
317 return (plan, vec![]);
318 };
319
320 // Inner projection schema fields store the projection name which is used in outer
321 // projection expr
322 let inner_expr_string = match inner_expr {
323 Expr::Column(_) => inner_expr.to_string(),
324 _ => inner_projection.schema.field(i).name().clone(),
325 };
326
327 if outer_alias.expr.to_string() != inner_expr_string {
328 return (plan, vec![]);
329 };
330
331 columns.push(outer_alias.name.as_str().into());
332 }
333
334 (outer_projections.input.as_ref(), columns)
335}
336
337/// Try to find the column alias for UNNEST in the inner projection.
338/// For example:
339/// ```sql
340/// SELECT * FROM t1 CROSS JOIN UNNEST(t1.c1) AS u(c1)
341/// ```
342/// The above query will be parsed into the following plan:
343/// ```text
344/// Projection: *
345/// Cross Join:
346/// SubqueryAlias: t1
347/// TableScan: t
348/// SubqueryAlias: u
349/// Subquery:
350/// Projection: UNNEST(outer_ref(t1.c1)) AS c1
351/// Projection: __unnest_placeholder(outer_ref(t1.c1),depth=1) AS UNNEST(outer_ref(t1.c1))
352/// Unnest: lists[__unnest_placeholder(outer_ref(t1.c1))|depth=1] structs[]
353/// Projection: outer_ref(t1.c1) AS __unnest_placeholder(outer_ref(t1.c1))
354/// EmptyRelation
355/// ```
356/// The function will return the inner projection and the column alias `c1` if the column name
357/// starts with `UNNEST(` (the `Display` result of [Expr::Unnest]) in the inner projection.
358pub(super) fn find_unnest_column_alias(
359 plan: &LogicalPlan,
360) -> (&LogicalPlan, Option<String>) {
361 if let LogicalPlan::Projection(projection) = plan {
362 if projection.expr.len() != 1 {
363 return (plan, None);
364 }
365 if let Some(Expr::Alias(alias)) = projection.expr.first() {
366 if alias
367 .expr
368 .schema_name()
369 .to_string()
370 .starts_with(&format!("{UNNEST_COLUMN_PREFIX}("))
371 {
372 return (projection.input.as_ref(), Some(alias.name.clone()));
373 }
374 }
375 }
376 (plan, None)
377}
378
379/// Injects column aliases into a subquery's logical plan. The function searches for a `Projection`
380/// within the given plan, which may be wrapped by other operators (e.g., LIMIT, SORT).
381/// If the top-level plan is a `Projection`, it directly injects the column aliases.
382/// Otherwise, it iterates through the plan's children to locate and transform the `Projection`.
383///
384/// Example:
385/// - `SELECT col1, col2 FROM table LIMIT 10` plan with aliases `["alias_1", "some_alias_2"]` will be transformed to
386/// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table LIMIT 10`
387pub(super) fn inject_column_aliases_into_subquery(
388 plan: LogicalPlan,
389 aliases: Vec<Ident>,
390) -> Result<LogicalPlan> {
391 match &plan {
392 LogicalPlan::Projection(inner_p) => Ok(inject_column_aliases(inner_p, aliases)),
393 _ => {
394 // projection is wrapped by other operator (LIMIT, SORT, etc), iterate through the plan to find it
395 plan.map_children(|child| {
396 if let LogicalPlan::Projection(p) = &child {
397 Ok(Transformed::yes(inject_column_aliases(p, aliases.clone())))
398 } else {
399 Ok(Transformed::no(child))
400 }
401 })
402 .map(|plan| plan.data)
403 }
404 }
405}
406
407/// Injects column aliases into the projection of a logical plan by wrapping expressions
408/// with `Expr::Alias` using the provided list of aliases.
409///
410/// Example:
411/// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to
412/// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table`
413pub(super) fn inject_column_aliases(
414 projection: &Projection,
415 aliases: impl IntoIterator<Item = Ident>,
416) -> LogicalPlan {
417 let mut updated_projection = projection.clone();
418
419 let new_exprs = updated_projection
420 .expr
421 .into_iter()
422 .zip(aliases)
423 .map(|(expr, col_alias)| {
424 let relation = match &expr {
425 Expr::Column(col) => col.relation.clone(),
426 _ => None,
427 };
428
429 Expr::Alias(Alias {
430 expr: Box::new(expr.clone()),
431 relation,
432 name: col_alias.value,
433 metadata: None,
434 })
435 })
436 .collect::<Vec<_>>();
437
438 updated_projection.expr = new_exprs;
439
440 LogicalPlan::Projection(updated_projection)
441}
442
443fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> {
444 match logical_plan {
445 LogicalPlan::Projection(p) => Some(p),
446 LogicalPlan::Limit(p) => find_projection(p.input.as_ref()),
447 LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()),
448 LogicalPlan::Sort(p) => find_projection(p.input.as_ref()),
449 _ => None,
450 }
451}
452
453/// A `TreeNodeRewriter` implementation that rewrites `Expr::Column` expressions by
454/// replacing the column's name with an alias if the column exists in the provided schema.
455///
456/// This is typically used to apply table aliases in query plans, ensuring that
457/// the column references in the expressions use the correct table alias.
458///
459/// # Fields
460///
461/// * `table_schema`: The schema (`SchemaRef`) representing the table structure
462/// from which the columns are referenced. This is used to look up columns by their names.
463/// * `alias_name`: The alias (`TableReference`) that will replace the table name
464/// in the column references when applicable.
465pub struct TableAliasRewriter<'a> {
466 pub table_schema: &'a Schema,
467 pub alias_name: TableReference,
468}
469
470impl TreeNodeRewriterWithPayload for TableAliasRewriter<'_> {
471 type Node = Expr;
472 type Payload<'a> = &'a datafusion_common::HashSet<String>;
473
474 fn f_down(
475 &mut self,
476 expr: Expr,
477 lambdas_params: &datafusion_common::HashSet<String>,
478 ) -> Result<Transformed<Expr>> {
479 match expr {
480 Expr::Column(column) if !column.is_lambda_parameter(lambdas_params) => {
481 if let Ok(field) = self.table_schema.field_with_name(&column.name) {
482 let new_column =
483 Column::new(Some(self.alias_name.clone()), field.name().clone());
484 Ok(Transformed::yes(Expr::Column(new_column)))
485 } else {
486 Ok(Transformed::no(Expr::Column(column)))
487 }
488 }
489 _ => Ok(Transformed::no(expr)),
490 }
491 }
492}