1use std::{cmp::Ordering, sync::Arc, vec};
19
20use super::{
21 dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle,
22 rewrite::TableAliasRewriter, Unparser,
23};
24use datafusion_common::{
25 internal_err,
26 tree_node::{Transformed, TransformedResult, TreeNode},
27 Column, DataFusionError, Result, ScalarValue,
28};
29use datafusion_expr::{
30 expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan,
31 LogicalPlanBuilder, Projection, SortExpr, Unnest, Window,
32};
33
34use indexmap::IndexSet;
35use sqlparser::ast;
36use sqlparser::tokenizer::Span;
37
38pub(crate) fn find_agg_node_within_select(
43 plan: &LogicalPlan,
44 already_projected: bool,
45) -> Option<&Aggregate> {
46 let input = plan.inputs();
49 let input = if input.len() > 1 {
50 return None;
51 } else {
52 input.first()?
53 };
54 if let LogicalPlan::Aggregate(agg) = input {
56 Some(agg)
57 } else if let LogicalPlan::TableScan(_) = input {
58 None
59 } else if let LogicalPlan::Projection(_) = input {
60 if already_projected {
61 None
62 } else {
63 find_agg_node_within_select(input, true)
64 }
65 } else {
66 find_agg_node_within_select(input, already_projected)
67 }
68}
69
70pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> {
72 let input = plan.inputs();
75 let input = if input.len() > 1 {
76 return None;
77 } else {
78 input.first()?
79 };
80
81 if let LogicalPlan::Unnest(unnest) = input {
82 Some(unnest)
83 } else if let LogicalPlan::TableScan(_) = input {
84 None
85 } else if let LogicalPlan::Projection(_) = input {
86 None
87 } else {
88 find_unnest_node_within_select(input)
89 }
90}
91
92pub(crate) fn find_unnest_node_until_relation(plan: &LogicalPlan) -> Option<&Unnest> {
95 let input = plan.inputs();
98 let input = if input.len() > 1 {
99 return None;
100 } else {
101 input.first()?
102 };
103
104 if let LogicalPlan::Unnest(unnest) = input {
105 Some(unnest)
106 } else if let LogicalPlan::TableScan(_) = input {
107 None
108 } else if let LogicalPlan::Subquery(_) = input {
109 None
110 } else if let LogicalPlan::SubqueryAlias(_) = input {
111 None
112 } else {
113 find_unnest_node_within_select(input)
114 }
115}
116
117pub(crate) fn find_window_nodes_within_select<'a>(
122 plan: &'a LogicalPlan,
123 mut prev_windows: Option<Vec<&'a Window>>,
124 already_projected: bool,
125) -> Option<Vec<&'a Window>> {
126 let input = plan.inputs();
129 let input = if input.len() > 1 {
130 return prev_windows;
131 } else {
132 input.first()?
133 };
134
135 match input {
137 LogicalPlan::Window(window) => {
138 prev_windows = match &mut prev_windows {
139 Some(windows) => {
140 windows.push(window);
141 prev_windows
142 }
143 _ => Some(vec![window]),
144 };
145 find_window_nodes_within_select(input, prev_windows, already_projected)
146 }
147 LogicalPlan::Projection(_) => {
148 if already_projected {
149 prev_windows
150 } else {
151 find_window_nodes_within_select(input, prev_windows, true)
152 }
153 }
154 LogicalPlan::TableScan(_) => prev_windows,
155 _ => find_window_nodes_within_select(input, prev_windows, already_projected),
156 }
157}
158
159pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result<Expr> {
164 expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| {
165 if let Expr::Column(col_ref) = &sub_expr {
166 if !col_ref.is_lambda_parameter(lambdas_params) && unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) {
169 if let Ok(idx) = unnest.schema.index_of_column(col_ref) {
170 if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() {
171 if let Some(unprojected_expr) = expr.get(idx) {
172 let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone()));
173 return Ok(Transformed::yes(unnest_expr));
174 }
175 }
176 }
177 return internal_err!(
178 "Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name
179 );
180 }
181 }
182
183 Ok(Transformed::no(sub_expr))
184
185 }).map(|e| e.data)
186}
187
188pub(crate) fn unproject_agg_exprs(
194 expr: Expr,
195 agg: &Aggregate,
196 windows: Option<&[&Window]>,
197) -> Result<Expr> {
198 expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| {
199 match sub_expr {
200 Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => if let Some(unprojected_expr) = find_agg_expr(agg, &c)? {
201 Ok(Transformed::yes(unprojected_expr.clone()))
202 } else if let Some(unprojected_expr) =
203 windows.and_then(|w| find_window_expr(w, &c.name).cloned())
204 {
205 Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?))
207 } else {
208 internal_err!(
209 "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name
210 )
211 },
212 _ => Ok(Transformed::no(sub_expr)),
213 }
214 })
215 .map(|e| e.data)
216}
217
218pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result<Expr> {
224 expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| match sub_expr {
225 Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
226 if let Some(unproj) = find_window_expr(windows, &c.name) {
227 Ok(Transformed::yes(unproj.clone()))
228 } else {
229 Ok(Transformed::no(Expr::Column(c)))
230 }
231 }
232 _ => Ok(Transformed::no(sub_expr)),
233 })
234 .map(|e| e.data)
235}
236
237fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a Expr>> {
238 if let Ok(index) = agg.schema.index_of_column(column) {
239 if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
240 let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?;
242 match index.cmp(&grouping_expr.len()) {
243 Ordering::Less => Ok(grouping_expr.into_iter().nth(index)),
244 Ordering::Equal => {
245 internal_err!(
246 "Tried to unproject column referring to internal grouping id"
247 )
248 }
249 Ordering::Greater => {
250 Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1))
251 }
252 }
253 } else {
254 Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index))
255 }
256 } else {
257 Ok(None)
258 }
259}
260
261fn find_window_expr<'a>(
262 windows: &'a [&'a Window],
263 column_name: &'a str,
264) -> Option<&'a Expr> {
265 windows
266 .iter()
267 .flat_map(|w| w.window_expr.iter())
268 .find(|expr| expr.schema_name().to_string() == column_name)
269}
270
271pub(crate) fn unproject_sort_expr(
276 mut sort_expr: SortExpr,
277 agg: Option<&Aggregate>,
278 input: &LogicalPlan,
279) -> Result<SortExpr> {
280 sort_expr.expr = sort_expr
281 .expr
282 .transform(|sub_expr| {
283 match sub_expr {
284 Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)),
286 Expr::Column(col) => {
287 if col.relation.is_some() {
288 return Ok(Transformed::no(Expr::Column(col)));
289 }
290
291 if let Some(agg) = agg {
293 if agg.schema.is_column_from_schema(&col) {
294 return Ok(Transformed::yes(unproject_agg_exprs(
295 Expr::Column(col),
296 agg,
297 None,
298 )?));
299 }
300 }
301
302 if let LogicalPlan::Projection(Projection { expr, schema, .. }) =
306 input
307 {
308 if let Ok(idx) = schema.index_of_column(&col) {
309 if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) {
310 return Ok(Transformed::yes(Expr::ScalarFunction(
311 scalar_fn.clone(),
312 )));
313 }
314 }
315 }
316
317 Ok(Transformed::no(Expr::Column(col)))
318 }
319 _ => Ok(Transformed::no(sub_expr)),
320 }
321 })
322 .map(|e| e.data)?;
323 Ok(sort_expr)
324}
325
326pub(crate) fn try_transform_to_simple_table_scan_with_filters(
343 plan: &LogicalPlan,
344) -> Result<Option<(LogicalPlan, Vec<Expr>)>> {
345 let mut filters: IndexSet<Expr> = IndexSet::new();
346 let mut plan_stack = vec![plan];
347 let mut table_alias = None;
348
349 while let Some(current_plan) = plan_stack.pop() {
350 match current_plan {
351 LogicalPlan::SubqueryAlias(alias) => {
352 table_alias = Some(alias.alias.clone());
353 plan_stack.push(alias.input.as_ref());
354 }
355 LogicalPlan::Filter(filter) => {
356 if !filters.contains(&filter.predicate) {
357 filters.insert(filter.predicate.clone());
358 }
359 plan_stack.push(filter.input.as_ref());
360 }
361 LogicalPlan::TableScan(table_scan) => {
362 let table_schema = table_scan.source.schema();
363 let mut filter_alias_rewriter =
365 table_alias.as_ref().map(|alias_name| TableAliasRewriter {
366 table_schema: &table_schema,
367 alias_name: alias_name.clone(),
368 });
369
370 let table_scan_filters = table_scan
372 .filters
373 .iter()
374 .cloned()
375 .map(|expr| {
376 if let Some(ref mut rewriter) = filter_alias_rewriter {
377 expr.rewrite_with_lambdas_params(rewriter).data()
378 } else {
379 Ok(expr)
380 }
381 })
382 .collect::<Result<Vec<_>, DataFusionError>>()?;
383
384 for table_scan_filter in table_scan_filters {
385 if !filters.contains(&table_scan_filter) {
386 filters.insert(table_scan_filter);
387 }
388 }
389
390 let mut builder = LogicalPlanBuilder::scan(
391 table_scan.table_name.clone(),
392 Arc::clone(&table_scan.source),
393 table_scan.projection.clone(),
394 )?;
395
396 if let Some(alias) = table_alias.take() {
397 builder = builder.alias(alias)?;
398 }
399
400 let plan = builder.build()?;
401 let filters = filters.into_iter().collect();
402
403 return Ok(Some((plan, filters)));
404 }
405 _ => {
406 return Ok(None);
407 }
408 }
409 }
410
411 Ok(None)
412}
413
414pub(crate) fn date_part_to_sql(
416 unparser: &Unparser,
417 style: DateFieldExtractStyle,
418 date_part_args: &[Expr],
419) -> Result<Option<ast::Expr>> {
420 match (style, date_part_args.len()) {
421 (DateFieldExtractStyle::Extract, 2) => {
422 let date_expr = unparser.expr_to_sql(&date_part_args[1])?;
423 if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
424 let field = match field.to_lowercase().as_str() {
425 "year" => ast::DateTimeField::Year,
426 "month" => ast::DateTimeField::Month,
427 "day" => ast::DateTimeField::Day,
428 "hour" => ast::DateTimeField::Hour,
429 "minute" => ast::DateTimeField::Minute,
430 "second" => ast::DateTimeField::Second,
431 _ => return Ok(None),
432 };
433
434 return Ok(Some(ast::Expr::Extract {
435 field,
436 expr: Box::new(date_expr),
437 syntax: ast::ExtractSyntax::From,
438 }));
439 }
440 }
441 (DateFieldExtractStyle::Strftime, 2) => {
442 let column = unparser.expr_to_sql(&date_part_args[1])?;
443
444 if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
445 let field = match field.to_lowercase().as_str() {
446 "year" => "%Y",
447 "month" => "%m",
448 "day" => "%d",
449 "hour" => "%H",
450 "minute" => "%M",
451 "second" => "%S",
452 _ => return Ok(None),
453 };
454
455 return Ok(Some(ast::Expr::Function(ast::Function {
456 name: ast::ObjectName::from(vec![ast::Ident {
457 value: "strftime".to_string(),
458 quote_style: None,
459 span: Span::empty(),
460 }]),
461 args: ast::FunctionArguments::List(ast::FunctionArgumentList {
462 duplicate_treatment: None,
463 args: vec![
464 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
465 ast::Expr::value(ast::Value::SingleQuotedString(
466 field.to_string(),
467 )),
468 )),
469 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)),
470 ],
471 clauses: vec![],
472 }),
473 filter: None,
474 null_treatment: None,
475 over: None,
476 within_group: vec![],
477 parameters: ast::FunctionArguments::None,
478 uses_odbc_syntax: false,
479 })));
480 }
481 }
482 (DateFieldExtractStyle::DatePart, _) => {
483 return Ok(Some(
484 unparser.scalar_function_to_sql("date_part", date_part_args)?,
485 ));
486 }
487 _ => {}
488 };
489
490 Ok(None)
491}
492
493pub(crate) fn character_length_to_sql(
494 unparser: &Unparser,
495 style: CharacterLengthStyle,
496 character_length_args: &[Expr],
497) -> Result<Option<ast::Expr>> {
498 let func_name = match style {
499 CharacterLengthStyle::CharacterLength => "character_length",
500 CharacterLengthStyle::Length => "length",
501 };
502
503 Ok(Some(unparser.scalar_function_to_sql(
504 func_name,
505 character_length_args,
506 )?))
507}
508
509pub(crate) fn sqlite_from_unixtime_to_sql(
518 unparser: &Unparser,
519 from_unixtime_args: &[Expr],
520) -> Result<Option<ast::Expr>> {
521 if from_unixtime_args.len() != 1 {
522 return internal_err!(
523 "from_unixtime for SQLite expects 1 argument, found {}",
524 from_unixtime_args.len()
525 );
526 }
527
528 Ok(Some(unparser.scalar_function_to_sql(
529 "datetime",
530 &[
531 from_unixtime_args[0].clone(),
532 Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string())), None),
533 ],
534 )?))
535}
536
537pub(crate) fn sqlite_date_trunc_to_sql(
545 unparser: &Unparser,
546 date_trunc_args: &[Expr],
547) -> Result<Option<ast::Expr>> {
548 if date_trunc_args.len() != 2 {
549 return internal_err!(
550 "date_trunc for SQLite expects 2 arguments, found {}",
551 date_trunc_args.len()
552 );
553 }
554
555 if let Expr::Literal(ScalarValue::Utf8(Some(unit)), _) = &date_trunc_args[0] {
556 let format = match unit.to_lowercase().as_str() {
557 "year" => "%Y",
558 "month" => "%Y-%m",
559 "day" => "%Y-%m-%d",
560 "hour" => "%Y-%m-%d %H",
561 "minute" => "%Y-%m-%d %H:%M",
562 "second" => "%Y-%m-%d %H:%M:%S",
563 _ => return Ok(None),
564 };
565
566 return Ok(Some(unparser.scalar_function_to_sql(
567 "strftime",
568 &[
569 Expr::Literal(ScalarValue::Utf8(Some(format.to_string())), None),
570 date_trunc_args[1].clone(),
571 ],
572 )?));
573 }
574
575 Ok(None)
576}