1use std::cmp::Ordering;
25use std::collections::{HashMap, VecDeque};
26use std::fs::File;
27use std::io::BufReader;
28use std::mem::size_of;
29use std::ops::Range;
30use std::pin::Pin;
31use std::sync::atomic::AtomicUsize;
32use std::sync::atomic::Ordering::Relaxed;
33use std::sync::Arc;
34use std::task::{Context, Poll};
35
36use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
37use crate::joins::utils::{compare_join_arrays, JoinFilter};
38use crate::spill::spill_manager::SpillManager;
39use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream};
40
41use arrow::array::{types::UInt64Type, *};
42use arrow::compute::{
43 self, concat_batches, filter_record_batch, is_not_null, take, SortOptions,
44};
45use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
46use arrow::error::ArrowError;
47use arrow::ipc::reader::StreamReader;
48use datafusion_common::config::SpillCompression;
49use datafusion_common::{
50 exec_err, internal_err, not_impl_err, DataFusionError, HashSet, JoinSide, JoinType,
51 NullEquality, Result,
52};
53use datafusion_execution::disk_manager::RefCountedTempFile;
54use datafusion_execution::memory_pool::MemoryReservation;
55use datafusion_execution::runtime_env::RuntimeEnv;
56use datafusion_physical_expr_common::physical_expr::PhysicalExprRef;
57
58use futures::{Stream, StreamExt};
59
60#[derive(Debug, PartialEq, Eq)]
62pub(super) enum SortMergeJoinState {
63 Init,
65 Polling,
67 JoinOutput,
69 Exhausted,
71}
72
73#[derive(Debug, PartialEq, Eq)]
75pub(super) enum StreamedState {
76 Init,
78 Polling,
80 Ready,
82 Exhausted,
84}
85
86#[derive(Debug, PartialEq, Eq)]
88pub(super) enum BufferedState {
89 Init,
91 PollingFirst,
93 PollingRest,
95 Ready,
97 Exhausted,
99}
100
101pub(super) struct StreamedJoinedChunk {
103 buffered_batch_idx: Option<usize>,
105 streamed_indices: UInt64Builder,
107 buffered_indices: UInt64Builder,
110}
111
112pub(super) struct StreamedBatch {
116 pub batch: RecordBatch,
118 pub idx: usize,
120 pub join_arrays: Vec<ArrayRef>,
123 pub output_indices: Vec<StreamedJoinedChunk>,
125 pub buffered_batch_idx: Option<usize>,
127 pub join_filter_matched_idxs: HashSet<u64>,
131}
132
133impl StreamedBatch {
134 fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self {
135 let join_arrays = join_arrays(&batch, on_column);
136 StreamedBatch {
137 batch,
138 idx: 0,
139 join_arrays,
140 output_indices: vec![],
141 buffered_batch_idx: None,
142 join_filter_matched_idxs: HashSet::new(),
143 }
144 }
145
146 fn new_empty(schema: SchemaRef) -> Self {
147 StreamedBatch {
148 batch: RecordBatch::new_empty(schema),
149 idx: 0,
150 join_arrays: vec![],
151 output_indices: vec![],
152 buffered_batch_idx: None,
153 join_filter_matched_idxs: HashSet::new(),
154 }
155 }
156
157 fn append_output_pair(
160 &mut self,
161 buffered_batch_idx: Option<usize>,
162 buffered_idx: Option<usize>,
163 ) {
164 if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx
167 {
168 self.output_indices.push(StreamedJoinedChunk {
169 buffered_batch_idx,
170 streamed_indices: UInt64Builder::with_capacity(1),
171 buffered_indices: UInt64Builder::with_capacity(1),
172 });
173 self.buffered_batch_idx = buffered_batch_idx;
174 };
175 let current_chunk = self.output_indices.last_mut().unwrap();
176
177 current_chunk.streamed_indices.append_value(self.idx as u64);
179 if let Some(idx) = buffered_idx {
180 current_chunk.buffered_indices.append_value(idx as u64);
181 } else {
182 current_chunk.buffered_indices.append_null();
183 }
184 }
185}
186
187#[derive(Debug)]
191pub(super) struct BufferedBatch {
192 pub batch: BufferedBatchState,
194 pub range: Range<usize>,
196 pub join_arrays: Vec<ArrayRef>,
198 pub null_joined: Vec<usize>,
200 pub size_estimation: usize,
202 pub join_filter_not_matched_map: HashMap<u64, bool>,
207 pub num_rows: usize,
211}
212
213impl BufferedBatch {
214 fn new(
215 batch: RecordBatch,
216 range: Range<usize>,
217 on_column: &[PhysicalExprRef],
218 ) -> Self {
219 let join_arrays = join_arrays(&batch, on_column);
220
221 let size_estimation = batch.get_array_memory_size()
228 + join_arrays
229 .iter()
230 .map(|arr| arr.get_array_memory_size())
231 .sum::<usize>()
232 + batch.num_rows().next_power_of_two() * size_of::<usize>()
233 + size_of::<Range<usize>>()
234 + size_of::<usize>();
235
236 let num_rows = batch.num_rows();
237 BufferedBatch {
238 batch: BufferedBatchState::InMemory(batch),
239 range,
240 join_arrays,
241 null_joined: vec![],
242 size_estimation,
243 join_filter_not_matched_map: HashMap::new(),
244 num_rows,
245 }
246 }
247}
248
249#[derive(Debug)]
252pub(super) enum BufferedBatchState {
253 InMemory(RecordBatch),
255 Spilled(RefCountedTempFile),
257}
258
259pub(super) struct SortMergeJoinStream {
262 pub schema: SchemaRef,
269 pub null_equality: NullEquality,
271 pub sort_options: Vec<SortOptions>,
273 pub filter: Option<JoinFilter>,
275 pub join_type: JoinType,
277 pub batch_size: usize,
279
280 pub streamed_schema: SchemaRef,
286 pub streamed: SendableRecordBatchStream,
288 pub streamed_batch: StreamedBatch,
290 pub streamed_joined: bool,
292 pub streamed_state: StreamedState,
294 pub on_streamed: Vec<PhysicalExprRef>,
296
297 pub buffered_schema: SchemaRef,
303 pub buffered: SendableRecordBatchStream,
305 pub buffered_data: BufferedData,
307 pub buffered_joined: bool,
309 pub buffered_state: BufferedState,
311 pub on_buffered: Vec<PhysicalExprRef>,
313
314 pub state: SortMergeJoinState,
321 pub staging_output_record_batches: JoinedRecordBatches,
323 pub output: RecordBatch,
326 pub output_size: usize,
330 pub current_ordering: Ordering,
332 pub spill_manager: SpillManager,
334
335 pub join_metrics: SortMergeJoinMetrics,
341 pub reservation: MemoryReservation,
343 pub runtime_env: Arc<RuntimeEnv>,
345 pub streamed_batch_counter: AtomicUsize,
347}
348
349pub(super) struct JoinedRecordBatches {
351 pub batches: Vec<RecordBatch>,
353 pub filter_mask: BooleanBuilder,
355 pub row_indices: UInt64Builder,
357 pub batch_ids: Vec<usize>,
361}
362
363impl JoinedRecordBatches {
364 fn clear(&mut self) {
365 self.batches.clear();
366 self.batch_ids.clear();
367 self.filter_mask = BooleanBuilder::new();
368 self.row_indices = UInt64Builder::new();
369 }
370}
371impl RecordBatchStream for SortMergeJoinStream {
372 fn schema(&self) -> SchemaRef {
373 Arc::clone(&self.schema)
374 }
375}
376
377#[inline(always)]
382fn last_index_for_row(
383 row_index: usize,
384 indices: &UInt64Array,
385 batch_ids: &[usize],
386 indices_len: usize,
387) -> bool {
388 row_index == indices_len - 1
389 || batch_ids[row_index] != batch_ids[row_index + 1]
390 || indices.value(row_index) != indices.value(row_index + 1)
391}
392
393pub(super) fn get_corrected_filter_mask(
399 join_type: JoinType,
400 row_indices: &UInt64Array,
401 batch_ids: &[usize],
402 filter_mask: &BooleanArray,
403 expected_size: usize,
404) -> Option<BooleanArray> {
405 let row_indices_length = row_indices.len();
406 let mut corrected_mask: BooleanBuilder =
407 BooleanBuilder::with_capacity(row_indices_length);
408 let mut seen_true = false;
409
410 match join_type {
411 JoinType::Left | JoinType::Right => {
412 for i in 0..row_indices_length {
413 let last_index =
414 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
415 if filter_mask.value(i) {
416 seen_true = true;
417 corrected_mask.append_value(true);
418 } else if seen_true || !filter_mask.value(i) && !last_index {
419 corrected_mask.append_null(); } else {
421 corrected_mask.append_value(false); }
423
424 if last_index {
425 seen_true = false;
426 }
427 }
428
429 corrected_mask.append_n(expected_size - corrected_mask.len(), false);
431 Some(corrected_mask.finish())
432 }
433 JoinType::LeftMark | JoinType::RightMark => {
434 for i in 0..row_indices_length {
435 let last_index =
436 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
437 if filter_mask.value(i) && !seen_true {
438 seen_true = true;
439 corrected_mask.append_value(true);
440 } else if seen_true || !filter_mask.value(i) && !last_index {
441 corrected_mask.append_null(); } else {
443 corrected_mask.append_value(false); }
445
446 if last_index {
447 seen_true = false;
448 }
449 }
450
451 corrected_mask.append_n(expected_size - corrected_mask.len(), false);
453 Some(corrected_mask.finish())
454 }
455 JoinType::LeftSemi | JoinType::RightSemi => {
456 for i in 0..row_indices_length {
457 let last_index =
458 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
459 if filter_mask.value(i) && !seen_true {
460 seen_true = true;
461 corrected_mask.append_value(true);
462 } else {
463 corrected_mask.append_null(); }
465
466 if last_index {
467 seen_true = false;
468 }
469 }
470
471 Some(corrected_mask.finish())
472 }
473 JoinType::LeftAnti | JoinType::RightAnti => {
474 for i in 0..row_indices_length {
475 let last_index =
476 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
477
478 if filter_mask.value(i) {
479 seen_true = true;
480 }
481
482 if last_index {
483 if !seen_true {
484 corrected_mask.append_value(true);
485 } else {
486 corrected_mask.append_null();
487 }
488
489 seen_true = false;
490 } else {
491 corrected_mask.append_null();
492 }
493 }
494 corrected_mask.append_n(expected_size - corrected_mask.len(), true);
497 Some(corrected_mask.finish())
498 }
499 JoinType::Full => {
500 let mut mask: Vec<Option<bool>> = vec![Some(true); row_indices_length];
501 let mut last_true_idx = 0;
502 let mut first_row_idx = 0;
503 let mut seen_false = false;
504
505 for i in 0..row_indices_length {
506 let last_index =
507 last_index_for_row(i, row_indices, batch_ids, row_indices_length);
508 let val = filter_mask.value(i);
509 let is_null = filter_mask.is_null(i);
510
511 if val {
512 if !seen_true {
514 last_true_idx = i;
515 }
516 seen_true = true;
517 }
518
519 if is_null || val {
520 mask[i] = Some(true);
521 } else if !is_null && !val && (seen_true || seen_false) {
522 mask[i] = None;
523 } else {
524 mask[i] = Some(false);
525 }
526
527 if !is_null && !val {
528 seen_false = true;
529 }
530
531 if last_index {
532 if seen_true {
535 #[allow(clippy::needless_range_loop)]
536 for j in first_row_idx..last_true_idx {
537 mask[j] = None;
538 }
539 }
540
541 seen_true = false;
542 seen_false = false;
543 last_true_idx = 0;
544 first_row_idx = i + 1;
545 }
546 }
547
548 Some(BooleanArray::from(mask))
549 }
550 _ => None,
552 }
553}
554
555impl Stream for SortMergeJoinStream {
556 type Item = Result<RecordBatch>;
557
558 fn poll_next(
559 mut self: Pin<&mut Self>,
560 cx: &mut Context<'_>,
561 ) -> Poll<Option<Self::Item>> {
562 let join_time = self.join_metrics.join_time().clone();
563 let _timer = join_time.timer();
564 loop {
565 match &self.state {
566 SortMergeJoinState::Init => {
567 let streamed_exhausted =
568 self.streamed_state == StreamedState::Exhausted;
569 let buffered_exhausted =
570 self.buffered_state == BufferedState::Exhausted;
571 self.state = if streamed_exhausted && buffered_exhausted {
572 SortMergeJoinState::Exhausted
573 } else {
574 match self.current_ordering {
575 Ordering::Less | Ordering::Equal => {
576 if !streamed_exhausted {
577 if self.filter.is_some()
578 && matches!(
579 self.join_type,
580 JoinType::Left
581 | JoinType::LeftSemi
582 | JoinType::LeftMark
583 | JoinType::Right
584 | JoinType::RightSemi
585 | JoinType::RightMark
586 | JoinType::LeftAnti
587 | JoinType::RightAnti
588 | JoinType::Full
589 )
590 {
591 self.freeze_all()?;
592
593 if !self
596 .staging_output_record_batches
597 .batches
598 .is_empty()
599 {
600 let out_filtered_batch =
602 self.filter_joined_batch()?;
603
604 self.output = concat_batches(
606 &self.schema(),
607 [&self.output, &out_filtered_batch],
608 )?;
609
610 if self.output.num_rows() >= self.batch_size {
612 let record_batch = std::mem::replace(
613 &mut self.output,
614 RecordBatch::new_empty(
615 out_filtered_batch.schema(),
616 ),
617 );
618 return Poll::Ready(Some(Ok(
619 record_batch,
620 )));
621 }
622 }
623 }
624
625 self.streamed_joined = false;
626 self.streamed_state = StreamedState::Init;
627 }
628 }
629 Ordering::Greater => {
630 if !buffered_exhausted {
631 self.buffered_joined = false;
632 self.buffered_state = BufferedState::Init;
633 }
634 }
635 }
636 SortMergeJoinState::Polling
637 };
638 }
639 SortMergeJoinState::Polling => {
640 if ![StreamedState::Exhausted, StreamedState::Ready]
641 .contains(&self.streamed_state)
642 {
643 match self.poll_streamed_row(cx)? {
644 Poll::Ready(_) => {}
645 Poll::Pending => return Poll::Pending,
646 }
647 }
648
649 if ![BufferedState::Exhausted, BufferedState::Ready]
650 .contains(&self.buffered_state)
651 {
652 match self.poll_buffered_batches(cx)? {
653 Poll::Ready(_) => {}
654 Poll::Pending => return Poll::Pending,
655 }
656 }
657 let streamed_exhausted =
658 self.streamed_state == StreamedState::Exhausted;
659 let buffered_exhausted =
660 self.buffered_state == BufferedState::Exhausted;
661 if streamed_exhausted && buffered_exhausted {
662 self.state = SortMergeJoinState::Exhausted;
663 continue;
664 }
665 self.current_ordering = self.compare_streamed_buffered()?;
666 self.state = SortMergeJoinState::JoinOutput;
667 }
668 SortMergeJoinState::JoinOutput => {
669 self.join_partial()?;
670
671 if self.output_size < self.batch_size {
672 if self.buffered_data.scanning_finished() {
673 self.buffered_data.scanning_reset();
674 self.state = SortMergeJoinState::Init;
675 }
676 } else {
677 self.freeze_all()?;
678 if !self.staging_output_record_batches.batches.is_empty() {
679 let record_batch = self.output_record_batch_and_reset()?;
680 if self.filter.is_some()
686 && matches!(
687 self.join_type,
688 JoinType::Left
689 | JoinType::LeftSemi
690 | JoinType::Right
691 | JoinType::RightSemi
692 | JoinType::LeftAnti
693 | JoinType::RightAnti
694 | JoinType::LeftMark
695 | JoinType::RightMark
696 | JoinType::Full
697 )
698 {
699 continue;
700 }
701
702 return Poll::Ready(Some(Ok(record_batch)));
703 }
704 return Poll::Pending;
705 }
706 }
707 SortMergeJoinState::Exhausted => {
708 self.freeze_all()?;
709
710 if !self.staging_output_record_batches.batches.is_empty() {
712 if self.filter.is_some()
713 && matches!(
714 self.join_type,
715 JoinType::Left
716 | JoinType::LeftSemi
717 | JoinType::Right
718 | JoinType::RightSemi
719 | JoinType::LeftAnti
720 | JoinType::RightAnti
721 | JoinType::Full
722 | JoinType::LeftMark
723 | JoinType::RightMark
724 )
725 {
726 let record_batch = self.filter_joined_batch()?;
727 return Poll::Ready(Some(Ok(record_batch)));
728 } else {
729 let record_batch = self.output_record_batch_and_reset()?;
730 return Poll::Ready(Some(Ok(record_batch)));
731 }
732 } else if self.output.num_rows() > 0 {
733 let schema = self.output.schema();
735 let record_batch = std::mem::replace(
736 &mut self.output,
737 RecordBatch::new_empty(schema),
738 );
739 return Poll::Ready(Some(Ok(record_batch)));
740 } else {
741 return Poll::Ready(None);
742 }
743 }
744 }
745 }
746 }
747}
748
749impl SortMergeJoinStream {
750 #[allow(clippy::too_many_arguments)]
751 pub fn try_new(
752 spill_compression: SpillCompression,
754 schema: SchemaRef,
755 sort_options: Vec<SortOptions>,
756 null_equality: NullEquality,
757 streamed: SendableRecordBatchStream,
758 buffered: SendableRecordBatchStream,
759 on_streamed: Vec<Arc<dyn PhysicalExpr>>,
760 on_buffered: Vec<Arc<dyn PhysicalExpr>>,
761 filter: Option<JoinFilter>,
762 join_type: JoinType,
763 batch_size: usize,
764 join_metrics: SortMergeJoinMetrics,
765 reservation: MemoryReservation,
766 runtime_env: Arc<RuntimeEnv>,
767 ) -> Result<Self> {
768 let streamed_schema = streamed.schema();
769 let buffered_schema = buffered.schema();
770 let spill_manager = SpillManager::new(
771 Arc::clone(&runtime_env),
772 join_metrics.spill_metrics().clone(),
773 Arc::clone(&buffered_schema),
774 )
775 .with_compression_type(spill_compression);
776 Ok(Self {
777 state: SortMergeJoinState::Init,
778 sort_options,
779 null_equality,
780 schema: Arc::clone(&schema),
781 streamed_schema: Arc::clone(&streamed_schema),
782 buffered_schema,
783 streamed,
784 buffered,
785 streamed_batch: StreamedBatch::new_empty(streamed_schema),
786 buffered_data: BufferedData::default(),
787 streamed_joined: false,
788 buffered_joined: false,
789 streamed_state: StreamedState::Init,
790 buffered_state: BufferedState::Init,
791 current_ordering: Ordering::Equal,
792 on_streamed,
793 on_buffered,
794 filter,
795 staging_output_record_batches: JoinedRecordBatches {
796 batches: vec![],
797 filter_mask: BooleanBuilder::new(),
798 row_indices: UInt64Builder::new(),
799 batch_ids: vec![],
800 },
801 output: RecordBatch::new_empty(schema),
802 output_size: 0,
803 batch_size,
804 join_type,
805 join_metrics,
806 reservation,
807 runtime_env,
808 spill_manager,
809 streamed_batch_counter: AtomicUsize::new(0),
810 })
811 }
812
813 fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
815 loop {
816 match &self.streamed_state {
817 StreamedState::Init => {
818 if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows()
819 {
820 self.streamed_batch.idx += 1;
821 self.streamed_state = StreamedState::Ready;
822 return Poll::Ready(Some(Ok(())));
823 } else {
824 self.streamed_state = StreamedState::Polling;
825 }
826 }
827 StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? {
828 Poll::Pending => {
829 return Poll::Pending;
830 }
831 Poll::Ready(None) => {
832 self.streamed_state = StreamedState::Exhausted;
833 }
834 Poll::Ready(Some(batch)) => {
835 if batch.num_rows() > 0 {
836 self.freeze_streamed()?;
837 self.join_metrics.input_batches().add(1);
838 self.join_metrics.input_rows().add(batch.num_rows());
839 self.streamed_batch =
840 StreamedBatch::new(batch, &self.on_streamed);
841 self.streamed_batch_counter
844 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
845 self.streamed_state = StreamedState::Ready;
846 }
847 }
848 },
849 StreamedState::Ready => {
850 return Poll::Ready(Some(Ok(())));
851 }
852 StreamedState::Exhausted => {
853 return Poll::Ready(None);
854 }
855 }
856 }
857 }
858
859 fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> {
860 if let BufferedBatchState::InMemory(_) = buffered_batch.batch {
862 self.reservation
863 .try_shrink(buffered_batch.size_estimation)?;
864 }
865 Ok(())
866 }
867
868 fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> {
869 match self.reservation.try_grow(buffered_batch.size_estimation) {
870 Ok(_) => {
871 self.join_metrics
872 .peak_mem_used()
873 .set_max(self.reservation.size());
874 Ok(())
875 }
876 Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
877 match buffered_batch.batch {
880 BufferedBatchState::InMemory(batch) => {
881 let spill_file = self
882 .spill_manager
883 .spill_record_batch_and_finish(
884 &[batch],
885 "sort_merge_join_buffered_spill",
886 )?
887 .unwrap(); buffered_batch.batch = BufferedBatchState::Spilled(spill_file);
890 Ok(())
891 }
892 _ => internal_err!("Buffered batch has empty body"),
893 }
894 }
895 Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()),
896 }?;
897
898 self.buffered_data.batches.push_back(buffered_batch);
899 Ok(())
900 }
901
902 fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll<Option<Result<()>>> {
904 loop {
905 match &self.buffered_state {
906 BufferedState::Init => {
907 while !self.buffered_data.batches.is_empty() {
909 let head_batch = self.buffered_data.head_batch();
910 if head_batch.range.end == head_batch.num_rows {
912 self.freeze_dequeuing_buffered()?;
913 if let Some(mut buffered_batch) =
914 self.buffered_data.batches.pop_front()
915 {
916 self.produce_buffered_not_matched(&mut buffered_batch)?;
917 self.free_reservation(buffered_batch)?;
918 }
919 } else {
920 break;
923 }
924 }
925 if self.buffered_data.batches.is_empty() {
926 self.buffered_state = BufferedState::PollingFirst;
927 } else {
928 let tail_batch = self.buffered_data.tail_batch_mut();
929 tail_batch.range.start = tail_batch.range.end;
930 tail_batch.range.end += 1;
931 self.buffered_state = BufferedState::PollingRest;
932 }
933 }
934 BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? {
935 Poll::Pending => {
936 return Poll::Pending;
937 }
938 Poll::Ready(None) => {
939 self.buffered_state = BufferedState::Exhausted;
940 return Poll::Ready(None);
941 }
942 Poll::Ready(Some(batch)) => {
943 self.join_metrics.input_batches().add(1);
944 self.join_metrics.input_rows().add(batch.num_rows());
945
946 if batch.num_rows() > 0 {
947 let buffered_batch =
948 BufferedBatch::new(batch, 0..1, &self.on_buffered);
949
950 self.allocate_reservation(buffered_batch)?;
951 self.buffered_state = BufferedState::PollingRest;
952 }
953 }
954 },
955 BufferedState::PollingRest => {
956 if self.buffered_data.tail_batch().range.end
957 < self.buffered_data.tail_batch().num_rows
958 {
959 while self.buffered_data.tail_batch().range.end
960 < self.buffered_data.tail_batch().num_rows
961 {
962 if is_join_arrays_equal(
963 &self.buffered_data.head_batch().join_arrays,
964 self.buffered_data.head_batch().range.start,
965 &self.buffered_data.tail_batch().join_arrays,
966 self.buffered_data.tail_batch().range.end,
967 )? {
968 self.buffered_data.tail_batch_mut().range.end += 1;
969 } else {
970 self.buffered_state = BufferedState::Ready;
971 return Poll::Ready(Some(Ok(())));
972 }
973 }
974 } else {
975 match self.buffered.poll_next_unpin(cx)? {
976 Poll::Pending => {
977 return Poll::Pending;
978 }
979 Poll::Ready(None) => {
980 self.buffered_state = BufferedState::Ready;
981 }
982 Poll::Ready(Some(batch)) => {
983 self.join_metrics.input_batches().add(1);
985 self.join_metrics.input_rows().add(batch.num_rows());
986 if batch.num_rows() > 0 {
987 let buffered_batch = BufferedBatch::new(
988 batch,
989 0..0,
990 &self.on_buffered,
991 );
992 self.allocate_reservation(buffered_batch)?;
993 }
994 }
995 }
996 }
997 }
998 BufferedState::Ready => {
999 return Poll::Ready(Some(Ok(())));
1000 }
1001 BufferedState::Exhausted => {
1002 return Poll::Ready(None);
1003 }
1004 }
1005 }
1006 }
1007
1008 fn compare_streamed_buffered(&self) -> Result<Ordering> {
1010 if self.streamed_state == StreamedState::Exhausted {
1011 return Ok(Ordering::Greater);
1012 }
1013 if !self.buffered_data.has_buffered_rows() {
1014 return Ok(Ordering::Less);
1015 }
1016
1017 compare_join_arrays(
1018 &self.streamed_batch.join_arrays,
1019 self.streamed_batch.idx,
1020 &self.buffered_data.head_batch().join_arrays,
1021 self.buffered_data.head_batch().range.start,
1022 &self.sort_options,
1023 self.null_equality,
1024 )
1025 }
1026
1027 fn join_partial(&mut self) -> Result<()> {
1030 let mut join_streamed = false;
1032 let mut join_buffered = false;
1034 let mut mark_row_as_match = false;
1036
1037 match self.current_ordering {
1039 Ordering::Less => {
1040 if matches!(
1041 self.join_type,
1042 JoinType::Left
1043 | JoinType::Right
1044 | JoinType::Full
1045 | JoinType::LeftAnti
1046 | JoinType::RightAnti
1047 | JoinType::LeftMark
1048 | JoinType::RightMark
1049 ) {
1050 join_streamed = !self.streamed_joined;
1051 }
1052 }
1053 Ordering::Equal => {
1054 if matches!(
1055 self.join_type,
1056 JoinType::LeftSemi
1057 | JoinType::LeftMark
1058 | JoinType::RightSemi
1059 | JoinType::RightMark
1060 ) {
1061 mark_row_as_match = matches!(
1062 self.join_type,
1063 JoinType::LeftMark | JoinType::RightMark
1064 );
1065 if self.filter.is_some() {
1070 join_streamed = !self
1071 .streamed_batch
1072 .join_filter_matched_idxs
1073 .contains(&(self.streamed_batch.idx as u64))
1074 && !self.streamed_joined;
1075 join_buffered = join_streamed;
1078 } else {
1079 join_streamed = !self.streamed_joined;
1080 }
1081 }
1082 if matches!(
1083 self.join_type,
1084 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
1085 ) {
1086 join_streamed = true;
1087 join_buffered = true;
1088 };
1089
1090 if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti)
1091 && self.filter.is_some()
1092 {
1093 join_streamed = !self.streamed_joined;
1094 join_buffered = join_streamed;
1095 }
1096 }
1097 Ordering::Greater => {
1098 if matches!(self.join_type, JoinType::Full) {
1099 join_buffered = !self.buffered_joined;
1100 };
1101 }
1102 }
1103 if !join_streamed && !join_buffered {
1104 self.buffered_data.scanning_finish();
1106 return Ok(());
1107 }
1108
1109 if join_buffered {
1110 while !self.buffered_data.scanning_finished()
1112 && self.output_size < self.batch_size
1113 {
1114 let scanning_idx = self.buffered_data.scanning_idx();
1115 if join_streamed {
1116 self.streamed_batch.append_output_pair(
1118 Some(self.buffered_data.scanning_batch_idx),
1119 Some(scanning_idx),
1120 );
1121 } else {
1122 self.buffered_data
1124 .scanning_batch_mut()
1125 .null_joined
1126 .push(scanning_idx);
1127 }
1128 self.output_size += 1;
1129 self.buffered_data.scanning_advance();
1130
1131 if self.buffered_data.scanning_finished() {
1132 self.streamed_joined = join_streamed;
1133 self.buffered_joined = true;
1134 }
1135 }
1136 } else {
1137 let scanning_batch_idx = if self.buffered_data.scanning_finished() {
1139 None
1140 } else {
1141 Some(self.buffered_data.scanning_batch_idx)
1142 };
1143 let scanning_idx = mark_row_as_match.then_some(0);
1145
1146 self.streamed_batch
1147 .append_output_pair(scanning_batch_idx, scanning_idx);
1148 self.output_size += 1;
1149 self.buffered_data.scanning_finish();
1150 self.streamed_joined = true;
1151 }
1152 Ok(())
1153 }
1154
1155 fn freeze_all(&mut self) -> Result<()> {
1156 self.freeze_buffered(self.buffered_data.batches.len())?;
1157 self.freeze_streamed()?;
1158 Ok(())
1159 }
1160
1161 fn freeze_dequeuing_buffered(&mut self) -> Result<()> {
1166 self.freeze_streamed()?;
1167 self.freeze_buffered(1)?;
1169 Ok(())
1170 }
1171
1172 fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> {
1178 if !matches!(self.join_type, JoinType::Full) {
1179 return Ok(());
1180 }
1181 for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
1182 let buffered_indices = UInt64Array::from_iter_values(
1183 buffered_batch.null_joined.iter().map(|&index| index as u64),
1184 );
1185 if let Some(record_batch) = produce_buffered_null_batch(
1186 &self.schema,
1187 &self.streamed_schema,
1188 &buffered_indices,
1189 buffered_batch,
1190 )? {
1191 let num_rows = record_batch.num_rows();
1192 self.staging_output_record_batches
1193 .filter_mask
1194 .append_nulls(num_rows);
1195 self.staging_output_record_batches
1196 .row_indices
1197 .append_nulls(num_rows);
1198 self.staging_output_record_batches.batch_ids.resize(
1199 self.staging_output_record_batches.batch_ids.len() + num_rows,
1200 0,
1201 );
1202
1203 self.staging_output_record_batches
1204 .batches
1205 .push(record_batch);
1206 }
1207 buffered_batch.null_joined.clear();
1208 }
1209 Ok(())
1210 }
1211
1212 fn produce_buffered_not_matched(
1213 &mut self,
1214 buffered_batch: &mut BufferedBatch,
1215 ) -> Result<()> {
1216 if !matches!(self.join_type, JoinType::Full) {
1217 return Ok(());
1218 }
1219
1220 let not_matched_buffered_indices = buffered_batch
1223 .join_filter_not_matched_map
1224 .iter()
1225 .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None })
1226 .collect::<Vec<_>>();
1227
1228 let buffered_indices =
1229 UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied());
1230
1231 if let Some(record_batch) = produce_buffered_null_batch(
1232 &self.schema,
1233 &self.streamed_schema,
1234 &buffered_indices,
1235 buffered_batch,
1236 )? {
1237 let num_rows = record_batch.num_rows();
1238
1239 self.staging_output_record_batches
1240 .filter_mask
1241 .append_nulls(num_rows);
1242 self.staging_output_record_batches
1243 .row_indices
1244 .append_nulls(num_rows);
1245 self.staging_output_record_batches.batch_ids.resize(
1246 self.staging_output_record_batches.batch_ids.len() + num_rows,
1247 0,
1248 );
1249 self.staging_output_record_batches
1250 .batches
1251 .push(record_batch);
1252 }
1253 buffered_batch.join_filter_not_matched_map.clear();
1254
1255 Ok(())
1256 }
1257
1258 fn freeze_streamed(&mut self) -> Result<()> {
1261 for chunk in self.streamed_batch.output_indices.iter_mut() {
1262 let left_indices = chunk.streamed_indices.finish();
1264
1265 if left_indices.is_empty() {
1266 continue;
1267 }
1268
1269 let mut left_columns = self
1270 .streamed_batch
1271 .batch
1272 .columns()
1273 .iter()
1274 .map(|column| take(column, &left_indices, None))
1275 .collect::<Result<Vec<_>, ArrowError>>()?;
1276
1277 let right_indices: UInt64Array = chunk.buffered_indices.finish();
1279 let mut right_columns =
1280 if matches!(self.join_type, JoinType::LeftMark | JoinType::RightMark) {
1281 vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
1282 } else if matches!(
1283 self.join_type,
1284 JoinType::LeftSemi
1285 | JoinType::LeftAnti
1286 | JoinType::RightAnti
1287 | JoinType::RightSemi
1288 ) {
1289 vec![]
1290 } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
1291 fetch_right_columns_by_idxs(
1292 &self.buffered_data,
1293 buffered_idx,
1294 &right_indices,
1295 )?
1296 } else {
1297 create_unmatched_columns(
1300 self.join_type,
1301 &self.buffered_schema,
1302 right_indices.len(),
1303 )
1304 };
1305
1306 let filter_columns = if chunk.buffered_batch_idx.is_some() {
1309 if !matches!(self.join_type, JoinType::Right) {
1310 if matches!(
1311 self.join_type,
1312 JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
1313 ) {
1314 let right_cols = fetch_right_columns_by_idxs(
1315 &self.buffered_data,
1316 chunk.buffered_batch_idx.unwrap(),
1317 &right_indices,
1318 )?;
1319
1320 get_filter_column(&self.filter, &left_columns, &right_cols)
1321 } else if matches!(
1322 self.join_type,
1323 JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark
1324 ) {
1325 let right_cols = fetch_right_columns_by_idxs(
1326 &self.buffered_data,
1327 chunk.buffered_batch_idx.unwrap(),
1328 &right_indices,
1329 )?;
1330
1331 get_filter_column(&self.filter, &right_cols, &left_columns)
1332 } else {
1333 get_filter_column(&self.filter, &left_columns, &right_columns)
1334 }
1335 } else {
1336 get_filter_column(&self.filter, &right_columns, &left_columns)
1337 }
1338 } else {
1339 vec![]
1342 };
1343
1344 let columns = if !matches!(self.join_type, JoinType::Right) {
1345 left_columns.extend(right_columns);
1346 left_columns
1347 } else {
1348 right_columns.extend(left_columns);
1349 right_columns
1350 };
1351
1352 let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
1353 if !filter_columns.is_empty() {
1355 if let Some(f) = &self.filter {
1356 let filter_batch =
1358 RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?;
1359
1360 let filter_result = f
1361 .expression()
1362 .evaluate(&filter_batch)?
1363 .into_array(filter_batch.num_rows())?;
1364
1365 let pre_mask =
1367 datafusion_common::cast::as_boolean_array(&filter_result)?;
1368
1369 let mask = if pre_mask.null_count() > 0 {
1372 compute::prep_null_mask_filter(
1373 datafusion_common::cast::as_boolean_array(&filter_result)?,
1374 )
1375 } else {
1376 pre_mask.clone()
1377 };
1378
1379 if matches!(
1381 self.join_type,
1382 JoinType::Left
1383 | JoinType::LeftSemi
1384 | JoinType::Right
1385 | JoinType::RightSemi
1386 | JoinType::LeftAnti
1387 | JoinType::RightAnti
1388 | JoinType::LeftMark
1389 | JoinType::RightMark
1390 | JoinType::Full
1391 ) {
1392 self.staging_output_record_batches
1393 .batches
1394 .push(output_batch);
1395 } else {
1396 let filtered_batch = filter_record_batch(&output_batch, &mask)?;
1397 self.staging_output_record_batches
1398 .batches
1399 .push(filtered_batch);
1400 }
1401
1402 if !matches!(self.join_type, JoinType::Full) {
1403 self.staging_output_record_batches.filter_mask.extend(&mask);
1404 } else {
1405 self.staging_output_record_batches
1406 .filter_mask
1407 .extend(pre_mask);
1408 }
1409 self.staging_output_record_batches
1410 .row_indices
1411 .extend(&left_indices);
1412 self.staging_output_record_batches.batch_ids.resize(
1413 self.staging_output_record_batches.batch_ids.len()
1414 + left_indices.len(),
1415 self.streamed_batch_counter.load(Relaxed),
1416 );
1417
1418 if matches!(self.join_type, JoinType::Full) {
1423 let buffered_batch = &mut self.buffered_data.batches
1424 [chunk.buffered_batch_idx.unwrap()];
1425
1426 for i in 0..pre_mask.len() {
1427 if right_indices.is_null(i) {
1430 continue;
1431 }
1432
1433 let buffered_index = right_indices.value(i);
1434
1435 buffered_batch.join_filter_not_matched_map.insert(
1436 buffered_index,
1437 *buffered_batch
1438 .join_filter_not_matched_map
1439 .get(&buffered_index)
1440 .unwrap_or(&true)
1441 && !pre_mask.value(i),
1442 );
1443 }
1444 }
1445 } else {
1446 self.staging_output_record_batches
1447 .batches
1448 .push(output_batch);
1449 }
1450 } else {
1451 self.staging_output_record_batches
1452 .batches
1453 .push(output_batch);
1454 }
1455 }
1456
1457 self.streamed_batch.output_indices.clear();
1458
1459 Ok(())
1460 }
1461
1462 fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
1463 let record_batch =
1464 concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
1465 self.join_metrics.output_batches().add(1);
1466 self.join_metrics
1467 .baseline_metrics()
1468 .record_output(record_batch.num_rows());
1469 if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size {
1475 self.output_size = 0;
1476 } else {
1477 self.output_size -= record_batch.num_rows();
1478 }
1479
1480 if !(self.filter.is_some()
1481 && matches!(
1482 self.join_type,
1483 JoinType::Left
1484 | JoinType::LeftSemi
1485 | JoinType::Right
1486 | JoinType::RightSemi
1487 | JoinType::LeftAnti
1488 | JoinType::RightAnti
1489 | JoinType::LeftMark
1490 | JoinType::RightMark
1491 | JoinType::Full
1492 ))
1493 {
1494 self.staging_output_record_batches.batches.clear();
1495 }
1496
1497 Ok(record_batch)
1498 }
1499
1500 fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
1501 let record_batch =
1502 concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
1503 let mut out_indices = self.staging_output_record_batches.row_indices.finish();
1504 let mut out_mask = self.staging_output_record_batches.filter_mask.finish();
1505 let mut batch_ids = &self.staging_output_record_batches.batch_ids;
1506 let default_batch_ids = vec![0; record_batch.num_rows()];
1507
1508 if out_indices.null_count() == out_indices.len()
1512 && out_indices.len() != record_batch.num_rows()
1513 {
1514 out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]);
1515 out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]);
1516 batch_ids = &default_batch_ids;
1517 }
1518
1519 if out_mask.is_empty() {
1520 self.staging_output_record_batches.batches.clear();
1521 return Ok(record_batch);
1522 }
1523
1524 let maybe_corrected_mask = get_corrected_filter_mask(
1525 self.join_type,
1526 &out_indices,
1527 batch_ids,
1528 &out_mask,
1529 record_batch.num_rows(),
1530 );
1531
1532 let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask {
1533 filtered_join_mask
1534 } else {
1535 &out_mask
1536 };
1537
1538 self.filter_record_batch_by_join_type(record_batch, corrected_mask)
1539 }
1540
1541 fn filter_record_batch_by_join_type(
1542 &mut self,
1543 record_batch: RecordBatch,
1544 corrected_mask: &BooleanArray,
1545 ) -> Result<RecordBatch> {
1546 let mut filtered_record_batch =
1547 filter_record_batch(&record_batch, corrected_mask)?;
1548 let left_columns_length = self.streamed_schema.fields.len();
1549 let right_columns_length = self.buffered_schema.fields.len();
1550
1551 if matches!(
1552 self.join_type,
1553 JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark
1554 ) {
1555 let null_mask = compute::not(corrected_mask)?;
1556 let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?;
1557
1558 let mut right_columns = create_unmatched_columns(
1559 self.join_type,
1560 &self.buffered_schema,
1561 null_joined_batch.num_rows(),
1562 );
1563
1564 let columns = if !matches!(self.join_type, JoinType::Right) {
1565 let mut left_columns = null_joined_batch
1566 .columns()
1567 .iter()
1568 .take(right_columns_length)
1569 .cloned()
1570 .collect::<Vec<_>>();
1571
1572 left_columns.extend(right_columns);
1573 left_columns
1574 } else {
1575 let left_columns = null_joined_batch
1576 .columns()
1577 .iter()
1578 .skip(left_columns_length)
1579 .cloned()
1580 .collect::<Vec<_>>();
1581
1582 right_columns.extend(left_columns);
1583 right_columns
1584 };
1585
1586 let null_joined_streamed_batch =
1588 RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
1589
1590 filtered_record_batch = concat_batches(
1591 &self.schema,
1592 &[filtered_record_batch, null_joined_streamed_batch],
1593 )?;
1594 } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
1595 let output_column_indices = (0..left_columns_length).collect::<Vec<_>>();
1596 filtered_record_batch =
1597 filtered_record_batch.project(&output_column_indices)?;
1598 } else if matches!(self.join_type, JoinType::RightAnti | JoinType::RightSemi) {
1599 let output_column_indices = (0..right_columns_length).collect::<Vec<_>>();
1600 filtered_record_batch =
1601 filtered_record_batch.project(&output_column_indices)?;
1602 } else if matches!(self.join_type, JoinType::Full)
1603 && corrected_mask.false_count() > 0
1604 {
1605 let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
1607 let joined_filter_not_matched_batch =
1608 filter_record_batch(&record_batch, &joined_filter_not_matched_mask)?;
1609
1610 let right_null_columns = self
1612 .buffered_schema
1613 .fields()
1614 .iter()
1615 .map(|f| {
1616 new_null_array(
1617 f.data_type(),
1618 joined_filter_not_matched_batch.num_rows(),
1619 )
1620 })
1621 .collect::<Vec<_>>();
1622
1623 let mut result_joined = joined_filter_not_matched_batch
1624 .columns()
1625 .iter()
1626 .take(left_columns_length)
1627 .cloned()
1628 .collect::<Vec<_>>();
1629
1630 result_joined.extend(right_null_columns);
1631
1632 let left_null_joined_batch =
1633 RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?;
1634
1635 let mut result_joined = self
1637 .streamed_schema
1638 .fields()
1639 .iter()
1640 .map(|f| {
1641 new_null_array(
1642 f.data_type(),
1643 joined_filter_not_matched_batch.num_rows(),
1644 )
1645 })
1646 .collect::<Vec<_>>();
1647
1648 let right_data = joined_filter_not_matched_batch
1649 .columns()
1650 .iter()
1651 .skip(left_columns_length)
1652 .cloned()
1653 .collect::<Vec<_>>();
1654
1655 result_joined.extend(right_data);
1656
1657 filtered_record_batch = concat_batches(
1658 &self.schema,
1659 &[filtered_record_batch, left_null_joined_batch],
1660 )?;
1661 }
1662
1663 self.staging_output_record_batches.clear();
1664
1665 Ok(filtered_record_batch)
1666 }
1667}
1668
1669fn create_unmatched_columns(
1670 join_type: JoinType,
1671 schema: &SchemaRef,
1672 size: usize,
1673) -> Vec<ArrayRef> {
1674 if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) {
1675 vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
1676 } else {
1677 schema
1678 .fields()
1679 .iter()
1680 .map(|f| new_null_array(f.data_type(), size))
1681 .collect::<Vec<_>>()
1682 }
1683}
1684
1685fn get_filter_column(
1687 join_filter: &Option<JoinFilter>,
1688 streamed_columns: &[ArrayRef],
1689 buffered_columns: &[ArrayRef],
1690) -> Vec<ArrayRef> {
1691 let mut filter_columns = vec![];
1692
1693 if let Some(f) = join_filter {
1694 let left_columns = f
1695 .column_indices()
1696 .iter()
1697 .filter(|col_index| col_index.side == JoinSide::Left)
1698 .map(|i| Arc::clone(&streamed_columns[i.index]))
1699 .collect::<Vec<_>>();
1700
1701 let right_columns = f
1702 .column_indices()
1703 .iter()
1704 .filter(|col_index| col_index.side == JoinSide::Right)
1705 .map(|i| Arc::clone(&buffered_columns[i.index]))
1706 .collect::<Vec<_>>();
1707
1708 filter_columns.extend(left_columns);
1709 filter_columns.extend(right_columns);
1710 }
1711
1712 filter_columns
1713}
1714
1715fn produce_buffered_null_batch(
1716 schema: &SchemaRef,
1717 streamed_schema: &SchemaRef,
1718 buffered_indices: &PrimitiveArray<UInt64Type>,
1719 buffered_batch: &BufferedBatch,
1720) -> Result<Option<RecordBatch>> {
1721 if buffered_indices.is_empty() {
1722 return Ok(None);
1723 }
1724
1725 let right_columns =
1727 fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?;
1728
1729 let mut left_columns = streamed_schema
1731 .fields()
1732 .iter()
1733 .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
1734 .collect::<Vec<_>>();
1735
1736 left_columns.extend(right_columns);
1737
1738 Ok(Some(RecordBatch::try_new(
1739 Arc::clone(schema),
1740 left_columns,
1741 )?))
1742}
1743
1744#[inline(always)]
1746fn fetch_right_columns_by_idxs(
1747 buffered_data: &BufferedData,
1748 buffered_batch_idx: usize,
1749 buffered_indices: &UInt64Array,
1750) -> Result<Vec<ArrayRef>> {
1751 fetch_right_columns_from_batch_by_idxs(
1752 &buffered_data.batches[buffered_batch_idx],
1753 buffered_indices,
1754 )
1755}
1756
1757#[inline(always)]
1758fn fetch_right_columns_from_batch_by_idxs(
1759 buffered_batch: &BufferedBatch,
1760 buffered_indices: &UInt64Array,
1761) -> Result<Vec<ArrayRef>> {
1762 match &buffered_batch.batch {
1763 BufferedBatchState::InMemory(batch) => Ok(batch
1765 .columns()
1766 .iter()
1767 .map(|column| take(column, &buffered_indices, None))
1768 .collect::<Result<Vec<_>, ArrowError>>()
1769 .map_err(Into::<DataFusionError>::into)?),
1770 BufferedBatchState::Spilled(spill_file) => {
1772 let mut buffered_cols: Vec<ArrayRef> =
1773 Vec::with_capacity(buffered_indices.len());
1774
1775 let file = BufReader::new(File::open(spill_file.path())?);
1776 let reader = StreamReader::try_new(file, None)?;
1777
1778 for batch in reader {
1779 batch?.columns().iter().for_each(|column| {
1780 buffered_cols.extend(take(column, &buffered_indices, None))
1781 });
1782 }
1783
1784 Ok(buffered_cols)
1785 }
1786 }
1787}
1788
1789#[derive(Debug, Default)]
1791pub(super) struct BufferedData {
1792 pub batches: VecDeque<BufferedBatch>,
1794 pub scanning_batch_idx: usize,
1796 pub scanning_offset: usize,
1798}
1799
1800impl BufferedData {
1801 pub fn head_batch(&self) -> &BufferedBatch {
1802 self.batches.front().unwrap()
1803 }
1804
1805 pub fn tail_batch(&self) -> &BufferedBatch {
1806 self.batches.back().unwrap()
1807 }
1808
1809 pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch {
1810 self.batches.back_mut().unwrap()
1811 }
1812
1813 pub fn has_buffered_rows(&self) -> bool {
1814 self.batches.iter().any(|batch| !batch.range.is_empty())
1815 }
1816
1817 pub fn scanning_reset(&mut self) {
1818 self.scanning_batch_idx = 0;
1819 self.scanning_offset = 0;
1820 }
1821
1822 pub fn scanning_advance(&mut self) {
1823 self.scanning_offset += 1;
1824 while !self.scanning_finished() && self.scanning_batch_finished() {
1825 self.scanning_batch_idx += 1;
1826 self.scanning_offset = 0;
1827 }
1828 }
1829
1830 pub fn scanning_batch(&self) -> &BufferedBatch {
1831 &self.batches[self.scanning_batch_idx]
1832 }
1833
1834 pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch {
1835 &mut self.batches[self.scanning_batch_idx]
1836 }
1837
1838 pub fn scanning_idx(&self) -> usize {
1839 self.scanning_batch().range.start + self.scanning_offset
1840 }
1841
1842 pub fn scanning_batch_finished(&self) -> bool {
1843 self.scanning_offset == self.scanning_batch().range.len()
1844 }
1845
1846 pub fn scanning_finished(&self) -> bool {
1847 self.scanning_batch_idx == self.batches.len()
1848 }
1849
1850 pub fn scanning_finish(&mut self) {
1851 self.scanning_batch_idx = self.batches.len();
1852 self.scanning_offset = 0;
1853 }
1854}
1855
1856fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec<ArrayRef> {
1858 on_column
1859 .iter()
1860 .map(|c| {
1861 let num_rows = batch.num_rows();
1862 let c = c.evaluate(batch).unwrap();
1863 c.into_array(num_rows).unwrap()
1864 })
1865 .collect()
1866}
1867
1868fn is_join_arrays_equal(
1871 left_arrays: &[ArrayRef],
1872 left: usize,
1873 right_arrays: &[ArrayRef],
1874 right: usize,
1875) -> Result<bool> {
1876 let mut is_equal = true;
1877 for (left_array, right_array) in left_arrays.iter().zip(right_arrays) {
1878 macro_rules! compare_value {
1879 ($T:ty) => {{
1880 match (left_array.is_null(left), right_array.is_null(right)) {
1881 (false, false) => {
1882 let left_array =
1883 left_array.as_any().downcast_ref::<$T>().unwrap();
1884 let right_array =
1885 right_array.as_any().downcast_ref::<$T>().unwrap();
1886 if left_array.value(left) != right_array.value(right) {
1887 is_equal = false;
1888 }
1889 }
1890 (true, false) => is_equal = false,
1891 (false, true) => is_equal = false,
1892 _ => {}
1893 }
1894 }};
1895 }
1896
1897 match left_array.data_type() {
1898 DataType::Null => {}
1899 DataType::Boolean => compare_value!(BooleanArray),
1900 DataType::Int8 => compare_value!(Int8Array),
1901 DataType::Int16 => compare_value!(Int16Array),
1902 DataType::Int32 => compare_value!(Int32Array),
1903 DataType::Int64 => compare_value!(Int64Array),
1904 DataType::UInt8 => compare_value!(UInt8Array),
1905 DataType::UInt16 => compare_value!(UInt16Array),
1906 DataType::UInt32 => compare_value!(UInt32Array),
1907 DataType::UInt64 => compare_value!(UInt64Array),
1908 DataType::Float32 => compare_value!(Float32Array),
1909 DataType::Float64 => compare_value!(Float64Array),
1910 DataType::Utf8 => compare_value!(StringArray),
1911 DataType::Utf8View => compare_value!(StringViewArray),
1912 DataType::LargeUtf8 => compare_value!(LargeStringArray),
1913 DataType::Binary => compare_value!(BinaryArray),
1914 DataType::BinaryView => compare_value!(BinaryViewArray),
1915 DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray),
1916 DataType::LargeBinary => compare_value!(LargeBinaryArray),
1917 DataType::Decimal32(..) => compare_value!(Decimal32Array),
1918 DataType::Decimal64(..) => compare_value!(Decimal64Array),
1919 DataType::Decimal128(..) => compare_value!(Decimal128Array),
1920 DataType::Decimal256(..) => compare_value!(Decimal256Array),
1921 DataType::Timestamp(time_unit, None) => match time_unit {
1922 TimeUnit::Second => compare_value!(TimestampSecondArray),
1923 TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray),
1924 TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray),
1925 TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray),
1926 },
1927 DataType::Date32 => compare_value!(Date32Array),
1928 DataType::Date64 => compare_value!(Date64Array),
1929 dt => {
1930 return not_impl_err!(
1931 "Unsupported data type in sort merge join comparator: {}",
1932 dt
1933 );
1934 }
1935 }
1936 if !is_equal {
1937 return Ok(false);
1938 }
1939 }
1940 Ok(true)
1941}