From d8dbe3bd6c7b3a169f5a332750022fbe0c3793eb Mon Sep 17 00:00:00 2001 From: Raghunandan Kumar <148339424+RaghunandanKumar@users.noreply.github.com> Date: Fri, 15 May 2026 12:48:45 -0400 Subject: [PATCH] [SPARK-56468][UDF] Validate required worker capabilities in direct dispatcher --- .../core/direct/DirectWorkerDispatcher.scala | 39 ++++++++++++++- .../core/DirectWorkerDispatcherSuite.scala | 47 ++++++++++++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index afaf23791d80f..cbc18af205a82 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -28,7 +28,8 @@ import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental -import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification} +import org.apache.spark.udf.worker.{ProcessCallable, UDFProtoCommunicationPattern, + UDFWorkerDataFormat, UDFWorkerSpecification} import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, WorkerLogger, WorkerSecurityScope, WorkerSession} import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, @@ -68,6 +69,7 @@ abstract class DirectWorkerDispatcher( // TODO: Connection pooling -- reuse idle workers across sessions. // TODO: Security scope isolation -- partition pool by WorkerSecurityScope. + validateWorkerSpec() validateTransportSupport() validateEnvironmentCallables() @@ -467,6 +469,41 @@ abstract class DirectWorkerDispatcher( require(!env.hasEnvironmentVerification || env.hasInstallation, "WorkerEnvironment.environment_verification requires installation to be set") } + + private def validateWorkerSpec(): Unit = { + require(workerSpec.hasDirect, + "UDFWorkerSpecification.worker must be set") + + val direct = workerSpec.getDirect + require(direct.hasRunner, + "DirectWorker.runner must be set") + + val capabilities = workerSpec.getCapabilities + require( + capabilities.getSupportedDataFormatsCount > 0, + "WorkerCapabilities.supported_data_formats must contain at least ARROW") + require( + !capabilities.getSupportedDataFormatsList.asScala.contains( + UDFWorkerDataFormat.UDF_WORKER_DATA_FORMAT_UNSPECIFIED), + "WorkerCapabilities.supported_data_formats must not contain UNSPECIFIED") + require( + capabilities.getSupportedDataFormatsList.asScala.contains(UDFWorkerDataFormat.ARROW), + "WorkerCapabilities.supported_data_formats must contain ARROW") + + require( + capabilities.getSupportedCommunicationPatternsCount > 0, + "WorkerCapabilities.supported_communication_patterns must contain " + + "BIDIRECTIONAL_STREAMING") + require( + !capabilities.getSupportedCommunicationPatternsList.asScala.contains( + UDFProtoCommunicationPattern.UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED), + "WorkerCapabilities.supported_communication_patterns must not contain UNSPECIFIED") + require( + capabilities.getSupportedCommunicationPatternsList.asScala.contains( + UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING), + "WorkerCapabilities.supported_communication_patterns must contain " + + "BIDIRECTIONAL_STREAMING") + } } private[direct] object DirectWorkerDispatcher { diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index 60f5e2211b702..5d2dcdcadd19f 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -27,8 +27,9 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.udf.worker.{ - DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties, - UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec, + DirectWorker, LocalTcpConnection, ProcessCallable, UDFProtoCommunicationPattern, + UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification, UnixDomainSocket, + WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment} import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher, DirectWorkerException, DirectWorkerProcess, DirectWorkerSession, @@ -121,8 +122,16 @@ class DirectWorkerDispatcherSuite private def directWorker(runner: ProcessCallable): DirectWorker = DirectWorker.newBuilder().setRunner(runner).setProperties(udsProperties).build() + private def defaultCapabilities: WorkerCapabilities = + WorkerCapabilities.newBuilder() + .addSupportedDataFormats(UDFWorkerDataFormat.ARROW) + .addSupportedCommunicationPatterns( + UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING) + .build() + private def specWithRunner(runner: ProcessCallable): UDFWorkerSpecification = UDFWorkerSpecification.newBuilder() + .setCapabilities(defaultCapabilities) .setDirect(directWorker(runner)) .build() @@ -131,6 +140,7 @@ class DirectWorkerDispatcherSuite env: WorkerEnvironment): UDFWorkerSpecification = UDFWorkerSpecification.newBuilder() .setEnvironment(env) + .setCapabilities(defaultCapabilities) .setDirect(directWorker(runner)) .build() @@ -177,6 +187,33 @@ class DirectWorkerDispatcherSuite assert(worker.activeSessions == 0, "should have 0 sessions after close") } + test("rejects spec without worker capabilities") { + val badSpec = UDFWorkerSpecification.newBuilder() + .setDirect(directWorker(defaultRunner)) + .build() + + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(badSpec) + } + assert(ex.getMessage.contains("supported_data_formats must contain at least ARROW")) + } + + test("rejects unspecified capability enum values") { + val badSpec = UDFWorkerSpecification.newBuilder() + .setCapabilities(WorkerCapabilities.newBuilder() + .addSupportedDataFormats(UDFWorkerDataFormat.UDF_WORKER_DATA_FORMAT_UNSPECIFIED) + .addSupportedCommunicationPatterns( + UDFProtoCommunicationPattern.UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED) + .build()) + .setDirect(directWorker(defaultRunner)) + .build() + + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(badSpec) + } + assert(ex.getMessage.contains("must not contain UNSPECIFIED")) + } + test("concurrent createSession calls produce distinct workers") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) @@ -267,6 +304,7 @@ class DirectWorkerDispatcherSuite .setGracefulTerminationTimeoutMs(500) .build() val spec = UDFWorkerSpecification.newBuilder() + .setCapabilities(defaultCapabilities) .setDirect(DirectWorker.newBuilder() .setRunner(runner).setProperties(shortGracefulProps).build()) .build() @@ -456,6 +494,7 @@ class DirectWorkerDispatcherSuite .setGracefulTerminationTimeoutMs(60000) .build() val spec = UDFWorkerSpecification.newBuilder() + .setCapabilities(defaultCapabilities) .setDirect(DirectWorker.newBuilder() .setRunner(defaultRunner).setProperties(oversizedProps).build()) .build() @@ -475,6 +514,7 @@ class DirectWorkerDispatcherSuite .setInitializationTimeoutMs(60000) .build() val spec = UDFWorkerSpecification.newBuilder() + .setCapabilities(defaultCapabilities) .setDirect(DirectWorker.newBuilder() .setRunner(defaultRunner).setProperties(oversizedProps).build()) .build() @@ -566,6 +606,7 @@ class DirectWorkerDispatcherSuite test("DirectWorker without a connection is rejected") { val badSpec = UDFWorkerSpecification.newBuilder() + .setCapabilities(defaultCapabilities) .setDirect(DirectWorker.newBuilder().setRunner(defaultRunner).build()) .build() val ex = intercept[IllegalArgumentException] { @@ -581,6 +622,7 @@ class DirectWorkerDispatcherSuite .setTcp(LocalTcpConnection.getDefaultInstance).build()) .build() val badSpec = UDFWorkerSpecification.newBuilder() + .setCapabilities(defaultCapabilities) .setDirect(DirectWorker.newBuilder() .setRunner(defaultRunner).setProperties(tcpProperties).build()) .build() @@ -658,6 +700,7 @@ class DirectWorkerDispatcherSuite .setInitializationTimeoutMs(500) .build() val spec = UDFWorkerSpecification.newBuilder() + .setCapabilities(defaultCapabilities) .setDirect(DirectWorker.newBuilder() .setRunner(hangingRunner).setProperties(shortInitProps).build()) .build()