datafusion_sql/unparser/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{cmp::Ordering, sync::Arc, vec};
19
20use super::{
21    dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle,
22    rewrite::TableAliasRewriter, Unparser,
23};
24use datafusion_common::{
25    internal_err,
26    tree_node::{Transformed, TransformedResult, TreeNode},
27    Column, DataFusionError, Result, ScalarValue,
28};
29use datafusion_expr::{
30    expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan,
31    LogicalPlanBuilder, Projection, SortExpr, Unnest, Window,
32};
33
34use indexmap::IndexSet;
35use sqlparser::ast;
36use sqlparser::tokenizer::Span;
37
38/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
39/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
40/// If an Aggregate or node is not found prior to this or at all before reaching the end
41/// of the tree, None is returned.
42pub(crate) fn find_agg_node_within_select(
43    plan: &LogicalPlan,
44    already_projected: bool,
45) -> Option<&Aggregate> {
46    // Note that none of the nodes that have a corresponding node can have more
47    // than 1 input node. E.g. Projection / Filter always have 1 input node.
48    let input = plan.inputs();
49    let input = if input.len() > 1 {
50        return None;
51    } else {
52        input.first()?
53    };
54    // Agg nodes explicitly return immediately with a single node
55    if let LogicalPlan::Aggregate(agg) = input {
56        Some(agg)
57    } else if let LogicalPlan::TableScan(_) = input {
58        None
59    } else if let LogicalPlan::Projection(_) = input {
60        if already_projected {
61            None
62        } else {
63            find_agg_node_within_select(input, true)
64        }
65    } else {
66        find_agg_node_within_select(input, already_projected)
67    }
68}
69
70/// Recursively searches children of [LogicalPlan] to find Unnest node if exist
71pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> {
72    // Note that none of the nodes that have a corresponding node can have more
73    // than 1 input node. E.g. Projection / Filter always have 1 input node.
74    let input = plan.inputs();
75    let input = if input.len() > 1 {
76        return None;
77    } else {
78        input.first()?
79    };
80
81    if let LogicalPlan::Unnest(unnest) = input {
82        Some(unnest)
83    } else if let LogicalPlan::TableScan(_) = input {
84        None
85    } else if let LogicalPlan::Projection(_) = input {
86        None
87    } else {
88        find_unnest_node_within_select(input)
89    }
90}
91
92/// Recursively searches children of [LogicalPlan] to find Unnest node if exist
93/// until encountering a Relation node with single input
94pub(crate) fn find_unnest_node_until_relation(plan: &LogicalPlan) -> Option<&Unnest> {
95    // Note that none of the nodes that have a corresponding node can have more
96    // than 1 input node. E.g. Projection / Filter always have 1 input node.
97    let input = plan.inputs();
98    let input = if input.len() > 1 {
99        return None;
100    } else {
101        input.first()?
102    };
103
104    if let LogicalPlan::Unnest(unnest) = input {
105        Some(unnest)
106    } else if let LogicalPlan::TableScan(_) = input {
107        None
108    } else if let LogicalPlan::Subquery(_) = input {
109        None
110    } else if let LogicalPlan::SubqueryAlias(_) = input {
111        None
112    } else {
113        find_unnest_node_within_select(input)
114    }
115}
116
117/// Recursively searches children of [LogicalPlan] to find Window nodes if exist
118/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
119/// If Window node is not found prior to this or at all before reaching the end
120/// of the tree, None is returned.
121pub(crate) fn find_window_nodes_within_select<'a>(
122    plan: &'a LogicalPlan,
123    mut prev_windows: Option<Vec<&'a Window>>,
124    already_projected: bool,
125) -> Option<Vec<&'a Window>> {
126    // Note that none of the nodes that have a corresponding node can have more
127    // than 1 input node. E.g. Projection / Filter always have 1 input node.
128    let input = plan.inputs();
129    let input = if input.len() > 1 {
130        return prev_windows;
131    } else {
132        input.first()?
133    };
134
135    // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection
136    match input {
137        LogicalPlan::Window(window) => {
138            prev_windows = match &mut prev_windows {
139                Some(windows) => {
140                    windows.push(window);
141                    prev_windows
142                }
143                _ => Some(vec![window]),
144            };
145            find_window_nodes_within_select(input, prev_windows, already_projected)
146        }
147        LogicalPlan::Projection(_) => {
148            if already_projected {
149                prev_windows
150            } else {
151                find_window_nodes_within_select(input, prev_windows, true)
152            }
153        }
154        LogicalPlan::TableScan(_) => prev_windows,
155        _ => find_window_nodes_within_select(input, prev_windows, already_projected),
156    }
157}
158
159/// Recursively identify Column expressions and transform them into the appropriate unnest expression
160///
161/// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)"
162/// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL])
163pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result<Expr> {
164    expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| {
165            if let Expr::Column(col_ref) = &sub_expr {
166                // Check if the column is among the columns to run unnest on. 
167                // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. 
168                if !col_ref.is_lambda_parameter(lambdas_params) && unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) {
169                    if let Ok(idx) = unnest.schema.index_of_column(col_ref) {
170                        if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() {
171                            if let Some(unprojected_expr) = expr.get(idx) {
172                                let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone()));
173                                return Ok(Transformed::yes(unnest_expr));
174                            }
175                        }
176                    }
177                    return internal_err!(
178                        "Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name
179                    );
180                }
181            }
182
183            Ok(Transformed::no(sub_expr))
184
185        }).map(|e| e.data)
186}
187
188/// Recursively identify all Column expressions and transform them into the appropriate
189/// aggregate expression contained in agg.
190///
191/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
192/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
193pub(crate) fn unproject_agg_exprs(
194    expr: Expr,
195    agg: &Aggregate,
196    windows: Option<&[&Window]>,
197) -> Result<Expr> {
198    expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| {
199            match sub_expr {
200                Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
201                                        Ok(Transformed::yes(unprojected_expr.clone()))
202                                    } else if let Some(unprojected_expr) =
203                                        windows.and_then(|w| find_window_expr(w, &c.name).cloned())
204                                    {
205                                        // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected
206                                        Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?))
207                                    } else {
208                                        internal_err!(
209                                            "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
210                                        )
211                                    },
212                _ => Ok(Transformed::no(sub_expr)),
213            }
214        })
215        .map(|e| e.data)
216}
217
218/// Recursively identify all Column expressions and transform them into the appropriate
219/// window expression contained in window.
220///
221/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed
222/// into an actual window expression as identified in the window node.
223pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result<Expr> {
224    expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| match sub_expr {
225        Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
226            if let Some(unproj) = find_window_expr(windows, &c.name) {
227                Ok(Transformed::yes(unproj.clone()))
228            } else {
229                Ok(Transformed::no(Expr::Column(c)))
230            }
231        }
232        _ => Ok(Transformed::no(sub_expr)),
233    })
234    .map(|e| e.data)
235}
236
237fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
238    if let Ok(index) = agg.schema.index_of_column(column) {
239        if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
240            // For grouping set expr, we must operate by expression list from the grouping set
241            let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?;
242            match index.cmp(&grouping_expr.len()) {
243                Ordering::Less => Ok(grouping_expr.into_iter().nth(index)),
244                Ordering::Equal => {
245                    internal_err!(
246                        "Tried to unproject column referring to internal grouping id"
247                    )
248                }
249                Ordering::Greater => {
250                    Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1))
251                }
252            }
253        } else {
254            Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index))
255        }
256    } else {
257        Ok(None)
258    }
259}
260
261fn find_window_expr<'a>(
262    windows: &'a [&'a Window],
263    column_name: &'a str,
264) -> Option<&'a Expr> {
265    windows
266        .iter()
267        .flat_map(|w| w.window_expr.iter())
268        .find(|expr| expr.schema_name().to_string() == column_name)
269}
270
271/// Transforms all Column expressions in a sort expression into the actual expression from aggregation or projection if found.
272/// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced
273/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to
274/// the actual expression, such as sum("catalog_returns"."cr_net_loss").
275pub(crate) fn unproject_sort_expr(
276    mut sort_expr: SortExpr,
277    agg: Option<&Aggregate>,
278    input: &LogicalPlan,
279) -> Result<SortExpr> {
280    sort_expr.expr = sort_expr
281        .expr
282        .transform(|sub_expr| {
283            match sub_expr {
284                // Remove alias if present, because ORDER BY cannot use aliases
285                Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)),
286                Expr::Column(col) => {
287                    if col.relation.is_some() {
288                        return Ok(Transformed::no(Expr::Column(col)));
289                    }
290
291                    // In case of aggregation there could be columns containing aggregation functions we need to unproject
292                    if let Some(agg) = agg {
293                        if agg.schema.is_column_from_schema(&col) {
294                            return Ok(Transformed::yes(unproject_agg_exprs(
295                                Expr::Column(col),
296                                agg,
297                                None,
298                            )?));
299                        }
300                    }
301
302                    // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will
303                    // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need
304                    // to transform it back to the actual expression.
305                    if let LogicalPlan::Projection(Projection { expr, schema, .. }) =
306                        input
307                    {
308                        if let Ok(idx) = schema.index_of_column(&col) {
309                            if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
310                                return Ok(Transformed::yes(Expr::ScalarFunction(
311                                    scalar_fn.clone(),
312                                )));
313                            }
314                        }
315                    }
316
317                    Ok(Transformed::no(Expr::Column(col)))
318                }
319                _ => Ok(Transformed::no(sub_expr)),
320            }
321        })
322        .map(|e| e.data)?;
323    Ok(sort_expr)
324}
325
326/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering
327/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan.
328/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately.
329/// If the plan contains a Projection, returns None.
330///
331/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias.
332///
333/// LogicalPlan example:
334///   Filter: ta.j1_id < 5
335///     Alias:  ta
336///       TableScan: j1, j1_id > 10
337///
338/// Will return LogicalPlan below:
339///     Alias:  ta
340///       TableScan: j1
341/// And filters: [ta.j1_id < 5, ta.j1_id > 10]
342pub(crate) fn try_transform_to_simple_table_scan_with_filters(
343    plan: &LogicalPlan,
344) -> Result<Option<(LogicalPlan, Vec<Expr>)>> {
345    let mut filters: IndexSet<Expr> = IndexSet::new();
346    let mut plan_stack = vec![plan];
347    let mut table_alias = None;
348
349    while let Some(current_plan) = plan_stack.pop() {
350        match current_plan {
351            LogicalPlan::SubqueryAlias(alias) => {
352                table_alias = Some(alias.alias.clone());
353                plan_stack.push(alias.input.as_ref());
354            }
355            LogicalPlan::Filter(filter) => {
356                if !filters.contains(&filter.predicate) {
357                    filters.insert(filter.predicate.clone());
358                }
359                plan_stack.push(filter.input.as_ref());
360            }
361            LogicalPlan::TableScan(table_scan) => {
362                let table_schema = table_scan.source.schema();
363                // optional rewriter if table has an alias
364                let mut filter_alias_rewriter =
365                    table_alias.as_ref().map(|alias_name| TableAliasRewriter {
366                        table_schema: &table_schema,
367                        alias_name: alias_name.clone(),
368                    });
369
370                // rewrite filters to use table alias if present
371                let table_scan_filters = table_scan
372                    .filters
373                    .iter()
374                    .cloned()
375                    .map(|expr| {
376                        if let Some(ref mut rewriter) = filter_alias_rewriter {
377                            expr.rewrite_with_lambdas_params(rewriter).data()
378                        } else {
379                            Ok(expr)
380                        }
381                    })
382                    .collect::<Result<Vec<_>, DataFusionError>>()?;
383
384                for table_scan_filter in table_scan_filters {
385                    if !filters.contains(&table_scan_filter) {
386                        filters.insert(table_scan_filter);
387                    }
388                }
389
390                let mut builder = LogicalPlanBuilder::scan(
391                    table_scan.table_name.clone(),
392                    Arc::clone(&table_scan.source),
393                    table_scan.projection.clone(),
394                )?;
395
396                if let Some(alias) = table_alias.take() {
397                    builder = builder.alias(alias)?;
398                }
399
400                let plan = builder.build()?;
401                let filters = filters.into_iter().collect();
402
403                return Ok(Some((plan, filters)));
404            }
405            _ => {
406                return Ok(None);
407            }
408        }
409    }
410
411    Ok(None)
412}
413
414/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
415pub(crate) fn date_part_to_sql(
416    unparser: &Unparser,
417    style: DateFieldExtractStyle,
418    date_part_args: &[Expr],
419) -> Result<Option<ast::Expr>> {
420    match (style, date_part_args.len()) {
421        (DateFieldExtractStyle::Extract, 2) => {
422            let date_expr = unparser.expr_to_sql(&date_part_args[1])?;
423            if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
424                let field = match field.to_lowercase().as_str() {
425                    "year" => ast::DateTimeField::Year,
426                    "month" => ast::DateTimeField::Month,
427                    "day" => ast::DateTimeField::Day,
428                    "hour" => ast::DateTimeField::Hour,
429                    "minute" => ast::DateTimeField::Minute,
430                    "second" => ast::DateTimeField::Second,
431                    _ => return Ok(None),
432                };
433
434                return Ok(Some(ast::Expr::Extract {
435                    field,
436                    expr: Box::new(date_expr),
437                    syntax: ast::ExtractSyntax::From,
438                }));
439            }
440        }
441        (DateFieldExtractStyle::Strftime, 2) => {
442            let column = unparser.expr_to_sql(&date_part_args[1])?;
443
444            if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
445                let field = match field.to_lowercase().as_str() {
446                    "year" => "%Y",
447                    "month" => "%m",
448                    "day" => "%d",
449                    "hour" => "%H",
450                    "minute" => "%M",
451                    "second" => "%S",
452                    _ => return Ok(None),
453                };
454
455                return Ok(Some(ast::Expr::Function(ast::Function {
456                    name: ast::ObjectName::from(vec![ast::Ident {
457                        value: "strftime".to_string(),
458                        quote_style: None,
459                        span: Span::empty(),
460                    }]),
461                    args: ast::FunctionArguments::List(ast::FunctionArgumentList {
462                        duplicate_treatment: None,
463                        args: vec![
464                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
465                                ast::Expr::value(ast::Value::SingleQuotedString(
466                                    field.to_string(),
467                                )),
468                            )),
469                            ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)),
470                        ],
471                        clauses: vec![],
472                    }),
473                    filter: None,
474                    null_treatment: None,
475                    over: None,
476                    within_group: vec![],
477                    parameters: ast::FunctionArguments::None,
478                    uses_odbc_syntax: false,
479                })));
480            }
481        }
482        (DateFieldExtractStyle::DatePart, _) => {
483            return Ok(Some(
484                unparser.scalar_function_to_sql("date_part", date_part_args)?,
485            ));
486        }
487        _ => {}
488    };
489
490    Ok(None)
491}
492
493pub(crate) fn character_length_to_sql(
494    unparser: &Unparser,
495    style: CharacterLengthStyle,
496    character_length_args: &[Expr],
497) -> Result<Option<ast::Expr>> {
498    let func_name = match style {
499        CharacterLengthStyle::CharacterLength => "character_length",
500        CharacterLengthStyle::Length => "length",
501    };
502
503    Ok(Some(unparser.scalar_function_to_sql(
504        func_name,
505        character_length_args,
506    )?))
507}
508
509/// SQLite does not support timestamp/date scalars like `to_timestamp`, `from_unixtime`, `date_trunc`, etc.
510/// This remaps `from_unixtime` to `datetime(expr, 'unixepoch')`, expecting the input to be in seconds.
511/// It supports no other arguments, so if any are supplied it will return an error.
512///
513/// # Errors
514///
515/// - If the number of arguments is not 1 - the column or expression to convert.
516/// - If the scalar function cannot be converted to SQL.
517pub(crate) fn sqlite_from_unixtime_to_sql(
518    unparser: &Unparser,
519    from_unixtime_args: &[Expr],
520) -> Result<Option<ast::Expr>> {
521    if from_unixtime_args.len() != 1 {
522        return internal_err!(
523            "from_unixtime for SQLite expects 1 argument, found {}",
524            from_unixtime_args.len()
525        );
526    }
527
528    Ok(Some(unparser.scalar_function_to_sql(
529        "datetime",
530        &[
531            from_unixtime_args[0].clone(),
532            Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string())), None),
533        ],
534    )?))
535}
536
537/// SQLite does not support timestamp/date scalars like `to_timestamp`, `from_unixtime`, `date_trunc`, etc.
538/// This uses the `strftime` function to format the timestamp as a string depending on the truncation unit.
539///
540/// # Errors
541///
542/// - If the number of arguments is not 2 - truncation unit and the column or expression to convert.
543/// - If the scalar function cannot be converted to SQL.
544pub(crate) fn sqlite_date_trunc_to_sql(
545    unparser: &Unparser,
546    date_trunc_args: &[Expr],
547) -> Result<Option<ast::Expr>> {
548    if date_trunc_args.len() != 2 {
549        return internal_err!(
550            "date_trunc for SQLite expects 2 arguments, found {}",
551            date_trunc_args.len()
552        );
553    }
554
555    if let Expr::Literal(ScalarValue::Utf8(Some(unit)), _) = &date_trunc_args[0] {
556        let format = match unit.to_lowercase().as_str() {
557            "year" => "%Y",
558            "month" => "%Y-%m",
559            "day" => "%Y-%m-%d",
560            "hour" => "%Y-%m-%d %H",
561            "minute" => "%Y-%m-%d %H:%M",
562            "second" => "%Y-%m-%d %H:%M:%S",
563            _ => return Ok(None),
564        };
565
566        return Ok(Some(unparser.scalar_function_to_sql(
567            "strftime",
568            &[
569                Expr::Literal(ScalarValue::Utf8(Some(format.to_string())), None),
570                date_trunc_args[1].clone(),
571            ],
572        )?));
573    }
574
575    Ok(None)
576}