datafusion_expr/expr_rewriter/
order_by.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Rewrite for order by expressions
19
20use crate::expr::Alias;
21use crate::expr_rewriter::normalize_col;
22use crate::{expr::Sort, Cast, Expr, LogicalPlan, TryCast};
23
24use datafusion_common::tree_node::{
25    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
26};
27use datafusion_common::{Column, Result};
28
29/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
30/// For example, `max(x)` is written to `col("max(x)")`
31pub fn rewrite_sort_cols_by_aggs(
32    sorts: impl IntoIterator<Item = impl Into<Sort>>,
33    plan: &LogicalPlan,
34) -> Result<Vec<Sort>> {
35    sorts
36        .into_iter()
37        .map(|e| {
38            let sort = e.into();
39            Ok(Sort::new(
40                rewrite_sort_col_by_aggs(sort.expr, plan)?,
41                sort.asc,
42                sort.nulls_first,
43            ))
44        })
45        .collect()
46}
47
48fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
49    let plan_inputs = plan.inputs();
50
51    // Joins, and Unions are not yet handled (should have a projection
52    // on top of them)
53    if plan_inputs.len() == 1 {
54        let proj_exprs = plan.expressions();
55        rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0])
56    } else {
57        Ok(expr)
58    }
59}
60
61/// Rewrites a sort expression in terms of the output of the previous [`LogicalPlan`]
62///
63/// Example:
64///
65/// Given an input expression such as `col(a) + col(b) + col(c)`
66///
67/// into `col(a) + col("b + c")`
68///
69/// Remember that:
70/// 1. given a projection with exprs: [a, b + c]
71/// 2. t produces an output schema with two columns "a", "b + c"
72fn rewrite_in_terms_of_projection(
73    expr: Expr,
74    proj_exprs: Vec<Expr>,
75    input: &LogicalPlan,
76) -> Result<Expr> {
77    // assumption is that each item in exprs, such as "b + c" is
78    // available as an output column named "b + c"
79    expr.transform(|expr| {
80        if matches!(expr, Expr::Lambda(_)) {
81            return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump))
82        }
83
84        // search for unnormalized names first such as "c1" (such as aliases)
85        if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
86            let (qualifier, field_name) = found.qualified_name();
87            let col = Expr::Column(Column::new(qualifier, field_name));
88            return Ok(Transformed::yes(col));
89        }
90
91        // if that doesn't work, try to match the expression as an
92        // output column -- however first it must be "normalized"
93        // (e.g. "c1" --> "t.c1") because that normalization is done
94        // at the input of the aggregate.
95
96        let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
97            e
98        } else {
99            // The expr is not based on Aggregate plan output. Skip it.
100            return Ok(Transformed::no(expr));
101        };
102
103        // expr is an actual expr like min(t.c2), but we are looking
104        // for a column with the same "MIN(C2)", so translate there
105        let name = normalized_expr.schema_name().to_string();
106
107        let search_col = Expr::Column(Column::new_unqualified(name));
108
109        // look for the column named the same as this expr
110        let mut found = None;
111        for proj_expr in &proj_exprs {
112            proj_expr.apply(|e| {
113                if expr_match(&search_col, e) {
114                    found = Some(e.clone());
115                    return Ok(TreeNodeRecursion::Stop);
116                }
117                Ok(TreeNodeRecursion::Continue)
118            })?;
119        }
120
121        if let Some(found) = found {
122            return Ok(Transformed::yes(match normalized_expr {
123                Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
124                    expr: Box::new(found),
125                    data_type,
126                }),
127                Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast {
128                    expr: Box::new(found),
129                    data_type,
130                }),
131                _ => found,
132            }));
133        }
134
135        Ok(Transformed::no(expr))
136    })
137    .data()
138}
139
140/// Does the underlying expr match e?
141/// so avg(c) as average will match avgc
142fn expr_match(needle: &Expr, expr: &Expr) -> bool {
143    // check inside aliases
144    if let Expr::Alias(Alias { expr, .. }) = &expr {
145        expr.as_ref() == needle
146    } else {
147        expr == needle
148    }
149}
150
151#[cfg(test)]
152mod test {
153    use std::ops::Add;
154    use std::sync::Arc;
155
156    use arrow::datatypes::{DataType, Field, Schema};
157
158    use crate::{
159        cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast,
160        LogicalPlanBuilder,
161    };
162
163    use super::*;
164    use crate::test::function_stub::avg;
165    use crate::test::function_stub::min;
166
167    #[test]
168    fn rewrite_sort_cols_by_agg() {
169        //  gby c1, agg: min(c2)
170        let agg = make_input()
171            .aggregate(
172                // gby: c1
173                vec![col("c1")],
174                // agg: min(c2)
175                vec![min(col("c2"))],
176            )
177            .unwrap()
178            .build()
179            .unwrap();
180
181        let cases = vec![
182            TestCase {
183                desc: "c1 --> c1",
184                input: sort(col("c1")),
185                expected: sort(col("c1")),
186            },
187            TestCase {
188                desc: "c1 + c2 --> c1 + c2",
189                input: sort(col("c1") + col("c1")),
190                expected: sort(col("c1") + col("c1")),
191            },
192            TestCase {
193                desc: r#"min(c2) --> "min(c2)"#,
194                input: sort(min(col("c2"))),
195                expected: sort(min(col("c2"))),
196            },
197            TestCase {
198                desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
199                input: sort(col("c1") + min(col("c2"))),
200                expected: sort(col("c1") + min(col("c2"))),
201            },
202        ];
203
204        for case in cases {
205            case.run(&agg)
206        }
207    }
208
209    #[test]
210    fn rewrite_sort_cols_by_agg_alias() {
211        let agg = make_input()
212            .aggregate(
213                // gby c1
214                vec![col("c1")],
215                // agg: min(c2), avg(c3)
216                vec![min(col("c2")), avg(col("c3"))],
217            )
218            .unwrap()
219            //  projects out an expression "c1" that is different than the column "c1"
220            .project(vec![
221                // c1 + 1 as c1,
222                col("c1").add(lit(1)).alias("c1"),
223                // min(c2)
224                min(col("c2")),
225                // avg("c3") as average
226                avg(col("c3")).alias("average"),
227            ])
228            .unwrap()
229            .build()
230            .unwrap();
231
232        let cases = vec![
233            TestCase {
234                desc: "c1 --> c1  -- column *named* c1 that came out of the projection, (not t.c1)",
235                input: sort(col("c1")),
236                // should be "c1" not t.c1
237                expected: sort(col("c1")),
238            },
239            TestCase {
240                desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
241                input: sort(min(col("c2"))),
242                expected: sort(col("min(t.c2)")),
243            },
244            TestCase {
245                desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
246                input: sort(col("c1") + min(col("c2"))),
247                // should be "c1" not t.c1
248                expected: sort(col("c1") + col("min(t.c2)")),
249            },
250            TestCase {
251                desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
252                input: sort(avg(col("c3"))),
253                expected: sort(col("avg(t.c3)").alias("average")),
254            },
255        ];
256
257        for case in cases {
258            case.run(&agg)
259        }
260    }
261
262    #[test]
263    fn preserve_cast() {
264        let plan = make_input()
265            .project(vec![col("c2").alias("c2")])
266            .unwrap()
267            .project(vec![col("c2").alias("c2")])
268            .unwrap()
269            .build()
270            .unwrap();
271
272        let cases = vec![
273            TestCase {
274                desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
275                input: sort(cast(col("c2"), DataType::Int64)),
276                expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
277            },
278            TestCase {
279                desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
280                input: sort(try_cast(col("c2"), DataType::Int64)),
281                expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)),
282            },
283        ];
284
285        for case in cases {
286            case.run(&plan)
287        }
288    }
289
290    struct TestCase {
291        desc: &'static str,
292        input: Sort,
293        expected: Sort,
294    }
295
296    impl TestCase {
297        /// calls rewrite_sort_cols_by_aggs for expr and compares it to expected_expr
298        fn run(self, input_plan: &LogicalPlan) {
299            let Self {
300                desc,
301                input,
302                expected,
303            } = self;
304
305            println!("running: '{desc}'");
306            let mut exprs =
307                rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
308
309            assert_eq!(exprs.len(), 1);
310            let rewritten = exprs.pop().unwrap();
311
312            assert_eq!(
313                rewritten, expected,
314                "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
315            );
316        }
317    }
318
319    /// Scan of a table: t(c1 int, c2 varchar, c3 float)
320    fn make_input() -> LogicalPlanBuilder {
321        let schema = Arc::new(Schema::new(vec![
322            Field::new("c1", DataType::Int32, true),
323            Field::new("c2", DataType::Utf8, true),
324            Field::new("c3", DataType::Float64, true),
325        ]));
326        let projection = None;
327        LogicalPlanBuilder::scan(
328            "t",
329            Arc::new(LogicalTableSource::new(schema)),
330            projection,
331        )
332        .unwrap()
333    }
334
335    fn sort(expr: Expr) -> Sort {
336        let asc = true;
337        let nulls_first = true;
338        expr.sort(asc, nulls_first)
339    }
340}