diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 2edac5b0179bb..0415a33e2d6dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -28,8 +28,9 @@ import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.CommandExecutionMode import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ /** @@ -79,7 +80,9 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo bucketSpec = table.bucketSpec, options = table.storage.properties ++ pathOption, // As discussed in SPARK-19583, we don't check if the location is existed - catalogTable = Some(tableWithDefaultOptions)).resolveRelation(checkFilesExist = false) + catalogTable = Some(tableWithDefaultOptions)) + .resolveRelation(checkFilesExist = false, + forceNullable = !sessionState.conf.getConf(SQLConf.FILE_SOURCE_INSERT_ENFORCE_NOT_NULL)) val partitionColumnNames = if (table.schema.nonEmpty) { table.partitionColumnNames @@ -107,17 +110,8 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo table.copy(schema = new StructType(), partitionColumnNames = Nil) case _ => - // Merge nullability from the user-specified schema into the resolved schema. - // DataSource.resolveRelation() calls dataSchema.asNullable which strips NOT NULL - // constraints. We restore nullability from the original user schema while keeping - // the resolved data types (which may include CharVarchar normalization, metadata, etc.) - val resolvedSchema = if (table.schema.nonEmpty) { - restoreNullability(dataSource.schema, table.schema) - } else { - dataSource.schema - } table.copy( - schema = resolvedSchema, + schema = dataSource.schema, partitionColumnNames = partitionColumnNames, // If metastore partition management for file source tables is enabled, we start off with // partition provider hive, but no partitions in the metastore. The user has to call @@ -132,38 +126,6 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo Seq.empty[Row] } - /** - * Recursively restores nullability from the original user-specified schema into - * the resolved schema. The resolved schema's data types are preserved (they may - * contain CharVarchar normalization, metadata, etc.), but nullability flags - * (top-level and nested) are taken from the original schema. - */ - private def restoreNullability(resolved: StructType, original: StructType): StructType = { - val originalFields = original.fields.map(f => f.name -> f).toMap - StructType(resolved.fields.map { resolvedField => - originalFields.get(resolvedField.name) match { - case Some(origField) => - resolvedField.copy( - nullable = origField.nullable, - dataType = restoreDataTypeNullability(resolvedField.dataType, origField.dataType)) - case None => resolvedField - } - }) - } - - private def restoreDataTypeNullability(resolved: DataType, original: DataType): DataType = { - (resolved, original) match { - case (r: StructType, o: StructType) => restoreNullability(r, o) - case (ArrayType(rElem, _), ArrayType(oElem, oNull)) => - ArrayType(restoreDataTypeNullability(rElem, oElem), oNull) - case (MapType(rKey, rVal, _), MapType(oKey, oVal, oValNull)) => - MapType( - restoreDataTypeNullability(rKey, oKey), - restoreDataTypeNullability(rVal, oVal), - oValNull) - case _ => resolved - } - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d2ec3f7ff486b..be1f05da308f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -362,7 +362,10 @@ case class DataSource( * is considered as a non-streaming file based data source. Since we know * that files already exist, we don't need to check them again. */ - def resolveRelation(checkFilesExist: Boolean = true, readOnly: Boolean = false): BaseRelation = { + def resolveRelation( + checkFilesExist: Boolean = true, + readOnly: Boolean = false, + forceNullable: Boolean = true): BaseRelation = { val relation = (providingInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => @@ -436,7 +439,7 @@ case class DataSource( HadoopFsRelation( fileCatalog, partitionSchema = partitionSchema, - dataSchema = dataSchema.asNullable, + dataSchema = if (forceNullable) dataSchema.asNullable else dataSchema, bucketSpec = bucketSpec, format, caseInsensitiveOptions)(sparkSession) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala index 95b539e58ac6b..e65bf1c72bb62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowCreateTableSuite.scala @@ -185,7 +185,7 @@ trait ShowCreateTableSuiteBase extends command.ShowCreateTableSuiteBase val showDDL = getShowCreateDDL(t) assert(showDDL === Array( s"CREATE TABLE $fullName (", - "a BIGINT NOT NULL,", + "a BIGINT,", "b BIGINT DEFAULT 42,", "c STRING COLLATE UTF8_BINARY DEFAULT 'abc, \"def\"' COMMENT 'comment')", "USING parquet", diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala index c2a5ca1023e90..b96d00ee43e35 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -340,12 +340,10 @@ class SparkMetadataOperationSuite extends HiveThriftServer2TestBase { case _ => assert(radix === 0) // nulls } - val expectedNullable = if (schema(pos).nullable) 1 else 0 - assert(rowSet.getInt("NULLABLE") === expectedNullable) + assert(rowSet.getInt("NULLABLE") === 1) assert(rowSet.getString("REMARKS") === pos.toString) assert(rowSet.getInt("ORDINAL_POSITION") === pos + 1) - val expectedIsNullable = if (schema(pos).nullable) "YES" else "NO" - assert(rowSet.getString("IS_NULLABLE") === expectedIsNullable) + assert(rowSet.getString("IS_NULLABLE") === "YES") assert(rowSet.getString("IS_AUTO_INCREMENT") === "NO") pos += 1 }