datafusion_physical_expr/expressions/
lambda.rs1use std::hash::Hash;
21use std::sync::Arc;
22use std::{any::Any, sync::OnceLock};
23
24use crate::expressions::Column;
25use crate::physical_expr::PhysicalExpr;
26use crate::PhysicalExprExt;
27use arrow::{
28 datatypes::{DataType, Schema},
29 record_batch::RecordBatch,
30};
31use datafusion_common::tree_node::TreeNodeRecursion;
32use datafusion_common::{internal_err, HashSet, Result};
33use datafusion_expr::ColumnarValue;
34
35#[derive(Debug, Eq, Clone)]
37pub struct LambdaExpr {
38 params: Vec<String>,
39 body: Arc<dyn PhysicalExpr>,
40 captures: OnceLock<HashSet<usize>>,
41}
42
43impl PartialEq for LambdaExpr {
45 fn eq(&self, other: &Self) -> bool {
46 self.params.eq(&other.params) && self.body.eq(&other.body)
47 }
48}
49
50impl Hash for LambdaExpr {
51 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
52 self.params.hash(state);
53 self.body.hash(state);
54 }
55}
56
57impl LambdaExpr {
58 pub fn new(params: Vec<String>, body: Arc<dyn PhysicalExpr>) -> Self {
60 Self {
61 params,
62 body,
63 captures: OnceLock::new(),
64 }
65 }
66
67 pub fn params(&self) -> &[String] {
69 &self.params
70 }
71
72 pub fn body(&self) -> &Arc<dyn PhysicalExpr> {
74 &self.body
75 }
76
77 pub fn captures(&self) -> &HashSet<usize> {
78 self.captures.get_or_init(|| {
79 let mut indices = HashSet::new();
80
81 self.body
82 .apply_with_lambdas_params(|expr, lambdas_params| {
83 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
84 if !lambdas_params.contains(column.name()) {
85 indices.insert(column.index());
86 }
87 }
88
89 Ok(TreeNodeRecursion::Continue)
90 })
91 .unwrap();
92
93 indices
94 })
95 }
96}
97
98impl std::fmt::Display for LambdaExpr {
99 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
100 write!(f, "({}) -> {}", self.params.join(", "), self.body)
101 }
102}
103
104impl PhysicalExpr for LambdaExpr {
105 fn as_any(&self) -> &dyn Any {
106 self
107 }
108
109 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
110 Ok(DataType::Null)
111 }
112
113 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
114 Ok(true)
115 }
116
117 fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
118 internal_err!("Lambda::evaluate() should not be called")
119 }
120
121 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
122 vec![&self.body]
123 }
124
125 fn with_new_children(
126 self: Arc<Self>,
127 children: Vec<Arc<dyn PhysicalExpr>>,
128 ) -> Result<Arc<dyn PhysicalExpr>> {
129 Ok(Arc::new(Self {
130 params: self.params.clone(),
131 body: Arc::clone(&children[0]),
132 captures: OnceLock::new(),
133 }))
134 }
135
136 fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 write!(f, "({}) -> {}", self.params.join(", "), self.body)
138 }
139}