datafusion_physical_optimizer/enforce_sorting/
sort_pushdown.rs1use std::fmt::Debug;
19use std::sync::Arc;
20
21use crate::utils::{
22 add_sort_above, is_sort, is_sort_preserving_merge, is_union, is_window,
23};
24
25use arrow::datatypes::SchemaRef;
26use datafusion_common::tree_node::{Transformed, TreeNode};
27use datafusion_common::{internal_err, HashSet, JoinSide, Result};
28use datafusion_expr::JoinType;
29use datafusion_physical_expr::expressions::Column;
30use datafusion_physical_expr::utils::collect_columns;
31use datafusion_physical_expr::{
32 add_offset_to_physical_sort_exprs, EquivalenceProperties, PhysicalExprExt,
33};
34use datafusion_physical_expr_common::sort_expr::{
35 LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr,
36 PhysicalSortRequirement,
37};
38use datafusion_physical_plan::execution_plan::CardinalityEffect;
39use datafusion_physical_plan::filter::FilterExec;
40use datafusion_physical_plan::joins::utils::{
41 calculate_join_output_ordering, ColumnIndex,
42};
43use datafusion_physical_plan::joins::{HashJoinExec, SortMergeJoinExec};
44use datafusion_physical_plan::projection::ProjectionExec;
45use datafusion_physical_plan::repartition::RepartitionExec;
46use datafusion_physical_plan::sorts::sort::SortExec;
47use datafusion_physical_plan::tree_node::PlanContext;
48use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
49
50#[derive(Default, Clone, Debug)]
58pub struct ParentRequirements {
59 ordering_requirement: Option<OrderingRequirements>,
60 fetch: Option<usize>,
61}
62
63pub type SortPushDown = PlanContext<ParentRequirements>;
64
65pub fn assign_initial_requirements(sort_push_down: &mut SortPushDown) {
67 let reqs = sort_push_down.plan.required_input_ordering();
68 for (child, requirement) in sort_push_down.children.iter_mut().zip(reqs) {
69 child.data = ParentRequirements {
70 ordering_requirement: requirement,
71 fetch: child.plan.fetch(),
74 };
75 }
76}
77
78pub fn pushdown_sorts(sort_push_down: SortPushDown) -> Result<SortPushDown> {
80 sort_push_down
81 .transform_down(pushdown_sorts_helper)
82 .map(|transformed| transformed.data)
83}
84
85fn min_fetch(f1: Option<usize>, f2: Option<usize>) -> Option<usize> {
86 match (f1, f2) {
87 (Some(f1), Some(f2)) => Some(f1.min(f2)),
88 (Some(_), _) => f1,
89 (_, Some(_)) => f2,
90 _ => None,
91 }
92}
93
94fn pushdown_sorts_helper(
95 mut sort_push_down: SortPushDown,
96) -> Result<Transformed<SortPushDown>> {
97 let plan = sort_push_down.plan;
98 let parent_fetch = sort_push_down.data.fetch;
99
100 let Some(parent_requirement) = sort_push_down.data.ordering_requirement.clone()
101 else {
102 if is_sort(&plan) {
105 let Some(sort_ordering) = plan.output_ordering().cloned() else {
106 return internal_err!("SortExec should have output ordering");
107 };
108 let fetch = min_fetch(plan.fetch(), parent_fetch);
111 sort_push_down = sort_push_down
112 .children
113 .swap_remove(0)
114 .update_plan_from_children()?;
115 sort_push_down.data.fetch = fetch;
116 sort_push_down.data.ordering_requirement =
117 Some(OrderingRequirements::from(sort_ordering));
118 return pushdown_sorts_helper(sort_push_down);
121 }
122 sort_push_down.plan = plan;
123 return Ok(Transformed::no(sort_push_down));
124 };
125
126 let eqp = plan.equivalence_properties();
127 let satisfy_parent =
128 eqp.ordering_satisfy_requirement(parent_requirement.first().clone())?;
129
130 if is_sort(&plan) {
131 let Some(sort_ordering) = plan.output_ordering().cloned() else {
132 return internal_err!("SortExec should have output ordering");
133 };
134
135 let sort_fetch = plan.fetch();
136 let parent_is_stricter = eqp.requirements_compatible(
137 parent_requirement.first().clone(),
138 sort_ordering.clone().into(),
139 );
140
141 sort_push_down = sort_push_down
144 .children
145 .swap_remove(0)
146 .update_plan_from_children()?;
147 if !satisfy_parent && !parent_is_stricter {
148 sort_push_down = add_sort_above(
153 sort_push_down,
154 parent_requirement.into_single(),
155 parent_fetch,
156 );
157 sort_push_down.children[0].data = ParentRequirements {
159 ordering_requirement: Some(OrderingRequirements::from(sort_ordering)),
160 fetch: sort_fetch,
161 };
162 return Ok(Transformed::yes(sort_push_down));
163 } else {
164 sort_push_down.data.fetch = min_fetch(sort_fetch, parent_fetch);
167 let current_is_stricter = eqp.requirements_compatible(
168 sort_ordering.clone().into(),
169 parent_requirement.first().clone(),
170 );
171 sort_push_down.data.ordering_requirement = if current_is_stricter {
172 Some(OrderingRequirements::from(sort_ordering))
173 } else {
174 Some(parent_requirement)
175 };
176 return pushdown_sorts_helper(sort_push_down);
179 }
180 }
181
182 sort_push_down.plan = plan;
183 if satisfy_parent {
184 let reqs = sort_push_down.plan.required_input_ordering();
186
187 for (child, order) in sort_push_down.children.iter_mut().zip(reqs) {
188 child.data.ordering_requirement = order;
189 child.data.fetch = min_fetch(parent_fetch, child.data.fetch);
190 }
191 } else if let Some(adjusted) = pushdown_requirement_to_children(
192 &sort_push_down.plan,
193 parent_requirement.clone(),
194 parent_fetch,
195 )? {
196 let current_fetch = sort_push_down.plan.fetch();
199 for (child, order) in sort_push_down.children.iter_mut().zip(adjusted) {
200 child.data.ordering_requirement = order;
201 child.data.fetch = min_fetch(current_fetch, parent_fetch);
202 }
203 sort_push_down.data.ordering_requirement = None;
204 } else {
205 sort_push_down = add_sort_above(
207 sort_push_down,
208 parent_requirement.into_single(),
209 parent_fetch,
210 );
211 assign_initial_requirements(&mut sort_push_down);
212 }
213 Ok(Transformed::yes(sort_push_down))
214}
215
216fn pushdown_requirement_to_children(
219 plan: &Arc<dyn ExecutionPlan>,
220 parent_required: OrderingRequirements,
221 parent_fetch: Option<usize>,
222) -> Result<Option<Vec<Option<OrderingRequirements>>>> {
223 if parent_fetch.is_some() && !plan.supports_limit_pushdown() {
227 return Ok(None);
228 }
229 if parent_fetch.is_some() {
248 match plan.cardinality_effect() {
249 CardinalityEffect::Equal => {
250 }
252 _ => return Ok(None),
253 }
254 }
255
256 let maintains_input_order = plan.maintains_input_order();
257 if is_window(plan) {
258 let mut required_input_ordering = plan.required_input_ordering();
259 let maybe_child_requirement = required_input_ordering.swap_remove(0);
260 let child_plan = plan.children().swap_remove(0);
261 let Some(child_req) = maybe_child_requirement else {
262 return Ok(None);
263 };
264 match determine_children_requirement(&parent_required, &child_req, child_plan) {
265 RequirementsCompatibility::Satisfy => Ok(Some(vec![Some(child_req)])),
266 RequirementsCompatibility::Compatible(adjusted) => {
267 if !plan
272 .equivalence_properties()
273 .ordering_satisfy_requirement(parent_required.into_single())?
274 {
275 return Ok(None);
276 }
277
278 Ok(Some(vec![adjusted]))
279 }
280 RequirementsCompatibility::NonCompatible => Ok(None),
281 }
282 } else if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
283 let Some(sort_ordering) = sort_exec.properties().output_ordering().cloned()
284 else {
285 return internal_err!("SortExec should have output ordering");
286 };
287 sort_exec
288 .properties()
289 .eq_properties
290 .requirements_compatible(
291 parent_required.first().clone(),
292 sort_ordering.into(),
293 )
294 .then(|| Ok(vec![Some(parent_required)]))
295 .transpose()
296 } else if plan.fetch().is_some()
297 && plan.supports_limit_pushdown()
298 && plan
299 .maintains_input_order()
300 .into_iter()
301 .all(|maintain| maintain)
302 {
303 let Some(ordering) = plan.properties().output_ordering() else {
307 return Ok(Some(vec![Some(parent_required)]));
308 };
309 if plan.properties().eq_properties.requirements_compatible(
310 parent_required.first().clone(),
311 ordering.clone().into(),
312 ) {
313 Ok(Some(vec![Some(parent_required)]))
314 } else {
315 Ok(None)
316 }
317 } else if is_union(plan) {
318 Ok(Some(vec![Some(parent_required); plan.children().len()]))
321 } else if let Some(smj) = plan.as_any().downcast_ref::<SortMergeJoinExec>() {
322 let left_columns_len = smj.left().schema().fields().len();
323 let parent_ordering: Vec<PhysicalSortExpr> = parent_required
324 .first()
325 .iter()
326 .cloned()
327 .map(Into::into)
328 .collect();
329 let eqp = smj.properties().equivalence_properties();
330 match expr_source_side(eqp, parent_ordering, smj.join_type(), left_columns_len) {
331 Some((JoinSide::Left, ordering)) => try_pushdown_requirements_to_join(
332 smj,
333 parent_required.into_single(),
334 ordering,
335 JoinSide::Left,
336 ),
337 Some((JoinSide::Right, ordering)) => {
338 let right_offset =
339 smj.schema().fields.len() - smj.right().schema().fields.len();
340 let ordering = add_offset_to_physical_sort_exprs(
341 ordering,
342 -(right_offset as isize),
343 )?;
344 try_pushdown_requirements_to_join(
345 smj,
346 parent_required.into_single(),
347 ordering,
348 JoinSide::Right,
349 )
350 }
351 _ => {
352 Ok(None)
354 }
355 }
356 } else if maintains_input_order.is_empty()
357 || !maintains_input_order.iter().any(|o| *o)
358 || plan.as_any().is::<RepartitionExec>()
359 || plan.as_any().is::<FilterExec>()
360 || plan.as_any().is::<ProjectionExec>()
362 || pushdown_would_violate_requirements(&parent_required, plan.as_ref())
363 {
364 Ok(None)
368 } else if is_sort_preserving_merge(plan) {
369 let new_ordering = LexOrdering::from(parent_required.first().clone());
370 let mut spm_eqs = plan.equivalence_properties().clone();
371 let old_ordering = spm_eqs.output_ordering().unwrap();
372 let change = spm_eqs.reorder(new_ordering)?;
374 if !change || spm_eqs.ordering_satisfy(old_ordering)? {
375 Ok(Some(vec![Some(parent_required)]))
378 } else {
379 Ok(None)
382 }
383 } else if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
384 handle_hash_join(hash_join, parent_required)
385 } else {
386 handle_custom_pushdown(plan, parent_required, maintains_input_order)
387 }
388 }
390
391fn pushdown_would_violate_requirements(
394 parent_required: &OrderingRequirements,
395 child: &dyn ExecutionPlan,
396) -> bool {
397 child
398 .required_input_ordering()
399 .into_iter()
400 .flatten()
402 .any(|child_required| {
403 child_required
406 .into_single()
407 .iter()
408 .zip(parent_required.first().iter())
409 .all(|(c, p)| !c.compatible(p))
410 })
411}
412
413fn determine_children_requirement(
419 parent_required: &OrderingRequirements,
420 child_requirement: &OrderingRequirements,
421 child_plan: &Arc<dyn ExecutionPlan>,
422) -> RequirementsCompatibility {
423 let eqp = child_plan.equivalence_properties();
424 if eqp.requirements_compatible(
425 child_requirement.first().clone(),
426 parent_required.first().clone(),
427 ) {
428 RequirementsCompatibility::Satisfy
430 } else if eqp.requirements_compatible(
431 parent_required.first().clone(),
432 child_requirement.first().clone(),
433 ) {
434 RequirementsCompatibility::Compatible(Some(parent_required.clone()))
437 } else {
438 RequirementsCompatibility::NonCompatible
439 }
440}
441
442fn try_pushdown_requirements_to_join(
443 smj: &SortMergeJoinExec,
444 parent_required: LexRequirement,
445 sort_exprs: Vec<PhysicalSortExpr>,
446 push_side: JoinSide,
447) -> Result<Option<Vec<Option<OrderingRequirements>>>> {
448 let mut smj_required_orderings = smj.required_input_ordering();
449
450 let ordering = LexOrdering::new(sort_exprs.clone());
451 let (new_left_ordering, new_right_ordering) = match push_side {
452 JoinSide::Left => {
453 let mut left_eq_properties = smj.left().equivalence_properties().clone();
454 left_eq_properties.reorder(sort_exprs)?;
455 let Some(left_requirement) = smj_required_orderings.swap_remove(0) else {
456 return Ok(None);
457 };
458 if !left_eq_properties
459 .ordering_satisfy_requirement(left_requirement.into_single())?
460 {
461 return Ok(None);
462 }
463 (ordering.as_ref(), smj.right().output_ordering())
465 }
466 JoinSide::Right => {
467 let mut right_eq_properties = smj.right().equivalence_properties().clone();
468 right_eq_properties.reorder(sort_exprs)?;
469 let Some(right_requirement) = smj_required_orderings.swap_remove(1) else {
470 return Ok(None);
471 };
472 if !right_eq_properties
473 .ordering_satisfy_requirement(right_requirement.into_single())?
474 {
475 return Ok(None);
476 }
477 (smj.left().output_ordering(), ordering.as_ref())
479 }
480 JoinSide::None => return Ok(None),
481 };
482 let join_type = smj.join_type();
483 let probe_side = SortMergeJoinExec::probe_side(&join_type);
484 let new_output_ordering = calculate_join_output_ordering(
485 new_left_ordering,
486 new_right_ordering,
487 join_type,
488 smj.left().schema().fields.len(),
489 &smj.maintains_input_order(),
490 Some(probe_side),
491 )?;
492 let mut smj_eqs = smj.properties().equivalence_properties().clone();
493 if let Some(new_output_ordering) = new_output_ordering {
494 smj_eqs.reorder(new_output_ordering)?;
496 }
497 let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required)?;
498 Ok(should_pushdown.then(|| {
499 let mut required_input_ordering = smj.required_input_ordering();
500 let new_req = ordering.map(Into::into);
501 match push_side {
502 JoinSide::Left => {
503 required_input_ordering[0] = new_req;
504 }
505 JoinSide::Right => {
506 required_input_ordering[1] = new_req;
507 }
508 JoinSide::None => unreachable!(),
509 }
510 required_input_ordering
511 }))
512}
513
514fn expr_source_side(
515 eqp: &EquivalenceProperties,
516 mut ordering: Vec<PhysicalSortExpr>,
517 join_type: JoinType,
518 left_columns_len: usize,
519) -> Option<(JoinSide, Vec<PhysicalSortExpr>)> {
520 match join_type {
523 JoinType::Inner
524 | JoinType::Left
525 | JoinType::Right
526 | JoinType::Full
527 | JoinType::LeftMark
528 | JoinType::RightMark => {
529 let eq_group = eqp.eq_group();
530 let mut right_ordering = ordering.clone();
531 let (mut valid_left, mut valid_right) = (true, true);
532 for (left, right) in ordering.iter_mut().zip(right_ordering.iter_mut()) {
533 let col = left.expr.as_any().downcast_ref::<Column>()?;
534 let eq_class = eq_group.get_equivalence_class(&left.expr);
535 if col.index() < left_columns_len {
536 if valid_right {
537 valid_right = eq_class.is_some_and(|cls| {
538 for expr in cls.iter() {
539 if expr
540 .as_any()
541 .downcast_ref::<Column>()
542 .is_some_and(|c| c.index() >= left_columns_len)
543 {
544 right.expr = Arc::clone(expr);
545 return true;
546 }
547 }
548 false
549 });
550 }
551 } else if valid_left {
552 valid_left = eq_class.is_some_and(|cls| {
553 for expr in cls.iter() {
554 if expr
555 .as_any()
556 .downcast_ref::<Column>()
557 .is_some_and(|c| c.index() < left_columns_len)
558 {
559 left.expr = Arc::clone(expr);
560 return true;
561 }
562 }
563 false
564 });
565 };
566 if !(valid_left || valid_right) {
567 return None;
568 }
569 }
570 if valid_left {
571 Some((JoinSide::Left, ordering))
572 } else if valid_right {
573 Some((JoinSide::Right, right_ordering))
574 } else {
575 None
577 }
578 }
579 JoinType::LeftSemi | JoinType::LeftAnti => ordering
580 .iter()
581 .all(|e| e.expr.as_any().is::<Column>())
582 .then_some((JoinSide::Left, ordering)),
583 JoinType::RightSemi | JoinType::RightAnti => ordering
584 .iter()
585 .all(|e| e.expr.as_any().is::<Column>())
586 .then_some((JoinSide::Right, ordering)),
587 }
588}
589
590fn handle_custom_pushdown(
605 plan: &Arc<dyn ExecutionPlan>,
606 parent_required: OrderingRequirements,
607 maintains_input_order: Vec<bool>,
608) -> Result<Option<Vec<Option<OrderingRequirements>>>> {
609 if plan.children().is_empty() {
611 return Ok(None);
612 }
613
614 let requirement = parent_required.into_single();
617 let all_indices: HashSet<usize> = requirement
618 .iter()
619 .flat_map(|order| {
620 collect_columns(&order.expr)
621 .iter()
622 .map(|col| col.index())
623 .collect::<HashSet<_>>()
624 })
625 .collect();
626
627 let children_schema_lengths: Vec<usize> = plan
629 .children()
630 .iter()
631 .map(|c| c.schema().fields().len())
632 .collect();
633
634 let Some(maintained_child_idx) = maintains_input_order
636 .iter()
637 .enumerate()
638 .find(|(_, m)| **m)
639 .map(|pair| pair.0)
640 else {
641 return Ok(None);
642 };
643
644 let start_idx = children_schema_lengths[..maintained_child_idx]
646 .iter()
647 .sum::<usize>();
648 let end_idx = start_idx + children_schema_lengths[maintained_child_idx];
649 let all_from_maintained_child =
650 all_indices.iter().all(|i| i >= &start_idx && i < &end_idx);
651
652 if all_from_maintained_child {
654 let sub_offset = children_schema_lengths
655 .iter()
656 .take(maintained_child_idx)
657 .sum::<usize>();
658 let updated_parent_req = requirement
661 .into_iter()
662 .map(|req| {
663 let child_schema = plan.children()[maintained_child_idx].schema();
664 let updated_columns =
665 req.expr
666 .transform_up_with_lambdas_params(|expr, lambdas_params| {
667 match expr.as_any().downcast_ref::<Column>() {
668 Some(col) if !lambdas_params.contains(col.name()) => {
669 let new_index = col.index() - sub_offset;
670 Ok(Transformed::yes(Arc::new(Column::new(
671 child_schema.field(new_index).name(),
672 new_index,
673 ))))
674 }
675 _ => Ok(Transformed::no(expr)),
676 }
677 })?
678 .data;
679 Ok(PhysicalSortRequirement::new(updated_columns, req.options))
680 })
681 .collect::<Result<Vec<_>>>()?;
682
683 let result = maintains_input_order
685 .iter()
686 .map(|&maintains_order| {
687 if maintains_order {
688 LexRequirement::new(updated_parent_req.clone())
689 .map(OrderingRequirements::new)
690 } else {
691 None
692 }
693 })
694 .collect();
695
696 Ok(Some(result))
697 } else {
698 Ok(None)
699 }
700}
701
702fn handle_hash_join(
705 plan: &HashJoinExec,
706 parent_required: OrderingRequirements,
707) -> Result<Option<Vec<Option<OrderingRequirements>>>> {
708 if !plan.maintains_input_order()[1] {
711 return Ok(None);
712 }
713
714 let requirement = parent_required.into_single();
716 let all_indices: HashSet<_> = requirement
717 .iter()
718 .flat_map(|order| {
719 collect_columns(&order.expr)
720 .into_iter()
721 .map(|col| col.index())
722 .collect::<HashSet<_>>()
723 })
724 .collect();
725
726 let column_indices = build_join_column_index(plan);
727 let projected_indices: Vec<_> = if let Some(projection) = &plan.projection {
728 projection.iter().map(|&i| &column_indices[i]).collect()
729 } else {
730 column_indices.iter().collect()
731 };
732 let len_of_left_fields = projected_indices
733 .iter()
734 .filter(|ci| ci.side == JoinSide::Left)
735 .count();
736
737 let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields);
738
739 if all_from_right_child {
741 let updated_parent_req = requirement
743 .into_iter()
744 .map(|req| {
745 let child_schema = plan.children()[1].schema();
746 let updated_columns =
747 req.expr
748 .transform_up_with_lambdas_params(|expr, lambdas_params| {
749 match expr.as_any().downcast_ref::<Column>() {
750 Some(col) if !lambdas_params.contains(col.name()) => {
751 let index = projected_indices[col.index()].index;
752 Ok(Transformed::yes(Arc::new(Column::new(
753 child_schema.field(index).name(),
754 index,
755 ))))
756 }
757 _ => Ok(Transformed::no(expr)),
758 }
759 })?
760 .data;
761 Ok(PhysicalSortRequirement::new(updated_columns, req.options))
762 })
763 .collect::<Result<Vec<_>>>()?;
764
765 Ok(Some(vec![
767 None,
768 LexRequirement::new(updated_parent_req).map(OrderingRequirements::new),
769 ]))
770 } else {
771 Ok(None)
772 }
773}
774
775fn build_join_column_index(plan: &HashJoinExec) -> Vec<ColumnIndex> {
778 let map_fields = |schema: SchemaRef, side: JoinSide| {
779 schema
780 .fields()
781 .iter()
782 .enumerate()
783 .map(|(index, _)| ColumnIndex { index, side })
784 .collect::<Vec<_>>()
785 };
786
787 match plan.join_type() {
788 JoinType::Inner | JoinType::Right => {
789 map_fields(plan.left().schema(), JoinSide::Left)
790 .into_iter()
791 .chain(map_fields(plan.right().schema(), JoinSide::Right))
792 .collect::<Vec<_>>()
793 }
794 JoinType::RightSemi | JoinType::RightAnti => {
795 map_fields(plan.right().schema(), JoinSide::Right)
796 }
797 _ => unreachable!("unexpected join type: {}", plan.join_type()),
798 }
799}
800
801#[derive(Debug)]
803enum RequirementsCompatibility {
804 Satisfy,
806 Compatible(Option<OrderingRequirements>),
808 NonCompatible,
810}