1use std::fmt::{Debug, Formatter};
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use std::{any::Any, vec};
27
28use super::common::SharedMemoryReservation;
29use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
30use super::{
31 DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream,
32};
33use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
34use crate::hash_utils::create_hashes;
35use crate::metrics::{BaselineMetrics, SpillMetrics};
36use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec};
37use crate::repartition::distributor_channels::{
38 channels, partition_aware_channels, DistributionReceiver, DistributionSender,
39};
40use crate::sorts::streaming_merge::StreamingMergeBuilder;
41use crate::spill::spill_manager::SpillManager;
42use crate::stream::RecordBatchStreamAdapter;
43use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics};
44
45use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions};
46use arrow::compute::take_arrays;
47use arrow::datatypes::{SchemaRef, UInt32Type};
48use datafusion_common::config::ConfigOptions;
49use datafusion_common::stats::Precision;
50use datafusion_common::utils::transpose;
51use datafusion_common::{internal_err, ColumnStatistics, HashMap};
52use datafusion_common::{not_impl_err, DataFusionError, Result};
53use datafusion_common_runtime::SpawnedTask;
54use datafusion_execution::disk_manager::RefCountedTempFile;
55use datafusion_execution::memory_pool::MemoryConsumer;
56use datafusion_execution::TaskContext;
57use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
58use datafusion_physical_expr_common::sort_expr::LexOrdering;
59
60use crate::filter_pushdown::{
61 ChildPushdownResult, FilterDescription, FilterPushdownPhase,
62 FilterPushdownPropagation,
63};
64use futures::stream::Stream;
65use futures::{FutureExt, StreamExt, TryStreamExt};
66use log::trace;
67use parking_lot::Mutex;
68
69mod distributor_channels;
70
71#[derive(Debug)]
73enum RepartitionBatch {
74 Memory(RecordBatch),
76 Spilled {
80 spill_file: RefCountedTempFile,
81 size: usize,
82 },
83}
84
85type MaybeBatch = Option<Result<RepartitionBatch>>;
86type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
87type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
88
89#[derive(Debug)]
91struct PartitionChannels {
92 tx: InputPartitionsToCurrentPartitionSender,
94 rx: InputPartitionsToCurrentPartitionReceiver,
96 reservation: SharedMemoryReservation,
98 spill_manager: Arc<SpillManager>,
100}
101
102#[derive(Debug)]
103struct ConsumingInputStreamsState {
104 channels: HashMap<usize, PartitionChannels>,
107
108 abort_helper: Arc<Vec<SpawnedTask<()>>>,
110}
111
112#[derive(Default)]
114enum RepartitionExecState {
115 #[default]
118 NotInitialized,
119 InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
123 ConsumingInputStreams(ConsumingInputStreamsState),
126}
127
128impl Debug for RepartitionExecState {
129 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
130 match self {
131 RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
132 RepartitionExecState::InputStreamsInitialized(v) => {
133 write!(f, "InputStreamsInitialized({:?})", v.len())
134 }
135 RepartitionExecState::ConsumingInputStreams(v) => {
136 write!(f, "ConsumingInputStreams({v:?})")
137 }
138 }
139 }
140}
141
142impl RepartitionExecState {
143 fn ensure_input_streams_initialized(
144 &mut self,
145 input: Arc<dyn ExecutionPlan>,
146 metrics: ExecutionPlanMetricsSet,
147 output_partitions: usize,
148 ctx: Arc<TaskContext>,
149 ) -> Result<()> {
150 if !matches!(self, RepartitionExecState::NotInitialized) {
151 return Ok(());
152 }
153
154 let num_input_partitions = input.output_partitioning().partition_count();
155 let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
156
157 for i in 0..num_input_partitions {
158 let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
159
160 let timer = metrics.fetch_time.timer();
161 let stream = input.execute(i, Arc::clone(&ctx))?;
162 timer.done();
163
164 streams_and_metrics.push((stream, metrics));
165 }
166 *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
167 Ok(())
168 }
169
170 fn consume_input_streams(
171 &mut self,
172 input: Arc<dyn ExecutionPlan>,
173 metrics: ExecutionPlanMetricsSet,
174 partitioning: Partitioning,
175 preserve_order: bool,
176 name: String,
177 context: Arc<TaskContext>,
178 ) -> Result<&mut ConsumingInputStreamsState> {
179 let streams_and_metrics = match self {
180 RepartitionExecState::NotInitialized => {
181 self.ensure_input_streams_initialized(
182 Arc::clone(&input),
183 metrics.clone(),
184 partitioning.partition_count(),
185 Arc::clone(&context),
186 )?;
187 let RepartitionExecState::InputStreamsInitialized(value) = self else {
188 return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized");
191 };
192 value
193 }
194 RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
195 RepartitionExecState::InputStreamsInitialized(value) => value,
196 };
197
198 let num_input_partitions = streams_and_metrics.len();
199 let num_output_partitions = partitioning.partition_count();
200
201 let (txs, rxs) = if preserve_order {
202 let (txs, rxs) =
203 partition_aware_channels(num_input_partitions, num_output_partitions);
204 let txs = transpose(txs);
206 let rxs = transpose(rxs);
207 (txs, rxs)
208 } else {
209 let (txs, rxs) = channels(num_output_partitions);
213 let txs = txs
215 .into_iter()
216 .map(|item| vec![item; num_input_partitions])
217 .collect::<Vec<_>>();
218 let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
219 (txs, rxs)
220 };
221
222 let mut channels = HashMap::with_capacity(txs.len());
223 for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
224 let reservation = Arc::new(Mutex::new(
225 MemoryConsumer::new(format!("{name}[{partition}]"))
226 .with_can_spill(true)
227 .register(context.memory_pool()),
228 ));
229 let spill_metrics = SpillMetrics::new(&metrics, partition);
230 let spill_manager = Arc::new(SpillManager::new(
231 Arc::clone(&context.runtime_env()),
232 spill_metrics,
233 input.schema(),
234 ));
235 channels.insert(
236 partition,
237 PartitionChannels {
238 tx,
239 rx,
240 reservation,
241 spill_manager,
242 },
243 );
244 }
245
246 let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
248 for (i, (stream, metrics)) in
249 std::mem::take(streams_and_metrics).into_iter().enumerate()
250 {
251 let txs: HashMap<_, _> = channels
252 .iter()
253 .map(|(partition, channels)| {
254 (
255 *partition,
256 (
257 channels.tx[i].clone(),
258 Arc::clone(&channels.reservation),
259 Arc::clone(&channels.spill_manager),
260 ),
261 )
262 })
263 .collect();
264
265 let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
266 stream,
267 txs.clone(),
268 partitioning.clone(),
269 metrics,
270 ));
271
272 let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task(
275 input_task,
276 txs.into_iter()
277 .map(|(partition, (tx, _reservation, _spill_manager))| {
278 (partition, tx)
279 })
280 .collect(),
281 ));
282 spawned_tasks.push(wait_for_task);
283 }
284 *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
285 channels,
286 abort_helper: Arc::new(spawned_tasks),
287 });
288 match self {
289 RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
290 _ => unreachable!(),
291 }
292 }
293}
294
295pub struct BatchPartitioner {
297 state: BatchPartitionerState,
298 timer: metrics::Time,
299}
300
301enum BatchPartitionerState {
302 Hash {
303 random_state: ahash::RandomState,
304 exprs: Vec<Arc<dyn PhysicalExpr>>,
305 num_partitions: usize,
306 hash_buffer: Vec<u64>,
307 },
308 RoundRobin {
309 num_partitions: usize,
310 next_idx: usize,
311 },
312}
313
314impl BatchPartitioner {
315 pub fn try_new(partitioning: Partitioning, timer: metrics::Time) -> Result<Self> {
319 let state = match partitioning {
320 Partitioning::RoundRobinBatch(num_partitions) => {
321 BatchPartitionerState::RoundRobin {
322 num_partitions,
323 next_idx: 0,
324 }
325 }
326 Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash {
327 exprs,
328 num_partitions,
329 random_state: ahash::RandomState::with_seeds(0, 0, 0, 0),
331 hash_buffer: vec![],
332 },
333 other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"),
334 };
335
336 Ok(Self { state, timer })
337 }
338
339 pub fn partition<F>(&mut self, batch: RecordBatch, mut f: F) -> Result<()>
349 where
350 F: FnMut(usize, RecordBatch) -> Result<()>,
351 {
352 self.partition_iter(batch)?.try_for_each(|res| match res {
353 Ok((partition, batch)) => f(partition, batch),
354 Err(e) => Err(e),
355 })
356 }
357
358 fn partition_iter(
364 &mut self,
365 batch: RecordBatch,
366 ) -> Result<impl Iterator<Item = Result<(usize, RecordBatch)>> + Send + '_> {
367 let it: Box<dyn Iterator<Item = Result<(usize, RecordBatch)>> + Send> =
368 match &mut self.state {
369 BatchPartitionerState::RoundRobin {
370 num_partitions,
371 next_idx,
372 } => {
373 let idx = *next_idx;
374 *next_idx = (*next_idx + 1) % *num_partitions;
375 Box::new(std::iter::once(Ok((idx, batch))))
376 }
377 BatchPartitionerState::Hash {
378 random_state,
379 exprs,
380 num_partitions: partitions,
381 hash_buffer,
382 } => {
383 let timer = self.timer.timer();
385
386 let arrays = exprs
387 .iter()
388 .map(|expr| expr.evaluate(&batch)?.into_array(batch.num_rows()))
389 .collect::<Result<Vec<_>>>()?;
390
391 hash_buffer.clear();
392 hash_buffer.resize(batch.num_rows(), 0);
393
394 create_hashes(&arrays, random_state, hash_buffer)?;
395
396 let mut indices: Vec<_> = (0..*partitions)
397 .map(|_| Vec::with_capacity(batch.num_rows()))
398 .collect();
399
400 for (index, hash) in hash_buffer.iter().enumerate() {
401 indices[(*hash % *partitions as u64) as usize].push(index as u32);
402 }
403
404 timer.done();
406
407 let partitioner_timer = &self.timer;
409 let it = indices
410 .into_iter()
411 .enumerate()
412 .filter_map(|(partition, indices)| {
413 let indices: PrimitiveArray<UInt32Type> = indices.into();
414 (!indices.is_empty()).then_some((partition, indices))
415 })
416 .map(move |(partition, indices)| {
417 let _timer = partitioner_timer.timer();
419
420 let columns = take_arrays(batch.columns(), &indices, None)?;
422
423 let mut options = RecordBatchOptions::new();
424 options = options.with_row_count(Some(indices.len()));
425 let batch = RecordBatch::try_new_with_options(
426 batch.schema(),
427 columns,
428 &options,
429 )
430 .unwrap();
431
432 Ok((partition, batch))
433 });
434
435 Box::new(it)
436 }
437 };
438
439 Ok(it)
440 }
441
442 fn num_partitions(&self) -> usize {
444 match self.state {
445 BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions,
446 BatchPartitionerState::Hash { num_partitions, .. } => num_partitions,
447 }
448 }
449}
450
451#[derive(Debug, Clone)]
518pub struct RepartitionExec {
519 input: Arc<dyn ExecutionPlan>,
521 state: Arc<Mutex<RepartitionExecState>>,
524 metrics: ExecutionPlanMetricsSet,
526 preserve_order: bool,
529 cache: PlanProperties,
531}
532
533#[derive(Debug, Clone)]
534struct RepartitionMetrics {
535 fetch_time: metrics::Time,
537 repartition_time: metrics::Time,
539 send_time: Vec<metrics::Time>,
543}
544
545impl RepartitionMetrics {
546 pub fn new(
547 input_partition: usize,
548 num_output_partitions: usize,
549 metrics: &ExecutionPlanMetricsSet,
550 ) -> Self {
551 let fetch_time =
553 MetricBuilder::new(metrics).subset_time("fetch_time", input_partition);
554
555 let repartition_time =
557 MetricBuilder::new(metrics).subset_time("repartition_time", input_partition);
558
559 let send_time = (0..num_output_partitions)
561 .map(|output_partition| {
562 let label =
563 metrics::Label::new("outputPartition", output_partition.to_string());
564 MetricBuilder::new(metrics)
565 .with_label(label)
566 .subset_time("send_time", input_partition)
567 })
568 .collect();
569
570 Self {
571 fetch_time,
572 repartition_time,
573 send_time,
574 }
575 }
576}
577
578impl RepartitionExec {
579 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
581 &self.input
582 }
583
584 pub fn partitioning(&self) -> &Partitioning {
586 &self.cache.partitioning
587 }
588
589 pub fn preserve_order(&self) -> bool {
592 self.preserve_order
593 }
594
595 pub fn name(&self) -> &str {
597 "RepartitionExec"
598 }
599}
600
601impl DisplayAs for RepartitionExec {
602 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
603 match t {
604 DisplayFormatType::Default | DisplayFormatType::Verbose => {
605 write!(
606 f,
607 "{}: partitioning={}, input_partitions={}",
608 self.name(),
609 self.partitioning(),
610 self.input.output_partitioning().partition_count()
611 )?;
612
613 if self.preserve_order {
614 write!(f, ", preserve_order=true")?;
615 }
616
617 if let Some(sort_exprs) = self.sort_exprs() {
618 write!(f, ", sort_exprs={}", sort_exprs.clone())?;
619 }
620 Ok(())
621 }
622 DisplayFormatType::TreeRender => {
623 writeln!(f, "partitioning_scheme={}", self.partitioning(),)?;
624
625 let input_partition_count =
626 self.input.output_partitioning().partition_count();
627 let output_partition_count = self.partitioning().partition_count();
628 let input_to_output_partition_str =
629 format!("{input_partition_count} -> {output_partition_count}");
630 writeln!(
631 f,
632 "partition_count(in->out)={input_to_output_partition_str}"
633 )?;
634
635 if self.preserve_order {
636 writeln!(f, "preserve_order={}", self.preserve_order)?;
637 }
638 Ok(())
639 }
640 }
641 }
642}
643
644impl ExecutionPlan for RepartitionExec {
645 fn name(&self) -> &'static str {
646 "RepartitionExec"
647 }
648
649 fn as_any(&self) -> &dyn Any {
651 self
652 }
653
654 fn properties(&self) -> &PlanProperties {
655 &self.cache
656 }
657
658 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
659 vec![&self.input]
660 }
661
662 fn with_new_children(
663 self: Arc<Self>,
664 mut children: Vec<Arc<dyn ExecutionPlan>>,
665 ) -> Result<Arc<dyn ExecutionPlan>> {
666 let mut repartition = RepartitionExec::try_new(
667 children.swap_remove(0),
668 self.partitioning().clone(),
669 )?;
670 if self.preserve_order {
671 repartition = repartition.with_preserve_order();
672 }
673 Ok(Arc::new(repartition))
674 }
675
676 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
677 vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
678 }
679
680 fn maintains_input_order(&self) -> Vec<bool> {
681 Self::maintains_input_order_helper(self.input(), self.preserve_order)
682 }
683
684 fn execute(
685 &self,
686 partition: usize,
687 context: Arc<TaskContext>,
688 ) -> Result<SendableRecordBatchStream> {
689 trace!(
690 "Start {}::execute for partition: {}",
691 self.name(),
692 partition
693 );
694
695 let input = Arc::clone(&self.input);
696 let partitioning = self.partitioning().clone();
697 let metrics = self.metrics.clone();
698 let preserve_order = self.sort_exprs().is_some();
699 let name = self.name().to_owned();
700 let schema = self.schema();
701 let schema_captured = Arc::clone(&schema);
702
703 let sort_exprs = self.sort_exprs().cloned();
705
706 let state = Arc::clone(&self.state);
707 if let Some(mut state) = state.try_lock() {
708 state.ensure_input_streams_initialized(
709 Arc::clone(&input),
710 metrics.clone(),
711 partitioning.partition_count(),
712 Arc::clone(&context),
713 )?;
714 }
715
716 let stream = futures::stream::once(async move {
717 let num_input_partitions = input.output_partitioning().partition_count();
718
719 let (mut rx, reservation, spill_manager, abort_helper) = {
721 let mut state = state.lock();
723 let state = state.consume_input_streams(
724 Arc::clone(&input),
725 metrics.clone(),
726 partitioning,
727 preserve_order,
728 name.clone(),
729 Arc::clone(&context),
730 )?;
731
732 let PartitionChannels {
735 rx,
736 reservation,
737 spill_manager,
738 ..
739 } = state
740 .channels
741 .remove(&partition)
742 .expect("partition not used yet");
743
744 (
745 rx,
746 reservation,
747 spill_manager,
748 Arc::clone(&state.abort_helper),
749 )
750 };
751
752 trace!(
753 "Before returning stream in {name}::execute for partition: {partition}"
754 );
755
756 if preserve_order {
757 let input_streams = rx
759 .into_iter()
760 .map(|receiver| {
761 Box::pin(PerPartitionStream {
762 schema: Arc::clone(&schema_captured),
763 receiver,
764 _drop_helper: Arc::clone(&abort_helper),
765 reservation: Arc::clone(&reservation),
766 spill_manager: Arc::clone(&spill_manager),
767 state: RepartitionStreamState::ReceivingFromChannel,
768 }) as SendableRecordBatchStream
769 })
770 .collect::<Vec<_>>();
771 let fetch = None;
776 let merge_reservation =
777 MemoryConsumer::new(format!("{name}[Merge {partition}]"))
778 .register(context.memory_pool());
779 StreamingMergeBuilder::new()
780 .with_streams(input_streams)
781 .with_schema(schema_captured)
782 .with_expressions(&sort_exprs.unwrap())
783 .with_metrics(BaselineMetrics::new(&metrics, partition))
784 .with_batch_size(context.session_config().batch_size())
785 .with_fetch(fetch)
786 .with_reservation(merge_reservation)
787 .build()
788 } else {
789 Ok(Box::pin(RepartitionStream {
790 num_input_partitions,
791 num_input_partitions_processed: 0,
792 schema: input.schema(),
793 input: rx.swap_remove(0),
794 _drop_helper: abort_helper,
795 reservation,
796 spill_manager,
797 state: RepartitionStreamState::ReceivingFromChannel,
798 }) as SendableRecordBatchStream)
799 }
800 })
801 .try_flatten();
802 let stream = RecordBatchStreamAdapter::new(schema, stream);
803 Ok(Box::pin(stream))
804 }
805
806 fn metrics(&self) -> Option<MetricsSet> {
807 Some(self.metrics.clone_inner())
808 }
809
810 fn statistics(&self) -> Result<Statistics> {
811 self.input.partition_statistics(None)
812 }
813
814 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
815 if let Some(partition) = partition {
816 let partition_count = self.partitioning().partition_count();
817 if partition_count == 0 {
818 return Ok(Statistics::new_unknown(&self.schema()));
819 }
820
821 if partition >= partition_count {
822 return internal_err!(
823 "RepartitionExec invalid partition {} (expected less than {})",
824 partition,
825 self.partitioning().partition_count()
826 );
827 }
828
829 let mut stats = self.input.partition_statistics(None)?;
830
831 stats.num_rows = stats
833 .num_rows
834 .get_value()
835 .map(|rows| Precision::Inexact(rows / partition_count))
836 .unwrap_or(Precision::Absent);
837 stats.total_byte_size = stats
838 .total_byte_size
839 .get_value()
840 .map(|bytes| Precision::Inexact(bytes / partition_count))
841 .unwrap_or(Precision::Absent);
842
843 stats.column_statistics = stats
845 .column_statistics
846 .iter()
847 .map(|_| ColumnStatistics::new_unknown())
848 .collect();
849
850 Ok(stats)
851 } else {
852 self.input.partition_statistics(None)
853 }
854 }
855
856 fn cardinality_effect(&self) -> CardinalityEffect {
857 CardinalityEffect::Equal
858 }
859
860 fn try_swapping_with_projection(
861 &self,
862 projection: &ProjectionExec,
863 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
864 if projection.expr().len() >= projection.input().schema().fields().len() {
866 return Ok(None);
867 }
868
869 if projection.benefits_from_input_partitioning()[0]
871 || !all_columns(projection.expr())
872 {
873 return Ok(None);
874 }
875
876 let new_projection = make_with_child(projection, self.input())?;
877
878 let new_partitioning = match self.partitioning() {
879 Partitioning::Hash(partitions, size) => {
880 let mut new_partitions = vec![];
881 for partition in partitions {
882 let Some(new_partition) =
883 update_expr(partition, projection.expr(), false)?
884 else {
885 return Ok(None);
886 };
887 new_partitions.push(new_partition);
888 }
889 Partitioning::Hash(new_partitions, *size)
890 }
891 others => others.clone(),
892 };
893
894 Ok(Some(Arc::new(RepartitionExec::try_new(
895 new_projection,
896 new_partitioning,
897 )?)))
898 }
899
900 fn gather_filters_for_pushdown(
901 &self,
902 _phase: FilterPushdownPhase,
903 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
904 _config: &ConfigOptions,
905 ) -> Result<FilterDescription> {
906 FilterDescription::from_children(parent_filters, &self.children())
907 }
908
909 fn handle_child_pushdown_result(
910 &self,
911 _phase: FilterPushdownPhase,
912 child_pushdown_result: ChildPushdownResult,
913 _config: &ConfigOptions,
914 ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
915 Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
916 }
917
918 fn repartitioned(
919 &self,
920 target_partitions: usize,
921 _config: &ConfigOptions,
922 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
923 use Partitioning::*;
924 let mut new_properties = self.cache.clone();
925 new_properties.partitioning = match new_properties.partitioning {
926 RoundRobinBatch(_) => RoundRobinBatch(target_partitions),
927 Hash(hash, _) => Hash(hash, target_partitions),
928 UnknownPartitioning(_) => UnknownPartitioning(target_partitions),
929 };
930 Ok(Some(Arc::new(Self {
931 input: Arc::clone(&self.input),
932 state: Arc::clone(&self.state),
933 metrics: self.metrics.clone(),
934 preserve_order: self.preserve_order,
935 cache: new_properties,
936 })))
937 }
938}
939
940impl RepartitionExec {
941 pub fn try_new(
945 input: Arc<dyn ExecutionPlan>,
946 partitioning: Partitioning,
947 ) -> Result<Self> {
948 let preserve_order = false;
949 let cache =
950 Self::compute_properties(&input, partitioning.clone(), preserve_order);
951 Ok(RepartitionExec {
952 input,
953 state: Default::default(),
954 metrics: ExecutionPlanMetricsSet::new(),
955 preserve_order,
956 cache,
957 })
958 }
959
960 fn maintains_input_order_helper(
961 input: &Arc<dyn ExecutionPlan>,
962 preserve_order: bool,
963 ) -> Vec<bool> {
964 vec![preserve_order || input.output_partitioning().partition_count() <= 1]
966 }
967
968 fn eq_properties_helper(
969 input: &Arc<dyn ExecutionPlan>,
970 preserve_order: bool,
971 ) -> EquivalenceProperties {
972 let mut eq_properties = input.equivalence_properties().clone();
974 if !Self::maintains_input_order_helper(input, preserve_order)[0] {
976 eq_properties.clear_orderings();
977 }
978 if input.output_partitioning().partition_count() > 1 {
981 eq_properties.clear_per_partition_constants();
982 }
983 eq_properties
984 }
985
986 fn compute_properties(
988 input: &Arc<dyn ExecutionPlan>,
989 partitioning: Partitioning,
990 preserve_order: bool,
991 ) -> PlanProperties {
992 PlanProperties::new(
993 Self::eq_properties_helper(input, preserve_order),
994 partitioning,
995 input.pipeline_behavior(),
996 input.boundedness(),
997 )
998 .with_scheduling_type(SchedulingType::Cooperative)
999 .with_evaluation_type(EvaluationType::Eager)
1000 }
1001
1002 pub fn with_preserve_order(mut self) -> Self {
1010 self.preserve_order =
1011 self.input.output_ordering().is_some() &&
1013 self.input.output_partitioning().partition_count() > 1;
1016 let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order);
1017 self.cache = self.cache.with_eq_properties(eq_properties);
1018 self
1019 }
1020
1021 fn sort_exprs(&self) -> Option<&LexOrdering> {
1023 if self.preserve_order {
1024 self.input.output_ordering()
1025 } else {
1026 None
1027 }
1028 }
1029
1030 async fn pull_from_input(
1035 mut stream: SendableRecordBatchStream,
1036 mut output_channels: HashMap<
1037 usize,
1038 (
1039 DistributionSender<MaybeBatch>,
1040 SharedMemoryReservation,
1041 Arc<SpillManager>,
1042 ),
1043 >,
1044 partitioning: Partitioning,
1045 metrics: RepartitionMetrics,
1046 ) -> Result<()> {
1047 let mut partitioner =
1048 BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
1049
1050 let mut batches_until_yield = partitioner.num_partitions();
1052 while !output_channels.is_empty() {
1053 let timer = metrics.fetch_time.timer();
1055 let result = stream.next().await;
1056 timer.done();
1057
1058 let batch = match result {
1060 Some(result) => result?,
1061 None => break,
1062 };
1063
1064 if batch.num_rows() == 0 {
1066 continue;
1067 }
1068
1069 for res in partitioner.partition_iter(batch)? {
1070 let (partition, batch) = res?;
1071 let size = batch.get_array_memory_size();
1072
1073 let timer = metrics.send_time[partition].timer();
1074 if let Some((tx, reservation, spill_manager)) =
1076 output_channels.get_mut(&partition)
1077 {
1078 let (batch_to_send, is_memory_batch) =
1079 match reservation.lock().try_grow(size) {
1080 Ok(_) => {
1081 (RepartitionBatch::Memory(batch), true)
1083 }
1084 Err(_) => {
1085 let spill_file = spill_manager
1087 .spill_record_batch_and_finish(
1088 &[batch],
1089 &format!(
1090 "RepartitionExec spill partition {partition}"
1091 ),
1092 )?
1093 .expect("non-empty batch should produce spill file");
1095
1096 (RepartitionBatch::Spilled { spill_file, size }, false)
1098 }
1099 };
1100
1101 if tx.send(Some(Ok(batch_to_send))).await.is_err() {
1102 if is_memory_batch {
1105 reservation.lock().shrink(size);
1106 }
1107 output_channels.remove(&partition);
1108 }
1109 }
1110 timer.done();
1111 }
1112
1113 if batches_until_yield == 0 {
1130 tokio::task::yield_now().await;
1131 batches_until_yield = partitioner.num_partitions();
1132 } else {
1133 batches_until_yield -= 1;
1134 }
1135 }
1136
1137 Ok(())
1138 }
1139
1140 async fn wait_for_task(
1146 input_task: SpawnedTask<Result<()>>,
1147 txs: HashMap<usize, DistributionSender<MaybeBatch>>,
1148 ) {
1149 match input_task.join().await {
1153 Err(e) => {
1155 let e = Arc::new(e);
1156
1157 for (_, tx) in txs {
1158 let err = Err(DataFusionError::Context(
1159 "Join Error".to_string(),
1160 Box::new(DataFusionError::External(Box::new(Arc::clone(&e)))),
1161 ));
1162 tx.send(Some(err)).await.ok();
1163 }
1164 }
1165 Ok(Err(e)) => {
1167 let e = Arc::new(e);
1169
1170 for (_, tx) in txs {
1171 let err = Err(DataFusionError::from(&e));
1173 tx.send(Some(err)).await.ok();
1174 }
1175 }
1176 Ok(Ok(())) => {
1178 for (_, tx) in txs {
1180 tx.send(None).await.ok();
1181 }
1182 }
1183 }
1184 }
1185}
1186
1187enum RepartitionStreamState {
1188 ReceivingFromChannel,
1190 ReadingSpilledBatch(SendableRecordBatchStream),
1192}
1193
1194struct RepartitionStream {
1195 num_input_partitions: usize,
1197
1198 num_input_partitions_processed: usize,
1200
1201 schema: SchemaRef,
1203
1204 input: DistributionReceiver<MaybeBatch>,
1206
1207 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1209
1210 reservation: SharedMemoryReservation,
1212
1213 spill_manager: Arc<SpillManager>,
1215
1216 state: RepartitionStreamState,
1218}
1219
1220impl Stream for RepartitionStream {
1221 type Item = Result<RecordBatch>;
1222
1223 fn poll_next(
1224 mut self: Pin<&mut Self>,
1225 cx: &mut Context<'_>,
1226 ) -> Poll<Option<Self::Item>> {
1227 loop {
1228 match &mut self.state {
1229 RepartitionStreamState::ReceivingFromChannel => {
1230 let value = futures::ready!(self.input.recv().poll_unpin(cx));
1231 match value {
1232 Some(Some(v)) => match v {
1233 Ok(RepartitionBatch::Memory(batch)) => {
1234 self.reservation
1236 .lock()
1237 .shrink(batch.get_array_memory_size());
1238 return Poll::Ready(Some(Ok(batch)));
1239 }
1240 Ok(RepartitionBatch::Spilled { spill_file, size }) => {
1241 let stream = self
1244 .spill_manager
1245 .read_spill_as_stream(spill_file, Some(size))?;
1246 self.state =
1247 RepartitionStreamState::ReadingSpilledBatch(stream);
1248 }
1250 Err(e) => {
1251 return Poll::Ready(Some(Err(e)));
1252 }
1253 },
1254 Some(None) => {
1255 self.num_input_partitions_processed += 1;
1256
1257 if self.num_input_partitions
1258 == self.num_input_partitions_processed
1259 {
1260 return Poll::Ready(None);
1262 } else {
1263 continue;
1265 }
1266 }
1267 None => {
1268 return Poll::Ready(None);
1269 }
1270 }
1271 }
1272 RepartitionStreamState::ReadingSpilledBatch(stream) => {
1273 match futures::ready!(stream.poll_next_unpin(cx)) {
1274 Some(Ok(batch)) => {
1275 return Poll::Ready(Some(Ok(batch)));
1277 }
1278 Some(Err(e)) => {
1279 self.state = RepartitionStreamState::ReceivingFromChannel;
1280 return Poll::Ready(Some(Err(e)));
1281 }
1282 None => {
1283 self.state = RepartitionStreamState::ReceivingFromChannel;
1285 continue;
1286 }
1287 }
1288 }
1289 }
1290 }
1291 }
1292}
1293
1294impl RecordBatchStream for RepartitionStream {
1295 fn schema(&self) -> SchemaRef {
1297 Arc::clone(&self.schema)
1298 }
1299}
1300
1301struct PerPartitionStream {
1304 schema: SchemaRef,
1306
1307 receiver: DistributionReceiver<MaybeBatch>,
1309
1310 _drop_helper: Arc<Vec<SpawnedTask<()>>>,
1312
1313 reservation: SharedMemoryReservation,
1315
1316 spill_manager: Arc<SpillManager>,
1318
1319 state: RepartitionStreamState,
1321}
1322
1323impl Stream for PerPartitionStream {
1324 type Item = Result<RecordBatch>;
1325
1326 fn poll_next(
1327 mut self: Pin<&mut Self>,
1328 cx: &mut Context<'_>,
1329 ) -> Poll<Option<Self::Item>> {
1330 loop {
1331 match &mut self.state {
1332 RepartitionStreamState::ReceivingFromChannel => {
1333 let value = futures::ready!(self.receiver.recv().poll_unpin(cx));
1334 match value {
1335 Some(Some(v)) => match v {
1336 Ok(RepartitionBatch::Memory(batch)) => {
1337 self.reservation
1339 .lock()
1340 .shrink(batch.get_array_memory_size());
1341 return Poll::Ready(Some(Ok(batch)));
1342 }
1343 Ok(RepartitionBatch::Spilled { spill_file, size }) => {
1344 let stream = self
1347 .spill_manager
1348 .read_spill_as_stream(spill_file, Some(size))?;
1349 self.state =
1350 RepartitionStreamState::ReadingSpilledBatch(stream);
1351 }
1353 Err(e) => {
1354 return Poll::Ready(Some(Err(e)));
1355 }
1356 },
1357 Some(None) => {
1358 return Poll::Ready(None);
1360 }
1361 None => return Poll::Ready(None),
1362 }
1363 }
1364
1365 RepartitionStreamState::ReadingSpilledBatch(stream) => {
1366 match futures::ready!(stream.poll_next_unpin(cx)) {
1367 Some(Ok(batch)) => {
1368 return Poll::Ready(Some(Ok(batch)));
1370 }
1371 Some(Err(e)) => {
1372 self.state = RepartitionStreamState::ReceivingFromChannel;
1373 return Poll::Ready(Some(Err(e)));
1374 }
1375 None => {
1376 self.state = RepartitionStreamState::ReceivingFromChannel;
1378 continue;
1379 }
1380 }
1381 }
1382 }
1383 }
1384 }
1385}
1386
1387impl RecordBatchStream for PerPartitionStream {
1388 fn schema(&self) -> SchemaRef {
1390 Arc::clone(&self.schema)
1391 }
1392}
1393
1394#[cfg(test)]
1395mod tests {
1396 use std::collections::HashSet;
1397
1398 use super::*;
1399 use crate::test::TestMemoryExec;
1400 use crate::{
1401 test::{
1402 assert_is_pending,
1403 exec::{
1404 assert_strong_count_converges_to_zero, BarrierExec, BlockingExec,
1405 ErrorExec, MockExec,
1406 },
1407 },
1408 {collect, expressions::col},
1409 };
1410
1411 use arrow::array::{ArrayRef, StringArray, UInt32Array};
1412 use arrow::datatypes::{DataType, Field, Schema};
1413 use datafusion_common::cast::as_string_array;
1414 use datafusion_common::exec_err;
1415 use datafusion_common::test_util::batches_to_sort_string;
1416 use datafusion_common_runtime::JoinSet;
1417 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1418 use insta::assert_snapshot;
1419 use itertools::Itertools;
1420
1421 #[tokio::test]
1422 async fn one_to_many_round_robin() -> Result<()> {
1423 let schema = test_schema();
1425 let partition = create_vec_batches(50);
1426 let partitions = vec![partition];
1427
1428 let output_partitions =
1430 repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?;
1431
1432 assert_eq!(4, output_partitions.len());
1433 assert_eq!(13, output_partitions[0].len());
1434 assert_eq!(13, output_partitions[1].len());
1435 assert_eq!(12, output_partitions[2].len());
1436 assert_eq!(12, output_partitions[3].len());
1437
1438 Ok(())
1439 }
1440
1441 #[tokio::test]
1442 async fn many_to_one_round_robin() -> Result<()> {
1443 let schema = test_schema();
1445 let partition = create_vec_batches(50);
1446 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1447
1448 let output_partitions =
1450 repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?;
1451
1452 assert_eq!(1, output_partitions.len());
1453 assert_eq!(150, output_partitions[0].len());
1454
1455 Ok(())
1456 }
1457
1458 #[tokio::test]
1459 async fn many_to_many_round_robin() -> Result<()> {
1460 let schema = test_schema();
1462 let partition = create_vec_batches(50);
1463 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1464
1465 let output_partitions =
1467 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?;
1468
1469 assert_eq!(5, output_partitions.len());
1470 assert_eq!(30, output_partitions[0].len());
1471 assert_eq!(30, output_partitions[1].len());
1472 assert_eq!(30, output_partitions[2].len());
1473 assert_eq!(30, output_partitions[3].len());
1474 assert_eq!(30, output_partitions[4].len());
1475
1476 Ok(())
1477 }
1478
1479 #[tokio::test]
1480 async fn many_to_many_hash_partition() -> Result<()> {
1481 let schema = test_schema();
1483 let partition = create_vec_batches(50);
1484 let partitions = vec![partition.clone(), partition.clone(), partition.clone()];
1485
1486 let output_partitions = repartition(
1487 &schema,
1488 partitions,
1489 Partitioning::Hash(vec![col("c0", &schema)?], 8),
1490 )
1491 .await?;
1492
1493 let total_rows: usize = output_partitions
1494 .iter()
1495 .map(|x| x.iter().map(|x| x.num_rows()).sum::<usize>())
1496 .sum();
1497
1498 assert_eq!(8, output_partitions.len());
1499 assert_eq!(total_rows, 8 * 50 * 3);
1500
1501 Ok(())
1502 }
1503
1504 fn test_schema() -> Arc<Schema> {
1505 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
1506 }
1507
1508 async fn repartition(
1509 schema: &SchemaRef,
1510 input_partitions: Vec<Vec<RecordBatch>>,
1511 partitioning: Partitioning,
1512 ) -> Result<Vec<Vec<RecordBatch>>> {
1513 let task_ctx = Arc::new(TaskContext::default());
1514 let exec =
1516 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(schema), None)?;
1517 let exec = RepartitionExec::try_new(exec, partitioning)?;
1518
1519 let mut output_partitions = vec![];
1521 for i in 0..exec.partitioning().partition_count() {
1522 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1524 let mut batches = vec![];
1525 while let Some(result) = stream.next().await {
1526 batches.push(result?);
1527 }
1528 output_partitions.push(batches);
1529 }
1530 Ok(output_partitions)
1531 }
1532
1533 #[tokio::test]
1534 async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
1535 let handle: SpawnedTask<Result<Vec<Vec<RecordBatch>>>> =
1536 SpawnedTask::spawn(async move {
1537 let schema = test_schema();
1539 let partition = create_vec_batches(50);
1540 let partitions =
1541 vec![partition.clone(), partition.clone(), partition.clone()];
1542
1543 repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await
1545 });
1546
1547 let output_partitions = handle.join().await.unwrap().unwrap();
1548
1549 assert_eq!(5, output_partitions.len());
1550 assert_eq!(30, output_partitions[0].len());
1551 assert_eq!(30, output_partitions[1].len());
1552 assert_eq!(30, output_partitions[2].len());
1553 assert_eq!(30, output_partitions[3].len());
1554 assert_eq!(30, output_partitions[4].len());
1555
1556 Ok(())
1557 }
1558
1559 #[tokio::test]
1560 async fn unsupported_partitioning() {
1561 let task_ctx = Arc::new(TaskContext::default());
1562 let batch = RecordBatch::try_from_iter(vec![(
1564 "my_awesome_field",
1565 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1566 )])
1567 .unwrap();
1568
1569 let schema = batch.schema();
1570 let input = MockExec::new(vec![Ok(batch)], schema);
1571 let partitioning = Partitioning::UnknownPartitioning(1);
1575 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1576 let output_stream = exec.execute(0, task_ctx).unwrap();
1577
1578 let result_string = crate::common::collect(output_stream)
1580 .await
1581 .unwrap_err()
1582 .to_string();
1583 assert!(
1584 result_string
1585 .contains("Unsupported repartitioning scheme UnknownPartitioning(1)"),
1586 "actual: {result_string}"
1587 );
1588 }
1589
1590 #[tokio::test]
1591 async fn error_for_input_exec() {
1592 let task_ctx = Arc::new(TaskContext::default());
1596 let input = ErrorExec::new();
1597 let partitioning = Partitioning::RoundRobinBatch(1);
1598 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1599
1600 let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1602
1603 assert!(
1604 result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
1605 "actual: {result_string}"
1606 );
1607 }
1608
1609 #[tokio::test]
1610 async fn repartition_with_error_in_stream() {
1611 let task_ctx = Arc::new(TaskContext::default());
1612 let batch = RecordBatch::try_from_iter(vec![(
1613 "my_awesome_field",
1614 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1615 )])
1616 .unwrap();
1617
1618 let err = exec_err!("bad data error");
1621
1622 let schema = batch.schema();
1623 let input = MockExec::new(vec![Ok(batch), err], schema);
1624 let partitioning = Partitioning::RoundRobinBatch(1);
1625 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1626
1627 let output_stream = exec.execute(0, task_ctx).unwrap();
1630
1631 let result_string = crate::common::collect(output_stream)
1633 .await
1634 .unwrap_err()
1635 .to_string();
1636 assert!(
1637 result_string.contains("bad data error"),
1638 "actual: {result_string}"
1639 );
1640 }
1641
1642 #[tokio::test]
1643 async fn repartition_with_delayed_stream() {
1644 let task_ctx = Arc::new(TaskContext::default());
1645 let batch1 = RecordBatch::try_from_iter(vec![(
1646 "my_awesome_field",
1647 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1648 )])
1649 .unwrap();
1650
1651 let batch2 = RecordBatch::try_from_iter(vec![(
1652 "my_awesome_field",
1653 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1654 )])
1655 .unwrap();
1656
1657 let schema = batch1.schema();
1660 let expected_batches = vec![batch1.clone(), batch2.clone()];
1661 let input = MockExec::new(vec![Ok(batch1), Ok(batch2)], schema);
1662 let partitioning = Partitioning::RoundRobinBatch(1);
1663
1664 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1665
1666 assert_snapshot!(batches_to_sort_string(&expected_batches), @r"
1667 +------------------+
1668 | my_awesome_field |
1669 +------------------+
1670 | bar |
1671 | baz |
1672 | foo |
1673 | frob |
1674 +------------------+
1675 ");
1676
1677 let output_stream = exec.execute(0, task_ctx).unwrap();
1678 let batches = crate::common::collect(output_stream).await.unwrap();
1679
1680 assert_snapshot!(batches_to_sort_string(&batches), @r"
1681 +------------------+
1682 | my_awesome_field |
1683 +------------------+
1684 | bar |
1685 | baz |
1686 | foo |
1687 | frob |
1688 +------------------+
1689 ");
1690 }
1691
1692 #[tokio::test]
1693 async fn robin_repartition_with_dropping_output_stream() {
1694 let task_ctx = Arc::new(TaskContext::default());
1695 let partitioning = Partitioning::RoundRobinBatch(2);
1696 let input = Arc::new(make_barrier_exec());
1699
1700 let exec = RepartitionExec::try_new(
1702 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1703 partitioning,
1704 )
1705 .unwrap();
1706
1707 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1708 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1709
1710 drop(output_stream0);
1713
1714 let mut background_task = JoinSet::new();
1716 background_task.spawn(async move {
1717 input.wait().await;
1718 });
1719
1720 let batches = crate::common::collect(output_stream1).await.unwrap();
1722
1723 assert_snapshot!(batches_to_sort_string(&batches), @r#"
1724 +------------------+
1725 | my_awesome_field |
1726 +------------------+
1727 | baz |
1728 | frob |
1729 | gaz |
1730 | grob |
1731 +------------------+
1732 "#);
1733 }
1734
1735 #[tokio::test]
1736 async fn hash_repartition_with_dropping_output_stream() {
1740 let task_ctx = Arc::new(TaskContext::default());
1741 let partitioning = Partitioning::Hash(
1742 vec![Arc::new(crate::expressions::Column::new(
1743 "my_awesome_field",
1744 0,
1745 ))],
1746 2,
1747 );
1748
1749 let input = Arc::new(make_barrier_exec());
1751 let exec = RepartitionExec::try_new(
1752 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1753 partitioning.clone(),
1754 )
1755 .unwrap();
1756 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1757 let mut background_task = JoinSet::new();
1758 background_task.spawn(async move {
1759 input.wait().await;
1760 });
1761 let batches_without_drop = crate::common::collect(output_stream1).await.unwrap();
1762
1763 let items_vec = str_batches_to_vec(&batches_without_drop);
1765 let items_set: HashSet<&str> = items_vec.iter().copied().collect();
1766 assert_eq!(items_vec.len(), items_set.len());
1767 let source_str_set: HashSet<&str> =
1768 ["foo", "bar", "frob", "baz", "goo", "gar", "grob", "gaz"]
1769 .iter()
1770 .copied()
1771 .collect();
1772 assert_eq!(items_set.difference(&source_str_set).count(), 0);
1773
1774 let input = Arc::new(make_barrier_exec());
1776 let exec = RepartitionExec::try_new(
1777 Arc::clone(&input) as Arc<dyn ExecutionPlan>,
1778 partitioning,
1779 )
1780 .unwrap();
1781 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1782 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1783 drop(output_stream0);
1786 let mut background_task = JoinSet::new();
1787 background_task.spawn(async move {
1788 input.wait().await;
1789 });
1790 let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
1791
1792 fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1793 batch
1794 .into_iter()
1795 .sorted_by_key(|b| format!("{b:?}"))
1796 .collect()
1797 }
1798
1799 assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
1800 }
1801
1802 fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {
1803 batches
1804 .iter()
1805 .flat_map(|batch| {
1806 assert_eq!(batch.columns().len(), 1);
1807 let string_array = as_string_array(batch.column(0))
1808 .expect("Unexpected type for repartitioned batch");
1809
1810 string_array
1811 .iter()
1812 .map(|v| v.expect("Unexpected null"))
1813 .collect::<Vec<_>>()
1814 })
1815 .collect::<Vec<_>>()
1816 }
1817
1818 fn make_barrier_exec() -> BarrierExec {
1820 let batch1 = RecordBatch::try_from_iter(vec![(
1821 "my_awesome_field",
1822 Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef,
1823 )])
1824 .unwrap();
1825
1826 let batch2 = RecordBatch::try_from_iter(vec![(
1827 "my_awesome_field",
1828 Arc::new(StringArray::from(vec!["frob", "baz"])) as ArrayRef,
1829 )])
1830 .unwrap();
1831
1832 let batch3 = RecordBatch::try_from_iter(vec![(
1833 "my_awesome_field",
1834 Arc::new(StringArray::from(vec!["goo", "gar"])) as ArrayRef,
1835 )])
1836 .unwrap();
1837
1838 let batch4 = RecordBatch::try_from_iter(vec![(
1839 "my_awesome_field",
1840 Arc::new(StringArray::from(vec!["grob", "gaz"])) as ArrayRef,
1841 )])
1842 .unwrap();
1843
1844 let schema = batch1.schema();
1847 BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], schema)
1848 }
1849
1850 #[tokio::test]
1851 async fn test_drop_cancel() -> Result<()> {
1852 let task_ctx = Arc::new(TaskContext::default());
1853 let schema =
1854 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1855
1856 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1857 let refs = blocking_exec.refs();
1858 let repartition_exec = Arc::new(RepartitionExec::try_new(
1859 blocking_exec,
1860 Partitioning::UnknownPartitioning(1),
1861 )?);
1862
1863 let fut = collect(repartition_exec, task_ctx);
1864 let mut fut = fut.boxed();
1865
1866 assert_is_pending(&mut fut);
1867 drop(fut);
1868 assert_strong_count_converges_to_zero(refs).await;
1869
1870 Ok(())
1871 }
1872
1873 #[tokio::test]
1874 async fn hash_repartition_avoid_empty_batch() -> Result<()> {
1875 let task_ctx = Arc::new(TaskContext::default());
1876 let batch = RecordBatch::try_from_iter(vec![(
1877 "a",
1878 Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
1879 )])
1880 .unwrap();
1881 let partitioning = Partitioning::Hash(
1882 vec![Arc::new(crate::expressions::Column::new("a", 0))],
1883 2,
1884 );
1885 let schema = batch.schema();
1886 let input = MockExec::new(vec![Ok(batch)], schema);
1887 let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
1888 let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap();
1889 let batch0 = crate::common::collect(output_stream0).await.unwrap();
1890 let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap();
1891 let batch1 = crate::common::collect(output_stream1).await.unwrap();
1892 assert!(batch0.is_empty() || batch1.is_empty());
1893 Ok(())
1894 }
1895
1896 #[tokio::test]
1897 async fn repartition_with_spilling() -> Result<()> {
1898 let schema = test_schema();
1900 let partition = create_vec_batches(50);
1901 let input_partitions = vec![partition];
1902 let partitioning = Partitioning::RoundRobinBatch(4);
1903
1904 let runtime = RuntimeEnvBuilder::default()
1906 .with_memory_limit(1, 1.0)
1907 .build_arc()?;
1908
1909 let task_ctx = TaskContext::default().with_runtime(runtime);
1910 let task_ctx = Arc::new(task_ctx);
1911
1912 let exec =
1914 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1915 let exec = RepartitionExec::try_new(exec, partitioning)?;
1916
1917 let mut total_rows = 0;
1919 for i in 0..exec.partitioning().partition_count() {
1920 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1921 while let Some(result) = stream.next().await {
1922 let batch = result?;
1923 total_rows += batch.num_rows();
1924 }
1925 }
1926
1927 assert_eq!(total_rows, 50 * 8);
1929
1930 let metrics = exec.metrics().unwrap();
1932 assert!(
1933 metrics.spill_count().unwrap() > 0,
1934 "Expected spill_count > 0, but got {:?}",
1935 metrics.spill_count()
1936 );
1937 println!("Spilled {} times", metrics.spill_count().unwrap());
1938 assert!(
1939 metrics.spilled_bytes().unwrap() > 0,
1940 "Expected spilled_bytes > 0, but got {:?}",
1941 metrics.spilled_bytes()
1942 );
1943 println!(
1944 "Spilled {} bytes in {} spills",
1945 metrics.spilled_bytes().unwrap(),
1946 metrics.spill_count().unwrap()
1947 );
1948 assert!(
1949 metrics.spilled_rows().unwrap() > 0,
1950 "Expected spilled_rows > 0, but got {:?}",
1951 metrics.spilled_rows()
1952 );
1953 println!("Spilled {} rows", metrics.spilled_rows().unwrap());
1954
1955 Ok(())
1956 }
1957
1958 #[tokio::test]
1959 async fn repartition_with_partial_spilling() -> Result<()> {
1960 let schema = test_schema();
1962 let partition = create_vec_batches(50);
1963 let input_partitions = vec![partition];
1964 let partitioning = Partitioning::RoundRobinBatch(4);
1965
1966 let runtime = RuntimeEnvBuilder::default()
1969 .with_memory_limit(2 * 1024, 1.0)
1970 .build_arc()?;
1971
1972 let task_ctx = TaskContext::default().with_runtime(runtime);
1973 let task_ctx = Arc::new(task_ctx);
1974
1975 let exec =
1977 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
1978 let exec = RepartitionExec::try_new(exec, partitioning)?;
1979
1980 let mut total_rows = 0;
1982 for i in 0..exec.partitioning().partition_count() {
1983 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
1984 while let Some(result) = stream.next().await {
1985 let batch = result?;
1986 total_rows += batch.num_rows();
1987 }
1988 }
1989
1990 assert_eq!(total_rows, 50 * 8);
1992
1993 let metrics = exec.metrics().unwrap();
1995 let spill_count = metrics.spill_count().unwrap();
1996 let spilled_rows = metrics.spilled_rows().unwrap();
1997 let spilled_bytes = metrics.spilled_bytes().unwrap();
1998
1999 assert!(
2000 spill_count > 0,
2001 "Expected some spilling to occur, but got spill_count={spill_count}"
2002 );
2003 assert!(
2004 spilled_rows > 0 && spilled_rows < total_rows,
2005 "Expected partial spilling (0 < spilled_rows < {total_rows}), but got spilled_rows={spilled_rows}"
2006 );
2007 assert!(
2008 spilled_bytes > 0,
2009 "Expected some bytes to be spilled, but got spilled_bytes={spilled_bytes}"
2010 );
2011
2012 println!(
2013 "Partial spilling: spilled {} out of {} rows ({:.1}%) in {} spills, {} bytes",
2014 spilled_rows,
2015 total_rows,
2016 (spilled_rows as f64 / total_rows as f64) * 100.0,
2017 spill_count,
2018 spilled_bytes
2019 );
2020
2021 Ok(())
2022 }
2023
2024 #[tokio::test]
2025 async fn repartition_without_spilling() -> Result<()> {
2026 let schema = test_schema();
2028 let partition = create_vec_batches(50);
2029 let input_partitions = vec![partition];
2030 let partitioning = Partitioning::RoundRobinBatch(4);
2031
2032 let runtime = RuntimeEnvBuilder::default()
2034 .with_memory_limit(10 * 1024 * 1024, 1.0) .build_arc()?;
2036
2037 let task_ctx = TaskContext::default().with_runtime(runtime);
2038 let task_ctx = Arc::new(task_ctx);
2039
2040 let exec =
2042 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2043 let exec = RepartitionExec::try_new(exec, partitioning)?;
2044
2045 let mut total_rows = 0;
2047 for i in 0..exec.partitioning().partition_count() {
2048 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2049 while let Some(result) = stream.next().await {
2050 let batch = result?;
2051 total_rows += batch.num_rows();
2052 }
2053 }
2054
2055 assert_eq!(total_rows, 50 * 8);
2057
2058 let metrics = exec.metrics().unwrap();
2060 assert_eq!(
2061 metrics.spill_count(),
2062 Some(0),
2063 "Expected no spilling, but got spill_count={:?}",
2064 metrics.spill_count()
2065 );
2066 assert_eq!(
2067 metrics.spilled_bytes(),
2068 Some(0),
2069 "Expected no bytes spilled, but got spilled_bytes={:?}",
2070 metrics.spilled_bytes()
2071 );
2072 assert_eq!(
2073 metrics.spilled_rows(),
2074 Some(0),
2075 "Expected no rows spilled, but got spilled_rows={:?}",
2076 metrics.spilled_rows()
2077 );
2078
2079 println!("No spilling occurred - all data processed in memory");
2080
2081 Ok(())
2082 }
2083
2084 #[tokio::test]
2085 async fn oom() -> Result<()> {
2086 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
2087
2088 let schema = test_schema();
2090 let partition = create_vec_batches(50);
2091 let input_partitions = vec![partition];
2092 let partitioning = Partitioning::RoundRobinBatch(4);
2093
2094 let runtime = RuntimeEnvBuilder::default()
2096 .with_memory_limit(1, 1.0)
2097 .with_disk_manager_builder(
2098 DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled),
2099 )
2100 .build_arc()?;
2101
2102 let task_ctx = TaskContext::default().with_runtime(runtime);
2103 let task_ctx = Arc::new(task_ctx);
2104
2105 let exec =
2107 TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?;
2108 let exec = RepartitionExec::try_new(exec, partitioning)?;
2109
2110 for i in 0..exec.partitioning().partition_count() {
2112 let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
2113 let err = stream.next().await.unwrap().unwrap_err();
2114 let err = err.find_root();
2115 assert!(
2116 matches!(err, DataFusionError::ResourcesExhausted(_)),
2117 "Wrong error type: {err}",
2118 );
2119 }
2120
2121 Ok(())
2122 }
2123
2124 fn create_vec_batches(n: usize) -> Vec<RecordBatch> {
2126 let batch = create_batch();
2127 (0..n).map(|_| batch.clone()).collect()
2128 }
2129
2130 fn create_batch() -> RecordBatch {
2132 let schema = test_schema();
2133 RecordBatch::try_new(
2134 schema,
2135 vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
2136 )
2137 .unwrap()
2138 }
2139}
2140
2141#[cfg(test)]
2142mod test {
2143 use arrow::compute::SortOptions;
2144 use arrow::datatypes::{DataType, Field, Schema};
2145
2146 use super::*;
2147 use crate::test::TestMemoryExec;
2148 use crate::union::UnionExec;
2149
2150 use datafusion_physical_expr::expressions::col;
2151 use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
2152
2153 macro_rules! assert_plan {
2158 ($PLAN: expr, @ $EXPECTED: expr) => {
2159 let formatted = crate::displayable($PLAN).indent(true).to_string();
2160
2161 insta::assert_snapshot!(
2162 formatted,
2163 @$EXPECTED
2164 );
2165 };
2166 }
2167
2168 #[tokio::test]
2169 async fn test_preserve_order() -> Result<()> {
2170 let schema = test_schema();
2171 let sort_exprs = sort_exprs(&schema);
2172 let source1 = sorted_memory_exec(&schema, sort_exprs.clone());
2173 let source2 = sorted_memory_exec(&schema, sort_exprs);
2174 let union = UnionExec::try_new(vec![source1, source2])?;
2176 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2177 .with_preserve_order();
2178
2179 assert_plan!(&exec, @r"
2181 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true, sort_exprs=c0@0 ASC
2182 UnionExec
2183 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2184 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2185 ");
2186 Ok(())
2187 }
2188
2189 #[tokio::test]
2190 async fn test_preserve_order_one_partition() -> Result<()> {
2191 let schema = test_schema();
2192 let sort_exprs = sort_exprs(&schema);
2193 let source = sorted_memory_exec(&schema, sort_exprs);
2194 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2196 .with_preserve_order();
2197
2198 assert_plan!(&exec, @r"
2200 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1
2201 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2202 ");
2203
2204 Ok(())
2205 }
2206
2207 #[tokio::test]
2208 async fn test_preserve_order_input_not_sorted() -> Result<()> {
2209 let schema = test_schema();
2210 let source1 = memory_exec(&schema);
2211 let source2 = memory_exec(&schema);
2212 let union = UnionExec::try_new(vec![source1, source2])?;
2214 let exec = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))?
2215 .with_preserve_order();
2216
2217 assert_plan!(&exec, @r"
2219 RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2
2220 UnionExec
2221 DataSourceExec: partitions=1, partition_sizes=[0]
2222 DataSourceExec: partitions=1, partition_sizes=[0]
2223 ");
2224 Ok(())
2225 }
2226
2227 #[tokio::test]
2228 async fn test_repartition() -> Result<()> {
2229 let schema = test_schema();
2230 let sort_exprs = sort_exprs(&schema);
2231 let source = sorted_memory_exec(&schema, sort_exprs);
2232 let exec = RepartitionExec::try_new(source, Partitioning::RoundRobinBatch(10))?
2234 .repartitioned(20, &Default::default())?
2235 .unwrap();
2236
2237 assert_plan!(exec.as_ref(), @r"
2239 RepartitionExec: partitioning=RoundRobinBatch(20), input_partitions=1
2240 DataSourceExec: partitions=1, partition_sizes=[0], output_ordering=c0@0 ASC
2241 ");
2242 Ok(())
2243 }
2244
2245 fn test_schema() -> Arc<Schema> {
2246 Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
2247 }
2248
2249 fn sort_exprs(schema: &Schema) -> LexOrdering {
2250 [PhysicalSortExpr {
2251 expr: col("c0", schema).unwrap(),
2252 options: SortOptions::default(),
2253 }]
2254 .into()
2255 }
2256
2257 fn memory_exec(schema: &SchemaRef) -> Arc<dyn ExecutionPlan> {
2258 TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(schema), None).unwrap()
2259 }
2260
2261 fn sorted_memory_exec(
2262 schema: &SchemaRef,
2263 sort_exprs: LexOrdering,
2264 ) -> Arc<dyn ExecutionPlan> {
2265 Arc::new(TestMemoryExec::update_cache(Arc::new(
2266 TestMemoryExec::try_new(&[vec![]], Arc::clone(schema), None)
2267 .unwrap()
2268 .try_with_sort_information(vec![sort_exprs])
2269 .unwrap(),
2270 )))
2271 }
2272}