datafusion_physical_expr/utils/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::scalar_function::PhysicalExprExt;
26use crate::tree_node::ExprContext;
27use crate::PhysicalExpr;
28use crate::PhysicalSortExpr;
29
30use arrow::datatypes::Schema;
31use datafusion_common::tree_node::{
32    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
33};
34use datafusion_common::{HashMap, HashSet, Result};
35use datafusion_expr::Operator;
36
37use petgraph::graph::NodeIndex;
38use petgraph::stable_graph::StableGraph;
39
40/// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs.
41///
42/// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
43pub fn split_conjunction(
44    predicate: &Arc<dyn PhysicalExpr>,
45) -> Vec<&Arc<dyn PhysicalExpr>> {
46    split_impl(Operator::And, predicate, vec![])
47}
48
49/// Create a conjunction of the given predicates.
50/// If the input is empty, return a literal true.
51/// If the input contains a single predicate, return the predicate.
52/// Otherwise, return a conjunction of the predicates (e.g. `a AND b AND c`).
53pub fn conjunction(
54    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
55) -> Arc<dyn PhysicalExpr> {
56    conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
57}
58
59/// Create a conjunction of the given predicates.
60/// If the input is empty or the return None.
61/// If the input contains a single predicate, return Some(predicate).
62/// Otherwise, return a Some(..) of a conjunction of the predicates (e.g. `Some(a AND b AND c)`).
63pub fn conjunction_opt(
64    predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
65) -> Option<Arc<dyn PhysicalExpr>> {
66    predicates
67        .into_iter()
68        .fold(None, |acc, predicate| match acc {
69            None => Some(predicate),
70            Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
71        })
72}
73
74/// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs.
75///
76/// For example, split "a1 = a2 OR b1 <= b2 OR c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"]
77pub fn split_disjunction(
78    predicate: &Arc<dyn PhysicalExpr>,
79) -> Vec<&Arc<dyn PhysicalExpr>> {
80    split_impl(Operator::Or, predicate, vec![])
81}
82
83fn split_impl<'a>(
84    operator: Operator,
85    predicate: &'a Arc<dyn PhysicalExpr>,
86    mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
87) -> Vec<&'a Arc<dyn PhysicalExpr>> {
88    match predicate.as_any().downcast_ref::<BinaryExpr>() {
89        Some(binary) if binary.op() == &operator => {
90            let exprs = split_impl(operator, binary.left(), exprs);
91            split_impl(operator, binary.right(), exprs)
92        }
93        Some(_) | None => {
94            exprs.push(predicate);
95            exprs
96        }
97    }
98}
99
100/// This function maps back requirement after ProjectionExec
101/// to the Executor for its input.
102// Specifically, `ProjectionExec` changes index of `Column`s in the schema of its input executor.
103// This function changes requirement given according to ProjectionExec schema to the requirement
104// according to schema of input executor to the ProjectionExec.
105// For instance, Column{"a", 0} would turn to Column{"a", 1}. Please note that this function assumes that
106// name of the Column is unique. If we have a requirement such that Column{"a", 0}, Column{"a", 1}.
107// This function will produce incorrect result (It will only emit single Column as a result).
108pub fn map_columns_before_projection(
109    parent_required: &[Arc<dyn PhysicalExpr>],
110    proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
111) -> Vec<Arc<dyn PhysicalExpr>> {
112    if parent_required.is_empty() {
113        // No need to build mapping.
114        return vec![];
115    }
116    let column_mapping = proj_exprs
117        .iter()
118        .filter_map(|(expr, name)| {
119            expr.as_any()
120                .downcast_ref::<Column>()
121                .map(|column| (name.clone(), column.clone()))
122        })
123        .collect::<HashMap<_, _>>();
124    parent_required
125        .iter()
126        .filter_map(|r| {
127            r.as_any()
128                .downcast_ref::<Column>()
129                .and_then(|c| column_mapping.get(c.name()))
130        })
131        .map(|e| Arc::new(e.clone()) as _)
132        .collect()
133}
134
135/// This function returns all `Arc<dyn PhysicalExpr>`s inside the given
136/// `PhysicalSortExpr` sequence.
137pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
138    sequence: impl IntoIterator<Item = T>,
139) -> Vec<Arc<dyn PhysicalExpr>> {
140    sequence
141        .into_iter()
142        .map(|elem| Arc::clone(&elem.borrow().expr))
143        .collect()
144}
145
146/// This function finds the indices of `targets` within `items` using strict
147/// equality.
148pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
149    targets: impl IntoIterator<Item = T>,
150    items: &[Arc<dyn PhysicalExpr>],
151) -> Vec<usize> {
152    targets
153        .into_iter()
154        .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
155        .collect()
156}
157
158pub type ExprTreeNode<T> = ExprContext<Option<T>>;
159
160/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression
161/// DAG) by collecting identical expressions in one node. Caller specifies the node type
162/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from
163/// the [`ExprTreeNode`] ancillary object.
164struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
165    // The resulting DAEG (expression DAG).
166    graph: StableGraph<T, usize>,
167    // A vector of visited expression nodes and their corresponding node indices.
168    visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
169    // A function to convert an input expression node to T.
170    constructor: &'a F,
171}
172
173impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
174    // This method mutates an expression node by transforming it to a physical expression
175    // and adding it to the graph. The method returns the mutated expression node.
176    fn mutate(
177        &mut self,
178        mut node: ExprTreeNode<NodeIndex>,
179    ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
180        // Get the expression associated with the input expression node.
181        let expr = &node.expr;
182
183        // Check if the expression has already been visited.
184        let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
185            // If the expression has been visited, return the corresponding node index.
186            Some((_, idx)) => *idx,
187            // If the expression has not been visited, add a new node to the graph and
188            // add edges to its child nodes. Add the visited expression to the vector
189            // of visited expressions and return the newly created node index.
190            None => {
191                let node_idx = self.graph.add_node((self.constructor)(&node)?);
192                for expr_node in node.children.iter() {
193                    self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
194                }
195                self.visited_plans.push((Arc::clone(expr), node_idx));
196                node_idx
197            }
198        };
199        // Set the data field of the input expression node to the corresponding node index.
200        node.data = Some(node_idx);
201        // Return the mutated expression node.
202        Ok(Transformed::yes(node))
203    }
204}
205
206// A function that builds a directed acyclic graph of physical expression trees.
207pub fn build_dag<T, F>(
208    expr: Arc<dyn PhysicalExpr>,
209    constructor: &F,
210) -> Result<(NodeIndex, StableGraph<T, usize>)>
211where
212    F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
213{
214    // Create a new expression tree node from the input expression.
215    let init = ExprTreeNode::new_default(expr);
216    // Create a new `PhysicalExprDAEGBuilder` instance.
217    let mut builder = PhysicalExprDAEGBuilder {
218        graph: StableGraph::<T, usize>::new(),
219        visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
220        constructor,
221    };
222    // Use the builder to transform the expression tree node into a DAG.
223    let root = init.transform_up(|node| builder.mutate(node)).data()?;
224    // Return a tuple containing the root node index and the DAG.
225    Ok((root.data.unwrap(), builder.graph))
226}
227
228/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`].
229pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
230    let mut columns = HashSet::<Column>::new();
231    expr.apply_with_lambdas_params(|expr, lambdas_params| {
232        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
233            if !lambdas_params.contains(column.name()) {
234                columns.get_or_insert_owned(column);
235            }
236        }
237        Ok(TreeNodeRecursion::Continue)
238    })
239    // pre_visit always returns OK, so this will always too
240    .expect("no way to return error during recursion");
241    columns
242}
243
244/// Re-assign indices of [`Column`]s within the given [`PhysicalExpr`] according to
245/// the provided [`Schema`].
246///
247/// This can be useful when attempting to map an expression onto a different schema.
248///
249/// # Errors
250///
251/// This function will return an error if any column in the expression cannot be found
252/// in the provided schema.
253pub fn reassign_expr_columns(
254    expr: Arc<dyn PhysicalExpr>,
255    schema: &Schema,
256) -> Result<Arc<dyn PhysicalExpr>> {
257    expr.transform_down_with_lambdas_params(|expr, lambdas_params| {
258        if let Some(column) = expr.as_any().downcast_ref::<Column>() {
259            if !lambdas_params.contains(column.name()) {
260                let index = schema.index_of(column.name())?;
261
262                return Ok(Transformed::yes(Arc::new(Column::new(
263                    column.name(),
264                    index,
265                ))));
266            }
267        }
268        Ok(Transformed::no(expr))
269    })
270    .data()
271}
272
273#[cfg(test)]
274pub(crate) mod tests {
275    use std::any::Any;
276    use std::fmt::{Display, Formatter};
277
278    use super::*;
279    use crate::expressions::{binary, cast, col, in_list, lit, Literal};
280
281    use arrow::array::{ArrayRef, Float32Array, Float64Array};
282    use arrow::datatypes::{DataType, Field, Schema};
283    use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue};
284    use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
285    use datafusion_expr::{
286        ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
287    };
288
289    use petgraph::visit::Bfs;
290
291    #[derive(Debug, PartialEq, Eq, Hash)]
292    pub struct TestScalarUDF {
293        pub(crate) signature: Signature,
294    }
295
296    impl TestScalarUDF {
297        pub fn new() -> Self {
298            use DataType::*;
299            Self {
300                signature: Signature::uniform(
301                    1,
302                    vec![Float64, Float32],
303                    Volatility::Immutable,
304                ),
305            }
306        }
307    }
308
309    impl ScalarUDFImpl for TestScalarUDF {
310        fn as_any(&self) -> &dyn Any {
311            self
312        }
313        fn name(&self) -> &str {
314            "test-scalar-udf"
315        }
316
317        fn signature(&self) -> &Signature {
318            &self.signature
319        }
320
321        fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
322            let arg_type = &arg_types[0];
323
324            match arg_type {
325                DataType::Float32 => Ok(DataType::Float32),
326                _ => Ok(DataType::Float64),
327            }
328        }
329
330        fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
331            Ok(input[0].sort_properties)
332        }
333
334        fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
335            let args = ColumnarValue::values_to_arrays(&args.args)?;
336
337            let arr: ArrayRef = match args[0].data_type() {
338                DataType::Float64 => Arc::new({
339                    let arg = &args[0]
340                        .as_any()
341                        .downcast_ref::<Float64Array>()
342                        .ok_or_else(|| {
343                            internal_datafusion_err!(
344                                "could not cast {} to {}",
345                                self.name(),
346                                std::any::type_name::<Float64Array>()
347                            )
348                        })?;
349
350                    arg.iter()
351                        .map(|a| a.map(f64::floor))
352                        .collect::<Float64Array>()
353                }),
354                DataType::Float32 => Arc::new({
355                    let arg = &args[0]
356                        .as_any()
357                        .downcast_ref::<Float32Array>()
358                        .ok_or_else(|| {
359                            internal_datafusion_err!(
360                                "could not cast {} to {}",
361                                self.name(),
362                                std::any::type_name::<Float32Array>()
363                            )
364                        })?;
365
366                    arg.iter()
367                        .map(|a| a.map(f32::floor))
368                        .collect::<Float32Array>()
369                }),
370                other => {
371                    return exec_err!(
372                        "Unsupported data type {other:?} for function {}",
373                        self.name()
374                    );
375                }
376            };
377            Ok(ColumnarValue::Array(arr))
378        }
379    }
380
381    #[derive(Clone)]
382    struct DummyProperty {
383        expr_type: String,
384    }
385
386    /// This is a dummy node in the DAEG; it stores a reference to the actual
387    /// [PhysicalExpr] as well as a dummy property.
388    #[derive(Clone)]
389    struct PhysicalExprDummyNode {
390        pub expr: Arc<dyn PhysicalExpr>,
391        pub property: DummyProperty,
392    }
393
394    impl Display for PhysicalExprDummyNode {
395        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396            write!(f, "{}", self.expr)
397        }
398    }
399
400    fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
401        let expr = Arc::clone(&node.expr);
402        let dummy_property = if expr.as_any().is::<BinaryExpr>() {
403            "Binary"
404        } else if expr.as_any().is::<Column>() {
405            "Column"
406        } else if expr.as_any().is::<Literal>() {
407            "Literal"
408        } else {
409            "Other"
410        }
411        .to_owned();
412        Ok(PhysicalExprDummyNode {
413            expr,
414            property: DummyProperty {
415                expr_type: dummy_property,
416            },
417        })
418    }
419
420    #[test]
421    fn test_build_dag() -> Result<()> {
422        let schema = Schema::new(vec![
423            Field::new("0", DataType::Int32, true),
424            Field::new("1", DataType::Int32, true),
425            Field::new("2", DataType::Int32, true),
426        ]);
427        let expr = binary(
428            cast(
429                binary(
430                    col("0", &schema)?,
431                    Operator::Plus,
432                    col("1", &schema)?,
433                    &schema,
434                )?,
435                &schema,
436                DataType::Int64,
437            )?,
438            Operator::Gt,
439            binary(
440                cast(col("2", &schema)?, &schema, DataType::Int64)?,
441                Operator::Plus,
442                lit(ScalarValue::Int64(Some(10))),
443                &schema,
444            )?,
445            &schema,
446        )?;
447        let mut vector_dummy_props = vec![];
448        let (root, graph) = build_dag(expr, &make_dummy_node)?;
449        let mut bfs = Bfs::new(&graph, root);
450        while let Some(node_index) = bfs.next(&graph) {
451            let node = &graph[node_index];
452            vector_dummy_props.push(node.property.clone());
453        }
454
455        assert_eq!(
456            vector_dummy_props
457                .iter()
458                .filter(|property| property.expr_type == "Binary")
459                .count(),
460            3
461        );
462        assert_eq!(
463            vector_dummy_props
464                .iter()
465                .filter(|property| property.expr_type == "Column")
466                .count(),
467            3
468        );
469        assert_eq!(
470            vector_dummy_props
471                .iter()
472                .filter(|property| property.expr_type == "Literal")
473                .count(),
474            1
475        );
476        assert_eq!(
477            vector_dummy_props
478                .iter()
479                .filter(|property| property.expr_type == "Other")
480                .count(),
481            2
482        );
483        Ok(())
484    }
485
486    #[test]
487    fn test_convert_to_expr() -> Result<()> {
488        let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
489        let sort_expr = vec![PhysicalSortExpr {
490            expr: col("a", &schema)?,
491            options: Default::default(),
492        }];
493        assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
494        Ok(())
495    }
496
497    #[test]
498    fn test_get_indices_of_exprs_strict() {
499        let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
500            Arc::new(Column::new("a", 0)),
501            Arc::new(Column::new("b", 1)),
502            Arc::new(Column::new("c", 2)),
503            Arc::new(Column::new("d", 3)),
504        ];
505        let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
506            Arc::new(Column::new("b", 1)),
507            Arc::new(Column::new("c", 2)),
508            Arc::new(Column::new("a", 0)),
509        ];
510        assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
511        assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
512    }
513
514    #[test]
515    fn test_reassign_expr_columns_in_list() {
516        let int_field = Field::new("should_not_matter", DataType::Int64, true);
517        let dict_field = Field::new(
518            "id",
519            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
520            true,
521        );
522        let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
523        let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
524        let pred = in_list(
525            Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
526            vec![lit(ScalarValue::Dictionary(
527                Box::new(DataType::Int32),
528                Box::new(ScalarValue::from("2")),
529            ))],
530            &false,
531            &schema_big,
532        )
533        .unwrap();
534
535        let actual = reassign_expr_columns(pred, &schema_small).unwrap();
536
537        let expected = in_list(
538            Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
539            vec![lit(ScalarValue::Dictionary(
540                Box::new(DataType::Int32),
541                Box::new(ScalarValue::from("2")),
542            ))],
543            &false,
544            &schema_small,
545        )
546        .unwrap();
547
548        assert_eq!(actual.as_ref(), expected.as_ref());
549    }
550
551    #[test]
552    fn test_collect_columns() -> Result<()> {
553        let expr1 = Arc::new(Column::new("col1", 2)) as _;
554        let mut expected = HashSet::new();
555        expected.insert(Column::new("col1", 2));
556        assert_eq!(collect_columns(&expr1), expected);
557
558        let expr2 = Arc::new(Column::new("col2", 5)) as _;
559        let mut expected = HashSet::new();
560        expected.insert(Column::new("col2", 5));
561        assert_eq!(collect_columns(&expr2), expected);
562
563        let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
564        let mut expected = HashSet::new();
565        expected.insert(Column::new("col1", 2));
566        expected.insert(Column::new("col2", 5));
567        assert_eq!(collect_columns(&expr3), expected);
568        Ok(())
569    }
570}