diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 79e19313699cb..f196870e97228 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -245,11 +245,10 @@ macro_rules! make_math_unary_udf { impl $UDF { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::uniform( 1, - vec![Float64, Float32], + vec![DataType::Float64, DataType::Float32], Volatility::Immutable, ), } @@ -270,7 +269,6 @@ macro_rules! make_math_unary_udf { match arg_type { DataType::Float32 => Ok(DataType::Float32), - // For other types (possible values float64/null/int), use Float64 _ => Ok(DataType::Float64), } } @@ -345,8 +343,12 @@ macro_rules! make_math_unary_udf { /// Macro to create a binary math UDF. /// -/// A binary math function takes two arguments of types Float32 or Float64, -/// applies a binary floating function to the argument, and returns a value of the same type. +/// A binary math function takes two numeric arguments. When both arguments are +/// Float32 the function is evaluated in single precision and returns Float32. +/// Any other combination of numeric (or null) argument types is coerced to +/// Float64 and returns Float64; in particular integers are widened to Float64 +/// rather than Float32 so that values needing more than 24 bits of mantissa are +/// not silently rounded. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` /// $NAME: the name of the function @@ -365,7 +367,6 @@ macro_rules! make_math_binary_udf { use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, ScalarValue, internal_err}; - use datafusion_expr::TypeSignature; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, @@ -379,13 +380,18 @@ macro_rules! make_math_binary_udf { impl $UDF { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Float32, Float32]), - TypeSignature::Exact(vec![Float64, Float64]), - ], + // Float64 is listed first so that integer (and other + // non-float) arguments coerce to Float64 rather than + // Float32; genuine Float32 arguments still match + // exactly and stay in single precision. Coercing + // integers to Float64 matters for correctness: Float32 + // has only a 24-bit mantissa, so widening a large + // integer to Float32 would round it before the function + // is ever applied. + signature: Signature::uniform( + 2, + vec![DataType::Float64, DataType::Float32], Volatility::Immutable, ), } @@ -402,11 +408,8 @@ macro_rules! make_math_binary_udf { } fn return_type(&self, arg_types: &[DataType]) -> Result { - let arg_type = &arg_types[0]; - - match arg_type { - DataType::Float32 => Ok(DataType::Float32), - // For other types (possible values float64/null/int), use Float64 + match (&arg_types[0], &arg_types[1]) { + (DataType::Float32, DataType::Float32) => Ok(DataType::Float32), _ => Ok(DataType::Float64), } } diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index 4a0db9ef0cf7a..52449f9c9e0b9 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -262,11 +262,11 @@ Can be a constant, column, or function, and any combination of arithmetic operat ) .with_sql_example(r#"```sql > SELECT atan2(1, 1); -+------------+ -| atan2(1,1) | -+------------+ -| 0.7853982 | -+------------+ ++--------------------+ +| atan2(1,1) | ++--------------------+ +| 0.7853981633974483 | ++--------------------+ ```"#) .build() }); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 38f76f13151bc..9dbf8f16d85ab 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -234,7 +234,26 @@ select round(atanh(a), 5), round(atanh(b), 5), round(atanh(c), 5) from small_flo query RRR rowsort select atan2(0, 1), atan2(1, 2), atan2(2, 2); ---- -0 0.4636476 0.7853982 +0 0.463647609001 0.785398163397 + +# atan2 returns Float32 only when both arguments are Float32; every other +# numeric combination (integers, Float64, mixed, NULL) is computed in Float64 +query TTTTTT +select + arrow_typeof(atan2(arrow_cast(1.0, 'Float32'), arrow_cast(1.0, 'Float32'))), + arrow_typeof(atan2(1, 1)), + arrow_typeof(atan2(arrow_cast(1.0, 'Float32'), arrow_cast(1.0, 'Float64'))), + arrow_typeof(atan2(arrow_cast(1.0, 'Float64'), arrow_cast(1.0, 'Float32'))), + arrow_typeof(atan2(null, null)), + arrow_typeof(atan2(null, 64)); +---- +Float32 Float64 Float64 Float64 Float64 Float64 + +# atan2 with integer inputs is computed in double precision +query B +select atan2(1, 1000000) = atan2(1.0, 1000000.0); +---- +true # atan2 scalar nulls query R rowsort diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 6bf61391eb10e..8ff0032723f90 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -227,11 +227,11 @@ atan2(expression_y, expression_x) ```sql > SELECT atan2(1, 1); -+------------+ -| atan2(1,1) | -+------------+ -| 0.7853982 | -+------------+ ++--------------------+ +| atan2(1,1) | ++--------------------+ +| 0.7853981633974483 | ++--------------------+ ``` ### `atanh`