datafusion_physical_plan/joins/
stream_join_utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file contains common subroutines for symmetric hash join
19//! related functionality, used both in join calculations and optimization rules.
20
21use std::collections::{HashMap, VecDeque};
22use std::mem::size_of;
23use std::sync::Arc;
24
25use crate::joins::join_hash_map::{
26    get_matched_indices, get_matched_indices_with_limit_offset, update_from_iter,
27    JoinHashMapOffset,
28};
29use crate::joins::utils::{JoinFilter, JoinHashMapType};
30use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder};
31use crate::{metrics, ExecutionPlan};
32
33use arrow::array::{
34    ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch,
35};
36use arrow::compute::concat_batches;
37use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
38use datafusion_common::tree_node::{Transformed, TransformedResult};
39use datafusion_common::utils::memory::estimate_memory_size;
40use datafusion_common::{
41    arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue,
42};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_physical_expr::expressions::Column;
45use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
46use datafusion_physical_expr::utils::collect_columns;
47use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt, PhysicalSortExpr};
48
49use datafusion_physical_expr_common::sort_expr::LexOrdering;
50use hashbrown::HashTable;
51
52/// Implementation of `JoinHashMapType` for `PruningJoinHashMap`.
53impl JoinHashMapType for PruningJoinHashMap {
54    // Extend with zero
55    fn extend_zero(&mut self, len: usize) {
56        self.next.resize(self.next.len() + len, 0)
57    }
58
59    fn update_from_iter<'a>(
60        &mut self,
61        iter: Box<dyn Iterator<Item = (usize, &'a u64)> + Send + 'a>,
62        deleted_offset: usize,
63    ) {
64        let slice: &mut [u64] = self.next.make_contiguous();
65        update_from_iter::<u64>(&mut self.map, slice, iter, deleted_offset);
66    }
67
68    fn get_matched_indices<'a>(
69        &self,
70        iter: Box<dyn Iterator<Item = (usize, &'a u64)> + 'a>,
71        deleted_offset: Option<usize>,
72    ) -> (Vec<u32>, Vec<u64>) {
73        // Flatten the deque
74        let next: Vec<u64> = self.next.iter().copied().collect();
75        get_matched_indices::<u64>(&self.map, &next, iter, deleted_offset)
76    }
77
78    fn get_matched_indices_with_limit_offset(
79        &self,
80        hash_values: &[u64],
81        limit: usize,
82        offset: JoinHashMapOffset,
83    ) -> (Vec<u32>, Vec<u64>, Option<JoinHashMapOffset>) {
84        // Flatten the deque
85        let next: Vec<u64> = self.next.iter().copied().collect();
86        get_matched_indices_with_limit_offset::<u64>(
87            &self.map,
88            &next,
89            hash_values,
90            limit,
91            offset,
92        )
93    }
94
95    fn is_empty(&self) -> bool {
96        self.map.is_empty()
97    }
98}
99
100/// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with
101/// the capability of pruning elements in an efficient manner. This structure
102/// is particularly useful for cases where it's necessary to remove elements
103/// from the map based on their buffer order.
104///
105/// # Example
106///
107/// ``` text
108/// Let's continue the example of `JoinHashMap` and then show how `PruningJoinHashMap` would
109/// handle the pruning scenario.
110///
111/// Insert the pair (10,4) into the `PruningJoinHashMap`:
112/// map:
113/// ----------
114/// | 10 | 5 |
115/// | 20 | 3 |
116/// ----------
117/// list:
118/// ---------------------
119/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1)
120/// ---------------------
121///
122/// Now, let's prune 3 rows from `PruningJoinHashMap`:
123/// map:
124/// ---------
125/// | 1 | 5 |
126/// ---------
127/// list:
128/// ---------
129/// | 2 | 4 | <--- hash value 10 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0)
130/// ---------
131///
132/// After pruning, the | 2 | 3 | entry is deleted from `PruningJoinHashMap` since
133/// there are no values left for this key.
134/// ```
135pub struct PruningJoinHashMap {
136    /// Stores hash value to last row index
137    pub map: HashTable<(u64, u64)>,
138    /// Stores indices in chained list data structure
139    pub next: VecDeque<u64>,
140}
141
142impl PruningJoinHashMap {
143    /// Constructs a new `PruningJoinHashMap` with the given capacity.
144    /// Both the map and the list are pre-allocated with the provided capacity.
145    ///
146    /// # Arguments
147    /// * `capacity`: The initial capacity of the hash map.
148    ///
149    /// # Returns
150    /// A new instance of `PruningJoinHashMap`.
151    pub(crate) fn with_capacity(capacity: usize) -> Self {
152        PruningJoinHashMap {
153            map: HashTable::with_capacity(capacity),
154            next: VecDeque::with_capacity(capacity),
155        }
156    }
157
158    /// Shrinks the capacity of the hash map, if necessary, based on the
159    /// provided scale factor.
160    ///
161    /// # Arguments
162    /// * `scale_factor`: The scale factor that determines how conservative the
163    ///   shrinking strategy is. The capacity will be reduced by 1/`scale_factor`
164    ///   when necessary.
165    ///
166    /// # Note
167    /// Increasing the scale factor results in less aggressive capacity shrinking,
168    /// leading to potentially higher memory usage but fewer resizes. Conversely,
169    /// decreasing the scale factor results in more aggressive capacity shrinking,
170    /// potentially leading to lower memory usage but more frequent resizing.
171    pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) {
172        let capacity = self.map.capacity();
173
174        if capacity > scale_factor * self.map.len() {
175            let new_capacity = (capacity * (scale_factor - 1)) / scale_factor;
176            // Resize the map with the new capacity.
177            self.map.shrink_to(new_capacity, |(hash, _)| *hash)
178        }
179    }
180
181    /// Calculates the size of the `PruningJoinHashMap` in bytes.
182    ///
183    /// # Returns
184    /// The size of the hash map in bytes.
185    pub(crate) fn size(&self) -> usize {
186        let fixed_size = size_of::<PruningJoinHashMap>();
187
188        // TODO: switch to using [HashTable::allocation_size] when available after upgrading hashbrown to 0.15
189        estimate_memory_size::<(u64, u64)>(self.map.capacity(), fixed_size).unwrap()
190            + self.next.capacity() * size_of::<u64>()
191    }
192
193    /// Removes hash values from the map and the list based on the given pruning
194    /// length and deleting offset.
195    ///
196    /// # Arguments
197    /// * `prune_length`: The number of elements to remove from the list.
198    /// * `deleting_offset`: The offset used to determine which hash values to remove from the map.
199    ///
200    /// # Returns
201    /// A `Result` indicating whether the operation was successful.
202    pub(crate) fn prune_hash_values(
203        &mut self,
204        prune_length: usize,
205        deleting_offset: u64,
206        shrink_factor: usize,
207    ) {
208        // Remove elements from the list based on the pruning length.
209        self.next.drain(0..prune_length);
210
211        // Calculate the keys that should be removed from the map.
212        let removable_keys = self
213            .map
214            .iter()
215            .filter_map(|(hash, tail_index)| {
216                (*tail_index < prune_length as u64 + deleting_offset).then_some(*hash)
217            })
218            .collect::<Vec<_>>();
219
220        // Remove the keys from the map.
221        removable_keys.into_iter().for_each(|hash_value| {
222            self.map
223                .find_entry(hash_value, |(hash, _)| hash_value == *hash)
224                .unwrap()
225                .remove();
226        });
227
228        // Shrink the map if necessary.
229        self.shrink_if_necessary(shrink_factor);
230    }
231}
232
233fn check_filter_expr_contains_sort_information(
234    expr: &Arc<dyn PhysicalExpr>,
235    reference: &Arc<dyn PhysicalExpr>,
236) -> bool {
237    expr.eq(reference)
238        || expr
239            .children()
240            .iter()
241            .any(|e| check_filter_expr_contains_sort_information(e, reference))
242}
243
244/// Create a one to one mapping from main columns to filter columns using
245/// filter column indices. A column index looks like:
246/// ```text
247/// ColumnIndex {
248///     index: 0, // field index in main schema
249///     side: JoinSide::Left, // child side
250/// }
251/// ```
252pub fn map_origin_col_to_filter_col(
253    filter: &JoinFilter,
254    schema: &SchemaRef,
255    side: &JoinSide,
256) -> Result<HashMap<Column, Column>> {
257    let filter_schema = filter.schema();
258    let mut col_to_col_map = HashMap::<Column, Column>::new();
259    for (filter_schema_index, index) in filter.column_indices().iter().enumerate() {
260        if index.side.eq(side) {
261            // Get the main field from column index:
262            let main_field = schema.field(index.index);
263            // Create a column expression:
264            let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?;
265            // Since the order of by filter.column_indices() is the same with
266            // that of intermediate schema fields, we can get the column directly.
267            let filter_field = filter_schema.field(filter_schema_index);
268            let filter_col = Column::new(filter_field.name(), filter_schema_index);
269            // Insert mapping:
270            col_to_col_map.insert(main_col, filter_col);
271        }
272    }
273    Ok(col_to_col_map)
274}
275
276/// This function analyzes [`PhysicalSortExpr`] graphs with respect to output orderings
277/// (sorting) properties. This is necessary since monotonically increasing and/or
278/// decreasing expressions are required when using join filter expressions for
279/// data pruning purposes.
280///
281/// The method works as follows:
282/// 1. Maps the original columns to the filter columns using the [`map_origin_col_to_filter_col`] function.
283/// 2. Collects all columns in the sort expression using the [`collect_columns`] function.
284/// 3. Checks if all columns are included in the map we obtain in the first step.
285/// 4. If all columns are included, the sort expression is converted into a filter expression using
286///    the [`convert_filter_columns`] function.
287/// 5. Searches for the converted filter expression in the filter expression using the
288///    [`check_filter_expr_contains_sort_information`] function.
289/// 6. If an exact match is found, returns the converted filter expression as `Some(Arc<dyn PhysicalExpr>)`.
290/// 7. If all columns are not included or an exact match is not found, returns [`None`].
291///
292/// Examples:
293/// Consider the filter expression "a + b > c + 10 AND a + b < c + 100".
294/// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter.
295/// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter.
296/// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However,
297///    there is no exact match, so this expression does not indicate pruning.
298pub fn convert_sort_expr_with_filter_schema(
299    side: &JoinSide,
300    filter: &JoinFilter,
301    schema: &SchemaRef,
302    sort_expr: &PhysicalSortExpr,
303) -> Result<Option<Arc<dyn PhysicalExpr>>> {
304    let column_map = map_origin_col_to_filter_col(filter, schema, side)?;
305    let expr = Arc::clone(&sort_expr.expr);
306    // Get main schema columns:
307    let expr_columns = collect_columns(&expr);
308    // Calculation is possible with `column_map` since sort exprs belong to a child.
309    let all_columns_are_included =
310        expr_columns.iter().all(|col| column_map.contains_key(col));
311    if all_columns_are_included {
312        // Since we are sure that one to one column mapping includes all columns, we convert
313        // the sort expression into a filter expression.
314        let converted_filter_expr = expr
315            .transform_up_with_lambdas_params(|p, lambdas_params| {
316                convert_filter_columns(p.as_ref(), &column_map, lambdas_params).map(
317                    |transformed| match transformed {
318                        Some(transformed) => Transformed::yes(transformed),
319                        None => Transformed::no(p),
320                    },
321                )
322            })
323            .data()?;
324        // Search the converted `PhysicalExpr` in filter expression; if an exact
325        // match is found, use this sorted expression in graph traversals.
326        if check_filter_expr_contains_sort_information(
327            filter.expression(),
328            &converted_filter_expr,
329        ) {
330            return Ok(Some(converted_filter_expr));
331        }
332    }
333    Ok(None)
334}
335
336/// This function is used to build the filter expression based on the sort order of input columns.
337///
338/// It first calls the [`convert_sort_expr_with_filter_schema`] method to determine if the sort
339/// order of columns can be used in the filter expression. If it returns a [`Some`] value, the
340/// method wraps the result in a [`SortedFilterExpr`] instance with the original sort expression and
341/// the converted filter expression. Otherwise, this function returns an error.
342///
343/// The `SortedFilterExpr` instance contains information about the sort order of columns that can
344/// be used in the filter expression, which can be used to optimize the query execution process.
345pub fn build_filter_input_order(
346    side: JoinSide,
347    filter: &JoinFilter,
348    schema: &SchemaRef,
349    order: &PhysicalSortExpr,
350) -> Result<Option<SortedFilterExpr>> {
351    let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?;
352    opt_expr
353        .map(|filter_expr| {
354            SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema())
355        })
356        .transpose()
357}
358
359/// Convert a physical expression into a filter expression using the given
360/// column mapping information.
361fn convert_filter_columns(
362    input: &dyn PhysicalExpr,
363    column_map: &HashMap<Column, Column>,
364    lambdas_params: &HashSet<String>,
365) -> Result<Option<Arc<dyn PhysicalExpr>>> {
366    // Attempt to downcast the input expression to a Column type.
367    Ok(match input.as_any().downcast_ref::<Column>() {
368        Some(col) if !lambdas_params.contains(col.name()) => {
369            column_map.get(col).map(|c| Arc::new(c.clone()) as _)
370        }
371        _ => {
372            // If the downcast fails, return the input expression as is.
373            None
374        }
375    })
376}
377
378/// The [SortedFilterExpr] object represents a sorted filter expression. It
379/// contains the following information: The origin expression, the filter
380/// expression, an interval encapsulating expression bounds, and a stable
381/// index identifying the expression in the expression DAG.
382///
383/// Physical schema of a [JoinFilter]'s intermediate batch combines two sides
384/// and uses new column names. In this process, a column exchange is done so
385/// we can utilize sorting information while traversing the filter expression
386/// DAG for interval calculations. When evaluating the inner buffer, we use
387/// `origin_sorted_expr`.
388#[derive(Debug, Clone)]
389pub struct SortedFilterExpr {
390    /// Sorted expression from a join side (i.e. a child of the join)
391    origin_sorted_expr: PhysicalSortExpr,
392    /// Expression adjusted for filter schema.
393    filter_expr: Arc<dyn PhysicalExpr>,
394    /// Interval containing expression bounds
395    interval: Interval,
396    /// Node index in the expression DAG
397    node_index: usize,
398}
399
400impl SortedFilterExpr {
401    /// Constructor
402    pub fn try_new(
403        origin_sorted_expr: PhysicalSortExpr,
404        filter_expr: Arc<dyn PhysicalExpr>,
405        filter_schema: &Schema,
406    ) -> Result<Self> {
407        let dt = filter_expr.data_type(filter_schema)?;
408        Ok(Self {
409            origin_sorted_expr,
410            filter_expr,
411            interval: Interval::make_unbounded(&dt)?,
412            node_index: 0,
413        })
414    }
415
416    /// Get origin expr information
417    pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr {
418        &self.origin_sorted_expr
419    }
420
421    /// Get filter expr information
422    pub fn filter_expr(&self) -> &Arc<dyn PhysicalExpr> {
423        &self.filter_expr
424    }
425
426    /// Get interval information
427    pub fn interval(&self) -> &Interval {
428        &self.interval
429    }
430
431    /// Sets interval
432    pub fn set_interval(&mut self, interval: Interval) {
433        self.interval = interval;
434    }
435
436    /// Node index in ExprIntervalGraph
437    pub fn node_index(&self) -> usize {
438        self.node_index
439    }
440
441    /// Node index setter in ExprIntervalGraph
442    pub fn set_node_index(&mut self, node_index: usize) {
443        self.node_index = node_index;
444    }
445}
446
447/// Calculate the filter expression intervals.
448///
449/// This function updates the `interval` field of each `SortedFilterExpr` based
450/// on the first or the last value of the expression in `build_input_buffer`
451/// and `probe_batch`.
452///
453/// # Parameters
454///
455/// * `build_input_buffer` - The [RecordBatch] on the build side of the join.
456/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update.
457/// * `probe_batch` - The `RecordBatch` on the probe side of the join.
458/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update.
459///
460/// ## Note
461///
462/// Utilizing interval arithmetic, this function computes feasible join intervals
463/// on the pruning side by evaluating the prospective value ranges that might
464/// emerge in subsequent data batches from the enforcer side. This is done by
465/// first creating an interval for join filter values in the pruning side of the
466/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/
467/// ascending) of the filter expression. Here, `FV` denotes the first value on the
468/// pruning side. This range is then compared with the enforcer side interval,
469/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/
470/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer
471/// side.
472///
473/// As a concrete example, consider the following query:
474///
475/// ```text
476///   SELECT * FROM left_table, right_table
477///   WHERE
478///     left_key = right_key AND
479///     a > b - 3 AND
480///     a < b + 10
481/// ```
482///
483/// where columns `a` and `b` come from tables `left_table` and `right_table`,
484/// respectively. When a new `RecordBatch` arrives at the right side, the
485/// condition `a > b - 3` will possibly indicate a prunable range for the left
486/// side. Conversely, when a new `RecordBatch` arrives at the left side, the
487/// condition `a < b + 10` will possibly indicate prunability for the right side.
488/// Let’s inspect what happens when a new `RecordBatch` arrives at the right
489/// side (i.e. when the left side is the build side):
490///
491/// ```text
492///         Build      Probe
493///       +-------+  +-------+
494///       | a | z |  | b | y |
495///       |+--|--+|  |+--|--+|
496///       | 1 | 2 |  | 4 | 3 |
497///       |+--|--+|  |+--|--+|
498///       | 3 | 1 |  | 4 | 3 |
499///       |+--|--+|  |+--|--+|
500///       | 5 | 7 |  | 6 | 1 |
501///       |+--|--+|  |+--|--+|
502///       | 7 | 1 |  | 6 | 3 |
503///       +-------+  +-------+
504/// ```
505///
506/// In this case, the interval representing viable (i.e. joinable) values for
507/// column `a` is `[1, ∞]`, and the interval representing possible future values
508/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate
509/// intervals for the whole filter expression and propagate join constraint by
510/// traversing the expression graph.
511pub fn calculate_filter_expr_intervals(
512    build_input_buffer: &RecordBatch,
513    build_sorted_filter_expr: &mut SortedFilterExpr,
514    probe_batch: &RecordBatch,
515    probe_sorted_filter_expr: &mut SortedFilterExpr,
516) -> Result<()> {
517    // If either build or probe side has no data, return early:
518    if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
519        return Ok(());
520    }
521    // Calculate the interval for the build side filter expression (if present):
522    update_filter_expr_interval(
523        &build_input_buffer.slice(0, 1),
524        build_sorted_filter_expr,
525    )?;
526    // Calculate the interval for the probe side filter expression (if present):
527    update_filter_expr_interval(
528        &probe_batch.slice(probe_batch.num_rows() - 1, 1),
529        probe_sorted_filter_expr,
530    )
531}
532
533/// This is a subroutine of the function [`calculate_filter_expr_intervals`].
534/// It constructs the current interval using the given `batch` and updates
535/// the filter expression (i.e. `sorted_expr`) with this interval.
536pub fn update_filter_expr_interval(
537    batch: &RecordBatch,
538    sorted_expr: &mut SortedFilterExpr,
539) -> Result<()> {
540    // Evaluate the filter expression and convert the result to an array:
541    let array = sorted_expr
542        .origin_sorted_expr()
543        .expr
544        .evaluate(batch)?
545        .into_array(1)?;
546    // Convert the array to a ScalarValue:
547    let value = ScalarValue::try_from_array(&array, 0)?;
548    // Create a ScalarValue representing positive or negative infinity for the same data type:
549    let inf = ScalarValue::try_from(value.data_type())?;
550    // Update the interval with lower and upper bounds based on the sort option:
551    let interval = if sorted_expr.origin_sorted_expr().options.descending {
552        Interval::try_new(inf, value)?
553    } else {
554        Interval::try_new(value, inf)?
555    };
556    // Set the calculated interval for the sorted filter expression:
557    sorted_expr.set_interval(interval);
558    Ok(())
559}
560
561/// Get the anti join indices from the visited hash set.
562///
563/// This method returns the indices from the original input that were not present in the visited hash set.
564///
565/// # Arguments
566///
567/// * `prune_length` - The length of the pruned record batch.
568/// * `deleted_offset` - The offset to the indices.
569/// * `visited_rows` - The hash set of visited indices.
570///
571/// # Returns
572///
573/// A `PrimitiveArray` of the anti join indices.
574pub fn get_pruning_anti_indices<T: ArrowPrimitiveType>(
575    prune_length: usize,
576    deleted_offset: usize,
577    visited_rows: &HashSet<usize>,
578) -> PrimitiveArray<T>
579where
580    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
581{
582    let mut bitmap = BooleanBufferBuilder::new(prune_length);
583    bitmap.append_n(prune_length, false);
584    // mark the indices as true if they are present in the visited hash set
585    for v in 0..prune_length {
586        let row = v + deleted_offset;
587        bitmap.set_bit(v, visited_rows.contains(&row));
588    }
589    // get the anti index
590    (0..prune_length)
591        .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
592        .collect()
593}
594
595/// This method creates a boolean buffer from the visited rows hash set
596/// and the indices of the pruned record batch slice.
597///
598/// It gets the indices from the original input that were present in the visited hash set.
599///
600/// # Arguments
601///
602/// * `prune_length` - The length of the pruned record batch.
603/// * `deleted_offset` - The offset to the indices.
604/// * `visited_rows` - The hash set of visited indices.
605///
606/// # Returns
607///
608/// A [PrimitiveArray] of the specified type T, containing the semi indices.
609pub fn get_pruning_semi_indices<T: ArrowPrimitiveType>(
610    prune_length: usize,
611    deleted_offset: usize,
612    visited_rows: &HashSet<usize>,
613) -> PrimitiveArray<T>
614where
615    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
616{
617    let mut bitmap = BooleanBufferBuilder::new(prune_length);
618    bitmap.append_n(prune_length, false);
619    // mark the indices as true if they are present in the visited hash set
620    (0..prune_length).for_each(|v| {
621        let row = &(v + deleted_offset);
622        bitmap.set_bit(v, visited_rows.contains(row));
623    });
624    // get the semi index
625    (0..prune_length)
626        .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
627        .collect()
628}
629
630pub fn combine_two_batches(
631    output_schema: &SchemaRef,
632    left_batch: Option<RecordBatch>,
633    right_batch: Option<RecordBatch>,
634) -> Result<Option<RecordBatch>> {
635    match (left_batch, right_batch) {
636        (Some(batch), None) | (None, Some(batch)) => {
637            // If only one of the batches are present, return it:
638            Ok(Some(batch))
639        }
640        (Some(left_batch), Some(right_batch)) => {
641            // If both batches are present, concatenate them:
642            concat_batches(output_schema, &[left_batch, right_batch])
643                .map_err(|e| arrow_datafusion_err!(e))
644                .map(Some)
645        }
646        (None, None) => {
647            // If neither is present, return an empty batch:
648            Ok(None)
649        }
650    }
651}
652
653/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`.
654/// This function will insert the indices (offset by `offset`) into the `visited` hash set.
655///
656/// # Arguments
657///
658/// * `visited` - A hash set to store the visited indices.
659/// * `offset` - An offset to the indices in the `PrimitiveArray`.
660/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded.
661pub fn record_visited_indices<T: ArrowPrimitiveType>(
662    visited: &mut HashSet<usize>,
663    offset: usize,
664    indices: &PrimitiveArray<T>,
665) {
666    for i in indices.values() {
667        visited.insert(i.as_usize() + offset);
668    }
669}
670
671#[derive(Debug)]
672pub struct StreamJoinSideMetrics {
673    /// Number of batches consumed by this operator
674    pub(crate) input_batches: metrics::Count,
675    /// Number of rows consumed by this operator
676    pub(crate) input_rows: metrics::Count,
677}
678
679/// Metrics for HashJoinExec
680#[derive(Debug)]
681pub struct StreamJoinMetrics {
682    /// Number of left batches/rows consumed by this operator
683    pub(crate) left: StreamJoinSideMetrics,
684    /// Number of right batches/rows consumed by this operator
685    pub(crate) right: StreamJoinSideMetrics,
686    /// Memory used by sides in bytes
687    pub(crate) stream_memory_usage: metrics::Gauge,
688    /// Number of batches produced by this operator
689    pub(crate) output_batches: metrics::Count,
690    /// Number of rows produced by this operator
691    pub(crate) baseline_metrics: BaselineMetrics,
692}
693
694impl StreamJoinMetrics {
695    pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
696        let input_batches =
697            MetricBuilder::new(metrics).counter("input_batches", partition);
698        let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
699        let left = StreamJoinSideMetrics {
700            input_batches,
701            input_rows,
702        };
703
704        let input_batches =
705            MetricBuilder::new(metrics).counter("input_batches", partition);
706        let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
707        let right = StreamJoinSideMetrics {
708            input_batches,
709            input_rows,
710        };
711
712        let stream_memory_usage =
713            MetricBuilder::new(metrics).gauge("stream_memory_usage", partition);
714
715        let output_batches =
716            MetricBuilder::new(metrics).counter("output_batches", partition);
717
718        Self {
719            left,
720            right,
721            output_batches,
722            stream_memory_usage,
723            baseline_metrics: BaselineMetrics::new(metrics, partition),
724        }
725    }
726}
727
728/// Updates sorted filter expressions with corresponding node indices from the
729/// expression interval graph.
730///
731/// This function iterates through the provided sorted filter expressions,
732/// gathers the corresponding node indices from the expression interval graph,
733/// and then updates the sorted expressions with these indices. It ensures
734/// that these sorted expressions are aligned with the structure of the graph.
735fn update_sorted_exprs_with_node_indices(
736    graph: &mut ExprIntervalGraph,
737    sorted_exprs: &mut [SortedFilterExpr],
738) {
739    // Extract filter expressions from the sorted expressions:
740    let filter_exprs = sorted_exprs
741        .iter()
742        .map(|expr| Arc::clone(expr.filter_expr()))
743        .collect::<Vec<_>>();
744
745    // Gather corresponding node indices for the extracted filter expressions from the graph:
746    let child_node_indices = graph.gather_node_indices(&filter_exprs);
747
748    // Iterate through the sorted expressions and the gathered node indices:
749    for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) {
750        // Update each sorted expression with the corresponding node index:
751        sorted_expr.set_node_index(index);
752    }
753}
754
755/// Prepares and sorts expressions based on a given filter, left and right schemas,
756/// and sort expressions.
757///
758/// This function prepares sorted filter expressions for both the left and right
759/// sides of a join operation. It first builds the filter order for each side
760/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter
761/// expressions, the function then constructs an expression interval graph and
762/// updates the sorted expressions with node indices. The final sorted filter
763/// expressions for both sides are then returned.
764///
765/// # Parameters
766///
767/// * `filter` - The join filter to base the sorting on.
768/// * `left` - The `ExecutionPlan` for the left side of the join.
769/// * `right` - The `ExecutionPlan` for the right side of the join.
770/// * `left_sort_exprs` - The expressions to sort on the left side.
771/// * `right_sort_exprs` - The expressions to sort on the right side.
772///
773/// # Returns
774///
775/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph.
776pub fn prepare_sorted_exprs(
777    filter: &JoinFilter,
778    left: &Arc<dyn ExecutionPlan>,
779    right: &Arc<dyn ExecutionPlan>,
780    left_sort_exprs: &LexOrdering,
781    right_sort_exprs: &LexOrdering,
782) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
783    let err = || {
784        datafusion_common::plan_datafusion_err!("Filter does not include the child order")
785    };
786
787    // Build the filter order for the left side:
788    let left_temp_sorted_filter_expr = build_filter_input_order(
789        JoinSide::Left,
790        filter,
791        &left.schema(),
792        &left_sort_exprs[0],
793    )?
794    .ok_or_else(err)?;
795
796    // Build the filter order for the right side:
797    let right_temp_sorted_filter_expr = build_filter_input_order(
798        JoinSide::Right,
799        filter,
800        &right.schema(),
801        &right_sort_exprs[0],
802    )?
803    .ok_or_else(err)?;
804
805    // Collect the sorted expressions
806    let mut sorted_exprs =
807        vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];
808
809    // Build the expression interval graph
810    let mut graph =
811        ExprIntervalGraph::try_new(Arc::clone(filter.expression()), filter.schema())?;
812
813    // Update sorted expressions with node indices
814    update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);
815
816    // Swap and remove to get the final sorted filter expressions
817    let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
818    let left_sorted_filter_expr = sorted_exprs.swap_remove(0);
819
820    Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
821}
822
823#[cfg(test)]
824pub mod tests {
825
826    use super::*;
827    use crate::{joins::test_utils::complicated_filter, joins::utils::ColumnIndex};
828
829    use arrow::compute::SortOptions;
830    use arrow::datatypes::{DataType, Field};
831    use datafusion_expr::Operator;
832    use datafusion_physical_expr::expressions::{binary, cast, col};
833
834    #[test]
835    fn test_column_exchange() -> Result<()> {
836        let left_child_schema =
837            Schema::new(vec![Field::new("left_1", DataType::Int32, true)]);
838        // Sorting information for the left side:
839        let left_child_sort_expr = PhysicalSortExpr {
840            expr: col("left_1", &left_child_schema)?,
841            options: SortOptions::default(),
842        };
843
844        let right_child_schema = Schema::new(vec![
845            Field::new("right_1", DataType::Int32, true),
846            Field::new("right_2", DataType::Int32, true),
847        ]);
848        // Sorting information for the right side:
849        let right_child_sort_expr = PhysicalSortExpr {
850            expr: binary(
851                col("right_1", &right_child_schema)?,
852                Operator::Plus,
853                col("right_2", &right_child_schema)?,
854                &right_child_schema,
855            )?,
856            options: SortOptions::default(),
857        };
858
859        let intermediate_schema = Schema::new(vec![
860            Field::new("filter_1", DataType::Int32, true),
861            Field::new("filter_2", DataType::Int32, true),
862            Field::new("filter_3", DataType::Int32, true),
863        ]);
864        // Our filter expression is: left_1 > right_1 + right_2.
865        let filter_left = col("filter_1", &intermediate_schema)?;
866        let filter_right = binary(
867            col("filter_2", &intermediate_schema)?,
868            Operator::Plus,
869            col("filter_3", &intermediate_schema)?,
870            &intermediate_schema,
871        )?;
872        let filter_expr = binary(
873            Arc::clone(&filter_left),
874            Operator::Gt,
875            Arc::clone(&filter_right),
876            &intermediate_schema,
877        )?;
878        let column_indices = vec![
879            ColumnIndex {
880                index: 0,
881                side: JoinSide::Left,
882            },
883            ColumnIndex {
884                index: 0,
885                side: JoinSide::Right,
886            },
887            ColumnIndex {
888                index: 1,
889                side: JoinSide::Right,
890            },
891        ];
892        let filter =
893            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
894
895        let left_sort_filter_expr = build_filter_input_order(
896            JoinSide::Left,
897            &filter,
898            &Arc::new(left_child_schema),
899            &left_child_sort_expr,
900        )?
901        .unwrap();
902        assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr()));
903
904        let right_sort_filter_expr = build_filter_input_order(
905            JoinSide::Right,
906            &filter,
907            &Arc::new(right_child_schema),
908            &right_child_sort_expr,
909        )?
910        .unwrap();
911        assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr()));
912
913        // Assert that adjusted (left) filter expression matches with `left_child_sort_expr`:
914        assert!(filter_left.eq(left_sort_filter_expr.filter_expr()));
915        // Assert that adjusted (right) filter expression matches with `right_child_sort_expr`:
916        assert!(filter_right.eq(right_sort_filter_expr.filter_expr()));
917        Ok(())
918    }
919
920    #[test]
921    fn test_column_collector() -> Result<()> {
922        let schema = Schema::new(vec![
923            Field::new("0", DataType::Int32, true),
924            Field::new("1", DataType::Int32, true),
925            Field::new("2", DataType::Int32, true),
926        ]);
927        let filter_expr = complicated_filter(&schema)?;
928        let columns = collect_columns(&filter_expr);
929        assert_eq!(columns.len(), 3);
930        Ok(())
931    }
932
933    #[test]
934    fn find_expr_inside_expr() -> Result<()> {
935        let schema = Schema::new(vec![
936            Field::new("0", DataType::Int32, true),
937            Field::new("1", DataType::Int32, true),
938            Field::new("2", DataType::Int32, true),
939        ]);
940        let filter_expr = complicated_filter(&schema)?;
941
942        let expr_1 = Arc::new(Column::new("gnz", 0)) as _;
943        assert!(!check_filter_expr_contains_sort_information(
944            &filter_expr,
945            &expr_1
946        ));
947
948        let expr_2 = col("1", &schema)? as _;
949
950        assert!(check_filter_expr_contains_sort_information(
951            &filter_expr,
952            &expr_2
953        ));
954
955        let expr_3 = cast(
956            binary(
957                col("0", &schema)?,
958                Operator::Plus,
959                col("1", &schema)?,
960                &schema,
961            )?,
962            &schema,
963            DataType::Int64,
964        )?;
965
966        assert!(check_filter_expr_contains_sort_information(
967            &filter_expr,
968            &expr_3
969        ));
970
971        let expr_4 = Arc::new(Column::new("1", 42)) as _;
972
973        assert!(!check_filter_expr_contains_sort_information(
974            &filter_expr,
975            &expr_4,
976        ));
977        Ok(())
978    }
979
980    #[test]
981    fn build_sorted_expr() -> Result<()> {
982        let left_schema = Schema::new(vec![
983            Field::new("la1", DataType::Int32, false),
984            Field::new("lb1", DataType::Int32, false),
985            Field::new("lc1", DataType::Int32, false),
986            Field::new("lt1", DataType::Int32, false),
987            Field::new("la2", DataType::Int32, false),
988            Field::new("la1_des", DataType::Int32, false),
989        ]);
990
991        let right_schema = Schema::new(vec![
992            Field::new("ra1", DataType::Int32, false),
993            Field::new("rb1", DataType::Int32, false),
994            Field::new("rc1", DataType::Int32, false),
995            Field::new("rt1", DataType::Int32, false),
996            Field::new("ra2", DataType::Int32, false),
997            Field::new("ra1_des", DataType::Int32, false),
998        ]);
999
1000        let intermediate_schema = Schema::new(vec![
1001            Field::new("0", DataType::Int32, true),
1002            Field::new("1", DataType::Int32, true),
1003            Field::new("2", DataType::Int32, true),
1004        ]);
1005        let filter_expr = complicated_filter(&intermediate_schema)?;
1006        let column_indices = vec![
1007            ColumnIndex {
1008                index: left_schema.index_of("la1")?,
1009                side: JoinSide::Left,
1010            },
1011            ColumnIndex {
1012                index: left_schema.index_of("la2")?,
1013                side: JoinSide::Left,
1014            },
1015            ColumnIndex {
1016                index: right_schema.index_of("ra1")?,
1017                side: JoinSide::Right,
1018            },
1019        ];
1020        let filter =
1021            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1022
1023        let left_schema = Arc::new(left_schema);
1024        let right_schema = Arc::new(right_schema);
1025
1026        assert!(build_filter_input_order(
1027            JoinSide::Left,
1028            &filter,
1029            &left_schema,
1030            &PhysicalSortExpr {
1031                expr: col("la1", left_schema.as_ref())?,
1032                options: SortOptions::default(),
1033            }
1034        )?
1035        .is_some());
1036        assert!(build_filter_input_order(
1037            JoinSide::Left,
1038            &filter,
1039            &left_schema,
1040            &PhysicalSortExpr {
1041                expr: col("lt1", left_schema.as_ref())?,
1042                options: SortOptions::default(),
1043            }
1044        )?
1045        .is_none());
1046        assert!(build_filter_input_order(
1047            JoinSide::Right,
1048            &filter,
1049            &right_schema,
1050            &PhysicalSortExpr {
1051                expr: col("ra1", right_schema.as_ref())?,
1052                options: SortOptions::default(),
1053            }
1054        )?
1055        .is_some());
1056        assert!(build_filter_input_order(
1057            JoinSide::Right,
1058            &filter,
1059            &right_schema,
1060            &PhysicalSortExpr {
1061                expr: col("rb1", right_schema.as_ref())?,
1062                options: SortOptions::default(),
1063            }
1064        )?
1065        .is_none());
1066
1067        Ok(())
1068    }
1069
1070    // Test the case when we have an "ORDER BY a + b", and join filter condition includes "a - b".
1071    #[test]
1072    fn sorted_filter_expr_build() -> Result<()> {
1073        let intermediate_schema = Schema::new(vec![
1074            Field::new("0", DataType::Int32, true),
1075            Field::new("1", DataType::Int32, true),
1076        ]);
1077        let filter_expr = binary(
1078            col("0", &intermediate_schema)?,
1079            Operator::Minus,
1080            col("1", &intermediate_schema)?,
1081            &intermediate_schema,
1082        )?;
1083        let column_indices = vec![
1084            ColumnIndex {
1085                index: 0,
1086                side: JoinSide::Left,
1087            },
1088            ColumnIndex {
1089                index: 1,
1090                side: JoinSide::Left,
1091            },
1092        ];
1093        let filter =
1094            JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1095
1096        let schema = Schema::new(vec![
1097            Field::new("a", DataType::Int32, false),
1098            Field::new("b", DataType::Int32, false),
1099        ]);
1100
1101        let sorted = PhysicalSortExpr {
1102            expr: binary(
1103                col("a", &schema)?,
1104                Operator::Plus,
1105                col("b", &schema)?,
1106                &schema,
1107            )?,
1108            options: SortOptions::default(),
1109        };
1110
1111        let res = convert_sort_expr_with_filter_schema(
1112            &JoinSide::Left,
1113            &filter,
1114            &Arc::new(schema),
1115            &sorted,
1116        )?;
1117        assert!(res.is_none());
1118        Ok(())
1119    }
1120
1121    #[test]
1122    fn test_shrink_if_necessary() {
1123        let scale_factor = 4;
1124        let mut join_hash_map = PruningJoinHashMap::with_capacity(100);
1125        let data_size = 2000;
1126        let deleted_part = 3 * data_size / 4;
1127        // Add elements to the JoinHashMap
1128        for hash_value in 0..data_size {
1129            join_hash_map.map.insert_unique(
1130                hash_value,
1131                (hash_value, hash_value),
1132                |(hash, _)| *hash,
1133            );
1134        }
1135
1136        assert_eq!(join_hash_map.map.len(), data_size as usize);
1137        assert!(join_hash_map.map.capacity() >= data_size as usize);
1138
1139        // Remove some elements from the JoinHashMap
1140        for hash_value in 0..deleted_part {
1141            join_hash_map
1142                .map
1143                .find_entry(hash_value, |(hash, _)| hash_value == *hash)
1144                .unwrap()
1145                .remove();
1146        }
1147
1148        assert_eq!(join_hash_map.map.len(), (data_size - deleted_part) as usize);
1149
1150        // Old capacity
1151        let old_capacity = join_hash_map.map.capacity();
1152
1153        // Test shrink_if_necessary
1154        join_hash_map.shrink_if_necessary(scale_factor);
1155
1156        // The capacity should be reduced by the scale factor
1157        let new_expected_capacity =
1158            join_hash_map.map.capacity() * (scale_factor - 1) / scale_factor;
1159        assert!(join_hash_map.map.capacity() >= new_expected_capacity);
1160        assert!(join_hash_map.map.capacity() <= old_capacity);
1161    }
1162}