datafusion_functions_aggregate/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::RecordBatch;
21use arrow::datatypes::Schema;
22use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue};
23use datafusion_expr::ColumnarValue;
24use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
25
26/// Evaluates a physical expression to extract its scalar value.
27///
28/// This is used to extract constant values from expressions (like percentile parameters)
29/// by evaluating them against an empty record batch.
30pub(crate) fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
31    let empty_schema = Arc::new(Schema::empty());
32    let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
33    if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
34        Ok(s)
35    } else {
36        internal_err!("Didn't expect ColumnarValue::Array")
37    }
38}
39
40/// Validates that a percentile expression is a literal float value between 0.0 and 1.0.
41///
42/// Used by both `percentile_cont` and `approx_percentile_cont` to validate their
43/// percentile parameters.
44pub(crate) fn validate_percentile_expr(
45    expr: &Arc<dyn PhysicalExpr>,
46    fn_name: &str,
47) -> Result<f64> {
48    let scalar_value = get_scalar_value(expr).map_err(|_e| {
49        DataFusionError::Plan(format!(
50            "Percentile value for '{fn_name}' must be a literal"
51        ))
52    })?;
53
54    let percentile = match scalar_value {
55        ScalarValue::Float32(Some(value)) => value as f64,
56        ScalarValue::Float64(Some(value)) => value,
57        sv => {
58            return plan_err!(
59                "Percentile value for '{fn_name}' must be Float32 or Float64 literal (got data type {})",
60                sv.data_type()
61            )
62        }
63    };
64
65    // Ensure the percentile is between 0 and 1.
66    if !(0.0..=1.0).contains(&percentile) {
67        return plan_err!(
68            "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
69        );
70    }
71    Ok(percentile)
72}