datafusion_sql/expr/
value.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
19use arrow::compute::kernels::cast_utils::{
20    parse_interval_month_day_nano_config, IntervalParseConfig, IntervalUnit,
21};
22use arrow::datatypes::{
23    i256, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
24};
25use bigdecimal::num_bigint::BigInt;
26use bigdecimal::{BigDecimal, Signed, ToPrimitive};
27use datafusion_common::{
28    internal_datafusion_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result,
29    ScalarValue,
30};
31use datafusion_expr::expr::{BinaryExpr, Placeholder};
32use datafusion_expr::planner::PlannerResult;
33use datafusion_expr::{lit, Expr, Operator};
34use log::debug;
35use sqlparser::ast::{
36    BinaryOperator, Expr as SQLExpr, Interval, UnaryOperator, Value, ValueWithSpan,
37};
38use sqlparser::parser::ParserError::ParserError;
39use std::borrow::Cow;
40use std::cmp::Ordering;
41use std::ops::Neg;
42use std::str::FromStr;
43
44impl<S: ContextProvider> SqlToRel<'_, S> {
45    pub(crate) fn parse_value(
46        &self,
47        value: Value,
48        param_data_types: &[FieldRef],
49    ) -> Result<Expr> {
50        match value {
51            Value::Number(n, _) => self.parse_sql_number(&n, false),
52            Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)),
53            Value::Null => Ok(Expr::Literal(ScalarValue::Null, None)),
54            Value::Boolean(n) => Ok(lit(n)),
55            Value::Placeholder(param) => {
56                Self::create_placeholder_expr(param, param_data_types)
57            }
58            Value::HexStringLiteral(s) => {
59                if let Some(v) = try_decode_hex_literal(&s) {
60                    Ok(lit(v))
61                } else {
62                    plan_err!("Invalid HexStringLiteral '{s}'")
63                }
64            }
65            Value::DollarQuotedString(s) => Ok(lit(s.value)),
66            Value::EscapedStringLiteral(s) => Ok(lit(s)),
67            _ => plan_err!("Unsupported Value '{value:?}'"),
68        }
69    }
70
71    /// Parse number in sql string, convert to Expr::Literal
72    pub(super) fn parse_sql_number(
73        &self,
74        unsigned_number: &str,
75        negative: bool,
76    ) -> Result<Expr> {
77        let signed_number: Cow<str> = if negative {
78            Cow::Owned(format!("-{unsigned_number}"))
79        } else {
80            Cow::Borrowed(unsigned_number)
81        };
82
83        // Try to parse as i64 first, then u64 if negative is false, then decimal or f64
84
85        if let Ok(n) = signed_number.parse::<i64>() {
86            return Ok(lit(n));
87        }
88
89        if !negative {
90            if let Ok(n) = unsigned_number.parse::<u64>() {
91                return Ok(lit(n));
92            }
93        }
94
95        if self.options.parse_float_as_decimal {
96            parse_decimal(unsigned_number, negative)
97        } else {
98            signed_number.parse::<f64>().map(lit).map_err(|_| {
99                DataFusionError::from(ParserError(format!(
100                    "Cannot parse {signed_number} as f64"
101                )))
102            })
103        }
104    }
105
106    /// Create a placeholder expression
107    /// This is the same as Postgres's prepare statement syntax in which a placeholder starts with `$` sign and then
108    /// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on.
109    fn create_placeholder_expr(
110        param: String,
111        param_data_types: &[FieldRef],
112    ) -> Result<Expr> {
113        // Parse the placeholder as a number because it is the only support from sqlparser and postgres
114        let index = param[1..].parse::<usize>();
115        let idx = match index {
116            Ok(0) => {
117                return plan_err!(
118                    "Invalid placeholder, zero is not a valid index: {param}"
119                );
120            }
121            Ok(index) => index - 1,
122            Err(_) => {
123                return if param_data_types.is_empty() {
124                    Ok(Expr::Placeholder(Placeholder::new_with_field(param, None)))
125                } else {
126                    // when PREPARE Statement, param_data_types length is always 0
127                    plan_err!("Invalid placeholder, not a number: {param}")
128                };
129            }
130        };
131        // Check if the placeholder is in the parameter list
132        let param_type = param_data_types.get(idx);
133        // Data type of the parameter
134        debug!("type of param {param} param_data_types[idx]: {param_type:?}");
135
136        Ok(Expr::Placeholder(Placeholder::new_with_field(
137            param,
138            param_type.cloned(),
139        )))
140    }
141
142    pub(super) fn sql_array_literal(
143        &self,
144        elements: Vec<SQLExpr>,
145        schema: &DFSchema,
146    ) -> Result<Expr> {
147        let values = elements
148            .into_iter()
149            .map(|element| {
150                self.sql_expr_to_logical_expr(element, schema, &mut PlannerContext::new())
151            })
152            .collect::<Result<Vec<_>>>()?;
153
154        self.try_plan_array_literal(values, schema)
155    }
156
157    fn try_plan_array_literal(
158        &self,
159        values: Vec<Expr>,
160        schema: &DFSchema,
161    ) -> Result<Expr> {
162        let mut exprs = values;
163        for planner in self.context_provider.get_expr_planners() {
164            match planner.plan_array_literal(exprs, schema)? {
165                PlannerResult::Planned(expr) => {
166                    return Ok(expr);
167                }
168                PlannerResult::Original(values) => exprs = values,
169            }
170        }
171
172        not_impl_err!("Could not plan array literal. Hint: Please try with `nested_expressions` DataFusion feature enabled")
173    }
174
175    /// Convert a SQL interval expression to a DataFusion logical plan
176    /// expression
177    #[allow(clippy::only_used_in_recursion)]
178    pub(super) fn sql_interval_to_expr(
179        &self,
180        negative: bool,
181        interval: Interval,
182    ) -> Result<Expr> {
183        if interval.leading_precision.is_some() {
184            return not_impl_err!(
185                "Unsupported Interval Expression with leading_precision {:?}",
186                interval.leading_precision
187            );
188        }
189
190        if interval.last_field.is_some() {
191            return not_impl_err!(
192                "Unsupported Interval Expression with last_field {:?}",
193                interval.last_field
194            );
195        }
196
197        if interval.fractional_seconds_precision.is_some() {
198            return not_impl_err!(
199                "Unsupported Interval Expression with fractional_seconds_precision {:?}",
200                interval.fractional_seconds_precision
201            );
202        }
203
204        if let SQLExpr::BinaryOp { left, op, right } = *interval.value {
205            let df_op = match op {
206                BinaryOperator::Plus => Operator::Plus,
207                BinaryOperator::Minus => Operator::Minus,
208                _ => {
209                    return not_impl_err!("Unsupported interval operator: {op:?}");
210                }
211            };
212            let left_expr = self.sql_interval_to_expr(
213                negative,
214                Interval {
215                    value: left,
216                    leading_field: interval.leading_field.clone(),
217                    leading_precision: None,
218                    last_field: None,
219                    fractional_seconds_precision: None,
220                },
221            )?;
222            let right_expr = self.sql_interval_to_expr(
223                false,
224                Interval {
225                    value: right,
226                    leading_field: interval.leading_field,
227                    leading_precision: None,
228                    last_field: None,
229                    fractional_seconds_precision: None,
230                },
231            )?;
232            return Ok(Expr::BinaryExpr(BinaryExpr::new(
233                Box::new(left_expr),
234                df_op,
235                Box::new(right_expr),
236            )));
237        }
238
239        let value = interval_literal(*interval.value, negative)?;
240
241        // leading_field really means the unit if specified
242        // For example, "month" in  `INTERVAL '5' month`
243        let value = match interval.leading_field.as_ref() {
244            Some(leading_field) => format!("{value} {leading_field}"),
245            None => value,
246        };
247
248        let config = IntervalParseConfig::new(IntervalUnit::Second);
249        let val = parse_interval_month_day_nano_config(&value, config)?;
250        Ok(lit(ScalarValue::IntervalMonthDayNano(Some(val))))
251    }
252}
253
254fn interval_literal(interval_value: SQLExpr, negative: bool) -> Result<String> {
255    let s = match interval_value {
256        SQLExpr::Value(ValueWithSpan {
257            value: Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
258            span: _,
259        }) => s,
260        SQLExpr::Value(ValueWithSpan {
261            value: Value::Number(ref v, long),
262            span: _,
263        }) => {
264            if long {
265                return not_impl_err!(
266                    "Unsupported interval argument. Long number not supported: {interval_value:?}"
267                );
268            } else {
269                v.to_string()
270            }
271        }
272        SQLExpr::UnaryOp { op, expr } => {
273            let negative = match op {
274                UnaryOperator::Minus => !negative,
275                UnaryOperator::Plus => negative,
276                _ => {
277                    return not_impl_err!(
278                        "Unsupported SQL unary operator in interval {op:?}"
279                    );
280                }
281            };
282            interval_literal(*expr, negative)?
283        }
284        _ => {
285            return not_impl_err!("Unsupported interval argument. Expected string literal or number, got: {interval_value:?}");
286        }
287    };
288    if negative {
289        Ok(format!("-{s}"))
290    } else {
291        Ok(s)
292    }
293}
294
295/// Try to decode bytes from hex literal string.
296///
297/// None will be returned if the input literal is hex-invalid.
298fn try_decode_hex_literal(s: &str) -> Option<Vec<u8>> {
299    let hex_bytes = s.as_bytes();
300
301    let mut decoded_bytes = Vec::with_capacity(hex_bytes.len().div_ceil(2));
302
303    let start_idx = hex_bytes.len() % 2;
304    if start_idx > 0 {
305        // The first byte is formed of only one char.
306        decoded_bytes.push(try_decode_hex_char(hex_bytes[0])?);
307    }
308
309    for i in (start_idx..hex_bytes.len()).step_by(2) {
310        let high = try_decode_hex_char(hex_bytes[i])?;
311        let low = try_decode_hex_char(hex_bytes[i + 1])?;
312        decoded_bytes.push((high << 4) | low);
313    }
314
315    Some(decoded_bytes)
316}
317
318/// Try to decode a byte from a hex char.
319///
320/// None will be returned if the input char is hex-invalid.
321const fn try_decode_hex_char(c: u8) -> Option<u8> {
322    match c {
323        b'A'..=b'F' => Some(c - b'A' + 10),
324        b'a'..=b'f' => Some(c - b'a' + 10),
325        b'0'..=b'9' => Some(c - b'0'),
326        _ => None,
327    }
328}
329
330/// Returns None if the value can't be converted to i256.
331/// Modified from <https://github.com/apache/arrow-rs/blob/c4dbf0d8af6ca5a19b8b2ea777da3c276807fc5e/arrow-buffer/src/bigint/mod.rs#L303>
332fn bigint_to_i256(v: &BigInt) -> Option<i256> {
333    let v_bytes = v.to_signed_bytes_le();
334    match v_bytes.len().cmp(&32) {
335        Ordering::Less => {
336            let mut bytes = if v.is_negative() {
337                [255_u8; 32]
338            } else {
339                [0; 32]
340            };
341            bytes[0..v_bytes.len()].copy_from_slice(&v_bytes[..v_bytes.len()]);
342            Some(i256::from_le_bytes(bytes))
343        }
344        Ordering::Equal => Some(i256::from_le_bytes(v_bytes.try_into().unwrap())),
345        Ordering::Greater => None,
346    }
347}
348
349fn parse_decimal(unsigned_number: &str, negative: bool) -> Result<Expr> {
350    let mut dec = BigDecimal::from_str(unsigned_number).map_err(|e| {
351        DataFusionError::from(ParserError(format!(
352            "Cannot parse {unsigned_number} as BigDecimal: {e}"
353        )))
354    })?;
355    if negative {
356        dec = dec.neg();
357    }
358
359    let digits = dec.digits();
360    let (int_val, scale) = dec.into_bigint_and_exponent();
361    if scale < i8::MIN as i64 {
362        return not_impl_err!(
363            "Decimal scale {} exceeds the minimum supported scale: {}",
364            scale,
365            i8::MIN
366        );
367    }
368    let precision = if scale > 0 {
369        // arrow-rs requires the precision to include the positive scale.
370        // See <https://github.com/apache/arrow-rs/blob/123045cc766d42d1eb06ee8bb3f09e39ea995ddc/arrow-array/src/types.rs#L1230>
371        std::cmp::max(digits, scale.unsigned_abs())
372    } else {
373        digits
374    };
375    if precision <= DECIMAL128_MAX_PRECISION as u64 {
376        let val = int_val.to_i128().ok_or_else(|| {
377            // Failures are unexpected here as we have already checked the precision
378            internal_datafusion_err!(
379                "Unexpected overflow when converting {} to i128",
380                int_val
381            )
382        })?;
383        Ok(Expr::Literal(
384            ScalarValue::Decimal128(Some(val), precision as u8, scale as i8),
385            None,
386        ))
387    } else if precision <= DECIMAL256_MAX_PRECISION as u64 {
388        let val = bigint_to_i256(&int_val).ok_or_else(|| {
389            // Failures are unexpected here as we have already checked the precision
390            internal_datafusion_err!(
391                "Unexpected overflow when converting {} to i256",
392                int_val
393            )
394        })?;
395        Ok(Expr::Literal(
396            ScalarValue::Decimal256(Some(val), precision as u8, scale as i8),
397            None,
398        ))
399    } else {
400        not_impl_err!(
401            "Decimal precision {} exceeds the maximum supported precision: {}",
402            precision,
403            DECIMAL256_MAX_PRECISION
404        )
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_decode_hex_literal() {
414        let cases = [
415            ("", Some(vec![])),
416            ("FF00", Some(vec![255, 0])),
417            ("a00a", Some(vec![160, 10])),
418            ("FF0", Some(vec![15, 240])),
419            ("f", Some(vec![15])),
420            ("FF0X", None),
421            ("X0", None),
422            ("XX", None),
423            ("x", None),
424        ];
425
426        for (input, expect) in cases {
427            let output = try_decode_hex_literal(input);
428            assert_eq!(output, expect);
429        }
430    }
431
432    #[test]
433    fn test_bigint_to_i256() {
434        let cases = [
435            (BigInt::from(0), Some(i256::from(0))),
436            (BigInt::from(1), Some(i256::from(1))),
437            (BigInt::from(-1), Some(i256::from(-1))),
438            (
439                BigInt::from_str(i256::MAX.to_string().as_str()).unwrap(),
440                Some(i256::MAX),
441            ),
442            (
443                BigInt::from_str(i256::MIN.to_string().as_str()).unwrap(),
444                Some(i256::MIN),
445            ),
446            (
447                // Can't fit into i256
448                BigInt::from_str((i256::MAX.to_string() + "1").as_str()).unwrap(),
449                None,
450            ),
451        ];
452
453        for (input, expect) in cases {
454            let output = bigint_to_i256(&input);
455            assert_eq!(output, expect);
456        }
457    }
458
459    #[test]
460    fn test_parse_decimal() {
461        // Supported cases
462        let cases = [
463            ("0", ScalarValue::Decimal128(Some(0), 1, 0)),
464            ("1", ScalarValue::Decimal128(Some(1), 1, 0)),
465            ("123.45", ScalarValue::Decimal128(Some(12345), 5, 2)),
466            // Digit count is less than scale
467            ("0.001", ScalarValue::Decimal128(Some(1), 3, 3)),
468            // Scientific notation
469            ("123.456e-2", ScalarValue::Decimal128(Some(123456), 6, 5)),
470            // Negative scale
471            ("123456e128", ScalarValue::Decimal128(Some(123456), 6, -128)),
472            // Decimal256
473            (
474                &("9".repeat(39) + "." + "99999"),
475                ScalarValue::Decimal256(
476                    Some(i256::from_string(&"9".repeat(44)).unwrap()),
477                    44,
478                    5,
479                ),
480            ),
481        ];
482        for (input, expect) in cases {
483            let output = parse_decimal(input, true).unwrap();
484            assert_eq!(
485                output,
486                Expr::Literal(expect.arithmetic_negate().unwrap(), None)
487            );
488
489            let output = parse_decimal(input, false).unwrap();
490            assert_eq!(output, Expr::Literal(expect, None));
491        }
492
493        // scale < i8::MIN
494        assert_eq!(
495            parse_decimal("1e129", false)
496                .unwrap_err()
497                .strip_backtrace(),
498            "This feature is not implemented: Decimal scale -129 exceeds the minimum supported scale: -128"
499        );
500
501        // Unsupported precision
502        assert_eq!(
503            parse_decimal(&"1".repeat(77), false)
504                .unwrap_err()
505                .strip_backtrace(),
506            "This feature is not implemented: Decimal precision 77 exceeds the maximum supported precision: 76"
507        );
508    }
509}