datafusion_functions_aggregate/utils.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::RecordBatch;
21use arrow::datatypes::Schema;
22use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue};
23use datafusion_expr::ColumnarValue;
24use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
25
26/// Evaluates a physical expression to extract its scalar value.
27///
28/// This is used to extract constant values from expressions (like percentile parameters)
29/// by evaluating them against an empty record batch.
30pub(crate) fn get_scalar_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
31 let empty_schema = Arc::new(Schema::empty());
32 let batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
33 if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? {
34 Ok(s)
35 } else {
36 internal_err!("Didn't expect ColumnarValue::Array")
37 }
38}
39
40/// Validates that a percentile expression is a literal float value between 0.0 and 1.0.
41///
42/// Used by both `percentile_cont` and `approx_percentile_cont` to validate their
43/// percentile parameters.
44pub(crate) fn validate_percentile_expr(
45 expr: &Arc<dyn PhysicalExpr>,
46 fn_name: &str,
47) -> Result<f64> {
48 let scalar_value = get_scalar_value(expr).map_err(|_e| {
49 DataFusionError::Plan(format!(
50 "Percentile value for '{fn_name}' must be a literal"
51 ))
52 })?;
53
54 let percentile = match scalar_value {
55 ScalarValue::Float32(Some(value)) => value as f64,
56 ScalarValue::Float64(Some(value)) => value,
57 sv => {
58 return plan_err!(
59 "Percentile value for '{fn_name}' must be Float32 or Float64 literal (got data type {})",
60 sv.data_type()
61 )
62 }
63 };
64
65 // Ensure the percentile is between 0 and 1.
66 if !(0.0..=1.0).contains(&percentile) {
67 return plan_err!(
68 "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid"
69 );
70 }
71 Ok(percentile)
72}