Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 42 additions & 64 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,80 +109,59 @@ fn is_valid_integer_base(base: f64) -> bool {
}

/// Calculate logarithm for Decimal32 values.
/// For integer bases >= 2 with non-negative scale, uses the efficient u32 ilog algorithm.
/// Otherwise falls back to f64 computation.
/// For integer bases >= 2 with zero scale, return an exact integer log when the
/// value is a perfect power of the base. Otherwise falls back to f64 computation.
fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
if is_valid_integer_base(base)
&& scale >= 0
&& let Some(unscaled) = unscale_to_u32(value, scale)
if scale == 0
&& is_valid_integer_base(base)
&& let Ok(unscaled) = u32::try_from(value)
&& unscaled > 0
{
return if unscaled > 0 {
Ok(unscaled.ilog(base as u32) as f64)
} else {
Ok(f64::NAN)
};
let base_u32 = base as u32;
let int_log = unscaled.ilog(base_u32);
if base_u32.checked_pow(int_log) == Some(unscaled) {
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
}

/// Calculate logarithm for Decimal64 values.
/// For integer bases >= 2 with non-negative scale, uses the efficient u64 ilog algorithm.
/// Otherwise falls back to f64 computation.
/// For integer bases >= 2 with zero scale, return an exact integer log when the
/// value is a perfect power of the base. Otherwise falls back to f64 computation.
fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
if is_valid_integer_base(base)
&& scale >= 0
&& let Some(unscaled) = unscale_to_u64(value, scale)
if scale == 0
&& is_valid_integer_base(base)
&& let Ok(unscaled) = u64::try_from(value)
&& unscaled > 0
{
return if unscaled > 0 {
Ok(unscaled.ilog(base as u64) as f64)
} else {
Ok(f64::NAN)
};
let base_u64 = base as u64;
let int_log = unscaled.ilog(base_u64);
if base_u64.checked_pow(int_log) == Some(unscaled) {
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
}

/// Calculate logarithm for Decimal128 values.
/// For integer bases >= 2 with non-negative scale, uses the efficient u128 ilog algorithm.
/// Otherwise falls back to f64 computation.
/// For integer bases >= 2 with zero scale, return an exact integer log when the
/// value is a perfect power of the base. Otherwise falls back to f64 computation.
fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> {
if is_valid_integer_base(base)
&& scale >= 0
&& let Some(unscaled) = unscale_to_u128(value, scale)
if scale == 0
&& is_valid_integer_base(base)
&& let Ok(unscaled) = u128::try_from(value)
&& unscaled > 0
{
return if unscaled > 0 {
Ok(unscaled.ilog(base as u128) as f64)
} else {
Ok(f64::NAN)
};
let base_u128 = base as u128;
let int_log = unscaled.ilog(base_u128);
if base_u128.checked_pow(int_log) == Some(unscaled) {
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
}

/// Unscale a Decimal32 value to u32.
#[inline]
fn unscale_to_u32(value: i32, scale: i8) -> Option<u32> {
let value_u32 = u32::try_from(value).ok()?;
let divisor = 10u32.checked_pow(scale as u32)?;
Some(value_u32 / divisor)
}

/// Unscale a Decimal64 value to u64.
#[inline]
fn unscale_to_u64(value: i64, scale: i8) -> Option<u64> {
let value_u64 = u64::try_from(value).ok()?;
let divisor = 10u64.checked_pow(scale as u32)?;
Some(value_u64 / divisor)
}

/// Unscale a Decimal128 value to u128.
#[inline]
fn unscale_to_u128(value: i128, scale: i8) -> Option<u128> {
let value_u128 = u128::try_from(value).ok()?;
let divisor = 10u128.checked_pow(scale as u32)?;
Some(value_u128 / divisor)
}

/// Convert a scaled decimal value to f64.
#[inline]
fn decimal_to_f64<T: ToPrimitive + Copy>(value: T, scale: i8) -> Result<f64, ArrowError> {
Expand Down Expand Up @@ -444,13 +423,9 @@ mod tests {
#[test]
fn test_log_decimal_native() {
let value = 10_i128.pow(35);
assert_eq!((value as f64).log2(), 116.26748332105768);
assert_eq!(
log_decimal128(value, 0, 2.0).unwrap(),
// TODO: see we're losing our decimal points compared to above
// https://github.com/apache/datafusion/issues/18524
116.0
);
let expected = (value as f64).log2();
let actual = log_decimal128(value, 0, 2.0).unwrap();
assert!((actual - expected).abs() < 1e-10);
}

#[test]
Expand Down Expand Up @@ -1012,7 +987,8 @@ mod tests {
assert!((floats.value(1) - 2.0).abs() < 1e-10);
assert!((floats.value(2) - 3.0).abs() < 1e-10);
assert!((floats.value(3) - 4.0).abs() < 1e-10);
assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding
let expected = 12600_f64.log(10.0);
assert!((floats.value(4) - expected).abs() < 1e-10);
assert!(floats.value(5).is_nan());
}
ColumnarValue::Scalar(_) => {
Expand Down Expand Up @@ -1147,8 +1123,10 @@ mod tests {
assert!((floats.value(1) - 2.0).abs() < 1e-10);
assert!((floats.value(2) - 3.0).abs() < 1e-10);
assert!((floats.value(3) - 4.0).abs() < 1e-10);
assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding for float log
assert!((floats.value(5) - 38.0).abs() < 1e-10);
let expected = 12600_f64.log(10.0);
assert!((floats.value(4) - expected).abs() < 1e-10);
let expected = ((i128::MAX - 1000) as f64).log(10.0);
assert!((floats.value(5) - expected).abs() < 1e-10);
assert!(floats.value(6).is_nan());
}
ColumnarValue::Scalar(_) => {
Expand Down
20 changes: 9 additions & 11 deletions datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ select log(arrow_cast(100, 'Decimal32(9, 2)'));
query R
select log(2.0, arrow_cast(12345.67, 'Decimal32(9, 2)'));
----
13
13.591717513272

# log for small decimal64
query R
Expand All @@ -820,7 +820,7 @@ select log(arrow_cast(100, 'Decimal64(18, 2)'));
query R
select log(2.0, arrow_cast(12345.6789, 'Decimal64(15, 4)'));
----
13
13.591718553311


# log for small decimal128
Expand Down Expand Up @@ -896,15 +896,13 @@ select log(10::decimal(38, 0), 100000000000000000000000000000000000::decimal(38,
query R
select log(2, 100000000000000000000000000000000000::decimal(38,0));
----
116
116.267483321058

# log(10^35) for decimal128 with another base
# TODO: this should be 116.267483321058, error with native decimal log impl
# https://github.com/apache/datafusion/issues/18524
query R
select log(2.0, 100000000000000000000000000000000000::decimal(38,0));
----
116
116.267483321058

# log with non-integer base (fallback to f64)
query R
Expand Down Expand Up @@ -1036,7 +1034,7 @@ from (values (10.0), (2.0), (3.0)) as t(base);
query R
SELECT log(10, arrow_cast(0.5, 'Decimal32(5, 1)'))
----
NaN
-0.301029995664

query R
SELECT log(10, arrow_cast(1 , 'Decimal32(5, 1)'))
Expand Down Expand Up @@ -1195,18 +1193,18 @@ select 100000000000000000000000000000000000::decimal(38,0)
99999999999999996863366107917975552

# log(10^35) for decimal128 with explicit decimal base
# Float parsing is rounding down
# Float parsing is rounding down, but log uses float computation so result rounds to 35
query R
select log(10, 100000000000000000000000000000000000::decimal(38,0));
----
34
35

# log(10^35) for large decimal128 if parsed as float
# Float parsing is rounding down
# Float parsing is rounding down, but log uses float computation so result rounds to 35
query R
select log(100000000000000000000000000000000000::decimal(38,0))
----
34
35

# Result is decimal since argument is decimal regardless decimals-as-floats parsing
query R
Expand Down