Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2587,43 +2587,47 @@ object DecimalAggregates extends Rule[LogicalPlan] {
// Window arm: `ExtractWindowExpressions` hoists composite children
// (here the widening Cast) into a child Project, so widened-Cast
// peel is unreachable from this expression-level rule.
case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction =
ae.copy(aggregateFunction = s.copy(child = UnscaledValue(e)))),
prec + 10, scale)

case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
case a @ Average(e @ DecimalExpression(prec, scale), _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr =
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
we.copy(windowFunction = ae.copy(aggregateFunction =
a.copy(child = UnscaledValue(e))))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone))

case _ => we
}
case ae @ AggregateExpression(af, _, _, _, _) => af match {
case Sum(WidenedDecimalChild(inner, p, pPrime, s), _)
case s @ Sum(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this bug affect the above case Sum(e @ DecimalExpression(prec, scale), _) and case Average(e @ DecimalExpression(prec, scale), _) branches?

if p + 10 <= MAX_LONG_DIGITS =>
Cast(
MakeDecimal(
ae.copy(aggregateFunction = Sum(UnscaledValue(inner))),
p + 10, s),
DecimalType.bounded(pPrime + 10, s),
ae.copy(aggregateFunction = s.copy(child = UnscaledValue(inner))),
p + 10, s_scale),
DecimalType.bounded(pPrime + 10, s_scale),
Option(conf.sessionLocalTimeZone))

case Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
case s @ Sum(e @ DecimalExpression(prec, scale), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = s.copy(child = UnscaledValue(e))),
prec + 10, scale)

// Ordered before the un-widened Average arm: when pPrime in [8, 11],
// the outer Cast's DecimalType would otherwise match that arm first.
case Average(WidenedDecimalChild(inner, p, pPrime, s), _)
case a @ Average(WidenedDecimalChild(inner, p, pPrime, s_scale), _)
if p <= AVG_PEEL_MAX_INNER_PRECISION =>
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(inner)))
val newAggExpr = ae.copy(aggregateFunction = a.copy(child = UnscaledValue(inner)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, s), DoubleType)),
DecimalType.bounded(pPrime + 4, s + 4), Option(conf.sessionLocalTimeZone))
Divide(newAggExpr, Literal.create(math.pow(10.0, s_scale), DoubleType)),
DecimalType.bounded(pPrime + 4, s_scale + 4), Option(conf.sessionLocalTimeZone))

case Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
case a @ Average(e @ DecimalExpression(prec, scale), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = a.copy(child = UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -582,4 +582,97 @@ class DecimalAggregatesSuite extends PlanTest with ScalaCheckDrivenPropertyCheck
val correctAnswer = q.analyze
comparePlans(optimized, correctAnswer)
}

// ---------------------------------------------------------------------------
// SPARK-56949: DecimalAggregates must preserve evalMode / evalContext when
// rewriting Sum / Average through the fast-path. The pre-fix rule called the
// single-arg helper ctor `Sum(child)` / `Average(child)`, which re-reads
// EvalMode from SQLConf and silently drops EvalMode.TRY from try_sum /
// try_avg, breaking their "return NULL on overflow" semantics.
//
// Vanilla 3.5.3 ground-truth (rule OFF vs ON) recorded in todos repo:
// features/spark-decimal-aggregate-evalmode-preserve/docs/0001-idea.md (section 3)
// ---------------------------------------------------------------------------

private def findSum(plan: LogicalPlan): Seq[Sum] =
plan.collect { case n => n.expressions }.flatten
.flatMap(_.collect { case s: Sum => s })
private def findAverage(plan: LogicalPlan): Seq[Average] =
plan.collect { case n => n.expressions }.flatten
.flatMap(_.collect { case a: Average => a })

test("SPARK-56949: DecimalAggregates preserves Sum.evalContext for try_sum") {
val trySum = Sum($"a", NumericEvalContext(EvalMode.TRY))
val q = testRelation.select(trySum.toAggregateExpression().as("ts"))
val optimized = Optimize.execute(q.analyze)
val sums = findSum(optimized)
assert(sums.nonEmpty, "DecimalAggregates fast path should fire for dec(2,1)")
assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
s"evalMode should be preserved as TRY after rewrite, got " +
sums.map(_.evalContext.evalMode).mkString(","))
}

test("SPARK-56949: DecimalAggregates preserves Average.evalMode for try_avg") {
val tryAvg = Average($"a", EvalMode.TRY)
val q = testRelation.select(tryAvg.toAggregateExpression().as("ta"))
val optimized = Optimize.execute(q.analyze)
val avgs = findAverage(optimized)
assert(avgs.nonEmpty, "DecimalAggregates fast path should fire for dec(2,1)")
assert(avgs.forall(_.evalMode == EvalMode.TRY),
s"evalMode should be preserved as TRY after rewrite, got " +
avgs.map(_.evalMode).mkString(","))
}

test("SPARK-56949: DecimalAggregates preserves Sum.evalContext " +
"for try_sum on widened-cast peel arm") {
val trySum = Sum($"d7_2".cast(DecimalType(12, 2)),
NumericEvalContext(EvalMode.TRY))
val q = widenRel.select(trySum.toAggregateExpression().as("ts"))
val optimized = Optimize.execute(q.analyze)
val sums = findSum(optimized)
assert(sums.nonEmpty, "widened-cast SUM peel should fire for dec(7,2)->dec(12,2)")
assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
s"evalMode should be preserved as TRY after rewrite, got " +
sums.map(_.evalContext.evalMode).mkString(","))
}

test("SPARK-56949: DecimalAggregates preserves Average.evalMode " +
"for try_avg on widened-cast peel arm") {
val tryAvg = Average($"d7_2".cast(DecimalType(12, 2)), EvalMode.TRY)
val q = widenRel.select(tryAvg.toAggregateExpression().as("ta"))
val optimized = Optimize.execute(q.analyze)
val avgs = findAverage(optimized)
assert(avgs.nonEmpty, "widened-cast AVG peel should fire for dec(7,2)->dec(12,2)")
assert(avgs.forall(_.evalMode == EvalMode.TRY),
s"evalMode should be preserved as TRY after rewrite, got " +
avgs.map(_.evalMode).mkString(","))
}

test("SPARK-56949: DecimalAggregates preserves Sum.evalContext " +
"for try_sum over Window (un-widened arm)") {
val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame)
val trySum = Sum($"a", NumericEvalContext(EvalMode.TRY))
val q = testRelation.select(
windowExpr(trySum.toAggregateExpression(), spec).as("ts"))
val optimized = Optimize.execute(q.analyze)
val sums = findSum(optimized)
assert(sums.nonEmpty, "Window-arm SUM peel should fire for dec(2,1)")
assert(sums.forall(_.evalContext.evalMode == EvalMode.TRY),
s"evalMode should be preserved as TRY after rewrite, got " +
sums.map(_.evalContext.evalMode).mkString(","))
}

test("SPARK-56949: DecimalAggregates preserves Average.evalMode " +
"for try_avg over Window (un-widened arm)") {
val spec = windowSpec(Seq($"a"), Nil, UnspecifiedFrame)
val tryAvg = Average($"a", EvalMode.TRY)
val q = testRelation.select(
windowExpr(tryAvg.toAggregateExpression(), spec).as("ta"))
val optimized = Optimize.execute(q.analyze)
val avgs = findAverage(optimized)
assert(avgs.nonEmpty, "Window-arm AVG peel should fire for dec(2,1)")
assert(avgs.forall(_.evalMode == EvalMode.TRY),
s"evalMode should be preserved as TRY after rewrite, got " +
avgs.map(_.evalMode).mkString(","))
}
}