datafusion_physical_plan/repartition/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the [`RepartitionExec`]  operator, which maps N input
19//! partitions to M output partitions based on a partitioning scheme, optionally
20//! maintaining the order of the input rows in the output.
21
22use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31    DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
34use crate::hash_utils::create_hashes;
35use crate::metrics::{BaselineMetrics, SpillMetrics};
36use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
37use crate::repartition::distributor_channels::{
38    channels, partition_aware_channels, DistributionReceiver, DistributionSender,
39};
40use crate::sorts::streaming_merge::StreamingMergeBuilder;
41use crate::spill::spill_manager::SpillManager;
42use crate::stream::RecordBatchStreamAdapter;
43use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
44
45use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
46use arrow::compute::take_arrays;
47use arrow::datatypes::{SchemaRef, UInt32Type};
48use datafusion_common::config::ConfigOptions;
49use datafusion_common::stats::Precision;
50use datafusion_common::utils::transpose;
51use datafusion_common::{internal_err, ColumnStatistics, HashMap};
52use datafusion_common::{not_impl_err, DataFusionError, Result};
53use datafusion_common_runtime::SpawnedTask;
54use datafusion_execution::disk_manager::RefCountedTempFile;
55use datafusion_execution::memory_pool::MemoryConsumer;
56use datafusion_execution::TaskContext;
57use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
58use datafusion_physical_expr_common::sort_expr::LexOrdering;
59
60use crate::filter_pushdown::{
61    ChildPushdownResult, FilterDescription, FilterPushdownPhase,
62    FilterPushdownPropagation,
63};
64use futures::stream::Stream;
65use futures::{FutureExt, StreamExt, TryStreamExt};
66use log::trace;
67use parking_lot::Mutex;
68
69mod distributor_channels;
70
71/// A batch in the repartition queue - either in memory or spilled to disk
72#[derive(Debug)]
73enum RepartitionBatch {
74    /// Batch held in memory (counts against memory reservation)
75    Memory(RecordBatch),
76    /// Batch spilled to disk (one file per batch for queue semantics)
77    /// File automatically deleted when dropped via reference counting
78    /// The size field stores the original batch size for validation when reading back
79    Spilled {
80        spill_file: RefCountedTempFile,
81        size: usize,
82    },
83}
84
85type MaybeBatch = Option<Result<RepartitionBatch>>;
86type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
87type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
88
89/// Channels and resources for a single output partition
90#[derive(Debug)]
91struct PartitionChannels {
92    /// Senders for each input partition to send data to this output partition
93    tx: InputPartitionsToCurrentPartitionSender,
94    /// Receivers for each input partition sending data to this output partition
95    rx: InputPartitionsToCurrentPartitionReceiver,
96    /// Memory reservation for this output partition
97    reservation: SharedMemoryReservation,
98    /// Spill manager for handling disk spills for this output partition
99    spill_manager: Arc<SpillManager>,
100}
101
102#[derive(Debug)]
103struct ConsumingInputStreamsState {
104    /// Channels for sending batches from input partitions to output partitions.
105    /// Key is the partition number.
106    channels: HashMap<usize, PartitionChannels>,
107
108    /// Helper that ensures that that background job is killed once it is no longer needed.
109    abort_helper: Arc<Vec<SpawnedTask<()>>>,
110}
111
112/// Inner state of [`RepartitionExec`].
113#[derive(Default)]
114enum RepartitionExecState {
115    /// Not initialized yet. This is the default state stored in the RepartitionExec node
116    /// upon instantiation.
117    #[default]
118    NotInitialized,
119    /// Input streams are initialized, but they are still not being consumed. The node
120    /// transitions to this state when the arrow's RecordBatch stream is created in
121    /// RepartitionExec::execute(), but before any message is polled.
122    InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
123    /// The input streams are being consumed. The node transitions to this state when
124    /// the first message in the arrow's RecordBatch stream is consumed.
125    ConsumingInputStreams(ConsumingInputStreamsState),
126}
127
128impl Debug for RepartitionExecState {
129    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
130        match self {
131            RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
132            RepartitionExecState::InputStreamsInitialized(v) => {
133                write!(f, "InputStreamsInitialized({:?})", v.len())
134            }
135            RepartitionExecState::ConsumingInputStreams(v) => {
136                write!(f, "ConsumingInputStreams({v:?})")
137            }
138        }
139    }
140}
141
142impl RepartitionExecState {
143    fn ensure_input_streams_initialized(
144        &mut self,
145        input: Arc<dyn ExecutionPlan>,
146        metrics: ExecutionPlanMetricsSet,
147        output_partitions: usize,
148        ctx: Arc<TaskContext>,
149    ) -> Result<()> {
150        if !matches!(self, RepartitionExecState::NotInitialized) {
151            return Ok(());
152        }
153
154        let num_input_partitions = input.output_partitioning().partition_count();
155        let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
156
157        for i in 0..num_input_partitions {
158            let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
159
160            let timer = metrics.fetch_time.timer();
161            let stream = input.execute(i, Arc::clone(&ctx))?;
162            timer.done();
163
164            streams_and_metrics.push((stream, metrics));
165        }
166        *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
167        Ok(())
168    }
169
170    fn consume_input_streams(
171        &mut self,
172        input: Arc<dyn ExecutionPlan>,
173        metrics: ExecutionPlanMetricsSet,
174        partitioning: Partitioning,
175        preserve_order: bool,
176        name: String,
177        context: Arc<TaskContext>,
178    ) -> Result<&mut ConsumingInputStreamsState> {
179        let streams_and_metrics = match self {
180            RepartitionExecState::NotInitialized => {
181                self.ensure_input_streams_initialized(
182                    Arc::clone(&input),
183                    metrics.clone(),
184                    partitioning.partition_count(),
185                    Arc::clone(&context),
186                )?;
187                let RepartitionExecState::InputStreamsInitialized(value) = self else {
188                    // This cannot happen, as ensure_input_streams_initialized() was just called,
189                    // but the compiler does not know.
190                    return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized");
191                };
192                value
193            }
194            RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
195            RepartitionExecState::InputStreamsInitialized(value) => value,
196        };
197
198        let num_input_partitions = streams_and_metrics.len();
199        let num_output_partitions = partitioning.partition_count();
200
201        let (txs, rxs) = if preserve_order {
202            let (txs, rxs) =
203                partition_aware_channels(num_input_partitions, num_output_partitions);
204            // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
205            let txs = transpose(txs);
206            let rxs = transpose(rxs);
207            (txs, rxs)
208        } else {
209            // create one channel per *output* partition
210            // note we use a custom channel that ensures there is always data for each receiver
211            // but limits the amount of buffering if required.
212            let (txs, rxs) = channels(num_output_partitions);
213            // Clone sender for each input partitions
214            let txs = txs
215                .into_iter()
216                .map(|item| vec![item; num_input_partitions])
217                .collect::<Vec<_>>();
218            let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
219            (txs, rxs)
220        };
221
222        let mut channels = HashMap::with_capacity(txs.len());
223        for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
224            let reservation = Arc::new(Mutex::new(
225                MemoryConsumer::new(format!("{name}[{partition}]"))
226                    .with_can_spill(true)
227                    .register(context.memory_pool()),
228            ));
229            let spill_metrics = SpillMetrics::new(&metrics, partition);
230            let spill_manager = Arc::new(SpillManager::new(
231                Arc::clone(&context.runtime_env()),
232                spill_metrics,
233                input.schema(),
234            ));
235            channels.insert(
236                partition,
237                PartitionChannels {
238                    tx,
239                    rx,
240                    reservation,
241                    spill_manager,
242                },
243            );
244        }
245
246        // launch one async task per *input* partition
247        let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
248        for (i, (stream, metrics)) in
249            std::mem::take(streams_and_metrics).into_iter().enumerate()
250        {
251            let txs: HashMap<_, _> = channels
252                .iter()
253                .map(|(partition, channels)| {
254                    (
255                        *partition,
256                        (
257                            channels.tx[i].clone(),
258                            Arc::clone(&channels.reservation),
259                            Arc::clone(&channels.spill_manager),
260                        ),
261                    )
262                })
263                .collect();
264
265            let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
266                stream,
267                txs.clone(),
268                partitioning.clone(),
269                metrics,
270            ));
271
272            // In a separate task, wait for each input to be done
273            // (and pass along any errors, including panic!s)
274            let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task(
275                input_task,
276                txs.into_iter()
277                    .map(|(partition, (tx, _reservation, _spill_manager))| {
278                        (partition, tx)
279                    })
280                    .collect(),
281            ));
282            spawned_tasks.push(wait_for_task);
283        }
284        *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
285            channels,
286            abort_helper: Arc::new(spawned_tasks),
287        });
288        match self {
289            RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
290            _ => unreachable!(),
291        }
292    }
293}
294
295/// A utility that can be used to partition batches based on [`Partitioning`]
296pub struct BatchPartitioner {
297    state: BatchPartitionerState,
298    timer: metrics::Time,
299}
300
301enum BatchPartitionerState {
302    Hash {
303        random_state: ahash::RandomState,
304        exprs: Vec<Arc<dyn PhysicalExpr>>,
305        num_partitions: usize,
306        hash_buffer: Vec<u64>,
307    },
308    RoundRobin {
309        num_partitions: usize,
310        next_idx: usize,
311    },
312}
313
314impl BatchPartitioner {
315    /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`]
316    ///
317    /// The time spent repartitioning will be recorded to `timer`
318    pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
319        let state = match partitioning {
320            Partitioning::RoundRobinBatch(num_partitions) => {
321                BatchPartitionerState::RoundRobin {
322                    num_partitions,
323                    next_idx: 0,
324                }
325            }
326            Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
327                exprs,
328                num_partitions,
329                // Use fixed random hash
330                random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
331                hash_buffer: vec![],
332            },
333            other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
334        };
335
336        Ok(Self { state, timer })
337    }
338
339    /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`]
340    /// based on the [`Partitioning`] specified on construction
341    ///
342    /// `f` will be called for each partitioned [`RecordBatch`] with the corresponding
343    /// partition index. Any error returned by `f` will be immediately returned by this
344    /// function without attempting to publish further [`RecordBatch`]
345    ///
346    /// The time spent repartitioning, not including time spent in `f` will be recorded
347    /// to the [`metrics::Time`] provided on construction
348    pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
349    where
350        F: FnMut(usize, RecordBatch) -> Result<()>,
351    {
352        self.partition_iter(batch)?.try_for_each(|res| match res {
353            Ok((partition, batch)) => f(partition, batch),
354            Err(e) => Err(e),
355        })
356    }
357
358    /// Actual implementation of [`partition`](Self::partition).
359    ///
360    /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions,
361    /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve
362    /// this (so we don't need to clone the entire implementation).
363    fn partition_iter(
364        &mut self,
365        batch: RecordBatch,
366    ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
367        let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
368            match &mut self.state {
369                BatchPartitionerState::RoundRobin {
370                    num_partitions,
371                    next_idx,
372                } => {
373                    let idx = *next_idx;
374                    *next_idx = (*next_idx + 1) % *num_partitions;
375                    Box::new(std::iter::once(Ok((idx, batch))))
376                }
377                BatchPartitionerState::Hash {
378                    random_state,
379                    exprs,
380                    num_partitions: partitions,
381                    hash_buffer,
382                } => {
383                    // Tracking time required for distributing indexes across output partitions
384                    let timer = self.timer.timer();
385
386                    let arrays = exprs
387                        .iter()
388                        .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
389                        .collect::<Result<Vec<_>>>()?;
390
391                    hash_buffer.clear();
392                    hash_buffer.resize(batch.num_rows(), 0);
393
394                    create_hashes(&arrays, random_state, hash_buffer)?;
395
396                    let mut indices: Vec<_> = (0..*partitions)
397                        .map(|_| Vec::with_capacity(batch.num_rows()))
398                        .collect();
399
400                    for (index, hash) in hash_buffer.iter().enumerate() {
401                        indices[(*hash % *partitions as u64) as usize].push(index as u32);
402                    }
403
404                    // Finished building index-arrays for output partitions
405                    timer.done();
406
407                    // Borrowing partitioner timer to prevent moving `self` to closure
408                    let partitioner_timer = &self.timer;
409                    let it = indices
410                        .into_iter()
411                        .enumerate()
412                        .filter_map(|(partition, indices)| {
413                            let indices: PrimitiveArray<UInt32Type> = indices.into();
414                            (!indices.is_empty()).then_some((partition, indices))
415                        })
416                        .map(move |(partition, indices)| {
417                            // Tracking time required for repartitioned batches construction
418                            let _timer = partitioner_timer.timer();
419
420                            // Produce batches based on indices
421                            let columns = take_arrays(batch.columns(), &indices, None)?;
422
423                            let mut options = RecordBatchOptions::new();
424                            options = options.with_row_count(Some(indices.len()));
425                            let batch = RecordBatch::try_new_with_options(
426                                batch.schema(),
427                                columns,
428                                &options,
429                            )
430                            .unwrap();
431
432                            Ok((partition, batch))
433                        });
434
435                    Box::new(it)
436                }
437            };
438
439        Ok(it)
440    }
441
442    // return the number of output partitions
443    fn num_partitions(&self) -> usize {
444        match self.state {
445            BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
446            BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
447        }
448    }
449}
450
451/// Maps `N` input partitions to `M` output partitions based on a
452/// [`Partitioning`] scheme.
453///
454/// # Background
455///
456/// DataFusion, like most other commercial systems, with the
457/// notable exception of DuckDB, uses the "Exchange Operator" based
458/// approach to parallelism which works well in practice given
459/// sufficient care in implementation.
460///
461/// DataFusion's planner picks the target number of partitions and
462/// then [`RepartitionExec`] redistributes [`RecordBatch`]es to that number
463/// of output partitions.
464///
465/// For example, given `target_partitions=3` (trying to use 3 cores)
466/// but scanning an input with 2 partitions, `RepartitionExec` can be
467/// used to get 3 even streams of `RecordBatch`es
468///
469///
470///```text
471///        ▲                  ▲                  ▲
472///        │                  │                  │
473///        │                  │                  │
474///        │                  │                  │
475/// ┌───────────────┐  ┌───────────────┐  ┌───────────────┐
476/// │    GroupBy    │  │    GroupBy    │  │    GroupBy    │
477/// │   (Partial)   │  │   (Partial)   │  │   (Partial)   │
478/// └───────────────┘  └───────────────┘  └───────────────┘
479///        ▲                  ▲                  ▲
480///        └──────────────────┼──────────────────┘
481///                           │
482///              ┌─────────────────────────┐
483///              │     RepartitionExec     │
484///              │   (hash/round robin)    │
485///              └─────────────────────────┘
486///                         ▲   ▲
487///             ┌───────────┘   └───────────┐
488///             │                           │
489///             │                           │
490///        .─────────.                 .─────────.
491///     ,─'           '─.           ,─'           '─.
492///    ;      Input      :         ;      Input      :
493///    :   Partition 0   ;         :   Partition 1   ;
494///     ╲               ╱           ╲               ╱
495///      '─.         ,─'             '─.         ,─'
496///         `───────'                   `───────'
497/// ```
498///
499/// # Error Handling
500///
501/// If any of the input partitions return an error, the error is propagated to
502/// all output partitions and inputs are not polled again.
503///
504/// # Output Ordering
505///
506/// If more than one stream is being repartitioned, the output will be some
507/// arbitrary interleaving (and thus unordered) unless
508/// [`Self::with_preserve_order`] specifies otherwise.
509///
510/// # Footnote
511///
512/// The "Exchange Operator" was first described in the 1989 paper
513/// [Encapsulation of parallelism in the Volcano query processing
514/// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720)
515/// which uses the term "Exchange" for the concept of repartitioning
516/// data across threads.
517#[derive(Debug, Clone)]
518pub struct RepartitionExec {
519    /// Input execution plan
520    input: Arc<dyn ExecutionPlan>,
521    /// Inner state that is initialized when the parent calls .execute() on this node
522    /// and consumed as soon as the parent starts consuming this node.
523    state: Arc<Mutex<RepartitionExecState>>,
524    /// Execution metrics
525    metrics: ExecutionPlanMetricsSet,
526    /// Boolean flag to decide whether to preserve ordering. If true means
527    /// `SortPreservingRepartitionExec`, false means `RepartitionExec`.
528    preserve_order: bool,
529    /// Cache holding plan properties like equivalences, output partitioning etc.
530    cache: PlanProperties,
531}
532
533#[derive(Debug, Clone)]
534struct RepartitionMetrics {
535    /// Time in nanos to execute child operator and fetch batches
536    fetch_time: metrics::Time,
537    /// Repartitioning elapsed time in nanos
538    repartition_time: metrics::Time,
539    /// Time in nanos for sending resulting batches to channels.
540    ///
541    /// One metric per output partition.
542    send_time: Vec<metrics::Time>,
543}
544
545impl RepartitionMetrics {
546    pub fn new(
547        input_partition: usize,
548        num_output_partitions: usize,
549        metrics: &ExecutionPlanMetricsSet,
550    ) -> Self {
551        // Time in nanos to execute child operator and fetch batches
552        let fetch_time =
553            MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
554
555        // Time in nanos to perform repartitioning
556        let repartition_time =
557            MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
558
559        // Time in nanos for sending resulting batches to channels
560        let send_time = (0..num_output_partitions)
561            .map(|output_partition| {
562                let label =
563                    metrics::Label::new("outputPartition", output_partition.to_string());
564                MetricBuilder::new(metrics)
565                    .with_label(label)
566                    .subset_time("send_time", input_partition)
567            })
568            .collect();
569
570        Self {
571            fetch_time,
572            repartition_time,
573            send_time,
574        }
575    }
576}
577
578impl RepartitionExec {
579    /// Input execution plan
580    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
581        &self.input
582    }
583
584    /// Partitioning scheme to use
585    pub fn partitioning(&self) -> &Partitioning {
586        &self.cache.partitioning
587    }
588
589    /// Get preserve_order flag of the RepartitionExecutor
590    /// `true` means `SortPreservingRepartitionExec`, `false` means `RepartitionExec`
591    pub fn preserve_order(&self) -> bool {
592        self.preserve_order
593    }
594
595    /// Get name used to display this Exec
596    pub fn name(&self) -> &str {
597        "RepartitionExec"
598    }
599}
600
601impl DisplayAs for RepartitionExec {
602    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
603        match t {
604            DisplayFormatType::Default | DisplayFormatType::Verbose => {
605                write!(
606                    f,
607                    "{}: partitioning={}, input_partitions={}",
608                    self.name(),
609                    self.partitioning(),
610                    self.input.output_partitioning().partition_count()
611                )?;
612
613                if self.preserve_order {
614                    write!(f, ", preserve_order=true")?;
615                }
616
617                if let Some(sort_exprs) = self.sort_exprs() {
618                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
619                }
620                Ok(())
621            }
622            DisplayFormatType::TreeRender => {
623                writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
624
625                let input_partition_count =
626                    self.input.output_partitioning().partition_count();
627                let output_partition_count = self.partitioning().partition_count();
628                let input_to_output_partition_str =
629                    format!("{input_partition_count} -> {output_partition_count}");
630                writeln!(
631                    f,
632                    "partition_count(in->out)={input_to_output_partition_str}"
633                )?;
634
635                if self.preserve_order {
636                    writeln!(f, "preserve_order={}", self.preserve_order)?;
637                }
638                Ok(())
639            }
640        }
641    }
642}
643
644impl ExecutionPlan for RepartitionExec {
645    fn name(&self) -> &'static str {
646        "RepartitionExec"
647    }
648
649    /// Return a reference to Any that can be used for downcasting
650    fn as_any(&self) -> &dyn Any {
651        self
652    }
653
654    fn properties(&self) -> &PlanProperties {
655        &self.cache
656    }
657
658    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
659        vec![&self.input]
660    }
661
662    fn with_new_children(
663        self: Arc<Self>,
664        mut children: Vec<Arc<dyn ExecutionPlan>>,
665    ) -> Result<Arc<dyn ExecutionPlan>> {
666        let mut repartition = RepartitionExec::try_new(
667            children.swap_remove(0),
668            self.partitioning().clone(),
669        )?;
670        if self.preserve_order {
671            repartition = repartition.with_preserve_order();
672        }
673        Ok(Arc::new(repartition))
674    }
675
676    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
677        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
678    }
679
680    fn maintains_input_order(&self) -> Vec<bool> {
681        Self::maintains_input_order_helper(self.input(), self.preserve_order)
682    }
683
684    fn execute(
685        &self,
686        partition: usize,
687        context: Arc<TaskContext>,
688    ) -> Result<SendableRecordBatchStream> {
689        trace!(
690            "Start {}::execute for partition: {}",
691            self.name(),
692            partition
693        );
694
695        let input = Arc::clone(&self.input);
696        let partitioning = self.partitioning().clone();
697        let metrics = self.metrics.clone();
698        let preserve_order = self.sort_exprs().is_some();
699        let name = self.name().to_owned();
700        let schema = self.schema();
701        let schema_captured = Arc::clone(&schema);
702
703        // Get existing ordering to use for merging
704        let sort_exprs = self.sort_exprs().cloned();
705
706        let state = Arc::clone(&self.state);
707        if let Some(mut state) = state.try_lock() {
708            state.ensure_input_streams_initialized(
709                Arc::clone(&input),
710                metrics.clone(),
711                partitioning.partition_count(),
712                Arc::clone(&context),
713            )?;
714        }
715
716        let stream = futures::stream::once(async move {
717            let num_input_partitions = input.output_partitioning().partition_count();
718
719            // lock scope
720            let (mut rx, reservation, spill_manager, abort_helper) = {
721                // lock mutexes
722                let mut state = state.lock();
723                let state = state.consume_input_streams(
724                    Arc::clone(&input),
725                    metrics.clone(),
726                    partitioning,
727                    preserve_order,
728                    name.clone(),
729                    Arc::clone(&context),
730                )?;
731
732                // now return stream for the specified *output* partition which will
733                // read from the channel
734                let PartitionChannels {
735                    rx,
736                    reservation,
737                    spill_manager,
738                    ..
739                } = state
740                    .channels
741                    .remove(&partition)
742                    .expect("partition not used yet");
743
744                (
745                    rx,
746                    reservation,
747                    spill_manager,
748                    Arc::clone(&state.abort_helper),
749                )
750            };
751
752            trace!(
753                "Before returning stream in {name}::execute for partition: {partition}"
754            );
755
756            if preserve_order {
757                // Store streams from all the input partitions:
758                let input_streams = rx
759                    .into_iter()
760                    .map(|receiver| {
761                        Box::pin(PerPartitionStream {
762                            schema: Arc::clone(&schema_captured),
763                            receiver,
764                            _drop_helper: Arc::clone(&abort_helper),
765                            reservation: Arc::clone(&reservation),
766                            spill_manager: Arc::clone(&spill_manager),
767                            state: RepartitionStreamState::ReceivingFromChannel,
768                        }) as SendableRecordBatchStream
769                    })
770                    .collect::<Vec<_>>();
771                // Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
772
773                // Merge streams (while preserving ordering) coming from
774                // input partitions to this partition:
775                let fetch = None;
776                let merge_reservation =
777                    MemoryConsumer::new(format!("{name}[Merge {partition}]"))
778                        .register(context.memory_pool());
779                StreamingMergeBuilder::new()
780                    .with_streams(input_streams)
781                    .with_schema(schema_captured)
782                    .with_expressions(&sort_exprs.unwrap())
783                    .with_metrics(BaselineMetrics::new(&metrics, partition))
784                    .with_batch_size(context.session_config().batch_size())
785                    .with_fetch(fetch)
786                    .with_reservation(merge_reservation)
787                    .build()
788            } else {
789                Ok(Box::pin(RepartitionStream {
790                    num_input_partitions,
791                    num_input_partitions_processed: 0,
792                    schema: input.schema(),
793                    input: rx.swap_remove(0),
794                    _drop_helper: abort_helper,
795                    reservation,
796                    spill_manager,
797                    state: RepartitionStreamState::ReceivingFromChannel,
798                }) as SendableRecordBatchStream)
799            }
800        })
801        .try_flatten();
802        let stream = RecordBatchStreamAdapter::new(schema, stream);
803        Ok(Box::pin(stream))
804    }
805
806    fn metrics(&self) -> Option<MetricsSet> {
807        Some(self.metrics.clone_inner())
808    }
809
810    fn statistics(&self) -> Result<Statistics> {
811        self.input.partition_statistics(None)
812    }
813
814    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
815        if let Some(partition) = partition {
816            let partition_count = self.partitioning().partition_count();
817            if partition_count == 0 {
818                return Ok(Statistics::new_unknown(&self.schema()));
819            }
820
821            if partition >= partition_count {
822                return internal_err!(
823                    "RepartitionExec invalid partition {} (expected less than {})",
824                    partition,
825                    self.partitioning().partition_count()
826                );
827            }
828
829            let mut stats = self.input.partition_statistics(None)?;
830
831            // Distribute statistics across partitions
832            stats.num_rows = stats
833                .num_rows
834                .get_value()
835                .map(|rows| Precision::Inexact(rows / partition_count))
836                .unwrap_or(Precision::Absent);
837            stats.total_byte_size = stats
838                .total_byte_size
839                .get_value()
840                .map(|bytes| Precision::Inexact(bytes / partition_count))
841                .unwrap_or(Precision::Absent);
842
843            // Make all column stats unknown
844            stats.column_statistics = stats
845                .column_statistics
846                .iter()
847                .map(|_| ColumnStatistics::new_unknown())
848                .collect();
849
850            Ok(stats)
851        } else {
852            self.input.partition_statistics(None)
853        }
854    }
855
856    fn cardinality_effect(&self) -> CardinalityEffect {
857        CardinalityEffect::Equal
858    }
859
860    fn try_swapping_with_projection(
861        &self,
862        projection: &ProjectionExec,
863    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
864        // If the projection does not narrow the schema, we should not try to push it down.
865        if projection.expr().len() >= projection.input().schema().fields().len() {
866            return Ok(None);
867        }
868
869        // If pushdown is not beneficial or applicable, break it.
870        if projection.benefits_from_input_partitioning()[0]
871            || !all_columns(projection.expr())
872        {
873            return Ok(None);
874        }
875
876        let new_projection = make_with_child(projection, self.input())?;
877
878        let new_partitioning = match self.partitioning() {
879            Partitioning::Hash(partitions, size) => {
880                let mut new_partitions = vec![];
881                for partition in partitions {
882                    let Some(new_partition) =
883                        update_expr(partition, projection.expr(), false)?
884                    else {
885                        return Ok(None);
886                    };
887                    new_partitions.push(new_partition);
888                }
889                Partitioning::Hash(new_partitions, *size)
890            }
891            others => others.clone(),
892        };
893
894        Ok(Some(Arc::new(RepartitionExec::try_new(
895            new_projection,
896            new_partitioning,
897        )?)))
898    }
899
900    fn gather_filters_for_pushdown(
901        &self,
902        _phase: FilterPushdownPhase,
903        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
904        _config: &ConfigOptions,
905    ) -> Result<FilterDescription> {
906        FilterDescription::from_children(parent_filters, &self.children())
907    }
908
909    fn handle_child_pushdown_result(
910        &self,
911        _phase: FilterPushdownPhase,
912        child_pushdown_result: ChildPushdownResult,
913        _config: &ConfigOptions,
914    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
915        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
916    }
917
918    fn repartitioned(
919        &self,
920        target_partitions: usize,
921        _config: &ConfigOptions,
922    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
923        use Partitioning::*;
924        let mut new_properties = self.cache.clone();
925        new_properties.partitioning = match new_properties.partitioning {
926            RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
927            Hash(hash, _) => Hash(hash, target_partitions),
928            UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
929        };
930        Ok(Some(Arc::new(Self {
931            input: Arc::clone(&self.input),
932            state: Arc::clone(&self.state),
933            metrics: self.metrics.clone(),
934            preserve_order: self.preserve_order,
935            cache: new_properties,
936        })))
937    }
938}
939
940impl RepartitionExec {
941    /// Create a new RepartitionExec, that produces output `partitioning`, and
942    /// does not preserve the order of the input (see [`Self::with_preserve_order`]
943    /// for more details)
944    pub fn try_new(
945        input: Arc<dyn ExecutionPlan>,
946        partitioning: Partitioning,
947    ) -> Result<Self> {
948        let preserve_order = false;
949        let cache =
950            Self::compute_properties(&input, partitioning.clone(), preserve_order);
951        Ok(RepartitionExec {
952            input,
953            state: Default::default(),
954            metrics: ExecutionPlanMetricsSet::new(),
955            preserve_order,
956            cache,
957        })
958    }
959
960    fn maintains_input_order_helper(
961        input: &Arc<dyn ExecutionPlan>,
962        preserve_order: bool,
963    ) -> Vec<bool> {
964        // We preserve ordering when repartition is order preserving variant or input partitioning is 1
965        vec![preserve_order || input.output_partitioning().partition_count() <= 1]
966    }
967
968    fn eq_properties_helper(
969        input: &Arc<dyn ExecutionPlan>,
970        preserve_order: bool,
971    ) -> EquivalenceProperties {
972        // Equivalence Properties
973        let mut eq_properties = input.equivalence_properties().clone();
974        // If the ordering is lost, reset the ordering equivalence class:
975        if !Self::maintains_input_order_helper(input, preserve_order)[0] {
976            eq_properties.clear_orderings();
977        }
978        // When there are more than one input partitions, they will be fused at the output.
979        // Therefore, remove per partition constants.
980        if input.output_partitioning().partition_count() > 1 {
981            eq_properties.clear_per_partition_constants();
982        }
983        eq_properties
984    }
985
986    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
987    fn compute_properties(
988        input: &Arc<dyn ExecutionPlan>,
989        partitioning: Partitioning,
990        preserve_order: bool,
991    ) -> PlanProperties {
992        PlanProperties::new(
993            Self::eq_properties_helper(input, preserve_order),
994            partitioning,
995            input.pipeline_behavior(),
996            input.boundedness(),
997        )
998        .with_scheduling_type(SchedulingType::Cooperative)
999        .with_evaluation_type(EvaluationType::Eager)
1000    }
1001
1002    /// Specify if this repartitioning operation should preserve the order of
1003    /// rows from its input when producing output. Preserving order is more
1004    /// expensive at runtime, so should only be set if the output of this
1005    /// operator can take advantage of it.
1006    ///
1007    /// If the input is not ordered, or has only one partition, this is a no op,
1008    /// and the node remains a `RepartitionExec`.
1009    pub fn with_preserve_order(mut self) -> Self {
1010        self.preserve_order =
1011                // If the input isn't ordered, there is no ordering to preserve
1012                self.input.output_ordering().is_some() &&
1013                // if there is only one input partition, merging is not required
1014                // to maintain order
1015                self.input.output_partitioning().partition_count() > 1;
1016        let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1017        self.cache = self.cache.with_eq_properties(eq_properties);
1018        self
1019    }
1020
1021    /// Return the sort expressions that are used to merge
1022    fn sort_exprs(&self) -> Option<&LexOrdering> {
1023        if self.preserve_order {
1024            self.input.output_ordering()
1025        } else {
1026            None
1027        }
1028    }
1029
1030    /// Pulls data from the specified input plan, feeding it to the
1031    /// output partitions based on the desired partitioning
1032    ///
1033    /// txs hold the output sending channels for each output partition
1034    async fn pull_from_input(
1035        mut stream: SendableRecordBatchStream,
1036        mut output_channels: HashMap<
1037            usize,
1038            (
1039                DistributionSender<MaybeBatch>,
1040                SharedMemoryReservation,
1041                Arc<SpillManager>,
1042            ),
1043        >,
1044        partitioning: Partitioning,
1045        metrics: RepartitionMetrics,
1046    ) -> Result<()> {
1047        let mut partitioner =
1048            BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
1049
1050        // While there are still outputs to send to, keep pulling inputs
1051        let mut batches_until_yield = partitioner.num_partitions();
1052        while !output_channels.is_empty() {
1053            // fetch the next batch
1054            let timer = metrics.fetch_time.timer();
1055            let result = stream.next().await;
1056            timer.done();
1057
1058            // Input is done
1059            let batch = match result {
1060                Some(result) => result?,
1061                None => break,
1062            };
1063
1064            // Handle empty batch
1065            if batch.num_rows() == 0 {
1066                continue;
1067            }
1068
1069            for res in partitioner.partition_iter(batch)? {
1070                let (partition, batch) = res?;
1071                let size = batch.get_array_memory_size();
1072
1073                let timer = metrics.send_time[partition].timer();
1074                // if there is still a receiver, send to it
1075                if let Some((tx, reservation, spill_manager)) =
1076                    output_channels.get_mut(&partition)
1077                {
1078                    let (batch_to_send, is_memory_batch) =
1079                        match reservation.lock().try_grow(size) {
1080                            Ok(_) => {
1081                                // Memory available - send in-memory batch
1082                                (RepartitionBatch::Memory(batch), true)
1083                            }
1084                            Err(_) => {
1085                                // We're memory limited - spill this single batch to its own file
1086                                let spill_file = spill_manager
1087                                    .spill_record_batch_and_finish(
1088                                        &[batch],
1089                                        &format!(
1090                                            "RepartitionExec spill partition {partition}"
1091                                        ),
1092                                    )?
1093                                    // Note that we handled empty batch above, so this is safe
1094                                    .expect("non-empty batch should produce spill file");
1095
1096                                // Store size for validation when reading back
1097                                (RepartitionBatch::Spilled { spill_file, size }, false)
1098                            }
1099                        };
1100
1101                    if tx.send(Some(Ok(batch_to_send))).await.is_err() {
1102                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
1103                        // Only shrink memory if it was a memory batch
1104                        if is_memory_batch {
1105                            reservation.lock().shrink(size);
1106                        }
1107                        output_channels.remove(&partition);
1108                    }
1109                }
1110                timer.done();
1111            }
1112
1113            // If the input stream is endless, we may spin forever and
1114            // never yield back to tokio.  See
1115            // https://github.com/apache/datafusion/issues/5278.
1116            //
1117            // However, yielding on every batch causes a bottleneck
1118            // when running with multiple cores. See
1119            // https://github.com/apache/datafusion/issues/6290
1120            //
1121            // Thus, heuristically yield after producing num_partition
1122            // batches
1123            //
1124            // In round robin this is ideal as each input will get a
1125            // new batch. In hash partitioning it may yield too often
1126            // on uneven distributions even if some partition can not
1127            // make progress, but parallelism is going to be limited
1128            // in that case anyways
1129            if batches_until_yield == 0 {
1130                tokio::task::yield_now().await;
1131                batches_until_yield = partitioner.num_partitions();
1132            } else {
1133                batches_until_yield -= 1;
1134            }
1135        }
1136
1137        Ok(())
1138    }
1139
1140    /// Waits for `input_task` which is consuming one of the inputs to
1141    /// complete. Upon each successful completion, sends a `None` to
1142    /// each of the output tx channels to signal one of the inputs is
1143    /// complete. Upon error, propagates the errors to all output tx
1144    /// channels.
1145    async fn wait_for_task(
1146        input_task: SpawnedTask<Result<()>>,
1147        txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1148    ) {
1149        // wait for completion, and propagate error
1150        // note we ignore errors on send (.ok) as that means the receiver has already shutdown.
1151
1152        match input_task.join().await {
1153            // Error in joining task
1154            Err(e) => {
1155                let e = Arc::new(e);
1156
1157                for (_, tx) in txs {
1158                    let err = Err(DataFusionError::Context(
1159                        "Join Error".to_string(),
1160                        Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1161                    ));
1162                    tx.send(Some(err)).await.ok();
1163                }
1164            }
1165            // Error from running input task
1166            Ok(Err(e)) => {
1167                // send the same Arc'd error to all output partitions
1168                let e = Arc::new(e);
1169
1170                for (_, tx) in txs {
1171                    // wrap it because need to send error to all output partitions
1172                    let err = Err(DataFusionError::from(&e));
1173                    tx.send(Some(err)).await.ok();
1174                }
1175            }
1176            // Input task completed successfully
1177            Ok(Ok(())) => {
1178                // notify each output partition that this input partition has no more data
1179                for (_, tx) in txs {
1180                    tx.send(None).await.ok();
1181                }
1182            }
1183        }
1184    }
1185}
1186
1187enum RepartitionStreamState {
1188    /// Waiting for next item from channel
1189    ReceivingFromChannel,
1190    /// Reading a spilled batch from disk (stream reads via tokio::fs)
1191    ReadingSpilledBatch(SendableRecordBatchStream),
1192}
1193
1194struct RepartitionStream {
1195    /// Number of input partitions that will be sending batches to this output channel
1196    num_input_partitions: usize,
1197
1198    /// Number of input partitions that have finished sending batches to this output channel
1199    num_input_partitions_processed: usize,
1200
1201    /// Schema wrapped by Arc
1202    schema: SchemaRef,
1203
1204    /// channel containing the repartitioned batches
1205    input: DistributionReceiver<MaybeBatch>,
1206
1207    /// Handle to ensure background tasks are killed when no longer needed.
1208    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1209
1210    /// Memory reservation.
1211    reservation: SharedMemoryReservation,
1212
1213    /// Spill manager for reading spilled batches
1214    spill_manager: Arc<SpillManager>,
1215
1216    /// Current state of the stream
1217    state: RepartitionStreamState,
1218}
1219
1220impl Stream for RepartitionStream {
1221    type Item = Result<RecordBatch>;
1222
1223    fn poll_next(
1224        mut self: Pin<&mut Self>,
1225        cx: &mut Context<'_>,
1226    ) -> Poll<Option<Self::Item>> {
1227        loop {
1228            match &mut self.state {
1229                RepartitionStreamState::ReceivingFromChannel => {
1230                    let value = futures::ready!(self.input.recv().poll_unpin(cx));
1231                    match value {
1232                        Some(Some(v)) => match v {
1233                            Ok(RepartitionBatch::Memory(batch)) => {
1234                                // Release memory and return
1235                                self.reservation
1236                                    .lock()
1237                                    .shrink(batch.get_array_memory_size());
1238                                return Poll::Ready(Some(Ok(batch)));
1239                            }
1240                            Ok(RepartitionBatch::Spilled { spill_file, size }) => {
1241                                // Read from disk - SpillReaderStream uses tokio::fs internally
1242                                // Pass the original size for validation
1243                                let stream = self
1244                                    .spill_manager
1245                                    .read_spill_as_stream(spill_file, Some(size))?;
1246                                self.state =
1247                                    RepartitionStreamState::ReadingSpilledBatch(stream);
1248                                // Continue loop to poll the stream immediately
1249                            }
1250                            Err(e) => {
1251                                return Poll::Ready(Some(Err(e)));
1252                            }
1253                        },
1254                        Some(None) => {
1255                            self.num_input_partitions_processed += 1;
1256
1257                            if self.num_input_partitions
1258                                == self.num_input_partitions_processed
1259                            {
1260                                // all input partitions have finished sending batches
1261                                return Poll::Ready(None);
1262                            } else {
1263                                // other partitions still have data to send
1264                                continue;
1265                            }
1266                        }
1267                        None => {
1268                            return Poll::Ready(None);
1269                        }
1270                    }
1271                }
1272                RepartitionStreamState::ReadingSpilledBatch(stream) => {
1273                    match futures::ready!(stream.poll_next_unpin(cx)) {
1274                        Some(Ok(batch)) => {
1275                            // Return batch and stay in ReadingSpilledBatch state to read more batches
1276                            return Poll::Ready(Some(Ok(batch)));
1277                        }
1278                        Some(Err(e)) => {
1279                            self.state = RepartitionStreamState::ReceivingFromChannel;
1280                            return Poll::Ready(Some(Err(e)));
1281                        }
1282                        None => {
1283                            // Spill stream ended - go back to receiving from channel
1284                            self.state = RepartitionStreamState::ReceivingFromChannel;
1285                            continue;
1286                        }
1287                    }
1288                }
1289            }
1290        }
1291    }
1292}
1293
1294impl RecordBatchStream for RepartitionStream {
1295    /// Get the schema
1296    fn schema(&self) -> SchemaRef {
1297        Arc::clone(&self.schema)
1298    }
1299}
1300
1301/// This struct converts a receiver to a stream.
1302/// Receiver receives data on an SPSC channel.
1303struct PerPartitionStream {
1304    /// Schema wrapped by Arc
1305    schema: SchemaRef,
1306
1307    /// channel containing the repartitioned batches
1308    receiver: DistributionReceiver<MaybeBatch>,
1309
1310    /// Handle to ensure background tasks are killed when no longer needed.
1311    _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1312
1313    /// Memory reservation.
1314    reservation: SharedMemoryReservation,
1315
1316    /// Spill manager for reading spilled batches
1317    spill_manager: Arc<SpillManager>,
1318
1319    /// Current state of the stream
1320    state: RepartitionStreamState,
1321}
1322
1323impl Stream for PerPartitionStream {
1324    type Item = Result<RecordBatch>;
1325
1326    fn poll_next(
1327        mut self: Pin<&mut Self>,
1328        cx: &mut Context<'_>,
1329    ) -> Poll<Option<Self::Item>> {
1330        loop {
1331            match &mut self.state {
1332                RepartitionStreamState::ReceivingFromChannel => {
1333                    let value = futures::ready!(self.receiver.recv().poll_unpin(cx));
1334                    match value {
1335                        Some(Some(v)) => match v {
1336                            Ok(RepartitionBatch::Memory(batch)) => {
1337                                // Release memory and return
1338                                self.reservation
1339                                    .lock()
1340                                    .shrink(batch.get_array_memory_size());
1341                                return Poll::Ready(Some(Ok(batch)));
1342                            }
1343                            Ok(RepartitionBatch::Spilled { spill_file, size }) => {
1344                                // Read from disk - SpillReaderStream uses tokio::fs internally
1345                                // Pass the original size for validation
1346                                let stream = self
1347                                    .spill_manager
1348                                    .read_spill_as_stream(spill_file, Some(size))?;
1349                                self.state =
1350                                    RepartitionStreamState::ReadingSpilledBatch(stream);
1351                                // Continue loop to poll the stream immediately
1352                            }
1353                            Err(e) => {
1354                                return Poll::Ready(Some(Err(e)));
1355                            }
1356                        },
1357                        Some(None) => {
1358                            // Input partition has finished sending batches
1359                            return Poll::Ready(None);
1360                        }
1361                        None => return Poll::Ready(None),
1362                    }
1363                }
1364
1365                RepartitionStreamState::ReadingSpilledBatch(stream) => {
1366                    match futures::ready!(stream.poll_next_unpin(cx)) {
1367                        Some(Ok(batch)) => {
1368                            // Return batch and stay in ReadingSpilledBatch state to read more batches
1369                            return Poll::Ready(Some(Ok(batch)));
1370                        }
1371                        Some(Err(e)) => {
1372                            self.state = RepartitionStreamState::ReceivingFromChannel;
1373                            return Poll::Ready(Some(Err(e)));
1374                        }
1375                        None => {
1376                            // Spill stream ended - go back to receiving from channel
1377                            self.state = RepartitionStreamState::ReceivingFromChannel;
1378                            continue;
1379                        }
1380                    }
1381                }
1382            }
1383        }
1384    }
1385}
1386
1387impl RecordBatchStream for PerPartitionStream {
1388    /// Get the schema
1389    fn schema(&self) -> SchemaRef {
1390        Arc::clone(&self.schema)
1391    }
1392}
1393
1394#[cfg(test)]
1395mod tests {
1396    use std::collections::HashSet;
1397
1398    use super::*;
1399    use crate::test::TestMemoryExec;
1400    use crate::{
1401        test::{
1402            assert_is_pending,
1403            exec::{
1404                assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1405                ErrorExec, MockExec,
1406            },
1407        },
1408        {collect, expressions::col},
1409    };
1410
1411    use arrow::array::{ArrayRef, StringArray, UInt32Array};
1412    use arrow::datatypes::{DataType, Field, Schema};
1413    use datafusion_common::cast::as_string_array;
1414    use datafusion_common::exec_err;
1415    use datafusion_common::test_util::batches_to_sort_string;
1416    use datafusion_common_runtime::JoinSet;
1417    use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1418    use insta::assert_snapshot;
1419    use itertools::Itertools;
1420
1421    #[tokio::test]
1422    async fn one_to_many_round_robin() -> Result<()> {
1423        // define input partitions
1424        let schema = test_schema();
1425        let partition = create_vec_batches(50);
1426        let partitions = vec![partition];
1427
1428        // repartition from 1 input to 4 output
1429        let output_partitions =
1430            repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1431
1432        assert_eq!(4, output_partitions.len());
1433        assert_eq!(13, output_partitions[0].len());
1434        assert_eq!(13, output_partitions[1].len());
1435        assert_eq!(12, output_partitions[2].len());
1436        assert_eq!(12, output_partitions[3].len());
1437
1438        Ok(())
1439    }
1440
1441    #[tokio::test]
1442    async fn many_to_one_round_robin() -> Result<()> {
1443        // define input partitions
1444        let schema = test_schema();
1445        let partition = create_vec_batches(50);
1446        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1447
1448        // repartition from 3 input to 1 output
1449        let output_partitions =
1450            repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1451
1452        assert_eq!(1, output_partitions.len());
1453        assert_eq!(150, output_partitions[0].len());
1454
1455        Ok(())
1456    }
1457
1458    #[tokio::test]
1459    async fn many_to_many_round_robin() -> Result<()> {
1460        // define input partitions
1461        let schema = test_schema();
1462        let partition = create_vec_batches(50);
1463        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1464
1465        // repartition from 3 input to 5 output
1466        let output_partitions =
1467            repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1468
1469        assert_eq!(5, output_partitions.len());
1470        assert_eq!(30, output_partitions[0].len());
1471        assert_eq!(30, output_partitions[1].len());
1472        assert_eq!(30, output_partitions[2].len());
1473        assert_eq!(30, output_partitions[3].len());
1474        assert_eq!(30, output_partitions[4].len());
1475
1476        Ok(())
1477    }
1478
1479    #[tokio::test]
1480    async fn many_to_many_hash_partition() -> Result<()> {
1481        // define input partitions
1482        let schema = test_schema();
1483        let partition = create_vec_batches(50);
1484        let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1485
1486        let output_partitions = repartition(
1487            &schema,
1488            partitions,
1489            Partitioning::Hash(vec![col("c0", &schema)?], 8),
1490        )
1491        .await?;
1492
1493        let total_rows: usize = output_partitions
1494            .iter()
1495            .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1496            .sum();
1497
1498        assert_eq!(8, output_partitions.len());
1499        assert_eq!(total_rows, 8 * 50 * 3);
1500
1501        Ok(())
1502    }
1503
1504    fn test_schema() -> Arc<Schema> {
1505        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1506    }
1507
1508    async fn repartition(
1509        schema: &SchemaRef,
1510        input_partitions: Vec<Vec<RecordBatch>>,
1511        partitioning: Partitioning,
1512    ) -> Result<Vec<Vec<RecordBatch>>> {
1513        let task_ctx = Arc::new(TaskContext::default());
1514        // create physical plan
1515        let exec =
1516            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1517        let exec = RepartitionExec::try_new(exec, partitioning)?;
1518
1519        // execute and collect results
1520        let mut output_partitions = vec![];
1521        for i in 0..exec.partitioning().partition_count() {
1522            // execute this *output* partition and collect all batches
1523            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1524            let mut batches = vec![];
1525            while let Some(result) = stream.next().await {
1526                batches.push(result?);
1527            }
1528            output_partitions.push(batches);
1529        }
1530        Ok(output_partitions)
1531    }
1532
1533    #[tokio::test]
1534    async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1535        let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1536            SpawnedTask::spawn(async move {
1537                // define input partitions
1538                let schema = test_schema();
1539                let partition = create_vec_batches(50);
1540                let partitions =
1541                    vec![partition.clone(), partition.clone(), partition.clone()];
1542
1543                // repartition from 3 input to 5 output
1544                repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1545            });
1546
1547        let output_partitions = handle.join().await.unwrap().unwrap();
1548
1549        assert_eq!(5, output_partitions.len());
1550        assert_eq!(30, output_partitions[0].len());
1551        assert_eq!(30, output_partitions[1].len());
1552        assert_eq!(30, output_partitions[2].len());
1553        assert_eq!(30, output_partitions[3].len());
1554        assert_eq!(30, output_partitions[4].len());
1555
1556        Ok(())
1557    }
1558
1559    #[tokio::test]
1560    async fn unsupported_partitioning() {
1561        let task_ctx = Arc::new(TaskContext::default());
1562        // have to send at least one batch through to provoke error
1563        let batch = RecordBatch::try_from_iter(vec![(
1564            "my_awesome_field",
1565            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1566        )])
1567        .unwrap();
1568
1569        let schema = batch.schema();
1570        let input = MockExec::new(vec![Ok(batch)], schema);
1571        // This generates an error (partitioning type not supported)
1572        // but only after the plan is executed. The error should be
1573        // returned and no results produced
1574        let partitioning = Partitioning::UnknownPartitioning(1);
1575        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1576        let output_stream = exec.execute(0, task_ctx).unwrap();
1577
1578        // Expect that an error is returned
1579        let result_string = crate::common::collect(output_stream)
1580            .await
1581            .unwrap_err()
1582            .to_string();
1583        assert!(
1584            result_string
1585                .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1586            "actual: {result_string}"
1587        );
1588    }
1589
1590    #[tokio::test]
1591    async fn error_for_input_exec() {
1592        // This generates an error on a call to execute. The error
1593        // should be returned and no results produced.
1594
1595        let task_ctx = Arc::new(TaskContext::default());
1596        let input = ErrorExec::new();
1597        let partitioning = Partitioning::RoundRobinBatch(1);
1598        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1599
1600        // Expect that an error is returned
1601        let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1602
1603        assert!(
1604            result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1605            "actual: {result_string}"
1606        );
1607    }
1608
1609    #[tokio::test]
1610    async fn repartition_with_error_in_stream() {
1611        let task_ctx = Arc::new(TaskContext::default());
1612        let batch = RecordBatch::try_from_iter(vec![(
1613            "my_awesome_field",
1614            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1615        )])
1616        .unwrap();
1617
1618        // input stream returns one good batch and then one error. The
1619        // error should be returned.
1620        let err = exec_err!("bad data error");
1621
1622        let schema = batch.schema();
1623        let input = MockExec::new(vec![Ok(batch), err], schema);
1624        let partitioning = Partitioning::RoundRobinBatch(1);
1625        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1626
1627        // Note: this should pass (the stream can be created) but the
1628        // error when the input is executed should get passed back
1629        let output_stream = exec.execute(0, task_ctx).unwrap();
1630
1631        // Expect that an error is returned
1632        let result_string = crate::common::collect(output_stream)
1633            .await
1634            .unwrap_err()
1635            .to_string();
1636        assert!(
1637            result_string.contains("bad data error"),
1638            "actual: {result_string}"
1639        );
1640    }
1641
1642    #[tokio::test]
1643    async fn repartition_with_delayed_stream() {
1644        let task_ctx = Arc::new(TaskContext::default());
1645        let batch1 = RecordBatch::try_from_iter(vec![(
1646            "my_awesome_field",
1647            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1648        )])
1649        .unwrap();
1650
1651        let batch2 = RecordBatch::try_from_iter(vec![(
1652            "my_awesome_field",
1653            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1654        )])
1655        .unwrap();
1656
1657        // The mock exec doesn't return immediately (instead it
1658        // requires the input to wait at least once)
1659        let schema = batch1.schema();
1660        let expected_batches = vec![batch1.clone(), batch2.clone()];
1661        let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1662        let partitioning = Partitioning::RoundRobinBatch(1);
1663
1664        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1665
1666        assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
1667        +------------------+
1668        | my_awesome_field |
1669        +------------------+
1670        | bar              |
1671        | baz              |
1672        | foo              |
1673        | frob             |
1674        +------------------+
1675        ");
1676
1677        let output_stream = exec.execute(0, task_ctx).unwrap();
1678        let batches = crate::common::collect(output_stream).await.unwrap();
1679
1680        assert_snapshot!(batches_to_sort_string(&batches), @r"
1681        +------------------+
1682        | my_awesome_field |
1683        +------------------+
1684        | bar              |
1685        | baz              |
1686        | foo              |
1687        | frob             |
1688        +------------------+
1689        ");
1690    }
1691
1692    #[tokio::test]
1693    async fn robin_repartition_with_dropping_output_stream() {
1694        let task_ctx = Arc::new(TaskContext::default());
1695        let partitioning = Partitioning::RoundRobinBatch(2);
1696        // The barrier exec waits to be pinged
1697        // requires the input to wait at least once)
1698        let input = Arc::new(make_barrier_exec());
1699
1700        // partition into two output streams
1701        let exec = RepartitionExec::try_new(
1702            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1703            partitioning,
1704        )
1705        .unwrap();
1706
1707        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1708        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1709
1710        // now, purposely drop output stream 0
1711        // *before* any outputs are produced
1712        drop(output_stream0);
1713
1714        // Now, start sending input
1715        let mut background_task = JoinSet::new();
1716        background_task.spawn(async move {
1717            input.wait().await;
1718        });
1719
1720        // output stream 1 should *not* error and have one of the input batches
1721        let batches = crate::common::collect(output_stream1).await.unwrap();
1722
1723        assert_snapshot!(batches_to_sort_string(&batches), @r#"
1724            +------------------+
1725            | my_awesome_field |
1726            +------------------+
1727            | baz              |
1728            | frob             |
1729            | gaz              |
1730            | grob             |
1731            +------------------+
1732            "#);
1733    }
1734
1735    #[tokio::test]
1736    // As the hash results might be different on different platforms or
1737    // with different compilers, we will compare the same execution with
1738    // and without dropping the output stream.
1739    async fn hash_repartition_with_dropping_output_stream() {
1740        let task_ctx = Arc::new(TaskContext::default());
1741        let partitioning = Partitioning::Hash(
1742            vec![Arc::new(crate::expressions::Column::new(
1743                "my_awesome_field",
1744                0,
1745            ))],
1746            2,
1747        );
1748
1749        // We first collect the results without dropping the output stream.
1750        let input = Arc::new(make_barrier_exec());
1751        let exec = RepartitionExec::try_new(
1752            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1753            partitioning.clone(),
1754        )
1755        .unwrap();
1756        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1757        let mut background_task = JoinSet::new();
1758        background_task.spawn(async move {
1759            input.wait().await;
1760        });
1761        let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1762
1763        // run some checks on the result
1764        let items_vec = str_batches_to_vec(&batches_without_drop);
1765        let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1766        assert_eq!(items_vec.len(), items_set.len());
1767        let source_str_set: HashSet<&str> =
1768            ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1769                .iter()
1770                .copied()
1771                .collect();
1772        assert_eq!(items_set.difference(&source_str_set).count(), 0);
1773
1774        // Now do the same but dropping the stream before waiting for the barrier
1775        let input = Arc::new(make_barrier_exec());
1776        let exec = RepartitionExec::try_new(
1777            Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1778            partitioning,
1779        )
1780        .unwrap();
1781        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1782        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1783        // now, purposely drop output stream 0
1784        // *before* any outputs are produced
1785        drop(output_stream0);
1786        let mut background_task = JoinSet::new();
1787        background_task.spawn(async move {
1788            input.wait().await;
1789        });
1790        let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1791
1792        fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1793            batch
1794                .into_iter()
1795                .sorted_by_key(|b| format!("{b:?}"))
1796                .collect()
1797        }
1798
1799        assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1800    }
1801
1802    fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1803        batches
1804            .iter()
1805            .flat_map(|batch| {
1806                assert_eq!(batch.columns().len(), 1);
1807                let string_array = as_string_array(batch.column(0))
1808                    .expect("Unexpected type for repartitioned batch");
1809
1810                string_array
1811                    .iter()
1812                    .map(|v| v.expect("Unexpected null"))
1813                    .collect::<Vec<_>>()
1814            })
1815            .collect::<Vec<_>>()
1816    }
1817
1818    /// Create a BarrierExec that returns two partitions of two batches each
1819    fn make_barrier_exec() -> BarrierExec {
1820        let batch1 = RecordBatch::try_from_iter(vec![(
1821            "my_awesome_field",
1822            Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1823        )])
1824        .unwrap();
1825
1826        let batch2 = RecordBatch::try_from_iter(vec![(
1827            "my_awesome_field",
1828            Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1829        )])
1830        .unwrap();
1831
1832        let batch3 = RecordBatch::try_from_iter(vec![(
1833            "my_awesome_field",
1834            Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1835        )])
1836        .unwrap();
1837
1838        let batch4 = RecordBatch::try_from_iter(vec![(
1839            "my_awesome_field",
1840            Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1841        )])
1842        .unwrap();
1843
1844        // The barrier exec waits to be pinged
1845        // requires the input to wait at least once)
1846        let schema = batch1.schema();
1847        BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1848    }
1849
1850    #[tokio::test]
1851    async fn test_drop_cancel() -> Result<()> {
1852        let task_ctx = Arc::new(TaskContext::default());
1853        let schema =
1854            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1855
1856        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1857        let refs = blocking_exec.refs();
1858        let repartition_exec = Arc::new(RepartitionExec::try_new(
1859            blocking_exec,
1860            Partitioning::UnknownPartitioning(1),
1861        )?);
1862
1863        let fut = collect(repartition_exec, task_ctx);
1864        let mut fut = fut.boxed();
1865
1866        assert_is_pending(&mut fut);
1867        drop(fut);
1868        assert_strong_count_converges_to_zero(refs).await;
1869
1870        Ok(())
1871    }
1872
1873    #[tokio::test]
1874    async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1875        let task_ctx = Arc::new(TaskContext::default());
1876        let batch = RecordBatch::try_from_iter(vec![(
1877            "a",
1878            Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
1879        )])
1880        .unwrap();
1881        let partitioning = Partitioning::Hash(
1882            vec![Arc::new(crate::expressions::Column::new("a", 0))],
1883            2,
1884        );
1885        let schema = batch.schema();
1886        let input = MockExec::new(vec![Ok(batch)], schema);
1887        let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1888        let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1889        let batch0 = crate::common::collect(output_stream0).await.unwrap();
1890        let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1891        let batch1 = crate::common::collect(output_stream1).await.unwrap();
1892        assert!(batch0.is_empty() || batch1.is_empty());
1893        Ok(())
1894    }
1895
1896    #[tokio::test]
1897    async fn repartition_with_spilling() -> Result<()> {
1898        // Test that repartition successfully spills to disk when memory is constrained
1899        let schema = test_schema();
1900        let partition = create_vec_batches(50);
1901        let input_partitions = vec![partition];
1902        let partitioning = Partitioning::RoundRobinBatch(4);
1903
1904        // Set up context with very tight memory limit to force spilling
1905        let runtime = RuntimeEnvBuilder::default()
1906            .with_memory_limit(1, 1.0)
1907            .build_arc()?;
1908
1909        let task_ctx = TaskContext::default().with_runtime(runtime);
1910        let task_ctx = Arc::new(task_ctx);
1911
1912        // create physical plan
1913        let exec =
1914            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1915        let exec = RepartitionExec::try_new(exec, partitioning)?;
1916
1917        // Collect all partitions - should succeed by spilling to disk
1918        let mut total_rows = 0;
1919        for i in 0..exec.partitioning().partition_count() {
1920            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1921            while let Some(result) = stream.next().await {
1922                let batch = result?;
1923                total_rows += batch.num_rows();
1924            }
1925        }
1926
1927        // Verify we got all the data (50 batches * 8 rows each)
1928        assert_eq!(total_rows, 50 * 8);
1929
1930        // Verify spilling metrics to confirm spilling actually happened
1931        let metrics = exec.metrics().unwrap();
1932        assert!(
1933            metrics.spill_count().unwrap() > 0,
1934            "Expected spill_count > 0, but got {:?}",
1935            metrics.spill_count()
1936        );
1937        println!("Spilled {} times", metrics.spill_count().unwrap());
1938        assert!(
1939            metrics.spilled_bytes().unwrap() > 0,
1940            "Expected spilled_bytes > 0, but got {:?}",
1941            metrics.spilled_bytes()
1942        );
1943        println!(
1944            "Spilled {} bytes in {} spills",
1945            metrics.spilled_bytes().unwrap(),
1946            metrics.spill_count().unwrap()
1947        );
1948        assert!(
1949            metrics.spilled_rows().unwrap() > 0,
1950            "Expected spilled_rows > 0, but got {:?}",
1951            metrics.spilled_rows()
1952        );
1953        println!("Spilled {} rows", metrics.spilled_rows().unwrap());
1954
1955        Ok(())
1956    }
1957
1958    #[tokio::test]
1959    async fn repartition_with_partial_spilling() -> Result<()> {
1960        // Test that repartition can handle partial spilling (some batches in memory, some spilled)
1961        let schema = test_schema();
1962        let partition = create_vec_batches(50);
1963        let input_partitions = vec![partition];
1964        let partitioning = Partitioning::RoundRobinBatch(4);
1965
1966        // Set up context with moderate memory limit to force partial spilling
1967        // 2KB should allow some batches in memory but force others to spill
1968        let runtime = RuntimeEnvBuilder::default()
1969            .with_memory_limit(2 * 1024, 1.0)
1970            .build_arc()?;
1971
1972        let task_ctx = TaskContext::default().with_runtime(runtime);
1973        let task_ctx = Arc::new(task_ctx);
1974
1975        // create physical plan
1976        let exec =
1977            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1978        let exec = RepartitionExec::try_new(exec, partitioning)?;
1979
1980        // Collect all partitions - should succeed with partial spilling
1981        let mut total_rows = 0;
1982        for i in 0..exec.partitioning().partition_count() {
1983            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1984            while let Some(result) = stream.next().await {
1985                let batch = result?;
1986                total_rows += batch.num_rows();
1987            }
1988        }
1989
1990        // Verify we got all the data (50 batches * 8 rows each)
1991        assert_eq!(total_rows, 50 * 8);
1992
1993        // Verify partial spilling metrics
1994        let metrics = exec.metrics().unwrap();
1995        let spill_count = metrics.spill_count().unwrap();
1996        let spilled_rows = metrics.spilled_rows().unwrap();
1997        let spilled_bytes = metrics.spilled_bytes().unwrap();
1998
1999        assert!(
2000            spill_count > 0,
2001            "Expected some spilling to occur, but got spill_count={spill_count}"
2002        );
2003        assert!(
2004            spilled_rows > 0 && spilled_rows < total_rows,
2005            "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2006        );
2007        assert!(
2008            spilled_bytes > 0,
2009            "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2010        );
2011
2012        println!(
2013            "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2014            spilled_rows,
2015            total_rows,
2016            (spilled_rows as f64 / total_rows as f64) * 100.0,
2017            spill_count,
2018            spilled_bytes
2019        );
2020
2021        Ok(())
2022    }
2023
2024    #[tokio::test]
2025    async fn repartition_without_spilling() -> Result<()> {
2026        // Test that repartition does not spill when there's ample memory
2027        let schema = test_schema();
2028        let partition = create_vec_batches(50);
2029        let input_partitions = vec![partition];
2030        let partitioning = Partitioning::RoundRobinBatch(4);
2031
2032        // Set up context with generous memory limit - no spilling should occur
2033        let runtime = RuntimeEnvBuilder::default()
2034            .with_memory_limit(10 * 1024 * 1024, 1.0) // 10MB
2035            .build_arc()?;
2036
2037        let task_ctx = TaskContext::default().with_runtime(runtime);
2038        let task_ctx = Arc::new(task_ctx);
2039
2040        // create physical plan
2041        let exec =
2042            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2043        let exec = RepartitionExec::try_new(exec, partitioning)?;
2044
2045        // Collect all partitions - should succeed without spilling
2046        let mut total_rows = 0;
2047        for i in 0..exec.partitioning().partition_count() {
2048            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2049            while let Some(result) = stream.next().await {
2050                let batch = result?;
2051                total_rows += batch.num_rows();
2052            }
2053        }
2054
2055        // Verify we got all the data (50 batches * 8 rows each)
2056        assert_eq!(total_rows, 50 * 8);
2057
2058        // Verify no spilling occurred
2059        let metrics = exec.metrics().unwrap();
2060        assert_eq!(
2061            metrics.spill_count(),
2062            Some(0),
2063            "Expected no spilling, but got spill_count={:?}",
2064            metrics.spill_count()
2065        );
2066        assert_eq!(
2067            metrics.spilled_bytes(),
2068            Some(0),
2069            "Expected no bytes spilled, but got spilled_bytes={:?}",
2070            metrics.spilled_bytes()
2071        );
2072        assert_eq!(
2073            metrics.spilled_rows(),
2074            Some(0),
2075            "Expected no rows spilled, but got spilled_rows={:?}",
2076            metrics.spilled_rows()
2077        );
2078
2079        println!("No spilling occurred - all data processed in memory");
2080
2081        Ok(())
2082    }
2083
2084    #[tokio::test]
2085    async fn oom() -> Result<()> {
2086        use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2087
2088        // Test that repartition fails with OOM when disk manager is disabled
2089        let schema = test_schema();
2090        let partition = create_vec_batches(50);
2091        let input_partitions = vec![partition];
2092        let partitioning = Partitioning::RoundRobinBatch(4);
2093
2094        // Setup context with memory limit but NO disk manager (explicitly disabled)
2095        let runtime = RuntimeEnvBuilder::default()
2096            .with_memory_limit(1, 1.0)
2097            .with_disk_manager_builder(
2098                DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2099            )
2100            .build_arc()?;
2101
2102        let task_ctx = TaskContext::default().with_runtime(runtime);
2103        let task_ctx = Arc::new(task_ctx);
2104
2105        // create physical plan
2106        let exec =
2107            TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2108        let exec = RepartitionExec::try_new(exec, partitioning)?;
2109
2110        // Attempt to execute - should fail with ResourcesExhausted error
2111        for i in 0..exec.partitioning().partition_count() {
2112            let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2113            let err = stream.next().await.unwrap().unwrap_err();
2114            let err = err.find_root();
2115            assert!(
2116                matches!(err, DataFusionError::ResourcesExhausted(_)),
2117                "Wrong error type: {err}",
2118            );
2119        }
2120
2121        Ok(())
2122    }
2123
2124    /// Create vector batches
2125    fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2126        let batch = create_batch();
2127        (0..n).map(|_| batch.clone()).collect()
2128    }
2129
2130    /// Create batch
2131    fn create_batch() -> RecordBatch {
2132        let schema = test_schema();
2133        RecordBatch::try_new(
2134            schema,
2135            vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2136        )
2137        .unwrap()
2138    }
2139}
2140
2141#[cfg(test)]
2142mod test {
2143    use arrow::compute::SortOptions;
2144    use arrow::datatypes::{DataType, Field, Schema};
2145
2146    use super::*;
2147    use crate::test::TestMemoryExec;
2148    use crate::union::UnionExec;
2149
2150    use datafusion_physical_expr::expressions::col;
2151    use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2152
2153    /// Asserts that the plan is as expected
2154    ///
2155    /// `$EXPECTED_PLAN_LINES`: input plan
2156    /// `$PLAN`: the plan to optimized
2157    macro_rules! assert_plan {
2158        ($PLAN: expr,  @ $EXPECTED: expr) => {
2159            let formatted = crate::displayable($PLAN).indent(true).to_string();
2160
2161            insta::assert_snapshot!(
2162                formatted,
2163                @$EXPECTED
2164            );
2165        };
2166    }
2167
2168    #[tokio::test]
2169    async fn test_preserve_order() -> Result<()> {
2170        let schema = test_schema();
2171        let sort_exprs = sort_exprs(&schema);
2172        let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2173        let source2 = sorted_memory_exec(&schema, sort_exprs);
2174        // output has multiple partitions, and is sorted
2175        let union = UnionExec::try_new(vec![source1, source2])?;
2176        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2177            .with_preserve_order();
2178
2179        // Repartition should preserve order
2180        assert_plan!(&exec, @r"
2181        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2182          UnionExec
2183            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2184            DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2185        ");
2186        Ok(())
2187    }
2188
2189    #[tokio::test]
2190    async fn test_preserve_order_one_partition() -> Result<()> {
2191        let schema = test_schema();
2192        let sort_exprs = sort_exprs(&schema);
2193        let source = sorted_memory_exec(&schema, sort_exprs);
2194        // output is sorted, but has only a single partition, so no need to sort
2195        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2196            .with_preserve_order();
2197
2198        // Repartition should not preserve order
2199        assert_plan!(&exec, @r"
2200        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1
2201          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2202        ");
2203
2204        Ok(())
2205    }
2206
2207    #[tokio::test]
2208    async fn test_preserve_order_input_not_sorted() -> Result<()> {
2209        let schema = test_schema();
2210        let source1 = memory_exec(&schema);
2211        let source2 = memory_exec(&schema);
2212        // output has multiple partitions, but is not sorted
2213        let union = UnionExec::try_new(vec![source1, source2])?;
2214        let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2215            .with_preserve_order();
2216
2217        // Repartition should not preserve order, as there is no order to preserve
2218        assert_plan!(&exec, @r"
2219        RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2220          UnionExec
2221            DataSourceExec: partitions=1, partition_sizes=[0]
2222            DataSourceExec: partitions=1, partition_sizes=[0]
2223        ");
2224        Ok(())
2225    }
2226
2227    #[tokio::test]
2228    async fn test_repartition() -> Result<()> {
2229        let schema = test_schema();
2230        let sort_exprs = sort_exprs(&schema);
2231        let source = sorted_memory_exec(&schema, sort_exprs);
2232        // output is sorted, but has only a single partition, so no need to sort
2233        let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2234            .repartitioned(20, &Default::default())?
2235            .unwrap();
2236
2237        // Repartition should not preserve order
2238        assert_plan!(exec.as_ref(), @r"
2239        RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1
2240          DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2241        ");
2242        Ok(())
2243    }
2244
2245    fn test_schema() -> Arc<Schema> {
2246        Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2247    }
2248
2249    fn sort_exprs(schema: &Schema) -> LexOrdering {
2250        [PhysicalSortExpr {
2251            expr: col("c0", schema).unwrap(),
2252            options: SortOptions::default(),
2253        }]
2254        .into()
2255    }
2256
2257    fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2258        TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2259    }
2260
2261    fn sorted_memory_exec(
2262        schema: &SchemaRef,
2263        sort_exprs: LexOrdering,
2264    ) -> Arc<dyn ExecutionPlan> {
2265        Arc::new(TestMemoryExec::update_cache(Arc::new(
2266            TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2267                .unwrap()
2268                .try_with_sort_information(vec![sort_exprs])
2269                .unwrap(),
2270        )))
2271    }
2272}