datafusion_optimizer/
filter_null_join_keys.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable
19
20use crate::optimizer::ApplyOrder;
21use crate::push_down_filter::on_lr_is_preserved;
22use crate::{OptimizerConfig, OptimizerRule};
23use datafusion_common::tree_node::Transformed;
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::utils::conjunction;
26use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan};
27use std::sync::Arc;
28
29/// The FilterNullJoinKeys rule will identify joins with equi-join conditions
30/// where the join key is nullable and then insert an `IsNotNull` filter on the nullable side since null values
31/// can never match.
32#[derive(Default, Debug)]
33pub struct FilterNullJoinKeys {}
34
35impl OptimizerRule for FilterNullJoinKeys {
36    fn supports_rewrite(&self) -> bool {
37        true
38    }
39
40    fn apply_order(&self) -> Option<ApplyOrder> {
41        Some(ApplyOrder::BottomUp)
42    }
43
44    fn rewrite(
45        &self,
46        plan: LogicalPlan,
47        config: &dyn OptimizerConfig,
48    ) -> Result<Transformed<LogicalPlan>> {
49        if !config.options().optimizer.filter_null_join_keys {
50            return Ok(Transformed::no(plan));
51        }
52        match plan {
53            LogicalPlan::Join(mut join)
54                if !join.on.is_empty()
55                    && join.null_equality == NullEquality::NullEqualsNothing =>
56            {
57                let (left_preserved, right_preserved) =
58                    on_lr_is_preserved(join.join_type);
59
60                let left_schema = join.left.schema();
61                let right_schema = join.right.schema();
62
63                let mut left_filters = vec![];
64                let mut right_filters = vec![];
65
66                for (l, r) in &join.on {
67                    if left_preserved && l.nullable(left_schema)? {
68                        left_filters.push(l.clone());
69                    }
70
71                    if right_preserved && r.nullable(right_schema)? {
72                        right_filters.push(r.clone());
73                    }
74                }
75
76                if !left_filters.is_empty() {
77                    let predicate = create_not_null_predicate(left_filters);
78                    join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(
79                        predicate, join.left,
80                    )?));
81                }
82                if !right_filters.is_empty() {
83                    let predicate = create_not_null_predicate(right_filters);
84                    join.right = Arc::new(LogicalPlan::Filter(Filter::try_new(
85                        predicate, join.right,
86                    )?));
87                }
88                Ok(Transformed::yes(LogicalPlan::Join(join)))
89            }
90            _ => Ok(Transformed::no(plan)),
91        }
92    }
93    fn name(&self) -> &str {
94        "filter_null_join_keys"
95    }
96}
97
98fn create_not_null_predicate(filters: Vec<Expr>) -> Expr {
99    let not_null_exprs: Vec<Expr> = filters
100        .into_iter()
101        .map(|c| Expr::IsNotNull(Box::new(c)))
102        .collect();
103
104    // directly unwrap since it should always have a value
105    conjunction(not_null_exprs).unwrap()
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::assert_optimized_plan_eq_snapshot;
112    use crate::OptimizerContext;
113    use arrow::datatypes::{DataType, Field, Schema};
114    use datafusion_common::Column;
115    use datafusion_expr::logical_plan::table_scan;
116    use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder};
117
118    macro_rules! assert_optimized_plan_equal {
119        (
120            $plan:expr,
121            @ $expected:literal $(,)?
122        ) => {{
123            let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
124            let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(FilterNullJoinKeys {})];
125            assert_optimized_plan_eq_snapshot!(
126                optimizer_ctx,
127                rules,
128                $plan,
129                @ $expected,
130            )
131        }};
132    }
133
134    #[test]
135    fn left_nullable() -> Result<()> {
136        let (t1, t2) = test_tables()?;
137        let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
138
139        assert_optimized_plan_equal!(plan, @r"
140        Inner Join: t1.optional_id = t2.id
141          Filter: t1.optional_id IS NOT NULL
142            TableScan: t1
143          TableScan: t2
144        ")
145    }
146
147    #[test]
148    fn left_nullable_left_join() -> Result<()> {
149        let (t1, t2) = test_tables()?;
150        let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?;
151
152        assert_optimized_plan_equal!(plan, @r"
153        Left Join: t1.optional_id = t2.id
154          TableScan: t1
155          TableScan: t2
156        ")
157    }
158
159    #[test]
160    fn left_nullable_left_join_reordered() -> Result<()> {
161        let (t_left, t_right) = test_tables()?;
162        // Note: order of tables is reversed
163        let plan =
164            build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?;
165
166        assert_optimized_plan_equal!(plan, @r"
167        Left Join: t2.id = t1.optional_id
168          TableScan: t2
169          Filter: t1.optional_id IS NOT NULL
170            TableScan: t1
171        ")
172    }
173
174    #[test]
175    fn left_nullable_on_condition_reversed() -> Result<()> {
176        let (t1, t2) = test_tables()?;
177        let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?;
178
179        assert_optimized_plan_equal!(plan, @r"
180        Inner Join: t1.optional_id = t2.id
181          Filter: t1.optional_id IS NOT NULL
182            TableScan: t1
183          TableScan: t2
184        ")
185    }
186
187    #[test]
188    fn nested_join_multiple_filter_expr() -> Result<()> {
189        let (t1, t2) = test_tables()?;
190        let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
191        let schema = Schema::new(vec![
192            Field::new("id", DataType::UInt32, false),
193            Field::new("t1_id", DataType::UInt32, true),
194            Field::new("t2_id", DataType::UInt32, true),
195        ]);
196        let t3 = table_scan(Some("t3"), &schema, None)?.build()?;
197        let plan = LogicalPlanBuilder::from(t3)
198            .join(
199                plan,
200                JoinType::Inner,
201                (
202                    vec![
203                        Column::from_qualified_name("t3.t1_id"),
204                        Column::from_qualified_name("t3.t2_id"),
205                    ],
206                    vec![
207                        Column::from_qualified_name("t1.id"),
208                        Column::from_qualified_name("t2.id"),
209                    ],
210                ),
211                None,
212            )?
213            .build()?;
214
215        assert_optimized_plan_equal!(plan, @r"
216        Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id
217          Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL
218            TableScan: t3
219          Inner Join: t1.optional_id = t2.id
220            Filter: t1.optional_id IS NOT NULL
221              TableScan: t1
222            TableScan: t2
223        ")
224    }
225
226    #[test]
227    fn left_nullable_expr_key() -> Result<()> {
228        let (t1, t2) = test_tables()?;
229        let plan = LogicalPlanBuilder::from(t1)
230            .join_with_expr_keys(
231                t2,
232                JoinType::Inner,
233                (
234                    vec![col("t1.optional_id") + lit(1u32)],
235                    vec![col("t2.id") + lit(1u32)],
236                ),
237                None,
238            )?
239            .build()?;
240
241        assert_optimized_plan_equal!(plan, @r"
242        Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1)
243          Filter: t1.optional_id + UInt32(1) IS NOT NULL
244            TableScan: t1
245          TableScan: t2
246        ")
247    }
248
249    #[test]
250    fn right_nullable_expr_key() -> Result<()> {
251        let (t1, t2) = test_tables()?;
252        let plan = LogicalPlanBuilder::from(t1)
253            .join_with_expr_keys(
254                t2,
255                JoinType::Inner,
256                (
257                    vec![col("t1.id") + lit(1u32)],
258                    vec![col("t2.optional_id") + lit(1u32)],
259                ),
260                None,
261            )?
262            .build()?;
263
264        assert_optimized_plan_equal!(plan, @r"
265        Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1)
266          TableScan: t1
267          Filter: t2.optional_id + UInt32(1) IS NOT NULL
268            TableScan: t2
269        ")
270    }
271
272    #[test]
273    fn both_side_nullable_expr_key() -> Result<()> {
274        let (t1, t2) = test_tables()?;
275        let plan = LogicalPlanBuilder::from(t1)
276            .join_with_expr_keys(
277                t2,
278                JoinType::Inner,
279                (
280                    vec![col("t1.optional_id") + lit(1u32)],
281                    vec![col("t2.optional_id") + lit(1u32)],
282                ),
283                None,
284            )?
285            .build()?;
286
287        assert_optimized_plan_equal!(plan, @r"
288        Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1)
289          Filter: t1.optional_id + UInt32(1) IS NOT NULL
290            TableScan: t1
291          Filter: t2.optional_id + UInt32(1) IS NOT NULL
292            TableScan: t2
293        ")
294    }
295
296    #[test]
297    fn one_side_unqualified() -> Result<()> {
298        let (t1, t2) = test_tables()?;
299        let plan_from_exprs = LogicalPlanBuilder::from(t1.clone())
300            .join_with_expr_keys(
301                t2.clone(),
302                JoinType::Inner,
303                (vec![col("optional_id")], vec![col("t2.optional_id")]),
304                None,
305            )?
306            .build()?;
307        let plan_from_cols = LogicalPlanBuilder::from(t1)
308            .join(
309                t2,
310                JoinType::Inner,
311                (vec!["optional_id"], vec!["t2.optional_id"]),
312                None,
313            )?
314            .build()?;
315
316        assert_optimized_plan_equal!(plan_from_cols, @r"
317        Inner Join: t1.optional_id = t2.optional_id
318          Filter: t1.optional_id IS NOT NULL
319            TableScan: t1
320          Filter: t2.optional_id IS NOT NULL
321            TableScan: t2
322        ")?;
323
324        assert_optimized_plan_equal!(plan_from_exprs, @r"
325        Inner Join: t1.optional_id = t2.optional_id
326          Filter: t1.optional_id IS NOT NULL
327            TableScan: t1
328          Filter: t2.optional_id IS NOT NULL
329            TableScan: t2
330        ")
331    }
332
333    fn build_plan(
334        left_table: LogicalPlan,
335        right_table: LogicalPlan,
336        left_key: &str,
337        right_key: &str,
338        join_type: JoinType,
339    ) -> Result<LogicalPlan> {
340        LogicalPlanBuilder::from(left_table)
341            .join(
342                right_table,
343                join_type,
344                (
345                    vec![Column::from_qualified_name(left_key)],
346                    vec![Column::from_qualified_name(right_key)],
347                ),
348                None,
349            )?
350            .build()
351    }
352
353    fn test_tables() -> Result<(LogicalPlan, LogicalPlan)> {
354        let schema = Schema::new(vec![
355            Field::new("id", DataType::UInt32, false),
356            Field::new("optional_id", DataType::UInt32, true),
357        ]);
358        let t1 = table_scan(Some("t1"), &schema, None)?.build()?;
359        let t2 = table_scan(Some("t2"), &schema, None)?.build()?;
360        Ok((t1, t2))
361    }
362}