datafusion_physical_plan/joins/piecewise_merge_join/
classic_join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream Implementation for PiecewiseMergeJoin's Classic Join (Left, Right, Full, Inner)
19
20use arrow::array::{new_null_array, Array, PrimitiveBuilder};
21use arrow::compute::{take, BatchCoalescer};
22use arrow::datatypes::UInt32Type;
23use arrow::{
24    array::{ArrayRef, RecordBatch, UInt32Array},
25    compute::{sort_to_indices, take_record_batch},
26};
27use arrow_schema::{Schema, SchemaRef, SortOptions};
28use datafusion_common::NullEquality;
29use datafusion_common::{internal_err, Result};
30use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream};
31use datafusion_expr::{JoinType, Operator};
32use datafusion_physical_expr::PhysicalExprRef;
33use futures::{Stream, StreamExt};
34use std::{cmp::Ordering, task::ready};
35use std::{sync::Arc, task::Poll};
36
37use crate::handle_state;
38use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState};
39use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final;
40use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap};
41use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult};
42
43pub(super) enum PiecewiseMergeJoinStreamState {
44    WaitBufferedSide,
45    FetchStreamBatch,
46    ProcessStreamBatch(SortedStreamBatch),
47    ProcessUnmatched,
48    Completed,
49}
50
51impl PiecewiseMergeJoinStreamState {
52    // Grab mutable reference to the current stream batch
53    fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut SortedStreamBatch> {
54        match self {
55            PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state),
56            _ => internal_err!("Expected streamed batch in StreamBatch"),
57        }
58    }
59}
60
61/// The stream side incoming batch with required sort order.
62///
63/// Note the compare key in the join predicate might include expressions on the original
64/// columns, so we store the evaluated compare key separately.
65/// e.g. For join predicate `buffer.v1 < (stream.v1 + 1)`, the `compare_key_values` field stores
66/// the evaluated `stream.v1 + 1` array.
67pub(super) struct SortedStreamBatch {
68    pub batch: RecordBatch,
69    compare_key_values: Vec<ArrayRef>,
70}
71
72impl SortedStreamBatch {
73    #[allow(dead_code)]
74    fn new(batch: RecordBatch, compare_key_values: Vec<ArrayRef>) -> Self {
75        Self {
76            batch,
77            compare_key_values,
78        }
79    }
80
81    fn compare_key_values(&self) -> &Vec<ArrayRef> {
82        &self.compare_key_values
83    }
84}
85
86pub(super) struct ClassicPWMJStream {
87    // Output schema of the `PiecewiseMergeJoin`
88    pub schema: Arc<Schema>,
89
90    // Physical expression that is evaluated on the streamed side
91    // We do not need on_buffered as this is already evaluated when
92    // creating the buffered side which happens before initializing
93    // `PiecewiseMergeJoinStream`
94    pub on_streamed: PhysicalExprRef,
95    // Type of join
96    pub join_type: JoinType,
97    // Comparison operator
98    pub operator: Operator,
99    // Streamed batch
100    pub streamed: SendableRecordBatchStream,
101    // Streamed schema
102    streamed_schema: SchemaRef,
103    // Buffered side data
104    buffered_side: BufferedSide,
105    // Tracks the state of the `PiecewiseMergeJoin`
106    state: PiecewiseMergeJoinStreamState,
107    // Sort option for streamed side (specifies whether
108    // the sort is ascending or descending)
109    sort_option: SortOptions,
110    // Metrics for build + probe joins
111    join_metrics: BuildProbeJoinMetrics,
112    // Tracking incremental state for emitting record batches
113    batch_process_state: BatchProcessState,
114}
115
116impl RecordBatchStream for ClassicPWMJStream {
117    fn schema(&self) -> SchemaRef {
118        Arc::clone(&self.schema)
119    }
120}
121
122// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`,
123// `ProcessStreamBatch`, `ProcessUnmatched` and `Completed`.
124//
125// Classic Joins
126//  1. `WaitBufferedSide` - Load in the buffered side data into memory.
127//  2. `FetchStreamBatch` -  Fetch + sort incoming stream batches. We switch the state to
128//     `Completed` if there are are still remaining partitions to process. It is only switched to
129//     `ExhaustedStreamBatch` if all partitions have been processed.
130//  3. `ProcessStreamBatch` - Compare stream batch row values against the buffered side data.
131//  4. `ExhaustedStreamBatch` - If the join type is Left or Inner we will return state as
132//      `Completed` however for Full and Right we will need to process the unmatched buffered rows.
133impl ClassicPWMJStream {
134    // Creates a new `PiecewiseMergeJoinStream` instance
135    #[allow(clippy::too_many_arguments)]
136    pub fn try_new(
137        schema: Arc<Schema>,
138        on_streamed: PhysicalExprRef,
139        join_type: JoinType,
140        operator: Operator,
141        streamed: SendableRecordBatchStream,
142        buffered_side: BufferedSide,
143        state: PiecewiseMergeJoinStreamState,
144        sort_option: SortOptions,
145        join_metrics: BuildProbeJoinMetrics,
146        batch_size: usize,
147    ) -> Self {
148        Self {
149            schema: Arc::clone(&schema),
150            on_streamed,
151            join_type,
152            operator,
153            streamed_schema: streamed.schema(),
154            streamed,
155            buffered_side,
156            state,
157            sort_option,
158            join_metrics,
159            batch_process_state: BatchProcessState::new(schema, batch_size),
160        }
161    }
162
163    fn poll_next_impl(
164        &mut self,
165        cx: &mut std::task::Context<'_>,
166    ) -> Poll<Option<Result<RecordBatch>>> {
167        loop {
168            return match self.state {
169                PiecewiseMergeJoinStreamState::WaitBufferedSide => {
170                    handle_state!(ready!(self.collect_buffered_side(cx)))
171                }
172                PiecewiseMergeJoinStreamState::FetchStreamBatch => {
173                    handle_state!(ready!(self.fetch_stream_batch(cx)))
174                }
175                PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => {
176                    handle_state!(self.process_stream_batch())
177                }
178                PiecewiseMergeJoinStreamState::ProcessUnmatched => {
179                    handle_state!(self.process_unmatched_buffered_batch())
180                }
181                PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None),
182            };
183        }
184    }
185
186    // Collects buffered side data
187    fn collect_buffered_side(
188        &mut self,
189        cx: &mut std::task::Context<'_>,
190    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
191        let build_timer = self.join_metrics.build_time.timer();
192        let buffered_data = ready!(self
193            .buffered_side
194            .try_as_initial_mut()?
195            .buffered_fut
196            .get_shared(cx))?;
197        build_timer.done();
198
199        // We will start fetching stream batches for classic joins
200        self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch;
201
202        self.buffered_side =
203            BufferedSide::Ready(BufferedSideReadyState { buffered_data });
204
205        Poll::Ready(Ok(StatefulStreamResult::Continue))
206    }
207
208    // Fetches incoming stream batches
209    fn fetch_stream_batch(
210        &mut self,
211        cx: &mut std::task::Context<'_>,
212    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
213        match ready!(self.streamed.poll_next_unpin(cx)) {
214            None => {
215                if self
216                    .buffered_side
217                    .try_as_ready_mut()?
218                    .buffered_data
219                    .remaining_partitions
220                    .fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
221                    == 1
222                {
223                    self.batch_process_state.reset();
224                    self.state = PiecewiseMergeJoinStreamState::ProcessUnmatched;
225                } else {
226                    self.state = PiecewiseMergeJoinStreamState::Completed;
227                }
228            }
229            Some(Ok(batch)) => {
230                // Evaluate the streamed physical expression on the stream batch
231                let stream_values: ArrayRef = self
232                    .on_streamed
233                    .evaluate(&batch)?
234                    .into_array(batch.num_rows())?;
235
236                self.join_metrics.input_batches.add(1);
237                self.join_metrics.input_rows.add(batch.num_rows());
238
239                // Sort stream values and change the streamed record batch accordingly
240                let indices = sort_to_indices(
241                    stream_values.as_ref(),
242                    Some(self.sort_option),
243                    None,
244                )?;
245                let stream_batch = take_record_batch(&batch, &indices)?;
246                let stream_values = take(stream_values.as_ref(), &indices, None)?;
247
248                // Reset BatchProcessState before processing a new stream batch
249                self.batch_process_state.reset();
250                self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch(
251                    SortedStreamBatch {
252                        batch: stream_batch,
253                        compare_key_values: vec![stream_values],
254                    },
255                );
256            }
257            Some(Err(err)) => return Poll::Ready(Err(err)),
258        };
259
260        Poll::Ready(Ok(StatefulStreamResult::Continue))
261    }
262
263    // Only classic join will call. This function will process stream batches and evaluate against
264    // the buffered side data.
265    fn process_stream_batch(
266        &mut self,
267    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
268        let buffered_side = self.buffered_side.try_as_ready_mut()?;
269        let stream_batch = self.state.try_as_process_stream_batch_mut()?;
270
271        if let Some(batch) = self
272            .batch_process_state
273            .output_batches
274            .next_completed_batch()
275        {
276            return Ok(StatefulStreamResult::Ready(Some(batch)));
277        }
278
279        // Produce more work
280        let batch = resolve_classic_join(
281            buffered_side,
282            stream_batch,
283            Arc::clone(&self.schema),
284            self.operator,
285            self.sort_option,
286            self.join_type,
287            &mut self.batch_process_state,
288        )?;
289
290        if !self.batch_process_state.continue_process {
291            // We finished scanning this stream batch.
292            self.batch_process_state
293                .output_batches
294                .finish_buffered_batch()?;
295            if let Some(b) = self
296                .batch_process_state
297                .output_batches
298                .next_completed_batch()
299            {
300                self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch;
301                return Ok(StatefulStreamResult::Ready(Some(b)));
302            }
303
304            // Nothing pending; hand back whatever `resolve` returned (often empty) and move on.
305            if self.batch_process_state.output_batches.is_empty() {
306                self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch;
307
308                return Ok(StatefulStreamResult::Ready(Some(batch)));
309            }
310        }
311
312        Ok(StatefulStreamResult::Ready(Some(batch)))
313    }
314
315    // Process remaining unmatched rows
316    fn process_unmatched_buffered_batch(
317        &mut self,
318    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
319        // Return early for `JoinType::Right` and `JoinType::Inner`
320        if matches!(self.join_type, JoinType::Right | JoinType::Inner) {
321            self.state = PiecewiseMergeJoinStreamState::Completed;
322            return Ok(StatefulStreamResult::Ready(None));
323        }
324
325        if !self.batch_process_state.continue_process {
326            if let Some(batch) = self
327                .batch_process_state
328                .output_batches
329                .next_completed_batch()
330            {
331                return Ok(StatefulStreamResult::Ready(Some(batch)));
332            }
333
334            self.batch_process_state
335                .output_batches
336                .finish_buffered_batch()?;
337            if let Some(batch) = self
338                .batch_process_state
339                .output_batches
340                .next_completed_batch()
341            {
342                self.state = PiecewiseMergeJoinStreamState::Completed;
343                return Ok(StatefulStreamResult::Ready(Some(batch)));
344            }
345        }
346
347        let buffered_data =
348            Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data);
349
350        let (buffered_indices, _streamed_indices) = get_final_indices_from_shared_bitmap(
351            &buffered_data.visited_indices_bitmap,
352            self.join_type,
353            true,
354        );
355
356        let new_buffered_batch =
357            take_record_batch(buffered_data.batch(), &buffered_indices)?;
358        let mut buffered_columns = new_buffered_batch.columns().to_vec();
359
360        let streamed_columns: Vec<ArrayRef> = self
361            .streamed_schema
362            .fields()
363            .iter()
364            .map(|f| new_null_array(f.data_type(), new_buffered_batch.num_rows()))
365            .collect();
366
367        buffered_columns.extend(streamed_columns);
368
369        let batch = RecordBatch::try_new(Arc::clone(&self.schema), buffered_columns)?;
370
371        self.batch_process_state.output_batches.push_batch(batch)?;
372
373        self.batch_process_state.continue_process = false;
374        if let Some(batch) = self
375            .batch_process_state
376            .output_batches
377            .next_completed_batch()
378        {
379            return Ok(StatefulStreamResult::Ready(Some(batch)));
380        }
381
382        self.batch_process_state
383            .output_batches
384            .finish_buffered_batch()?;
385        if let Some(batch) = self
386            .batch_process_state
387            .output_batches
388            .next_completed_batch()
389        {
390            self.state = PiecewiseMergeJoinStreamState::Completed;
391            return Ok(StatefulStreamResult::Ready(Some(batch)));
392        }
393
394        self.state = PiecewiseMergeJoinStreamState::Completed;
395        self.batch_process_state.reset();
396        Ok(StatefulStreamResult::Ready(None))
397    }
398}
399
400struct BatchProcessState {
401    // Used to pick up from the last index on the stream side
402    output_batches: Box<BatchCoalescer>,
403    // Used to store the unmatched stream indices for `JoinType::Right` and `JoinType::Full`
404    unmatched_indices: PrimitiveBuilder<UInt32Type>,
405    // Used to store the start index on the buffered side; used to resume processing on the correct
406    // row
407    start_buffer_idx: usize,
408    // Used to store the start index on the stream side; used to resume processing on the correct
409    // row
410    start_stream_idx: usize,
411    // Signals if we found a match for the current stream row
412    found: bool,
413    // Signals to continue processing the current stream batch
414    continue_process: bool,
415    // Skip nulls
416    processed_null_count: bool,
417}
418
419impl BatchProcessState {
420    pub(crate) fn new(schema: Arc<Schema>, batch_size: usize) -> Self {
421        Self {
422            output_batches: Box::new(BatchCoalescer::new(schema, batch_size)),
423            unmatched_indices: PrimitiveBuilder::new(),
424            start_buffer_idx: 0,
425            start_stream_idx: 0,
426            found: false,
427            continue_process: true,
428            processed_null_count: false,
429        }
430    }
431
432    pub(crate) fn reset(&mut self) {
433        self.unmatched_indices = PrimitiveBuilder::new();
434        self.start_buffer_idx = 0;
435        self.start_stream_idx = 0;
436        self.found = false;
437        self.continue_process = true;
438        self.processed_null_count = false;
439    }
440}
441
442impl Stream for ClassicPWMJStream {
443    type Item = Result<RecordBatch>;
444
445    fn poll_next(
446        mut self: std::pin::Pin<&mut Self>,
447        cx: &mut std::task::Context<'_>,
448    ) -> Poll<Option<Self::Item>> {
449        self.poll_next_impl(cx)
450    }
451}
452
453// For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted.
454#[allow(clippy::too_many_arguments)]
455fn resolve_classic_join(
456    buffered_side: &mut BufferedSideReadyState,
457    stream_batch: &SortedStreamBatch,
458    join_schema: Arc<Schema>,
459    operator: Operator,
460    sort_options: SortOptions,
461    join_type: JoinType,
462    batch_process_state: &mut BatchProcessState,
463) -> Result<RecordBatch> {
464    let buffered_len = buffered_side.buffered_data.values().len();
465    let stream_values = stream_batch.compare_key_values();
466
467    let mut buffer_idx = batch_process_state.start_buffer_idx;
468    let mut stream_idx = batch_process_state.start_stream_idx;
469
470    if !batch_process_state.processed_null_count {
471        let buffered_null_idx = buffered_side.buffered_data.values().null_count();
472        let stream_null_idx = stream_values[0].null_count();
473        buffer_idx = buffered_null_idx;
474        stream_idx = stream_null_idx;
475        batch_process_state.processed_null_count = true;
476    }
477
478    // Our buffer_idx variable allows us to start probing on the buffered side where we last matched
479    // in the previous stream row.
480    for row_idx in stream_idx..stream_batch.batch.num_rows() {
481        while buffer_idx < buffered_len {
482            let compare = {
483                let buffered_values = buffered_side.buffered_data.values();
484                compare_join_arrays(
485                    &[Arc::clone(&stream_values[0])],
486                    row_idx,
487                    &[Arc::clone(buffered_values)],
488                    buffer_idx,
489                    &[sort_options],
490                    NullEquality::NullEqualsNothing,
491                )?
492            };
493
494            // If we find a match we append all indices and move to the next stream row index
495            match operator {
496                Operator::Gt | Operator::Lt => {
497                    if matches!(compare, Ordering::Less) {
498                        batch_process_state.found = true;
499                        let count = buffered_len - buffer_idx;
500
501                        let batch = build_matched_indices_and_set_buffered_bitmap(
502                            (buffer_idx, count),
503                            (row_idx, count),
504                            buffered_side,
505                            stream_batch,
506                            join_type,
507                            Arc::clone(&join_schema),
508                        )?;
509
510                        batch_process_state.output_batches.push_batch(batch)?;
511
512                        // Flush batch and update pointers if we have a completed batch
513                        if let Some(batch) =
514                            batch_process_state.output_batches.next_completed_batch()
515                        {
516                            batch_process_state.found = false;
517                            batch_process_state.start_buffer_idx = buffer_idx;
518                            batch_process_state.start_stream_idx = row_idx + 1;
519                            return Ok(batch);
520                        }
521
522                        break;
523                    }
524                }
525                Operator::GtEq | Operator::LtEq => {
526                    if matches!(compare, Ordering::Equal | Ordering::Less) {
527                        batch_process_state.found = true;
528                        let count = buffered_len - buffer_idx;
529                        let batch = build_matched_indices_and_set_buffered_bitmap(
530                            (buffer_idx, count),
531                            (row_idx, count),
532                            buffered_side,
533                            stream_batch,
534                            join_type,
535                            Arc::clone(&join_schema),
536                        )?;
537
538                        // Flush batch and update pointers if we have a completed batch
539                        batch_process_state.output_batches.push_batch(batch)?;
540                        if let Some(batch) =
541                            batch_process_state.output_batches.next_completed_batch()
542                        {
543                            batch_process_state.found = false;
544                            batch_process_state.start_buffer_idx = buffer_idx;
545                            batch_process_state.start_stream_idx = row_idx + 1;
546                            return Ok(batch);
547                        }
548
549                        break;
550                    }
551                }
552                _ => {
553                    return internal_err!(
554                        "PiecewiseMergeJoin should not contain operator, {}",
555                        operator
556                    )
557                }
558            };
559
560            // Increment buffer_idx after every row
561            buffer_idx += 1;
562        }
563
564        // If a match was not found for the current stream row index the stream indice is appended
565        // to the unmatched indices to be flushed later.
566        if matches!(join_type, JoinType::Right | JoinType::Full)
567            && !batch_process_state.found
568        {
569            batch_process_state
570                .unmatched_indices
571                .append_value(row_idx as u32);
572        }
573
574        batch_process_state.found = false;
575    }
576
577    // Flushed all unmatched indices on the streamed side
578    if matches!(join_type, JoinType::Right | JoinType::Full) {
579        let batch = create_unmatched_batch(
580            &mut batch_process_state.unmatched_indices,
581            stream_batch,
582            Arc::clone(&join_schema),
583        )?;
584
585        batch_process_state.output_batches.push_batch(batch)?;
586    }
587
588    batch_process_state.continue_process = false;
589    Ok(RecordBatch::new_empty(Arc::clone(&join_schema)))
590}
591
592// Builds a record batch from indices ranges on the buffered and streamed side.
593//
594// The two ranges are: buffered_range: (start index, count) and streamed_range: (start index, count) due
595// to batch.slice(start, count).
596fn build_matched_indices_and_set_buffered_bitmap(
597    buffered_range: (usize, usize),
598    streamed_range: (usize, usize),
599    buffered_side: &mut BufferedSideReadyState,
600    stream_batch: &SortedStreamBatch,
601    join_type: JoinType,
602    join_schema: Arc<Schema>,
603) -> Result<RecordBatch> {
604    // Mark the buffered indices as visited
605    if need_produce_result_in_final(join_type) {
606        let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock();
607        for i in buffered_range.0..buffered_range.0 + buffered_range.1 {
608            bitmap.set_bit(i, true);
609        }
610    }
611
612    let new_buffered_batch = buffered_side
613        .buffered_data
614        .batch()
615        .slice(buffered_range.0, buffered_range.1);
616    let mut buffered_columns = new_buffered_batch.columns().to_vec();
617
618    let indices = UInt32Array::from_value(streamed_range.0 as u32, streamed_range.1);
619    let new_stream_batch = take_record_batch(&stream_batch.batch, &indices)?;
620    let streamed_columns = new_stream_batch.columns().to_vec();
621
622    buffered_columns.extend(streamed_columns);
623
624    Ok(RecordBatch::try_new(
625        Arc::clone(&join_schema),
626        buffered_columns,
627    )?)
628}
629
630// Creates a record batch from the unmatched indices on the streamed side
631fn create_unmatched_batch(
632    streamed_indices: &mut PrimitiveBuilder<UInt32Type>,
633    stream_batch: &SortedStreamBatch,
634    join_schema: Arc<Schema>,
635) -> Result<RecordBatch> {
636    let streamed_indices = streamed_indices.finish();
637    let new_stream_batch = take_record_batch(&stream_batch.batch, &streamed_indices)?;
638    let streamed_columns = new_stream_batch.columns().to_vec();
639    let buffered_cols_len = join_schema.fields().len() - streamed_columns.len();
640
641    let num_rows = new_stream_batch.num_rows();
642    let mut buffered_columns: Vec<ArrayRef> = join_schema
643        .fields()
644        .iter()
645        .take(buffered_cols_len)
646        .map(|field| new_null_array(field.data_type(), num_rows))
647        .collect();
648
649    buffered_columns.extend(streamed_columns);
650
651    Ok(RecordBatch::try_new(
652        Arc::clone(&join_schema),
653        buffered_columns,
654    )?)
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660    use crate::{
661        common,
662        joins::PiecewiseMergeJoinExec,
663        test::{build_table_i32, TestMemoryExec},
664        ExecutionPlan,
665    };
666    use arrow::array::{Date32Array, Date64Array};
667    use arrow_schema::{DataType, Field};
668    use datafusion_common::test_util::batches_to_string;
669    use datafusion_execution::TaskContext;
670    use datafusion_expr::JoinType;
671    use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
672    use insta::assert_snapshot;
673    use std::sync::Arc;
674
675    fn columns(schema: &Schema) -> Vec<String> {
676        schema.fields().iter().map(|f| f.name().clone()).collect()
677    }
678
679    fn build_table(
680        a: (&str, &Vec<i32>),
681        b: (&str, &Vec<i32>),
682        c: (&str, &Vec<i32>),
683    ) -> Arc<dyn ExecutionPlan> {
684        let batch = build_table_i32(a, b, c);
685        let schema = batch.schema();
686        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
687    }
688
689    fn build_date_table(
690        a: (&str, &Vec<i32>),
691        b: (&str, &Vec<i32>),
692        c: (&str, &Vec<i32>),
693    ) -> Arc<dyn ExecutionPlan> {
694        let schema = Schema::new(vec![
695            Field::new(a.0, DataType::Date32, false),
696            Field::new(b.0, DataType::Date32, false),
697            Field::new(c.0, DataType::Date32, false),
698        ]);
699
700        let batch = RecordBatch::try_new(
701            Arc::new(schema),
702            vec![
703                Arc::new(Date32Array::from(a.1.clone())),
704                Arc::new(Date32Array::from(b.1.clone())),
705                Arc::new(Date32Array::from(c.1.clone())),
706            ],
707        )
708        .unwrap();
709
710        let schema = batch.schema();
711        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
712    }
713
714    fn build_date64_table(
715        a: (&str, &Vec<i64>),
716        b: (&str, &Vec<i64>),
717        c: (&str, &Vec<i64>),
718    ) -> Arc<dyn ExecutionPlan> {
719        let schema = Schema::new(vec![
720            Field::new(a.0, DataType::Date64, false),
721            Field::new(b.0, DataType::Date64, false),
722            Field::new(c.0, DataType::Date64, false),
723        ]);
724
725        let batch = RecordBatch::try_new(
726            Arc::new(schema),
727            vec![
728                Arc::new(Date64Array::from(a.1.clone())),
729                Arc::new(Date64Array::from(b.1.clone())),
730                Arc::new(Date64Array::from(c.1.clone())),
731            ],
732        )
733        .unwrap();
734
735        let schema = batch.schema();
736        TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap()
737    }
738
739    fn join(
740        left: Arc<dyn ExecutionPlan>,
741        right: Arc<dyn ExecutionPlan>,
742        on: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
743        operator: Operator,
744        join_type: JoinType,
745    ) -> Result<PiecewiseMergeJoinExec> {
746        PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type, 1)
747    }
748
749    async fn join_collect(
750        left: Arc<dyn ExecutionPlan>,
751        right: Arc<dyn ExecutionPlan>,
752        on: (PhysicalExprRef, PhysicalExprRef),
753        operator: Operator,
754        join_type: JoinType,
755    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
756        join_collect_with_options(left, right, on, operator, join_type).await
757    }
758
759    async fn join_collect_with_options(
760        left: Arc<dyn ExecutionPlan>,
761        right: Arc<dyn ExecutionPlan>,
762        on: (PhysicalExprRef, PhysicalExprRef),
763        operator: Operator,
764        join_type: JoinType,
765    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
766        let task_ctx = Arc::new(TaskContext::default());
767        let join = join(left, right, on, operator, join_type)?;
768        let columns = columns(&join.schema());
769
770        let stream = join.execute(0, task_ctx)?;
771        let batches = common::collect(stream).await?;
772        Ok((columns, batches))
773    }
774
775    #[tokio::test]
776    async fn join_inner_less_than() -> Result<()> {
777        // +----+----+----+
778        // | a1 | b1 | c1 |
779        // +----+----+----+
780        // | 1  | 3  | 7  |
781        // | 2  | 2  | 8  |
782        // | 3  | 1  | 9  |
783        // +----+----+----+
784        let left = build_table(
785            ("a1", &vec![1, 2, 3]),
786            ("b1", &vec![3, 2, 1]), // this has a repetition
787            ("c1", &vec![7, 8, 9]),
788        );
789
790        // +----+----+----+
791        // | a2 | b1 | c2 |
792        // +----+----+----+
793        // | 10 | 2  | 70 |
794        // | 20 | 3  | 80 |
795        // | 30 | 4  | 90 |
796        // +----+----+----+
797        let right = build_table(
798            ("a2", &vec![10, 20, 30]),
799            ("b1", &vec![2, 3, 4]),
800            ("c2", &vec![70, 80, 90]),
801        );
802
803        let on = (
804            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
805            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
806        );
807
808        let (_, batches) =
809            join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?;
810
811        assert_snapshot!(batches_to_string(&batches), @r#"
812        +----+----+----+----+----+----+
813        | a1 | b1 | c1 | a2 | b1 | c2 |
814        +----+----+----+----+----+----+
815        | 1  | 3  | 7  | 30 | 4  | 90 |
816        | 2  | 2  | 8  | 30 | 4  | 90 |
817        | 3  | 1  | 9  | 30 | 4  | 90 |
818        | 2  | 2  | 8  | 20 | 3  | 80 |
819        | 3  | 1  | 9  | 20 | 3  | 80 |
820        | 3  | 1  | 9  | 10 | 2  | 70 |
821        +----+----+----+----+----+----+
822        "#);
823        Ok(())
824    }
825
826    #[tokio::test]
827    async fn join_inner_less_than_unsorted() -> Result<()> {
828        // +----+----+----+
829        // | a1 | b1 | c1 |
830        // +----+----+----+
831        // | 1  | 3  | 7  |
832        // | 2  | 2  | 8  |
833        // | 3  | 1  | 9  |
834        // +----+----+----+
835        let left = build_table(
836            ("a1", &vec![1, 2, 3]),
837            ("b1", &vec![3, 2, 1]), // this has a repetition
838            ("c1", &vec![7, 8, 9]),
839        );
840
841        // +----+----+----+
842        // | a2 | b1 | c2 |
843        // +----+----+----+
844        // | 10 | 3  | 70 |
845        // | 20 | 2  | 80 |
846        // | 30 | 4  | 90 |
847        // +----+----+----+
848        let right = build_table(
849            ("a2", &vec![10, 20, 30]),
850            ("b1", &vec![3, 2, 4]),
851            ("c2", &vec![70, 80, 90]),
852        );
853
854        let on = (
855            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
856            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
857        );
858
859        let (_, batches) =
860            join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?;
861
862        assert_snapshot!(batches_to_string(&batches), @r#"
863            +----+----+----+----+----+----+
864            | a1 | b1 | c1 | a2 | b1 | c2 |
865            +----+----+----+----+----+----+
866            | 1  | 3  | 7  | 30 | 4  | 90 |
867            | 2  | 2  | 8  | 30 | 4  | 90 |
868            | 3  | 1  | 9  | 30 | 4  | 90 |
869            | 2  | 2  | 8  | 10 | 3  | 70 |
870            | 3  | 1  | 9  | 10 | 3  | 70 |
871            | 3  | 1  | 9  | 20 | 2  | 80 |
872            +----+----+----+----+----+----+
873        "#);
874        Ok(())
875    }
876
877    #[tokio::test]
878    async fn join_inner_greater_than_equal_to() -> Result<()> {
879        // +----+----+----+
880        // | a1 | b1 | c1 |
881        // +----+----+----+
882        // | 1  | 2  | 7  |
883        // | 2  | 3  | 8  |
884        // | 3  | 4  | 9  |
885        // +----+----+----+
886        let left = build_table(
887            ("a1", &vec![1, 2, 3]),
888            ("b1", &vec![2, 3, 4]),
889            ("c1", &vec![7, 8, 9]),
890        );
891
892        // +----+----+----+
893        // | a2 | b1 | c2 |
894        // +----+----+----+
895        // | 10 | 3  | 70 |
896        // | 20 | 2  | 80 |
897        // | 30 | 1  | 90 |
898        // +----+----+----+
899        let right = build_table(
900            ("a2", &vec![10, 20, 30]),
901            ("b1", &vec![3, 2, 1]),
902            ("c2", &vec![70, 80, 90]),
903        );
904
905        let on = (
906            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
907            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
908        );
909
910        let (_, batches) =
911            join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?;
912
913        assert_snapshot!(batches_to_string(&batches), @r#"
914        +----+----+----+----+----+----+
915        | a1 | b1 | c1 | a2 | b1 | c2 |
916        +----+----+----+----+----+----+
917        | 1  | 2  | 7  | 30 | 1  | 90 |
918        | 2  | 3  | 8  | 30 | 1  | 90 |
919        | 3  | 4  | 9  | 30 | 1  | 90 |
920        | 1  | 2  | 7  | 20 | 2  | 80 |
921        | 2  | 3  | 8  | 20 | 2  | 80 |
922        | 3  | 4  | 9  | 20 | 2  | 80 |
923        | 2  | 3  | 8  | 10 | 3  | 70 |
924        | 3  | 4  | 9  | 10 | 3  | 70 |
925        +----+----+----+----+----+----+
926        "#);
927        Ok(())
928    }
929
930    #[tokio::test]
931    async fn join_inner_empty_left() -> Result<()> {
932        // +----+----+----+
933        // | a1 | b1 | c1 |
934        // +----+----+----+
935        // (empty)
936        // +----+----+----+
937        let left = build_table(
938            ("a1", &Vec::<i32>::new()),
939            ("b1", &Vec::<i32>::new()),
940            ("c1", &Vec::<i32>::new()),
941        );
942
943        // +----+----+----+
944        // | a2 | b1 | c2 |
945        // +----+----+----+
946        // | 1  | 1  | 1  |
947        // | 2  | 2  | 2  |
948        // +----+----+----+
949        let right = build_table(
950            ("a2", &vec![1, 2]),
951            ("b1", &vec![1, 2]),
952            ("c2", &vec![1, 2]),
953        );
954
955        let on = (
956            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
957            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
958        );
959        let (_, batches) =
960            join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?;
961        assert_snapshot!(batches_to_string(&batches), @r#"
962        +----+----+----+----+----+----+
963        | a1 | b1 | c1 | a2 | b1 | c2 |
964        +----+----+----+----+----+----+
965        +----+----+----+----+----+----+
966        "#);
967        Ok(())
968    }
969
970    #[tokio::test]
971    async fn join_full_greater_than_equal_to() -> Result<()> {
972        // +----+----+-----+
973        // | a1 | b1 | c1  |
974        // +----+----+-----+
975        // | 1  | 1  | 100 |
976        // | 2  | 2  | 200 |
977        // +----+----+-----+
978        let left = build_table(
979            ("a1", &vec![1, 2]),
980            ("b1", &vec![1, 2]),
981            ("c1", &vec![100, 200]),
982        );
983
984        // +----+----+-----+
985        // | a2 | b1 | c2  |
986        // +----+----+-----+
987        // | 10 | 3  | 300 |
988        // | 20 | 2  | 400 |
989        // +----+----+-----+
990        let right = build_table(
991            ("a2", &vec![10, 20]),
992            ("b1", &vec![3, 2]),
993            ("c2", &vec![300, 400]),
994        );
995
996        let on = (
997            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
998            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
999        );
1000
1001        let (_, batches) =
1002            join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?;
1003
1004        assert_snapshot!(batches_to_string(&batches), @r#"
1005        +----+----+-----+----+----+-----+
1006        | a1 | b1 | c1  | a2 | b1 | c2  |
1007        +----+----+-----+----+----+-----+
1008        | 2  | 2  | 200 | 20 | 2  | 400 |
1009        |    |    |     | 10 | 3  | 300 |
1010        | 1  | 1  | 100 |    |    |     |
1011        +----+----+-----+----+----+-----+
1012        "#);
1013
1014        Ok(())
1015    }
1016
1017    #[tokio::test]
1018    async fn join_left_greater_than() -> Result<()> {
1019        // +----+----+----+
1020        // | a1 | b1 | c1 |
1021        // +----+----+----+
1022        // | 1  | 1  | 7  |
1023        // | 2  | 3  | 8  |
1024        // | 3  | 4  | 9  |
1025        // +----+----+----+
1026        let left = build_table(
1027            ("a1", &vec![1, 2, 3]),
1028            ("b1", &vec![1, 3, 4]),
1029            ("c1", &vec![7, 8, 9]),
1030        );
1031
1032        // +----+----+----+
1033        // | a2 | b1 | c2 |
1034        // +----+----+----+
1035        // | 10 | 3  | 70 |
1036        // | 20 | 2  | 80 |
1037        // | 30 | 1  | 90 |
1038        // +----+----+----+
1039        let right = build_table(
1040            ("a2", &vec![10, 20, 30]),
1041            ("b1", &vec![3, 2, 1]),
1042            ("c2", &vec![70, 80, 90]),
1043        );
1044
1045        let on = (
1046            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1047            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1048        );
1049
1050        let (_, batches) =
1051            join_collect(left, right, on, Operator::Gt, JoinType::Left).await?;
1052
1053        assert_snapshot!(batches_to_string(&batches), @r#"
1054        +----+----+----+----+----+----+
1055        | a1 | b1 | c1 | a2 | b1 | c2 |
1056        +----+----+----+----+----+----+
1057        | 2  | 3  | 8  | 30 | 1  | 90 |
1058        | 3  | 4  | 9  | 30 | 1  | 90 |
1059        | 2  | 3  | 8  | 20 | 2  | 80 |
1060        | 3  | 4  | 9  | 20 | 2  | 80 |
1061        | 3  | 4  | 9  | 10 | 3  | 70 |
1062        | 1  | 1  | 7  |    |    |    |
1063        +----+----+----+----+----+----+
1064        "#);
1065        Ok(())
1066    }
1067
1068    #[tokio::test]
1069    async fn join_right_greater_than() -> Result<()> {
1070        // +----+----+----+
1071        // | a1 | b1 | c1 |
1072        // +----+----+----+
1073        // | 1  | 1  | 7  |
1074        // | 2  | 3  | 8  |
1075        // | 3  | 4  | 9  |
1076        // +----+----+----+
1077        let left = build_table(
1078            ("a1", &vec![1, 2, 3]),
1079            ("b1", &vec![1, 3, 4]),
1080            ("c1", &vec![7, 8, 9]),
1081        );
1082
1083        // +----+----+----+
1084        // | a2 | b1 | c2 |
1085        // +----+----+----+
1086        // | 10 | 5  | 70 |
1087        // | 20 | 3  | 80 |
1088        // | 30 | 2  | 90 |
1089        // +----+----+----+
1090        let right = build_table(
1091            ("a2", &vec![10, 20, 30]),
1092            ("b1", &vec![5, 3, 2]),
1093            ("c2", &vec![70, 80, 90]),
1094        );
1095
1096        let on = (
1097            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1098            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1099        );
1100
1101        let (_, batches) =
1102            join_collect(left, right, on, Operator::Gt, JoinType::Right).await?;
1103
1104        assert_snapshot!(batches_to_string(&batches), @r#"
1105        +----+----+----+----+----+----+
1106        | a1 | b1 | c1 | a2 | b1 | c2 |
1107        +----+----+----+----+----+----+
1108        | 2  | 3  | 8  | 30 | 2  | 90 |
1109        | 3  | 4  | 9  | 30 | 2  | 90 |
1110        | 3  | 4  | 9  | 20 | 3  | 80 |
1111        |    |    |    | 10 | 5  | 70 |
1112        +----+----+----+----+----+----+
1113        "#);
1114        Ok(())
1115    }
1116
1117    #[tokio::test]
1118    async fn join_right_less_than() -> Result<()> {
1119        // +----+----+----+
1120        // | a1 | b1 | c1 |
1121        // +----+----+----+
1122        // | 1  | 4  | 7  |
1123        // | 2  | 3  | 8  |
1124        // | 3  | 1  | 9  |
1125        // +----+----+----+
1126        let left = build_table(
1127            ("a1", &vec![1, 2, 3]),
1128            ("b1", &vec![4, 3, 1]),
1129            ("c1", &vec![7, 8, 9]),
1130        );
1131
1132        // +----+----+----+
1133        // | a2 | b1 | c2 |
1134        // +----+----+----+
1135        // | 10 | 2  | 70 |
1136        // | 20 | 3  | 80 |
1137        // | 30 | 5  | 90 |
1138        // +----+----+----+
1139        let right = build_table(
1140            ("a2", &vec![10, 20, 30]),
1141            ("b1", &vec![2, 3, 5]),
1142            ("c2", &vec![70, 80, 90]),
1143        );
1144
1145        let on = (
1146            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1147            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1148        );
1149
1150        let (_, batches) =
1151            join_collect(left, right, on, Operator::Lt, JoinType::Right).await?;
1152
1153        assert_snapshot!(batches_to_string(&batches), @r#"
1154        +----+----+----+----+----+----+
1155        | a1 | b1 | c1 | a2 | b1 | c2 |
1156        +----+----+----+----+----+----+
1157        | 1  | 4  | 7  | 30 | 5  | 90 |
1158        | 2  | 3  | 8  | 30 | 5  | 90 |
1159        | 3  | 1  | 9  | 30 | 5  | 90 |
1160        | 3  | 1  | 9  | 20 | 3  | 80 |
1161        | 3  | 1  | 9  | 10 | 2  | 70 |
1162        +----+----+----+----+----+----+
1163        "#);
1164        Ok(())
1165    }
1166
1167    #[tokio::test]
1168    async fn join_inner_less_than_equal_with_dups() -> Result<()> {
1169        // +----+----+----+
1170        // | a1 | b1 | c1 |
1171        // +----+----+----+
1172        // | 1  | 4  | 7  |
1173        // | 2  | 4  | 8  |
1174        // | 3  | 2  | 9  |
1175        // +----+----+----+
1176        let left = build_table(
1177            ("a1", &vec![1, 2, 3]),
1178            ("b1", &vec![4, 4, 2]),
1179            ("c1", &vec![7, 8, 9]),
1180        );
1181
1182        // +----+----+----+
1183        // | a2 | b1 | c2 |
1184        // +----+----+----+
1185        // | 10 | 4  | 70 |
1186        // | 20 | 3  | 80 |
1187        // | 30 | 2  | 90 |
1188        // +----+----+----+
1189        let right = build_table(
1190            ("a2", &vec![10, 20, 30]),
1191            ("b1", &vec![4, 3, 2]),
1192            ("c2", &vec![70, 80, 90]),
1193        );
1194
1195        let on = (
1196            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1197            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1198        );
1199
1200        let (_, batches) =
1201            join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?;
1202
1203        // Expected grouping follows right.b1 descending (4, 3, 2)
1204        assert_snapshot!(batches_to_string(&batches), @r#"
1205        +----+----+----+----+----+----+
1206        | a1 | b1 | c1 | a2 | b1 | c2 |
1207        +----+----+----+----+----+----+
1208        | 1  | 4  | 7  | 10 | 4  | 70 |
1209        | 2  | 4  | 8  | 10 | 4  | 70 |
1210        | 3  | 2  | 9  | 10 | 4  | 70 |
1211        | 3  | 2  | 9  | 20 | 3  | 80 |
1212        | 3  | 2  | 9  | 30 | 2  | 90 |
1213        +----+----+----+----+----+----+
1214        "#);
1215        Ok(())
1216    }
1217
1218    #[tokio::test]
1219    async fn join_inner_greater_than_unsorted_right() -> Result<()> {
1220        // +----+----+----+
1221        // | a1 | b1 | c1 |
1222        // +----+----+----+
1223        // | 1  | 1  | 7  |
1224        // | 2  | 2  | 8  |
1225        // | 3  | 4  | 9  |
1226        // +----+----+----+
1227        let left = build_table(
1228            ("a1", &vec![1, 2, 3]),
1229            ("b1", &vec![1, 2, 4]),
1230            ("c1", &vec![7, 8, 9]),
1231        );
1232
1233        // +----+----+----+
1234        // | a2 | b1 | c2 |
1235        // +----+----+----+
1236        // | 10 | 3  | 70 |
1237        // | 20 | 1  | 80 |
1238        // | 30 | 2  | 90 |
1239        // +----+----+----+
1240        let right = build_table(
1241            ("a2", &vec![10, 20, 30]),
1242            ("b1", &vec![3, 1, 2]),
1243            ("c2", &vec![70, 80, 90]),
1244        );
1245
1246        let on = (
1247            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1248            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1249        );
1250
1251        let (_, batches) =
1252            join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?;
1253
1254        // Grouped by right in ascending evaluation for > (1,2,3)
1255        assert_snapshot!(batches_to_string(&batches), @r#"
1256        +----+----+----+----+----+----+
1257        | a1 | b1 | c1 | a2 | b1 | c2 |
1258        +----+----+----+----+----+----+
1259        | 2  | 2  | 8  | 20 | 1  | 80 |
1260        | 3  | 4  | 9  | 20 | 1  | 80 |
1261        | 3  | 4  | 9  | 30 | 2  | 90 |
1262        | 3  | 4  | 9  | 10 | 3  | 70 |
1263        +----+----+----+----+----+----+
1264        "#);
1265        Ok(())
1266    }
1267
1268    #[tokio::test]
1269    async fn join_left_less_than_equal_with_left_nulls_on_no_match() -> Result<()> {
1270        // +----+----+----+
1271        // | a1 | b1 | c1 |
1272        // +----+----+----+
1273        // | 1  | 5  | 7  |
1274        // | 2  | 4  | 8  |
1275        // | 3  | 1  | 9  |
1276        // +----+----+----+
1277        let left = build_table(
1278            ("a1", &vec![1, 2, 3]),
1279            ("b1", &vec![5, 4, 1]),
1280            ("c1", &vec![7, 8, 9]),
1281        );
1282
1283        // +----+----+----+
1284        // | a2 | b1 | c2 |
1285        // +----+----+----+
1286        // | 10 | 3  | 70 |
1287        // +----+----+----+
1288        let right = build_table(("a2", &vec![10]), ("b1", &vec![3]), ("c2", &vec![70]));
1289
1290        let on = (
1291            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1292            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1293        );
1294
1295        let (_, batches) =
1296            join_collect(left, right, on, Operator::LtEq, JoinType::Left).await?;
1297
1298        assert_snapshot!(batches_to_string(&batches), @r#"
1299        +----+----+----+----+----+----+
1300        | a1 | b1 | c1 | a2 | b1 | c2 |
1301        +----+----+----+----+----+----+
1302        | 3  | 1  | 9  | 10 | 3  | 70 |
1303        | 1  | 5  | 7  |    |    |    |
1304        | 2  | 4  | 8  |    |    |    |
1305        +----+----+----+----+----+----+
1306        "#);
1307        Ok(())
1308    }
1309
1310    #[tokio::test]
1311    async fn join_right_greater_than_equal_with_right_nulls_on_no_match() -> Result<()> {
1312        // +----+----+----+
1313        // | a1 | b1 | c1 |
1314        // +----+----+----+
1315        // | 1  | 1  | 7  |
1316        // | 2  | 2  | 8  |
1317        // +----+----+----+
1318        let left = build_table(
1319            ("a1", &vec![1, 2]),
1320            ("b1", &vec![1, 2]),
1321            ("c1", &vec![7, 8]),
1322        );
1323
1324        // +----+----+----+
1325        // | a2 | b1 | c2 |
1326        // +----+----+----+
1327        // | 10 | 3  | 70 |
1328        // | 20 | 5  | 80 |
1329        // +----+----+----+
1330        let right = build_table(
1331            ("a2", &vec![10, 20]),
1332            ("b1", &vec![3, 5]),
1333            ("c2", &vec![70, 80]),
1334        );
1335
1336        let on = (
1337            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1338            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1339        );
1340
1341        let (_, batches) =
1342            join_collect(left, right, on, Operator::GtEq, JoinType::Right).await?;
1343
1344        assert_snapshot!(batches_to_string(&batches), @r#"
1345        +----+----+----+----+----+----+
1346        | a1 | b1 | c1 | a2 | b1 | c2 |
1347        +----+----+----+----+----+----+
1348        |    |    |    | 10 | 3  | 70 |
1349        |    |    |    | 20 | 5  | 80 |
1350        +----+----+----+----+----+----+
1351        "#);
1352        Ok(())
1353    }
1354
1355    #[tokio::test]
1356    async fn join_inner_single_row_left_less_than() -> Result<()> {
1357        let left = build_table(("a1", &vec![42]), ("b1", &vec![5]), ("c1", &vec![999]));
1358
1359        let right = build_table(
1360            ("a2", &vec![10, 20, 30]),
1361            ("b1", &vec![1, 5, 7]),
1362            ("c2", &vec![70, 80, 90]),
1363        );
1364
1365        let on = (
1366            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1367            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1368        );
1369
1370        let (_, batches) =
1371            join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?;
1372
1373        assert_snapshot!(batches_to_string(&batches), @r#"
1374        +----+----+-----+----+----+----+
1375        | a1 | b1 | c1  | a2 | b1 | c2 |
1376        +----+----+-----+----+----+----+
1377        | 42 | 5  | 999 | 30 | 7  | 90 |
1378        +----+----+-----+----+----+----+
1379        "#);
1380        Ok(())
1381    }
1382
1383    #[tokio::test]
1384    async fn join_inner_empty_right() -> Result<()> {
1385        let left = build_table(
1386            ("a1", &vec![1, 2, 3]),
1387            ("b1", &vec![1, 2, 3]),
1388            ("c1", &vec![7, 8, 9]),
1389        );
1390
1391        let right = build_table(
1392            ("a2", &Vec::<i32>::new()),
1393            ("b1", &Vec::<i32>::new()),
1394            ("c2", &Vec::<i32>::new()),
1395        );
1396
1397        let on = (
1398            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1399            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1400        );
1401
1402        let (_, batches) =
1403            join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?;
1404
1405        assert_snapshot!(batches_to_string(&batches), @r#"
1406        +----+----+----+----+----+----+
1407        | a1 | b1 | c1 | a2 | b1 | c2 |
1408        +----+----+----+----+----+----+
1409        +----+----+----+----+----+----+
1410        "#);
1411        Ok(())
1412    }
1413
1414    #[tokio::test]
1415    async fn join_date32_inner_less_than() -> Result<()> {
1416        // +----+-------+----+
1417        // | a1 |  b1   | c1 |
1418        // +----+-------+----+
1419        // | 1  | 19107 | 7  |
1420        // | 2  | 19107 | 8  |
1421        // | 3  | 19105 | 9  |
1422        // +----+-------+----+
1423        let left = build_date_table(
1424            ("a1", &vec![1, 2, 3]),
1425            ("b1", &vec![19107, 19107, 19105]),
1426            ("c1", &vec![7, 8, 9]),
1427        );
1428
1429        // +----+-------+----+
1430        // | a2 |  b1   | c2 |
1431        // +----+-------+----+
1432        // | 10 | 19105 | 70 |
1433        // | 20 | 19103 | 80 |
1434        // | 30 | 19107 | 90 |
1435        // +----+-------+----+
1436        let right = build_date_table(
1437            ("a2", &vec![10, 20, 30]),
1438            ("b1", &vec![19105, 19103, 19107]),
1439            ("c2", &vec![70, 80, 90]),
1440        );
1441
1442        let on = (
1443            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1444            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1445        );
1446
1447        let (_, batches) =
1448            join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?;
1449
1450        assert_snapshot!(batches_to_string(&batches), @r#"
1451    +------------+------------+------------+------------+------------+------------+
1452    | a1         | b1         | c1         | a2         | b1         | c2         |
1453    +------------+------------+------------+------------+------------+------------+
1454    | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 |
1455    +------------+------------+------------+------------+------------+------------+
1456    "#);
1457        Ok(())
1458    }
1459
1460    #[tokio::test]
1461    async fn join_date64_inner_less_than() -> Result<()> {
1462        // +----+---------------+----+
1463        // | a1 |     b1        | c1 |
1464        // +----+---------------+----+
1465        // | 1  | 1650903441000 |  7 |
1466        // | 2  | 1650903441000 |  8 |
1467        // | 3  | 1650703441000 |  9 |
1468        // +----+---------------+----+
1469        let left = build_date64_table(
1470            ("a1", &vec![1, 2, 3]),
1471            ("b1", &vec![1650903441000, 1650903441000, 1650703441000]),
1472            ("c1", &vec![7, 8, 9]),
1473        );
1474
1475        // +----+---------------+----+
1476        // | a2 |     b1        | c2 |
1477        // +----+---------------+----+
1478        // | 10 | 1650703441000 | 70 |
1479        // | 20 | 1650503441000 | 80 |
1480        // | 30 | 1650903441000 | 90 |
1481        // +----+---------------+----+
1482        let right = build_date64_table(
1483            ("a2", &vec![10, 20, 30]),
1484            ("b1", &vec![1650703441000, 1650503441000, 1650903441000]),
1485            ("c2", &vec![70, 80, 90]),
1486        );
1487
1488        let on = (
1489            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1490            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1491        );
1492
1493        let (_, batches) =
1494            join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?;
1495
1496        assert_snapshot!(batches_to_string(&batches), @r#"
1497        +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1498        | a1                      | b1                  | c1                      | a2                      | b1                  | c2                      |
1499        +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1500        | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |
1501        +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1502        "#);
1503        Ok(())
1504    }
1505
1506    #[tokio::test]
1507    async fn join_date64_right_less_than() -> Result<()> {
1508        // +----+---------------+----+
1509        // | a1 |     b1        | c1 |
1510        // +----+---------------+----+
1511        // | 1  | 1650903441000 |  7 |
1512        // | 2  | 1650703441000 |  8 |
1513        // +----+---------------+----+
1514        let left = build_date64_table(
1515            ("a1", &vec![1, 2]),
1516            ("b1", &vec![1650903441000, 1650703441000]),
1517            ("c1", &vec![7, 8]),
1518        );
1519
1520        // +----+---------------+----+
1521        // | a2 |     b1        | c2 |
1522        // +----+---------------+----+
1523        // | 10 | 1650703441000 | 80 |
1524        // | 20 | 1650903441000 | 90 |
1525        // +----+---------------+----+
1526        let right = build_date64_table(
1527            ("a2", &vec![10, 20]),
1528            ("b1", &vec![1650703441000, 1650903441000]),
1529            ("c2", &vec![80, 90]),
1530        );
1531
1532        let on = (
1533            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
1534            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
1535        );
1536
1537        let (_, batches) =
1538            join_collect(left, right, on, Operator::Lt, JoinType::Right).await?;
1539
1540        assert_snapshot!(batches_to_string(&batches), @r#"
1541    +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1542    | a1                      | b1                  | c1                      | a2                      | b1                  | c2                      |
1543    +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1544    | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |
1545    |                         |                     |                         | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 |
1546    +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+
1547"#);
1548        Ok(())
1549    }
1550}