1use std::sync::Arc;
24use std::task::Poll;
25
26use crate::joins::hash_join::exec::JoinLeftData;
27use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator;
28use crate::joins::utils::{
29 equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut,
30};
31use crate::joins::PartitionMode;
32use crate::{
33 handle_state,
34 hash_utils::create_hashes,
35 joins::join_hash_map::JoinHashMapOffset,
36 joins::utils::{
37 adjust_indices_by_join_type, apply_join_filter_to_indices,
38 build_batch_empty_build_side, build_batch_from_indices,
39 need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
40 JoinHashMapType, StatefulStreamResult,
41 },
42 RecordBatchStream, SendableRecordBatchStream,
43};
44
45use arrow::array::{ArrayRef, UInt32Array, UInt64Array};
46use arrow::datatypes::{Schema, SchemaRef};
47use arrow::record_batch::RecordBatch;
48use datafusion_common::{
49 internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result,
50};
51use datafusion_physical_expr::PhysicalExprRef;
52
53use ahash::RandomState;
54use futures::{ready, Stream, StreamExt};
55
56pub(super) enum BuildSide {
58 Initial(BuildSideInitialState),
60 Ready(BuildSideReadyState),
62}
63
64pub(super) struct BuildSideInitialState {
66 pub(super) left_fut: OnceFut<JoinLeftData>,
68}
69
70pub(super) struct BuildSideReadyState {
72 left_data: Arc<JoinLeftData>,
74}
75
76impl BuildSide {
77 fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> {
80 match self {
81 BuildSide::Initial(state) => Ok(state),
82 _ => internal_err!("Expected build side in initial state"),
83 }
84 }
85
86 fn try_as_ready(&self) -> Result<&BuildSideReadyState> {
89 match self {
90 BuildSide::Ready(state) => Ok(state),
91 _ => internal_err!("Expected build side in ready state"),
92 }
93 }
94
95 fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> {
98 match self {
99 BuildSide::Ready(state) => Ok(state),
100 _ => internal_err!("Expected build side in ready state"),
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
120pub(super) enum HashJoinStreamState {
121 WaitBuildSide,
123 WaitPartitionBoundsReport,
125 FetchProbeBatch,
127 ProcessProbeBatch(ProcessProbeBatchState),
129 ExhaustedProbeSide,
131 Completed,
133}
134
135impl HashJoinStreamState {
136 fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> {
139 match self {
140 HashJoinStreamState::ProcessProbeBatch(state) => Ok(state),
141 _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"),
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
148pub(super) struct ProcessProbeBatchState {
149 batch: RecordBatch,
151 values: Vec<ArrayRef>,
153 offset: JoinHashMapOffset,
155 joined_probe_idx: Option<usize>,
157}
158
159impl ProcessProbeBatchState {
160 fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option<usize>) {
161 self.offset = offset;
162 if joined_probe_idx.is_some() {
163 self.joined_probe_idx = joined_probe_idx;
164 }
165 }
166}
167
168pub(super) struct HashJoinStream {
178 partition: usize,
180 schema: Arc<Schema>,
182 on_right: Vec<PhysicalExprRef>,
184 filter: Option<JoinFilter>,
186 join_type: JoinType,
188 right: SendableRecordBatchStream,
190 random_state: RandomState,
192 join_metrics: BuildProbeJoinMetrics,
194 column_indices: Vec<ColumnIndex>,
196 null_equality: NullEquality,
198 state: HashJoinStreamState,
200 build_side: BuildSide,
202 batch_size: usize,
204 hashes_buffer: Vec<u64>,
206 right_side_ordered: bool,
208 bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>,
210 bounds_waiter: Option<OnceFut<()>>,
213
214 mode: PartitionMode,
216}
217
218impl RecordBatchStream for HashJoinStream {
219 fn schema(&self) -> SchemaRef {
220 Arc::clone(&self.schema)
221 }
222}
223
224#[allow(clippy::too_many_arguments)]
273pub(super) fn lookup_join_hashmap(
274 build_hashmap: &dyn JoinHashMapType,
275 build_side_values: &[ArrayRef],
276 probe_side_values: &[ArrayRef],
277 null_equality: NullEquality,
278 hashes_buffer: &[u64],
279 limit: usize,
280 offset: JoinHashMapOffset,
281) -> Result<(UInt64Array, UInt32Array, Option<JoinHashMapOffset>)> {
282 let (probe_indices, build_indices, next_offset) =
283 build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset);
284
285 let build_indices: UInt64Array = build_indices.into();
286 let probe_indices: UInt32Array = probe_indices.into();
287
288 let (build_indices, probe_indices) = equal_rows_arr(
289 &build_indices,
290 &probe_indices,
291 build_side_values,
292 probe_side_values,
293 null_equality,
294 )?;
295
296 Ok((build_indices, probe_indices, next_offset))
297}
298
299impl HashJoinStream {
300 #[allow(clippy::too_many_arguments)]
301 pub(super) fn new(
302 partition: usize,
303 schema: Arc<Schema>,
304 on_right: Vec<PhysicalExprRef>,
305 filter: Option<JoinFilter>,
306 join_type: JoinType,
307 right: SendableRecordBatchStream,
308 random_state: RandomState,
309 join_metrics: BuildProbeJoinMetrics,
310 column_indices: Vec<ColumnIndex>,
311 null_equality: NullEquality,
312 state: HashJoinStreamState,
313 build_side: BuildSide,
314 batch_size: usize,
315 hashes_buffer: Vec<u64>,
316 right_side_ordered: bool,
317 bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>,
318 mode: PartitionMode,
319 ) -> Self {
320 Self {
321 partition,
322 schema,
323 on_right,
324 filter,
325 join_type,
326 right,
327 random_state,
328 join_metrics,
329 column_indices,
330 null_equality,
331 state,
332 build_side,
333 batch_size,
334 hashes_buffer,
335 right_side_ordered,
336 bounds_accumulator,
337 bounds_waiter: None,
338 mode,
339 }
340 }
341
342 fn poll_next_impl(
345 &mut self,
346 cx: &mut std::task::Context<'_>,
347 ) -> Poll<Option<Result<RecordBatch>>> {
348 loop {
349 return match self.state {
350 HashJoinStreamState::WaitBuildSide => {
351 handle_state!(ready!(self.collect_build_side(cx)))
352 }
353 HashJoinStreamState::WaitPartitionBoundsReport => {
354 handle_state!(ready!(self.wait_for_partition_bounds_report(cx)))
355 }
356 HashJoinStreamState::FetchProbeBatch => {
357 handle_state!(ready!(self.fetch_probe_batch(cx)))
358 }
359 HashJoinStreamState::ProcessProbeBatch(_) => {
360 let poll = handle_state!(self.process_probe_batch());
361 self.join_metrics.baseline.record_poll(poll)
362 }
363 HashJoinStreamState::ExhaustedProbeSide => {
364 let poll = handle_state!(self.process_unmatched_build_batch());
365 self.join_metrics.baseline.record_poll(poll)
366 }
367 HashJoinStreamState::Completed => Poll::Ready(None),
368 };
369 }
370 }
371
372 fn wait_for_partition_bounds_report(
382 &mut self,
383 cx: &mut std::task::Context<'_>,
384 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
385 if let Some(ref mut fut) = self.bounds_waiter {
386 ready!(fut.get_shared(cx))?;
387 }
388 self.state = HashJoinStreamState::FetchProbeBatch;
389 Poll::Ready(Ok(StatefulStreamResult::Continue))
390 }
391
392 fn collect_build_side(
396 &mut self,
397 cx: &mut std::task::Context<'_>,
398 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
399 let build_timer = self.join_metrics.build_time.timer();
400 let left_data = ready!(self
402 .build_side
403 .try_as_initial_mut()?
404 .left_fut
405 .get_shared(cx))?;
406 build_timer.done();
407
408 if let Some(ref bounds_accumulator) = self.bounds_accumulator {
413 let bounds_accumulator = Arc::clone(bounds_accumulator);
414
415 let left_side_partition_id = match self.mode {
416 PartitionMode::Partitioned => self.partition,
417 PartitionMode::CollectLeft => 0,
418 PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"),
419 };
420
421 let left_data_bounds = left_data.bounds.clone();
422 self.bounds_waiter = Some(OnceFut::new(async move {
423 bounds_accumulator
424 .report_partition_bounds(left_side_partition_id, left_data_bounds)
425 .await
426 }));
427 self.state = HashJoinStreamState::WaitPartitionBoundsReport;
428 } else {
429 self.state = HashJoinStreamState::FetchProbeBatch;
430 }
431
432 self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });
433 Poll::Ready(Ok(StatefulStreamResult::Continue))
434 }
435
436 fn fetch_probe_batch(
441 &mut self,
442 cx: &mut std::task::Context<'_>,
443 ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
444 match ready!(self.right.poll_next_unpin(cx)) {
445 None => {
446 self.state = HashJoinStreamState::ExhaustedProbeSide;
447 }
448 Some(Ok(batch)) => {
449 let keys_values = self
451 .on_right
452 .iter()
453 .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows()))
454 .collect::<Result<Vec<_>>>()?;
455
456 self.hashes_buffer.clear();
457 self.hashes_buffer.resize(batch.num_rows(), 0);
458 create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?;
459
460 self.join_metrics.input_batches.add(1);
461 self.join_metrics.input_rows.add(batch.num_rows());
462
463 self.state =
464 HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
465 batch,
466 values: keys_values,
467 offset: (0, None),
468 joined_probe_idx: None,
469 });
470 }
471 Some(Err(err)) => return Poll::Ready(Err(err)),
472 };
473
474 Poll::Ready(Ok(StatefulStreamResult::Continue))
475 }
476
477 fn process_probe_batch(
481 &mut self,
482 ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
483 let state = self.state.try_as_process_probe_batch_mut()?;
484 let build_side = self.build_side.try_as_ready_mut()?;
485
486 let timer = self.join_metrics.join_time.timer();
487
488 if build_side.left_data.hash_map.is_empty() && self.filter.is_none() {
490 let result = build_batch_empty_build_side(
491 &self.schema,
492 build_side.left_data.batch(),
493 &state.batch,
494 &self.column_indices,
495 self.join_type,
496 )?;
497 self.join_metrics.output_batches.add(1);
498 timer.done();
499
500 self.state = HashJoinStreamState::FetchProbeBatch;
501
502 return Ok(StatefulStreamResult::Ready(Some(result)));
503 }
504
505 let (left_indices, right_indices, next_offset) = lookup_join_hashmap(
507 build_side.left_data.hash_map(),
508 build_side.left_data.values(),
509 &state.values,
510 self.null_equality,
511 &self.hashes_buffer,
512 self.batch_size,
513 state.offset,
514 )?;
515
516 let (left_indices, right_indices) = if let Some(filter) = &self.filter {
518 apply_join_filter_to_indices(
519 build_side.left_data.batch(),
520 &state.batch,
521 left_indices,
522 right_indices,
523 filter,
524 JoinSide::Left,
525 None,
526 )?
527 } else {
528 (left_indices, right_indices)
529 };
530
531 if need_produce_result_in_final(self.join_type) {
533 let mut bitmap = build_side.left_data.visited_indices_bitmap().lock();
534 left_indices.iter().flatten().for_each(|x| {
535 bitmap.set_bit(x as usize, true);
536 });
537 }
538
539 let last_joined_right_idx = match right_indices.len() {
557 0 => None,
558 n => Some(right_indices.value(n - 1) as usize),
559 };
560
561 let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1);
564 let index_alignment_range_end = if next_offset.is_none() {
565 state.batch.num_rows()
566 } else {
567 last_joined_right_idx.map_or(0, |v| v + 1)
568 };
569
570 let (left_indices, right_indices) = adjust_indices_by_join_type(
571 left_indices,
572 right_indices,
573 index_alignment_range_start..index_alignment_range_end,
574 self.join_type,
575 self.right_side_ordered,
576 )?;
577
578 let result = if self.join_type == JoinType::RightMark {
579 build_batch_from_indices(
580 &self.schema,
581 &state.batch,
582 build_side.left_data.batch(),
583 &left_indices,
584 &right_indices,
585 &self.column_indices,
586 JoinSide::Right,
587 )?
588 } else {
589 build_batch_from_indices(
590 &self.schema,
591 build_side.left_data.batch(),
592 &state.batch,
593 &left_indices,
594 &right_indices,
595 &self.column_indices,
596 JoinSide::Left,
597 )?
598 };
599
600 self.join_metrics.output_batches.add(1);
601 timer.done();
602
603 if next_offset.is_none() {
604 self.state = HashJoinStreamState::FetchProbeBatch;
605 } else {
606 state.advance(
607 next_offset
608 .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?,
609 last_joined_right_idx,
610 )
611 };
612
613 Ok(StatefulStreamResult::Ready(Some(result)))
614 }
615
616 fn process_unmatched_build_batch(
620 &mut self,
621 ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
622 let timer = self.join_metrics.join_time.timer();
623
624 if !need_produce_result_in_final(self.join_type) {
625 self.state = HashJoinStreamState::Completed;
626 return Ok(StatefulStreamResult::Continue);
627 }
628
629 let build_side = self.build_side.try_as_ready()?;
630 if !build_side.left_data.report_probe_completed() {
631 self.state = HashJoinStreamState::Completed;
632 return Ok(StatefulStreamResult::Continue);
633 }
634
635 let (left_side, right_side) = get_final_indices_from_shared_bitmap(
637 build_side.left_data.visited_indices_bitmap(),
638 self.join_type,
639 true,
640 );
641 let empty_right_batch = RecordBatch::new_empty(self.right.schema());
642 let result = build_batch_from_indices(
644 &self.schema,
645 build_side.left_data.batch(),
646 &empty_right_batch,
647 &left_side,
648 &right_side,
649 &self.column_indices,
650 JoinSide::Left,
651 );
652
653 if let Ok(ref batch) = result {
654 self.join_metrics.input_batches.add(1);
655 self.join_metrics.input_rows.add(batch.num_rows());
656
657 self.join_metrics.output_batches.add(1);
658 }
659 timer.done();
660
661 self.state = HashJoinStreamState::Completed;
662
663 Ok(StatefulStreamResult::Ready(Some(result?)))
664 }
665}
666
667impl Stream for HashJoinStream {
668 type Item = Result<RecordBatch>;
669
670 fn poll_next(
671 mut self: std::pin::Pin<&mut Self>,
672 cx: &mut std::task::Context<'_>,
673 ) -> Poll<Option<Self::Item>> {
674 self.poll_next_impl(cx)
675 }
676}