diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala index 6c0bca0e1104f..17fd9486ad9aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE} import org.apache.spark.sql.catalyst.rules.Rule @@ -53,7 +54,7 @@ object NormalizeCTEIds extends Rule[LogicalPlan] { private def canonicalizeCTE( plan: LogicalPlan, defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = { - plan.transformDownWithSubqueries { + val normalizedPlan = plan match { // For nested WithCTE, if defIndex didn't contain the cteId, // means it's not current WithCTE's ref. case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) => @@ -62,6 +63,17 @@ object NormalizeCTEIds extends Rule[LogicalPlan] { unionLoop.copy(id = defIdToNewId(unionLoop.id)) case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) => unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId)) + case other => other } + + normalizedPlan + .withNewChildren(normalizedPlan.children.map { + case withCTE: WithCTE => withCTE + case child => canonicalizeCTE(child, defIdToNewId) + }) + .transformExpressionsDown { + case subqueryExpression: SubqueryExpression => + subqueryExpression.withNewPlan(canonicalizeCTE(subqueryExpression.plan, defIdToNewId)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 7562d5669cc2c..1371e59545fdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -261,6 +261,45 @@ abstract class CTEInlineSuiteBase } } + test("SPARK-56921: plan normalization handles nested CTEs under union") { + withTempView("input", "common") { + Seq((1, 1, 10), (1, 2, 20), (2, 1, 30)) + .toDF("a", "b", "value") + .createOrReplaceTempView("input") + + sql( + s"""with cte_common as ( + | select a, b, sum(value) as value + | from input + | group by a, b + |) + |select * from cte_common + """.stripMargin).createOrReplaceTempView("common") + + val left = sql( + s"""with cte_a as ( + | select a, sum(value) as value + | from common + | group by a + |) + |select a as id, value from cte_a + """.stripMargin) + + val right = sql( + s"""with cte_b as ( + | select b, sum(value) as value + | from common + | group by b + |) + |select b as id, value from cte_b + """.stripMargin) + + val df = left.union(right) + df.queryExecution.normalized + checkAnswer(df, Row(1, 30) :: Row(2, 30) :: Row(1, 40) :: Row(2, 20) :: Nil) + } + } + test("SPARK-36447: invalid nested CTEs") { withTempView("t") { Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")