1use std::sync::Arc;
21
22use arrow::compute::can_cast_types;
23use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
24use datafusion_common::HashSet;
25use datafusion_common::{
26 exec_err,
27 tree_node::{Transformed, TransformedResult},
28 Result, ScalarValue,
29};
30use datafusion_functions::core::getfield::GetFieldFunc;
31use datafusion_physical_expr::PhysicalExprExt;
32use datafusion_physical_expr::{
33 expressions::{self, CastExpr, Column},
34 ScalarFunctionExpr,
35};
36use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
37
38pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug {
127 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>;
141
142 fn with_partition_values(
143 &self,
144 partition_values: Vec<(FieldRef, ScalarValue)>,
145 ) -> Arc<dyn PhysicalExprAdapter>;
146}
147
148pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug {
149 fn create(
151 &self,
152 logical_file_schema: SchemaRef,
153 physical_file_schema: SchemaRef,
154 ) -> Arc<dyn PhysicalExprAdapter>;
155}
156
157#[derive(Debug, Clone)]
158pub struct DefaultPhysicalExprAdapterFactory;
159
160impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory {
161 fn create(
162 &self,
163 logical_file_schema: SchemaRef,
164 physical_file_schema: SchemaRef,
165 ) -> Arc<dyn PhysicalExprAdapter> {
166 Arc::new(DefaultPhysicalExprAdapter {
167 logical_file_schema,
168 physical_file_schema,
169 partition_values: Vec::new(),
170 })
171 }
172}
173
174#[derive(Debug, Clone)]
195pub struct DefaultPhysicalExprAdapter {
196 logical_file_schema: SchemaRef,
197 physical_file_schema: SchemaRef,
198 partition_values: Vec<(FieldRef, ScalarValue)>,
199}
200
201impl DefaultPhysicalExprAdapter {
202 pub fn new(logical_file_schema: SchemaRef, physical_file_schema: SchemaRef) -> Self {
207 Self {
208 logical_file_schema,
209 physical_file_schema,
210 partition_values: Vec::new(),
211 }
212 }
213}
214
215impl PhysicalExprAdapter for DefaultPhysicalExprAdapter {
216 fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
217 let rewriter = DefaultPhysicalExprAdapterRewriter {
218 logical_file_schema: &self.logical_file_schema,
219 physical_file_schema: &self.physical_file_schema,
220 partition_fields: &self.partition_values,
221 };
222 expr.transform_with_lambdas_params(|expr, lambdas_params| {
223 rewriter.rewrite_expr(Arc::clone(&expr), lambdas_params)
224 })
225 .data()
226 }
227
228 fn with_partition_values(
229 &self,
230 partition_values: Vec<(FieldRef, ScalarValue)>,
231 ) -> Arc<dyn PhysicalExprAdapter> {
232 Arc::new(DefaultPhysicalExprAdapter {
233 partition_values,
234 ..self.clone()
235 })
236 }
237}
238
239struct DefaultPhysicalExprAdapterRewriter<'a> {
240 logical_file_schema: &'a Schema,
241 physical_file_schema: &'a Schema,
242 partition_fields: &'a [(FieldRef, ScalarValue)],
243}
244
245impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
246 fn rewrite_expr(
247 &self,
248 expr: Arc<dyn PhysicalExpr>,
249 lambdas_params: &HashSet<String>,
250 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
251 if let Some(transformed) =
252 self.try_rewrite_struct_field_access(&expr, lambdas_params)?
253 {
254 return Ok(Transformed::yes(transformed));
255 }
256
257 if let Some(column) = expr.as_any().downcast_ref::<Column>() {
258 if !lambdas_params.contains(column.name()) {
259 return self.rewrite_column(Arc::clone(&expr), column);
260 }
261 }
262
263 Ok(Transformed::no(expr))
264 }
265
266 fn try_rewrite_struct_field_access(
270 &self,
271 expr: &Arc<dyn PhysicalExpr>,
272 lambdas_params: &HashSet<String>,
273 ) -> Result<Option<Arc<dyn PhysicalExpr>>> {
274 let get_field_expr =
275 match ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(expr.as_ref()) {
276 Some(expr) => expr,
277 None => return Ok(None),
278 };
279
280 let source_expr = match get_field_expr.args().first() {
281 Some(expr) => expr,
282 None => return Ok(None),
283 };
284
285 let field_name_expr = match get_field_expr.args().get(1) {
286 Some(expr) => expr,
287 None => return Ok(None),
288 };
289
290 let lit = match field_name_expr
291 .as_any()
292 .downcast_ref::<expressions::Literal>()
293 {
294 Some(lit) => lit,
295 None => return Ok(None),
296 };
297
298 let field_name = match lit.value().try_as_str().flatten() {
299 Some(name) => name,
300 None => return Ok(None),
301 };
302
303 let column = match source_expr.as_any().downcast_ref::<Column>() {
304 Some(column) if !lambdas_params.contains(column.name()) => column,
305 _ => return Ok(None),
306 };
307
308 let physical_field =
309 match self.physical_file_schema.field_with_name(column.name()) {
310 Ok(field) => field,
311 Err(_) => return Ok(None),
312 };
313
314 let physical_struct_fields = match physical_field.data_type() {
315 DataType::Struct(fields) => fields,
316 _ => return Ok(None),
317 };
318
319 if physical_struct_fields
320 .iter()
321 .any(|f| f.name() == field_name)
322 {
323 return Ok(None);
324 }
325
326 let logical_field = match self.logical_file_schema.field_with_name(column.name())
327 {
328 Ok(field) => field,
329 Err(_) => return Ok(None),
330 };
331
332 let logical_struct_fields = match logical_field.data_type() {
333 DataType::Struct(fields) => fields,
334 _ => return Ok(None),
335 };
336
337 let logical_struct_field = match logical_struct_fields
338 .iter()
339 .find(|f| f.name() == field_name)
340 {
341 Some(field) => field,
342 None => return Ok(None),
343 };
344
345 let null_value = ScalarValue::Null.cast_to(logical_struct_field.data_type())?;
346 Ok(Some(expressions::lit(null_value)))
347 }
348
349 fn rewrite_column(
350 &self,
351 expr: Arc<dyn PhysicalExpr>,
352 column: &Column,
353 ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
354 let logical_field = match self.logical_file_schema.field_with_name(column.name())
356 {
357 Ok(field) => field,
358 Err(e) => {
359 if let Some(partition_value) = self.get_partition_value(column.name()) {
361 return Ok(Transformed::yes(expressions::lit(partition_value)));
362 }
363 if let Ok(physical_field) =
367 self.physical_file_schema.field_with_name(column.name())
368 {
369 physical_field
373 } else {
374 return Err(e.into());
378 }
379 }
380 };
381
382 let physical_column_index =
384 match self.physical_file_schema.index_of(column.name()) {
385 Ok(index) => index,
386 Err(_) => {
387 if !logical_field.is_nullable() {
388 return exec_err!(
389 "Non-nullable column '{}' is missing from the physical schema",
390 column.name()
391 );
392 }
393 let null_value =
398 ScalarValue::Null.cast_to(logical_field.data_type())?;
399 return Ok(Transformed::yes(expressions::lit(null_value)));
400 }
401 };
402 let physical_field = self.physical_file_schema.field(physical_column_index);
403
404 let column = match (
405 column.index() == physical_column_index,
406 logical_field.data_type() == physical_field.data_type(),
407 ) {
408 (true, true) => return Ok(Transformed::no(expr)),
410 (true, _) => column.clone(),
412 (false, _) => {
413 Column::new_with_schema(logical_field.name(), self.physical_file_schema)?
414 }
415 };
416
417 if logical_field.data_type() == physical_field.data_type() {
418 return Ok(Transformed::yes(Arc::new(column)));
420 }
421
422 let is_compatible =
427 can_cast_types(physical_field.data_type(), logical_field.data_type());
428 if !is_compatible {
429 return exec_err!(
430 "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)",
431 column.name(),
432 physical_field.data_type(),
433 logical_field.data_type()
434 );
435 }
436
437 let cast_expr = Arc::new(CastExpr::new(
438 Arc::new(column),
439 logical_field.data_type().clone(),
440 None,
441 ));
442
443 Ok(Transformed::yes(cast_expr))
444 }
445
446 fn get_partition_value(&self, column_name: &str) -> Option<ScalarValue> {
447 self.partition_fields
448 .iter()
449 .find(|(field, _)| field.name() == column_name)
450 .map(|(_, value)| value.clone())
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use arrow::array::{RecordBatch, RecordBatchOptions};
458 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
459 use datafusion_common::hashbrown::HashSet;
460 use datafusion_common::{assert_contains, record_batch, Result, ScalarValue};
461 use datafusion_expr::Operator;
462 use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal};
463 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
464 use itertools::Itertools;
465 use std::sync::Arc;
466
467 fn create_test_schema() -> (Schema, Schema) {
468 let physical_schema = Schema::new(vec![
469 Field::new("a", DataType::Int32, false),
470 Field::new("b", DataType::Utf8, true),
471 ]);
472
473 let logical_schema = Schema::new(vec![
474 Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, true),
476 Field::new("c", DataType::Float64, true), ]);
478
479 (physical_schema, logical_schema)
480 }
481
482 #[test]
483 fn test_rewrite_column_with_type_cast() {
484 let (physical_schema, logical_schema) = create_test_schema();
485
486 let factory = DefaultPhysicalExprAdapterFactory;
487 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
488 let column_expr = Arc::new(Column::new("a", 0));
489
490 let result = adapter.rewrite(column_expr).unwrap();
491
492 assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
494 }
495
496 #[test]
497 fn test_rewrite_multi_column_expr_with_type_cast() {
498 let (physical_schema, logical_schema) = create_test_schema();
499 let factory = DefaultPhysicalExprAdapterFactory;
500 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
501
502 let column_a = Arc::new(Column::new("a", 0)) as Arc<dyn PhysicalExpr>;
504 let column_c = Arc::new(Column::new("c", 2)) as Arc<dyn PhysicalExpr>;
505 let expr = expressions::BinaryExpr::new(
506 Arc::clone(&column_a),
507 Operator::Plus,
508 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
509 );
510 let expr = expressions::BinaryExpr::new(
511 Arc::new(expr),
512 Operator::Or,
513 Arc::new(expressions::BinaryExpr::new(
514 Arc::clone(&column_c),
515 Operator::Gt,
516 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
517 )),
518 );
519
520 let result = adapter.rewrite(Arc::new(expr)).unwrap();
521 println!("Rewritten expression: {result}");
522
523 let expected = expressions::BinaryExpr::new(
524 Arc::new(CastExpr::new(
525 Arc::new(Column::new("a", 0)),
526 DataType::Int64,
527 None,
528 )),
529 Operator::Plus,
530 Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))),
531 );
532 let expected = Arc::new(expressions::BinaryExpr::new(
533 Arc::new(expected),
534 Operator::Or,
535 Arc::new(expressions::BinaryExpr::new(
536 lit(ScalarValue::Float64(None)), Operator::Gt,
538 Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))),
539 )),
540 )) as Arc<dyn PhysicalExpr>;
541
542 assert_eq!(
543 result.to_string(),
544 expected.to_string(),
545 "The rewritten expression did not match the expected output"
546 );
547 }
548
549 #[test]
550 fn test_rewrite_struct_column_incompatible() {
551 let physical_schema = Schema::new(vec![Field::new(
552 "data",
553 DataType::Struct(vec![Field::new("field1", DataType::Binary, true)].into()),
554 true,
555 )]);
556
557 let logical_schema = Schema::new(vec![Field::new(
558 "data",
559 DataType::Struct(vec![Field::new("field1", DataType::Int32, true)].into()),
560 true,
561 )]);
562
563 let factory = DefaultPhysicalExprAdapterFactory;
564 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
565 let column_expr = Arc::new(Column::new("data", 0));
566
567 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
568 assert_contains!(error_msg, "Cannot cast column 'data'");
569 }
570
571 #[test]
572 fn test_rewrite_struct_compatible_cast() {
573 let physical_schema = Schema::new(vec![Field::new(
574 "data",
575 DataType::Struct(
576 vec![
577 Field::new("id", DataType::Int32, false),
578 Field::new("name", DataType::Utf8, true),
579 ]
580 .into(),
581 ),
582 false,
583 )]);
584
585 let logical_schema = Schema::new(vec![Field::new(
586 "data",
587 DataType::Struct(
588 vec![
589 Field::new("id", DataType::Int64, false),
590 Field::new("name", DataType::Utf8View, true),
591 ]
592 .into(),
593 ),
594 false,
595 )]);
596
597 let factory = DefaultPhysicalExprAdapterFactory;
598 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
599 let column_expr = Arc::new(Column::new("data", 0));
600
601 let result = adapter.rewrite(column_expr).unwrap();
602
603 let expected = Arc::new(CastExpr::new(
604 Arc::new(Column::new("data", 0)),
605 DataType::Struct(
606 vec![
607 Field::new("id", DataType::Int64, false),
608 Field::new("name", DataType::Utf8View, true),
609 ]
610 .into(),
611 ),
612 None,
613 )) as Arc<dyn PhysicalExpr>;
614
615 assert_eq!(result.to_string(), expected.to_string());
616 }
617
618 #[test]
619 fn test_rewrite_missing_column() -> Result<()> {
620 let (physical_schema, logical_schema) = create_test_schema();
621
622 let factory = DefaultPhysicalExprAdapterFactory;
623 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
624 let column_expr = Arc::new(Column::new("c", 2));
625
626 let result = adapter.rewrite(column_expr)?;
627
628 if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
630 assert_eq!(*literal.value(), ScalarValue::Float64(None));
631 } else {
632 panic!("Expected literal expression");
633 }
634
635 Ok(())
636 }
637
638 #[test]
639 fn test_rewrite_missing_column_non_nullable_error() {
640 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
641 let logical_schema = Schema::new(vec![
642 Field::new("a", DataType::Int64, false),
643 Field::new("b", DataType::Utf8, false), ]);
645
646 let factory = DefaultPhysicalExprAdapterFactory;
647 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
648 let column_expr = Arc::new(Column::new("b", 1));
649
650 let error_msg = adapter.rewrite(column_expr).unwrap_err().to_string();
651 assert_contains!(error_msg, "Non-nullable column 'b' is missing");
652 }
653
654 #[test]
655 fn test_rewrite_missing_column_nullable() {
656 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
657 let logical_schema = Schema::new(vec![
658 Field::new("a", DataType::Int64, false),
659 Field::new("b", DataType::Utf8, true), ]);
661
662 let factory = DefaultPhysicalExprAdapterFactory;
663 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
664 let column_expr = Arc::new(Column::new("b", 1));
665
666 let result = adapter.rewrite(column_expr).unwrap();
667
668 let expected =
669 Arc::new(Literal::new(ScalarValue::Utf8(None))) as Arc<dyn PhysicalExpr>;
670
671 assert_eq!(result.to_string(), expected.to_string());
672 }
673
674 #[test]
675 fn test_rewrite_partition_column() -> Result<()> {
676 let (physical_schema, logical_schema) = create_test_schema();
677
678 let partition_field =
679 Arc::new(Field::new("partition_col", DataType::Utf8, false));
680 let partition_value = ScalarValue::Utf8(Some("test_value".to_string()));
681 let partition_values = vec![(partition_field, partition_value)];
682
683 let factory = DefaultPhysicalExprAdapterFactory;
684 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
685 let adapter = adapter.with_partition_values(partition_values);
686
687 let column_expr = Arc::new(Column::new("partition_col", 0));
688 let result = adapter.rewrite(column_expr)?;
689
690 if let Some(literal) = result.as_any().downcast_ref::<expressions::Literal>() {
692 assert_eq!(
693 *literal.value(),
694 ScalarValue::Utf8(Some("test_value".to_string()))
695 );
696 } else {
697 panic!("Expected literal expression");
698 }
699
700 Ok(())
701 }
702
703 #[test]
704 fn test_rewrite_no_change_needed() -> Result<()> {
705 let (physical_schema, logical_schema) = create_test_schema();
706
707 let factory = DefaultPhysicalExprAdapterFactory;
708 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
709 let column_expr = Arc::new(Column::new("b", 1)) as Arc<dyn PhysicalExpr>;
710
711 let result = adapter.rewrite(Arc::clone(&column_expr))?;
712
713 assert!(std::ptr::eq(
716 column_expr.as_ref() as *const dyn PhysicalExpr,
717 result.as_ref() as *const dyn PhysicalExpr
718 ));
719
720 Ok(())
721 }
722
723 #[test]
724 fn test_non_nullable_missing_column_error() {
725 let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
726 let logical_schema = Schema::new(vec![
727 Field::new("a", DataType::Int32, false),
728 Field::new("b", DataType::Utf8, false), ]);
730
731 let factory = DefaultPhysicalExprAdapterFactory;
732 let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema));
733 let column_expr = Arc::new(Column::new("b", 1));
734
735 let result = adapter.rewrite(column_expr);
736 assert!(result.is_err());
737 assert_contains!(
738 result.unwrap_err().to_string(),
739 "Non-nullable column 'b' is missing from the physical schema"
740 );
741 }
742
743 fn batch_project(
745 expr: Vec<Arc<dyn PhysicalExpr>>,
746 batch: &RecordBatch,
747 schema: SchemaRef,
748 ) -> Result<RecordBatch> {
749 let arrays = expr
750 .iter()
751 .map(|expr| {
752 expr.evaluate(batch)
753 .and_then(|v| v.into_array(batch.num_rows()))
754 })
755 .collect::<Result<Vec<_>>>()?;
756
757 if arrays.is_empty() {
758 let options =
759 RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
760 RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options)
761 .map_err(Into::into)
762 } else {
763 RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into)
764 }
765 }
766
767 #[test]
770 fn test_adapt_batches() {
771 let physical_batch = record_batch!(
772 ("a", Int32, vec![Some(1), None, Some(3)]),
773 ("extra", Utf8, vec![Some("x"), Some("y"), None])
774 )
775 .unwrap();
776
777 let physical_schema = physical_batch.schema();
778
779 let logical_schema = Arc::new(Schema::new(vec![
780 Field::new("a", DataType::Int64, true), Field::new("b", DataType::Utf8, true), ]));
783
784 let projection = vec![
785 col("b", &logical_schema).unwrap(),
786 col("a", &logical_schema).unwrap(),
787 ];
788
789 let factory = DefaultPhysicalExprAdapterFactory;
790 let adapter =
791 factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema));
792
793 let adapted_projection = projection
794 .into_iter()
795 .map(|expr| adapter.rewrite(expr).unwrap())
796 .collect_vec();
797
798 let adapted_schema = Arc::new(Schema::new(
799 adapted_projection
800 .iter()
801 .map(|expr| expr.return_field(&physical_schema).unwrap())
802 .collect_vec(),
803 ));
804
805 let res = batch_project(
806 adapted_projection,
807 &physical_batch,
808 Arc::clone(&adapted_schema),
809 )
810 .unwrap();
811
812 assert_eq!(res.num_columns(), 2);
813 assert_eq!(res.column(0).data_type(), &DataType::Utf8);
814 assert_eq!(res.column(1).data_type(), &DataType::Int64);
815 assert_eq!(
816 res.column(0)
817 .as_any()
818 .downcast_ref::<arrow::array::StringArray>()
819 .unwrap()
820 .iter()
821 .collect_vec(),
822 vec![None, None, None]
823 );
824 assert_eq!(
825 res.column(1)
826 .as_any()
827 .downcast_ref::<arrow::array::Int64Array>()
828 .unwrap()
829 .iter()
830 .collect_vec(),
831 vec![Some(1), None, Some(3)]
832 );
833 }
834
835 #[test]
836 fn test_try_rewrite_struct_field_access() {
837 let physical_schema = Schema::new(vec![Field::new(
839 "struct_col",
840 DataType::Struct(
841 vec![Field::new("existing_field", DataType::Int32, true)].into(),
842 ),
843 true,
844 )]);
845
846 let logical_schema = Schema::new(vec![Field::new(
847 "struct_col",
848 DataType::Struct(
849 vec![
850 Field::new("existing_field", DataType::Int32, true),
851 Field::new("missing_field", DataType::Utf8, true),
852 ]
853 .into(),
854 ),
855 true,
856 )]);
857
858 let rewriter = DefaultPhysicalExprAdapterRewriter {
859 logical_file_schema: &logical_schema,
860 physical_file_schema: &physical_schema,
861 partition_fields: &[],
862 };
863
864 let column = Arc::new(Column::new("struct_col", 0)) as Arc<dyn PhysicalExpr>;
866 let result = rewriter
867 .try_rewrite_struct_field_access(&column, &HashSet::new())
868 .unwrap();
869 assert!(result.is_none());
870
871 }
875}