diff --git a/packages/pynumaflow/pynumaflow/accumulator/async_server.py b/packages/pynumaflow/pynumaflow/accumulator/async_server.py index 200e4422..63191151 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/async_server.py +++ b/packages/pynumaflow/pynumaflow/accumulator/async_server.py @@ -1,9 +1,13 @@ +import asyncio +import contextlib import inspect +import sys import aiorun import grpc from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.proto.accumulator import accumulator_pb2_grpc @@ -15,6 +19,7 @@ MAX_NUM_THREADS, ACCUMULATOR_SOCK_PATH, ACCUMULATOR_SERVER_INFO_FILE_PATH, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.accumulator._dtypes import ( @@ -23,7 +28,7 @@ Accumulator, ) -from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server +from pynumaflow.shared.server import NumaflowServer, check_instance def get_handler( @@ -157,6 +162,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncAccumulatorServicer(self.accumulator_handler) + self._error: BaseException | None = None def start(self): """ @@ -167,6 +173,9 @@ def start(self): "Starting Async Accumulator Server", ) aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -176,18 +185,62 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator] - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_running_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py index 16be7911..886f0e66 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py @@ -3,7 +3,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING +from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc from pynumaflow.accumulator._dtypes import ( Datum, @@ -13,7 +13,7 @@ KeyedWindow, ) from pynumaflow.accumulator.servicer.task_manager import TaskManager -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -57,6 +57,12 @@ def __init__( ): # The accumulator handler can be a function or a builder class instance. self.__accumulator_handler: AccumulatorAsyncCallable | _AccumulatorBuilderClass = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def AccumulateFn( self, @@ -104,20 +110,49 @@ async def AccumulateFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window EOF response or Window result response # back to the client else: yield msg + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return # Wait for the process_input_stream task to finish for a clean exit try: await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py index a9758bf7..ee14e3ea 100644 --- a/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py @@ -213,6 +213,9 @@ async def __invoke_accumulator( _ = await new_instance(request_iterator, output) # send EOF to the output stream await output.put(STREAM_EOF) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + return # If there is an error in the accumulator operation, log and # then send the error to the result queue except BaseException as err: @@ -243,6 +246,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator case _: _LOGGER.debug(f"No operation matched for request: {request}", exc_info=True) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + return # If there is an error in the accumulator operation, log and # then send the error to the result queue except BaseException as e: @@ -274,6 +280,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator # Now send STREAM_EOF to terminate the global result queue iterator await self.global_result_queue.put(STREAM_EOF) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + return except BaseException as e: err_msg = f"Accumulator Streaming Error: {repr(e)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py index 1078e012..0bdcdaac 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -8,9 +12,11 @@ BATCH_MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.batchmapper._dtypes import BatchMapCallable from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -19,7 +25,7 @@ ContainerType, ) from pynumaflow.proto.mapper import map_pb2_grpc -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer class BatchMapAsyncServer(NumaflowServer): @@ -92,6 +98,7 @@ async def handler( ] self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance) + self._error: BaseException | None = None def start(self): """ @@ -99,6 +106,9 @@ def start(self): to the aexec so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -108,25 +118,64 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server( - self.servicer, - server, - ) - _LOGGER.info("Starting Batch Map Server") + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) + serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.BatchMap - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_running_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py index 523a4ad4..63733961 100644 --- a/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/batchmapper/servicer/async_servicer.py @@ -7,7 +7,7 @@ from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext from pynumaflow._constants import _LOGGER, STREAM_EOF, ERR_UDF_EXCEPTION_STRING @@ -26,6 +26,12 @@ def __init__( ): self.background_tasks = set() self.__batch_map_handler: BatchMapCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -96,9 +102,20 @@ async def MapFn( ) await req_queue.put(datum) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() return async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py index 90a55b7b..df0265e4 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow.mapper._dtypes import MapAsyncCallable, Datum, MapError, Message, Messages from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -25,6 +25,12 @@ def __init__( ): self.background_tasks = set() self.__map_handler: MapAsyncCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -57,16 +63,32 @@ async def MapFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window response back to the client else: yield msg # wait for the producer task to complete await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index cb757e3c..49d0898f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -2,8 +2,9 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterator +import grpc from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER, ERR_UDF_EXCEPTION_STRING @@ -20,12 +21,14 @@ class SyncMapServicer(map_pb2_grpc.MapServicer): Provides the functionality for the required rpc methods. """ - def __init__(self, handler: MapSyncCallable, multiproc: bool = False): + def __init__(self, handler: MapSyncCallable): self.__map_handler: MapSyncCallable = handler - # This indicates whether the grpc server attached is multiproc or not - self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + # Graceful shutdown: when set, a watcher thread in _run_server() calls + # server.stop() instead of hard-killing the process via psutil. + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def MapFn( self, @@ -36,6 +39,7 @@ def MapFn( Applies a function to each datum element. The pascal case function name comes from the proto map_pb2_grpc.py file. """ + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -57,10 +61,19 @@ def MapFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + if isinstance(res, grpc.RpcError): + # Client disconnected mid-stream — the reader thread + # surfaced the error via the queue. Not a UDF fault. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + result_queue.close() + self.shutdown_event.set() + return + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + update_context_err(context, res, err_msg) + # Unblock the reader thread if it is waiting on queue.put() + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -69,12 +82,23 @@ def MapFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + # Client disconnected — not a UDF error, but we still need to + # shut down the server so the process can exit cleanly. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( @@ -92,6 +116,11 @@ def _process_requests( self.executor.shutdown(wait=True) # Indicate to the result queue that no more messages left to process result_queue.put(STREAM_EOF) + except grpc.RpcError as e: + # Client disconnected — expected during pod shutdown. + # Surface to the consumer which will trigger graceful shutdown. + _LOGGER.warning("gRPC stream closed in reader thread") + result_queue.put(e) except BaseException as e: _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) # Surface the error to the consumer; MapFn will handle and exit diff --git a/packages/pynumaflow/pynumaflow/mapper/async_server.py b/packages/pynumaflow/pynumaflow/mapper/async_server.py index 5bba75d7..a48031e6 100644 --- a/packages/pynumaflow/pynumaflow/mapper/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -7,7 +11,10 @@ MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -18,10 +25,7 @@ from pynumaflow.mapper._dtypes import MapAsyncCallable from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer from pynumaflow.proto.mapper import map_pb2_grpc -from pynumaflow.shared.server import ( - NumaflowServer, - start_async_server, -) +from pynumaflow.shared.server import NumaflowServer class MapAsyncServer(NumaflowServer): @@ -92,6 +96,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncMapServicer(handler=mapper_instance) + self._error: BaseException | None = None def start(self) -> None: """ @@ -99,32 +104,76 @@ def start(self) -> None: so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self) -> None: """ Starts the Async gRPC server on the given UNIX socket with given max threads. """ - # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) - server_new = grpc.aio.server(options=self._server_options) - server_new.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server(self.servicer, server_new) + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap - # Start the async server - await start_async_server( - server_async=server_new, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_running_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py index 5d68a96b..5f7c6567 100644 --- a/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/multiproc_server.py @@ -1,4 +1,8 @@ +import multiprocessing +import sys + from pynumaflow._constants import ( + _LOGGER, NUM_THREADS_DEFAULT, MAX_MESSAGE_SIZE, MAP_SOCK_PATH, @@ -102,7 +106,12 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: # Setting the max value to 2 * CPU count # Used for multiproc server self._process_count = min(server_count, 2 * _PROCESS_COUNT) - self.servicer = SyncMapServicer(handler=mapper_instance, multiproc=True) + self.servicer = SyncMapServicer(handler=mapper_instance) + + # Shared event across all worker processes for coordinated shutdown. + # When any worker's servicer sets this event, all workers' watcher + # threads trigger server.stop() for a graceful coordinated exit. + self._shutdown_event = multiprocessing.Event() def start(self) -> None: """ @@ -121,7 +130,7 @@ def start(self) -> None: server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap # Start the multiproc server - start_multiproc_server( + has_error = start_multiproc_server( max_threads=self.max_threads, servicer=self.servicer, process_count=self._process_count, @@ -129,4 +138,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=server_info, + shutdown_event=self._shutdown_event, ) + + if has_error: + _LOGGER.critical("Server exiting due to worker error") + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/mapper/sync_server.py b/packages/pynumaflow/pynumaflow/mapper/sync_server.py index 9c2431b6..a96ceb7f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/sync_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/sync_server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -112,4 +114,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py index 187c720d..f670e156 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/async_server.py @@ -1,6 +1,11 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -18,11 +23,12 @@ _LOGGER, MAP_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.mapstreamer._dtypes import MapStreamCallable -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer class MapStreamAsyncServer(NumaflowServer): @@ -111,6 +117,7 @@ async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Messag ] self.servicer = AsyncMapStreamServicer(handler=self.map_stream_instance) + self._error: BaseException | None = None def start(self): """ @@ -118,6 +125,9 @@ def start(self): to the aexec so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -127,25 +137,64 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - map_pb2_grpc.add_MapServicer_to_server( - self.servicer, - server, - ) - _LOGGER.info("Starting Map Stream Server") + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) + serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper] # Add the MAP_MODE metadata to the server info for the correct map mode serv_info.metadata[MAP_MODE_KEY] = MapMode.StreamMap - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_running_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py index f5a9a999..77942b71 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow.mapstreamer import Datum from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2 -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -22,6 +22,12 @@ class AsyncMapStreamServicer(map_pb2_grpc.MapServicer): def __init__(self, handler: MapStreamCallable): self.__map_stream_handler: MapStreamCallable = handler self._background_tasks: set[asyncio.Task] = set() + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def MapFn( self, @@ -51,7 +57,12 @@ async def MapFn( # Consume results as they arrive and stream them to the client async for msg in global_result_queue.read_iterator(): if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return else: # msg is a map_pb2.MapResponse, already formed @@ -60,9 +71,20 @@ async def MapFn( # Ensure producer has finished (covers graceful shutdown) await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( @@ -124,7 +146,7 @@ async def _invoke_map_stream( except BaseException as err: _LOGGER.critical("MapFn handler error", exc_info=True) # Surface handler error to the main producer; - # it will call handle_async_error and end the RPC + # it will set the shutdown event and end the RPC await result_queue.put(err) async def IsReady( diff --git a/packages/pynumaflow/pynumaflow/reducer/async_server.py b/packages/pynumaflow/pynumaflow/reducer/async_server.py index aee4d355..3dadd67f 100644 --- a/packages/pynumaflow/pynumaflow/reducer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducer/async_server.py @@ -1,8 +1,12 @@ +import asyncio +import contextlib import inspect +import sys import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.proto.reducer import reduce_pb2_grpc @@ -15,6 +19,7 @@ _LOGGER, REDUCE_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.reducer._dtypes import ( @@ -23,7 +28,7 @@ Reducer, ) -from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server +from pynumaflow.shared.server import NumaflowServer, check_instance def get_handler( @@ -143,6 +148,7 @@ def __init__( ] # Get the servicer instance for the async server self.servicer = AsyncReduceServicer(self.reducer_handler) + self._error: BaseException | None = None def start(self): """ @@ -153,6 +159,9 @@ def start(self): "Starting Async Reduce Server", ) aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -162,20 +171,62 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - reduce_servicer = self.servicer - reduce_pb2_grpc.add_ReduceServicer_to_server(reduce_servicer, server) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + reduce_pb2_grpc.add_ReduceServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Reducer] - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_running_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py index 44db8077..d027e999 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/async_servicer.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import AsyncIterable from google.protobuf import empty_pb2 as _empty_pb2 @@ -12,7 +13,7 @@ WindowOperation, ) from pynumaflow.reducer.servicer.task_manager import TaskManager -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -53,6 +54,12 @@ def __init__( ): # The Reduce handler can be a function or a builder class instance. self.__reduce_handler: ReduceAsyncCallable | _ReduceBuilderClass = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def ReduceFn( self, @@ -101,11 +108,22 @@ async def ReduceFn( # append the task data to the existing task # if the task does not exist, it will create a new task await task_manager.append_task(request) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() + return # send EOF to all the tasks once the request iterator is exhausted # This will signal the tasks to stop reading the data on their @@ -132,11 +150,20 @@ async def ReduceFn( for window in current_window.values(): # yield the EOF response once the task is completed for a keyed window yield reduce_pb2.ReduceResponse(window=window, EOF=True) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: _LOGGER.critical("Reduce Error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + update_context_err(context, e, f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}") + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext diff --git a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py index 2e21d60a..3023a706 100644 --- a/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py +++ b/packages/pynumaflow/pynumaflow/reducer/servicer/task_manager.py @@ -2,8 +2,6 @@ from datetime import datetime, timezone from collections.abc import AsyncIterable -import grpc - from pynumaflow.exceptions import UDFError from pynumaflow.proto.reducer import reduce_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator @@ -21,7 +19,7 @@ ReduceAsyncCallable, ReduceWindow, ) -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.types import NumaflowServicerContext @@ -167,16 +165,14 @@ async def __invoke_reduce( new_instance = self.__reduce_handler.create() try: msgs = await new_instance(keys, request_iterator, md) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + raise except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Send a context abort signal for the rpc, this is required for numa container to get - # the correct grpc error - await asyncio.gather( - self.context.abort(grpc.StatusCode.UNKNOWN, details=repr(err)), - return_exceptions=True, - ) - exit_on_error(err=repr(err), parent=False, context=self.context, update_context=False) - return + err_msg = f"ReduceError: {repr(err)}" + update_context_err(self.context, err, err_msg) + raise datum_responses = [] for msg in msgs: diff --git a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py index 63123c88..07d02cba 100644 --- a/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py +++ b/packages/pynumaflow/pynumaflow/reducestreamer/async_server.py @@ -220,7 +220,17 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error @@ -230,5 +240,5 @@ async def _watch_for_shutdown(): await shutdown_task _LOGGER.info("Stopping event loop...") - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/shared/server.py b/packages/pynumaflow/pynumaflow/shared/server.py index 3986e0dc..943f3cee 100644 --- a/packages/pynumaflow/pynumaflow/shared/server.py +++ b/packages/pynumaflow/pynumaflow/shared/server.py @@ -1,8 +1,9 @@ -import asyncio import contextlib import io import multiprocessing +import multiprocessing.synchronize import os +import signal import socket import threading import traceback @@ -14,7 +15,6 @@ from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor import grpc -import psutil from pynumaflow._constants import ( _LOGGER, @@ -151,6 +151,7 @@ def start_multiproc_server( server_info: ServerInfo | None = None, server_options=None, udf_type: str = UDFType.Map, + shutdown_event: multiprocessing.synchronize.Event | None = None, ): """ Start N grpc servers in different processes where N = The number of CPUs or the @@ -179,20 +180,38 @@ def start_multiproc_server( worker = multiprocessing.Process( target=_run_server, args=(servicer, bind_address, max_threads, server_options, udf_type), + kwargs={"shutdown_event": shutdown_event} if shutdown_event else {}, ) worker.start() workers.append(worker) if server_info is None: server_info = ServerInfo.get_default_server_info() - server_info.metadata = get_metadata_env(envs=METADATA_ENVS) + # Merge env metadata into existing metadata (preserving caller-set keys + # like MAP_MODE_KEY) rather than overwriting the entire dict. + server_info.metadata.update(get_metadata_env(envs=METADATA_ENVS)) # Add the MULTIPROC metadata using the number of servers to use server_info.metadata[MULTIPROC_KEY] = str(process_count) info_server_write(server_info=server_info, info_file=server_info_file) + # Register a SIGTERM handler so that kubectl delete triggers graceful + # shutdown of all child workers via the shared multiprocessing.Event, + # instead of the default abrupt kill. + if shutdown_event is not None: + + def _sigterm_handler(signum, frame): + _LOGGER.info("SIGTERM received, signalling workers to shut down...") + shutdown_event.set() + + signal.signal(signal.SIGTERM, _sigterm_handler) + for worker in workers: worker.join() + # Return True if any worker exited with a non-zero code (i.e. a real error, + # not a clean SIGTERM shutdown). + return any(w.exitcode != 0 for w in workers) + async def start_async_server( server_async: grpc.aio.Server, @@ -278,37 +297,6 @@ def get_grpc_status(err: str, detail: str | None = None): return rpc_status.to_status(status) -def exit_on_error( - context: NumaflowServicerContext, err: str, parent: bool = False, update_context=True -): - """ - Exit the current/parent process on an error. - - Args: - context (NumaflowServicerContext): The gRPC context. - err (str): The error message. - parent (bool, optional): Whether this is the parent process. - Defaults to False. - update_context(bool, optional) : Is there a need to update - the context with the error codes - """ - if update_context: - # Create a status object with the error details - grpc_status = get_grpc_status(err) - - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(err) - context.set_trailing_metadata(grpc_status.trailing_metadata) - - p = psutil.Process(os.getpid()) - # If the parent flag is true, we exit from the parent process - # Use this for Multiproc right now to exit from the parent fork - if parent: - p = psutil.Process(os.getppid()) - _LOGGER.info("Killing process: Got exception %s", err) - p.kill() - - def update_context_err(context: NumaflowServicerContext, e: BaseException, err_msg: str): """ Update the context with the error and log the exception. @@ -328,17 +316,3 @@ def get_exception_traceback_str(exc) -> str: file = io.StringIO() traceback.print_exception(exc, value=exc, tb=exc.__traceback__, file=file) return file.getvalue().rstrip() - - -async def handle_async_error( - context: NumaflowServicerContext, exception: BaseException, exception_type: str -): - """ - Handle exceptions for async servers by updating the context and exiting. - """ - err_msg = f"{exception_type}: {repr(exception)}" - update_context_err(context, exception, err_msg) - await asyncio.gather( - context.abort(grpc.StatusCode.INTERNAL, details=err_msg), return_exceptions=True - ) - exit_on_error(err=err_msg, parent=False, context=context, update_context=False) diff --git a/packages/pynumaflow/pynumaflow/sideinput/server.py b/packages/pynumaflow/pynumaflow/sideinput/server.py index 7bb27b86..445e23e1 100644 --- a/packages/pynumaflow/pynumaflow/sideinput/server.py +++ b/packages/pynumaflow/pynumaflow/sideinput/server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.shared import NumaflowServer from pynumaflow.shared.server import sync_server_start @@ -99,4 +101,8 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SideInput, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py b/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py index 7f46bf68..da836948 100644 --- a/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py +++ b/packages/pynumaflow/pynumaflow/sideinput/servicer/servicer.py @@ -1,3 +1,5 @@ +import threading + from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow._constants import ( @@ -5,7 +7,7 @@ ERR_UDF_EXCEPTION_STRING, ) from pynumaflow.proto.sideinput import sideinput_pb2_grpc, sideinput_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sideinput._dtypes import RetrieverCallable from pynumaflow.types import NumaflowServicerContext @@ -16,6 +18,8 @@ def __init__( handler: RetrieverCallable, ): self.__retrieve_handler: RetrieverCallable = handler + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def RetrieveSideInput( self, request: _empty_pb2.Empty, context: NumaflowServicerContext @@ -30,7 +34,9 @@ def RetrieveSideInput( except BaseException as err: err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) - exit_on_error(context, err_msg) + update_context_err(context, err, err_msg) + self.error = err + self.shutdown_event.set() return return sideinput_pb2.SideInputResponse(value=rspn.value, no_broadcast=rspn.no_broadcast) diff --git a/packages/pynumaflow/pynumaflow/sinker/async_server.py b/packages/pynumaflow/pynumaflow/sinker/async_server.py index 516fbb82..03f2f3fe 100644 --- a/packages/pynumaflow/pynumaflow/sinker/async_server.py +++ b/packages/pynumaflow/pynumaflow/sinker/async_server.py @@ -174,7 +174,17 @@ async def _watch_for_shutdown(): await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) shutdown_task = asyncio.create_task(_watch_for_shutdown()) - await server.wait_for_termination() + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) # Propagate error so start() can exit with a non-zero code self._error = self.servicer._error @@ -189,5 +199,5 @@ async def _watch_for_shutdown(): # event loop explicitly here, the python process will not exit. # It reamins stuck for 5 minutes until liveness and readiness probe # fails enough times and k8s sends a SIGTERM - asyncio.get_event_loop().stop() + asyncio.get_running_loop().stop() _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py index 908aa126..5eade634 100644 --- a/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sinker/servicer/async_servicer.py @@ -87,6 +87,13 @@ async def SinkFn( # if we have a valid message, we will add it to the request queue for processing. datum = datum_from_sink_req(d) await req_queue.put(datum) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: err_msg = f"UDSinkError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) diff --git a/packages/pynumaflow/pynumaflow/sourcer/async_server.py b/packages/pynumaflow/pynumaflow/sourcer/async_server.py index 3bca9dfb..9c9e6072 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcer/async_server.py @@ -1,6 +1,11 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION from pynumaflow.sourcer.servicer.async_servicer import AsyncSourceServicer @@ -10,10 +15,12 @@ NUM_THREADS_DEFAULT, SOURCE_SERVER_INFO_FILE_PATH, MAX_NUM_THREADS, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) from pynumaflow.proto.sourcer import source_pb2_grpc -from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.shared.server import NumaflowServer from pynumaflow.sourcer._dtypes import SourceCallable @@ -153,6 +160,7 @@ async def partitions_handler(self) -> PartitionsResponse: ] self.servicer = AsyncSourceServicer(source_handler=sourcer_instance) + self._error: BaseException | None = None def start(self): """ @@ -160,6 +168,9 @@ def start(self): so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self): """ @@ -168,20 +179,62 @@ async def aexec(self): # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context - # Create a new async server instance and add the servicer to it server = grpc.aio.server(options=self._server_options) server.add_insecure_port(self.sock_path) - source_servicer = self.servicer - source_pb2_grpc.add_SourceServicer_to_server(source_servicer, server) + + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + source_pb2_grpc.add_SourceServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sourcer] - # Start the async server - await start_async_server( - server_async=server, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_running_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py index 3e0839c4..51cd74c2 100644 --- a/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py @@ -5,7 +5,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sourcer import ReadRequest, Offset, NackRequest, AckRequest, SourceCallable from pynumaflow.proto.sourcer import source_pb2 from pynumaflow.proto.sourcer import source_pb2_grpc @@ -71,6 +71,12 @@ def __init__(self, source_handler: SourceCallable): self.source_handler = source_handler self.__initialize_handlers() self.cleanup_coroutines = [] + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event def __initialize_handlers(self): """Initialize handler methods from the provided source handler.""" @@ -110,7 +116,12 @@ async def ReadFn( async for resp in riter: if isinstance(resp, BaseException): - await handle_async_error(context, resp, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(resp)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, resp, err_msg) + self._error = resp + if self._shutdown_event is not None: + self._shutdown_event.set() return yield _create_read_response(resp) @@ -119,9 +130,21 @@ async def ReadFn( await task # send an eot to signal all messages have been processed. yield _create_eot_response() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("User-Defined Source ReadFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def __invoke_read( self, req: source_pb2.ReadRequest, niter: NonBlockingIterator[Message | Exception] @@ -167,9 +190,21 @@ async def AckFn( ] await self.__source_ack_handler(AckRequest(offsets=offsets)) yield _create_ack_response() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("User-Defined Source AckFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return async def NackFn( self, @@ -184,9 +219,21 @@ async def NackFn( Offset(offset.offset, offset.partition_id) for offset in request.request.offsets ] await self.__source_nack_handler(NackRequest(offsets=offsets)) + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("User-Defined Source NackFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return return source_pb2.NackResponse( result=source_pb2.NackResponse.Result(success=_empty_pb2.Empty()) ) @@ -209,10 +256,21 @@ async def PendingFn( """ try: count = await self.__source_pending_handler() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PendingResponse(result=source_pb2.PendingResponse.Result(count=0)) + except BaseException as err: _LOGGER.critical("PendingFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) - raise + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PendingResponse(result=source_pb2.PendingResponse.Result(count=0)) resp = source_pb2.PendingResponse.Result(count=count.count) return source_pb2.PendingResponse(result=resp) @@ -224,10 +282,25 @@ async def PartitionsFn( """ try: partitions = await self.__source_partitions_handler() + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PartitionsResponse( + result=source_pb2.PartitionsResponse.Result(partitions=[]) + ) + except BaseException as err: _LOGGER.critical("PartitionsFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) - raise + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + self._error = err + if self._shutdown_event is not None: + self._shutdown_event.set() + return source_pb2.PartitionsResponse( + result=source_pb2.PartitionsResponse.Result(partitions=[]) + ) resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) return source_pb2.PartitionsResponse(result=resp) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py index 990e4587..28623f47 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/async_server.py @@ -1,3 +1,7 @@ +import asyncio +import contextlib +import sys + import aiorun import grpc @@ -7,17 +11,17 @@ MAX_NUM_THREADS, SOURCE_TRANSFORMER_SOCK_PATH, SOURCE_TRANSFORMER_SERVER_INFO_FILE_PATH, + _LOGGER, + NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS, ) +from pynumaflow.info.server import write as info_server_write from pynumaflow.info.types import ( ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType, ) from pynumaflow.proto.sourcetransformer import transform_pb2_grpc -from pynumaflow.shared.server import ( - NumaflowServer, - start_async_server, -) +from pynumaflow.shared.server import NumaflowServer from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer @@ -115,6 +119,7 @@ def __init__( ("grpc.max_receive_message_length", self.max_message_size), ] self.servicer = SourceTransformAsyncServicer(handler=source_transform_instance) + self._error: BaseException | None = None def start(self) -> None: """ @@ -122,32 +127,76 @@ def start(self) -> None: so that all the async coroutines can be started from a single context """ aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback) + if self._error: + _LOGGER.critical("Server exiting due to UDF error: %s", self._error) + sys.exit(1) async def aexec(self) -> None: """ Starts the Async gRPC server on the given UNIX socket with given max threads. """ - # As the server is async, we need to create a new server instance in the # same thread as the event loop so that all the async calls are made in the # same context + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) - server_new = grpc.aio.server(options=self._server_options) - server_new.add_insecure_port(self.sock_path) - transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server_new) + # The asyncio.Event must be created here (inside aexec) rather than in __init__, + # because it must be bound to the running event loop that aiorun creates. + # At __init__ time no event loop exists yet. + shutdown_event = asyncio.Event() + self.servicer.set_shutdown_event(shutdown_event) + + transform_pb2_grpc.add_SourceTransformServicer_to_server(self.servicer, server) serv_info = ServerInfo.get_default_server_info() serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ ContainerType.Sourcetransformer ] - # Start the async server - await start_async_server( - server_async=server_new, - sock_path=self.sock_path, - max_threads=self.max_threads, - cleanup_coroutines=list(), - server_info_file=self.server_info_file, - server_info=serv_info, + await server.start() + info_server_write(server_info=serv_info, info_file=self.server_info_file) + + _LOGGER.info( + "Async GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, ) + + async def _watch_for_shutdown(): + """Wait for the shutdown event and stop the server with a grace period.""" + await shutdown_event.wait() + _LOGGER.info("Shutdown signal received, stopping server gracefully...") + # Stop accepting new requests and wait for a maximum of + # NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + shutdown_task = asyncio.create_task(_watch_for_shutdown()) + try: + await server.wait_for_termination() + except asyncio.CancelledError: + # SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error + # path (where _watch_for_shutdown calls server.stop()), this path + # must stop the gRPC server explicitly. Without this, the server + # object is never stopped and when it is garbage-collected, its + # __del__ tries to schedule a cleanup coroutine on an event loop + # that is already closed, causing errors/warnings. + _LOGGER.info("Received cancellation, stopping server gracefully...") + await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS) + + # Propagate error so start() can exit with a non-zero code + self._error = self.servicer._error + + shutdown_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await shutdown_task + + _LOGGER.info("Stopping event loop...") + # We use aiorun to manage the event loop. The aiorun.run() runs + # forever until loop.stop() is called. If we don't stop the + # event loop explicitly here, the python process will not exit. + # It reamins stuck for 5 minutes until liveness and readiness probe + # fails enough times and k8s sends a SIGTERM + asyncio.get_running_loop().stop() + _LOGGER.info("Event loop stopped") diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py index dbc8b7b5..3e7b150f 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/multiproc_server.py @@ -1,9 +1,13 @@ +import multiprocessing +import sys + from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType from pynumaflow.sourcetransformer.servicer._servicer import SourceTransformServicer from pynumaflow.shared.server import start_multiproc_server from pynumaflow._constants import ( + _LOGGER, MAX_MESSAGE_SIZE, SOURCE_TRANSFORMER_SOCK_PATH, NUM_THREADS_DEFAULT, @@ -127,7 +131,12 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: # Setting the max value to 2 * CPU count # Used for multiproc server self._process_count = min(server_count, 2 * _PROCESS_COUNT) - self.servicer = SourceTransformServicer(handler=source_transform_instance, multiproc=True) + self.servicer = SourceTransformServicer(handler=source_transform_instance) + + # Shared event across all worker processes for coordinated shutdown. + # When any worker's servicer sets this event, all workers' watcher + # threads trigger server.stop() for a graceful coordinated exit. + self._shutdown_event = multiprocessing.Event() def start(self): """ @@ -140,7 +149,7 @@ def start(self): serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ ContainerType.Sourcetransformer ] - start_multiproc_server( + has_error = start_multiproc_server( max_threads=self.max_threads, servicer=self.servicer, process_count=self._process_count, @@ -148,4 +157,9 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SourceTransformer, server_info=serv_info, + shutdown_event=self._shutdown_event, ) + + if has_error: + _LOGGER.critical("Server exiting due to worker error") + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/server.py b/packages/pynumaflow/pynumaflow/sourcetransformer/server.py index 7069e2b6..c410adce 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/server.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ContainerType, MINIMUM_NUMAFLOW_VERSION, ServerInfo from pynumaflow._constants import ( MAX_MESSAGE_SIZE, @@ -128,4 +130,8 @@ def start(self): server_options=self._server_options, udf_type=UDFType.SourceTransformer, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1) diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py index da7384c2..819c27c3 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_async_servicer.py @@ -8,7 +8,7 @@ from pynumaflow._metadata import _user_and_system_metadata_from_proto from pynumaflow.proto.sourcetransformer import transform_pb2, transform_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator -from pynumaflow.shared.server import handle_async_error +from pynumaflow.shared.server import update_context_err from pynumaflow.sourcetransformer import Datum from pynumaflow.sourcetransformer._dtypes import SourceTransformAsyncCallable from pynumaflow.types import NumaflowServicerContext @@ -28,6 +28,12 @@ def __init__( ): self.background_tasks = set() self.__transform_handler: SourceTransformAsyncCallable = handler + self._shutdown_event: asyncio.Event | None = None + self._error: BaseException | None = None + + def set_shutdown_event(self, event: asyncio.Event): + """Wire up the shutdown event created by the server's aexec() coroutine.""" + self._shutdown_event = event async def SourceTransformFn( self, @@ -61,16 +67,32 @@ async def SourceTransformFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, msg, err_msg) + self._error = msg + if self._shutdown_event is not None: + self._shutdown_event.set() return # Send window response back to the client else: yield msg # wait for the producer task to complete await producer + except asyncio.CancelledError: + # Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault. + _LOGGER.info("Server shutting down, cancelling RPC.") + if self._shutdown_event is not None: + self._shutdown_event.set() + return + except BaseException as e: - _LOGGER.critical("SourceTransformFnError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, e, err_msg) + self._error = e + if self._shutdown_event is not None: + self._shutdown_event.set() return async def _process_inputs( diff --git a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py index 2091bf47..3945e194 100644 --- a/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py +++ b/packages/pynumaflow/pynumaflow/sourcetransformer/servicer/_servicer.py @@ -2,10 +2,11 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterable +import grpc from google.protobuf import empty_pb2 as _empty_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow.shared.synciter import SyncIterator from pynumaflow.sourcetransformer import Datum from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable @@ -40,12 +41,14 @@ class SourceTransformServicer(transform_pb2_grpc.SourceTransformServicer): Provides the functionality for the required rpc methods. """ - def __init__(self, handler: SourceTransformCallable, multiproc: bool = False): + def __init__(self, handler: SourceTransformCallable): self.__transform_handler: SourceTransformCallable = handler - # This indicates whether the grpc server attached is multiproc or not - self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + # Graceful shutdown: when set, a watcher thread in _run_server() calls + # server.stop() instead of hard-killing the process via psutil. + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def SourceTransformFn( self, @@ -56,6 +59,8 @@ def SourceTransformFn( Applies a function to each datum element. The pascal case function name comes from the generated transform_pb2_grpc.py file. """ + # Initialize before try so it's accessible in except blocks + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -78,10 +83,18 @@ def SourceTransformFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + if isinstance(res, grpc.RpcError): + # Client disconnected mid-stream — the reader thread + # surfaced the error via the queue. Not a UDF fault. + _LOGGER.warning("gRPC stream closed, shutting down the server.") + result_queue.close() + self.shutdown_event.set() + return + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + update_context_err(context, res, err_msg) + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -90,12 +103,21 @@ def SourceTransformFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + update_context_err(context, err, err_msg) + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( @@ -113,6 +135,11 @@ def _process_requests( self.executor.shutdown(wait=True) # Indicate to the result queue that no more messages left to process result_queue.put(STREAM_EOF) + except grpc.RpcError as e: + # Client disconnected — expected during pod shutdown. + # Surface to the consumer which will trigger graceful shutdown. + _LOGGER.warning("gRPC stream closed in reader thread") + result_queue.put(e) except BaseException as e: _LOGGER.critical("SourceTransformFnError, re-raising the error", exc_info=True) # Surface the error to the consumer; SourceTransformFn will handle and exit diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py new file mode 100644 index 00000000..6d84eb5b --- /dev/null +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator_shutdown.py @@ -0,0 +1,196 @@ +""" +Shutdown-event tests for the async Accumulator servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during consumer iteration (SIGTERM scenario) +2. BaseException during consumer iteration (unexpected error) +3. CancelledError during producer await (SIGTERM scenario) +4. BaseException during producer await (producer task error) +5. Exception object yielded from the result queue (task_manager error) +""" + +import asyncio +from unittest import mock + +from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer + + +async def noop_handler(datums, output): + async for _ in datums: + pass + + +async def _empty_request_iter(): + return + yield # make it an async generator + + +async def _collect(async_gen): + """Collect all items from an async generator.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_event_on_consumer_cancelled_error(): + """CancelledError while reading the result queue (e.g. SIGTERM) should + set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _cancelled_reader(): + raise asyncio.CancelledError() + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _cancelled_reader() + mock_task_manager.process_input_stream = mock.AsyncMock() + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_event_on_consumer_base_exception(): + """BaseException on the result queue should set shutdown_event AND store the error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _error_reader(): + raise RuntimeError("unexpected consumer error") + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _error_reader() + mock_task_manager.process_input_stream = mock.AsyncMock() + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "unexpected consumer error" in repr(servicer._error) + + asyncio.run(_run()) + + +def test_shutdown_event_on_producer_cancelled_error(): + """CancelledError when awaiting the producer task should set shutdown_event + but NOT store an error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _empty_reader(): + return + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _empty_reader() + + async def _cancelled_producer(_): + raise asyncio.CancelledError() + + mock_task_manager.process_input_stream = _cancelled_producer + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_event_on_producer_base_exception(): + """BaseException from the producer task should set shutdown_event AND store the error.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _empty_reader(): + return + yield + + mock_task_manager.global_result_queue.read_iterator.return_value = _empty_reader() + + async def _error_producer(_): + raise RuntimeError("producer blew up") + + mock_task_manager.process_input_stream = _error_producer + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "producer blew up" in repr(servicer._error) + + asyncio.run(_run()) + + +def test_shutdown_event_on_result_queue_exception_message(): + """When the result queue yields a BaseException (from task_manager), + shutdown_event should be set and the error stored.""" + + async def _run(): + servicer = AsyncAccumulatorServicer(handler=noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + mock_task_manager = mock.MagicMock() + + async def _exception_in_queue(): + yield RuntimeError("handler error from task manager") + + mock_task_manager.global_result_queue.read_iterator.return_value = _exception_in_queue() + mock_task_manager.process_input_stream = mock.AsyncMock() + + with mock.patch( + "pynumaflow.accumulator.servicer.async_servicer.TaskManager", + return_value=mock_task_manager, + ): + ctx = mock.MagicMock() + await _collect(servicer.AccumulateFn(_empty_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler error from task manager" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py new file mode 100644 index 00000000..c22ce7ee --- /dev/null +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map_shutdown.py @@ -0,0 +1,82 @@ +""" +Shutdown-event tests for the async BatchMap servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during MapFn iteration (SIGTERM scenario) +2. Handler RuntimeError caught by the outer BaseException handler +""" + +import asyncio +from unittest import mock + +from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer +from pynumaflow.batchmapper import BatchResponses +from tests.batchmap.utils import request_generator + + +async def _noop_handler(datums) -> BatchResponses: + async for _ in datums: + pass + return BatchResponses() + + +async def _err_handler(datums) -> BatchResponses: + raise RuntimeError("handler blew up") + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during MapFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncBatchMapServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError caught by the outer BaseException handler; shutdown_event + is set and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncBatchMapServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + data + EOT + sync_datums = list(request_generator(count=2, session=1, handshake=True)) + + async def _request_iter(): + for d in sync_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/map/test_async_map_shutdown.py b/packages/pynumaflow/tests/map/test_async_map_shutdown.py new file mode 100644 index 00000000..cfbf3e50 --- /dev/null +++ b/packages/pynumaflow/tests/map/test_async_map_shutdown.py @@ -0,0 +1,106 @@ +""" +Shutdown-event tests for the async Map servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during MapFn iteration (SIGTERM scenario) +2. Handler RuntimeError surfaced via result queue +3. Bad handshake raising MapError +""" + +import asyncio +from unittest import mock + +from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer +from pynumaflow.mapper._dtypes import Messages, Message, Datum +from tests.map.utils import get_test_datums + + +async def _noop_handler(keys: list[str], datum: Datum) -> Messages: + return Messages(Message(b"ok", keys=keys)) + + +async def _err_handler(keys: list[str], datum: Datum) -> Messages: + raise RuntimeError("handler blew up") + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during MapFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncMapServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError surfaces via the result queue; shutdown_event is set + and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncMapServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + one datum + test_datums = get_test_datums(handshake=True) + + async def _request_iter(): + for d in test_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) + + +def test_shutdown_on_handshake_error(): + """Missing handshake raises MapError; shutdown_event is set and the error is stored.""" + + async def _run(): + servicer = AsyncMapServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Send data messages without a handshake first + test_datums = get_test_datums(handshake=False) + + async def _request_iter(): + for d in test_datums: + yield d + + ctx = mock.MagicMock() + await _collect(servicer.MapFn(_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "expected handshake" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py b/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py new file mode 100644 index 00000000..669fd77d --- /dev/null +++ b/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py @@ -0,0 +1,119 @@ +""" +Shutdown-event tests for the multiproc Map servicer. + +These tests verify that the SyncMapServicer (as used by MapMultiprocServer) +correctly sets shutdown_event on error, enabling coordinated graceful shutdown +across all worker processes via the shared multiprocessing.Event. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.mapper import MapMultiprocServer +from pynumaflow.proto.mapper import map_pb2 +from tests.map.utils import map_handler, err_map_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + server = MapMultiprocServer(mapper_instance=err_map_handler) + servicer = server.servicer + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + server = MapMultiprocServer(mapper_instance=map_handler) + servicer = server.servicer + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "MapFn: expected handshake as the first message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + server = MapMultiprocServer(mapper_instance=map_handler) + servicer = server.servicer + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + server = MapMultiprocServer(mapper_instance=map_handler) + servicer = server.servicer + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + assert servicer.error is None diff --git a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py new file mode 100644 index 00000000..cf8523c1 --- /dev/null +++ b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py @@ -0,0 +1,120 @@ +""" +Shutdown-event tests for the synchronous Map servicer. + +Mirrors the sinker shutdown test pattern (tests/sink/test_server.py lines 345-461). +Each test verifies that the servicer sets shutdown_event (and optionally captures the +error) under a specific failure mode, enabling graceful server stop via the watcher +thread in _run_server() instead of a hard process kill. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer +from pynumaflow.proto.mapper import map_pb2 +from tests.map.utils import map_handler, err_map_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SyncMapServicer(handler=err_map_handler) + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + servicer = SyncMapServicer(handler=map_handler) + + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + # Send a data message without a handshake first + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "MapFn: expected handshake as the first message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + servicer = SyncMapServicer(handler=map_handler) + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + servicer = SyncMapServicer(handler=map_handler) + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.MapFn(_cancelled_iter(), mock.MagicMock())) + + # Should have at least the handshake response + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py new file mode 100644 index 00000000..697132f5 --- /dev/null +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream_shutdown.py @@ -0,0 +1,83 @@ +""" +Shutdown-event tests for the async MapStream servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during MapFn iteration (SIGTERM scenario) +2. Handler RuntimeError surfaced via result queue +""" + +import asyncio +from collections.abc import AsyncIterable +from unittest import mock + +from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer +from pynumaflow.mapstreamer._dtypes import Message +from pynumaflow.mapstreamer import Datum +from tests.mapstream.utils import request_generator + + +async def _noop_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]: + yield Message(b"ok", keys=keys) + + +async def _err_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]: + raise RuntimeError("handler blew up") + yield # make it an async generator + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during MapFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncMapStreamServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + await _collect(servicer.MapFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError surfaces via the result queue; shutdown_event is set + and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncMapStreamServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + data + sync_datums = list(request_generator(count=2, handshake=True)) + + async def _request_iter(): + for d in sync_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.MapFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py b/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py new file mode 100644 index 00000000..a999de5b --- /dev/null +++ b/packages/pynumaflow/tests/reduce/test_async_reduce_shutdown.py @@ -0,0 +1,113 @@ +""" +Shutdown-event tests for the async Reduce servicer. + +The ReduceFn method has two try/except blocks where we added shutdown handling: + 1. During request iteration (datum_generator loop) + 2. During task result collection (awaiting futures) + +Tests verify that: + - CancelledError sets shutdown_event but does NOT store an error + - Handler errors set shutdown_event AND store the error +""" + +import asyncio +from collections.abc import AsyncIterable +from unittest import mock + +from pynumaflow.reducer.servicer.async_servicer import AsyncReduceServicer +from pynumaflow.reducer._dtypes import ( + Datum, + Messages, + Message, + Metadata, + Reducer, +) + +# --------------------------------------------------------------------------- +# Minimal handler — never raises on its own. +# --------------------------------------------------------------------------- + + +class _StubReducer(Reducer): + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + async for _ in datums: + pass + return Messages(Message(b"done", keys=keys)) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _collect(async_gen): + """Drain an async generator and return the collected items.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +# --------------------------------------------------------------------------- +# Test 1: CancelledError during request iteration +# +# We feed the servicer a request iterator that raises CancelledError. +# This exercises the first except block in ReduceFn. +# --------------------------------------------------------------------------- + + +def test_shutdown_on_cancelled_error(): + """CancelledError during request iteration should set shutdown_event + but NOT store an error.""" + + async def _run(): + servicer = AsyncReduceServicer(handler=_StubReducer) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # An async iterator that immediately raises CancelledError, + # simulating SIGTERM arriving while reading from the gRPC stream. + async def _cancelled_request_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + await _collect(servicer.ReduceFn(_cancelled_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Test 2: Handler raises a real error during request iteration +# +# We feed a request iterator that raises RuntimeError. +# This exercises the BaseException except block in the first try. +# --------------------------------------------------------------------------- + + +def test_shutdown_on_handler_error(): + """A real exception during request iteration should set shutdown_event + AND store the error.""" + + async def _run(): + servicer = AsyncReduceServicer(handler=_StubReducer) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _error_request_iter(): + raise RuntimeError("reduce iteration blew up") + yield + + ctx = mock.MagicMock() + await _collect(servicer.ReduceFn(_error_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "reduce iteration blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/sideinput/test_shutdown.py b/packages/pynumaflow/tests/sideinput/test_shutdown.py new file mode 100644 index 00000000..66f55160 --- /dev/null +++ b/packages/pynumaflow/tests/sideinput/test_shutdown.py @@ -0,0 +1,72 @@ +""" +Shutdown-event tests for the SideInput servicer. + +Verifies that the servicer sets shutdown_event and captures the error when the +UDF handler raises, enabling graceful server stop via the watcher thread in +_run_server() instead of a hard process kill. +""" + +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.sideinput.servicer.servicer import SideInputServicer +from pynumaflow.proto.sideinput import sideinput_pb2 + + +def _ok_handler(): + from pynumaflow.sideinput import Response + + return Response.broadcast_message(b"test") + + +def _err_handler(): + raise RuntimeError("Something is fishy!") + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SideInputServicer(handler=_err_handler) + + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + method = test_server.invoke_unary_unary( + method_descriptor=( + sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ + "RetrieveSideInput" + ] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + _, _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_not_set_on_success(): + """On a successful call, shutdown_event must remain unset.""" + servicer = SideInputServicer(handler=_ok_handler) + + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + method = test_server.invoke_unary_unary( + method_descriptor=( + sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ + "RetrieveSideInput" + ] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + _, _, code, _ = method.termination() + assert code == StatusCode.OK + assert not servicer.shutdown_event.is_set() + assert servicer.error is None diff --git a/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py b/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py new file mode 100644 index 00000000..037c3144 --- /dev/null +++ b/packages/pynumaflow/tests/sink/test_async_sink_shutdown.py @@ -0,0 +1,86 @@ +""" +Shutdown-event tests for the async Sink servicer. + +Tests verify that the servicer correctly handles: +1. CancelledError during SinkFn iteration (SIGTERM scenario) +2. Handler RuntimeError caught by the outer BaseException handler +""" + +import asyncio +from collections.abc import AsyncIterable +from unittest import mock + +from pynumaflow.sinker.servicer.async_servicer import AsyncSinkServicer +from pynumaflow.sinker import Datum, Responses, Response +from tests.sink.test_async_sink import request_generator + + +async def _noop_handler(datums: AsyncIterable[Datum]) -> Responses: + responses = Responses() + async for msg in datums: + responses.append(Response.as_success(msg.id)) + return responses + + +async def _err_handler(datums: AsyncIterable[Datum]) -> Responses: + raise RuntimeError("handler blew up") + + +async def _collect(async_gen): + """Drain an async generator into a list.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during SinkFn should set shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncSinkServicer(handler=_noop_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield # make it an async generator + + ctx = mock.MagicMock() + await _collect(servicer.SinkFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """Handler RuntimeError caught by the outer BaseException handler; shutdown_event + is set and the error is stored on the servicer.""" + + async def _run(): + servicer = AsyncSinkServicer(handler=_err_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build an async request iterator with handshake + data + EOT + sync_datums = list( + request_generator(count=2, req_type="success", session=1, handshake=True) + ) + + async def _request_iter(): + for d in sync_datums: + yield d + + ctx = mock.MagicMock() + responses = await _collect(servicer.SinkFn(_request_iter(), ctx)) + + # First response should be the handshake ack + assert responses[0].handshake.sot + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "handler blew up" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/source/test_async_source_shutdown.py b/packages/pynumaflow/tests/source/test_async_source_shutdown.py new file mode 100644 index 00000000..b6e537a5 --- /dev/null +++ b/packages/pynumaflow/tests/source/test_async_source_shutdown.py @@ -0,0 +1,209 @@ +""" +Shutdown-event tests for the async Source servicer. + +Each test verifies that when a CancelledError is raised during an RPC +(simulating a SIGTERM-triggered task cancellation), the servicer: + - sets the shutdown_event + - does NOT store an error (CancelledError is not a UDF fault) +""" + +import asyncio +from unittest import mock + +from pynumaflow.sourcer.servicer.async_servicer import AsyncSourceServicer +from pynumaflow.sourcer._dtypes import ( + Sourcer, + ReadRequest, + AckRequest, + NackRequest, + PendingResponse, + PartitionsResponse, +) +from pynumaflow.shared.asynciter import NonBlockingIterator +from pynumaflow.proto.sourcer import source_pb2 + +# --------------------------------------------------------------------------- +# Minimal handler that never raises on its own — individual tests will +# override specific methods or inject CancelledError via the request stream. +# --------------------------------------------------------------------------- + + +class _StubSource(Sourcer): + async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator): + pass + + async def ack_handler(self, ack_request: AckRequest): + pass + + async def nack_handler(self, nack_request: NackRequest): + pass + + async def pending_handler(self) -> PendingResponse: + return PendingResponse(count=0) + + async def partitions_handler(self) -> PartitionsResponse: + return PartitionsResponse(partitions=[]) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _collect(async_gen): + """Drain an async generator and return the collected items.""" + results = [] + async for item in async_gen: + results.append(item) + return results + + +# --------------------------------------------------------------------------- +# ReadFn — streaming RPC. CancelledError raised from the request iterator +# after the handshake has been sent. +# --------------------------------------------------------------------------- + + +def test_shutdown_on_read_cancelled_error(): + """CancelledError during ReadFn request iteration should set + shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncSourceServicer(source_handler=_StubSource()) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build a request iterator that sends the handshake then raises + # CancelledError on the next iteration (simulating SIGTERM). + async def _cancelled_request_iter(): + yield source_pb2.ReadRequest(handshake=source_pb2.Handshake(sot=True)) + raise asyncio.CancelledError() + + ctx = mock.MagicMock() + await _collect(servicer.ReadFn(_cancelled_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# AckFn — streaming RPC. Same pattern as ReadFn. +# --------------------------------------------------------------------------- + + +def test_shutdown_on_ack_cancelled_error(): + """CancelledError during AckFn request iteration should set + shutdown_event but NOT store an error.""" + + async def _run(): + servicer = AsyncSourceServicer(source_handler=_StubSource()) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_request_iter(): + yield source_pb2.AckRequest(handshake=source_pb2.Handshake(sot=True)) + raise asyncio.CancelledError() + + ctx = mock.MagicMock() + await _collect(servicer.AckFn(_cancelled_request_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# NackFn — unary RPC. We make the handler raise CancelledError. +# --------------------------------------------------------------------------- + + +def test_shutdown_on_nack_cancelled_error(): + """CancelledError during NackFn handler should set + shutdown_event but NOT store an error.""" + + async def _run(): + handler = _StubSource() + + async def _cancelled_nack(nack_request): + raise asyncio.CancelledError() + + handler.nack_handler = _cancelled_nack + + servicer = AsyncSourceServicer(source_handler=handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + # Build a valid NackRequest proto with at least one offset. + offset = source_pb2.Offset(offset=b"test", partition_id=0) + request = source_pb2.NackRequest(request=source_pb2.NackRequest.Request(offsets=[offset])) + ctx = mock.MagicMock() + # NackFn is a coroutine (not an async generator), so we await it. + await servicer.NackFn(request, ctx) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# PendingFn — unary RPC. Handler raises CancelledError. +# --------------------------------------------------------------------------- + + +def test_shutdown_on_pending_cancelled_error(): + """CancelledError during PendingFn handler should set + shutdown_event but NOT store an error.""" + + async def _run(): + handler = _StubSource() + + async def _cancelled_pending(): + raise asyncio.CancelledError() + + handler.pending_handler = _cancelled_pending + + servicer = AsyncSourceServicer(source_handler=handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + ctx = mock.MagicMock() + await servicer.PendingFn(mock.MagicMock(), ctx) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# PartitionsFn — unary RPC. Handler raises CancelledError. +# --------------------------------------------------------------------------- + + +def test_shutdown_on_partitions_cancelled_error(): + """CancelledError during PartitionsFn handler should set + shutdown_event but NOT store an error.""" + + async def _run(): + handler = _StubSource() + + async def _cancelled_partitions(): + raise asyncio.CancelledError() + + handler.partitions_handler = _cancelled_partitions + + servicer = AsyncSourceServicer(source_handler=handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + ctx = mock.MagicMock() + await servicer.PartitionsFn(mock.MagicMock(), ctx) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py new file mode 100644 index 00000000..74530e99 --- /dev/null +++ b/packages/pynumaflow/tests/sourcetransform/test_async_shutdown.py @@ -0,0 +1,66 @@ +""" +Shutdown-event tests for the async SourceTransform servicer. + +Covers the CancelledError and BaseException handlers in SourceTransformFn. +""" + +import asyncio +from unittest import mock + +from pynumaflow.sourcetransformer.servicer._async_servicer import SourceTransformAsyncServicer +from pynumaflow.sourcetransformer import Datum, Messages, Message +from tests.testing_utils import mock_new_event_time + + +async def async_transform_handler(keys: list[str], datum: Datum) -> Messages: + return Messages(Message(datum.value, mock_new_event_time(), keys=keys)) + + +async def _collect(async_gen): + results = [] + async for item in async_gen: + results.append(item) + return results + + +def test_shutdown_on_cancelled_error(): + """CancelledError during SourceTransformFn should set shutdown_event, no error stored.""" + + async def _run(): + servicer = SourceTransformAsyncServicer(handler=async_transform_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _cancelled_iter(): + raise asyncio.CancelledError() + yield + + ctx = mock.MagicMock() + await _collect(servicer.SourceTransformFn(_cancelled_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is None + + asyncio.run(_run()) + + +def test_shutdown_on_handler_error(): + """BaseException in SourceTransformFn should set shutdown_event and store error.""" + + async def _run(): + servicer = SourceTransformAsyncServicer(handler=async_transform_handler) + shutdown_event = asyncio.Event() + servicer.set_shutdown_event(shutdown_event) + + async def _error_iter(): + raise RuntimeError("unexpected error") + yield + + ctx = mock.MagicMock() + await _collect(servicer.SourceTransformFn(_error_iter(), ctx)) + + assert shutdown_event.is_set() + assert servicer._error is not None + assert "unexpected error" in repr(servicer._error) + + asyncio.run(_run()) diff --git a/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py new file mode 100644 index 00000000..bd4a78e4 --- /dev/null +++ b/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py @@ -0,0 +1,128 @@ +""" +Shutdown-event tests for the multiproc SourceTransform servicer. + +These tests verify that the SourceTransformServicer (as used by +SourceTransformMultiProcServer) correctly sets shutdown_event on error, +enabling coordinated graceful shutdown across all worker processes via +the shared multiprocessing.Event. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer +from pynumaflow.proto.sourcetransformer import transform_pb2 +from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + server = SourceTransformMultiProcServer(source_transform_instance=err_transform_handler) + servicer = server.servicer + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + servicer = server.servicer + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "SourceTransformFn: expected handshake message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + servicer = server.servicer + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + servicer = server.servicer + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + assert servicer.error is None diff --git a/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py new file mode 100644 index 00000000..6045cf6a --- /dev/null +++ b/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py @@ -0,0 +1,128 @@ +""" +Shutdown-event tests for the synchronous SourceTransform servicer. + +Mirrors the mapper shutdown test pattern (tests/map/test_sync_map_shutdown.py). +Each test verifies that the servicer sets shutdown_event (and optionally captures the +error) under a specific failure mode, enabling graceful server stop via the watcher +thread in _run_server() instead of a hard process kill. +""" + +from unittest import mock + +import grpc +from grpc import StatusCode +from grpc_testing import server_from_dictionary, strict_real_time + +from pynumaflow.sourcetransformer.servicer._servicer import SourceTransformServicer +from pynumaflow.proto.sourcetransformer import transform_pb2 +from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums + + +def test_shutdown_event_set_on_handler_error(): + """When the UDF handler raises, the servicer must signal the shutdown event.""" + servicer = SourceTransformServicer(handler=err_transform_handler) + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=2, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, _ = method.termination() + assert code == StatusCode.INTERNAL + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_handshake_error(): + """Missing handshake must also signal the shutdown event.""" + servicer = SourceTransformServicer(handler=transform_handler) + + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + # Send a data message without a handshake first + test_datums = get_test_datums(handshake=False) + + method = test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=1, + ) + + for d in test_datums: + method.send_request(d) + method.requests_closed() + + while True: + try: + method.take_response() + except ValueError: + break + + _, code, details = method.termination() + assert code == StatusCode.INTERNAL + assert "SourceTransformFn: expected handshake message" in details + assert servicer.shutdown_event.is_set() + assert servicer.error is not None + + +def test_shutdown_event_set_on_stream_close_before_handshake(): + """grpc.RpcError on the first read (before handshake): shutdown_event set, + result_queue is None so close is skipped.""" + servicer = SourceTransformServicer(handler=transform_handler) + + def _cancelled_iter(): + raise grpc.RpcError() + yield # make it a generator + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + assert responses == [] + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None + + +def test_shutdown_event_set_on_stream_close_mid_processing(): + """grpc.RpcError mid-processing: result_queue is closed (unblocking the handler + thread) and shutdown_event is set.""" + servicer = SourceTransformServicer(handler=transform_handler) + + test_datums = get_test_datums(handshake=True) + + def _cancelled_iter(): + yield test_datums[0] # handshake + yield test_datums[1] # first data message + raise grpc.RpcError() + + responses = list(servicer.SourceTransformFn(_cancelled_iter(), mock.MagicMock())) + + # Should have at least the handshake response + assert responses[0].handshake.sot + assert servicer.shutdown_event.is_set() + # Not a UDF error — error stays None + assert servicer.error is None