datafusion_physical_plan/sorts/
merge.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Merge that deals with an arbitrary size of streaming inputs.
19//! This is an order-preserving merge.
20
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{ready, Context, Poll};
24
25use crate::metrics::BaselineMetrics;
26use crate::sorts::builder::BatchBuilder;
27use crate::sorts::cursor::{Cursor, CursorValues};
28use crate::sorts::stream::PartitionedStream;
29use crate::RecordBatchStream;
30
31use arrow::datatypes::SchemaRef;
32use arrow::record_batch::RecordBatch;
33use datafusion_common::Result;
34use datafusion_execution::memory_pool::MemoryReservation;
35
36use futures::Stream;
37
38/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
39type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, RecordBatch)>>>;
40
41/// Merges a stream of sorted cursors and record batches into a single sorted stream
42#[derive(Debug)]
43pub(crate) struct SortPreservingMergeStream<C: CursorValues> {
44    in_progress: BatchBuilder,
45
46    /// The sorted input streams to merge together
47    streams: CursorStream<C>,
48
49    /// used to record execution metrics
50    metrics: BaselineMetrics,
51
52    /// If the stream has encountered an error or reaches the
53    /// `fetch` limit.
54    done: bool,
55
56    /// A loser tree that always produces the minimum cursor
57    ///
58    /// Node 0 stores the top winner, Nodes 1..num_streams store
59    /// the loser nodes
60    ///
61    /// This implements a "Tournament Tree" (aka Loser Tree) to keep
62    /// track of the current smallest element at the top. When the top
63    /// record is taken, the tree structure is not modified, and only
64    /// the path from bottom to top is visited, keeping the number of
65    /// comparisons close to the theoretical limit of `log(S)`.
66    ///
67    /// The current implementation uses a vector to store the tree.
68    /// Conceptually, it looks like this (assuming 8 streams):
69    ///
70    /// ```text
71    ///     0 (winner)
72    ///
73    ///     1
74    ///    / \
75    ///   2   3
76    ///  / \ / \
77    /// 4  5 6  7
78    /// ```
79    ///
80    /// Where element at index 0 in the vector is the current winner. Element
81    /// at index 1 is the root of the loser tree, element at index 2 is the
82    /// left child of the root, and element at index 3 is the right child of
83    /// the root and so on.
84    ///
85    /// reference: <https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
86    loser_tree: Vec<usize>,
87
88    /// If the most recently yielded overall winner has been replaced
89    /// within the loser tree. A value of `false` indicates that the
90    /// overall winner has been yielded but the loser tree has not
91    /// been updated
92    loser_tree_adjusted: bool,
93
94    /// Target batch size
95    batch_size: usize,
96
97    /// Cursors for each input partition. `None` means the input is exhausted
98    cursors: Vec<Option<Cursor<C>>>,
99
100    /// Configuration parameter to enable round-robin selection of tied winners of loser tree.
101    ///
102    /// This option controls the tie-breaker strategy and attempts to avoid the
103    /// issue of unbalanced polling between partitions
104    ///
105    /// If `true`, when multiple partitions have the same value, the partition
106    /// that has the fewest poll counts is selected. This strategy ensures that
107    /// multiple partitions with the same value are chosen equally, distributing
108    /// the polling load in a round-robin fashion. This approach balances the
109    /// workload more effectively across partitions and avoids excessive buffer
110    /// growth.
111    ///
112    /// if `false`, partitions with smaller indices are consistently chosen as
113    /// the winners, which can lead to an uneven distribution of polling and potentially
114    /// causing upstream operator buffers for the other partitions to grow
115    /// excessively, as they continued receiving data without consuming it.
116    ///
117    /// For example, an upstream operator like `RepartitionExec` execution would
118    /// keep sending data to certain partitions, but those partitions wouldn't
119    /// consume the data if they weren't selected as winners. This resulted in
120    /// inefficient buffer usage.
121    enable_round_robin_tie_breaker: bool,
122
123    /// Flag indicating whether we are in the mode of round-robin
124    /// tie breaker for the loser tree winners.
125    round_robin_tie_breaker_mode: bool,
126
127    /// Total number of polls returning the same value, as per partition.
128    /// We select the one that has less poll counts for tie-breaker in loser tree.
129    num_of_polled_with_same_value: Vec<usize>,
130
131    /// To keep track of reset counts
132    poll_reset_epochs: Vec<usize>,
133
134    /// Current reset count
135    current_reset_epoch: usize,
136
137    /// Stores the previous value of each partitions for tracking the poll counts on the same value.
138    prev_cursors: Vec<Option<Cursor<C>>>,
139
140    /// Optional number of rows to fetch
141    fetch: Option<usize>,
142
143    /// number of rows produced
144    produced: usize,
145
146    /// This vector contains the indices of the partitions that have not started emitting yet.
147    uninitiated_partitions: Vec<usize>,
148}
149
150impl<C: CursorValues> SortPreservingMergeStream<C> {
151    pub(crate) fn new(
152        streams: CursorStream<C>,
153        schema: SchemaRef,
154        metrics: BaselineMetrics,
155        batch_size: usize,
156        fetch: Option<usize>,
157        reservation: MemoryReservation,
158        enable_round_robin_tie_breaker: bool,
159    ) -> Self {
160        let stream_count = streams.partitions();
161
162        Self {
163            in_progress: BatchBuilder::new(schema, stream_count, batch_size, reservation),
164            streams,
165            metrics,
166            done: false,
167            cursors: (0..stream_count).map(|_| None).collect(),
168            prev_cursors: (0..stream_count).map(|_| None).collect(),
169            round_robin_tie_breaker_mode: false,
170            num_of_polled_with_same_value: vec![0; stream_count],
171            current_reset_epoch: 0,
172            poll_reset_epochs: vec![0; stream_count],
173            loser_tree: vec![],
174            loser_tree_adjusted: false,
175            batch_size,
176            fetch,
177            produced: 0,
178            uninitiated_partitions: (0..stream_count).collect(),
179            enable_round_robin_tie_breaker,
180        }
181    }
182
183    /// If the stream at the given index is not exhausted, and the last cursor for the
184    /// stream is finished, poll the stream for the next RecordBatch and create a new
185    /// cursor for the stream from the returned result
186    fn maybe_poll_stream(
187        &mut self,
188        cx: &mut Context<'_>,
189        idx: usize,
190    ) -> Poll<Result<()>> {
191        if self.cursors[idx].is_some() {
192            // Cursor is not finished - don't need a new RecordBatch yet
193            return Poll::Ready(Ok(()));
194        }
195
196        match futures::ready!(self.streams.poll_next(cx, idx)) {
197            None => Poll::Ready(Ok(())),
198            Some(Err(e)) => Poll::Ready(Err(e)),
199            Some(Ok((cursor, batch))) => {
200                self.cursors[idx] = Some(Cursor::new(cursor));
201                Poll::Ready(self.in_progress.push_batch(idx, batch))
202            }
203        }
204    }
205
206    fn poll_next_inner(
207        &mut self,
208        cx: &mut Context<'_>,
209    ) -> Poll<Option<Result<RecordBatch>>> {
210        if self.done {
211            return Poll::Ready(None);
212        }
213        // Once all partitions have set their corresponding cursors for the loser tree,
214        // we skip the following block. Until then, this function may be called multiple
215        // times and can return Poll::Pending if any partition returns Poll::Pending.
216
217        if self.loser_tree.is_empty() {
218            // Manual indexing since we're iterating over the vector and shrinking it in the loop
219            let mut idx = 0;
220            while idx < self.uninitiated_partitions.len() {
221                let partition_idx = self.uninitiated_partitions[idx];
222                match self.maybe_poll_stream(cx, partition_idx) {
223                    Poll::Ready(Err(e)) => {
224                        self.done = true;
225                        return Poll::Ready(Some(Err(e)));
226                    }
227                    Poll::Pending => {
228                        // The polled stream is pending which means we're already set up to
229                        // be woken when necessary
230                        // Try the next stream
231                        idx += 1;
232                    }
233                    _ => {
234                        // The polled stream is ready
235                        // Remove it from uninitiated_partitions
236                        // Don't bump idx here, since a new element will have taken its
237                        // place which we'll try in the next loop iteration
238                        // swap_remove will change the partition poll order, but that shouldn't
239                        // make a difference since we're waiting for all streams to be ready.
240                        self.uninitiated_partitions.swap_remove(idx);
241                    }
242                }
243            }
244
245            if self.uninitiated_partitions.is_empty() {
246                // If there are no more uninitiated partitions, set up the loser tree and continue
247                // to the next phase.
248
249                // Claim the memory for the uninitiated partitions
250                self.uninitiated_partitions.shrink_to_fit();
251                self.init_loser_tree();
252            } else {
253                // There are still uninitiated partitions so return pending.
254                // We only get here if we've polled all uninitiated streams and at least one of them
255                // returned pending itself. That means we will be woken as soon as one of the
256                // streams would like to be polled again.
257                // There is no need to reschedule ourselves eagerly.
258                return Poll::Pending;
259            }
260        }
261
262        // NB timer records time taken on drop, so there are no
263        // calls to `timer.done()` below.
264        let elapsed_compute = self.metrics.elapsed_compute().clone();
265        let _timer = elapsed_compute.timer();
266
267        loop {
268            // Adjust the loser tree if necessary, returning control if needed
269            if !self.loser_tree_adjusted {
270                let winner = self.loser_tree[0];
271                if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
272                    self.done = true;
273                    return Poll::Ready(Some(Err(e)));
274                }
275                self.update_loser_tree();
276            }
277
278            let stream_idx = self.loser_tree[0];
279            if self.advance_cursors(stream_idx) {
280                self.loser_tree_adjusted = false;
281                self.in_progress.push_row(stream_idx);
282
283                // stop sorting if fetch has been reached
284                if self.fetch_reached() {
285                    self.done = true;
286                } else if self.in_progress.len() < self.batch_size {
287                    continue;
288                }
289            }
290
291            self.produced += self.in_progress.len();
292
293            return Poll::Ready(self.in_progress.build_record_batch().transpose());
294        }
295    }
296
297    /// For the given partition, updates the poll count. If the current value is the same
298    /// of the previous value, it increases the count by 1; otherwise, it is reset as 0.
299    fn update_poll_count_on_the_same_value(&mut self, partition_idx: usize) {
300        let cursor = &mut self.cursors[partition_idx];
301
302        // Check if the current partition's poll count is logically "reset"
303        if self.poll_reset_epochs[partition_idx] != self.current_reset_epoch {
304            self.poll_reset_epochs[partition_idx] = self.current_reset_epoch;
305            self.num_of_polled_with_same_value[partition_idx] = 0;
306        }
307
308        if let Some(c) = cursor.as_mut() {
309            // Compare with the last row in the previous batch
310            let prev_cursor = &self.prev_cursors[partition_idx];
311            if c.is_eq_to_prev_one(prev_cursor.as_ref()) {
312                self.num_of_polled_with_same_value[partition_idx] += 1;
313            } else {
314                self.num_of_polled_with_same_value[partition_idx] = 0;
315            }
316        }
317    }
318
319    fn fetch_reached(&mut self) -> bool {
320        self.fetch
321            .map(|fetch| self.produced + self.in_progress.len() >= fetch)
322            .unwrap_or(false)
323    }
324
325    /// Advances the actual cursor. If it reaches its end, update the
326    /// previous cursor with it.
327    ///
328    /// If the given partition is not exhausted, the function returns `true`.
329    fn advance_cursors(&mut self, stream_idx: usize) -> bool {
330        if let Some(cursor) = &mut self.cursors[stream_idx] {
331            let _ = cursor.advance();
332            if cursor.is_finished() {
333                // Take the current cursor, leaving `None` in its place
334                self.prev_cursors[stream_idx] = self.cursors[stream_idx].take();
335            }
336            true
337        } else {
338            false
339        }
340    }
341
342    /// Returns `true` if the cursor at index `a` is greater than at index `b`.
343    /// In an equality case, it compares the partition indices given.
344    #[inline]
345    fn is_gt(&self, a: usize, b: usize) -> bool {
346        match (&self.cursors[a], &self.cursors[b]) {
347            (None, _) => true,
348            (_, None) => false,
349            (Some(ac), Some(bc)) => ac.cmp(bc).then_with(|| a.cmp(&b)).is_gt(),
350        }
351    }
352
353    #[inline]
354    fn is_poll_count_gt(&self, a: usize, b: usize) -> bool {
355        let poll_a = self.num_of_polled_with_same_value[a];
356        let poll_b = self.num_of_polled_with_same_value[b];
357        poll_a.cmp(&poll_b).then_with(|| a.cmp(&b)).is_gt()
358    }
359
360    #[inline]
361    fn update_winner(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) {
362        self.loser_tree[cmp_node] = *winner;
363        *winner = challenger;
364    }
365
366    /// Find the leaf node index in the loser tree for the given cursor index
367    ///
368    /// Note that this is not necessarily a leaf node in the tree, but it can
369    /// also be a half-node (a node with only one child). This happens when the
370    /// number of cursors/streams is not a power of two. Thus, the loser tree
371    /// will be unbalanced, but it will still work correctly.
372    ///
373    /// For example, with 5 streams, the loser tree will look like this:
374    ///
375    /// ```text
376    ///           0 (winner)
377    ///
378    ///           1
379    ///        /     \
380    ///       2       3
381    ///     /  \     / \
382    ///    4    |   |   |
383    ///   / \   |   |   |
384    /// -+---+--+---+---+---- Below is not a part of loser tree
385    ///  S3 S4 S0   S1  S2
386    /// ```
387    ///
388    /// S0, S1, ... S4 are the streams (read: stream at index 0, stream at
389    /// index 1, etc.)
390    ///
391    /// Zooming in at node 2 in the loser tree as an example, we can see that
392    /// it takes as input the next item at (S0) and the loser of (S3, S4).
393    #[inline]
394    fn lt_leaf_node_index(&self, cursor_index: usize) -> usize {
395        (self.cursors.len() + cursor_index) / 2
396    }
397
398    /// Find the parent node index for the given node index
399    #[inline]
400    fn lt_parent_node_index(&self, node_idx: usize) -> usize {
401        node_idx / 2
402    }
403
404    /// Attempts to initialize the loser tree with one value from each
405    /// non exhausted input, if possible
406    fn init_loser_tree(&mut self) {
407        // Init loser tree
408        self.loser_tree = vec![usize::MAX; self.cursors.len()];
409        for i in 0..self.cursors.len() {
410            let mut winner = i;
411            let mut cmp_node = self.lt_leaf_node_index(i);
412            while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
413                let challenger = self.loser_tree[cmp_node];
414                if self.is_gt(winner, challenger) {
415                    self.loser_tree[cmp_node] = winner;
416                    winner = challenger;
417                }
418
419                cmp_node = self.lt_parent_node_index(cmp_node);
420            }
421            self.loser_tree[cmp_node] = winner;
422        }
423        self.loser_tree_adjusted = true;
424    }
425
426    /// Resets the poll count by incrementing the reset epoch.
427    fn reset_poll_counts(&mut self) {
428        self.current_reset_epoch += 1;
429    }
430
431    /// Handles tie-breaking logic during the adjustment of the loser tree.
432    ///
433    /// When comparing elements from multiple partitions in the `update_loser_tree` process, a tie can occur
434    /// between the current winner and a challenger. This function is invoked when such a tie needs to be
435    /// resolved according to the round-robin tie-breaker mode.
436    ///
437    /// If round-robin tie-breaking is not active, it is enabled, and the poll counts for all elements are reset.
438    /// The function then compares the poll counts of the current winner and the challenger:
439    /// - If the winner remains at the top after the final comparison, it increments the winner's poll count.
440    /// - If the challenger has a lower poll count than the current winner, the challenger becomes the new winner.
441    /// - If the poll counts are equal but the challenger's index is smaller, the challenger is preferred.
442    ///
443    /// # Parameters
444    /// - `cmp_node`: The index of the comparison node in the loser tree where the tie-breaking is happening.
445    /// - `winner`: A mutable reference to the current winner, which may be updated based on the tie-breaking result.
446    /// - `challenger`: The index of the challenger being compared against the winner.
447    ///
448    /// This function ensures fair selection among elements with equal values when tie-breaking mode is enabled,
449    /// aiming to balance the polling across different partitions.
450    #[inline]
451    fn handle_tie(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) {
452        if !self.round_robin_tie_breaker_mode {
453            self.round_robin_tie_breaker_mode = true;
454            // Reset poll count for tie-breaker
455            self.reset_poll_counts();
456        }
457        // Update poll count if the winner survives in the final match
458        if *winner == self.loser_tree[0] {
459            self.update_poll_count_on_the_same_value(*winner);
460            if self.is_poll_count_gt(*winner, challenger) {
461                self.update_winner(cmp_node, winner, challenger);
462            }
463        } else if challenger < *winner {
464            // If the winner doesn’t survive in the final match, it indicates that the original winner
465            // has moved up in value, so the challenger now becomes the new winner.
466            // This also means that we’re in a new round of the tie breaker,
467            // and the polls count is outdated (though not yet cleaned up).
468            //
469            // By the time we reach this code, both the new winner and the current challenger
470            // have the same value, and neither has an updated polls count.
471            // Therefore, we simply select the one with the smaller index.
472            self.update_winner(cmp_node, winner, challenger);
473        }
474    }
475
476    /// Updates the loser tree to reflect the new winner after the previous winner is consumed.
477    /// This function adjusts the tree by comparing the current winner with challengers from
478    /// other partitions.
479    ///
480    /// If `enable_round_robin_tie_breaker` is true and a tie occurs at the final level, the
481    /// tie-breaker logic will be applied to ensure fair selection among equal elements.
482    fn update_loser_tree(&mut self) {
483        // Start with the current winner
484        let mut winner = self.loser_tree[0];
485
486        // Find the leaf node index of the winner in the loser tree.
487        let mut cmp_node = self.lt_leaf_node_index(winner);
488
489        // Traverse up the tree to adjust comparisons until reaching the root.
490        while cmp_node != 0 {
491            let challenger = self.loser_tree[cmp_node];
492            // If round-robin tie-breaker is enabled and we're at the final comparison (cmp_node == 1)
493            if self.enable_round_robin_tie_breaker && cmp_node == 1 {
494                match (&self.cursors[winner], &self.cursors[challenger]) {
495                    (Some(ac), Some(bc)) => {
496                        if ac == bc {
497                            self.handle_tie(cmp_node, &mut winner, challenger);
498                        } else {
499                            // Ends of tie breaker
500                            self.round_robin_tie_breaker_mode = false;
501                            if ac > bc {
502                                self.update_winner(cmp_node, &mut winner, challenger);
503                            }
504                        }
505                    }
506                    (None, _) => {
507                        // Challenger wins, update winner
508                        // Ends of tie breaker
509                        self.round_robin_tie_breaker_mode = false;
510                        self.update_winner(cmp_node, &mut winner, challenger);
511                    }
512                    (_, None) => {
513                        // Winner wins again
514                        // Ends of tie breaker
515                        self.round_robin_tie_breaker_mode = false;
516                    }
517                }
518            } else if self.is_gt(winner, challenger) {
519                self.update_winner(cmp_node, &mut winner, challenger);
520            }
521            cmp_node = self.lt_parent_node_index(cmp_node);
522        }
523        self.loser_tree[0] = winner;
524        self.loser_tree_adjusted = true;
525    }
526}
527
528impl<C: CursorValues + Unpin> Stream for SortPreservingMergeStream<C> {
529    type Item = Result<RecordBatch>;
530
531    fn poll_next(
532        mut self: Pin<&mut Self>,
533        cx: &mut Context<'_>,
534    ) -> Poll<Option<Self::Item>> {
535        let poll = self.poll_next_inner(cx);
536        self.metrics.record_poll(poll)
537    }
538}
539
540impl<C: CursorValues + Unpin> RecordBatchStream for SortPreservingMergeStream<C> {
541    fn schema(&self) -> SchemaRef {
542        Arc::clone(self.in_progress.schema())
543    }
544}