Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

// https://github.com/apache/datafusion-comet/issues/3906
ignore("cast nested ArrayType to nested ArrayType") {
test("cast nested ArrayType to nested ArrayType") {
val types = Seq(
BooleanType,
StringType,
Expand All @@ -1600,14 +1600,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
FloatType,
DoubleType,
DecimalType(10, 2),
DecimalType(38, 18),
// DecimalType(38, 18) is excluded for the same reason as the one-dimensional array
// matrix: decimal-to-float/double casts can differ by ~1 ULP from Spark.
DateType,
TimestampType,
BinaryType)
testArrayCastMatrix(
types,
dt => ArrayType(ArrayType(dt)),
dt => generateArrays(100, ArrayType(dt)))
testArrayCastMatrix(types, dt => ArrayType(ArrayType(dt)), dt => generateNestedArrays(20, dt))
}

// CAST from TimestampNTZType
Expand Down Expand Up @@ -1715,7 +1713,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
import scala.collection.JavaConverters._
import scala.jdk.CollectionConverters._
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
def buildRows(values: Seq[Any]): Seq[Row] = {
Range(0, rowNum).map { i =>
Expand Down Expand Up @@ -1769,6 +1767,37 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

private def generateNestedArrays(rowNum: Int, elementType: DataType): DataFrame = {
import scala.jdk.CollectionConverters._
val schema = StructType(Seq(StructField("a", ArrayType(ArrayType(elementType)), true)))
val innerArrays = generateArrays(rowNum, elementType)
.collect()
.map { row =>
if (row.isNullAt(0)) {
null
} else {
row.getSeq[Any](0)
}
}
.toSeq

def buildRows(values: Seq[Seq[Any]]): Seq[Row] = {
Range(0, rowNum).map { i =>
Row(
Seq[Any](
values(i % values.length),
// Keep every third row's middle nested-array element null.
if (i % 3 == 0) null else values((i + 1) % values.length),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : may be add a comment to let user's know every 3rd element is null ?

values((i + 2) % values.length)))
}
}

val sampleValue = innerArrays.find(_ != null).orNull
val rows = Seq(Row(Seq(sampleValue, null, sampleValue)), Row(Seq.empty[Any]), Row(null)) ++
buildRows(innerArrays)
spark.createDataFrame(rows.asJava, schema)
}

// https://github.com/apache/datafusion-comet/issues/2038
test("test implicit cast to dictionary with case when and dictionary type") {
withSQLConf("parquet.enable.dictionary" -> "true") {
Expand Down
Loading