datafusion_physical_expr/equivalence/
mod.rs1use std::borrow::Borrow;
19use std::sync::Arc;
20
21use crate::PhysicalExpr;
22
23use arrow::compute::SortOptions;
24use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
25
26mod class;
27mod ordering;
28mod properties;
29
30pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup};
31pub use ordering::OrderingEquivalenceClass;
32pub use crate::projection::{project_ordering, project_orderings, ProjectionMapping};
35pub use properties::{
36 calculate_union, join_equivalence_properties, EquivalenceProperties,
37};
38
39pub fn convert_to_sort_exprs<T: Borrow<Arc<dyn PhysicalExpr>>>(
41 args: &[(T, SortOptions)],
42) -> Vec<PhysicalSortExpr> {
43 args.iter()
44 .map(|(expr, options)| PhysicalSortExpr::new(Arc::clone(expr.borrow()), *options))
45 .collect()
46}
47
48pub fn convert_to_orderings<T: Borrow<Arc<dyn PhysicalExpr>>>(
50 args: &[Vec<(T, SortOptions)>],
51) -> Vec<LexOrdering> {
52 args.iter()
53 .filter_map(|sort_exprs| LexOrdering::new(convert_to_sort_exprs(sort_exprs)))
54 .collect()
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use crate::expressions::{col, Column};
61 use crate::{LexRequirement, PhysicalSortExpr};
62
63 use arrow::compute::SortOptions;
64 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
65 use datafusion_common::Result;
66 use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement;
67
68 pub fn parse_sort_expr(name: &str, schema: &SchemaRef) -> PhysicalSortExpr {
74 let mut parts = name.split_whitespace();
75 let name = parts.next().expect("empty sort expression");
76 let mut sort_expr = PhysicalSortExpr::new(
77 col(name, schema).expect("invalid column name"),
78 SortOptions::default(),
79 );
80
81 if let Some(options) = parts.next() {
82 sort_expr = match options {
83 "ASC" => sort_expr.asc(),
84 "DESC" => sort_expr.desc(),
85 _ => panic!(
86 "unknown sort options. Expected 'ASC' or 'DESC', got {options}"
87 ),
88 }
89 }
90
91 assert!(
92 parts.next().is_none(),
93 "unexpected tokens in column name. Expected 'name' / 'name ASC' / 'name DESC' but got '{name}'"
94 );
95
96 sort_expr
97 }
98
99 pub fn create_test_schema() -> Result<SchemaRef> {
101 let a = Field::new("a", DataType::Int32, true);
102 let b = Field::new("b", DataType::Int32, true);
103 let c = Field::new("c", DataType::Int32, true);
104 let d = Field::new("d", DataType::Int32, true);
105 let e = Field::new("e", DataType::Int32, true);
106 let f = Field::new("f", DataType::Int32, true);
107 let g = Field::new("g", DataType::Int32, true);
108 let h = Field::new("h", DataType::Int32, true);
109 let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h]));
110
111 Ok(schema)
112 }
113
114 pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> {
120 let test_schema = create_test_schema()?;
121 let col_a = col("a", &test_schema)?;
122 let col_b = col("b", &test_schema)?;
123 let col_c = col("c", &test_schema)?;
124 let col_d = col("d", &test_schema)?;
125 let col_e = col("e", &test_schema)?;
126 let col_f = col("f", &test_schema)?;
127 let col_g = col("g", &test_schema)?;
128 let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema));
129 eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_c))?;
130
131 let option_asc = SortOptions {
132 descending: false,
133 nulls_first: false,
134 };
135 let option_desc = SortOptions {
136 descending: true,
137 nulls_first: true,
138 };
139 let orderings = vec![
140 vec![(col_a, option_asc)],
142 vec![(col_d, option_asc), (col_b, option_asc)],
144 vec![
146 (col_e, option_desc),
147 (col_f, option_asc),
148 (col_g, option_asc),
149 ],
150 ];
151 let orderings = convert_to_orderings(&orderings);
152 eq_properties.add_orderings(orderings);
153 Ok((test_schema, eq_properties))
154 }
155
156 pub fn convert_to_sort_reqs(
159 args: &[(&Arc<dyn PhysicalExpr>, Option<SortOptions>)],
160 ) -> LexRequirement {
161 let exprs = args.iter().map(|(expr, options)| {
162 PhysicalSortRequirement::new(Arc::clone(*expr), *options)
163 });
164 LexRequirement::new(exprs).unwrap()
165 }
166
167 #[test]
168 fn add_equal_conditions_test() -> Result<()> {
169 let schema = Arc::new(Schema::new(vec![
170 Field::new("a", DataType::Int64, true),
171 Field::new("b", DataType::Int64, true),
172 Field::new("c", DataType::Int64, true),
173 Field::new("x", DataType::Int64, true),
174 Field::new("y", DataType::Int64, true),
175 ]));
176
177 let mut eq_properties = EquivalenceProperties::new(schema);
178 let col_a = Arc::new(Column::new("a", 0)) as _;
179 let col_b = Arc::new(Column::new("b", 1)) as _;
180 let col_c = Arc::new(Column::new("c", 2)) as _;
181 let col_x = Arc::new(Column::new("x", 3)) as _;
182 let col_y = Arc::new(Column::new("y", 4)) as _;
183
184 eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?;
186 assert_eq!(eq_properties.eq_group().len(), 1);
187
188 eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?;
190 assert_eq!(eq_properties.eq_group().len(), 1);
191 let eq_groups = eq_properties.eq_group().iter().next().unwrap();
192 assert_eq!(eq_groups.len(), 2);
193 assert!(eq_groups.contains(&col_a));
194 assert!(eq_groups.contains(&col_b));
195
196 eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?;
199 assert_eq!(eq_properties.eq_group().len(), 1);
200 let eq_groups = eq_properties.eq_group().iter().next().unwrap();
201 assert_eq!(eq_groups.len(), 3);
202 assert!(eq_groups.contains(&col_a));
203 assert!(eq_groups.contains(&col_b));
204 assert!(eq_groups.contains(&col_c));
205
206 eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?;
208 assert_eq!(eq_properties.eq_group().len(), 2);
209
210 eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?;
213 assert_eq!(eq_properties.eq_group().len(), 1);
214 let eq_groups = eq_properties.eq_group().iter().next().unwrap();
215 assert_eq!(eq_groups.len(), 5);
216 assert!(eq_groups.contains(&col_a));
217 assert!(eq_groups.contains(&col_b));
218 assert!(eq_groups.contains(&col_c));
219 assert!(eq_groups.contains(&col_x));
220 assert!(eq_groups.contains(&col_y));
221
222 Ok(())
223 }
224}