1use std::collections::HashSet;
21use std::sync::Arc;
22
23use crate::operator::Operator;
24
25use arrow::array::{new_empty_array, Array};
26use arrow::compute::can_cast_types;
27use arrow::datatypes::{
28 DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
29 DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
30 DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION,
31 DECIMAL64_MAX_SCALE,
32};
33use datafusion_common::types::NativeType;
34use datafusion_common::{
35 exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Diagnostic,
36 Result, Span, Spans,
37};
38use itertools::Itertools;
39
40struct Signature {
46 lhs: DataType,
48 rhs: DataType,
50 ret: DataType,
52}
53
54impl Signature {
55 fn uniform(t: DataType) -> Self {
57 Self {
58 lhs: t.clone(),
59 rhs: t.clone(),
60 ret: t,
61 }
62 }
63
64 fn comparison(t: DataType) -> Self {
66 Self {
67 lhs: t.clone(),
68 rhs: t,
69 ret: DataType::Boolean,
70 }
71 }
72}
73
74pub struct BinaryTypeCoercer<'a> {
77 lhs: &'a DataType,
78 op: &'a Operator,
79 rhs: &'a DataType,
80
81 lhs_spans: Spans,
82 op_spans: Spans,
83 rhs_spans: Spans,
84}
85
86impl<'a> BinaryTypeCoercer<'a> {
87 pub fn new(lhs: &'a DataType, op: &'a Operator, rhs: &'a DataType) -> Self {
90 Self {
91 lhs,
92 op,
93 rhs,
94 lhs_spans: Spans::new(),
95 op_spans: Spans::new(),
96 rhs_spans: Spans::new(),
97 }
98 }
99
100 pub fn set_lhs_spans(&mut self, spans: Spans) {
103 self.lhs_spans = spans;
104 }
105
106 pub fn set_op_spans(&mut self, spans: Spans) {
109 self.op_spans = spans;
110 }
111
112 pub fn set_rhs_spans(&mut self, spans: Spans) {
115 self.rhs_spans = spans;
116 }
117
118 fn span(&self) -> Option<Span> {
119 Span::union_iter(
120 [self.lhs_spans.first(), self.rhs_spans.first()]
121 .iter()
122 .copied()
123 .flatten(),
124 )
125 }
126
127 fn signature(&'a self) -> Result<Signature> {
129 if matches!((self.lhs, self.rhs), (DataType::Null, DataType::Null))
134 && self.op.is_numerical_operators()
135 {
136 return Ok(Signature::uniform(DataType::Int64));
137 }
138
139 if let Some(coerced) = null_coercion(self.lhs, self.rhs) {
140 if self.op.is_numerical_operators() && !coerced.is_temporal() {
147 let ret = self.get_result(&coerced, &coerced).map_err(|e| {
148 plan_datafusion_err!(
149 "Cannot get result type for arithmetic operation {coerced} {} {coerced}: {e}",
150 self.op
151 )
152 })?;
153
154 return Ok(Signature {
155 lhs: coerced.clone(),
156 rhs: coerced,
157 ret,
158 });
159 }
160 return self.signature_inner(&coerced, &coerced);
161 }
162 self.signature_inner(self.lhs, self.rhs)
163 }
164
165 fn get_result(
167 &self,
168 lhs: &DataType,
169 rhs: &DataType,
170 ) -> arrow::error::Result<DataType> {
171 use arrow::compute::kernels::numeric::*;
172 let l = new_empty_array(lhs);
173 let r = new_empty_array(rhs);
174
175 let result = match self.op {
176 Operator::Plus => add_wrapping(&l, &r),
177 Operator::Minus => sub_wrapping(&l, &r),
178 Operator::Multiply => mul_wrapping(&l, &r),
179 Operator::Divide => div(&l, &r),
180 Operator::Modulo => rem(&l, &r),
181 _ => unreachable!(),
182 };
183 result.map(|x| x.data_type().clone())
184 }
185
186 fn signature_inner(&'a self, lhs: &DataType, rhs: &DataType) -> Result<Signature> {
187 use arrow::datatypes::DataType::*;
188 use Operator::*;
189 let result = match self.op {
190 Eq |
191 NotEq |
192 Lt |
193 LtEq |
194 Gt |
195 GtEq |
196 IsDistinctFrom |
197 IsNotDistinctFrom => {
198 comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
199 plan_datafusion_err!(
200 "Cannot infer common argument type for comparison operation {} {} {}",
201 self.lhs,
202 self.op,
203 self.rhs
204 )
205 })
206 }
207 And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) {
208 Ok(Signature::uniform(Boolean))
211 } else {
212 plan_err!(
213 "Cannot infer common argument type for logical boolean operation {} {} {}", self.lhs, self.op, self.rhs
214 )
215 }
216 RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => {
217 regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
218 plan_datafusion_err!(
219 "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
220 )
221 })
222 }
223 LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => {
224 regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
225 plan_datafusion_err!(
226 "Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
227 )
228 })
229 }
230 BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => {
231 bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
232 plan_datafusion_err!(
233 "Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs
234 )
235 })
236 }
237 StringConcat => {
238 string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
239 plan_datafusion_err!(
240 "Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs
241 )
242 })
243 }
244 AtArrow | ArrowAt => {
245 array_coercion(lhs, rhs)
247 .or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| {
248 plan_datafusion_err!(
249 "Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs
250 )
251 })
252 }
253 AtAt => {
254 like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
256 plan_datafusion_err!(
257 "Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs
258 )
259 })
260 }
261 Plus | Minus | Multiply | Divide | Modulo => {
262 if let Ok(ret) = self.get_result(lhs, rhs) {
263 Ok(Signature{
265 lhs: lhs.clone(),
266 rhs: rhs.clone(),
267 ret,
268 })
269 } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) {
270 let ret = self.get_result(&coerced, &coerced).map_err(|e| {
273 plan_datafusion_err!(
274 "Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op
275 )
276 })?;
277 Ok(Signature{
278 lhs: coerced.clone(),
279 rhs: coerced,
280 ret,
281 })
282 } else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
283 let ret = self.get_result(&lhs, &rhs).map_err(|e| {
285 plan_datafusion_err!(
286 "Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs
287 )
288 })?;
289 Ok(Signature{
290 lhs,
291 rhs,
292 ret,
293 })
294 } else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
295 Ok(Signature::uniform(numeric))
297 } else {
298 plan_err!(
299 "Cannot coerce arithmetic expression {} {} {} to valid types", self.lhs, self.op, self.rhs
300 )
301 }
302 },
303 IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow
304 | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => {
305 not_impl_err!("Operator {} is not yet supported", self.op)
306 }
307 };
308 result.map_err(|err| {
309 let diagnostic =
310 Diagnostic::new_error("expressions have incompatible types", self.span())
311 .with_note(format!("has type {}", self.lhs), self.lhs_spans.first())
312 .with_note(format!("has type {}", self.rhs), self.rhs_spans.first());
313 err.with_diagnostic(diagnostic)
314 })
315 }
316
317 pub fn get_result_type(&'a self) -> Result<DataType> {
319 self.signature().map(|sig| sig.ret)
320 }
321
322 pub fn get_input_types(&'a self) -> Result<(DataType, DataType)> {
324 self.signature().map(|sig| (sig.lhs, sig.rhs))
325 }
326}
327
328fn is_decimal(data_type: &DataType) -> bool {
331 matches!(
332 data_type,
333 DataType::Decimal32(..)
334 | DataType::Decimal64(..)
335 | DataType::Decimal128(..)
336 | DataType::Decimal256(..)
337 )
338}
339
340fn math_decimal_coercion(
342 lhs_type: &DataType,
343 rhs_type: &DataType,
344) -> Option<(DataType, DataType)> {
345 use arrow::datatypes::DataType::*;
346
347 match (lhs_type, rhs_type) {
348 (Dictionary(_, value_type), _) => {
349 let (value_type, rhs_type) = math_decimal_coercion(value_type, rhs_type)?;
350 Some((value_type, rhs_type))
351 }
352 (_, Dictionary(_, value_type)) => {
353 let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?;
354 Some((lhs_type, value_type))
355 }
356 (
357 Null,
358 Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
359 ) => Some((rhs_type.clone(), rhs_type.clone())),
360 (
361 Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
362 Null,
363 ) => Some((lhs_type.clone(), lhs_type.clone())),
364 (Decimal32(_, _), Decimal32(_, _))
365 | (Decimal64(_, _), Decimal64(_, _))
366 | (Decimal128(_, _), Decimal128(_, _))
367 | (Decimal256(_, _), Decimal256(_, _)) => {
368 Some((lhs_type.clone(), rhs_type.clone()))
369 }
370 (lhs, rhs)
372 if is_decimal(lhs)
373 && is_decimal(rhs)
374 && std::mem::discriminant(lhs) != std::mem::discriminant(rhs) =>
375 {
376 let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
377 Some((coerced_type.clone(), coerced_type))
378 }
379 (
382 Decimal32(_, _),
383 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
384 ) => Some((
385 lhs_type.clone(),
386 coerce_numeric_type_to_decimal32(rhs_type)?,
387 )),
388 (
389 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
390 Decimal32(_, _),
391 ) => Some((
392 coerce_numeric_type_to_decimal32(lhs_type)?,
393 rhs_type.clone(),
394 )),
395 (
396 Decimal64(_, _),
397 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
398 ) => Some((
399 lhs_type.clone(),
400 coerce_numeric_type_to_decimal64(rhs_type)?,
401 )),
402 (
403 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
404 Decimal64(_, _),
405 ) => Some((
406 coerce_numeric_type_to_decimal64(lhs_type)?,
407 rhs_type.clone(),
408 )),
409 (
410 Decimal128(_, _),
411 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
412 ) => Some((
413 lhs_type.clone(),
414 coerce_numeric_type_to_decimal128(rhs_type)?,
415 )),
416 (
417 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
418 Decimal128(_, _),
419 ) => Some((
420 coerce_numeric_type_to_decimal128(lhs_type)?,
421 rhs_type.clone(),
422 )),
423 (
424 Decimal256(_, _),
425 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
426 ) => Some((
427 lhs_type.clone(),
428 coerce_numeric_type_to_decimal256(rhs_type)?,
429 )),
430 (
431 Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64,
432 Decimal256(_, _),
433 ) => Some((
434 coerce_numeric_type_to_decimal256(lhs_type)?,
435 rhs_type.clone(),
436 )),
437 _ => None,
438 }
439}
440
441fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataType> {
444 use arrow::datatypes::DataType::*;
445
446 if !both_numeric_or_null_and_numeric(left_type, right_type) {
447 return None;
448 }
449
450 if left_type == right_type {
451 return Some(left_type.clone());
452 }
453
454 match (left_type, right_type) {
455 (UInt64, _) | (_, UInt64) => Some(UInt64),
456 (Int64, _)
457 | (_, Int64)
458 | (UInt32, Int8)
459 | (Int8, UInt32)
460 | (UInt32, Int16)
461 | (Int16, UInt32)
462 | (UInt32, Int32)
463 | (Int32, UInt32) => Some(Int64),
464 (Int32, _)
465 | (_, Int32)
466 | (UInt16, Int16)
467 | (Int16, UInt16)
468 | (UInt16, Int8)
469 | (Int8, UInt16) => Some(Int32),
470 (UInt32, _) | (_, UInt32) => Some(UInt32),
471 (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
472 (UInt16, _) | (_, UInt16) => Some(UInt16),
473 (Int8, _) | (_, Int8) => Some(Int8),
474 (UInt8, _) | (_, UInt8) => Some(UInt8),
475 _ => None,
476 }
477}
478
479#[derive(Debug, PartialEq, Eq, Hash, Clone)]
480enum TypeCategory {
481 Array,
482 Boolean,
483 Numeric,
484 DateTime,
486 Composite,
487 Unknown,
488 NotSupported,
489}
490
491impl From<&DataType> for TypeCategory {
492 fn from(data_type: &DataType) -> Self {
493 match data_type {
494 DataType::Dictionary(_, v) => {
496 let v = v.as_ref();
497 TypeCategory::from(v)
498 }
499 _ => {
500 if data_type.is_numeric() {
501 return TypeCategory::Numeric;
502 }
503
504 if matches!(data_type, DataType::Boolean) {
505 return TypeCategory::Boolean;
506 }
507
508 if matches!(
509 data_type,
510 DataType::List(_)
511 | DataType::FixedSizeList(_, _)
512 | DataType::LargeList(_)
513 ) {
514 return TypeCategory::Array;
515 }
516
517 if matches!(
519 data_type,
520 DataType::Utf8
521 | DataType::LargeUtf8
522 | DataType::Utf8View
523 | DataType::Null
524 ) {
525 return TypeCategory::Unknown;
526 }
527
528 if matches!(
529 data_type,
530 DataType::Date32
531 | DataType::Date64
532 | DataType::Time32(_)
533 | DataType::Time64(_)
534 | DataType::Timestamp(_, _)
535 | DataType::Interval(_)
536 | DataType::Duration(_)
537 ) {
538 return TypeCategory::DateTime;
539 }
540
541 if matches!(
542 data_type,
543 DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _)
544 ) {
545 return TypeCategory::Composite;
546 }
547
548 TypeCategory::NotSupported
549 }
550 }
551 }
552}
553
554pub fn type_union_resolution(data_types: &[DataType]) -> Option<DataType> {
567 if data_types.is_empty() {
568 return None;
569 }
570
571 if data_types.iter().all(|t| t == &data_types[0]) {
573 return Some(data_types[0].clone());
574 }
575
576 if data_types.iter().all(|t| t == &DataType::Null) {
578 return Some(DataType::Utf8View);
579 }
580
581 let data_types_category: Vec<TypeCategory> = data_types
583 .iter()
584 .filter(|&t| t != &DataType::Null)
585 .map(|t| t.into())
586 .collect();
587
588 if data_types_category
589 .iter()
590 .any(|t| t == &TypeCategory::NotSupported)
591 {
592 return None;
593 }
594
595 let categories: HashSet<TypeCategory> = HashSet::from_iter(
597 data_types_category
598 .iter()
599 .filter(|&c| c != &TypeCategory::Unknown)
600 .cloned(),
601 );
602 if categories.len() > 1 {
603 return None;
604 }
605
606 let mut candidate_type: Option<DataType> = None;
608 for data_type in data_types.iter() {
609 if data_type == &DataType::Null {
610 continue;
611 }
612 if let Some(ref candidate_t) = candidate_type {
613 if let Some(t) = type_union_resolution_coercion(data_type, candidate_t) {
620 candidate_type = Some(t);
621 } else {
622 return None;
623 }
624 } else {
625 candidate_type = Some(data_type.clone());
626 }
627 }
628
629 candidate_type
630}
631
632fn type_union_resolution_coercion(
635 lhs_type: &DataType,
636 rhs_type: &DataType,
637) -> Option<DataType> {
638 if lhs_type == rhs_type {
639 return Some(lhs_type.clone());
640 }
641
642 match (lhs_type, rhs_type) {
643 (
644 DataType::Dictionary(lhs_index_type, lhs_value_type),
645 DataType::Dictionary(rhs_index_type, rhs_value_type),
646 ) => {
647 let new_index_type =
648 type_union_resolution_coercion(lhs_index_type, rhs_index_type);
649 let new_value_type =
650 type_union_resolution_coercion(lhs_value_type, rhs_value_type);
651 if let (Some(new_index_type), Some(new_value_type)) =
652 (new_index_type, new_value_type)
653 {
654 Some(DataType::Dictionary(
655 Box::new(new_index_type),
656 Box::new(new_value_type),
657 ))
658 } else {
659 None
660 }
661 }
662 (DataType::Dictionary(index_type, value_type), other_type)
663 | (other_type, DataType::Dictionary(index_type, value_type)) => {
664 match type_union_resolution_coercion(value_type, other_type) {
665 Some(DataType::Utf8View) => Some(DataType::Utf8View),
668 Some(new_value_type) => Some(DataType::Dictionary(
669 index_type.clone(),
670 Box::new(new_value_type),
671 )),
672 None => None,
673 }
674 }
675 (DataType::Struct(lhs), DataType::Struct(rhs)) => {
676 if lhs.len() != rhs.len() {
677 return None;
678 }
679
680 fn search_corresponding_coerced_type(
682 lhs_field: &FieldRef,
683 rhs: &Fields,
684 ) -> Option<DataType> {
685 for rhs_field in rhs.iter() {
686 if lhs_field.name() == rhs_field.name() {
687 if let Some(t) = type_union_resolution_coercion(
688 lhs_field.data_type(),
689 rhs_field.data_type(),
690 ) {
691 return Some(t);
692 } else {
693 return None;
694 }
695 }
696 }
697
698 None
699 }
700
701 let coerced_types = lhs
702 .iter()
703 .map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs))
704 .collect::<Option<Vec<_>>>()?;
705
706 let orig_fields = std::iter::zip(lhs.iter(), rhs.iter());
708
709 let fields: Vec<FieldRef> = coerced_types
710 .into_iter()
711 .zip(orig_fields)
712 .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs))
713 .collect();
714 Some(DataType::Struct(fields.into()))
715 }
716 _ => {
717 binary_numeric_coercion(lhs_type, rhs_type)
720 .or_else(|| list_coercion(lhs_type, rhs_type))
721 .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
722 .or_else(|| string_coercion(lhs_type, rhs_type))
723 .or_else(|| numeric_string_coercion(lhs_type, rhs_type))
724 .or_else(|| binary_coercion(lhs_type, rhs_type))
725 }
726 }
727}
728
729pub fn try_type_union_resolution(data_types: &[DataType]) -> Result<Vec<DataType>> {
731 let err = match try_type_union_resolution_with_struct(data_types) {
732 Ok(struct_types) => return Ok(struct_types),
733 Err(e) => Some(e),
734 };
735
736 if let Some(new_type) = type_union_resolution(data_types) {
737 Ok(vec![new_type; data_types.len()])
738 } else {
739 exec_err!("Fail to find the coerced type, errors: {:?}", err)
740 }
741}
742
743pub fn try_type_union_resolution_with_struct(
746 data_types: &[DataType],
747) -> Result<Vec<DataType>> {
748 let mut keys_string: Option<String> = None;
749 for data_type in data_types {
750 if let DataType::Struct(fields) = data_type {
751 let keys = fields.iter().map(|f| f.name().to_owned()).join(",");
752 if let Some(ref k) = keys_string {
753 if *k != keys {
754 return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys);
755 }
756 } else {
757 keys_string = Some(keys);
758 }
759 } else {
760 return exec_err!("Expect to get struct but got {data_type}");
761 }
762 }
763
764 let mut struct_types: Vec<DataType> = if let DataType::Struct(fields) = &data_types[0]
765 {
766 fields.iter().map(|f| f.data_type().to_owned()).collect()
767 } else {
768 return internal_err!("Struct type is checked is the previous function, so this should be unreachable");
769 };
770
771 for data_type in data_types.iter().skip(1) {
772 if let DataType::Struct(fields) = data_type {
773 let incoming_struct_types: Vec<DataType> =
774 fields.iter().map(|f| f.data_type().to_owned()).collect();
775 for (lhs_type, rhs_type) in
777 struct_types.iter_mut().zip(incoming_struct_types.iter())
778 {
779 if let Some(coerced_type) =
780 type_union_resolution_coercion(lhs_type, rhs_type)
781 {
782 *lhs_type = coerced_type;
783 } else {
784 return exec_err!(
785 "Fail to find the coerced type for {} and {}",
786 lhs_type,
787 rhs_type
788 );
789 }
790 }
791 } else {
792 return exec_err!("Expect to get struct but got {data_type}");
793 }
794 }
795
796 let mut final_struct_types = vec![];
797 for s in data_types {
798 let mut new_fields = vec![];
799 if let DataType::Struct(fields) = s {
800 for (i, f) in fields.iter().enumerate() {
801 let field = Arc::unwrap_or_clone(Arc::clone(f))
802 .with_data_type(struct_types[i].to_owned());
803 new_fields.push(Arc::new(field));
804 }
805 }
806 final_struct_types.push(DataType::Struct(new_fields.into()))
807 }
808
809 Ok(final_struct_types)
810}
811
812pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
835 if lhs_type.equals_datatype(rhs_type) {
836 return Some(lhs_type.clone());
838 }
839 binary_numeric_coercion(lhs_type, rhs_type)
840 .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, true))
841 .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
842 .or_else(|| string_coercion(lhs_type, rhs_type))
843 .or_else(|| list_coercion(lhs_type, rhs_type))
844 .or_else(|| null_coercion(lhs_type, rhs_type))
845 .or_else(|| string_numeric_coercion(lhs_type, rhs_type))
846 .or_else(|| string_temporal_coercion(lhs_type, rhs_type))
847 .or_else(|| binary_coercion(lhs_type, rhs_type))
848 .or_else(|| struct_coercion(lhs_type, rhs_type))
849 .or_else(|| map_coercion(lhs_type, rhs_type))
850}
851
852pub fn comparison_coercion_numeric(
861 lhs_type: &DataType,
862 rhs_type: &DataType,
863) -> Option<DataType> {
864 if lhs_type == rhs_type {
865 return Some(lhs_type.clone());
867 }
868 binary_numeric_coercion(lhs_type, rhs_type)
869 .or_else(|| dictionary_comparison_coercion_numeric(lhs_type, rhs_type, true))
870 .or_else(|| string_coercion(lhs_type, rhs_type))
871 .or_else(|| null_coercion(lhs_type, rhs_type))
872 .or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type))
873}
874
875fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
878 use arrow::datatypes::DataType::*;
879 match (lhs_type, rhs_type) {
880 (Utf8, _) if rhs_type.is_numeric() => Some(Utf8),
881 (LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8),
882 (Utf8View, _) if rhs_type.is_numeric() => Some(Utf8View),
883 (_, Utf8) if lhs_type.is_numeric() => Some(Utf8),
884 (_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8),
885 (_, Utf8View) if lhs_type.is_numeric() => Some(Utf8View),
886 _ => None,
887 }
888}
889
890fn string_numeric_coercion_as_numeric(
893 lhs_type: &DataType,
894 rhs_type: &DataType,
895) -> Option<DataType> {
896 let lhs_logical_type = NativeType::from(lhs_type);
897 let rhs_logical_type = NativeType::from(rhs_type);
898 if lhs_logical_type.is_numeric() && rhs_logical_type == NativeType::String {
899 return Some(lhs_type.to_owned());
900 }
901 if rhs_logical_type.is_numeric() && lhs_logical_type == NativeType::String {
902 return Some(rhs_type.to_owned());
903 }
904
905 None
906}
907
908fn string_temporal_coercion(
922 lhs_type: &DataType,
923 rhs_type: &DataType,
924) -> Option<DataType> {
925 use arrow::datatypes::DataType::*;
926
927 fn match_rule(l: &DataType, r: &DataType) -> Option<DataType> {
928 match (l, r) {
929 (Utf8, temporal) | (LargeUtf8, temporal) | (Utf8View, temporal) => {
931 match temporal {
932 Date32 | Date64 => Some(temporal.clone()),
933 Time32(_) | Time64(_) => {
934 if is_time_with_valid_unit(temporal.to_owned()) {
935 Some(temporal.to_owned())
936 } else {
937 None
938 }
939 }
940 Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())),
941 _ => None,
942 }
943 }
944 _ => None,
945 }
946 }
947
948 match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type))
949}
950
951pub fn binary_numeric_coercion(
953 lhs_type: &DataType,
954 rhs_type: &DataType,
955) -> Option<DataType> {
956 if !lhs_type.is_numeric() || !rhs_type.is_numeric() {
957 return None;
958 };
959
960 if lhs_type == rhs_type {
962 return Some(lhs_type.clone());
963 }
964
965 if let Some(t) = decimal_coercion(lhs_type, rhs_type) {
966 return Some(t);
967 }
968
969 numerical_coercion(lhs_type, rhs_type)
970}
971
972pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
974 use arrow::datatypes::DataType::*;
975
976 match (lhs_type, rhs_type) {
978 (lhs_type, rhs_type)
980 if is_decimal(lhs_type)
981 && is_decimal(rhs_type)
982 && std::mem::discriminant(lhs_type)
983 == std::mem::discriminant(rhs_type) =>
984 {
985 get_wider_decimal_type(lhs_type, rhs_type)
986 }
987 (lhs_type, rhs_type)
989 if is_decimal(lhs_type)
990 && is_decimal(rhs_type)
991 && std::mem::discriminant(lhs_type)
992 != std::mem::discriminant(rhs_type) =>
993 {
994 get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
995 }
996 (Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), _) => {
998 get_common_decimal_type(lhs_type, rhs_type)
999 }
1000 (_, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
1001 get_common_decimal_type(rhs_type, lhs_type)
1002 }
1003 (_, _) => None,
1004 }
1005}
1006fn get_wider_decimal_type_cross_variant(
1008 lhs_type: &DataType,
1009 rhs_type: &DataType,
1010) -> Option<DataType> {
1011 use arrow::datatypes::DataType::*;
1012
1013 let (p1, s1) = match lhs_type {
1014 Decimal32(p, s) => (*p, *s),
1015 Decimal64(p, s) => (*p, *s),
1016 Decimal128(p, s) => (*p, *s),
1017 Decimal256(p, s) => (*p, *s),
1018 _ => return None,
1019 };
1020
1021 let (p2, s2) = match rhs_type {
1022 Decimal32(p, s) => (*p, *s),
1023 Decimal64(p, s) => (*p, *s),
1024 Decimal128(p, s) => (*p, *s),
1025 Decimal256(p, s) => (*p, *s),
1026 _ => return None,
1027 };
1028
1029 let s = s1.max(s2);
1031 let range = (p1 as i8 - s1).max(p2 as i8 - s2);
1032 let required_precision = (range + s) as u8;
1033
1034 match (lhs_type, rhs_type) {
1036 (Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _))
1037 if required_precision <= DECIMAL64_MAX_PRECISION =>
1038 {
1039 Some(Decimal64(required_precision, s))
1040 }
1041 (Decimal32(_, _), Decimal128(_, _))
1042 | (Decimal128(_, _), Decimal32(_, _))
1043 | (Decimal64(_, _), Decimal128(_, _))
1044 | (Decimal128(_, _), Decimal64(_, _))
1045 if required_precision <= DECIMAL128_MAX_PRECISION =>
1046 {
1047 Some(Decimal128(required_precision, s))
1048 }
1049 (Decimal32(_, _), Decimal256(_, _))
1050 | (Decimal256(_, _), Decimal32(_, _))
1051 | (Decimal64(_, _), Decimal256(_, _))
1052 | (Decimal256(_, _), Decimal64(_, _))
1053 | (Decimal128(_, _), Decimal256(_, _))
1054 | (Decimal256(_, _), Decimal128(_, _))
1055 if required_precision <= DECIMAL256_MAX_PRECISION =>
1056 {
1057 Some(Decimal256(required_precision, s))
1058 }
1059 _ => None,
1060 }
1061}
1062
1063fn get_common_decimal_type(
1065 decimal_type: &DataType,
1066 other_type: &DataType,
1067) -> Option<DataType> {
1068 use arrow::datatypes::DataType::*;
1069 match decimal_type {
1070 Decimal32(_, _) => {
1071 let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?;
1072 get_wider_decimal_type(decimal_type, &other_decimal_type)
1073 }
1074 Decimal64(_, _) => {
1075 let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?;
1076 get_wider_decimal_type(decimal_type, &other_decimal_type)
1077 }
1078 Decimal128(_, _) => {
1079 let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?;
1080 get_wider_decimal_type(decimal_type, &other_decimal_type)
1081 }
1082 Decimal256(_, _) => {
1083 let other_decimal_type = coerce_numeric_type_to_decimal256(other_type)?;
1084 get_wider_decimal_type(decimal_type, &other_decimal_type)
1085 }
1086 _ => None,
1087 }
1088}
1089
1090fn get_wider_decimal_type(
1095 lhs_decimal_type: &DataType,
1096 rhs_type: &DataType,
1097) -> Option<DataType> {
1098 match (lhs_decimal_type, rhs_type) {
1099 (DataType::Decimal32(p1, s1), DataType::Decimal32(p2, s2)) => {
1100 let s = *s1.max(s2);
1102 let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
1103 Some(create_decimal32_type((range + s) as u8, s))
1104 }
1105 (DataType::Decimal64(p1, s1), DataType::Decimal64(p2, s2)) => {
1106 let s = *s1.max(s2);
1108 let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
1109 Some(create_decimal64_type((range + s) as u8, s))
1110 }
1111 (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => {
1112 let s = *s1.max(s2);
1114 let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
1115 Some(create_decimal128_type((range + s) as u8, s))
1116 }
1117 (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => {
1118 let s = *s1.max(s2);
1120 let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
1121 Some(create_decimal256_type((range + s) as u8, s))
1122 }
1123 (_, _) => None,
1124 }
1125}
1126
1127fn coerce_numeric_type_to_decimal32(numeric_type: &DataType) -> Option<DataType> {
1130 use arrow::datatypes::DataType::*;
1131 match numeric_type {
1134 Int8 | UInt8 => Some(Decimal32(3, 0)),
1135 Int16 | UInt16 => Some(Decimal32(5, 0)),
1136 Float16 => Some(Decimal32(6, 3)),
1138 _ => None,
1139 }
1140}
1141
1142fn coerce_numeric_type_to_decimal64(numeric_type: &DataType) -> Option<DataType> {
1145 use arrow::datatypes::DataType::*;
1146 match numeric_type {
1149 Int8 | UInt8 => Some(Decimal64(3, 0)),
1150 Int16 | UInt16 => Some(Decimal64(5, 0)),
1151 Int32 | UInt32 => Some(Decimal64(10, 0)),
1152 Float16 => Some(Decimal64(6, 3)),
1154 Float32 => Some(Decimal64(14, 7)),
1155 _ => None,
1156 }
1157}
1158
1159fn coerce_numeric_type_to_decimal128(numeric_type: &DataType) -> Option<DataType> {
1162 use arrow::datatypes::DataType::*;
1163 match numeric_type {
1166 Int8 | UInt8 => Some(Decimal128(3, 0)),
1167 Int16 | UInt16 => Some(Decimal128(5, 0)),
1168 Int32 | UInt32 => Some(Decimal128(10, 0)),
1169 Int64 | UInt64 => Some(Decimal128(20, 0)),
1170 Float16 => Some(Decimal128(6, 3)),
1172 Float32 => Some(Decimal128(14, 7)),
1173 Float64 => Some(Decimal128(30, 15)),
1174 _ => None,
1175 }
1176}
1177
1178fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType> {
1181 use arrow::datatypes::DataType::*;
1182 match numeric_type {
1185 Int8 | UInt8 => Some(Decimal256(3, 0)),
1186 Int16 | UInt16 => Some(Decimal256(5, 0)),
1187 Int32 | UInt32 => Some(Decimal256(10, 0)),
1188 Int64 | UInt64 => Some(Decimal256(20, 0)),
1189 Float16 => Some(Decimal256(6, 3)),
1191 Float32 => Some(Decimal256(14, 7)),
1192 Float64 => Some(Decimal256(30, 15)),
1193 _ => None,
1194 }
1195}
1196
1197fn struct_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1198 use arrow::datatypes::DataType::*;
1199 match (lhs_type, rhs_type) {
1200 (Struct(lhs_fields), Struct(rhs_fields)) => {
1201 if lhs_fields.len() != rhs_fields.len() {
1202 return None;
1203 }
1204
1205 let coerced_types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter())
1206 .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type()))
1207 .collect::<Option<Vec<DataType>>>()?;
1208
1209 let orig_fields = std::iter::zip(lhs_fields.iter(), rhs_fields.iter());
1211
1212 let fields: Vec<FieldRef> = coerced_types
1213 .into_iter()
1214 .zip(orig_fields)
1215 .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs))
1216 .collect();
1217 Some(Struct(fields.into()))
1218 }
1219 _ => None,
1220 }
1221}
1222
1223fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> FieldRef {
1225 let is_nullable = lhs.is_nullable() || rhs.is_nullable();
1226 let name = lhs.name(); Arc::new(Field::new(name, common_type, is_nullable))
1228}
1229
1230fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1233 use arrow::datatypes::DataType::*;
1234 match (lhs_type, rhs_type) {
1235 (Map(lhs_field, lhs_ordered), Map(rhs_field, rhs_ordered)) => {
1236 struct_coercion(lhs_field.data_type(), rhs_field.data_type()).map(
1237 |key_value_type| {
1238 Map(
1239 Arc::new((**lhs_field).clone().with_data_type(key_value_type)),
1240 *lhs_ordered && *rhs_ordered,
1241 )
1242 },
1243 )
1244 }
1245 _ => None,
1246 }
1247}
1248
1249fn mathematics_numerical_coercion(
1252 lhs_type: &DataType,
1253 rhs_type: &DataType,
1254) -> Option<DataType> {
1255 use arrow::datatypes::DataType::*;
1256
1257 if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) {
1259 return None;
1260 };
1261
1262 match (lhs_type, rhs_type) {
1265 (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
1266 mathematics_numerical_coercion(lhs_value_type, rhs_value_type)
1267 }
1268 (Dictionary(_, value_type), _) => {
1269 mathematics_numerical_coercion(value_type, rhs_type)
1270 }
1271 (_, Dictionary(_, value_type)) => {
1272 mathematics_numerical_coercion(lhs_type, value_type)
1273 }
1274 _ => numerical_coercion(lhs_type, rhs_type),
1275 }
1276}
1277
1278fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1281 use arrow::datatypes::DataType::*;
1282
1283 match (lhs_type, rhs_type) {
1284 (Float64, _) | (_, Float64) => Some(Float64),
1285 (_, Float32) | (Float32, _) => Some(Float32),
1286 (_, Float16) | (Float16, _) => Some(Float16),
1287 (UInt64, Int64 | Int32 | Int16 | Int8)
1292 | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)),
1293 (UInt64, _) | (_, UInt64) => Some(UInt64),
1294 (Int64, _)
1295 | (_, Int64)
1296 | (UInt32, Int32 | Int16 | Int8)
1297 | (Int32 | Int16 | Int8, UInt32) => Some(Int64),
1298 (UInt32, _) | (_, UInt32) => Some(UInt32),
1299 (Int32, _) | (_, Int32) | (UInt16, Int16 | Int8) | (Int16 | Int8, UInt16) => {
1300 Some(Int32)
1301 }
1302 (UInt16, _) | (_, UInt16) => Some(UInt16),
1303 (Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
1304 (Int8, _) | (_, Int8) => Some(Int8),
1305 (UInt8, _) | (_, UInt8) => Some(UInt8),
1306 _ => None,
1307 }
1308}
1309
1310fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
1311 DataType::Decimal32(
1312 DECIMAL32_MAX_PRECISION.min(precision),
1313 DECIMAL32_MAX_SCALE.min(scale),
1314 )
1315}
1316
1317fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
1318 DataType::Decimal64(
1319 DECIMAL64_MAX_PRECISION.min(precision),
1320 DECIMAL64_MAX_SCALE.min(scale),
1321 )
1322}
1323
1324fn create_decimal128_type(precision: u8, scale: i8) -> DataType {
1325 DataType::Decimal128(
1326 DECIMAL128_MAX_PRECISION.min(precision),
1327 DECIMAL128_MAX_SCALE.min(scale),
1328 )
1329}
1330
1331fn create_decimal256_type(precision: u8, scale: i8) -> DataType {
1332 DataType::Decimal256(
1333 DECIMAL256_MAX_PRECISION.min(precision),
1334 DECIMAL256_MAX_SCALE.min(scale),
1335 )
1336}
1337
1338fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool {
1340 use arrow::datatypes::DataType::*;
1341 match (lhs_type, rhs_type) {
1342 (_, Null) => lhs_type.is_numeric(),
1343 (Null, _) => rhs_type.is_numeric(),
1344 (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
1345 lhs_value_type.is_numeric() && rhs_value_type.is_numeric()
1346 }
1347 (Dictionary(_, value_type), _) => {
1348 value_type.is_numeric() && rhs_type.is_numeric()
1349 }
1350 (_, Dictionary(_, value_type)) => {
1351 lhs_type.is_numeric() && value_type.is_numeric()
1352 }
1353 _ => lhs_type.is_numeric() && rhs_type.is_numeric(),
1354 }
1355}
1356
1357fn dictionary_comparison_coercion_generic(
1366 lhs_type: &DataType,
1367 rhs_type: &DataType,
1368 preserve_dictionaries: bool,
1369 coerce_fn: fn(&DataType, &DataType) -> Option<DataType>,
1370) -> Option<DataType> {
1371 use arrow::datatypes::DataType::*;
1372 match (lhs_type, rhs_type) {
1373 (
1374 Dictionary(_lhs_index_type, lhs_value_type),
1375 Dictionary(_rhs_index_type, rhs_value_type),
1376 ) => coerce_fn(lhs_value_type, rhs_value_type),
1377 (d @ Dictionary(_, value_type), other_type)
1378 | (other_type, d @ Dictionary(_, value_type))
1379 if preserve_dictionaries && value_type.as_ref() == other_type =>
1380 {
1381 Some(d.clone())
1382 }
1383 (Dictionary(_index_type, value_type), _) => coerce_fn(value_type, rhs_type),
1384 (_, Dictionary(_index_type, value_type)) => coerce_fn(lhs_type, value_type),
1385 _ => None,
1386 }
1387}
1388
1389fn dictionary_comparison_coercion(
1395 lhs_type: &DataType,
1396 rhs_type: &DataType,
1397 preserve_dictionaries: bool,
1398) -> Option<DataType> {
1399 dictionary_comparison_coercion_generic(
1400 lhs_type,
1401 rhs_type,
1402 preserve_dictionaries,
1403 comparison_coercion,
1404 )
1405}
1406
1407fn dictionary_comparison_coercion_numeric(
1414 lhs_type: &DataType,
1415 rhs_type: &DataType,
1416 preserve_dictionaries: bool,
1417) -> Option<DataType> {
1418 dictionary_comparison_coercion_generic(
1419 lhs_type,
1420 rhs_type,
1421 preserve_dictionaries,
1422 comparison_coercion_numeric,
1423 )
1424}
1425
1426fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1431 use arrow::datatypes::DataType::*;
1432 string_coercion(lhs_type, rhs_type).or_else(|| match (lhs_type, rhs_type) {
1433 (Utf8View, from_type) | (from_type, Utf8View) => {
1434 string_concat_internal_coercion(from_type, &Utf8View)
1435 }
1436 (Utf8, from_type) | (from_type, Utf8) => {
1437 string_concat_internal_coercion(from_type, &Utf8)
1438 }
1439 (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
1440 string_concat_internal_coercion(from_type, &LargeUtf8)
1441 }
1442 (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
1443 string_coercion(lhs_value_type, rhs_value_type).or(None)
1444 }
1445 _ => None,
1446 })
1447}
1448
1449fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1450 if lhs_type.equals_datatype(rhs_type) {
1451 Some(lhs_type.to_owned())
1452 } else {
1453 None
1454 }
1455}
1456
1457fn string_concat_internal_coercion(
1460 from_type: &DataType,
1461 to_type: &DataType,
1462) -> Option<DataType> {
1463 if can_cast_types(from_type, to_type) {
1464 Some(to_type.to_owned())
1465 } else {
1466 None
1467 }
1468}
1469
1470pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1476 use arrow::datatypes::DataType::*;
1477 match (lhs_type, rhs_type) {
1478 (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => {
1480 Some(Utf8View)
1481 }
1482 (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
1484 (Utf8, Utf8) => Some(Utf8),
1486 _ => None,
1487 }
1488}
1489
1490fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1491 use arrow::datatypes::DataType::*;
1492 match (lhs_type, rhs_type) {
1493 (Utf8 | LargeUtf8 | Utf8View, other_type)
1494 | (other_type, Utf8 | LargeUtf8 | Utf8View)
1495 if other_type.is_numeric() =>
1496 {
1497 Some(other_type.clone())
1498 }
1499 _ => None,
1500 }
1501}
1502
1503fn coerce_list_children(lhs_field: &FieldRef, rhs_field: &FieldRef) -> Option<FieldRef> {
1505 let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()];
1506 Some(Arc::new(
1507 (**lhs_field)
1508 .clone()
1509 .with_data_type(type_union_resolution(&data_types)?)
1510 .with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()),
1511 ))
1512}
1513
1514fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1516 use arrow::datatypes::DataType::*;
1517 match (lhs_type, rhs_type) {
1518 (FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => {
1521 if ls == rs {
1522 Some(FixedSizeList(
1523 coerce_list_children(lhs_field, rhs_field)?,
1524 *rs,
1525 ))
1526 } else {
1527 Some(List(coerce_list_children(lhs_field, rhs_field)?))
1528 }
1529 }
1530 (
1532 LargeList(lhs_field),
1533 List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _),
1534 )
1535 | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => {
1536 Some(LargeList(coerce_list_children(lhs_field, rhs_field)?))
1537 }
1538 (List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _))
1540 | (FixedSizeList(lhs_field, _), List(rhs_field)) => {
1541 Some(List(coerce_list_children(lhs_field, rhs_field)?))
1542 }
1543 _ => None,
1544 }
1545}
1546
1547pub fn binary_to_string_coercion(
1551 lhs_type: &DataType,
1552 rhs_type: &DataType,
1553) -> Option<DataType> {
1554 use arrow::datatypes::DataType::*;
1555 match (lhs_type, rhs_type) {
1556 (Binary, Utf8) => Some(Utf8),
1557 (Binary, LargeUtf8) => Some(LargeUtf8),
1558 (BinaryView, Utf8) => Some(Utf8View),
1559 (BinaryView, LargeUtf8) => Some(LargeUtf8),
1560 (LargeBinary, Utf8) => Some(LargeUtf8),
1561 (LargeBinary, LargeUtf8) => Some(LargeUtf8),
1562 (Utf8, Binary) => Some(Utf8),
1563 (Utf8, LargeBinary) => Some(LargeUtf8),
1564 (Utf8, BinaryView) => Some(Utf8View),
1565 (LargeUtf8, Binary) => Some(LargeUtf8),
1566 (LargeUtf8, LargeBinary) => Some(LargeUtf8),
1567 (LargeUtf8, BinaryView) => Some(LargeUtf8),
1568 _ => None,
1569 }
1570}
1571
1572fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1576 use arrow::datatypes::DataType::*;
1577 match (lhs_type, rhs_type) {
1578 (BinaryView, BinaryView | Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View)
1580 | (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, BinaryView) => {
1581 Some(BinaryView)
1582 }
1583 (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary)
1585 | (LargeBinary, Binary | Utf8 | LargeUtf8 | Utf8View) => Some(LargeBinary),
1586
1587 (Utf8View | LargeUtf8, Binary) | (Binary, Utf8View | LargeUtf8) => {
1589 Some(LargeBinary)
1590 }
1591 (Binary, Utf8) | (Utf8, Binary) => Some(Binary),
1592
1593 (FixedSizeBinary(_), Binary) | (Binary, FixedSizeBinary(_)) => Some(Binary),
1595 (FixedSizeBinary(_), BinaryView) | (BinaryView, FixedSizeBinary(_)) => {
1596 Some(BinaryView)
1597 }
1598
1599 _ => None,
1600 }
1601}
1602
1603pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1606 string_coercion(lhs_type, rhs_type)
1607 .or_else(|| list_coercion(lhs_type, rhs_type))
1608 .or_else(|| binary_to_string_coercion(lhs_type, rhs_type))
1609 .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false))
1610 .or_else(|| regex_null_coercion(lhs_type, rhs_type))
1611 .or_else(|| null_coercion(lhs_type, rhs_type))
1612}
1613
1614fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1616 use arrow::datatypes::DataType::*;
1617 match (lhs_type, rhs_type) {
1618 (Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()),
1619 (Utf8View | Utf8 | LargeUtf8, Null) => Some(lhs_type.clone()),
1620 (Null, Null) => Some(Utf8),
1621 _ => None,
1622 }
1623}
1624
1625pub fn regex_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1628 string_coercion(lhs_type, rhs_type)
1629 .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false))
1630 .or_else(|| regex_null_coercion(lhs_type, rhs_type))
1631}
1632
1633fn is_time_with_valid_unit(datatype: DataType) -> bool {
1637 matches!(
1638 datatype,
1639 DataType::Time32(TimeUnit::Second)
1640 | DataType::Time32(TimeUnit::Millisecond)
1641 | DataType::Time64(TimeUnit::Microsecond)
1642 | DataType::Time64(TimeUnit::Nanosecond)
1643 )
1644}
1645
1646fn temporal_coercion_nonstrict_timezone(
1658 lhs_type: &DataType,
1659 rhs_type: &DataType,
1660) -> Option<DataType> {
1661 use arrow::datatypes::DataType::*;
1662
1663 match (lhs_type, rhs_type) {
1664 (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => {
1665 let tz = match (lhs_tz, rhs_tz) {
1666 (Some(lhs_tz), Some(_rhs_tz)) => Some(Arc::clone(lhs_tz)),
1668 (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)),
1669 (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)),
1670 (None, None) => None,
1671 };
1672
1673 let unit = timeunit_coercion(lhs_unit, rhs_unit);
1674
1675 Some(Timestamp(unit, tz))
1676 }
1677 _ => temporal_coercion(lhs_type, rhs_type),
1678 }
1679}
1680
1681fn temporal_coercion_strict_timezone(
1695 lhs_type: &DataType,
1696 rhs_type: &DataType,
1697) -> Option<DataType> {
1698 use arrow::datatypes::DataType::*;
1699
1700 match (lhs_type, rhs_type) {
1701 (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => {
1702 let tz = match (lhs_tz, rhs_tz) {
1703 (Some(lhs_tz), Some(rhs_tz)) => {
1704 match (lhs_tz.as_ref(), rhs_tz.as_ref()) {
1705 ("UTC", "+00:00") | ("+00:00", "UTC") => Some(Arc::clone(lhs_tz)),
1708 (lhs, rhs) if lhs == rhs => Some(Arc::clone(lhs_tz)),
1709 _ => {
1711 return None;
1712 }
1713 }
1714 }
1715 (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)),
1716 (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)),
1717 (None, None) => None,
1718 };
1719
1720 let unit = timeunit_coercion(lhs_unit, rhs_unit);
1721
1722 Some(Timestamp(unit, tz))
1723 }
1724 _ => temporal_coercion(lhs_type, rhs_type),
1725 }
1726}
1727
1728fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1729 use arrow::datatypes::DataType::*;
1730 use arrow::datatypes::IntervalUnit::*;
1731 use arrow::datatypes::TimeUnit::*;
1732
1733 match (lhs_type, rhs_type) {
1734 (Interval(_) | Duration(_), Interval(_) | Duration(_)) => {
1735 Some(Interval(MonthDayNano))
1736 }
1737 (Date64, Date32) | (Date32, Date64) => Some(Date64),
1738 (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => {
1739 Some(Timestamp(Nanosecond, None))
1740 }
1741 (Timestamp(_, _tz), Date64) | (Date64, Timestamp(_, _tz)) => {
1742 Some(Timestamp(Nanosecond, None))
1743 }
1744 (Timestamp(_, None), Date32) | (Date32, Timestamp(_, None)) => {
1745 Some(Timestamp(Nanosecond, None))
1746 }
1747 (Timestamp(_, _tz), Date32) | (Date32, Timestamp(_, _tz)) => {
1748 Some(Timestamp(Nanosecond, None))
1749 }
1750 _ => None,
1751 }
1752}
1753
1754fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit {
1755 use arrow::datatypes::TimeUnit::*;
1756 match (lhs_unit, rhs_unit) {
1757 (Second, Millisecond) => Second,
1758 (Second, Microsecond) => Second,
1759 (Second, Nanosecond) => Second,
1760 (Millisecond, Second) => Second,
1761 (Millisecond, Microsecond) => Millisecond,
1762 (Millisecond, Nanosecond) => Millisecond,
1763 (Microsecond, Second) => Second,
1764 (Microsecond, Millisecond) => Millisecond,
1765 (Microsecond, Nanosecond) => Microsecond,
1766 (Nanosecond, Second) => Second,
1767 (Nanosecond, Millisecond) => Millisecond,
1768 (Nanosecond, Microsecond) => Microsecond,
1769 (l, r) => {
1770 assert_eq!(l, r);
1771 *l
1772 }
1773 }
1774}
1775
1776fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1779 match (lhs_type, rhs_type) {
1780 (DataType::Null, other_type) | (other_type, DataType::Null) => {
1781 if can_cast_types(&DataType::Null, other_type) {
1782 Some(other_type.clone())
1783 } else {
1784 None
1785 }
1786 }
1787 _ => None,
1788 }
1789}
1790
1791#[cfg(test)]
1792mod tests;