datafusion_optimizer/
join_key_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [JoinKeySet] for tracking the set of join keys in a plan.
19
20use datafusion_expr::Expr;
21use indexmap::{Equivalent, IndexSet};
22
23/// Tracks a set of equality Join keys
24///
25/// A join key is an expression that is used to join two tables via an equality
26/// predicate such as `a.x = b.y`
27///
28/// This struct models `a.x + 5 = b.y AND a.z = b.z` as two join keys
29/// 1. `(a.x + 5,  b.y)`
30/// 2. `(a.z,      b.z)`
31///
32/// # Important properties:
33///
34/// 1. Retains insert order
35/// 2. Can quickly look up if a pair of expressions are in the set.
36#[derive(Debug)]
37pub struct JoinKeySet {
38    inner: IndexSet<(Expr, Expr)>,
39}
40
41impl JoinKeySet {
42    /// Create a new empty set
43    pub fn new() -> Self {
44        Self {
45            inner: IndexSet::new(),
46        }
47    }
48
49    /// Return true if the set contains a join pair
50    /// where left = right or right = left
51    pub fn contains(&self, left: &Expr, right: &Expr) -> bool {
52        self.inner.contains(&ExprPair::new(left, right))
53            || self.inner.contains(&ExprPair::new(right, left))
54    }
55
56    /// Insert the join key `(left = right)` into the set  if join pair `(right =
57    /// left)` is not already in the set
58    ///
59    /// returns true if the pair was inserted
60    pub fn insert(&mut self, left: &Expr, right: &Expr) -> bool {
61        if self.contains(left, right) {
62            false
63        } else {
64            self.inner.insert((left.clone(), right.clone()));
65            true
66        }
67    }
68
69    /// Same as [`Self::insert`] but avoids cloning expression if they
70    /// are owned
71    pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool {
72        if self.contains(&left, &right) {
73            false
74        } else {
75            self.inner.insert((left, right));
76            true
77        }
78    }
79
80    /// Inserts potentially many join keys into the set, copying only when necessary
81    ///
82    /// returns true if any of the pairs were inserted
83    pub fn insert_all<'a>(
84        &mut self,
85        iter: impl IntoIterator<Item = &'a (Expr, Expr)>,
86    ) -> bool {
87        let mut inserted = false;
88        for (left, right) in iter.into_iter() {
89            inserted |= self.insert(left, right);
90        }
91        inserted
92    }
93
94    /// Same as [`Self::insert_all`] but avoids cloning expressions if they are
95    /// already owned
96    ///
97    /// returns true if any of the pairs were inserted
98    pub fn insert_all_owned(
99        &mut self,
100        iter: impl IntoIterator<Item = (Expr, Expr)>,
101    ) -> bool {
102        let mut inserted = false;
103        for (left, right) in iter.into_iter() {
104            inserted |= self.insert_owned(left, right);
105        }
106        inserted
107    }
108
109    /// Inserts any join keys that are common to both `s1` and `s2` into self
110    pub fn insert_intersection(&mut self, s1: &JoinKeySet, s2: &JoinKeySet) {
111        // note can't use inner.intersection as we need to consider both (l, r)
112        // and (r, l) in equality
113        for (left, right) in s1.inner.iter() {
114            if s2.contains(left, right) {
115                self.insert(left, right);
116            }
117        }
118    }
119
120    /// returns true if this set is empty
121    pub fn is_empty(&self) -> bool {
122        self.inner.is_empty()
123    }
124
125    /// Return the length of this set
126    #[cfg(test)]
127    pub fn len(&self) -> usize {
128        self.inner.len()
129    }
130
131    /// Return an iterator over the join keys in this set
132    pub fn iter(&self) -> impl Iterator<Item = (&Expr, &Expr)> {
133        self.inner.iter().map(|(l, r)| (l, r))
134    }
135}
136
137/// Custom comparison operation to avoid copying owned values
138///
139/// This behaves like a `(Expr, Expr)` tuple for hashing and  comparison, but
140/// avoids copying the values simply to comparing them.
141
142#[derive(Debug, Eq, PartialEq, Hash)]
143struct ExprPair<'a>(&'a Expr, &'a Expr);
144
145impl<'a> ExprPair<'a> {
146    fn new(left: &'a Expr, right: &'a Expr) -> Self {
147        Self(left, right)
148    }
149}
150
151impl Equivalent<(Expr, Expr)> for ExprPair<'_> {
152    fn equivalent(&self, other: &(Expr, Expr)) -> bool {
153        self.0 == &other.0 && self.1 == &other.1
154    }
155}
156
157#[cfg(test)]
158mod test {
159    use crate::join_key_set::JoinKeySet;
160    use datafusion_expr::{col, Expr};
161
162    #[test]
163    fn test_insert() {
164        let mut set = JoinKeySet::new();
165        // new sets should be empty
166        assert!(set.is_empty());
167
168        // insert (a = b)
169        assert!(set.insert(&col("a"), &col("b")));
170        assert!(!set.is_empty());
171
172        // insert (a=b) again returns false
173        assert!(!set.insert(&col("a"), &col("b")));
174        assert_eq!(set.len(), 1);
175
176        // insert (b = a) , should be considered equivalent
177        assert!(!set.insert(&col("b"), &col("a")));
178        assert_eq!(set.len(), 1);
179
180        // insert (a = c) should be considered different
181        assert!(set.insert(&col("a"), &col("c")));
182        assert_eq!(set.len(), 2);
183    }
184
185    #[test]
186    fn test_insert_owned() {
187        let mut set = JoinKeySet::new();
188        assert!(set.insert_owned(col("a"), col("b")));
189        assert!(set.contains(&col("a"), &col("b")));
190        assert!(set.contains(&col("b"), &col("a")));
191        assert!(!set.contains(&col("a"), &col("c")));
192    }
193
194    #[test]
195    fn test_contains() {
196        let mut set = JoinKeySet::new();
197        assert!(set.insert(&col("a"), &col("b")));
198        assert!(set.contains(&col("a"), &col("b")));
199        assert!(set.contains(&col("b"), &col("a")));
200        assert!(!set.contains(&col("a"), &col("c")));
201
202        assert!(set.insert(&col("a"), &col("c")));
203        assert!(set.contains(&col("a"), &col("c")));
204        assert!(set.contains(&col("c"), &col("a")));
205    }
206
207    #[test]
208    fn test_iterator() {
209        // put in c = a and
210        let mut set = JoinKeySet::new();
211        // put in c = a , b = c, and a = c and expect to get only the first 2
212        set.insert(&col("c"), &col("a"));
213        set.insert(&col("b"), &col("c"));
214        set.insert(&col("a"), &col("c"));
215        assert_contents(&set, vec![(&col("c"), &col("a")), (&col("b"), &col("c"))]);
216    }
217
218    #[test]
219    fn test_insert_intersection() {
220        // a = b, b = c, c = d
221        let mut set1 = JoinKeySet::new();
222        set1.insert(&col("a"), &col("b"));
223        set1.insert(&col("b"), &col("c"));
224        set1.insert(&col("c"), &col("d"));
225
226        // a = a, b = b, b = c, d = c
227        // should only intersect on b = c and c = d
228        let mut set2 = JoinKeySet::new();
229        set2.insert(&col("a"), &col("a"));
230        set2.insert(&col("b"), &col("b"));
231        set2.insert(&col("b"), &col("c"));
232        set2.insert(&col("d"), &col("c"));
233
234        let mut set = JoinKeySet::new();
235        // put something in there already
236        set.insert(&col("x"), &col("y"));
237        set.insert_intersection(&set1, &set2);
238
239        assert_contents(
240            &set,
241            vec![
242                (&col("x"), &col("y")),
243                (&col("b"), &col("c")),
244                (&col("c"), &col("d")),
245            ],
246        );
247    }
248
249    fn assert_contents(set: &JoinKeySet, expected: Vec<(&Expr, &Expr)>) {
250        let contents: Vec<_> = set.iter().collect();
251        assert_eq!(contents, expected);
252    }
253
254    #[test]
255    fn test_insert_all() {
256        let mut set = JoinKeySet::new();
257
258        // insert (a=b), (b=c), (b=a)
259        set.insert_all(vec![
260            &(col("a"), col("b")),
261            &(col("b"), col("c")),
262            &(col("b"), col("a")),
263        ]);
264        assert_eq!(set.len(), 2);
265        assert!(set.contains(&col("a"), &col("b")));
266        assert!(set.contains(&col("b"), &col("c")));
267        assert!(set.contains(&col("b"), &col("a")));
268
269        // should not contain (a=c)
270        assert!(!set.contains(&col("a"), &col("c")));
271    }
272
273    #[test]
274    fn test_insert_all_owned() {
275        let mut set = JoinKeySet::new();
276
277        // insert (a=b), (b=c), (b=a)
278        set.insert_all_owned(vec![
279            (col("a"), col("b")),
280            (col("b"), col("c")),
281            (col("b"), col("a")),
282        ]);
283        assert_eq!(set.len(), 2);
284        assert!(set.contains(&col("a"), &col("b")));
285        assert!(set.contains(&col("b"), &col("c")));
286        assert!(set.contains(&col("b"), &col("a")));
287
288        // should not contain (a=c)
289        assert!(!set.contains(&col("a"), &col("c")));
290    }
291}