1mod guarantee;
19pub use guarantee::{Guarantee, LiteralGuarantee};
20
21use std::borrow::Borrow;
22use std::sync::Arc;
23
24use crate::expressions::{BinaryExpr, Column};
25use crate::scalar_function::PhysicalExprExt;
26use crate::tree_node::ExprContext;
27use crate::PhysicalExpr;
28use crate::PhysicalSortExpr;
29
30use arrow::datatypes::Schema;
31use datafusion_common::tree_node::{
32 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
33};
34use datafusion_common::{HashMap, HashSet, Result};
35use datafusion_expr::Operator;
36
37use petgraph::graph::NodeIndex;
38use petgraph::stable_graph::StableGraph;
39
40pub fn split_conjunction(
44 predicate: &Arc<dyn PhysicalExpr>,
45) -> Vec<&Arc<dyn PhysicalExpr>> {
46 split_impl(Operator::And, predicate, vec![])
47}
48
49pub fn conjunction(
54 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
55) -> Arc<dyn PhysicalExpr> {
56 conjunction_opt(predicates).unwrap_or_else(|| crate::expressions::lit(true))
57}
58
59pub fn conjunction_opt(
64 predicates: impl IntoIterator<Item = Arc<dyn PhysicalExpr>>,
65) -> Option<Arc<dyn PhysicalExpr>> {
66 predicates
67 .into_iter()
68 .fold(None, |acc, predicate| match acc {
69 None => Some(predicate),
70 Some(acc) => Some(Arc::new(BinaryExpr::new(acc, Operator::And, predicate))),
71 })
72}
73
74pub fn split_disjunction(
78 predicate: &Arc<dyn PhysicalExpr>,
79) -> Vec<&Arc<dyn PhysicalExpr>> {
80 split_impl(Operator::Or, predicate, vec![])
81}
82
83fn split_impl<'a>(
84 operator: Operator,
85 predicate: &'a Arc<dyn PhysicalExpr>,
86 mut exprs: Vec<&'a Arc<dyn PhysicalExpr>>,
87) -> Vec<&'a Arc<dyn PhysicalExpr>> {
88 match predicate.as_any().downcast_ref::<BinaryExpr>() {
89 Some(binary) if binary.op() == &operator => {
90 let exprs = split_impl(operator, binary.left(), exprs);
91 split_impl(operator, binary.right(), exprs)
92 }
93 Some(_) | None => {
94 exprs.push(predicate);
95 exprs
96 }
97 }
98}
99
100pub fn map_columns_before_projection(
109 parent_required: &[Arc<dyn PhysicalExpr>],
110 proj_exprs: &[(Arc<dyn PhysicalExpr>, String)],
111) -> Vec<Arc<dyn PhysicalExpr>> {
112 if parent_required.is_empty() {
113 return vec![];
115 }
116 let column_mapping = proj_exprs
117 .iter()
118 .filter_map(|(expr, name)| {
119 expr.as_any()
120 .downcast_ref::<Column>()
121 .map(|column| (name.clone(), column.clone()))
122 })
123 .collect::<HashMap<_, _>>();
124 parent_required
125 .iter()
126 .filter_map(|r| {
127 r.as_any()
128 .downcast_ref::<Column>()
129 .and_then(|c| column_mapping.get(c.name()))
130 })
131 .map(|e| Arc::new(e.clone()) as _)
132 .collect()
133}
134
135pub fn convert_to_expr<T: Borrow<PhysicalSortExpr>>(
138 sequence: impl IntoIterator<Item = T>,
139) -> Vec<Arc<dyn PhysicalExpr>> {
140 sequence
141 .into_iter()
142 .map(|elem| Arc::clone(&elem.borrow().expr))
143 .collect()
144}
145
146pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
149 targets: impl IntoIterator<Item = T>,
150 items: &[Arc<dyn PhysicalExpr>],
151) -> Vec<usize> {
152 targets
153 .into_iter()
154 .filter_map(|target| items.iter().position(|e| e.eq(target.borrow())))
155 .collect()
156}
157
158pub type ExprTreeNode<T> = ExprContext<Option<T>>;
159
160struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
165 graph: StableGraph<T, usize>,
167 visited_plans: Vec<(Arc<dyn PhysicalExpr>, NodeIndex)>,
169 constructor: &'a F,
171}
172
173impl<T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> PhysicalExprDAEGBuilder<'_, T, F> {
174 fn mutate(
177 &mut self,
178 mut node: ExprTreeNode<NodeIndex>,
179 ) -> Result<Transformed<ExprTreeNode<NodeIndex>>> {
180 let expr = &node.expr;
182
183 let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
185 Some((_, idx)) => *idx,
187 None => {
191 let node_idx = self.graph.add_node((self.constructor)(&node)?);
192 for expr_node in node.children.iter() {
193 self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
194 }
195 self.visited_plans.push((Arc::clone(expr), node_idx));
196 node_idx
197 }
198 };
199 node.data = Some(node_idx);
201 Ok(Transformed::yes(node))
203 }
204}
205
206pub fn build_dag<T, F>(
208 expr: Arc<dyn PhysicalExpr>,
209 constructor: &F,
210) -> Result<(NodeIndex, StableGraph<T, usize>)>
211where
212 F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
213{
214 let init = ExprTreeNode::new_default(expr);
216 let mut builder = PhysicalExprDAEGBuilder {
218 graph: StableGraph::<T, usize>::new(),
219 visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
220 constructor,
221 };
222 let root = init.transform_up(|node| builder.mutate(node)).data()?;
224 Ok((root.data.unwrap(), builder.graph))
226}
227
228pub fn collect_columns(expr: &Arc<dyn PhysicalExpr>) -> HashSet<Column> {
230 let mut columns = HashSet::<Column>::new();
231 expr.apply_with_lambdas_params(|expr, lambdas_params| {
232 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
233 if !lambdas_params.contains(column.name()) {
234 columns.get_or_insert_owned(column);
235 }
236 }
237 Ok(TreeNodeRecursion::Continue)
238 })
239 .expect("no way to return error during recursion");
241 columns
242}
243
244pub fn reassign_expr_columns(
254 expr: Arc<dyn PhysicalExpr>,
255 schema: &Schema,
256) -> Result<Arc<dyn PhysicalExpr>> {
257 expr.transform_down_with_lambdas_params(|expr, lambdas_params| {
258 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
259 if !lambdas_params.contains(column.name()) {
260 let index = schema.index_of(column.name())?;
261
262 return Ok(Transformed::yes(Arc::new(Column::new(
263 column.name(),
264 index,
265 ))));
266 }
267 }
268 Ok(Transformed::no(expr))
269 })
270 .data()
271}
272
273#[cfg(test)]
274pub(crate) mod tests {
275 use std::any::Any;
276 use std::fmt::{Display, Formatter};
277
278 use super::*;
279 use crate::expressions::{binary, cast, col, in_list, lit, Literal};
280
281 use arrow::array::{ArrayRef, Float32Array, Float64Array};
282 use arrow::datatypes::{DataType, Field, Schema};
283 use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue};
284 use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
285 use datafusion_expr::{
286 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
287 };
288
289 use petgraph::visit::Bfs;
290
291 #[derive(Debug, PartialEq, Eq, Hash)]
292 pub struct TestScalarUDF {
293 pub(crate) signature: Signature,
294 }
295
296 impl TestScalarUDF {
297 pub fn new() -> Self {
298 use DataType::*;
299 Self {
300 signature: Signature::uniform(
301 1,
302 vec![Float64, Float32],
303 Volatility::Immutable,
304 ),
305 }
306 }
307 }
308
309 impl ScalarUDFImpl for TestScalarUDF {
310 fn as_any(&self) -> &dyn Any {
311 self
312 }
313 fn name(&self) -> &str {
314 "test-scalar-udf"
315 }
316
317 fn signature(&self) -> &Signature {
318 &self.signature
319 }
320
321 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
322 let arg_type = &arg_types[0];
323
324 match arg_type {
325 DataType::Float32 => Ok(DataType::Float32),
326 _ => Ok(DataType::Float64),
327 }
328 }
329
330 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
331 Ok(input[0].sort_properties)
332 }
333
334 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
335 let args = ColumnarValue::values_to_arrays(&args.args)?;
336
337 let arr: ArrayRef = match args[0].data_type() {
338 DataType::Float64 => Arc::new({
339 let arg = &args[0]
340 .as_any()
341 .downcast_ref::<Float64Array>()
342 .ok_or_else(|| {
343 internal_datafusion_err!(
344 "could not cast {} to {}",
345 self.name(),
346 std::any::type_name::<Float64Array>()
347 )
348 })?;
349
350 arg.iter()
351 .map(|a| a.map(f64::floor))
352 .collect::<Float64Array>()
353 }),
354 DataType::Float32 => Arc::new({
355 let arg = &args[0]
356 .as_any()
357 .downcast_ref::<Float32Array>()
358 .ok_or_else(|| {
359 internal_datafusion_err!(
360 "could not cast {} to {}",
361 self.name(),
362 std::any::type_name::<Float32Array>()
363 )
364 })?;
365
366 arg.iter()
367 .map(|a| a.map(f32::floor))
368 .collect::<Float32Array>()
369 }),
370 other => {
371 return exec_err!(
372 "Unsupported data type {other:?} for function {}",
373 self.name()
374 );
375 }
376 };
377 Ok(ColumnarValue::Array(arr))
378 }
379 }
380
381 #[derive(Clone)]
382 struct DummyProperty {
383 expr_type: String,
384 }
385
386 #[derive(Clone)]
389 struct PhysicalExprDummyNode {
390 pub expr: Arc<dyn PhysicalExpr>,
391 pub property: DummyProperty,
392 }
393
394 impl Display for PhysicalExprDummyNode {
395 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
396 write!(f, "{}", self.expr)
397 }
398 }
399
400 fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
401 let expr = Arc::clone(&node.expr);
402 let dummy_property = if expr.as_any().is::<BinaryExpr>() {
403 "Binary"
404 } else if expr.as_any().is::<Column>() {
405 "Column"
406 } else if expr.as_any().is::<Literal>() {
407 "Literal"
408 } else {
409 "Other"
410 }
411 .to_owned();
412 Ok(PhysicalExprDummyNode {
413 expr,
414 property: DummyProperty {
415 expr_type: dummy_property,
416 },
417 })
418 }
419
420 #[test]
421 fn test_build_dag() -> Result<()> {
422 let schema = Schema::new(vec![
423 Field::new("0", DataType::Int32, true),
424 Field::new("1", DataType::Int32, true),
425 Field::new("2", DataType::Int32, true),
426 ]);
427 let expr = binary(
428 cast(
429 binary(
430 col("0", &schema)?,
431 Operator::Plus,
432 col("1", &schema)?,
433 &schema,
434 )?,
435 &schema,
436 DataType::Int64,
437 )?,
438 Operator::Gt,
439 binary(
440 cast(col("2", &schema)?, &schema, DataType::Int64)?,
441 Operator::Plus,
442 lit(ScalarValue::Int64(Some(10))),
443 &schema,
444 )?,
445 &schema,
446 )?;
447 let mut vector_dummy_props = vec![];
448 let (root, graph) = build_dag(expr, &make_dummy_node)?;
449 let mut bfs = Bfs::new(&graph, root);
450 while let Some(node_index) = bfs.next(&graph) {
451 let node = &graph[node_index];
452 vector_dummy_props.push(node.property.clone());
453 }
454
455 assert_eq!(
456 vector_dummy_props
457 .iter()
458 .filter(|property| property.expr_type == "Binary")
459 .count(),
460 3
461 );
462 assert_eq!(
463 vector_dummy_props
464 .iter()
465 .filter(|property| property.expr_type == "Column")
466 .count(),
467 3
468 );
469 assert_eq!(
470 vector_dummy_props
471 .iter()
472 .filter(|property| property.expr_type == "Literal")
473 .count(),
474 1
475 );
476 assert_eq!(
477 vector_dummy_props
478 .iter()
479 .filter(|property| property.expr_type == "Other")
480 .count(),
481 2
482 );
483 Ok(())
484 }
485
486 #[test]
487 fn test_convert_to_expr() -> Result<()> {
488 let schema = Schema::new(vec![Field::new("a", DataType::UInt64, false)]);
489 let sort_expr = vec![PhysicalSortExpr {
490 expr: col("a", &schema)?,
491 options: Default::default(),
492 }];
493 assert!(convert_to_expr(&sort_expr)[0].eq(&sort_expr[0].expr));
494 Ok(())
495 }
496
497 #[test]
498 fn test_get_indices_of_exprs_strict() {
499 let list1: Vec<Arc<dyn PhysicalExpr>> = vec![
500 Arc::new(Column::new("a", 0)),
501 Arc::new(Column::new("b", 1)),
502 Arc::new(Column::new("c", 2)),
503 Arc::new(Column::new("d", 3)),
504 ];
505 let list2: Vec<Arc<dyn PhysicalExpr>> = vec![
506 Arc::new(Column::new("b", 1)),
507 Arc::new(Column::new("c", 2)),
508 Arc::new(Column::new("a", 0)),
509 ];
510 assert_eq!(get_indices_of_exprs_strict(&list1, &list2), vec![2, 0, 1]);
511 assert_eq!(get_indices_of_exprs_strict(&list2, &list1), vec![1, 2, 0]);
512 }
513
514 #[test]
515 fn test_reassign_expr_columns_in_list() {
516 let int_field = Field::new("should_not_matter", DataType::Int64, true);
517 let dict_field = Field::new(
518 "id",
519 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
520 true,
521 );
522 let schema_small = Arc::new(Schema::new(vec![dict_field.clone()]));
523 let schema_big = Arc::new(Schema::new(vec![int_field, dict_field]));
524 let pred = in_list(
525 Arc::new(Column::new_with_schema("id", &schema_big).unwrap()),
526 vec![lit(ScalarValue::Dictionary(
527 Box::new(DataType::Int32),
528 Box::new(ScalarValue::from("2")),
529 ))],
530 &false,
531 &schema_big,
532 )
533 .unwrap();
534
535 let actual = reassign_expr_columns(pred, &schema_small).unwrap();
536
537 let expected = in_list(
538 Arc::new(Column::new_with_schema("id", &schema_small).unwrap()),
539 vec![lit(ScalarValue::Dictionary(
540 Box::new(DataType::Int32),
541 Box::new(ScalarValue::from("2")),
542 ))],
543 &false,
544 &schema_small,
545 )
546 .unwrap();
547
548 assert_eq!(actual.as_ref(), expected.as_ref());
549 }
550
551 #[test]
552 fn test_collect_columns() -> Result<()> {
553 let expr1 = Arc::new(Column::new("col1", 2)) as _;
554 let mut expected = HashSet::new();
555 expected.insert(Column::new("col1", 2));
556 assert_eq!(collect_columns(&expr1), expected);
557
558 let expr2 = Arc::new(Column::new("col2", 5)) as _;
559 let mut expected = HashSet::new();
560 expected.insert(Column::new("col2", 5));
561 assert_eq!(collect_columns(&expr2), expected);
562
563 let expr3 = Arc::new(BinaryExpr::new(expr1, Operator::Plus, expr2)) as _;
564 let mut expected = HashSet::new();
565 expected.insert(Column::new("col1", 2));
566 expected.insert(Column::new("col2", 5));
567 assert_eq!(collect_columns(&expr3), expected);
568 Ok(())
569 }
570}