datafusion_optimizer/analyzer/
type_coercion.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Optimizer rule for type validation and coercion
19
20use std::sync::Arc;
21
22use datafusion_expr::binary::BinaryTypeCoercer;
23use datafusion_expr::tree_node::TreeNodeRewriterWithPayload;
24use itertools::{izip, Itertools as _};
25
26use arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
27
28use crate::analyzer::AnalyzerRule;
29use crate::utils::NamePreserver;
30use datafusion_common::config::ConfigOptions;
31use datafusion_common::tree_node::Transformed;
32use datafusion_common::{
33    exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err,
34    plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
35    TableReference,
36};
37use datafusion_expr::expr::{
38    self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList,
39    InSubquery, Like, ScalarFunction, Sort, WindowFunction,
40};
41use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
42use datafusion_expr::expr_schema::cast_subquery;
43use datafusion_expr::logical_plan::Subquery;
44use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion};
45use datafusion_expr::type_coercion::functions::{
46    data_types_with_scalar_udf, fields_with_aggregate_udf,
47};
48use datafusion_expr::type_coercion::other::{
49    get_coerce_type_for_case_expression, get_coerce_type_for_list,
50};
51use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8};
52use datafusion_expr::utils::merge_schema;
53use datafusion_expr::{
54    is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not,
55    AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection,
56    ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits,
57};
58
59/// Performs type coercion by determining the schema
60/// and performing the expression rewrites.
61#[derive(Default, Debug)]
62pub struct TypeCoercion {}
63
64impl TypeCoercion {
65    pub fn new() -> Self {
66        Self {}
67    }
68}
69
70/// Coerce output schema based upon optimizer config.
71fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
72    if !config.optimizer.expand_views_at_output {
73        return Ok(plan);
74    }
75
76    let outer_refs = plan.expressions();
77    if outer_refs.is_empty() {
78        return Ok(plan);
79    }
80
81    if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) {
82        coerce_plan_expr_for_schema(plan, &dfschema?)
83    } else {
84        Ok(plan)
85    }
86}
87
88impl AnalyzerRule for TypeCoercion {
89    fn name(&self) -> &str {
90        "type_coercion"
91    }
92
93    fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result<LogicalPlan> {
94        let empty_schema = DFSchema::empty();
95
96        // recurse
97        let transformed_plan = plan
98            .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))?
99            .data;
100
101        // finish
102        coerce_output(transformed_plan, config)
103    }
104}
105
106/// use the external schema to handle the correlated subqueries case
107///
108/// Assumes that children have already been optimized
109fn analyze_internal(
110    external_schema: &DFSchema,
111    plan: LogicalPlan,
112) -> Result<Transformed<LogicalPlan>> {
113    // get schema representing all available input fields. This is used for data type
114    // resolution only, so order does not matter here
115    let mut schema = merge_schema(&plan.inputs());
116
117    if let LogicalPlan::TableScan(ts) = &plan {
118        let source_schema = DFSchema::try_from_qualified_schema(
119            ts.table_name.clone(),
120            &ts.source.schema(),
121        )?;
122        schema.merge(&source_schema);
123    }
124
125    // merge the outer schema for correlated subqueries
126    // like case:
127    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
128    schema.merge(external_schema);
129
130    // Coerce filter predicates to boolean (handles `WHERE NULL`)
131    let plan = if let LogicalPlan::Filter(mut filter) = plan {
132        filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?;
133        LogicalPlan::Filter(filter)
134    } else {
135        plan
136    };
137
138    let mut expr_rewrite = TypeCoercionRewriter::new(&schema);
139
140    let name_preserver = NamePreserver::new(&plan);
141    // apply coercion rewrite all expressions in the plan individually
142    plan.map_expressions(|expr| {
143        let original_name = name_preserver.save(&expr);
144        expr.rewrite_with_schema(&schema, &mut expr_rewrite)
145            .map(|transformed| transformed.update_data(|e| original_name.restore(e)))
146    })?
147    // some plans need extra coercion after their expressions are coerced
148    .map_data(|plan| expr_rewrite.coerce_plan(plan))?
149    // recompute the schema after the expressions have been rewritten as the types may have changed
150    .map_data(|plan| plan.recompute_schema())
151}
152
153/// Rewrite expressions to apply type coercion.
154pub struct TypeCoercionRewriter<'a> {
155    pub(crate) schema: &'a DFSchema,
156}
157
158impl<'a> TypeCoercionRewriter<'a> {
159    /// Create a new [`TypeCoercionRewriter`] with a provided schema
160    /// representing both the inputs and output of the [`LogicalPlan`] node.
161    pub fn new(schema: &'a DFSchema) -> Self {
162        Self { schema }
163    }
164
165    /// Coerce the [`LogicalPlan`].
166    ///
167    /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`]
168    /// for type-coercion approach.
169    pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result<LogicalPlan> {
170        match plan {
171            LogicalPlan::Join(join) => self.coerce_join(join),
172            LogicalPlan::Union(union) => Self::coerce_union(union),
173            LogicalPlan::Limit(limit) => Self::coerce_limit(limit),
174            _ => Ok(plan),
175        }
176    }
177
178    /// Coerce join equality expressions and join filter
179    ///
180    /// Joins must be treated specially as their equality expressions are stored
181    /// as a parallel list of left and right expressions, rather than a single
182    /// equality expression
183    ///
184    /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored
185    /// as a list of `(t1.a, t2.b), (t1.x, t2.y)`
186    pub fn coerce_join(&mut self, mut join: Join) -> Result<LogicalPlan> {
187        join.on = join
188            .on
189            .into_iter()
190            .map(|(lhs, rhs)| {
191                // coerce the arguments as though they were a single binary equality
192                // expression
193                let left_schema = join.left.schema();
194                let right_schema = join.right.schema();
195                let (lhs, rhs) = self.coerce_binary_op(
196                    lhs,
197                    left_schema,
198                    Operator::Eq,
199                    rhs,
200                    right_schema,
201                )?;
202                Ok((lhs, rhs))
203            })
204            .collect::<Result<Vec<_>>>()?;
205
206        // Join filter must be boolean
207        join.filter = join
208            .filter
209            .map(|expr| self.coerce_join_filter(expr))
210            .transpose()?;
211
212        Ok(LogicalPlan::Join(join))
213    }
214
215    /// Coerce the union’s inputs to a common schema compatible with all inputs.
216    /// This occurs after wildcard expansion and the coercion of the input expressions.
217    pub fn coerce_union(union_plan: Union) -> Result<LogicalPlan> {
218        let union_schema = Arc::new(coerce_union_schema_with_schema(
219            &union_plan.inputs,
220            &union_plan.schema,
221        )?);
222        let new_inputs = union_plan
223            .inputs
224            .into_iter()
225            .map(|p| {
226                let plan =
227                    coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?;
228                match plan {
229                    LogicalPlan::Projection(Projection { expr, input, .. }) => {
230                        Ok(Arc::new(project_with_column_index(
231                            expr,
232                            input,
233                            Arc::clone(&union_schema),
234                        )?))
235                    }
236                    other_plan => Ok(Arc::new(other_plan)),
237                }
238            })
239            .collect::<Result<Vec<_>>>()?;
240        Ok(LogicalPlan::Union(Union {
241            inputs: new_inputs,
242            schema: union_schema,
243        }))
244    }
245
246    /// Coerce the fetch and skip expression to Int64 type.
247    fn coerce_limit(limit: Limit) -> Result<LogicalPlan> {
248        fn coerce_limit_expr(
249            expr: Expr,
250            schema: &DFSchema,
251            expr_name: &str,
252        ) -> Result<Expr> {
253            let dt = expr.get_type(schema)?;
254            if dt.is_integer() || dt.is_null() {
255                expr.cast_to(&DataType::Int64, schema)
256            } else {
257                plan_err!("Expected {expr_name} to be an integer or null, but got {dt}")
258            }
259        }
260
261        let empty_schema = DFSchema::empty();
262        let new_fetch = limit
263            .fetch
264            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT"))
265            .transpose()?;
266        let new_skip = limit
267            .skip
268            .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET"))
269            .transpose()?;
270        Ok(LogicalPlan::Limit(Limit {
271            input: limit.input,
272            fetch: new_fetch.map(Box::new),
273            skip: new_skip.map(Box::new),
274        }))
275    }
276
277    fn coerce_join_filter(&self, expr: Expr) -> Result<Expr> {
278        let expr_type = expr.get_type(self.schema)?;
279        match expr_type {
280            DataType::Boolean => Ok(expr),
281            DataType::Null => expr.cast_to(&DataType::Boolean, self.schema),
282            other => plan_err!("Join condition must be boolean type, but got {other:?}"),
283        }
284    }
285
286    fn coerce_binary_op(
287        &self,
288        left: Expr,
289        left_schema: &DFSchema,
290        op: Operator,
291        right: Expr,
292        right_schema: &DFSchema,
293    ) -> Result<(Expr, Expr)> {
294        let (left_type, right_type) = BinaryTypeCoercer::new(
295            &left.get_type(left_schema)?,
296            &op,
297            &right.get_type(right_schema)?,
298        )
299        .get_input_types()?;
300
301        Ok((
302            left.cast_to(&left_type, left_schema)?,
303            right.cast_to(&right_type, right_schema)?,
304        ))
305    }
306}
307
308impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> {
309    type Node = Expr;
310    type Payload<'a> = &'a DFSchema;
311
312    fn f_up(&mut self, expr: Expr, schema: &DFSchema) -> Result<Transformed<Expr>> {
313        match expr {
314            Expr::Unnest(_) => not_impl_err!(
315                "Unnest should be rewritten to LogicalPlan::Unnest before type coercion"
316            ),
317            Expr::ScalarSubquery(Subquery {
318                subquery,
319                outer_ref_columns,
320                spans,
321            }) => {
322                let new_plan =
323                    analyze_internal(schema, Arc::unwrap_or_clone(subquery))?.data;
324                Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
325                    subquery: Arc::new(new_plan),
326                    outer_ref_columns,
327                    spans,
328                })))
329            }
330            Expr::Exists(Exists { subquery, negated }) => {
331                let new_plan = analyze_internal(
332                    schema,
333                    Arc::unwrap_or_clone(subquery.subquery),
334                )?
335                .data;
336                Ok(Transformed::yes(Expr::Exists(Exists {
337                    subquery: Subquery {
338                        subquery: Arc::new(new_plan),
339                        outer_ref_columns: subquery.outer_ref_columns,
340                        spans: subquery.spans,
341                    },
342                    negated,
343                })))
344            }
345            Expr::InSubquery(InSubquery {
346                expr,
347                subquery,
348                negated,
349            }) => {
350                let new_plan = analyze_internal(
351                    schema,
352                    Arc::unwrap_or_clone(subquery.subquery),
353                )?
354                .data;
355                let expr_type = expr.get_type(schema)?;
356                let subquery_type = new_plan.schema().field(0).data_type();
357                let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
358                    plan_datafusion_err!(
359                    "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
360                ),
361                )?;
362                let new_subquery = Subquery {
363                    subquery: Arc::new(new_plan),
364                    outer_ref_columns: subquery.outer_ref_columns,
365                    spans: subquery.spans,
366                };
367                Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
368                    Box::new(expr.cast_to(&common_type, schema)?),
369                    cast_subquery(new_subquery, &common_type)?,
370                    negated,
371                ))))
372            }
373            Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
374                *expr,
375                schema,
376            )?))),
377            Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
378                get_casted_expr_for_bool_op(*expr, schema)?,
379            ))),
380            Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
381                get_casted_expr_for_bool_op(*expr, schema)?,
382            ))),
383            Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
384                get_casted_expr_for_bool_op(*expr, schema)?,
385            ))),
386            Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
387                get_casted_expr_for_bool_op(*expr, schema)?,
388            ))),
389            Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
390                get_casted_expr_for_bool_op(*expr, schema)?,
391            ))),
392            Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
393                get_casted_expr_for_bool_op(*expr, schema)?,
394            ))),
395            Expr::Like(Like {
396                negated,
397                expr,
398                pattern,
399                escape_char,
400                case_insensitive,
401            }) => {
402                let left_type = expr.get_type(schema)?;
403                let right_type = pattern.get_type(schema)?;
404                let coerced_type = like_coercion(&left_type,  &right_type).ok_or_else(|| {
405                    let op_name = if case_insensitive {
406                        "ILIKE"
407                    } else {
408                        "LIKE"
409                    };
410                    plan_datafusion_err!(
411                        "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
412                    )
413                })?;
414                let expr = match left_type {
415                    DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr,
416                    _ => Box::new(expr.cast_to(&coerced_type, schema)?),
417                };
418                let pattern = Box::new(pattern.cast_to(&coerced_type, schema)?);
419                Ok(Transformed::yes(Expr::Like(Like::new(
420                    negated,
421                    expr,
422                    pattern,
423                    escape_char,
424                    case_insensitive,
425                ))))
426            }
427            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
428                let (left, right) =
429                    self.coerce_binary_op(*left, schema, op, *right, schema)?;
430                Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
431                    Box::new(left),
432                    op,
433                    Box::new(right),
434                ))))
435            }
436            Expr::Between(Between {
437                expr,
438                negated,
439                low,
440                high,
441            }) => {
442                let expr_type = expr.get_type(schema)?;
443                let low_type = low.get_type(schema)?;
444                let low_coerced_type = comparison_coercion(&expr_type, &low_type)
445                    .ok_or_else(|| {
446                        internal_datafusion_err!(
447                            "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
448                        )
449                    })?;
450                let high_type = high.get_type(schema)?;
451                let high_coerced_type = comparison_coercion(&expr_type, &high_type)
452                    .ok_or_else(|| {
453                        internal_datafusion_err!(
454                            "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
455                        )
456                    })?;
457                let coercion_type =
458                    comparison_coercion(&low_coerced_type, &high_coerced_type)
459                        .ok_or_else(|| {
460                            internal_datafusion_err!(
461                                "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression"
462                            )
463                        })?;
464                Ok(Transformed::yes(Expr::Between(Between::new(
465                    Box::new(expr.cast_to(&coercion_type, schema)?),
466                    negated,
467                    Box::new(low.cast_to(&coercion_type, schema)?),
468                    Box::new(high.cast_to(&coercion_type, schema)?),
469                ))))
470            }
471            Expr::InList(InList {
472                expr,
473                list,
474                negated,
475            }) => {
476                let expr_data_type = expr.get_type(schema)?;
477                let list_data_types = list
478                    .iter()
479                    .map(|list_expr| list_expr.get_type(schema))
480                    .collect::<Result<Vec<_>>>()?;
481                let result_type =
482                    get_coerce_type_for_list(&expr_data_type, &list_data_types);
483                match result_type {
484                    None => plan_err!(
485                        "Can not find compatible types to compare {expr_data_type} with [{}]", list_data_types.iter().join(", ")
486                    ),
487                    Some(coerced_type) => {
488                        // find the coerced type
489                        let cast_expr = expr.cast_to(&coerced_type, schema)?;
490                        let cast_list_expr = list
491                            .into_iter()
492                            .map(|list_expr| {
493                                list_expr.cast_to(&coerced_type, schema)
494                            })
495                            .collect::<Result<Vec<_>>>()?;
496                        Ok(Transformed::yes(Expr::InList(InList ::new(
497                             Box::new(cast_expr),
498                             cast_list_expr,
499                            negated,
500                        ))))
501                    }
502                }
503            }
504            Expr::Case(case) => {
505                let case = coerce_case_expression(case, schema)?;
506                Ok(Transformed::yes(Expr::Case(case)))
507            }
508            Expr::ScalarFunction(ScalarFunction { func, args }) => {
509                let new_expr = coerce_arguments_for_signature_with_scalar_udf(
510                    args,
511                    schema,
512                    &func,
513                )?;
514                Ok(Transformed::yes(Expr::ScalarFunction(
515                    ScalarFunction::new_udf(func, new_expr),
516                )))
517            }
518            Expr::AggregateFunction(expr::AggregateFunction {
519                func,
520                params:
521                    AggregateFunctionParams {
522                        args,
523                        distinct,
524                        filter,
525                        order_by,
526                        null_treatment,
527                    },
528            }) => {
529                let new_expr = coerce_arguments_for_signature_with_aggregate_udf(
530                    args,
531                    schema,
532                    &func,
533                )?;
534                Ok(Transformed::yes(Expr::AggregateFunction(
535                    expr::AggregateFunction::new_udf(
536                        func,
537                        new_expr,
538                        distinct,
539                        filter,
540                        order_by,
541                        null_treatment,
542                    ),
543                )))
544            }
545            Expr::WindowFunction(window_fun) => {
546                let WindowFunction {
547                    fun,
548                    params:
549                        expr::WindowFunctionParams {
550                            args,
551                            partition_by,
552                            order_by,
553                            window_frame,
554                            filter,
555                            null_treatment,
556                            distinct,
557                        },
558                } = *window_fun;
559                let window_frame =
560                    coerce_window_frame(window_frame, schema, &order_by)?;
561
562                let args = match &fun {
563                    expr::WindowFunctionDefinition::AggregateUDF(udf) => {
564                        coerce_arguments_for_signature_with_aggregate_udf(
565                            args,
566                            schema,
567                            udf,
568                        )?
569                    }
570                    _ => args,
571                };
572
573                let new_expr = Expr::from(WindowFunction {
574                    fun,
575                    params: expr::WindowFunctionParams {
576                        args,
577                        partition_by,
578                        order_by,
579                        window_frame,
580                        filter,
581                        null_treatment,
582                        distinct,
583                    },
584                });
585                Ok(Transformed::yes(new_expr))
586            }
587            // TODO: remove the next line after `Expr::Wildcard` is removed
588            #[expect(deprecated)]
589            Expr::Alias(_)
590            | Expr::Column(_)
591            | Expr::ScalarVariable(_, _)
592            | Expr::Literal(_, _)
593            | Expr::SimilarTo(_)
594            | Expr::IsNotNull(_)
595            | Expr::IsNull(_)
596            | Expr::Negative(_)
597            | Expr::Cast(_)
598            | Expr::TryCast(_)
599            | Expr::Wildcard { .. }
600            | Expr::GroupingSet(_)
601            | Expr::Placeholder(_)
602            | Expr::OuterReferenceColumn(_, _)
603            | Expr::Lambda { .. } => Ok(Transformed::no(expr)),
604        }
605    }
606}
607
608/// Transform a schema to use non-view types for Utf8View and BinaryView
609fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option<Result<DFSchema>> {
610    let metadata = dfschema.as_arrow().metadata.clone();
611    let mut transformed = false;
612
613    let (qualifiers, transformed_fields): (Vec<Option<TableReference>>, Vec<Arc<Field>>) =
614        dfschema
615            .iter()
616            .map(|(qualifier, field)| match field.data_type() {
617                DataType::Utf8View => {
618                    transformed = true;
619                    (
620                        qualifier.cloned() as Option<TableReference>,
621                        Arc::new(Field::new(
622                            field.name(),
623                            DataType::LargeUtf8,
624                            field.is_nullable(),
625                        )),
626                    )
627                }
628                DataType::BinaryView => {
629                    transformed = true;
630                    (
631                        qualifier.cloned() as Option<TableReference>,
632                        Arc::new(Field::new(
633                            field.name(),
634                            DataType::LargeBinary,
635                            field.is_nullable(),
636                        )),
637                    )
638                }
639                _ => (
640                    qualifier.cloned() as Option<TableReference>,
641                    Arc::clone(field),
642                ),
643            })
644            .unzip();
645
646    if !transformed {
647        return None;
648    }
649
650    let schema = Schema::new_with_metadata(transformed_fields, metadata);
651    Some(DFSchema::from_field_specific_qualified_schema(
652        qualifiers,
653        &Arc::new(schema),
654    ))
655}
656
657/// Casts the given `value` to `target_type`. Note that this function
658/// only considers `Null` or `Utf8` values.
659fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarValue> {
660    match value {
661        // Coerce Utf8 values:
662        ScalarValue::Utf8(Some(val)) => {
663            ScalarValue::try_from_string(val.clone(), target_type)
664        }
665        s => {
666            if s.is_null() {
667                // Coerce `Null` values:
668                ScalarValue::try_from(target_type)
669            } else {
670                // Values except `Utf8`/`Null` variants already have the right type
671                // (casted before) since we convert `sqlparser` outputs to `Utf8`
672                // for all possible cases. Therefore, we return a clone here.
673                Ok(s.clone())
674            }
675        }
676    }
677}
678
679/// This function coerces `value` to `target_type` in a range-aware fashion.
680/// If the coercion is successful, we return an `Ok` value with the result.
681/// If the coercion fails because `target_type` is not wide enough (i.e. we
682/// can not coerce to `target_type`, but we can to a wider type in the same
683/// family), we return a `Null` value of this type to signal this situation.
684/// Downstream code uses this signal to treat these values as *unbounded*.
685fn coerce_scalar_range_aware(
686    target_type: &DataType,
687    value: &ScalarValue,
688) -> Result<ScalarValue> {
689    coerce_scalar(target_type, value).or_else(|err| {
690        // If type coercion fails, check if the largest type in family works:
691        if let Some(largest_type) = get_widest_type_in_family(target_type) {
692            coerce_scalar(largest_type, value).map_or_else(
693                |_| exec_err!("Cannot cast {value:?} to {target_type}"),
694                |_| ScalarValue::try_from(target_type),
695            )
696        } else {
697            Err(err)
698        }
699    })
700}
701
702/// This function returns the widest type in the family of `given_type`.
703/// If the given type is already the widest type, it returns `None`.
704/// For example, if `given_type` is `Int8`, it returns `Int64`.
705fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
706    match given_type {
707        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => Some(&DataType::UInt64),
708        DataType::Int8 | DataType::Int16 | DataType::Int32 => Some(&DataType::Int64),
709        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
710        _ => None,
711    }
712}
713
714/// Coerces the given (window frame) `bound` to `target_type`.
715fn coerce_frame_bound(
716    target_type: &DataType,
717    bound: WindowFrameBound,
718) -> Result<WindowFrameBound> {
719    match bound {
720        WindowFrameBound::Preceding(v) => {
721            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding)
722        }
723        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
724        WindowFrameBound::Following(v) => {
725            coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following)
726        }
727    }
728}
729
730fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
731    if col_type.is_numeric()
732        || is_utf8_or_utf8view_or_large_utf8(col_type)
733        || matches!(col_type, DataType::List(_))
734        || matches!(col_type, DataType::LargeList(_))
735        || matches!(col_type, DataType::FixedSizeList(_, _))
736        || matches!(col_type, DataType::Null)
737        || matches!(col_type, DataType::Boolean)
738    {
739        Ok(col_type.clone())
740    } else if is_datetime(col_type) {
741        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
742    } else if let DataType::Dictionary(_, value_type) = col_type {
743        extract_window_frame_target_type(value_type)
744    } else {
745        internal_err!("Cannot run range queries on datatype: {col_type}")
746    }
747}
748
749// Coerces the given `window_frame` to use appropriate natural types.
750// For example, ROWS and GROUPS frames use `UInt64` during calculations.
751fn coerce_window_frame(
752    window_frame: WindowFrame,
753    schema: &DFSchema,
754    expressions: &[Sort],
755) -> Result<WindowFrame> {
756    let mut window_frame = window_frame;
757    let target_type = match window_frame.units {
758        WindowFrameUnits::Range => {
759            let current_types = expressions
760                .first()
761                .map(|s| s.expr.get_type(schema))
762                .transpose()?;
763            if let Some(col_type) = current_types {
764                extract_window_frame_target_type(&col_type)?
765            } else {
766                return internal_err!("ORDER BY column cannot be empty");
767            }
768        }
769        WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64,
770    };
771    window_frame.start_bound =
772        coerce_frame_bound(&target_type, window_frame.start_bound)?;
773    window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?;
774    Ok(window_frame)
775}
776
777// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
778// The above op will be rewrite to the binary op when creating the physical op.
779fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
780    let left_type = expr.get_type(schema)?;
781    BinaryTypeCoercer::new(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)
782        .get_input_types()?;
783    expr.cast_to(&DataType::Boolean, schema)
784}
785
786/// Returns `expressions` coerced to types compatible with
787/// `signature`, if possible.
788///
789/// See the module level documentation for more detail on coercion.
790fn coerce_arguments_for_signature_with_scalar_udf(
791    expressions: Vec<Expr>,
792    schema: &DFSchema,
793    func: &ScalarUDF,
794) -> Result<Vec<Expr>> {
795    if expressions.is_empty() {
796        return Ok(expressions);
797    }
798
799    let current_types = expressions.iter()
800        .map(|e| match e {
801            Expr::Lambda { .. } => Ok(DataType::Null),
802            _ => e.get_type(schema),
803        })
804        .collect::<Result<Vec<_>>>()?;
805
806    let new_types = data_types_with_scalar_udf(&current_types, func)?;
807
808    expressions
809        .into_iter()
810        .enumerate()
811        .map(|(i, expr)| match expr {
812            lambda @ Expr::Lambda { .. } => Ok(lambda),
813            _ => expr.cast_to(&new_types[i], schema),
814        })
815        .collect()
816}
817
818/// Returns `expressions` coerced to types compatible with
819/// `signature`, if possible.
820///
821/// See the module level documentation for more detail on coercion.
822fn coerce_arguments_for_signature_with_aggregate_udf(
823    expressions: Vec<Expr>,
824    schema: &DFSchema,
825    func: &AggregateUDF,
826) -> Result<Vec<Expr>> {
827    if expressions.is_empty() {
828        return Ok(expressions);
829    }
830
831    let current_fields = expressions
832        .iter()
833        .map(|e| e.to_field(schema).map(|(_, f)| f))
834        .collect::<Result<Vec<_>>>()?;
835
836    let new_types = fields_with_aggregate_udf(&current_fields, func)?
837        .into_iter()
838        .map(|f| f.data_type().clone())
839        .collect::<Vec<_>>();
840
841    expressions
842        .into_iter()
843        .enumerate()
844        .map(|(i, expr)| expr.cast_to(&new_types[i], schema))
845        .collect()
846}
847
848fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
849    // Given expressions like:
850    //
851    // CASE a1
852    //   WHEN a2 THEN b1
853    //   WHEN a3 THEN b2
854    //   ELSE b3
855    // END
856    //
857    // or:
858    //
859    // CASE
860    //   WHEN x1 THEN b1
861    //   WHEN x2 THEN b2
862    //   ELSE b3
863    // END
864    //
865    // Then all aN (a1, a2, a3) must be converted to a common data type in the first example
866    // (case-when expression coercion)
867    //
868    // All xN (x1, x2) must be converted to a boolean data type in the second example
869    // (when-boolean expression coercion)
870    //
871    // And all bN (b1, b2, b3) must be converted to a common data type in both examples
872    // (then-else expression coercion)
873    //
874    // If any fail to find and cast to a common/specific data type, will return error
875    //
876    // Note that case-when and when-boolean expression coercions are mutually exclusive
877    // Only one or the other can occur for a case expression, whilst then-else expression coercion will always occur
878
879    // prepare types
880    let case_type = case
881        .expr
882        .as_ref()
883        .map(|expr| expr.get_type(schema))
884        .transpose()?;
885    let then_types = case
886        .when_then_expr
887        .iter()
888        .map(|(_when, then)| then.get_type(schema))
889        .collect::<Result<Vec<_>>>()?;
890    let else_type = case
891        .else_expr
892        .as_ref()
893        .map(|expr| expr.get_type(schema))
894        .transpose()?;
895
896    // find common coercible types
897    let case_when_coerce_type = case_type
898        .as_ref()
899        .map(|case_type| {
900            let when_types = case
901                .when_then_expr
902                .iter()
903                .map(|(when, _then)| when.get_type(schema))
904                .collect::<Result<Vec<_>>>()?;
905            let coerced_type =
906                get_coerce_type_for_case_expression(&when_types, Some(case_type));
907            coerced_type.ok_or_else(|| {
908                plan_datafusion_err!(
909                    "Failed to coerce case ({case_type}) and when ({}) \
910                     to common types in CASE WHEN expression",
911                    when_types.iter().join(", ")
912                )
913            })
914        })
915        .transpose()?;
916    let then_else_coerce_type =
917        get_coerce_type_for_case_expression(&then_types, else_type.as_ref()).ok_or_else(
918            || {
919                if let Some(else_type) = else_type {
920                    plan_datafusion_err!(
921                        "Failed to coerce then ({}) and else ({else_type}) \
922                         to common types in CASE WHEN expression",
923                        then_types.iter().join(", ")
924                    )
925                } else {
926                    plan_datafusion_err!(
927                        "Failed to coerce then ({}) and else (None) \
928                         to common types in CASE WHEN expression",
929                        then_types.iter().join(", ")
930                    )
931                }
932            },
933        )?;
934
935    // do cast if found common coercible types
936    let case_expr = case
937        .expr
938        .zip(case_when_coerce_type.as_ref())
939        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema))
940        .transpose()?
941        .map(Box::new);
942    let when_then = case
943        .when_then_expr
944        .into_iter()
945        .map(|(when, then)| {
946            let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
947            let when = when.cast_to(when_type, schema).map_err(|e| {
948                DataFusionError::Context(
949                    format!(
950                        "WHEN expressions in CASE couldn't be \
951                         converted to common type ({when_type})"
952                    ),
953                    Box::new(e),
954                )
955            })?;
956            let then = then.cast_to(&then_else_coerce_type, schema)?;
957            Ok((Box::new(when), Box::new(then)))
958        })
959        .collect::<Result<Vec<_>>>()?;
960    let else_expr = case
961        .else_expr
962        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
963        .transpose()?
964        .map(Box::new);
965
966    Ok(Case::new(case_expr, when_then, else_expr))
967}
968
969/// Get a common schema that is compatible with all inputs of UNION.
970///
971/// This method presumes that the wildcard expansion is unneeded, or has already
972/// been applied.
973///
974/// ## Schema and Field Handling in Union Coercion
975///
976/// **Processing order**: The function starts with the base schema (first input) and then
977/// processes remaining inputs sequentially, with later inputs taking precedence in merging.
978///
979/// **Schema-level metadata merging**: Later schemas take precedence for duplicate keys.
980///
981/// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys.
982///
983/// **Type coercion precedence**: The coerced type is determined by iteratively applying
984/// `comparison_coercion()` between the accumulated type and each new input's type. The
985/// result depends on type coercion rules, not input order.
986///
987/// **Nullability merging**: Nullability is accumulated using logical OR (`||`).
988/// Once any input field is nullable, the result field becomes nullable permanently.
989/// Later inputs can make a field nullable but cannot make it non-nullable.
990///
991/// **Field precedence**: Field names come from the first (base) schema, but the field properties
992/// (nullability and field-level metadata) have later schemas taking precedence.
993///
994/// **Example**:
995/// ```sql
996/// SELECT a, b FROM table1  -- a: Int32, metadata {"source": "t1"}, nullable=false
997/// UNION
998/// SELECT a, b FROM table2  -- a: Int64, metadata {"source": "t2"}, nullable=true
999/// UNION
1000/// SELECT a, b FROM table3  -- a: Int32, metadata {"encoding": "utf8"}, nullable=false
1001/// -- Result:
1002/// -- a: Int64 (from type coercion), nullable=true (from table2),
1003/// -- metadata: {"source": "t2", "encoding": "utf8"} (later inputs take precedence)
1004/// ```
1005///
1006/// **Precedence Summary**:
1007/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order
1008/// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR)
1009/// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics)
1010pub fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> {
1011    coerce_union_schema_with_schema(&inputs[1..], inputs[0].schema())
1012}
1013fn coerce_union_schema_with_schema(
1014    inputs: &[Arc<LogicalPlan>],
1015    base_schema: &DFSchemaRef,
1016) -> Result<DFSchema> {
1017    let mut union_datatypes = base_schema
1018        .fields()
1019        .iter()
1020        .map(|f| f.data_type().clone())
1021        .collect::<Vec<_>>();
1022    let mut union_nullabilities = base_schema
1023        .fields()
1024        .iter()
1025        .map(|f| f.is_nullable())
1026        .collect::<Vec<_>>();
1027    let mut union_field_meta = base_schema
1028        .fields()
1029        .iter()
1030        .map(|f| f.metadata().clone())
1031        .collect::<Vec<_>>();
1032
1033    let mut metadata = base_schema.metadata().clone();
1034
1035    for (i, plan) in inputs.iter().enumerate() {
1036        let plan_schema = plan.schema();
1037        metadata.extend(plan_schema.metadata().clone());
1038
1039        if plan_schema.fields().len() != base_schema.fields().len() {
1040            return plan_err!(
1041                "Union schemas have different number of fields: \
1042                query 1 has {} fields whereas query {} has {} fields",
1043                base_schema.fields().len(),
1044                i + 1,
1045                plan_schema.fields().len()
1046            );
1047        }
1048
1049        // coerce data type and nullability for each field
1050        for (union_datatype, union_nullable, union_field_map, plan_field) in izip!(
1051            union_datatypes.iter_mut(),
1052            union_nullabilities.iter_mut(),
1053            union_field_meta.iter_mut(),
1054            plan_schema.fields().iter()
1055        ) {
1056            let coerced_type =
1057                comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else(
1058                    || {
1059                        plan_datafusion_err!(
1060                            "Incompatible inputs for Union: Previous inputs were \
1061                            of type {}, but got incompatible type {} on column '{}'",
1062                            union_datatype,
1063                            plan_field.data_type(),
1064                            plan_field.name()
1065                        )
1066                    },
1067                )?;
1068
1069            *union_datatype = coerced_type;
1070            *union_nullable = *union_nullable || plan_field.is_nullable();
1071            union_field_map.extend(plan_field.metadata().clone());
1072        }
1073    }
1074    let union_qualified_fields = izip!(
1075        base_schema.fields(),
1076        union_datatypes.into_iter(),
1077        union_nullabilities,
1078        union_field_meta.into_iter()
1079    )
1080    .map(|(field, datatype, nullable, metadata)| {
1081        let mut field = Field::new(field.name().clone(), datatype, nullable);
1082        field.set_metadata(metadata);
1083        (None, field.into())
1084    })
1085    .collect::<Vec<_>>();
1086
1087    DFSchema::new_with_metadata(union_qualified_fields, metadata)
1088}
1089
1090/// See `<https://github.com/apache/datafusion/pull/2108>`
1091fn project_with_column_index(
1092    expr: Vec<Expr>,
1093    input: Arc<LogicalPlan>,
1094    schema: DFSchemaRef,
1095) -> Result<LogicalPlan> {
1096    let alias_expr = expr
1097        .into_iter()
1098        .enumerate()
1099        .map(|(i, e)| match e {
1100            Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => {
1101                Ok(e.unalias().alias(schema.field(i).name()))
1102            }
1103            Expr::Column(Column {
1104                relation: _,
1105                ref name,
1106                spans: _,
1107            }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())),
1108            Expr::Alias { .. } | Expr::Column { .. } => Ok(e),
1109            #[expect(deprecated)]
1110            Expr::Wildcard { .. } => {
1111                plan_err!("Wildcard should be expanded before type coercion")
1112            }
1113            _ => Ok(e.alias(schema.field(i).name())),
1114        })
1115        .collect::<Result<Vec<_>>>()?;
1116
1117    Projection::try_new_with_schema(alias_expr, input, schema)
1118        .map(LogicalPlan::Projection)
1119}
1120
1121#[cfg(test)]
1122mod test {
1123    use std::any::Any;
1124    use std::sync::Arc;
1125
1126    use arrow::datatypes::DataType::Utf8;
1127    use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit};
1128    use insta::assert_snapshot;
1129
1130    use crate::analyzer::type_coercion::{
1131        coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
1132    };
1133    use crate::analyzer::Analyzer;
1134    use crate::assert_analyzed_plan_with_config_eq_snapshot;
1135    use datafusion_common::config::ConfigOptions;
1136    use datafusion_common::tree_node::{TransformedResult};
1137    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans};
1138    use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
1139    use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort};
1140    use datafusion_expr::test::function_stub::avg_udaf;
1141    use datafusion_expr::{
1142        cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
1143        BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
1144        Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
1145        SimpleAggregateUDF, Subquery, Union, Volatility,
1146    };
1147    use datafusion_functions_aggregate::average::AvgAccumulator;
1148    use datafusion_sql::TableReference;
1149
1150    fn empty() -> Arc<LogicalPlan> {
1151        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1152            produce_one_row: false,
1153            schema: Arc::new(DFSchema::empty()),
1154        }))
1155    }
1156
1157    fn empty_with_type(data_type: DataType) -> Arc<LogicalPlan> {
1158        Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1159            produce_one_row: false,
1160            schema: Arc::new(
1161                DFSchema::from_unqualified_fields(
1162                    vec![Field::new("a", data_type, true)].into(),
1163                    std::collections::HashMap::new(),
1164                )
1165                .unwrap(),
1166            ),
1167        }))
1168    }
1169
1170    macro_rules! assert_analyzed_plan_eq {
1171        (
1172            $plan: expr,
1173            @ $expected: literal $(,)?
1174        ) => {{
1175            let options = ConfigOptions::default();
1176            let rule = Arc::new(TypeCoercion::new());
1177            assert_analyzed_plan_with_config_eq_snapshot!(
1178                options,
1179                rule,
1180                $plan,
1181                @ $expected,
1182            )
1183            }};
1184    }
1185
1186    macro_rules! coerce_on_output_if_viewtype {
1187        (
1188            $is_viewtype: expr,
1189            $plan: expr,
1190            @ $expected: literal $(,)?
1191        ) => {{
1192            let mut options = ConfigOptions::default();
1193            // coerce on output
1194            if $is_viewtype {options.optimizer.expand_views_at_output = true;}
1195            let rule = Arc::new(TypeCoercion::new());
1196
1197            assert_analyzed_plan_with_config_eq_snapshot!(
1198                options,
1199                rule,
1200                $plan,
1201                @ $expected,
1202            )
1203        }};
1204    }
1205
1206    fn assert_type_coercion_error(
1207        plan: LogicalPlan,
1208        expected_substr: &str,
1209    ) -> Result<()> {
1210        let options = ConfigOptions::default();
1211        let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]);
1212
1213        match analyzer.execute_and_check(plan, &options, |_, _| {}) {
1214            Ok(succeeded_plan) => {
1215                panic!(
1216                    "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}"
1217                );
1218            }
1219            Err(e) => {
1220                let msg = e.to_string();
1221                assert!(
1222                    msg.contains(expected_substr),
1223                    "Error did not contain expected substring.\n  expected to find: `{expected_substr}`\n  actual error: `{msg}`"
1224                );
1225            }
1226        }
1227
1228        Ok(())
1229    }
1230
1231    #[test]
1232    fn simple_case() -> Result<()> {
1233        let expr = col("a").lt(lit(2_u32));
1234        let empty = empty_with_type(DataType::Float64);
1235        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1236
1237        assert_analyzed_plan_eq!(
1238            plan,
1239            @r"
1240        Projection: a < CAST(UInt32(2) AS Float64)
1241          EmptyRelation: rows=0
1242        "
1243        )
1244    }
1245
1246    #[test]
1247    fn test_coerce_union() -> Result<()> {
1248        let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1249            produce_one_row: false,
1250            schema: Arc::new(
1251                DFSchema::try_from_qualified_schema(
1252                    TableReference::full("datafusion", "test", "foo"),
1253                    &Schema::new(vec![Field::new("a", DataType::Int32, false)]),
1254                )
1255                .unwrap(),
1256            ),
1257        }));
1258        let right_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1259            produce_one_row: false,
1260            schema: Arc::new(
1261                DFSchema::try_from_qualified_schema(
1262                    TableReference::full("datafusion", "test", "foo"),
1263                    &Schema::new(vec![Field::new("a", DataType::Int64, false)]),
1264                )
1265                .unwrap(),
1266            ),
1267        }));
1268        let union = LogicalPlan::Union(Union::try_new_with_loose_types(vec![
1269            left_plan, right_plan,
1270        ])?);
1271        let analyzed_union = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())])
1272            .execute_and_check(union, &ConfigOptions::default(), |_, _| {})?;
1273        let top_level_plan = LogicalPlan::Projection(Projection::try_new(
1274            vec![col("a")],
1275            Arc::new(analyzed_union),
1276        )?);
1277
1278        assert_analyzed_plan_eq!(
1279            top_level_plan,
1280            @r"
1281        Projection: a
1282          Union
1283            Projection: CAST(datafusion.test.foo.a AS Int64) AS a
1284              EmptyRelation: rows=0
1285            EmptyRelation: rows=0
1286        "
1287        )
1288    }
1289
1290    #[test]
1291    fn coerce_utf8view_output() -> Result<()> {
1292        // Plan A
1293        // scenario: outermost utf8view projection
1294        let expr = col("a");
1295        let empty = empty_with_type(DataType::Utf8View);
1296        let plan = LogicalPlan::Projection(Projection::try_new(
1297            vec![expr.clone()],
1298            Arc::clone(&empty),
1299        )?);
1300
1301        // Plan A: no coerce
1302        coerce_on_output_if_viewtype!(
1303            false,
1304            plan.clone(),
1305            @r"
1306        Projection: a
1307          EmptyRelation: rows=0
1308        "
1309        )?;
1310
1311        // Plan A: coerce requested: Utf8View => LargeUtf8
1312        coerce_on_output_if_viewtype!(
1313            true,
1314            plan.clone(),
1315            @r"
1316        Projection: CAST(a AS LargeUtf8)
1317          EmptyRelation: rows=0
1318        "
1319        )?;
1320
1321        // Plan B
1322        // scenario: outermost bool projection
1323        let bool_expr = col("a").lt(lit("foo"));
1324        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1325            vec![bool_expr],
1326            Arc::clone(&empty),
1327        )?);
1328        // Plan B: no coerce
1329        coerce_on_output_if_viewtype!(
1330            false,
1331            bool_plan.clone(),
1332            @r#"
1333        Projection: a < CAST(Utf8("foo") AS Utf8View)
1334          EmptyRelation: rows=0
1335        "#
1336        )?;
1337
1338        coerce_on_output_if_viewtype!(
1339            false,
1340            plan.clone(),
1341            @r"
1342        Projection: a
1343          EmptyRelation: rows=0
1344        "
1345        )?;
1346
1347        // Plan B: coerce requested: no coercion applied
1348        coerce_on_output_if_viewtype!(
1349            true,
1350            plan.clone(),
1351            @r"
1352        Projection: CAST(a AS LargeUtf8)
1353          EmptyRelation: rows=0
1354        "
1355        )?;
1356
1357        // Plan C
1358        // scenario: with a non-projection root logical plan node
1359        let sort_expr = expr.sort(true, true);
1360        let sort_plan = LogicalPlan::Sort(Sort {
1361            expr: vec![sort_expr],
1362            input: Arc::new(plan),
1363            fetch: None,
1364        });
1365
1366        // Plan C: no coerce
1367        coerce_on_output_if_viewtype!(
1368            false,
1369            sort_plan.clone(),
1370            @r"
1371        Sort: a ASC NULLS FIRST
1372          Projection: a
1373            EmptyRelation: rows=0
1374        "
1375        )?;
1376
1377        // Plan C: coerce requested: Utf8View => LargeUtf8
1378        coerce_on_output_if_viewtype!(
1379            true,
1380            sort_plan.clone(),
1381            @r"
1382        Projection: CAST(a AS LargeUtf8)
1383          Sort: a ASC NULLS FIRST
1384            Projection: a
1385              EmptyRelation: rows=0
1386        "
1387        )?;
1388
1389        // Plan D
1390        // scenario: two layers of projections with view types
1391        let plan = LogicalPlan::Projection(Projection::try_new(
1392            vec![col("a")],
1393            Arc::new(sort_plan),
1394        )?);
1395        // Plan D: no coerce
1396        coerce_on_output_if_viewtype!(
1397            false,
1398            plan.clone(),
1399            @r"
1400        Projection: a
1401          Sort: a ASC NULLS FIRST
1402            Projection: a
1403              EmptyRelation: rows=0
1404        "
1405        )?;
1406        // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost
1407        coerce_on_output_if_viewtype!(
1408            true,
1409            plan.clone(),
1410            @r"
1411        Projection: CAST(a AS LargeUtf8)
1412          Sort: a ASC NULLS FIRST
1413            Projection: a
1414              EmptyRelation: rows=0
1415        "
1416        )?;
1417
1418        Ok(())
1419    }
1420
1421    #[test]
1422    fn coerce_binaryview_output() -> Result<()> {
1423        // Plan A
1424        // scenario: outermost binaryview projection
1425        let expr = col("a");
1426        let empty = empty_with_type(DataType::BinaryView);
1427        let plan = LogicalPlan::Projection(Projection::try_new(
1428            vec![expr.clone()],
1429            Arc::clone(&empty),
1430        )?);
1431
1432        // Plan A: no coerce
1433        coerce_on_output_if_viewtype!(
1434            false,
1435            plan.clone(),
1436            @r"
1437        Projection: a
1438          EmptyRelation: rows=0
1439        "
1440        )?;
1441
1442        // Plan A: coerce requested: BinaryView => LargeBinary
1443        coerce_on_output_if_viewtype!(
1444            true,
1445            plan.clone(),
1446            @r"
1447        Projection: CAST(a AS LargeBinary)
1448          EmptyRelation: rows=0
1449        "
1450        )?;
1451
1452        // Plan B
1453        // scenario: outermost bool projection
1454        let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1]));
1455        let bool_plan = LogicalPlan::Projection(Projection::try_new(
1456            vec![bool_expr],
1457            Arc::clone(&empty),
1458        )?);
1459
1460        // Plan B: no coerce
1461        coerce_on_output_if_viewtype!(
1462            false,
1463            bool_plan.clone(),
1464            @r#"
1465        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1466          EmptyRelation: rows=0
1467        "#
1468        )?;
1469
1470        // Plan B: coerce requested: no coercion applied
1471        coerce_on_output_if_viewtype!(
1472            true,
1473            bool_plan.clone(),
1474            @r#"
1475        Projection: a < CAST(Binary("8,1,8,1") AS BinaryView)
1476          EmptyRelation: rows=0
1477        "#
1478        )?;
1479
1480        // Plan C
1481        // scenario: with a non-projection root logical plan node
1482        let sort_expr = expr.sort(true, true);
1483        let sort_plan = LogicalPlan::Sort(Sort {
1484            expr: vec![sort_expr],
1485            input: Arc::new(plan),
1486            fetch: None,
1487        });
1488
1489        // Plan C: no coerce
1490        coerce_on_output_if_viewtype!(
1491            false,
1492            sort_plan.clone(),
1493            @r"
1494        Sort: a ASC NULLS FIRST
1495          Projection: a
1496            EmptyRelation: rows=0
1497        "
1498        )?;
1499        // Plan C: coerce requested: BinaryView => LargeBinary
1500        coerce_on_output_if_viewtype!(
1501            true,
1502            sort_plan.clone(),
1503            @r"
1504        Projection: CAST(a AS LargeBinary)
1505          Sort: a ASC NULLS FIRST
1506            Projection: a
1507              EmptyRelation: rows=0
1508        "
1509        )?;
1510
1511        // Plan D
1512        // scenario: two layers of projections with view types
1513        let plan = LogicalPlan::Projection(Projection::try_new(
1514            vec![col("a")],
1515            Arc::new(sort_plan),
1516        )?);
1517
1518        // Plan D: no coerce
1519        coerce_on_output_if_viewtype!(
1520            false,
1521            plan.clone(),
1522            @r"
1523        Projection: a
1524          Sort: a ASC NULLS FIRST
1525            Projection: a
1526              EmptyRelation: rows=0
1527        "
1528        )?;
1529
1530        // Plan B: coerce requested: BinaryView => LargeBinary only on outermost
1531        coerce_on_output_if_viewtype!(
1532            true,
1533            plan.clone(),
1534            @r"
1535        Projection: CAST(a AS LargeBinary)
1536          Sort: a ASC NULLS FIRST
1537            Projection: a
1538              EmptyRelation: rows=0
1539        "
1540        )?;
1541
1542        Ok(())
1543    }
1544
1545    #[test]
1546    fn nested_case() -> Result<()> {
1547        let expr = col("a").lt(lit(2_u32));
1548        let empty = empty_with_type(DataType::Float64);
1549
1550        let plan = LogicalPlan::Projection(Projection::try_new(
1551            vec![expr.clone().or(expr)],
1552            empty,
1553        )?);
1554
1555        assert_analyzed_plan_eq!(
1556            plan,
1557            @r"
1558        Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)
1559          EmptyRelation: rows=0
1560        "
1561        )
1562    }
1563
1564    #[derive(Debug, PartialEq, Eq, Hash)]
1565    struct TestScalarUDF {
1566        signature: Signature,
1567    }
1568
1569    impl ScalarUDFImpl for TestScalarUDF {
1570        fn as_any(&self) -> &dyn Any {
1571            self
1572        }
1573
1574        fn name(&self) -> &str {
1575            "TestScalarUDF"
1576        }
1577
1578        fn signature(&self) -> &Signature {
1579            &self.signature
1580        }
1581
1582        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
1583            Ok(Utf8)
1584        }
1585
1586        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
1587            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
1588        }
1589    }
1590
1591    #[test]
1592    fn scalar_udf() -> Result<()> {
1593        let empty = empty();
1594
1595        let udf = ScalarUDF::from(TestScalarUDF {
1596            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1597        })
1598        .call(vec![lit(123_i32)]);
1599        let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?);
1600
1601        assert_analyzed_plan_eq!(
1602            plan,
1603            @r"
1604        Projection: TestScalarUDF(CAST(Int32(123) AS Float32))
1605          EmptyRelation: rows=0
1606        "
1607        )
1608    }
1609
1610    #[test]
1611    fn scalar_udf_invalid_input() -> Result<()> {
1612        let empty = empty();
1613        let udf = ScalarUDF::from(TestScalarUDF {
1614            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1615        })
1616        .call(vec![lit("Apple")]);
1617        Projection::try_new(vec![udf], empty)
1618            .expect_err("Expected an error due to incorrect function input");
1619
1620        Ok(())
1621    }
1622
1623    #[test]
1624    fn scalar_function() -> Result<()> {
1625        // test that automatic argument type coercion for scalar functions work
1626        let empty = empty();
1627        let lit_expr = lit(10i64);
1628        let fun = ScalarUDF::new_from_impl(TestScalarUDF {
1629            signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1630        });
1631        let scalar_function_expr =
1632            Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr]));
1633        let plan = LogicalPlan::Projection(Projection::try_new(
1634            vec![scalar_function_expr],
1635            empty,
1636        )?);
1637
1638        assert_analyzed_plan_eq!(
1639            plan,
1640            @r"
1641        Projection: TestScalarUDF(CAST(Int64(10) AS Float32))
1642          EmptyRelation: rows=0
1643        "
1644        )
1645    }
1646
1647    #[test]
1648    fn agg_udaf() -> Result<()> {
1649        let empty = empty();
1650        let my_avg = create_udaf(
1651            "MY_AVG",
1652            vec![DataType::Float64],
1653            Arc::new(DataType::Float64),
1654            Volatility::Immutable,
1655            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
1656            Arc::new(vec![DataType::UInt64, DataType::Float64]),
1657        );
1658        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1659            Arc::new(my_avg),
1660            vec![lit(10i64)],
1661            false,
1662            None,
1663            vec![],
1664            None,
1665        ));
1666        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?);
1667
1668        assert_analyzed_plan_eq!(
1669            plan,
1670            @r"
1671        Projection: MY_AVG(CAST(Int64(10) AS Float64))
1672          EmptyRelation: rows=0
1673        "
1674        )
1675    }
1676
1677    #[test]
1678    fn agg_udaf_invalid_input() -> Result<()> {
1679        let empty = empty();
1680        let return_type = DataType::Float64;
1681        let accumulator: AccumulatorFactoryFunction =
1682            Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
1683        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
1684            "MY_AVG",
1685            Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
1686            return_type,
1687            accumulator,
1688            vec![
1689                Field::new("count", DataType::UInt64, true).into(),
1690                Field::new("avg", DataType::Float64, true).into(),
1691            ],
1692        ));
1693        let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1694            Arc::new(my_avg),
1695            vec![lit("10")],
1696            false,
1697            None,
1698            vec![],
1699            None,
1700        ));
1701
1702        let err = Projection::try_new(vec![udaf], empty).err().unwrap();
1703        assert!(
1704            err.strip_backtrace().starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'MY_AVG' function: coercion from Utf8 to the signature Uniform(1, [Float64]) failed")
1705        );
1706        Ok(())
1707    }
1708
1709    #[test]
1710    fn agg_function_case() -> Result<()> {
1711        let empty = empty();
1712        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1713            avg_udaf(),
1714            vec![lit(12f64)],
1715            false,
1716            None,
1717            vec![],
1718            None,
1719        ));
1720        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1721
1722        assert_analyzed_plan_eq!(
1723            plan,
1724            @r"
1725        Projection: avg(Float64(12))
1726          EmptyRelation: rows=0
1727        "
1728        )?;
1729
1730        let empty = empty_with_type(DataType::Int32);
1731        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1732            avg_udaf(),
1733            vec![cast(col("a"), DataType::Float64)],
1734            false,
1735            None,
1736            vec![],
1737            None,
1738        ));
1739        let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
1740
1741        assert_analyzed_plan_eq!(
1742            plan,
1743            @r"
1744        Projection: avg(CAST(a AS Float64))
1745          EmptyRelation: rows=0
1746        "
1747        )
1748    }
1749
1750    #[test]
1751    fn agg_function_invalid_input_avg() -> Result<()> {
1752        let empty = empty();
1753        let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
1754            avg_udaf(),
1755            vec![lit("1")],
1756            false,
1757            None,
1758            vec![],
1759            None,
1760        ));
1761        let err = Projection::try_new(vec![agg_expr], empty)
1762            .err()
1763            .unwrap()
1764            .strip_backtrace();
1765        assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed"));
1766        Ok(())
1767    }
1768
1769    #[test]
1770    fn binary_op_date32_op_interval() -> Result<()> {
1771        // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
1772        let expr = cast(lit("1998-03-18"), DataType::Date32)
1773            + lit(ScalarValue::new_interval_dt(123, 456));
1774        let empty = empty();
1775        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1776
1777        assert_analyzed_plan_eq!(
1778            plan,
1779            @r#"
1780        Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }")
1781          EmptyRelation: rows=0
1782        "#
1783        )
1784    }
1785
1786    #[test]
1787    fn inlist_case() -> Result<()> {
1788        // a in (1,4,8), a is int64
1789        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1790        let empty = empty_with_type(DataType::Int64);
1791        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1792        assert_analyzed_plan_eq!(
1793            plan,
1794            @r"
1795        Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])
1796          EmptyRelation: rows=0
1797        ")?;
1798
1799        // a in (1,4,8), a is decimal
1800        let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false);
1801        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
1802            produce_one_row: false,
1803            schema: Arc::new(DFSchema::from_unqualified_fields(
1804                vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(),
1805                std::collections::HashMap::new(),
1806            )?),
1807        }));
1808        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1809        assert_analyzed_plan_eq!(
1810            plan,
1811            @r"
1812        Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])
1813          EmptyRelation: rows=0
1814        ")
1815    }
1816
1817    #[test]
1818    fn between_case() -> Result<()> {
1819        let expr = col("a").between(
1820            lit("2002-05-08"),
1821            // (cast('2002-05-08' as date) + interval '1 months')
1822            cast(lit("2002-05-08"), DataType::Date32)
1823                + lit(ScalarValue::new_interval_ym(0, 1)),
1824        );
1825        let empty = empty_with_type(Utf8);
1826        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1827
1828        assert_analyzed_plan_eq!(
1829            plan,
1830            @r#"
1831        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1")
1832          EmptyRelation: rows=0
1833        "#
1834        )
1835    }
1836
1837    #[test]
1838    fn between_infer_cheap_type() -> Result<()> {
1839        let expr = col("a").between(
1840            // (cast('2002-05-08' as date) + interval '1 months')
1841            cast(lit("2002-05-08"), DataType::Date32)
1842                + lit(ScalarValue::new_interval_ym(0, 1)),
1843            lit("2002-12-08"),
1844        );
1845        let empty = empty_with_type(Utf8);
1846        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1847
1848        // TODO: we should cast col(a).
1849        assert_analyzed_plan_eq!(
1850            plan,
1851            @r#"
1852        Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32)
1853          EmptyRelation: rows=0
1854        "#
1855        )
1856    }
1857
1858    #[test]
1859    fn between_null() -> Result<()> {
1860        let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64));
1861        let empty = empty();
1862        let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?);
1863
1864        assert_analyzed_plan_eq!(
1865            plan,
1866            @r"
1867        Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)
1868          EmptyRelation: rows=0
1869        "
1870        )
1871    }
1872
1873    #[test]
1874    fn is_bool_for_type_coercion() -> Result<()> {
1875        // is true
1876        let expr = col("a").is_true();
1877        let empty = empty_with_type(DataType::Boolean);
1878        let plan =
1879            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
1880
1881        assert_analyzed_plan_eq!(
1882            plan,
1883            @r"
1884        Projection: a IS TRUE
1885          EmptyRelation: rows=0
1886        "
1887        )?;
1888
1889        let empty = empty_with_type(DataType::Int64);
1890        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1891        assert_type_coercion_error(
1892            plan,
1893            "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"
1894        )?;
1895
1896        // is not true
1897        let expr = col("a").is_not_true();
1898        let empty = empty_with_type(DataType::Boolean);
1899        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1900
1901        assert_analyzed_plan_eq!(
1902            plan,
1903            @r"
1904        Projection: a IS NOT TRUE
1905          EmptyRelation: rows=0
1906        "
1907        )?;
1908
1909        // is false
1910        let expr = col("a").is_false();
1911        let empty = empty_with_type(DataType::Boolean);
1912        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1913
1914        assert_analyzed_plan_eq!(
1915            plan,
1916            @r"
1917        Projection: a IS FALSE
1918          EmptyRelation: rows=0
1919        "
1920        )?;
1921
1922        // is not false
1923        let expr = col("a").is_not_false();
1924        let empty = empty_with_type(DataType::Boolean);
1925        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1926
1927        assert_analyzed_plan_eq!(
1928            plan,
1929            @r"
1930        Projection: a IS NOT FALSE
1931          EmptyRelation: rows=0
1932        "
1933        )
1934    }
1935
1936    #[test]
1937    fn like_for_type_coercion() -> Result<()> {
1938        // like : utf8 like "abc"
1939        let expr = Box::new(col("a"));
1940        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1941        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1942        let empty = empty_with_type(Utf8);
1943        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1944
1945        assert_analyzed_plan_eq!(
1946            plan,
1947            @r#"
1948        Projection: a LIKE Utf8("abc")
1949          EmptyRelation: rows=0
1950        "#
1951        )?;
1952
1953        let expr = Box::new(col("a"));
1954        let pattern = Box::new(lit(ScalarValue::Null));
1955        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1956        let empty = empty_with_type(Utf8);
1957        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1958
1959        assert_analyzed_plan_eq!(
1960            plan,
1961            @r"
1962        Projection: a LIKE CAST(NULL AS Utf8)
1963          EmptyRelation: rows=0
1964        "
1965        )?;
1966
1967        let expr = Box::new(col("a"));
1968        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1969        let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false));
1970        let empty = empty_with_type(DataType::Int64);
1971        let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
1972        assert_type_coercion_error(
1973            plan,
1974            "There isn't a common type to coerce Int64 and Utf8 in LIKE expression",
1975        )?;
1976
1977        // ilike
1978        let expr = Box::new(col("a"));
1979        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
1980        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1981        let empty = empty_with_type(Utf8);
1982        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1983
1984        assert_analyzed_plan_eq!(
1985            plan,
1986            @r#"
1987        Projection: a ILIKE Utf8("abc")
1988          EmptyRelation: rows=0
1989        "#
1990        )?;
1991
1992        let expr = Box::new(col("a"));
1993        let pattern = Box::new(lit(ScalarValue::Null));
1994        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
1995        let empty = empty_with_type(Utf8);
1996        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
1997
1998        assert_analyzed_plan_eq!(
1999            plan,
2000            @r"
2001        Projection: a ILIKE CAST(NULL AS Utf8)
2002          EmptyRelation: rows=0
2003        "
2004        )?;
2005
2006        let expr = Box::new(col("a"));
2007        let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
2008        let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true));
2009        let empty = empty_with_type(DataType::Int64);
2010        let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
2011        assert_type_coercion_error(
2012            plan,
2013            "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression",
2014        )?;
2015
2016        Ok(())
2017    }
2018
2019    #[test]
2020    fn unknown_for_type_coercion() -> Result<()> {
2021        // unknown
2022        let expr = col("a").is_unknown();
2023        let empty = empty_with_type(DataType::Boolean);
2024        let plan =
2025            LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?);
2026
2027        assert_analyzed_plan_eq!(
2028            plan,
2029            @r"
2030        Projection: a IS UNKNOWN
2031          EmptyRelation: rows=0
2032        "
2033        )?;
2034
2035        let empty = empty_with_type(Utf8);
2036        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2037        assert_type_coercion_error(
2038            plan,
2039            "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"
2040        )?;
2041
2042        // is not unknown
2043        let expr = col("a").is_not_unknown();
2044        let empty = empty_with_type(DataType::Boolean);
2045        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2046
2047        assert_analyzed_plan_eq!(
2048            plan,
2049            @r"
2050        Projection: a IS NOT UNKNOWN
2051          EmptyRelation: rows=0
2052        "
2053        )
2054    }
2055
2056    #[test]
2057    fn concat_for_type_coercion() -> Result<()> {
2058        let empty = empty_with_type(Utf8);
2059        let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)];
2060
2061        // concat-type signature
2062        let expr = ScalarUDF::new_from_impl(TestScalarUDF {
2063            signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
2064        })
2065        .call(args.to_vec());
2066        let plan =
2067            LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?);
2068        assert_analyzed_plan_eq!(
2069            plan,
2070            @r#"
2071        Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))
2072          EmptyRelation: rows=0
2073        "#
2074        )
2075    }
2076
2077    #[test]
2078    fn test_type_coercion_rewrite() -> Result<()> {
2079        // gt
2080        let schema = Arc::new(DFSchema::from_unqualified_fields(
2081            vec![Field::new("a", DataType::Int64, true)].into(),
2082            std::collections::HashMap::new(),
2083        )?);
2084        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2085        let expr = is_true(lit(12i32).gt(lit(13i64)));
2086        let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
2087        let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?;
2088        assert_eq!(expected, result);
2089
2090        // eq
2091        let schema = Arc::new(DFSchema::from_unqualified_fields(
2092            vec![Field::new("a", DataType::Int64, true)].into(),
2093            std::collections::HashMap::new(),
2094        )?);
2095        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2096        let expr = is_true(lit(12i32).eq(lit(13i64)));
2097        let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
2098        let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?;
2099        assert_eq!(expected, result);
2100
2101        // lt
2102        let schema = Arc::new(DFSchema::from_unqualified_fields(
2103            vec![Field::new("a", DataType::Int64, true)].into(),
2104            std::collections::HashMap::new(),
2105        )?);
2106        let mut rewriter = TypeCoercionRewriter { schema: &schema };
2107        let expr = is_true(lit(12i32).lt(lit(13i64)));
2108        let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
2109        let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?;
2110        assert_eq!(expected, result);
2111
2112        Ok(())
2113    }
2114
2115    #[test]
2116    fn binary_op_date32_eq_ts() -> Result<()> {
2117        let expr = cast(
2118            lit("1998-03-18"),
2119            DataType::Timestamp(TimeUnit::Nanosecond, None),
2120        )
2121        .eq(cast(lit("1998-03-18"), DataType::Date32));
2122        let empty = empty();
2123        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2124
2125        assert_analyzed_plan_eq!(
2126            plan,
2127            @r#"
2128        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(ns))
2129          EmptyRelation: rows=0
2130        "#
2131        )
2132    }
2133
2134    fn cast_if_not_same_type(
2135        expr: Box<Expr>,
2136        data_type: &DataType,
2137        schema: &DFSchemaRef,
2138    ) -> Box<Expr> {
2139        if &expr.get_type(schema).unwrap() != data_type {
2140            Box::new(cast(*expr, data_type.clone()))
2141        } else {
2142            expr
2143        }
2144    }
2145
2146    fn cast_helper(
2147        case: Case,
2148        case_when_type: &DataType,
2149        then_else_type: &DataType,
2150        schema: &DFSchemaRef,
2151    ) -> Case {
2152        let expr = case
2153            .expr
2154            .map(|e| cast_if_not_same_type(e, case_when_type, schema));
2155        let when_then_expr = case
2156            .when_then_expr
2157            .into_iter()
2158            .map(|(when, then)| {
2159                (
2160                    cast_if_not_same_type(when, case_when_type, schema),
2161                    cast_if_not_same_type(then, then_else_type, schema),
2162                )
2163            })
2164            .collect::<Vec<_>>();
2165        let else_expr = case
2166            .else_expr
2167            .map(|e| cast_if_not_same_type(e, then_else_type, schema));
2168
2169        Case {
2170            expr,
2171            when_then_expr,
2172            else_expr,
2173        }
2174    }
2175
2176    #[test]
2177    fn test_case_expression_coercion() -> Result<()> {
2178        let schema = Arc::new(DFSchema::from_unqualified_fields(
2179            vec![
2180                Field::new("boolean", DataType::Boolean, true),
2181                Field::new("integer", DataType::Int32, true),
2182                Field::new("float", DataType::Float32, true),
2183                Field::new(
2184                    "timestamp",
2185                    DataType::Timestamp(TimeUnit::Nanosecond, None),
2186                    true,
2187                ),
2188                Field::new("date", DataType::Date32, true),
2189                Field::new(
2190                    "interval",
2191                    DataType::Interval(arrow::datatypes::IntervalUnit::MonthDayNano),
2192                    true,
2193                ),
2194                Field::new("binary", DataType::Binary, true),
2195                Field::new("string", Utf8, true),
2196                Field::new("decimal", DataType::Decimal128(10, 10), true),
2197            ]
2198            .into(),
2199            std::collections::HashMap::new(),
2200        )?);
2201
2202        let case = Case {
2203            expr: None,
2204            when_then_expr: vec![
2205                (Box::new(col("boolean")), Box::new(col("integer"))),
2206                (Box::new(col("integer")), Box::new(col("float"))),
2207                (Box::new(col("string")), Box::new(col("string"))),
2208            ],
2209            else_expr: None,
2210        };
2211        let case_when_common_type = DataType::Boolean;
2212        let then_else_common_type = Utf8;
2213        let expected = cast_helper(
2214            case.clone(),
2215            &case_when_common_type,
2216            &then_else_common_type,
2217            &schema,
2218        );
2219        let actual = coerce_case_expression(case, &schema)?;
2220        assert_eq!(expected, actual);
2221
2222        let case = Case {
2223            expr: Some(Box::new(col("string"))),
2224            when_then_expr: vec![
2225                (Box::new(col("float")), Box::new(col("integer"))),
2226                (Box::new(col("integer")), Box::new(col("float"))),
2227                (Box::new(col("string")), Box::new(col("string"))),
2228            ],
2229            else_expr: Some(Box::new(col("string"))),
2230        };
2231        let case_when_common_type = Utf8;
2232        let then_else_common_type = Utf8;
2233        let expected = cast_helper(
2234            case.clone(),
2235            &case_when_common_type,
2236            &then_else_common_type,
2237            &schema,
2238        );
2239        let actual = coerce_case_expression(case, &schema)?;
2240        assert_eq!(expected, actual);
2241
2242        let case = Case {
2243            expr: Some(Box::new(col("interval"))),
2244            when_then_expr: vec![
2245                (Box::new(col("float")), Box::new(col("integer"))),
2246                (Box::new(col("binary")), Box::new(col("float"))),
2247                (Box::new(col("string")), Box::new(col("string"))),
2248            ],
2249            else_expr: Some(Box::new(col("string"))),
2250        };
2251        let err = coerce_case_expression(case, &schema).unwrap_err();
2252        assert_snapshot!(
2253            err.strip_backtrace(),
2254            @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when (Float32, Binary, Utf8) to common types in CASE WHEN expression"
2255        );
2256
2257        let case = Case {
2258            expr: Some(Box::new(col("string"))),
2259            when_then_expr: vec![
2260                (Box::new(col("float")), Box::new(col("date"))),
2261                (Box::new(col("string")), Box::new(col("float"))),
2262                (Box::new(col("string")), Box::new(col("binary"))),
2263            ],
2264            else_expr: Some(Box::new(col("timestamp"))),
2265        };
2266        let err = coerce_case_expression(case, &schema).unwrap_err();
2267        assert_snapshot!(
2268            err.strip_backtrace(),
2269            @"Error during planning: Failed to coerce then (Date32, Float32, Binary) and else (Timestamp(ns)) to common types in CASE WHEN expression"
2270        );
2271
2272        Ok(())
2273    }
2274
2275    macro_rules! test_case_expression {
2276        ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => {
2277            let case = Case {
2278                expr: $expr.map(|e| Box::new(col(e))),
2279                when_then_expr: $when_then,
2280                else_expr: None,
2281            };
2282
2283            let expected =
2284                cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema);
2285
2286            let actual = coerce_case_expression(case, &$schema)?;
2287            assert_eq!(expected, actual);
2288        };
2289    }
2290
2291    #[test]
2292    fn tes_case_when_list() -> Result<()> {
2293        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2294        let schema = Arc::new(DFSchema::from_unqualified_fields(
2295            vec![
2296                Field::new(
2297                    "large_list",
2298                    DataType::LargeList(Arc::clone(&inner_field)),
2299                    true,
2300                ),
2301                Field::new(
2302                    "fixed_list",
2303                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2304                    true,
2305                ),
2306                Field::new("list", DataType::List(inner_field), true),
2307            ]
2308            .into(),
2309            std::collections::HashMap::new(),
2310        )?);
2311
2312        test_case_expression!(
2313            Some("list"),
2314            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2315            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2316            Utf8,
2317            schema
2318        );
2319
2320        test_case_expression!(
2321            Some("large_list"),
2322            vec![(Box::new(col("list")), Box::new(lit("1")))],
2323            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2324            Utf8,
2325            schema
2326        );
2327
2328        test_case_expression!(
2329            Some("list"),
2330            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2331            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2332            Utf8,
2333            schema
2334        );
2335
2336        test_case_expression!(
2337            Some("fixed_list"),
2338            vec![(Box::new(col("list")), Box::new(lit("1")))],
2339            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2340            Utf8,
2341            schema
2342        );
2343
2344        test_case_expression!(
2345            Some("fixed_list"),
2346            vec![(Box::new(col("large_list")), Box::new(lit("1")))],
2347            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2348            Utf8,
2349            schema
2350        );
2351
2352        test_case_expression!(
2353            Some("large_list"),
2354            vec![(Box::new(col("fixed_list")), Box::new(lit("1")))],
2355            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2356            Utf8,
2357            schema
2358        );
2359        Ok(())
2360    }
2361
2362    #[test]
2363    fn test_then_else_list() -> Result<()> {
2364        let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true));
2365        let schema = Arc::new(DFSchema::from_unqualified_fields(
2366            vec![
2367                Field::new("boolean", DataType::Boolean, true),
2368                Field::new(
2369                    "large_list",
2370                    DataType::LargeList(Arc::clone(&inner_field)),
2371                    true,
2372                ),
2373                Field::new(
2374                    "fixed_list",
2375                    DataType::FixedSizeList(Arc::clone(&inner_field), 3),
2376                    true,
2377                ),
2378                Field::new("list", DataType::List(inner_field), true),
2379            ]
2380            .into(),
2381            std::collections::HashMap::new(),
2382        )?);
2383
2384        // large list and list
2385        test_case_expression!(
2386            None::<String>,
2387            vec![
2388                (Box::new(col("boolean")), Box::new(col("large_list"))),
2389                (Box::new(col("boolean")), Box::new(col("list")))
2390            ],
2391            DataType::Boolean,
2392            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2393            schema
2394        );
2395
2396        test_case_expression!(
2397            None::<String>,
2398            vec![
2399                (Box::new(col("boolean")), Box::new(col("list"))),
2400                (Box::new(col("boolean")), Box::new(col("large_list")))
2401            ],
2402            DataType::Boolean,
2403            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2404            schema
2405        );
2406
2407        // fixed list and list
2408        test_case_expression!(
2409            None::<String>,
2410            vec![
2411                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2412                (Box::new(col("boolean")), Box::new(col("list")))
2413            ],
2414            DataType::Boolean,
2415            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2416            schema
2417        );
2418
2419        test_case_expression!(
2420            None::<String>,
2421            vec![
2422                (Box::new(col("boolean")), Box::new(col("list"))),
2423                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2424            ],
2425            DataType::Boolean,
2426            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
2427            schema
2428        );
2429
2430        // fixed list and large list
2431        test_case_expression!(
2432            None::<String>,
2433            vec![
2434                (Box::new(col("boolean")), Box::new(col("fixed_list"))),
2435                (Box::new(col("boolean")), Box::new(col("large_list")))
2436            ],
2437            DataType::Boolean,
2438            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2439            schema
2440        );
2441
2442        test_case_expression!(
2443            None::<String>,
2444            vec![
2445                (Box::new(col("boolean")), Box::new(col("large_list"))),
2446                (Box::new(col("boolean")), Box::new(col("fixed_list")))
2447            ],
2448            DataType::Boolean,
2449            DataType::LargeList(Arc::new(Field::new_list_field(DataType::Int64, true))),
2450            schema
2451        );
2452        Ok(())
2453    }
2454
2455    #[test]
2456    fn test_map_with_diff_name() -> Result<()> {
2457        let mut builder = SchemaBuilder::new();
2458        builder.push(Field::new("key", Utf8, false));
2459        builder.push(Field::new("value", DataType::Float64, true));
2460        let struct_fields = builder.finish().fields;
2461
2462        let fields =
2463            Field::new("entries", DataType::Struct(struct_fields.clone()), false);
2464        let map_type_entries = DataType::Map(Arc::new(fields), false);
2465
2466        let fields = Field::new("key_value", DataType::Struct(struct_fields), false);
2467        let may_type_custom = DataType::Map(Arc::new(fields), false);
2468
2469        let expr = col("a").eq(cast(col("a"), may_type_custom));
2470        let empty = empty_with_type(map_type_entries);
2471        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2472
2473        assert_analyzed_plan_eq!(
2474            plan,
2475            @r#"
2476        Projection: a = CAST(CAST(a AS Map("key_value": Struct("key": Utf8, "value": nullable Float64), unsorted)) AS Map("entries": Struct("key": Utf8, "value": nullable Float64), unsorted))
2477          EmptyRelation: rows=0
2478        "#
2479        )
2480    }
2481
2482    #[test]
2483    fn interval_plus_timestamp() -> Result<()> {
2484        // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp;
2485        let expr = Expr::BinaryExpr(BinaryExpr::new(
2486            Box::new(lit(ScalarValue::IntervalYearMonth(Some(12)))),
2487            Operator::Plus,
2488            Box::new(cast(
2489                lit("2000-01-01T00:00:00"),
2490                DataType::Timestamp(TimeUnit::Nanosecond, None),
2491            )),
2492        ));
2493        let empty = empty();
2494        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2495
2496        assert_analyzed_plan_eq!(
2497            plan,
2498            @r#"
2499        Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(ns))
2500          EmptyRelation: rows=0
2501        "#
2502        )
2503    }
2504
2505    #[test]
2506    fn timestamp_subtract_timestamp() -> Result<()> {
2507        let expr = Expr::BinaryExpr(BinaryExpr::new(
2508            Box::new(cast(
2509                lit("1998-03-18"),
2510                DataType::Timestamp(TimeUnit::Nanosecond, None),
2511            )),
2512            Operator::Minus,
2513            Box::new(cast(
2514                lit("1998-03-18"),
2515                DataType::Timestamp(TimeUnit::Nanosecond, None),
2516            )),
2517        ));
2518        let empty = empty();
2519        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
2520
2521        assert_analyzed_plan_eq!(
2522            plan,
2523            @r#"
2524        Projection: CAST(Utf8("1998-03-18") AS Timestamp(ns)) - CAST(Utf8("1998-03-18") AS Timestamp(ns))
2525          EmptyRelation: rows=0
2526        "#
2527        )
2528    }
2529
2530    #[test]
2531    fn in_subquery_cast_subquery() -> Result<()> {
2532        let empty_int32 = empty_with_type(DataType::Int32);
2533        let empty_int64 = empty_with_type(DataType::Int64);
2534
2535        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2536            Box::new(col("a")),
2537            Subquery {
2538                subquery: empty_int32,
2539                outer_ref_columns: vec![],
2540                spans: Spans::new(),
2541            },
2542            false,
2543        ));
2544        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?);
2545        // add cast for subquery
2546
2547        assert_analyzed_plan_eq!(
2548            plan,
2549            @r"
2550        Filter: a IN (<subquery>)
2551          Subquery:
2552            Projection: CAST(a AS Int64)
2553              EmptyRelation: rows=0
2554          EmptyRelation: rows=0
2555        "
2556        )
2557    }
2558
2559    #[test]
2560    fn in_subquery_cast_expr() -> Result<()> {
2561        let empty_int32 = empty_with_type(DataType::Int32);
2562        let empty_int64 = empty_with_type(DataType::Int64);
2563
2564        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2565            Box::new(col("a")),
2566            Subquery {
2567                subquery: empty_int64,
2568                outer_ref_columns: vec![],
2569                spans: Spans::new(),
2570            },
2571            false,
2572        ));
2573        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?);
2574
2575        // add cast for subquery
2576        assert_analyzed_plan_eq!(
2577            plan,
2578            @r"
2579        Filter: CAST(a AS Int64) IN (<subquery>)
2580          Subquery:
2581            EmptyRelation: rows=0
2582          EmptyRelation: rows=0
2583        "
2584        )
2585    }
2586
2587    #[test]
2588    fn in_subquery_cast_all() -> Result<()> {
2589        let empty_inside = empty_with_type(DataType::Decimal128(10, 5));
2590        let empty_outside = empty_with_type(DataType::Decimal128(8, 8));
2591
2592        let in_subquery_expr = Expr::InSubquery(InSubquery::new(
2593            Box::new(col("a")),
2594            Subquery {
2595                subquery: empty_inside,
2596                outer_ref_columns: vec![],
2597                spans: Spans::new(),
2598            },
2599            false,
2600        ));
2601        let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?);
2602
2603        // add cast for subquery
2604        assert_analyzed_plan_eq!(
2605            plan,
2606            @r"
2607        Filter: CAST(a AS Decimal128(13, 8)) IN (<subquery>)
2608          Subquery:
2609            Projection: CAST(a AS Decimal128(13, 8))
2610              EmptyRelation: rows=0
2611          EmptyRelation: rows=0
2612        "
2613        )
2614    }
2615}