datafusion_physical_optimizer/
output_requirements.rs1use std::sync::Arc;
26
27use crate::PhysicalOptimizerRule;
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
31use datafusion_common::{Result, Statistics};
32use datafusion_execution::TaskContext;
33use datafusion_physical_expr::Distribution;
34use datafusion_physical_expr_common::sort_expr::OrderingRequirements;
35use datafusion_physical_plan::execution_plan::Boundedness;
36use datafusion_physical_plan::projection::{
37 make_with_child, update_expr, update_ordering_requirement, ProjectionExec,
38};
39use datafusion_physical_plan::sorts::sort::SortExec;
40use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
41use datafusion_physical_plan::{
42 DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
43 SendableRecordBatchStream,
44};
45
46#[derive(Debug)]
56pub struct OutputRequirements {
57 mode: RuleMode,
58}
59
60impl OutputRequirements {
61 pub fn new_add_mode() -> Self {
66 Self {
67 mode: RuleMode::Add,
68 }
69 }
70
71 pub fn new_remove_mode() -> Self {
78 Self {
79 mode: RuleMode::Remove,
80 }
81 }
82}
83
84#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Hash)]
85enum RuleMode {
86 Add,
87 Remove,
88}
89
90#[derive(Debug)]
97pub struct OutputRequirementExec {
98 input: Arc<dyn ExecutionPlan>,
99 order_requirement: Option<OrderingRequirements>,
100 dist_requirement: Distribution,
101 cache: PlanProperties,
102 fetch: Option<usize>,
103}
104
105impl OutputRequirementExec {
106 pub fn new(
107 input: Arc<dyn ExecutionPlan>,
108 requirements: Option<OrderingRequirements>,
109 dist_requirement: Distribution,
110 fetch: Option<usize>,
111 ) -> Self {
112 let cache = Self::compute_properties(&input, &fetch);
113 Self {
114 input,
115 order_requirement: requirements,
116 dist_requirement,
117 cache,
118 fetch,
119 }
120 }
121
122 pub fn input(&self) -> Arc<dyn ExecutionPlan> {
123 Arc::clone(&self.input)
124 }
125
126 fn compute_properties(
128 input: &Arc<dyn ExecutionPlan>,
129 fetch: &Option<usize>,
130 ) -> PlanProperties {
131 let boundedness = if fetch.is_some() {
132 Boundedness::Bounded
133 } else {
134 input.boundedness()
135 };
136
137 PlanProperties::new(
138 input.equivalence_properties().clone(), input.output_partitioning().clone(), input.pipeline_behavior(), boundedness, )
143 }
144
145 pub fn fetch(&self) -> Option<usize> {
147 self.fetch
148 }
149}
150
151impl DisplayAs for OutputRequirementExec {
152 fn fmt_as(
153 &self,
154 t: DisplayFormatType,
155 f: &mut std::fmt::Formatter,
156 ) -> std::fmt::Result {
157 match t {
158 DisplayFormatType::Default | DisplayFormatType::Verbose => {
159 let order_cols = self
160 .order_requirement
161 .as_ref()
162 .map(|reqs| reqs.first())
163 .map(|lex| {
164 let pairs: Vec<String> = lex
165 .iter()
166 .map(|req| {
167 let direction = req
168 .options
169 .as_ref()
170 .map(
171 |opt| if opt.descending { "desc" } else { "asc" },
172 )
173 .unwrap_or("unspecified");
174 format!("({}, {direction})", req.expr)
175 })
176 .collect();
177 format!("[{}]", pairs.join(", "))
178 })
179 .unwrap_or_else(|| "[]".to_string());
180
181 write!(
182 f,
183 "OutputRequirementExec: order_by={}, dist_by={}",
184 order_cols, self.dist_requirement
185 )
186 }
187 DisplayFormatType::TreeRender => {
188 write!(f, "")
189 }
190 }
191 }
192}
193
194impl ExecutionPlan for OutputRequirementExec {
195 fn name(&self) -> &'static str {
196 "OutputRequirementExec"
197 }
198
199 fn as_any(&self) -> &dyn std::any::Any {
200 self
201 }
202
203 fn properties(&self) -> &PlanProperties {
204 &self.cache
205 }
206
207 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
208 vec![false]
209 }
210
211 fn required_input_distribution(&self) -> Vec<Distribution> {
212 vec![self.dist_requirement.clone()]
213 }
214
215 fn maintains_input_order(&self) -> Vec<bool> {
216 vec![true]
217 }
218
219 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
220 vec![&self.input]
221 }
222
223 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
224 vec![self.order_requirement.clone()]
225 }
226
227 fn with_new_children(
228 self: Arc<Self>,
229 mut children: Vec<Arc<dyn ExecutionPlan>>,
230 ) -> Result<Arc<dyn ExecutionPlan>> {
231 Ok(Arc::new(Self::new(
232 children.remove(0), self.order_requirement.clone(),
234 self.dist_requirement.clone(),
235 self.fetch,
236 )))
237 }
238
239 fn execute(
240 &self,
241 _partition: usize,
242 _context: Arc<TaskContext>,
243 ) -> Result<SendableRecordBatchStream> {
244 unreachable!();
245 }
246
247 fn statistics(&self) -> Result<Statistics> {
248 self.input.partition_statistics(None)
249 }
250
251 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
252 self.input.partition_statistics(partition)
253 }
254
255 fn try_swapping_with_projection(
256 &self,
257 projection: &ProjectionExec,
258 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
259 let proj_exprs = projection.expr();
261 if proj_exprs.len() >= projection.input().schema().fields().len() {
262 return Ok(None);
263 }
264
265 let mut requirements = self.required_input_ordering().swap_remove(0);
266 if let Some(reqs) = requirements {
267 let mut updated_reqs = vec![];
268 let (lexes, soft) = reqs.into_alternatives();
269 for lex in lexes.into_iter() {
270 let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)?
271 else {
272 return Ok(None);
273 };
274 updated_reqs.push(updated_lex);
275 }
276 requirements = OrderingRequirements::new_alternatives(updated_reqs, soft);
277 }
278
279 let dist_req = match &self.required_input_distribution()[0] {
280 Distribution::HashPartitioned(exprs) => {
281 let mut updated_exprs = vec![];
282 for expr in exprs {
283 let Some(new_expr) = update_expr(expr, projection.expr(), false)?
284 else {
285 return Ok(None);
286 };
287 updated_exprs.push(new_expr);
288 }
289 Distribution::HashPartitioned(updated_exprs)
290 }
291 dist => dist.clone(),
292 };
293
294 make_with_child(projection, &self.input()).map(|input| {
295 let e = OutputRequirementExec::new(input, requirements, dist_req, self.fetch);
296 Some(Arc::new(e) as _)
297 })
298 }
299
300 fn fetch(&self) -> Option<usize> {
301 self.fetch
302 }
303}
304
305impl PhysicalOptimizerRule for OutputRequirements {
306 fn optimize(
307 &self,
308 plan: Arc<dyn ExecutionPlan>,
309 _config: &ConfigOptions,
310 ) -> Result<Arc<dyn ExecutionPlan>> {
311 match self.mode {
312 RuleMode::Add => require_top_ordering(plan),
313 RuleMode::Remove => plan
314 .transform_up(|plan| {
315 if let Some(sort_req) =
316 plan.as_any().downcast_ref::<OutputRequirementExec>()
317 {
318 Ok(Transformed::yes(sort_req.input()))
319 } else {
320 Ok(Transformed::no(plan))
321 }
322 })
323 .data(),
324 }
325 }
326
327 fn name(&self) -> &str {
328 "OutputRequirements"
329 }
330
331 fn schema_check(&self) -> bool {
332 true
333 }
334}
335
336fn require_top_ordering(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
339 let (new_plan, is_changed) = require_top_ordering_helper(plan)?;
340 if is_changed {
341 Ok(new_plan)
342 } else {
343 Ok(Arc::new(OutputRequirementExec::new(
345 new_plan,
346 None,
348 Distribution::UnspecifiedDistribution,
349 None,
350 )) as _)
351 }
352}
353
354fn require_top_ordering_helper(
358 plan: Arc<dyn ExecutionPlan>,
359) -> Result<(Arc<dyn ExecutionPlan>, bool)> {
360 let mut children = plan.children();
361 if children.len() != 1 {
363 Ok((plan, false))
364 } else if let Some(sort_exec) = plan.as_any().downcast_ref::<SortExec>() {
365 let req_dist = sort_exec.required_input_distribution().swap_remove(0);
369 let req_ordering = sort_exec.expr();
370 let reqs = OrderingRequirements::from(req_ordering.clone());
371 let fetch = sort_exec.fetch();
372
373 Ok((
374 Arc::new(OutputRequirementExec::new(
375 plan,
376 Some(reqs),
377 req_dist,
378 fetch,
379 )) as _,
380 true,
381 ))
382 } else if let Some(spm) = plan.as_any().downcast_ref::<SortPreservingMergeExec>() {
383 let reqs = OrderingRequirements::from(spm.expr().clone());
384 let fetch = spm.fetch();
385 Ok((
386 Arc::new(OutputRequirementExec::new(
387 plan,
388 Some(reqs),
389 Distribution::SinglePartition,
390 fetch,
391 )) as _,
392 true,
393 ))
394 } else if plan.maintains_input_order()[0]
395 && (plan.required_input_ordering()[0]
396 .as_ref()
397 .is_none_or(|o| matches!(o, OrderingRequirements::Soft(_))))
398 {
399 let (new_child, is_changed) =
404 require_top_ordering_helper(Arc::clone(children.swap_remove(0)))?;
405 Ok((plan.with_new_children(vec![new_child])?, is_changed))
406 } else {
407 Ok((plan, false))
409 }
410}
411
412