datafusion_physical_expr/expressions/
lambda.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Physical column reference: [`Column`]
19
20use std::hash::Hash;
21use std::sync::Arc;
22use std::{any::Any, sync::OnceLock};
23
24use crate::expressions::Column;
25use crate::physical_expr::PhysicalExpr;
26use crate::PhysicalExprExt;
27use arrow::{
28    datatypes::{DataType, Schema},
29    record_batch::RecordBatch,
30};
31use datafusion_common::tree_node::TreeNodeRecursion;
32use datafusion_common::{internal_err, HashSet, Result};
33use datafusion_expr::ColumnarValue;
34
35/// Represents a lambda with the given parameters name and body
36#[derive(Debug, Eq, Clone)]
37pub struct LambdaExpr {
38    params: Vec<String>,
39    body: Arc<dyn PhysicalExpr>,
40    captures: OnceLock<HashSet<usize>>,
41}
42
43// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196]
44impl PartialEq for LambdaExpr {
45    fn eq(&self, other: &Self) -> bool {
46        self.params.eq(&other.params) && self.body.eq(&other.body)
47    }
48}
49
50impl Hash for LambdaExpr {
51    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
52        self.params.hash(state);
53        self.body.hash(state);
54    }
55}
56
57impl LambdaExpr {
58    /// Create a new lambda expression with the given parameters and body
59    pub fn new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Self {
60        Self {
61            params,
62            body,
63            captures: OnceLock::new(),
64        }
65    }
66
67    /// Get the lambda's params names
68    pub fn params(&self) -> &[String] {
69        &self.params
70    }
71
72    /// Get the lambda's body
73    pub fn body(&self) -> &Arc<dyn PhysicalExpr> {
74        &self.body
75    }
76
77    pub fn captures(&self) -> &HashSet<usize> {
78        self.captures.get_or_init(|| {
79            let mut indices = HashSet::new();
80
81            self.body
82                .apply_with_lambdas_params(|expr, lambdas_params| {
83                    if let Some(column) = expr.as_any().downcast_ref::<Column>() {
84                        if !lambdas_params.contains(column.name()) {
85                            indices.insert(column.index());
86                        }
87                    }
88
89                    Ok(TreeNodeRecursion::Continue)
90                })
91                .unwrap();
92
93            indices
94        })
95    }
96}
97
98impl std::fmt::Display for LambdaExpr {
99    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
100        write!(f, "({}) -> {}", self.params.join(", "), self.body)
101    }
102}
103
104impl PhysicalExpr for LambdaExpr {
105    fn as_any(&self) -> &dyn Any {
106        self
107    }
108
109    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
110        Ok(DataType::Null)
111    }
112
113    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
114        Ok(true)
115    }
116
117    fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
118        internal_err!("Lambda::evaluate() should not be called")
119    }
120
121    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
122        vec![&self.body]
123    }
124
125    fn with_new_children(
126        self: Arc<Self>,
127        children: Vec<Arc<dyn PhysicalExpr>>,
128    ) -> Result<Arc<dyn PhysicalExpr>> {
129        Ok(Arc::new(Self {
130            params: self.params.clone(),
131            body: Arc::clone(&children[0]),
132            captures: OnceLock::new(),
133        }))
134    }
135
136    fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        write!(f, "({}) -> {}", self.params.join(", "), self.body)
138    }
139}