diff --git a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py index 77942b71..d1a581f3 100644 --- a/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -133,10 +133,34 @@ async def _invoke_map_stream( headers=dict(req.request.headers), ) - # Stream results from the user handler as they are produced + # Stream results from the user handler as they are produced. + # The asyncio.sleep(0) after each put yields control to the event loop, + # allowing MapFn to consume and stream the response to gRPC immediately. + # Without it, await Queue.put() on an unbounded queue completes without + # suspending (Queue.full() is always False), starving other tasks. + # The starvation can happen if the UDF code yields messages using regular + # for-loop (non async). See the sample code in https://github.com/numaproj/numaflow-python/issues/342 + # With asyncio.sleep(0), this makes our below 'async for' loop equivalent to: + # + # while True: + # msg = await handler.__anext__() # await point + # await result_queue.put(...) + # + # The "await result_queue.put()" isn't a real await point yielding control back to + # eventloop in the case of an unbounded queue. When queue is not full, it simply calls + # a non-async function https://github.com/python/cpython/blob/f4c9bc899b982b9742b45cff0643fa34de3dc84d/Lib/asyncio/queues.py#L125-L154 + # Or you can refer the source code with: + # python -c "import asyncio, inspect; print(inspect.getsource(asyncio.Queue.put))" + # This results in a tight loop, blocking other tasks on event loop from proceeding. + # Like in the issue linked here, if the user yields 10 messages at 1 second a part, + # the task that reads from the queue can only proceed this 'async for loop' ends as + # it never yields control back to eventloop. So you will see all 10 messages at the + # same time in the next vertex instead of in a true streaming fashion. + # The asyncio.sleep(0) will yield the control back to event loop avoiding starvation. async for msg in self.__map_stream_handler(list(req.request.keys), datum): res = map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags) await result_queue.put(map_pb2.MapResponse(results=[res], id=req.id)) + await asyncio.sleep(0) # Emit EOT for this request id await result_queue.put( diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream_streaming.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream_streaming.py new file mode 100644 index 00000000..cdc3a5ae --- /dev/null +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream_streaming.py @@ -0,0 +1,129 @@ +""" +Test that MapStreamAsyncServer streams messages incrementally even when the +user handler yields via a regular for-loop (no await between yields). + +Regression test for https://github.com/numaproj/numaflow-python/issues/342 + +Root cause: asyncio.Queue.put() on an unbounded queue never suspends, so the +MapFn consumer task was starved and couldn't stream responses to gRPC until +the handler completed. Fix: asyncio.sleep(0) after each put in the servicer. +""" + +import logging +import threading +import time +from collections import deque +from collections.abc import AsyncIterable + +import grpc +import pytest + +from pynumaflow import setup_logging +from pynumaflow.mapstreamer import Datum, MapStreamAsyncServer, Message +from pynumaflow.proto.mapper import map_pb2_grpc +from tests.conftest import create_async_loop, start_async_server, teardown_async_server +from tests.mapstream.utils import request_generator + +LOGGER = setup_logging(__name__) + +pytestmark = pytest.mark.integration + +SOCK_PATH = "unix:///tmp/async_map_stream_streaming.sock" + +NUM_MESSAGES = 5 +PRODUCE_INTERVAL_SECS = 0.2 + + +async def slow_streaming_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]: + """ + Handler that produces messages from a background thread with a delay + between each, and yields them via a tight for-loop with NO await. + This is the pattern from issue #342. + """ + messages: deque[Message] = deque() + + def _produce(): + for i in range(NUM_MESSAGES): + messages.append(Message(f"msg-{i}".encode(), keys=keys)) + time.sleep(PRODUCE_INTERVAL_SECS) + + thread = threading.Thread(target=_produce) + thread.start() + + while thread.is_alive(): + # Tight loop: regular for, no await — the pattern that triggers #342 + while messages: + yield messages.popleft() + + thread.join() + while messages: + yield messages.popleft() + + +async def _start_server(udfs): + server = grpc.aio.server() + map_pb2_grpc.add_MapServicer_to_server(udfs, server) + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) + await server.start() + return server, SOCK_PATH + + +@pytest.fixture(scope="module") +def streaming_server(): + loop = create_async_loop() + server_obj = MapStreamAsyncServer(map_stream_instance=slow_streaming_handler) + udfs = server_obj.servicer + server = start_async_server(loop, _start_server(udfs)) + yield loop + teardown_async_server(loop, server) + + +@pytest.fixture() +def streaming_stub(streaming_server): + return map_pb2_grpc.MapStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_messages_stream_incrementally(streaming_stub): + """ + Verify that messages are streamed to the client as they are produced, + not batched until the handler completes. + + The handler produces NUM_MESSAGES messages with PRODUCE_INTERVAL_SECS between + each. If streaming works, the first message should arrive well before the + last one is produced (total production time = NUM_MESSAGES * PRODUCE_INTERVAL_SECS). + """ + generator_response = streaming_stub.MapFn( + request_iterator=request_generator(count=1, session=1) + ) + + # Consume handshake + handshake = next(generator_response) + assert handshake.handshake.sot + + # Collect messages with their arrival timestamps + arrival_times = [] + result_count = 0 + for msg in generator_response: + if hasattr(msg, "status") and msg.status.eot: + continue + arrival_times.append(time.monotonic()) + result_count += 1 + + assert result_count == NUM_MESSAGES, f"Expected {NUM_MESSAGES} messages, got {result_count}" + + # If messages streamed incrementally, the time span between the first and + # last arrival should be a significant portion of the total production time. + # If they were batched, they'd all arrive within a few milliseconds of each other. + total_production_time = NUM_MESSAGES * PRODUCE_INTERVAL_SECS + first_to_last = arrival_times[-1] - arrival_times[0] + + # The spread should be at least 40% of production time if streaming works. + # If batched, the spread would be near zero (~1-5ms). + min_expected_spread = total_production_time * 0.4 + assert first_to_last >= min_expected_spread, ( + f"Messages arrived too close together ({first_to_last:.3f}s spread), " + f"expected at least {min_expected_spread:.3f}s. " + f"This indicates messages were batched instead of streamed. " + f"Arrival gaps: {[f'{b - a:.3f}s' for a, b in zip(arrival_times, arrival_times[1:])]}" + )