Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/contributor-guide/spark_expressions_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@
- [x] mod
- [x] negative
- [x] pi
- [ ] pmod
- [x] pmod
- [x] positive
- [x] pow
- [x] power
Expand Down
34 changes: 33 additions & 1 deletion native/core/src/execution/expressions/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ use std::sync::Arc;
use arrow::datatypes::SchemaRef;
use datafusion::logical_expr::Operator as DataFusionOperator;
use datafusion_comet_proto::spark_expression::Expr;
use datafusion_comet_spark_expr::{create_modulo_expr, create_negate_expr, EvalMode};
use datafusion_comet_spark_expr::{
create_modulo_expr, create_negate_expr, create_pmod_expr, EvalMode,
};

use crate::execution::{
expressions::extract_expr,
Expand Down Expand Up @@ -255,6 +257,36 @@ impl ExpressionBuilder for RemainderBuilder {
}
}

/// Builder for Pmod expressions (uses special pmod function)
pub struct PmodBuilder;

impl ExpressionBuilder for PmodBuilder {
fn build(
&self,
spark_expr: &Expr,
input_schema: SchemaRef,
planner: &PhysicalPlanner,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let expr = extract_expr!(spark_expr, Pmod);
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
let left = planner.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right = planner.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;

let result = create_pmod_expr(
left,
right,
expr.return_type
.as_ref()
.map(crate::execution::serde::to_arrow_datatype)
.unwrap(),
input_schema,
eval_mode == EvalMode::Ansi,
&planner.session_ctx().state(),
);
result.map_err(|e| ExecutionError::GeneralError(e.to_string()))
}
}

/// Builder for UnaryMinus expressions (uses special negate function)
pub struct UnaryMinusBuilder;

Expand Down
4 changes: 4 additions & 0 deletions native/core/src/execution/planner/expression_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub enum ExpressionType {
Divide,
IntegralDivide,
Remainder,
Pmod,
UnaryMinus,

// Comparison expressions
Expand Down Expand Up @@ -212,6 +213,8 @@ impl ExpressionRegistry {
);
self.builders
.insert(ExpressionType::Remainder, Box::new(RemainderBuilder));
self.builders
.insert(ExpressionType::Pmod, Box::new(PmodBuilder));
self.builders
.insert(ExpressionType::UnaryMinus, Box::new(UnaryMinusBuilder));
}
Expand Down Expand Up @@ -327,6 +330,7 @@ impl ExpressionRegistry {
Some(ExprStruct::Divide(_)) => Ok(ExpressionType::Divide),
Some(ExprStruct::IntegralDivide(_)) => Ok(ExpressionType::IntegralDivide),
Some(ExprStruct::Remainder(_)) => Ok(ExpressionType::Remainder),
Some(ExprStruct::Pmod(_)) => Ok(ExpressionType::Pmod),
Some(ExprStruct::UnaryMinus(_)) => Ok(ExpressionType::UnaryMinus),

Some(ExprStruct::Eq(_)) => Ok(ExpressionType::Eq),
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ message Expr {
HoursTransform hours_transform = 68;
ArraysZip arrays_zip = 69;
JvmScalarUdf jvm_scalar_udf = 70;
MathExpr pmod = 71;
}

// Optional QueryContext for error reporting (contains SQL text and position)
Expand Down
6 changes: 5 additions & 1 deletion native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::map_funcs::spark_map_sort;
use crate::math_funcs::abs::abs;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
use crate::math_funcs::log::spark_log;
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::math_funcs::modulo_expr::{spark_modulo, spark_pmod};
use crate::{
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan,
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
Expand Down Expand Up @@ -176,6 +176,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_modulo);
make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error)
}
"spark_pmod" => {
let func = Arc::new(spark_pmod);
make_comet_scalar_udf!("spark_pmod", func, without data_type, fail_on_error)
}
"abs" => {
let func = Arc::new(abs);
make_comet_scalar_udf!("abs", func, without data_type)
Expand Down
2 changes: 1 addition & 1 deletion native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, Spark
pub use hash_funcs::*;
pub use json_funcs::{FromJson, ToJson};
pub use math_funcs::{
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
create_modulo_expr, create_negate_expr, create_pmod_expr, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_log, spark_make_decimal, spark_round,
spark_unhex, spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr,
NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
Expand Down
1 change: 1 addition & 0 deletions native/spark-expr/src/math_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub use floor::spark_floor;
pub use internal::*;
pub use log::spark_log;
pub use modulo_expr::create_modulo_expr;
pub use modulo_expr::create_pmod_expr;
pub use negative::{create_negate_expr, NegativeExpr};
pub use round::spark_round;
pub use unhex::spark_unhex;
Expand Down
166 changes: 166 additions & 0 deletions native/spark-expr/src/math_funcs/modulo_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::{create_comet_physical_fun, IfExpr};
use crate::{remainder_by_zero_error, Cast, EvalMode, SparkCastOptions};
use arrow::array::ArrayRef;
use arrow::compute::kernels::numeric::rem;
use arrow::datatypes::*;
use datafusion::common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
Expand Down Expand Up @@ -63,6 +64,130 @@ pub fn spark_modulo(args: &[ColumnarValue], fail_on_error: bool) -> Result<Colum
}
}

/// Spark-compliant pmod (positive modulo) function. Returns the positive remainder of division.
/// If `fail_on_error` is true, returns an error on division by zero; otherwise returns `NULL`.
pub fn spark_pmod(args: &[ColumnarValue], fail_on_error: bool) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!("pmod expects exactly two arguments");
}

let left_data_type = args[0].data_type();

if left_data_type.is_nested() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested-type branch dispatches to apply_cmp_for_nested(Operator::Modulo, ...), but two things look off:

  1. Pmod is numeric-only in Spark, and CometPmod.supportedDataType already rejects non-numeric types. This branch should be unreachable.
  2. Even if it were reached, it would compute Modulo, not Pmod.

Suggest dropping the whole if left_data_type.is_nested() block, or replacing the body with internal_err!("spark_pmod does not support nested types") to make the invariant explicit.

return internal_err!("spark_pmod does not support nested types");
}

let arrays = ColumnarValue::values_to_arrays(args)?;
let left = &arrays[0];
let right = &arrays[1];

// Mirror Spark's `if (r < 0) (r + n) % n else r` so that the non-negative
// branch (including `-0.0`) is returned untouched. The right-hand side is
// masked to zero on non-negative rows so that `add(result, masked_right)`
// never overflows on values whose adjusted branch is discarded by the zip.
//
// Arrow's `cmp::lt` uses total ordering, where `lt(-0.0, 0.0)` is `true`.
// For floats we therefore compare against `-0.0`, which makes the
// total-order `lt` align with IEEE 754 (`-0.0` is not strictly less than
// `-0.0`, and `NaN` sorts above any number so the comparison is `false`).
match (|| -> std::result::Result<ArrayRef, arrow::error::ArrowError> {
let to_array = |sv: ScalarValue| {
sv.to_array_of_size(left.len())
.map_err(|e| arrow::error::ArrowError::ComputeError(e.to_string()))
};
let new_zero = || {
ScalarValue::new_zero(&left_data_type)
.map_err(|e| arrow::error::ArrowError::ComputeError(e.to_string()))
};
let zero = to_array(new_zero()?)?;
let lt_threshold = match left_data_type {
DataType::Float32 => to_array(ScalarValue::Float32(Some(-0.0)))?,
DataType::Float64 => to_array(ScalarValue::Float64(Some(-0.0)))?,
_ => Arc::clone(&zero),
};
let result = rem(left, right)?;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the negative-zero incompatibility be eliminated by branching with zip instead of doing an unconditional add/rem?

The current shape

let result = rem(left, right)?;
let neg = arrow::compute::kernels::cmp::lt(&result, &zero)?;
let plus = arrow::compute::kernels::zip::zip(&neg, right, &zero)?;
let result = arrow::compute::kernels::numeric::add(&plus, &result)?;
rem(&result, right)

normalizes -0.0 through add(+0.0, -0.0) = +0.0, which is the source of the documented strict-floating-point caveat. Spark's reference implementation is if (r < 0) (r + n) % n else r, i.e. the non-negative branch returns r untouched.

A direct vectorized translation:

let result = rem(left, right)?;
let neg = arrow::compute::kernels::cmp::lt(&result, &zero)?;
let adjusted = rem(&arrow::compute::kernels::numeric::add(&result, right)?, right)?;
arrow::compute::kernels::zip::zip(&neg, &adjusted, &result)

That preserves the original r for r >= 0, including -0.0, and matches Spark exactly. Quick sanity check on the interesting inputs:

input r neg result
pmod(-0.0, 3.0) -0.0 false -0.0
pmod(NaN, 3.0) NaN false (NaN cmp) NaN
pmod(-10, 3) -1 true 2
pmod(-10, -3) -1 true -1

If this works, CometPmod can drop getIncompatibleReasons / getSupportLevel and just be Compatible().

let neg = arrow::compute::kernels::cmp::lt(&result, &lt_threshold)?;
let masked_right = arrow::compute::kernels::zip::zip(&neg, right, &zero)?;
let adjusted = rem(
&arrow::compute::kernels::numeric::add(&result, &masked_right)?,
right,
)?;
arrow::compute::kernels::zip::zip(&neg, &adjusted, &result)
})() {
Ok(result) => Ok(ColumnarValue::Array(result)),
Err(e) if e.to_string().contains("Divide by zero") && fail_on_error => {
Err(remainder_by_zero_error().into())
}
Err(e) => Err(DataFusionError::ArrowError(Box::new(e), None)),
}
}

pub fn create_pmod_expr(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
data_type: DataType,
input_schema: SchemaRef,
fail_on_error: bool,
registry: &dyn FunctionRegistry,
) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
let right_non_ansi_safe = if !fail_on_error {
null_if_zero_primitive(right, &input_schema)?
} else {
right
};

match (
left.data_type(&input_schema),
right_non_ansi_safe.data_type(&input_schema),
) {
(Ok(DataType::Decimal128(p1, s1)), Ok(DataType::Decimal128(p2, s2)))
if max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) > DECIMAL128_MAX_PRECISION =>
{
let left_256 = Arc::new(Cast::new(
left,
DataType::Decimal256(p1, s1),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
None,
None,
));
let right_256 = Arc::new(Cast::new(
right_non_ansi_safe,
DataType::Decimal256(p2, s2),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
None,
None,
));

let decimal256_return_type = match &data_type {
DataType::Decimal128(p, s) => DataType::Decimal256(*p, *s),
other => other.clone(),
};
let pmod_scalar_func = create_pmod_scalar_function(
left_256,
right_256,
&decimal256_return_type,
registry,
fail_on_error,
)?;

Ok(Arc::new(Cast::new(
pmod_scalar_func,
data_type,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
None,
None,
)))
}
_ => create_pmod_scalar_function(
left,
right_non_ansi_safe,
&data_type,
registry,
fail_on_error,
),
}
}

pub fn create_modulo_expr(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -212,6 +337,25 @@ fn create_modulo_scalar_function(
)))
}

fn create_pmod_scalar_function(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
data_type: &DataType,
registry: &dyn FunctionRegistry,
fail_on_error: bool,
) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
let func_name = "spark_pmod";
let pmod_expr =
create_comet_physical_fun(func_name, data_type.clone(), registry, Some(fail_on_error))?;
Ok(Arc::new(ScalarFunctionExpr::new(
func_name,
pmod_expr,
vec![left, right],
Arc::new(Field::new(func_name, data_type.clone(), true)),
Arc::new(ConfigOptions::default()),
)))
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -223,6 +367,28 @@ mod tests {
use datafusion::physical_expr::expressions::{Column, Literal};
use datafusion::prelude::SessionContext;

#[test]
fn spark_pmod_preserves_negative_zero_f64() {
use arrow::array::Float64Array;
let a: ArrayRef = Arc::new(Float64Array::from(vec![-0.0_f64]));
let b: ArrayRef = Arc::new(Float64Array::from(vec![3.0_f64]));
let out = spark_pmod(&[ColumnarValue::Array(a), ColumnarValue::Array(b)], false).unwrap();
let arr = out.into_array(1).unwrap();
let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
assert_eq!(arr.value(0).to_bits(), (-0.0_f64).to_bits());
}

#[test]
fn spark_pmod_preserves_negative_zero_f32() {
use arrow::array::Float32Array;
let a: ArrayRef = Arc::new(Float32Array::from(vec![-0.0_f32]));
let b: ArrayRef = Arc::new(Float32Array::from(vec![3.0_f32]));
let out = spark_pmod(&[ColumnarValue::Array(a), ColumnarValue::Array(b)], false).unwrap();
let arr = out.into_array(1).unwrap();
let arr = arr.as_any().downcast_ref::<Float32Array>().unwrap();
assert_eq!(arr.value(0).to_bits(), (-0.0_f32).to_bits());
}

fn with_fail_on_error<F: Fn(bool)>(test_fn: F) {
for fail_on_error in [true, false] {
test_fn(fail_on_error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
classOf[Logarithm] -> CometLogarithm,
classOf[Multiply] -> CometMultiply,
classOf[Pi] -> CometScalarFunction("pi"),
classOf[Pmod] -> CometPmod,
classOf[Pow] -> CometScalarFunction("pow"),
classOf[Rand] -> CometRand,
classOf[Randn] -> CometRandn,
Expand Down
29 changes: 28 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import scala.math.min

import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Round, Subtract, UnaryMinus}
import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Cast, Divide, EmptyRow, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Pmod, Remainder, Round, Subtract, UnaryMinus}
import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
Expand Down Expand Up @@ -284,6 +284,33 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
}
}

object CometPmod extends CometExpressionSerde[Pmod] with MathBase {

override def convert(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description says "rejecting TRY eval mode", but CometPmod.convert doesn't have the expr.evalMode == EvalMode.TRY guard that CometRemainder has at line 270.

There's no try_pmod in Spark's function registry, so the gap is unlikely to be reachable today. But matching CometRemainder's shape keeps the serde defensive in case Pmod ever gets constructed with a TRY context, and aligns the code with the PR description.

expr: Pmod,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
if (!supportedDataType(expr.left.dataType)) {
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
if (expr.evalMode == EvalMode.TRY) {
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
return None
}

createMathExpression(
expr,
expr.left,
expr.right,
inputs,
binding,
expr.dataType,
expr.evalMode,
(builder, mathExpr) => builder.setPmod(mathExpr))
}
}

object CometRound extends CometExpressionSerde[Round] {

override def convert(
Expand Down
Loading
Loading