1use crate::optimizer::ApplyOrder;
21use crate::push_down_filter::on_lr_is_preserved;
22use crate::{OptimizerConfig, OptimizerRule};
23use datafusion_common::tree_node::Transformed;
24use datafusion_common::{NullEquality, Result};
25use datafusion_expr::utils::conjunction;
26use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan};
27use std::sync::Arc;
28
29#[derive(Default, Debug)]
33pub struct FilterNullJoinKeys {}
34
35impl OptimizerRule for FilterNullJoinKeys {
36 fn supports_rewrite(&self) -> bool {
37 true
38 }
39
40 fn apply_order(&self) -> Option<ApplyOrder> {
41 Some(ApplyOrder::BottomUp)
42 }
43
44 fn rewrite(
45 &self,
46 plan: LogicalPlan,
47 config: &dyn OptimizerConfig,
48 ) -> Result<Transformed<LogicalPlan>> {
49 if !config.options().optimizer.filter_null_join_keys {
50 return Ok(Transformed::no(plan));
51 }
52 match plan {
53 LogicalPlan::Join(mut join)
54 if !join.on.is_empty()
55 && join.null_equality == NullEquality::NullEqualsNothing =>
56 {
57 let (left_preserved, right_preserved) =
58 on_lr_is_preserved(join.join_type);
59
60 let left_schema = join.left.schema();
61 let right_schema = join.right.schema();
62
63 let mut left_filters = vec![];
64 let mut right_filters = vec![];
65
66 for (l, r) in &join.on {
67 if left_preserved && l.nullable(left_schema)? {
68 left_filters.push(l.clone());
69 }
70
71 if right_preserved && r.nullable(right_schema)? {
72 right_filters.push(r.clone());
73 }
74 }
75
76 if !left_filters.is_empty() {
77 let predicate = create_not_null_predicate(left_filters);
78 join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(
79 predicate, join.left,
80 )?));
81 }
82 if !right_filters.is_empty() {
83 let predicate = create_not_null_predicate(right_filters);
84 join.right = Arc::new(LogicalPlan::Filter(Filter::try_new(
85 predicate, join.right,
86 )?));
87 }
88 Ok(Transformed::yes(LogicalPlan::Join(join)))
89 }
90 _ => Ok(Transformed::no(plan)),
91 }
92 }
93 fn name(&self) -> &str {
94 "filter_null_join_keys"
95 }
96}
97
98fn create_not_null_predicate(filters: Vec<Expr>) -> Expr {
99 let not_null_exprs: Vec<Expr> = filters
100 .into_iter()
101 .map(|c| Expr::IsNotNull(Box::new(c)))
102 .collect();
103
104 conjunction(not_null_exprs).unwrap()
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::assert_optimized_plan_eq_snapshot;
112 use crate::OptimizerContext;
113 use arrow::datatypes::{DataType, Field, Schema};
114 use datafusion_common::Column;
115 use datafusion_expr::logical_plan::table_scan;
116 use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder};
117
118 macro_rules! assert_optimized_plan_equal {
119 (
120 $plan:expr,
121 @ $expected:literal $(,)?
122 ) => {{
123 let optimizer_ctx = OptimizerContext::new().with_max_passes(1);
124 let rules: Vec<Arc<dyn crate::OptimizerRule + Send + Sync>> = vec![Arc::new(FilterNullJoinKeys {})];
125 assert_optimized_plan_eq_snapshot!(
126 optimizer_ctx,
127 rules,
128 $plan,
129 @ $expected,
130 )
131 }};
132 }
133
134 #[test]
135 fn left_nullable() -> Result<()> {
136 let (t1, t2) = test_tables()?;
137 let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
138
139 assert_optimized_plan_equal!(plan, @r"
140 Inner Join: t1.optional_id = t2.id
141 Filter: t1.optional_id IS NOT NULL
142 TableScan: t1
143 TableScan: t2
144 ")
145 }
146
147 #[test]
148 fn left_nullable_left_join() -> Result<()> {
149 let (t1, t2) = test_tables()?;
150 let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?;
151
152 assert_optimized_plan_equal!(plan, @r"
153 Left Join: t1.optional_id = t2.id
154 TableScan: t1
155 TableScan: t2
156 ")
157 }
158
159 #[test]
160 fn left_nullable_left_join_reordered() -> Result<()> {
161 let (t_left, t_right) = test_tables()?;
162 let plan =
164 build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?;
165
166 assert_optimized_plan_equal!(plan, @r"
167 Left Join: t2.id = t1.optional_id
168 TableScan: t2
169 Filter: t1.optional_id IS NOT NULL
170 TableScan: t1
171 ")
172 }
173
174 #[test]
175 fn left_nullable_on_condition_reversed() -> Result<()> {
176 let (t1, t2) = test_tables()?;
177 let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?;
178
179 assert_optimized_plan_equal!(plan, @r"
180 Inner Join: t1.optional_id = t2.id
181 Filter: t1.optional_id IS NOT NULL
182 TableScan: t1
183 TableScan: t2
184 ")
185 }
186
187 #[test]
188 fn nested_join_multiple_filter_expr() -> Result<()> {
189 let (t1, t2) = test_tables()?;
190 let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
191 let schema = Schema::new(vec![
192 Field::new("id", DataType::UInt32, false),
193 Field::new("t1_id", DataType::UInt32, true),
194 Field::new("t2_id", DataType::UInt32, true),
195 ]);
196 let t3 = table_scan(Some("t3"), &schema, None)?.build()?;
197 let plan = LogicalPlanBuilder::from(t3)
198 .join(
199 plan,
200 JoinType::Inner,
201 (
202 vec![
203 Column::from_qualified_name("t3.t1_id"),
204 Column::from_qualified_name("t3.t2_id"),
205 ],
206 vec![
207 Column::from_qualified_name("t1.id"),
208 Column::from_qualified_name("t2.id"),
209 ],
210 ),
211 None,
212 )?
213 .build()?;
214
215 assert_optimized_plan_equal!(plan, @r"
216 Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id
217 Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL
218 TableScan: t3
219 Inner Join: t1.optional_id = t2.id
220 Filter: t1.optional_id IS NOT NULL
221 TableScan: t1
222 TableScan: t2
223 ")
224 }
225
226 #[test]
227 fn left_nullable_expr_key() -> Result<()> {
228 let (t1, t2) = test_tables()?;
229 let plan = LogicalPlanBuilder::from(t1)
230 .join_with_expr_keys(
231 t2,
232 JoinType::Inner,
233 (
234 vec![col("t1.optional_id") + lit(1u32)],
235 vec![col("t2.id") + lit(1u32)],
236 ),
237 None,
238 )?
239 .build()?;
240
241 assert_optimized_plan_equal!(plan, @r"
242 Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1)
243 Filter: t1.optional_id + UInt32(1) IS NOT NULL
244 TableScan: t1
245 TableScan: t2
246 ")
247 }
248
249 #[test]
250 fn right_nullable_expr_key() -> Result<()> {
251 let (t1, t2) = test_tables()?;
252 let plan = LogicalPlanBuilder::from(t1)
253 .join_with_expr_keys(
254 t2,
255 JoinType::Inner,
256 (
257 vec![col("t1.id") + lit(1u32)],
258 vec![col("t2.optional_id") + lit(1u32)],
259 ),
260 None,
261 )?
262 .build()?;
263
264 assert_optimized_plan_equal!(plan, @r"
265 Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1)
266 TableScan: t1
267 Filter: t2.optional_id + UInt32(1) IS NOT NULL
268 TableScan: t2
269 ")
270 }
271
272 #[test]
273 fn both_side_nullable_expr_key() -> Result<()> {
274 let (t1, t2) = test_tables()?;
275 let plan = LogicalPlanBuilder::from(t1)
276 .join_with_expr_keys(
277 t2,
278 JoinType::Inner,
279 (
280 vec![col("t1.optional_id") + lit(1u32)],
281 vec![col("t2.optional_id") + lit(1u32)],
282 ),
283 None,
284 )?
285 .build()?;
286
287 assert_optimized_plan_equal!(plan, @r"
288 Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1)
289 Filter: t1.optional_id + UInt32(1) IS NOT NULL
290 TableScan: t1
291 Filter: t2.optional_id + UInt32(1) IS NOT NULL
292 TableScan: t2
293 ")
294 }
295
296 #[test]
297 fn one_side_unqualified() -> Result<()> {
298 let (t1, t2) = test_tables()?;
299 let plan_from_exprs = LogicalPlanBuilder::from(t1.clone())
300 .join_with_expr_keys(
301 t2.clone(),
302 JoinType::Inner,
303 (vec![col("optional_id")], vec![col("t2.optional_id")]),
304 None,
305 )?
306 .build()?;
307 let plan_from_cols = LogicalPlanBuilder::from(t1)
308 .join(
309 t2,
310 JoinType::Inner,
311 (vec!["optional_id"], vec!["t2.optional_id"]),
312 None,
313 )?
314 .build()?;
315
316 assert_optimized_plan_equal!(plan_from_cols, @r"
317 Inner Join: t1.optional_id = t2.optional_id
318 Filter: t1.optional_id IS NOT NULL
319 TableScan: t1
320 Filter: t2.optional_id IS NOT NULL
321 TableScan: t2
322 ")?;
323
324 assert_optimized_plan_equal!(plan_from_exprs, @r"
325 Inner Join: t1.optional_id = t2.optional_id
326 Filter: t1.optional_id IS NOT NULL
327 TableScan: t1
328 Filter: t2.optional_id IS NOT NULL
329 TableScan: t2
330 ")
331 }
332
333 fn build_plan(
334 left_table: LogicalPlan,
335 right_table: LogicalPlan,
336 left_key: &str,
337 right_key: &str,
338 join_type: JoinType,
339 ) -> Result<LogicalPlan> {
340 LogicalPlanBuilder::from(left_table)
341 .join(
342 right_table,
343 join_type,
344 (
345 vec![Column::from_qualified_name(left_key)],
346 vec![Column::from_qualified_name(right_key)],
347 ),
348 None,
349 )?
350 .build()
351 }
352
353 fn test_tables() -> Result<(LogicalPlan, LogicalPlan)> {
354 let schema = Schema::new(vec![
355 Field::new("id", DataType::UInt32, false),
356 Field::new("optional_id", DataType::UInt32, true),
357 ]);
358 let t1 = table_scan(Some("t1"), &schema, None)?.build()?;
359 let t2 = table_scan(Some("t2"), &schema, None)?.build()?;
360 Ok((t1, t2))
361 }
362}