1use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use datafusion_expr::tree_node::TreeNodeRewriterWithPayload;
24use itertools::{izip, Itertools as _};
25
26use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
27
28use crate::analyzer::AnalyzerRule;
29use crate::utils::NamePreserver;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::Transformed;
32use datafusion_common::{
33 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
34 plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
35 TableReference,
36};
37use datafusion_expr::expr::{
38 self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
39 InSubquery, Like, ScalarFunction, Sort, WindowFunction,
40};
41use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
42use datafusion_expr::expr_schema::cast_subquery;
43use datafusion_expr::logical_plan::Subquery;
44use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
45use datafusion_expr::type_coercion::functions::{
46 data_types_with_scalar_udf, fields_with_aggregate_udf,
47};
48use datafusion_expr::type_coercion::other::{
49 get_coerce_type_for_case_expression, get_coerce_type_for_list,
50};
51use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
52use datafusion_expr::utils::merge_schema;
53use datafusion_expr::{
54 is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
55 AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection,
56 ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
57};
58
59#[derive(Default, Debug)]
62pub struct TypeCoercion {}
63
64impl TypeCoercion {
65 pub fn new() -> Self {
66 Self {}
67 }
68}
69
70fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
72 if !config.optimizer.expand_views_at_output {
73 return Ok(plan);
74 }
75
76 let outer_refs = plan.expressions();
77 if outer_refs.is_empty() {
78 return Ok(plan);
79 }
80
81 if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
82 coerce_plan_expr_for_schema(plan, &dfschema?)
83 } else {
84 Ok(plan)
85 }
86}
87
88impl AnalyzerRule for TypeCoercion {
89 fn name(&self) -> &str {
90 "type_coercion"
91 }
92
93 fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
94 let empty_schema = DFSchema::empty();
95
96 let transformed_plan = plan
98 .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
99 .data;
100
101 coerce_output(transformed_plan, config)
103 }
104}
105
106fn analyze_internal(
110 external_schema: &DFSchema,
111 plan: LogicalPlan,
112) -> Result<Transformed<LogicalPlan>> {
113 let mut schema = merge_schema(&plan.inputs());
116
117 if let LogicalPlan::TableScan(ts) = &plan {
118 let source_schema = DFSchema::try_from_qualified_schema(
119 ts.table_name.clone(),
120 &ts.source.schema(),
121 )?;
122 schema.merge(&source_schema);
123 }
124
125 schema.merge(external_schema);
129
130 let plan = if let LogicalPlan::Filter(mut filter) = plan {
132 filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
133 LogicalPlan::Filter(filter)
134 } else {
135 plan
136 };
137
138 let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
139
140 let name_preserver = NamePreserver::new(&plan);
141 plan.map_expressions(|expr| {
143 let original_name = name_preserver.save(&expr);
144 expr.rewrite_with_schema(&schema, &mut expr_rewrite)
145 .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
146 })?
147 .map_data(|plan| expr_rewrite.coerce_plan(plan))?
149 .map_data(|plan| plan.recompute_schema())
151}
152
153pub struct TypeCoercionRewriter<'a> {
155 pub(crate) schema: &'a DFSchema,
156}
157
158impl<'a> TypeCoercionRewriter<'a> {
159 pub fn new(schema: &'a DFSchema) -> Self {
162 Self { schema }
163 }
164
165 pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
170 match plan {
171 LogicalPlan::Join(join) => self.coerce_join(join),
172 LogicalPlan::Union(union) => Self::coerce_union(union),
173 LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
174 _ => Ok(plan),
175 }
176 }
177
178 pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
187 join.on = join
188 .on
189 .into_iter()
190 .map(|(lhs, rhs)| {
191 let left_schema = join.left.schema();
194 let right_schema = join.right.schema();
195 let (lhs, rhs) = self.coerce_binary_op(
196 lhs,
197 left_schema,
198 Operator::Eq,
199 rhs,
200 right_schema,
201 )?;
202 Ok((lhs, rhs))
203 })
204 .collect::<Result<Vec<_>>>()?;
205
206 join.filter = join
208 .filter
209 .map(|expr| self.coerce_join_filter(expr))
210 .transpose()?;
211
212 Ok(LogicalPlan::Join(join))
213 }
214
215 pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
218 let union_schema = Arc::new(coerce_union_schema_with_schema(
219 &union_plan.inputs,
220 &union_plan.schema,
221 )?);
222 let new_inputs = union_plan
223 .inputs
224 .into_iter()
225 .map(|p| {
226 let plan =
227 coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
228 match plan {
229 LogicalPlan::Projection(Projection { expr, input, .. }) => {
230 Ok(Arc::new(project_with_column_index(
231 expr,
232 input,
233 Arc::clone(&union_schema),
234 )?))
235 }
236 other_plan => Ok(Arc::new(other_plan)),
237 }
238 })
239 .collect::<Result<Vec<_>>>()?;
240 Ok(LogicalPlan::Union(Union {
241 inputs: new_inputs,
242 schema: union_schema,
243 }))
244 }
245
246 fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
248 fn coerce_limit_expr(
249 expr: Expr,
250 schema: &DFSchema,
251 expr_name: &str,
252 ) -> Result<Expr> {
253 let dt = expr.get_type(schema)?;
254 if dt.is_integer() || dt.is_null() {
255 expr.cast_to(&DataType::Int64, schema)
256 } else {
257 plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
258 }
259 }
260
261 let empty_schema = DFSchema::empty();
262 let new_fetch = limit
263 .fetch
264 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
265 .transpose()?;
266 let new_skip = limit
267 .skip
268 .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
269 .transpose()?;
270 Ok(LogicalPlan::Limit(Limit {
271 input: limit.input,
272 fetch: new_fetch.map(Box::new),
273 skip: new_skip.map(Box::new),
274 }))
275 }
276
277 fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
278 let expr_type = expr.get_type(self.schema)?;
279 match expr_type {
280 DataType::Boolean => Ok(expr),
281 DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
282 other => plan_err!("Join condition must be boolean type, but got {other:?}"),
283 }
284 }
285
286 fn coerce_binary_op(
287 &self,
288 left: Expr,
289 left_schema: &DFSchema,
290 op: Operator,
291 right: Expr,
292 right_schema: &DFSchema,
293 ) -> Result<(Expr, Expr)> {
294 let (left_type, right_type) = BinaryTypeCoercer::new(
295 &left.get_type(left_schema)?,
296 &op,
297 &right.get_type(right_schema)?,
298 )
299 .get_input_types()?;
300
301 Ok((
302 left.cast_to(&left_type, left_schema)?,
303 right.cast_to(&right_type, right_schema)?,
304 ))
305 }
306}
307
308impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> {
309 type Node = Expr;
310 type Payload<'a> = &'a DFSchema;
311
312 fn f_up(&mut self, expr: Expr, schema: &DFSchema) -> Result<Transformed<Expr>> {
313 match expr {
314 Expr::Unnest(_) => not_impl_err!(
315 "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
316 ),
317 Expr::ScalarSubquery(Subquery {
318 subquery,
319 outer_ref_columns,
320 spans,
321 }) => {
322 let new_plan =
323 analyze_internal(schema, Arc::unwrap_or_clone(subquery))?.data;
324 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
325 subquery: Arc::new(new_plan),
326 outer_ref_columns,
327 spans,
328 })))
329 }
330 Expr::Exists(Exists { subquery, negated }) => {
331 let new_plan = analyze_internal(
332 schema,
333 Arc::unwrap_or_clone(subquery.subquery),
334 )?
335 .data;
336 Ok(Transformed::yes(Expr::Exists(Exists {
337 subquery: Subquery {
338 subquery: Arc::new(new_plan),
339 outer_ref_columns: subquery.outer_ref_columns,
340 spans: subquery.spans,
341 },
342 negated,
343 })))
344 }
345 Expr::InSubquery(InSubquery {
346 expr,
347 subquery,
348 negated,
349 }) => {
350 let new_plan = analyze_internal(
351 schema,
352 Arc::unwrap_or_clone(subquery.subquery),
353 )?
354 .data;
355 let expr_type = expr.get_type(schema)?;
356 let subquery_type = new_plan.schema().field(0).data_type();
357 let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
358 plan_datafusion_err!(
359 "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
360 ),
361 )?;
362 let new_subquery = Subquery {
363 subquery: Arc::new(new_plan),
364 outer_ref_columns: subquery.outer_ref_columns,
365 spans: subquery.spans,
366 };
367 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
368 Box::new(expr.cast_to(&common_type, schema)?),
369 cast_subquery(new_subquery, &common_type)?,
370 negated,
371 ))))
372 }
373 Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
374 *expr,
375 schema,
376 )?))),
377 Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
378 get_casted_expr_for_bool_op(*expr, schema)?,
379 ))),
380 Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
381 get_casted_expr_for_bool_op(*expr, schema)?,
382 ))),
383 Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
384 get_casted_expr_for_bool_op(*expr, schema)?,
385 ))),
386 Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
387 get_casted_expr_for_bool_op(*expr, schema)?,
388 ))),
389 Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
390 get_casted_expr_for_bool_op(*expr, schema)?,
391 ))),
392 Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
393 get_casted_expr_for_bool_op(*expr, schema)?,
394 ))),
395 Expr::Like(Like {
396 negated,
397 expr,
398 pattern,
399 escape_char,
400 case_insensitive,
401 }) => {
402 let left_type = expr.get_type(schema)?;
403 let right_type = pattern.get_type(schema)?;
404 let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
405 let op_name = if case_insensitive {
406 "ILIKE"
407 } else {
408 "LIKE"
409 };
410 plan_datafusion_err!(
411 "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
412 )
413 })?;
414 let expr = match left_type {
415 DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
416 _ => Box::new(expr.cast_to(&coerced_type, schema)?),
417 };
418 let pattern = Box::new(pattern.cast_to(&coerced_type, schema)?);
419 Ok(Transformed::yes(Expr::Like(Like::new(
420 negated,
421 expr,
422 pattern,
423 escape_char,
424 case_insensitive,
425 ))))
426 }
427 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
428 let (left, right) =
429 self.coerce_binary_op(*left, schema, op, *right, schema)?;
430 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
431 Box::new(left),
432 op,
433 Box::new(right),
434 ))))
435 }
436 Expr::Between(Between {
437 expr,
438 negated,
439 low,
440 high,
441 }) => {
442 let expr_type = expr.get_type(schema)?;
443 let low_type = low.get_type(schema)?;
444 let low_coerced_type = comparison_coercion(&expr_type, &low_type)
445 .ok_or_else(|| {
446 internal_datafusion_err!(
447 "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
448 )
449 })?;
450 let high_type = high.get_type(schema)?;
451 let high_coerced_type = comparison_coercion(&expr_type, &high_type)
452 .ok_or_else(|| {
453 internal_datafusion_err!(
454 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
455 )
456 })?;
457 let coercion_type =
458 comparison_coercion(&low_coerced_type, &high_coerced_type)
459 .ok_or_else(|| {
460 internal_datafusion_err!(
461 "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
462 )
463 })?;
464 Ok(Transformed::yes(Expr::Between(Between::new(
465 Box::new(expr.cast_to(&coercion_type, schema)?),
466 negated,
467 Box::new(low.cast_to(&coercion_type, schema)?),
468 Box::new(high.cast_to(&coercion_type, schema)?),
469 ))))
470 }
471 Expr::InList(InList {
472 expr,
473 list,
474 negated,
475 }) => {
476 let expr_data_type = expr.get_type(schema)?;
477 let list_data_types = list
478 .iter()
479 .map(|list_expr| list_expr.get_type(schema))
480 .collect::<Result<Vec<_>>>()?;
481 let result_type =
482 get_coerce_type_for_list(&expr_data_type, &list_data_types);
483 match result_type {
484 None => plan_err!(
485 "Can not find compatible types to compare {expr_data_type} with [{}]", list_data_types.iter().join(", ")
486 ),
487 Some(coerced_type) => {
488 let cast_expr = expr.cast_to(&coerced_type, schema)?;
490 let cast_list_expr = list
491 .into_iter()
492 .map(|list_expr| {
493 list_expr.cast_to(&coerced_type, schema)
494 })
495 .collect::<Result<Vec<_>>>()?;
496 Ok(Transformed::yes(Expr::InList(InList ::new(
497 Box::new(cast_expr),
498 cast_list_expr,
499 negated,
500 ))))
501 }
502 }
503 }
504 Expr::Case(case) => {
505 let case = coerce_case_expression(case, schema)?;
506 Ok(Transformed::yes(Expr::Case(case)))
507 }
508 Expr::ScalarFunction(ScalarFunction { func, args }) => {
509 let new_expr = coerce_arguments_for_signature_with_scalar_udf(
510 args,
511 schema,
512 &func,
513 )?;
514 Ok(Transformed::yes(Expr::ScalarFunction(
515 ScalarFunction::new_udf(func, new_expr),
516 )))
517 }
518 Expr::AggregateFunction(expr::AggregateFunction {
519 func,
520 params:
521 AggregateFunctionParams {
522 args,
523 distinct,
524 filter,
525 order_by,
526 null_treatment,
527 },
528 }) => {
529 let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
530 args,
531 schema,
532 &func,
533 )?;
534 Ok(Transformed::yes(Expr::AggregateFunction(
535 expr::AggregateFunction::new_udf(
536 func,
537 new_expr,
538 distinct,
539 filter,
540 order_by,
541 null_treatment,
542 ),
543 )))
544 }
545 Expr::WindowFunction(window_fun) => {
546 let WindowFunction {
547 fun,
548 params:
549 expr::WindowFunctionParams {
550 args,
551 partition_by,
552 order_by,
553 window_frame,
554 filter,
555 null_treatment,
556 distinct,
557 },
558 } = *window_fun;
559 let window_frame =
560 coerce_window_frame(window_frame, schema, &order_by)?;
561
562 let args = match &fun {
563 expr::WindowFunctionDefinition::AggregateUDF(udf) => {
564 coerce_arguments_for_signature_with_aggregate_udf(
565 args,
566 schema,
567 udf,
568 )?
569 }
570 _ => args,
571 };
572
573 let new_expr = Expr::from(WindowFunction {
574 fun,
575 params: expr::WindowFunctionParams {
576 args,
577 partition_by,
578 order_by,
579 window_frame,
580 filter,
581 null_treatment,
582 distinct,
583 },
584 });
585 Ok(Transformed::yes(new_expr))
586 }
587 #[expect(deprecated)]
589 Expr::Alias(_)
590 | Expr::Column(_)
591 | Expr::ScalarVariable(_, _)
592 | Expr::Literal(_, _)
593 | Expr::SimilarTo(_)
594 | Expr::IsNotNull(_)
595 | Expr::IsNull(_)
596 | Expr::Negative(_)
597 | Expr::Cast(_)
598 | Expr::TryCast(_)
599 | Expr::Wildcard { .. }
600 | Expr::GroupingSet(_)
601 | Expr::Placeholder(_)
602 | Expr::OuterReferenceColumn(_, _)
603 | Expr::Lambda { .. } => Ok(Transformed::no(expr)),
604 }
605 }
606}
607
608fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
610 let metadata = dfschema.as_arrow().metadata.clone();
611 let mut transformed = false;
612
613 let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
614 dfschema
615 .iter()
616 .map(|(qualifier, field)| match field.data_type() {
617 DataType::Utf8View => {
618 transformed = true;
619 (
620 qualifier.cloned() as Option<TableReference>,
621 Arc::new(Field::new(
622 field.name(),
623 DataType::LargeUtf8,
624 field.is_nullable(),
625 )),
626 )
627 }
628 DataType::BinaryView => {
629 transformed = true;
630 (
631 qualifier.cloned() as Option<TableReference>,
632 Arc::new(Field::new(
633 field.name(),
634 DataType::LargeBinary,
635 field.is_nullable(),
636 )),
637 )
638 }
639 _ => (
640 qualifier.cloned() as Option<TableReference>,
641 Arc::clone(field),
642 ),
643 })
644 .unzip();
645
646 if !transformed {
647 return None;
648 }
649
650 let schema = Schema::new_with_metadata(transformed_fields, metadata);
651 Some(DFSchema::from_field_specific_qualified_schema(
652 qualifiers,
653 &Arc::new(schema),
654 ))
655}
656
657fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
660 match value {
661 ScalarValue::Utf8(Some(val)) => {
663 ScalarValue::try_from_string(val.clone(), target_type)
664 }
665 s => {
666 if s.is_null() {
667 ScalarValue::try_from(target_type)
669 } else {
670 Ok(s.clone())
674 }
675 }
676 }
677}
678
679fn coerce_scalar_range_aware(
686 target_type: &DataType,
687 value: &ScalarValue,
688) -> Result<ScalarValue> {
689 coerce_scalar(target_type, value).or_else(|err| {
690 if let Some(largest_type) = get_widest_type_in_family(target_type) {
692 coerce_scalar(largest_type, value).map_or_else(
693 |_| exec_err!("Cannot cast {value:?} to {target_type}"),
694 |_| ScalarValue::try_from(target_type),
695 )
696 } else {
697 Err(err)
698 }
699 })
700}
701
702fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
706 match given_type {
707 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
708 DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
709 DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
710 _ => None,
711 }
712}
713
714fn coerce_frame_bound(
716 target_type: &DataType,
717 bound: WindowFrameBound,
718) -> Result<WindowFrameBound> {
719 match bound {
720 WindowFrameBound::Preceding(v) => {
721 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
722 }
723 WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
724 WindowFrameBound::Following(v) => {
725 coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
726 }
727 }
728}
729
730fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
731 if col_type.is_numeric()
732 || is_utf8_or_utf8view_or_large_utf8(col_type)
733 || matches!(col_type, DataType::List(_))
734 || matches!(col_type, DataType::LargeList(_))
735 || matches!(col_type, DataType::FixedSizeList(_, _))
736 || matches!(col_type, DataType::Null)
737 || matches!(col_type, DataType::Boolean)
738 {
739 Ok(col_type.clone())
740 } else if is_datetime(col_type) {
741 Ok(DataType::Interval(IntervalUnit::MonthDayNano))
742 } else if let DataType::Dictionary(_, value_type) = col_type {
743 extract_window_frame_target_type(value_type)
744 } else {
745 internal_err!("Cannot run range queries on datatype: {col_type}")
746 }
747}
748
749fn coerce_window_frame(
752 window_frame: WindowFrame,
753 schema: &DFSchema,
754 expressions: &[Sort],
755) -> Result<WindowFrame> {
756 let mut window_frame = window_frame;
757 let target_type = match window_frame.units {
758 WindowFrameUnits::Range => {
759 let current_types = expressions
760 .first()
761 .map(|s| s.expr.get_type(schema))
762 .transpose()?;
763 if let Some(col_type) = current_types {
764 extract_window_frame_target_type(&col_type)?
765 } else {
766 return internal_err!("ORDER BY column cannot be empty");
767 }
768 }
769 WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
770 };
771 window_frame.start_bound =
772 coerce_frame_bound(&target_type, window_frame.start_bound)?;
773 window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
774 Ok(window_frame)
775}
776
777fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
780 let left_type = expr.get_type(schema)?;
781 BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
782 .get_input_types()?;
783 expr.cast_to(&DataType::Boolean, schema)
784}
785
786fn coerce_arguments_for_signature_with_scalar_udf(
791 expressions: Vec<Expr>,
792 schema: &DFSchema,
793 func: &ScalarUDF,
794) -> Result<Vec<Expr>> {
795 if expressions.is_empty() {
796 return Ok(expressions);
797 }
798
799 let current_types = expressions.iter()
800 .map(|e| match e {
801 Expr::Lambda { .. } => Ok(DataType::Null),
802 _ => e.get_type(schema),
803 })
804 .collect::<Result<Vec<_>>>()?;
805
806 let new_types = data_types_with_scalar_udf(¤t_types, func)?;
807
808 expressions
809 .into_iter()
810 .enumerate()
811 .map(|(i, expr)| match expr {
812 lambda @ Expr::Lambda { .. } => Ok(lambda),
813 _ => expr.cast_to(&new_types[i], schema),
814 })
815 .collect()
816}
817
818fn coerce_arguments_for_signature_with_aggregate_udf(
823 expressions: Vec<Expr>,
824 schema: &DFSchema,
825 func: &AggregateUDF,
826) -> Result<Vec<Expr>> {
827 if expressions.is_empty() {
828 return Ok(expressions);
829 }
830
831 let current_fields = expressions
832 .iter()
833 .map(|e| e.to_field(schema).map(|(_, f)| f))
834 .collect::<Result<Vec<_>>>()?;
835
836 let new_types = fields_with_aggregate_udf(¤t_fields, func)?
837 .into_iter()
838 .map(|f| f.data_type().clone())
839 .collect::<Vec<_>>();
840
841 expressions
842 .into_iter()
843 .enumerate()
844 .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
845 .collect()
846}
847
848fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
849 let case_type = case
881 .expr
882 .as_ref()
883 .map(|expr| expr.get_type(schema))
884 .transpose()?;
885 let then_types = case
886 .when_then_expr
887 .iter()
888 .map(|(_when, then)| then.get_type(schema))
889 .collect::<Result<Vec<_>>>()?;
890 let else_type = case
891 .else_expr
892 .as_ref()
893 .map(|expr| expr.get_type(schema))
894 .transpose()?;
895
896 let case_when_coerce_type = case_type
898 .as_ref()
899 .map(|case_type| {
900 let when_types = case
901 .when_then_expr
902 .iter()
903 .map(|(when, _then)| when.get_type(schema))
904 .collect::<Result<Vec<_>>>()?;
905 let coerced_type =
906 get_coerce_type_for_case_expression(&when_types, Some(case_type));
907 coerced_type.ok_or_else(|| {
908 plan_datafusion_err!(
909 "Failed to coerce case ({case_type}) and when ({}) \
910 to common types in CASE WHEN expression",
911 when_types.iter().join(", ")
912 )
913 })
914 })
915 .transpose()?;
916 let then_else_coerce_type =
917 get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
918 || {
919 if let Some(else_type) = else_type {
920 plan_datafusion_err!(
921 "Failed to coerce then ({}) and else ({else_type}) \
922 to common types in CASE WHEN expression",
923 then_types.iter().join(", ")
924 )
925 } else {
926 plan_datafusion_err!(
927 "Failed to coerce then ({}) and else (None) \
928 to common types in CASE WHEN expression",
929 then_types.iter().join(", ")
930 )
931 }
932 },
933 )?;
934
935 let case_expr = case
937 .expr
938 .zip(case_when_coerce_type.as_ref())
939 .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
940 .transpose()?
941 .map(Box::new);
942 let when_then = case
943 .when_then_expr
944 .into_iter()
945 .map(|(when, then)| {
946 let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
947 let when = when.cast_to(when_type, schema).map_err(|e| {
948 DataFusionError::Context(
949 format!(
950 "WHEN expressions in CASE couldn't be \
951 converted to common type ({when_type})"
952 ),
953 Box::new(e),
954 )
955 })?;
956 let then = then.cast_to(&then_else_coerce_type, schema)?;
957 Ok((Box::new(when), Box::new(then)))
958 })
959 .collect::<Result<Vec<_>>>()?;
960 let else_expr = case
961 .else_expr
962 .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
963 .transpose()?
964 .map(Box::new);
965
966 Ok(Case::new(case_expr, when_then, else_expr))
967}
968
969pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1011 coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1012}
1013fn coerce_union_schema_with_schema(
1014 inputs: &[Arc<LogicalPlan>],
1015 base_schema: &DFSchemaRef,
1016) -> Result<DFSchema> {
1017 let mut union_datatypes = base_schema
1018 .fields()
1019 .iter()
1020 .map(|f| f.data_type().clone())
1021 .collect::<Vec<_>>();
1022 let mut union_nullabilities = base_schema
1023 .fields()
1024 .iter()
1025 .map(|f| f.is_nullable())
1026 .collect::<Vec<_>>();
1027 let mut union_field_meta = base_schema
1028 .fields()
1029 .iter()
1030 .map(|f| f.metadata().clone())
1031 .collect::<Vec<_>>();
1032
1033 let mut metadata = base_schema.metadata().clone();
1034
1035 for (i, plan) in inputs.iter().enumerate() {
1036 let plan_schema = plan.schema();
1037 metadata.extend(plan_schema.metadata().clone());
1038
1039 if plan_schema.fields().len() != base_schema.fields().len() {
1040 return plan_err!(
1041 "Union schemas have different number of fields: \
1042 query 1 has {} fields whereas query {} has {} fields",
1043 base_schema.fields().len(),
1044 i + 1,
1045 plan_schema.fields().len()
1046 );
1047 }
1048
1049 for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1051 union_datatypes.iter_mut(),
1052 union_nullabilities.iter_mut(),
1053 union_field_meta.iter_mut(),
1054 plan_schema.fields().iter()
1055 ) {
1056 let coerced_type =
1057 comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1058 || {
1059 plan_datafusion_err!(
1060 "Incompatible inputs for Union: Previous inputs were \
1061 of type {}, but got incompatible type {} on column '{}'",
1062 union_datatype,
1063 plan_field.data_type(),
1064 plan_field.name()
1065 )
1066 },
1067 )?;
1068
1069 *union_datatype = coerced_type;
1070 *union_nullable = *union_nullable || plan_field.is_nullable();
1071 union_field_map.extend(plan_field.metadata().clone());
1072 }
1073 }
1074 let union_qualified_fields = izip!(
1075 base_schema.fields(),
1076 union_datatypes.into_iter(),
1077 union_nullabilities,
1078 union_field_meta.into_iter()
1079 )
1080 .map(|(field, datatype, nullable, metadata)| {
1081 let mut field = Field::new(field.name().clone(), datatype, nullable);
1082 field.set_metadata(metadata);
1083 (None, field.into())
1084 })
1085 .collect::<Vec<_>>();
1086
1087 DFSchema::new_with_metadata(union_qualified_fields, metadata)
1088}
1089
1090fn project_with_column_index(
1092 expr: Vec<Expr>,
1093 input: Arc<LogicalPlan>,
1094 schema: DFSchemaRef,
1095) -> Result<LogicalPlan> {
1096 let alias_expr = expr
1097 .into_iter()
1098 .enumerate()
1099 .map(|(i, e)| match e {
1100 Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1101 Ok(e.unalias().alias(schema.field(i).name()))
1102 }
1103 Expr::Column(Column {
1104 relation: _,
1105 ref name,
1106 spans: _,
1107 }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1108 Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1109 #[expect(deprecated)]
1110 Expr::Wildcard { .. } => {
1111 plan_err!("Wildcard should be expanded before type coercion")
1112 }
1113 _ => Ok(e.alias(schema.field(i).name())),
1114 })
1115 .collect::<Result<Vec<_>>>()?;
1116
1117 Projection::try_new_with_schema(alias_expr, input, schema)
1118 .map(LogicalPlan::Projection)
1119}
1120
1121#[cfg(test)]
1122mod test {
1123 use std::any::Any;
1124 use std::sync::Arc;
1125
1126 use arrow::datatypes::DataType::Utf8;
1127 use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1128 use insta::assert_snapshot;
1129
1130 use crate::analyzer::type_coercion::{
1131 coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1132 };
1133 use crate::analyzer::Analyzer;
1134 use crate::assert_analyzed_plan_with_config_eq_snapshot;
1135 use datafusion_common::config::ConfigOptions;
1136 use datafusion_common::tree_node::{TransformedResult};
1137 use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1138 use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1139 use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1140 use datafusion_expr::test::function_stub::avg_udaf;
1141 use datafusion_expr::{
1142 cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1143 BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1144 Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1145 SimpleAggregateUDF, Subquery, Union, Volatility,
1146 };
1147 use datafusion_functions_aggregate::average::AvgAccumulator;
1148 use datafusion_sql::TableReference;
1149
1150 fn empty() -> Arc<LogicalPlan> {
1151 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1152 produce_one_row: false,
1153 schema: Arc::new(DFSchema::empty()),
1154 }))
1155 }
1156
1157 fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1158 Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1159 produce_one_row: false,
1160 schema: Arc::new(
1161 DFSchema::from_unqualified_fields(
1162 vec![Field::new("a", data_type, true)].into(),
1163 std::collections::HashMap::new(),
1164 )
1165 .unwrap(),
1166 ),
1167 }))
1168 }
1169
1170 macro_rules! assert_analyzed_plan_eq {
1171 (
1172 $plan: expr,
1173 @ $expected: literal $(,)?
1174 ) => {{
1175 let options = ConfigOptions::default();
1176 let rule = Arc::new(TypeCoercion::new());
1177 assert_analyzed_plan_with_config_eq_snapshot!(
1178 options,
1179 rule,
1180 $plan,
1181 @ $expected,
1182 )
1183 }};
1184 }
1185
1186 macro_rules! coerce_on_output_if_viewtype {
1187 (
1188 $is_viewtype: expr,
1189 $plan: expr,
1190 @ $expected: literal $(,)?
1191 ) => {{
1192 let mut options = ConfigOptions::default();
1193 if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1195 let rule = Arc::new(TypeCoercion::new());
1196
1197 assert_analyzed_plan_with_config_eq_snapshot!(
1198 options,
1199 rule,
1200 $plan,
1201 @ $expected,
1202 )
1203 }};
1204 }
1205
1206 fn assert_type_coercion_error(
1207 plan: LogicalPlan,
1208 expected_substr: &str,
1209 ) -> Result<()> {
1210 let options = ConfigOptions::default();
1211 let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1212
1213 match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1214 Ok(succeeded_plan) => {
1215 panic!(
1216 "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1217 );
1218 }
1219 Err(e) => {
1220 let msg = e.to_string();
1221 assert!(
1222 msg.contains(expected_substr),
1223 "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`"
1224 );
1225 }
1226 }
1227
1228 Ok(())
1229 }
1230
1231 #[test]
1232 fn simple_case() -> Result<()> {
1233 let expr = col("a").lt(lit(2_u32));
1234 let empty = empty_with_type(DataType::Float64);
1235 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1236
1237 assert_analyzed_plan_eq!(
1238 plan,
1239 @r"
1240 Projection: a < CAST(UInt32(2) AS Float64)
1241 EmptyRelation: rows=0
1242 "
1243 )
1244 }
1245
1246 #[test]
1247 fn test_coerce_union() -> Result<()> {
1248 let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1249 produce_one_row: false,
1250 schema: Arc::new(
1251 DFSchema::try_from_qualified_schema(
1252 TableReference::full("datafusion", "test", "foo"),
1253 &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1254 )
1255 .unwrap(),
1256 ),
1257 }));
1258 let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1259 produce_one_row: false,
1260 schema: Arc::new(
1261 DFSchema::try_from_qualified_schema(
1262 TableReference::full("datafusion", "test", "foo"),
1263 &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1264 )
1265 .unwrap(),
1266 ),
1267 }));
1268 let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1269 left_plan, right_plan,
1270 ])?);
1271 let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1272 .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1273 let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1274 vec![col("a")],
1275 Arc::new(analyzed_union),
1276 )?);
1277
1278 assert_analyzed_plan_eq!(
1279 top_level_plan,
1280 @r"
1281 Projection: a
1282 Union
1283 Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1284 EmptyRelation: rows=0
1285 EmptyRelation: rows=0
1286 "
1287 )
1288 }
1289
1290 #[test]
1291 fn coerce_utf8view_output() -> Result<()> {
1292 let expr = col("a");
1295 let empty = empty_with_type(DataType::Utf8View);
1296 let plan = LogicalPlan::Projection(Projection::try_new(
1297 vec![expr.clone()],
1298 Arc::clone(&empty),
1299 )?);
1300
1301 coerce_on_output_if_viewtype!(
1303 false,
1304 plan.clone(),
1305 @r"
1306 Projection: a
1307 EmptyRelation: rows=0
1308 "
1309 )?;
1310
1311 coerce_on_output_if_viewtype!(
1313 true,
1314 plan.clone(),
1315 @r"
1316 Projection: CAST(a AS LargeUtf8)
1317 EmptyRelation: rows=0
1318 "
1319 )?;
1320
1321 let bool_expr = col("a").lt(lit("foo"));
1324 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1325 vec![bool_expr],
1326 Arc::clone(&empty),
1327 )?);
1328 coerce_on_output_if_viewtype!(
1330 false,
1331 bool_plan.clone(),
1332 @r#"
1333 Projection: a < CAST(Utf8("foo") AS Utf8View)
1334 EmptyRelation: rows=0
1335 "#
1336 )?;
1337
1338 coerce_on_output_if_viewtype!(
1339 false,
1340 plan.clone(),
1341 @r"
1342 Projection: a
1343 EmptyRelation: rows=0
1344 "
1345 )?;
1346
1347 coerce_on_output_if_viewtype!(
1349 true,
1350 plan.clone(),
1351 @r"
1352 Projection: CAST(a AS LargeUtf8)
1353 EmptyRelation: rows=0
1354 "
1355 )?;
1356
1357 let sort_expr = expr.sort(true, true);
1360 let sort_plan = LogicalPlan::Sort(Sort {
1361 expr: vec![sort_expr],
1362 input: Arc::new(plan),
1363 fetch: None,
1364 });
1365
1366 coerce_on_output_if_viewtype!(
1368 false,
1369 sort_plan.clone(),
1370 @r"
1371 Sort: a ASC NULLS FIRST
1372 Projection: a
1373 EmptyRelation: rows=0
1374 "
1375 )?;
1376
1377 coerce_on_output_if_viewtype!(
1379 true,
1380 sort_plan.clone(),
1381 @r"
1382 Projection: CAST(a AS LargeUtf8)
1383 Sort: a ASC NULLS FIRST
1384 Projection: a
1385 EmptyRelation: rows=0
1386 "
1387 )?;
1388
1389 let plan = LogicalPlan::Projection(Projection::try_new(
1392 vec![col("a")],
1393 Arc::new(sort_plan),
1394 )?);
1395 coerce_on_output_if_viewtype!(
1397 false,
1398 plan.clone(),
1399 @r"
1400 Projection: a
1401 Sort: a ASC NULLS FIRST
1402 Projection: a
1403 EmptyRelation: rows=0
1404 "
1405 )?;
1406 coerce_on_output_if_viewtype!(
1408 true,
1409 plan.clone(),
1410 @r"
1411 Projection: CAST(a AS LargeUtf8)
1412 Sort: a ASC NULLS FIRST
1413 Projection: a
1414 EmptyRelation: rows=0
1415 "
1416 )?;
1417
1418 Ok(())
1419 }
1420
1421 #[test]
1422 fn coerce_binaryview_output() -> Result<()> {
1423 let expr = col("a");
1426 let empty = empty_with_type(DataType::BinaryView);
1427 let plan = LogicalPlan::Projection(Projection::try_new(
1428 vec![expr.clone()],
1429 Arc::clone(&empty),
1430 )?);
1431
1432 coerce_on_output_if_viewtype!(
1434 false,
1435 plan.clone(),
1436 @r"
1437 Projection: a
1438 EmptyRelation: rows=0
1439 "
1440 )?;
1441
1442 coerce_on_output_if_viewtype!(
1444 true,
1445 plan.clone(),
1446 @r"
1447 Projection: CAST(a AS LargeBinary)
1448 EmptyRelation: rows=0
1449 "
1450 )?;
1451
1452 let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1455 let bool_plan = LogicalPlan::Projection(Projection::try_new(
1456 vec![bool_expr],
1457 Arc::clone(&empty),
1458 )?);
1459
1460 coerce_on_output_if_viewtype!(
1462 false,
1463 bool_plan.clone(),
1464 @r#"
1465 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1466 EmptyRelation: rows=0
1467 "#
1468 )?;
1469
1470 coerce_on_output_if_viewtype!(
1472 true,
1473 bool_plan.clone(),
1474 @r#"
1475 Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1476 EmptyRelation: rows=0
1477 "#
1478 )?;
1479
1480 let sort_expr = expr.sort(true, true);
1483 let sort_plan = LogicalPlan::Sort(Sort {
1484 expr: vec![sort_expr],
1485 input: Arc::new(plan),
1486 fetch: None,
1487 });
1488
1489 coerce_on_output_if_viewtype!(
1491 false,
1492 sort_plan.clone(),
1493 @r"
1494 Sort: a ASC NULLS FIRST
1495 Projection: a
1496 EmptyRelation: rows=0
1497 "
1498 )?;
1499 coerce_on_output_if_viewtype!(
1501 true,
1502 sort_plan.clone(),
1503 @r"
1504 Projection: CAST(a AS LargeBinary)
1505 Sort: a ASC NULLS FIRST
1506 Projection: a
1507 EmptyRelation: rows=0
1508 "
1509 )?;
1510
1511 let plan = LogicalPlan::Projection(Projection::try_new(
1514 vec![col("a")],
1515 Arc::new(sort_plan),
1516 )?);
1517
1518 coerce_on_output_if_viewtype!(
1520 false,
1521 plan.clone(),
1522 @r"
1523 Projection: a
1524 Sort: a ASC NULLS FIRST
1525 Projection: a
1526 EmptyRelation: rows=0
1527 "
1528 )?;
1529
1530 coerce_on_output_if_viewtype!(
1532 true,
1533 plan.clone(),
1534 @r"
1535 Projection: CAST(a AS LargeBinary)
1536 Sort: a ASC NULLS FIRST
1537 Projection: a
1538 EmptyRelation: rows=0
1539 "
1540 )?;
1541
1542 Ok(())
1543 }
1544
1545 #[test]
1546 fn nested_case() -> Result<()> {
1547 let expr = col("a").lt(lit(2_u32));
1548 let empty = empty_with_type(DataType::Float64);
1549
1550 let plan = LogicalPlan::Projection(Projection::try_new(
1551 vec![expr.clone().or(expr)],
1552 empty,
1553 )?);
1554
1555 assert_analyzed_plan_eq!(
1556 plan,
1557 @r"
1558 Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1559 EmptyRelation: rows=0
1560 "
1561 )
1562 }
1563
1564 #[derive(Debug, PartialEq, Eq, Hash)]
1565 struct TestScalarUDF {
1566 signature: Signature,
1567 }
1568
1569 impl ScalarUDFImpl for TestScalarUDF {
1570 fn as_any(&self) -> &dyn Any {
1571 self
1572 }
1573
1574 fn name(&self) -> &str {
1575 "TestScalarUDF"
1576 }
1577
1578 fn signature(&self) -> &Signature {
1579 &self.signature
1580 }
1581
1582 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1583 Ok(Utf8)
1584 }
1585
1586 fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1587 Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1588 }
1589 }
1590
1591 #[test]
1592 fn scalar_udf() -> Result<()> {
1593 let empty = empty();
1594
1595 let udf = ScalarUDF::from(TestScalarUDF {
1596 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1597 })
1598 .call(vec![lit(123_i32)]);
1599 let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1600
1601 assert_analyzed_plan_eq!(
1602 plan,
1603 @r"
1604 Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1605 EmptyRelation: rows=0
1606 "
1607 )
1608 }
1609
1610 #[test]
1611 fn scalar_udf_invalid_input() -> Result<()> {
1612 let empty = empty();
1613 let udf = ScalarUDF::from(TestScalarUDF {
1614 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1615 })
1616 .call(vec![lit("Apple")]);
1617 Projection::try_new(vec![udf], empty)
1618 .expect_err("Expected an error due to incorrect function input");
1619
1620 Ok(())
1621 }
1622
1623 #[test]
1624 fn scalar_function() -> Result<()> {
1625 let empty = empty();
1627 let lit_expr = lit(10i64);
1628 let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1629 signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1630 });
1631 let scalar_function_expr =
1632 Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1633 let plan = LogicalPlan::Projection(Projection::try_new(
1634 vec![scalar_function_expr],
1635 empty,
1636 )?);
1637
1638 assert_analyzed_plan_eq!(
1639 plan,
1640 @r"
1641 Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1642 EmptyRelation: rows=0
1643 "
1644 )
1645 }
1646
1647 #[test]
1648 fn agg_udaf() -> Result<()> {
1649 let empty = empty();
1650 let my_avg = create_udaf(
1651 "MY_AVG",
1652 vec![DataType::Float64],
1653 Arc::new(DataType::Float64),
1654 Volatility::Immutable,
1655 Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1656 Arc::new(vec![DataType::UInt64, DataType::Float64]),
1657 );
1658 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1659 Arc::new(my_avg),
1660 vec![lit(10i64)],
1661 false,
1662 None,
1663 vec![],
1664 None,
1665 ));
1666 let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1667
1668 assert_analyzed_plan_eq!(
1669 plan,
1670 @r"
1671 Projection: MY_AVG(CAST(Int64(10) AS Float64))
1672 EmptyRelation: rows=0
1673 "
1674 )
1675 }
1676
1677 #[test]
1678 fn agg_udaf_invalid_input() -> Result<()> {
1679 let empty = empty();
1680 let return_type = DataType::Float64;
1681 let accumulator: AccumulatorFactoryFunction =
1682 Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1683 let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1684 "MY_AVG",
1685 Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1686 return_type,
1687 accumulator,
1688 vec![
1689 Field::new("count", DataType::UInt64, true).into(),
1690 Field::new("avg", DataType::Float64, true).into(),
1691 ],
1692 ));
1693 let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1694 Arc::new(my_avg),
1695 vec![lit("10")],
1696 false,
1697 None,
1698 vec![],
1699 None,
1700 ));
1701
1702 let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1703 assert!(
1704 err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1705 );
1706 Ok(())
1707 }
1708
1709 #[test]
1710 fn agg_function_case() -> Result<()> {
1711 let empty = empty();
1712 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1713 avg_udaf(),
1714 vec![lit(12f64)],
1715 false,
1716 None,
1717 vec![],
1718 None,
1719 ));
1720 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1721
1722 assert_analyzed_plan_eq!(
1723 plan,
1724 @r"
1725 Projection: avg(Float64(12))
1726 EmptyRelation: rows=0
1727 "
1728 )?;
1729
1730 let empty = empty_with_type(DataType::Int32);
1731 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1732 avg_udaf(),
1733 vec![cast(col("a"), DataType::Float64)],
1734 false,
1735 None,
1736 vec![],
1737 None,
1738 ));
1739 let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1740
1741 assert_analyzed_plan_eq!(
1742 plan,
1743 @r"
1744 Projection: avg(CAST(a AS Float64))
1745 EmptyRelation: rows=0
1746 "
1747 )
1748 }
1749
1750 #[test]
1751 fn agg_function_invalid_input_avg() -> Result<()> {
1752 let empty = empty();
1753 let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1754 avg_udaf(),
1755 vec![lit("1")],
1756 false,
1757 None,
1758 vec![],
1759 None,
1760 ));
1761 let err = Projection::try_new(vec![agg_expr], empty)
1762 .err()
1763 .unwrap()
1764 .strip_backtrace();
1765 assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1766 Ok(())
1767 }
1768
1769 #[test]
1770 fn binary_op_date32_op_interval() -> Result<()> {
1771 let expr = cast(lit("1998-03-18"), DataType::Date32)
1773 + lit(ScalarValue::new_interval_dt(123, 456));
1774 let empty = empty();
1775 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1776
1777 assert_analyzed_plan_eq!(
1778 plan,
1779 @r#"
1780 Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1781 EmptyRelation: rows=0
1782 "#
1783 )
1784 }
1785
1786 #[test]
1787 fn inlist_case() -> Result<()> {
1788 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1790 let empty = empty_with_type(DataType::Int64);
1791 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1792 assert_analyzed_plan_eq!(
1793 plan,
1794 @r"
1795 Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1796 EmptyRelation: rows=0
1797 ")?;
1798
1799 let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1801 let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1802 produce_one_row: false,
1803 schema: Arc::new(DFSchema::from_unqualified_fields(
1804 vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1805 std::collections::HashMap::new(),
1806 )?),
1807 }));
1808 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1809 assert_analyzed_plan_eq!(
1810 plan,
1811 @r"
1812 Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1813 EmptyRelation: rows=0
1814 ")
1815 }
1816
1817 #[test]
1818 fn between_case() -> Result<()> {
1819 let expr = col("a").between(
1820 lit("2002-05-08"),
1821 cast(lit("2002-05-08"), DataType::Date32)
1823 + lit(ScalarValue::new_interval_ym(0, 1)),
1824 );
1825 let empty = empty_with_type(Utf8);
1826 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1827
1828 assert_analyzed_plan_eq!(
1829 plan,
1830 @r#"
1831 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1832 EmptyRelation: rows=0
1833 "#
1834 )
1835 }
1836
1837 #[test]
1838 fn between_infer_cheap_type() -> Result<()> {
1839 let expr = col("a").between(
1840 cast(lit("2002-05-08"), DataType::Date32)
1842 + lit(ScalarValue::new_interval_ym(0, 1)),
1843 lit("2002-12-08"),
1844 );
1845 let empty = empty_with_type(Utf8);
1846 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1847
1848 assert_analyzed_plan_eq!(
1850 plan,
1851 @r#"
1852 Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1853 EmptyRelation: rows=0
1854 "#
1855 )
1856 }
1857
1858 #[test]
1859 fn between_null() -> Result<()> {
1860 let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1861 let empty = empty();
1862 let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1863
1864 assert_analyzed_plan_eq!(
1865 plan,
1866 @r"
1867 Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1868 EmptyRelation: rows=0
1869 "
1870 )
1871 }
1872
1873 #[test]
1874 fn is_bool_for_type_coercion() -> Result<()> {
1875 let expr = col("a").is_true();
1877 let empty = empty_with_type(DataType::Boolean);
1878 let plan =
1879 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1880
1881 assert_analyzed_plan_eq!(
1882 plan,
1883 @r"
1884 Projection: a IS TRUE
1885 EmptyRelation: rows=0
1886 "
1887 )?;
1888
1889 let empty = empty_with_type(DataType::Int64);
1890 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1891 assert_type_coercion_error(
1892 plan,
1893 "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1894 )?;
1895
1896 let expr = col("a").is_not_true();
1898 let empty = empty_with_type(DataType::Boolean);
1899 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1900
1901 assert_analyzed_plan_eq!(
1902 plan,
1903 @r"
1904 Projection: a IS NOT TRUE
1905 EmptyRelation: rows=0
1906 "
1907 )?;
1908
1909 let expr = col("a").is_false();
1911 let empty = empty_with_type(DataType::Boolean);
1912 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1913
1914 assert_analyzed_plan_eq!(
1915 plan,
1916 @r"
1917 Projection: a IS FALSE
1918 EmptyRelation: rows=0
1919 "
1920 )?;
1921
1922 let expr = col("a").is_not_false();
1924 let empty = empty_with_type(DataType::Boolean);
1925 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1926
1927 assert_analyzed_plan_eq!(
1928 plan,
1929 @r"
1930 Projection: a IS NOT FALSE
1931 EmptyRelation: rows=0
1932 "
1933 )
1934 }
1935
1936 #[test]
1937 fn like_for_type_coercion() -> Result<()> {
1938 let expr = Box::new(col("a"));
1940 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1941 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1942 let empty = empty_with_type(Utf8);
1943 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1944
1945 assert_analyzed_plan_eq!(
1946 plan,
1947 @r#"
1948 Projection: a LIKE Utf8("abc")
1949 EmptyRelation: rows=0
1950 "#
1951 )?;
1952
1953 let expr = Box::new(col("a"));
1954 let pattern = Box::new(lit(ScalarValue::Null));
1955 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1956 let empty = empty_with_type(Utf8);
1957 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1958
1959 assert_analyzed_plan_eq!(
1960 plan,
1961 @r"
1962 Projection: a LIKE CAST(NULL AS Utf8)
1963 EmptyRelation: rows=0
1964 "
1965 )?;
1966
1967 let expr = Box::new(col("a"));
1968 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1969 let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1970 let empty = empty_with_type(DataType::Int64);
1971 let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1972 assert_type_coercion_error(
1973 plan,
1974 "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1975 )?;
1976
1977 let expr = Box::new(col("a"));
1979 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1980 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1981 let empty = empty_with_type(Utf8);
1982 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1983
1984 assert_analyzed_plan_eq!(
1985 plan,
1986 @r#"
1987 Projection: a ILIKE Utf8("abc")
1988 EmptyRelation: rows=0
1989 "#
1990 )?;
1991
1992 let expr = Box::new(col("a"));
1993 let pattern = Box::new(lit(ScalarValue::Null));
1994 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1995 let empty = empty_with_type(Utf8);
1996 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1997
1998 assert_analyzed_plan_eq!(
1999 plan,
2000 @r"
2001 Projection: a ILIKE CAST(NULL AS Utf8)
2002 EmptyRelation: rows=0
2003 "
2004 )?;
2005
2006 let expr = Box::new(col("a"));
2007 let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2008 let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2009 let empty = empty_with_type(DataType::Int64);
2010 let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2011 assert_type_coercion_error(
2012 plan,
2013 "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2014 )?;
2015
2016 Ok(())
2017 }
2018
2019 #[test]
2020 fn unknown_for_type_coercion() -> Result<()> {
2021 let expr = col("a").is_unknown();
2023 let empty = empty_with_type(DataType::Boolean);
2024 let plan =
2025 LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2026
2027 assert_analyzed_plan_eq!(
2028 plan,
2029 @r"
2030 Projection: a IS UNKNOWN
2031 EmptyRelation: rows=0
2032 "
2033 )?;
2034
2035 let empty = empty_with_type(Utf8);
2036 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2037 assert_type_coercion_error(
2038 plan,
2039 "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
2040 )?;
2041
2042 let expr = col("a").is_not_unknown();
2044 let empty = empty_with_type(DataType::Boolean);
2045 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2046
2047 assert_analyzed_plan_eq!(
2048 plan,
2049 @r"
2050 Projection: a IS NOT UNKNOWN
2051 EmptyRelation: rows=0
2052 "
2053 )
2054 }
2055
2056 #[test]
2057 fn concat_for_type_coercion() -> Result<()> {
2058 let empty = empty_with_type(Utf8);
2059 let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2060
2061 let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2063 signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2064 })
2065 .call(args.to_vec());
2066 let plan =
2067 LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2068 assert_analyzed_plan_eq!(
2069 plan,
2070 @r#"
2071 Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2072 EmptyRelation: rows=0
2073 "#
2074 )
2075 }
2076
2077 #[test]
2078 fn test_type_coercion_rewrite() -> Result<()> {
2079 let schema = Arc::new(DFSchema::from_unqualified_fields(
2081 vec![Field::new("a", DataType::Int64, true)].into(),
2082 std::collections::HashMap::new(),
2083 )?);
2084 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2085 let expr = is_true(lit(12i32).gt(lit(13i64)));
2086 let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2087 let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?;
2088 assert_eq!(expected, result);
2089
2090 let schema = Arc::new(DFSchema::from_unqualified_fields(
2092 vec![Field::new("a", DataType::Int64, true)].into(),
2093 std::collections::HashMap::new(),
2094 )?);
2095 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2096 let expr = is_true(lit(12i32).eq(lit(13i64)));
2097 let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2098 let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?;
2099 assert_eq!(expected, result);
2100
2101 let schema = Arc::new(DFSchema::from_unqualified_fields(
2103 vec![Field::new("a", DataType::Int64, true)].into(),
2104 std::collections::HashMap::new(),
2105 )?);
2106 let mut rewriter = TypeCoercionRewriter { schema: &schema };
2107 let expr = is_true(lit(12i32).lt(lit(13i64)));
2108 let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2109 let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?;
2110 assert_eq!(expected, result);
2111
2112 Ok(())
2113 }
2114
2115 #[test]
2116 fn binary_op_date32_eq_ts() -> Result<()> {
2117 let expr = cast(
2118 lit("1998-03-18"),
2119 DataType::Timestamp(TimeUnit::Nanosecond, None),
2120 )
2121 .eq(cast(lit("1998-03-18"), DataType::Date32));
2122 let empty = empty();
2123 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2124
2125 assert_analyzed_plan_eq!(
2126 plan,
2127 @r#"
2128 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2129 EmptyRelation: rows=0
2130 "#
2131 )
2132 }
2133
2134 fn cast_if_not_same_type(
2135 expr: Box<Expr>,
2136 data_type: &DataType,
2137 schema: &DFSchemaRef,
2138 ) -> Box<Expr> {
2139 if &expr.get_type(schema).unwrap() != data_type {
2140 Box::new(cast(*expr, data_type.clone()))
2141 } else {
2142 expr
2143 }
2144 }
2145
2146 fn cast_helper(
2147 case: Case,
2148 case_when_type: &DataType,
2149 then_else_type: &DataType,
2150 schema: &DFSchemaRef,
2151 ) -> Case {
2152 let expr = case
2153 .expr
2154 .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2155 let when_then_expr = case
2156 .when_then_expr
2157 .into_iter()
2158 .map(|(when, then)| {
2159 (
2160 cast_if_not_same_type(when, case_when_type, schema),
2161 cast_if_not_same_type(then, then_else_type, schema),
2162 )
2163 })
2164 .collect::<Vec<_>>();
2165 let else_expr = case
2166 .else_expr
2167 .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2168
2169 Case {
2170 expr,
2171 when_then_expr,
2172 else_expr,
2173 }
2174 }
2175
2176 #[test]
2177 fn test_case_expression_coercion() -> Result<()> {
2178 let schema = Arc::new(DFSchema::from_unqualified_fields(
2179 vec![
2180 Field::new("boolean", DataType::Boolean, true),
2181 Field::new("integer", DataType::Int32, true),
2182 Field::new("float", DataType::Float32, true),
2183 Field::new(
2184 "timestamp",
2185 DataType::Timestamp(TimeUnit::Nanosecond, None),
2186 true,
2187 ),
2188 Field::new("date", DataType::Date32, true),
2189 Field::new(
2190 "interval",
2191 DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2192 true,
2193 ),
2194 Field::new("binary", DataType::Binary, true),
2195 Field::new("string", Utf8, true),
2196 Field::new("decimal", DataType::Decimal128(10, 10), true),
2197 ]
2198 .into(),
2199 std::collections::HashMap::new(),
2200 )?);
2201
2202 let case = Case {
2203 expr: None,
2204 when_then_expr: vec![
2205 (Box::new(col("boolean")), Box::new(col("integer"))),
2206 (Box::new(col("integer")), Box::new(col("float"))),
2207 (Box::new(col("string")), Box::new(col("string"))),
2208 ],
2209 else_expr: None,
2210 };
2211 let case_when_common_type = DataType::Boolean;
2212 let then_else_common_type = Utf8;
2213 let expected = cast_helper(
2214 case.clone(),
2215 &case_when_common_type,
2216 &then_else_common_type,
2217 &schema,
2218 );
2219 let actual = coerce_case_expression(case, &schema)?;
2220 assert_eq!(expected, actual);
2221
2222 let case = Case {
2223 expr: Some(Box::new(col("string"))),
2224 when_then_expr: vec![
2225 (Box::new(col("float")), Box::new(col("integer"))),
2226 (Box::new(col("integer")), Box::new(col("float"))),
2227 (Box::new(col("string")), Box::new(col("string"))),
2228 ],
2229 else_expr: Some(Box::new(col("string"))),
2230 };
2231 let case_when_common_type = Utf8;
2232 let then_else_common_type = Utf8;
2233 let expected = cast_helper(
2234 case.clone(),
2235 &case_when_common_type,
2236 &then_else_common_type,
2237 &schema,
2238 );
2239 let actual = coerce_case_expression(case, &schema)?;
2240 assert_eq!(expected, actual);
2241
2242 let case = Case {
2243 expr: Some(Box::new(col("interval"))),
2244 when_then_expr: vec![
2245 (Box::new(col("float")), Box::new(col("integer"))),
2246 (Box::new(col("binary")), Box::new(col("float"))),
2247 (Box::new(col("string")), Box::new(col("string"))),
2248 ],
2249 else_expr: Some(Box::new(col("string"))),
2250 };
2251 let err = coerce_case_expression(case, &schema).unwrap_err();
2252 assert_snapshot!(
2253 err.strip_backtrace(),
2254 @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2255 );
2256
2257 let case = Case {
2258 expr: Some(Box::new(col("string"))),
2259 when_then_expr: vec![
2260 (Box::new(col("float")), Box::new(col("date"))),
2261 (Box::new(col("string")), Box::new(col("float"))),
2262 (Box::new(col("string")), Box::new(col("binary"))),
2263 ],
2264 else_expr: Some(Box::new(col("timestamp"))),
2265 };
2266 let err = coerce_case_expression(case, &schema).unwrap_err();
2267 assert_snapshot!(
2268 err.strip_backtrace(),
2269 @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2270 );
2271
2272 Ok(())
2273 }
2274
2275 macro_rules! test_case_expression {
2276 ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2277 let case = Case {
2278 expr: $expr.map(|e| Box::new(col(e))),
2279 when_then_expr: $when_then,
2280 else_expr: None,
2281 };
2282
2283 let expected =
2284 cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2285
2286 let actual = coerce_case_expression(case, &$schema)?;
2287 assert_eq!(expected, actual);
2288 };
2289 }
2290
2291 #[test]
2292 fn tes_case_when_list() -> Result<()> {
2293 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2294 let schema = Arc::new(DFSchema::from_unqualified_fields(
2295 vec![
2296 Field::new(
2297 "large_list",
2298 DataType::LargeList(Arc::clone(&inner_field)),
2299 true,
2300 ),
2301 Field::new(
2302 "fixed_list",
2303 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2304 true,
2305 ),
2306 Field::new("list", DataType::List(inner_field), true),
2307 ]
2308 .into(),
2309 std::collections::HashMap::new(),
2310 )?);
2311
2312 test_case_expression!(
2313 Some("list"),
2314 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2315 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2316 Utf8,
2317 schema
2318 );
2319
2320 test_case_expression!(
2321 Some("large_list"),
2322 vec![(Box::new(col("list")), Box::new(lit("1")))],
2323 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2324 Utf8,
2325 schema
2326 );
2327
2328 test_case_expression!(
2329 Some("list"),
2330 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2331 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2332 Utf8,
2333 schema
2334 );
2335
2336 test_case_expression!(
2337 Some("fixed_list"),
2338 vec![(Box::new(col("list")), Box::new(lit("1")))],
2339 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2340 Utf8,
2341 schema
2342 );
2343
2344 test_case_expression!(
2345 Some("fixed_list"),
2346 vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2347 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2348 Utf8,
2349 schema
2350 );
2351
2352 test_case_expression!(
2353 Some("large_list"),
2354 vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2355 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2356 Utf8,
2357 schema
2358 );
2359 Ok(())
2360 }
2361
2362 #[test]
2363 fn test_then_else_list() -> Result<()> {
2364 let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2365 let schema = Arc::new(DFSchema::from_unqualified_fields(
2366 vec![
2367 Field::new("boolean", DataType::Boolean, true),
2368 Field::new(
2369 "large_list",
2370 DataType::LargeList(Arc::clone(&inner_field)),
2371 true,
2372 ),
2373 Field::new(
2374 "fixed_list",
2375 DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2376 true,
2377 ),
2378 Field::new("list", DataType::List(inner_field), true),
2379 ]
2380 .into(),
2381 std::collections::HashMap::new(),
2382 )?);
2383
2384 test_case_expression!(
2386 None::<String>,
2387 vec![
2388 (Box::new(col("boolean")), Box::new(col("large_list"))),
2389 (Box::new(col("boolean")), Box::new(col("list")))
2390 ],
2391 DataType::Boolean,
2392 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2393 schema
2394 );
2395
2396 test_case_expression!(
2397 None::<String>,
2398 vec![
2399 (Box::new(col("boolean")), Box::new(col("list"))),
2400 (Box::new(col("boolean")), Box::new(col("large_list")))
2401 ],
2402 DataType::Boolean,
2403 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2404 schema
2405 );
2406
2407 test_case_expression!(
2409 None::<String>,
2410 vec![
2411 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2412 (Box::new(col("boolean")), Box::new(col("list")))
2413 ],
2414 DataType::Boolean,
2415 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2416 schema
2417 );
2418
2419 test_case_expression!(
2420 None::<String>,
2421 vec![
2422 (Box::new(col("boolean")), Box::new(col("list"))),
2423 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2424 ],
2425 DataType::Boolean,
2426 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2427 schema
2428 );
2429
2430 test_case_expression!(
2432 None::<String>,
2433 vec![
2434 (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2435 (Box::new(col("boolean")), Box::new(col("large_list")))
2436 ],
2437 DataType::Boolean,
2438 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2439 schema
2440 );
2441
2442 test_case_expression!(
2443 None::<String>,
2444 vec![
2445 (Box::new(col("boolean")), Box::new(col("large_list"))),
2446 (Box::new(col("boolean")), Box::new(col("fixed_list")))
2447 ],
2448 DataType::Boolean,
2449 DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2450 schema
2451 );
2452 Ok(())
2453 }
2454
2455 #[test]
2456 fn test_map_with_diff_name() -> Result<()> {
2457 let mut builder = SchemaBuilder::new();
2458 builder.push(Field::new("key", Utf8, false));
2459 builder.push(Field::new("value", DataType::Float64, true));
2460 let struct_fields = builder.finish().fields;
2461
2462 let fields =
2463 Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2464 let map_type_entries = DataType::Map(Arc::new(fields), false);
2465
2466 let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2467 let may_type_custom = DataType::Map(Arc::new(fields), false);
2468
2469 let expr = col("a").eq(cast(col("a"), may_type_custom));
2470 let empty = empty_with_type(map_type_entries);
2471 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2472
2473 assert_analyzed_plan_eq!(
2474 plan,
2475 @r#"
2476 Projection: a = CAST(CAST(a AS Map("key_value": Struct("key": Utf8, "value": nullable Float64), unsorted)) AS Map("entries": Struct("key": Utf8, "value": nullable Float64), unsorted))
2477 EmptyRelation: rows=0
2478 "#
2479 )
2480 }
2481
2482 #[test]
2483 fn interval_plus_timestamp() -> Result<()> {
2484 let expr = Expr::BinaryExpr(BinaryExpr::new(
2486 Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2487 Operator::Plus,
2488 Box::new(cast(
2489 lit("2000-01-01T00:00:00"),
2490 DataType::Timestamp(TimeUnit::Nanosecond, None),
2491 )),
2492 ));
2493 let empty = empty();
2494 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2495
2496 assert_analyzed_plan_eq!(
2497 plan,
2498 @r#"
2499 Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2500 EmptyRelation: rows=0
2501 "#
2502 )
2503 }
2504
2505 #[test]
2506 fn timestamp_subtract_timestamp() -> Result<()> {
2507 let expr = Expr::BinaryExpr(BinaryExpr::new(
2508 Box::new(cast(
2509 lit("1998-03-18"),
2510 DataType::Timestamp(TimeUnit::Nanosecond, None),
2511 )),
2512 Operator::Minus,
2513 Box::new(cast(
2514 lit("1998-03-18"),
2515 DataType::Timestamp(TimeUnit::Nanosecond, None),
2516 )),
2517 ));
2518 let empty = empty();
2519 let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2520
2521 assert_analyzed_plan_eq!(
2522 plan,
2523 @r#"
2524 Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2525 EmptyRelation: rows=0
2526 "#
2527 )
2528 }
2529
2530 #[test]
2531 fn in_subquery_cast_subquery() -> Result<()> {
2532 let empty_int32 = empty_with_type(DataType::Int32);
2533 let empty_int64 = empty_with_type(DataType::Int64);
2534
2535 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2536 Box::new(col("a")),
2537 Subquery {
2538 subquery: empty_int32,
2539 outer_ref_columns: vec![],
2540 spans: Spans::new(),
2541 },
2542 false,
2543 ));
2544 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2545 assert_analyzed_plan_eq!(
2548 plan,
2549 @r"
2550 Filter: a IN (<subquery>)
2551 Subquery:
2552 Projection: CAST(a AS Int64)
2553 EmptyRelation: rows=0
2554 EmptyRelation: rows=0
2555 "
2556 )
2557 }
2558
2559 #[test]
2560 fn in_subquery_cast_expr() -> Result<()> {
2561 let empty_int32 = empty_with_type(DataType::Int32);
2562 let empty_int64 = empty_with_type(DataType::Int64);
2563
2564 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2565 Box::new(col("a")),
2566 Subquery {
2567 subquery: empty_int64,
2568 outer_ref_columns: vec![],
2569 spans: Spans::new(),
2570 },
2571 false,
2572 ));
2573 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2574
2575 assert_analyzed_plan_eq!(
2577 plan,
2578 @r"
2579 Filter: CAST(a AS Int64) IN (<subquery>)
2580 Subquery:
2581 EmptyRelation: rows=0
2582 EmptyRelation: rows=0
2583 "
2584 )
2585 }
2586
2587 #[test]
2588 fn in_subquery_cast_all() -> Result<()> {
2589 let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2590 let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2591
2592 let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2593 Box::new(col("a")),
2594 Subquery {
2595 subquery: empty_inside,
2596 outer_ref_columns: vec![],
2597 spans: Spans::new(),
2598 },
2599 false,
2600 ));
2601 let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2602
2603 assert_analyzed_plan_eq!(
2605 plan,
2606 @r"
2607 Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2608 Subquery:
2609 Projection: CAST(a AS Decimal128(13, 8))
2610 EmptyRelation: rows=0
2611 EmptyRelation: rows=0
2612 "
2613 )
2614 }
2615}