datafusion_expr_common/type_coercion/aggregates.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::signature::TypeSignature;
19use arrow::datatypes::{DataType, FieldRef};
20
21use datafusion_common::{internal_err, plan_err, Result};
22
23// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures
24// see https://github.com/apache/datafusion/issues/18092
25pub static INTEGERS: &[DataType] = &[
26 DataType::Int8,
27 DataType::Int16,
28 DataType::Int32,
29 DataType::Int64,
30 DataType::UInt8,
31 DataType::UInt16,
32 DataType::UInt32,
33 DataType::UInt64,
34];
35
36pub static NUMERICS: &[DataType] = &[
37 DataType::Int8,
38 DataType::Int16,
39 DataType::Int32,
40 DataType::Int64,
41 DataType::UInt8,
42 DataType::UInt16,
43 DataType::UInt32,
44 DataType::UInt64,
45 DataType::Float32,
46 DataType::Float64,
47];
48
49/// Validate the length of `input_fields` matches the `signature` for `agg_fun`.
50///
51/// This method DOES NOT validate the argument fields - only that (at least one,
52/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
53/// number of input types.
54pub fn check_arg_count(
55 func_name: &str,
56 input_fields: &[FieldRef],
57 signature: &TypeSignature,
58) -> Result<()> {
59 match signature {
60 TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
61 if input_fields.len() != *agg_count {
62 return plan_err!(
63 "The function {func_name} expects {:?} arguments, but {:?} were provided",
64 agg_count,
65 input_fields.len()
66 );
67 }
68 }
69 TypeSignature::Exact(types) => {
70 if types.len() != input_fields.len() {
71 return plan_err!(
72 "The function {func_name} expects {:?} arguments, but {:?} were provided",
73 types.len(),
74 input_fields.len()
75 );
76 }
77 }
78 TypeSignature::OneOf(variants) => {
79 let ok = variants
80 .iter()
81 .any(|v| check_arg_count(func_name, input_fields, v).is_ok());
82 if !ok {
83 return plan_err!(
84 "The function {func_name} does not accept {:?} function arguments.",
85 input_fields.len()
86 );
87 }
88 }
89 TypeSignature::VariadicAny => {
90 if input_fields.is_empty() {
91 return plan_err!(
92 "The function {func_name} expects at least one argument"
93 );
94 }
95 }
96 TypeSignature::UserDefined
97 | TypeSignature::Numeric(_)
98 | TypeSignature::Coercible(_) => {
99 // User-defined signature is validated in `coerce_types`
100 // Numeric and Coercible signature is validated in `get_valid_types`
101 }
102 _ => {
103 return internal_err!(
104 "Aggregate functions do not support this {signature:?}"
105 );
106 }
107 }
108 Ok(())
109}