datafusion_optimizer/simplify_expressions/
regex.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion_common::{DataFusionError, Result, ScalarValue};
19use datafusion_expr::{lit, BinaryExpr, Expr, Like, Operator};
20use regex_syntax::hir::{Capture, Hir, HirKind, Literal, Look};
21
22/// Maximum number of regex alternations (`foo|bar|...`) that will be expanded into multiple `LIKE` expressions.
23const MAX_REGEX_ALTERNATIONS_EXPANSION: usize = 4;
24
25const ANY_CHAR_REGEX_PATTERN: &str = ".*";
26
27/// Tries to convert a regexp expression to a `LIKE` or `Eq`/`NotEq` expression.
28///
29/// This function also validates the regex pattern. And will return error if the
30/// pattern is invalid.
31///
32/// Typical cases this function can simplify:
33/// - empty regex pattern to `LIKE '%'`
34/// - literal regex patterns to `LIKE '%foo%'`
35/// - full anchored regex patterns (e.g. `^foo$`) to `= 'foo'`
36/// - partial anchored regex patterns (e.g. `^foo`) to `LIKE 'foo%'`
37/// - combinations (alternatives) of the above, will be concatenated with `OR` or `AND`
38/// - `EQ .*` to NotNull
39/// - `NE .*` means IS EMPTY
40///
41/// Dev note: unit tests of this function are in `expr_simplifier.rs`, case `test_simplify_regex`.
42pub fn simplify_regex_expr(
43    left: Box<Expr>,
44    op: Operator,
45    right: Box<Expr>,
46) -> Result<Expr> {
47    let mode = OperatorMode::new(&op);
48
49    if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() {
50        // Handle the special case for ".*" pattern
51        if pattern == ANY_CHAR_REGEX_PATTERN {
52            let new_expr = if mode.not {
53                // not empty
54                let empty_lit = Box::new(lit(""));
55                Expr::BinaryExpr(BinaryExpr {
56                    left,
57                    op: Operator::Eq,
58                    right: empty_lit,
59                })
60            } else {
61                // not null
62                left.is_not_null()
63            };
64            return Ok(new_expr);
65        }
66
67        match regex_syntax::Parser::new().parse(pattern) {
68            Ok(hir) => {
69                let kind = hir.kind();
70                if let HirKind::Alternation(alts) = kind {
71                    if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION {
72                        if let Some(expr) = lower_alt(&mode, &left, alts) {
73                            return Ok(expr);
74                        }
75                    }
76                } else if let Some(expr) = lower_simple(&mode, &left, &hir) {
77                    return Ok(expr);
78                }
79            }
80            Err(e) => {
81                // error out early since the execution may fail anyways
82                return Err(DataFusionError::Context(
83                    "Invalid regex".to_owned(),
84                    Box::new(DataFusionError::External(Box::new(e))),
85                ));
86            }
87        }
88    }
89
90    // Leave untouched if optimization didn't work
91    Ok(Expr::BinaryExpr(BinaryExpr { left, op, right }))
92}
93
94#[derive(Debug)]
95struct OperatorMode {
96    /// Negative match.
97    not: bool,
98    /// Ignore case (`true` for case-insensitive).
99    i: bool,
100}
101
102impl OperatorMode {
103    fn new(op: &Operator) -> Self {
104        let not = match op {
105            Operator::RegexMatch | Operator::RegexIMatch => false,
106            Operator::RegexNotMatch | Operator::RegexNotIMatch => true,
107            _ => unreachable!(),
108        };
109
110        let i = match op {
111            Operator::RegexMatch | Operator::RegexNotMatch => false,
112            Operator::RegexIMatch | Operator::RegexNotIMatch => true,
113            _ => unreachable!(),
114        };
115
116        Self { not, i }
117    }
118
119    /// Creates an [`LIKE`](Expr::Like) from the given `LIKE` pattern.
120    fn expr(&self, expr: Box<Expr>, pattern: String) -> Expr {
121        let like = Like {
122            negated: self.not,
123            expr,
124            pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)),
125            escape_char: None,
126            case_insensitive: self.i,
127        };
128
129        Expr::Like(like)
130    }
131
132    /// Creates an [`Expr::BinaryExpr`] of "`left` = `right`" or "`left` != `right`".
133    fn expr_matches_literal(&self, left: Box<Expr>, right: Box<Expr>) -> Expr {
134        let op = if self.not {
135            Operator::NotEq
136        } else {
137            Operator::Eq
138        };
139        Expr::BinaryExpr(BinaryExpr { left, op, right })
140    }
141}
142
143fn collect_concat_to_like_string(parts: &[Hir]) -> Option<String> {
144    let mut s = String::with_capacity(parts.len() + 2);
145    s.push('%');
146
147    for sub in parts {
148        if let HirKind::Literal(l) = sub.kind() {
149            s.push_str(like_str_from_literal(l)?);
150        } else {
151            return None;
152        }
153    }
154
155    s.push('%');
156    Some(s)
157}
158
159/// Returns a str represented by `Literal` if it contains a valid utf8
160/// sequence and is safe for like (has no '%' and '_')
161fn like_str_from_literal(l: &Literal) -> Option<&str> {
162    // if not utf8, no good
163    let s = std::str::from_utf8(&l.0).ok()?;
164
165    if s.chars().all(is_safe_for_like) {
166        Some(s)
167    } else {
168        None
169    }
170}
171
172/// Returns a str represented by `Literal` if it contains a valid utf8
173fn str_from_literal(l: &Literal) -> Option<&str> {
174    // if not utf8, no good
175    let s = std::str::from_utf8(&l.0).ok()?;
176
177    Some(s)
178}
179
180fn is_safe_for_like(c: char) -> bool {
181    (c != '%') && (c != '_')
182}
183
184/// Returns true if the elements in a `Concat` pattern are:
185/// - `[Look::Start, Look::End]`
186/// - `[Look::Start, Literal(_), Look::End]`
187fn is_anchored_literal(v: &[Hir]) -> bool {
188    match v.len() {
189        2..=3 => (),
190        _ => return false,
191    };
192
193    let first_last = (
194        v.first().expect("length checked"),
195        v.last().expect("length checked"),
196    );
197    if !matches!(first_last,
198        (s, e) if s.kind() == &HirKind::Look(Look::Start)
199        && e.kind() == &HirKind::Look(Look::End)
200    ) {
201        return false;
202    }
203
204    v.iter()
205        .skip(1)
206        .take(v.len() - 2)
207        .all(|h| matches!(h.kind(), HirKind::Literal(_)))
208}
209
210/// Returns true if the elements in a `Concat` pattern are:
211/// - `[Look::Start, Capture(Alternation(Literals...)), Look::End]`
212fn is_anchored_capture(v: &[Hir]) -> bool {
213    if v.len() != 3
214        || !matches!(
215            (v.first().unwrap().kind(), v.last().unwrap().kind()),
216            (&HirKind::Look(Look::Start), &HirKind::Look(Look::End))
217        )
218    {
219        return false;
220    }
221
222    if let HirKind::Capture(cap, ..) = v[1].kind() {
223        let Capture { sub, .. } = cap;
224        if let HirKind::Alternation(alters) = sub.kind() {
225            let has_non_literal = alters
226                .iter()
227                .any(|v| !matches!(v.kind(), &HirKind::Literal(_)));
228            if has_non_literal {
229                return false;
230            }
231        }
232    }
233
234    true
235}
236
237/// Returns the `LIKE` pattern if the `Concat` pattern is partial anchored:
238/// - `[Look::Start, Literal(_)]`
239/// - `[Literal(_), Look::End]`
240///
241/// Full anchored patterns are handled by [`anchored_literal_to_expr`].
242fn partial_anchored_literal_to_like(v: &[Hir]) -> Option<String> {
243    if v.len() != 2 {
244        return None;
245    }
246
247    let (lit, match_begin) = match (&v[0].kind(), &v[1].kind()) {
248        (HirKind::Look(Look::Start), HirKind::Literal(l)) => {
249            (like_str_from_literal(l)?, true)
250        }
251        (HirKind::Literal(l), HirKind::Look(Look::End)) => {
252            (like_str_from_literal(l)?, false)
253        }
254        _ => return None,
255    };
256
257    if match_begin {
258        Some(format!("{lit}%"))
259    } else {
260        Some(format!("%{lit}"))
261    }
262}
263
264/// Extracts a string literal expression assuming that [`is_anchored_literal`]
265/// returned true.
266fn anchored_literal_to_expr(v: &[Hir]) -> Option<Expr> {
267    match v.len() {
268        2 => Some(lit("")),
269        3 => {
270            let HirKind::Literal(l) = v[1].kind() else {
271                return None;
272            };
273            like_str_from_literal(l).map(lit)
274        }
275        _ => None,
276    }
277}
278
279fn anchored_alternation_to_exprs(v: &[Hir]) -> Option<Vec<Expr>> {
280    if 3 != v.len() {
281        return None;
282    }
283
284    if let HirKind::Capture(cap, ..) = v[1].kind() {
285        let Capture { sub, .. } = cap;
286        if let HirKind::Alternation(alters) = sub.kind() {
287            let mut literals = Vec::with_capacity(alters.len());
288            for hir in alters {
289                let mut is_safe = false;
290                if let HirKind::Literal(l) = hir.kind() {
291                    if let Some(safe_literal) = str_from_literal(l).map(lit) {
292                        literals.push(safe_literal);
293                        is_safe = true;
294                    }
295                }
296
297                if !is_safe {
298                    return None;
299                }
300            }
301
302            return Some(literals);
303        } else if let HirKind::Literal(l) = sub.kind() {
304            if let Some(safe_literal) = str_from_literal(l).map(lit) {
305                return Some(vec![safe_literal]);
306            }
307            return None;
308        }
309    }
310    None
311}
312
313/// Tries to lower (transform) a simple regex pattern to a LIKE expression.
314fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option<Expr> {
315    match hir.kind() {
316        HirKind::Empty => {
317            return Some(mode.expr(Box::new(left.clone()), "%".to_owned()));
318        }
319        HirKind::Literal(l) => {
320            let s = like_str_from_literal(l)?;
321            return Some(mode.expr(Box::new(left.clone()), format!("%{s}%")));
322        }
323        HirKind::Concat(inner) if is_anchored_literal(inner) => {
324            return anchored_literal_to_expr(inner).map(|right| {
325                mode.expr_matches_literal(Box::new(left.clone()), Box::new(right))
326            });
327        }
328        HirKind::Concat(inner) if is_anchored_capture(inner) => {
329            return anchored_alternation_to_exprs(inner)
330                .map(|right| left.clone().in_list(right, mode.not));
331        }
332        HirKind::Concat(inner) => {
333            if let Some(pattern) = partial_anchored_literal_to_like(inner)
334                .or_else(|| collect_concat_to_like_string(inner))
335            {
336                return Some(mode.expr(Box::new(left.clone()), pattern));
337            }
338        }
339        _ => {}
340    }
341    None
342}
343
344/// Calls [`lower_simple`] for each alternative and combine the results with `or` or `and`
345/// based on [`OperatorMode`]. Any fail attempt to lower an alternative will makes this
346/// function to return `None`.
347fn lower_alt(mode: &OperatorMode, left: &Expr, alts: &[Hir]) -> Option<Expr> {
348    let mut accu: Option<Expr> = None;
349
350    for part in alts {
351        if let Some(expr) = lower_simple(mode, left, part) {
352            accu = match accu {
353                Some(accu) => {
354                    if mode.not {
355                        Some(accu.and(expr))
356                    } else {
357                        Some(accu.or(expr))
358                    }
359                }
360                None => Some(expr),
361            };
362        } else {
363            return None;
364        }
365    }
366
367    Some(accu.expect("at least two alts"))
368}