datafusion_expr_common/type_coercion/
aggregates.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::signature::TypeSignature;
19use arrow::datatypes::{DataType, FieldRef};
20
21use datafusion_common::{internal_err, plan_err, Result};
22
23// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures
24//       see https://github.com/apache/datafusion/issues/18092
25pub static INTEGERS: &[DataType] = &[
26    DataType::Int8,
27    DataType::Int16,
28    DataType::Int32,
29    DataType::Int64,
30    DataType::UInt8,
31    DataType::UInt16,
32    DataType::UInt32,
33    DataType::UInt64,
34];
35
36pub static NUMERICS: &[DataType] = &[
37    DataType::Int8,
38    DataType::Int16,
39    DataType::Int32,
40    DataType::Int64,
41    DataType::UInt8,
42    DataType::UInt16,
43    DataType::UInt32,
44    DataType::UInt64,
45    DataType::Float32,
46    DataType::Float64,
47];
48
49/// Validate the length of `input_fields` matches the `signature` for `agg_fun`.
50///
51/// This method DOES NOT validate the argument fields - only that (at least one,
52/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
53/// number of input types.
54pub fn check_arg_count(
55    func_name: &str,
56    input_fields: &[FieldRef],
57    signature: &TypeSignature,
58) -> Result<()> {
59    match signature {
60        TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
61            if input_fields.len() != *agg_count {
62                return plan_err!(
63                    "The function {func_name} expects {:?} arguments, but {:?} were provided",
64                    agg_count,
65                    input_fields.len()
66                );
67            }
68        }
69        TypeSignature::Exact(types) => {
70            if types.len() != input_fields.len() {
71                return plan_err!(
72                    "The function {func_name} expects {:?} arguments, but {:?} were provided",
73                    types.len(),
74                    input_fields.len()
75                );
76            }
77        }
78        TypeSignature::OneOf(variants) => {
79            let ok = variants
80                .iter()
81                .any(|v| check_arg_count(func_name, input_fields, v).is_ok());
82            if !ok {
83                return plan_err!(
84                    "The function {func_name} does not accept {:?} function arguments.",
85                    input_fields.len()
86                );
87            }
88        }
89        TypeSignature::VariadicAny => {
90            if input_fields.is_empty() {
91                return plan_err!(
92                    "The function {func_name} expects at least one argument"
93                );
94            }
95        }
96        TypeSignature::UserDefined
97        | TypeSignature::Numeric(_)
98        | TypeSignature::Coercible(_) => {
99            // User-defined signature is validated in `coerce_types`
100            // Numeric and Coercible signature is validated in `get_valid_types`
101        }
102        _ => {
103            return internal_err!(
104                "Aggregate functions do not support this {signature:?}"
105            );
106        }
107    }
108    Ok(())
109}