datafusion_physical_expr/equivalence/properties/
joins.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::EquivalenceProperties;
19use crate::{equivalence::OrderingEquivalenceClass, PhysicalExprRef};
20
21use arrow::datatypes::SchemaRef;
22use datafusion_common::{JoinSide, JoinType, Result};
23
24/// Calculate ordering equivalence properties for the given join operation.
25pub fn join_equivalence_properties(
26    left: EquivalenceProperties,
27    right: EquivalenceProperties,
28    join_type: &JoinType,
29    join_schema: SchemaRef,
30    maintains_input_order: &[bool],
31    probe_side: Option<JoinSide>,
32    on: &[(PhysicalExprRef, PhysicalExprRef)],
33) -> Result<EquivalenceProperties> {
34    let left_size = left.schema.fields.len();
35    let mut result = EquivalenceProperties::new(join_schema);
36    result.add_equivalence_group(left.eq_group().join(
37        right.eq_group(),
38        join_type,
39        left_size,
40        on,
41    )?)?;
42
43    let EquivalenceProperties {
44        oeq_class: left_oeq_class,
45        ..
46    } = left;
47    let EquivalenceProperties {
48        oeq_class: mut right_oeq_class,
49        ..
50    } = right;
51    match maintains_input_order {
52        [true, false] => {
53            // In this special case, right side ordering can be prefixed with
54            // the left side ordering.
55            if matches!(join_type, JoinType::Inner | JoinType::Left)
56                && probe_side == Some(JoinSide::Left)
57            {
58                updated_right_ordering_equivalence_class(
59                    &mut right_oeq_class,
60                    join_type,
61                    left_size,
62                )?;
63
64                // Right side ordering equivalence properties should be prepended
65                // with those of the left side while constructing output ordering
66                // equivalence properties since stream side is the left side.
67                //
68                // For example, if the right side ordering equivalences contain
69                // `b ASC`, and the left side ordering equivalences contain `a ASC`,
70                // then we should add `a ASC, b ASC` to the ordering equivalences
71                // of the join output.
72                let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class);
73                result.add_orderings(out_oeq_class);
74            } else {
75                result.add_orderings(left_oeq_class);
76            }
77        }
78        [false, true] => {
79            updated_right_ordering_equivalence_class(
80                &mut right_oeq_class,
81                join_type,
82                left_size,
83            )?;
84            // In this special case, left side ordering can be prefixed with
85            // the right side ordering.
86            if matches!(join_type, JoinType::Inner | JoinType::Right)
87                && probe_side == Some(JoinSide::Right)
88            {
89                // Left side ordering equivalence properties should be prepended
90                // with those of the right side while constructing output ordering
91                // equivalence properties since stream side is the right side.
92                //
93                // For example, if the left side ordering equivalences contain
94                // `a ASC`, and the right side ordering equivalences contain `b ASC`,
95                // then we should add `b ASC, a ASC` to the ordering equivalences
96                // of the join output.
97                let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class);
98                result.add_orderings(out_oeq_class);
99            } else {
100                result.add_orderings(right_oeq_class);
101            }
102        }
103        [false, false] => {}
104        [true, true] => unreachable!("Cannot maintain ordering of both sides"),
105        _ => unreachable!("Join operators can not have more than two children"),
106    }
107    Ok(result)
108}
109
110/// In the context of a join, update the right side `OrderingEquivalenceClass`
111/// so that they point to valid indices in the join output schema.
112///
113/// To do so, we increment column indices by the size of the left table when
114/// join schema consists of a combination of the left and right schemas. This
115/// is the case for `Inner`, `Left`, `Full` and `Right` joins. For other cases,
116/// indices do not change.
117pub fn updated_right_ordering_equivalence_class(
118    right_oeq_class: &mut OrderingEquivalenceClass,
119    join_type: &JoinType,
120    left_size: usize,
121) -> Result<()> {
122    if matches!(
123        join_type,
124        JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right
125    ) {
126        right_oeq_class.add_offset(left_size as _)?;
127    }
128    Ok(())
129}
130
131#[cfg(test)]
132mod tests {
133    use std::sync::Arc;
134
135    use super::*;
136    use crate::equivalence::convert_to_orderings;
137    use crate::equivalence::tests::create_test_schema;
138    use crate::expressions::col;
139    use crate::physical_expr::add_offset_to_expr;
140
141    use arrow::compute::SortOptions;
142    use arrow::datatypes::{DataType, Field, Fields, Schema};
143    use datafusion_common::Result;
144
145    #[test]
146    fn test_join_equivalence_properties() -> Result<()> {
147        let schema = create_test_schema()?;
148        let col_a = &col("a", &schema)?;
149        let col_b = &col("b", &schema)?;
150        let col_c = &col("c", &schema)?;
151        let offset = schema.fields.len() as _;
152        let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset)?;
153        let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset)?;
154        let option_asc = SortOptions {
155            descending: false,
156            nulls_first: false,
157        };
158        let test_cases = vec![
159            // ------- TEST CASE 1 --------
160            // [a ASC], [b ASC]
161            (
162                // [a ASC], [b ASC]
163                vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]],
164                // [a ASC], [b ASC]
165                vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]],
166                // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC]
167                vec![
168                    vec![(col_a, option_asc), (col_a2, option_asc)],
169                    vec![(col_a, option_asc), (col_b2, option_asc)],
170                    vec![(col_b, option_asc), (col_a2, option_asc)],
171                    vec![(col_b, option_asc), (col_b2, option_asc)],
172                ],
173            ),
174            // ------- TEST CASE 2 --------
175            // [a ASC], [b ASC]
176            (
177                // [a ASC], [b ASC], [c ASC]
178                vec![
179                    vec![(col_a, option_asc)],
180                    vec![(col_b, option_asc)],
181                    vec![(col_c, option_asc)],
182                ],
183                // [a ASC], [b ASC]
184                vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]],
185                // expected [a ASC, a2 ASC], [a ASC, b2 ASC], [b ASC, a2 ASC], [b ASC, b2 ASC], [c ASC, a2 ASC], [c ASC, b2 ASC]
186                vec![
187                    vec![(col_a, option_asc), (col_a2, option_asc)],
188                    vec![(col_a, option_asc), (col_b2, option_asc)],
189                    vec![(col_b, option_asc), (col_a2, option_asc)],
190                    vec![(col_b, option_asc), (col_b2, option_asc)],
191                    vec![(col_c, option_asc), (col_a2, option_asc)],
192                    vec![(col_c, option_asc), (col_b2, option_asc)],
193                ],
194            ),
195        ];
196        for (left_orderings, right_orderings, expected) in test_cases {
197            let mut left_eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
198            let mut right_eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
199            let left_orderings = convert_to_orderings(&left_orderings);
200            let right_orderings = convert_to_orderings(&right_orderings);
201            let expected = convert_to_orderings(&expected);
202            left_eq_properties.add_orderings(left_orderings);
203            right_eq_properties.add_orderings(right_orderings);
204            let join_eq = join_equivalence_properties(
205                left_eq_properties,
206                right_eq_properties,
207                &JoinType::Inner,
208                Arc::new(Schema::empty()),
209                &[true, false],
210                Some(JoinSide::Left),
211                &[],
212            )?;
213            let err_msg =
214                format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class);
215            assert_eq!(join_eq.oeq_class.len(), expected.len(), "{err_msg}");
216            for ordering in join_eq.oeq_class {
217                assert!(
218                    expected.contains(&ordering),
219                    "{err_msg}, ordering: {ordering:?}"
220                );
221            }
222        }
223        Ok(())
224    }
225
226    #[test]
227    fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> {
228        let join_type = JoinType::Inner;
229        // Join right child schema
230        let child_fields: Fields = ["x", "y", "z", "w"]
231            .into_iter()
232            .map(|name| Field::new(name, DataType::Int32, true))
233            .collect();
234        let child_schema = Schema::new(child_fields);
235        let col_x = &col("x", &child_schema)?;
236        let col_y = &col("y", &child_schema)?;
237        let col_z = &col("z", &child_schema)?;
238        let col_w = &col("w", &child_schema)?;
239        let option_asc = SortOptions {
240            descending: false,
241            nulls_first: false,
242        };
243        // [x ASC, y ASC], [z ASC, w ASC]
244        let orderings = vec![
245            vec![(col_x, option_asc), (col_y, option_asc)],
246            vec![(col_z, option_asc), (col_w, option_asc)],
247        ];
248        let orderings = convert_to_orderings(&orderings);
249        // Right child ordering equivalences
250        let mut right_oeq_class = OrderingEquivalenceClass::from(orderings);
251
252        let left_columns_len = 4;
253
254        let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"]
255            .into_iter()
256            .map(|name| Field::new(name, DataType::Int32, true))
257            .collect();
258
259        // Join Schema
260        let schema = Schema::new(fields);
261        let col_a = col("a", &schema)?;
262        let col_d = col("d", &schema)?;
263        let col_x = col("x", &schema)?;
264        let col_y = col("y", &schema)?;
265        let col_z = col("z", &schema)?;
266        let col_w = col("w", &schema)?;
267
268        let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema));
269        // a=x and d=w
270        join_eq_properties.add_equal_conditions(col_a, Arc::clone(&col_x))?;
271        join_eq_properties.add_equal_conditions(col_d, Arc::clone(&col_w))?;
272
273        updated_right_ordering_equivalence_class(
274            &mut right_oeq_class,
275            &join_type,
276            left_columns_len,
277        )?;
278        join_eq_properties.add_orderings(right_oeq_class);
279        let result = join_eq_properties.oeq_class().clone();
280
281        // [x ASC, y ASC], [z ASC, w ASC]
282        let orderings = vec![
283            vec![(col_x, option_asc), (col_y, option_asc)],
284            vec![(col_z, option_asc), (col_w, option_asc)],
285        ];
286        let orderings = convert_to_orderings(&orderings);
287        let expected = OrderingEquivalenceClass::from(orderings);
288
289        assert_eq!(result, expected);
290
291        Ok(())
292    }
293}