From 2c2886b3f1ee5e17e3e398cceff09a141f952348 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 16 Mar 2026 15:27:46 +0530 Subject: [PATCH 01/15] graceful shutdown for all UDFs Signed-off-by: Sreekanth --- .../pynumaflow/accumulator/async_server.py | 61 +++++++-- .../accumulator/servicer/async_servicer.py | 31 ++++- .../pynumaflow/batchmapper/async_server.py | 69 +++++++--- .../batchmapper/servicer/async_servicer.py | 16 ++- .../mapper/_servicer/_async_servicer.py | 23 +++- .../mapper/_servicer/_sync_servicer.py | 44 +++++-- .../pynumaflow/mapper/async_server.py | 71 +++++++--- .../pynumaflow/mapper/multiproc_server.py | 14 ++ .../pynumaflow/mapper/sync_server.py | 7 + .../pynumaflow/mapstreamer/async_server.py | 69 +++++++--- .../mapstreamer/servicer/async_servicer.py | 25 +++- .../pynumaflow/reducer/async_server.py | 65 ++++++++-- .../reducer/servicer/async_servicer.py | 29 ++++- .../reducer/servicer/task_manager.py | 15 +-- .../pynumaflow/pynumaflow/shared/server.py | 49 +------ .../pynumaflow/pynumaflow/sideinput/server.py | 6 + .../pynumaflow/sideinput/servicer/servicer.py | 10 +- .../pynumaflow/sourcer/async_server.py | 67 ++++++++-- .../sourcer/servicer/async_servicer.py | 56 ++++++-- .../sourcetransformer/async_server.py | 71 +++++++--- .../sourcetransformer/multiproc_server.py | 14 ++ .../pynumaflow/sourcetransformer/server.py | 6 + .../servicer/_async_servicer.py | 23 +++- .../sourcetransformer/servicer/_servicer.py | 42 ++++-- .../tests/map/test_sync_map_shutdown.py | 121 ++++++++++++++++++ 25 files changed, 795 insertions(+), 209 deletions(-) create mode 100644 packages/pynumaflow/tests/map/test_sync_map_shutdown.py diff --git a/packages/pynumaflow/pynumaflow/accumulator/async_server.py b/packages/pynumaflow/pynumaflow/accumulator/async_server.py index 200e4422..6a2f165d 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/async_server.py +++ b/packages/pynumaflow/pynumaflow/accumulator/async_server.py @@ -1,9 +1,13 @@ +import asyncio +import contextlib import inspect +import sys import aiorun import grpc from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.proto.accumulator import accumulator_pb2_grpc @@ -15,6 +19,7 @@ MAX_NUM_THREADS, ACCUMULATOR_SOCK_PATH, ACCUMULATOR_SERVER_INFO_FILE_PATH, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.accumulator._dtypes import ( @@ -23,7 +28,7 @@ Accumulator, ) -from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server +from pynumaflow.shared.server import NumaflowServer, check_instance def get_handler( @@ -157,6 +162,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncAccumulatorServicer(self.accumulator_handler) + self._error: BaseException | None = None def start(self): """ @@ -167,6 +173,9 @@ def start(self): "Starting Async Accumulator Server", ) aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -176,18 +185,52 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator] - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py index 16be7911..cd35962d 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py @@ -3,7 +3,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING +from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc from pynumaflow.accumulator._dtypes import ( Datum, @@ -13,7 +13,7 @@ KeyedWindow, ) from pynumaflow.accumulator.servicer.task_manager import TaskManager -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -57,6 +57,12 @@ def __init__( ): # The accumulator handler can be a function or a builder class instance. self.__accumulator_handler: AccumulatorAsyncCallable | _AccumulatorBuilderClass = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def AccumulateFn( self, @@ -104,20 +110,35 @@ async def AccumulateFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window EOF response or Window result response # back to the client else: yield msg except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return # Wait for the process_input_stream task to finish for a clean exit try: await producer except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py index 1078e012..4fa3221b 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -8,9 +12,11 @@ BATCH_MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.batchmapper._dtypes import BatchMapCallable from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -19,7 +25,7 @@ ContainerType, ) from pynumaflow.proto.mapper import map_pb2_grpc -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer class BatchMapAsyncServer(NumaflowServer): @@ -92,6 +98,7 @@ async def handler( ] self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance) + self._error: BaseException | None = None def start(self): """ @@ -99,6 +106,9 @@ def start(self): to the aexec so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -108,25 +118,54 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server( - self.servicer, - server, - ) - _LOGGER.info("Starting Batch Map Server") + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) + serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.BatchMap - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py index 523a4ad4..b6d866d3 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py @@ -7,7 +7,7 @@ from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING @@ -26,6 +26,12 @@ def __init__( ): self.background_tasks = set() self.__batch_map_handler: BatchMapCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -97,8 +103,12 @@ async def MapFn( await req_queue.put(datum) except BaseException as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() return async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py index 90a55b7b..0cbf18f2 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError, Message, Messages from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -25,6 +25,12 @@ def __init__( ): self.background_tasks = set() self.__map_handler: MapAsyncCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -57,7 +63,12 @@ async def MapFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window response back to the client else: @@ -65,8 +76,12 @@ async def MapFn( # wait for the producer task to complete await producer except BaseException as e: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index cb757e3c..17895992 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -2,8 +2,9 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterator +import grpc from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER, ERR_UDF_EXCEPTION_STRING @@ -26,6 +27,10 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False): self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + # Graceful shutdown: when set, a watcher thread in _run_server() calls + # server.stop() instead of hard-killing the process via psutil. + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def MapFn( self, @@ -36,6 +41,7 @@ def MapFn( Applies a function to each datum element. The pascal case function name comes from the proto map_pb2_grpc.py file. """ + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -57,10 +63,19 @@ def MapFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + if isinstance(res, grpc.RpcError): + # Client disconnected mid-stream — the reader thread + # surfaced the error via the queue. Not a UDF fault. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + result_queue.close() + self.shutdown_event.set() + return + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + update_context_err(context, res, err_msg) + # Unblock the reader thread if it is waiting on queue.put() + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -69,12 +84,23 @@ def MapFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + # Client disconnected — not a UDF error, but we still need to + # shut down the server so the process can exit cleanly. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( diff --git a/packages/pynumaflow/pynumaflow/mapper/async_server.py b/packages/pynumaflow/pynumaflow/mapper/async_server.py index 5bba75d7..7cddfc57 100644 --- a/packages/pynumaflow/pynumaflow/mapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -7,7 +11,10 @@ MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -18,10 +25,7 @@ from pynumaflow.mapper._dtypes import MapAsyncCallable from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer from pynumaflow.proto.mapper import map_pb2_grpc -from pynumaflow.shared.server import ( - NumaflowServer, - start_async_server, -) +from pynumaflow.shared.server import NumaflowServer class MapAsyncServer(NumaflowServer): @@ -92,6 +96,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncMapServicer(handler=mapper_instance) + self._error: BaseException | None = None def start(self) -> None: """ @@ -99,32 +104,66 @@ def start(self) -> None: so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self) -> None: """ Starts the Async gRPC server on the given UNIX socket with given max threads. """ - # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) - server_new = grpc.aio.server(options=self._server_options) - server_new.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server(self.servicer, server_new) + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap - # Start the async server - await start_async_server( - server_async=server_new, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py index 5d68a96b..de08f075 100644 --- a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py @@ -1,4 +1,8 @@ +import multiprocessing +import sys + from pynumaflow._constants import ( + _LOGGER, NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, MAP_SOCK_PATH, @@ -104,6 +108,11 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: self._process_count = min(server_count, 2 * _PROCESS_COUNT) self.servicer = SyncMapServicer(handler=mapper_instance, multiproc=True) + # Shared event across all worker processes for coordinated shutdown. + # When any worker's servicer sets this event, all workers' watcher + # threads trigger server.stop() for a graceful coordinated exit. + self._shutdown_event = multiprocessing.Event() + def start(self) -> None: """ Starts the N grpc servers gRPC serves on the with given max threads. @@ -129,4 +138,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=server_info, + shutdown_event=self._shutdown_event, ) + + if self._shutdown_event.is_set(): + _LOGGER.critical("Server exiting due to worker error") + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/mapper/sync_server.py b/packages/pynumaflow/pynumaflow/mapper/sync_server.py index 9c2431b6..a96ceb7f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/sync_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/sync_server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -112,4 +114,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py index 187c720d..b6b0fb23 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py @@ -1,6 +1,11 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -18,11 +23,12 @@ _LOGGER, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.mapstreamer._dtypes import MapStreamCallable -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer class MapStreamAsyncServer(NumaflowServer): @@ -111,6 +117,7 @@ async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Messag ] self.servicer = AsyncMapStreamServicer(handler=self.map_stream_instance) + self._error: BaseException | None = None def start(self): """ @@ -118,6 +125,9 @@ def start(self): to the aexec so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -127,25 +137,54 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server( - self.servicer, - server, - ) - _LOGGER.info("Starting Map Stream Server") + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) + serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.StreamMap - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py index f5a9a999..c5aa3545 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow.mapstreamer import Datum from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2 -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -22,6 +22,12 @@ class AsyncMapStreamServicer(map_pb2_grpc.MapServicer): def __init__(self, handler: MapStreamCallable): self.__map_stream_handler: MapStreamCallable = handler self._background_tasks: set[asyncio.Task] = set() + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -51,7 +57,12 @@ async def MapFn( # Consume results as they arrive and stream them to the client async for msg in global_result_queue.read_iterator(): if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return else: # msg is a map_pb2.MapResponse, already formed @@ -61,8 +72,12 @@ async def MapFn( await producer except BaseException as e: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( @@ -124,7 +139,7 @@ async def _invoke_map_stream( except BaseException as err: _LOGGER.critical("MapFn handler error", exc_info=True) # Surface handler error to the main producer; - # it will call handle_async_error and end the RPC + # it will set the shutdown event and end the RPC await result_queue.put(err) async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/reducer/async_server.py b/packages/pynumaflow/pynumaflow/reducer/async_server.py index aee4d355..ff5f9d8e 100644 --- a/packages/pynumaflow/pynumaflow/reducer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducer/async_server.py @@ -1,8 +1,12 @@ +import asyncio +import contextlib import inspect +import sys import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.proto.reducer import reduce_pb2_grpc @@ -15,6 +19,7 @@ _LOGGER, REDUCE_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.reducer._dtypes import ( @@ -23,7 +28,7 @@ Reducer, ) -from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server +from pynumaflow.shared.server import NumaflowServer, check_instance def get_handler( @@ -143,6 +148,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncReduceServicer(self.reducer_handler) + self._error: BaseException | None = None def start(self): """ @@ -153,6 +159,9 @@ def start(self): "Starting Async Reduce Server", ) aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -162,20 +171,52 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - reduce_servicer = self.servicer - reduce_pb2_grpc.add_ReduceServicer_to_server(reduce_servicer, server) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + reduce_pb2_grpc.add_ReduceServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Reducer] - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py index 44db8077..3ea646e3 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import AsyncIterable from google.protobuf import empty_pb2 as _empty_pb2 @@ -12,7 +13,7 @@ WindowOperation, ) from pynumaflow.reducer.servicer.task_manager import TaskManager -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -53,6 +54,12 @@ def __init__( ): # The Reduce handler can be a function or a builder class instance. self.__reduce_handler: ReduceAsyncCallable | _ReduceBuilderClass = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def ReduceFn( self, @@ -103,9 +110,13 @@ async def ReduceFn( await task_manager.append_task(request) except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() + return # send EOF to all the tasks once the request iterator is exhausted # This will signal the tasks to stop reading the data on their @@ -134,9 +145,13 @@ async def ReduceFn( yield reduce_pb2.ReduceResponse(window=window, EOF=True) except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py index 2e21d60a..bfc802a7 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py @@ -2,8 +2,6 @@ from datetime import datetime, timezone from collections.abc import AsyncIterable -import grpc - from pynumaflow.exceptions import UDFError from pynumaflow.proto.reducer import reduce_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator @@ -21,7 +19,7 @@ ReduceAsyncCallable, ReduceWindow, ) -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -169,14 +167,9 @@ async def __invoke_reduce( msgs = await new_instance(keys, request_iterator, md) except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await asyncio.gather( - self.context.abort(grpc.StatusCode.UNKNOWN, details=repr(err)), - return_exceptions=True, - ) - exit_on_error(err=repr(err), parent=False, context=self.context, update_context=False) - return + err_msg = f"ReduceError: {repr(err)}" + update_context_err(self.context, err, err_msg) + raise datum_responses = [] for msg in msgs: diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 3986e0dc..64a2fd03 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import io import multiprocessing @@ -14,7 +13,6 @@ from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor import grpc -import psutil from pynumaflow._constants import ( _LOGGER, @@ -90,7 +88,7 @@ def _run_server( udf_type: str, server_info_file: str | None = None, server_info: ServerInfo | None = None, - shutdown_event: threading.Event | None = None, + shutdown_event: threading.Event | multiprocessing.Event | None = None, ) -> None: """ Starts the Synchronous server instance on the given UNIX socket @@ -151,6 +149,7 @@ def start_multiproc_server( server_info: ServerInfo | None = None, server_options=None, udf_type: str = UDFType.Map, + shutdown_event: multiprocessing.Event | None = None, ): """ Start N grpc servers in different processes where N = The number of CPUs or the @@ -179,6 +178,7 @@ def start_multiproc_server( worker = multiprocessing.Process( target=_run_server, args=(servicer, bind_address, max_threads, server_options, udf_type), + kwargs={"shutdown_event": shutdown_event} if shutdown_event else {}, ) worker.start() workers.append(worker) @@ -278,37 +278,6 @@ def get_grpc_status(err: str, detail: str | None = None): return rpc_status.to_status(status) -def exit_on_error( - context: NumaflowServicerContext, err: str, parent: bool = False, update_context=True -): - """ - Exit the current/parent process on an error. - - Args: - context (NumaflowServicerContext): The gRPC context. - err (str): The error message. - parent (bool, optional): Whether this is the parent process. - Defaults to False. - update_context(bool, optional) : Is there a need to update - the context with the error codes - """ - if update_context: - # Create a status object with the error details - grpc_status = get_grpc_status(err) - - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(err) - context.set_trailing_metadata(grpc_status.trailing_metadata) - - p = psutil.Process(os.getpid()) - # If the parent flag is true, we exit from the parent process - # Use this for Multiproc right now to exit from the parent fork - if parent: - p = psutil.Process(os.getppid()) - _LOGGER.info("Killing process: Got exception %s", err) - p.kill() - - def update_context_err(context: NumaflowServicerContext, e: BaseException, err_msg: str): """ Update the context with the error and log the exception. @@ -330,15 +299,3 @@ def get_exception_traceback_str(exc) -> str: return file.getvalue().rstrip() -async def handle_async_error( - context: NumaflowServicerContext, exception: BaseException, exception_type: str -): - """ - Handle exceptions for async servers by updating the context and exiting. - """ - err_msg = f"{exception_type}: {repr(exception)}" - update_context_err(context, exception, err_msg) - await asyncio.gather( - context.abort(grpc.StatusCode.INTERNAL, details=err_msg), return_exceptions=True - ) - exit_on_error(err=err_msg, parent=False, context=context, update_context=False) diff --git a/packages/pynumaflow/pynumaflow/sideinput/server.py b/packages/pynumaflow/pynumaflow/sideinput/server.py index 7bb27b86..445e23e1 100644 --- a/packages/pynumaflow/pynumaflow/sideinput/server.py +++ b/packages/pynumaflow/pynumaflow/sideinput/server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.shared import NumaflowServer from pynumaflow.shared.server import sync_server_start @@ -99,4 +101,8 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SideInput, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py b/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py index 7f46bf68..da836948 100644 --- a/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py +++ b/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py @@ -1,3 +1,5 @@ +import threading + from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow._constants import ( @@ -5,7 +7,7 @@ ERR_UDF_EXCEPTION_STRING, ) from pynumaflow.proto.sideinput import sideinput_pb2_grpc, sideinput_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sideinput._dtypes import RetrieverCallable from pynumaflow.types import NumaflowServicerContext @@ -16,6 +18,8 @@ def __init__( handler: RetrieverCallable, ): self.__retrieve_handler: RetrieverCallable = handler + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def RetrieveSideInput( self, request: _empty_pb2.Empty, context: NumaflowServicerContext @@ -30,7 +34,9 @@ def RetrieveSideInput( except BaseException as err: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) - exit_on_error(context, err_msg) + update_context_err(context, err, err_msg) + self.error = err + self.shutdown_event.set() return return sideinput_pb2.SideInputResponse(value=rspn.value, no_broadcast=rspn.no_broadcast) diff --git a/packages/pynumaflow/pynumaflow/sourcer/async_server.py b/packages/pynumaflow/pynumaflow/sourcer/async_server.py index 3bca9dfb..2f54b158 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcer/async_server.py @@ -1,6 +1,11 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.sourcer.servicer.async_servicer import AsyncSourceServicer @@ -10,10 +15,12 @@ NUM_THREADS_DEFAULT, SOURCE_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.proto.sourcer import source_pb2_grpc -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer from pynumaflow.sourcer._dtypes import SourceCallable @@ -153,6 +160,7 @@ async def partitions_handler(self) -> PartitionsResponse: ] self.servicer = AsyncSourceServicer(source_handler=sourcer_instance) + self._error: BaseException | None = None def start(self): """ @@ -160,6 +168,9 @@ def start(self): so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -168,20 +179,52 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - source_servicer = self.servicer - source_pb2_grpc.add_SourceServicer_to_server(source_servicer, server) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + source_pb2_grpc.add_SourceServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sourcer] - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py index 3e0839c4..0f8a4db9 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py @@ -5,7 +5,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sourcer import ReadRequest, Offset, NackRequest, AckRequest, SourceCallable from pynumaflow.proto.sourcer import source_pb2 from pynumaflow.proto.sourcer import source_pb2_grpc @@ -71,6 +71,12 @@ def __init__(self, source_handler: SourceCallable): self.source_handler = source_handler self.__initialize_handlers() self.cleanup_coroutines = [] + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event def __initialize_handlers(self): """Initialize handler methods from the provided source handler.""" @@ -110,7 +116,12 @@ async def ReadFn( async for resp in riter: if isinstance(resp, BaseException): - await handle_async_error(context, resp, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(resp)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, resp, err_msg) + self._error = resp + if self._shutdown_event is not None: + self._shutdown_event.set() return yield _create_read_response(resp) @@ -121,7 +132,12 @@ async def ReadFn( yield _create_eot_response() except BaseException as err: _LOGGER.critical("User-Defined Source ReadFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def __invoke_read( self, req: source_pb2.ReadRequest, niter: NonBlockingIterator[Message | Exception] @@ -169,7 +185,12 @@ async def AckFn( yield _create_ack_response() except BaseException as err: _LOGGER.critical("User-Defined Source AckFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def NackFn( self, @@ -186,7 +207,12 @@ async def NackFn( await self.__source_nack_handler(NackRequest(offsets=offsets)) except BaseException as err: _LOGGER.critical("User-Defined Source NackFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return return source_pb2.NackResponse( result=source_pb2.NackResponse.Result(success=_empty_pb2.Empty()) ) @@ -211,8 +237,14 @@ async def PendingFn( count = await self.__source_pending_handler() except BaseException as err: _LOGGER.critical("PendingFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) - raise + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PendingResponse( + result=source_pb2.PendingResponse.Result(count=0) + ) resp = source_pb2.PendingResponse.Result(count=count.count) return source_pb2.PendingResponse(result=resp) @@ -226,8 +258,14 @@ async def PartitionsFn( partitions = await self.__source_partitions_handler() except BaseException as err: _LOGGER.critical("PartitionsFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) - raise + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PartitionsResponse( + result=source_pb2.PartitionsResponse.Result(partitions=[]) + ) resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) return source_pb2.PartitionsResponse(result=resp) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py index 990e4587..16ce1496 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -7,17 +11,17 @@ MAX_NUM_THREADS, SOURCE_TRANSFORMER_SOCK_PATH, SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType, ) from pynumaflow.proto.sourcetransformer import transform_pb2_grpc -from pynumaflow.shared.server import ( - NumaflowServer, - start_async_server, -) +from pynumaflow.shared.server import NumaflowServer from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer @@ -115,6 +119,7 @@ def __init__( ("grpc.max_receive_message_length", self.max_message_size), ] self.servicer = SourceTransformAsyncServicer(handler=source_transform_instance) + self._error: BaseException | None = None def start(self) -> None: """ @@ -122,32 +127,66 @@ def start(self) -> None: so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self) -> None: """ Starts the Async gRPC server on the given UNIX socket with given max threads. """ - # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) - server_new = grpc.aio.server(options=self._server_options) - server_new.add_insecure_port(self.sock_path) - transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server_new) + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ ContainerType.Sourcetransformer ] - # Start the async server - await start_async_server( - server_async=server_new, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + await server.wait_for_termination() + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_event_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py index dbc8b7b5..f1ff372e 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py @@ -1,9 +1,13 @@ +import multiprocessing +import sys + from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.sourcetransformer.servicer._servicer import SourceTransformServicer from pynumaflow.shared.server import start_multiproc_server from pynumaflow._constants import ( + _LOGGER, MAX_MESSAGE_SIZE, SOURCE_TRANSFORMER_SOCK_PATH, NUM_THREADS_DEFAULT, @@ -129,6 +133,11 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: self._process_count = min(server_count, 2 * _PROCESS_COUNT) self.servicer = SourceTransformServicer(handler=source_transform_instance, multiproc=True) + # Shared event across all worker processes for coordinated shutdown. + # When any worker's servicer sets this event, all workers' watcher + # threads trigger server.stop() for a graceful coordinated exit. + self._shutdown_event = multiprocessing.Event() + def start(self): """ Starts the N gRPC servers on the given socket path with given max threads. @@ -148,4 +157,9 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SourceTransformer, server_info=serv_info, + shutdown_event=self._shutdown_event, ) + + if self._shutdown_event.is_set(): + _LOGGER.critical("Server exiting due to worker error") + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/server.py index 7069e2b6..c410adce 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ContainerType, MINIMUM_NUMAFLOW_VERSION, ServerInfo from pynumaflow._constants import ( MAX_MESSAGE_SIZE, @@ -128,4 +130,8 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SourceTransformer, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py index da7384c2..d85fcebd 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.sourcetransformer import transform_pb2, transform_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sourcetransformer import Datum from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable from pynumaflow.types import NumaflowServicerContext @@ -28,6 +28,12 @@ def __init__( ): self.background_tasks = set() self.__transform_handler: SourceTransformAsyncCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def SourceTransformFn( self, @@ -61,7 +67,12 @@ async def SourceTransformFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window response back to the client else: @@ -69,8 +80,12 @@ async def SourceTransformFn( # wait for the producer task to complete await producer except BaseException as e: - _LOGGER.critical("SourceTransformFnError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py index 2091bf47..1b93b96c 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py @@ -2,10 +2,11 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterable +import grpc from google.protobuf import empty_pb2 as _empty_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.shared.synciter import SyncIterator from pynumaflow.sourcetransformer import Datum from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable @@ -46,6 +47,10 @@ def __init__(self, handler: SourceTransformCallable, multiproc: bool = False): self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + # Graceful shutdown: when set, a watcher thread in _run_server() calls + # server.stop() instead of hard-killing the process via psutil. + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def SourceTransformFn( self, @@ -56,6 +61,8 @@ def SourceTransformFn( Applies a function to each datum element. The pascal case function name comes from the generated transform_pb2_grpc.py file. """ + # Initialize before try so it's accessible in except blocks + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -78,10 +85,18 @@ def SourceTransformFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + if isinstance(res, grpc.RpcError): + # Client disconnected mid-stream — the reader thread + # surfaced the error via the queue. Not a UDF fault. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + result_queue.close() + self.shutdown_event.set() + return + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + update_context_err(context, res, err_msg) + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -90,12 +105,21 @@ def SourceTransformFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( diff --git a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py new file mode 100644 index 00000000..d8df92f0 --- /dev/null +++ b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py @@ -0,0 +1,121 @@ +""" +Shutdown-event tests for the synchronous Map servicer. + +Mirrors the sinker shutdown test pattern (tests/sink/test_server.py lines 345-461). +Each test verifies that the servicer sets shutdown_event (and optionally captures the +error) under a specific failure mode, enabling graceful server stop via the watcher +thread in _run_server() instead of a hard process kill. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.mapper import Datum, Messages, Message +from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer +from pynumaflow.proto.mapper import map_pb2 +from tests.map.utils import map_handler, err_map_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SyncMapServicer(handler=err_map_handler) + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + servicer = SyncMapServicer(handler=map_handler) + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + # Send a data message without a handshake first + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "MapFn: expected handshake as the first message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + servicer = SyncMapServicer(handler=map_handler) + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + servicer = SyncMapServicer(handler=map_handler) + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + # Should have at least the handshake response + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None From aa9496f301997d2ae8982be72138f6cccfc83842 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 16 Mar 2026 15:39:37 +0530 Subject: [PATCH 02/15] file formatting Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/shared/server.py | 2 -- .../pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py | 4 +--- packages/pynumaflow/tests/map/test_sync_map_shutdown.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 64a2fd03..d2ba610a 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -297,5 +297,3 @@ def get_exception_traceback_str(exc) -> str: file = io.StringIO() traceback.print_exception(exc, value=exc, tb=exc.__traceback__, file=file) return file.getvalue().rstrip() - - diff --git a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py index 0f8a4db9..1929e536 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py @@ -242,9 +242,7 @@ async def PendingFn( self._error = err if self._shutdown_event is not None: self._shutdown_event.set() - return source_pb2.PendingResponse( - result=source_pb2.PendingResponse.Result(count=0) - ) + return source_pb2.PendingResponse(result=source_pb2.PendingResponse.Result(count=0)) resp = source_pb2.PendingResponse.Result(count=count.count) return source_pb2.PendingResponse(result=resp) diff --git a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py index d8df92f0..cf8523c1 100644 --- a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py +++ b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py @@ -13,7 +13,6 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.mapper import Datum, Messages, Message from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, get_test_datums From 6c37cb5e1591bdb25b3cf46c0f7fb541d765cf13 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Mon, 16 Mar 2026 20:33:15 +0530 Subject: [PATCH 03/15] fix CI Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/shared/server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index d2ba610a..e5aec76c 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -1,6 +1,7 @@ import contextlib import io import multiprocessing +import multiprocessing.synchronize import os import socket import threading @@ -88,7 +89,7 @@ def _run_server( udf_type: str, server_info_file: str | None = None, server_info: ServerInfo | None = None, - shutdown_event: threading.Event | multiprocessing.Event | None = None, + shutdown_event: threading.Event | None = None, ) -> None: """ Starts the Synchronous server instance on the given UNIX socket @@ -149,7 +150,7 @@ def start_multiproc_server( server_info: ServerInfo | None = None, server_options=None, udf_type: str = UDFType.Map, - shutdown_event: multiprocessing.Event | None = None, + shutdown_event: multiprocessing.synchronize.Event | None = None, ): """ Start N grpc servers in different processes where N = The number of CPUs or the From 72849eceb521bd343ff81191c3180d4709bea5c0 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 05:42:54 +0530 Subject: [PATCH 04/15] tested graceful shutdown for source Signed-off-by: Sreekanth --- .../pynumaflow/pynumaflow/accumulator/async_server.py | 9 ++++++++- .../pynumaflow/pynumaflow/batchmapper/async_server.py | 9 ++++++++- packages/pynumaflow/pynumaflow/mapper/async_server.py | 9 ++++++++- .../pynumaflow/pynumaflow/mapstreamer/async_server.py | 9 ++++++++- packages/pynumaflow/pynumaflow/reducer/async_server.py | 9 ++++++++- .../pynumaflow/pynumaflow/reducestreamer/async_server.py | 9 ++++++++- packages/pynumaflow/pynumaflow/sinker/async_server.py | 9 ++++++++- packages/pynumaflow/pynumaflow/sourcer/async_server.py | 9 ++++++++- .../pynumaflow/sourcetransformer/async_server.py | 9 ++++++++- 9 files changed, 72 insertions(+), 9 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/accumulator/async_server.py b/packages/pynumaflow/pynumaflow/accumulator/async_server.py index 6a2f165d..36146e68 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/async_server.py +++ b/packages/pynumaflow/pynumaflow/accumulator/async_server.py @@ -217,7 +217,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py index 4fa3221b..78d9f12e 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py @@ -152,7 +152,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/mapper/async_server.py b/packages/pynumaflow/pynumaflow/mapper/async_server.py index 7cddfc57..9078e5f4 100644 --- a/packages/pynumaflow/pynumaflow/mapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/async_server.py @@ -150,7 +150,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py index b6b0fb23..2c833d43 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py @@ -171,7 +171,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/reducer/async_server.py b/packages/pynumaflow/pynumaflow/reducer/async_server.py index ff5f9d8e..cc52d15e 100644 --- a/packages/pynumaflow/pynumaflow/reducer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducer/async_server.py @@ -203,7 +203,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index 63123c88..5ac14455 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -220,7 +220,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/sinker/async_server.py b/packages/pynumaflow/pynumaflow/sinker/async_server.py index 516fbb82..d2c3c0b8 100644 --- a/packages/pynumaflow/pynumaflow/sinker/async_server.py +++ b/packages/pynumaflow/pynumaflow/sinker/async_server.py @@ -174,7 +174,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/sourcer/async_server.py b/packages/pynumaflow/pynumaflow/sourcer/async_server.py index 2f54b158..b8ef4965 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcer/async_server.py @@ -211,7 +211,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py index 16ce1496..ee25c0af 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py @@ -173,7 +173,14 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. We must stop + # the gRPC server explicitly so its __del__ doesn't try to + # schedule a coroutine on the already-closed event loop. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error From 3fd5398387d240d75a992279a1fa5858f9b741a1 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 05:49:49 +0530 Subject: [PATCH 05/15] tested graceful shutdown from sync map Signed-off-by: Sreekanth --- .../pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py | 5 +++++ .../pynumaflow/sourcetransformer/servicer/_servicer.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index 17895992..9415f0a4 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -118,6 +118,11 @@ def _process_requests( self.executor.shutdown(wait=True) # Indicate to the result queue that no more messages left to process result_queue.put(STREAM_EOF) + except grpc.RpcError as e: + # Client disconnected — expected during pod shutdown. + # Surface to the consumer which will trigger graceful shutdown. + _LOGGER.warning("gRPC stream closed in reader thread") + result_queue.put(e) except BaseException as e: _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) # Surface the error to the consumer; MapFn will handle and exit diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py index 1b93b96c..7508e41f 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py @@ -137,6 +137,11 @@ def _process_requests( self.executor.shutdown(wait=True) # Indicate to the result queue that no more messages left to process result_queue.put(STREAM_EOF) + except grpc.RpcError as e: + # Client disconnected — expected during pod shutdown. + # Surface to the consumer which will trigger graceful shutdown. + _LOGGER.warning("gRPC stream closed in reader thread") + result_queue.put(e) except BaseException as e: _LOGGER.critical("SourceTransformFnError, re-raising the error", exc_info=True) # Surface the error to the consumer; SourceTransformFn will handle and exit From 3de2701c1c5dc1929beb8317007f46b909027000 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 05:59:56 +0530 Subject: [PATCH 06/15] fixes Signed-off-by: Sreekanth --- .../accumulator/servicer/async_servicer.py | 14 +++++++ .../batchmapper/servicer/async_servicer.py | 7 ++++ .../mapper/_servicer/_async_servicer.py | 7 ++++ .../mapstreamer/servicer/async_servicer.py | 7 ++++ .../reducer/servicer/async_servicer.py | 14 +++++++ .../sinker/servicer/async_servicer.py | 7 ++++ .../sourcer/servicer/async_servicer.py | 37 +++++++++++++++++++ .../servicer/_async_servicer.py | 7 ++++ 8 files changed, 100 insertions(+) diff --git a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py index cd35962d..886f0e66 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py @@ -121,6 +121,13 @@ async def AccumulateFn( # back to the client else: yield msg + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) @@ -132,6 +139,13 @@ async def AccumulateFn( # Wait for the process_input_stream task to finish for a clean exit try: await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py index b6d866d3..63733961 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py @@ -102,6 +102,13 @@ async def MapFn( ) await req_queue.put(datum) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py index 0cbf18f2..df0265e4 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py @@ -75,6 +75,13 @@ async def MapFn( yield msg # wait for the producer task to complete await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py index c5aa3545..77942b71 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -71,6 +71,13 @@ async def MapFn( # Ensure producer has finished (covers graceful shutdown) await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py index 3ea646e3..7b147064 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py @@ -108,6 +108,13 @@ async def ReduceFn( # append the task data to the existing task # if the task does not exist, it will create a new task await task_manager.append_task(request) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" @@ -143,6 +150,13 @@ async def ReduceFn( for window in current_window.values(): # yield the EOF response once the task is completed for a keyed window yield reduce_pb2.ReduceResponse(window=window, EOF=True) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" diff --git a/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py index 908aa126..5eade634 100644 --- a/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py @@ -87,6 +87,13 @@ async def SinkFn( # if we have a valid message, we will add it to the request queue for processing. datum = datum_from_sink_req(d) await req_queue.put(datum) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py index 1929e536..51cd74c2 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py @@ -130,6 +130,13 @@ async def ReadFn( await task # send an eot to signal all messages have been processed. yield _create_eot_response() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("User-Defined Source ReadFn error", exc_info=True) err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" @@ -183,6 +190,13 @@ async def AckFn( ] await self.__source_ack_handler(AckRequest(offsets=offsets)) yield _create_ack_response() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("User-Defined Source AckFn error", exc_info=True) err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" @@ -205,6 +219,13 @@ async def NackFn( Offset(offset.offset, offset.partition_id) for offset in request.request.offsets ] await self.__source_nack_handler(NackRequest(offsets=offsets)) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("User-Defined Source NackFn error", exc_info=True) err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" @@ -235,6 +256,13 @@ async def PendingFn( """ try: count = await self.__source_pending_handler() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PendingResponse(result=source_pb2.PendingResponse.Result(count=0)) + except BaseException as err: _LOGGER.critical("PendingFn Error", exc_info=True) err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" @@ -254,6 +282,15 @@ async def PartitionsFn( """ try: partitions = await self.__source_partitions_handler() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PartitionsResponse( + result=source_pb2.PartitionsResponse.Result(partitions=[]) + ) + except BaseException as err: _LOGGER.critical("PartitionsFn Error", exc_info=True) err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py index d85fcebd..819c27c3 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py @@ -79,6 +79,13 @@ async def SourceTransformFn( yield msg # wait for the producer task to complete await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) From d8c1eace9976cbd176e5e388dc6704d4a0c62b19 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 06:12:25 +0530 Subject: [PATCH 07/15] multiproc map clean shutdown Signed-off-by: Sreekanth --- .../pynumaflow/accumulator/servicer/task_manager.py | 9 +++++++++ .../pynumaflow/reducer/servicer/task_manager.py | 3 +++ packages/pynumaflow/pynumaflow/shared/server.py | 12 ++++++++++++ 3 files changed, 24 insertions(+) diff --git a/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py index a9758bf7..ee14e3ea 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py @@ -213,6 +213,9 @@ async def __invoke_accumulator( _ = await new_instance(request_iterator, output) # send EOF to the output stream await output.put(STREAM_EOF) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + return # If there is an error in the accumulator operation, log and # then send the error to the result queue except BaseException as err: @@ -243,6 +246,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator case _: _LOGGER.debug(f"No operation matched for request: {request}", exc_info=True) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + return # If there is an error in the accumulator operation, log and # then send the error to the result queue except BaseException as e: @@ -274,6 +280,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator # Now send STREAM_EOF to terminate the global result queue iterator await self.global_result_queue.put(STREAM_EOF) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + return except BaseException as e: err_msg = f"Accumulator Streaming Error: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py index bfc802a7..3023a706 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py @@ -165,6 +165,9 @@ async def __invoke_reduce( new_instance = self.__reduce_handler.create() try: msgs = await new_instance(keys, request_iterator, md) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + raise except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) err_msg = f"ReduceError: {repr(err)}" diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index e5aec76c..848cae14 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -3,6 +3,7 @@ import multiprocessing import multiprocessing.synchronize import os +import signal import socket import threading import traceback @@ -191,6 +192,17 @@ def start_multiproc_server( server_info.metadata[MULTIPROC_KEY] = str(process_count) info_server_write(server_info=server_info, info_file=server_info_file) + # Register a SIGTERM handler so that kubectl delete triggers graceful + # shutdown of all child workers via the shared multiprocessing.Event, + # instead of the default abrupt kill. + if shutdown_event is not None: + + def _sigterm_handler(signum, frame): + _LOGGER.info("SIGTERM received, signalling workers to shut down...") + shutdown_event.set() + + signal.signal(signal.SIGTERM, _sigterm_handler) + for worker in workers: worker.join() From 18b3d879a5868221b2e5ebdbc24bee4dc5ea86f7 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 06:20:03 +0530 Subject: [PATCH 08/15] tested all for graceful shutdown on kubectl delete Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/mapper/multiproc_server.py | 4 ++-- packages/pynumaflow/pynumaflow/shared/server.py | 4 ++++ .../pynumaflow/sourcetransformer/multiproc_server.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py index de08f075..4055666f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py @@ -130,7 +130,7 @@ def start(self) -> None: server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap # Start the multiproc server - start_multiproc_server( + has_error = start_multiproc_server( max_threads=self.max_threads, servicer=self.servicer, process_count=self._process_count, @@ -141,6 +141,6 @@ def start(self) -> None: shutdown_event=self._shutdown_event, ) - if self._shutdown_event.is_set(): + if has_error: _LOGGER.critical("Server exiting due to worker error") sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 848cae14..0c29f1e0 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -206,6 +206,10 @@ def _sigterm_handler(signum, frame): for worker in workers: worker.join() + # Return True if any worker exited with a non-zero code (i.e. a real error, + # not a clean SIGTERM shutdown). + return any(w.exitcode != 0 for w in workers) + async def start_async_server( server_async: grpc.aio.Server, diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py index f1ff372e..fba85017 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py @@ -149,7 +149,7 @@ def start(self): serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ ContainerType.Sourcetransformer ] - start_multiproc_server( + has_error = start_multiproc_server( max_threads=self.max_threads, servicer=self.servicer, process_count=self._process_count, @@ -160,6 +160,6 @@ def start(self): shutdown_event=self._shutdown_event, ) - if self._shutdown_event.is_set(): + if has_error: _LOGGER.critical("Server exiting due to worker error") sys.exit(1) From 000324314f70189ade18d32407e3aee212b613ed Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 11:01:27 +0530 Subject: [PATCH 09/15] minor changes Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/shared/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 0c29f1e0..943f3cee 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -187,7 +187,9 @@ def start_multiproc_server( if server_info is None: server_info = ServerInfo.get_default_server_info() - server_info.metadata = get_metadata_env(envs=METADATA_ENVS) + # Merge env metadata into existing metadata (preserving caller-set keys + # like MAP_MODE_KEY) rather than overwriting the entire dict. + server_info.metadata.update(get_metadata_env(envs=METADATA_ENVS)) # Add the MULTIPROC metadata using the number of servers to use server_info.metadata[MULTIPROC_KEY] = str(process_count) info_server_write(server_info=server_info, info_file=server_info_file) From 4f412c29458f51883c2404e61b45ce68157711a5 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 11:16:00 +0530 Subject: [PATCH 10/15] unit tests Signed-off-by: Sreekanth --- .../test_async_accumulator_shutdown.py | 199 +++++++++++++++++ .../batchmap/test_async_batch_map_shutdown.py | 83 +++++++ .../tests/map/test_async_map_shutdown.py | 107 +++++++++ .../test_async_map_stream_shutdown.py | 84 +++++++ .../reduce/test_async_reduce_shutdown.py | 111 ++++++++++ .../tests/sink/test_async_sink_shutdown.py | 85 ++++++++ .../source/test_async_source_shutdown.py | 205 ++++++++++++++++++ .../sourcetransform/test_sync_shutdown.py | 128 +++++++++++ 8 files changed, 1002 insertions(+) create mode 100644 packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py create mode 100644 packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py create mode 100644 packages/pynumaflow/tests/map/test_async_map_shutdown.py create mode 100644 packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py create mode 100644 packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py create mode 100644 packages/pynumaflow/tests/sink/test_async_sink_shutdown.py create mode 100644 packages/pynumaflow/tests/source/test_async_source_shutdown.py create mode 100644 packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py new file mode 100644 index 00000000..5a7020fa --- /dev/null +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py @@ -0,0 +1,199 @@ +""" +Shutdown-event tests for the async Accumulator servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during consumer iteration (SIGTERM scenario) +2. BaseException during consumer iteration (unexpected error) +3. CancelledError during producer await (SIGTERM scenario) +4. BaseException during producer await (producer task error) +5. Exception object yielded from the result queue (task_manager error) +""" + +import asyncio +from unittest import mock + +from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer +from pynumaflow.shared.asynciter import NonBlockingIterator + + +async def noop_handler(datums, output): + async for _ in datums: + pass + + +async def _empty_request_iter(): + return + yield # make it an async generator + + +async def _collect(async_gen): + """Collect all items from an async generator.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_event_on_consumer_cancelled_error(): + """CancelledError while reading the result queue (e.g. SIGTERM) should + set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _cancelled_reader(): + raise asyncio.CancelledError() + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _cancelled_reader() + mock_task_manager.process_input_stream = mock.AsyncMock() + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_event_on_consumer_base_exception(): + """BaseException on the result queue should set shutdown_event AND store the error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _error_reader(): + raise RuntimeError("unexpected consumer error") + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _error_reader() + mock_task_manager.process_input_stream = mock.AsyncMock() + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "unexpected consumer error" in repr(servicer._error) + + asyncio.run(_run()) + + +def test_shutdown_event_on_producer_cancelled_error(): + """CancelledError when awaiting the producer task should set shutdown_event + but NOT store an error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _empty_reader(): + return + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _empty_reader() + + async def _cancelled_producer(_): + raise asyncio.CancelledError() + + mock_task_manager.process_input_stream = _cancelled_producer + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_event_on_producer_base_exception(): + """BaseException from the producer task should set shutdown_event AND store the error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _empty_reader(): + return + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _empty_reader() + + async def _error_producer(_): + raise RuntimeError("producer blew up") + + mock_task_manager.process_input_stream = _error_producer + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "producer blew up" in repr(servicer._error) + + asyncio.run(_run()) + + +def test_shutdown_event_on_result_queue_exception_message(): + """When the result queue yields a BaseException (from task_manager), + shutdown_event should be set and the error stored.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _exception_in_queue(): + yield RuntimeError("handler error from task manager") + + mock_task_manager.global_result_queue.read_iterator.return_value = ( + _exception_in_queue() + ) + mock_task_manager.process_input_stream = mock.AsyncMock() + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler error from task manager" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py new file mode 100644 index 00000000..b759f194 --- /dev/null +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py @@ -0,0 +1,83 @@ +""" +Shutdown-event tests for the async BatchMap servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during MapFn iteration (SIGTERM scenario) +2. Handler RuntimeError caught by the outer BaseException handler +""" + +import asyncio +from unittest import mock + +from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer +from pynumaflow.batchmapper import BatchResponses +from pynumaflow.proto.mapper import map_pb2 +from tests.batchmap.utils import request_generator + + +async def _noop_handler(datums) -> BatchResponses: + async for _ in datums: + pass + return BatchResponses() + + +async def _err_handler(datums) -> BatchResponses: + raise RuntimeError("handler blew up") + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during MapFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncBatchMapServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError caught by the outer BaseException handler; shutdown_event + is set and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncBatchMapServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + data + EOT + sync_datums = list(request_generator(count=2, session=1, handshake=True)) + + async def _request_iter(): + for d in sync_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/map/test_async_map_shutdown.py b/packages/pynumaflow/tests/map/test_async_map_shutdown.py new file mode 100644 index 00000000..af77a76f --- /dev/null +++ b/packages/pynumaflow/tests/map/test_async_map_shutdown.py @@ -0,0 +1,107 @@ +""" +Shutdown-event tests for the async Map servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during MapFn iteration (SIGTERM scenario) +2. Handler RuntimeError surfaced via result queue +3. Bad handshake raising MapError +""" + +import asyncio +from unittest import mock + +from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer +from pynumaflow.mapper._dtypes import Messages, Message, Datum +from pynumaflow.proto.mapper import map_pb2 +from tests.map.utils import get_test_datums + + +async def _noop_handler(keys: list[str], datum: Datum) -> Messages: + return Messages(Message(b"ok", keys=keys)) + + +async def _err_handler(keys: list[str], datum: Datum) -> Messages: + raise RuntimeError("handler blew up") + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during MapFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncMapServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError surfaces via the result queue; shutdown_event is set + and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncMapServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + one datum + test_datums = get_test_datums(handshake=True) + + async def _request_iter(): + for d in test_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) + + +def test_shutdown_on_handshake_error(): + """Missing handshake raises MapError; shutdown_event is set and the error is stored.""" + + async def _run(): + servicer = AsyncMapServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Send data messages without a handshake first + test_datums = get_test_datums(handshake=False) + + async def _request_iter(): + for d in test_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "expected handshake" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py new file mode 100644 index 00000000..317f2581 --- /dev/null +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py @@ -0,0 +1,84 @@ +""" +Shutdown-event tests for the async MapStream servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during MapFn iteration (SIGTERM scenario) +2. Handler RuntimeError surfaced via result queue +""" + +import asyncio +from collections.abc import AsyncIterable +from unittest import mock + +from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer +from pynumaflow.mapstreamer._dtypes import Message +from pynumaflow.mapstreamer import Datum +from pynumaflow.proto.mapper import map_pb2 +from tests.mapstream.utils import request_generator + + +async def _noop_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]: + yield Message(b"ok", keys=keys) + + +async def _err_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]: + raise RuntimeError("handler blew up") + yield # make it an async generator + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during MapFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncMapStreamServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError surfaces via the result queue; shutdown_event is set + and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncMapStreamServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + data + sync_datums = list(request_generator(count=2, handshake=True)) + + async def _request_iter(): + for d in sync_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py b/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py new file mode 100644 index 00000000..b13b6502 --- /dev/null +++ b/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py @@ -0,0 +1,111 @@ +""" +Shutdown-event tests for the async Reduce servicer. + +The ReduceFn method has two try/except blocks where we added shutdown handling: + 1. During request iteration (datum_generator loop) + 2. During task result collection (awaiting futures) + +Tests verify that: + - CancelledError sets shutdown_event but does NOT store an error + - Handler errors set shutdown_event AND store the error +""" + +import asyncio +from collections.abc import AsyncIterable +from unittest import mock + +from pynumaflow.reducer.servicer.async_servicer import AsyncReduceServicer +from pynumaflow.reducer._dtypes import ( + Datum, + Messages, + Message, + Metadata, + Reducer, +) +from pynumaflow.proto.reducer import reduce_pb2 + + +# --------------------------------------------------------------------------- +# Minimal handler — never raises on its own. +# --------------------------------------------------------------------------- + +class _StubReducer(Reducer): + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + async for _ in datums: + pass + return Messages(Message(b"done", keys=keys)) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +async def _collect(async_gen): + """Drain an async generator and return the collected items.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +# --------------------------------------------------------------------------- +# Test 1: CancelledError during request iteration +# +# We feed the servicer a request iterator that raises CancelledError. +# This exercises the first except block in ReduceFn. +# --------------------------------------------------------------------------- + +def test_shutdown_on_cancelled_error(): + """CancelledError during request iteration should set shutdown_event + but NOT store an error.""" + + async def _run(): + servicer = AsyncReduceServicer(handler=_StubReducer) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # An async iterator that immediately raises CancelledError, + # simulating SIGTERM arriving while reading from the gRPC stream. + async def _cancelled_request_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + await _collect(servicer.ReduceFn(_cancelled_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Test 2: Handler raises a real error during request iteration +# +# We feed a request iterator that raises RuntimeError. +# This exercises the BaseException except block in the first try. +# --------------------------------------------------------------------------- + +def test_shutdown_on_handler_error(): + """A real exception during request iteration should set shutdown_event + AND store the error.""" + + async def _run(): + servicer = AsyncReduceServicer(handler=_StubReducer) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _error_request_iter(): + raise RuntimeError("reduce iteration blew up") + yield + + ctx = mock.MagicMock() + await _collect(servicer.ReduceFn(_error_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "reduce iteration blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py b/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py new file mode 100644 index 00000000..3b57b179 --- /dev/null +++ b/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py @@ -0,0 +1,85 @@ +""" +Shutdown-event tests for the async Sink servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during SinkFn iteration (SIGTERM scenario) +2. Handler RuntimeError caught by the outer BaseException handler +""" + +import asyncio +from collections.abc import AsyncIterable +from unittest import mock + +from pynumaflow.sinker.servicer.async_servicer import AsyncSinkServicer +from pynumaflow.sinker import Datum, Responses, Response +from pynumaflow.proto.sinker import sink_pb2 +from tests.sink.test_async_sink import request_generator + + +async def _noop_handler(datums: AsyncIterable[Datum]) -> Responses: + responses = Responses() + async for msg in datums: + responses.append(Response.as_success(msg.id)) + return responses + + +async def _err_handler(datums: AsyncIterable[Datum]) -> Responses: + raise RuntimeError("handler blew up") + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during SinkFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncSinkServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + responses = await _collect(servicer.SinkFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError caught by the outer BaseException handler; shutdown_event + is set and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncSinkServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + data + EOT + sync_datums = list(request_generator(count=2, req_type="success", session=1, handshake=True)) + + async def _request_iter(): + for d in sync_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.SinkFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/source/test_async_source_shutdown.py b/packages/pynumaflow/tests/source/test_async_source_shutdown.py new file mode 100644 index 00000000..d05dad7d --- /dev/null +++ b/packages/pynumaflow/tests/source/test_async_source_shutdown.py @@ -0,0 +1,205 @@ +""" +Shutdown-event tests for the async Source servicer. + +Each test verifies that when a CancelledError is raised during an RPC +(simulating a SIGTERM-triggered task cancellation), the servicer: + - sets the shutdown_event + - does NOT store an error (CancelledError is not a UDF fault) +""" + +import asyncio +from unittest import mock + +from pynumaflow.sourcer.servicer.async_servicer import AsyncSourceServicer +from pynumaflow.sourcer._dtypes import ( + Sourcer, + ReadRequest, + AckRequest, + NackRequest, + PendingResponse, + PartitionsResponse, +) +from pynumaflow.shared.asynciter import NonBlockingIterator +from pynumaflow.proto.sourcer import source_pb2 + + +# --------------------------------------------------------------------------- +# Minimal handler that never raises on its own — individual tests will +# override specific methods or inject CancelledError via the request stream. +# --------------------------------------------------------------------------- + +class _StubSource(Sourcer): + async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): + pass + + async def ack_handler(self, ack_request: AckRequest): + pass + + async def nack_handler(self, nack_request: NackRequest): + pass + + async def pending_handler(self) -> PendingResponse: + return PendingResponse(count=0) + + async def partitions_handler(self) -> PartitionsResponse: + return PartitionsResponse(partitions=[]) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +async def _collect(async_gen): + """Drain an async generator and return the collected items.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +# --------------------------------------------------------------------------- +# ReadFn — streaming RPC. CancelledError raised from the request iterator +# after the handshake has been sent. +# --------------------------------------------------------------------------- + +def test_shutdown_on_read_cancelled_error(): + """CancelledError during ReadFn request iteration should set + shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncSourceServicer(source_handler=_StubSource()) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build a request iterator that sends the handshake then raises + # CancelledError on the next iteration (simulating SIGTERM). + async def _cancelled_request_iter(): + yield source_pb2.ReadRequest(handshake=source_pb2.Handshake(sot=True)) + raise asyncio.CancelledError() + + ctx = mock.MagicMock() + await _collect(servicer.ReadFn(_cancelled_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# AckFn — streaming RPC. Same pattern as ReadFn. +# --------------------------------------------------------------------------- + +def test_shutdown_on_ack_cancelled_error(): + """CancelledError during AckFn request iteration should set + shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncSourceServicer(source_handler=_StubSource()) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_request_iter(): + yield source_pb2.AckRequest(handshake=source_pb2.Handshake(sot=True)) + raise asyncio.CancelledError() + + ctx = mock.MagicMock() + await _collect(servicer.AckFn(_cancelled_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# NackFn — unary RPC. We make the handler raise CancelledError. +# --------------------------------------------------------------------------- + +def test_shutdown_on_nack_cancelled_error(): + """CancelledError during NackFn handler should set + shutdown_event but NOT store an error.""" + + async def _run(): + handler = _StubSource() + + async def _cancelled_nack(nack_request): + raise asyncio.CancelledError() + + handler.nack_handler = _cancelled_nack + + servicer = AsyncSourceServicer(source_handler=handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build a valid NackRequest proto with at least one offset. + offset = source_pb2.Offset(offset=b"test", partition_id=0) + request = source_pb2.NackRequest( + request=source_pb2.NackRequest.Request(offsets=[offset]) + ) + ctx = mock.MagicMock() + # NackFn is a coroutine (not an async generator), so we await it. + await servicer.NackFn(request, ctx) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# PendingFn — unary RPC. Handler raises CancelledError. +# --------------------------------------------------------------------------- + +def test_shutdown_on_pending_cancelled_error(): + """CancelledError during PendingFn handler should set + shutdown_event but NOT store an error.""" + + async def _run(): + handler = _StubSource() + + async def _cancelled_pending(): + raise asyncio.CancelledError() + + handler.pending_handler = _cancelled_pending + + servicer = AsyncSourceServicer(source_handler=handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + ctx = mock.MagicMock() + await servicer.PendingFn(mock.MagicMock(), ctx) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# PartitionsFn — unary RPC. Handler raises CancelledError. +# --------------------------------------------------------------------------- + +def test_shutdown_on_partitions_cancelled_error(): + """CancelledError during PartitionsFn handler should set + shutdown_event but NOT store an error.""" + + async def _run(): + handler = _StubSource() + + async def _cancelled_partitions(): + raise asyncio.CancelledError() + + handler.partitions_handler = _cancelled_partitions + + servicer = AsyncSourceServicer(source_handler=handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + ctx = mock.MagicMock() + await servicer.PartitionsFn(mock.MagicMock(), ctx) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py new file mode 100644 index 00000000..6045cf6a --- /dev/null +++ b/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py @@ -0,0 +1,128 @@ +""" +Shutdown-event tests for the synchronous SourceTransform servicer. + +Mirrors the mapper shutdown test pattern (tests/map/test_sync_map_shutdown.py). +Each test verifies that the servicer sets shutdown_event (and optionally captures the +error) under a specific failure mode, enabling graceful server stop via the watcher +thread in _run_server() instead of a hard process kill. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.sourcetransformer.servicer._servicer import SourceTransformServicer +from pynumaflow.proto.sourcetransformer import transform_pb2 +from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SourceTransformServicer(handler=err_transform_handler) + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + servicer = SourceTransformServicer(handler=transform_handler) + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + # Send a data message without a handshake first + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "SourceTransformFn: expected handshake message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + servicer = SourceTransformServicer(handler=transform_handler) + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + servicer = SourceTransformServicer(handler=transform_handler) + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + # Should have at least the handshake response + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None From 1b4db2180fa14b2f064314b4b55d757e96be6384 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 11:20:21 +0530 Subject: [PATCH 11/15] fix lints Signed-off-by: Sreekanth --- .../accumulator/test_async_accumulator_shutdown.py | 5 +---- .../tests/batchmap/test_async_batch_map_shutdown.py | 3 +-- .../pynumaflow/tests/map/test_async_map_shutdown.py | 5 ++--- .../mapstream/test_async_map_stream_shutdown.py | 3 +-- .../tests/reduce/test_async_reduce_shutdown.py | 6 ++++-- .../tests/sink/test_async_sink_shutdown.py | 7 ++++--- .../tests/source/test_async_source_shutdown.py | 12 ++++++++---- 7 files changed, 21 insertions(+), 20 deletions(-) diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py index 5a7020fa..6d84eb5b 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py @@ -13,7 +13,6 @@ from unittest import mock from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer -from pynumaflow.shared.asynciter import NonBlockingIterator async def noop_handler(datums, output): @@ -180,9 +179,7 @@ async def _run(): async def _exception_in_queue(): yield RuntimeError("handler error from task manager") - mock_task_manager.global_result_queue.read_iterator.return_value = ( - _exception_in_queue() - ) + mock_task_manager.global_result_queue.read_iterator.return_value = _exception_in_queue() mock_task_manager.process_input_stream = mock.AsyncMock() with mock.patch( diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py index b759f194..c22ce7ee 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py @@ -11,7 +11,6 @@ from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer from pynumaflow.batchmapper import BatchResponses -from pynumaflow.proto.mapper import map_pb2 from tests.batchmap.utils import request_generator @@ -46,7 +45,7 @@ async def _cancelled_iter(): yield # make it an async generator ctx = mock.MagicMock() - responses = await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + await _collect(servicer.MapFn(_cancelled_iter(), ctx)) assert shutdown_event.is_set() assert servicer._error is None diff --git a/packages/pynumaflow/tests/map/test_async_map_shutdown.py b/packages/pynumaflow/tests/map/test_async_map_shutdown.py index af77a76f..cfbf3e50 100644 --- a/packages/pynumaflow/tests/map/test_async_map_shutdown.py +++ b/packages/pynumaflow/tests/map/test_async_map_shutdown.py @@ -12,7 +12,6 @@ from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer from pynumaflow.mapper._dtypes import Messages, Message, Datum -from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import get_test_datums @@ -45,7 +44,7 @@ async def _cancelled_iter(): yield # make it an async generator ctx = mock.MagicMock() - responses = await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + await _collect(servicer.MapFn(_cancelled_iter(), ctx)) assert shutdown_event.is_set() assert servicer._error is None @@ -98,7 +97,7 @@ async def _request_iter(): yield d ctx = mock.MagicMock() - responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + await _collect(servicer.MapFn(_request_iter(), ctx)) assert shutdown_event.is_set() assert servicer._error is not None diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py index 317f2581..697132f5 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py @@ -13,7 +13,6 @@ from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer from pynumaflow.mapstreamer._dtypes import Message from pynumaflow.mapstreamer import Datum -from pynumaflow.proto.mapper import map_pb2 from tests.mapstream.utils import request_generator @@ -47,7 +46,7 @@ async def _cancelled_iter(): yield # make it an async generator ctx = mock.MagicMock() - responses = await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + await _collect(servicer.MapFn(_cancelled_iter(), ctx)) assert shutdown_event.is_set() assert servicer._error is None diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py b/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py index b13b6502..a999de5b 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py @@ -22,13 +22,12 @@ Metadata, Reducer, ) -from pynumaflow.proto.reducer import reduce_pb2 - # --------------------------------------------------------------------------- # Minimal handler — never raises on its own. # --------------------------------------------------------------------------- + class _StubReducer(Reducer): async def handler( self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata @@ -42,6 +41,7 @@ async def handler( # Helpers # --------------------------------------------------------------------------- + async def _collect(async_gen): """Drain an async generator and return the collected items.""" results = [] @@ -57,6 +57,7 @@ async def _collect(async_gen): # This exercises the first except block in ReduceFn. # --------------------------------------------------------------------------- + def test_shutdown_on_cancelled_error(): """CancelledError during request iteration should set shutdown_event but NOT store an error.""" @@ -88,6 +89,7 @@ async def _cancelled_request_iter(): # This exercises the BaseException except block in the first try. # --------------------------------------------------------------------------- + def test_shutdown_on_handler_error(): """A real exception during request iteration should set shutdown_event AND store the error.""" diff --git a/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py b/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py index 3b57b179..037c3144 100644 --- a/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py +++ b/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py @@ -12,7 +12,6 @@ from pynumaflow.sinker.servicer.async_servicer import AsyncSinkServicer from pynumaflow.sinker import Datum, Responses, Response -from pynumaflow.proto.sinker import sink_pb2 from tests.sink.test_async_sink import request_generator @@ -48,7 +47,7 @@ async def _cancelled_iter(): yield # make it an async generator ctx = mock.MagicMock() - responses = await _collect(servicer.SinkFn(_cancelled_iter(), ctx)) + await _collect(servicer.SinkFn(_cancelled_iter(), ctx)) assert shutdown_event.is_set() assert servicer._error is None @@ -66,7 +65,9 @@ async def _run(): servicer.set_shutdown_event(shutdown_event) # Build an async request iterator with handshake + data + EOT - sync_datums = list(request_generator(count=2, req_type="success", session=1, handshake=True)) + sync_datums = list( + request_generator(count=2, req_type="success", session=1, handshake=True) + ) async def _request_iter(): for d in sync_datums: diff --git a/packages/pynumaflow/tests/source/test_async_source_shutdown.py b/packages/pynumaflow/tests/source/test_async_source_shutdown.py index d05dad7d..b6e537a5 100644 --- a/packages/pynumaflow/tests/source/test_async_source_shutdown.py +++ b/packages/pynumaflow/tests/source/test_async_source_shutdown.py @@ -22,12 +22,12 @@ from pynumaflow.shared.asynciter import NonBlockingIterator from pynumaflow.proto.sourcer import source_pb2 - # --------------------------------------------------------------------------- # Minimal handler that never raises on its own — individual tests will # override specific methods or inject CancelledError via the request stream. # --------------------------------------------------------------------------- + class _StubSource(Sourcer): async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): pass @@ -49,6 +49,7 @@ async def partitions_handler(self) -> PartitionsResponse: # Helpers # --------------------------------------------------------------------------- + async def _collect(async_gen): """Drain an async generator and return the collected items.""" results = [] @@ -62,6 +63,7 @@ async def _collect(async_gen): # after the handshake has been sent. # --------------------------------------------------------------------------- + def test_shutdown_on_read_cancelled_error(): """CancelledError during ReadFn request iteration should set shutdown_event but NOT store an error.""" @@ -90,6 +92,7 @@ async def _cancelled_request_iter(): # AckFn — streaming RPC. Same pattern as ReadFn. # --------------------------------------------------------------------------- + def test_shutdown_on_ack_cancelled_error(): """CancelledError during AckFn request iteration should set shutdown_event but NOT store an error.""" @@ -116,6 +119,7 @@ async def _cancelled_request_iter(): # NackFn — unary RPC. We make the handler raise CancelledError. # --------------------------------------------------------------------------- + def test_shutdown_on_nack_cancelled_error(): """CancelledError during NackFn handler should set shutdown_event but NOT store an error.""" @@ -134,9 +138,7 @@ async def _cancelled_nack(nack_request): # Build a valid NackRequest proto with at least one offset. offset = source_pb2.Offset(offset=b"test", partition_id=0) - request = source_pb2.NackRequest( - request=source_pb2.NackRequest.Request(offsets=[offset]) - ) + request = source_pb2.NackRequest(request=source_pb2.NackRequest.Request(offsets=[offset])) ctx = mock.MagicMock() # NackFn is a coroutine (not an async generator), so we await it. await servicer.NackFn(request, ctx) @@ -151,6 +153,7 @@ async def _cancelled_nack(nack_request): # PendingFn — unary RPC. Handler raises CancelledError. # --------------------------------------------------------------------------- + def test_shutdown_on_pending_cancelled_error(): """CancelledError during PendingFn handler should set shutdown_event but NOT store an error.""" @@ -180,6 +183,7 @@ async def _cancelled_pending(): # PartitionsFn — unary RPC. Handler raises CancelledError. # --------------------------------------------------------------------------- + def test_shutdown_on_partitions_cancelled_error(): """CancelledError during PartitionsFn handler should set shutdown_event but NOT store an error.""" From abecb268c2c685f7d3024e513cd0aced4625c3e0 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 17 Mar 2026 14:27:40 +0530 Subject: [PATCH 12/15] more tests Signed-off-by: Sreekanth --- .../sourcetransform/test_async_shutdown.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py diff --git a/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py new file mode 100644 index 00000000..74530e99 --- /dev/null +++ b/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py @@ -0,0 +1,66 @@ +""" +Shutdown-event tests for the async SourceTransform servicer. + +Covers the CancelledError and BaseException handlers in SourceTransformFn. +""" + +import asyncio +from unittest import mock + +from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer +from pynumaflow.sourcetransformer import Datum, Messages, Message +from tests.testing_utils import mock_new_event_time + + +async def async_transform_handler(keys: list[str], datum: Datum) -> Messages: + return Messages(Message(datum.value, mock_new_event_time(), keys=keys)) + + +async def _collect(async_gen): + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during SourceTransformFn should set shutdown_event, no error stored.""" + + async def _run(): + servicer = SourceTransformAsyncServicer(handler=async_transform_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield + + ctx = mock.MagicMock() + await _collect(servicer.SourceTransformFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """BaseException in SourceTransformFn should set shutdown_event and store error.""" + + async def _run(): + servicer = SourceTransformAsyncServicer(handler=async_transform_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _error_iter(): + raise RuntimeError("unexpected error") + yield + + ctx = mock.MagicMock() + await _collect(servicer.SourceTransformFn(_error_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "unexpected error" in repr(servicer._error) + + asyncio.run(_run()) From f5f73ce3d02a6184117de761f49590fa0b2ee8ae Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Wed, 18 Mar 2026 05:03:23 +0530 Subject: [PATCH 13/15] more tests, remove unused variables Signed-off-by: Sreekanth --- .../mapper/_servicer/_sync_servicer.py | 4 +- .../pynumaflow/mapper/multiproc_server.py | 2 +- .../sourcetransformer/multiproc_server.py | 2 +- .../sourcetransformer/servicer/_servicer.py | 4 +- .../tests/map/test_multiproc_map_shutdown.py | 119 ++++++++++++++++ .../tests/sideinput/test_shutdown.py | 72 ++++++++++ .../test_multiproc_shutdown.py | 128 ++++++++++++++++++ 7 files changed, 323 insertions(+), 8 deletions(-) create mode 100644 packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py create mode 100644 packages/pynumaflow/tests/sideinput/test_shutdown.py create mode 100644 packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index 9415f0a4..49d0898f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -21,10 +21,8 @@ class SyncMapServicer(map_pb2_grpc.MapServicer): Provides the functionality for the required rpc methods. """ - def __init__(self, handler: MapSyncCallable, multiproc: bool = False): + def __init__(self, handler: MapSyncCallable): self.__map_handler: MapSyncCallable = handler - # This indicates whether the grpc server attached is multiproc or not - self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) # Graceful shutdown: when set, a watcher thread in _run_server() calls diff --git a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py index 4055666f..5f7c6567 100644 --- a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py @@ -106,7 +106,7 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: # Setting the max value to 2 * CPU count # Used for multiproc server self._process_count = min(server_count, 2 * _PROCESS_COUNT) - self.servicer = SyncMapServicer(handler=mapper_instance, multiproc=True) + self.servicer = SyncMapServicer(handler=mapper_instance) # Shared event across all worker processes for coordinated shutdown. # When any worker's servicer sets this event, all workers' watcher diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py index fba85017..3e7b150f 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py @@ -131,7 +131,7 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: # Setting the max value to 2 * CPU count # Used for multiproc server self._process_count = min(server_count, 2 * _PROCESS_COUNT) - self.servicer = SourceTransformServicer(handler=source_transform_instance, multiproc=True) + self.servicer = SourceTransformServicer(handler=source_transform_instance) # Shared event across all worker processes for coordinated shutdown. # When any worker's servicer sets this event, all workers' watcher diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py index 7508e41f..3945e194 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py @@ -41,10 +41,8 @@ class SourceTransformServicer(transform_pb2_grpc.SourceTransformServicer): Provides the functionality for the required rpc methods. """ - def __init__(self, handler: SourceTransformCallable, multiproc: bool = False): + def __init__(self, handler: SourceTransformCallable): self.__transform_handler: SourceTransformCallable = handler - # This indicates whether the grpc server attached is multiproc or not - self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) # Graceful shutdown: when set, a watcher thread in _run_server() calls diff --git a/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py b/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py new file mode 100644 index 00000000..669fd77d --- /dev/null +++ b/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py @@ -0,0 +1,119 @@ +""" +Shutdown-event tests for the multiproc Map servicer. + +These tests verify that the SyncMapServicer (as used by MapMultiprocServer) +correctly sets shutdown_event on error, enabling coordinated graceful shutdown +across all worker processes via the shared multiprocessing.Event. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.mapper import MapMultiprocServer +from pynumaflow.proto.mapper import map_pb2 +from tests.map.utils import map_handler, err_map_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + server = MapMultiprocServer(mapper_instance=err_map_handler) + servicer = server.servicer + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + server = MapMultiprocServer(mapper_instance=map_handler) + servicer = server.servicer + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "MapFn: expected handshake as the first message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + server = MapMultiprocServer(mapper_instance=map_handler) + servicer = server.servicer + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + server = MapMultiprocServer(mapper_instance=map_handler) + servicer = server.servicer + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + assert servicer.error is None diff --git a/packages/pynumaflow/tests/sideinput/test_shutdown.py b/packages/pynumaflow/tests/sideinput/test_shutdown.py new file mode 100644 index 00000000..66f55160 --- /dev/null +++ b/packages/pynumaflow/tests/sideinput/test_shutdown.py @@ -0,0 +1,72 @@ +""" +Shutdown-event tests for the SideInput servicer. + +Verifies that the servicer sets shutdown_event and captures the error when the +UDF handler raises, enabling graceful server stop via the watcher thread in +_run_server() instead of a hard process kill. +""" + +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.sideinput.servicer.servicer import SideInputServicer +from pynumaflow.proto.sideinput import sideinput_pb2 + + +def _ok_handler(): + from pynumaflow.sideinput import Response + + return Response.broadcast_message(b"test") + + +def _err_handler(): + raise RuntimeError("Something is fishy!") + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SideInputServicer(handler=_err_handler) + + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + method = test_server.invoke_unary_unary( + method_descriptor=( + sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ + "RetrieveSideInput" + ] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + _, _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_not_set_on_success(): + """On a successful call, shutdown_event must remain unset.""" + servicer = SideInputServicer(handler=_ok_handler) + + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + method = test_server.invoke_unary_unary( + method_descriptor=( + sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ + "RetrieveSideInput" + ] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + _, _, code, _ = method.termination() + assert code == StatusCode.OK + assert not servicer.shutdown_event.is_set() + assert servicer.error is None diff --git a/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py new file mode 100644 index 00000000..bd4a78e4 --- /dev/null +++ b/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py @@ -0,0 +1,128 @@ +""" +Shutdown-event tests for the multiproc SourceTransform servicer. + +These tests verify that the SourceTransformServicer (as used by +SourceTransformMultiProcServer) correctly sets shutdown_event on error, +enabling coordinated graceful shutdown across all worker processes via +the shared multiprocessing.Event. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer +from pynumaflow.proto.sourcetransformer import transform_pb2 +from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + server = SourceTransformMultiProcServer(source_transform_instance=err_transform_handler) + servicer = server.servicer + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + servicer = server.servicer + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "SourceTransformFn: expected handshake message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + servicer = server.servicer + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + servicer = server.servicer + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + assert servicer.error is None From 6b360ac3ac871e3ff34f666dadda53efa129740a Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Wed, 18 Mar 2026 05:28:35 +0530 Subject: [PATCH 14/15] Use get_running_loop instead of get_event_loop Signed-off-by: Sreekanth --- packages/pynumaflow/pynumaflow/accumulator/async_server.py | 2 +- packages/pynumaflow/pynumaflow/batchmapper/async_server.py | 2 +- packages/pynumaflow/pynumaflow/mapper/async_server.py | 2 +- packages/pynumaflow/pynumaflow/mapstreamer/async_server.py | 2 +- packages/pynumaflow/pynumaflow/reducer/async_server.py | 2 +- packages/pynumaflow/pynumaflow/reducestreamer/async_server.py | 2 +- packages/pynumaflow/pynumaflow/sinker/async_server.py | 2 +- packages/pynumaflow/pynumaflow/sourcer/async_server.py | 2 +- .../pynumaflow/pynumaflow/sourcetransformer/async_server.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/accumulator/async_server.py b/packages/pynumaflow/pynumaflow/accumulator/async_server.py index 36146e68..f15cf03f 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/async_server.py +++ b/packages/pynumaflow/pynumaflow/accumulator/async_server.py @@ -239,5 +239,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py index 78d9f12e..3b399730 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py @@ -174,5 +174,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapper/async_server.py b/packages/pynumaflow/pynumaflow/mapper/async_server.py index 9078e5f4..bb685232 100644 --- a/packages/pynumaflow/pynumaflow/mapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/async_server.py @@ -172,5 +172,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py index 2c833d43..6e6af9ce 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py @@ -193,5 +193,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/reducer/async_server.py b/packages/pynumaflow/pynumaflow/reducer/async_server.py index cc52d15e..33800120 100644 --- a/packages/pynumaflow/pynumaflow/reducer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducer/async_server.py @@ -225,5 +225,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index 5ac14455..9200c925 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -237,5 +237,5 @@ async def _watch_for_shutdown(): await shutdown_task _LOGGER.info("Stopping event loop...") - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sinker/async_server.py b/packages/pynumaflow/pynumaflow/sinker/async_server.py index d2c3c0b8..129bf6b9 100644 --- a/packages/pynumaflow/pynumaflow/sinker/async_server.py +++ b/packages/pynumaflow/pynumaflow/sinker/async_server.py @@ -196,5 +196,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcer/async_server.py b/packages/pynumaflow/pynumaflow/sourcer/async_server.py index b8ef4965..3eea4e65 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcer/async_server.py @@ -233,5 +233,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py index ee25c0af..05939eed 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py @@ -195,5 +195,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") From e0e2624dc44139b360072e74cd6624ff966054ce Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Wed, 18 Mar 2026 08:56:33 +0530 Subject: [PATCH 15/15] More documentation comments Signed-off-by: Sreekanth --- .../pynumaflow/pynumaflow/accumulator/async_server.py | 9 ++++++--- .../pynumaflow/pynumaflow/batchmapper/async_server.py | 9 ++++++--- packages/pynumaflow/pynumaflow/mapper/async_server.py | 9 ++++++--- .../pynumaflow/pynumaflow/mapstreamer/async_server.py | 9 ++++++--- packages/pynumaflow/pynumaflow/reducer/async_server.py | 9 ++++++--- .../pynumaflow/reducer/servicer/async_servicer.py | 4 +--- .../pynumaflow/pynumaflow/reducestreamer/async_server.py | 9 ++++++--- packages/pynumaflow/pynumaflow/sinker/async_server.py | 9 ++++++--- packages/pynumaflow/pynumaflow/sourcer/async_server.py | 9 ++++++--- .../pynumaflow/sourcetransformer/async_server.py | 9 ++++++--- 10 files changed, 55 insertions(+), 30 deletions(-) diff --git a/packages/pynumaflow/pynumaflow/accumulator/async_server.py b/packages/pynumaflow/pynumaflow/accumulator/async_server.py index f15cf03f..63191151 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/async_server.py +++ b/packages/pynumaflow/pynumaflow/accumulator/async_server.py @@ -220,9 +220,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py index 3b399730..0bdcdaac 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py @@ -155,9 +155,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/mapper/async_server.py b/packages/pynumaflow/pynumaflow/mapper/async_server.py index bb685232..a48031e6 100644 --- a/packages/pynumaflow/pynumaflow/mapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/async_server.py @@ -153,9 +153,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py index 6e6af9ce..f670e156 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py @@ -174,9 +174,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/reducer/async_server.py b/packages/pynumaflow/pynumaflow/reducer/async_server.py index 33800120..3dadd67f 100644 --- a/packages/pynumaflow/pynumaflow/reducer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducer/async_server.py @@ -206,9 +206,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py index 7b147064..d027e999 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py @@ -159,9 +159,7 @@ async def ReduceFn( except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) - err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" - _LOGGER.critical(err_msg, exc_info=True) - update_context_err(context, e, err_msg) + update_context_err(context, e, f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}") self._error = e if self._shutdown_event is not None: self._shutdown_event.set() diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index 9200c925..07d02cba 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -223,9 +223,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/sinker/async_server.py b/packages/pynumaflow/pynumaflow/sinker/async_server.py index 129bf6b9..03f2f3fe 100644 --- a/packages/pynumaflow/pynumaflow/sinker/async_server.py +++ b/packages/pynumaflow/pynumaflow/sinker/async_server.py @@ -177,9 +177,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/sourcer/async_server.py b/packages/pynumaflow/pynumaflow/sourcer/async_server.py index 3eea4e65..9c9e6072 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcer/async_server.py @@ -214,9 +214,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py index 05939eed..28623f47 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py @@ -176,9 +176,12 @@ async def _watch_for_shutdown(): try: await server.wait_for_termination() except asyncio.CancelledError: - # SIGTERM received — aiorun cancels all tasks. We must stop - # the gRPC server explicitly so its __del__ doesn't try to - # schedule a coroutine on the already-closed event loop. + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. _LOGGER.info("Received cancellation, stopping server gracefully...") await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)