From 37eff10bcb6b2103168f570fab3f1527a6c90787 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 15 May 2026 17:00:50 +0200 Subject: [PATCH 1/2] [SPARK-56877][SQL] Enforce `KeyedPartitioning` invariant in `PartitioningCollection` ### What changes were proposed in this pull request? - Add a `require` in `PartitioningCollection` that all `KeyedPartitioning`s reachable from the collection share the same `partitionKeys` reference (`eq`) and have matching expression arity. The check walks the partitioning tree via `foreach` so nested collections are covered. - Add a smart factory `PartitioningCollection.fromPartitionings` that interns `partitionKeys` references across `KeyedPartitioning`s. Use this at sites that combine independently-computed partitionings (joins) where keys are structurally equal but not reference-equal. The factory uses manual recursion rather than `transformWithPruning` because `KeyedPartitioning.equals` compares `partitionKeys` element-wise, which would make `transformWithPruning` discard the rule's replacement as structurally-equal-to-input. - In `GroupPartitionsExec.outputPartitioning`, hoist `val partitionKeys = groupedPartitions.map(_._1)` above the `transform` so every rebuilt `KeyedPartitioning` shares the same `partitionKeys` reference. Drop the ad-hoc consistency assert (now enforced by `PartitioningCollection`). - Switch `ShuffledJoin` and `StreamingSymmetricHashJoinExec` to `PartitioningCollection.fromPartitionings` for their inner-join `outputPartitioning`. - Update affected tests to construct collections via `fromPartitionings`. Rewrite the `SPARK-46367` arity-mismatch test in `ProjectedOrderingAndPartitioningSuite` since the scenario is now rejected at `PartitioningCollection` construction rather than inside `AliasAwareOutputExpression`. ### Why are the changes needed? The "all `KeyedPartitioning`s in a collection must agree on `partitionKeys`" invariant already existed informally -- `GroupPartitionsExec.outputPartitioning` had a runtime assert checking `==`, `AliasAwareOutputExpression.projectKeyedPartitionings` asserted matching arity, and various consumers relied on the invariant being upheld. Consolidating the check into the `PartitioningCollection` constructor makes it load-bearing: any future construction site that violates it fails immediately rather than producing silently divergent state. The stronger `eq` check enables cheap reference comparisons downstream and is naturally achievable when collections are built through the smart factory. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test suites (`EnsureRequirementsSuite`, `GroupPartitionsExecSuite`, `ProjectedOrderingAndPartitioningSuite`) updated to use `PartitioningCollection.fromPartitionings` where they previously constructed collections from independently-built `KeyedPartitioning`s. The `SPARK-46367` test was rewritten to assert that the invalid mixed-arity scenario is rejected at `PartitioningCollection` construction. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code 4.7 --- .../plans/physical/partitioning.scala | 54 +++++++++++++++++++ .../AliasAwareOutputExpression.scala | 15 ++---- .../datasources/v2/GroupPartitionsExec.scala | 20 ++----- .../sql/execution/joins/ShuffledJoin.scala | 3 +- .../join/StreamingSymmetricHashJoinExec.scala | 3 +- ...rojectedOrderingAndPartitioningSuite.scala | 25 ++++----- .../v2/GroupPartitionsExecSuite.scala | 2 +- .../exchange/EnsureRequirementsSuite.scala | 12 ++--- 8 files changed, 83 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cc50da1f17fdf..edac7e09b1dc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -691,6 +691,12 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings` * in this collection do not need to be equivalent, which is useful for * Outer Join operators. + * + * [[KeyedPartitioning]]s within a `PartitioningCollection` describe the same physical partitioning + * and so must share the same `partitionKeys` reference, differing only in their `expressions` (with + * matching arity). Use [[PartitioningCollection.fromPartitionings]] to build a collection from + * independently-computed partitionings (e.g. join `outputPartitioning`); it interns `partitionKeys` + * references (including across nested collections) so the invariant holds. */ case class PartitioningCollection(partitionings: Seq[Partitioning]) extends Expression with Partitioning with Unevaluable { @@ -699,6 +705,24 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) partitionings.map(_.numPartitions).distinct.length == 1, s"PartitioningCollection requires all of its partitionings have the same numPartitions.") + { + var first: KeyedPartitioning = null + foreach { + case k: KeyedPartitioning => + if (first == null) { + first = k + } else { + require(k.expressions.length == first.expressions.length, + "All KeyedPartitionings in a PartitioningCollection must have matching expression " + + "arity") + require(k.partitionKeys eq first.partitionKeys, + "All KeyedPartitionings in a PartitioningCollection must share the same " + + "partitionKeys reference") + } + case _ => + } + } + override def children: Seq[Expression] = partitionings.collect { case expr: Expression => expr } @@ -730,6 +754,36 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection] } +object PartitioningCollection { + /** + * Builds a [[PartitioningCollection]], unifying the `partitionKeys` reference across all + * [[KeyedPartitioning]]s (including those in nested collections). Use this when combining + * independently-computed partitionings (e.g. join `outputPartitioning`) where + * `KeyedPartitioning.partitionKeys` are structurally equal but may not be reference-equal. + * + * Note: this can't be implemented with `TreeNode.transform`. + */ + def fromPartitionings(partitionings: Seq[Partitioning]): PartitioningCollection = { + var canonicalKeys: Seq[InternalRowComparableWrapper] = null + def intern(p: Partitioning): Partitioning = p match { + case k: KeyedPartitioning => + if (canonicalKeys == null) { + canonicalKeys = k.partitionKeys + k + } else if (k.partitionKeys ne canonicalKeys) { + require(k.partitionKeys == canonicalKeys, + "All KeyedPartitionings in a PartitioningCollection must have equal partitionKeys") + k.copy(partitionKeys = canonicalKeys) + } else { + k + } + case pc: PartitioningCollection => new PartitioningCollection(pc.partitionings.map(intern)) + case other => other + } + new PartitioningCollection(partitionings.map(intern)) + } +} + /** * Represents a partitioning where rows are collected, transformed and broadcasted to each * node in the cluster. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 1f2b1d0a585d6..90598608aa2b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -88,22 +88,15 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode * * The resulting [[KeyedPartitioning]]s are the cross-product of the per-position alternatives * restricted to the projectable positions. All share the same `partitionKeys` object (projected - * to the same subset of positions), preserving the invariant required by [[GroupPartitionsExec]]. + * to the same subset of positions), preserving the invariant required by + * [[PartitioningCollection]]. */ private def projectKeyedPartitionings( kps: Seq[KeyedPartitioning]): LazyList[KeyedPartitioning] = { if (kps.isEmpty) return LazyList.empty + // All input KPs share the same `partitionKeys` reference and matching arity by the + // [[PartitioningCollection]] invariant (the only producer of multi-KP inputs here). val numPositions = kps.head.expressions.length - // The function assumes all input KPs share the same `partitionKeys`, which implies matching - // expression arity. This invariant is asserted by [[GroupPartitionsExec]] and is established - // by the constructors of [[PartitioningCollection]] feeding this method (a join's - // `PartitioningCollection(left.outputPartitioning, right.outputPartitioning)` combines KPs - // that have been aligned by [[EnsureRequirements]] to the same join keys). If the invariant - // is ever violated upstream, fail early with a clear message instead of throwing an opaque - // `IndexOutOfBoundsException` from `kp.expressions(i)` below. - assert(kps.tail.forall(_.expressions.length == numPositions), - s"All input KeyedPartitionings must share the same expression arity, " + - s"but got: ${kps.map(_.expressions.length).mkString(", ")}.") val alternativesPerPosition: IndexedSeq[LazyList[Expression]] = if (hasAlias) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala index 264a0e954936f..4d87be6622939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExec.scala @@ -67,24 +67,14 @@ case class GroupPartitionsExec( override def outputPartitioning: Partitioning = { child.outputPartitioning match { case p: Partitioning with Expression => - // There can be multiple `KeyedPartitioning` in an output partitioning of a join, but they - // can only differ in `expressions`. `partitionKeys` must match so we can calculate it only - // once via `groupedPartitions`. - - val keyedPartitionings = p.collect { case k: KeyedPartitioning => k } - if (keyedPartitionings.size > 1) { - val first = keyedPartitionings.head - keyedPartitionings.tail.foreach { k => - assert(k.partitionKeys == first.partitionKeys, - "All KeyedPartitioning nodes must have identical partition keys") - } - } - + // There can be multiple `KeyedPartitioning`s in an output partitioning of a join, but they + // can only differ in `expressions`; their `partitionKeys` reference is shared (enforced by + // `PartitioningCollection`), so `groupedPartitions` is computed only once. + val partitionKeys = groupedPartitions.map(_._1) p.transform { case k: KeyedPartitioning => val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions)) - KeyedPartitioning(projectedExpressions, groupedPartitions.map(_._1), - isGrouped = isGrouped) + KeyedPartitioning(projectedExpressions, partitionKeys, isGrouped = isGrouped) }.asInstanceOf[Partitioning] case o => o } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index f363156c81e54..3fb968bfea7a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -46,7 +46,8 @@ trait ShuffledJoin extends JoinCodegenSupport { override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + PartitioningCollection.fromPartitionings( + Seq(left.outputPartitioning, right.outputPartitioning)) case LeftOuter | LeftSingle => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala index 71a7d4cf56e13..9eca04c985913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala @@ -242,7 +242,8 @@ case class StreamingSymmetricHashJoinExec( override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + PartitioningCollection.fromPartitionings( + Seq(left.outputPartitioning, right.outputPartitioning)) case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala index a38570924620a..a70baece77844 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala @@ -387,7 +387,7 @@ class ProjectedOrderingAndPartitioningSuite val y = AttributeReference("y", IntegerType)() val yAlias = AttributeReference("y_alias", IntegerType)() val keys2d = Seq(InternalRow(1, 1), InternalRow(1, 2), InternalRow(2, 1), InternalRow(2, 2)) - val childPartitioning = PartitioningCollection(Seq( + val childPartitioning = PartitioningCollection.fromPartitionings(Seq( KeyedPartitioning(Seq(x, y), keys2d), KeyedPartitioning(Seq(x, yAlias), keys2d))) val child = DummyLeafExecWithPartitioning( @@ -587,27 +587,20 @@ class ProjectedOrderingAndPartitioningSuite } } - test("SPARK-46367: mixed-arity KeyedPartitionings in input fail with a clear assertion") { - // The function assumes all input KPs share the same arity (the invariant asserted by - // `GroupPartitionsExec`). Without the assert below, indexing `kp.expressions(i)` for - // `i >= kp.expressions.length` would throw an opaque `IndexOutOfBoundsException`. The assert - // surfaces the real cause -- an upstream node violated the invariant -- so the bug can be - // fixed at the producer. + test("SPARK-46367: mixed-arity KeyedPartitionings rejected by PartitioningCollection") { + // PartitioningCollection enforces matching expression arity (and shared partitionKeys + // references) across all its KeyedPartitionings, so the invariant required by + // `AliasAwareOutputExpression` cannot be violated by the input. val x = AttributeReference("x", IntegerType)() val y = AttributeReference("y", IntegerType)() val keys2d = Seq(InternalRow(1, 1), InternalRow(2, 2)) val keys1d = Seq(InternalRow(1), InternalRow(2)) - val child = DummyLeafExecWithPartitioning( - output = Seq(x, y), - partitioning = PartitioningCollection(Seq( + val e = intercept[IllegalArgumentException] { + PartitioningCollection.fromPartitionings(Seq( KeyedPartitioning(Seq(x, y), keys2d), - KeyedPartitioning(Seq(x), keys1d)))) - val project = ProjectExec(Seq(x), child) - val e = intercept[AssertionError] { - project.outputPartitioning + KeyedPartitioning(Seq(x), keys1d))) } - assert(e.getMessage.contains("All input KeyedPartitionings must share the same expression " + - "arity")) + assert(e.getMessage.contains("partitionKeys")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala index 5d2adeb0c00af..51951d68cc606 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/GroupPartitionsExecSuite.scala @@ -97,7 +97,7 @@ class GroupPartitionsExecSuite extends SharedSparkSession { val leftKP = KeyedPartitioning(Seq(exprA), partitionKeys) val rightKP = KeyedPartitioning(Seq(exprB), partitionKeys) val child = DummySparkPlan( - outputPartitioning = PartitioningCollection(Seq(leftKP, rightKP)), + outputPartitioning = PartitioningCollection.fromPartitionings(Seq(leftKP, rightKP)), outputOrdering = Seq(SortOrder(exprA, Ascending, sameOrderExpressions = Seq(exprB)))) val gpe = GroupPartitionsExec(child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 1e35985f50491..74b706bce34f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -821,7 +821,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty) ) plan2 = new DummySparkPlanWithBatchScanChild( - outputPartitioning = PartitioningCollection(Seq( + outputPartitioning = PartitioningCollection.fromPartitionings(Seq( KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty), KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty)) ) @@ -1050,7 +1050,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { // With partition collections plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning = - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues), KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues)) ) @@ -1077,13 +1077,13 @@ class EnsureRequirementsSuite extends SharedSparkSession { // Nested partition collections plan2 = new DummySparkPlanWithBatchScanChild(outputPartitioning = - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq( - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq( KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))), - PartitioningCollection( + PartitioningCollection.fromPartitionings( Seq( KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues), KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))) @@ -1539,7 +1539,7 @@ private case class DummyBothKPBinaryExec(left: SparkPlan, right: SparkPlan) override def output: Seq[Attribute] = left.output ++ right.output override def outputOrdering: Seq[SortOrder] = left.outputOrdering override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + PartitioningCollection.fromPartitionings(Seq(left.outputPartitioning, right.outputPartitioning)) override protected def doExecute(): RDD[InternalRow] = null override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = From a4329cc0cab770dd466e528d9a42b348f1ece5f1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 16 May 2026 14:38:04 +0200 Subject: [PATCH 2/2] address review findings --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 4 +++- .../spark/sql/execution/AliasAwareOutputExpression.scala | 2 +- .../spark/sql/execution/joins/BroadcastHashJoinExec.scala | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index edac7e09b1dc4..f331cd124759f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -705,7 +705,9 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) partitionings.map(_.numPartitions).distinct.length == 1, s"PartitioningCollection requires all of its partitionings have the same numPartitions.") - { + checkKeyedPartitioningInvariant() + + private def checkKeyedPartitioningInvariant(): Unit = { var first: KeyedPartitioning = null foreach { case k: KeyedPartitioning => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 90598608aa2b9..b37e1b258e9bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -40,7 +40,7 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode (projectedKPs ++ projectedOthers).take(aliasCandidateLimit) match { case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions) case Seq(p) => p - case ps => PartitioningCollection(ps) + case ps => PartitioningCollection.fromPartitionings(ps) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index e4f18c9144dda..2881aeac55d89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -84,7 +84,7 @@ case class BroadcastHashJoinExec private( // constructor prevents that. case p :: Nil => p - case ps => PartitioningCollection(ps) + case ps => PartitioningCollection.fromPartitionings(ps) } case _ => streamedPlan.outputPartitioning }