From 41d0a9e201d703448de0ffb2e34aee37d97b00a2 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Thu, 19 Feb 2026 19:58:57 +0530 Subject: [PATCH] Fix decimal log precision for non-power values --- datafusion/functions/src/math/log.rs | 106 +++++++----------- .../sqllogictest/test_files/decimal.slt | 20 ++-- 2 files changed, 51 insertions(+), 75 deletions(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index d1906a4bf0e01..71e0281f8202a 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -109,80 +109,59 @@ fn is_valid_integer_base(base: f64) -> bool { } /// Calculate logarithm for Decimal32 values. -/// For integer bases >= 2 with non-negative scale, uses the efficient u32 ilog algorithm. -/// Otherwise falls back to f64 computation. +/// For integer bases >= 2 with zero scale, return an exact integer log when the +/// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { - if is_valid_integer_base(base) - && scale >= 0 - && let Some(unscaled) = unscale_to_u32(value, scale) + if scale == 0 + && is_valid_integer_base(base) + && let Ok(unscaled) = u32::try_from(value) + && unscaled > 0 { - return if unscaled > 0 { - Ok(unscaled.ilog(base as u32) as f64) - } else { - Ok(f64::NAN) - }; + let base_u32 = base as u32; + let int_log = unscaled.ilog(base_u32); + if base_u32.checked_pow(int_log) == Some(unscaled) { + return Ok(int_log as f64); + } } decimal_to_f64(value, scale).map(|v| v.log(base)) } /// Calculate logarithm for Decimal64 values. -/// For integer bases >= 2 with non-negative scale, uses the efficient u64 ilog algorithm. -/// Otherwise falls back to f64 computation. +/// For integer bases >= 2 with zero scale, return an exact integer log when the +/// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { - if is_valid_integer_base(base) - && scale >= 0 - && let Some(unscaled) = unscale_to_u64(value, scale) + if scale == 0 + && is_valid_integer_base(base) + && let Ok(unscaled) = u64::try_from(value) + && unscaled > 0 { - return if unscaled > 0 { - Ok(unscaled.ilog(base as u64) as f64) - } else { - Ok(f64::NAN) - }; + let base_u64 = base as u64; + let int_log = unscaled.ilog(base_u64); + if base_u64.checked_pow(int_log) == Some(unscaled) { + return Ok(int_log as f64); + } } decimal_to_f64(value, scale).map(|v| v.log(base)) } /// Calculate logarithm for Decimal128 values. -/// For integer bases >= 2 with non-negative scale, uses the efficient u128 ilog algorithm. -/// Otherwise falls back to f64 computation. +/// For integer bases >= 2 with zero scale, return an exact integer log when the +/// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { - if is_valid_integer_base(base) - && scale >= 0 - && let Some(unscaled) = unscale_to_u128(value, scale) + if scale == 0 + && is_valid_integer_base(base) + && let Ok(unscaled) = u128::try_from(value) + && unscaled > 0 { - return if unscaled > 0 { - Ok(unscaled.ilog(base as u128) as f64) - } else { - Ok(f64::NAN) - }; + let base_u128 = base as u128; + let int_log = unscaled.ilog(base_u128); + if base_u128.checked_pow(int_log) == Some(unscaled) { + return Ok(int_log as f64); + } } decimal_to_f64(value, scale).map(|v| v.log(base)) } -/// Unscale a Decimal32 value to u32. -#[inline] -fn unscale_to_u32(value: i32, scale: i8) -> Option { - let value_u32 = u32::try_from(value).ok()?; - let divisor = 10u32.checked_pow(scale as u32)?; - Some(value_u32 / divisor) -} - -/// Unscale a Decimal64 value to u64. -#[inline] -fn unscale_to_u64(value: i64, scale: i8) -> Option { - let value_u64 = u64::try_from(value).ok()?; - let divisor = 10u64.checked_pow(scale as u32)?; - Some(value_u64 / divisor) -} - -/// Unscale a Decimal128 value to u128. -#[inline] -fn unscale_to_u128(value: i128, scale: i8) -> Option { - let value_u128 = u128::try_from(value).ok()?; - let divisor = 10u128.checked_pow(scale as u32)?; - Some(value_u128 / divisor) -} - /// Convert a scaled decimal value to f64. #[inline] fn decimal_to_f64(value: T, scale: i8) -> Result { @@ -444,13 +423,9 @@ mod tests { #[test] fn test_log_decimal_native() { let value = 10_i128.pow(35); - assert_eq!((value as f64).log2(), 116.26748332105768); - assert_eq!( - log_decimal128(value, 0, 2.0).unwrap(), - // TODO: see we're losing our decimal points compared to above - // https://github.com/apache/datafusion/issues/18524 - 116.0 - ); + let expected = (value as f64).log2(); + let actual = log_decimal128(value, 0, 2.0).unwrap(); + assert!((actual - expected).abs() < 1e-10); } #[test] @@ -1012,7 +987,8 @@ mod tests { assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); - assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding + let expected = 12600_f64.log(10.0); + assert!((floats.value(4) - expected).abs() < 1e-10); assert!(floats.value(5).is_nan()); } ColumnarValue::Scalar(_) => { @@ -1147,8 +1123,10 @@ mod tests { assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); - assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding for float log - assert!((floats.value(5) - 38.0).abs() < 1e-10); + let expected = 12600_f64.log(10.0); + assert!((floats.value(4) - expected).abs() < 1e-10); + let expected = ((i128::MAX - 1000) as f64).log(10.0); + assert!((floats.value(5) - expected).abs() < 1e-10); assert!(floats.value(6).is_nan()); } ColumnarValue::Scalar(_) => { diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index f53f4939299c5..cf6123f9d6766 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -804,7 +804,7 @@ select log(arrow_cast(100, 'Decimal32(9, 2)')); query R select log(2.0, arrow_cast(12345.67, 'Decimal32(9, 2)')); ---- -13 +13.591717513272 # log for small decimal64 query R @@ -820,7 +820,7 @@ select log(arrow_cast(100, 'Decimal64(18, 2)')); query R select log(2.0, arrow_cast(12345.6789, 'Decimal64(15, 4)')); ---- -13 +13.591718553311 # log for small decimal128 @@ -896,15 +896,13 @@ select log(10::decimal(38, 0), 100000000000000000000000000000000000::decimal(38, query R select log(2, 100000000000000000000000000000000000::decimal(38,0)); ---- -116 +116.267483321058 # log(10^35) for decimal128 with another base -# TODO: this should be 116.267483321058, error with native decimal log impl -# https://github.com/apache/datafusion/issues/18524 query R select log(2.0, 100000000000000000000000000000000000::decimal(38,0)); ---- -116 +116.267483321058 # log with non-integer base (fallback to f64) query R @@ -1036,7 +1034,7 @@ from (values (10.0), (2.0), (3.0)) as t(base); query R SELECT log(10, arrow_cast(0.5, 'Decimal32(5, 1)')) ---- -NaN +-0.301029995664 query R SELECT log(10, arrow_cast(1 , 'Decimal32(5, 1)')) @@ -1195,18 +1193,18 @@ select 100000000000000000000000000000000000::decimal(38,0) 99999999999999996863366107917975552 # log(10^35) for decimal128 with explicit decimal base -# Float parsing is rounding down +# Float parsing is rounding down, but log uses float computation so result rounds to 35 query R select log(10, 100000000000000000000000000000000000::decimal(38,0)); ---- -34 +35 # log(10^35) for large decimal128 if parsed as float -# Float parsing is rounding down +# Float parsing is rounding down, but log uses float computation so result rounds to 35 query R select log(100000000000000000000000000000000000::decimal(38,0)) ---- -34 +35 # Result is decimal since argument is decimal regardless decimals-as-floats parsing query R