datafusion_physical_plan/sorts/merge.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Merge that deals with an arbitrary size of streaming inputs.
19//! This is an order-preserving merge.
20
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{ready, Context, Poll};
24
25use crate::metrics::BaselineMetrics;
26use crate::sorts::builder::BatchBuilder;
27use crate::sorts::cursor::{Cursor, CursorValues};
28use crate::sorts::stream::PartitionedStream;
29use crate::RecordBatchStream;
30
31use arrow::datatypes::SchemaRef;
32use arrow::record_batch::RecordBatch;
33use datafusion_common::Result;
34use datafusion_execution::memory_pool::MemoryReservation;
35
36use futures::Stream;
37
38/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
39type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch)>>>;
40
41/// Merges a stream of sorted cursors and record batches into a single sorted stream
42#[derive(Debug)]
43pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
44 in_progress: BatchBuilder,
45
46 /// The sorted input streams to merge together
47 streams: CursorStream<C>,
48
49 /// used to record execution metrics
50 metrics: BaselineMetrics,
51
52 /// If the stream has encountered an error or reaches the
53 /// `fetch` limit.
54 done: bool,
55
56 /// A loser tree that always produces the minimum cursor
57 ///
58 /// Node 0 stores the top winner, Nodes 1..num_streams store
59 /// the loser nodes
60 ///
61 /// This implements a "Tournament Tree" (aka Loser Tree) to keep
62 /// track of the current smallest element at the top. When the top
63 /// record is taken, the tree structure is not modified, and only
64 /// the path from bottom to top is visited, keeping the number of
65 /// comparisons close to the theoretical limit of `log(S)`.
66 ///
67 /// The current implementation uses a vector to store the tree.
68 /// Conceptually, it looks like this (assuming 8 streams):
69 ///
70 /// ```text
71 /// 0 (winner)
72 ///
73 /// 1
74 /// / \
75 /// 2 3
76 /// / \ / \
77 /// 4 5 6 7
78 /// ```
79 ///
80 /// Where element at index 0 in the vector is the current winner. Element
81 /// at index 1 is the root of the loser tree, element at index 2 is the
82 /// left child of the root, and element at index 3 is the right child of
83 /// the root and so on.
84 ///
85 /// reference: <https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
86 loser_tree: Vec<usize>,
87
88 /// If the most recently yielded overall winner has been replaced
89 /// within the loser tree. A value of `false` indicates that the
90 /// overall winner has been yielded but the loser tree has not
91 /// been updated
92 loser_tree_adjusted: bool,
93
94 /// Target batch size
95 batch_size: usize,
96
97 /// Cursors for each input partition. `None` means the input is exhausted
98 cursors: Vec<Option<Cursor<C>>>,
99
100 /// Configuration parameter to enable round-robin selection of tied winners of loser tree.
101 ///
102 /// This option controls the tie-breaker strategy and attempts to avoid the
103 /// issue of unbalanced polling between partitions
104 ///
105 /// If `true`, when multiple partitions have the same value, the partition
106 /// that has the fewest poll counts is selected. This strategy ensures that
107 /// multiple partitions with the same value are chosen equally, distributing
108 /// the polling load in a round-robin fashion. This approach balances the
109 /// workload more effectively across partitions and avoids excessive buffer
110 /// growth.
111 ///
112 /// if `false`, partitions with smaller indices are consistently chosen as
113 /// the winners, which can lead to an uneven distribution of polling and potentially
114 /// causing upstream operator buffers for the other partitions to grow
115 /// excessively, as they continued receiving data without consuming it.
116 ///
117 /// For example, an upstream operator like `RepartitionExec` execution would
118 /// keep sending data to certain partitions, but those partitions wouldn't
119 /// consume the data if they weren't selected as winners. This resulted in
120 /// inefficient buffer usage.
121 enable_round_robin_tie_breaker: bool,
122
123 /// Flag indicating whether we are in the mode of round-robin
124 /// tie breaker for the loser tree winners.
125 round_robin_tie_breaker_mode: bool,
126
127 /// Total number of polls returning the same value, as per partition.
128 /// We select the one that has less poll counts for tie-breaker in loser tree.
129 num_of_polled_with_same_value: Vec<usize>,
130
131 /// To keep track of reset counts
132 poll_reset_epochs: Vec<usize>,
133
134 /// Current reset count
135 current_reset_epoch: usize,
136
137 /// Stores the previous value of each partitions for tracking the poll counts on the same value.
138 prev_cursors: Vec<Option<Cursor<C>>>,
139
140 /// Optional number of rows to fetch
141 fetch: Option<usize>,
142
143 /// number of rows produced
144 produced: usize,
145
146 /// This vector contains the indices of the partitions that have not started emitting yet.
147 uninitiated_partitions: Vec<usize>,
148}
149
150impl<C: CursorValues> SortPreservingMergeStream<C> {
151 pub(crate) fn new(
152 streams: CursorStream<C>,
153 schema: SchemaRef,
154 metrics: BaselineMetrics,
155 batch_size: usize,
156 fetch: Option<usize>,
157 reservation: MemoryReservation,
158 enable_round_robin_tie_breaker: bool,
159 ) -> Self {
160 let stream_count = streams.partitions();
161
162 Self {
163 in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation),
164 streams,
165 metrics,
166 done: false,
167 cursors: (0..stream_count).map(|_| None).collect(),
168 prev_cursors: (0..stream_count).map(|_| None).collect(),
169 round_robin_tie_breaker_mode: false,
170 num_of_polled_with_same_value: vec![0; stream_count],
171 current_reset_epoch: 0,
172 poll_reset_epochs: vec![0; stream_count],
173 loser_tree: vec![],
174 loser_tree_adjusted: false,
175 batch_size,
176 fetch,
177 produced: 0,
178 uninitiated_partitions: (0..stream_count).collect(),
179 enable_round_robin_tie_breaker,
180 }
181 }
182
183 /// If the stream at the given index is not exhausted, and the last cursor for the
184 /// stream is finished, poll the stream for the next RecordBatch and create a new
185 /// cursor for the stream from the returned result
186 fn maybe_poll_stream(
187 &mut self,
188 cx: &mut Context<'_>,
189 idx: usize,
190 ) -> Poll<Result<()>> {
191 if self.cursors[idx].is_some() {
192 // Cursor is not finished - don't need a new RecordBatch yet
193 return Poll::Ready(Ok(()));
194 }
195
196 match futures::ready!(self.streams.poll_next(cx, idx)) {
197 None => Poll::Ready(Ok(())),
198 Some(Err(e)) => Poll::Ready(Err(e)),
199 Some(Ok((cursor, batch))) => {
200 self.cursors[idx] = Some(Cursor::new(cursor));
201 Poll::Ready(self.in_progress.push_batch(idx, batch))
202 }
203 }
204 }
205
206 fn poll_next_inner(
207 &mut self,
208 cx: &mut Context<'_>,
209 ) -> Poll<Option<Result<RecordBatch>>> {
210 if self.done {
211 return Poll::Ready(None);
212 }
213 // Once all partitions have set their corresponding cursors for the loser tree,
214 // we skip the following block. Until then, this function may be called multiple
215 // times and can return Poll::Pending if any partition returns Poll::Pending.
216
217 if self.loser_tree.is_empty() {
218 // Manual indexing since we're iterating over the vector and shrinking it in the loop
219 let mut idx = 0;
220 while idx < self.uninitiated_partitions.len() {
221 let partition_idx = self.uninitiated_partitions[idx];
222 match self.maybe_poll_stream(cx, partition_idx) {
223 Poll::Ready(Err(e)) => {
224 self.done = true;
225 return Poll::Ready(Some(Err(e)));
226 }
227 Poll::Pending => {
228 // The polled stream is pending which means we're already set up to
229 // be woken when necessary
230 // Try the next stream
231 idx += 1;
232 }
233 _ => {
234 // The polled stream is ready
235 // Remove it from uninitiated_partitions
236 // Don't bump idx here, since a new element will have taken its
237 // place which we'll try in the next loop iteration
238 // swap_remove will change the partition poll order, but that shouldn't
239 // make a difference since we're waiting for all streams to be ready.
240 self.uninitiated_partitions.swap_remove(idx);
241 }
242 }
243 }
244
245 if self.uninitiated_partitions.is_empty() {
246 // If there are no more uninitiated partitions, set up the loser tree and continue
247 // to the next phase.
248
249 // Claim the memory for the uninitiated partitions
250 self.uninitiated_partitions.shrink_to_fit();
251 self.init_loser_tree();
252 } else {
253 // There are still uninitiated partitions so return pending.
254 // We only get here if we've polled all uninitiated streams and at least one of them
255 // returned pending itself. That means we will be woken as soon as one of the
256 // streams would like to be polled again.
257 // There is no need to reschedule ourselves eagerly.
258 return Poll::Pending;
259 }
260 }
261
262 // NB timer records time taken on drop, so there are no
263 // calls to `timer.done()` below.
264 let elapsed_compute = self.metrics.elapsed_compute().clone();
265 let _timer = elapsed_compute.timer();
266
267 loop {
268 // Adjust the loser tree if necessary, returning control if needed
269 if !self.loser_tree_adjusted {
270 let winner = self.loser_tree[0];
271 if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
272 self.done = true;
273 return Poll::Ready(Some(Err(e)));
274 }
275 self.update_loser_tree();
276 }
277
278 let stream_idx = self.loser_tree[0];
279 if self.advance_cursors(stream_idx) {
280 self.loser_tree_adjusted = false;
281 self.in_progress.push_row(stream_idx);
282
283 // stop sorting if fetch has been reached
284 if self.fetch_reached() {
285 self.done = true;
286 } else if self.in_progress.len() < self.batch_size {
287 continue;
288 }
289 }
290
291 self.produced += self.in_progress.len();
292
293 return Poll::Ready(self.in_progress.build_record_batch().transpose());
294 }
295 }
296
297 /// For the given partition, updates the poll count. If the current value is the same
298 /// of the previous value, it increases the count by 1; otherwise, it is reset as 0.
299 fn update_poll_count_on_the_same_value(&mut self, partition_idx: usize) {
300 let cursor = &mut self.cursors[partition_idx];
301
302 // Check if the current partition's poll count is logically "reset"
303 if self.poll_reset_epochs[partition_idx] != self.current_reset_epoch {
304 self.poll_reset_epochs[partition_idx] = self.current_reset_epoch;
305 self.num_of_polled_with_same_value[partition_idx] = 0;
306 }
307
308 if let Some(c) = cursor.as_mut() {
309 // Compare with the last row in the previous batch
310 let prev_cursor = &self.prev_cursors[partition_idx];
311 if c.is_eq_to_prev_one(prev_cursor.as_ref()) {
312 self.num_of_polled_with_same_value[partition_idx] += 1;
313 } else {
314 self.num_of_polled_with_same_value[partition_idx] = 0;
315 }
316 }
317 }
318
319 fn fetch_reached(&mut self) -> bool {
320 self.fetch
321 .map(|fetch| self.produced + self.in_progress.len() >= fetch)
322 .unwrap_or(false)
323 }
324
325 /// Advances the actual cursor. If it reaches its end, update the
326 /// previous cursor with it.
327 ///
328 /// If the given partition is not exhausted, the function returns `true`.
329 fn advance_cursors(&mut self, stream_idx: usize) -> bool {
330 if let Some(cursor) = &mut self.cursors[stream_idx] {
331 let _ = cursor.advance();
332 if cursor.is_finished() {
333 // Take the current cursor, leaving `None` in its place
334 self.prev_cursors[stream_idx] = self.cursors[stream_idx].take();
335 }
336 true
337 } else {
338 false
339 }
340 }
341
342 /// Returns `true` if the cursor at index `a` is greater than at index `b`.
343 /// In an equality case, it compares the partition indices given.
344 #[inline]
345 fn is_gt(&self, a: usize, b: usize) -> bool {
346 match (&self.cursors[a], &self.cursors[b]) {
347 (None, _) => true,
348 (_, None) => false,
349 (Some(ac), Some(bc)) => ac.cmp(bc).then_with(|| a.cmp(&b)).is_gt(),
350 }
351 }
352
353 #[inline]
354 fn is_poll_count_gt(&self, a: usize, b: usize) -> bool {
355 let poll_a = self.num_of_polled_with_same_value[a];
356 let poll_b = self.num_of_polled_with_same_value[b];
357 poll_a.cmp(&poll_b).then_with(|| a.cmp(&b)).is_gt()
358 }
359
360 #[inline]
361 fn update_winner(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) {
362 self.loser_tree[cmp_node] = *winner;
363 *winner = challenger;
364 }
365
366 /// Find the leaf node index in the loser tree for the given cursor index
367 ///
368 /// Note that this is not necessarily a leaf node in the tree, but it can
369 /// also be a half-node (a node with only one child). This happens when the
370 /// number of cursors/streams is not a power of two. Thus, the loser tree
371 /// will be unbalanced, but it will still work correctly.
372 ///
373 /// For example, with 5 streams, the loser tree will look like this:
374 ///
375 /// ```text
376 /// 0 (winner)
377 ///
378 /// 1
379 /// / \
380 /// 2 3
381 /// / \ / \
382 /// 4 | | |
383 /// / \ | | |
384 /// -+---+--+---+---+---- Below is not a part of loser tree
385 /// S3 S4 S0 S1 S2
386 /// ```
387 ///
388 /// S0, S1, ... S4 are the streams (read: stream at index 0, stream at
389 /// index 1, etc.)
390 ///
391 /// Zooming in at node 2 in the loser tree as an example, we can see that
392 /// it takes as input the next item at (S0) and the loser of (S3, S4).
393 #[inline]
394 fn lt_leaf_node_index(&self, cursor_index: usize) -> usize {
395 (self.cursors.len() + cursor_index) / 2
396 }
397
398 /// Find the parent node index for the given node index
399 #[inline]
400 fn lt_parent_node_index(&self, node_idx: usize) -> usize {
401 node_idx / 2
402 }
403
404 /// Attempts to initialize the loser tree with one value from each
405 /// non exhausted input, if possible
406 fn init_loser_tree(&mut self) {
407 // Init loser tree
408 self.loser_tree = vec![usize::MAX; self.cursors.len()];
409 for i in 0..self.cursors.len() {
410 let mut winner = i;
411 let mut cmp_node = self.lt_leaf_node_index(i);
412 while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
413 let challenger = self.loser_tree[cmp_node];
414 if self.is_gt(winner, challenger) {
415 self.loser_tree[cmp_node] = winner;
416 winner = challenger;
417 }
418
419 cmp_node = self.lt_parent_node_index(cmp_node);
420 }
421 self.loser_tree[cmp_node] = winner;
422 }
423 self.loser_tree_adjusted = true;
424 }
425
426 /// Resets the poll count by incrementing the reset epoch.
427 fn reset_poll_counts(&mut self) {
428 self.current_reset_epoch += 1;
429 }
430
431 /// Handles tie-breaking logic during the adjustment of the loser tree.
432 ///
433 /// When comparing elements from multiple partitions in the `update_loser_tree` process, a tie can occur
434 /// between the current winner and a challenger. This function is invoked when such a tie needs to be
435 /// resolved according to the round-robin tie-breaker mode.
436 ///
437 /// If round-robin tie-breaking is not active, it is enabled, and the poll counts for all elements are reset.
438 /// The function then compares the poll counts of the current winner and the challenger:
439 /// - If the winner remains at the top after the final comparison, it increments the winner's poll count.
440 /// - If the challenger has a lower poll count than the current winner, the challenger becomes the new winner.
441 /// - If the poll counts are equal but the challenger's index is smaller, the challenger is preferred.
442 ///
443 /// # Parameters
444 /// - `cmp_node`: The index of the comparison node in the loser tree where the tie-breaking is happening.
445 /// - `winner`: A mutable reference to the current winner, which may be updated based on the tie-breaking result.
446 /// - `challenger`: The index of the challenger being compared against the winner.
447 ///
448 /// This function ensures fair selection among elements with equal values when tie-breaking mode is enabled,
449 /// aiming to balance the polling across different partitions.
450 #[inline]
451 fn handle_tie(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) {
452 if !self.round_robin_tie_breaker_mode {
453 self.round_robin_tie_breaker_mode = true;
454 // Reset poll count for tie-breaker
455 self.reset_poll_counts();
456 }
457 // Update poll count if the winner survives in the final match
458 if *winner == self.loser_tree[0] {
459 self.update_poll_count_on_the_same_value(*winner);
460 if self.is_poll_count_gt(*winner, challenger) {
461 self.update_winner(cmp_node, winner, challenger);
462 }
463 } else if challenger < *winner {
464 // If the winner doesn’t survive in the final match, it indicates that the original winner
465 // has moved up in value, so the challenger now becomes the new winner.
466 // This also means that we’re in a new round of the tie breaker,
467 // and the polls count is outdated (though not yet cleaned up).
468 //
469 // By the time we reach this code, both the new winner and the current challenger
470 // have the same value, and neither has an updated polls count.
471 // Therefore, we simply select the one with the smaller index.
472 self.update_winner(cmp_node, winner, challenger);
473 }
474 }
475
476 /// Updates the loser tree to reflect the new winner after the previous winner is consumed.
477 /// This function adjusts the tree by comparing the current winner with challengers from
478 /// other partitions.
479 ///
480 /// If `enable_round_robin_tie_breaker` is true and a tie occurs at the final level, the
481 /// tie-breaker logic will be applied to ensure fair selection among equal elements.
482 fn update_loser_tree(&mut self) {
483 // Start with the current winner
484 let mut winner = self.loser_tree[0];
485
486 // Find the leaf node index of the winner in the loser tree.
487 let mut cmp_node = self.lt_leaf_node_index(winner);
488
489 // Traverse up the tree to adjust comparisons until reaching the root.
490 while cmp_node != 0 {
491 let challenger = self.loser_tree[cmp_node];
492 // If round-robin tie-breaker is enabled and we're at the final comparison (cmp_node == 1)
493 if self.enable_round_robin_tie_breaker && cmp_node == 1 {
494 match (&self.cursors[winner], &self.cursors[challenger]) {
495 (Some(ac), Some(bc)) => {
496 if ac == bc {
497 self.handle_tie(cmp_node, &mut winner, challenger);
498 } else {
499 // Ends of tie breaker
500 self.round_robin_tie_breaker_mode = false;
501 if ac > bc {
502 self.update_winner(cmp_node, &mut winner, challenger);
503 }
504 }
505 }
506 (None, _) => {
507 // Challenger wins, update winner
508 // Ends of tie breaker
509 self.round_robin_tie_breaker_mode = false;
510 self.update_winner(cmp_node, &mut winner, challenger);
511 }
512 (_, None) => {
513 // Winner wins again
514 // Ends of tie breaker
515 self.round_robin_tie_breaker_mode = false;
516 }
517 }
518 } else if self.is_gt(winner, challenger) {
519 self.update_winner(cmp_node, &mut winner, challenger);
520 }
521 cmp_node = self.lt_parent_node_index(cmp_node);
522 }
523 self.loser_tree[0] = winner;
524 self.loser_tree_adjusted = true;
525 }
526}
527
528impl<C: CursorValues + Unpin> Stream for SortPreservingMergeStream<C> {
529 type Item = Result<RecordBatch>;
530
531 fn poll_next(
532 mut self: Pin<&mut Self>,
533 cx: &mut Context<'_>,
534 ) -> Poll<Option<Self::Item>> {
535 let poll = self.poll_next_inner(cx);
536 self.metrics.record_poll(poll)
537 }
538}
539
540impl<C: CursorValues + Unpin> RecordBatchStream for SortPreservingMergeStream<C> {
541 fn schema(&self) -> SchemaRef {
542 Arc::clone(self.in_progress.schema())
543 }
544}