Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,12 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
* `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
* in this collection do not need to be equivalent, which is useful for
* Outer Join operators.
*
* [[KeyedPartitioning]]s within a `PartitioningCollection` describe the same physical partitioning
* and so must share the same `partitionKeys` reference, differing only in their `expressions` (with
* matching arity). Use [[PartitioningCollection.fromPartitionings]] to build a collection from
* independently-computed partitionings (e.g. join `outputPartitioning`); it interns `partitionKeys`
* references (including across nested collections) so the invariant holds.
*/
case class PartitioningCollection(partitionings: Seq[Partitioning])
extends Expression with Partitioning with Unevaluable {
Expand All @@ -699,6 +705,26 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
partitionings.map(_.numPartitions).distinct.length == 1,
s"PartitioningCollection requires all of its partitionings have the same numPartitions.")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit (no need to change): would extracting this block to a private def checkInvariant(): Unit; checkInvariant() read closer to the other Distribution/Partitioning case-class init blocks in this file (e.g. ClusteredDistribution L90-94, OrderedDistribution L168-172) which use single-line require(...)? Just a question — current form works.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Extracted in a4329cc.

checkKeyedPartitioningInvariant()

private def checkKeyedPartitioningInvariant(): Unit = {
var first: KeyedPartitioning = null
foreach {
case k: KeyedPartitioning =>
if (first == null) {
first = k
} else {
require(k.expressions.length == first.expressions.length,
"All KeyedPartitionings in a PartitioningCollection must have matching expression " +
"arity")
require(k.partitionKeys eq first.partitionKeys,
"All KeyedPartitionings in a PartitioningCollection must share the same " +
"partitionKeys reference")
}
case _ =>
}
}

override def children: Seq[Expression] = partitionings.collect {
case expr: Expression => expr
}
Expand Down Expand Up @@ -730,6 +756,36 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
super.legacyWithNewChildren(newChildren).asInstanceOf[PartitioningCollection]
}

object PartitioningCollection {
/**
* Builds a [[PartitioningCollection]], unifying the `partitionKeys` reference across all
* [[KeyedPartitioning]]s (including those in nested collections). Use this when combining
* independently-computed partitionings (e.g. join `outputPartitioning`) where
* `KeyedPartitioning.partitionKeys` are structurally equal but may not be reference-equal.
*
* Note: this can't be implemented with `TreeNode.transform`.
*/
def fromPartitionings(partitionings: Seq[Partitioning]): PartitioningCollection = {
var canonicalKeys: Seq[InternalRowComparableWrapper] = null
def intern(p: Partitioning): Partitioning = p match {
case k: KeyedPartitioning =>
if (canonicalKeys == null) {
canonicalKeys = k.partitionKeys
k
} else if (k.partitionKeys ne canonicalKeys) {
require(k.partitionKeys == canonicalKeys,
"All KeyedPartitionings in a PartitioningCollection must have equal partitionKeys")
k.copy(partitionKeys = canonicalKeys)
} else {
k
}
case pc: PartitioningCollection => new PartitioningCollection(pc.partitionings.map(intern))
case other => other
}
new PartitioningCollection(partitionings.map(intern))
}
}

/**
* Represents a partitioning where rows are collected, transformed and broadcasted to each
* node in the cluster.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode
(projectedKPs ++ projectedOthers).take(aliasCandidateLimit) match {
case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
case Seq(p) => p
case ps => PartitioningCollection(ps)
case ps => PartitioningCollection.fromPartitionings(ps)
}
}

Expand Down Expand Up @@ -88,22 +88,15 @@ trait PartitioningPreservingUnaryExecNode extends UnaryExecNode
*
* The resulting [[KeyedPartitioning]]s are the cross-product of the per-position alternatives
* restricted to the projectable positions. All share the same `partitionKeys` object (projected
* to the same subset of positions), preserving the invariant required by [[GroupPartitionsExec]].
* to the same subset of positions), preserving the invariant required by
* [[PartitioningCollection]].
*/
private def projectKeyedPartitionings(
kps: Seq[KeyedPartitioning]): LazyList[KeyedPartitioning] = {
if (kps.isEmpty) return LazyList.empty
// All input KPs share the same `partitionKeys` reference and matching arity by the
// [[PartitioningCollection]] invariant (the only producer of multi-KP inputs here).
val numPositions = kps.head.expressions.length
// The function assumes all input KPs share the same `partitionKeys`, which implies matching
// expression arity. This invariant is asserted by [[GroupPartitionsExec]] and is established
// by the constructors of [[PartitioningCollection]] feeding this method (a join's
// `PartitioningCollection(left.outputPartitioning, right.outputPartitioning)` combines KPs
// that have been aligned by [[EnsureRequirements]] to the same join keys). If the invariant
// is ever violated upstream, fail early with a clear message instead of throwing an opaque
// `IndexOutOfBoundsException` from `kp.expressions(i)` below.
assert(kps.tail.forall(_.expressions.length == numPositions),
s"All input KeyedPartitionings must share the same expression arity, " +
s"but got: ${kps.map(_.expressions.length).mkString(", ")}.")

val alternativesPerPosition: IndexedSeq[LazyList[Expression]] =
if (hasAlias) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,14 @@ case class GroupPartitionsExec(
override def outputPartitioning: Partitioning = {
child.outputPartitioning match {
case p: Partitioning with Expression =>
// There can be multiple `KeyedPartitioning` in an output partitioning of a join, but they
// can only differ in `expressions`. `partitionKeys` must match so we can calculate it only
// once via `groupedPartitions`.

val keyedPartitionings = p.collect { case k: KeyedPartitioning => k }
if (keyedPartitionings.size > 1) {
val first = keyedPartitionings.head
keyedPartitionings.tail.foreach { k =>
assert(k.partitionKeys == first.partitionKeys,
"All KeyedPartitioning nodes must have identical partition keys")
}
}

// There can be multiple `KeyedPartitioning`s in an output partitioning of a join, but they
// can only differ in `expressions`; their `partitionKeys` reference is shared (enforced by
// `PartitioningCollection`), so `groupedPartitions` is computed only once.
val partitionKeys = groupedPartitions.map(_._1)
p.transform {
case k: KeyedPartitioning =>
val projectedExpressions = joinKeyPositions.fold(k.expressions)(_.map(k.expressions))
KeyedPartitioning(projectedExpressions, groupedPartitions.map(_._1),
isGrouped = isGrouped)
KeyedPartitioning(projectedExpressions, partitionKeys, isGrouped = isGrouped)
}.asInstanceOf[Partitioning]
case o => o
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ case class BroadcastHashJoinExec private(
// constructor prevents that.

case p :: Nil => p
case ps => PartitioningCollection(ps)
case ps => PartitioningCollection.fromPartitionings(ps)
}
case _ => streamedPlan.outputPartitioning
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ trait ShuffledJoin extends JoinCodegenSupport {

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
PartitioningCollection.fromPartitionings(
Seq(left.outputPartitioning, right.outputPartitioning))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should BroadcastHashJoinExec.scala:87 (case ps => PartitioningCollection(ps) in expandOutputPartitioning) and AliasAwareOutputExpression.scala:43 (case ps => PartitioningCollection(ps) in outputPartitioning) also use this new fromPartitionings factory for consistency? Both build a PartitioningCollection from independently-computed partitionings that may contain KeyedPartitionings, which looks like the same scenario fromPartitionings was introduced for. Or are these intentionally left as direct construction because the inputs already satisfy the partitionKeys eq invariant by other means?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Techincally, .fromPartitionings() is not necessary in BroadcastHashJoinExec and in AliasAwareOutputExpression, but it is kind of a no-op when the input already satisfies the invariant, so we can make them consistent. Changed in a4329cc.

case LeftOuter | LeftSingle => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ case class StreamingSymmetricHashJoinExec(

override def outputPartitioning: Partitioning = joinType match {
case _: InnerLike =>
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
PartitioningCollection.fromPartitionings(
Seq(left.outputPartitioning, right.outputPartitioning))
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class ProjectedOrderingAndPartitioningSuite
val y = AttributeReference("y", IntegerType)()
val yAlias = AttributeReference("y_alias", IntegerType)()
val keys2d = Seq(InternalRow(1, 1), InternalRow(1, 2), InternalRow(2, 1), InternalRow(2, 2))
val childPartitioning = PartitioningCollection(Seq(
val childPartitioning = PartitioningCollection.fromPartitionings(Seq(
KeyedPartitioning(Seq(x, y), keys2d),
KeyedPartitioning(Seq(x, yAlias), keys2d)))
val child = DummyLeafExecWithPartitioning(
Expand Down Expand Up @@ -587,27 +587,20 @@ class ProjectedOrderingAndPartitioningSuite
}
}

test("SPARK-46367: mixed-arity KeyedPartitionings in input fail with a clear assertion") {
// The function assumes all input KPs share the same arity (the invariant asserted by
// `GroupPartitionsExec`). Without the assert below, indexing `kp.expressions(i)` for
// `i >= kp.expressions.length` would throw an opaque `IndexOutOfBoundsException`. The assert
// surfaces the real cause -- an upstream node violated the invariant -- so the bug can be
// fixed at the producer.
test("SPARK-46367: mixed-arity KeyedPartitionings rejected by PartitioningCollection") {
// PartitioningCollection enforces matching expression arity (and shared partitionKeys
// references) across all its KeyedPartitionings, so the invariant required by
// `AliasAwareOutputExpression` cannot be violated by the input.
val x = AttributeReference("x", IntegerType)()
val y = AttributeReference("y", IntegerType)()
val keys2d = Seq(InternalRow(1, 1), InternalRow(2, 2))
val keys1d = Seq(InternalRow(1), InternalRow(2))
val child = DummyLeafExecWithPartitioning(
output = Seq(x, y),
partitioning = PartitioningCollection(Seq(
val e = intercept[IllegalArgumentException] {
PartitioningCollection.fromPartitionings(Seq(
KeyedPartitioning(Seq(x, y), keys2d),
KeyedPartitioning(Seq(x), keys1d))))
val project = ProjectExec(Seq(x), child)
val e = intercept[AssertionError] {
project.outputPartitioning
KeyedPartitioning(Seq(x), keys1d)))
}
assert(e.getMessage.contains("All input KeyedPartitionings must share the same expression " +
"arity"))
assert(e.getMessage.contains("partitionKeys"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class GroupPartitionsExecSuite extends SharedSparkSession {
val leftKP = KeyedPartitioning(Seq(exprA), partitionKeys)
val rightKP = KeyedPartitioning(Seq(exprB), partitionKeys)
val child = DummySparkPlan(
outputPartitioning = PartitioningCollection(Seq(leftKP, rightKP)),
outputPartitioning = PartitioningCollection.fromPartitionings(Seq(leftKP, rightKP)),
outputOrdering = Seq(SortOrder(exprA, Ascending, sameOrderExpressions = Seq(exprB))))
val gpe = GroupPartitionsExec(child)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprB) :: Nil, Seq.empty)
)
plan2 = new DummySparkPlanWithBatchScanChild(
outputPartitioning = PartitioningCollection(Seq(
outputPartitioning = PartitioningCollection.fromPartitionings(Seq(
KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty),
KeyedPartitioning(bucket(4, exprA) :: bucket(16, exprC) :: Nil, Seq.empty))
)
Expand Down Expand Up @@ -1050,7 +1050,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {

// With partition collections
plan1 = new DummySparkPlanWithBatchScanChild(outputPartitioning =
PartitioningCollection(
PartitioningCollection.fromPartitionings(
Seq(KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues),
KeyedPartitioning(bucket(4, exprB) :: bucket(8, exprC) :: Nil, leftPartValues))
)
Expand All @@ -1077,13 +1077,13 @@ class EnsureRequirementsSuite extends SharedSparkSession {

// Nested partition collections
plan2 = new DummySparkPlanWithBatchScanChild(outputPartitioning =
PartitioningCollection(
PartitioningCollection.fromPartitionings(
Seq(
PartitioningCollection(
PartitioningCollection.fromPartitionings(
Seq(
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues),
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues))),
PartitioningCollection(
PartitioningCollection.fromPartitionings(
Seq(
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues),
KeyedPartitioning(bucket(4, exprC) :: bucket(8, exprB) :: Nil, rightPartValues)))
Expand Down Expand Up @@ -1539,7 +1539,7 @@ private case class DummyBothKPBinaryExec(left: SparkPlan, right: SparkPlan)
override def output: Seq[Attribute] = left.output ++ right.output
override def outputOrdering: Seq[SortOrder] = left.outputOrdering
override def outputPartitioning: Partitioning =
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
PartitioningCollection.fromPartitionings(Seq(left.outputPartitioning, right.outputPartitioning))
override protected def doExecute(): RDD[InternalRow] = null
override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
Expand Down