-
Notifications
You must be signed in to change notification settings - Fork 321
feat: add native support for pmod expression #4277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2ed72c6
c90162d
5a31b2c
9c4c367
8fd6aa2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -399,7 +399,7 @@ | |
| - [x] mod | ||
| - [x] negative | ||
| - [x] pi | ||
| - [ ] pmod | ||
| - [x] pmod | ||
| - [x] positive | ||
| - [x] pow | ||
| - [x] power | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| use crate::{create_comet_physical_fun, IfExpr}; | ||||||||||||||||||||||
| use crate::{remainder_by_zero_error, Cast, EvalMode, SparkCastOptions}; | ||||||||||||||||||||||
| use arrow::array::ArrayRef; | ||||||||||||||||||||||
| use arrow::compute::kernels::numeric::rem; | ||||||||||||||||||||||
| use arrow::datatypes::*; | ||||||||||||||||||||||
| use datafusion::common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; | ||||||||||||||||||||||
|
|
@@ -63,6 +64,130 @@ pub fn spark_modulo(args: &[ColumnarValue], fail_on_error: bool) -> Result<Colum | |||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| /// Spark-compliant pmod (positive modulo) function. Returns the positive remainder of division. | ||||||||||||||||||||||
| /// If `fail_on_error` is true, returns an error on division by zero; otherwise returns `NULL`. | ||||||||||||||||||||||
| pub fn spark_pmod(args: &[ColumnarValue], fail_on_error: bool) -> Result<ColumnarValue> { | ||||||||||||||||||||||
| if args.len() != 2 { | ||||||||||||||||||||||
| return exec_err!("pmod expects exactly two arguments"); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| let left_data_type = args[0].data_type(); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if left_data_type.is_nested() { | ||||||||||||||||||||||
| return internal_err!("spark_pmod does not support nested types"); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| let arrays = ColumnarValue::values_to_arrays(args)?; | ||||||||||||||||||||||
| let left = &arrays[0]; | ||||||||||||||||||||||
| let right = &arrays[1]; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Mirror Spark's `if (r < 0) (r + n) % n else r` so that the non-negative | ||||||||||||||||||||||
| // branch (including `-0.0`) is returned untouched. The right-hand side is | ||||||||||||||||||||||
| // masked to zero on non-negative rows so that `add(result, masked_right)` | ||||||||||||||||||||||
| // never overflows on values whose adjusted branch is discarded by the zip. | ||||||||||||||||||||||
| // | ||||||||||||||||||||||
| // Arrow's `cmp::lt` uses total ordering, where `lt(-0.0, 0.0)` is `true`. | ||||||||||||||||||||||
| // For floats we therefore compare against `-0.0`, which makes the | ||||||||||||||||||||||
| // total-order `lt` align with IEEE 754 (`-0.0` is not strictly less than | ||||||||||||||||||||||
| // `-0.0`, and `NaN` sorts above any number so the comparison is `false`). | ||||||||||||||||||||||
| match (|| -> std::result::Result<ArrayRef, arrow::error::ArrowError> { | ||||||||||||||||||||||
| let to_array = |sv: ScalarValue| { | ||||||||||||||||||||||
| sv.to_array_of_size(left.len()) | ||||||||||||||||||||||
| .map_err(|e| arrow::error::ArrowError::ComputeError(e.to_string())) | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| let new_zero = || { | ||||||||||||||||||||||
| ScalarValue::new_zero(&left_data_type) | ||||||||||||||||||||||
| .map_err(|e| arrow::error::ArrowError::ComputeError(e.to_string())) | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| let zero = to_array(new_zero()?)?; | ||||||||||||||||||||||
| let lt_threshold = match left_data_type { | ||||||||||||||||||||||
| DataType::Float32 => to_array(ScalarValue::Float32(Some(-0.0)))?, | ||||||||||||||||||||||
| DataType::Float64 => to_array(ScalarValue::Float64(Some(-0.0)))?, | ||||||||||||||||||||||
| _ => Arc::clone(&zero), | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| let result = rem(left, right)?; | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could the negative-zero incompatibility be eliminated by branching with The current shape let result = rem(left, right)?;
let neg = arrow::compute::kernels::cmp::lt(&result, &zero)?;
let plus = arrow::compute::kernels::zip::zip(&neg, right, &zero)?;
let result = arrow::compute::kernels::numeric::add(&plus, &result)?;
rem(&result, right)normalizes A direct vectorized translation: let result = rem(left, right)?;
let neg = arrow::compute::kernels::cmp::lt(&result, &zero)?;
let adjusted = rem(&arrow::compute::kernels::numeric::add(&result, right)?, right)?;
arrow::compute::kernels::zip::zip(&neg, &adjusted, &result)That preserves the original
If this works, |
||||||||||||||||||||||
| let neg = arrow::compute::kernels::cmp::lt(&result, <_threshold)?; | ||||||||||||||||||||||
| let masked_right = arrow::compute::kernels::zip::zip(&neg, right, &zero)?; | ||||||||||||||||||||||
| let adjusted = rem( | ||||||||||||||||||||||
| &arrow::compute::kernels::numeric::add(&result, &masked_right)?, | ||||||||||||||||||||||
| right, | ||||||||||||||||||||||
| )?; | ||||||||||||||||||||||
| arrow::compute::kernels::zip::zip(&neg, &adjusted, &result) | ||||||||||||||||||||||
| })() { | ||||||||||||||||||||||
| Ok(result) => Ok(ColumnarValue::Array(result)), | ||||||||||||||||||||||
| Err(e) if e.to_string().contains("Divide by zero") && fail_on_error => { | ||||||||||||||||||||||
| Err(remainder_by_zero_error().into()) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| Err(e) => Err(DataFusionError::ArrowError(Box::new(e), None)), | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| pub fn create_pmod_expr( | ||||||||||||||||||||||
| left: Arc<dyn PhysicalExpr>, | ||||||||||||||||||||||
| right: Arc<dyn PhysicalExpr>, | ||||||||||||||||||||||
| data_type: DataType, | ||||||||||||||||||||||
| input_schema: SchemaRef, | ||||||||||||||||||||||
| fail_on_error: bool, | ||||||||||||||||||||||
| registry: &dyn FunctionRegistry, | ||||||||||||||||||||||
| ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> { | ||||||||||||||||||||||
| let right_non_ansi_safe = if !fail_on_error { | ||||||||||||||||||||||
| null_if_zero_primitive(right, &input_schema)? | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| right | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| match ( | ||||||||||||||||||||||
| left.data_type(&input_schema), | ||||||||||||||||||||||
| right_non_ansi_safe.data_type(&input_schema), | ||||||||||||||||||||||
| ) { | ||||||||||||||||||||||
| (Ok(DataType::Decimal128(p1, s1)), Ok(DataType::Decimal128(p2, s2))) | ||||||||||||||||||||||
| if max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) > DECIMAL128_MAX_PRECISION => | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| let left_256 = Arc::new(Cast::new( | ||||||||||||||||||||||
| left, | ||||||||||||||||||||||
| DataType::Decimal256(p1, s1), | ||||||||||||||||||||||
| SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), | ||||||||||||||||||||||
| None, | ||||||||||||||||||||||
| None, | ||||||||||||||||||||||
| )); | ||||||||||||||||||||||
| let right_256 = Arc::new(Cast::new( | ||||||||||||||||||||||
| right_non_ansi_safe, | ||||||||||||||||||||||
| DataType::Decimal256(p2, s2), | ||||||||||||||||||||||
| SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), | ||||||||||||||||||||||
| None, | ||||||||||||||||||||||
| None, | ||||||||||||||||||||||
| )); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| let decimal256_return_type = match &data_type { | ||||||||||||||||||||||
| DataType::Decimal128(p, s) => DataType::Decimal256(*p, *s), | ||||||||||||||||||||||
| other => other.clone(), | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| let pmod_scalar_func = create_pmod_scalar_function( | ||||||||||||||||||||||
| left_256, | ||||||||||||||||||||||
| right_256, | ||||||||||||||||||||||
| &decimal256_return_type, | ||||||||||||||||||||||
| registry, | ||||||||||||||||||||||
| fail_on_error, | ||||||||||||||||||||||
| )?; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Ok(Arc::new(Cast::new( | ||||||||||||||||||||||
| pmod_scalar_func, | ||||||||||||||||||||||
| data_type, | ||||||||||||||||||||||
| SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), | ||||||||||||||||||||||
| None, | ||||||||||||||||||||||
| None, | ||||||||||||||||||||||
| ))) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| _ => create_pmod_scalar_function( | ||||||||||||||||||||||
| left, | ||||||||||||||||||||||
| right_non_ansi_safe, | ||||||||||||||||||||||
| &data_type, | ||||||||||||||||||||||
| registry, | ||||||||||||||||||||||
| fail_on_error, | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| pub fn create_modulo_expr( | ||||||||||||||||||||||
| left: Arc<dyn PhysicalExpr>, | ||||||||||||||||||||||
| right: Arc<dyn PhysicalExpr>, | ||||||||||||||||||||||
|
|
@@ -212,6 +337,25 @@ fn create_modulo_scalar_function( | |||||||||||||||||||||
| ))) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| fn create_pmod_scalar_function( | ||||||||||||||||||||||
| left: Arc<dyn PhysicalExpr>, | ||||||||||||||||||||||
| right: Arc<dyn PhysicalExpr>, | ||||||||||||||||||||||
| data_type: &DataType, | ||||||||||||||||||||||
| registry: &dyn FunctionRegistry, | ||||||||||||||||||||||
| fail_on_error: bool, | ||||||||||||||||||||||
| ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> { | ||||||||||||||||||||||
| let func_name = "spark_pmod"; | ||||||||||||||||||||||
| let pmod_expr = | ||||||||||||||||||||||
| create_comet_physical_fun(func_name, data_type.clone(), registry, Some(fail_on_error))?; | ||||||||||||||||||||||
| Ok(Arc::new(ScalarFunctionExpr::new( | ||||||||||||||||||||||
| func_name, | ||||||||||||||||||||||
| pmod_expr, | ||||||||||||||||||||||
| vec![left, right], | ||||||||||||||||||||||
| Arc::new(Field::new(func_name, data_type.clone(), true)), | ||||||||||||||||||||||
| Arc::new(ConfigOptions::default()), | ||||||||||||||||||||||
| ))) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| #[cfg(test)] | ||||||||||||||||||||||
| mod tests { | ||||||||||||||||||||||
| use super::*; | ||||||||||||||||||||||
|
|
@@ -223,6 +367,28 @@ mod tests { | |||||||||||||||||||||
| use datafusion::physical_expr::expressions::{Column, Literal}; | ||||||||||||||||||||||
| use datafusion::prelude::SessionContext; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| #[test] | ||||||||||||||||||||||
| fn spark_pmod_preserves_negative_zero_f64() { | ||||||||||||||||||||||
| use arrow::array::Float64Array; | ||||||||||||||||||||||
| let a: ArrayRef = Arc::new(Float64Array::from(vec![-0.0_f64])); | ||||||||||||||||||||||
| let b: ArrayRef = Arc::new(Float64Array::from(vec![3.0_f64])); | ||||||||||||||||||||||
| let out = spark_pmod(&[ColumnarValue::Array(a), ColumnarValue::Array(b)], false).unwrap(); | ||||||||||||||||||||||
| let arr = out.into_array(1).unwrap(); | ||||||||||||||||||||||
| let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap(); | ||||||||||||||||||||||
| assert_eq!(arr.value(0).to_bits(), (-0.0_f64).to_bits()); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| #[test] | ||||||||||||||||||||||
| fn spark_pmod_preserves_negative_zero_f32() { | ||||||||||||||||||||||
| use arrow::array::Float32Array; | ||||||||||||||||||||||
| let a: ArrayRef = Arc::new(Float32Array::from(vec![-0.0_f32])); | ||||||||||||||||||||||
| let b: ArrayRef = Arc::new(Float32Array::from(vec![3.0_f32])); | ||||||||||||||||||||||
| let out = spark_pmod(&[ColumnarValue::Array(a), ColumnarValue::Array(b)], false).unwrap(); | ||||||||||||||||||||||
| let arr = out.into_array(1).unwrap(); | ||||||||||||||||||||||
| let arr = arr.as_any().downcast_ref::<Float32Array>().unwrap(); | ||||||||||||||||||||||
| assert_eq!(arr.value(0).to_bits(), (-0.0_f32).to_bits()); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| fn with_fail_on_error<F: Fn(bool)>(test_fn: F) { | ||||||||||||||||||||||
| for fail_on_error in [true, false] { | ||||||||||||||||||||||
| test_fn(fail_on_error); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ package org.apache.comet.serde | |
|
|
||
| import scala.math.min | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Round, Subtract, UnaryMinus} | ||
| import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Pmod, Remainder, Round, Subtract, UnaryMinus} | ||
| import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType} | ||
|
|
||
| import org.apache.comet.CometSparkSessionExtensions.withInfo | ||
|
|
@@ -284,6 +284,33 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase { | |
| } | ||
| } | ||
|
|
||
| object CometPmod extends CometExpressionSerde[Pmod] with MathBase { | ||
|
|
||
| override def convert( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The PR description says "rejecting TRY eval mode", but There's no |
||
| expr: Pmod, | ||
| inputs: Seq[Attribute], | ||
| binding: Boolean): Option[ExprOuterClass.Expr] = { | ||
| if (!supportedDataType(expr.left.dataType)) { | ||
| withInfo(expr, s"Unsupported datatype ${expr.left.dataType}") | ||
| return None | ||
| } | ||
| if (expr.evalMode == EvalMode.TRY) { | ||
| withInfo(expr, s"Eval mode ${expr.evalMode} is not supported") | ||
| return None | ||
| } | ||
|
|
||
| createMathExpression( | ||
| expr, | ||
| expr.left, | ||
| expr.right, | ||
| inputs, | ||
| binding, | ||
| expr.dataType, | ||
| expr.evalMode, | ||
| (builder, mathExpr) => builder.setPmod(mathExpr)) | ||
| } | ||
| } | ||
|
|
||
| object CometRound extends CometExpressionSerde[Round] { | ||
|
|
||
| override def convert( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested-type branch dispatches to
apply_cmp_for_nested(Operator::Modulo, ...), but two things look off:Pmodis numeric-only in Spark, andCometPmod.supportedDataTypealready rejects non-numeric types. This branch should be unreachable.Suggest dropping the whole
if left_data_type.is_nested()block, or replacing the body withinternal_err!("spark_pmod does not support nested types")to make the invariant explicit.