1use datafusion_expr::Expr;
21use indexmap::{Equivalent, IndexSet};
22
23#[derive(Debug)]
37pub struct JoinKeySet {
38 inner: IndexSet<(Expr, Expr)>,
39}
40
41impl JoinKeySet {
42 pub fn new() -> Self {
44 Self {
45 inner: IndexSet::new(),
46 }
47 }
48
49 pub fn contains(&self, left: &Expr, right: &Expr) -> bool {
52 self.inner.contains(&ExprPair::new(left, right))
53 || self.inner.contains(&ExprPair::new(right, left))
54 }
55
56 pub fn insert(&mut self, left: &Expr, right: &Expr) -> bool {
61 if self.contains(left, right) {
62 false
63 } else {
64 self.inner.insert((left.clone(), right.clone()));
65 true
66 }
67 }
68
69 pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool {
72 if self.contains(&left, &right) {
73 false
74 } else {
75 self.inner.insert((left, right));
76 true
77 }
78 }
79
80 pub fn insert_all<'a>(
84 &mut self,
85 iter: impl IntoIterator<Item = &'a (Expr, Expr)>,
86 ) -> bool {
87 let mut inserted = false;
88 for (left, right) in iter.into_iter() {
89 inserted |= self.insert(left, right);
90 }
91 inserted
92 }
93
94 pub fn insert_all_owned(
99 &mut self,
100 iter: impl IntoIterator<Item = (Expr, Expr)>,
101 ) -> bool {
102 let mut inserted = false;
103 for (left, right) in iter.into_iter() {
104 inserted |= self.insert_owned(left, right);
105 }
106 inserted
107 }
108
109 pub fn insert_intersection(&mut self, s1: &JoinKeySet, s2: &JoinKeySet) {
111 for (left, right) in s1.inner.iter() {
114 if s2.contains(left, right) {
115 self.insert(left, right);
116 }
117 }
118 }
119
120 pub fn is_empty(&self) -> bool {
122 self.inner.is_empty()
123 }
124
125 #[cfg(test)]
127 pub fn len(&self) -> usize {
128 self.inner.len()
129 }
130
131 pub fn iter(&self) -> impl Iterator<Item = (&Expr, &Expr)> {
133 self.inner.iter().map(|(l, r)| (l, r))
134 }
135}
136
137#[derive(Debug, Eq, PartialEq, Hash)]
143struct ExprPair<'a>(&'a Expr, &'a Expr);
144
145impl<'a> ExprPair<'a> {
146 fn new(left: &'a Expr, right: &'a Expr) -> Self {
147 Self(left, right)
148 }
149}
150
151impl Equivalent<(Expr, Expr)> for ExprPair<'_> {
152 fn equivalent(&self, other: &(Expr, Expr)) -> bool {
153 self.0 == &other.0 && self.1 == &other.1
154 }
155}
156
157#[cfg(test)]
158mod test {
159 use crate::join_key_set::JoinKeySet;
160 use datafusion_expr::{col, Expr};
161
162 #[test]
163 fn test_insert() {
164 let mut set = JoinKeySet::new();
165 assert!(set.is_empty());
167
168 assert!(set.insert(&col("a"), &col("b")));
170 assert!(!set.is_empty());
171
172 assert!(!set.insert(&col("a"), &col("b")));
174 assert_eq!(set.len(), 1);
175
176 assert!(!set.insert(&col("b"), &col("a")));
178 assert_eq!(set.len(), 1);
179
180 assert!(set.insert(&col("a"), &col("c")));
182 assert_eq!(set.len(), 2);
183 }
184
185 #[test]
186 fn test_insert_owned() {
187 let mut set = JoinKeySet::new();
188 assert!(set.insert_owned(col("a"), col("b")));
189 assert!(set.contains(&col("a"), &col("b")));
190 assert!(set.contains(&col("b"), &col("a")));
191 assert!(!set.contains(&col("a"), &col("c")));
192 }
193
194 #[test]
195 fn test_contains() {
196 let mut set = JoinKeySet::new();
197 assert!(set.insert(&col("a"), &col("b")));
198 assert!(set.contains(&col("a"), &col("b")));
199 assert!(set.contains(&col("b"), &col("a")));
200 assert!(!set.contains(&col("a"), &col("c")));
201
202 assert!(set.insert(&col("a"), &col("c")));
203 assert!(set.contains(&col("a"), &col("c")));
204 assert!(set.contains(&col("c"), &col("a")));
205 }
206
207 #[test]
208 fn test_iterator() {
209 let mut set = JoinKeySet::new();
211 set.insert(&col("c"), &col("a"));
213 set.insert(&col("b"), &col("c"));
214 set.insert(&col("a"), &col("c"));
215 assert_contents(&set, vec![(&col("c"), &col("a")), (&col("b"), &col("c"))]);
216 }
217
218 #[test]
219 fn test_insert_intersection() {
220 let mut set1 = JoinKeySet::new();
222 set1.insert(&col("a"), &col("b"));
223 set1.insert(&col("b"), &col("c"));
224 set1.insert(&col("c"), &col("d"));
225
226 let mut set2 = JoinKeySet::new();
229 set2.insert(&col("a"), &col("a"));
230 set2.insert(&col("b"), &col("b"));
231 set2.insert(&col("b"), &col("c"));
232 set2.insert(&col("d"), &col("c"));
233
234 let mut set = JoinKeySet::new();
235 set.insert(&col("x"), &col("y"));
237 set.insert_intersection(&set1, &set2);
238
239 assert_contents(
240 &set,
241 vec![
242 (&col("x"), &col("y")),
243 (&col("b"), &col("c")),
244 (&col("c"), &col("d")),
245 ],
246 );
247 }
248
249 fn assert_contents(set: &JoinKeySet, expected: Vec<(&Expr, &Expr)>) {
250 let contents: Vec<_> = set.iter().collect();
251 assert_eq!(contents, expected);
252 }
253
254 #[test]
255 fn test_insert_all() {
256 let mut set = JoinKeySet::new();
257
258 set.insert_all(vec![
260 &(col("a"), col("b")),
261 &(col("b"), col("c")),
262 &(col("b"), col("a")),
263 ]);
264 assert_eq!(set.len(), 2);
265 assert!(set.contains(&col("a"), &col("b")));
266 assert!(set.contains(&col("b"), &col("c")));
267 assert!(set.contains(&col("b"), &col("a")));
268
269 assert!(!set.contains(&col("a"), &col("c")));
271 }
272
273 #[test]
274 fn test_insert_all_owned() {
275 let mut set = JoinKeySet::new();
276
277 set.insert_all_owned(vec![
279 (col("a"), col("b")),
280 (col("b"), col("c")),
281 (col("b"), col("a")),
282 ]);
283 assert_eq!(set.len(), 2);
284 assert!(set.contains(&col("a"), &col("b")));
285 assert!(set.contains(&col("b"), &col("c")));
286 assert!(set.contains(&col("b"), &col("a")));
287
288 assert!(!set.contains(&col("a"), &col("c")));
290 }
291}