diff --git a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala index d778212392..62710c28dc 100644 --- a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala +++ b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala @@ -77,6 +77,29 @@ object IcebergReflection extends Logging { val UNKNOWN = "unknown" } + /** + * Loads a class using the thread context classloader first, then falls back to the system + * classloader. + * + * @param className + * Fully qualified class name to load + * @return + * The loaded Class object + */ + def loadClass(className: String): Class[_] = { + val classLoader = Thread.currentThread().getContextClassLoader + if (classLoader != null) { + // scalastyle:off classforname + Class.forName(className, true, classLoader) + // scalastyle:on classforname + } else { + // Fallback to default classloader if context classloader is null + // scalastyle:off classforname + Class.forName(className) + // scalastyle:on classforname + } + } + /** * Searches through class hierarchy to find a method (including protected methods). */ @@ -124,9 +147,7 @@ object IcebergReflection extends Logging { */ def extractFileLocation(file: Any): Option[String] = { try { - // scalastyle:off classforname - val contentFileClass = Class.forName(ClassNames.CONTENT_FILE) - // scalastyle:on classforname + val contentFileClass = loadClass(ClassNames.CONTENT_FILE) extractFileLocation(contentFileClass, file) } catch { case _: Exception => None @@ -387,9 +408,7 @@ object IcebergReflection extends Logging { */ def getEqualityFieldIds(deleteFile: Any): java.util.List[_] = { try { - // scalastyle:off classforname - val deleteFileClass = Class.forName(ClassNames.DELETE_FILE) - // scalastyle:on classforname + val deleteFileClass = loadClass(ClassNames.DELETE_FILE) val equalityFieldIdsMethod = deleteFileClass.getMethod("equalityFieldIds") val ids = equalityFieldIdsMethod.invoke(deleteFile).asInstanceOf[java.util.List[_]] if (ids == null) new java.util.ArrayList[Any]() else ids @@ -515,9 +534,7 @@ object IcebergReflection extends Logging { val fieldsMethod = partitionSpec.getClass.getMethod("fields") val fields = fieldsMethod.invoke(partitionSpec).asInstanceOf[java.util.List[_]] - // scalastyle:off classforname - val partitionFieldClass = Class.forName(ClassNames.PARTITION_FIELD) - // scalastyle:on classforname + val partitionFieldClass = loadClass(ClassNames.PARTITION_FIELD) val sourceIdMethod = partitionFieldClass.getMethod("sourceId") val findFieldMethod = schema.getClass.getMethod("findField", classOf[Int]) diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala index 458bc52fb8..3f240b11f8 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala @@ -227,9 +227,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit fileScanTaskClass: Class[_], fileIO: Option[Any]): Seq[OperatorOuterClass.IcebergDeleteFile] = { try { - // scalastyle:off classforname - val deleteFileClass = Class.forName(IcebergReflection.ClassNames.DELETE_FILE) - // scalastyle:on classforname + val deleteFileClass = IcebergReflection.loadClass(IcebergReflection.ClassNames.DELETE_FILE) val deletes = IcebergReflection.getDeleteFilesFromTask(task, fileScanTaskClass) @@ -336,13 +334,11 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit if (spec != null) { // Deduplicate partition spec try { - // scalastyle:off classforname val partitionSpecParserClass = - Class.forName(IcebergReflection.ClassNames.PARTITION_SPEC_PARSER) + IcebergReflection.loadClass(IcebergReflection.ClassNames.PARTITION_SPEC_PARSER) val toJsonMethod = partitionSpecParserClass.getMethod( "toJson", - Class.forName(IcebergReflection.ClassNames.PARTITION_SPEC)) - // scalastyle:on classforname + IcebergReflection.loadClass(IcebergReflection.ClassNames.PARTITION_SPEC)) val partitionSpecJson = toJsonMethod .invoke(null, spec) .asInstanceOf[String] @@ -685,9 +681,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit */ private def convertIcebergLiteral(icebergLiteral: Any, sparkType: DataType): Literal = { // Load Literal interface to get value() method (use interface to avoid package-private issues) - // scalastyle:off classforname - val literalClass = Class.forName(IcebergReflection.ClassNames.LITERAL) - // scalastyle:on classforname + val literalClass = IcebergReflection.loadClass(IcebergReflection.ClassNames.LITERAL) val valueMethod = literalClass.getMethod("value") val value = valueMethod.invoke(icebergLiteral) @@ -790,13 +784,16 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } // Load Iceberg classes once (avoid repeated class loading in loop) - // scalastyle:off classforname - val contentScanTaskClass = Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK) - val fileScanTaskClass = Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK) - val contentFileClass = Class.forName(IcebergReflection.ClassNames.CONTENT_FILE) - val schemaParserClass = Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER) - val schemaClass = Class.forName(IcebergReflection.ClassNames.SCHEMA) - // scalastyle:on classforname + val contentScanTaskClass = + IcebergReflection.loadClass(IcebergReflection.ClassNames.CONTENT_SCAN_TASK) + val fileScanTaskClass = + IcebergReflection.loadClass(IcebergReflection.ClassNames.FILE_SCAN_TASK) + val contentFileClass = + IcebergReflection.loadClass(IcebergReflection.ClassNames.CONTENT_FILE) + val schemaParserClass = + IcebergReflection.loadClass(IcebergReflection.ClassNames.SCHEMA_PARSER) + val schemaClass = + IcebergReflection.loadClass(IcebergReflection.ClassNames.SCHEMA) // Cache method lookups (avoid repeated getMethod in loop) val fileMethod = contentScanTaskClass.getMethod("file")