Skip to content
Merged
71 changes: 62 additions & 9 deletions packages/pynumaflow/pynumaflow/accumulator/async_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
import contextlib
import inspect
import sys

import aiorun
import grpc

from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer
from pynumaflow.info.server import write as info_server_write
from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
from pynumaflow.proto.accumulator import accumulator_pb2_grpc

Expand All @@ -15,6 +19,7 @@
MAX_NUM_THREADS,
ACCUMULATOR_SOCK_PATH,
ACCUMULATOR_SERVER_INFO_FILE_PATH,
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
)

from pynumaflow.accumulator._dtypes import (
Expand All @@ -23,7 +28,7 @@
Accumulator,
)

from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server
from pynumaflow.shared.server import NumaflowServer, check_instance


def get_handler(
Expand Down Expand Up @@ -157,6 +162,7 @@ def __init__(
]
# Get the servicer instance for the async server
self.servicer = AsyncAccumulatorServicer(self.accumulator_handler)
self._error: BaseException | None = None

def start(self):
"""
Expand All @@ -167,6 +173,9 @@ def start(self):
"Starting Async Accumulator Server",
)
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
if self._error:
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
sys.exit(1)

async def aexec(self):
"""
Expand All @@ -176,18 +185,62 @@ async def aexec(self):
# As the server is async, we need to create a new server instance in the
# same thread as the event loop so that all the async calls are made in the
# same context
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)

# The asyncio.Event must be created here (inside aexec) rather than in __init__,
# because it must be bound to the running event loop that aiorun creates.
# At __init__ time no event loop exists yet.
shutdown_event = asyncio.Event()
self.servicer.set_shutdown_event(shutdown_event)

accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server)

serv_info = ServerInfo.get_default_server_info()
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator]
await start_async_server(
server_async=server,
sock_path=self.sock_path,
max_threads=self.max_threads,
cleanup_coroutines=list(),
server_info_file=self.server_info_file,
server_info=serv_info,

await server.start()
info_server_write(server_info=serv_info, info_file=self.server_info_file)

_LOGGER.info(
"Async GRPC Server listening on: %s with max threads: %s",
self.sock_path,
self.max_threads,
)

async def _watch_for_shutdown():
"""Wait for the shutdown event and stop the server with a grace period."""
await shutdown_event.wait()
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
# Stop accepting new requests and wait for a maximum of
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

shutdown_task = asyncio.create_task(_watch_for_shutdown())
try:
await server.wait_for_termination()
except asyncio.CancelledError:
# SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error
# path (where _watch_for_shutdown calls server.stop()), this path
# must stop the gRPC server explicitly. Without this, the server
# object is never stopped and when it is garbage-collected, its
# __del__ tries to schedule a cleanup coroutine on an event loop
# that is already closed, causing errors/warnings.
_LOGGER.info("Received cancellation, stopping server gracefully...")
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

# Propagate error so start() can exit with a non-zero code
self._error = self.servicer._error

shutdown_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await shutdown_task

_LOGGER.info("Stopping event loop...")
# We use aiorun to manage the event loop. The aiorun.run() runs
# forever until loop.stop() is called. If we don't stop the
# event loop explicitly here, the python process will not exit.
# It reamins stuck for 5 minutes until liveness and readiness probe
# fails enough times and k8s sends a SIGTERM
asyncio.get_running_loop().stop()
_LOGGER.info("Event loop stopped")
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from google.protobuf import empty_pb2 as _empty_pb2

from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING
from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING
from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc
from pynumaflow.accumulator._dtypes import (
Datum,
Expand All @@ -13,7 +13,7 @@
KeyedWindow,
)
from pynumaflow.accumulator.servicer.task_manager import TaskManager
from pynumaflow.shared.server import handle_async_error
from pynumaflow.shared.server import update_context_err
from pynumaflow.types import NumaflowServicerContext


Expand Down Expand Up @@ -57,6 +57,12 @@ def __init__(
):
# The accumulator handler can be a function or a builder class instance.
self.__accumulator_handler: AccumulatorAsyncCallable | _AccumulatorBuilderClass = handler
self._shutdown_event: asyncio.Event | None = None
self._error: BaseException | None = None

def set_shutdown_event(self, event: asyncio.Event):
"""Wire up the shutdown event created by the server's aexec() coroutine."""
self._shutdown_event = event

async def AccumulateFn(
self,
Expand Down Expand Up @@ -104,20 +110,49 @@ async def AccumulateFn(
async for msg in consumer:
# If the message is an exception, we raise the exception
if isinstance(msg, BaseException):
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, msg, err_msg)
self._error = msg
if self._shutdown_event is not None:
self._shutdown_event.set()
return
# Send window EOF response or Window result response
# back to the client
else:
yield msg
except asyncio.CancelledError:
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
_LOGGER.info("Server shutting down, cancelling RPC.")
if self._shutdown_event is not None:
self._shutdown_event.set()
return

except BaseException as e:
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, e, err_msg)
self._error = e
if self._shutdown_event is not None:
self._shutdown_event.set()
return
# Wait for the process_input_stream task to finish for a clean exit
try:
await producer
except asyncio.CancelledError:
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
_LOGGER.info("Server shutting down, cancelling RPC.")
if self._shutdown_event is not None:
self._shutdown_event.set()
return

except BaseException as e:
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
_LOGGER.critical(err_msg, exc_info=True)
update_context_err(context, e, err_msg)
self._error = e
if self._shutdown_event is not None:
self._shutdown_event.set()
return

async def IsReady(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ async def __invoke_accumulator(
_ = await new_instance(request_iterator, output)
# send EOF to the output stream
await output.put(STREAM_EOF)
except asyncio.CancelledError:
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
return
# If there is an error in the accumulator operation, log and
# then send the error to the result queue
except BaseException as err:
Expand Down Expand Up @@ -243,6 +246,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator
case _:
_LOGGER.debug(f"No operation matched for request: {request}", exc_info=True)

except asyncio.CancelledError:
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
return
# If there is an error in the accumulator operation, log and
# then send the error to the result queue
except BaseException as e:
Expand Down Expand Up @@ -274,6 +280,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator

# Now send STREAM_EOF to terminate the global result queue iterator
await self.global_result_queue.put(STREAM_EOF)
except asyncio.CancelledError:
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
return
except BaseException as e:
err_msg = f"Accumulator Streaming Error: {repr(e)}"
_LOGGER.critical(err_msg, exc_info=True)
Expand Down
79 changes: 64 additions & 15 deletions packages/pynumaflow/pynumaflow/batchmapper/async_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import asyncio
import contextlib
import sys

import aiorun
import grpc

Expand All @@ -8,9 +12,11 @@
BATCH_MAP_SOCK_PATH,
MAP_SERVER_INFO_FILE_PATH,
MAX_NUM_THREADS,
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
)
from pynumaflow.batchmapper._dtypes import BatchMapCallable
from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer
from pynumaflow.info.server import write as info_server_write
from pynumaflow.info.types import (
ServerInfo,
MAP_MODE_KEY,
Expand All @@ -19,7 +25,7 @@
ContainerType,
)
from pynumaflow.proto.mapper import map_pb2_grpc
from pynumaflow.shared.server import NumaflowServer, start_async_server
from pynumaflow.shared.server import NumaflowServer


class BatchMapAsyncServer(NumaflowServer):
Expand Down Expand Up @@ -92,13 +98,17 @@ async def handler(
]

self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance)
self._error: BaseException | None = None

def start(self):
"""
Starter function for the Async Batch Map server, we need a separate caller
to the aexec so that all the async coroutines can be started from a single context
"""
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
if self._error:
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
sys.exit(1)

async def aexec(self):
"""
Expand All @@ -108,25 +118,64 @@ async def aexec(self):
# As the server is async, we need to create a new server instance in the
# same thread as the event loop so that all the async calls are made in the
# same context
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)
map_pb2_grpc.add_MapServicer_to_server(
self.servicer,
server,
)
_LOGGER.info("Starting Batch Map Server")

# The asyncio.Event must be created here (inside aexec) rather than in __init__,
# because it must be bound to the running event loop that aiorun creates.
# At __init__ time no event loop exists yet.
shutdown_event = asyncio.Event()
self.servicer.set_shutdown_event(shutdown_event)

map_pb2_grpc.add_MapServicer_to_server(self.servicer, server)

serv_info = ServerInfo.get_default_server_info()
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper]
# Add the MAP_MODE metadata to the server info for the correct map mode
serv_info.metadata[MAP_MODE_KEY] = MapMode.BatchMap

# Start the async server
await start_async_server(
server_async=server,
sock_path=self.sock_path,
max_threads=self.max_threads,
cleanup_coroutines=list(),
server_info_file=self.server_info_file,
server_info=serv_info,
await server.start()
info_server_write(server_info=serv_info, info_file=self.server_info_file)

_LOGGER.info(
"Async GRPC Server listening on: %s with max threads: %s",
self.sock_path,
self.max_threads,
)

async def _watch_for_shutdown():
"""Wait for the shutdown event and stop the server with a grace period."""
await shutdown_event.wait()
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
# Stop accepting new requests and wait for a maximum of
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

shutdown_task = asyncio.create_task(_watch_for_shutdown())
try:
await server.wait_for_termination()
except asyncio.CancelledError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this happens, then rest of the code after except won't be invoked, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The asyncio.CancelledError will be raised when event loop shutdown or the task is cancelled explicitly. This will cause the block of code under except asyncio.CancelledError to execute. We want to ignore this exception.
All other exceptions will be caught in the BaseException catching blocks, which are categorized as critical and mostly indicate a UDF error, which we should propagate to numa.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment why we are doing so for posterity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked this part of the code again. There are no BaseException catching here. The reason is added under CancelledError exception block.

# SIGTERM received — aiorun cancels all tasks. We must stop
# the gRPC server explicitly so its __del__ doesn't try to
# schedule a coroutine on the already-closed event loop. 

I was seeing something like below due to Python's GC during shutdown of the server:

  Exception ignored in: <function _Server.__del__ at 0x...>
  Traceback (most recent call last):
    File ".../grpc/aio/_server.py", line ..., in __del__
      self._loop.call_soon_threadsafe(...)
  RuntimeError: Event loop is closed

RuntimeError: cannot schedule new futures after shutdown

I will update the comment with more details.

# SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error
# path (where _watch_for_shutdown calls server.stop()), this path
# must stop the gRPC server explicitly. Without this, the server
# object is never stopped and when it is garbage-collected, its
# __del__ tries to schedule a cleanup coroutine on an event loop
# that is already closed, causing errors/warnings.
_LOGGER.info("Received cancellation, stopping server gracefully...")
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)

# Propagate error so start() can exit with a non-zero code
self._error = self.servicer._error

shutdown_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await shutdown_task

_LOGGER.info("Stopping event loop...")
# We use aiorun to manage the event loop. The aiorun.run() runs
# forever until loop.stop() is called. If we don't stop the
# event loop explicitly here, the python process will not exit.
# It reamins stuck for 5 minutes until liveness and readiness probe
# fails enough times and k8s sends a SIGTERM
asyncio.get_running_loop().stop()
_LOGGER.info("Event loop stopped")
Loading
Loading