datafusion_physical_plan/joins/hash_join/
stream.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Stream implementation for Hash Join
19//!
20//! This module implements [`HashJoinStream`], the streaming engine for
21//! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details.
22
23use std::sync::Arc;
24use std::task::Poll;
25
26use crate::joins::hash_join::exec::JoinLeftData;
27use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator;
28use crate::joins::utils::{
29    equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut,
30};
31use crate::joins::PartitionMode;
32use crate::{
33    handle_state,
34    hash_utils::create_hashes,
35    joins::join_hash_map::JoinHashMapOffset,
36    joins::utils::{
37        adjust_indices_by_join_type, apply_join_filter_to_indices,
38        build_batch_empty_build_side, build_batch_from_indices,
39        need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
40        JoinHashMapType, StatefulStreamResult,
41    },
42    RecordBatchStream, SendableRecordBatchStream,
43};
44
45use arrow::array::{ArrayRef, UInt32Array, UInt64Array};
46use arrow::datatypes::{Schema, SchemaRef};
47use arrow::record_batch::RecordBatch;
48use datafusion_common::{
49    internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result,
50};
51use datafusion_physical_expr::PhysicalExprRef;
52
53use ahash::RandomState;
54use futures::{ready, Stream, StreamExt};
55
56/// Represents build-side of hash join.
57pub(super) enum BuildSide {
58    /// Indicates that build-side not collected yet
59    Initial(BuildSideInitialState),
60    /// Indicates that build-side data has been collected
61    Ready(BuildSideReadyState),
62}
63
64/// Container for BuildSide::Initial related data
65pub(super) struct BuildSideInitialState {
66    /// Future for building hash table from build-side input
67    pub(super) left_fut: OnceFut<JoinLeftData>,
68}
69
70/// Container for BuildSide::Ready related data
71pub(super) struct BuildSideReadyState {
72    /// Collected build-side data
73    left_data: Arc<JoinLeftData>,
74}
75
76impl BuildSide {
77    /// Tries to extract BuildSideInitialState from BuildSide enum.
78    /// Returns an error if state is not Initial.
79    fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> {
80        match self {
81            BuildSide::Initial(state) => Ok(state),
82            _ => internal_err!("Expected build side in initial state"),
83        }
84    }
85
86    /// Tries to extract BuildSideReadyState from BuildSide enum.
87    /// Returns an error if state is not Ready.
88    fn try_as_ready(&self) -> Result<&BuildSideReadyState> {
89        match self {
90            BuildSide::Ready(state) => Ok(state),
91            _ => internal_err!("Expected build side in ready state"),
92        }
93    }
94
95    /// Tries to extract BuildSideReadyState from BuildSide enum.
96    /// Returns an error if state is not Ready.
97    fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> {
98        match self {
99            BuildSide::Ready(state) => Ok(state),
100            _ => internal_err!("Expected build side in ready state"),
101        }
102    }
103}
104
105/// Represents state of HashJoinStream
106///
107/// Expected state transitions performed by HashJoinStream are:
108///
109/// ```text
110///
111///       WaitBuildSide
112///             │
113///             ▼
114///  ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed
115///  │          │
116///  │          ▼
117///  └─ ProcessProbeBatch
118/// ```
119#[derive(Debug, Clone)]
120pub(super) enum HashJoinStreamState {
121    /// Initial state for HashJoinStream indicating that build-side data not collected yet
122    WaitBuildSide,
123    /// Waiting for bounds to be reported by all partitions
124    WaitPartitionBoundsReport,
125    /// Indicates that build-side has been collected, and stream is ready for fetching probe-side
126    FetchProbeBatch,
127    /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed
128    ProcessProbeBatch(ProcessProbeBatchState),
129    /// Indicates that probe-side has been fully processed
130    ExhaustedProbeSide,
131    /// Indicates that HashJoinStream execution is completed
132    Completed,
133}
134
135impl HashJoinStreamState {
136    /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum.
137    /// Returns an error if state is not ProcessProbeBatchState.
138    fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> {
139        match self {
140            HashJoinStreamState::ProcessProbeBatch(state) => Ok(state),
141            _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"),
142        }
143    }
144}
145
146/// Container for HashJoinStreamState::ProcessProbeBatch related data
147#[derive(Debug, Clone)]
148pub(super) struct ProcessProbeBatchState {
149    /// Current probe-side batch
150    batch: RecordBatch,
151    /// Probe-side on expressions values
152    values: Vec<ArrayRef>,
153    /// Starting offset for JoinHashMap lookups
154    offset: JoinHashMapOffset,
155    /// Max joined probe-side index from current batch
156    joined_probe_idx: Option<usize>,
157}
158
159impl ProcessProbeBatchState {
160    fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option<usize>) {
161        self.offset = offset;
162        if joined_probe_idx.is_some() {
163            self.joined_probe_idx = joined_probe_idx;
164        }
165    }
166}
167
168/// [`Stream`] for [`super::HashJoinExec`] that does the actual join.
169///
170/// This stream:
171///
172/// - Collecting the build side (left input) into a hash map
173/// - Iterating over the probe side (right input) in streaming fashion
174/// - Looking up matches against the hash table and applying join filters
175/// - Producing joined [`RecordBatch`]es incrementally
176/// - Emitting unmatched rows for outer/semi/anti joins in the final stage
177pub(super) struct HashJoinStream {
178    /// Partition identifier for debugging and determinism
179    partition: usize,
180    /// Input schema
181    schema: Arc<Schema>,
182    /// equijoin columns from the right (probe side)
183    on_right: Vec<PhysicalExprRef>,
184    /// optional join filter
185    filter: Option<JoinFilter>,
186    /// type of the join (left, right, semi, etc)
187    join_type: JoinType,
188    /// right (probe) input
189    right: SendableRecordBatchStream,
190    /// Random state used for hashing initialization
191    random_state: RandomState,
192    /// Metrics
193    join_metrics: BuildProbeJoinMetrics,
194    /// Information of index and left / right placement of columns
195    column_indices: Vec<ColumnIndex>,
196    /// Defines the null equality for the join.
197    null_equality: NullEquality,
198    /// State of the stream
199    state: HashJoinStreamState,
200    /// Build side
201    build_side: BuildSide,
202    /// Maximum output batch size
203    batch_size: usize,
204    /// Scratch space for computing hashes
205    hashes_buffer: Vec<u64>,
206    /// Specifies whether the right side has an ordering to potentially preserve
207    right_side_ordered: bool,
208    /// Shared bounds accumulator for coordinating dynamic filter updates (optional)
209    bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>,
210    /// Optional future to signal when bounds have been reported by all partitions
211    /// and the dynamic filter has been updated
212    bounds_waiter: Option<OnceFut<()>>,
213
214    /// Partitioning mode to use
215    mode: PartitionMode,
216}
217
218impl RecordBatchStream for HashJoinStream {
219    fn schema(&self) -> SchemaRef {
220        Arc::clone(&self.schema)
221    }
222}
223
224/// Executes lookups by hash against JoinHashMap and resolves potential
225/// hash collisions.
226/// Returns build/probe indices satisfying the equality condition, along with
227/// (optional) starting point for next iteration.
228///
229/// # Example
230///
231/// For `LEFT.b1 = RIGHT.b2`:
232/// LEFT (build) Table:
233/// ```text
234///  a1  b1  c1
235///  1   1   10
236///  3   3   30
237///  5   5   50
238///  7   7   70
239///  9   8   90
240///  11  8   110
241///  13   10  130
242/// ```
243///
244/// RIGHT (probe) Table:
245/// ```text
246///  a2   b2  c2
247///  2    2   20
248///  4    4   40
249///  6    6   60
250///  8    8   80
251/// 10   10  100
252/// 12   10  120
253/// ```
254///
255/// The result is
256/// ```text
257/// "+----+----+-----+----+----+-----+",
258/// "| a1 | b1 | c1  | a2 | b2 | c2  |",
259/// "+----+----+-----+----+----+-----+",
260/// "| 9  | 8  | 90  | 8  | 8  | 80  |",
261/// "| 11 | 8  | 110 | 8  | 8  | 80  |",
262/// "| 13 | 10 | 130 | 10 | 10 | 100 |",
263/// "| 13 | 10 | 130 | 12 | 10 | 120 |",
264/// "+----+----+-----+----+----+-----+"
265/// ```
266///
267/// And the result of build and probe indices are:
268/// ```text
269/// Build indices: 4, 5, 6, 6
270/// Probe indices: 3, 3, 4, 5
271/// ```
272#[allow(clippy::too_many_arguments)]
273pub(super) fn lookup_join_hashmap(
274    build_hashmap: &dyn JoinHashMapType,
275    build_side_values: &[ArrayRef],
276    probe_side_values: &[ArrayRef],
277    null_equality: NullEquality,
278    hashes_buffer: &[u64],
279    limit: usize,
280    offset: JoinHashMapOffset,
281) -> Result<(UInt64Array, UInt32Array, Option<JoinHashMapOffset>)> {
282    let (probe_indices, build_indices, next_offset) =
283        build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset);
284
285    let build_indices: UInt64Array = build_indices.into();
286    let probe_indices: UInt32Array = probe_indices.into();
287
288    let (build_indices, probe_indices) = equal_rows_arr(
289        &build_indices,
290        &probe_indices,
291        build_side_values,
292        probe_side_values,
293        null_equality,
294    )?;
295
296    Ok((build_indices, probe_indices, next_offset))
297}
298
299impl HashJoinStream {
300    #[allow(clippy::too_many_arguments)]
301    pub(super) fn new(
302        partition: usize,
303        schema: Arc<Schema>,
304        on_right: Vec<PhysicalExprRef>,
305        filter: Option<JoinFilter>,
306        join_type: JoinType,
307        right: SendableRecordBatchStream,
308        random_state: RandomState,
309        join_metrics: BuildProbeJoinMetrics,
310        column_indices: Vec<ColumnIndex>,
311        null_equality: NullEquality,
312        state: HashJoinStreamState,
313        build_side: BuildSide,
314        batch_size: usize,
315        hashes_buffer: Vec<u64>,
316        right_side_ordered: bool,
317        bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>,
318        mode: PartitionMode,
319    ) -> Self {
320        Self {
321            partition,
322            schema,
323            on_right,
324            filter,
325            join_type,
326            right,
327            random_state,
328            join_metrics,
329            column_indices,
330            null_equality,
331            state,
332            build_side,
333            batch_size,
334            hashes_buffer,
335            right_side_ordered,
336            bounds_accumulator,
337            bounds_waiter: None,
338            mode,
339        }
340    }
341
342    /// Separate implementation function that unpins the [`HashJoinStream`] so
343    /// that partial borrows work correctly
344    fn poll_next_impl(
345        &mut self,
346        cx: &mut std::task::Context<'_>,
347    ) -> Poll<Option<Result<RecordBatch>>> {
348        loop {
349            return match self.state {
350                HashJoinStreamState::WaitBuildSide => {
351                    handle_state!(ready!(self.collect_build_side(cx)))
352                }
353                HashJoinStreamState::WaitPartitionBoundsReport => {
354                    handle_state!(ready!(self.wait_for_partition_bounds_report(cx)))
355                }
356                HashJoinStreamState::FetchProbeBatch => {
357                    handle_state!(ready!(self.fetch_probe_batch(cx)))
358                }
359                HashJoinStreamState::ProcessProbeBatch(_) => {
360                    let poll = handle_state!(self.process_probe_batch());
361                    self.join_metrics.baseline.record_poll(poll)
362                }
363                HashJoinStreamState::ExhaustedProbeSide => {
364                    let poll = handle_state!(self.process_unmatched_build_batch());
365                    self.join_metrics.baseline.record_poll(poll)
366                }
367                HashJoinStreamState::Completed => Poll::Ready(None),
368            };
369        }
370    }
371
372    /// Optional step to wait until bounds have been reported by all partitions.
373    /// This state is only entered if a bounds accumulator is present.
374    ///
375    /// ## Why wait?
376    ///
377    /// The dynamic filter is only built once all partitions have reported their bounds.
378    /// If we do not wait here, the probe-side scan may start before the filter is ready.
379    /// This can lead to the probe-side scan missing the opportunity to apply the filter
380    /// and skip reading unnecessary data.
381    fn wait_for_partition_bounds_report(
382        &mut self,
383        cx: &mut std::task::Context<'_>,
384    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
385        if let Some(ref mut fut) = self.bounds_waiter {
386            ready!(fut.get_shared(cx))?;
387        }
388        self.state = HashJoinStreamState::FetchProbeBatch;
389        Poll::Ready(Ok(StatefulStreamResult::Continue))
390    }
391
392    /// Collects build-side data by polling `OnceFut` future from initialized build-side
393    ///
394    /// Updates build-side to `Ready`, and state to `FetchProbeSide`
395    fn collect_build_side(
396        &mut self,
397        cx: &mut std::task::Context<'_>,
398    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
399        let build_timer = self.join_metrics.build_time.timer();
400        // build hash table from left (build) side, if not yet done
401        let left_data = ready!(self
402            .build_side
403            .try_as_initial_mut()?
404            .left_fut
405            .get_shared(cx))?;
406        build_timer.done();
407
408        // Handle dynamic filter bounds accumulation
409        //
410        // Dynamic filter coordination between partitions:
411        // Report bounds to the accumulator which will handle synchronization and filter updates
412        if let Some(ref bounds_accumulator) = self.bounds_accumulator {
413            let bounds_accumulator = Arc::clone(bounds_accumulator);
414
415            let left_side_partition_id = match self.mode {
416                PartitionMode::Partitioned => self.partition,
417                PartitionMode::CollectLeft => 0,
418                PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"),
419            };
420
421            let left_data_bounds = left_data.bounds.clone();
422            self.bounds_waiter = Some(OnceFut::new(async move {
423                bounds_accumulator
424                    .report_partition_bounds(left_side_partition_id, left_data_bounds)
425                    .await
426            }));
427            self.state = HashJoinStreamState::WaitPartitionBoundsReport;
428        } else {
429            self.state = HashJoinStreamState::FetchProbeBatch;
430        }
431
432        self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });
433        Poll::Ready(Ok(StatefulStreamResult::Continue))
434    }
435
436    /// Fetches next batch from probe-side
437    ///
438    /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`,
439    /// otherwise updates state to `ExhaustedProbeSide`
440    fn fetch_probe_batch(
441        &mut self,
442        cx: &mut std::task::Context<'_>,
443    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
444        match ready!(self.right.poll_next_unpin(cx)) {
445            None => {
446                self.state = HashJoinStreamState::ExhaustedProbeSide;
447            }
448            Some(Ok(batch)) => {
449                // Precalculate hash values for fetched batch
450                let keys_values = self
451                    .on_right
452                    .iter()
453                    .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows()))
454                    .collect::<Result<Vec<_>>>()?;
455
456                self.hashes_buffer.clear();
457                self.hashes_buffer.resize(batch.num_rows(), 0);
458                create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?;
459
460                self.join_metrics.input_batches.add(1);
461                self.join_metrics.input_rows.add(batch.num_rows());
462
463                self.state =
464                    HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
465                        batch,
466                        values: keys_values,
467                        offset: (0, None),
468                        joined_probe_idx: None,
469                    });
470            }
471            Some(Err(err)) => return Poll::Ready(Err(err)),
472        };
473
474        Poll::Ready(Ok(StatefulStreamResult::Continue))
475    }
476
477    /// Joins current probe batch with build-side data and produces batch with matched output
478    ///
479    /// Updates state to `FetchProbeBatch`
480    fn process_probe_batch(
481        &mut self,
482    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
483        let state = self.state.try_as_process_probe_batch_mut()?;
484        let build_side = self.build_side.try_as_ready_mut()?;
485
486        let timer = self.join_metrics.join_time.timer();
487
488        // if the left side is empty, we can skip the (potentially expensive) join operation
489        if build_side.left_data.hash_map.is_empty() && self.filter.is_none() {
490            let result = build_batch_empty_build_side(
491                &self.schema,
492                build_side.left_data.batch(),
493                &state.batch,
494                &self.column_indices,
495                self.join_type,
496            )?;
497            self.join_metrics.output_batches.add(1);
498            timer.done();
499
500            self.state = HashJoinStreamState::FetchProbeBatch;
501
502            return Ok(StatefulStreamResult::Ready(Some(result)));
503        }
504
505        // get the matched by join keys indices
506        let (left_indices, right_indices, next_offset) = lookup_join_hashmap(
507            build_side.left_data.hash_map(),
508            build_side.left_data.values(),
509            &state.values,
510            self.null_equality,
511            &self.hashes_buffer,
512            self.batch_size,
513            state.offset,
514        )?;
515
516        // apply join filter if exists
517        let (left_indices, right_indices) = if let Some(filter) = &self.filter {
518            apply_join_filter_to_indices(
519                build_side.left_data.batch(),
520                &state.batch,
521                left_indices,
522                right_indices,
523                filter,
524                JoinSide::Left,
525                None,
526            )?
527        } else {
528            (left_indices, right_indices)
529        };
530
531        // mark joined left-side indices as visited, if required by join type
532        if need_produce_result_in_final(self.join_type) {
533            let mut bitmap = build_side.left_data.visited_indices_bitmap().lock();
534            left_indices.iter().flatten().for_each(|x| {
535                bitmap.set_bit(x as usize, true);
536            });
537        }
538
539        // The goals of index alignment for different join types are:
540        //
541        // 1) Right & FullJoin -- to append all missing probe-side indices between
542        //    previous (excluding) and current joined indices.
543        // 2) SemiJoin -- deduplicate probe indices in range between previous
544        //    (excluding) and current joined indices.
545        // 3) AntiJoin -- return only missing indices in range between
546        //    previous and current joined indices.
547        //    Inclusion/exclusion of the indices themselves don't matter
548        //
549        // As a summary -- alignment range can be produced based only on
550        // joined (matched with filters applied) probe side indices, excluding starting one
551        // (left from previous iteration).
552
553        // if any rows have been joined -- get last joined probe-side (right) row
554        // it's important that index counts as "joined" after hash collisions checks
555        // and join filters applied.
556        let last_joined_right_idx = match right_indices.len() {
557            0 => None,
558            n => Some(right_indices.value(n - 1) as usize),
559        };
560
561        // Calculate range and perform alignment.
562        // In case probe batch has been processed -- align all remaining rows.
563        let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1);
564        let index_alignment_range_end = if next_offset.is_none() {
565            state.batch.num_rows()
566        } else {
567            last_joined_right_idx.map_or(0, |v| v + 1)
568        };
569
570        let (left_indices, right_indices) = adjust_indices_by_join_type(
571            left_indices,
572            right_indices,
573            index_alignment_range_start..index_alignment_range_end,
574            self.join_type,
575            self.right_side_ordered,
576        )?;
577
578        let result = if self.join_type == JoinType::RightMark {
579            build_batch_from_indices(
580                &self.schema,
581                &state.batch,
582                build_side.left_data.batch(),
583                &left_indices,
584                &right_indices,
585                &self.column_indices,
586                JoinSide::Right,
587            )?
588        } else {
589            build_batch_from_indices(
590                &self.schema,
591                build_side.left_data.batch(),
592                &state.batch,
593                &left_indices,
594                &right_indices,
595                &self.column_indices,
596                JoinSide::Left,
597            )?
598        };
599
600        self.join_metrics.output_batches.add(1);
601        timer.done();
602
603        if next_offset.is_none() {
604            self.state = HashJoinStreamState::FetchProbeBatch;
605        } else {
606            state.advance(
607                next_offset
608                    .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?,
609                last_joined_right_idx,
610            )
611        };
612
613        Ok(StatefulStreamResult::Ready(Some(result)))
614    }
615
616    /// Processes unmatched build-side rows for certain join types and produces output batch
617    ///
618    /// Updates state to `Completed`
619    fn process_unmatched_build_batch(
620        &mut self,
621    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
622        let timer = self.join_metrics.join_time.timer();
623
624        if !need_produce_result_in_final(self.join_type) {
625            self.state = HashJoinStreamState::Completed;
626            return Ok(StatefulStreamResult::Continue);
627        }
628
629        let build_side = self.build_side.try_as_ready()?;
630        if !build_side.left_data.report_probe_completed() {
631            self.state = HashJoinStreamState::Completed;
632            return Ok(StatefulStreamResult::Continue);
633        }
634
635        // use the global left bitmap to produce the left indices and right indices
636        let (left_side, right_side) = get_final_indices_from_shared_bitmap(
637            build_side.left_data.visited_indices_bitmap(),
638            self.join_type,
639            true,
640        );
641        let empty_right_batch = RecordBatch::new_empty(self.right.schema());
642        // use the left and right indices to produce the batch result
643        let result = build_batch_from_indices(
644            &self.schema,
645            build_side.left_data.batch(),
646            &empty_right_batch,
647            &left_side,
648            &right_side,
649            &self.column_indices,
650            JoinSide::Left,
651        );
652
653        if let Ok(ref batch) = result {
654            self.join_metrics.input_batches.add(1);
655            self.join_metrics.input_rows.add(batch.num_rows());
656
657            self.join_metrics.output_batches.add(1);
658        }
659        timer.done();
660
661        self.state = HashJoinStreamState::Completed;
662
663        Ok(StatefulStreamResult::Ready(Some(result?)))
664    }
665}
666
667impl Stream for HashJoinStream {
668    type Item = Result<RecordBatch>;
669
670    fn poll_next(
671        mut self: std::pin::Pin<&mut Self>,
672        cx: &mut std::task::Context<'_>,
673    ) -> Poll<Option<Self::Item>> {
674        self.poll_next_impl(cx)
675    }
676}