1use std::collections::{HashMap, HashSet};
21use std::sync::Arc;
22
23use arrow::datatypes::DataType;
24use indexmap::IndexSet;
25use itertools::Itertools;
26
27use datafusion_common::tree_node::{
28 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
29};
30use datafusion_common::{
31 internal_err, plan_err, qualified_name, Column, DFSchema, Result,
32};
33use datafusion_expr::expr::WindowFunction;
34use datafusion_expr::expr_rewriter::replace_col;
35use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union};
36use datafusion_expr::utils::{
37 conjunction, expr_to_columns, split_conjunction, split_conjunction_owned,
38};
39use datafusion_expr::{
40 and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
41};
42
43use crate::optimizer::ApplyOrder;
44use crate::simplify_expressions::simplify_predicates;
45use crate::utils::{has_all_column_refs, is_restrict_null_predicate};
46use crate::{OptimizerConfig, OptimizerRule};
47
48#[derive(Default, Debug)]
136pub struct PushDownFilter {}
137
138pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
163 match join_type {
164 JoinType::Inner => (true, true),
165 JoinType::Left => (true, false),
166 JoinType::Right => (false, true),
167 JoinType::Full => (false, false),
168 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
171 JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
174 }
175}
176
177pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
187 match join_type {
188 JoinType::Inner => (true, true),
189 JoinType::Left => (false, true),
190 JoinType::Right => (true, false),
191 JoinType::Full => (false, false),
192 JoinType::LeftSemi | JoinType::RightSemi => (true, true),
193 JoinType::LeftAnti => (false, true),
194 JoinType::RightAnti => (true, false),
195 JoinType::LeftMark => (false, true),
196 JoinType::RightMark => (true, false),
197 }
198}
199
200#[derive(Debug)]
203struct ColumnChecker<'a> {
204 left_schema: &'a DFSchema,
206 left_columns: Option<HashSet<Column>>,
208 right_schema: &'a DFSchema,
210 right_columns: Option<HashSet<Column>>,
212}
213
214impl<'a> ColumnChecker<'a> {
215 fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self {
216 Self {
217 left_schema,
218 left_columns: None,
219 right_schema,
220 right_columns: None,
221 }
222 }
223
224 fn is_left_only(&mut self, predicate: &Expr) -> bool {
226 if self.left_columns.is_none() {
227 self.left_columns = Some(schema_columns(self.left_schema));
228 }
229 has_all_column_refs(predicate, self.left_columns.as_ref().unwrap())
230 }
231
232 fn is_right_only(&mut self, predicate: &Expr) -> bool {
234 if self.right_columns.is_none() {
235 self.right_columns = Some(schema_columns(self.right_schema));
236 }
237 has_all_column_refs(predicate, self.right_columns.as_ref().unwrap())
238 }
239}
240
241fn schema_columns(schema: &DFSchema) -> HashSet<Column> {
243 schema
244 .iter()
245 .flat_map(|(qualifier, field)| {
246 [
247 Column::new(qualifier.cloned(), field.name()),
248 Column::new_unqualified(field.name()),
250 ]
251 })
252 .collect::<HashSet<_>>()
253}
254
255fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
257 let mut is_evaluate = true;
258 predicate.apply(|expr| match expr {
259 Expr::Column(_)
260 | Expr::Literal(_, _)
261 | Expr::Placeholder(_)
262 | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
263 Expr::Exists { .. }
264 | Expr::InSubquery(_)
265 | Expr::ScalarSubquery(_)
266 | Expr::OuterReferenceColumn(_, _)
267 | Expr::Unnest(_) => {
268 is_evaluate = false;
269 Ok(TreeNodeRecursion::Stop)
270 }
271 Expr::Alias(_)
272 | Expr::BinaryExpr(_)
273 | Expr::Like(_)
274 | Expr::SimilarTo(_)
275 | Expr::Not(_)
276 | Expr::IsNotNull(_)
277 | Expr::IsNull(_)
278 | Expr::IsTrue(_)
279 | Expr::IsFalse(_)
280 | Expr::IsUnknown(_)
281 | Expr::IsNotTrue(_)
282 | Expr::IsNotFalse(_)
283 | Expr::IsNotUnknown(_)
284 | Expr::Negative(_)
285 | Expr::Between(_)
286 | Expr::Case(_)
287 | Expr::Cast(_)
288 | Expr::TryCast(_)
289 | Expr::InList { .. }
290 | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue),
291 #[expect(deprecated)]
293 Expr::AggregateFunction(_)
294 | Expr::WindowFunction(_)
295 | Expr::Wildcard { .. }
296 | Expr::GroupingSet(_)
297 | Expr::Lambda { .. } => internal_err!("Unsupported predicate type"),
298 })?;
299 Ok(is_evaluate)
300}
301
302fn extract_or_clauses_for_join<'a>(
336 filters: &'a [Expr],
337 schema: &'a DFSchema,
338) -> impl Iterator<Item = Expr> + 'a {
339 let schema_columns = schema_columns(schema);
340
341 filters.iter().filter_map(move |expr| {
343 if let Expr::BinaryExpr(BinaryExpr {
344 left,
345 op: Operator::Or,
346 right,
347 }) = expr
348 {
349 let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
350 let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
351
352 if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
354 return Some(or(left_expr, right_expr));
355 }
356 }
357 None
358 })
359}
360
361fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
373 let mut predicate = None;
374
375 match expr {
376 Expr::BinaryExpr(BinaryExpr {
377 left: l_expr,
378 op: Operator::Or,
379 right: r_expr,
380 }) => {
381 let l_expr = extract_or_clause(l_expr, schema_columns);
382 let r_expr = extract_or_clause(r_expr, schema_columns);
383
384 if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
385 predicate = Some(or(l_expr, r_expr));
386 }
387 }
388 Expr::BinaryExpr(BinaryExpr {
389 left: l_expr,
390 op: Operator::And,
391 right: r_expr,
392 }) => {
393 let l_expr = extract_or_clause(l_expr, schema_columns);
394 let r_expr = extract_or_clause(r_expr, schema_columns);
395
396 match (l_expr, r_expr) {
397 (Some(l_expr), Some(r_expr)) => {
398 predicate = Some(and(l_expr, r_expr));
399 }
400 (Some(l_expr), None) => {
401 predicate = Some(l_expr);
402 }
403 (None, Some(r_expr)) => {
404 predicate = Some(r_expr);
405 }
406 (None, None) => {
407 predicate = None;
408 }
409 }
410 }
411 _ => {
412 if has_all_column_refs(expr, schema_columns) {
413 predicate = Some(expr.clone());
414 }
415 }
416 }
417
418 predicate
419}
420
421fn push_down_all_join(
423 predicates: Vec<Expr>,
424 inferred_join_predicates: Vec<Expr>,
425 mut join: Join,
426 on_filter: Vec<Expr>,
427) -> Result<Transformed<LogicalPlan>> {
428 let is_inner_join = join.join_type == JoinType::Inner;
429 let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
431
432 let left_schema = join.left.schema();
437 let right_schema = join.right.schema();
438 let mut left_push = vec![];
439 let mut right_push = vec![];
440 let mut keep_predicates = vec![];
441 let mut join_conditions = vec![];
442 let mut checker = ColumnChecker::new(left_schema, right_schema);
443 for predicate in predicates {
444 if left_preserved && checker.is_left_only(&predicate) {
445 left_push.push(predicate);
446 } else if right_preserved && checker.is_right_only(&predicate) {
447 right_push.push(predicate);
448 } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
449 join_conditions.push(predicate);
452 } else {
453 keep_predicates.push(predicate);
454 }
455 }
456
457 for predicate in inferred_join_predicates {
459 if left_preserved && checker.is_left_only(&predicate) {
460 left_push.push(predicate);
461 } else if right_preserved && checker.is_right_only(&predicate) {
462 right_push.push(predicate);
463 }
464 }
465
466 let mut on_filter_join_conditions = vec![];
467 let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
468
469 if !on_filter.is_empty() {
470 for on in on_filter {
471 if on_left_preserved && checker.is_left_only(&on) {
472 left_push.push(on)
473 } else if on_right_preserved && checker.is_right_only(&on) {
474 right_push.push(on)
475 } else {
476 on_filter_join_conditions.push(on)
477 }
478 }
479 }
480
481 if left_preserved {
484 left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
485 left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
486 }
487 if right_preserved {
488 right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
489 right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
490 }
491
492 if on_left_preserved {
495 left_push.extend(extract_or_clauses_for_join(
496 &on_filter_join_conditions,
497 left_schema,
498 ));
499 }
500 if on_right_preserved {
501 right_push.extend(extract_or_clauses_for_join(
502 &on_filter_join_conditions,
503 right_schema,
504 ));
505 }
506
507 if let Some(predicate) = conjunction(left_push) {
508 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
509 }
510 if let Some(predicate) = conjunction(right_push) {
511 join.right =
512 Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
513 }
514
515 join_conditions.extend(on_filter_join_conditions);
517 join.filter = conjunction(join_conditions);
518
519 let plan = LogicalPlan::Join(join);
521 let plan = if let Some(predicate) = conjunction(keep_predicates) {
522 LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
523 } else {
524 plan
525 };
526 Ok(Transformed::yes(plan))
527}
528
529fn push_down_join(
530 join: Join,
531 parent_predicate: Option<&Expr>,
532) -> Result<Transformed<LogicalPlan>> {
533 let predicates = parent_predicate
535 .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
536
537 let on_filters = join
539 .filter
540 .as_ref()
541 .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
542
543 let inferred_join_predicates =
545 infer_join_predicates(&join, &predicates, &on_filters)?;
546
547 if on_filters.is_empty()
548 && predicates.is_empty()
549 && inferred_join_predicates.is_empty()
550 {
551 return Ok(Transformed::no(LogicalPlan::Join(join)));
552 }
553
554 push_down_all_join(predicates, inferred_join_predicates, join, on_filters)
555}
556
557fn infer_join_predicates(
567 join: &Join,
568 predicates: &[Expr],
569 on_filters: &[Expr],
570) -> Result<Vec<Expr>> {
571 let join_col_keys = join
573 .on
574 .iter()
575 .filter_map(|(l, r)| {
576 let left_col = l.try_as_col()?;
577 let right_col = r.try_as_col()?;
578 Some((left_col, right_col))
579 })
580 .collect::<Vec<_>>();
581
582 let join_type = join.join_type;
583
584 let mut inferred_predicates = InferredPredicates::new(join_type);
585
586 infer_join_predicates_from_predicates(
587 &join_col_keys,
588 predicates,
589 &mut inferred_predicates,
590 )?;
591
592 infer_join_predicates_from_on_filters(
593 &join_col_keys,
594 join_type,
595 on_filters,
596 &mut inferred_predicates,
597 )?;
598
599 Ok(inferred_predicates.predicates)
600}
601
602struct InferredPredicates {
611 predicates: Vec<Expr>,
612 is_inner_join: bool,
613}
614
615impl InferredPredicates {
616 fn new(join_type: JoinType) -> Self {
617 Self {
618 predicates: vec![],
619 is_inner_join: matches!(join_type, JoinType::Inner),
620 }
621 }
622
623 fn try_build_predicate(
624 &mut self,
625 predicate: Expr,
626 replace_map: &HashMap<&Column, &Column>,
627 ) -> Result<()> {
628 if self.is_inner_join
629 || matches!(
630 is_restrict_null_predicate(
631 predicate.clone(),
632 replace_map.keys().cloned()
633 ),
634 Ok(true)
635 )
636 {
637 self.predicates.push(replace_col(predicate, replace_map)?);
638 }
639
640 Ok(())
641 }
642}
643
644fn infer_join_predicates_from_predicates(
653 join_col_keys: &[(&Column, &Column)],
654 predicates: &[Expr],
655 inferred_predicates: &mut InferredPredicates,
656) -> Result<()> {
657 infer_join_predicates_impl::<true, true>(
658 join_col_keys,
659 predicates,
660 inferred_predicates,
661 )
662}
663
664fn infer_join_predicates_from_on_filters(
676 join_col_keys: &[(&Column, &Column)],
677 join_type: JoinType,
678 on_filters: &[Expr],
679 inferred_predicates: &mut InferredPredicates,
680) -> Result<()> {
681 match join_type {
682 JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()),
683 JoinType::Inner => infer_join_predicates_impl::<true, true>(
684 join_col_keys,
685 on_filters,
686 inferred_predicates,
687 ),
688 JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
689 infer_join_predicates_impl::<true, false>(
690 join_col_keys,
691 on_filters,
692 inferred_predicates,
693 )
694 }
695 JoinType::Right | JoinType::RightSemi | JoinType::RightMark => {
696 infer_join_predicates_impl::<false, true>(
697 join_col_keys,
698 on_filters,
699 inferred_predicates,
700 )
701 }
702 }
703}
704
705fn infer_join_predicates_impl<
721 const ENABLE_LEFT_TO_RIGHT: bool,
722 const ENABLE_RIGHT_TO_LEFT: bool,
723>(
724 join_col_keys: &[(&Column, &Column)],
725 input_predicates: &[Expr],
726 inferred_predicates: &mut InferredPredicates,
727) -> Result<()> {
728 for predicate in input_predicates {
729 let mut join_cols_to_replace = HashMap::new();
730
731 for &col in &predicate.column_refs() {
732 for (l, r) in join_col_keys.iter() {
733 if ENABLE_LEFT_TO_RIGHT && col == *l {
734 join_cols_to_replace.insert(col, *r);
735 break;
736 }
737 if ENABLE_RIGHT_TO_LEFT && col == *r {
738 join_cols_to_replace.insert(col, *l);
739 break;
740 }
741 }
742 }
743 if join_cols_to_replace.is_empty() {
744 continue;
745 }
746
747 inferred_predicates
748 .try_build_predicate(predicate.clone(), &join_cols_to_replace)?;
749 }
750 Ok(())
751}
752
753impl OptimizerRule for PushDownFilter {
754 fn name(&self) -> &str {
755 "push_down_filter"
756 }
757
758 fn apply_order(&self) -> Option<ApplyOrder> {
759 Some(ApplyOrder::TopDown)
760 }
761
762 fn supports_rewrite(&self) -> bool {
763 true
764 }
765
766 fn rewrite(
767 &self,
768 plan: LogicalPlan,
769 _config: &dyn OptimizerConfig,
770 ) -> Result<Transformed<LogicalPlan>> {
771 if let LogicalPlan::Join(join) = plan {
772 return push_down_join(join, None);
773 };
774
775 let plan_schema = Arc::clone(plan.schema());
776
777 let LogicalPlan::Filter(mut filter) = plan else {
778 return Ok(Transformed::no(plan));
779 };
780
781 let predicate = split_conjunction_owned(filter.predicate.clone());
782 let old_predicate_len = predicate.len();
783 let new_predicates = simplify_predicates(predicate)?;
784 if old_predicate_len != new_predicates.len() {
785 let Some(new_predicate) = conjunction(new_predicates) else {
786 return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input)));
789 };
790 filter.predicate = new_predicate;
791 }
792
793 match Arc::unwrap_or_clone(filter.input) {
794 LogicalPlan::Filter(child_filter) => {
795 let parents_predicates = split_conjunction_owned(filter.predicate);
796
797 let child_predicates = split_conjunction_owned(child_filter.predicate);
799 let new_predicates = parents_predicates
800 .into_iter()
801 .chain(child_predicates)
802 .collect::<IndexSet<_>>()
804 .into_iter()
805 .collect::<Vec<_>>();
806
807 let Some(new_predicate) = conjunction(new_predicates) else {
808 return plan_err!("at least one expression exists");
809 };
810 let new_filter = LogicalPlan::Filter(Filter::try_new(
811 new_predicate,
812 child_filter.input,
813 )?);
814 #[allow(clippy::used_underscore_binding)]
815 self.rewrite(new_filter, _config)
816 }
817 LogicalPlan::Repartition(repartition) => {
818 let new_filter =
819 Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
820 .map(LogicalPlan::Filter)?;
821 insert_below(LogicalPlan::Repartition(repartition), new_filter)
822 }
823 LogicalPlan::Distinct(distinct) => {
824 let new_filter =
825 Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
826 .map(LogicalPlan::Filter)?;
827 insert_below(LogicalPlan::Distinct(distinct), new_filter)
828 }
829 LogicalPlan::Sort(sort) => {
830 let new_filter =
831 Filter::try_new(filter.predicate, Arc::clone(&sort.input))
832 .map(LogicalPlan::Filter)?;
833 insert_below(LogicalPlan::Sort(sort), new_filter)
834 }
835 LogicalPlan::SubqueryAlias(subquery_alias) => {
836 let mut replace_map = HashMap::new();
837 for (i, (qualifier, field)) in
838 subquery_alias.input.schema().iter().enumerate()
839 {
840 let (sub_qualifier, sub_field) =
841 subquery_alias.schema.qualified_field(i);
842 replace_map.insert(
843 qualified_name(sub_qualifier, sub_field.name()),
844 Expr::Column(Column::new(qualifier.cloned(), field.name())),
845 );
846 }
847 let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;
848
849 let new_filter = LogicalPlan::Filter(Filter::try_new(
850 new_predicate,
851 Arc::clone(&subquery_alias.input),
852 )?);
853 insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
854 }
855 LogicalPlan::Projection(projection) => {
856 let predicates = split_conjunction_owned(filter.predicate.clone());
857 let (new_projection, keep_predicate) =
858 rewrite_projection(predicates, projection)?;
859 if new_projection.transformed {
860 match keep_predicate {
861 None => Ok(new_projection),
862 Some(keep_predicate) => new_projection.map_data(|child_plan| {
863 Filter::try_new(keep_predicate, Arc::new(child_plan))
864 .map(LogicalPlan::Filter)
865 }),
866 }
867 } else {
868 filter.input = Arc::new(new_projection.data);
869 Ok(Transformed::no(LogicalPlan::Filter(filter)))
870 }
871 }
872 LogicalPlan::Unnest(mut unnest) => {
873 let predicates = split_conjunction_owned(filter.predicate.clone());
874 let mut non_unnest_predicates = vec![];
875 let mut unnest_predicates = vec![];
876 let mut unnest_struct_columns = vec![];
877
878 for idx in &unnest.struct_type_columns {
879 let (sub_qualifier, field) =
880 unnest.input.schema().qualified_field(*idx);
881 let field_name = field.name().clone();
882
883 if let DataType::Struct(children) = field.data_type() {
884 for child in children {
885 let child_name = child.name().clone();
886 unnest_struct_columns.push(Column::new(
887 sub_qualifier.cloned(),
888 format!("{field_name}.{child_name}"),
889 ));
890 }
891 }
892 }
893
894 for predicate in predicates {
895 let mut accum: HashSet<Column> = HashSet::new();
897 expr_to_columns(&predicate, &mut accum)?;
898
899 let contains_list_columns =
900 unnest.list_type_columns.iter().any(|(_, unnest_list)| {
901 accum.contains(&unnest_list.output_column)
902 });
903 let contains_struct_columns =
904 unnest_struct_columns.iter().any(|c| accum.contains(c));
905
906 if contains_list_columns || contains_struct_columns {
907 unnest_predicates.push(predicate);
908 } else {
909 non_unnest_predicates.push(predicate);
910 }
911 }
912
913 if non_unnest_predicates.is_empty() {
916 filter.input = Arc::new(LogicalPlan::Unnest(unnest));
917 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
918 }
919
920 let unnest_input = std::mem::take(&mut unnest.input);
929
930 let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
931 conjunction(non_unnest_predicates).unwrap(), unnest_input,
933 )?);
934
935 let unnest_plan =
939 insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?;
940
941 match conjunction(unnest_predicates) {
942 None => Ok(unnest_plan),
943 Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
944 Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
945 ))),
946 }
947 }
948 LogicalPlan::Union(ref union) => {
949 let mut inputs = Vec::with_capacity(union.inputs.len());
950 for input in &union.inputs {
951 let mut replace_map = HashMap::new();
952 for (i, (qualifier, field)) in input.schema().iter().enumerate() {
953 let (union_qualifier, union_field) =
954 union.schema.qualified_field(i);
955 replace_map.insert(
956 qualified_name(union_qualifier, union_field.name()),
957 Expr::Column(Column::new(qualifier.cloned(), field.name())),
958 );
959 }
960
961 let push_predicate =
962 replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
963 inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
964 push_predicate,
965 Arc::clone(input),
966 )?)))
967 }
968 Ok(Transformed::yes(LogicalPlan::Union(Union {
969 inputs,
970 schema: Arc::clone(&plan_schema),
971 })))
972 }
973 LogicalPlan::Aggregate(agg) => {
974 let group_expr_columns = agg
976 .group_expr
977 .iter()
978 .map(|e| {
979 let (relation, name) = e.qualified_name();
980 Column::new(relation, name)
981 })
982 .collect::<HashSet<_>>();
983
984 let predicates = split_conjunction_owned(filter.predicate);
985
986 let mut keep_predicates = vec![];
987 let mut push_predicates = vec![];
988 for expr in predicates {
989 let cols = expr.column_refs();
990 if cols.iter().all(|c| group_expr_columns.contains(c)) {
991 push_predicates.push(expr);
992 } else {
993 keep_predicates.push(expr);
994 }
995 }
996
997 let mut replace_map = HashMap::new();
1001 for expr in &agg.group_expr {
1002 replace_map.insert(expr.schema_name().to_string(), expr.clone());
1003 }
1004 let replaced_push_predicates = push_predicates
1005 .into_iter()
1006 .map(|expr| replace_cols_by_name(expr, &replace_map))
1007 .collect::<Result<Vec<_>>>()?;
1008
1009 let agg_input = Arc::clone(&agg.input);
1010 Transformed::yes(LogicalPlan::Aggregate(agg))
1011 .transform_data(|new_plan| {
1012 if let Some(predicate) = conjunction(replaced_push_predicates) {
1014 let new_filter = make_filter(predicate, agg_input)?;
1015 insert_below(new_plan, new_filter)
1016 } else {
1017 Ok(Transformed::no(new_plan))
1018 }
1019 })?
1020 .map_data(|child_plan| {
1021 if let Some(predicate) = conjunction(keep_predicates) {
1024 make_filter(predicate, Arc::new(child_plan))
1025 } else {
1026 Ok(child_plan)
1027 }
1028 })
1029 }
1030 LogicalPlan::Window(window) => {
1041 let extract_partition_keys = |func: &WindowFunction| {
1047 func.params
1048 .partition_by
1049 .iter()
1050 .map(|c| {
1051 let (relation, name) = c.qualified_name();
1052 Column::new(relation, name)
1053 })
1054 .collect::<HashSet<_>>()
1055 };
1056 let potential_partition_keys = window
1057 .window_expr
1058 .iter()
1059 .map(|e| {
1060 match e {
1061 Expr::WindowFunction(window_func) => {
1062 extract_partition_keys(window_func)
1063 }
1064 Expr::Alias(alias) => {
1065 if let Expr::WindowFunction(window_func) =
1066 alias.expr.as_ref()
1067 {
1068 extract_partition_keys(window_func)
1069 } else {
1070 unreachable!()
1072 }
1073 }
1074 _ => {
1075 unreachable!()
1077 }
1078 }
1079 })
1080 .reduce(|a, b| &a & &b)
1083 .unwrap_or_default();
1084
1085 let predicates = split_conjunction_owned(filter.predicate);
1086 let mut keep_predicates = vec![];
1087 let mut push_predicates = vec![];
1088 for expr in predicates {
1089 let cols = expr.column_refs();
1090 if cols.iter().all(|c| potential_partition_keys.contains(c)) {
1091 push_predicates.push(expr);
1092 } else {
1093 keep_predicates.push(expr);
1094 }
1095 }
1096
1097 let window_input = Arc::clone(&window.input);
1106 Transformed::yes(LogicalPlan::Window(window))
1107 .transform_data(|new_plan| {
1108 if let Some(predicate) = conjunction(push_predicates) {
1110 let new_filter = make_filter(predicate, window_input)?;
1111 insert_below(new_plan, new_filter)
1112 } else {
1113 Ok(Transformed::no(new_plan))
1114 }
1115 })?
1116 .map_data(|child_plan| {
1117 if let Some(predicate) = conjunction(keep_predicates) {
1120 make_filter(predicate, Arc::new(child_plan))
1121 } else {
1122 Ok(child_plan)
1123 }
1124 })
1125 }
1126 LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)),
1127 LogicalPlan::TableScan(scan) => {
1128 let filter_predicates = split_conjunction(&filter.predicate);
1129
1130 let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) =
1131 filter_predicates
1132 .into_iter()
1133 .partition(|pred| pred.is_volatile());
1134
1135 let supported_filters = scan
1137 .source
1138 .supports_filters_pushdown(non_volatile_filters.as_slice())?;
1139 if non_volatile_filters.len() != supported_filters.len() {
1140 return internal_err!(
1141 "Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
1142 supported_filters.len(),
1143 non_volatile_filters.len());
1144 }
1145
1146 let zip = non_volatile_filters.into_iter().zip(supported_filters);
1148
1149 let new_scan_filters = zip
1150 .clone()
1151 .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
1152 .map(|(pred, _)| pred);
1153
1154 let new_scan_filters: Vec<Expr> = scan
1156 .filters
1157 .iter()
1158 .chain(new_scan_filters)
1159 .unique()
1160 .cloned()
1161 .collect();
1162
1163 let new_predicate: Vec<Expr> = zip
1165 .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
1166 .map(|(pred, _)| pred)
1167 .chain(volatile_filters)
1168 .cloned()
1169 .collect();
1170
1171 let new_scan = LogicalPlan::TableScan(TableScan {
1172 filters: new_scan_filters,
1173 ..scan
1174 });
1175
1176 Transformed::yes(new_scan).transform_data(|new_scan| {
1177 if let Some(predicate) = conjunction(new_predicate) {
1178 make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes)
1179 } else {
1180 Ok(Transformed::no(new_scan))
1181 }
1182 })
1183 }
1184 LogicalPlan::Extension(extension_plan) => {
1185 if extension_plan.node.inputs().is_empty() {
1188 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1189 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1190 }
1191 let prevent_cols =
1192 extension_plan.node.prevent_predicate_push_down_columns();
1193
1194 let predicate_push_or_keep = split_conjunction(&filter.predicate)
1198 .iter()
1199 .map(|expr| {
1200 let cols = expr.column_refs();
1201 if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
1202 Ok(false) } else {
1204 Ok(true) }
1206 })
1207 .collect::<Result<Vec<_>>>()?;
1208
1209 if predicate_push_or_keep.iter().all(|&x| !x) {
1211 filter.input = Arc::new(LogicalPlan::Extension(extension_plan));
1212 return Ok(Transformed::no(LogicalPlan::Filter(filter)));
1213 }
1214
1215 let mut keep_predicates = vec![];
1217 let mut push_predicates = vec![];
1218 for (push, expr) in predicate_push_or_keep
1219 .into_iter()
1220 .zip(split_conjunction_owned(filter.predicate).into_iter())
1221 {
1222 if !push {
1223 keep_predicates.push(expr);
1224 } else {
1225 push_predicates.push(expr);
1226 }
1227 }
1228
1229 let new_children = match conjunction(push_predicates) {
1230 Some(predicate) => extension_plan
1231 .node
1232 .inputs()
1233 .into_iter()
1234 .map(|child| {
1235 Ok(LogicalPlan::Filter(Filter::try_new(
1236 predicate.clone(),
1237 Arc::new(child.clone()),
1238 )?))
1239 })
1240 .collect::<Result<Vec<_>>>()?,
1241 None => extension_plan.node.inputs().into_iter().cloned().collect(),
1242 };
1243 let child_plan = LogicalPlan::Extension(extension_plan);
1245 let new_extension =
1246 child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
1247
1248 let new_plan = match conjunction(keep_predicates) {
1249 Some(predicate) => LogicalPlan::Filter(Filter::try_new(
1250 predicate,
1251 Arc::new(new_extension),
1252 )?),
1253 None => new_extension,
1254 };
1255 Ok(Transformed::yes(new_plan))
1256 }
1257 child => {
1258 filter.input = Arc::new(child);
1259 Ok(Transformed::no(LogicalPlan::Filter(filter)))
1260 }
1261 }
1262 }
1263}
1264
1265fn rewrite_projection(
1293 predicates: Vec<Expr>,
1294 mut projection: Projection,
1295) -> Result<(Transformed<LogicalPlan>, Option<Expr>)> {
1296 let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection
1300 .schema
1301 .iter()
1302 .zip(projection.expr.iter())
1303 .map(|((qualifier, field), expr)| {
1304 let expr = expr.clone().unalias();
1306
1307 (qualified_name(qualifier, field.name()), expr)
1308 })
1309 .partition(|(_, value)| value.is_volatile());
1310
1311 let mut push_predicates = vec![];
1312 let mut keep_predicates = vec![];
1313 for expr in predicates {
1314 if contain(&expr, &volatile_map) {
1315 keep_predicates.push(expr);
1316 } else {
1317 push_predicates.push(expr);
1318 }
1319 }
1320
1321 match conjunction(push_predicates) {
1322 Some(expr) => {
1323 let new_filter = LogicalPlan::Filter(Filter::try_new(
1326 replace_cols_by_name(expr, &non_volatile_map)?,
1327 std::mem::take(&mut projection.input),
1328 )?);
1329
1330 projection.input = Arc::new(new_filter);
1331
1332 Ok((
1333 Transformed::yes(LogicalPlan::Projection(projection)),
1334 conjunction(keep_predicates),
1335 ))
1336 }
1337 None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)),
1338 }
1339}
1340
1341pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
1343 Filter::try_new(predicate, input).map(LogicalPlan::Filter)
1344}
1345
1346fn insert_below(
1360 plan: LogicalPlan,
1361 new_child: LogicalPlan,
1362) -> Result<Transformed<LogicalPlan>> {
1363 let mut new_child = Some(new_child);
1364 let transformed_plan = plan.map_children(|_child| {
1365 if let Some(new_child) = new_child.take() {
1366 Ok(Transformed::yes(new_child))
1367 } else {
1368 internal_err!("node had more than one input")
1370 }
1371 })?;
1372
1373 if new_child.is_some() {
1375 return internal_err!("node had no inputs");
1376 }
1377
1378 Ok(transformed_plan)
1379}
1380
1381impl PushDownFilter {
1382 #[allow(missing_docs)]
1383 pub fn new() -> Self {
1384 Self {}
1385 }
1386}
1387
1388pub fn replace_cols_by_name(
1390 e: Expr,
1391 replace_map: &HashMap<String, Expr>,
1392) -> Result<Expr> {
1393 e.transform_up_with_lambdas_params(|expr, lambdas_params| {
1394 Ok(match &expr {
1395 Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
1396 match replace_map.get(&c.flat_name()) {
1397 Some(new_c) => Transformed::yes(new_c.clone()),
1398 None => Transformed::no(expr),
1399 }
1400 }
1401 _ => Transformed::no(expr),
1402 })
1403 })
1404 .data()
1405}
1406
1407fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
1409 let mut is_contain = false;
1410 e.apply_with_lambdas_params(|expr, lambdas_params| {
1411 Ok(match &expr {
1412 Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => {
1413 match check_map.get(&c.flat_name()) {
1414 Some(_) => {
1415 is_contain = true;
1416 TreeNodeRecursion::Stop
1417 }
1418 None => TreeNodeRecursion::Continue,
1419 }
1420 }
1421 _ => TreeNodeRecursion::Continue,
1422 })
1423 })
1424 .unwrap();
1425 is_contain
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430 use std::any::Any;
1431 use std::cmp::Ordering;
1432 use std::fmt::{Debug, Formatter};
1433
1434 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
1435 use async_trait::async_trait;
1436
1437 use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
1438 use datafusion_expr::expr::{ScalarFunction, WindowFunction};
1439 use datafusion_expr::logical_plan::table_scan;
1440 use datafusion_expr::{
1441 col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
1442 LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1443 TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
1444 WindowFunctionDefinition,
1445 };
1446
1447 use crate::assert_optimized_plan_eq_snapshot;
1448 use crate::optimizer::Optimizer;
1449 use crate::simplify_expressions::SimplifyExpressions;
1450 use crate::test::*;
1451 use crate::OptimizerContext;
1452 use datafusion_expr::test::function_stub::sum;
1453 use insta::assert_snapshot;
1454
1455 use super::*;
1456
1457 fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
1458
1459 macro_rules! assert_optimized_plan_equal {
1460 (
1461 $plan:expr,
1462 @ $expected:literal $(,)?
1463 ) => {{
1464 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
1465 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(PushDownFilter::new())];
1466 assert_optimized_plan_eq_snapshot!(
1467 optimizer_ctx,
1468 rules,
1469 $plan,
1470 @ $expected,
1471 )
1472 }};
1473 }
1474
1475 macro_rules! assert_optimized_plan_eq_with_rewrite_predicate {
1476 (
1477 $plan:expr,
1478 @ $expected:literal $(,)?
1479 ) => {{
1480 let optimizer = Optimizer::with_rules(vec![
1481 Arc::new(SimplifyExpressions::new()),
1482 Arc::new(PushDownFilter::new()),
1483 ]);
1484 let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?;
1485 assert_snapshot!(optimized_plan, @ $expected);
1486 Ok::<(), DataFusionError>(())
1487 }};
1488 }
1489
1490 #[test]
1491 fn filter_before_projection() -> Result<()> {
1492 let table_scan = test_table_scan()?;
1493 let plan = LogicalPlanBuilder::from(table_scan)
1494 .project(vec![col("a"), col("b")])?
1495 .filter(col("a").eq(lit(1i64)))?
1496 .build()?;
1497 assert_optimized_plan_equal!(
1499 plan,
1500 @r"
1501 Projection: test.a, test.b
1502 TableScan: test, full_filters=[test.a = Int64(1)]
1503 "
1504 )
1505 }
1506
1507 #[test]
1508 fn filter_after_limit() -> Result<()> {
1509 let table_scan = test_table_scan()?;
1510 let plan = LogicalPlanBuilder::from(table_scan)
1511 .project(vec![col("a"), col("b")])?
1512 .limit(0, Some(10))?
1513 .filter(col("a").eq(lit(1i64)))?
1514 .build()?;
1515 assert_optimized_plan_equal!(
1517 plan,
1518 @r"
1519 Filter: test.a = Int64(1)
1520 Limit: skip=0, fetch=10
1521 Projection: test.a, test.b
1522 TableScan: test
1523 "
1524 )
1525 }
1526
1527 #[test]
1528 fn filter_no_columns() -> Result<()> {
1529 let table_scan = test_table_scan()?;
1530 let plan = LogicalPlanBuilder::from(table_scan)
1531 .filter(lit(0i64).eq(lit(1i64)))?
1532 .build()?;
1533 assert_optimized_plan_equal!(
1534 plan,
1535 @"TableScan: test, full_filters=[Int64(0) = Int64(1)]"
1536 )
1537 }
1538
1539 #[test]
1540 fn filter_jump_2_plans() -> Result<()> {
1541 let table_scan = test_table_scan()?;
1542 let plan = LogicalPlanBuilder::from(table_scan)
1543 .project(vec![col("a"), col("b"), col("c")])?
1544 .project(vec![col("c"), col("b")])?
1545 .filter(col("a").eq(lit(1i64)))?
1546 .build()?;
1547 assert_optimized_plan_equal!(
1549 plan,
1550 @r"
1551 Projection: test.c, test.b
1552 Projection: test.a, test.b, test.c
1553 TableScan: test, full_filters=[test.a = Int64(1)]
1554 "
1555 )
1556 }
1557
1558 #[test]
1559 fn filter_move_agg() -> Result<()> {
1560 let table_scan = test_table_scan()?;
1561 let plan = LogicalPlanBuilder::from(table_scan)
1562 .aggregate(vec![col("a")], vec![sum(col("b")).alias("total_salary")])?
1563 .filter(col("a").gt(lit(10i64)))?
1564 .build()?;
1565 assert_optimized_plan_equal!(
1567 plan,
1568 @r"
1569 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]
1570 TableScan: test, full_filters=[test.a > Int64(10)]
1571 "
1572 )
1573 }
1574
1575 #[test]
1577 fn filter_move_agg_special() -> Result<()> {
1578 let schema = Schema::new(vec![
1579 Field::new("$a", DataType::UInt32, false),
1580 Field::new("$b", DataType::UInt32, false),
1581 Field::new("$c", DataType::UInt32, false),
1582 ]);
1583 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1584
1585 let plan = LogicalPlanBuilder::from(table_scan)
1586 .aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1587 .filter(col("$a").gt(lit(10i64)))?
1588 .build()?;
1589 assert_optimized_plan_equal!(
1591 plan,
1592 @r"
1593 Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1594 TableScan: test, full_filters=[test.$a > Int64(10)]
1595 "
1596 )
1597 }
1598
1599 #[test]
1600 fn filter_complex_group_by() -> Result<()> {
1601 let table_scan = test_table_scan()?;
1602 let plan = LogicalPlanBuilder::from(table_scan)
1603 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1604 .filter(col("b").gt(lit(10i64)))?
1605 .build()?;
1606 assert_optimized_plan_equal!(
1607 plan,
1608 @r"
1609 Filter: test.b > Int64(10)
1610 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1611 TableScan: test
1612 "
1613 )
1614 }
1615
1616 #[test]
1617 fn push_agg_need_replace_expr() -> Result<()> {
1618 let plan = LogicalPlanBuilder::from(test_table_scan()?)
1619 .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])?
1620 .filter(col("test.b + test.a").gt(lit(10i64)))?
1621 .build()?;
1622 assert_optimized_plan_equal!(
1623 plan,
1624 @r"
1625 Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]
1626 TableScan: test, full_filters=[test.b + test.a > Int64(10)]
1627 "
1628 )
1629 }
1630
1631 #[test]
1632 fn filter_keep_agg() -> Result<()> {
1633 let table_scan = test_table_scan()?;
1634 let plan = LogicalPlanBuilder::from(table_scan)
1635 .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])?
1636 .filter(col("b").gt(lit(10i64)))?
1637 .build()?;
1638 assert_optimized_plan_equal!(
1640 plan,
1641 @r"
1642 Filter: b > Int64(10)
1643 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]
1644 TableScan: test
1645 "
1646 )
1647 }
1648
1649 #[test]
1651 fn filter_move_window() -> Result<()> {
1652 let table_scan = test_table_scan()?;
1653
1654 let window = Expr::from(WindowFunction::new(
1655 WindowFunctionDefinition::WindowUDF(
1656 datafusion_functions_window::rank::rank_udwf(),
1657 ),
1658 vec![],
1659 ))
1660 .partition_by(vec![col("a"), col("b")])
1661 .order_by(vec![col("c").sort(true, true)])
1662 .build()
1663 .unwrap();
1664
1665 let plan = LogicalPlanBuilder::from(table_scan)
1666 .window(vec![window])?
1667 .filter(col("b").gt(lit(10i64)))?
1668 .build()?;
1669
1670 assert_optimized_plan_equal!(
1671 plan,
1672 @r"
1673 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1674 TableScan: test, full_filters=[test.b > Int64(10)]
1675 "
1676 )
1677 }
1678
1679 #[test]
1681 fn filter_window_special_identifier() -> Result<()> {
1682 let schema = Schema::new(vec![
1683 Field::new("$a", DataType::UInt32, false),
1684 Field::new("$b", DataType::UInt32, false),
1685 Field::new("$c", DataType::UInt32, false),
1686 ]);
1687 let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1688
1689 let window = Expr::from(WindowFunction::new(
1690 WindowFunctionDefinition::WindowUDF(
1691 datafusion_functions_window::rank::rank_udwf(),
1692 ),
1693 vec![],
1694 ))
1695 .partition_by(vec![col("$a"), col("$b")])
1696 .order_by(vec![col("$c").sort(true, true)])
1697 .build()
1698 .unwrap();
1699
1700 let plan = LogicalPlanBuilder::from(table_scan)
1701 .window(vec![window])?
1702 .filter(col("$b").gt(lit(10i64)))?
1703 .build()?;
1704
1705 assert_optimized_plan_equal!(
1706 plan,
1707 @r"
1708 WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1709 TableScan: test, full_filters=[test.$b > Int64(10)]
1710 "
1711 )
1712 }
1713
1714 #[test]
1717 fn filter_move_complex_window() -> Result<()> {
1718 let table_scan = test_table_scan()?;
1719
1720 let window = Expr::from(WindowFunction::new(
1721 WindowFunctionDefinition::WindowUDF(
1722 datafusion_functions_window::rank::rank_udwf(),
1723 ),
1724 vec![],
1725 ))
1726 .partition_by(vec![col("a"), col("b")])
1727 .order_by(vec![col("c").sort(true, true)])
1728 .build()
1729 .unwrap();
1730
1731 let plan = LogicalPlanBuilder::from(table_scan)
1732 .window(vec![window])?
1733 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1734 .build()?;
1735
1736 assert_optimized_plan_equal!(
1737 plan,
1738 @r"
1739 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1740 TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]
1741 "
1742 )
1743 }
1744
1745 #[test]
1747 fn filter_move_partial_window() -> Result<()> {
1748 let table_scan = test_table_scan()?;
1749
1750 let window = Expr::from(WindowFunction::new(
1751 WindowFunctionDefinition::WindowUDF(
1752 datafusion_functions_window::rank::rank_udwf(),
1753 ),
1754 vec![],
1755 ))
1756 .partition_by(vec![col("a")])
1757 .order_by(vec![col("c").sort(true, true)])
1758 .build()
1759 .unwrap();
1760
1761 let plan = LogicalPlanBuilder::from(table_scan)
1762 .window(vec![window])?
1763 .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))?
1764 .build()?;
1765
1766 assert_optimized_plan_equal!(
1767 plan,
1768 @r"
1769 Filter: test.b = Int64(1)
1770 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1771 TableScan: test, full_filters=[test.a > Int64(10)]
1772 "
1773 )
1774 }
1775
1776 #[test]
1779 fn filter_expression_keep_window() -> Result<()> {
1780 let table_scan = test_table_scan()?;
1781
1782 let window = Expr::from(WindowFunction::new(
1783 WindowFunctionDefinition::WindowUDF(
1784 datafusion_functions_window::rank::rank_udwf(),
1785 ),
1786 vec![],
1787 ))
1788 .partition_by(vec![add(col("a"), col("b"))]) .order_by(vec![col("c").sort(true, true)])
1790 .build()
1791 .unwrap();
1792
1793 let plan = LogicalPlanBuilder::from(table_scan)
1794 .window(vec![window])?
1795 .filter(add(col("a"), col("b")).gt(lit(10i64)))?
1798 .build()?;
1799
1800 assert_optimized_plan_equal!(
1801 plan,
1802 @r"
1803 Filter: test.a + test.b > Int64(10)
1804 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1805 TableScan: test
1806 "
1807 )
1808 }
1809
1810 #[test]
1812 fn filter_order_keep_window() -> Result<()> {
1813 let table_scan = test_table_scan()?;
1814
1815 let window = Expr::from(WindowFunction::new(
1816 WindowFunctionDefinition::WindowUDF(
1817 datafusion_functions_window::rank::rank_udwf(),
1818 ),
1819 vec![],
1820 ))
1821 .partition_by(vec![col("a")])
1822 .order_by(vec![col("c").sort(true, true)])
1823 .build()
1824 .unwrap();
1825
1826 let plan = LogicalPlanBuilder::from(table_scan)
1827 .window(vec![window])?
1828 .filter(col("c").gt(lit(10i64)))?
1829 .build()?;
1830
1831 assert_optimized_plan_equal!(
1832 plan,
1833 @r"
1834 Filter: test.c > Int64(10)
1835 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1836 TableScan: test
1837 "
1838 )
1839 }
1840
1841 #[test]
1844 fn filter_multiple_windows_common_partitions() -> Result<()> {
1845 let table_scan = test_table_scan()?;
1846
1847 let window1 = Expr::from(WindowFunction::new(
1848 WindowFunctionDefinition::WindowUDF(
1849 datafusion_functions_window::rank::rank_udwf(),
1850 ),
1851 vec![],
1852 ))
1853 .partition_by(vec![col("a")])
1854 .order_by(vec![col("c").sort(true, true)])
1855 .build()
1856 .unwrap();
1857
1858 let window2 = Expr::from(WindowFunction::new(
1859 WindowFunctionDefinition::WindowUDF(
1860 datafusion_functions_window::rank::rank_udwf(),
1861 ),
1862 vec![],
1863 ))
1864 .partition_by(vec![col("b"), col("a")])
1865 .order_by(vec![col("c").sort(true, true)])
1866 .build()
1867 .unwrap();
1868
1869 let plan = LogicalPlanBuilder::from(table_scan)
1870 .window(vec![window1, window2])?
1871 .filter(col("a").gt(lit(10i64)))? .build()?;
1873
1874 assert_optimized_plan_equal!(
1875 plan,
1876 @r"
1877 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1878 TableScan: test, full_filters=[test.a > Int64(10)]
1879 "
1880 )
1881 }
1882
1883 #[test]
1886 fn filter_multiple_windows_disjoint_partitions() -> Result<()> {
1887 let table_scan = test_table_scan()?;
1888
1889 let window1 = Expr::from(WindowFunction::new(
1890 WindowFunctionDefinition::WindowUDF(
1891 datafusion_functions_window::rank::rank_udwf(),
1892 ),
1893 vec![],
1894 ))
1895 .partition_by(vec![col("a")])
1896 .order_by(vec![col("c").sort(true, true)])
1897 .build()
1898 .unwrap();
1899
1900 let window2 = Expr::from(WindowFunction::new(
1901 WindowFunctionDefinition::WindowUDF(
1902 datafusion_functions_window::rank::rank_udwf(),
1903 ),
1904 vec![],
1905 ))
1906 .partition_by(vec![col("b"), col("a")])
1907 .order_by(vec![col("c").sort(true, true)])
1908 .build()
1909 .unwrap();
1910
1911 let plan = LogicalPlanBuilder::from(table_scan)
1912 .window(vec![window1, window2])?
1913 .filter(col("b").gt(lit(10i64)))? .build()?;
1915
1916 assert_optimized_plan_equal!(
1917 plan,
1918 @r"
1919 Filter: test.b > Int64(10)
1920 WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1921 TableScan: test
1922 "
1923 )
1924 }
1925
1926 #[test]
1928 fn alias() -> Result<()> {
1929 let table_scan = test_table_scan()?;
1930 let plan = LogicalPlanBuilder::from(table_scan)
1931 .project(vec![col("a").alias("b"), col("c")])?
1932 .filter(col("b").eq(lit(1i64)))?
1933 .build()?;
1934 assert_optimized_plan_equal!(
1936 plan,
1937 @r"
1938 Projection: test.a AS b, test.c
1939 TableScan: test, full_filters=[test.a = Int64(1)]
1940 "
1941 )
1942 }
1943
1944 fn add(left: Expr, right: Expr) -> Expr {
1945 Expr::BinaryExpr(BinaryExpr::new(
1946 Box::new(left),
1947 Operator::Plus,
1948 Box::new(right),
1949 ))
1950 }
1951
1952 fn multiply(left: Expr, right: Expr) -> Expr {
1953 Expr::BinaryExpr(BinaryExpr::new(
1954 Box::new(left),
1955 Operator::Multiply,
1956 Box::new(right),
1957 ))
1958 }
1959
1960 #[test]
1962 fn complex_expression() -> Result<()> {
1963 let table_scan = test_table_scan()?;
1964 let plan = LogicalPlanBuilder::from(table_scan)
1965 .project(vec![
1966 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1967 col("c"),
1968 ])?
1969 .filter(col("b").eq(lit(1i64)))?
1970 .build()?;
1971
1972 assert_snapshot!(plan,
1974 @r"
1975 Filter: b = Int64(1)
1976 Projection: test.a * Int32(2) + test.c AS b, test.c
1977 TableScan: test
1978 ",
1979 );
1980 assert_optimized_plan_equal!(
1982 plan,
1983 @r"
1984 Projection: test.a * Int32(2) + test.c AS b, test.c
1985 TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]
1986 "
1987 )
1988 }
1989
1990 #[test]
1992 fn complex_plan() -> Result<()> {
1993 let table_scan = test_table_scan()?;
1994 let plan = LogicalPlanBuilder::from(table_scan)
1995 .project(vec![
1996 add(multiply(col("a"), lit(2)), col("c")).alias("b"),
1997 col("c"),
1998 ])?
1999 .project(vec![multiply(col("b"), lit(3)).alias("a"), col("c")])?
2001 .filter(col("a").eq(lit(1i64)))?
2002 .build()?;
2003
2004 assert_snapshot!(plan,
2006 @r"
2007 Filter: a = Int64(1)
2008 Projection: b * Int32(3) AS a, test.c
2009 Projection: test.a * Int32(2) + test.c AS b, test.c
2010 TableScan: test
2011 ",
2012 );
2013 assert_optimized_plan_equal!(
2015 plan,
2016 @r"
2017 Projection: b * Int32(3) AS a, test.c
2018 Projection: test.a * Int32(2) + test.c AS b, test.c
2019 TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]
2020 "
2021 )
2022 }
2023
2024 #[derive(Debug, PartialEq, Eq, Hash)]
2025 struct NoopPlan {
2026 input: Vec<LogicalPlan>,
2027 schema: DFSchemaRef,
2028 }
2029
2030 impl PartialOrd for NoopPlan {
2032 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2033 self.input
2034 .partial_cmp(&other.input)
2035 .filter(|cmp| *cmp != Ordering::Equal || self == other)
2037 }
2038 }
2039
2040 impl UserDefinedLogicalNodeCore for NoopPlan {
2041 fn name(&self) -> &str {
2042 "NoopPlan"
2043 }
2044
2045 fn inputs(&self) -> Vec<&LogicalPlan> {
2046 self.input.iter().collect()
2047 }
2048
2049 fn schema(&self) -> &DFSchemaRef {
2050 &self.schema
2051 }
2052
2053 fn expressions(&self) -> Vec<Expr> {
2054 self.input
2055 .iter()
2056 .flat_map(|child| child.expressions())
2057 .collect()
2058 }
2059
2060 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
2061 HashSet::from_iter(vec!["c".to_string()])
2062 }
2063
2064 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
2065 write!(f, "NoopPlan")
2066 }
2067
2068 fn with_exprs_and_inputs(
2069 &self,
2070 _exprs: Vec<Expr>,
2071 inputs: Vec<LogicalPlan>,
2072 ) -> Result<Self> {
2073 Ok(Self {
2074 input: inputs,
2075 schema: Arc::clone(&self.schema),
2076 })
2077 }
2078
2079 fn supports_limit_pushdown(&self) -> bool {
2080 false }
2082 }
2083
2084 #[test]
2085 fn user_defined_plan() -> Result<()> {
2086 let table_scan = test_table_scan()?;
2087
2088 let custom_plan = LogicalPlan::Extension(Extension {
2089 node: Arc::new(NoopPlan {
2090 input: vec![table_scan.clone()],
2091 schema: Arc::clone(table_scan.schema()),
2092 }),
2093 });
2094 let plan = LogicalPlanBuilder::from(custom_plan)
2095 .filter(col("a").eq(lit(1i64)))?
2096 .build()?;
2097
2098 assert_optimized_plan_equal!(
2100 plan,
2101 @r"
2102 NoopPlan
2103 TableScan: test, full_filters=[test.a = Int64(1)]
2104 "
2105 )?;
2106
2107 let custom_plan = LogicalPlan::Extension(Extension {
2108 node: Arc::new(NoopPlan {
2109 input: vec![table_scan.clone()],
2110 schema: Arc::clone(table_scan.schema()),
2111 }),
2112 });
2113 let plan = LogicalPlanBuilder::from(custom_plan)
2114 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2115 .build()?;
2116
2117 assert_optimized_plan_equal!(
2119 plan,
2120 @r"
2121 Filter: test.c = Int64(2)
2122 NoopPlan
2123 TableScan: test, full_filters=[test.a = Int64(1)]
2124 "
2125 )?;
2126
2127 let custom_plan = LogicalPlan::Extension(Extension {
2128 node: Arc::new(NoopPlan {
2129 input: vec![table_scan.clone(), table_scan.clone()],
2130 schema: Arc::clone(table_scan.schema()),
2131 }),
2132 });
2133 let plan = LogicalPlanBuilder::from(custom_plan)
2134 .filter(col("a").eq(lit(1i64)))?
2135 .build()?;
2136
2137 assert_optimized_plan_equal!(
2139 plan,
2140 @r"
2141 NoopPlan
2142 TableScan: test, full_filters=[test.a = Int64(1)]
2143 TableScan: test, full_filters=[test.a = Int64(1)]
2144 "
2145 )?;
2146
2147 let custom_plan = LogicalPlan::Extension(Extension {
2148 node: Arc::new(NoopPlan {
2149 input: vec![table_scan.clone(), table_scan.clone()],
2150 schema: Arc::clone(table_scan.schema()),
2151 }),
2152 });
2153 let plan = LogicalPlanBuilder::from(custom_plan)
2154 .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
2155 .build()?;
2156
2157 assert_optimized_plan_equal!(
2159 plan,
2160 @r"
2161 Filter: test.c = Int64(2)
2162 NoopPlan
2163 TableScan: test, full_filters=[test.a = Int64(1)]
2164 TableScan: test, full_filters=[test.a = Int64(1)]
2165 "
2166 )
2167 }
2168
2169 #[test]
2172 fn multi_filter() -> Result<()> {
2173 let table_scan = test_table_scan()?;
2175 let plan = LogicalPlanBuilder::from(table_scan)
2176 .project(vec![col("a").alias("b"), col("c")])?
2177 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2178 .filter(col("b").gt(lit(10i64)))?
2179 .filter(col("sum(test.c)").gt(lit(10i64)))?
2180 .build()?;
2181
2182 assert_snapshot!(plan,
2184 @r"
2185 Filter: sum(test.c) > Int64(10)
2186 Filter: b > Int64(10)
2187 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2188 Projection: test.a AS b, test.c
2189 TableScan: test
2190 ",
2191 );
2192 assert_optimized_plan_equal!(
2194 plan,
2195 @r"
2196 Filter: sum(test.c) > Int64(10)
2197 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2198 Projection: test.a AS b, test.c
2199 TableScan: test, full_filters=[test.a > Int64(10)]
2200 "
2201 )
2202 }
2203
2204 #[test]
2207 fn split_filter() -> Result<()> {
2208 let table_scan = test_table_scan()?;
2210 let plan = LogicalPlanBuilder::from(table_scan)
2211 .project(vec![col("a").alias("b"), col("c")])?
2212 .aggregate(vec![col("b")], vec![sum(col("c"))])?
2213 .filter(and(
2214 col("sum(test.c)").gt(lit(10i64)),
2215 and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))),
2216 ))?
2217 .build()?;
2218
2219 assert_snapshot!(plan,
2221 @r"
2222 Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)
2223 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2224 Projection: test.a AS b, test.c
2225 TableScan: test
2226 ",
2227 );
2228 assert_optimized_plan_equal!(
2230 plan,
2231 @r"
2232 Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)
2233 Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]
2234 Projection: test.a AS b, test.c
2235 TableScan: test, full_filters=[test.a > Int64(10)]
2236 "
2237 )
2238 }
2239
2240 #[test]
2242 fn double_limit() -> Result<()> {
2243 let table_scan = test_table_scan()?;
2244 let plan = LogicalPlanBuilder::from(table_scan)
2245 .project(vec![col("a"), col("b")])?
2246 .limit(0, Some(20))?
2247 .limit(0, Some(10))?
2248 .project(vec![col("a"), col("b")])?
2249 .filter(col("a").eq(lit(1i64)))?
2250 .build()?;
2251 assert_optimized_plan_equal!(
2253 plan,
2254 @r"
2255 Projection: test.a, test.b
2256 Filter: test.a = Int64(1)
2257 Limit: skip=0, fetch=10
2258 Limit: skip=0, fetch=20
2259 Projection: test.a, test.b
2260 TableScan: test
2261 "
2262 )
2263 }
2264
2265 #[test]
2266 fn union_all() -> Result<()> {
2267 let table_scan = test_table_scan()?;
2268 let table_scan2 = test_table_scan_with_name("test2")?;
2269 let plan = LogicalPlanBuilder::from(table_scan)
2270 .union(LogicalPlanBuilder::from(table_scan2).build()?)?
2271 .filter(col("a").eq(lit(1i64)))?
2272 .build()?;
2273 assert_optimized_plan_equal!(
2275 plan,
2276 @r"
2277 Union
2278 TableScan: test, full_filters=[test.a = Int64(1)]
2279 TableScan: test2, full_filters=[test2.a = Int64(1)]
2280 "
2281 )
2282 }
2283
2284 #[test]
2285 fn union_all_on_projection() -> Result<()> {
2286 let table_scan = test_table_scan()?;
2287 let table = LogicalPlanBuilder::from(table_scan)
2288 .project(vec![col("a").alias("b")])?
2289 .alias("test2")?;
2290
2291 let plan = table
2292 .clone()
2293 .union(table.build()?)?
2294 .filter(col("b").eq(lit(1i64)))?
2295 .build()?;
2296
2297 assert_optimized_plan_equal!(
2299 plan,
2300 @r"
2301 Union
2302 SubqueryAlias: test2
2303 Projection: test.a AS b
2304 TableScan: test, full_filters=[test.a = Int64(1)]
2305 SubqueryAlias: test2
2306 Projection: test.a AS b
2307 TableScan: test, full_filters=[test.a = Int64(1)]
2308 "
2309 )
2310 }
2311
2312 #[test]
2313 fn test_union_different_schema() -> Result<()> {
2314 let left = LogicalPlanBuilder::from(test_table_scan()?)
2315 .project(vec![col("a"), col("b"), col("c")])?
2316 .build()?;
2317
2318 let schema = Schema::new(vec![
2319 Field::new("d", DataType::UInt32, false),
2320 Field::new("e", DataType::UInt32, false),
2321 Field::new("f", DataType::UInt32, false),
2322 ]);
2323 let right = table_scan(Some("test1"), &schema, None)?
2324 .project(vec![col("d"), col("e"), col("f")])?
2325 .build()?;
2326 let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2)));
2327 let plan = LogicalPlanBuilder::from(left)
2328 .cross_join(right)?
2329 .project(vec![col("test.a"), col("test1.d")])?
2330 .filter(filter)?
2331 .build()?;
2332
2333 assert_optimized_plan_equal!(
2334 plan,
2335 @r"
2336 Projection: test.a, test1.d
2337 Cross Join:
2338 Projection: test.a, test.b, test.c
2339 TableScan: test, full_filters=[test.a = Int32(1)]
2340 Projection: test1.d, test1.e, test1.f
2341 TableScan: test1, full_filters=[test1.d > Int32(2)]
2342 "
2343 )
2344 }
2345
2346 #[test]
2347 fn test_project_same_name_different_qualifier() -> Result<()> {
2348 let table_scan = test_table_scan()?;
2349 let left = LogicalPlanBuilder::from(table_scan)
2350 .project(vec![col("a"), col("b"), col("c")])?
2351 .build()?;
2352 let right_table_scan = test_table_scan_with_name("test1")?;
2353 let right = LogicalPlanBuilder::from(right_table_scan)
2354 .project(vec![col("a"), col("b"), col("c")])?
2355 .build()?;
2356 let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2)));
2357 let plan = LogicalPlanBuilder::from(left)
2358 .cross_join(right)?
2359 .project(vec![col("test.a"), col("test1.a")])?
2360 .filter(filter)?
2361 .build()?;
2362
2363 assert_optimized_plan_equal!(
2364 plan,
2365 @r"
2366 Projection: test.a, test1.a
2367 Cross Join:
2368 Projection: test.a, test.b, test.c
2369 TableScan: test, full_filters=[test.a = Int32(1)]
2370 Projection: test1.a, test1.b, test1.c
2371 TableScan: test1, full_filters=[test1.a > Int32(2)]
2372 "
2373 )
2374 }
2375
2376 #[test]
2378 fn filter_2_breaks_limits() -> Result<()> {
2379 let table_scan = test_table_scan()?;
2380 let plan = LogicalPlanBuilder::from(table_scan)
2381 .project(vec![col("a")])?
2382 .filter(col("a").lt_eq(lit(1i64)))?
2383 .limit(0, Some(1))?
2384 .project(vec![col("a")])?
2385 .filter(col("a").gt_eq(lit(1i64)))?
2386 .build()?;
2387 assert_snapshot!(plan,
2391 @r"
2392 Filter: test.a >= Int64(1)
2393 Projection: test.a
2394 Limit: skip=0, fetch=1
2395 Filter: test.a <= Int64(1)
2396 Projection: test.a
2397 TableScan: test
2398 ",
2399 );
2400 assert_optimized_plan_equal!(
2401 plan,
2402 @r"
2403 Projection: test.a
2404 Filter: test.a >= Int64(1)
2405 Limit: skip=0, fetch=1
2406 Projection: test.a
2407 TableScan: test, full_filters=[test.a <= Int64(1)]
2408 "
2409 )
2410 }
2411
2412 #[test]
2414 fn two_filters_on_same_depth() -> Result<()> {
2415 let table_scan = test_table_scan()?;
2416 let plan = LogicalPlanBuilder::from(table_scan)
2417 .limit(0, Some(1))?
2418 .filter(col("a").lt_eq(lit(1i64)))?
2419 .filter(col("a").gt_eq(lit(1i64)))?
2420 .project(vec![col("a")])?
2421 .build()?;
2422
2423 assert_snapshot!(plan,
2425 @r"
2426 Projection: test.a
2427 Filter: test.a >= Int64(1)
2428 Filter: test.a <= Int64(1)
2429 Limit: skip=0, fetch=1
2430 TableScan: test
2431 ",
2432 );
2433 assert_optimized_plan_equal!(
2434 plan,
2435 @r"
2436 Projection: test.a
2437 Filter: test.a >= Int64(1) AND test.a <= Int64(1)
2438 Limit: skip=0, fetch=1
2439 TableScan: test
2440 "
2441 )
2442 }
2443
2444 #[test]
2447 fn filters_user_defined_node() -> Result<()> {
2448 let table_scan = test_table_scan()?;
2449 let plan = LogicalPlanBuilder::from(table_scan)
2450 .filter(col("a").lt_eq(lit(1i64)))?
2451 .build()?;
2452
2453 let plan = user_defined::new(plan);
2454
2455 assert_snapshot!(plan,
2457 @r"
2458 TestUserDefined
2459 Filter: test.a <= Int64(1)
2460 TableScan: test
2461 ",
2462 );
2463 assert_optimized_plan_equal!(
2464 plan,
2465 @r"
2466 TestUserDefined
2467 TableScan: test, full_filters=[test.a <= Int64(1)]
2468 "
2469 )
2470 }
2471
2472 #[test]
2474 fn filter_on_join_on_common_independent() -> Result<()> {
2475 let table_scan = test_table_scan()?;
2476 let left = LogicalPlanBuilder::from(table_scan).build()?;
2477 let right_table_scan = test_table_scan_with_name("test2")?;
2478 let right = LogicalPlanBuilder::from(right_table_scan)
2479 .project(vec![col("a")])?
2480 .build()?;
2481 let plan = LogicalPlanBuilder::from(left)
2482 .join(
2483 right,
2484 JoinType::Inner,
2485 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2486 None,
2487 )?
2488 .filter(col("test.a").lt_eq(lit(1i64)))?
2489 .build()?;
2490
2491 assert_snapshot!(plan,
2493 @r"
2494 Filter: test.a <= Int64(1)
2495 Inner Join: test.a = test2.a
2496 TableScan: test
2497 Projection: test2.a
2498 TableScan: test2
2499 ",
2500 );
2501 assert_optimized_plan_equal!(
2503 plan,
2504 @r"
2505 Inner Join: test.a = test2.a
2506 TableScan: test, full_filters=[test.a <= Int64(1)]
2507 Projection: test2.a
2508 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2509 "
2510 )
2511 }
2512
2513 #[test]
2515 fn filter_using_join_on_common_independent() -> Result<()> {
2516 let table_scan = test_table_scan()?;
2517 let left = LogicalPlanBuilder::from(table_scan).build()?;
2518 let right_table_scan = test_table_scan_with_name("test2")?;
2519 let right = LogicalPlanBuilder::from(right_table_scan)
2520 .project(vec![col("a")])?
2521 .build()?;
2522 let plan = LogicalPlanBuilder::from(left)
2523 .join_using(
2524 right,
2525 JoinType::Inner,
2526 vec![Column::from_name("a".to_string())],
2527 )?
2528 .filter(col("a").lt_eq(lit(1i64)))?
2529 .build()?;
2530
2531 assert_snapshot!(plan,
2533 @r"
2534 Filter: test.a <= Int64(1)
2535 Inner Join: Using test.a = test2.a
2536 TableScan: test
2537 Projection: test2.a
2538 TableScan: test2
2539 ",
2540 );
2541 assert_optimized_plan_equal!(
2543 plan,
2544 @r"
2545 Inner Join: Using test.a = test2.a
2546 TableScan: test, full_filters=[test.a <= Int64(1)]
2547 Projection: test2.a
2548 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2549 "
2550 )
2551 }
2552
2553 #[test]
2555 fn filter_join_on_common_dependent() -> Result<()> {
2556 let table_scan = test_table_scan()?;
2557 let left = LogicalPlanBuilder::from(table_scan)
2558 .project(vec![col("a"), col("c")])?
2559 .build()?;
2560 let right_table_scan = test_table_scan_with_name("test2")?;
2561 let right = LogicalPlanBuilder::from(right_table_scan)
2562 .project(vec![col("a"), col("b")])?
2563 .build()?;
2564 let plan = LogicalPlanBuilder::from(left)
2565 .join(
2566 right,
2567 JoinType::Inner,
2568 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2569 None,
2570 )?
2571 .filter(col("c").lt_eq(col("b")))?
2572 .build()?;
2573
2574 assert_snapshot!(plan,
2576 @r"
2577 Filter: test.c <= test2.b
2578 Inner Join: test.a = test2.a
2579 Projection: test.a, test.c
2580 TableScan: test
2581 Projection: test2.a, test2.b
2582 TableScan: test2
2583 ",
2584 );
2585 assert_optimized_plan_equal!(
2587 plan,
2588 @r"
2589 Inner Join: test.a = test2.a Filter: test.c <= test2.b
2590 Projection: test.a, test.c
2591 TableScan: test
2592 Projection: test2.a, test2.b
2593 TableScan: test2
2594 "
2595 )
2596 }
2597
2598 #[test]
2600 fn filter_join_on_one_side() -> Result<()> {
2601 let table_scan = test_table_scan()?;
2602 let left = LogicalPlanBuilder::from(table_scan)
2603 .project(vec![col("a"), col("b")])?
2604 .build()?;
2605 let table_scan_right = test_table_scan_with_name("test2")?;
2606 let right = LogicalPlanBuilder::from(table_scan_right)
2607 .project(vec![col("a"), col("c")])?
2608 .build()?;
2609
2610 let plan = LogicalPlanBuilder::from(left)
2611 .join(
2612 right,
2613 JoinType::Inner,
2614 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2615 None,
2616 )?
2617 .filter(col("b").lt_eq(lit(1i64)))?
2618 .build()?;
2619
2620 assert_snapshot!(plan,
2622 @r"
2623 Filter: test.b <= Int64(1)
2624 Inner Join: test.a = test2.a
2625 Projection: test.a, test.b
2626 TableScan: test
2627 Projection: test2.a, test2.c
2628 TableScan: test2
2629 ",
2630 );
2631 assert_optimized_plan_equal!(
2632 plan,
2633 @r"
2634 Inner Join: test.a = test2.a
2635 Projection: test.a, test.b
2636 TableScan: test, full_filters=[test.b <= Int64(1)]
2637 Projection: test2.a, test2.c
2638 TableScan: test2
2639 "
2640 )
2641 }
2642
2643 #[test]
2646 fn filter_using_left_join() -> Result<()> {
2647 let table_scan = test_table_scan()?;
2648 let left = LogicalPlanBuilder::from(table_scan).build()?;
2649 let right_table_scan = test_table_scan_with_name("test2")?;
2650 let right = LogicalPlanBuilder::from(right_table_scan)
2651 .project(vec![col("a")])?
2652 .build()?;
2653 let plan = LogicalPlanBuilder::from(left)
2654 .join_using(
2655 right,
2656 JoinType::Left,
2657 vec![Column::from_name("a".to_string())],
2658 )?
2659 .filter(col("test2.a").lt_eq(lit(1i64)))?
2660 .build()?;
2661
2662 assert_snapshot!(plan,
2664 @r"
2665 Filter: test2.a <= Int64(1)
2666 Left Join: Using test.a = test2.a
2667 TableScan: test
2668 Projection: test2.a
2669 TableScan: test2
2670 ",
2671 );
2672 assert_optimized_plan_equal!(
2674 plan,
2675 @r"
2676 Filter: test2.a <= Int64(1)
2677 Left Join: Using test.a = test2.a
2678 TableScan: test, full_filters=[test.a <= Int64(1)]
2679 Projection: test2.a
2680 TableScan: test2
2681 "
2682 )
2683 }
2684
2685 #[test]
2687 fn filter_using_right_join() -> Result<()> {
2688 let table_scan = test_table_scan()?;
2689 let left = LogicalPlanBuilder::from(table_scan).build()?;
2690 let right_table_scan = test_table_scan_with_name("test2")?;
2691 let right = LogicalPlanBuilder::from(right_table_scan)
2692 .project(vec![col("a")])?
2693 .build()?;
2694 let plan = LogicalPlanBuilder::from(left)
2695 .join_using(
2696 right,
2697 JoinType::Right,
2698 vec![Column::from_name("a".to_string())],
2699 )?
2700 .filter(col("test.a").lt_eq(lit(1i64)))?
2701 .build()?;
2702
2703 assert_snapshot!(plan,
2705 @r"
2706 Filter: test.a <= Int64(1)
2707 Right Join: Using test.a = test2.a
2708 TableScan: test
2709 Projection: test2.a
2710 TableScan: test2
2711 ",
2712 );
2713 assert_optimized_plan_equal!(
2715 plan,
2716 @r"
2717 Filter: test.a <= Int64(1)
2718 Right Join: Using test.a = test2.a
2719 TableScan: test
2720 Projection: test2.a
2721 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2722 "
2723 )
2724 }
2725
2726 #[test]
2729 fn filter_using_left_join_on_common() -> Result<()> {
2730 let table_scan = test_table_scan()?;
2731 let left = LogicalPlanBuilder::from(table_scan).build()?;
2732 let right_table_scan = test_table_scan_with_name("test2")?;
2733 let right = LogicalPlanBuilder::from(right_table_scan)
2734 .project(vec![col("a")])?
2735 .build()?;
2736 let plan = LogicalPlanBuilder::from(left)
2737 .join_using(
2738 right,
2739 JoinType::Left,
2740 vec![Column::from_name("a".to_string())],
2741 )?
2742 .filter(col("a").lt_eq(lit(1i64)))?
2743 .build()?;
2744
2745 assert_snapshot!(plan,
2747 @r"
2748 Filter: test.a <= Int64(1)
2749 Left Join: Using test.a = test2.a
2750 TableScan: test
2751 Projection: test2.a
2752 TableScan: test2
2753 ",
2754 );
2755 assert_optimized_plan_equal!(
2757 plan,
2758 @r"
2759 Left Join: Using test.a = test2.a
2760 TableScan: test, full_filters=[test.a <= Int64(1)]
2761 Projection: test2.a
2762 TableScan: test2
2763 "
2764 )
2765 }
2766
2767 #[test]
2770 fn filter_using_right_join_on_common() -> Result<()> {
2771 let table_scan = test_table_scan()?;
2772 let left = LogicalPlanBuilder::from(table_scan).build()?;
2773 let right_table_scan = test_table_scan_with_name("test2")?;
2774 let right = LogicalPlanBuilder::from(right_table_scan)
2775 .project(vec![col("a")])?
2776 .build()?;
2777 let plan = LogicalPlanBuilder::from(left)
2778 .join_using(
2779 right,
2780 JoinType::Right,
2781 vec![Column::from_name("a".to_string())],
2782 )?
2783 .filter(col("test2.a").lt_eq(lit(1i64)))?
2784 .build()?;
2785
2786 assert_snapshot!(plan,
2788 @r"
2789 Filter: test2.a <= Int64(1)
2790 Right Join: Using test.a = test2.a
2791 TableScan: test
2792 Projection: test2.a
2793 TableScan: test2
2794 ",
2795 );
2796 assert_optimized_plan_equal!(
2798 plan,
2799 @r"
2800 Right Join: Using test.a = test2.a
2801 TableScan: test
2802 Projection: test2.a
2803 TableScan: test2, full_filters=[test2.a <= Int64(1)]
2804 "
2805 )
2806 }
2807
2808 #[test]
2810 fn join_on_with_filter() -> Result<()> {
2811 let table_scan = test_table_scan()?;
2812 let left = LogicalPlanBuilder::from(table_scan)
2813 .project(vec![col("a"), col("b"), col("c")])?
2814 .build()?;
2815 let right_table_scan = test_table_scan_with_name("test2")?;
2816 let right = LogicalPlanBuilder::from(right_table_scan)
2817 .project(vec![col("a"), col("b"), col("c")])?
2818 .build()?;
2819 let filter = col("test.c")
2820 .gt(lit(1u32))
2821 .and(col("test.b").lt(col("test2.b")))
2822 .and(col("test2.c").gt(lit(4u32)));
2823 let plan = LogicalPlanBuilder::from(left)
2824 .join(
2825 right,
2826 JoinType::Inner,
2827 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2828 Some(filter),
2829 )?
2830 .build()?;
2831
2832 assert_snapshot!(plan,
2834 @r"
2835 Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2836 Projection: test.a, test.b, test.c
2837 TableScan: test
2838 Projection: test2.a, test2.b, test2.c
2839 TableScan: test2
2840 ",
2841 );
2842 assert_optimized_plan_equal!(
2843 plan,
2844 @r"
2845 Inner Join: test.a = test2.a Filter: test.b < test2.b
2846 Projection: test.a, test.b, test.c
2847 TableScan: test, full_filters=[test.c > UInt32(1)]
2848 Projection: test2.a, test2.b, test2.c
2849 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2850 "
2851 )
2852 }
2853
2854 #[test]
2856 fn join_filter_removed() -> Result<()> {
2857 let table_scan = test_table_scan()?;
2858 let left = LogicalPlanBuilder::from(table_scan)
2859 .project(vec![col("a"), col("b"), col("c")])?
2860 .build()?;
2861 let right_table_scan = test_table_scan_with_name("test2")?;
2862 let right = LogicalPlanBuilder::from(right_table_scan)
2863 .project(vec![col("a"), col("b"), col("c")])?
2864 .build()?;
2865 let filter = col("test.b")
2866 .gt(lit(1u32))
2867 .and(col("test2.c").gt(lit(4u32)));
2868 let plan = LogicalPlanBuilder::from(left)
2869 .join(
2870 right,
2871 JoinType::Inner,
2872 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2873 Some(filter),
2874 )?
2875 .build()?;
2876
2877 assert_snapshot!(plan,
2879 @r"
2880 Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)
2881 Projection: test.a, test.b, test.c
2882 TableScan: test
2883 Projection: test2.a, test2.b, test2.c
2884 TableScan: test2
2885 ",
2886 );
2887 assert_optimized_plan_equal!(
2888 plan,
2889 @r"
2890 Inner Join: test.a = test2.a
2891 Projection: test.a, test.b, test.c
2892 TableScan: test, full_filters=[test.b > UInt32(1)]
2893 Projection: test2.a, test2.b, test2.c
2894 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2895 "
2896 )
2897 }
2898
2899 #[test]
2901 fn join_filter_on_common() -> Result<()> {
2902 let table_scan = test_table_scan()?;
2903 let left = LogicalPlanBuilder::from(table_scan)
2904 .project(vec![col("a")])?
2905 .build()?;
2906 let right_table_scan = test_table_scan_with_name("test2")?;
2907 let right = LogicalPlanBuilder::from(right_table_scan)
2908 .project(vec![col("b")])?
2909 .build()?;
2910 let filter = col("test.a").gt(lit(1u32));
2911 let plan = LogicalPlanBuilder::from(left)
2912 .join(
2913 right,
2914 JoinType::Inner,
2915 (vec![Column::from_name("a")], vec![Column::from_name("b")]),
2916 Some(filter),
2917 )?
2918 .build()?;
2919
2920 assert_snapshot!(plan,
2922 @r"
2923 Inner Join: test.a = test2.b Filter: test.a > UInt32(1)
2924 Projection: test.a
2925 TableScan: test
2926 Projection: test2.b
2927 TableScan: test2
2928 ",
2929 );
2930 assert_optimized_plan_equal!(
2931 plan,
2932 @r"
2933 Inner Join: test.a = test2.b
2934 Projection: test.a
2935 TableScan: test, full_filters=[test.a > UInt32(1)]
2936 Projection: test2.b
2937 TableScan: test2, full_filters=[test2.b > UInt32(1)]
2938 "
2939 )
2940 }
2941
2942 #[test]
2944 fn left_join_on_with_filter() -> Result<()> {
2945 let table_scan = test_table_scan()?;
2946 let left = LogicalPlanBuilder::from(table_scan)
2947 .project(vec![col("a"), col("b"), col("c")])?
2948 .build()?;
2949 let right_table_scan = test_table_scan_with_name("test2")?;
2950 let right = LogicalPlanBuilder::from(right_table_scan)
2951 .project(vec![col("a"), col("b"), col("c")])?
2952 .build()?;
2953 let filter = col("test.a")
2954 .gt(lit(1u32))
2955 .and(col("test.b").lt(col("test2.b")))
2956 .and(col("test2.c").gt(lit(4u32)));
2957 let plan = LogicalPlanBuilder::from(left)
2958 .join(
2959 right,
2960 JoinType::Left,
2961 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
2962 Some(filter),
2963 )?
2964 .build()?;
2965
2966 assert_snapshot!(plan,
2968 @r"
2969 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
2970 Projection: test.a, test.b, test.c
2971 TableScan: test
2972 Projection: test2.a, test2.b, test2.c
2973 TableScan: test2
2974 ",
2975 );
2976 assert_optimized_plan_equal!(
2977 plan,
2978 @r"
2979 Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b
2980 Projection: test.a, test.b, test.c
2981 TableScan: test
2982 Projection: test2.a, test2.b, test2.c
2983 TableScan: test2, full_filters=[test2.c > UInt32(4)]
2984 "
2985 )
2986 }
2987
2988 #[test]
2990 fn right_join_on_with_filter() -> Result<()> {
2991 let table_scan = test_table_scan()?;
2992 let left = LogicalPlanBuilder::from(table_scan)
2993 .project(vec![col("a"), col("b"), col("c")])?
2994 .build()?;
2995 let right_table_scan = test_table_scan_with_name("test2")?;
2996 let right = LogicalPlanBuilder::from(right_table_scan)
2997 .project(vec![col("a"), col("b"), col("c")])?
2998 .build()?;
2999 let filter = col("test.a")
3000 .gt(lit(1u32))
3001 .and(col("test.b").lt(col("test2.b")))
3002 .and(col("test2.c").gt(lit(4u32)));
3003 let plan = LogicalPlanBuilder::from(left)
3004 .join(
3005 right,
3006 JoinType::Right,
3007 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3008 Some(filter),
3009 )?
3010 .build()?;
3011
3012 assert_snapshot!(plan,
3014 @r"
3015 Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3016 Projection: test.a, test.b, test.c
3017 TableScan: test
3018 Projection: test2.a, test2.b, test2.c
3019 TableScan: test2
3020 ",
3021 );
3022 assert_optimized_plan_equal!(
3023 plan,
3024 @r"
3025 Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)
3026 Projection: test.a, test.b, test.c
3027 TableScan: test, full_filters=[test.a > UInt32(1)]
3028 Projection: test2.a, test2.b, test2.c
3029 TableScan: test2
3030 "
3031 )
3032 }
3033
3034 #[test]
3036 fn full_join_on_with_filter() -> Result<()> {
3037 let table_scan = test_table_scan()?;
3038 let left = LogicalPlanBuilder::from(table_scan)
3039 .project(vec![col("a"), col("b"), col("c")])?
3040 .build()?;
3041 let right_table_scan = test_table_scan_with_name("test2")?;
3042 let right = LogicalPlanBuilder::from(right_table_scan)
3043 .project(vec![col("a"), col("b"), col("c")])?
3044 .build()?;
3045 let filter = col("test.a")
3046 .gt(lit(1u32))
3047 .and(col("test.b").lt(col("test2.b")))
3048 .and(col("test2.c").gt(lit(4u32)));
3049 let plan = LogicalPlanBuilder::from(left)
3050 .join(
3051 right,
3052 JoinType::Full,
3053 (vec![Column::from_name("a")], vec![Column::from_name("a")]),
3054 Some(filter),
3055 )?
3056 .build()?;
3057
3058 assert_snapshot!(plan,
3060 @r"
3061 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3062 Projection: test.a, test.b, test.c
3063 TableScan: test
3064 Projection: test2.a, test2.b, test2.c
3065 TableScan: test2
3066 ",
3067 );
3068 assert_optimized_plan_equal!(
3069 plan,
3070 @r"
3071 Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)
3072 Projection: test.a, test.b, test.c
3073 TableScan: test
3074 Projection: test2.a, test2.b, test2.c
3075 TableScan: test2
3076 "
3077 )
3078 }
3079
3080 struct PushDownProvider {
3081 pub filter_support: TableProviderFilterPushDown,
3082 }
3083
3084 #[async_trait]
3085 impl TableSource for PushDownProvider {
3086 fn schema(&self) -> SchemaRef {
3087 Arc::new(Schema::new(vec![
3088 Field::new("a", DataType::Int32, true),
3089 Field::new("b", DataType::Int32, true),
3090 ]))
3091 }
3092
3093 fn table_type(&self) -> TableType {
3094 TableType::Base
3095 }
3096
3097 fn supports_filters_pushdown(
3098 &self,
3099 filters: &[&Expr],
3100 ) -> Result<Vec<TableProviderFilterPushDown>> {
3101 Ok((0..filters.len())
3102 .map(|_| self.filter_support.clone())
3103 .collect())
3104 }
3105
3106 fn as_any(&self) -> &dyn Any {
3107 self
3108 }
3109 }
3110
3111 fn table_scan_with_pushdown_provider_builder(
3112 filter_support: TableProviderFilterPushDown,
3113 filters: Vec<Expr>,
3114 projection: Option<Vec<usize>>,
3115 ) -> Result<LogicalPlanBuilder> {
3116 let test_provider = PushDownProvider { filter_support };
3117
3118 let table_scan = LogicalPlan::TableScan(TableScan {
3119 table_name: "test".into(),
3120 filters,
3121 projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?),
3122 projection,
3123 source: Arc::new(test_provider),
3124 fetch: None,
3125 });
3126
3127 Ok(LogicalPlanBuilder::from(table_scan))
3128 }
3129
3130 fn table_scan_with_pushdown_provider(
3131 filter_support: TableProviderFilterPushDown,
3132 ) -> Result<LogicalPlan> {
3133 table_scan_with_pushdown_provider_builder(filter_support, vec![], None)?
3134 .filter(col("a").eq(lit(1i64)))?
3135 .build()
3136 }
3137
3138 #[test]
3139 fn filter_with_table_provider_exact() -> Result<()> {
3140 let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
3141
3142 assert_optimized_plan_equal!(
3143 plan,
3144 @"TableScan: test, full_filters=[a = Int64(1)]"
3145 )
3146 }
3147
3148 #[test]
3149 fn filter_with_table_provider_inexact() -> Result<()> {
3150 let plan =
3151 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3152
3153 assert_optimized_plan_equal!(
3154 plan,
3155 @r"
3156 Filter: a = Int64(1)
3157 TableScan: test, partial_filters=[a = Int64(1)]
3158 "
3159 )
3160 }
3161
3162 #[test]
3163 fn filter_with_table_provider_multiple_invocations() -> Result<()> {
3164 let plan =
3165 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;
3166
3167 let optimized_plan = PushDownFilter::new()
3168 .rewrite(plan, &OptimizerContext::new())
3169 .expect("failed to optimize plan")
3170 .data;
3171
3172 assert_optimized_plan_equal!(
3175 optimized_plan,
3176 @r"
3177 Filter: a = Int64(1)
3178 TableScan: test, partial_filters=[a = Int64(1)]
3179 "
3180 )
3181 }
3182
3183 #[test]
3184 fn filter_with_table_provider_unsupported() -> Result<()> {
3185 let plan =
3186 table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?;
3187
3188 assert_optimized_plan_equal!(
3189 plan,
3190 @r"
3191 Filter: a = Int64(1)
3192 TableScan: test
3193 "
3194 )
3195 }
3196
3197 #[test]
3198 fn multi_combined_filter() -> Result<()> {
3199 let plan = table_scan_with_pushdown_provider_builder(
3200 TableProviderFilterPushDown::Inexact,
3201 vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
3202 Some(vec![0]),
3203 )?
3204 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3205 .project(vec![col("a"), col("b")])?
3206 .build()?;
3207
3208 assert_optimized_plan_equal!(
3209 plan,
3210 @r"
3211 Projection: a, b
3212 Filter: a = Int64(10) AND b > Int64(11)
3213 TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
3214 "
3215 )
3216 }
3217
3218 #[test]
3219 fn multi_combined_filter_exact() -> Result<()> {
3220 let plan = table_scan_with_pushdown_provider_builder(
3221 TableProviderFilterPushDown::Exact,
3222 vec![],
3223 Some(vec![0]),
3224 )?
3225 .filter(and(col("a").eq(lit(10i64)), col("b").gt(lit(11i64))))?
3226 .project(vec![col("a"), col("b")])?
3227 .build()?;
3228
3229 assert_optimized_plan_equal!(
3230 plan,
3231 @r"
3232 Projection: a, b
3233 TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
3234 "
3235 )
3236 }
3237
3238 #[test]
3239 fn test_filter_with_alias() -> Result<()> {
3240 let table_scan = test_table_scan()?;
3244 let plan = LogicalPlanBuilder::from(table_scan)
3245 .project(vec![col("a").alias("b"), col("c")])?
3246 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3247 .build()?;
3248
3249 assert_snapshot!(plan,
3251 @r"
3252 Filter: b > Int64(10) AND test.c > Int64(10)
3253 Projection: test.a AS b, test.c
3254 TableScan: test
3255 ",
3256 );
3257 assert_optimized_plan_equal!(
3259 plan,
3260 @r"
3261 Projection: test.a AS b, test.c
3262 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3263 "
3264 )
3265 }
3266
3267 #[test]
3268 fn test_filter_with_alias_2() -> Result<()> {
3269 let table_scan = test_table_scan()?;
3273 let plan = LogicalPlanBuilder::from(table_scan)
3274 .project(vec![col("a").alias("b"), col("c")])?
3275 .project(vec![col("b"), col("c")])?
3276 .filter(and(col("b").gt(lit(10i64)), col("c").gt(lit(10i64))))?
3277 .build()?;
3278
3279 assert_snapshot!(plan,
3281 @r"
3282 Filter: b > Int64(10) AND test.c > Int64(10)
3283 Projection: b, test.c
3284 Projection: test.a AS b, test.c
3285 TableScan: test
3286 ",
3287 );
3288 assert_optimized_plan_equal!(
3290 plan,
3291 @r"
3292 Projection: b, test.c
3293 Projection: test.a AS b, test.c
3294 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3295 "
3296 )
3297 }
3298
3299 #[test]
3300 fn test_filter_with_multi_alias() -> Result<()> {
3301 let table_scan = test_table_scan()?;
3302 let plan = LogicalPlanBuilder::from(table_scan)
3303 .project(vec![col("a").alias("b"), col("c").alias("d")])?
3304 .filter(and(col("b").gt(lit(10i64)), col("d").gt(lit(10i64))))?
3305 .build()?;
3306
3307 assert_snapshot!(plan,
3309 @r"
3310 Filter: b > Int64(10) AND d > Int64(10)
3311 Projection: test.a AS b, test.c AS d
3312 TableScan: test
3313 ",
3314 );
3315 assert_optimized_plan_equal!(
3317 plan,
3318 @r"
3319 Projection: test.a AS b, test.c AS d
3320 TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]
3321 "
3322 )
3323 }
3324
3325 #[test]
3327 fn join_filter_with_alias() -> Result<()> {
3328 let table_scan = test_table_scan()?;
3329 let left = LogicalPlanBuilder::from(table_scan)
3330 .project(vec![col("a").alias("c")])?
3331 .build()?;
3332 let right_table_scan = test_table_scan_with_name("test2")?;
3333 let right = LogicalPlanBuilder::from(right_table_scan)
3334 .project(vec![col("b").alias("d")])?
3335 .build()?;
3336 let filter = col("c").gt(lit(1u32));
3337 let plan = LogicalPlanBuilder::from(left)
3338 .join(
3339 right,
3340 JoinType::Inner,
3341 (vec![Column::from_name("c")], vec![Column::from_name("d")]),
3342 Some(filter),
3343 )?
3344 .build()?;
3345
3346 assert_snapshot!(plan,
3347 @r"
3348 Inner Join: c = d Filter: c > UInt32(1)
3349 Projection: test.a AS c
3350 TableScan: test
3351 Projection: test2.b AS d
3352 TableScan: test2
3353 ",
3354 );
3355 assert_optimized_plan_equal!(
3357 plan,
3358 @r"
3359 Inner Join: c = d
3360 Projection: test.a AS c
3361 TableScan: test, full_filters=[test.a > UInt32(1)]
3362 Projection: test2.b AS d
3363 TableScan: test2, full_filters=[test2.b > UInt32(1)]
3364 "
3365 )
3366 }
3367
3368 #[test]
3369 fn test_in_filter_with_alias() -> Result<()> {
3370 let table_scan = test_table_scan()?;
3374 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3375 let plan = LogicalPlanBuilder::from(table_scan)
3376 .project(vec![col("a").alias("b"), col("c")])?
3377 .filter(in_list(col("b"), filter_value, false))?
3378 .build()?;
3379
3380 assert_snapshot!(plan,
3382 @r"
3383 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3384 Projection: test.a AS b, test.c
3385 TableScan: test
3386 ",
3387 );
3388 assert_optimized_plan_equal!(
3390 plan,
3391 @r"
3392 Projection: test.a AS b, test.c
3393 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3394 "
3395 )
3396 }
3397
3398 #[test]
3399 fn test_in_filter_with_alias_2() -> Result<()> {
3400 let table_scan = test_table_scan()?;
3404 let filter_value = vec![lit(1u32), lit(2u32), lit(3u32), lit(4u32)];
3405 let plan = LogicalPlanBuilder::from(table_scan)
3406 .project(vec![col("a").alias("b"), col("c")])?
3407 .project(vec![col("b"), col("c")])?
3408 .filter(in_list(col("b"), filter_value, false))?
3409 .build()?;
3410
3411 assert_snapshot!(plan,
3413 @r"
3414 Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])
3415 Projection: b, test.c
3416 Projection: test.a AS b, test.c
3417 TableScan: test
3418 ",
3419 );
3420 assert_optimized_plan_equal!(
3422 plan,
3423 @r"
3424 Projection: b, test.c
3425 Projection: test.a AS b, test.c
3426 TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]
3427 "
3428 )
3429 }
3430
3431 #[test]
3432 fn test_in_subquery_with_alias() -> Result<()> {
3433 let table_scan = test_table_scan()?;
3436 let table_scan_sq = test_table_scan_with_name("sq")?;
3437 let subplan = Arc::new(
3438 LogicalPlanBuilder::from(table_scan_sq)
3439 .project(vec![col("c")])?
3440 .build()?,
3441 );
3442 let plan = LogicalPlanBuilder::from(table_scan)
3443 .project(vec![col("a").alias("b"), col("c")])?
3444 .filter(in_subquery(col("b"), subplan))?
3445 .build()?;
3446
3447 assert_snapshot!(plan,
3449 @r"
3450 Filter: b IN (<subquery>)
3451 Subquery:
3452 Projection: sq.c
3453 TableScan: sq
3454 Projection: test.a AS b, test.c
3455 TableScan: test
3456 ",
3457 );
3458 assert_optimized_plan_equal!(
3460 plan,
3461 @r"
3462 Projection: test.a AS b, test.c
3463 TableScan: test, full_filters=[test.a IN (<subquery>)]
3464 Subquery:
3465 Projection: sq.c
3466 TableScan: sq
3467 "
3468 )
3469 }
3470
3471 #[test]
3472 fn test_propagation_of_optimized_inner_filters_with_projections() -> Result<()> {
3473 let plan = LogicalPlanBuilder::empty(true)
3475 .project(vec![lit(0i64).alias("a")])?
3476 .alias("b")?
3477 .project(vec![col("b.a")])?
3478 .alias("b")?
3479 .filter(col("b.a").eq(lit(1i64)))?
3480 .project(vec![col("b.a")])?
3481 .build()?;
3482
3483 assert_snapshot!(plan,
3484 @r"
3485 Projection: b.a
3486 Filter: b.a = Int64(1)
3487 SubqueryAlias: b
3488 Projection: b.a
3489 SubqueryAlias: b
3490 Projection: Int64(0) AS a
3491 EmptyRelation: rows=1
3492 ",
3493 );
3494 assert_optimized_plan_equal!(
3497 plan,
3498 @r"
3499 Projection: b.a
3500 SubqueryAlias: b
3501 Projection: b.a
3502 SubqueryAlias: b
3503 Projection: Int64(0) AS a
3504 Filter: Int64(0) = Int64(1)
3505 EmptyRelation: rows=1
3506 "
3507 )
3508 }
3509
3510 #[test]
3511 fn test_crossjoin_with_or_clause() -> Result<()> {
3512 let table_scan = test_table_scan()?;
3514 let left = LogicalPlanBuilder::from(table_scan)
3515 .project(vec![col("a"), col("b"), col("c")])?
3516 .build()?;
3517 let right_table_scan = test_table_scan_with_name("test1")?;
3518 let right = LogicalPlanBuilder::from(right_table_scan)
3519 .project(vec![col("a").alias("d"), col("a").alias("e")])?
3520 .build()?;
3521 let filter = or(
3522 and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
3523 and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
3524 );
3525 let plan = LogicalPlanBuilder::from(left)
3526 .cross_join(right)?
3527 .filter(filter)?
3528 .build()?;
3529
3530 assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r"
3531 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3532 Projection: test.a, test.b, test.c
3533 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3534 Projection: test1.a AS d, test1.a AS e
3535 TableScan: test1
3536 ")?;
3537
3538 let optimized_plan = PushDownFilter::new()
3541 .rewrite(plan, &OptimizerContext::new())
3542 .expect("failed to optimize plan")
3543 .data;
3544 assert_optimized_plan_equal!(
3545 optimized_plan,
3546 @r"
3547 Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)
3548 Projection: test.a, test.b, test.c
3549 TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]
3550 Projection: test1.a AS d, test1.a AS e
3551 TableScan: test1
3552 "
3553 )
3554 }
3555
3556 #[test]
3557 fn left_semi_join() -> Result<()> {
3558 let left = test_table_scan_with_name("test1")?;
3559 let right_table_scan = test_table_scan_with_name("test2")?;
3560 let right = LogicalPlanBuilder::from(right_table_scan)
3561 .project(vec![col("a"), col("b")])?
3562 .build()?;
3563 let plan = LogicalPlanBuilder::from(left)
3564 .join(
3565 right,
3566 JoinType::LeftSemi,
3567 (
3568 vec![Column::from_qualified_name("test1.a")],
3569 vec![Column::from_qualified_name("test2.a")],
3570 ),
3571 None,
3572 )?
3573 .filter(col("test2.a").lt_eq(lit(1i64)))?
3574 .build()?;
3575
3576 assert_snapshot!(plan,
3578 @r"
3579 Filter: test2.a <= Int64(1)
3580 LeftSemi Join: test1.a = test2.a
3581 TableScan: test1
3582 Projection: test2.a, test2.b
3583 TableScan: test2
3584 ",
3585 );
3586 assert_optimized_plan_equal!(
3588 plan,
3589 @r"
3590 Filter: test2.a <= Int64(1)
3591 LeftSemi Join: test1.a = test2.a
3592 TableScan: test1, full_filters=[test1.a <= Int64(1)]
3593 Projection: test2.a, test2.b
3594 TableScan: test2
3595 "
3596 )
3597 }
3598
3599 #[test]
3600 fn left_semi_join_with_filters() -> Result<()> {
3601 let left = test_table_scan_with_name("test1")?;
3602 let right_table_scan = test_table_scan_with_name("test2")?;
3603 let right = LogicalPlanBuilder::from(right_table_scan)
3604 .project(vec![col("a"), col("b")])?
3605 .build()?;
3606 let plan = LogicalPlanBuilder::from(left)
3607 .join(
3608 right,
3609 JoinType::LeftSemi,
3610 (
3611 vec![Column::from_qualified_name("test1.a")],
3612 vec![Column::from_qualified_name("test2.a")],
3613 ),
3614 Some(
3615 col("test1.b")
3616 .gt(lit(1u32))
3617 .and(col("test2.b").gt(lit(2u32))),
3618 ),
3619 )?
3620 .build()?;
3621
3622 assert_snapshot!(plan,
3624 @r"
3625 LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3626 TableScan: test1
3627 Projection: test2.a, test2.b
3628 TableScan: test2
3629 ",
3630 );
3631 assert_optimized_plan_equal!(
3633 plan,
3634 @r"
3635 LeftSemi Join: test1.a = test2.a
3636 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3637 Projection: test2.a, test2.b
3638 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3639 "
3640 )
3641 }
3642
3643 #[test]
3644 fn right_semi_join() -> Result<()> {
3645 let left = test_table_scan_with_name("test1")?;
3646 let right_table_scan = test_table_scan_with_name("test2")?;
3647 let right = LogicalPlanBuilder::from(right_table_scan)
3648 .project(vec![col("a"), col("b")])?
3649 .build()?;
3650 let plan = LogicalPlanBuilder::from(left)
3651 .join(
3652 right,
3653 JoinType::RightSemi,
3654 (
3655 vec![Column::from_qualified_name("test1.a")],
3656 vec![Column::from_qualified_name("test2.a")],
3657 ),
3658 None,
3659 )?
3660 .filter(col("test1.a").lt_eq(lit(1i64)))?
3661 .build()?;
3662
3663 assert_snapshot!(plan,
3665 @r"
3666 Filter: test1.a <= Int64(1)
3667 RightSemi Join: test1.a = test2.a
3668 TableScan: test1
3669 Projection: test2.a, test2.b
3670 TableScan: test2
3671 ",
3672 );
3673 assert_optimized_plan_equal!(
3675 plan,
3676 @r"
3677 Filter: test1.a <= Int64(1)
3678 RightSemi Join: test1.a = test2.a
3679 TableScan: test1
3680 Projection: test2.a, test2.b
3681 TableScan: test2, full_filters=[test2.a <= Int64(1)]
3682 "
3683 )
3684 }
3685
3686 #[test]
3687 fn right_semi_join_with_filters() -> Result<()> {
3688 let left = test_table_scan_with_name("test1")?;
3689 let right_table_scan = test_table_scan_with_name("test2")?;
3690 let right = LogicalPlanBuilder::from(right_table_scan)
3691 .project(vec![col("a"), col("b")])?
3692 .build()?;
3693 let plan = LogicalPlanBuilder::from(left)
3694 .join(
3695 right,
3696 JoinType::RightSemi,
3697 (
3698 vec![Column::from_qualified_name("test1.a")],
3699 vec![Column::from_qualified_name("test2.a")],
3700 ),
3701 Some(
3702 col("test1.b")
3703 .gt(lit(1u32))
3704 .and(col("test2.b").gt(lit(2u32))),
3705 ),
3706 )?
3707 .build()?;
3708
3709 assert_snapshot!(plan,
3711 @r"
3712 RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3713 TableScan: test1
3714 Projection: test2.a, test2.b
3715 TableScan: test2
3716 ",
3717 );
3718 assert_optimized_plan_equal!(
3720 plan,
3721 @r"
3722 RightSemi Join: test1.a = test2.a
3723 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3724 Projection: test2.a, test2.b
3725 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3726 "
3727 )
3728 }
3729
3730 #[test]
3731 fn left_anti_join() -> Result<()> {
3732 let table_scan = test_table_scan_with_name("test1")?;
3733 let left = LogicalPlanBuilder::from(table_scan)
3734 .project(vec![col("a"), col("b")])?
3735 .build()?;
3736 let right_table_scan = test_table_scan_with_name("test2")?;
3737 let right = LogicalPlanBuilder::from(right_table_scan)
3738 .project(vec![col("a"), col("b")])?
3739 .build()?;
3740 let plan = LogicalPlanBuilder::from(left)
3741 .join(
3742 right,
3743 JoinType::LeftAnti,
3744 (
3745 vec![Column::from_qualified_name("test1.a")],
3746 vec![Column::from_qualified_name("test2.a")],
3747 ),
3748 None,
3749 )?
3750 .filter(col("test2.a").gt(lit(2u32)))?
3751 .build()?;
3752
3753 assert_snapshot!(plan,
3755 @r"
3756 Filter: test2.a > UInt32(2)
3757 LeftAnti Join: test1.a = test2.a
3758 Projection: test1.a, test1.b
3759 TableScan: test1
3760 Projection: test2.a, test2.b
3761 TableScan: test2
3762 ",
3763 );
3764 assert_optimized_plan_equal!(
3766 plan,
3767 @r"
3768 Filter: test2.a > UInt32(2)
3769 LeftAnti Join: test1.a = test2.a
3770 Projection: test1.a, test1.b
3771 TableScan: test1, full_filters=[test1.a > UInt32(2)]
3772 Projection: test2.a, test2.b
3773 TableScan: test2
3774 "
3775 )
3776 }
3777
3778 #[test]
3779 fn left_anti_join_with_filters() -> Result<()> {
3780 let table_scan = test_table_scan_with_name("test1")?;
3781 let left = LogicalPlanBuilder::from(table_scan)
3782 .project(vec![col("a"), col("b")])?
3783 .build()?;
3784 let right_table_scan = test_table_scan_with_name("test2")?;
3785 let right = LogicalPlanBuilder::from(right_table_scan)
3786 .project(vec![col("a"), col("b")])?
3787 .build()?;
3788 let plan = LogicalPlanBuilder::from(left)
3789 .join(
3790 right,
3791 JoinType::LeftAnti,
3792 (
3793 vec![Column::from_qualified_name("test1.a")],
3794 vec![Column::from_qualified_name("test2.a")],
3795 ),
3796 Some(
3797 col("test1.b")
3798 .gt(lit(1u32))
3799 .and(col("test2.b").gt(lit(2u32))),
3800 ),
3801 )?
3802 .build()?;
3803
3804 assert_snapshot!(plan,
3806 @r"
3807 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3808 Projection: test1.a, test1.b
3809 TableScan: test1
3810 Projection: test2.a, test2.b
3811 TableScan: test2
3812 ",
3813 );
3814 assert_optimized_plan_equal!(
3816 plan,
3817 @r"
3818 LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)
3819 Projection: test1.a, test1.b
3820 TableScan: test1
3821 Projection: test2.a, test2.b
3822 TableScan: test2, full_filters=[test2.b > UInt32(2)]
3823 "
3824 )
3825 }
3826
3827 #[test]
3828 fn right_anti_join() -> Result<()> {
3829 let table_scan = test_table_scan_with_name("test1")?;
3830 let left = LogicalPlanBuilder::from(table_scan)
3831 .project(vec![col("a"), col("b")])?
3832 .build()?;
3833 let right_table_scan = test_table_scan_with_name("test2")?;
3834 let right = LogicalPlanBuilder::from(right_table_scan)
3835 .project(vec![col("a"), col("b")])?
3836 .build()?;
3837 let plan = LogicalPlanBuilder::from(left)
3838 .join(
3839 right,
3840 JoinType::RightAnti,
3841 (
3842 vec![Column::from_qualified_name("test1.a")],
3843 vec![Column::from_qualified_name("test2.a")],
3844 ),
3845 None,
3846 )?
3847 .filter(col("test1.a").gt(lit(2u32)))?
3848 .build()?;
3849
3850 assert_snapshot!(plan,
3852 @r"
3853 Filter: test1.a > UInt32(2)
3854 RightAnti Join: test1.a = test2.a
3855 Projection: test1.a, test1.b
3856 TableScan: test1
3857 Projection: test2.a, test2.b
3858 TableScan: test2
3859 ",
3860 );
3861 assert_optimized_plan_equal!(
3863 plan,
3864 @r"
3865 Filter: test1.a > UInt32(2)
3866 RightAnti Join: test1.a = test2.a
3867 Projection: test1.a, test1.b
3868 TableScan: test1
3869 Projection: test2.a, test2.b
3870 TableScan: test2, full_filters=[test2.a > UInt32(2)]
3871 "
3872 )
3873 }
3874
3875 #[test]
3876 fn right_anti_join_with_filters() -> Result<()> {
3877 let table_scan = test_table_scan_with_name("test1")?;
3878 let left = LogicalPlanBuilder::from(table_scan)
3879 .project(vec![col("a"), col("b")])?
3880 .build()?;
3881 let right_table_scan = test_table_scan_with_name("test2")?;
3882 let right = LogicalPlanBuilder::from(right_table_scan)
3883 .project(vec![col("a"), col("b")])?
3884 .build()?;
3885 let plan = LogicalPlanBuilder::from(left)
3886 .join(
3887 right,
3888 JoinType::RightAnti,
3889 (
3890 vec![Column::from_qualified_name("test1.a")],
3891 vec![Column::from_qualified_name("test2.a")],
3892 ),
3893 Some(
3894 col("test1.b")
3895 .gt(lit(1u32))
3896 .and(col("test2.b").gt(lit(2u32))),
3897 ),
3898 )?
3899 .build()?;
3900
3901 assert_snapshot!(plan,
3903 @r"
3904 RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)
3905 Projection: test1.a, test1.b
3906 TableScan: test1
3907 Projection: test2.a, test2.b
3908 TableScan: test2
3909 ",
3910 );
3911 assert_optimized_plan_equal!(
3913 plan,
3914 @r"
3915 RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)
3916 Projection: test1.a, test1.b
3917 TableScan: test1, full_filters=[test1.b > UInt32(1)]
3918 Projection: test2.a, test2.b
3919 TableScan: test2
3920 "
3921 )
3922 }
3923
3924 #[derive(Debug, PartialEq, Eq, Hash)]
3925 struct TestScalarUDF {
3926 signature: Signature,
3927 }
3928
3929 impl ScalarUDFImpl for TestScalarUDF {
3930 fn as_any(&self) -> &dyn Any {
3931 self
3932 }
3933 fn name(&self) -> &str {
3934 "TestScalarUDF"
3935 }
3936
3937 fn signature(&self) -> &Signature {
3938 &self.signature
3939 }
3940
3941 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
3942 Ok(DataType::Int32)
3943 }
3944
3945 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
3946 Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
3947 }
3948 }
3949
3950 #[test]
3951 fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
3952 let table_scan = test_table_scan_with_name("test1")?;
3954 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3955 signature: Signature::exact(vec![], Volatility::Volatile),
3956 });
3957 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3958
3959 let plan = LogicalPlanBuilder::from(table_scan)
3960 .aggregate(vec![col("a")], vec![sum(col("b"))])?
3961 .project(vec![col("a"), sum(col("b")), add(expr, lit(1)).alias("r")])?
3962 .alias("t")?
3963 .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
3964 .project(vec![col("t.a"), col("t.r")])?
3965 .build()?;
3966
3967 assert_snapshot!(plan,
3968 @r"
3969 Projection: t.a, t.r
3970 Filter: t.a > Int32(5) AND t.r > Float64(0.5)
3971 SubqueryAlias: t
3972 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3973 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3974 TableScan: test1
3975 ",
3976 );
3977 assert_optimized_plan_equal!(
3978 plan,
3979 @r"
3980 Projection: t.a, t.r
3981 SubqueryAlias: t
3982 Filter: r > Float64(0.5)
3983 Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r
3984 Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]
3985 TableScan: test1, full_filters=[test1.a > Int32(5)]
3986 "
3987 )
3988 }
3989
3990 #[test]
3991 fn test_push_down_volatile_function_in_join() -> Result<()> {
3992 let table_scan = test_table_scan_with_name("test1")?;
3994 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
3995 signature: Signature::exact(vec![], Volatility::Volatile),
3996 });
3997 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
3998 let left = LogicalPlanBuilder::from(table_scan).build()?;
3999 let right_table_scan = test_table_scan_with_name("test2")?;
4000 let right = LogicalPlanBuilder::from(right_table_scan).build()?;
4001 let plan = LogicalPlanBuilder::from(left)
4002 .join(
4003 right,
4004 JoinType::Inner,
4005 (
4006 vec![Column::from_qualified_name("test1.a")],
4007 vec![Column::from_qualified_name("test2.a")],
4008 ),
4009 None,
4010 )?
4011 .project(vec![col("test1.a").alias("a"), expr.alias("r")])?
4012 .alias("t")?
4013 .filter(col("t.r").gt(lit(0.8)))?
4014 .project(vec![col("t.a"), col("t.r")])?
4015 .build()?;
4016
4017 assert_snapshot!(plan,
4018 @r"
4019 Projection: t.a, t.r
4020 Filter: t.r > Float64(0.8)
4021 SubqueryAlias: t
4022 Projection: test1.a AS a, TestScalarUDF() AS r
4023 Inner Join: test1.a = test2.a
4024 TableScan: test1
4025 TableScan: test2
4026 ",
4027 );
4028 assert_optimized_plan_equal!(
4029 plan,
4030 @r"
4031 Projection: t.a, t.r
4032 SubqueryAlias: t
4033 Filter: r > Float64(0.8)
4034 Projection: test1.a AS a, TestScalarUDF() AS r
4035 Inner Join: test1.a = test2.a
4036 TableScan: test1
4037 TableScan: test2
4038 "
4039 )
4040 }
4041
4042 #[test]
4043 fn test_push_down_volatile_table_scan() -> Result<()> {
4044 let table_scan = test_table_scan()?;
4046 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4047 signature: Signature::exact(vec![], Volatility::Volatile),
4048 });
4049 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4050 let plan = LogicalPlanBuilder::from(table_scan)
4051 .project(vec![col("a"), col("b")])?
4052 .filter(expr.gt(lit(0.1)))?
4053 .build()?;
4054
4055 assert_snapshot!(plan,
4056 @r"
4057 Filter: TestScalarUDF() > Float64(0.1)
4058 Projection: test.a, test.b
4059 TableScan: test
4060 ",
4061 );
4062 assert_optimized_plan_equal!(
4063 plan,
4064 @r"
4065 Projection: test.a, test.b
4066 Filter: TestScalarUDF() > Float64(0.1)
4067 TableScan: test
4068 "
4069 )
4070 }
4071
4072 #[test]
4073 fn test_push_down_volatile_mixed_table_scan() -> Result<()> {
4074 let table_scan = test_table_scan()?;
4076 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4077 signature: Signature::exact(vec![], Volatility::Volatile),
4078 });
4079 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4080 let plan = LogicalPlanBuilder::from(table_scan)
4081 .project(vec![col("a"), col("b")])?
4082 .filter(
4083 expr.gt(lit(0.1))
4084 .and(col("t.a").gt(lit(5)))
4085 .and(col("t.b").gt(lit(10))),
4086 )?
4087 .build()?;
4088
4089 assert_snapshot!(plan,
4090 @r"
4091 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4092 Projection: test.a, test.b
4093 TableScan: test
4094 ",
4095 );
4096 assert_optimized_plan_equal!(
4097 plan,
4098 @r"
4099 Projection: test.a, test.b
4100 Filter: TestScalarUDF() > Float64(0.1)
4101 TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]
4102 "
4103 )
4104 }
4105
4106 #[test]
4107 fn test_push_down_volatile_mixed_unsupported_table_scan() -> Result<()> {
4108 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
4110 signature: Signature::exact(vec![], Volatility::Volatile),
4111 });
4112 let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![]));
4113 let plan = table_scan_with_pushdown_provider_builder(
4114 TableProviderFilterPushDown::Unsupported,
4115 vec![],
4116 None,
4117 )?
4118 .project(vec![col("a"), col("b")])?
4119 .filter(
4120 expr.gt(lit(0.1))
4121 .and(col("t.a").gt(lit(5)))
4122 .and(col("t.b").gt(lit(10))),
4123 )?
4124 .build()?;
4125
4126 assert_snapshot!(plan,
4127 @r"
4128 Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)
4129 Projection: a, b
4130 TableScan: test
4131 ",
4132 );
4133 assert_optimized_plan_equal!(
4134 plan,
4135 @r"
4136 Projection: a, b
4137 Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)
4138 TableScan: test
4139 "
4140 )
4141 }
4142
4143 #[test]
4144 fn test_push_down_filter_to_user_defined_node() -> Result<()> {
4145 #[derive(Debug, Hash, Eq, PartialEq)]
4147 struct TestUserNode {
4148 schema: DFSchemaRef,
4149 }
4150
4151 impl PartialOrd for TestUserNode {
4152 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
4153 None
4154 }
4155 }
4156
4157 impl TestUserNode {
4158 fn new() -> Self {
4159 let schema = Arc::new(
4160 DFSchema::new_with_metadata(
4161 vec![(None, Field::new("a", DataType::Int64, false).into())],
4162 Default::default(),
4163 )
4164 .unwrap(),
4165 );
4166
4167 Self { schema }
4168 }
4169 }
4170
4171 impl UserDefinedLogicalNodeCore for TestUserNode {
4172 fn name(&self) -> &str {
4173 "test_node"
4174 }
4175
4176 fn inputs(&self) -> Vec<&LogicalPlan> {
4177 vec![]
4178 }
4179
4180 fn schema(&self) -> &DFSchemaRef {
4181 &self.schema
4182 }
4183
4184 fn expressions(&self) -> Vec<Expr> {
4185 vec![]
4186 }
4187
4188 fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
4189 write!(f, "TestUserNode")
4190 }
4191
4192 fn with_exprs_and_inputs(
4193 &self,
4194 exprs: Vec<Expr>,
4195 inputs: Vec<LogicalPlan>,
4196 ) -> Result<Self> {
4197 assert!(exprs.is_empty());
4198 assert!(inputs.is_empty());
4199 Ok(Self {
4200 schema: Arc::clone(&self.schema),
4201 })
4202 }
4203 }
4204
4205 let node = LogicalPlan::Extension(Extension {
4207 node: Arc::new(TestUserNode::new()),
4208 });
4209
4210 let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?;
4211
4212 assert_snapshot!(plan,
4214 @r"
4215 Filter: Boolean(false)
4216 TestUserNode
4217 ",
4218 );
4219 assert_optimized_plan_equal!(
4221 plan,
4222 @r"
4223 Filter: Boolean(false)
4224 TestUserNode
4225 "
4226 )
4227 }
4228}