1use super::EquivalenceProperties;
19use crate::{equivalence::OrderingEquivalenceClass, PhysicalExprRef};
20
21use arrow::datatypes::SchemaRef;
22use datafusion_common::{JoinSide, JoinType, Result};
23
24pub fn join_equivalence_properties(
26 left: EquivalenceProperties,
27 right: EquivalenceProperties,
28 join_type: &JoinType,
29 join_schema: SchemaRef,
30 maintains_input_order: &[bool],
31 probe_side: Option<JoinSide>,
32 on: &[(PhysicalExprRef, PhysicalExprRef)],
33) -> Result<EquivalenceProperties> {
34 let left_size = left.schema.fields.len();
35 let mut result = EquivalenceProperties::new(join_schema);
36 result.add_equivalence_group(left.eq_group().join(
37 right.eq_group(),
38 join_type,
39 left_size,
40 on,
41 )?)?;
42
43 let EquivalenceProperties {
44 oeq_class: left_oeq_class,
45 ..
46 } = left;
47 let EquivalenceProperties {
48 oeq_class: mut right_oeq_class,
49 ..
50 } = right;
51 match maintains_input_order {
52 [true, false] => {
53 if matches!(join_type, JoinType::Inner | JoinType::Left)
56 && probe_side == Some(JoinSide::Left)
57 {
58 updated_right_ordering_equivalence_class(
59 &mut right_oeq_class,
60 join_type,
61 left_size,
62 )?;
63
64 let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class);
73 result.add_orderings(out_oeq_class);
74 } else {
75 result.add_orderings(left_oeq_class);
76 }
77 }
78 [false, true] => {
79 updated_right_ordering_equivalence_class(
80 &mut right_oeq_class,
81 join_type,
82 left_size,
83 )?;
84 if matches!(join_type, JoinType::Inner | JoinType::Right)
87 && probe_side == Some(JoinSide::Right)
88 {
89 let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class);
98 result.add_orderings(out_oeq_class);
99 } else {
100 result.add_orderings(right_oeq_class);
101 }
102 }
103 [false, false] => {}
104 [true, true] => unreachable!("Cannot maintain ordering of both sides"),
105 _ => unreachable!("Join operators can not have more than two children"),
106 }
107 Ok(result)
108}
109
110pub fn updated_right_ordering_equivalence_class(
118 right_oeq_class: &mut OrderingEquivalenceClass,
119 join_type: &JoinType,
120 left_size: usize,
121) -> Result<()> {
122 if matches!(
123 join_type,
124 JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right
125 ) {
126 right_oeq_class.add_offset(left_size as _)?;
127 }
128 Ok(())
129}
130
131#[cfg(test)]
132mod tests {
133 use std::sync::Arc;
134
135 use super::*;
136 use crate::equivalence::convert_to_orderings;
137 use crate::equivalence::tests::create_test_schema;
138 use crate::expressions::col;
139 use crate::physical_expr::add_offset_to_expr;
140
141 use arrow::compute::SortOptions;
142 use arrow::datatypes::{DataType, Field, Fields, Schema};
143 use datafusion_common::Result;
144
145 #[test]
146 fn test_join_equivalence_properties() -> Result<()> {
147 let schema = create_test_schema()?;
148 let col_a = &col("a", &schema)?;
149 let col_b = &col("b", &schema)?;
150 let col_c = &col("c", &schema)?;
151 let offset = schema.fields.len() as _;
152 let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset)?;
153 let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset)?;
154 let option_asc = SortOptions {
155 descending: false,
156 nulls_first: false,
157 };
158 let test_cases = vec![
159 (
162 vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]],
164 vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]],
166 vec![
168 vec![(col_a, option_asc), (col_a2, option_asc)],
169 vec![(col_a, option_asc), (col_b2, option_asc)],
170 vec![(col_b, option_asc), (col_a2, option_asc)],
171 vec![(col_b, option_asc), (col_b2, option_asc)],
172 ],
173 ),
174 (
177 vec![
179 vec![(col_a, option_asc)],
180 vec![(col_b, option_asc)],
181 vec![(col_c, option_asc)],
182 ],
183 vec![vec![(col_a, option_asc)], vec![(col_b, option_asc)]],
185 vec![
187 vec![(col_a, option_asc), (col_a2, option_asc)],
188 vec![(col_a, option_asc), (col_b2, option_asc)],
189 vec![(col_b, option_asc), (col_a2, option_asc)],
190 vec![(col_b, option_asc), (col_b2, option_asc)],
191 vec![(col_c, option_asc), (col_a2, option_asc)],
192 vec![(col_c, option_asc), (col_b2, option_asc)],
193 ],
194 ),
195 ];
196 for (left_orderings, right_orderings, expected) in test_cases {
197 let mut left_eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
198 let mut right_eq_properties = EquivalenceProperties::new(Arc::clone(&schema));
199 let left_orderings = convert_to_orderings(&left_orderings);
200 let right_orderings = convert_to_orderings(&right_orderings);
201 let expected = convert_to_orderings(&expected);
202 left_eq_properties.add_orderings(left_orderings);
203 right_eq_properties.add_orderings(right_orderings);
204 let join_eq = join_equivalence_properties(
205 left_eq_properties,
206 right_eq_properties,
207 &JoinType::Inner,
208 Arc::new(Schema::empty()),
209 &[true, false],
210 Some(JoinSide::Left),
211 &[],
212 )?;
213 let err_msg =
214 format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class);
215 assert_eq!(join_eq.oeq_class.len(), expected.len(), "{err_msg}");
216 for ordering in join_eq.oeq_class {
217 assert!(
218 expected.contains(&ordering),
219 "{err_msg}, ordering: {ordering:?}"
220 );
221 }
222 }
223 Ok(())
224 }
225
226 #[test]
227 fn test_get_updated_right_ordering_equivalence_properties() -> Result<()> {
228 let join_type = JoinType::Inner;
229 let child_fields: Fields = ["x", "y", "z", "w"]
231 .into_iter()
232 .map(|name| Field::new(name, DataType::Int32, true))
233 .collect();
234 let child_schema = Schema::new(child_fields);
235 let col_x = &col("x", &child_schema)?;
236 let col_y = &col("y", &child_schema)?;
237 let col_z = &col("z", &child_schema)?;
238 let col_w = &col("w", &child_schema)?;
239 let option_asc = SortOptions {
240 descending: false,
241 nulls_first: false,
242 };
243 let orderings = vec![
245 vec![(col_x, option_asc), (col_y, option_asc)],
246 vec![(col_z, option_asc), (col_w, option_asc)],
247 ];
248 let orderings = convert_to_orderings(&orderings);
249 let mut right_oeq_class = OrderingEquivalenceClass::from(orderings);
251
252 let left_columns_len = 4;
253
254 let fields: Fields = ["a", "b", "c", "d", "x", "y", "z", "w"]
255 .into_iter()
256 .map(|name| Field::new(name, DataType::Int32, true))
257 .collect();
258
259 let schema = Schema::new(fields);
261 let col_a = col("a", &schema)?;
262 let col_d = col("d", &schema)?;
263 let col_x = col("x", &schema)?;
264 let col_y = col("y", &schema)?;
265 let col_z = col("z", &schema)?;
266 let col_w = col("w", &schema)?;
267
268 let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema));
269 join_eq_properties.add_equal_conditions(col_a, Arc::clone(&col_x))?;
271 join_eq_properties.add_equal_conditions(col_d, Arc::clone(&col_w))?;
272
273 updated_right_ordering_equivalence_class(
274 &mut right_oeq_class,
275 &join_type,
276 left_columns_len,
277 )?;
278 join_eq_properties.add_orderings(right_oeq_class);
279 let result = join_eq_properties.oeq_class().clone();
280
281 let orderings = vec![
283 vec![(col_x, option_asc), (col_y, option_asc)],
284 vec![(col_z, option_asc), (col_w, option_asc)],
285 ];
286 let orderings = convert_to_orderings(&orderings);
287 let expected = OrderingEquivalenceClass::from(orderings);
288
289 assert_eq!(result, expected);
290
291 Ok(())
292 }
293}