datafusion_sql/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL Utility Functions
19
20use std::vec;
21
22use arrow::datatypes::{
23    DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
24};
25use datafusion_common::tree_node::{
26    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
27};
28use datafusion_common::{
29    exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchema, Diagnostic, HashMap, Result, ScalarValue
30};
31use datafusion_expr::builder::get_struct_unnested_columns;
32use datafusion_expr::expr::{
33    Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
34};
35use datafusion_expr::tree_node::TreeNodeRewriterWithPayload;
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38    col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44/// Make a best-effort attempt at resolving all columns in the expression tree
45pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46    expr.clone()
47        .transform_up_with_lambdas_params(|nested_expr, lambdas_params| {
48            match nested_expr {
49                Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => {
50                    let (qualifier, field) =
51                        plan.schema().qualified_field_from_column(&col)?;
52                    Ok(Transformed::yes(Expr::Column(Column::from((
53                        qualifier, field,
54                    )))))
55                }
56                _ => {
57                    // keep recursing
58                    Ok(Transformed::no(nested_expr))
59                }
60            }
61        })
62        .data()
63}
64
65/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
66///
67/// For example, the expression `a + b < 1` would require, as input, the 2
68/// individual columns, `a` and `b`. But, if the base expressions already
69/// contain the `a + b` result, then that may be used in lieu of the `a` and
70/// `b` columns.
71///
72/// This is useful in the context of a query like:
73///
74/// SELECT a + b < 1 ... GROUP BY a + b
75///
76/// where post-aggregation, `a + b` need not be a projection against the
77/// individual columns `a` and `b`, but rather it is a projection against the
78/// `a + b` found in the GROUP BY.
79pub(crate) fn rebase_expr(
80    expr: &Expr,
81    base_exprs: &[Expr],
82    plan: &LogicalPlan,
83) -> Result<Expr> {
84    //todo user transform_down_with_lambdas_params
85    expr.clone()
86        .transform_down(|nested_expr| {
87            if base_exprs.contains(&nested_expr) {
88                Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
89            } else {
90                Ok(Transformed::no(nested_expr))
91            }
92        })
93        .data()
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
98    Projection,
99    Having,
100    Qualify,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub(crate) enum CheckColumnsSatisfyExprsPurpose {
105    Aggregate(CheckColumnsMustReferenceAggregatePurpose),
106}
107
108impl CheckColumnsSatisfyExprsPurpose {
109    fn message_prefix(&self) -> &'static str {
110        match self {
111            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
112                "Column in SELECT must be in GROUP BY or an aggregate function"
113            }
114            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
115                "Column in HAVING must be in GROUP BY or an aggregate function"
116            }
117            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
118                "Column in QUALIFY must be in GROUP BY or an aggregate function"
119            }
120        }
121    }
122
123    fn diagnostic_message(&self, expr: &Expr) -> String {
124        format!("'{expr}' must appear in GROUP BY clause because it's not an aggregate expression")
125    }
126}
127
128/// Determines if the set of `Expr`'s are a valid projection on the input
129/// `Expr::Column`'s.
130pub(crate) fn check_columns_satisfy_exprs(
131    columns: &[Expr],
132    exprs: &[Expr],
133    purpose: CheckColumnsSatisfyExprsPurpose,
134) -> Result<()> {
135    columns.iter().try_for_each(|c| match c {
136        Expr::Column(_) => Ok(()),
137        _ => internal_err!("Expr::Column are required"),
138    })?;
139    let column_exprs = find_column_exprs(exprs);
140    for e in &column_exprs {
141        match e {
142            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
143                for e in exprs {
144                    check_column_satisfies_expr(columns, e, purpose)?;
145                }
146            }
147            Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
148                for e in exprs {
149                    check_column_satisfies_expr(columns, e, purpose)?;
150                }
151            }
152            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
153                for exprs in lists_of_exprs {
154                    for e in exprs {
155                        check_column_satisfies_expr(columns, e, purpose)?;
156                    }
157                }
158            }
159            _ => check_column_satisfies_expr(columns, e, purpose)?,
160        }
161    }
162    Ok(())
163}
164
165fn check_column_satisfies_expr(
166    columns: &[Expr],
167    expr: &Expr,
168    purpose: CheckColumnsSatisfyExprsPurpose,
169) -> Result<()> {
170    if !columns.contains(expr) {
171        let diagnostic = Diagnostic::new_error(
172            purpose.diagnostic_message(expr),
173            expr.spans().and_then(|spans| spans.first()),
174        )
175        .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
176
177        return plan_err!(
178            "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
179            purpose.message_prefix(),
180            expr,
181            expr_vec_fmt!(columns);
182            diagnostic=diagnostic
183        );
184    }
185    Ok(())
186}
187
188/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
189/// aliasing.
190pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
191    exprs
192        .iter()
193        .filter_map(|expr| match expr {
194            Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
195            _ => None,
196        })
197        .collect::<HashMap<String, Expr>>()
198}
199
200/// Given an expression that's literal int encoding position, lookup the corresponding expression
201/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
202/// otherwise, returns planning error.
203/// If input expression is not an int literal, returns expression as-is.
204pub(crate) fn resolve_positions_to_exprs(
205    expr: Expr,
206    select_exprs: &[Expr],
207) -> Result<Expr> {
208    match expr {
209        // sql_expr_to_logical_expr maps number to i64
210        // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
211        Expr::Literal(ScalarValue::Int64(Some(position)), _)
212            if position > 0_i64 && position <= select_exprs.len() as i64 =>
213        {
214            let index = (position - 1) as usize;
215            let select_expr = &select_exprs[index];
216            Ok(match select_expr {
217                Expr::Alias(Alias { expr, .. }) => *expr.clone(),
218                _ => select_expr.clone(),
219            })
220        }
221        Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
222            "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
223            position, select_exprs.len()
224        ),
225        _ => Ok(expr),
226    }
227}
228
229/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
230/// alias' underlying `Expr`.
231pub(crate) fn resolve_aliases_to_exprs(
232    expr: Expr,
233    aliases: &HashMap<String, Expr>,
234) -> Result<Expr> {
235    expr.transform_up_with_lambdas_params(|nested_expr, lambdas_params| match nested_expr {
236        Expr::Column(c) if c.relation.is_none() && !c.is_lambda_parameter(lambdas_params) => {
237            if let Some(aliased_expr) = aliases.get(&c.name) {
238                Ok(Transformed::yes(aliased_expr.clone()))
239            } else {
240                Ok(Transformed::no(Expr::Column(c)))
241            }
242        }
243        _ => Ok(Transformed::no(nested_expr)),
244    })
245    .data()
246}
247
248/// Given a slice of window expressions sharing the same sort key, find their common partition
249/// keys.
250pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
251    let all_partition_keys = window_exprs
252        .iter()
253        .map(|expr| match expr {
254            Expr::WindowFunction(window_fun) => {
255                let WindowFunction {
256                    params: WindowFunctionParams { partition_by, .. },
257                    ..
258                } = window_fun.as_ref();
259                Ok(partition_by)
260            }
261            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
262                Expr::WindowFunction(window_fun) => {
263                    let WindowFunction {
264                        params: WindowFunctionParams { partition_by, .. },
265                        ..
266                    } = window_fun.as_ref();
267                    Ok(partition_by)
268                }
269                expr => exec_err!("Impossibly got non-window expr {expr:?}"),
270            },
271            expr => exec_err!("Impossibly got non-window expr {expr:?}"),
272        })
273        .collect::<Result<Vec<_>>>()?;
274    let result = all_partition_keys
275        .iter()
276        .min_by_key(|s| s.len())
277        .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
278    Ok(result)
279}
280
281/// Returns a validated `DataType` for the specified precision and
282/// scale
283pub(crate) fn make_decimal_type(
284    precision: Option<u64>,
285    scale: Option<u64>,
286) -> Result<DataType> {
287    // postgres like behavior
288    let (precision, scale) = match (precision, scale) {
289        (Some(p), Some(s)) => (p as u8, s as i8),
290        (Some(p), None) => (p as u8, 0),
291        (None, Some(_)) => {
292            return plan_err!("Cannot specify only scale for decimal data type")
293        }
294        (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
295    };
296
297    if precision == 0
298        || precision > DECIMAL256_MAX_PRECISION
299        || scale.unsigned_abs() > precision
300    {
301        plan_err!(
302            "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
303        )
304    } else if precision > DECIMAL128_MAX_PRECISION
305        && precision <= DECIMAL256_MAX_PRECISION
306    {
307        Ok(DataType::Decimal256(precision, scale))
308    } else {
309        Ok(DataType::Decimal128(precision, scale))
310    }
311}
312
313/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
314pub(crate) fn normalize_ident(id: Ident) -> String {
315    match id.quote_style {
316        Some(_) => id.value,
317        None => id.value.to_ascii_lowercase(),
318    }
319}
320
321pub(crate) fn value_to_string(value: &Value) -> Option<String> {
322    match value {
323        Value::SingleQuotedString(s) => Some(s.to_string()),
324        Value::DollarQuotedString(s) => Some(s.to_string()),
325        Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
326        Value::UnicodeStringLiteral(s) => Some(s.to_string()),
327        Value::EscapedStringLiteral(s) => Some(s.to_string()),
328        Value::DoubleQuotedString(_)
329        | Value::NationalStringLiteral(_)
330        | Value::SingleQuotedByteStringLiteral(_)
331        | Value::DoubleQuotedByteStringLiteral(_)
332        | Value::TripleSingleQuotedString(_)
333        | Value::TripleDoubleQuotedString(_)
334        | Value::TripleSingleQuotedByteStringLiteral(_)
335        | Value::TripleDoubleQuotedByteStringLiteral(_)
336        | Value::SingleQuotedRawStringLiteral(_)
337        | Value::DoubleQuotedRawStringLiteral(_)
338        | Value::TripleSingleQuotedRawStringLiteral(_)
339        | Value::TripleDoubleQuotedRawStringLiteral(_)
340        | Value::HexStringLiteral(_)
341        | Value::Null
342        | Value::Placeholder(_) => None,
343    }
344}
345
346pub(crate) fn rewrite_recursive_unnests_bottom_up(
347    input: &LogicalPlan,
348    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
349    inner_projection_exprs: &mut Vec<Expr>,
350    original_exprs: &[Expr],
351) -> Result<Vec<Expr>> {
352    Ok(original_exprs
353        .iter()
354        .map(|expr| {
355            rewrite_recursive_unnest_bottom_up(
356                input,
357                unnest_placeholder_columns,
358                inner_projection_exprs,
359                expr,
360            )
361        })
362        .collect::<Result<Vec<_>>>()?
363        .into_iter()
364        .flatten()
365        .collect::<Vec<_>>())
366}
367
368pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
369
370/*
371This is only usedful when used with transform down up
372A full example of how the transformation works:
373 */
374struct RecursiveUnnestRewriter<'a> {
375    root_expr: &'a Expr,
376    // Useful to detect which child expr is a part of/ not a part of unnest operation
377    top_most_unnest: Option<Unnest>,
378    consecutive_unnest: Vec<Option<Unnest>>,
379    inner_projection_exprs: &'a mut Vec<Expr>,
380    columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
381    transformed_root_exprs: Option<Vec<Expr>>,
382}
383impl RecursiveUnnestRewriter<'_> {
384    /// This struct stores the history of expr
385    /// during its tree-traversal with a notation of
386    /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
387    /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
388    ///
389    /// The first item will be the inner most expr
390    fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
391        self.consecutive_unnest
392            .iter()
393            .rev()
394            .skip_while(|item| item.is_none())
395            .take_while(|item| item.is_some())
396            .to_owned()
397            .cloned()
398            .map(|item| item.unwrap())
399            .collect()
400    }
401
402    fn transform(
403        &mut self,
404        level: usize,
405        alias_name: String,
406        expr_in_unnest: &Expr,
407        struct_allowed: bool,
408        input_schema: &DFSchema,
409    ) -> Result<Vec<Expr>> {
410        let inner_expr_name = expr_in_unnest.schema_name().to_string();
411
412        // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
413        // inside unnest execution, each column inside the inner projection
414        // will be transformed into new columns. Thus we need to keep track of these placeholding column names
415        let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
416        let post_unnest_name =
417            format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
418        // This is due to the fact that unnest transformation should keep the original
419        // column name as is, to comply with group by and order by
420        let placeholder_column = Column::from_name(placeholder_name.clone());
421
422        let (data_type, _) = expr_in_unnest.data_type_and_nullable(input_schema)?;
423
424        match data_type {
425            DataType::Struct(inner_fields) => {
426                if !struct_allowed {
427                    return internal_err!("unnest on struct can only be applied at the root level of select expression");
428                }
429                push_projection_dedupl(
430                    self.inner_projection_exprs,
431                    expr_in_unnest.clone().alias(placeholder_name.clone()),
432                );
433                self.columns_unnestings
434                    .insert(Column::from_name(placeholder_name.clone()), None);
435                Ok(
436                    get_struct_unnested_columns(&placeholder_name, &inner_fields)
437                        .into_iter()
438                        .map(Expr::Column)
439                        .collect(),
440                )
441            }
442            DataType::List(_)
443            | DataType::FixedSizeList(_, _)
444            | DataType::LargeList(_) => {
445                push_projection_dedupl(
446                    self.inner_projection_exprs,
447                    expr_in_unnest.clone().alias(placeholder_name.clone()),
448                );
449
450                let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
451                let list_unnesting = self
452                    .columns_unnestings
453                    .entry(placeholder_column)
454                    .or_insert(Some(vec![]));
455                let unnesting = ColumnUnnestList {
456                    output_column: Column::from_name(post_unnest_name),
457                    depth: level,
458                };
459                let list_unnestings = list_unnesting.as_mut().unwrap();
460                if !list_unnestings.contains(&unnesting) {
461                    list_unnestings.push(unnesting);
462                }
463                Ok(vec![post_unnest_expr])
464            }
465            _ => {
466                internal_err!("unnest on non-list or struct type is not supported")
467            }
468        }
469    }
470}
471
472impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> {
473    type Node = Expr;
474    type Payload<'a> = &'a DFSchema;
475
476    /// This downward traversal needs to keep track of:
477    /// - Whether or not some unnest expr has been visited from the top util the current node
478    /// - If some unnest expr has been visited, maintain a stack of such information, this
479    ///   is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
480    fn f_down(&mut self, expr: Expr, input_schema: &DFSchema) -> Result<Transformed<Expr>> {
481        if let Expr::Unnest(ref unnest_expr) = expr {
482            let (data_type, _) =
483                unnest_expr.expr.data_type_and_nullable(input_schema)?;
484            self.consecutive_unnest.push(Some(unnest_expr.clone()));
485            // if expr inside unnest is a struct, do not consider
486            // the next unnest as consecutive unnest (if any)
487            // meaning unnest(unnest(struct_arr_col)) can't
488            // be interpreted as unnest(struct_arr_col, depth:=2)
489            // but has to be split into multiple unnest logical plan instead
490            // a.k.a:
491            // - unnest(struct_col)
492            //      unnest(struct_arr_col) as struct_col
493
494            if let DataType::Struct(_) = data_type {
495                self.consecutive_unnest.push(None);
496            }
497            if self.top_most_unnest.is_none() {
498                self.top_most_unnest = Some(unnest_expr.clone());
499            }
500
501            Ok(Transformed::no(expr))
502        } else {
503            self.consecutive_unnest.push(None);
504            Ok(Transformed::no(expr))
505        }
506    }
507
508    /// The rewriting only happens when the traversal has reached the top-most unnest expr
509    /// within a sequence of consecutive unnest exprs node
510    ///
511    /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
512    /// ```text
513    ///                         ┌──────────────────┐
514    ///                         │    binaryexpr    │
515    ///                         │                  │
516    ///                         └──────────────────┘
517    ///                f_down  / /            │ │
518    ///                       / / f_up        │ │
519    ///                      / /        f_down│ │f_up
520    ///                  unnest               │ │
521    ///                                       │ │
522    ///       f_down  / / f_up(rewriting)     │ │
523    ///              / /
524    ///             / /                      unnest
525    ///         unnest
526    ///                           f_down  / / f_up(rewriting)
527    /// f_down / /f_up                   / /
528    ///       / /                       / /
529    ///      / /                    unnest
530    ///   column1
531    ///                     f_down / /f_up
532    ///                           / /
533    ///                          / /
534    ///                       column2
535    /// ```
536    ///
537    fn f_up(&mut self, expr: Expr, input_schema: &DFSchema) -> Result<Transformed<Expr>> {
538        if let Expr::Unnest(ref traversing_unnest) = expr {
539            if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
540                self.top_most_unnest = None;
541            }
542            // Find inside consecutive_unnest, the sequence of continuous unnest exprs
543
544            // Get the latest consecutive unnest exprs
545            // and check if current upward traversal is the returning to the root expr
546            // for example given a expr `unnest(unnest(col))` then the traversal happens like:
547            // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
548            // the result of such traversal is unnest(col, depth:=2)
549            let unnest_stack = self.get_latest_consecutive_unnest();
550
551            // This traversal has reached the top most unnest again
552            // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
553            // -> Unnest(2nd) -> Unnest(top) a.k.a here
554            // Thus
555            // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
556            if traversing_unnest == unnest_stack.last().unwrap() {
557                let most_inner = unnest_stack.first().unwrap();
558                let inner_expr = most_inner.expr.as_ref();
559                // unnest(unnest(struct_arr_col)) is not allow to be done recursively
560                // it needs to be split into multiple unnest logical plan
561                // unnest(struct_arr)
562                //  unnest(struct_arr_col) as struct_arr
563                // instead of unnest(struct_arr_col, depth = 2)
564
565                let unnest_recursion = unnest_stack.len();
566                let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1;
567
568                let mut transformed_exprs = self.transform(
569                    unnest_recursion,
570                    expr.schema_name().to_string(),
571                    inner_expr,
572                    struct_allowed,
573                    input_schema,
574                )?;
575                if struct_allowed {
576                    self.transformed_root_exprs = Some(transformed_exprs.clone());
577                }
578                return Ok(Transformed::new(
579                    transformed_exprs.swap_remove(0),
580                    true,
581                    TreeNodeRecursion::Continue,
582                ));
583            }
584        } else {
585            self.consecutive_unnest.push(None);
586        }
587
588        // For column exprs that are not descendants of any unnest node
589        // retain their projection
590        // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
591        // this condition can be checked by maintaining an Option<top most unnest>
592        if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
593            push_projection_dedupl(self.inner_projection_exprs, expr.clone());
594        }
595
596        Ok(Transformed::no(expr))
597    }
598}
599
600fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
601    let schema_name = expr.schema_name().to_string();
602    if !projection
603        .iter()
604        .any(|e| e.schema_name().to_string() == schema_name)
605    {
606        projection.push(expr);
607    }
608}
609/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
610/// Given an expression which contains unnest expr as one of its children,
611/// Try transform depends on unnest type
612/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
613/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
614///
615/// The transformed exprs will be used in the outer projection
616/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
617/// is done only for the bottom expression
618pub(crate) fn rewrite_recursive_unnest_bottom_up(
619    input: &LogicalPlan,
620    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
621    inner_projection_exprs: &mut Vec<Expr>,
622    original_expr: &Expr,
623) -> Result<Vec<Expr>> {
624    let mut rewriter = RecursiveUnnestRewriter {
625        root_expr: original_expr,
626        top_most_unnest: None,
627        consecutive_unnest: vec![],
628        inner_projection_exprs,
629        columns_unnestings: unnest_placeholder_columns,
630        transformed_root_exprs: None,
631    };
632
633    // This transformation is only done for list unnest
634    // struct unnest is done at the root level, and at the later stage
635    // because the syntax of TreeNode only support transform into 1 Expr, while
636    // Unnest struct will be transformed into multiple Exprs
637    // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
638    //
639    // The transformation looks like:
640    // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
641    // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
642    let Transformed {
643        data: transformed_expr,
644        transformed,
645        tnr: _,
646    } = original_expr.clone().rewrite_with_schema(input.schema(), &mut rewriter)?;
647
648    if !transformed {
649        // TODO: remove the next line after `Expr::Wildcard` is removed
650        #[expect(deprecated)]
651        if matches!(&transformed_expr, Expr::Column(_))
652            || matches!(&transformed_expr, Expr::Wildcard { .. })
653        {
654            push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
655            Ok(vec![transformed_expr])
656        } else {
657            // We need to evaluate the expr in the inner projection,
658            // outer projection just select its name
659            let column_name = transformed_expr.schema_name().to_string();
660            push_projection_dedupl(inner_projection_exprs, transformed_expr);
661            Ok(vec![Expr::Column(Column::from_name(column_name))])
662        }
663    } else {
664        if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
665            return Ok(transformed_root_exprs);
666        }
667        Ok(vec![transformed_expr])
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use std::{ops::Add, sync::Arc};
674
675    use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
676    use datafusion_common::{Column, DFSchema, Result};
677    use datafusion_expr::{
678        col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan,
679    };
680    use datafusion_functions::core::expr_ext::FieldAccessor;
681    use datafusion_functions_aggregate::expr_fn::count;
682
683    use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
684    use indexmap::IndexMap;
685
686    fn column_unnests_eq(
687        l: Vec<&str>,
688        r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
689    ) {
690        let r_formatted: Vec<String> = r
691            .iter()
692            .map(|i| match i.1 {
693                None => format!("{}", i.0),
694                Some(vec) => format!(
695                    "{}=>[{}]",
696                    i.0,
697                    vec.iter()
698                        .map(|i| format!("{i}"))
699                        .collect::<Vec<String>>()
700                        .join(", ")
701                ),
702            })
703            .collect();
704        let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
705        assert_eq!(l_formatted, r_formatted);
706    }
707
708    #[test]
709    fn test_transform_bottom_unnest_recursive() -> Result<()> {
710        let schema = Schema::new(vec![
711            Field::new(
712                "3d_col",
713                ArrowDataType::List(Arc::new(Field::new(
714                    "2d_col",
715                    ArrowDataType::List(Arc::new(Field::new(
716                        "elements",
717                        ArrowDataType::Int64,
718                        true,
719                    ))),
720                    true,
721                ))),
722                true,
723            ),
724            Field::new("i64_col", ArrowDataType::Int64, true),
725        ]);
726
727        let dfschema = DFSchema::try_from(schema)?;
728
729        let input = LogicalPlan::EmptyRelation(EmptyRelation {
730            produce_one_row: false,
731            schema: Arc::new(dfschema),
732        });
733
734        let mut unnest_placeholder_columns = IndexMap::new();
735        let mut inner_projection_exprs = vec![];
736
737        // unnest(unnest(3d_col)) + unnest(unnest(3d_col))
738        let original_expr = unnest(unnest(col("3d_col")))
739            .add(unnest(unnest(col("3d_col"))))
740            .add(col("i64_col"));
741        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
742            &input,
743            &mut unnest_placeholder_columns,
744            &mut inner_projection_exprs,
745            &original_expr,
746        )?;
747        // Only the bottom most unnest exprs are transformed
748        assert_eq!(
749            transformed_exprs,
750            vec![col("__unnest_placeholder(3d_col,depth=2)")
751                .alias("UNNEST(UNNEST(3d_col))")
752                .add(
753                    col("__unnest_placeholder(3d_col,depth=2)")
754                        .alias("UNNEST(UNNEST(3d_col))")
755                )
756                .add(col("i64_col"))]
757        );
758        column_unnests_eq(
759            vec![
760                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
761            ],
762            &unnest_placeholder_columns,
763        );
764
765        // Still reference struct_col in original schema but with alias,
766        // to avoid colliding with the projection on the column itself if any
767        assert_eq!(
768            inner_projection_exprs,
769            vec![
770                col("3d_col").alias("__unnest_placeholder(3d_col)"),
771                col("i64_col")
772            ]
773        );
774
775        // unnest(3d_col) as 2d_col
776        let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
777        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
778            &input,
779            &mut unnest_placeholder_columns,
780            &mut inner_projection_exprs,
781            &original_expr_2,
782        )?;
783
784        assert_eq!(
785            transformed_exprs,
786            vec![
787                (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
788                    .alias("2d_col")
789            ]
790        );
791        column_unnests_eq(
792            vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"],
793            &unnest_placeholder_columns,
794        );
795        // Still reference struct_col in original schema but with alias,
796        // to avoid colliding with the projection on the column itself if any
797        assert_eq!(
798            inner_projection_exprs,
799            vec![
800                col("3d_col").alias("__unnest_placeholder(3d_col)"),
801                col("i64_col")
802            ]
803        );
804
805        Ok(())
806    }
807
808    #[test]
809    fn test_transform_bottom_unnest() -> Result<()> {
810        let schema = Schema::new(vec![
811            Field::new(
812                "struct_col",
813                ArrowDataType::Struct(Fields::from(vec![
814                    Field::new("field1", ArrowDataType::Int32, false),
815                    Field::new("field2", ArrowDataType::Int32, false),
816                ])),
817                false,
818            ),
819            Field::new(
820                "array_col",
821                ArrowDataType::List(Arc::new(Field::new_list_field(
822                    ArrowDataType::Int64,
823                    true,
824                ))),
825                true,
826            ),
827            Field::new("int_col", ArrowDataType::Int32, false),
828        ]);
829
830        let dfschema = DFSchema::try_from(schema)?;
831
832        let input = LogicalPlan::EmptyRelation(EmptyRelation {
833            produce_one_row: false,
834            schema: Arc::new(dfschema),
835        });
836
837        let mut unnest_placeholder_columns = IndexMap::new();
838        let mut inner_projection_exprs = vec![];
839
840        // unnest(struct_col)
841        let original_expr = unnest(col("struct_col"));
842        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
843            &input,
844            &mut unnest_placeholder_columns,
845            &mut inner_projection_exprs,
846            &original_expr,
847        )?;
848        assert_eq!(
849            transformed_exprs,
850            vec![
851                col("__unnest_placeholder(struct_col).field1"),
852                col("__unnest_placeholder(struct_col).field2"),
853            ]
854        );
855        column_unnests_eq(
856            vec!["__unnest_placeholder(struct_col)"],
857            &unnest_placeholder_columns,
858        );
859        // Still reference struct_col in original schema but with alias,
860        // to avoid colliding with the projection on the column itself if any
861        assert_eq!(
862            inner_projection_exprs,
863            vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
864        );
865
866        // unnest(array_col) + 1
867        let original_expr = unnest(col("array_col")).add(lit(1i64));
868        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
869            &input,
870            &mut unnest_placeholder_columns,
871            &mut inner_projection_exprs,
872            &original_expr,
873        )?;
874        column_unnests_eq(
875            vec![
876                "__unnest_placeholder(struct_col)",
877                "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
878            ],
879            &unnest_placeholder_columns,
880        );
881        // Only transform the unnest children
882        assert_eq!(
883            transformed_exprs,
884            vec![col("__unnest_placeholder(array_col,depth=1)")
885                .alias("UNNEST(array_col)")
886                .add(lit(1i64))]
887        );
888
889        // Keep appending to the current vector
890        // Still reference array_col in original schema but with alias,
891        // to avoid colliding with the projection on the column itself if any
892        assert_eq!(
893            inner_projection_exprs,
894            vec![
895                col("struct_col").alias("__unnest_placeholder(struct_col)"),
896                col("array_col").alias("__unnest_placeholder(array_col)")
897            ]
898        );
899
900        Ok(())
901    }
902
903    // Unnest -> field access -> unnest
904    #[test]
905    fn test_transform_non_consecutive_unnests() -> Result<()> {
906        // List of struct
907        // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
908        let schema = Schema::new(vec![
909            Field::new(
910                "struct_list",
911                ArrowDataType::List(Arc::new(Field::new(
912                    "element",
913                    ArrowDataType::Struct(Fields::from(vec![
914                        Field::new(
915                            // list of i64
916                            "subfield1",
917                            ArrowDataType::List(Arc::new(Field::new(
918                                "i64_element",
919                                ArrowDataType::Int64,
920                                true,
921                            ))),
922                            true,
923                        ),
924                        Field::new(
925                            // list of utf8
926                            "subfield2",
927                            ArrowDataType::List(Arc::new(Field::new(
928                                "utf8_element",
929                                ArrowDataType::Utf8,
930                                true,
931                            ))),
932                            true,
933                        ),
934                    ])),
935                    true,
936                ))),
937                true,
938            ),
939            Field::new("int_col", ArrowDataType::Int32, false),
940        ]);
941
942        let dfschema = DFSchema::try_from(schema)?;
943
944        let input = LogicalPlan::EmptyRelation(EmptyRelation {
945            produce_one_row: false,
946            schema: Arc::new(dfschema),
947        });
948
949        let mut unnest_placeholder_columns = IndexMap::new();
950        let mut inner_projection_exprs = vec![];
951
952        // An expr with multiple unnest
953        let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
954        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
955            &input,
956            &mut unnest_placeholder_columns,
957            &mut inner_projection_exprs,
958            &select_expr1,
959        )?;
960        // Only the inner most/ bottom most unnest is transformed
961        assert_eq!(
962            transformed_exprs,
963            vec![unnest(
964                col("__unnest_placeholder(struct_list,depth=1)")
965                    .alias("UNNEST(struct_list)")
966                    .field("subfield1")
967            )]
968        );
969
970        column_unnests_eq(
971            vec![
972                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
973            ],
974            &unnest_placeholder_columns,
975        );
976
977        assert_eq!(
978            inner_projection_exprs,
979            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
980        );
981
982        // continue rewrite another expr in select
983        let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
984        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
985            &input,
986            &mut unnest_placeholder_columns,
987            &mut inner_projection_exprs,
988            &select_expr2,
989        )?;
990        // Only the inner most/ bottom most unnest is transformed
991        assert_eq!(
992            transformed_exprs,
993            vec![unnest(
994                col("__unnest_placeholder(struct_list,depth=1)")
995                    .alias("UNNEST(struct_list)")
996                    .field("subfield2")
997            )]
998        );
999
1000        // unnest place holder columns remain the same
1001        // because expr1 and expr2 derive from the same unnest result
1002        column_unnests_eq(
1003            vec![
1004                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1005            ],
1006            &unnest_placeholder_columns,
1007        );
1008
1009        assert_eq!(
1010            inner_projection_exprs,
1011            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1012        );
1013
1014        Ok(())
1015    }
1016
1017    #[test]
1018    fn test_resolve_positions_to_exprs() -> Result<()> {
1019        let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1020
1021        // Assert 1 resolved as first column in select list
1022        let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1023        assert_eq!(resolved, col("c1"));
1024
1025        // Assert error if index out of select clause bounds
1026        let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1027        assert!(resolved.is_err_and(|e| e.message().contains(
1028            "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1029        )));
1030
1031        let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1032        assert!(resolved.is_err_and(|e| e.message().contains(
1033            "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1034        )));
1035
1036        // Assert expression returned as-is
1037        let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1038        assert_eq!(resolved, lit("text"));
1039
1040        let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1041        assert_eq!(resolved, col("fake"));
1042
1043        Ok(())
1044    }
1045}