diff --git a/datafusion/spark/src/function/math/ceil.rs b/datafusion/spark/src/function/math/ceil.rs new file mode 100644 index 0000000000000..9683b54cb1ef0 --- /dev/null +++ b/datafusion/spark/src/function/math/ceil.rs @@ -0,0 +1,438 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::cast::AsArray; +use arrow::array::types::Decimal128Type; +use arrow::array::{Decimal128Array, Int64Array}; +use arrow::compute::kernels::arity::unary; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{DataFusionError, ScalarValue, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `ceil` function. +/// Unlike standard DataFusion ceil, this returns Int64 for float/integer (per Spark spec) +/// Supported types +/// Float32, Float64 -> Int64 +/// Int8, Int16, Int32, Int64 -> Int64 +/// Decimal128(p, s): -> Decimal128(p, s) (preserves precision and scale) +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCeil { + signature: Signature, +} + +impl Default for SparkCeil { + fn default() -> Self { + Self::new() + } +} + +impl SparkCeil { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkCeil { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ceil" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args( + &self, + args: ReturnFieldArgs, + ) -> datafusion_common::Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let return_type = match args.arg_fields[0].data_type() { + DataType::Decimal128(p, s) => DataType::Decimal128(*p, *s), + _ => DataType::Int64, + }; + Ok(Arc::new(Field::new(self.name(), return_type, nullable))) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + spark_ceil(&args.args) + } +} + +pub fn spark_ceil(args: &[ColumnarValue]) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let input = array.as_primitive::(); + let result: Int64Array = unary(input, |x| x.ceil() as i64); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Float64 => { + let input = array.as_primitive::(); + let result: Int64Array = unary(input, |x| x.ceil() as i64); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int8 => { + let input = array.as_primitive::(); + let result: Int64Array = unary(input, |x| x as i64); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int16 => { + let input = array.as_primitive::(); + let result: Int64Array = unary(input, |x| x as i64); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int32 => { + let input = array.as_primitive::(); + let result: Int64Array = unary(input, |x| x as i64); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int64 => Ok(ColumnarValue::Array(Arc::clone(array))), + DataType::Decimal128(precision, scale) => { + if *scale <= 0 { + Ok(ColumnarValue::Array(Arc::clone(array))) + } else { + let f = decimal_ceil_helper(*scale); + let input = array.as_primitive::(); + let result: Decimal128Array = unary(input, &f); + let result = + result.with_data_type(DataType::Decimal128(*precision, *scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + } + } + other => { + exec_err!("Unsupported data type {other:?} for function ceil") + } + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Float32(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + v.map(|x| x.ceil() as i64), + ))), + ScalarValue::Float64(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + v.map(|x| x.ceil() as i64), + ))), + ScalarValue::Int8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + v.map(|x| x as i64), + ))), + ScalarValue::Int16(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + v.map(|x| x as i64), + ))), + ScalarValue::Int32(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + v.map(|x| x as i64), + ))), + ScalarValue::Int64(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(*v))), + ScalarValue::Decimal128(v, precision, scale) => { + if *scale <= 0 { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + *v, *precision, *scale, + ))) + } else { + let f = decimal_ceil_helper(*scale); + let result = v.map(f); + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + result, *precision, *scale, + ))) + } + } + other => { + exec_err!( + "Unsupported data type {:?} for function ceil", + other.data_type() + ) + } + }, + } +} + +#[inline] +fn decimal_ceil_helper(scale: i8) -> impl Fn(i128) -> i128 { + let divisor = 10_i128.pow(scale as u32); + move |x: i128| { + let quotient = x / divisor; + let remainder = x % divisor; + if remainder > 0 { + (quotient + 1) * divisor + } else { + quotient * divisor + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Array, Decimal128Array, Float32Array, Float64Array, Int8Array, Int64Array, + }; + use datafusion_common::Result; + use datafusion_common::cast::{as_decimal128_array, as_int64_array}; + + #[test] + fn test_ceil_float32_array() -> Result<()> { + let array = Float32Array::from(vec![ + Some(1.1), + Some(1.9), + Some(-1.1), + Some(-1.9), + Some(0.0), + None, + ]); + let args = vec![ColumnarValue::Array(Arc::new(array))]; + let ColumnarValue::Array(result) = spark_ceil(&args)? else { + unreachable!() + }; + let result = as_int64_array(&result)?; + assert_eq!(result.value(0), 2); + assert_eq!(result.value(1), 2); + assert_eq!(result.value(2), -1); + assert_eq!(result.value(3), -1); + assert_eq!(result.value(4), 0); + assert!(result.is_null(5)); + Ok(()) + } + + #[test] + fn test_ceil_float32_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.5)))]; + let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) = spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, 2); + Ok(()) + } + + #[test] + fn test_ceil_float64_array() -> Result<()> { + let array = Float64Array::from(vec![ + Some(1.1), + Some(1.9), + Some(-1.1), + Some(-1.9), + Some(0.0), + Some(123.0), + None, + ]); + let args = vec![ColumnarValue::Array(Arc::new(array))]; + let ColumnarValue::Array(result) = spark_ceil(&args)? else { + unreachable!() + }; + let result = as_int64_array(&result)?; + assert_eq!(result.value(0), 2); + assert_eq!(result.value(1), 2); + assert_eq!(result.value(2), -1); + assert_eq!(result.value(3), -1); + assert_eq!(result.value(4), 0); + assert_eq!(result.value(5), 123); + assert!(result.is_null(6)); + Ok(()) + } + + #[test] + fn test_ceil_float64_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-1.5)))]; + let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) = spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, -1); + Ok(()) + } + + #[test] + fn test_ceil_float64_null_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Float64(None))]; + let ColumnarValue::Scalar(ScalarValue::Int64(result)) = spark_ceil(&args)? else { + unreachable!() + }; + assert_eq!(result, None); + Ok(()) + } + + #[test] + fn test_ceil_int8_array() -> Result<()> { + let array = Int8Array::from(vec![Some(1), Some(-1), Some(127), Some(-128), None]); + let args = vec![ColumnarValue::Array(Arc::new(array))]; + let ColumnarValue::Array(result) = spark_ceil(&args)? else { + unreachable!() + }; + let result = as_int64_array(&result)?; + assert_eq!(result.value(0), 1); + assert_eq!(result.value(1), -1); + assert_eq!(result.value(2), 127); + assert_eq!(result.value(3), -128); + assert!(result.is_null(4)); + Ok(()) + } + + #[test] + fn test_ceil_int16_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int16(Some(100)))]; + let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) = spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, 100); + Ok(()) + } + + #[test] + fn test_ceil_int32_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(-500)))]; + let ColumnarValue::Scalar(ScalarValue::Int64(Some(result))) = spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, -500); + Ok(()) + } + + #[test] + fn test_ceil_int64_array() -> Result<()> { + let array = Int64Array::from(vec![ + Some(1), + Some(-1), + Some(i64::MAX), + Some(i64::MIN), + None, + ]); + let args = vec![ColumnarValue::Array(Arc::new(array))]; + let ColumnarValue::Array(result) = spark_ceil(&args)? else { + unreachable!() + }; + let result = as_int64_array(&result)?; + assert_eq!(result.value(0), 1); + assert_eq!(result.value(1), -1); + assert_eq!(result.value(2), i64::MAX); + assert_eq!(result.value(3), i64::MIN); + assert!(result.is_null(4)); + Ok(()) + } + + #[test] + fn test_ceil_decimal128_array() -> Result<()> { + let array = + Decimal128Array::from(vec![Some(12345), Some(12500), Some(-12999), None]) + .with_precision_and_scale(5, 2)?; + let args = vec![ColumnarValue::Array(Arc::new(array))]; + let ColumnarValue::Array(result) = spark_ceil(&args)? else { + unreachable!() + }; + let expected = + Decimal128Array::from(vec![Some(12400), Some(12500), Some(-12900), None]) + .with_precision_and_scale(5, 2)?; + let actual = as_decimal128_array(&result)?; + assert_eq!(actual, &expected); + Ok(()) + } + + #[test] + fn test_ceil_decimal128_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(567), + 3, + 1, + ))]; + let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), 3, 1)) = + spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, 570); + Ok(()) + } + + #[test] + fn test_ceil_decimal128_negative_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(-567), + 3, + 1, + ))]; + let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), 3, 1)) = + spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, -560); + Ok(()) + } + + #[test] + fn test_ceil_decimal128_null_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128(None, 5, 2))]; + let ColumnarValue::Scalar(ScalarValue::Decimal128(result, 5, 2)) = + spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, None); + Ok(()) + } + + #[test] + fn test_ceil_decimal128_scale_zero() -> Result<()> { + let array = Decimal128Array::from(vec![Some(123), Some(-456), None]) + .with_precision_and_scale(10, 0)?; + let args = vec![ColumnarValue::Array(Arc::new(array))]; + let ColumnarValue::Array(result) = spark_ceil(&args)? else { + unreachable!() + }; + let result = as_decimal128_array(&result)?; + assert_eq!(result.value(0), 123); + assert_eq!(result.value(1), -456); + assert!(result.is_null(2)); + Ok(()) + } + + #[test] + fn test_ceil_decimal128_scale_zero_scalar() -> Result<()> { + let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(12345), + 10, + 0, + ))]; + let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), 10, 0)) = + spark_ceil(&args)? + else { + unreachable!() + }; + assert_eq!(result, 12345); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 7f7d04e06b0be..c9d702f8efb7f 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -17,6 +17,7 @@ pub mod abs; pub mod bin; +pub mod ceil; pub mod expm1; pub mod factorial; pub mod hex; @@ -43,7 +44,7 @@ make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); make_udf_function!(negative::SparkNegative, negative); -make_udf_function!(bin::SparkBin, bin); +make_udf_function!(bin::SparkBin, bin, ceil); pub mod expr_fn { use datafusion_functions::export_functions; @@ -72,6 +73,7 @@ pub mod expr_fn { "Returns the negation of expr (unary minus).", arg1 )); + export_functions!((ceil, "Returns the ceiling of expr.", arg1)); export_functions!(( bin, "Returns the string representation of the long value represented in binary.", @@ -93,6 +95,7 @@ pub fn functions() -> Vec> { csc(), sec(), negative(), + ceil(), bin(), ] } diff --git a/datafusion/sqllogictest/test_files/spark/math/ceil.slt b/datafusion/sqllogictest/test_files/spark/math/ceil.slt index c87a29b61fd49..b02ad8bcd35c8 100644 --- a/datafusion/sqllogictest/test_files/spark/math/ceil.slt +++ b/datafusion/sqllogictest/test_files/spark/math/ceil.slt @@ -38,5 +38,211 @@ ## Original Query: SELECT ceil(5); ## PySpark 3.5.5 Result: {'CEIL(5)': 5, 'typeof(CEIL(5))': 'bigint', 'typeof(5)': 'int'} -#query -#SELECT ceil(5::int); +query I +SELECT ceil(5::int); +---- +5 + +# Additional tests for all supported numeric types + +# Test Float32 +query I +SELECT ceil(125.2345::float); +---- +126 + +query I +SELECT ceil(15.0001::float); +---- +16 + +query I +SELECT ceil(0.1::float); +---- +1 + +query I +SELECT ceil(-0.9::float); +---- +0 + +query I +SELECT ceil(-1.1::float); +---- +-1 + +query I +SELECT ceil(123.0::float); +---- +123 + +query I +SELECT ceil(0.0::float); +---- +0 + +# Test Float64 +query I +SELECT ceil(125.2345::double); +---- +126 + +query I +SELECT ceil(15.0001::double); +---- +16 + +query I +SELECT ceil(0.1::double); +---- +1 + +query I +SELECT ceil(-0.9::double); +---- +0 + +query I +SELECT ceil(-1.1::double); +---- +-1 + +query I +SELECT ceil(123.0::double); +---- +123 + +query I +SELECT ceil(1.9999::double); +---- +2 + +query I +SELECT ceil(-1.9999::double); +---- +-1 + +# Test Int8 (tinyint) +query I +SELECT ceil(5::tinyint); +---- +5 + +query I +SELECT ceil(-1::tinyint); +---- +-1 + +query I +SELECT ceil(0::tinyint); +---- +0 + +query I +SELECT ceil(127::tinyint); +---- +127 + +query I +SELECT ceil(CAST(-128 AS tinyint)); +---- +-128 + +# Test Int16 (smallint) +query I +SELECT ceil(100::smallint); +---- +100 + +query I +SELECT ceil(-50::smallint); +---- +-50 + +query I +SELECT ceil(0::smallint); +---- +0 + +query I +SELECT ceil(32767::smallint); +---- +32767 + +query I +SELECT ceil(CAST(-32768 AS smallint)); +---- +-32768 + +# Test Int32 (int) +query I +SELECT ceil(1000::int); +---- +1000 + +query I +SELECT ceil(-500::int); +---- +-500 + +query I +SELECT ceil(0::int); +---- +0 + +# Test Int64 (bigint) +query I +SELECT ceil(48::bigint); +---- +48 + +query I +SELECT ceil(-1::bigint); +---- +-1 + +query I +SELECT ceil(0::bigint); +---- +0 + +query I +SELECT ceil(9223372036854775807::bigint); +---- +9223372036854775807 + +query I +SELECT ceil(CAST(-9223372036854775808 AS bigint)); +---- +-9223372036854775808 + +# Test NULL values +query I +SELECT ceil(NULL::float); +---- +NULL + +query I +SELECT ceil(NULL::double); +---- +NULL + +query I +SELECT ceil(NULL::tinyint); +---- +NULL + +query I +SELECT ceil(NULL::smallint); +---- +NULL + +query I +SELECT ceil(NULL::int); +---- +NULL + +query I +SELECT ceil(NULL::bigint); +---- +NULL