Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark.annotation.Experimental
import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification}
import org.apache.spark.udf.worker.{ProcessCallable, UDFProtoCommunicationPattern,
UDFWorkerDataFormat, UDFWorkerSpecification}
import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher,
WorkerLogger, WorkerSecurityScope, WorkerSession}
import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult,
Expand Down Expand Up @@ -68,6 +69,7 @@ abstract class DirectWorkerDispatcher(
// TODO: Connection pooling -- reuse idle workers across sessions.
// TODO: Security scope isolation -- partition pool by WorkerSecurityScope.

validateWorkerSpec()
validateTransportSupport()
validateEnvironmentCallables()

Expand Down Expand Up @@ -467,6 +469,41 @@ abstract class DirectWorkerDispatcher(
require(!env.hasEnvironmentVerification || env.hasInstallation,
"WorkerEnvironment.environment_verification requires installation to be set")
}

private def validateWorkerSpec(): Unit = {
require(workerSpec.hasDirect,
"UDFWorkerSpecification.worker must be set")

val direct = workerSpec.getDirect
require(direct.hasRunner,
"DirectWorker.runner must be set")

val capabilities = workerSpec.getCapabilities
require(
capabilities.getSupportedDataFormatsCount > 0,
"WorkerCapabilities.supported_data_formats must contain at least ARROW")
require(
!capabilities.getSupportedDataFormatsList.asScala.contains(
UDFWorkerDataFormat.UDF_WORKER_DATA_FORMAT_UNSPECIFIED),
"WorkerCapabilities.supported_data_formats must not contain UNSPECIFIED")
require(
capabilities.getSupportedDataFormatsList.asScala.contains(UDFWorkerDataFormat.ARROW),
"WorkerCapabilities.supported_data_formats must contain ARROW")

require(
capabilities.getSupportedCommunicationPatternsCount > 0,
"WorkerCapabilities.supported_communication_patterns must contain " +
"BIDIRECTIONAL_STREAMING")
require(
!capabilities.getSupportedCommunicationPatternsList.asScala.contains(
UDFProtoCommunicationPattern.UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED),
"WorkerCapabilities.supported_communication_patterns must not contain UNSPECIFIED")
require(
capabilities.getSupportedCommunicationPatternsList.asScala.contains(
UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING),
"WorkerCapabilities.supported_communication_patterns must contain " +
"BIDIRECTIONAL_STREAMING")
}
}

private[direct] object DirectWorkerDispatcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.udf.worker.{
DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec,
DirectWorker, LocalTcpConnection, ProcessCallable, UDFProtoCommunicationPattern,
UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification, UnixDomainSocket,
WorkerCapabilities, WorkerConnectionSpec,
WorkerEnvironment}
import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher,
DirectWorkerException, DirectWorkerProcess, DirectWorkerSession,
Expand Down Expand Up @@ -121,8 +122,16 @@ class DirectWorkerDispatcherSuite
private def directWorker(runner: ProcessCallable): DirectWorker =
DirectWorker.newBuilder().setRunner(runner).setProperties(udsProperties).build()

private def defaultCapabilities: WorkerCapabilities =
WorkerCapabilities.newBuilder()
.addSupportedDataFormats(UDFWorkerDataFormat.ARROW)
.addSupportedCommunicationPatterns(
UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING)
.build()

private def specWithRunner(runner: ProcessCallable): UDFWorkerSpecification =
UDFWorkerSpecification.newBuilder()
.setCapabilities(defaultCapabilities)
.setDirect(directWorker(runner))
.build()

Expand All @@ -131,6 +140,7 @@ class DirectWorkerDispatcherSuite
env: WorkerEnvironment): UDFWorkerSpecification =
UDFWorkerSpecification.newBuilder()
.setEnvironment(env)
.setCapabilities(defaultCapabilities)
.setDirect(directWorker(runner))
.build()

Expand Down Expand Up @@ -177,6 +187,33 @@ class DirectWorkerDispatcherSuite
assert(worker.activeSessions == 0, "should have 0 sessions after close")
}

test("rejects spec without worker capabilities") {
val badSpec = UDFWorkerSpecification.newBuilder()
.setDirect(directWorker(defaultRunner))
.build()

val ex = intercept[IllegalArgumentException] {
new TestDirectWorkerDispatcher(badSpec)
}
assert(ex.getMessage.contains("supported_data_formats must contain at least ARROW"))
}

test("rejects unspecified capability enum values") {
val badSpec = UDFWorkerSpecification.newBuilder()
.setCapabilities(WorkerCapabilities.newBuilder()
.addSupportedDataFormats(UDFWorkerDataFormat.UDF_WORKER_DATA_FORMAT_UNSPECIFIED)
.addSupportedCommunicationPatterns(
UDFProtoCommunicationPattern.UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED)
.build())
.setDirect(directWorker(defaultRunner))
.build()

val ex = intercept[IllegalArgumentException] {
new TestDirectWorkerDispatcher(badSpec)
}
assert(ex.getMessage.contains("must not contain UNSPECIFIED"))
}

test("concurrent createSession calls produce distinct workers") {
dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner))

Expand Down Expand Up @@ -267,6 +304,7 @@ class DirectWorkerDispatcherSuite
.setGracefulTerminationTimeoutMs(500)
.build()
val spec = UDFWorkerSpecification.newBuilder()
.setCapabilities(defaultCapabilities)
.setDirect(DirectWorker.newBuilder()
.setRunner(runner).setProperties(shortGracefulProps).build())
.build()
Expand Down Expand Up @@ -456,6 +494,7 @@ class DirectWorkerDispatcherSuite
.setGracefulTerminationTimeoutMs(60000)
.build()
val spec = UDFWorkerSpecification.newBuilder()
.setCapabilities(defaultCapabilities)
.setDirect(DirectWorker.newBuilder()
.setRunner(defaultRunner).setProperties(oversizedProps).build())
.build()
Expand All @@ -475,6 +514,7 @@ class DirectWorkerDispatcherSuite
.setInitializationTimeoutMs(60000)
.build()
val spec = UDFWorkerSpecification.newBuilder()
.setCapabilities(defaultCapabilities)
.setDirect(DirectWorker.newBuilder()
.setRunner(defaultRunner).setProperties(oversizedProps).build())
.build()
Expand Down Expand Up @@ -566,6 +606,7 @@ class DirectWorkerDispatcherSuite

test("DirectWorker without a connection is rejected") {
val badSpec = UDFWorkerSpecification.newBuilder()
.setCapabilities(defaultCapabilities)
.setDirect(DirectWorker.newBuilder().setRunner(defaultRunner).build())
.build()
val ex = intercept[IllegalArgumentException] {
Expand All @@ -581,6 +622,7 @@ class DirectWorkerDispatcherSuite
.setTcp(LocalTcpConnection.getDefaultInstance).build())
.build()
val badSpec = UDFWorkerSpecification.newBuilder()
.setCapabilities(defaultCapabilities)
.setDirect(DirectWorker.newBuilder()
.setRunner(defaultRunner).setProperties(tcpProperties).build())
.build()
Expand Down Expand Up @@ -658,6 +700,7 @@ class DirectWorkerDispatcherSuite
.setInitializationTimeoutMs(500)
.build()
val spec = UDFWorkerSpecification.newBuilder()
.setCapabilities(defaultCapabilities)
.setDirect(DirectWorker.newBuilder()
.setRunner(hangingRunner).setProperties(shortInitProps).build())
.build()
Expand Down