1use std::collections::BTreeSet;
21use std::fmt::Debug;
22use std::sync::Arc;
23
24use crate::{OptimizerConfig, OptimizerRule};
25
26use crate::optimizer::ApplyOrder;
27use crate::utils::NamePreserver;
28use datafusion_common::alias::AliasGenerator;
29
30use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE};
31use datafusion_common::tree_node::{Transformed, TreeNode};
32use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, HashSet, Result};
33use datafusion_expr::expr::{Alias, ScalarFunction};
34use datafusion_expr::logical_plan::{
35 Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
36};
37use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr};
38
39const CSE_PREFIX: &str = "__common_expr";
40
41#[derive(Debug)]
68pub struct CommonSubexprEliminate {}
69
70impl CommonSubexprEliminate {
71 pub fn new() -> Self {
72 Self {}
73 }
74
75 fn try_optimize_proj(
76 &self,
77 projection: Projection,
78 config: &dyn OptimizerConfig,
79 ) -> Result<Transformed<LogicalPlan>> {
80 let Projection {
81 expr,
82 input,
83 schema,
84 ..
85 } = projection;
86 let input = Arc::unwrap_or_clone(input);
87 self.try_unary_plan(expr, input, config)?
88 .map_data(|(new_expr, new_input)| {
89 Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema)
90 .map(LogicalPlan::Projection)
91 })
92 }
93
94 fn try_optimize_sort(
95 &self,
96 sort: Sort,
97 config: &dyn OptimizerConfig,
98 ) -> Result<Transformed<LogicalPlan>> {
99 let Sort { expr, input, fetch } = sort;
100 let input = Arc::unwrap_or_clone(input);
101 let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
102 .into_iter()
103 .map(|sort| (sort.expr, (sort.asc, sort.nulls_first)))
104 .unzip();
105 let new_sort = self
106 .try_unary_plan(sort_expressions, input, config)?
107 .update_data(|(new_expr, new_input)| {
108 LogicalPlan::Sort(Sort {
109 expr: new_expr
110 .into_iter()
111 .zip(sort_params)
112 .map(|(expr, (asc, nulls_first))| SortExpr {
113 expr,
114 asc,
115 nulls_first,
116 })
117 .collect(),
118 input: Arc::new(new_input),
119 fetch,
120 })
121 });
122 Ok(new_sort)
123 }
124
125 fn try_optimize_filter(
126 &self,
127 filter: Filter,
128 config: &dyn OptimizerConfig,
129 ) -> Result<Transformed<LogicalPlan>> {
130 let Filter {
131 predicate, input, ..
132 } = filter;
133 let input = Arc::unwrap_or_clone(input);
134 let expr = vec![predicate];
135 self.try_unary_plan(expr, input, config)?
136 .map_data(|(mut new_expr, new_input)| {
137 assert_eq!(new_expr.len(), 1); let new_predicate = new_expr.pop().unwrap();
139 Filter::try_new(new_predicate, Arc::new(new_input))
140 .map(LogicalPlan::Filter)
141 })
142 }
143
144 fn try_optimize_window(
145 &self,
146 window: Window,
147 config: &dyn OptimizerConfig,
148 ) -> Result<Transformed<LogicalPlan>> {
149 let (window_expr_list, window_schemas, input) =
152 get_consecutive_window_exprs(window);
153
154 match CSE::new(ExprCSEController::new(
157 config.alias_generator().as_ref(),
158 ExprMask::Normal,
159 ))
160 .extract_common_nodes(window_expr_list)?
161 {
162 FoundCommonNodes::Yes {
166 common_nodes: common_exprs,
167 new_nodes_list: new_exprs_list,
168 original_nodes_list: original_exprs_list,
169 } => build_common_expr_project_plan(input, common_exprs).map(|new_input| {
170 Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list)))
171 }),
172 FoundCommonNodes::No {
173 original_nodes_list: original_exprs_list,
174 } => Ok(Transformed::no((original_exprs_list, input, None))),
175 }?
176 .transform_data(|(new_window_expr_list, new_input, window_expr_list)| {
179 self.rewrite(new_input, config)?.map_data(|new_input| {
180 Ok((new_window_expr_list, new_input, window_expr_list))
181 })
182 })?
183 .map_data(|(new_window_expr_list, new_input, window_expr_list)| {
185 if let Some(window_expr_list) = window_expr_list {
194 let name_preserver = NamePreserver::new_for_projection();
195 let saved_names = window_expr_list
196 .iter()
197 .map(|exprs| {
198 exprs
199 .iter()
200 .map(|expr| name_preserver.save(expr))
201 .collect::<Vec<_>>()
202 })
203 .collect::<Vec<_>>();
204 new_window_expr_list.into_iter().zip(saved_names).try_rfold(
205 new_input,
206 |plan, (new_window_expr, saved_names)| {
207 let new_window_expr = new_window_expr
208 .into_iter()
209 .zip(saved_names)
210 .map(|(new_window_expr, saved_name)| {
211 saved_name.restore(new_window_expr)
212 })
213 .collect::<Vec<_>>();
214 Window::try_new(new_window_expr, Arc::new(plan))
215 .map(LogicalPlan::Window)
216 },
217 )
218 } else {
219 new_window_expr_list
220 .into_iter()
221 .zip(window_schemas)
222 .try_rfold(new_input, |plan, (new_window_expr, schema)| {
223 Window::try_new_with_schema(
224 new_window_expr,
225 Arc::new(plan),
226 schema,
227 )
228 .map(LogicalPlan::Window)
229 })
230 }
231 })
232 }
233
234 fn try_optimize_aggregate(
235 &self,
236 aggregate: Aggregate,
237 config: &dyn OptimizerConfig,
238 ) -> Result<Transformed<LogicalPlan>> {
239 let Aggregate {
240 group_expr,
241 aggr_expr,
242 input,
243 schema,
244 ..
245 } = aggregate;
246 let input = Arc::unwrap_or_clone(input);
247 match CSE::new(ExprCSEController::new(
249 config.alias_generator().as_ref(),
250 ExprMask::Normal,
251 ))
252 .extract_common_nodes(vec![group_expr, aggr_expr])?
253 {
254 FoundCommonNodes::Yes {
258 common_nodes: common_exprs,
259 new_nodes_list: mut new_exprs_list,
260 original_nodes_list: mut original_exprs_list,
261 } => {
262 let new_aggr_expr = new_exprs_list.pop().unwrap();
263 let new_group_expr = new_exprs_list.pop().unwrap();
264
265 build_common_expr_project_plan(input, common_exprs).map(|new_input| {
266 let aggr_expr = original_exprs_list.pop().unwrap();
267 Transformed::yes((
268 new_aggr_expr,
269 new_group_expr,
270 new_input,
271 Some(aggr_expr),
272 ))
273 })
274 }
275
276 FoundCommonNodes::No {
277 original_nodes_list: mut original_exprs_list,
278 } => {
279 let new_aggr_expr = original_exprs_list.pop().unwrap();
280 let new_group_expr = original_exprs_list.pop().unwrap();
281
282 Ok(Transformed::no((
283 new_aggr_expr,
284 new_group_expr,
285 input,
286 None,
287 )))
288 }
289 }?
290 .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| {
293 self.rewrite(new_input, config)?.map_data(|new_input| {
294 Ok((
295 new_aggr_expr,
296 new_group_expr,
297 aggr_expr,
298 Arc::new(new_input),
299 ))
300 })
301 })?
302 .transform_data(
304 |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| {
305 match CSE::new(ExprCSEController::new(
307 config.alias_generator().as_ref(),
308 ExprMask::NormalAndAggregates,
309 ))
310 .extract_common_nodes(vec![new_aggr_expr])?
311 {
312 FoundCommonNodes::Yes {
313 common_nodes: common_exprs,
314 new_nodes_list: mut new_exprs_list,
315 original_nodes_list: mut original_exprs_list,
316 } => {
317 let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
318 let new_aggr_expr = original_exprs_list.pop().unwrap();
319 let saved_names = if let Some(aggr_expr) = aggr_expr {
320 let name_preserver = NamePreserver::new_for_projection();
321 aggr_expr
322 .iter()
323 .map(|expr| Some(name_preserver.save(expr)))
324 .collect::<Vec<_>>()
325 } else {
326 new_aggr_expr
327 .clone()
328 .into_iter()
329 .map(|_| None)
330 .collect::<Vec<_>>()
331 };
332
333 let mut agg_exprs = common_exprs
334 .into_iter()
335 .map(|(expr, expr_alias)| expr.alias(expr_alias))
336 .collect::<Vec<_>>();
337
338 let mut proj_exprs = vec![];
339 for expr in &new_group_expr {
340 extract_expressions(expr, &mut proj_exprs)
341 }
342 for ((expr_rewritten, expr_orig), saved_name) in
343 rewritten_aggr_expr
344 .into_iter()
345 .zip(new_aggr_expr)
346 .zip(saved_names)
347 {
348 if expr_rewritten == expr_orig {
349 let expr_rewritten = if let Some(saved_name) = saved_name
350 {
351 saved_name.restore(expr_rewritten)
352 } else {
353 expr_rewritten
354 };
355 if let Expr::Alias(Alias { expr, name, .. }) =
356 expr_rewritten
357 {
358 agg_exprs.push(expr.alias(&name));
359 proj_exprs
360 .push(Expr::Column(Column::from_name(name)));
361 } else {
362 let expr_alias =
363 config.alias_generator().next(CSE_PREFIX);
364 let (qualifier, field_name) =
365 expr_rewritten.qualified_name();
366 let out_name =
367 qualified_name(qualifier.as_ref(), &field_name);
368
369 agg_exprs.push(expr_rewritten.alias(&expr_alias));
370 proj_exprs.push(
371 Expr::Column(Column::from_name(expr_alias))
372 .alias(out_name),
373 );
374 }
375 } else {
376 proj_exprs.push(expr_rewritten);
377 }
378 }
379
380 let agg = LogicalPlan::Aggregate(Aggregate::try_new(
381 new_input,
382 new_group_expr,
383 agg_exprs,
384 )?);
385 Projection::try_new(proj_exprs, Arc::new(agg))
386 .map(|p| Transformed::yes(LogicalPlan::Projection(p)))
387 }
388
389 FoundCommonNodes::No {
392 original_nodes_list: mut original_exprs_list,
393 } => {
394 let rewritten_aggr_expr = original_exprs_list.pop().unwrap();
395
396 if let Some(aggr_expr) = aggr_expr {
407 let name_preserver = NamePreserver::new_for_projection();
408 let saved_names = aggr_expr
409 .iter()
410 .map(|expr| name_preserver.save(expr))
411 .collect::<Vec<_>>();
412 let new_aggr_expr = rewritten_aggr_expr
413 .into_iter()
414 .zip(saved_names)
415 .map(|(new_expr, saved_name)| {
416 saved_name.restore(new_expr)
417 })
418 .collect::<Vec<Expr>>();
419
420 Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
423 .map(LogicalPlan::Aggregate)
424 .map(Transformed::no)
425 } else {
426 Aggregate::try_new_with_schema(
427 new_input,
428 new_group_expr,
429 rewritten_aggr_expr,
430 schema,
431 )
432 .map(LogicalPlan::Aggregate)
433 .map(Transformed::no)
434 }
435 }
436 }
437 },
438 )
439 }
440
441 fn try_unary_plan(
456 &self,
457 exprs: Vec<Expr>,
458 input: LogicalPlan,
459 config: &dyn OptimizerConfig,
460 ) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
461 match CSE::new(ExprCSEController::new(
463 config.alias_generator().as_ref(),
464 ExprMask::Normal,
465 ))
466 .extract_common_nodes(vec![exprs])?
467 {
468 FoundCommonNodes::Yes {
469 common_nodes: common_exprs,
470 new_nodes_list: mut new_exprs_list,
471 original_nodes_list: _,
472 } => {
473 let new_exprs = new_exprs_list.pop().unwrap();
474 build_common_expr_project_plan(input, common_exprs)
475 .map(|new_input| Transformed::yes((new_exprs, new_input)))
476 }
477 FoundCommonNodes::No {
478 original_nodes_list: mut original_exprs_list,
479 } => {
480 let new_exprs = original_exprs_list.pop().unwrap();
481 Ok(Transformed::no((new_exprs, input)))
482 }
483 }?
484 .transform_data(|(new_exprs, new_input)| {
487 self.rewrite(new_input, config)?
488 .map_data(|new_input| Ok((new_exprs, new_input)))
489 })
490 }
491}
492
493fn get_consecutive_window_exprs(
525 window: Window,
526) -> (Vec<Vec<Expr>>, Vec<DFSchemaRef>, LogicalPlan) {
527 let mut window_expr_list = vec![];
528 let mut window_schemas = vec![];
529 let mut plan = LogicalPlan::Window(window);
530 while let LogicalPlan::Window(Window {
531 input,
532 window_expr,
533 schema,
534 }) = plan
535 {
536 window_expr_list.push(window_expr);
537 window_schemas.push(schema);
538
539 plan = Arc::unwrap_or_clone(input);
540 }
541 (window_expr_list, window_schemas, plan)
542}
543
544impl OptimizerRule for CommonSubexprEliminate {
545 fn supports_rewrite(&self) -> bool {
546 true
547 }
548
549 fn apply_order(&self) -> Option<ApplyOrder> {
550 None
554 }
555
556 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
557 fn rewrite(
558 &self,
559 plan: LogicalPlan,
560 config: &dyn OptimizerConfig,
561 ) -> Result<Transformed<LogicalPlan>> {
562 let original_schema = Arc::clone(plan.schema());
563
564 let optimized_plan = match plan {
565 LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?,
566 LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?,
567 LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?,
568 LogicalPlan::Window(window) => self.try_optimize_window(window, config)?,
569 LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?,
570 LogicalPlan::Join(_)
571 | LogicalPlan::Repartition(_)
572 | LogicalPlan::Union(_)
573 | LogicalPlan::TableScan(_)
574 | LogicalPlan::Values(_)
575 | LogicalPlan::EmptyRelation(_)
576 | LogicalPlan::Subquery(_)
577 | LogicalPlan::SubqueryAlias(_)
578 | LogicalPlan::Limit(_)
579 | LogicalPlan::Ddl(_)
580 | LogicalPlan::Explain(_)
581 | LogicalPlan::Analyze(_)
582 | LogicalPlan::Statement(_)
583 | LogicalPlan::DescribeTable(_)
584 | LogicalPlan::Distinct(_)
585 | LogicalPlan::Extension(_)
586 | LogicalPlan::Dml(_)
587 | LogicalPlan::Copy(_)
588 | LogicalPlan::Unnest(_)
589 | LogicalPlan::RecursiveQuery(_) => {
590 plan.map_children(|c| self.rewrite(c, config))?
593 }
594 };
595
596 if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema
598 {
599 optimized_plan.map_data(|optimized_plan| {
600 build_recover_project_plan(&original_schema, optimized_plan)
601 })
602 } else {
603 Ok(optimized_plan)
604 }
605 }
606
607 fn name(&self) -> &str {
608 "common_sub_expression_eliminate"
609 }
610}
611
612#[derive(Debug, Clone, Copy)]
614enum ExprMask {
615 Normal,
624
625 NormalAndAggregates,
627}
628
629struct ExprCSEController<'a> {
630 alias_generator: &'a AliasGenerator,
631 mask: ExprMask,
632
633 alias_counter: usize,
635 lambdas_params: HashSet<String>,
636}
637
638impl<'a> ExprCSEController<'a> {
639 fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self {
640 Self {
641 alias_generator,
642 mask,
643 alias_counter: 0,
644 lambdas_params: HashSet::new(),
645 }
646 }
647}
648
649impl CSEController for ExprCSEController<'_> {
650 type Node = Expr;
651
652 fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
653 match node {
654 Expr::ScalarFunction(ScalarFunction { func, args }) => {
658 func.conditional_arguments(args)
659 }
660
661 Expr::BinaryExpr(BinaryExpr {
664 left,
665 op: Operator::And | Operator::Or,
666 right,
667 }) => Some((vec![left.as_ref()], vec![right.as_ref()])),
668
669 Expr::Case(Case {
673 expr,
674 when_then_expr,
675 else_expr,
676 }) => Some((
677 expr.iter()
678 .map(|e| e.as_ref())
679 .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref()))
680 .collect(),
681 when_then_expr
682 .iter()
683 .take(1)
684 .map(|(_, then)| then.as_ref())
685 .chain(
686 when_then_expr
687 .iter()
688 .skip(1)
689 .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]),
690 )
691 .chain(else_expr.iter().map(|e| e.as_ref()))
692 .collect(),
693 )),
694 _ => None,
695 }
696 }
697
698 fn visit_f_down(&mut self, node: &Expr) {
699 if let Expr::Lambda(lambda) = node {
700 self.lambdas_params
701 .extend(lambda.params.iter().cloned());
702 }
703 }
704
705 fn visit_f_up(&mut self, node: &Expr) {
706 if let Expr::Lambda(lambda) = node {
707 for param in &lambda.params {
708 self.lambdas_params.remove(param);
709 }
710 }
711 }
712
713 fn is_valid(node: &Expr) -> bool {
714 !node.is_volatile_node()
715 }
716
717 fn is_ignored(&self, node: &Expr) -> bool {
718 if matches!(node, Expr::Column(c) if c.is_lambda_parameter(&self.lambdas_params)) {
719 return true
720 }
721
722 #[expect(deprecated)]
724 let is_normal_minus_aggregates = matches!(
725 node,
726 Expr::Literal(..)
727 | Expr::Column(..)
728 | Expr::ScalarVariable(..)
729 | Expr::Alias(..)
730 | Expr::Wildcard { .. }
731 );
732
733 let is_aggr = matches!(node, Expr::AggregateFunction(..));
734
735 match self.mask {
736 ExprMask::Normal => is_normal_minus_aggregates || is_aggr,
737 ExprMask::NormalAndAggregates => is_normal_minus_aggregates,
738 }
739 }
740
741 fn generate_alias(&self) -> String {
742 self.alias_generator.next(CSE_PREFIX)
743 }
744
745 fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
746 if self.alias_counter > 0 {
748 col(alias)
749 } else {
750 self.alias_counter += 1;
751 col(alias).alias(node.schema_name().to_string())
752 }
753 }
754
755 fn rewrite_f_down(&mut self, node: &Expr) {
756 if matches!(node, Expr::Alias(_)) {
757 self.alias_counter += 1;
758 }
759 }
760 fn rewrite_f_up(&mut self, node: &Expr) {
761 if matches!(node, Expr::Alias(_)) {
762 self.alias_counter -= 1
763 }
764 }
765}
766
767impl Default for CommonSubexprEliminate {
768 fn default() -> Self {
769 Self::new()
770 }
771}
772
773fn build_common_expr_project_plan(
784 input: LogicalPlan,
785 common_exprs: Vec<(Expr, String)>,
786) -> Result<LogicalPlan> {
787 let mut fields_set = BTreeSet::new();
788 let mut project_exprs = common_exprs
789 .into_iter()
790 .map(|(expr, expr_alias)| {
791 fields_set.insert(expr_alias.clone());
792 Ok(expr.alias(expr_alias))
793 })
794 .collect::<Result<Vec<_>>>()?;
795
796 for (qualifier, field) in input.schema().iter() {
797 if fields_set.insert(qualified_name(qualifier, field.name())) {
798 project_exprs.push(Expr::from((qualifier, field)));
799 }
800 }
801
802 Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection)
803}
804
805fn build_recover_project_plan(
811 schema: &DFSchema,
812 input: LogicalPlan,
813) -> Result<LogicalPlan> {
814 let col_exprs = schema.iter().map(Expr::from).collect();
815 Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection)
816}
817
818fn extract_expressions(expr: &Expr, result: &mut Vec<Expr>) {
819 if let Expr::GroupingSet(groupings) = expr {
820 for e in groupings.distinct_expr() {
821 let (qualifier, field_name) = e.qualified_name();
822 let col = Column::new(qualifier, field_name);
823 result.push(Expr::Column(col))
824 }
825 } else {
826 let (qualifier, field_name) = expr.qualified_name();
827 let col = Column::new(qualifier, field_name);
828 result.push(Expr::Column(col));
829 }
830}
831
832#[cfg(test)]
833mod test {
834 use std::any::Any;
835 use std::iter;
836
837 use arrow::datatypes::{DataType, Field, Schema};
838 use datafusion_expr::logical_plan::{table_scan, JoinType};
839 use datafusion_expr::{
840 grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF,
841 ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
842 SimpleAggregateUDF, Volatility,
843 };
844 use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
845
846 use super::*;
847 use crate::assert_optimized_plan_eq_snapshot;
848 use crate::optimizer::OptimizerContext;
849 use crate::test::*;
850 use datafusion_expr::test::function_stub::{avg, sum};
851
852 macro_rules! assert_optimized_plan_equal {
853 (
854 $config:expr,
855 $plan:expr,
856 @ $expected:literal $(,)?
857 ) => {{
858 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
859 assert_optimized_plan_eq_snapshot!(
860 $config,
861 rules,
862 $plan,
863 @ $expected,
864 )
865 }};
866
867 (
868 $plan:expr,
869 @ $expected:literal $(,)?
870 ) => {{
871 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(CommonSubexprEliminate::new())];
872 let optimizer_ctx = OptimizerContext::new();
873 assert_optimized_plan_eq_snapshot!(
874 optimizer_ctx,
875 rules,
876 $plan,
877 @ $expected,
878 )
879 }};
880 }
881
882 #[test]
883 fn tpch_q1_simplified() -> Result<()> {
884 let table_scan = test_table_scan()?;
893
894 let plan = LogicalPlanBuilder::from(table_scan)
895 .aggregate(
896 iter::empty::<Expr>(),
897 vec![
898 sum(col("a") * (lit(1) - col("b"))),
899 sum((col("a") * (lit(1) - col("b"))) * (lit(1) + col("c"))),
900 ],
901 )?
902 .build()?;
903
904 assert_optimized_plan_equal!(
905 plan,
906 @ r"
907 Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]
908 Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c
909 TableScan: test
910 "
911 )
912 }
913
914 #[test]
915 fn nested_aliases() -> Result<()> {
916 let table_scan = test_table_scan()?;
917
918 let plan = LogicalPlanBuilder::from(table_scan)
919 .project(vec![
920 (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")),
921 col("a") + col("b"),
922 ])?
923 .build()?;
924
925 assert_optimized_plan_equal!(
926 plan,
927 @ r"
928 Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b
929 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
930 TableScan: test
931 "
932 )
933 }
934
935 #[test]
936 fn aggregate() -> Result<()> {
937 let table_scan = test_table_scan()?;
938
939 let return_type = DataType::UInt32;
940 let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
941 let udf_agg = |inner: Expr| {
942 Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
943 Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
944 "my_agg",
945 Signature::exact(vec![DataType::UInt32], Volatility::Stable),
946 return_type.clone(),
947 Arc::clone(&accumulator),
948 vec![Field::new("value", DataType::UInt32, true).into()],
949 ))),
950 vec![inner],
951 false,
952 None,
953 vec![],
954 None,
955 ))
956 };
957
958 let plan = LogicalPlanBuilder::from(table_scan.clone())
960 .aggregate(
961 iter::empty::<Expr>(),
962 vec![
963 avg(col("a")).alias("col1"),
965 avg(col("a")).alias("col2"),
966 avg(col("b")).alias("col3"),
968 avg(col("c")),
969 udf_agg(col("a")).alias("col4"),
971 udf_agg(col("a")).alias("col5"),
972 udf_agg(col("b")).alias("col6"),
974 udf_agg(col("c")),
975 ],
976 )?
977 .build()?;
978
979 assert_optimized_plan_equal!(
980 plan,
981 @ r"
982 Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)
983 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]
984 TableScan: test
985 "
986 )?;
987
988 let plan = LogicalPlanBuilder::from(table_scan.clone())
990 .aggregate(
991 iter::empty::<Expr>(),
992 vec![
993 lit(1) + avg(col("a")),
994 lit(1) - avg(col("a")),
995 lit(1) + udf_agg(col("a")),
996 lit(1) - udf_agg(col("a")),
997 ],
998 )?
999 .build()?;
1000
1001 assert_optimized_plan_equal!(
1002 plan,
1003 @ r"
1004 Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)
1005 Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]
1006 TableScan: test
1007 "
1008 )?;
1009
1010 let plan = LogicalPlanBuilder::from(table_scan.clone())
1012 .aggregate(
1013 iter::empty::<Expr>(),
1014 vec![
1015 avg(lit(1u32) + col("a")).alias("col1"),
1016 udf_agg(lit(1u32) + col("a")).alias("col2"),
1017 ],
1018 )?
1019 .build()?;
1020
1021 assert_optimized_plan_equal!(
1022 plan,
1023 @ r"
1024 Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1025 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1026 TableScan: test
1027 "
1028 )?;
1029
1030 let plan = LogicalPlanBuilder::from(table_scan.clone())
1032 .aggregate(
1033 vec![lit(1u32) + col("a")],
1034 vec![
1035 avg(lit(1u32) + col("a")).alias("col1"),
1036 udf_agg(lit(1u32) + col("a")).alias("col2"),
1037 ],
1038 )?
1039 .build()?;
1040
1041 assert_optimized_plan_equal!(
1042 plan,
1043 @ r"
1044 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]
1045 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1046 TableScan: test
1047 "
1048 )?;
1049
1050 let plan = LogicalPlanBuilder::from(table_scan)
1052 .aggregate(
1053 vec![lit(1u32) + col("a")],
1054 vec![
1055 (lit(1u32) + avg(lit(1u32) + col("a"))).alias("col1"),
1056 (lit(1u32) - avg(lit(1u32) + col("a"))).alias("col2"),
1057 avg(lit(1u32) + col("a")),
1058 (lit(1u32) + udf_agg(lit(1u32) + col("a"))).alias("col3"),
1059 (lit(1u32) - udf_agg(lit(1u32) + col("a"))).alias("col4"),
1060 udf_agg(lit(1u32) + col("a")),
1061 ],
1062 )?
1063 .build()?;
1064
1065 assert_optimized_plan_equal!(
1066 plan,
1067 @ r"
1068 Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)
1069 Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]
1070 Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1071 TableScan: test
1072 "
1073 )
1074 }
1075
1076 #[test]
1077 fn aggregate_with_relations_and_dots() -> Result<()> {
1078 let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]);
1079 let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?;
1080
1081 let col_a = Expr::Column(Column::new(Some("table.test"), "col.a"));
1082
1083 let plan = LogicalPlanBuilder::from(table_scan)
1084 .aggregate(
1085 vec![col_a.clone()],
1086 vec![
1087 (lit(1u32) + avg(lit(1u32) + col_a.clone())),
1088 avg(lit(1u32) + col_a),
1089 ],
1090 )?
1091 .build()?;
1092
1093 assert_optimized_plan_equal!(
1094 plan,
1095 @ r"
1096 Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)
1097 Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]
1098 Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a
1099 TableScan: table.test
1100 "
1101 )
1102 }
1103
1104 #[test]
1105 fn subexpr_in_same_order() -> Result<()> {
1106 let table_scan = test_table_scan()?;
1107
1108 let plan = LogicalPlanBuilder::from(table_scan)
1109 .project(vec![
1110 (lit(1) + col("a")).alias("first"),
1111 (lit(1) + col("a")).alias("second"),
1112 ])?
1113 .build()?;
1114
1115 assert_optimized_plan_equal!(
1116 plan,
1117 @ r"
1118 Projection: __common_expr_1 AS first, __common_expr_1 AS second
1119 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1120 TableScan: test
1121 "
1122 )
1123 }
1124
1125 #[test]
1126 fn subexpr_in_different_order() -> Result<()> {
1127 let table_scan = test_table_scan()?;
1128
1129 let plan = LogicalPlanBuilder::from(table_scan)
1130 .project(vec![lit(1) + col("a"), col("a") + lit(1)])?
1131 .build()?;
1132
1133 assert_optimized_plan_equal!(
1134 plan,
1135 @ r"
1136 Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)
1137 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1138 TableScan: test
1139 "
1140 )
1141 }
1142
1143 #[test]
1144 fn cross_plans_subexpr() -> Result<()> {
1145 let table_scan = test_table_scan()?;
1146
1147 let plan = LogicalPlanBuilder::from(table_scan)
1148 .project(vec![lit(1) + col("a"), col("a")])?
1149 .project(vec![lit(1) + col("a")])?
1150 .build()?;
1151
1152 assert_optimized_plan_equal!(
1153 plan,
1154 @ r"
1155 Projection: Int32(1) + test.a
1156 Projection: Int32(1) + test.a, test.a
1157 TableScan: test
1158 "
1159 )
1160 }
1161
1162 #[test]
1163 fn redundant_project_fields() {
1164 let table_scan = test_table_scan().unwrap();
1165 let c_plus_a = col("c") + col("a");
1166 let b_plus_a = col("b") + col("a");
1167 let common_exprs_1 = vec![
1168 (c_plus_a, format!("{CSE_PREFIX}_1")),
1169 (b_plus_a, format!("{CSE_PREFIX}_2")),
1170 ];
1171 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1172 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1173 let common_exprs_2 = vec![
1174 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1175 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1176 ];
1177 let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap();
1178 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1179
1180 let mut field_set = BTreeSet::new();
1181 for name in project_2.schema().field_names() {
1182 assert!(field_set.insert(name));
1183 }
1184 }
1185
1186 #[test]
1187 fn redundant_project_fields_join_input() {
1188 let table_scan_1 = test_table_scan_with_name("test1").unwrap();
1189 let table_scan_2 = test_table_scan_with_name("test2").unwrap();
1190 let join = LogicalPlanBuilder::from(table_scan_1)
1191 .join(table_scan_2, JoinType::Inner, (vec!["a"], vec!["a"]), None)
1192 .unwrap()
1193 .build()
1194 .unwrap();
1195 let c_plus_a = col("test1.c") + col("test1.a");
1196 let b_plus_a = col("test1.b") + col("test1.a");
1197 let common_exprs_1 = vec![
1198 (c_plus_a, format!("{CSE_PREFIX}_1")),
1199 (b_plus_a, format!("{CSE_PREFIX}_2")),
1200 ];
1201 let c_plus_a_2 = col(format!("{CSE_PREFIX}_1"));
1202 let b_plus_a_2 = col(format!("{CSE_PREFIX}_2"));
1203 let common_exprs_2 = vec![
1204 (c_plus_a_2, format!("{CSE_PREFIX}_3")),
1205 (b_plus_a_2, format!("{CSE_PREFIX}_4")),
1206 ];
1207 let project = build_common_expr_project_plan(join, common_exprs_1).unwrap();
1208 let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap();
1209
1210 let mut field_set = BTreeSet::new();
1211 for name in project_2.schema().field_names() {
1212 assert!(field_set.insert(name));
1213 }
1214 }
1215
1216 #[test]
1217 fn eliminated_subexpr_datatype() {
1218 use datafusion_expr::cast;
1219
1220 let schema = Schema::new(vec![
1221 Field::new("a", DataType::UInt64, false),
1222 Field::new("b", DataType::UInt64, false),
1223 Field::new("c", DataType::UInt64, false),
1224 ]);
1225
1226 let plan = table_scan(Some("table"), &schema, None)
1227 .unwrap()
1228 .filter(
1229 cast(col("a"), DataType::Int64)
1230 .lt(lit(1_i64))
1231 .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
1232 )
1233 .unwrap()
1234 .build()
1235 .unwrap();
1236 let rule = CommonSubexprEliminate::new();
1237 let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
1238 assert!(optimized_plan.transformed);
1239 let optimized_plan = optimized_plan.data;
1240
1241 let schema = optimized_plan.schema();
1242 let fields_with_datatypes: Vec<_> = schema
1243 .fields()
1244 .iter()
1245 .map(|field| (field.name(), field.data_type()))
1246 .collect();
1247 let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
1248 let expected = r#"[
1249 (
1250 "a",
1251 UInt64,
1252 ),
1253 (
1254 "b",
1255 UInt64,
1256 ),
1257 (
1258 "c",
1259 UInt64,
1260 ),
1261]"#;
1262 assert_eq!(expected, formatted_fields_with_datatype);
1263 }
1264
1265 #[test]
1266 fn filter_schema_changed() -> Result<()> {
1267 let table_scan = test_table_scan()?;
1268
1269 let plan = LogicalPlanBuilder::from(table_scan)
1270 .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
1271 .build()?;
1272
1273 assert_optimized_plan_equal!(
1274 plan,
1275 @ r"
1276 Projection: test.a, test.b, test.c
1277 Filter: __common_expr_1 - Int32(10) > __common_expr_1
1278 Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c
1279 TableScan: test
1280 "
1281 )
1282 }
1283
1284 #[test]
1285 fn test_extract_expressions_from_grouping_set() -> Result<()> {
1286 let mut result = Vec::with_capacity(3);
1287 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]);
1288 extract_expressions(&grouping, &mut result);
1289
1290 assert!(result.len() == 3);
1291 Ok(())
1292 }
1293
1294 #[test]
1295 fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> {
1296 let mut result = Vec::with_capacity(2);
1297 let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]);
1298 extract_expressions(&grouping, &mut result);
1299 assert!(result.len() == 2);
1300 Ok(())
1301 }
1302
1303 #[test]
1304 fn test_alias_collision() -> Result<()> {
1305 let table_scan = test_table_scan()?;
1306
1307 let config = OptimizerContext::new();
1308 let common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1309 let plan = LogicalPlanBuilder::from(table_scan.clone())
1310 .project(vec![
1311 (col("a") + col("b")).alias(common_expr_1.clone()),
1312 col("c"),
1313 ])?
1314 .project(vec![
1315 col(common_expr_1.clone()).alias("c1"),
1316 col(common_expr_1).alias("c2"),
1317 (col("c") + lit(2)).alias("c3"),
1318 (col("c") + lit(2)).alias("c4"),
1319 ])?
1320 .build()?;
1321
1322 assert_optimized_plan_equal!(
1323 config,
1324 plan,
1325 @ r"
1326 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4
1327 Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c
1328 Projection: test.a + test.b AS __common_expr_1, test.c
1329 TableScan: test
1330 "
1331 )?;
1332
1333 let config = OptimizerContext::new();
1334 let _common_expr_1 = config.alias_generator().next(CSE_PREFIX);
1335 let common_expr_2 = config.alias_generator().next(CSE_PREFIX);
1336 let plan = LogicalPlanBuilder::from(table_scan)
1337 .project(vec![
1338 (col("a") + col("b")).alias(common_expr_2.clone()),
1339 col("c"),
1340 ])?
1341 .project(vec![
1342 col(common_expr_2.clone()).alias("c1"),
1343 col(common_expr_2).alias("c2"),
1344 (col("c") + lit(2)).alias("c3"),
1345 (col("c") + lit(2)).alias("c4"),
1346 ])?
1347 .build()?;
1348
1349 assert_optimized_plan_equal!(
1350 config,
1351 plan,
1352 @ r"
1353 Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4
1354 Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c
1355 Projection: test.a + test.b AS __common_expr_2, test.c
1356 TableScan: test
1357 "
1358 )?;
1359
1360 Ok(())
1361 }
1362
1363 #[test]
1364 fn test_extract_expressions_from_col() -> Result<()> {
1365 let mut result = Vec::with_capacity(1);
1366 extract_expressions(&col("a"), &mut result);
1367 assert!(result.len() == 1);
1368 Ok(())
1369 }
1370
1371 #[test]
1372 fn test_short_circuits() -> Result<()> {
1373 let table_scan = test_table_scan()?;
1374
1375 let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0)));
1376 let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0));
1377 let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0));
1378 let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0));
1379 let plan = LogicalPlanBuilder::from(table_scan)
1380 .project(vec![
1381 extracted_short_circuit.clone().alias("c1"),
1382 extracted_short_circuit.alias("c2"),
1383 extracted_short_circuit_leg_1
1384 .clone()
1385 .or(not_extracted_short_circuit_leg_2.clone())
1386 .alias("c3"),
1387 extracted_short_circuit_leg_1
1388 .and(not_extracted_short_circuit_leg_2)
1389 .alias("c4"),
1390 extracted_short_circuit_leg_3
1391 .clone()
1392 .or(extracted_short_circuit_leg_3)
1393 .alias("c5"),
1394 ])?
1395 .build()?;
1396
1397 assert_optimized_plan_equal!(
1398 plan,
1399 @ r"
1400 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5
1401 Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c
1402 TableScan: test
1403 "
1404 )
1405 }
1406
1407 #[test]
1408 fn test_volatile() -> Result<()> {
1409 let table_scan = test_table_scan()?;
1410
1411 let extracted_child = col("a") + col("b");
1412 let rand = rand_func().call(vec![]);
1413 let not_extracted_volatile = extracted_child + rand;
1414 let plan = LogicalPlanBuilder::from(table_scan)
1415 .project(vec![
1416 not_extracted_volatile.clone().alias("c1"),
1417 not_extracted_volatile.alias("c2"),
1418 ])?
1419 .build()?;
1420
1421 assert_optimized_plan_equal!(
1422 plan,
1423 @ r"
1424 Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2
1425 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1426 TableScan: test
1427 "
1428 )
1429 }
1430
1431 #[test]
1432 fn test_volatile_short_circuits() -> Result<()> {
1433 let table_scan = test_table_scan()?;
1434
1435 let rand = rand_func().call(vec![]);
1436 let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
1437 let not_extracted_volatile_short_circuit_1 =
1438 extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
1439 let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
1440 let not_extracted_volatile_short_circuit_2 =
1441 rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
1442 let plan = LogicalPlanBuilder::from(table_scan)
1443 .project(vec![
1444 not_extracted_volatile_short_circuit_1.clone().alias("c1"),
1445 not_extracted_volatile_short_circuit_1.alias("c2"),
1446 not_extracted_volatile_short_circuit_2.clone().alias("c3"),
1447 not_extracted_volatile_short_circuit_2.alias("c4"),
1448 ])?
1449 .build()?;
1450
1451 assert_optimized_plan_equal!(
1452 plan,
1453 @ r"
1454 Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4
1455 Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c
1456 TableScan: test
1457 "
1458 )
1459 }
1460
1461 #[test]
1462 fn test_non_top_level_common_expression() -> Result<()> {
1463 let table_scan = test_table_scan()?;
1464
1465 let common_expr = col("a") + col("b");
1466 let plan = LogicalPlanBuilder::from(table_scan)
1467 .project(vec![
1468 common_expr.clone().alias("c1"),
1469 common_expr.alias("c2"),
1470 ])?
1471 .project(vec![col("c1"), col("c2")])?
1472 .build()?;
1473
1474 assert_optimized_plan_equal!(
1475 plan,
1476 @ r"
1477 Projection: c1, c2
1478 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1479 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1480 TableScan: test
1481 "
1482 )
1483 }
1484
1485 #[test]
1486 fn test_nested_common_expression() -> Result<()> {
1487 let table_scan = test_table_scan()?;
1488
1489 let nested_common_expr = col("a") + col("b");
1490 let common_expr = nested_common_expr.clone() * nested_common_expr;
1491 let plan = LogicalPlanBuilder::from(table_scan)
1492 .project(vec![
1493 common_expr.clone().alias("c1"),
1494 common_expr.alias("c2"),
1495 ])?
1496 .build()?;
1497
1498 assert_optimized_plan_equal!(
1499 plan,
1500 @ r"
1501 Projection: __common_expr_1 AS c1, __common_expr_1 AS c2
1502 Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c
1503 Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c
1504 TableScan: test
1505 "
1506 )
1507 }
1508
1509 #[test]
1510 fn test_normalize_add_expression() -> Result<()> {
1511 let table_scan = test_table_scan()?;
1513 let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30));
1514 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1515
1516 assert_optimized_plan_equal!(
1517 plan,
1518 @ r"
1519 Projection: test.a, test.b, test.c
1520 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1521 Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c
1522 TableScan: test
1523 "
1524 )
1525 }
1526
1527 #[test]
1528 fn test_normalize_multi_expression() -> Result<()> {
1529 let table_scan = test_table_scan()?;
1531 let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30));
1532 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1533
1534 assert_optimized_plan_equal!(
1535 plan,
1536 @ r"
1537 Projection: test.a, test.b, test.c
1538 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1539 Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c
1540 TableScan: test
1541 "
1542 )
1543 }
1544
1545 #[test]
1546 fn test_normalize_bitset_and_expression() -> Result<()> {
1547 let table_scan = test_table_scan()?;
1549 let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30));
1550 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1551
1552 assert_optimized_plan_equal!(
1553 plan,
1554 @ r"
1555 Projection: test.a, test.b, test.c
1556 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1557 Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c
1558 TableScan: test
1559 "
1560 )
1561 }
1562
1563 #[test]
1564 fn test_normalize_bitset_or_expression() -> Result<()> {
1565 let table_scan = test_table_scan()?;
1567 let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30));
1568 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1569
1570 assert_optimized_plan_equal!(
1571 plan,
1572 @ r"
1573 Projection: test.a, test.b, test.c
1574 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1575 Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c
1576 TableScan: test
1577 "
1578 )
1579 }
1580
1581 #[test]
1582 fn test_normalize_bitset_xor_expression() -> Result<()> {
1583 let table_scan = test_table_scan()?;
1585 let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30));
1586 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1587
1588 assert_optimized_plan_equal!(
1589 plan,
1590 @ r"
1591 Projection: test.a, test.b, test.c
1592 Filter: __common_expr_1 + __common_expr_1 = Int32(30)
1593 Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c
1594 TableScan: test
1595 "
1596 )
1597 }
1598
1599 #[test]
1600 fn test_normalize_eq_expression() -> Result<()> {
1601 let table_scan = test_table_scan()?;
1603 let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a")));
1604 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1605
1606 assert_optimized_plan_equal!(
1607 plan,
1608 @ r"
1609 Projection: test.a, test.b, test.c
1610 Filter: __common_expr_1 AND __common_expr_1
1611 Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1612 TableScan: test
1613 "
1614 )
1615 }
1616
1617 #[test]
1618 fn test_normalize_ne_expression() -> Result<()> {
1619 let table_scan = test_table_scan()?;
1621 let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a")));
1622 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1623
1624 assert_optimized_plan_equal!(
1625 plan,
1626 @ r"
1627 Projection: test.a, test.b, test.c
1628 Filter: __common_expr_1 AND __common_expr_1
1629 Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c
1630 TableScan: test
1631 "
1632 )
1633 }
1634
1635 #[test]
1636 fn test_normalize_complex_expression() -> Result<()> {
1637 let table_scan = test_table_scan()?;
1639 let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a")))
1640 .eq(lit(30));
1641 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1642
1643 assert_optimized_plan_equal!(
1644 plan,
1645 @ r"
1646 Projection: test.a, test.b, test.c
1647 Filter: __common_expr_1 - __common_expr_1 = Int32(30)
1648 Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c
1649 TableScan: test
1650 "
1651 )?;
1652
1653 let table_scan = test_table_scan()?;
1655 let expr = (((col("a") + col("b") / col("c")) * col("c"))
1656 / (col("c") * (col("b") / col("c") + col("a")))
1657 + col("a"))
1658 .eq(lit(30));
1659 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1660
1661 assert_optimized_plan_equal!(
1662 plan,
1663 @ r"
1664 Projection: test.a, test.b, test.c
1665 Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)
1666 Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c
1667 TableScan: test
1668 "
1669 )?;
1670
1671 let table_scan = test_table_scan()?;
1673 let expr = ((col("b") / (col("a") + col("c")))
1674 * (col("b") / (col("c") + col("a"))))
1675 .eq(lit(30));
1676 let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?;
1677 assert_optimized_plan_equal!(
1678 plan,
1679 @ r"
1680 Projection: test.a, test.b, test.c
1681 Filter: __common_expr_1 * __common_expr_1 = Int32(30)
1682 Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c
1683 TableScan: test
1684 "
1685 )?;
1686
1687 Ok(())
1688 }
1689
1690 #[derive(Debug, PartialEq, Eq, Hash)]
1691 pub struct TestUdf {
1692 signature: Signature,
1693 }
1694
1695 impl TestUdf {
1696 pub fn new() -> Self {
1697 Self {
1698 signature: Signature::numeric(1, Volatility::Immutable),
1699 }
1700 }
1701 }
1702
1703 impl ScalarUDFImpl for TestUdf {
1704 fn as_any(&self) -> &dyn Any {
1705 self
1706 }
1707 fn name(&self) -> &str {
1708 "my_udf"
1709 }
1710
1711 fn signature(&self) -> &Signature {
1712 &self.signature
1713 }
1714
1715 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
1716 Ok(DataType::Int32)
1717 }
1718
1719 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1720 panic!("not implemented")
1721 }
1722 }
1723
1724 #[test]
1725 fn test_normalize_inner_binary_expression() -> Result<()> {
1726 let table_scan = test_table_scan()?;
1728 let expr1 = not(col("a").eq(col("b")));
1729 let expr2 = not(col("b").eq(col("a")));
1730 let plan = LogicalPlanBuilder::from(table_scan)
1731 .project(vec![expr1, expr2])?
1732 .build()?;
1733 assert_optimized_plan_equal!(
1734 plan,
1735 @ r"
1736 Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a
1737 Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c
1738 TableScan: test
1739 "
1740 )?;
1741
1742 let table_scan = test_table_scan()?;
1744 let expr1 = is_null(col("a").eq(col("b")));
1745 let expr2 = is_null(col("b").eq(col("a")));
1746 let plan = LogicalPlanBuilder::from(table_scan)
1747 .project(vec![expr1, expr2])?
1748 .build()?;
1749 assert_optimized_plan_equal!(
1750 plan,
1751 @ r"
1752 Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL
1753 Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c
1754 TableScan: test
1755 "
1756 )?;
1757
1758 let table_scan = test_table_scan()?;
1760 let expr1 = (col("a") + col("b")).between(lit(0), lit(10));
1761 let expr2 = (col("b") + col("a")).between(lit(0), lit(10));
1762 let plan = LogicalPlanBuilder::from(table_scan)
1763 .project(vec![expr1, expr2])?
1764 .build()?;
1765 assert_optimized_plan_equal!(
1766 plan,
1767 @ r"
1768 Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)
1769 Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1770 TableScan: test
1771 "
1772 )?;
1773
1774 let table_scan = test_table_scan()?;
1776 let expr1 = col("c").between(col("a") + col("b"), lit(10));
1777 let expr2 = col("c").between(col("b") + col("a"), lit(10));
1778 let plan = LogicalPlanBuilder::from(table_scan)
1779 .project(vec![expr1, expr2])?
1780 .build()?;
1781 assert_optimized_plan_equal!(
1782 plan,
1783 @ r"
1784 Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)
1785 Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c
1786 TableScan: test
1787 "
1788 )?;
1789
1790 let udf = ScalarUDF::from(TestUdf::new());
1792 let table_scan = test_table_scan()?;
1793 let expr1 = udf.call(vec![col("a") + col("b")]);
1794 let expr2 = udf.call(vec![col("b") + col("a")]);
1795 let plan = LogicalPlanBuilder::from(table_scan)
1796 .project(vec![expr1, expr2])?
1797 .build()?;
1798 assert_optimized_plan_equal!(
1799 plan,
1800 @ r"
1801 Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)
1802 Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c
1803 TableScan: test
1804 "
1805 )
1806 }
1807
1808 fn rand_func() -> ScalarUDF {
1814 ScalarUDF::new_from_impl(RandomStub::new())
1815 }
1816
1817 #[derive(Debug, PartialEq, Eq, Hash)]
1818 struct RandomStub {
1819 signature: Signature,
1820 }
1821
1822 impl RandomStub {
1823 fn new() -> Self {
1824 Self {
1825 signature: Signature::exact(vec![], Volatility::Volatile),
1826 }
1827 }
1828 }
1829 impl ScalarUDFImpl for RandomStub {
1830 fn as_any(&self) -> &dyn Any {
1831 self
1832 }
1833
1834 fn name(&self) -> &str {
1835 "random"
1836 }
1837
1838 fn signature(&self) -> &Signature {
1839 &self.signature
1840 }
1841
1842 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
1843 Ok(DataType::Float64)
1844 }
1845
1846 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1847 panic!("dummy - not implemented")
1848 }
1849 }
1850}