datafusion_physical_plan/joins/stream_join_utils.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file contains common subroutines for symmetric hash join
19//! related functionality, used both in join calculations and optimization rules.
20
21use std::collections::{HashMap, VecDeque};
22use std::mem::size_of;
23use std::sync::Arc;
24
25use crate::joins::join_hash_map::{
26 get_matched_indices, get_matched_indices_with_limit_offset, update_from_iter,
27 JoinHashMapOffset,
28};
29use crate::joins::utils::{JoinFilter, JoinHashMapType};
30use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder};
31use crate::{metrics, ExecutionPlan};
32
33use arrow::array::{
34 ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch,
35};
36use arrow::compute::concat_batches;
37use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef};
38use datafusion_common::tree_node::{Transformed, TransformedResult};
39use datafusion_common::utils::memory::estimate_memory_size;
40use datafusion_common::{
41 arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue,
42};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_physical_expr::expressions::Column;
45use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
46use datafusion_physical_expr::utils::collect_columns;
47use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt, PhysicalSortExpr};
48
49use datafusion_physical_expr_common::sort_expr::LexOrdering;
50use hashbrown::HashTable;
51
52/// Implementation of `JoinHashMapType` for `PruningJoinHashMap`.
53impl JoinHashMapType for PruningJoinHashMap {
54 // Extend with zero
55 fn extend_zero(&mut self, len: usize) {
56 self.next.resize(self.next.len() + len, 0)
57 }
58
59 fn update_from_iter<'a>(
60 &mut self,
61 iter: Box<dyn Iterator<Item = (usize, &'a u64)> + Send + 'a>,
62 deleted_offset: usize,
63 ) {
64 let slice: &mut [u64] = self.next.make_contiguous();
65 update_from_iter::<u64>(&mut self.map, slice, iter, deleted_offset);
66 }
67
68 fn get_matched_indices<'a>(
69 &self,
70 iter: Box<dyn Iterator<Item = (usize, &'a u64)> + 'a>,
71 deleted_offset: Option<usize>,
72 ) -> (Vec<u32>, Vec<u64>) {
73 // Flatten the deque
74 let next: Vec<u64> = self.next.iter().copied().collect();
75 get_matched_indices::<u64>(&self.map, &next, iter, deleted_offset)
76 }
77
78 fn get_matched_indices_with_limit_offset(
79 &self,
80 hash_values: &[u64],
81 limit: usize,
82 offset: JoinHashMapOffset,
83 ) -> (Vec<u32>, Vec<u64>, Option<JoinHashMapOffset>) {
84 // Flatten the deque
85 let next: Vec<u64> = self.next.iter().copied().collect();
86 get_matched_indices_with_limit_offset::<u64>(
87 &self.map,
88 &next,
89 hash_values,
90 limit,
91 offset,
92 )
93 }
94
95 fn is_empty(&self) -> bool {
96 self.map.is_empty()
97 }
98}
99
100/// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with
101/// the capability of pruning elements in an efficient manner. This structure
102/// is particularly useful for cases where it's necessary to remove elements
103/// from the map based on their buffer order.
104///
105/// # Example
106///
107/// ``` text
108/// Let's continue the example of `JoinHashMap` and then show how `PruningJoinHashMap` would
109/// handle the pruning scenario.
110///
111/// Insert the pair (10,4) into the `PruningJoinHashMap`:
112/// map:
113/// ----------
114/// | 10 | 5 |
115/// | 20 | 3 |
116/// ----------
117/// list:
118/// ---------------------
119/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means indices values 4,3,1)
120/// ---------------------
121///
122/// Now, let's prune 3 rows from `PruningJoinHashMap`:
123/// map:
124/// ---------
125/// | 1 | 5 |
126/// ---------
127/// list:
128/// ---------
129/// | 2 | 4 | <--- hash value 10 maps to 2 (5 - 3), 1 (4 - 3), NA (2 - 3) (which means indices values 1,0)
130/// ---------
131///
132/// After pruning, the | 2 | 3 | entry is deleted from `PruningJoinHashMap` since
133/// there are no values left for this key.
134/// ```
135pub struct PruningJoinHashMap {
136 /// Stores hash value to last row index
137 pub map: HashTable<(u64, u64)>,
138 /// Stores indices in chained list data structure
139 pub next: VecDeque<u64>,
140}
141
142impl PruningJoinHashMap {
143 /// Constructs a new `PruningJoinHashMap` with the given capacity.
144 /// Both the map and the list are pre-allocated with the provided capacity.
145 ///
146 /// # Arguments
147 /// * `capacity`: The initial capacity of the hash map.
148 ///
149 /// # Returns
150 /// A new instance of `PruningJoinHashMap`.
151 pub(crate) fn with_capacity(capacity: usize) -> Self {
152 PruningJoinHashMap {
153 map: HashTable::with_capacity(capacity),
154 next: VecDeque::with_capacity(capacity),
155 }
156 }
157
158 /// Shrinks the capacity of the hash map, if necessary, based on the
159 /// provided scale factor.
160 ///
161 /// # Arguments
162 /// * `scale_factor`: The scale factor that determines how conservative the
163 /// shrinking strategy is. The capacity will be reduced by 1/`scale_factor`
164 /// when necessary.
165 ///
166 /// # Note
167 /// Increasing the scale factor results in less aggressive capacity shrinking,
168 /// leading to potentially higher memory usage but fewer resizes. Conversely,
169 /// decreasing the scale factor results in more aggressive capacity shrinking,
170 /// potentially leading to lower memory usage but more frequent resizing.
171 pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) {
172 let capacity = self.map.capacity();
173
174 if capacity > scale_factor * self.map.len() {
175 let new_capacity = (capacity * (scale_factor - 1)) / scale_factor;
176 // Resize the map with the new capacity.
177 self.map.shrink_to(new_capacity, |(hash, _)| *hash)
178 }
179 }
180
181 /// Calculates the size of the `PruningJoinHashMap` in bytes.
182 ///
183 /// # Returns
184 /// The size of the hash map in bytes.
185 pub(crate) fn size(&self) -> usize {
186 let fixed_size = size_of::<PruningJoinHashMap>();
187
188 // TODO: switch to using [HashTable::allocation_size] when available after upgrading hashbrown to 0.15
189 estimate_memory_size::<(u64, u64)>(self.map.capacity(), fixed_size).unwrap()
190 + self.next.capacity() * size_of::<u64>()
191 }
192
193 /// Removes hash values from the map and the list based on the given pruning
194 /// length and deleting offset.
195 ///
196 /// # Arguments
197 /// * `prune_length`: The number of elements to remove from the list.
198 /// * `deleting_offset`: The offset used to determine which hash values to remove from the map.
199 ///
200 /// # Returns
201 /// A `Result` indicating whether the operation was successful.
202 pub(crate) fn prune_hash_values(
203 &mut self,
204 prune_length: usize,
205 deleting_offset: u64,
206 shrink_factor: usize,
207 ) {
208 // Remove elements from the list based on the pruning length.
209 self.next.drain(0..prune_length);
210
211 // Calculate the keys that should be removed from the map.
212 let removable_keys = self
213 .map
214 .iter()
215 .filter_map(|(hash, tail_index)| {
216 (*tail_index < prune_length as u64 + deleting_offset).then_some(*hash)
217 })
218 .collect::<Vec<_>>();
219
220 // Remove the keys from the map.
221 removable_keys.into_iter().for_each(|hash_value| {
222 self.map
223 .find_entry(hash_value, |(hash, _)| hash_value == *hash)
224 .unwrap()
225 .remove();
226 });
227
228 // Shrink the map if necessary.
229 self.shrink_if_necessary(shrink_factor);
230 }
231}
232
233fn check_filter_expr_contains_sort_information(
234 expr: &Arc<dyn PhysicalExpr>,
235 reference: &Arc<dyn PhysicalExpr>,
236) -> bool {
237 expr.eq(reference)
238 || expr
239 .children()
240 .iter()
241 .any(|e| check_filter_expr_contains_sort_information(e, reference))
242}
243
244/// Create a one to one mapping from main columns to filter columns using
245/// filter column indices. A column index looks like:
246/// ```text
247/// ColumnIndex {
248/// index: 0, // field index in main schema
249/// side: JoinSide::Left, // child side
250/// }
251/// ```
252pub fn map_origin_col_to_filter_col(
253 filter: &JoinFilter,
254 schema: &SchemaRef,
255 side: &JoinSide,
256) -> Result<HashMap<Column, Column>> {
257 let filter_schema = filter.schema();
258 let mut col_to_col_map = HashMap::<Column, Column>::new();
259 for (filter_schema_index, index) in filter.column_indices().iter().enumerate() {
260 if index.side.eq(side) {
261 // Get the main field from column index:
262 let main_field = schema.field(index.index);
263 // Create a column expression:
264 let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?;
265 // Since the order of by filter.column_indices() is the same with
266 // that of intermediate schema fields, we can get the column directly.
267 let filter_field = filter_schema.field(filter_schema_index);
268 let filter_col = Column::new(filter_field.name(), filter_schema_index);
269 // Insert mapping:
270 col_to_col_map.insert(main_col, filter_col);
271 }
272 }
273 Ok(col_to_col_map)
274}
275
276/// This function analyzes [`PhysicalSortExpr`] graphs with respect to output orderings
277/// (sorting) properties. This is necessary since monotonically increasing and/or
278/// decreasing expressions are required when using join filter expressions for
279/// data pruning purposes.
280///
281/// The method works as follows:
282/// 1. Maps the original columns to the filter columns using the [`map_origin_col_to_filter_col`] function.
283/// 2. Collects all columns in the sort expression using the [`collect_columns`] function.
284/// 3. Checks if all columns are included in the map we obtain in the first step.
285/// 4. If all columns are included, the sort expression is converted into a filter expression using
286/// the [`convert_filter_columns`] function.
287/// 5. Searches for the converted filter expression in the filter expression using the
288/// [`check_filter_expr_contains_sort_information`] function.
289/// 6. If an exact match is found, returns the converted filter expression as `Some(Arc<dyn PhysicalExpr>)`.
290/// 7. If all columns are not included or an exact match is not found, returns [`None`].
291///
292/// Examples:
293/// Consider the filter expression "a + b > c + 10 AND a + b < c + 100".
294/// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter.
295/// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter.
296/// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However,
297/// there is no exact match, so this expression does not indicate pruning.
298pub fn convert_sort_expr_with_filter_schema(
299 side: &JoinSide,
300 filter: &JoinFilter,
301 schema: &SchemaRef,
302 sort_expr: &PhysicalSortExpr,
303) -> Result<Option<Arc<dyn PhysicalExpr>>> {
304 let column_map = map_origin_col_to_filter_col(filter, schema, side)?;
305 let expr = Arc::clone(&sort_expr.expr);
306 // Get main schema columns:
307 let expr_columns = collect_columns(&expr);
308 // Calculation is possible with `column_map` since sort exprs belong to a child.
309 let all_columns_are_included =
310 expr_columns.iter().all(|col| column_map.contains_key(col));
311 if all_columns_are_included {
312 // Since we are sure that one to one column mapping includes all columns, we convert
313 // the sort expression into a filter expression.
314 let converted_filter_expr = expr
315 .transform_up_with_lambdas_params(|p, lambdas_params| {
316 convert_filter_columns(p.as_ref(), &column_map, lambdas_params).map(
317 |transformed| match transformed {
318 Some(transformed) => Transformed::yes(transformed),
319 None => Transformed::no(p),
320 },
321 )
322 })
323 .data()?;
324 // Search the converted `PhysicalExpr` in filter expression; if an exact
325 // match is found, use this sorted expression in graph traversals.
326 if check_filter_expr_contains_sort_information(
327 filter.expression(),
328 &converted_filter_expr,
329 ) {
330 return Ok(Some(converted_filter_expr));
331 }
332 }
333 Ok(None)
334}
335
336/// This function is used to build the filter expression based on the sort order of input columns.
337///
338/// It first calls the [`convert_sort_expr_with_filter_schema`] method to determine if the sort
339/// order of columns can be used in the filter expression. If it returns a [`Some`] value, the
340/// method wraps the result in a [`SortedFilterExpr`] instance with the original sort expression and
341/// the converted filter expression. Otherwise, this function returns an error.
342///
343/// The `SortedFilterExpr` instance contains information about the sort order of columns that can
344/// be used in the filter expression, which can be used to optimize the query execution process.
345pub fn build_filter_input_order(
346 side: JoinSide,
347 filter: &JoinFilter,
348 schema: &SchemaRef,
349 order: &PhysicalSortExpr,
350) -> Result<Option<SortedFilterExpr>> {
351 let opt_expr = convert_sort_expr_with_filter_schema(&side, filter, schema, order)?;
352 opt_expr
353 .map(|filter_expr| {
354 SortedFilterExpr::try_new(order.clone(), filter_expr, filter.schema())
355 })
356 .transpose()
357}
358
359/// Convert a physical expression into a filter expression using the given
360/// column mapping information.
361fn convert_filter_columns(
362 input: &dyn PhysicalExpr,
363 column_map: &HashMap<Column, Column>,
364 lambdas_params: &HashSet<String>,
365) -> Result<Option<Arc<dyn PhysicalExpr>>> {
366 // Attempt to downcast the input expression to a Column type.
367 Ok(match input.as_any().downcast_ref::<Column>() {
368 Some(col) if !lambdas_params.contains(col.name()) => {
369 column_map.get(col).map(|c| Arc::new(c.clone()) as _)
370 }
371 _ => {
372 // If the downcast fails, return the input expression as is.
373 None
374 }
375 })
376}
377
378/// The [SortedFilterExpr] object represents a sorted filter expression. It
379/// contains the following information: The origin expression, the filter
380/// expression, an interval encapsulating expression bounds, and a stable
381/// index identifying the expression in the expression DAG.
382///
383/// Physical schema of a [JoinFilter]'s intermediate batch combines two sides
384/// and uses new column names. In this process, a column exchange is done so
385/// we can utilize sorting information while traversing the filter expression
386/// DAG for interval calculations. When evaluating the inner buffer, we use
387/// `origin_sorted_expr`.
388#[derive(Debug, Clone)]
389pub struct SortedFilterExpr {
390 /// Sorted expression from a join side (i.e. a child of the join)
391 origin_sorted_expr: PhysicalSortExpr,
392 /// Expression adjusted for filter schema.
393 filter_expr: Arc<dyn PhysicalExpr>,
394 /// Interval containing expression bounds
395 interval: Interval,
396 /// Node index in the expression DAG
397 node_index: usize,
398}
399
400impl SortedFilterExpr {
401 /// Constructor
402 pub fn try_new(
403 origin_sorted_expr: PhysicalSortExpr,
404 filter_expr: Arc<dyn PhysicalExpr>,
405 filter_schema: &Schema,
406 ) -> Result<Self> {
407 let dt = filter_expr.data_type(filter_schema)?;
408 Ok(Self {
409 origin_sorted_expr,
410 filter_expr,
411 interval: Interval::make_unbounded(&dt)?,
412 node_index: 0,
413 })
414 }
415
416 /// Get origin expr information
417 pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr {
418 &self.origin_sorted_expr
419 }
420
421 /// Get filter expr information
422 pub fn filter_expr(&self) -> &Arc<dyn PhysicalExpr> {
423 &self.filter_expr
424 }
425
426 /// Get interval information
427 pub fn interval(&self) -> &Interval {
428 &self.interval
429 }
430
431 /// Sets interval
432 pub fn set_interval(&mut self, interval: Interval) {
433 self.interval = interval;
434 }
435
436 /// Node index in ExprIntervalGraph
437 pub fn node_index(&self) -> usize {
438 self.node_index
439 }
440
441 /// Node index setter in ExprIntervalGraph
442 pub fn set_node_index(&mut self, node_index: usize) {
443 self.node_index = node_index;
444 }
445}
446
447/// Calculate the filter expression intervals.
448///
449/// This function updates the `interval` field of each `SortedFilterExpr` based
450/// on the first or the last value of the expression in `build_input_buffer`
451/// and `probe_batch`.
452///
453/// # Parameters
454///
455/// * `build_input_buffer` - The [RecordBatch] on the build side of the join.
456/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update.
457/// * `probe_batch` - The `RecordBatch` on the probe side of the join.
458/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update.
459///
460/// ## Note
461///
462/// Utilizing interval arithmetic, this function computes feasible join intervals
463/// on the pruning side by evaluating the prospective value ranges that might
464/// emerge in subsequent data batches from the enforcer side. This is done by
465/// first creating an interval for join filter values in the pruning side of the
466/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/
467/// ascending) of the filter expression. Here, `FV` denotes the first value on the
468/// pruning side. This range is then compared with the enforcer side interval,
469/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/
470/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer
471/// side.
472///
473/// As a concrete example, consider the following query:
474///
475/// ```text
476/// SELECT * FROM left_table, right_table
477/// WHERE
478/// left_key = right_key AND
479/// a > b - 3 AND
480/// a < b + 10
481/// ```
482///
483/// where columns `a` and `b` come from tables `left_table` and `right_table`,
484/// respectively. When a new `RecordBatch` arrives at the right side, the
485/// condition `a > b - 3` will possibly indicate a prunable range for the left
486/// side. Conversely, when a new `RecordBatch` arrives at the left side, the
487/// condition `a < b + 10` will possibly indicate prunability for the right side.
488/// Let’s inspect what happens when a new `RecordBatch` arrives at the right
489/// side (i.e. when the left side is the build side):
490///
491/// ```text
492/// Build Probe
493/// +-------+ +-------+
494/// | a | z | | b | y |
495/// |+--|--+| |+--|--+|
496/// | 1 | 2 | | 4 | 3 |
497/// |+--|--+| |+--|--+|
498/// | 3 | 1 | | 4 | 3 |
499/// |+--|--+| |+--|--+|
500/// | 5 | 7 | | 6 | 1 |
501/// |+--|--+| |+--|--+|
502/// | 7 | 1 | | 6 | 3 |
503/// +-------+ +-------+
504/// ```
505///
506/// In this case, the interval representing viable (i.e. joinable) values for
507/// column `a` is `[1, ∞]`, and the interval representing possible future values
508/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate
509/// intervals for the whole filter expression and propagate join constraint by
510/// traversing the expression graph.
511pub fn calculate_filter_expr_intervals(
512 build_input_buffer: &RecordBatch,
513 build_sorted_filter_expr: &mut SortedFilterExpr,
514 probe_batch: &RecordBatch,
515 probe_sorted_filter_expr: &mut SortedFilterExpr,
516) -> Result<()> {
517 // If either build or probe side has no data, return early:
518 if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
519 return Ok(());
520 }
521 // Calculate the interval for the build side filter expression (if present):
522 update_filter_expr_interval(
523 &build_input_buffer.slice(0, 1),
524 build_sorted_filter_expr,
525 )?;
526 // Calculate the interval for the probe side filter expression (if present):
527 update_filter_expr_interval(
528 &probe_batch.slice(probe_batch.num_rows() - 1, 1),
529 probe_sorted_filter_expr,
530 )
531}
532
533/// This is a subroutine of the function [`calculate_filter_expr_intervals`].
534/// It constructs the current interval using the given `batch` and updates
535/// the filter expression (i.e. `sorted_expr`) with this interval.
536pub fn update_filter_expr_interval(
537 batch: &RecordBatch,
538 sorted_expr: &mut SortedFilterExpr,
539) -> Result<()> {
540 // Evaluate the filter expression and convert the result to an array:
541 let array = sorted_expr
542 .origin_sorted_expr()
543 .expr
544 .evaluate(batch)?
545 .into_array(1)?;
546 // Convert the array to a ScalarValue:
547 let value = ScalarValue::try_from_array(&array, 0)?;
548 // Create a ScalarValue representing positive or negative infinity for the same data type:
549 let inf = ScalarValue::try_from(value.data_type())?;
550 // Update the interval with lower and upper bounds based on the sort option:
551 let interval = if sorted_expr.origin_sorted_expr().options.descending {
552 Interval::try_new(inf, value)?
553 } else {
554 Interval::try_new(value, inf)?
555 };
556 // Set the calculated interval for the sorted filter expression:
557 sorted_expr.set_interval(interval);
558 Ok(())
559}
560
561/// Get the anti join indices from the visited hash set.
562///
563/// This method returns the indices from the original input that were not present in the visited hash set.
564///
565/// # Arguments
566///
567/// * `prune_length` - The length of the pruned record batch.
568/// * `deleted_offset` - The offset to the indices.
569/// * `visited_rows` - The hash set of visited indices.
570///
571/// # Returns
572///
573/// A `PrimitiveArray` of the anti join indices.
574pub fn get_pruning_anti_indices<T: ArrowPrimitiveType>(
575 prune_length: usize,
576 deleted_offset: usize,
577 visited_rows: &HashSet<usize>,
578) -> PrimitiveArray<T>
579where
580 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
581{
582 let mut bitmap = BooleanBufferBuilder::new(prune_length);
583 bitmap.append_n(prune_length, false);
584 // mark the indices as true if they are present in the visited hash set
585 for v in 0..prune_length {
586 let row = v + deleted_offset;
587 bitmap.set_bit(v, visited_rows.contains(&row));
588 }
589 // get the anti index
590 (0..prune_length)
591 .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
592 .collect()
593}
594
595/// This method creates a boolean buffer from the visited rows hash set
596/// and the indices of the pruned record batch slice.
597///
598/// It gets the indices from the original input that were present in the visited hash set.
599///
600/// # Arguments
601///
602/// * `prune_length` - The length of the pruned record batch.
603/// * `deleted_offset` - The offset to the indices.
604/// * `visited_rows` - The hash set of visited indices.
605///
606/// # Returns
607///
608/// A [PrimitiveArray] of the specified type T, containing the semi indices.
609pub fn get_pruning_semi_indices<T: ArrowPrimitiveType>(
610 prune_length: usize,
611 deleted_offset: usize,
612 visited_rows: &HashSet<usize>,
613) -> PrimitiveArray<T>
614where
615 NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
616{
617 let mut bitmap = BooleanBufferBuilder::new(prune_length);
618 bitmap.append_n(prune_length, false);
619 // mark the indices as true if they are present in the visited hash set
620 (0..prune_length).for_each(|v| {
621 let row = &(v + deleted_offset);
622 bitmap.set_bit(v, visited_rows.contains(row));
623 });
624 // get the semi index
625 (0..prune_length)
626 .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx)))
627 .collect()
628}
629
630pub fn combine_two_batches(
631 output_schema: &SchemaRef,
632 left_batch: Option<RecordBatch>,
633 right_batch: Option<RecordBatch>,
634) -> Result<Option<RecordBatch>> {
635 match (left_batch, right_batch) {
636 (Some(batch), None) | (None, Some(batch)) => {
637 // If only one of the batches are present, return it:
638 Ok(Some(batch))
639 }
640 (Some(left_batch), Some(right_batch)) => {
641 // If both batches are present, concatenate them:
642 concat_batches(output_schema, &[left_batch, right_batch])
643 .map_err(|e| arrow_datafusion_err!(e))
644 .map(Some)
645 }
646 (None, None) => {
647 // If neither is present, return an empty batch:
648 Ok(None)
649 }
650 }
651}
652
653/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`.
654/// This function will insert the indices (offset by `offset`) into the `visited` hash set.
655///
656/// # Arguments
657///
658/// * `visited` - A hash set to store the visited indices.
659/// * `offset` - An offset to the indices in the `PrimitiveArray`.
660/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded.
661pub fn record_visited_indices<T: ArrowPrimitiveType>(
662 visited: &mut HashSet<usize>,
663 offset: usize,
664 indices: &PrimitiveArray<T>,
665) {
666 for i in indices.values() {
667 visited.insert(i.as_usize() + offset);
668 }
669}
670
671#[derive(Debug)]
672pub struct StreamJoinSideMetrics {
673 /// Number of batches consumed by this operator
674 pub(crate) input_batches: metrics::Count,
675 /// Number of rows consumed by this operator
676 pub(crate) input_rows: metrics::Count,
677}
678
679/// Metrics for HashJoinExec
680#[derive(Debug)]
681pub struct StreamJoinMetrics {
682 /// Number of left batches/rows consumed by this operator
683 pub(crate) left: StreamJoinSideMetrics,
684 /// Number of right batches/rows consumed by this operator
685 pub(crate) right: StreamJoinSideMetrics,
686 /// Memory used by sides in bytes
687 pub(crate) stream_memory_usage: metrics::Gauge,
688 /// Number of batches produced by this operator
689 pub(crate) output_batches: metrics::Count,
690 /// Number of rows produced by this operator
691 pub(crate) baseline_metrics: BaselineMetrics,
692}
693
694impl StreamJoinMetrics {
695 pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
696 let input_batches =
697 MetricBuilder::new(metrics).counter("input_batches", partition);
698 let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
699 let left = StreamJoinSideMetrics {
700 input_batches,
701 input_rows,
702 };
703
704 let input_batches =
705 MetricBuilder::new(metrics).counter("input_batches", partition);
706 let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition);
707 let right = StreamJoinSideMetrics {
708 input_batches,
709 input_rows,
710 };
711
712 let stream_memory_usage =
713 MetricBuilder::new(metrics).gauge("stream_memory_usage", partition);
714
715 let output_batches =
716 MetricBuilder::new(metrics).counter("output_batches", partition);
717
718 Self {
719 left,
720 right,
721 output_batches,
722 stream_memory_usage,
723 baseline_metrics: BaselineMetrics::new(metrics, partition),
724 }
725 }
726}
727
728/// Updates sorted filter expressions with corresponding node indices from the
729/// expression interval graph.
730///
731/// This function iterates through the provided sorted filter expressions,
732/// gathers the corresponding node indices from the expression interval graph,
733/// and then updates the sorted expressions with these indices. It ensures
734/// that these sorted expressions are aligned with the structure of the graph.
735fn update_sorted_exprs_with_node_indices(
736 graph: &mut ExprIntervalGraph,
737 sorted_exprs: &mut [SortedFilterExpr],
738) {
739 // Extract filter expressions from the sorted expressions:
740 let filter_exprs = sorted_exprs
741 .iter()
742 .map(|expr| Arc::clone(expr.filter_expr()))
743 .collect::<Vec<_>>();
744
745 // Gather corresponding node indices for the extracted filter expressions from the graph:
746 let child_node_indices = graph.gather_node_indices(&filter_exprs);
747
748 // Iterate through the sorted expressions and the gathered node indices:
749 for (sorted_expr, (_, index)) in sorted_exprs.iter_mut().zip(child_node_indices) {
750 // Update each sorted expression with the corresponding node index:
751 sorted_expr.set_node_index(index);
752 }
753}
754
755/// Prepares and sorts expressions based on a given filter, left and right schemas,
756/// and sort expressions.
757///
758/// This function prepares sorted filter expressions for both the left and right
759/// sides of a join operation. It first builds the filter order for each side
760/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter
761/// expressions, the function then constructs an expression interval graph and
762/// updates the sorted expressions with node indices. The final sorted filter
763/// expressions for both sides are then returned.
764///
765/// # Parameters
766///
767/// * `filter` - The join filter to base the sorting on.
768/// * `left` - The `ExecutionPlan` for the left side of the join.
769/// * `right` - The `ExecutionPlan` for the right side of the join.
770/// * `left_sort_exprs` - The expressions to sort on the left side.
771/// * `right_sort_exprs` - The expressions to sort on the right side.
772///
773/// # Returns
774///
775/// * A tuple consisting of the sorted filter expression for the left and right sides, and an expression interval graph.
776pub fn prepare_sorted_exprs(
777 filter: &JoinFilter,
778 left: &Arc<dyn ExecutionPlan>,
779 right: &Arc<dyn ExecutionPlan>,
780 left_sort_exprs: &LexOrdering,
781 right_sort_exprs: &LexOrdering,
782) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
783 let err = || {
784 datafusion_common::plan_datafusion_err!("Filter does not include the child order")
785 };
786
787 // Build the filter order for the left side:
788 let left_temp_sorted_filter_expr = build_filter_input_order(
789 JoinSide::Left,
790 filter,
791 &left.schema(),
792 &left_sort_exprs[0],
793 )?
794 .ok_or_else(err)?;
795
796 // Build the filter order for the right side:
797 let right_temp_sorted_filter_expr = build_filter_input_order(
798 JoinSide::Right,
799 filter,
800 &right.schema(),
801 &right_sort_exprs[0],
802 )?
803 .ok_or_else(err)?;
804
805 // Collect the sorted expressions
806 let mut sorted_exprs =
807 vec![left_temp_sorted_filter_expr, right_temp_sorted_filter_expr];
808
809 // Build the expression interval graph
810 let mut graph =
811 ExprIntervalGraph::try_new(Arc::clone(filter.expression()), filter.schema())?;
812
813 // Update sorted expressions with node indices
814 update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs);
815
816 // Swap and remove to get the final sorted filter expressions
817 let right_sorted_filter_expr = sorted_exprs.swap_remove(1);
818 let left_sorted_filter_expr = sorted_exprs.swap_remove(0);
819
820 Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
821}
822
823#[cfg(test)]
824pub mod tests {
825
826 use super::*;
827 use crate::{joins::test_utils::complicated_filter, joins::utils::ColumnIndex};
828
829 use arrow::compute::SortOptions;
830 use arrow::datatypes::{DataType, Field};
831 use datafusion_expr::Operator;
832 use datafusion_physical_expr::expressions::{binary, cast, col};
833
834 #[test]
835 fn test_column_exchange() -> Result<()> {
836 let left_child_schema =
837 Schema::new(vec![Field::new("left_1", DataType::Int32, true)]);
838 // Sorting information for the left side:
839 let left_child_sort_expr = PhysicalSortExpr {
840 expr: col("left_1", &left_child_schema)?,
841 options: SortOptions::default(),
842 };
843
844 let right_child_schema = Schema::new(vec![
845 Field::new("right_1", DataType::Int32, true),
846 Field::new("right_2", DataType::Int32, true),
847 ]);
848 // Sorting information for the right side:
849 let right_child_sort_expr = PhysicalSortExpr {
850 expr: binary(
851 col("right_1", &right_child_schema)?,
852 Operator::Plus,
853 col("right_2", &right_child_schema)?,
854 &right_child_schema,
855 )?,
856 options: SortOptions::default(),
857 };
858
859 let intermediate_schema = Schema::new(vec![
860 Field::new("filter_1", DataType::Int32, true),
861 Field::new("filter_2", DataType::Int32, true),
862 Field::new("filter_3", DataType::Int32, true),
863 ]);
864 // Our filter expression is: left_1 > right_1 + right_2.
865 let filter_left = col("filter_1", &intermediate_schema)?;
866 let filter_right = binary(
867 col("filter_2", &intermediate_schema)?,
868 Operator::Plus,
869 col("filter_3", &intermediate_schema)?,
870 &intermediate_schema,
871 )?;
872 let filter_expr = binary(
873 Arc::clone(&filter_left),
874 Operator::Gt,
875 Arc::clone(&filter_right),
876 &intermediate_schema,
877 )?;
878 let column_indices = vec![
879 ColumnIndex {
880 index: 0,
881 side: JoinSide::Left,
882 },
883 ColumnIndex {
884 index: 0,
885 side: JoinSide::Right,
886 },
887 ColumnIndex {
888 index: 1,
889 side: JoinSide::Right,
890 },
891 ];
892 let filter =
893 JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
894
895 let left_sort_filter_expr = build_filter_input_order(
896 JoinSide::Left,
897 &filter,
898 &Arc::new(left_child_schema),
899 &left_child_sort_expr,
900 )?
901 .unwrap();
902 assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr()));
903
904 let right_sort_filter_expr = build_filter_input_order(
905 JoinSide::Right,
906 &filter,
907 &Arc::new(right_child_schema),
908 &right_child_sort_expr,
909 )?
910 .unwrap();
911 assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr()));
912
913 // Assert that adjusted (left) filter expression matches with `left_child_sort_expr`:
914 assert!(filter_left.eq(left_sort_filter_expr.filter_expr()));
915 // Assert that adjusted (right) filter expression matches with `right_child_sort_expr`:
916 assert!(filter_right.eq(right_sort_filter_expr.filter_expr()));
917 Ok(())
918 }
919
920 #[test]
921 fn test_column_collector() -> Result<()> {
922 let schema = Schema::new(vec![
923 Field::new("0", DataType::Int32, true),
924 Field::new("1", DataType::Int32, true),
925 Field::new("2", DataType::Int32, true),
926 ]);
927 let filter_expr = complicated_filter(&schema)?;
928 let columns = collect_columns(&filter_expr);
929 assert_eq!(columns.len(), 3);
930 Ok(())
931 }
932
933 #[test]
934 fn find_expr_inside_expr() -> Result<()> {
935 let schema = Schema::new(vec![
936 Field::new("0", DataType::Int32, true),
937 Field::new("1", DataType::Int32, true),
938 Field::new("2", DataType::Int32, true),
939 ]);
940 let filter_expr = complicated_filter(&schema)?;
941
942 let expr_1 = Arc::new(Column::new("gnz", 0)) as _;
943 assert!(!check_filter_expr_contains_sort_information(
944 &filter_expr,
945 &expr_1
946 ));
947
948 let expr_2 = col("1", &schema)? as _;
949
950 assert!(check_filter_expr_contains_sort_information(
951 &filter_expr,
952 &expr_2
953 ));
954
955 let expr_3 = cast(
956 binary(
957 col("0", &schema)?,
958 Operator::Plus,
959 col("1", &schema)?,
960 &schema,
961 )?,
962 &schema,
963 DataType::Int64,
964 )?;
965
966 assert!(check_filter_expr_contains_sort_information(
967 &filter_expr,
968 &expr_3
969 ));
970
971 let expr_4 = Arc::new(Column::new("1", 42)) as _;
972
973 assert!(!check_filter_expr_contains_sort_information(
974 &filter_expr,
975 &expr_4,
976 ));
977 Ok(())
978 }
979
980 #[test]
981 fn build_sorted_expr() -> Result<()> {
982 let left_schema = Schema::new(vec![
983 Field::new("la1", DataType::Int32, false),
984 Field::new("lb1", DataType::Int32, false),
985 Field::new("lc1", DataType::Int32, false),
986 Field::new("lt1", DataType::Int32, false),
987 Field::new("la2", DataType::Int32, false),
988 Field::new("la1_des", DataType::Int32, false),
989 ]);
990
991 let right_schema = Schema::new(vec![
992 Field::new("ra1", DataType::Int32, false),
993 Field::new("rb1", DataType::Int32, false),
994 Field::new("rc1", DataType::Int32, false),
995 Field::new("rt1", DataType::Int32, false),
996 Field::new("ra2", DataType::Int32, false),
997 Field::new("ra1_des", DataType::Int32, false),
998 ]);
999
1000 let intermediate_schema = Schema::new(vec![
1001 Field::new("0", DataType::Int32, true),
1002 Field::new("1", DataType::Int32, true),
1003 Field::new("2", DataType::Int32, true),
1004 ]);
1005 let filter_expr = complicated_filter(&intermediate_schema)?;
1006 let column_indices = vec![
1007 ColumnIndex {
1008 index: left_schema.index_of("la1")?,
1009 side: JoinSide::Left,
1010 },
1011 ColumnIndex {
1012 index: left_schema.index_of("la2")?,
1013 side: JoinSide::Left,
1014 },
1015 ColumnIndex {
1016 index: right_schema.index_of("ra1")?,
1017 side: JoinSide::Right,
1018 },
1019 ];
1020 let filter =
1021 JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1022
1023 let left_schema = Arc::new(left_schema);
1024 let right_schema = Arc::new(right_schema);
1025
1026 assert!(build_filter_input_order(
1027 JoinSide::Left,
1028 &filter,
1029 &left_schema,
1030 &PhysicalSortExpr {
1031 expr: col("la1", left_schema.as_ref())?,
1032 options: SortOptions::default(),
1033 }
1034 )?
1035 .is_some());
1036 assert!(build_filter_input_order(
1037 JoinSide::Left,
1038 &filter,
1039 &left_schema,
1040 &PhysicalSortExpr {
1041 expr: col("lt1", left_schema.as_ref())?,
1042 options: SortOptions::default(),
1043 }
1044 )?
1045 .is_none());
1046 assert!(build_filter_input_order(
1047 JoinSide::Right,
1048 &filter,
1049 &right_schema,
1050 &PhysicalSortExpr {
1051 expr: col("ra1", right_schema.as_ref())?,
1052 options: SortOptions::default(),
1053 }
1054 )?
1055 .is_some());
1056 assert!(build_filter_input_order(
1057 JoinSide::Right,
1058 &filter,
1059 &right_schema,
1060 &PhysicalSortExpr {
1061 expr: col("rb1", right_schema.as_ref())?,
1062 options: SortOptions::default(),
1063 }
1064 )?
1065 .is_none());
1066
1067 Ok(())
1068 }
1069
1070 // Test the case when we have an "ORDER BY a + b", and join filter condition includes "a - b".
1071 #[test]
1072 fn sorted_filter_expr_build() -> Result<()> {
1073 let intermediate_schema = Schema::new(vec![
1074 Field::new("0", DataType::Int32, true),
1075 Field::new("1", DataType::Int32, true),
1076 ]);
1077 let filter_expr = binary(
1078 col("0", &intermediate_schema)?,
1079 Operator::Minus,
1080 col("1", &intermediate_schema)?,
1081 &intermediate_schema,
1082 )?;
1083 let column_indices = vec![
1084 ColumnIndex {
1085 index: 0,
1086 side: JoinSide::Left,
1087 },
1088 ColumnIndex {
1089 index: 1,
1090 side: JoinSide::Left,
1091 },
1092 ];
1093 let filter =
1094 JoinFilter::new(filter_expr, column_indices, Arc::new(intermediate_schema));
1095
1096 let schema = Schema::new(vec![
1097 Field::new("a", DataType::Int32, false),
1098 Field::new("b", DataType::Int32, false),
1099 ]);
1100
1101 let sorted = PhysicalSortExpr {
1102 expr: binary(
1103 col("a", &schema)?,
1104 Operator::Plus,
1105 col("b", &schema)?,
1106 &schema,
1107 )?,
1108 options: SortOptions::default(),
1109 };
1110
1111 let res = convert_sort_expr_with_filter_schema(
1112 &JoinSide::Left,
1113 &filter,
1114 &Arc::new(schema),
1115 &sorted,
1116 )?;
1117 assert!(res.is_none());
1118 Ok(())
1119 }
1120
1121 #[test]
1122 fn test_shrink_if_necessary() {
1123 let scale_factor = 4;
1124 let mut join_hash_map = PruningJoinHashMap::with_capacity(100);
1125 let data_size = 2000;
1126 let deleted_part = 3 * data_size / 4;
1127 // Add elements to the JoinHashMap
1128 for hash_value in 0..data_size {
1129 join_hash_map.map.insert_unique(
1130 hash_value,
1131 (hash_value, hash_value),
1132 |(hash, _)| *hash,
1133 );
1134 }
1135
1136 assert_eq!(join_hash_map.map.len(), data_size as usize);
1137 assert!(join_hash_map.map.capacity() >= data_size as usize);
1138
1139 // Remove some elements from the JoinHashMap
1140 for hash_value in 0..deleted_part {
1141 join_hash_map
1142 .map
1143 .find_entry(hash_value, |(hash, _)| hash_value == *hash)
1144 .unwrap()
1145 .remove();
1146 }
1147
1148 assert_eq!(join_hash_map.map.len(), (data_size - deleted_part) as usize);
1149
1150 // Old capacity
1151 let old_capacity = join_hash_map.map.capacity();
1152
1153 // Test shrink_if_necessary
1154 join_hash_map.shrink_if_necessary(scale_factor);
1155
1156 // The capacity should be reduced by the scale factor
1157 let new_expected_capacity =
1158 join_hash_map.map.capacity() * (scale_factor - 1) / scale_factor;
1159 assert!(join_hash_map.map.capacity() >= new_expected_capacity);
1160 assert!(join_hash_map.map.capacity() <= old_capacity);
1161 }
1162}