datafusion_physical_plan/joins/sort_merge_join/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Sort-Merge Join execution
19//!
20//! This module implements the runtime state machine for the Sort-Merge Join
21//! operator. It drives two sorted input streams (the *streamed* side and the
22//! *buffered* side), compares join keys, and produces joined `RecordBatch`es.
23
24use std::cmp::Ordering;
25use std::collections::{HashMap, VecDeque};
26use std::fs::File;
27use std::io::BufReader;
28use std::mem::size_of;
29use std::ops::Range;
30use std::pin::Pin;
31use std::sync::atomic::AtomicUsize;
32use std::sync::atomic::Ordering::Relaxed;
33use std::sync::Arc;
34use std::task::{Context, Poll};
35
36use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
37use crate::joins::utils::{compare_join_arrays, JoinFilter};
38use crate::spill::spill_manager::SpillManager;
39use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream};
40
41use arrow::array::{types::UInt64Type, *};
42use arrow::compute::{
43    self, concat_batches, filter_record_batch, is_not_null, take, SortOptions,
44};
45use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
46use arrow::error::ArrowError;
47use arrow::ipc::reader::StreamReader;
48use datafusion_common::config::SpillCompression;
49use datafusion_common::{
50    exec_err, internal_err, not_impl_err, DataFusionError, HashSet, JoinSide, JoinType,
51    NullEquality, Result,
52};
53use datafusion_execution::disk_manager::RefCountedTempFile;
54use datafusion_execution::memory_pool::MemoryReservation;
55use datafusion_execution::runtime_env::RuntimeEnv;
56use datafusion_physical_expr_common::physical_expr::PhysicalExprRef;
57
58use futures::{Stream, StreamExt};
59
60/// State of SMJ stream
61#[derive(Debug, PartialEq, Eq)]
62pub(super) enum SortMergeJoinState {
63    /// Init joining with a new streamed row or a new buffered batches
64    Init,
65    /// Polling one streamed row or one buffered batch, or both
66    Polling,
67    /// Joining polled data and making output
68    JoinOutput,
69    /// No more output
70    Exhausted,
71}
72
73/// State of streamed data stream
74#[derive(Debug, PartialEq, Eq)]
75pub(super) enum StreamedState {
76    /// Init polling
77    Init,
78    /// Polling one streamed row
79    Polling,
80    /// Ready to produce one streamed row
81    Ready,
82    /// No more streamed row
83    Exhausted,
84}
85
86/// State of buffered data stream
87#[derive(Debug, PartialEq, Eq)]
88pub(super) enum BufferedState {
89    /// Init polling
90    Init,
91    /// Polling first row in the next batch
92    PollingFirst,
93    /// Polling rest rows in the next batch
94    PollingRest,
95    /// Ready to produce one batch
96    Ready,
97    /// No more buffered batches
98    Exhausted,
99}
100
101/// Represents a chunk of joined data from streamed and buffered side
102pub(super) struct StreamedJoinedChunk {
103    /// Index of batch in buffered_data
104    buffered_batch_idx: Option<usize>,
105    /// Array builder for streamed indices
106    streamed_indices: UInt64Builder,
107    /// Array builder for buffered indices
108    /// This could contain nulls if the join is null-joined
109    buffered_indices: UInt64Builder,
110}
111
112/// Represents a record batch from streamed input.
113///
114/// Also stores information of matching rows from buffered batches.
115pub(super) struct StreamedBatch {
116    /// The streamed record batch
117    pub batch: RecordBatch,
118    /// The index of row in the streamed batch to compare with buffered batches
119    pub idx: usize,
120    /// The join key arrays of streamed batch which are used to compare with buffered batches
121    /// and to produce output. They are produced by evaluating `on` expressions.
122    pub join_arrays: Vec<ArrayRef>,
123    /// Chunks of indices from buffered side (may be nulls) joined to streamed
124    pub output_indices: Vec<StreamedJoinedChunk>,
125    /// Index of currently scanned batch from buffered data
126    pub buffered_batch_idx: Option<usize>,
127    /// Indices that found a match for the given join filter
128    /// Used for semi joins to keep track the streaming index which got a join filter match
129    /// and already emitted to the output.
130    pub join_filter_matched_idxs: HashSet<u64>,
131}
132
133impl StreamedBatch {
134    fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self {
135        let join_arrays = join_arrays(&batch, on_column);
136        StreamedBatch {
137            batch,
138            idx: 0,
139            join_arrays,
140            output_indices: vec![],
141            buffered_batch_idx: None,
142            join_filter_matched_idxs: HashSet::new(),
143        }
144    }
145
146    fn new_empty(schema: SchemaRef) -> Self {
147        StreamedBatch {
148            batch: RecordBatch::new_empty(schema),
149            idx: 0,
150            join_arrays: vec![],
151            output_indices: vec![],
152            buffered_batch_idx: None,
153            join_filter_matched_idxs: HashSet::new(),
154        }
155    }
156
157    /// Appends new pair consisting of current streamed index and `buffered_idx`
158    /// index of buffered batch with `buffered_batch_idx` index.
159    fn append_output_pair(
160        &mut self,
161        buffered_batch_idx: Option<usize>,
162        buffered_idx: Option<usize>,
163    ) {
164        // If no current chunk exists or current chunk is not for current buffered batch,
165        // create a new chunk
166        if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx
167        {
168            self.output_indices.push(StreamedJoinedChunk {
169                buffered_batch_idx,
170                streamed_indices: UInt64Builder::with_capacity(1),
171                buffered_indices: UInt64Builder::with_capacity(1),
172            });
173            self.buffered_batch_idx = buffered_batch_idx;
174        };
175        let current_chunk = self.output_indices.last_mut().unwrap();
176
177        // Append index of streamed batch and index of buffered batch into current chunk
178        current_chunk.streamed_indices.append_value(self.idx as u64);
179        if let Some(idx) = buffered_idx {
180            current_chunk.buffered_indices.append_value(idx as u64);
181        } else {
182            current_chunk.buffered_indices.append_null();
183        }
184    }
185}
186
187/// A buffered batch that contains contiguous rows with same join key
188///
189/// `BufferedBatch` can exist as either an in-memory `RecordBatch` or a `RefCountedTempFile` on disk.
190#[derive(Debug)]
191pub(super) struct BufferedBatch {
192    /// Represents in memory or spilled record batch
193    pub batch: BufferedBatchState,
194    /// The range in which the rows share the same join key
195    pub range: Range<usize>,
196    /// Array refs of the join key
197    pub join_arrays: Vec<ArrayRef>,
198    /// Buffered joined index (null joining buffered)
199    pub null_joined: Vec<usize>,
200    /// Size estimation used for reserving / releasing memory
201    pub size_estimation: usize,
202    /// The indices of buffered batch that the join filter doesn't satisfy.
203    /// This is a map between right row index and a boolean value indicating whether all joined row
204    /// of the right row does not satisfy the filter .
205    /// When dequeuing the buffered batch, we need to produce null joined rows for these indices.
206    pub join_filter_not_matched_map: HashMap<u64, bool>,
207    /// Current buffered batch number of rows. Equal to batch.num_rows()
208    /// but if batch is spilled to disk this property is preferable
209    /// and less expensive
210    pub num_rows: usize,
211}
212
213impl BufferedBatch {
214    fn new(
215        batch: RecordBatch,
216        range: Range<usize>,
217        on_column: &[PhysicalExprRef],
218    ) -> Self {
219        let join_arrays = join_arrays(&batch, on_column);
220
221        // Estimation is calculated as
222        //   inner batch size
223        // + join keys size
224        // + worst case null_joined (as vector capacity * element size)
225        // + Range size
226        // + size of this estimation
227        let size_estimation = batch.get_array_memory_size()
228            + join_arrays
229                .iter()
230                .map(|arr| arr.get_array_memory_size())
231                .sum::<usize>()
232            + batch.num_rows().next_power_of_two() * size_of::<usize>()
233            + size_of::<Range<usize>>()
234            + size_of::<usize>();
235
236        let num_rows = batch.num_rows();
237        BufferedBatch {
238            batch: BufferedBatchState::InMemory(batch),
239            range,
240            join_arrays,
241            null_joined: vec![],
242            size_estimation,
243            join_filter_not_matched_map: HashMap::new(),
244            num_rows,
245        }
246    }
247}
248
249// TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429)
250// Used to represent whether the buffered data is currently in memory or written to disk
251#[derive(Debug)]
252pub(super) enum BufferedBatchState {
253    // In memory record batch
254    InMemory(RecordBatch),
255    // Spilled temp file
256    Spilled(RefCountedTempFile),
257}
258
259/// Sort-Merge join stream that consumes streamed and buffered data streams
260/// and produces joined output stream.
261pub(super) struct SortMergeJoinStream {
262    // ========================================================================
263    // PROPERTIES:
264    // These fields are initialized at the start and remain constant throughout
265    // the execution.
266    // ========================================================================
267    /// Output schema
268    pub schema: SchemaRef,
269    /// Defines the null equality for the join.
270    pub null_equality: NullEquality,
271    /// Sort options of join columns used to sort streamed and buffered data stream
272    pub sort_options: Vec<SortOptions>,
273    /// optional join filter
274    pub filter: Option<JoinFilter>,
275    /// How the join is performed
276    pub join_type: JoinType,
277    /// Target output batch size
278    pub batch_size: usize,
279
280    // ========================================================================
281    // STREAMED FIELDS:
282    // These fields manage the properties and state of the streamed input.
283    // ========================================================================
284    /// Input schema of streamed
285    pub streamed_schema: SchemaRef,
286    /// Streamed data stream
287    pub streamed: SendableRecordBatchStream,
288    /// Current processing record batch of streamed
289    pub streamed_batch: StreamedBatch,
290    /// (used in outer join) Is current streamed row joined at least once?
291    pub streamed_joined: bool,
292    /// State of streamed
293    pub streamed_state: StreamedState,
294    /// Join key columns of streamed
295    pub on_streamed: Vec<PhysicalExprRef>,
296
297    // ========================================================================
298    // BUFFERED FIELDS:
299    // These fields manage the properties and state of the buffered input.
300    // ========================================================================
301    /// Input schema of buffered
302    pub buffered_schema: SchemaRef,
303    /// Buffered data stream
304    pub buffered: SendableRecordBatchStream,
305    /// Current buffered data
306    pub buffered_data: BufferedData,
307    /// (used in outer join) Is current buffered batches joined at least once?
308    pub buffered_joined: bool,
309    /// State of buffered
310    pub buffered_state: BufferedState,
311    /// Join key columns of buffered
312    pub on_buffered: Vec<PhysicalExprRef>,
313
314    // ========================================================================
315    // MERGE JOIN STATES:
316    // These fields track the execution state of merge join and are updated
317    // during the execution.
318    // ========================================================================
319    /// Current state of the stream
320    pub state: SortMergeJoinState,
321    /// Staging output array builders
322    pub staging_output_record_batches: JoinedRecordBatches,
323    /// Output buffer. Currently used by filtering as it requires double buffering
324    /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches`
325    pub output: RecordBatch,
326    /// Staging output size, including output batches and staging joined results.
327    /// Increased when we put rows into buffer and decreased after we actually output batches.
328    /// Used to trigger output when sufficient rows are ready
329    pub output_size: usize,
330    /// The comparison result of current streamed row and buffered batches
331    pub current_ordering: Ordering,
332    /// Manages the process of spilling and reading back intermediate data
333    pub spill_manager: SpillManager,
334
335    // ========================================================================
336    // EXECUTION RESOURCES:
337    // Fields related to managing execution resources and monitoring performance.
338    // ========================================================================
339    /// Metrics
340    pub join_metrics: SortMergeJoinMetrics,
341    /// Memory reservation
342    pub reservation: MemoryReservation,
343    /// Runtime env
344    pub runtime_env: Arc<RuntimeEnv>,
345    /// A unique number for each batch
346    pub streamed_batch_counter: AtomicUsize,
347}
348
349/// Joined batches with attached join filter information
350pub(super) struct JoinedRecordBatches {
351    /// Joined batches. Each batch is already joined columns from left and right sources
352    pub batches: Vec<RecordBatch>,
353    /// Filter match mask for each row(matched/non-matched)
354    pub filter_mask: BooleanBuilder,
355    /// Left row indices to glue together rows in `batches` and `filter_mask`
356    pub row_indices: UInt64Builder,
357    /// Which unique batch id the row belongs to
358    /// It is necessary to differentiate rows that are distributed the way when they point to the same
359    /// row index but in not the same batches
360    pub batch_ids: Vec<usize>,
361}
362
363impl JoinedRecordBatches {
364    fn clear(&mut self) {
365        self.batches.clear();
366        self.batch_ids.clear();
367        self.filter_mask = BooleanBuilder::new();
368        self.row_indices = UInt64Builder::new();
369    }
370}
371impl RecordBatchStream for SortMergeJoinStream {
372    fn schema(&self) -> SchemaRef {
373        Arc::clone(&self.schema)
374    }
375}
376
377/// True if next index refers to either:
378/// - another batch id
379/// - another row index within same batch id
380/// - end of row indices
381#[inline(always)]
382fn last_index_for_row(
383    row_index: usize,
384    indices: &UInt64Array,
385    batch_ids: &[usize],
386    indices_len: usize,
387) -> bool {
388    row_index == indices_len - 1
389        || batch_ids[row_index] != batch_ids[row_index + 1]
390        || indices.value(row_index) != indices.value(row_index + 1)
391}
392
393// Returns a corrected boolean bitmask for the given join type
394// Values in the corrected bitmask can be: true, false, null
395// `true` - the row found its match and sent to the output
396// `null` - the row ignored, no output
397// `false` - the row sent as NULL joined row
398pub(super) fn get_corrected_filter_mask(
399    join_type: JoinType,
400    row_indices: &UInt64Array,
401    batch_ids: &[usize],
402    filter_mask: &BooleanArray,
403    expected_size: usize,
404) -> Option<BooleanArray> {
405    let row_indices_length = row_indices.len();
406    let mut corrected_mask: BooleanBuilder =
407        BooleanBuilder::with_capacity(row_indices_length);
408    let mut seen_true = false;
409
410    match join_type {
411        JoinType::Left | JoinType::Right => {
412            for i in 0..row_indices_length {
413                let last_index =
414                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
415                if filter_mask.value(i) {
416                    seen_true = true;
417                    corrected_mask.append_value(true);
418                } else if seen_true || !filter_mask.value(i) && !last_index {
419                    corrected_mask.append_null(); // to be ignored and not set to output
420                } else {
421                    corrected_mask.append_value(false); // to be converted to null joined row
422                }
423
424                if last_index {
425                    seen_true = false;
426                }
427            }
428
429            // Generate null joined rows for records which have no matching join key
430            corrected_mask.append_n(expected_size - corrected_mask.len(), false);
431            Some(corrected_mask.finish())
432        }
433        JoinType::LeftMark | JoinType::RightMark => {
434            for i in 0..row_indices_length {
435                let last_index =
436                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
437                if filter_mask.value(i) && !seen_true {
438                    seen_true = true;
439                    corrected_mask.append_value(true);
440                } else if seen_true || !filter_mask.value(i) && !last_index {
441                    corrected_mask.append_null(); // to be ignored and not set to output
442                } else {
443                    corrected_mask.append_value(false); // to be converted to null joined row
444                }
445
446                if last_index {
447                    seen_true = false;
448                }
449            }
450
451            // Generate null joined rows for records which have no matching join key
452            corrected_mask.append_n(expected_size - corrected_mask.len(), false);
453            Some(corrected_mask.finish())
454        }
455        JoinType::LeftSemi | JoinType::RightSemi => {
456            for i in 0..row_indices_length {
457                let last_index =
458                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
459                if filter_mask.value(i) && !seen_true {
460                    seen_true = true;
461                    corrected_mask.append_value(true);
462                } else {
463                    corrected_mask.append_null(); // to be ignored and not set to output
464                }
465
466                if last_index {
467                    seen_true = false;
468                }
469            }
470
471            Some(corrected_mask.finish())
472        }
473        JoinType::LeftAnti | JoinType::RightAnti => {
474            for i in 0..row_indices_length {
475                let last_index =
476                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
477
478                if filter_mask.value(i) {
479                    seen_true = true;
480                }
481
482                if last_index {
483                    if !seen_true {
484                        corrected_mask.append_value(true);
485                    } else {
486                        corrected_mask.append_null();
487                    }
488
489                    seen_true = false;
490                } else {
491                    corrected_mask.append_null();
492                }
493            }
494            // Generate null joined rows for records which have no matching join key,
495            // for LeftAnti non-matched considered as true
496            corrected_mask.append_n(expected_size - corrected_mask.len(), true);
497            Some(corrected_mask.finish())
498        }
499        JoinType::Full => {
500            let mut mask: Vec<Option<bool>> = vec![Some(true); row_indices_length];
501            let mut last_true_idx = 0;
502            let mut first_row_idx = 0;
503            let mut seen_false = false;
504
505            for i in 0..row_indices_length {
506                let last_index =
507                    last_index_for_row(i, row_indices, batch_ids, row_indices_length);
508                let val = filter_mask.value(i);
509                let is_null = filter_mask.is_null(i);
510
511                if val {
512                    // memoize the first seen matched row
513                    if !seen_true {
514                        last_true_idx = i;
515                    }
516                    seen_true = true;
517                }
518
519                if is_null || val {
520                    mask[i] = Some(true);
521                } else if !is_null && !val && (seen_true || seen_false) {
522                    mask[i] = None;
523                } else {
524                    mask[i] = Some(false);
525                }
526
527                if !is_null && !val {
528                    seen_false = true;
529                }
530
531                if last_index {
532                    // If the left row seen as true its needed to output it once
533                    // To do that we mark all other matches for same row as null to avoid the output
534                    if seen_true {
535                        #[allow(clippy::needless_range_loop)]
536                        for j in first_row_idx..last_true_idx {
537                            mask[j] = None;
538                        }
539                    }
540
541                    seen_true = false;
542                    seen_false = false;
543                    last_true_idx = 0;
544                    first_row_idx = i + 1;
545                }
546            }
547
548            Some(BooleanArray::from(mask))
549        }
550        // Only outer joins needs to keep track of processed rows and apply corrected filter mask
551        _ => None,
552    }
553}
554
555impl Stream for SortMergeJoinStream {
556    type Item = Result<RecordBatch>;
557
558    fn poll_next(
559        mut self: Pin<&mut Self>,
560        cx: &mut Context<'_>,
561    ) -> Poll<Option<Self::Item>> {
562        let join_time = self.join_metrics.join_time().clone();
563        let _timer = join_time.timer();
564        loop {
565            match &self.state {
566                SortMergeJoinState::Init => {
567                    let streamed_exhausted =
568                        self.streamed_state == StreamedState::Exhausted;
569                    let buffered_exhausted =
570                        self.buffered_state == BufferedState::Exhausted;
571                    self.state = if streamed_exhausted && buffered_exhausted {
572                        SortMergeJoinState::Exhausted
573                    } else {
574                        match self.current_ordering {
575                            Ordering::Less | Ordering::Equal => {
576                                if !streamed_exhausted {
577                                    if self.filter.is_some()
578                                        && matches!(
579                                            self.join_type,
580                                            JoinType::Left
581                                                | JoinType::LeftSemi
582                                                | JoinType::LeftMark
583                                                | JoinType::Right
584                                                | JoinType::RightSemi
585                                                | JoinType::RightMark
586                                                | JoinType::LeftAnti
587                                                | JoinType::RightAnti
588                                                | JoinType::Full
589                                        )
590                                    {
591                                        self.freeze_all()?;
592
593                                        // If join is filtered and there is joined tuples waiting
594                                        // to be filtered
595                                        if !self
596                                            .staging_output_record_batches
597                                            .batches
598                                            .is_empty()
599                                        {
600                                            // Apply filter on joined tuples and get filtered batch
601                                            let out_filtered_batch =
602                                                self.filter_joined_batch()?;
603
604                                            // Append filtered batch to the output buffer
605                                            self.output = concat_batches(
606                                                &self.schema(),
607                                                [&self.output, &out_filtered_batch],
608                                            )?;
609
610                                            // Send to output if the output buffer surpassed the `batch_size`
611                                            if self.output.num_rows() >= self.batch_size {
612                                                let record_batch = std::mem::replace(
613                                                    &mut self.output,
614                                                    RecordBatch::new_empty(
615                                                        out_filtered_batch.schema(),
616                                                    ),
617                                                );
618                                                return Poll::Ready(Some(Ok(
619                                                    record_batch,
620                                                )));
621                                            }
622                                        }
623                                    }
624
625                                    self.streamed_joined = false;
626                                    self.streamed_state = StreamedState::Init;
627                                }
628                            }
629                            Ordering::Greater => {
630                                if !buffered_exhausted {
631                                    self.buffered_joined = false;
632                                    self.buffered_state = BufferedState::Init;
633                                }
634                            }
635                        }
636                        SortMergeJoinState::Polling
637                    };
638                }
639                SortMergeJoinState::Polling => {
640                    if ![StreamedState::Exhausted, StreamedState::Ready]
641                        .contains(&self.streamed_state)
642                    {
643                        match self.poll_streamed_row(cx)? {
644                            Poll::Ready(_) => {}
645                            Poll::Pending => return Poll::Pending,
646                        }
647                    }
648
649                    if ![BufferedState::Exhausted, BufferedState::Ready]
650                        .contains(&self.buffered_state)
651                    {
652                        match self.poll_buffered_batches(cx)? {
653                            Poll::Ready(_) => {}
654                            Poll::Pending => return Poll::Pending,
655                        }
656                    }
657                    let streamed_exhausted =
658                        self.streamed_state == StreamedState::Exhausted;
659                    let buffered_exhausted =
660                        self.buffered_state == BufferedState::Exhausted;
661                    if streamed_exhausted && buffered_exhausted {
662                        self.state = SortMergeJoinState::Exhausted;
663                        continue;
664                    }
665                    self.current_ordering = self.compare_streamed_buffered()?;
666                    self.state = SortMergeJoinState::JoinOutput;
667                }
668                SortMergeJoinState::JoinOutput => {
669                    self.join_partial()?;
670
671                    if self.output_size < self.batch_size {
672                        if self.buffered_data.scanning_finished() {
673                            self.buffered_data.scanning_reset();
674                            self.state = SortMergeJoinState::Init;
675                        }
676                    } else {
677                        self.freeze_all()?;
678                        if !self.staging_output_record_batches.batches.is_empty() {
679                            let record_batch = self.output_record_batch_and_reset()?;
680                            // For non-filtered join output whenever the target output batch size
681                            // is hit. For filtered join its needed to output on later phase
682                            // because target output batch size can be hit in the middle of
683                            // filtering causing the filtering to be incomplete and causing
684                            // correctness issues
685                            if self.filter.is_some()
686                                && matches!(
687                                    self.join_type,
688                                    JoinType::Left
689                                        | JoinType::LeftSemi
690                                        | JoinType::Right
691                                        | JoinType::RightSemi
692                                        | JoinType::LeftAnti
693                                        | JoinType::RightAnti
694                                        | JoinType::LeftMark
695                                        | JoinType::RightMark
696                                        | JoinType::Full
697                                )
698                            {
699                                continue;
700                            }
701
702                            return Poll::Ready(Some(Ok(record_batch)));
703                        }
704                        return Poll::Pending;
705                    }
706                }
707                SortMergeJoinState::Exhausted => {
708                    self.freeze_all()?;
709
710                    // if there is still something not processed
711                    if !self.staging_output_record_batches.batches.is_empty() {
712                        if self.filter.is_some()
713                            && matches!(
714                                self.join_type,
715                                JoinType::Left
716                                    | JoinType::LeftSemi
717                                    | JoinType::Right
718                                    | JoinType::RightSemi
719                                    | JoinType::LeftAnti
720                                    | JoinType::RightAnti
721                                    | JoinType::Full
722                                    | JoinType::LeftMark
723                                    | JoinType::RightMark
724                            )
725                        {
726                            let record_batch = self.filter_joined_batch()?;
727                            return Poll::Ready(Some(Ok(record_batch)));
728                        } else {
729                            let record_batch = self.output_record_batch_and_reset()?;
730                            return Poll::Ready(Some(Ok(record_batch)));
731                        }
732                    } else if self.output.num_rows() > 0 {
733                        // if processed but still not outputted because it didn't hit batch size before
734                        let schema = self.output.schema();
735                        let record_batch = std::mem::replace(
736                            &mut self.output,
737                            RecordBatch::new_empty(schema),
738                        );
739                        return Poll::Ready(Some(Ok(record_batch)));
740                    } else {
741                        return Poll::Ready(None);
742                    }
743                }
744            }
745        }
746    }
747}
748
749impl SortMergeJoinStream {
750    #[allow(clippy::too_many_arguments)]
751    pub fn try_new(
752        // Configured via `datafusion.execution.spill_compression`.
753        spill_compression: SpillCompression,
754        schema: SchemaRef,
755        sort_options: Vec<SortOptions>,
756        null_equality: NullEquality,
757        streamed: SendableRecordBatchStream,
758        buffered: SendableRecordBatchStream,
759        on_streamed: Vec<Arc<dyn PhysicalExpr>>,
760        on_buffered: Vec<Arc<dyn PhysicalExpr>>,
761        filter: Option<JoinFilter>,
762        join_type: JoinType,
763        batch_size: usize,
764        join_metrics: SortMergeJoinMetrics,
765        reservation: MemoryReservation,
766        runtime_env: Arc<RuntimeEnv>,
767    ) -> Result<Self> {
768        let streamed_schema = streamed.schema();
769        let buffered_schema = buffered.schema();
770        let spill_manager = SpillManager::new(
771            Arc::clone(&runtime_env),
772            join_metrics.spill_metrics().clone(),
773            Arc::clone(&buffered_schema),
774        )
775        .with_compression_type(spill_compression);
776        Ok(Self {
777            state: SortMergeJoinState::Init,
778            sort_options,
779            null_equality,
780            schema: Arc::clone(&schema),
781            streamed_schema: Arc::clone(&streamed_schema),
782            buffered_schema,
783            streamed,
784            buffered,
785            streamed_batch: StreamedBatch::new_empty(streamed_schema),
786            buffered_data: BufferedData::default(),
787            streamed_joined: false,
788            buffered_joined: false,
789            streamed_state: StreamedState::Init,
790            buffered_state: BufferedState::Init,
791            current_ordering: Ordering::Equal,
792            on_streamed,
793            on_buffered,
794            filter,
795            staging_output_record_batches: JoinedRecordBatches {
796                batches: vec![],
797                filter_mask: BooleanBuilder::new(),
798                row_indices: UInt64Builder::new(),
799                batch_ids: vec![],
800            },
801            output: RecordBatch::new_empty(schema),
802            output_size: 0,
803            batch_size,
804            join_type,
805            join_metrics,
806            reservation,
807            runtime_env,
808            spill_manager,
809            streamed_batch_counter: AtomicUsize::new(0),
810        })
811    }
812
813    /// Poll next streamed row
814    fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
815        loop {
816            match &self.streamed_state {
817                StreamedState::Init => {
818                    if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows()
819                    {
820                        self.streamed_batch.idx += 1;
821                        self.streamed_state = StreamedState::Ready;
822                        return Poll::Ready(Some(Ok(())));
823                    } else {
824                        self.streamed_state = StreamedState::Polling;
825                    }
826                }
827                StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? {
828                    Poll::Pending => {
829                        return Poll::Pending;
830                    }
831                    Poll::Ready(None) => {
832                        self.streamed_state = StreamedState::Exhausted;
833                    }
834                    Poll::Ready(Some(batch)) => {
835                        if batch.num_rows() > 0 {
836                            self.freeze_streamed()?;
837                            self.join_metrics.input_batches().add(1);
838                            self.join_metrics.input_rows().add(batch.num_rows());
839                            self.streamed_batch =
840                                StreamedBatch::new(batch, &self.on_streamed);
841                            // Every incoming streaming batch should have its unique id
842                            // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation
843                            self.streamed_batch_counter
844                                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
845                            self.streamed_state = StreamedState::Ready;
846                        }
847                    }
848                },
849                StreamedState::Ready => {
850                    return Poll::Ready(Some(Ok(())));
851                }
852                StreamedState::Exhausted => {
853                    return Poll::Ready(None);
854                }
855            }
856        }
857    }
858
859    fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> {
860        // Shrink memory usage for in-memory batches only
861        if let BufferedBatchState::InMemory(_) = buffered_batch.batch {
862            self.reservation
863                .try_shrink(buffered_batch.size_estimation)?;
864        }
865        Ok(())
866    }
867
868    fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
869        match self.reservation.try_grow(buffered_batch.size_estimation) {
870            Ok(_) => {
871                self.join_metrics
872                    .peak_mem_used()
873                    .set_max(self.reservation.size());
874                Ok(())
875            }
876            Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
877                // Spill buffered batch to disk
878
879                match buffered_batch.batch {
880                    BufferedBatchState::InMemory(batch) => {
881                        let spill_file = self
882                            .spill_manager
883                            .spill_record_batch_and_finish(
884                                &[batch],
885                                "sort_merge_join_buffered_spill",
886                            )?
887                            .unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled
888
889                        buffered_batch.batch = BufferedBatchState::Spilled(spill_file);
890                        Ok(())
891                    }
892                    _ => internal_err!("Buffered batch has empty body"),
893                }
894            }
895            Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
896        }?;
897
898        self.buffered_data.batches.push_back(buffered_batch);
899        Ok(())
900    }
901
902    /// Poll next buffered batches
903    fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
904        loop {
905            match &self.buffered_state {
906                BufferedState::Init => {
907                    // pop previous buffered batches
908                    while !self.buffered_data.batches.is_empty() {
909                        let head_batch = self.buffered_data.head_batch();
910                        // If the head batch is fully processed, dequeue it and produce output of it.
911                        if head_batch.range.end == head_batch.num_rows {
912                            self.freeze_dequeuing_buffered()?;
913                            if let Some(mut buffered_batch) =
914                                self.buffered_data.batches.pop_front()
915                            {
916                                self.produce_buffered_not_matched(&mut buffered_batch)?;
917                                self.free_reservation(buffered_batch)?;
918                            }
919                        } else {
920                            // If the head batch is not fully processed, break the loop.
921                            // Streamed batch will be joined with the head batch in the next step.
922                            break;
923                        }
924                    }
925                    if self.buffered_data.batches.is_empty() {
926                        self.buffered_state = BufferedState::PollingFirst;
927                    } else {
928                        let tail_batch = self.buffered_data.tail_batch_mut();
929                        tail_batch.range.start = tail_batch.range.end;
930                        tail_batch.range.end += 1;
931                        self.buffered_state = BufferedState::PollingRest;
932                    }
933                }
934                BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? {
935                    Poll::Pending => {
936                        return Poll::Pending;
937                    }
938                    Poll::Ready(None) => {
939                        self.buffered_state = BufferedState::Exhausted;
940                        return Poll::Ready(None);
941                    }
942                    Poll::Ready(Some(batch)) => {
943                        self.join_metrics.input_batches().add(1);
944                        self.join_metrics.input_rows().add(batch.num_rows());
945
946                        if batch.num_rows() > 0 {
947                            let buffered_batch =
948                                BufferedBatch::new(batch, 0..1, &self.on_buffered);
949
950                            self.allocate_reservation(buffered_batch)?;
951                            self.buffered_state = BufferedState::PollingRest;
952                        }
953                    }
954                },
955                BufferedState::PollingRest => {
956                    if self.buffered_data.tail_batch().range.end
957                        < self.buffered_data.tail_batch().num_rows
958                    {
959                        while self.buffered_data.tail_batch().range.end
960                            < self.buffered_data.tail_batch().num_rows
961                        {
962                            if is_join_arrays_equal(
963                                &self.buffered_data.head_batch().join_arrays,
964                                self.buffered_data.head_batch().range.start,
965                                &self.buffered_data.tail_batch().join_arrays,
966                                self.buffered_data.tail_batch().range.end,
967                            )? {
968                                self.buffered_data.tail_batch_mut().range.end += 1;
969                            } else {
970                                self.buffered_state = BufferedState::Ready;
971                                return Poll::Ready(Some(Ok(())));
972                            }
973                        }
974                    } else {
975                        match self.buffered.poll_next_unpin(cx)? {
976                            Poll::Pending => {
977                                return Poll::Pending;
978                            }
979                            Poll::Ready(None) => {
980                                self.buffered_state = BufferedState::Ready;
981                            }
982                            Poll::Ready(Some(batch)) => {
983                                // Polling batches coming concurrently as multiple partitions
984                                self.join_metrics.input_batches().add(1);
985                                self.join_metrics.input_rows().add(batch.num_rows());
986                                if batch.num_rows() > 0 {
987                                    let buffered_batch = BufferedBatch::new(
988                                        batch,
989                                        0..0,
990                                        &self.on_buffered,
991                                    );
992                                    self.allocate_reservation(buffered_batch)?;
993                                }
994                            }
995                        }
996                    }
997                }
998                BufferedState::Ready => {
999                    return Poll::Ready(Some(Ok(())));
1000                }
1001                BufferedState::Exhausted => {
1002                    return Poll::Ready(None);
1003                }
1004            }
1005        }
1006    }
1007
1008    /// Get comparison result of streamed row and buffered batches
1009    fn compare_streamed_buffered(&self) -> Result<Ordering> {
1010        if self.streamed_state == StreamedState::Exhausted {
1011            return Ok(Ordering::Greater);
1012        }
1013        if !self.buffered_data.has_buffered_rows() {
1014            return Ok(Ordering::Less);
1015        }
1016
1017        compare_join_arrays(
1018            &self.streamed_batch.join_arrays,
1019            self.streamed_batch.idx,
1020            &self.buffered_data.head_batch().join_arrays,
1021            self.buffered_data.head_batch().range.start,
1022            &self.sort_options,
1023            self.null_equality,
1024        )
1025    }
1026
1027    /// Produce join and fill output buffer until reaching target batch size
1028    /// or the join is finished
1029    fn join_partial(&mut self) -> Result<()> {
1030        // Whether to join streamed rows
1031        let mut join_streamed = false;
1032        // Whether to join buffered rows
1033        let mut join_buffered = false;
1034        // For Mark join we store a dummy id to indicate the row has a match
1035        let mut mark_row_as_match = false;
1036
1037        // determine whether we need to join streamed/buffered rows
1038        match self.current_ordering {
1039            Ordering::Less => {
1040                if matches!(
1041                    self.join_type,
1042                    JoinType::Left
1043                        | JoinType::Right
1044                        | JoinType::Full
1045                        | JoinType::LeftAnti
1046                        | JoinType::RightAnti
1047                        | JoinType::LeftMark
1048                        | JoinType::RightMark
1049                ) {
1050                    join_streamed = !self.streamed_joined;
1051                }
1052            }
1053            Ordering::Equal => {
1054                if matches!(
1055                    self.join_type,
1056                    JoinType::LeftSemi
1057                        | JoinType::LeftMark
1058                        | JoinType::RightSemi
1059                        | JoinType::RightMark
1060                ) {
1061                    mark_row_as_match = matches!(
1062                        self.join_type,
1063                        JoinType::LeftMark | JoinType::RightMark
1064                    );
1065                    // if the join filter is specified then its needed to output the streamed index
1066                    // only if it has not been emitted before
1067                    // the `join_filter_matched_idxs` keeps track on if streamed index has a successful
1068                    // filter match and prevents the same index to go into output more than once
1069                    if self.filter.is_some() {
1070                        join_streamed = !self
1071                            .streamed_batch
1072                            .join_filter_matched_idxs
1073                            .contains(&(self.streamed_batch.idx as u64))
1074                            && !self.streamed_joined;
1075                        // if the join filter specified there can be references to buffered columns
1076                        // so buffered columns are needed to access them
1077                        join_buffered = join_streamed;
1078                    } else {
1079                        join_streamed = !self.streamed_joined;
1080                    }
1081                }
1082                if matches!(
1083                    self.join_type,
1084                    JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
1085                ) {
1086                    join_streamed = true;
1087                    join_buffered = true;
1088                };
1089
1090                if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti)
1091                    && self.filter.is_some()
1092                {
1093                    join_streamed = !self.streamed_joined;
1094                    join_buffered = join_streamed;
1095                }
1096            }
1097            Ordering::Greater => {
1098                if matches!(self.join_type, JoinType::Full) {
1099                    join_buffered = !self.buffered_joined;
1100                };
1101            }
1102        }
1103        if !join_streamed && !join_buffered {
1104            // no joined data
1105            self.buffered_data.scanning_finish();
1106            return Ok(());
1107        }
1108
1109        if join_buffered {
1110            // joining streamed/nulls and buffered
1111            while !self.buffered_data.scanning_finished()
1112                && self.output_size < self.batch_size
1113            {
1114                let scanning_idx = self.buffered_data.scanning_idx();
1115                if join_streamed {
1116                    // Join streamed row and buffered row
1117                    self.streamed_batch.append_output_pair(
1118                        Some(self.buffered_data.scanning_batch_idx),
1119                        Some(scanning_idx),
1120                    );
1121                } else {
1122                    // Join nulls and buffered row for FULL join
1123                    self.buffered_data
1124                        .scanning_batch_mut()
1125                        .null_joined
1126                        .push(scanning_idx);
1127                }
1128                self.output_size += 1;
1129                self.buffered_data.scanning_advance();
1130
1131                if self.buffered_data.scanning_finished() {
1132                    self.streamed_joined = join_streamed;
1133                    self.buffered_joined = true;
1134                }
1135            }
1136        } else {
1137            // joining streamed and nulls
1138            let scanning_batch_idx = if self.buffered_data.scanning_finished() {
1139                None
1140            } else {
1141                Some(self.buffered_data.scanning_batch_idx)
1142            };
1143            // For Mark join we store a dummy id to indicate the row has a match
1144            let scanning_idx = mark_row_as_match.then_some(0);
1145
1146            self.streamed_batch
1147                .append_output_pair(scanning_batch_idx, scanning_idx);
1148            self.output_size += 1;
1149            self.buffered_data.scanning_finish();
1150            self.streamed_joined = true;
1151        }
1152        Ok(())
1153    }
1154
1155    fn freeze_all(&mut self) -> Result<()> {
1156        self.freeze_buffered(self.buffered_data.batches.len())?;
1157        self.freeze_streamed()?;
1158        Ok(())
1159    }
1160
1161    // Produces and stages record batches to ensure dequeued buffered batch
1162    // no longer needed:
1163    //   1. freezes all indices joined to streamed side
1164    //   2. freezes NULLs joined to dequeued buffered batch to "release" it
1165    fn freeze_dequeuing_buffered(&mut self) -> Result<()> {
1166        self.freeze_streamed()?;
1167        // Only freeze and produce the first batch in buffered_data as the batch is fully processed
1168        self.freeze_buffered(1)?;
1169        Ok(())
1170    }
1171
1172    // Produces and stages record batch from buffered indices with corresponding
1173    // NULLs on streamed side.
1174    //
1175    // Applicable only in case of Full join.
1176    //
1177    fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> {
1178        if !matches!(self.join_type, JoinType::Full) {
1179            return Ok(());
1180        }
1181        for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
1182            let buffered_indices = UInt64Array::from_iter_values(
1183                buffered_batch.null_joined.iter().map(|&index| index as u64),
1184            );
1185            if let Some(record_batch) = produce_buffered_null_batch(
1186                &self.schema,
1187                &self.streamed_schema,
1188                &buffered_indices,
1189                buffered_batch,
1190            )? {
1191                let num_rows = record_batch.num_rows();
1192                self.staging_output_record_batches
1193                    .filter_mask
1194                    .append_nulls(num_rows);
1195                self.staging_output_record_batches
1196                    .row_indices
1197                    .append_nulls(num_rows);
1198                self.staging_output_record_batches.batch_ids.resize(
1199                    self.staging_output_record_batches.batch_ids.len() + num_rows,
1200                    0,
1201                );
1202
1203                self.staging_output_record_batches
1204                    .batches
1205                    .push(record_batch);
1206            }
1207            buffered_batch.null_joined.clear();
1208        }
1209        Ok(())
1210    }
1211
1212    fn produce_buffered_not_matched(
1213        &mut self,
1214        buffered_batch: &mut BufferedBatch,
1215    ) -> Result<()> {
1216        if !matches!(self.join_type, JoinType::Full) {
1217            return Ok(());
1218        }
1219
1220        // For buffered row which is joined with streamed side rows but all joined rows
1221        // don't satisfy the join filter
1222        let not_matched_buffered_indices = buffered_batch
1223            .join_filter_not_matched_map
1224            .iter()
1225            .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None })
1226            .collect::<Vec<_>>();
1227
1228        let buffered_indices =
1229            UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied());
1230
1231        if let Some(record_batch) = produce_buffered_null_batch(
1232            &self.schema,
1233            &self.streamed_schema,
1234            &buffered_indices,
1235            buffered_batch,
1236        )? {
1237            let num_rows = record_batch.num_rows();
1238
1239            self.staging_output_record_batches
1240                .filter_mask
1241                .append_nulls(num_rows);
1242            self.staging_output_record_batches
1243                .row_indices
1244                .append_nulls(num_rows);
1245            self.staging_output_record_batches.batch_ids.resize(
1246                self.staging_output_record_batches.batch_ids.len() + num_rows,
1247                0,
1248            );
1249            self.staging_output_record_batches
1250                .batches
1251                .push(record_batch);
1252        }
1253        buffered_batch.join_filter_not_matched_map.clear();
1254
1255        Ok(())
1256    }
1257
1258    // Produces and stages record batch for all output indices found
1259    // for current streamed batch and clears staged output indices.
1260    fn freeze_streamed(&mut self) -> Result<()> {
1261        for chunk in self.streamed_batch.output_indices.iter_mut() {
1262            // The row indices of joined streamed batch
1263            let left_indices = chunk.streamed_indices.finish();
1264
1265            if left_indices.is_empty() {
1266                continue;
1267            }
1268
1269            let mut left_columns = self
1270                .streamed_batch
1271                .batch
1272                .columns()
1273                .iter()
1274                .map(|column| take(column, &left_indices, None))
1275                .collect::<Result<Vec<_>, ArrowError>>()?;
1276
1277            // The row indices of joined buffered batch
1278            let right_indices: UInt64Array = chunk.buffered_indices.finish();
1279            let mut right_columns =
1280                if matches!(self.join_type, JoinType::LeftMark | JoinType::RightMark) {
1281                    vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
1282                } else if matches!(
1283                    self.join_type,
1284                    JoinType::LeftSemi
1285                        | JoinType::LeftAnti
1286                        | JoinType::RightAnti
1287                        | JoinType::RightSemi
1288                ) {
1289                    vec![]
1290                } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
1291                    fetch_right_columns_by_idxs(
1292                        &self.buffered_data,
1293                        buffered_idx,
1294                        &right_indices,
1295                    )?
1296                } else {
1297                    // If buffered batch none, meaning it is null joined batch.
1298                    // We need to create null arrays for buffered columns to join with streamed rows.
1299                    create_unmatched_columns(
1300                        self.join_type,
1301                        &self.buffered_schema,
1302                        right_indices.len(),
1303                    )
1304                };
1305
1306            // Prepare the columns we apply join filter on later.
1307            // Only for joined rows between streamed and buffered.
1308            let filter_columns = if chunk.buffered_batch_idx.is_some() {
1309                if !matches!(self.join_type, JoinType::Right) {
1310                    if matches!(
1311                        self.join_type,
1312                        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
1313                    ) {
1314                        let right_cols = fetch_right_columns_by_idxs(
1315                            &self.buffered_data,
1316                            chunk.buffered_batch_idx.unwrap(),
1317                            &right_indices,
1318                        )?;
1319
1320                        get_filter_column(&self.filter, &left_columns, &right_cols)
1321                    } else if matches!(
1322                        self.join_type,
1323                        JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark
1324                    ) {
1325                        let right_cols = fetch_right_columns_by_idxs(
1326                            &self.buffered_data,
1327                            chunk.buffered_batch_idx.unwrap(),
1328                            &right_indices,
1329                        )?;
1330
1331                        get_filter_column(&self.filter, &right_cols, &left_columns)
1332                    } else {
1333                        get_filter_column(&self.filter, &left_columns, &right_columns)
1334                    }
1335                } else {
1336                    get_filter_column(&self.filter, &right_columns, &left_columns)
1337                }
1338            } else {
1339                // This chunk is totally for null joined rows (outer join), we don't need to apply join filter.
1340                // Any join filter applied only on either streamed or buffered side will be pushed already.
1341                vec![]
1342            };
1343
1344            let columns = if !matches!(self.join_type, JoinType::Right) {
1345                left_columns.extend(right_columns);
1346                left_columns
1347            } else {
1348                right_columns.extend(left_columns);
1349                right_columns
1350            };
1351
1352            let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
1353            // Apply join filter if any
1354            if !filter_columns.is_empty() {
1355                if let Some(f) = &self.filter {
1356                    // Construct batch with only filter columns
1357                    let filter_batch =
1358                        RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?;
1359
1360                    let filter_result = f
1361                        .expression()
1362                        .evaluate(&filter_batch)?
1363                        .into_array(filter_batch.num_rows())?;
1364
1365                    // The boolean selection mask of the join filter result
1366                    let pre_mask =
1367                        datafusion_common::cast::as_boolean_array(&filter_result)?;
1368
1369                    // If there are nulls in join filter result, exclude them from selecting
1370                    // the rows to output.
1371                    let mask = if pre_mask.null_count() > 0 {
1372                        compute::prep_null_mask_filter(
1373                            datafusion_common::cast::as_boolean_array(&filter_result)?,
1374                        )
1375                    } else {
1376                        pre_mask.clone()
1377                    };
1378
1379                    // Push the filtered batch which contains rows passing join filter to the output
1380                    if matches!(
1381                        self.join_type,
1382                        JoinType::Left
1383                            | JoinType::LeftSemi
1384                            | JoinType::Right
1385                            | JoinType::RightSemi
1386                            | JoinType::LeftAnti
1387                            | JoinType::RightAnti
1388                            | JoinType::LeftMark
1389                            | JoinType::RightMark
1390                            | JoinType::Full
1391                    ) {
1392                        self.staging_output_record_batches
1393                            .batches
1394                            .push(output_batch);
1395                    } else {
1396                        let filtered_batch = filter_record_batch(&output_batch, &mask)?;
1397                        self.staging_output_record_batches
1398                            .batches
1399                            .push(filtered_batch);
1400                    }
1401
1402                    if !matches!(self.join_type, JoinType::Full) {
1403                        self.staging_output_record_batches.filter_mask.extend(&mask);
1404                    } else {
1405                        self.staging_output_record_batches
1406                            .filter_mask
1407                            .extend(pre_mask);
1408                    }
1409                    self.staging_output_record_batches
1410                        .row_indices
1411                        .extend(&left_indices);
1412                    self.staging_output_record_batches.batch_ids.resize(
1413                        self.staging_output_record_batches.batch_ids.len()
1414                            + left_indices.len(),
1415                        self.streamed_batch_counter.load(Relaxed),
1416                    );
1417
1418                    // For outer joins, we need to push the null joined rows to the output if
1419                    // all joined rows are failed on the join filter.
1420                    // I.e., if all rows joined from a streamed row are failed with the join filter,
1421                    // we need to join it with nulls as buffered side.
1422                    if matches!(self.join_type, JoinType::Full) {
1423                        let buffered_batch = &mut self.buffered_data.batches
1424                            [chunk.buffered_batch_idx.unwrap()];
1425
1426                        for i in 0..pre_mask.len() {
1427                            // If the buffered row is not joined with streamed side,
1428                            // skip it.
1429                            if right_indices.is_null(i) {
1430                                continue;
1431                            }
1432
1433                            let buffered_index = right_indices.value(i);
1434
1435                            buffered_batch.join_filter_not_matched_map.insert(
1436                                buffered_index,
1437                                *buffered_batch
1438                                    .join_filter_not_matched_map
1439                                    .get(&buffered_index)
1440                                    .unwrap_or(&true)
1441                                    && !pre_mask.value(i),
1442                            );
1443                        }
1444                    }
1445                } else {
1446                    self.staging_output_record_batches
1447                        .batches
1448                        .push(output_batch);
1449                }
1450            } else {
1451                self.staging_output_record_batches
1452                    .batches
1453                    .push(output_batch);
1454            }
1455        }
1456
1457        self.streamed_batch.output_indices.clear();
1458
1459        Ok(())
1460    }
1461
1462    fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
1463        let record_batch =
1464            concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
1465        self.join_metrics.output_batches().add(1);
1466        self.join_metrics
1467            .baseline_metrics()
1468            .record_output(record_batch.num_rows());
1469        // If join filter exists, `self.output_size` is not accurate as we don't know the exact
1470        // number of rows in the output record batch. If streamed row joined with buffered rows,
1471        // once join filter is applied, the number of output rows may be more than 1.
1472        // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened
1473        // when the join filter is applied and all rows are filtered out.
1474        if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size {
1475            self.output_size = 0;
1476        } else {
1477            self.output_size -= record_batch.num_rows();
1478        }
1479
1480        if !(self.filter.is_some()
1481            && matches!(
1482                self.join_type,
1483                JoinType::Left
1484                    | JoinType::LeftSemi
1485                    | JoinType::Right
1486                    | JoinType::RightSemi
1487                    | JoinType::LeftAnti
1488                    | JoinType::RightAnti
1489                    | JoinType::LeftMark
1490                    | JoinType::RightMark
1491                    | JoinType::Full
1492            ))
1493        {
1494            self.staging_output_record_batches.batches.clear();
1495        }
1496
1497        Ok(record_batch)
1498    }
1499
1500    fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
1501        let record_batch =
1502            concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
1503        let mut out_indices = self.staging_output_record_batches.row_indices.finish();
1504        let mut out_mask = self.staging_output_record_batches.filter_mask.finish();
1505        let mut batch_ids = &self.staging_output_record_batches.batch_ids;
1506        let default_batch_ids = vec![0; record_batch.num_rows()];
1507
1508        // If only nulls come in and indices sizes doesn't match with expected record batch count
1509        // generate missing indices
1510        // Happens for null joined batches for Full Join
1511        if out_indices.null_count() == out_indices.len()
1512            && out_indices.len() != record_batch.num_rows()
1513        {
1514            out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]);
1515            out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]);
1516            batch_ids = &default_batch_ids;
1517        }
1518
1519        if out_mask.is_empty() {
1520            self.staging_output_record_batches.batches.clear();
1521            return Ok(record_batch);
1522        }
1523
1524        let maybe_corrected_mask = get_corrected_filter_mask(
1525            self.join_type,
1526            &out_indices,
1527            batch_ids,
1528            &out_mask,
1529            record_batch.num_rows(),
1530        );
1531
1532        let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask {
1533            filtered_join_mask
1534        } else {
1535            &out_mask
1536        };
1537
1538        self.filter_record_batch_by_join_type(record_batch, corrected_mask)
1539    }
1540
1541    fn filter_record_batch_by_join_type(
1542        &mut self,
1543        record_batch: RecordBatch,
1544        corrected_mask: &BooleanArray,
1545    ) -> Result<RecordBatch> {
1546        let mut filtered_record_batch =
1547            filter_record_batch(&record_batch, corrected_mask)?;
1548        let left_columns_length = self.streamed_schema.fields.len();
1549        let right_columns_length = self.buffered_schema.fields.len();
1550
1551        if matches!(
1552            self.join_type,
1553            JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark
1554        ) {
1555            let null_mask = compute::not(corrected_mask)?;
1556            let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?;
1557
1558            let mut right_columns = create_unmatched_columns(
1559                self.join_type,
1560                &self.buffered_schema,
1561                null_joined_batch.num_rows(),
1562            );
1563
1564            let columns = if !matches!(self.join_type, JoinType::Right) {
1565                let mut left_columns = null_joined_batch
1566                    .columns()
1567                    .iter()
1568                    .take(right_columns_length)
1569                    .cloned()
1570                    .collect::<Vec<_>>();
1571
1572                left_columns.extend(right_columns);
1573                left_columns
1574            } else {
1575                let left_columns = null_joined_batch
1576                    .columns()
1577                    .iter()
1578                    .skip(left_columns_length)
1579                    .cloned()
1580                    .collect::<Vec<_>>();
1581
1582                right_columns.extend(left_columns);
1583                right_columns
1584            };
1585
1586            // Push the streamed/buffered batch joined nulls to the output
1587            let null_joined_streamed_batch =
1588                RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
1589
1590            filtered_record_batch = concat_batches(
1591                &self.schema,
1592                &[filtered_record_batch, null_joined_streamed_batch],
1593            )?;
1594        } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
1595            let output_column_indices = (0..left_columns_length).collect::<Vec<_>>();
1596            filtered_record_batch =
1597                filtered_record_batch.project(&output_column_indices)?;
1598        } else if matches!(self.join_type, JoinType::RightAnti | JoinType::RightSemi) {
1599            let output_column_indices = (0..right_columns_length).collect::<Vec<_>>();
1600            filtered_record_batch =
1601                filtered_record_batch.project(&output_column_indices)?;
1602        } else if matches!(self.join_type, JoinType::Full)
1603            && corrected_mask.false_count() > 0
1604        {
1605            // Find rows which joined by key but Filter predicate evaluated as false
1606            let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
1607            let joined_filter_not_matched_batch =
1608                filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?;
1609
1610            // Add left unmatched rows adding the right side as nulls
1611            let right_null_columns = self
1612                .buffered_schema
1613                .fields()
1614                .iter()
1615                .map(|f| {
1616                    new_null_array(
1617                        f.data_type(),
1618                        joined_filter_not_matched_batch.num_rows(),
1619                    )
1620                })
1621                .collect::<Vec<_>>();
1622
1623            let mut result_joined = joined_filter_not_matched_batch
1624                .columns()
1625                .iter()
1626                .take(left_columns_length)
1627                .cloned()
1628                .collect::<Vec<_>>();
1629
1630            result_joined.extend(right_null_columns);
1631
1632            let left_null_joined_batch =
1633                RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?;
1634
1635            // Add right unmatched rows adding the left side as nulls
1636            let mut result_joined = self
1637                .streamed_schema
1638                .fields()
1639                .iter()
1640                .map(|f| {
1641                    new_null_array(
1642                        f.data_type(),
1643                        joined_filter_not_matched_batch.num_rows(),
1644                    )
1645                })
1646                .collect::<Vec<_>>();
1647
1648            let right_data = joined_filter_not_matched_batch
1649                .columns()
1650                .iter()
1651                .skip(left_columns_length)
1652                .cloned()
1653                .collect::<Vec<_>>();
1654
1655            result_joined.extend(right_data);
1656
1657            filtered_record_batch = concat_batches(
1658                &self.schema,
1659                &[filtered_record_batch, left_null_joined_batch],
1660            )?;
1661        }
1662
1663        self.staging_output_record_batches.clear();
1664
1665        Ok(filtered_record_batch)
1666    }
1667}
1668
1669fn create_unmatched_columns(
1670    join_type: JoinType,
1671    schema: &SchemaRef,
1672    size: usize,
1673) -> Vec<ArrayRef> {
1674    if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
1675        vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
1676    } else {
1677        schema
1678            .fields()
1679            .iter()
1680            .map(|f| new_null_array(f.data_type(), size))
1681            .collect::<Vec<_>>()
1682    }
1683}
1684
1685/// Gets the arrays which join filters are applied on.
1686fn get_filter_column(
1687    join_filter: &Option<JoinFilter>,
1688    streamed_columns: &[ArrayRef],
1689    buffered_columns: &[ArrayRef],
1690) -> Vec<ArrayRef> {
1691    let mut filter_columns = vec![];
1692
1693    if let Some(f) = join_filter {
1694        let left_columns = f
1695            .column_indices()
1696            .iter()
1697            .filter(|col_index| col_index.side == JoinSide::Left)
1698            .map(|i| Arc::clone(&streamed_columns[i.index]))
1699            .collect::<Vec<_>>();
1700
1701        let right_columns = f
1702            .column_indices()
1703            .iter()
1704            .filter(|col_index| col_index.side == JoinSide::Right)
1705            .map(|i| Arc::clone(&buffered_columns[i.index]))
1706            .collect::<Vec<_>>();
1707
1708        filter_columns.extend(left_columns);
1709        filter_columns.extend(right_columns);
1710    }
1711
1712    filter_columns
1713}
1714
1715fn produce_buffered_null_batch(
1716    schema: &SchemaRef,
1717    streamed_schema: &SchemaRef,
1718    buffered_indices: &PrimitiveArray<UInt64Type>,
1719    buffered_batch: &BufferedBatch,
1720) -> Result<Option<RecordBatch>> {
1721    if buffered_indices.is_empty() {
1722        return Ok(None);
1723    }
1724
1725    // Take buffered (right) columns
1726    let right_columns =
1727        fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?;
1728
1729    // Create null streamed (left) columns
1730    let mut left_columns = streamed_schema
1731        .fields()
1732        .iter()
1733        .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
1734        .collect::<Vec<_>>();
1735
1736    left_columns.extend(right_columns);
1737
1738    Ok(Some(RecordBatch::try_new(
1739        Arc::clone(schema),
1740        left_columns,
1741    )?))
1742}
1743
1744/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices
1745#[inline(always)]
1746fn fetch_right_columns_by_idxs(
1747    buffered_data: &BufferedData,
1748    buffered_batch_idx: usize,
1749    buffered_indices: &UInt64Array,
1750) -> Result<Vec<ArrayRef>> {
1751    fetch_right_columns_from_batch_by_idxs(
1752        &buffered_data.batches[buffered_batch_idx],
1753        buffered_indices,
1754    )
1755}
1756
1757#[inline(always)]
1758fn fetch_right_columns_from_batch_by_idxs(
1759    buffered_batch: &BufferedBatch,
1760    buffered_indices: &UInt64Array,
1761) -> Result<Vec<ArrayRef>> {
1762    match &buffered_batch.batch {
1763        // In memory batch
1764        BufferedBatchState::InMemory(batch) => Ok(batch
1765            .columns()
1766            .iter()
1767            .map(|column| take(column, &buffered_indices, None))
1768            .collect::<Result<Vec<_>, ArrowError>>()
1769            .map_err(Into::<DataFusionError>::into)?),
1770        // If the batch was spilled to disk, less likely
1771        BufferedBatchState::Spilled(spill_file) => {
1772            let mut buffered_cols: Vec<ArrayRef> =
1773                Vec::with_capacity(buffered_indices.len());
1774
1775            let file = BufReader::new(File::open(spill_file.path())?);
1776            let reader = StreamReader::try_new(file, None)?;
1777
1778            for batch in reader {
1779                batch?.columns().iter().for_each(|column| {
1780                    buffered_cols.extend(take(column, &buffered_indices, None))
1781                });
1782            }
1783
1784            Ok(buffered_cols)
1785        }
1786    }
1787}
1788
1789/// Buffered data contains all buffered batches with one unique join key
1790#[derive(Debug, Default)]
1791pub(super) struct BufferedData {
1792    /// Buffered batches with the same key
1793    pub batches: VecDeque<BufferedBatch>,
1794    /// current scanning batch index used in join_partial()
1795    pub scanning_batch_idx: usize,
1796    /// current scanning offset used in join_partial()
1797    pub scanning_offset: usize,
1798}
1799
1800impl BufferedData {
1801    pub fn head_batch(&self) -> &BufferedBatch {
1802        self.batches.front().unwrap()
1803    }
1804
1805    pub fn tail_batch(&self) -> &BufferedBatch {
1806        self.batches.back().unwrap()
1807    }
1808
1809    pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch {
1810        self.batches.back_mut().unwrap()
1811    }
1812
1813    pub fn has_buffered_rows(&self) -> bool {
1814        self.batches.iter().any(|batch| !batch.range.is_empty())
1815    }
1816
1817    pub fn scanning_reset(&mut self) {
1818        self.scanning_batch_idx = 0;
1819        self.scanning_offset = 0;
1820    }
1821
1822    pub fn scanning_advance(&mut self) {
1823        self.scanning_offset += 1;
1824        while !self.scanning_finished() && self.scanning_batch_finished() {
1825            self.scanning_batch_idx += 1;
1826            self.scanning_offset = 0;
1827        }
1828    }
1829
1830    pub fn scanning_batch(&self) -> &BufferedBatch {
1831        &self.batches[self.scanning_batch_idx]
1832    }
1833
1834    pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch {
1835        &mut self.batches[self.scanning_batch_idx]
1836    }
1837
1838    pub fn scanning_idx(&self) -> usize {
1839        self.scanning_batch().range.start + self.scanning_offset
1840    }
1841
1842    pub fn scanning_batch_finished(&self) -> bool {
1843        self.scanning_offset == self.scanning_batch().range.len()
1844    }
1845
1846    pub fn scanning_finished(&self) -> bool {
1847        self.scanning_batch_idx == self.batches.len()
1848    }
1849
1850    pub fn scanning_finish(&mut self) {
1851        self.scanning_batch_idx = self.batches.len();
1852        self.scanning_offset = 0;
1853    }
1854}
1855
1856/// Get join array refs of given batch and join columns
1857fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec<ArrayRef> {
1858    on_column
1859        .iter()
1860        .map(|c| {
1861            let num_rows = batch.num_rows();
1862            let c = c.evaluate(batch).unwrap();
1863            c.into_array(num_rows).unwrap()
1864        })
1865        .collect()
1866}
1867
1868/// A faster version of compare_join_arrays() that only output whether
1869/// the given two rows are equal
1870fn is_join_arrays_equal(
1871    left_arrays: &[ArrayRef],
1872    left: usize,
1873    right_arrays: &[ArrayRef],
1874    right: usize,
1875) -> Result<bool> {
1876    let mut is_equal = true;
1877    for (left_array, right_array) in left_arrays.iter().zip(right_arrays) {
1878        macro_rules! compare_value {
1879            ($T:ty) => {{
1880                match (left_array.is_null(left), right_array.is_null(right)) {
1881                    (false, false) => {
1882                        let left_array =
1883                            left_array.as_any().downcast_ref::<$T>().unwrap();
1884                        let right_array =
1885                            right_array.as_any().downcast_ref::<$T>().unwrap();
1886                        if left_array.value(left) != right_array.value(right) {
1887                            is_equal = false;
1888                        }
1889                    }
1890                    (true, false) => is_equal = false,
1891                    (false, true) => is_equal = false,
1892                    _ => {}
1893                }
1894            }};
1895        }
1896
1897        match left_array.data_type() {
1898            DataType::Null => {}
1899            DataType::Boolean => compare_value!(BooleanArray),
1900            DataType::Int8 => compare_value!(Int8Array),
1901            DataType::Int16 => compare_value!(Int16Array),
1902            DataType::Int32 => compare_value!(Int32Array),
1903            DataType::Int64 => compare_value!(Int64Array),
1904            DataType::UInt8 => compare_value!(UInt8Array),
1905            DataType::UInt16 => compare_value!(UInt16Array),
1906            DataType::UInt32 => compare_value!(UInt32Array),
1907            DataType::UInt64 => compare_value!(UInt64Array),
1908            DataType::Float32 => compare_value!(Float32Array),
1909            DataType::Float64 => compare_value!(Float64Array),
1910            DataType::Utf8 => compare_value!(StringArray),
1911            DataType::Utf8View => compare_value!(StringViewArray),
1912            DataType::LargeUtf8 => compare_value!(LargeStringArray),
1913            DataType::Binary => compare_value!(BinaryArray),
1914            DataType::BinaryView => compare_value!(BinaryViewArray),
1915            DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
1916            DataType::LargeBinary => compare_value!(LargeBinaryArray),
1917            DataType::Decimal32(..) => compare_value!(Decimal32Array),
1918            DataType::Decimal64(..) => compare_value!(Decimal64Array),
1919            DataType::Decimal128(..) => compare_value!(Decimal128Array),
1920            DataType::Decimal256(..) => compare_value!(Decimal256Array),
1921            DataType::Timestamp(time_unit, None) => match time_unit {
1922                TimeUnit::Second => compare_value!(TimestampSecondArray),
1923                TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
1924                TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
1925                TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
1926            },
1927            DataType::Date32 => compare_value!(Date32Array),
1928            DataType::Date64 => compare_value!(Date64Array),
1929            dt => {
1930                return not_impl_err!(
1931                    "Unsupported data type in sort merge join comparator: {}",
1932                    dt
1933                );
1934            }
1935        }
1936        if !is_equal {
1937            return Ok(false);
1938        }
1939    }
1940    Ok(true)
1941}