datafusion_expr/expr_rewriter/
order_by.rs1use crate::expr::Alias;
21use crate::expr_rewriter::normalize_col;
22use crate::{expr::Sort, Cast, Expr, LogicalPlan, TryCast};
23
24use datafusion_common::tree_node::{
25 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
26};
27use datafusion_common::{Column, Result};
28
29pub fn rewrite_sort_cols_by_aggs(
32 sorts: impl IntoIterator<Item = impl Into<Sort>>,
33 plan: &LogicalPlan,
34) -> Result<Vec<Sort>> {
35 sorts
36 .into_iter()
37 .map(|e| {
38 let sort = e.into();
39 Ok(Sort::new(
40 rewrite_sort_col_by_aggs(sort.expr, plan)?,
41 sort.asc,
42 sort.nulls_first,
43 ))
44 })
45 .collect()
46}
47
48fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
49 let plan_inputs = plan.inputs();
50
51 if plan_inputs.len() == 1 {
54 let proj_exprs = plan.expressions();
55 rewrite_in_terms_of_projection(expr, proj_exprs, plan_inputs[0])
56 } else {
57 Ok(expr)
58 }
59}
60
61fn rewrite_in_terms_of_projection(
73 expr: Expr,
74 proj_exprs: Vec<Expr>,
75 input: &LogicalPlan,
76) -> Result<Expr> {
77 expr.transform(|expr| {
80 if matches!(expr, Expr::Lambda(_)) {
81 return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump))
82 }
83
84 if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) {
86 let (qualifier, field_name) = found.qualified_name();
87 let col = Expr::Column(Column::new(qualifier, field_name));
88 return Ok(Transformed::yes(col));
89 }
90
91 let normalized_expr = if let Ok(e) = normalize_col(expr.clone(), input) {
97 e
98 } else {
99 return Ok(Transformed::no(expr));
101 };
102
103 let name = normalized_expr.schema_name().to_string();
106
107 let search_col = Expr::Column(Column::new_unqualified(name));
108
109 let mut found = None;
111 for proj_expr in &proj_exprs {
112 proj_expr.apply(|e| {
113 if expr_match(&search_col, e) {
114 found = Some(e.clone());
115 return Ok(TreeNodeRecursion::Stop);
116 }
117 Ok(TreeNodeRecursion::Continue)
118 })?;
119 }
120
121 if let Some(found) = found {
122 return Ok(Transformed::yes(match normalized_expr {
123 Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
124 expr: Box::new(found),
125 data_type,
126 }),
127 Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast {
128 expr: Box::new(found),
129 data_type,
130 }),
131 _ => found,
132 }));
133 }
134
135 Ok(Transformed::no(expr))
136 })
137 .data()
138}
139
140fn expr_match(needle: &Expr, expr: &Expr) -> bool {
143 if let Expr::Alias(Alias { expr, .. }) = &expr {
145 expr.as_ref() == needle
146 } else {
147 expr == needle
148 }
149}
150
151#[cfg(test)]
152mod test {
153 use std::ops::Add;
154 use std::sync::Arc;
155
156 use arrow::datatypes::{DataType, Field, Schema};
157
158 use crate::{
159 cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast,
160 LogicalPlanBuilder,
161 };
162
163 use super::*;
164 use crate::test::function_stub::avg;
165 use crate::test::function_stub::min;
166
167 #[test]
168 fn rewrite_sort_cols_by_agg() {
169 let agg = make_input()
171 .aggregate(
172 vec![col("c1")],
174 vec![min(col("c2"))],
176 )
177 .unwrap()
178 .build()
179 .unwrap();
180
181 let cases = vec![
182 TestCase {
183 desc: "c1 --> c1",
184 input: sort(col("c1")),
185 expected: sort(col("c1")),
186 },
187 TestCase {
188 desc: "c1 + c2 --> c1 + c2",
189 input: sort(col("c1") + col("c1")),
190 expected: sort(col("c1") + col("c1")),
191 },
192 TestCase {
193 desc: r#"min(c2) --> "min(c2)"#,
194 input: sort(min(col("c2"))),
195 expected: sort(min(col("c2"))),
196 },
197 TestCase {
198 desc: r#"c1 + min(c2) --> "c1 + min(c2)"#,
199 input: sort(col("c1") + min(col("c2"))),
200 expected: sort(col("c1") + min(col("c2"))),
201 },
202 ];
203
204 for case in cases {
205 case.run(&agg)
206 }
207 }
208
209 #[test]
210 fn rewrite_sort_cols_by_agg_alias() {
211 let agg = make_input()
212 .aggregate(
213 vec![col("c1")],
215 vec![min(col("c2")), avg(col("c3"))],
217 )
218 .unwrap()
219 .project(vec![
221 col("c1").add(lit(1)).alias("c1"),
223 min(col("c2")),
225 avg(col("c3")).alias("average"),
227 ])
228 .unwrap()
229 .build()
230 .unwrap();
231
232 let cases = vec![
233 TestCase {
234 desc: "c1 --> c1 -- column *named* c1 that came out of the projection, (not t.c1)",
235 input: sort(col("c1")),
236 expected: sort(col("c1")),
238 },
239 TestCase {
240 desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#,
241 input: sort(min(col("c2"))),
242 expected: sort(col("min(t.c2)")),
243 },
244 TestCase {
245 desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#,
246 input: sort(col("c1") + min(col("c2"))),
247 expected: sort(col("c1") + col("min(t.c2)")),
249 },
250 TestCase {
251 desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
252 input: sort(avg(col("c3"))),
253 expected: sort(col("avg(t.c3)").alias("average")),
254 },
255 ];
256
257 for case in cases {
258 case.run(&agg)
259 }
260 }
261
262 #[test]
263 fn preserve_cast() {
264 let plan = make_input()
265 .project(vec![col("c2").alias("c2")])
266 .unwrap()
267 .project(vec![col("c2").alias("c2")])
268 .unwrap()
269 .build()
270 .unwrap();
271
272 let cases = vec![
273 TestCase {
274 desc: "Cast is preserved by rewrite_sort_cols_by_aggs",
275 input: sort(cast(col("c2"), DataType::Int64)),
276 expected: sort(cast(col("c2").alias("c2"), DataType::Int64)),
277 },
278 TestCase {
279 desc: "TryCast is preserved by rewrite_sort_cols_by_aggs",
280 input: sort(try_cast(col("c2"), DataType::Int64)),
281 expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)),
282 },
283 ];
284
285 for case in cases {
286 case.run(&plan)
287 }
288 }
289
290 struct TestCase {
291 desc: &'static str,
292 input: Sort,
293 expected: Sort,
294 }
295
296 impl TestCase {
297 fn run(self, input_plan: &LogicalPlan) {
299 let Self {
300 desc,
301 input,
302 expected,
303 } = self;
304
305 println!("running: '{desc}'");
306 let mut exprs =
307 rewrite_sort_cols_by_aggs(vec![input.clone()], input_plan).unwrap();
308
309 assert_eq!(exprs.len(), 1);
310 let rewritten = exprs.pop().unwrap();
311
312 assert_eq!(
313 rewritten, expected,
314 "\n\ninput:{input:?}\nrewritten:{rewritten:?}\nexpected:{expected:?}\n"
315 );
316 }
317 }
318
319 fn make_input() -> LogicalPlanBuilder {
321 let schema = Arc::new(Schema::new(vec![
322 Field::new("c1", DataType::Int32, true),
323 Field::new("c2", DataType::Utf8, true),
324 Field::new("c3", DataType::Float64, true),
325 ]));
326 let projection = None;
327 LogicalPlanBuilder::scan(
328 "t",
329 Arc::new(LogicalTableSource::new(schema)),
330 projection,
331 )
332 .unwrap()
333 }
334
335 fn sort(expr: Expr) -> Sort {
336 let asc = true;
337 let nulls_first = true;
338 expr.sort(asc, nulls_first)
339 }
340}