datafusion_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression utilities
19
20use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27    and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37    internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap,
38    Result, TableReference,
39};
40
41#[cfg(not(feature = "sql"))]
42use crate::expr::{ExceptSelectItem, ExcludeSelectItem};
43use indexmap::IndexSet;
44#[cfg(feature = "sql")]
45use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
46
47pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
48
49///  The value to which `COUNT(*)` is expanded to in
50///  `COUNT(<constant>)` expressions
51pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
52
53/// Count the number of distinct exprs in a list of group by expressions. If the
54/// first element is a `GroupingSet` expression then it must be the only expr.
55pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
56    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
57        if group_expr.len() > 1 {
58            return plan_err!(
59                "Invalid group by expressions, GroupingSet must be the only expression"
60            );
61        }
62        // Groupings sets have an additional integral column for the grouping id
63        Ok(grouping_set.distinct_expr().len() + 1)
64    } else {
65        grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
66    }
67}
68
69/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
70/// including the empty set and S itself.
71///
72/// Example:
73///
74/// If S is the set {x, y, z}, then all the subsets of S are \
75///  {} \
76///  {x} \
77///  {y} \
78///  {z} \
79///  {x, y} \
80///  {x, z} \
81///  {y, z} \
82///  {x, y, z} \
83///  and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
84///
85/// [power set]: https://en.wikipedia.org/wiki/Power_set
86fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
87    if slice.len() >= 64 {
88        return Err("The size of the set must be less than 64.".into());
89    }
90
91    let mut v = Vec::new();
92    for mask in 0..(1 << slice.len()) {
93        let mut ss = vec![];
94        let mut bitset = mask;
95        while bitset > 0 {
96            let rightmost: u64 = bitset & !(bitset - 1);
97            let idx = rightmost.trailing_zeros();
98            let item = slice.get(idx as usize).unwrap();
99            ss.push(item);
100            // zero the trailing bit
101            bitset &= bitset - 1;
102        }
103        v.push(ss);
104    }
105    Ok(v)
106}
107
108/// check the number of expressions contained in the grouping_set
109fn check_grouping_set_size_limit(size: usize) -> Result<()> {
110    let max_grouping_set_size = 65535;
111    if size > max_grouping_set_size {
112        return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
113    }
114
115    Ok(())
116}
117
118/// check the number of grouping_set contained in the grouping sets
119fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
120    let max_grouping_sets_size = 4096;
121    if size > max_grouping_sets_size {
122        return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
123    }
124
125    Ok(())
126}
127
128/// Merge two grouping_set
129///
130/// # Example
131/// ```text
132/// (A, B), (C, D) -> (A, B, C, D)
133/// ```
134///
135/// # Error
136/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
137///
138/// [`DataFusionError`]: datafusion_common::DataFusionError
139fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
140    check_grouping_set_size_limit(left.len() + right.len())?;
141    Ok(left.iter().chain(right.iter()).cloned().collect())
142}
143
144/// Compute the cross product of two grouping_sets
145///
146/// # Example
147/// ```text
148/// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)]
149/// ```
150///
151/// # Error
152/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
153/// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit
154///
155/// [`DataFusionError`]: datafusion_common::DataFusionError
156fn cross_join_grouping_sets<T: Clone>(
157    left: &[Vec<T>],
158    right: &[Vec<T>],
159) -> Result<Vec<Vec<T>>> {
160    let grouping_sets_size = left.len() * right.len();
161
162    check_grouping_sets_size_limit(grouping_sets_size)?;
163
164    let mut result = Vec::with_capacity(grouping_sets_size);
165    for le in left {
166        for re in right {
167            result.push(merge_grouping_set(le, re)?);
168        }
169    }
170    Ok(result)
171}
172
173/// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\
174/// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\
175/// no conversion will be performed.
176///
177/// e.g.
178///
179/// person.id,\
180/// GROUPING SETS ((person.age, person.salary),(person.age)),\
181/// ROLLUP(person.state, person.birth_date)
182///
183/// =>
184///
185/// GROUPING SETS (\
186///   (person.id, person.age, person.salary),\
187///   (person.id, person.age, person.salary, person.state),\
188///   (person.id, person.age, person.salary, person.state, person.birth_date),\
189///   (person.id, person.age),\
190///   (person.id, person.age, person.state),\
191///   (person.id, person.age, person.state, person.birth_date)\
192/// )
193pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
194    let has_grouping_set = group_expr
195        .iter()
196        .any(|expr| matches!(expr, Expr::GroupingSet(_)));
197    if !has_grouping_set || group_expr.len() == 1 {
198        return Ok(group_expr);
199    }
200    // Only process mix grouping sets
201    let partial_sets = group_expr
202        .iter()
203        .map(|expr| {
204            let exprs = match expr {
205                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
206                    check_grouping_sets_size_limit(grouping_sets.len())?;
207                    grouping_sets.iter().map(|e| e.iter().collect()).collect()
208                }
209                Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
210                    let grouping_sets = powerset(group_exprs)
211                        .map_err(|e| plan_datafusion_err!("{}", e))?;
212                    check_grouping_sets_size_limit(grouping_sets.len())?;
213                    grouping_sets
214                }
215                Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
216                    let size = group_exprs.len();
217                    let slice = group_exprs.as_slice();
218                    check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
219                    (0..(size + 1))
220                        .map(|i| slice[0..i].iter().collect())
221                        .collect()
222                }
223                expr => vec![vec![expr]],
224            };
225            Ok(exprs)
226        })
227        .collect::<Result<Vec<_>>>()?;
228
229    // Cross Join
230    let grouping_sets = partial_sets
231        .into_iter()
232        .map(Ok)
233        .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
234        .transpose()?
235        .map(|e| {
236            e.into_iter()
237                .map(|e| e.into_iter().cloned().collect())
238                .collect()
239        })
240        .unwrap_or_default();
241
242    Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
243        grouping_sets,
244    ))])
245}
246
247/// Find all distinct exprs in a list of group by expressions. If the
248/// first element is a `GroupingSet` expression then it must be the only expr.
249pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
250    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
251        if group_expr.len() > 1 {
252            return plan_err!(
253                "Invalid group by expressions, GroupingSet must be the only expression"
254            );
255        }
256        Ok(grouping_set.distinct_expr())
257    } else {
258        Ok(group_expr
259            .iter()
260            .collect::<IndexSet<_>>()
261            .into_iter()
262            .collect())
263    }
264}
265
266/// Recursively walk an expression tree, collecting the unique set of columns
267/// referenced in the expression
268pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
269    expr.apply_with_lambdas_params(|expr, lambdas_params| {
270        match expr {
271            Expr::Column(qc) => {
272                if qc.relation.is_some() || !lambdas_params.contains(qc.name()) {
273                    accum.insert(qc.clone());
274                }
275            }
276            // Use explicit pattern match instead of a default
277            // implementation, so that in the future if someone adds
278            // new Expr types, they will check here as well
279            // TODO: remove the next line after `Expr::Wildcard` is removed
280            #[expect(deprecated)]
281            Expr::Unnest(_)
282            | Expr::ScalarVariable(_, _)
283            | Expr::Alias(_)
284            | Expr::Literal(_, _)
285            | Expr::BinaryExpr { .. }
286            | Expr::Like { .. }
287            | Expr::SimilarTo { .. }
288            | Expr::Not(_)
289            | Expr::IsNotNull(_)
290            | Expr::IsNull(_)
291            | Expr::IsTrue(_)
292            | Expr::IsFalse(_)
293            | Expr::IsUnknown(_)
294            | Expr::IsNotTrue(_)
295            | Expr::IsNotFalse(_)
296            | Expr::IsNotUnknown(_)
297            | Expr::Negative(_)
298            | Expr::Between { .. }
299            | Expr::Case { .. }
300            | Expr::Cast { .. }
301            | Expr::TryCast { .. }
302            | Expr::ScalarFunction(..)
303            | Expr::WindowFunction { .. }
304            | Expr::AggregateFunction { .. }
305            | Expr::GroupingSet(_)
306            | Expr::InList { .. }
307            | Expr::Exists { .. }
308            | Expr::InSubquery(_)
309            | Expr::ScalarSubquery(_)
310            | Expr::Wildcard { .. }
311            | Expr::Placeholder(_)
312            | Expr::OuterReferenceColumn { .. }
313            | Expr::Lambda { .. } => {}
314        }
315        Ok(TreeNodeRecursion::Continue)
316    })
317    .map(|_| ())
318}
319
320/// Find excluded columns in the schema, if any
321/// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]`
322fn get_excluded_columns(
323    opt_exclude: Option<&ExcludeSelectItem>,
324    opt_except: Option<&ExceptSelectItem>,
325    schema: &DFSchema,
326    qualifier: Option<&TableReference>,
327) -> Result<Vec<Column>> {
328    let mut idents = vec![];
329    if let Some(excepts) = opt_except {
330        idents.push(&excepts.first_element);
331        idents.extend(&excepts.additional_elements);
332    }
333    if let Some(exclude) = opt_exclude {
334        match exclude {
335            ExcludeSelectItem::Single(ident) => idents.push(ident),
336            ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
337        }
338    }
339    // Excluded columns should be unique
340    let n_elem = idents.len();
341    let unique_idents = idents.into_iter().collect::<HashSet<_>>();
342    // If HashSet size, and vector length are different, this means that some of the excluded columns
343    // are not unique. In this case return error.
344    if n_elem != unique_idents.len() {
345        return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
346    }
347
348    let mut result = vec![];
349    for ident in unique_idents.into_iter() {
350        let col_name = ident.value.as_str();
351        let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
352        result.push(Column::from((qualifier, field)));
353    }
354    Ok(result)
355}
356
357/// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip`
358fn get_exprs_except_skipped(
359    schema: &DFSchema,
360    columns_to_skip: HashSet<Column>,
361) -> Vec<Expr> {
362    if columns_to_skip.is_empty() {
363        schema.iter().map(Expr::from).collect::<Vec<Expr>>()
364    } else {
365        schema
366            .columns()
367            .iter()
368            .filter_map(|c| {
369                if !columns_to_skip.contains(c) {
370                    Some(Expr::Column(c.clone()))
371                } else {
372                    None
373                }
374            })
375            .collect::<Vec<Expr>>()
376    }
377}
378
379/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice
380/// (once for each join side), but an unqualified wildcard should include it only once.
381/// This function returns the columns that should be excluded.
382fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
383    let using_columns = plan.using_columns()?;
384    let excluded = using_columns
385        .into_iter()
386        // For each USING JOIN condition, only expand to one of each join column in projection
387        .flat_map(|cols| {
388            let mut cols = cols.into_iter().collect::<Vec<_>>();
389            // sort join columns to make sure we consistently keep the same
390            // qualified column
391            cols.sort();
392            let mut out_column_names: HashSet<String> = HashSet::new();
393            cols.into_iter().filter_map(move |c| {
394                if out_column_names.contains(&c.name) {
395                    Some(c)
396                } else {
397                    out_column_names.insert(c.name);
398                    None
399                }
400            })
401        })
402        .collect::<HashSet<_>>();
403    Ok(excluded)
404}
405
406/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
407pub fn expand_wildcard(
408    schema: &DFSchema,
409    plan: &LogicalPlan,
410    wildcard_options: Option<&WildcardOptions>,
411) -> Result<Vec<Expr>> {
412    let mut columns_to_skip = exclude_using_columns(plan)?;
413    let excluded_columns = if let Some(WildcardOptions {
414        exclude: opt_exclude,
415        except: opt_except,
416        ..
417    }) = wildcard_options
418    {
419        get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
420    } else {
421        vec![]
422    };
423    // Add each excluded `Column` to columns_to_skip
424    columns_to_skip.extend(excluded_columns);
425    Ok(get_exprs_except_skipped(schema, columns_to_skip))
426}
427
428/// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s.
429pub fn expand_qualified_wildcard(
430    qualifier: &TableReference,
431    schema: &DFSchema,
432    wildcard_options: Option<&WildcardOptions>,
433) -> Result<Vec<Expr>> {
434    let qualified_indices = schema.fields_indices_with_qualified(qualifier);
435    let projected_func_dependencies = schema
436        .functional_dependencies()
437        .project_functional_dependencies(&qualified_indices, qualified_indices.len());
438    let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
439    if fields_with_qualified.is_empty() {
440        return plan_err!("Invalid qualifier {qualifier}");
441    }
442
443    let qualified_schema = Arc::new(Schema::new_with_metadata(
444        fields_with_qualified,
445        schema.metadata().clone(),
446    ));
447    let qualified_dfschema =
448        DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
449            .with_functional_dependencies(projected_func_dependencies)?;
450    let excluded_columns = if let Some(WildcardOptions {
451        exclude: opt_exclude,
452        except: opt_except,
453        ..
454    }) = wildcard_options
455    {
456        get_excluded_columns(
457            opt_exclude.as_ref(),
458            opt_except.as_ref(),
459            schema,
460            Some(qualifier),
461        )?
462    } else {
463        vec![]
464    };
465    // Add each excluded `Column` to columns_to_skip
466    let mut columns_to_skip = HashSet::new();
467    columns_to_skip.extend(excluded_columns);
468    Ok(get_exprs_except_skipped(
469        &qualified_dfschema,
470        columns_to_skip,
471    ))
472}
473
474/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
475/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
476type WindowSortKey = Vec<(Sort, bool)>;
477
478/// Generate a sort key for a given window expr's partition_by and order_by expr
479pub fn generate_sort_key(
480    partition_by: &[Expr],
481    order_by: &[Sort],
482) -> Result<WindowSortKey> {
483    let normalized_order_by_keys = order_by
484        .iter()
485        .map(|e| {
486            let Sort { expr, .. } = e;
487            Sort::new(expr.clone(), true, false)
488        })
489        .collect::<Vec<_>>();
490
491    let mut final_sort_keys = vec![];
492    let mut is_partition_flag = vec![];
493    partition_by.iter().for_each(|e| {
494        // By default, create sort key with ASC is true and NULLS LAST to be consistent with
495        // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
496        let e = e.clone().sort(true, false);
497        if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
498            let order_by_key = &order_by[pos];
499            if !final_sort_keys.contains(order_by_key) {
500                final_sort_keys.push(order_by_key.clone());
501                is_partition_flag.push(true);
502            }
503        } else if !final_sort_keys.contains(&e) {
504            final_sort_keys.push(e);
505            is_partition_flag.push(true);
506        }
507    });
508
509    order_by.iter().for_each(|e| {
510        if !final_sort_keys.contains(e) {
511            final_sort_keys.push(e.clone());
512            is_partition_flag.push(false);
513        }
514    });
515    let res = final_sort_keys
516        .into_iter()
517        .zip(is_partition_flag)
518        .collect::<Vec<_>>();
519    Ok(res)
520}
521
522/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
523/// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c>
524pub fn compare_sort_expr(
525    sort_expr_a: &Sort,
526    sort_expr_b: &Sort,
527    schema: &DFSchemaRef,
528) -> Ordering {
529    let Sort {
530        expr: expr_a,
531        asc: asc_a,
532        nulls_first: nulls_first_a,
533    } = sort_expr_a;
534
535    let Sort {
536        expr: expr_b,
537        asc: asc_b,
538        nulls_first: nulls_first_b,
539    } = sort_expr_b;
540
541    let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
542    let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
543    for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
544        match idx_a.cmp(idx_b) {
545            Ordering::Less => {
546                return Ordering::Less;
547            }
548            Ordering::Greater => {
549                return Ordering::Greater;
550            }
551            Ordering::Equal => {}
552        }
553    }
554    match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
555        Ordering::Less => return Ordering::Greater,
556        Ordering::Greater => {
557            return Ordering::Less;
558        }
559        Ordering::Equal => {}
560    }
561    match (asc_a, asc_b) {
562        (true, false) => {
563            return Ordering::Greater;
564        }
565        (false, true) => {
566            return Ordering::Less;
567        }
568        _ => {}
569    }
570    match (nulls_first_a, nulls_first_b) {
571        (true, false) => {
572            return Ordering::Less;
573        }
574        (false, true) => {
575            return Ordering::Greater;
576        }
577        _ => {}
578    }
579    Ordering::Equal
580}
581
582/// Group a slice of window expression expr by their order by expressions
583pub fn group_window_expr_by_sort_keys(
584    window_expr: impl IntoIterator<Item = Expr>,
585) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
586    let mut result = vec![];
587    window_expr.into_iter().try_for_each(|expr| match &expr {
588        Expr::WindowFunction(window_fun) => {
589            let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
590            let sort_key = generate_sort_key(partition_by, order_by)?;
591            if let Some((_, values)) = result.iter_mut().find(
592                |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
593            ) {
594                values.push(expr);
595            } else {
596                result.push((sort_key, vec![expr]))
597            }
598            Ok(())
599        }
600        other => internal_err!(
601            "Impossibly got non-window expr {other:?}"
602        ),
603    })?;
604    Ok(result)
605}
606
607/// Collect all deeply nested `Expr::AggregateFunction`.
608/// They are returned in order of occurrence (depth
609/// first), with duplicates omitted.
610pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
611    find_exprs_in_exprs(exprs, &|nested_expr| {
612        matches!(nested_expr, Expr::AggregateFunction { .. })
613    })
614}
615
616/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
617/// (depth first), with duplicates omitted.
618pub fn find_window_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
619    find_exprs_in_exprs(exprs, &|nested_expr| {
620        matches!(nested_expr, Expr::WindowFunction { .. })
621    })
622}
623
624/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence
625/// (depth first), with duplicates omitted.
626pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
627    find_exprs_in_expr(expr, &|nested_expr| {
628        matches!(nested_expr, Expr::OuterReferenceColumn { .. })
629    })
630}
631
632/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
633/// pass the provided test. The returned `Expr`'s are deduplicated and returned
634/// in order of appearance (depth first).
635fn find_exprs_in_exprs<'a, F>(
636    exprs: impl IntoIterator<Item = &'a Expr>,
637    test_fn: &F,
638) -> Vec<Expr>
639where
640    F: Fn(&Expr) -> bool,
641{
642    exprs
643        .into_iter()
644        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
645        .fold(vec![], |mut acc, expr| {
646            if !acc.contains(&expr) {
647                acc.push(expr)
648            }
649            acc
650        })
651}
652
653/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
654/// provided test. The returned `Expr`'s are deduplicated and returned in order
655/// of appearance (depth first).
656/// todo: document about that columns may refer to a lambda parameter?
657fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
658where
659    F: Fn(&Expr) -> bool,
660{
661    let mut exprs = vec![];
662    expr.apply(|expr| {
663        if test_fn(expr) {
664            if !(exprs.contains(expr)) {
665                exprs.push(expr.clone())
666            }
667            // Stop recursing down this expr once we find a match
668            return Ok(TreeNodeRecursion::Jump);
669        }
670
671        Ok(TreeNodeRecursion::Continue)
672    })
673    // pre_visit always returns OK, so this will always too
674    .expect("no way to return error during recursion");
675    exprs
676}
677
678/// Recursively inspect an [`Expr`] and all its children.
679/// todo: document about that columns may refer to a lambda parameter?
680pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
681where
682    F: FnMut(&Expr) -> Result<(), E>,
683{
684    let mut err = Ok(());
685    expr.apply(|expr| {
686        if let Err(e) = f(expr) {
687            // Save the error for later (it may not be a DataFusionError)
688            err = Err(e);
689            Ok(TreeNodeRecursion::Stop)
690        } else {
691            // keep going
692            Ok(TreeNodeRecursion::Continue)
693        }
694    })
695    // The closure always returns OK, so this will always too
696    .expect("no way to return error during recursion");
697
698    err
699}
700
701/// Create schema fields from an expression list, for use in result set schema construction
702///
703/// This function converts a list of expressions into a list of complete schema fields,
704/// making comprehensive determinations about each field's properties including:
705/// - **Data type**: Resolved based on expression type and input schema context
706/// - **Nullability**: Determined by expression-specific nullability rules
707/// - **Metadata**: Computed based on expression type (preserving, merging, or generating new metadata)
708/// - **Table reference scoping**: Establishing proper qualified field references
709///
710/// Each expression is converted to a field by calling [`Expr::to_field`], which performs
711/// the complete field resolution process for all field properties.
712///
713/// # Returns
714///
715/// A `Result` containing a vector of `(Option<TableReference>, Arc<Field>)` tuples,
716/// where each Field contains complete schema information (type, nullability, metadata)
717/// and proper table reference scoping for the corresponding expression.
718pub fn exprlist_to_fields<'a>(
719    exprs: impl IntoIterator<Item = &'a Expr>,
720    plan: &LogicalPlan,
721) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
722    // Look for exact match in plan's output schema
723    let input_schema = plan.schema();
724    exprs
725        .into_iter()
726        .map(|e| e.to_field(input_schema))
727        .collect()
728}
729
730/// Convert an expression into Column expression if it's already provided as input plan.
731///
732/// For example, it rewrites:
733///
734/// ```text
735/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
736/// .project(vec![col("c1"), sum(col("c2"))?
737/// ```
738///
739/// Into:
740///
741/// ```text
742/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
743/// .project(vec![col("c1"), col("SUM(c2)")?
744/// ```
745pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
746    let output_exprs = match input.columnized_output_exprs() {
747        Ok(exprs) if !exprs.is_empty() => exprs,
748        _ => return Ok(e),
749    };
750    let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
751    e.transform_down_with_lambdas_params(|node: Expr, lambdas_params| {
752        if matches!(&node, Expr::Column(c) if c.is_lambda_parameter(lambdas_params)) {
753            return Ok(Transformed::no(node));
754        }
755
756        match exprs_map.get(&node) {
757            Some(column) => Ok(Transformed::new(
758                Expr::Column(column.clone()),
759                true,
760                TreeNodeRecursion::Jump,
761            )),
762            None => Ok(Transformed::no(node)),
763        }
764    })
765    .data()
766}
767
768/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
769/// appearance (depth first), and may contain duplicates.
770pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
771    exprs
772        .iter()
773        .flat_map(find_columns_referenced_by_expr)
774        .map(Expr::Column)
775        .collect()
776}
777
778pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
779    let mut exprs = vec![];
780    e.apply_with_lambdas_params(|expr, lambdas_params| {
781        if let Expr::Column(c) = expr {
782            if !c.is_lambda_parameter(lambdas_params) {
783                exprs.push(c.clone())
784            }
785        }
786        Ok(TreeNodeRecursion::Continue)
787    })
788    // As the closure always returns Ok, this "can't" error
789    .expect("Unexpected error");
790    exprs
791}
792
793/// Convert any `Expr` to an `Expr::Column`.
794pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
795    match expr {
796        Expr::Column(col) => {
797            let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
798            Ok(Expr::from(Column::from((qualifier, field))))
799        }
800        _ => Ok(Expr::Column(Column::from_name(
801            expr.schema_name().to_string(),
802        ))),
803    }
804}
805
806/// Recursively walk an expression tree, collecting the column indexes
807/// referenced in the expression
808pub(crate) fn find_column_indexes_referenced_by_expr(
809    e: &Expr,
810    schema: &DFSchemaRef,
811) -> Vec<usize> {
812    let mut indexes = vec![];
813    e.apply_with_lambdas_params(|expr, lambdas_params| {
814        match expr {
815            Expr::Column(qc) if !qc.is_lambda_parameter(lambdas_params) => {
816                if let Ok(idx) = schema.index_of_column(qc) {
817                    indexes.push(idx);
818                }
819            }
820            Expr::Literal(_, _) => {
821                indexes.push(usize::MAX);
822            }
823            _ => {}
824        }
825        Ok(TreeNodeRecursion::Continue)
826    })
827    .unwrap();
828    indexes
829}
830
831/// Can this data type be used in hash join equal conditions??
832/// Data types here come from function 'equal_rows', if more data types are supported
833/// in create_hashes, add those data types here to generate join logical plan.
834pub fn can_hash(data_type: &DataType) -> bool {
835    match data_type {
836        DataType::Null => true,
837        DataType::Boolean => true,
838        DataType::Int8 => true,
839        DataType::Int16 => true,
840        DataType::Int32 => true,
841        DataType::Int64 => true,
842        DataType::UInt8 => true,
843        DataType::UInt16 => true,
844        DataType::UInt32 => true,
845        DataType::UInt64 => true,
846        DataType::Float16 => true,
847        DataType::Float32 => true,
848        DataType::Float64 => true,
849        DataType::Decimal32(_, _) => true,
850        DataType::Decimal64(_, _) => true,
851        DataType::Decimal128(_, _) => true,
852        DataType::Decimal256(_, _) => true,
853        DataType::Timestamp(_, _) => true,
854        DataType::Utf8 => true,
855        DataType::LargeUtf8 => true,
856        DataType::Utf8View => true,
857        DataType::Binary => true,
858        DataType::LargeBinary => true,
859        DataType::BinaryView => true,
860        DataType::Date32 => true,
861        DataType::Date64 => true,
862        DataType::Time32(_) => true,
863        DataType::Time64(_) => true,
864        DataType::Duration(_) => true,
865        DataType::Interval(_) => true,
866        DataType::FixedSizeBinary(_) => true,
867        DataType::Dictionary(key_type, value_type) => {
868            DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
869        }
870        DataType::List(value_type) => can_hash(value_type.data_type()),
871        DataType::LargeList(value_type) => can_hash(value_type.data_type()),
872        DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
873        DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
874        DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
875
876        DataType::ListView(_)
877        | DataType::LargeListView(_)
878        | DataType::Union(_, _)
879        | DataType::RunEndEncoded(_, _) => false,
880    }
881}
882
883/// Check whether all columns are from the schema.
884pub fn check_all_columns_from_schema(
885    columns: &HashSet<&Column>,
886    schema: &DFSchema,
887) -> Result<bool> {
888    for col in columns.iter() {
889        let exist = schema.is_column_from_schema(col);
890        if !exist {
891            return Ok(false);
892        }
893    }
894
895    Ok(true)
896}
897
898/// Give two sides of the equijoin predicate, return a valid join key pair.
899/// If there is no valid join key pair, return None.
900///
901/// A valid join means:
902/// 1. All referenced column of the left side is from the left schema, and
903///    all referenced column of the right side is from the right schema.
904/// 2. Or opposite. All referenced column of the left side is from the right schema,
905///    and the right side is from the left schema.
906pub fn find_valid_equijoin_key_pair(
907    left_key: &Expr,
908    right_key: &Expr,
909    left_schema: &DFSchema,
910    right_schema: &DFSchema,
911) -> Result<Option<(Expr, Expr)>> {
912    let left_using_columns = left_key.column_refs();
913    let right_using_columns = right_key.column_refs();
914
915    // Conditions like a = 10, will be added to non-equijoin.
916    if left_using_columns.is_empty() || right_using_columns.is_empty() {
917        return Ok(None);
918    }
919
920    if check_all_columns_from_schema(&left_using_columns, left_schema)?
921        && check_all_columns_from_schema(&right_using_columns, right_schema)?
922    {
923        return Ok(Some((left_key.clone(), right_key.clone())));
924    } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
925        && check_all_columns_from_schema(&left_using_columns, right_schema)?
926    {
927        return Ok(Some((right_key.clone(), left_key.clone())));
928    }
929
930    Ok(None)
931}
932
933/// Creates a detailed error message for a function with wrong signature.
934///
935/// For example, a query like `select round(3.14, 1.1);` would yield:
936/// ```text
937/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
938///     Candidate functions:
939///     round(Float64, Int64)
940///     round(Float32, Int64)
941///     round(Float64)
942///     round(Float32)
943/// ```
944pub fn generate_signature_error_msg(
945    func_name: &str,
946    func_signature: Signature,
947    input_expr_types: &[DataType],
948) -> String {
949    let candidate_signatures = func_signature
950        .type_signature
951        .to_string_repr_with_names(func_signature.parameter_names.as_deref())
952        .iter()
953        .map(|args_str| format!("\t{func_name}({args_str})"))
954        .collect::<Vec<String>>()
955        .join("\n");
956
957    format!(
958            "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
959            func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
960        )
961}
962
963/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
964///
965/// See [`split_conjunction_owned`] for more details and an example.
966pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
967    split_conjunction_impl(expr, vec![])
968}
969
970fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
971    match expr {
972        Expr::BinaryExpr(BinaryExpr {
973            right,
974            op: Operator::And,
975            left,
976        }) => {
977            let exprs = split_conjunction_impl(left, exprs);
978            split_conjunction_impl(right, exprs)
979        }
980        Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
981        other => {
982            exprs.push(other);
983            exprs
984        }
985    }
986}
987
988/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
989///
990/// See [`split_conjunction_owned`] for more details and an example.
991pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
992    let mut stack = vec![expr];
993    std::iter::from_fn(move || {
994        while let Some(expr) = stack.pop() {
995            match expr {
996                Expr::BinaryExpr(BinaryExpr {
997                    right,
998                    op: Operator::And,
999                    left,
1000                }) => {
1001                    stack.push(right);
1002                    stack.push(left);
1003                }
1004                Expr::Alias(Alias { expr, .. }) => stack.push(expr),
1005                other => return Some(other),
1006            }
1007        }
1008        None
1009    })
1010}
1011
1012/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1013///
1014/// See [`split_conjunction_owned`] for more details and an example.
1015pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
1016    let mut stack = vec![expr];
1017    std::iter::from_fn(move || {
1018        while let Some(expr) = stack.pop() {
1019            match expr {
1020                Expr::BinaryExpr(BinaryExpr {
1021                    right,
1022                    op: Operator::And,
1023                    left,
1024                }) => {
1025                    stack.push(*right);
1026                    stack.push(*left);
1027                }
1028                Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
1029                other => return Some(other),
1030            }
1031        }
1032        None
1033    })
1034}
1035
1036/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1037///
1038/// This is often used to "split" filter expressions such as `col1 = 5
1039/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1040///
1041/// # Example
1042/// ```
1043/// # use datafusion_expr::{col, lit};
1044/// # use datafusion_expr::utils::split_conjunction_owned;
1045/// // a=1 AND b=2
1046/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1047///
1048/// // [a=1, b=2]
1049/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1050///
1051/// // use split_conjunction_owned to split them
1052/// assert_eq!(split_conjunction_owned(expr), split);
1053/// ```
1054pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1055    split_binary_owned(expr, Operator::And)
1056}
1057
1058/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1059///
1060/// This is often used to "split" expressions such as `col1 = 5
1061/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1062///
1063/// # Example
1064/// ```
1065/// # use datafusion_expr::{col, lit, Operator};
1066/// # use datafusion_expr::utils::split_binary_owned;
1067/// # use std::ops::Add;
1068/// // a=1 + b=2
1069/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
1070///
1071/// // [a=1, b=2]
1072/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1073///
1074/// // use split_binary_owned to split them
1075/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
1076/// ```
1077pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1078    split_binary_owned_impl(expr, op, vec![])
1079}
1080
1081fn split_binary_owned_impl(
1082    expr: Expr,
1083    operator: Operator,
1084    mut exprs: Vec<Expr>,
1085) -> Vec<Expr> {
1086    match expr {
1087        Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1088            let exprs = split_binary_owned_impl(*left, operator, exprs);
1089            split_binary_owned_impl(*right, operator, exprs)
1090        }
1091        Expr::Alias(Alias { expr, .. }) => {
1092            split_binary_owned_impl(*expr, operator, exprs)
1093        }
1094        other => {
1095            exprs.push(other);
1096            exprs
1097        }
1098    }
1099}
1100
1101/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1102///
1103/// See [`split_binary_owned`] for more details and an example.
1104pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1105    split_binary_impl(expr, op, vec![])
1106}
1107
1108fn split_binary_impl<'a>(
1109    expr: &'a Expr,
1110    operator: Operator,
1111    mut exprs: Vec<&'a Expr>,
1112) -> Vec<&'a Expr> {
1113    match expr {
1114        Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1115            let exprs = split_binary_impl(left, operator, exprs);
1116            split_binary_impl(right, operator, exprs)
1117        }
1118        Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1119        other => {
1120            exprs.push(other);
1121            exprs
1122        }
1123    }
1124}
1125
1126/// Combines an array of filter expressions into a single filter
1127/// expression consisting of the input filter expressions joined with
1128/// logical AND.
1129///
1130/// Returns None if the filters array is empty.
1131///
1132/// # Example
1133/// ```
1134/// # use datafusion_expr::{col, lit};
1135/// # use datafusion_expr::utils::conjunction;
1136/// // a=1 AND b=2
1137/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1138///
1139/// // [a=1, b=2]
1140/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1141///
1142/// // use conjunction to join them together with `AND`
1143/// assert_eq!(conjunction(split), Some(expr));
1144/// ```
1145pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1146    filters.into_iter().reduce(Expr::and)
1147}
1148
1149/// Combines an array of filter expressions into a single filter
1150/// expression consisting of the input filter expressions joined with
1151/// logical OR.
1152///
1153/// Returns None if the filters array is empty.
1154///
1155/// # Example
1156/// ```
1157/// # use datafusion_expr::{col, lit};
1158/// # use datafusion_expr::utils::disjunction;
1159/// // a=1 OR b=2
1160/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
1161///
1162/// // [a=1, b=2]
1163/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1164///
1165/// // use disjunction to join them together with `OR`
1166/// assert_eq!(disjunction(split), Some(expr));
1167/// ```
1168pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1169    filters.into_iter().reduce(Expr::or)
1170}
1171
1172/// Returns a new [LogicalPlan] that filters the output of  `plan` with a
1173/// [LogicalPlan::Filter] with all `predicates` ANDed.
1174///
1175/// # Example
1176/// Before:
1177/// ```text
1178/// plan
1179/// ```
1180///
1181/// After:
1182/// ```text
1183/// Filter(predicate)
1184///   plan
1185/// ```
1186pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1187    // reduce filters to a single filter with an AND
1188    let predicate = predicates
1189        .iter()
1190        .skip(1)
1191        .fold(predicates[0].clone(), |acc, predicate| {
1192            and(acc, (*predicate).to_owned())
1193        });
1194
1195    Ok(LogicalPlan::Filter(Filter::try_new(
1196        predicate,
1197        Arc::new(plan),
1198    )?))
1199}
1200
1201/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and
1202/// one not in the subquery (closed upon from outer scope)
1203///
1204/// # Arguments
1205///
1206/// * `exprs` - List of expressions that may or may not be joins
1207///
1208/// # Return value
1209///
1210/// Tuple of (expressions containing joins, remaining non-join expressions)
1211pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1212    let mut joins = vec![];
1213    let mut others = vec![];
1214    for filter in exprs.into_iter() {
1215        // If the expression contains correlated predicates, add it to join filters
1216        if filter.contains_outer() {
1217            if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1218            {
1219                joins.push(strip_outer_reference((*filter).clone()));
1220            }
1221        } else {
1222            others.push((*filter).clone());
1223        }
1224    }
1225
1226    Ok((joins, others))
1227}
1228
1229/// Returns the first (and only) element in a slice, or an error
1230///
1231/// # Arguments
1232///
1233/// * `slice` - The slice to extract from
1234///
1235/// # Return value
1236///
1237/// The first element, or an error
1238pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1239    match slice {
1240        [it] => Ok(it),
1241        [] => plan_err!("No items found!"),
1242        _ => plan_err!("More than one item found!"),
1243    }
1244}
1245
1246/// merge inputs schema into a single schema.
1247///
1248/// This function merges schemas from multiple logical plan inputs using [`DFSchema::merge`].
1249/// Refer to that documentation for details on precedence and metadata handling.
1250pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1251    if inputs.len() == 1 {
1252        inputs[0].schema().as_ref().clone()
1253    } else {
1254        inputs.iter().map(|input| input.schema()).fold(
1255            DFSchema::empty(),
1256            |mut lhs, rhs| {
1257                lhs.merge(rhs);
1258                lhs
1259            },
1260        )
1261    }
1262}
1263
1264/// Build state name. State is the intermediate state of the aggregate function.
1265pub fn format_state_name(name: &str, state_name: &str) -> String {
1266    format!("{name}[{state_name}]")
1267}
1268
1269/// Determine the set of [`Column`]s produced by the subquery.
1270pub fn collect_subquery_cols(
1271    exprs: &[Expr],
1272    subquery_schema: &DFSchema,
1273) -> Result<BTreeSet<Column>> {
1274    exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1275        let mut using_cols: Vec<Column> = vec![];
1276        for col in expr.column_refs().into_iter() {
1277            if subquery_schema.has_column(col) {
1278                using_cols.push(col.clone());
1279            }
1280        }
1281
1282        cols.extend(using_cols);
1283        Result::<_>::Ok(cols)
1284    })
1285}
1286
1287#[cfg(test)]
1288mod tests {
1289    use super::*;
1290    use crate::{
1291        col, cube,
1292        expr::WindowFunction,
1293        expr_vec_fmt, grouping_set, lit, rollup,
1294        test::function_stub::{max_udaf, min_udaf, sum_udaf},
1295        Cast, ExprFunctionExt, WindowFunctionDefinition,
1296    };
1297    use arrow::datatypes::{UnionFields, UnionMode};
1298    use datafusion_expr_common::signature::{TypeSignature, Volatility};
1299
1300    #[test]
1301    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1302        let result = group_window_expr_by_sort_keys(vec![])?;
1303        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1304        assert_eq!(expected, result);
1305        Ok(())
1306    }
1307
1308    #[test]
1309    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1310        let max1 = Expr::from(WindowFunction::new(
1311            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1312            vec![col("name")],
1313        ));
1314        let max2 = Expr::from(WindowFunction::new(
1315            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1316            vec![col("name")],
1317        ));
1318        let min3 = Expr::from(WindowFunction::new(
1319            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1320            vec![col("name")],
1321        ));
1322        let sum4 = Expr::from(WindowFunction::new(
1323            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1324            vec![col("age")],
1325        ));
1326        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1327        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1328        let key = vec![];
1329        let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1330            vec![(key, vec![max1, max2, min3, sum4])];
1331        assert_eq!(expected, result);
1332        Ok(())
1333    }
1334
1335    #[test]
1336    fn test_group_window_expr_by_sort_keys() -> Result<()> {
1337        let age_asc = Sort::new(col("age"), true, true);
1338        let name_desc = Sort::new(col("name"), false, true);
1339        let created_at_desc = Sort::new(col("created_at"), false, true);
1340        let max1 = Expr::from(WindowFunction::new(
1341            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1342            vec![col("name")],
1343        ))
1344        .order_by(vec![age_asc.clone(), name_desc.clone()])
1345        .build()
1346        .unwrap();
1347        let max2 = Expr::from(WindowFunction::new(
1348            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1349            vec![col("name")],
1350        ));
1351        let min3 = Expr::from(WindowFunction::new(
1352            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1353            vec![col("name")],
1354        ))
1355        .order_by(vec![age_asc.clone(), name_desc.clone()])
1356        .build()
1357        .unwrap();
1358        let sum4 = Expr::from(WindowFunction::new(
1359            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1360            vec![col("age")],
1361        ))
1362        .order_by(vec![
1363            name_desc.clone(),
1364            age_asc.clone(),
1365            created_at_desc.clone(),
1366        ])
1367        .build()
1368        .unwrap();
1369        // FIXME use as_ref
1370        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1371        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1372
1373        let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1374        let key2 = vec![];
1375        let key3 = vec![
1376            (name_desc, false),
1377            (age_asc, false),
1378            (created_at_desc, false),
1379        ];
1380
1381        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1382            (key1, vec![max1, min3]),
1383            (key2, vec![max2]),
1384            (key3, vec![sum4]),
1385        ];
1386        assert_eq!(expected, result);
1387        Ok(())
1388    }
1389
1390    #[test]
1391    fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1392        let asc_or_desc = [true, false];
1393        let nulls_first_or_last = [true, false];
1394        let partition_by = &[col("age"), col("name"), col("created_at")];
1395        for asc_ in asc_or_desc {
1396            for nulls_first_ in nulls_first_or_last {
1397                let order_by = &[
1398                    Sort {
1399                        expr: col("age"),
1400                        asc: asc_,
1401                        nulls_first: nulls_first_,
1402                    },
1403                    Sort {
1404                        expr: col("name"),
1405                        asc: asc_,
1406                        nulls_first: nulls_first_,
1407                    },
1408                ];
1409
1410                let expected = vec![
1411                    (
1412                        Sort {
1413                            expr: col("age"),
1414                            asc: asc_,
1415                            nulls_first: nulls_first_,
1416                        },
1417                        true,
1418                    ),
1419                    (
1420                        Sort {
1421                            expr: col("name"),
1422                            asc: asc_,
1423                            nulls_first: nulls_first_,
1424                        },
1425                        true,
1426                    ),
1427                    (
1428                        Sort {
1429                            expr: col("created_at"),
1430                            asc: true,
1431                            nulls_first: false,
1432                        },
1433                        true,
1434                    ),
1435                ];
1436                let result = generate_sort_key(partition_by, order_by)?;
1437                assert_eq!(expected, result);
1438            }
1439        }
1440        Ok(())
1441    }
1442
1443    #[test]
1444    fn test_enumerate_grouping_sets() -> Result<()> {
1445        let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1446        let simple_col = col("simple_col");
1447        let cube = cube(multi_cols.clone());
1448        let rollup = rollup(multi_cols.clone());
1449        let grouping_set = grouping_set(vec![multi_cols]);
1450
1451        // 1. col
1452        let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1453        let result = format!("[{}]", expr_vec_fmt!(sets));
1454        assert_eq!("[simple_col]", &result);
1455
1456        // 2. cube
1457        let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1458        let result = format!("[{}]", expr_vec_fmt!(sets));
1459        assert_eq!("[CUBE (col1, col2, col3)]", &result);
1460
1461        // 3. rollup
1462        let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1463        let result = format!("[{}]", expr_vec_fmt!(sets));
1464        assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1465
1466        // 4. col + cube
1467        let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1468        let result = format!("[{}]", expr_vec_fmt!(sets));
1469        assert_eq!(
1470            "[GROUPING SETS (\
1471            (simple_col), \
1472            (simple_col, col1), \
1473            (simple_col, col2), \
1474            (simple_col, col1, col2), \
1475            (simple_col, col3), \
1476            (simple_col, col1, col3), \
1477            (simple_col, col2, col3), \
1478            (simple_col, col1, col2, col3))]",
1479            &result
1480        );
1481
1482        // 5. col + rollup
1483        let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1484        let result = format!("[{}]", expr_vec_fmt!(sets));
1485        assert_eq!(
1486            "[GROUPING SETS (\
1487            (simple_col), \
1488            (simple_col, col1), \
1489            (simple_col, col1, col2), \
1490            (simple_col, col1, col2, col3))]",
1491            &result
1492        );
1493
1494        // 6. col + grouping_set
1495        let sets =
1496            enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1497        let result = format!("[{}]", expr_vec_fmt!(sets));
1498        assert_eq!(
1499            "[GROUPING SETS (\
1500            (simple_col, col1, col2, col3))]",
1501            &result
1502        );
1503
1504        // 7. col + grouping_set + rollup
1505        let sets = enumerate_grouping_sets(vec![
1506            simple_col.clone(),
1507            grouping_set,
1508            rollup.clone(),
1509        ])?;
1510        let result = format!("[{}]", expr_vec_fmt!(sets));
1511        assert_eq!(
1512            "[GROUPING SETS (\
1513            (simple_col, col1, col2, col3), \
1514            (simple_col, col1, col2, col3, col1), \
1515            (simple_col, col1, col2, col3, col1, col2), \
1516            (simple_col, col1, col2, col3, col1, col2, col3))]",
1517            &result
1518        );
1519
1520        // 8. col + cube + rollup
1521        let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1522        let result = format!("[{}]", expr_vec_fmt!(sets));
1523        assert_eq!(
1524            "[GROUPING SETS (\
1525            (simple_col), \
1526            (simple_col, col1), \
1527            (simple_col, col1, col2), \
1528            (simple_col, col1, col2, col3), \
1529            (simple_col, col1), \
1530            (simple_col, col1, col1), \
1531            (simple_col, col1, col1, col2), \
1532            (simple_col, col1, col1, col2, col3), \
1533            (simple_col, col2), \
1534            (simple_col, col2, col1), \
1535            (simple_col, col2, col1, col2), \
1536            (simple_col, col2, col1, col2, col3), \
1537            (simple_col, col1, col2), \
1538            (simple_col, col1, col2, col1), \
1539            (simple_col, col1, col2, col1, col2), \
1540            (simple_col, col1, col2, col1, col2, col3), \
1541            (simple_col, col3), \
1542            (simple_col, col3, col1), \
1543            (simple_col, col3, col1, col2), \
1544            (simple_col, col3, col1, col2, col3), \
1545            (simple_col, col1, col3), \
1546            (simple_col, col1, col3, col1), \
1547            (simple_col, col1, col3, col1, col2), \
1548            (simple_col, col1, col3, col1, col2, col3), \
1549            (simple_col, col2, col3), \
1550            (simple_col, col2, col3, col1), \
1551            (simple_col, col2, col3, col1, col2), \
1552            (simple_col, col2, col3, col1, col2, col3), \
1553            (simple_col, col1, col2, col3), \
1554            (simple_col, col1, col2, col3, col1), \
1555            (simple_col, col1, col2, col3, col1, col2), \
1556            (simple_col, col1, col2, col3, col1, col2, col3))]",
1557            &result
1558        );
1559
1560        Ok(())
1561    }
1562    #[test]
1563    fn test_split_conjunction() {
1564        let expr = col("a");
1565        let result = split_conjunction(&expr);
1566        assert_eq!(result, vec![&expr]);
1567    }
1568
1569    #[test]
1570    fn test_split_conjunction_two() {
1571        let expr = col("a").eq(lit(5)).and(col("b"));
1572        let expr1 = col("a").eq(lit(5));
1573        let expr2 = col("b");
1574
1575        let result = split_conjunction(&expr);
1576        assert_eq!(result, vec![&expr1, &expr2]);
1577    }
1578
1579    #[test]
1580    fn test_split_conjunction_alias() {
1581        let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1582        let expr1 = col("a").eq(lit(5));
1583        let expr2 = col("b"); // has no alias
1584
1585        let result = split_conjunction(&expr);
1586        assert_eq!(result, vec![&expr1, &expr2]);
1587    }
1588
1589    #[test]
1590    fn test_split_conjunction_or() {
1591        let expr = col("a").eq(lit(5)).or(col("b"));
1592        let result = split_conjunction(&expr);
1593        assert_eq!(result, vec![&expr]);
1594    }
1595
1596    #[test]
1597    fn test_split_binary_owned() {
1598        let expr = col("a");
1599        assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1600    }
1601
1602    #[test]
1603    fn test_split_binary_owned_two() {
1604        assert_eq!(
1605            split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1606            vec![col("a").eq(lit(5)), col("b")]
1607        );
1608    }
1609
1610    #[test]
1611    fn test_split_binary_owned_different_op() {
1612        let expr = col("a").eq(lit(5)).or(col("b"));
1613        assert_eq!(
1614            // expr is connected by OR, but pass in AND
1615            split_binary_owned(expr.clone(), Operator::And),
1616            vec![expr]
1617        );
1618    }
1619
1620    #[test]
1621    fn test_split_conjunction_owned() {
1622        let expr = col("a");
1623        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1624    }
1625
1626    #[test]
1627    fn test_split_conjunction_owned_two() {
1628        assert_eq!(
1629            split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1630            vec![col("a").eq(lit(5)), col("b")]
1631        );
1632    }
1633
1634    #[test]
1635    fn test_split_conjunction_owned_alias() {
1636        assert_eq!(
1637            split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1638            vec![
1639                col("a").eq(lit(5)),
1640                // no alias on b
1641                col("b"),
1642            ]
1643        );
1644    }
1645
1646    #[test]
1647    fn test_conjunction_empty() {
1648        assert_eq!(conjunction(vec![]), None);
1649    }
1650
1651    #[test]
1652    fn test_conjunction() {
1653        // `[A, B, C]`
1654        let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1655
1656        // --> `(A AND B) AND C`
1657        assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1658
1659        // which is different than `A AND (B AND C)`
1660        assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1661    }
1662
1663    #[test]
1664    fn test_disjunction_empty() {
1665        assert_eq!(disjunction(vec![]), None);
1666    }
1667
1668    #[test]
1669    fn test_disjunction() {
1670        // `[A, B, C]`
1671        let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1672
1673        // --> `(A OR B) OR C`
1674        assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1675
1676        // which is different than `A OR (B OR C)`
1677        assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1678    }
1679
1680    #[test]
1681    fn test_split_conjunction_owned_or() {
1682        let expr = col("a").eq(lit(5)).or(col("b"));
1683        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1684    }
1685
1686    #[test]
1687    fn test_collect_expr() -> Result<()> {
1688        let mut accum: HashSet<Column> = HashSet::new();
1689        expr_to_columns(
1690            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1691            &mut accum,
1692        )?;
1693        expr_to_columns(
1694            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1695            &mut accum,
1696        )?;
1697        assert_eq!(1, accum.len());
1698        assert!(accum.contains(&Column::from_name("a")));
1699        Ok(())
1700    }
1701
1702    #[test]
1703    fn test_can_hash() {
1704        let union_fields: UnionFields = [
1705            (0, Arc::new(Field::new("A", DataType::Int32, true))),
1706            (1, Arc::new(Field::new("B", DataType::Float64, true))),
1707        ]
1708        .into_iter()
1709        .collect();
1710
1711        let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1712        assert!(!can_hash(&union_type));
1713
1714        let list_union_type =
1715            DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1716        assert!(!can_hash(&list_union_type));
1717    }
1718
1719    #[test]
1720    fn test_generate_signature_error_msg_with_parameter_names() {
1721        let sig = Signature::one_of(
1722            vec![
1723                TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
1724                TypeSignature::Exact(vec![
1725                    DataType::Utf8,
1726                    DataType::Int64,
1727                    DataType::Int64,
1728                ]),
1729            ],
1730            Volatility::Immutable,
1731        )
1732        .with_parameter_names(vec![
1733            "str".to_string(),
1734            "start_pos".to_string(),
1735            "length".to_string(),
1736        ])
1737        .expect("valid parameter names");
1738
1739        // Generate error message with only 1 argument provided
1740        let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]);
1741
1742        assert!(
1743            error_msg.contains("str: Utf8, start_pos: Int64"),
1744            "Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}"
1745        );
1746        assert!(
1747            error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"),
1748            "Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}"
1749        );
1750    }
1751
1752    #[test]
1753    fn test_generate_signature_error_msg_without_parameter_names() {
1754        let sig = Signature::one_of(
1755            vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1756            Volatility::Immutable,
1757        );
1758
1759        let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]);
1760
1761        assert!(
1762            error_msg.contains("Any, Any"),
1763            "Expected 'Any, Any' without parameter names, got: {error_msg}"
1764        );
1765    }
1766}