From e1aa90f2341f3164b84480b133e5d9cab7f36c1f Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 21 Mar 2026 06:30:39 +0530 Subject: [PATCH 1/7] Update pytest to latest Signed-off-by: Sreekanth --- packages/pynumaflow/pyproject.toml | 2 +- packages/pynumaflow/uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/pynumaflow/pyproject.toml b/packages/pynumaflow/pyproject.toml index 59a04e5e..3d18ca8d 100644 --- a/packages/pynumaflow/pyproject.toml +++ b/packages/pynumaflow/pyproject.toml @@ -38,7 +38,7 @@ Repository = "https://github.com/numaproj/numaflow-python" [dependency-groups] dev = [ - "pytest>=7.2.1", + "pytest>=9.0.2", "pytest-cov>=3.0", "black>=23.1", "grpcio-testing>=1.48.1", diff --git a/packages/pynumaflow/uv.lock b/packages/pynumaflow/uv.lock index 146f58de..5fad94fd 100644 --- a/packages/pynumaflow/uv.lock +++ b/packages/pynumaflow/uv.lock @@ -1554,7 +1554,7 @@ dev = [ { name = "grpcio-testing", specifier = ">=1.48.1" }, { name = "mypy", specifier = ">=1.18.2" }, { name = "pre-commit", specifier = ">=3.3.1" }, - { name = "pytest", specifier = ">=7.2.1" }, + { name = "pytest", specifier = ">=9.0.2" }, { name = "pytest-cov", specifier = ">=3.0" }, { name = "ruff", specifier = ">=0.0.264" }, { name = "types-protobuf", specifier = ">=6.32.1.20250918" }, From adfad47c1a4a9b9163ee4990c5c5ccde732b6d28 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 21 Mar 2026 06:42:07 +0530 Subject: [PATCH 2/7] conftest.py with shared fixtures Signed-off-by: Sreekanth --- packages/pynumaflow/tests/conftest.py | 64 +++++++++++++++++++ .../tests/map/test_multiproc_map_shutdown.py | 23 ++----- .../tests/map/test_multiproc_mapper.py | 43 ++----------- .../tests/map/test_sync_map_shutdown.py | 23 ++----- .../pynumaflow/tests/map/test_sync_mapper.py | 43 ++----------- packages/pynumaflow/tests/sink/test_server.py | 40 ++---------- .../tests/sourcetransform/test_multiproc.py | 43 ++----------- .../test_multiproc_shutdown.py | 23 ++----- .../tests/sourcetransform/test_sync_server.py | 57 +++-------------- .../sourcetransform/test_sync_shutdown.py | 23 ++----- 10 files changed, 121 insertions(+), 261 deletions(-) create mode 100644 packages/pynumaflow/tests/conftest.py diff --git a/packages/pynumaflow/tests/conftest.py b/packages/pynumaflow/tests/conftest.py new file mode 100644 index 00000000..be529123 --- /dev/null +++ b/packages/pynumaflow/tests/conftest.py @@ -0,0 +1,64 @@ +""" +Root conftest.py — shared pytest fixtures and helpers for all test modules. + +Provides helpers for common gRPC testing patterns that are duplicated across +sync, multiproc, and async test files. +""" + + +def collect_responses(method): + """Collect all responses from a grpc_testing stream method until exhausted. + + Replaces the repeated pattern: + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in err.__str__(): + break + + Returns a list of response protos. + """ + responses = [] + while True: + try: + resp = method.take_response() + responses.append(resp) + except ValueError as err: + if "No more responses!" in str(err): + break + return responses + + +def drain_responses(method): + """Drain all responses from a grpc_testing stream method, discarding them. + + Replaces the repeated pattern: + while True: + try: + method.take_response() + except ValueError: + break + + Useful in shutdown tests where we only care about termination status. + """ + while True: + try: + method.take_response() + except ValueError: + break + + +def send_test_requests(method, datums): + """Send a list of test datums to a grpc_testing stream method and close. + + Replaces the repeated pattern: + for d in test_datums: + method.send_request(d) + method.requests_closed() + """ + for d in datums: + method.send_request(d) + method.requests_closed() diff --git a/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py b/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py index 669fd77d..24c4a578 100644 --- a/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py +++ b/packages/pynumaflow/tests/map/test_multiproc_map_shutdown.py @@ -14,6 +14,7 @@ from pynumaflow.mapper import MapMultiprocServer from pynumaflow.proto.mapper import map_pb2 +from tests.conftest import drain_responses, send_test_requests from tests.map.utils import map_handler, err_map_handler, get_test_datums @@ -33,15 +34,8 @@ def test_shutdown_event_set_on_handler_error(): timeout=2, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, _ = method.termination() assert code == StatusCode.INTERNAL @@ -65,15 +59,8 @@ def test_shutdown_event_set_on_handshake_error(): timeout=1, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, details = method.termination() assert code == StatusCode.INTERNAL diff --git a/packages/pynumaflow/tests/map/test_multiproc_mapper.py b/packages/pynumaflow/tests/map/test_multiproc_mapper.py index 90eabb9b..7250ad0c 100644 --- a/packages/pynumaflow/tests/map/test_multiproc_mapper.py +++ b/packages/pynumaflow/tests/map/test_multiproc_mapper.py @@ -10,6 +10,7 @@ from pynumaflow.mapper import MapMultiprocServer from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, get_test_datums +from tests.conftest import collect_responses, drain_responses, send_test_requests from tests.testing_utils import ( mock_terminate_on_stop, ) @@ -50,18 +51,8 @@ def test_udf_map_err_handshake(self): invocation_metadata={}, timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("MapFn: expected handshake as the first message" in details) @@ -77,18 +68,8 @@ def test_udf_map_err(self): invocation_metadata={}, timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("Something is fishy!" in details) @@ -116,18 +97,8 @@ def test_map_forward_message(self): invocation_metadata={}, timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + responses = collect_responses(method) metadata, code, details = method.termination() diff --git a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py index cf8523c1..70ae54b5 100644 --- a/packages/pynumaflow/tests/map/test_sync_map_shutdown.py +++ b/packages/pynumaflow/tests/map/test_sync_map_shutdown.py @@ -15,6 +15,7 @@ from pynumaflow.mapper._servicer._sync_servicer import SyncMapServicer from pynumaflow.proto.mapper import map_pb2 +from tests.conftest import drain_responses, send_test_requests from tests.map.utils import map_handler, err_map_handler, get_test_datums @@ -33,15 +34,8 @@ def test_shutdown_event_set_on_handler_error(): timeout=2, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, _ = method.termination() assert code == StatusCode.INTERNAL @@ -65,15 +59,8 @@ def test_shutdown_event_set_on_handshake_error(): timeout=1, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, details = method.termination() assert code == StatusCode.INTERNAL diff --git a/packages/pynumaflow/tests/map/test_sync_mapper.py b/packages/pynumaflow/tests/map/test_sync_mapper.py index edafa835..ce830440 100644 --- a/packages/pynumaflow/tests/map/test_sync_mapper.py +++ b/packages/pynumaflow/tests/map/test_sync_mapper.py @@ -9,6 +9,7 @@ from pynumaflow.mapper import MapServer from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, ExampleMap, get_test_datums +from tests.conftest import collect_responses, drain_responses, send_test_requests from tests.testing_utils import ( mock_terminate_on_stop, ) @@ -44,18 +45,8 @@ def test_udf_map_err_handshake(self): invocation_metadata={}, timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("MapFn: expected handshake as the first message" in details) @@ -72,18 +63,8 @@ def test_udf_map_error_response(self): invocation_metadata={}, timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("Something is fishy!" in details) @@ -111,18 +92,8 @@ def test_map_forward_message(self): invocation_metadata={}, timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + responses = collect_responses(method) metadata, code, details = method.termination() # 1 handshake + 3 data responses diff --git a/packages/pynumaflow/tests/sink/test_server.py b/packages/pynumaflow/tests/sink/test_server.py index 580dfb5e..65da9452 100644 --- a/packages/pynumaflow/tests/sink/test_server.py +++ b/packages/pynumaflow/tests/sink/test_server.py @@ -22,6 +22,7 @@ from pynumaflow.proto.sinker import sink_pb2 from pynumaflow.sinker import Responses, Datum, Response, SinkServer, Message, UserMetadata from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer +from tests.conftest import collect_responses, drain_responses, send_test_requests def mockenv(**envvars): @@ -201,18 +202,8 @@ def test_udsink_err(err_sink_test_server): timeout=1, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in str(err): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() assert code == StatusCode.INTERNAL @@ -267,18 +258,8 @@ def test_forward_message(sink_test_server): invocation_metadata={}, timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in str(err): - break + send_test_requests(method, test_datums) + responses = collect_responses(method) # 1 handshake + 1 data messages + 1 EOT assert len(responses) == 3 @@ -370,15 +351,8 @@ def test_shutdown_event_set_on_handler_error(): timeout=2, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, _ = method.termination() assert code == StatusCode.INTERNAL diff --git a/packages/pynumaflow/tests/sourcetransform/test_multiproc.py b/packages/pynumaflow/tests/sourcetransform/test_multiproc.py index b844231e..4136f9d3 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_multiproc.py +++ b/packages/pynumaflow/tests/sourcetransform/test_multiproc.py @@ -11,6 +11,7 @@ from pynumaflow.proto.sourcetransformer import transform_pb2 from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums +from tests.conftest import collect_responses, drain_responses, send_test_requests from tests.testing_utils import ( mock_new_event_time, mock_terminate_on_stop, @@ -61,18 +62,8 @@ def test_udf_mapt_err_handshake(self): timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("SourceTransformFn: expected handshake message" in details) @@ -95,18 +86,8 @@ def test_udf_mapt_err(self): timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("Something is fishy" in details) @@ -142,18 +123,8 @@ def test_mapt_assign_new_event_time(self): timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + responses = collect_responses(method) metadata, code, details = method.termination() diff --git a/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py index bd4a78e4..92324060 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py +++ b/packages/pynumaflow/tests/sourcetransform/test_multiproc_shutdown.py @@ -15,6 +15,7 @@ from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer from pynumaflow.proto.sourcetransformer import transform_pb2 +from tests.conftest import drain_responses, send_test_requests from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums @@ -38,15 +39,8 @@ def test_shutdown_event_set_on_handler_error(): timeout=2, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, _ = method.termination() assert code == StatusCode.INTERNAL @@ -74,15 +68,8 @@ def test_shutdown_event_set_on_handshake_error(): timeout=1, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, details = method.termination() assert code == StatusCode.INTERNAL diff --git a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py index 2576c97f..1b6b4c89 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py +++ b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py @@ -11,6 +11,7 @@ from pynumaflow.proto.sourcetransformer import transform_pb2 from pynumaflow.sourcetransformer import SourceTransformServer, Datum, Messages, Message from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums +from tests.conftest import collect_responses, drain_responses, send_test_requests from tests.testing_utils import ( mock_terminate_on_stop, mock_new_event_time, @@ -53,18 +54,8 @@ def test_udf_mapt_err(self): timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("Something is fishy" in details) @@ -104,18 +95,8 @@ def test_udf_mapt_err_handshake(self): timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + drain_responses(method) metadata, code, details = method.termination() self.assertTrue("SourceTransformFn: expected handshake message" in details) @@ -134,18 +115,8 @@ def test_mapt_assign_new_event_time(self): timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + responses = collect_responses(method) metadata, code, details = method.termination() @@ -235,18 +206,8 @@ def test_source_transform_with_metadata(self): timeout=1, ) - for x in test_datums: - method.send_request(x) - method.requests_closed() - - responses = [] - while True: - try: - resp = method.take_response() - responses.append(resp) - except ValueError as err: - if "No more responses!" in err.__str__(): - break + send_test_requests(method, test_datums) + responses = collect_responses(method) metadata, code, details = method.termination() diff --git a/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py b/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py index 6045cf6a..e4a811e6 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py +++ b/packages/pynumaflow/tests/sourcetransform/test_sync_shutdown.py @@ -15,6 +15,7 @@ from pynumaflow.sourcetransformer.servicer._servicer import SourceTransformServicer from pynumaflow.proto.sourcetransformer import transform_pb2 +from tests.conftest import drain_responses, send_test_requests from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums @@ -37,15 +38,8 @@ def test_shutdown_event_set_on_handler_error(): timeout=2, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, _ = method.termination() assert code == StatusCode.INTERNAL @@ -73,15 +67,8 @@ def test_shutdown_event_set_on_handshake_error(): timeout=1, ) - for d in test_datums: - method.send_request(d) - method.requests_closed() - - while True: - try: - method.take_response() - except ValueError: - break + send_test_requests(method, test_datums) + drain_responses(method) _, code, details = method.termination() assert code == StatusCode.INTERNAL From 27424baf9d281512d8cd11c3992716aef2198654 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 21 Mar 2026 08:47:14 +0530 Subject: [PATCH 3/7] Migrate unittest.TestCase to plain pytest functions Signed-off-by: Sreekanth --- .../accumulator/test_async_accumulator.py | 581 +++++++++-------- .../accumulator/test_async_accumulator_err.py | 137 ++-- .../tests/accumulator/test_datatypes.py | 593 +++++++++--------- .../tests/batchmap/test_async_batch_map.py | 184 +++--- .../batchmap/test_async_batch_map_err.py | 148 ++--- .../tests/batchmap/test_datatypes.py | 155 +++-- .../tests/batchmap/test_messages.py | 181 +++--- .../pynumaflow/tests/errors/test_dtypes.py | 110 ++-- .../errors/test_persist_critical_error.py | 242 ++++--- .../pynumaflow/tests/map/test_async_mapper.py | 354 +++++------ .../pynumaflow/tests/map/test_messages.py | 204 +++--- .../tests/map/test_multiproc_mapper.py | 228 ++++--- .../pynumaflow/tests/map/test_sync_mapper.py | 244 ++++--- .../tests/mapstream/test_async_map_stream.py | 245 ++++---- .../mapstream/test_async_map_stream_err.py | 167 +++-- .../tests/mapstream/test_messages.py | 176 +++--- .../tests/reduce/test_async_reduce.py | 313 +++++---- .../tests/reduce/test_async_reduce_err.py | 141 ++--- .../pynumaflow/tests/reduce/test_datatypes.py | 275 ++++---- .../pynumaflow/tests/reduce/test_messages.py | 176 +++--- .../tests/reducestreamer/test_async_reduce.py | 434 +++++++------ .../reducestreamer/test_async_reduce_err.py | 180 +++--- .../tests/reducestreamer/test_datatypes.py | 269 ++++---- .../tests/reducestreamer/test_messages.py | 79 ++- .../tests/sideinput/test_responses.py | 70 +-- .../tests/sideinput/test_side_input_server.py | 235 +++---- .../pynumaflow/tests/sink/test_async_sink.py | 488 +++++++------- .../pynumaflow/tests/sink/test_datatypes.py | 172 +++-- .../pynumaflow/tests/sink/test_responses.py | 153 +++-- .../tests/source/test_async_source.py | 348 +++++----- .../tests/source/test_async_source_err.py | 289 ++++----- .../pynumaflow/tests/source/test_message.py | 120 ++-- .../tests/sourcetransform/test_async.py | 489 +++++++-------- .../tests/sourcetransform/test_messages.py | 312 ++++----- .../tests/sourcetransform/test_multiproc.py | 291 ++++----- .../tests/sourcetransform/test_sync_server.py | 338 +++++----- packages/pynumaflow/tests/test_info_server.py | 88 +-- packages/pynumaflow/tests/testing_utils.py | 4 - 38 files changed, 4370 insertions(+), 4843 deletions(-) diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py index 2afc1d64..46922b4e 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py @@ -1,12 +1,11 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio._server import Server from pynumaflow import setup_logging from pynumaflow.accumulator import ( @@ -26,6 +25,8 @@ LOGGER = setup_logging(__name__) +SOCK_PATH = "unix:///tmp/accumulator.sock" + def request_generator(count, request, resetkey: bool = False, send_close: bool = False): for i in range(count): @@ -137,11 +138,6 @@ def start_request_without_open() -> accumulator_pb2.AccumulatorRequest: return request -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/accumulator.sock") -_loop = None - - def startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() @@ -172,321 +168,306 @@ def NewAsyncAccumulator(): return udfs -async def start_server(udfs): +async def _start_server(udfs): server = grpc.aio.server() accumulator_pb2_grpc.add_AccumulatorServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/accumulator.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -class TestAsyncAccumulator(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncAccumulator() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/accumulator.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: - try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: - LOGGER.error(e) + return server - def test_accumulate(self) -> None: - stub = self.__stub() - request = start_request() - generator_response = None - try: - generator_response = stub.AccumulateFn( - request_iterator=request_generator(count=5, request=request) - ) - except grpc.RpcError as e: - logging.error(e) - # capture the output from the AccumulateFn generator and assert. - count = 0 - eof_count = 0 - for r in generator_response: - if hasattr(r, "payload") and r.payload.value: - count += 1 - # Each datum should increment the counter - expected_msg = f"counter:{count}" - self.assertEqual( - bytes(expected_msg, encoding="utf-8"), - r.payload.value, - ) - self.assertEqual(r.EOF, False) - # Check that keys are preserved - self.assertEqual(list(r.payload.keys), ["test_key"]) - else: - self.assertEqual(r.EOF, True) - eof_count += 1 - - # We should have received 5 messages (one for each datum) - self.assertEqual(5, count) - self.assertEqual(1, eof_count) - - def test_accumulate_with_multiple_keys(self) -> None: - stub = self.__stub() - request = start_request() - generator_response = None +@pytest.fixture(scope="module") +def async_accumulator_server(): + """Module-scoped fixture: starts an async gRPC accumulator server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncAccumulator() + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) + + # Wait for the server to be ready + while True: try: - generator_response = stub.AccumulateFn( - request_iterator=request_generator(count=10, request=request, resetkey=True), - ) - except grpc.RpcError as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - count = 0 - eof_count = 0 - key_counts = {} - - # capture the output from the AccumulateFn generator and assert. - for r in generator_response: - # Check for responses with values - if r.payload.value: - count += 1 - # Track count per key - key = r.payload.keys[0] if r.payload.keys else "no_key" - key_counts[key] = key_counts.get(key, 0) + 1 - - # Each key should have its own counter starting from 1 - expected_msg = f"counter:{key_counts[key]}" - self.assertEqual( - bytes(expected_msg, encoding="utf-8"), - r.payload.value, - ) - self.assertEqual(r.EOF, False) - else: - eof_count += 1 - self.assertEqual(r.EOF, True) - - # We should have 10 messages (one for each key) - self.assertEqual(10, count) - self.assertEqual(10, eof_count) # Each key/task sends its own EOF - # Each key should appear once - self.assertEqual(len(key_counts), 10) - - def test_accumulate_with_close(self) -> None: - stub = self.__stub() - request = start_request() - generator_response = None - try: - generator_response = stub.AccumulateFn( - request_iterator=request_generator(count=5, request=request, send_close=True) - ) - except grpc.RpcError as e: - logging.error(e) + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def accumulator_stub(async_accumulator_server): + """Returns an AccumulatorStub connected to the running async server.""" + return accumulator_pb2_grpc.AccumulatorStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_accumulate(accumulator_stub) -> None: + request = start_request() + generator_response = None + try: + generator_response = accumulator_stub.AccumulateFn( + request_iterator=request_generator(count=5, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = f"counter:{count}" + assert bytes(expected_msg, encoding="utf-8") == r.payload.value + assert r.EOF is False + # Check that keys are preserved + assert list(r.payload.keys) == ["test_key"] + else: + assert r.EOF is True + eof_count += 1 + + # We should have received 5 messages (one for each datum) + assert 5 == count + assert 1 == eof_count - # capture the output from the AccumulateFn generator and assert. - count = 0 - eof_count = 0 - for r in generator_response: - if hasattr(r, "payload") and r.payload.value: - count += 1 - # Each datum should increment the counter - expected_msg = f"counter:{count}" - self.assertEqual( - bytes(expected_msg, encoding="utf-8"), - r.payload.value, - ) - self.assertEqual(r.EOF, False) - # Check that keys are preserved - self.assertEqual(list(r.payload.keys), ["test_key"]) - else: - self.assertEqual(r.EOF, True) - eof_count += 1 - - # We should have received 5 messages (one for each datum) - self.assertEqual(5, count) - self.assertEqual(1, eof_count) - - def test_accumulate_append_without_open(self) -> None: - stub = self.__stub() - request = start_request_without_open() - generator_response = None - try: - generator_response = stub.AccumulateFn( - request_iterator=request_generator_append_only(count=5, request=request) - ) - except grpc.RpcError as e: - logging.error(e) - # capture the output from the AccumulateFn generator and assert. - count = 0 - eof_count = 0 - for r in generator_response: - if hasattr(r, "payload") and r.payload.value: - count += 1 - # Each datum should increment the counter - expected_msg = f"counter:{count}" - self.assertEqual( - bytes(expected_msg, encoding="utf-8"), - r.payload.value, - ) - self.assertEqual(r.EOF, False) - # Check that keys are preserved - self.assertEqual(list(r.payload.keys), ["test_key"]) - else: - self.assertEqual(r.EOF, True) - eof_count += 1 - - # We should have received 5 messages (one for each datum) - self.assertEqual(5, count) - self.assertEqual(1, eof_count) - - def test_accumulate_append_mixed(self) -> None: - stub = self.__stub() - request = start_request() - generator_response = None +def test_accumulate_with_multiple_keys(accumulator_stub) -> None: + request = start_request() + generator_response = None + try: + generator_response = accumulator_stub.AccumulateFn( + request_iterator=request_generator(count=10, request=request, resetkey=True), + ) + except grpc.RpcError as e: + LOGGER.error(e) + + count = 0 + eof_count = 0 + key_counts = {} + + # capture the output from the AccumulateFn generator and assert. + for r in generator_response: + # Check for responses with values + if r.payload.value: + count += 1 + # Track count per key + key = r.payload.keys[0] if r.payload.keys else "no_key" + key_counts[key] = key_counts.get(key, 0) + 1 + + # Each key should have its own counter starting from 1 + expected_msg = f"counter:{key_counts[key]}" + assert bytes(expected_msg, encoding="utf-8") == r.payload.value + assert r.EOF is False + else: + eof_count += 1 + assert r.EOF is True + + # We should have 10 messages (one for each key) + assert 10 == count + assert 10 == eof_count # Each key/task sends its own EOF + # Each key should appear once + assert len(key_counts) == 10 + + +def test_accumulate_with_close(accumulator_stub) -> None: + request = start_request() + generator_response = None + try: + generator_response = accumulator_stub.AccumulateFn( + request_iterator=request_generator(count=5, request=request, send_close=True) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = f"counter:{count}" + assert bytes(expected_msg, encoding="utf-8") == r.payload.value + assert r.EOF is False + # Check that keys are preserved + assert list(r.payload.keys) == ["test_key"] + else: + assert r.EOF is True + eof_count += 1 + + # We should have received 5 messages (one for each datum) + assert 5 == count + assert 1 == eof_count + + +def test_accumulate_append_without_open(accumulator_stub) -> None: + request = start_request_without_open() + generator_response = None + try: + generator_response = accumulator_stub.AccumulateFn( + request_iterator=request_generator_append_only(count=5, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = f"counter:{count}" + assert bytes(expected_msg, encoding="utf-8") == r.payload.value + assert r.EOF is False + # Check that keys are preserved + assert list(r.payload.keys) == ["test_key"] + else: + assert r.EOF is True + eof_count += 1 + + # We should have received 5 messages (one for each datum) + assert 5 == count + assert 1 == eof_count + + +def test_accumulate_append_mixed(accumulator_stub) -> None: + request = start_request() + generator_response = None + try: + generator_response = accumulator_stub.AccumulateFn( + request_iterator=request_generator_mixed(count=5, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = "counter:1" + assert bytes(expected_msg, encoding="utf-8") == r.payload.value + assert r.EOF is False + # Check that keys are preserved + assert list(r.payload.keys) == ["test_key"] + else: + assert r.EOF is True + eof_count += 1 + + # We should have received 5 messages (one for each datum) + assert 3 == count + assert 3 == eof_count + + +def test_is_ready(async_accumulator_server) -> None: + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = accumulator_pb2_grpc.AccumulatorStub(channel) + + request = _empty_pb2.Empty() + response = None try: - generator_response = stub.AccumulateFn( - request_iterator=request_generator_mixed(count=5, request=request) - ) + response = stub.IsReady(request=request) except grpc.RpcError as e: logging.error(e) - # capture the output from the AccumulateFn generator and assert. - count = 0 - eof_count = 0 - for r in generator_response: - if hasattr(r, "payload") and r.payload.value: - count += 1 - # Each datum should increment the counter - expected_msg = "counter:1" - self.assertEqual( - bytes(expected_msg, encoding="utf-8"), - r.payload.value, - ) - self.assertEqual(r.EOF, False) - # Check that keys are preserved - self.assertEqual(list(r.payload.keys), ["test_key"]) - else: - self.assertEqual(r.EOF, True) - eof_count += 1 - - # We should have received 5 messages (one for each datum) - self.assertEqual(3, count) - self.assertEqual(3, eof_count) - - def test_is_ready(self) -> None: - with grpc.insecure_channel("unix:///tmp/accumulator.sock") as channel: - stub = accumulator_pb2_grpc.AccumulatorStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def __stub(self): - return accumulator_pb2_grpc.AccumulatorStub(_channel) - - def test_error_init(self): - # Check that accumulator_instance is required - with self.assertRaises(TypeError): - AccumulatorAsyncServer() - # Check that the init_args and init_kwargs are passed - # only with an Accumulator class - with self.assertRaises(TypeError): - AccumulatorAsyncServer(accumulator_handler_func, init_args=(0, 1)) - # Check that an instance is not passed instead of the class - # signature - with self.assertRaises(TypeError): - AccumulatorAsyncServer(ExampleClass(0)) - - # Check that an invalid class is passed - class ExampleBadClass: - pass - - with self.assertRaises(TypeError): - AccumulatorAsyncServer(accumulator_instance=ExampleBadClass) - - def test_max_threads(self): - # max cap at 16 - server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = AccumulatorAsyncServer(accumulator_instance=ExampleClass) - self.assertEqual(server.max_threads, 4) - - # zero threads - server = AccumulatorAsyncServer(ExampleClass, max_threads=0) - self.assertEqual(server.max_threads, 0) - - # negative threads - server = AccumulatorAsyncServer(ExampleClass, max_threads=-5) - self.assertEqual(server.max_threads, -5) - - def test_server_info_file_path_handling(self): - """Test AccumulatorAsyncServer with custom server info file path.""" - - server = AccumulatorAsyncServer( - ExampleClass, init_args=(0,), server_info_file="/custom/path/server_info.json" - ) + assert response.ready - self.assertEqual(server.server_info_file, "/custom/path/server_info.json") - def test_init_kwargs_none_handling(self): - """Test init_kwargs None handling in AccumulatorAsyncServer.""" +def test_error_init(): + # Check that accumulator_instance is required + with pytest.raises(TypeError): + AccumulatorAsyncServer() + # Check that the init_args and init_kwargs are passed + # only with an Accumulator class + with pytest.raises(TypeError): + AccumulatorAsyncServer(accumulator_handler_func, init_args=(0, 1)) + # Check that an instance is not passed instead of the class + # signature + with pytest.raises(TypeError): + AccumulatorAsyncServer(ExampleClass(0)) - server = AccumulatorAsyncServer( - ExampleClass, init_args=(0,), init_kwargs=None # This should be converted to {} - ) + # Check that an invalid class is passed + class ExampleBadClass: + pass + + with pytest.raises(TypeError): + AccumulatorAsyncServer(accumulator_instance=ExampleBadClass) + + +def test_max_threads(): + # max cap at 16 + server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = AccumulatorAsyncServer(accumulator_instance=ExampleClass) + assert server.max_threads == 4 + + # zero threads + server = AccumulatorAsyncServer(ExampleClass, max_threads=0) + assert server.max_threads == 0 + + # negative threads + server = AccumulatorAsyncServer(ExampleClass, max_threads=-5) + assert server.max_threads == -5 - # Should not raise any errors and should work correctly - self.assertIsNotNone(server.accumulator_handler) - def test_server_start_method_logging(self): - """Test server start method includes proper logging.""" - from unittest.mock import patch +def test_server_info_file_path_handling(): + """Test AccumulatorAsyncServer with custom server info file path.""" + + server = AccumulatorAsyncServer( + ExampleClass, init_args=(0,), server_info_file="/custom/path/server_info.json" + ) + + assert server.server_info_file == "/custom/path/server_info.json" + + +def test_init_kwargs_none_handling(): + """Test init_kwargs None handling in AccumulatorAsyncServer.""" + + server = AccumulatorAsyncServer( + ExampleClass, init_args=(0,), init_kwargs=None # This should be converted to {} + ) + + # Should not raise any errors and should work correctly + assert server.accumulator_handler is not None - server = AccumulatorAsyncServer(ExampleClass) - # Mock aiorun.run to prevent actual server startup - with ( - patch("pynumaflow.accumulator.async_server.aiorun") as mock_aiorun, - patch("pynumaflow.accumulator.async_server._LOGGER") as mock_logger, - ): - server.start() +def test_server_start_method_logging(): + """Test server start method includes proper logging.""" + from unittest.mock import patch - # Verify logging was called - mock_logger.info.assert_called_once_with("Starting Async Accumulator Server") + server = AccumulatorAsyncServer(ExampleClass) - # Verify aiorun.run was called with correct parameters - mock_aiorun.run.assert_called_once() - self.assertTrue(mock_aiorun.run.call_args[1]["use_uvloop"]) + # Mock aiorun.run to prevent actual server startup + with ( + patch("pynumaflow.accumulator.async_server.aiorun") as mock_aiorun, + patch("pynumaflow.accumulator.async_server._LOGGER") as mock_logger, + ): + server.start() + # Verify logging was called + mock_logger.info.assert_called_once_with("Starting Async Accumulator Server") -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + # Verify aiorun.run was called with correct parameters + mock_aiorun.run.assert_called_once() + assert mock_aiorun.run.call_args[1]["use_uvloop"] diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py index a6b49b26..09eead23 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py @@ -1,12 +1,10 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable -from unittest.mock import patch import grpc -from grpc.aio._server import Server +import pytest from pynumaflow import setup_logging from pynumaflow.accumulator import ( @@ -20,11 +18,12 @@ from tests.testing_utils import ( mock_message, get_time_args, - mock_terminate_on_stop, ) LOGGER = setup_logging(__name__) +SOCK_PATH = "unix:///tmp/accumulator_err.sock" + def request_generator(count, request): for i in range(count): @@ -58,11 +57,6 @@ def start_request() -> accumulator_pb2.AccumulatorRequest: return request -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/accumulator_err.sock") -_loop = None - - def startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() @@ -99,77 +93,68 @@ def NewAsyncAccumulatorError(): return udfs -@patch("psutil.Process.kill", mock_terminate_on_stop) -async def start_server(udfs): +async def _start_server(udfs): server = grpc.aio.server() accumulator_pb2_grpc.add_AccumulatorServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/accumulator_err.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncAccumulatorError(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncAccumulatorError() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/accumulator_err.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_accumulator_err_server(): + """Module-scoped fixture: starts an async gRPC accumulator error server.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncAccumulatorError() + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) + + # Wait for the server to be ready + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - @patch("psutil.Process.kill", mock_terminate_on_stop) - def test_accumulate_partial_success(self) -> None: - """Test that the first datum is processed before error occurs""" - stub = self.__stub() - request = start_request() + yield loop - try: - generator_response = stub.AccumulateFn( - request_iterator=request_generator(count=5, request=request) - ) - - # Try to consume the generator - counter = 0 - for response in generator_response: - self.assertIsInstance(response, accumulator_pb2.AccumulatorResponse) - self.assertTrue(response.payload.value.startswith(b"counter:")) - counter += 1 - - self.assertEqual(counter, 1, "Expected only one successful response before error") - except BaseException as err: - self.assertTrue("Simulated error in accumulator handler" in str(err)) - return - self.fail("Expected an exception.") - - def __stub(self): - return accumulator_pb2_grpc.AccumulatorStub(_channel) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def accumulator_err_stub(async_accumulator_err_server): + """Returns an AccumulatorStub connected to the running async error server.""" + return accumulator_pb2_grpc.AccumulatorStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_accumulate_partial_success(accumulator_err_stub) -> None: + """Test that the first datum is processed before error occurs""" + request = start_request() + + try: + generator_response = accumulator_err_stub.AccumulateFn( + request_iterator=request_generator(count=5, request=request) + ) + + # Try to consume the generator + counter = 0 + for response in generator_response: + assert isinstance(response, accumulator_pb2.AccumulatorResponse) + assert response.payload.value.startswith(b"counter:") + counter += 1 + + assert counter == 1, "Expected only one successful response before error" + except BaseException as err: + assert "Simulated error in accumulator handler" in str(err) + return + pytest.fail("Expected an exception.") diff --git a/packages/pynumaflow/tests/accumulator/test_datatypes.py b/packages/pynumaflow/tests/accumulator/test_datatypes.py index d82e4bf2..0f829024 100644 --- a/packages/pynumaflow/tests/accumulator/test_datatypes.py +++ b/packages/pynumaflow/tests/accumulator/test_datatypes.py @@ -1,4 +1,5 @@ -import unittest +import asyncio +import pytest from collections.abc import AsyncIterable from datetime import datetime, timezone @@ -28,312 +29,334 @@ TEST_HEADERS = {"key1": "value1", "key2": "value2"} -class TestDatum(unittest.TestCase): - def test_err_event_time(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - headers = {"key1": "value1", "key2": "value2"} - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=ts, - watermark=mock_watermark(), - id_=TEST_ID, - headers=headers, - ) - self.assertEqual( - "Wrong data type: " - "for Datum.event_time", - str(context.exception), - ) +# --- TestDatum --- - def test_err_watermark(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - headers = {"key1": "value1", "key2": "value2"} - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=ts, - id_=TEST_ID, - headers=headers, - ) - self.assertEqual( - "Wrong data type: " - "for Datum.watermark", - str(context.exception), - ) - def test_properties(self): - d = Datum( +def test_datum_err_event_time(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + headers = {"key1": "value1", "key2": "value2"} + with pytest.raises(Exception) as exc_info: + Datum( keys=TEST_KEYS, value=mock_message(), - event_time=mock_event_time(), + event_time=ts, watermark=mock_watermark(), id_=TEST_ID, - headers=TEST_HEADERS, + headers=headers, ) - self.assertEqual(mock_message(), d.value) - self.assertEqual(TEST_KEYS, d.keys) - self.assertEqual(mock_event_time(), d.event_time) - self.assertEqual(mock_watermark(), d.watermark) - self.assertEqual(TEST_HEADERS, d.headers) - self.assertEqual(TEST_ID, d.id) - - def test_default_values(self): - d = Datum( - keys=None, - value=None, + assert ( + "Wrong data type: " "for Datum.event_time" + ) == str(exc_info.value) + + +def test_datum_err_watermark(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + headers = {"key1": "value1", "key2": "value2"} + with pytest.raises(Exception) as exc_info: + Datum( + keys=TEST_KEYS, + value=mock_message(), event_time=mock_event_time(), - watermark=mock_watermark(), + watermark=ts, id_=TEST_ID, + headers=headers, ) - self.assertEqual([], d.keys) - self.assertEqual(b"", d.value) - self.assertEqual({}, d.headers) + assert ( + "Wrong data type: " "for Datum.watermark" + ) == str(exc_info.value) -class TestIntervalWindow(unittest.TestCase): - def test_start(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(mock_start_time(), i.start) +def test_datum_properties(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id_=TEST_ID, + headers=TEST_HEADERS, + ) + assert mock_message() == d.value + assert TEST_KEYS == d.keys + assert mock_event_time() == d.event_time + assert mock_watermark() == d.watermark + assert TEST_HEADERS == d.headers + assert TEST_ID == d.id - def test_end(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(mock_end_time(), i.end) +def test_datum_default_values(): + d = Datum( + keys=None, + value=None, + event_time=mock_event_time(), + watermark=mock_watermark(), + id_=TEST_ID, + ) + assert [] == d.keys + assert b"" == d.value + assert {} == d.headers -class TestKeyedWindow(unittest.TestCase): - def test_create_window(self): - kw = KeyedWindow( - start=mock_start_time(), end=mock_end_time(), slot="slot-0", keys=["key1", "key2"] - ) - self.assertEqual(kw.start, mock_start_time()) - self.assertEqual(kw.end, mock_end_time()) - self.assertEqual(kw.slot, "slot-0") - self.assertEqual(kw.keys, ["key1", "key2"]) - - def test_default_values(self): - kw = KeyedWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(kw.slot, "") - self.assertEqual(kw.keys, []) - - def test_window_property(self): - kw = KeyedWindow(start=mock_start_time(), end=mock_end_time()) - self.assertIsInstance(kw.window, IntervalWindow) - self.assertEqual(kw.window.start, mock_start_time()) - self.assertEqual(kw.window.end, mock_end_time()) - - -class TestAccumulatorResult(unittest.TestCase): - def test_create_result(self): - # Create mock objects - future = None # In real usage, this would be an asyncio.Task - iterator = NonBlockingIterator() - keys = ["key1", "key2"] - result_queue = NonBlockingIterator() - consumer_future = None # In real usage, this would be an asyncio.Task - watermark = datetime.fromtimestamp(1662998400, timezone.utc) - - result = AccumulatorResult(future, iterator, keys, result_queue, consumer_future, watermark) - - self.assertEqual(result.future, future) - self.assertEqual(result.iterator, iterator) - self.assertEqual(result.keys, keys) - self.assertEqual(result.result_queue, result_queue) - self.assertEqual(result.consumer_future, consumer_future) - self.assertEqual(result.latest_watermark, watermark) - - def test_update_watermark(self): - result = AccumulatorResult( - None, None, [], None, None, datetime.fromtimestamp(1662998400, timezone.utc) - ) - new_watermark = datetime.fromtimestamp(1662998460, timezone.utc) - result.update_watermark(new_watermark) - self.assertEqual(result.latest_watermark, new_watermark) - def test_update_watermark_invalid_type(self): - result = AccumulatorResult( - None, None, [], None, None, datetime.fromtimestamp(1662998400, timezone.utc) - ) - with self.assertRaises(TypeError): - result.update_watermark("not a datetime") +# --- TestIntervalWindow --- -class TestAccumulatorRequest(unittest.TestCase): - def test_create_request(self): - operation = WindowOperation.OPEN - keyed_window = KeyedWindow(start=mock_start_time(), end=mock_end_time()) - payload = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - id_=TEST_ID, - ) +def test_interval_window_start(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert mock_start_time() == i.start - request = AccumulatorRequest(operation, keyed_window, payload) - self.assertEqual(request.operation, operation) - self.assertEqual(request.keyed_window, keyed_window) - self.assertEqual(request.payload, payload) - - -class TestWindowOperation(unittest.TestCase): - def test_enum_values(self): - self.assertEqual(WindowOperation.OPEN, 0) - self.assertEqual(WindowOperation.CLOSE, 1) - self.assertEqual(WindowOperation.APPEND, 2) - - -class TestMessage(unittest.TestCase): - def test_create_message(self): - value = b"test_value" - keys = ["key1", "key2"] - tags = ["tag1", "tag2"] - - msg = Message(value=value, keys=keys, tags=tags) - self.assertEqual(msg.value, value) - self.assertEqual(msg.keys, keys) - self.assertEqual(msg.tags, tags) - - def test_default_values(self): - msg = Message(value=b"test") - self.assertEqual(msg.keys, []) - self.assertEqual(msg.tags, []) - - def test_to_drop(self): - msg = Message.to_drop() - self.assertEqual(msg.value, b"") - self.assertEqual(msg.keys, []) - self.assertTrue("U+005C__DROP__" in msg.tags) - - def test_none_values(self): - msg = Message(value=None, keys=None, tags=None) - self.assertEqual(msg.value, b"") - self.assertEqual(msg.keys, []) - self.assertEqual(msg.tags, []) - - def test_from_datum(self): - """Test that Message.from_datum correctly creates a Message from a Datum""" - # Create a sample datum with all properties - test_keys = ["key1", "key2"] - test_value = b"test_message_value" - test_event_time = mock_event_time() - test_watermark = mock_watermark() - test_headers = {"header1": "value1", "header2": "value2"} - test_id = "test_datum_id" - - datum = Datum( - keys=test_keys, - value=test_value, - event_time=test_event_time, - watermark=test_watermark, - id_=test_id, - headers=test_headers, - ) - # Create message from datum - message = Message.from_datum(datum) - - # Verify all properties are correctly transferred - self.assertEqual(message.value, test_value) - self.assertEqual(message.keys, test_keys) - self.assertEqual(message.event_time, test_event_time) - self.assertEqual(message.watermark, test_watermark) - self.assertEqual(message.headers, test_headers) - self.assertEqual(message.id, test_id) - - # Verify that tags are empty (default for Message) - self.assertEqual(message.tags, []) - - def test_from_datum_minimal(self): - """Test from_datum with minimal Datum (no headers)""" - test_keys = ["minimal_key"] - test_value = b"minimal_value" - test_event_time = mock_event_time() - test_watermark = mock_watermark() - test_id = "minimal_id" - - datum = Datum( - keys=test_keys, - value=test_value, - event_time=test_event_time, - watermark=test_watermark, - id_=test_id, - # headers not provided (will default to {}) - ) +def test_interval_window_end(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert mock_end_time() == i.end + + +# --- TestKeyedWindow --- + + +def test_keyed_window_create(): + kw = KeyedWindow( + start=mock_start_time(), end=mock_end_time(), slot="slot-0", keys=["key1", "key2"] + ) + assert kw.start == mock_start_time() + assert kw.end == mock_end_time() + assert kw.slot == "slot-0" + assert kw.keys == ["key1", "key2"] + + +def test_keyed_window_default_values(): + kw = KeyedWindow(start=mock_start_time(), end=mock_end_time()) + assert kw.slot == "" + assert kw.keys == [] + + +def test_keyed_window_window_property(): + kw = KeyedWindow(start=mock_start_time(), end=mock_end_time()) + assert isinstance(kw.window, IntervalWindow) + assert kw.window.start == mock_start_time() + assert kw.window.end == mock_end_time() + + +# --- TestAccumulatorResult --- + + +def test_accumulator_result_create(): + # Create mock objects + future = None # In real usage, this would be an asyncio.Task + iterator = NonBlockingIterator() + keys = ["key1", "key2"] + result_queue = NonBlockingIterator() + consumer_future = None # In real usage, this would be an asyncio.Task + watermark = datetime.fromtimestamp(1662998400, timezone.utc) + + result = AccumulatorResult(future, iterator, keys, result_queue, consumer_future, watermark) + + assert result.future == future + assert result.iterator == iterator + assert result.keys == keys + assert result.result_queue == result_queue + assert result.consumer_future == consumer_future + assert result.latest_watermark == watermark + + +def test_accumulator_result_update_watermark(): + result = AccumulatorResult( + None, None, [], None, None, datetime.fromtimestamp(1662998400, timezone.utc) + ) + new_watermark = datetime.fromtimestamp(1662998460, timezone.utc) + result.update_watermark(new_watermark) + assert result.latest_watermark == new_watermark + + +def test_accumulator_result_update_watermark_invalid_type(): + result = AccumulatorResult( + None, None, [], None, None, datetime.fromtimestamp(1662998400, timezone.utc) + ) + with pytest.raises(TypeError): + result.update_watermark("not a datetime") + + +# --- TestAccumulatorRequest --- + + +def test_accumulator_request_create(): + operation = WindowOperation.OPEN + keyed_window = KeyedWindow(start=mock_start_time(), end=mock_end_time()) + payload = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id_=TEST_ID, + ) + + request = AccumulatorRequest(operation, keyed_window, payload) + assert request.operation == operation + assert request.keyed_window == keyed_window + assert request.payload == payload + + +# --- TestWindowOperation --- + + +def test_window_operation_enum_values(): + assert WindowOperation.OPEN == 0 + assert WindowOperation.CLOSE == 1 + assert WindowOperation.APPEND == 2 + + +# --- TestMessage --- + + +def test_message_create(): + value = b"test_value" + keys = ["key1", "key2"] + tags = ["tag1", "tag2"] + + msg = Message(value=value, keys=keys, tags=tags) + assert msg.value == value + assert msg.keys == keys + assert msg.tags == tags + + +def test_message_default_values(): + msg = Message(value=b"test") + assert msg.keys == [] + assert msg.tags == [] + + +def test_message_to_drop(): + msg = Message.to_drop() + assert msg.value == b"" + assert msg.keys == [] + assert "U+005C__DROP__" in msg.tags + + +def test_message_none_values(): + msg = Message(value=None, keys=None, tags=None) + assert msg.value == b"" + assert msg.keys == [] + assert msg.tags == [] + + +def test_message_from_datum(): + """Test that Message.from_datum correctly creates a Message from a Datum""" + # Create a sample datum with all properties + test_keys = ["key1", "key2"] + test_value = b"test_message_value" + test_event_time = mock_event_time() + test_watermark = mock_watermark() + test_headers = {"header1": "value1", "header2": "value2"} + test_id = "test_datum_id" + + datum = Datum( + keys=test_keys, + value=test_value, + event_time=test_event_time, + watermark=test_watermark, + id_=test_id, + headers=test_headers, + ) + + # Create message from datum + message = Message.from_datum(datum) + + # Verify all properties are correctly transferred + assert message.value == test_value + assert message.keys == test_keys + assert message.event_time == test_event_time + assert message.watermark == test_watermark + assert message.headers == test_headers + assert message.id == test_id + + # Verify that tags are empty (default for Message) + assert message.tags == [] + + +def test_message_from_datum_minimal(): + """Test from_datum with minimal Datum (no headers)""" + test_keys = ["minimal_key"] + test_value = b"minimal_value" + test_event_time = mock_event_time() + test_watermark = mock_watermark() + test_id = "minimal_id" + + datum = Datum( + keys=test_keys, + value=test_value, + event_time=test_event_time, + watermark=test_watermark, + id_=test_id, + # headers not provided (will default to {}) + ) + + message = Message.from_datum(datum) - message = Message.from_datum(datum) - - self.assertEqual(message.value, test_value) - self.assertEqual(message.keys, test_keys) - self.assertEqual(message.event_time, test_event_time) - self.assertEqual(message.watermark, test_watermark) - self.assertEqual(message.headers, {}) - self.assertEqual(message.id, test_id) - self.assertEqual(message.tags, []) - - def test_from_datum_empty_keys(self): - """Test from_datum with empty keys""" - datum = Datum( - keys=None, # Will default to [] - value=b"test_value", + assert message.value == test_value + assert message.keys == test_keys + assert message.event_time == test_event_time + assert message.watermark == test_watermark + assert message.headers == {} + assert message.id == test_id + assert message.tags == [] + + +def test_message_from_datum_empty_keys(): + """Test from_datum with empty keys""" + datum = Datum( + keys=None, # Will default to [] + value=b"test_value", + event_time=mock_event_time(), + watermark=mock_watermark(), + id_="test_id", + ) + + message = Message.from_datum(datum) + + assert message.keys == [] + assert message.value == b"test_value" + assert message.id == "test_id" + + +# --- TestAccumulatorClass --- + + +class _ExampleAccumulator(Accumulator): + async def handler(self, datums: AsyncIterable[Datum], output: NonBlockingIterator): + pass + + def __init__(self, test1, test2): + self.test1 = test1 + self.test2 = test2 + self.test3 = self.test1 + + +def test_accumulator_class_init(): + r = _ExampleAccumulator(test1=1, test2=2) + assert 1 == r.test1 + assert 2 == r.test2 + assert 1 == r.test3 + + +def test_accumulator_class_callable(): + """Test that accumulator instances can be called directly""" + r = _ExampleAccumulator(test1=1, test2=2) + # The __call__ method should be callable and delegate to the handler method + assert callable(r) + # __call__ should return the result of calling handler + # Since handler is an async method, __call__ should return a coroutine + + async def test_datums(): + yield Datum( + keys=["test"], + value=b"test", event_time=mock_event_time(), watermark=mock_watermark(), - id_="test_id", + id_="test", ) - message = Message.from_datum(datum) - - self.assertEqual(message.keys, []) - self.assertEqual(message.value, b"test_value") - self.assertEqual(message.id, "test_id") - - -class TestAccumulatorClass(unittest.TestCase): - class ExampleClass(Accumulator): - async def handler(self, datums: AsyncIterable[Datum], output: NonBlockingIterator): - pass - - def __init__(self, test1, test2): - self.test1 = test1 - self.test2 = test2 - self.test3 = self.test1 - - def test_init(self): - r = self.ExampleClass(test1=1, test2=2) - self.assertEqual(1, r.test1) - self.assertEqual(2, r.test2) - self.assertEqual(1, r.test3) - - def test_callable(self): - """Test that accumulator instances can be called directly""" - r = self.ExampleClass(test1=1, test2=2) - # The __call__ method should be callable and delegate to the handler method - self.assertTrue(callable(r)) - # __call__ should return the result of calling handler - # Since handler is an async method, __call__ should return a coroutine - import asyncio - from pynumaflow.shared.asynciter import NonBlockingIterator - - async def test_datums(): - yield Datum( - keys=["test"], - value=b"test", - event_time=mock_event_time(), - watermark=mock_watermark(), - id_="test", - ) - - output = NonBlockingIterator() - result = r(test_datums(), output) - self.assertTrue(asyncio.iscoroutine(result)) - # Clean up the coroutine - result.close() - - -if __name__ == "__main__": - unittest.main() + output = NonBlockingIterator() + result = r(test_datums(), output) + assert asyncio.iscoroutine(result) + # Clean up the coroutine + result.close() diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py index 7ea32a07..cdfc2003 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py @@ -1,12 +1,11 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio._server import Server from pynumaflow import setup_logging from pynumaflow.batchmapper import ( @@ -24,10 +23,6 @@ listen_addr = "unix:///tmp/batch_map.sock" -_s: Server = None -_channel = grpc.insecure_channel(listen_addr) -_loop = None - def startup_callable(loop): asyncio.set_event_loop(loop) @@ -88,106 +83,101 @@ async def start_server(udfs): map_pb2_grpc.add_MapServicer_to_server(udfs, server) server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) - global _s - _s = server await server.start() await server.wait_for_termination() -class TestAsyncBatchMapper(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncBatchMapper() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel(listen_addr) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: +@pytest.fixture(scope="module") +def async_batch_map_server(): + """Module-scoped fixture: starts an async gRPC batch map server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncBatchMapper() + asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) + + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(listen_addr) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_batch_map(self) -> None: - stub = self.__stub() - generator_response = None + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def batch_map_stub(async_batch_map_server): + """Returns a MapStub connected to the running async batch map server.""" + return map_pb2_grpc.MapStub(grpc.insecure_channel(listen_addr)) + + +def test_batch_map(batch_map_stub) -> None: + generator_response = None + try: + generator_response = batch_map_stub.MapFn( + request_iterator=request_generator(count=10, session=1) + ) + except grpc.RpcError as e: + logging.error(e) + + handshake = next(generator_response) + # assert that handshake response is received. + assert handshake.handshake.sot + data_resp = [] + for r in generator_response: + data_resp.append(r) + + idx = 0 + while idx < len(data_resp) - 1: + assert ( + bytes( + "test_mock_message", + encoding="utf-8", + ) + == data_resp[idx].results[0].value + ) + _id = data_resp[idx].id + assert _id == "test-id-" + str(idx) + idx += 1 + # EOT Response + assert data_resp[len(data_resp) - 1].status.eot is True + # 10 sink responses + 1 EOT response + assert 11 == len(data_resp) + + +def test_is_ready(async_batch_map_server) -> None: + with grpc.insecure_channel(listen_addr) as channel: + stub = map_pb2_grpc.MapStub(channel) + + request = _empty_pb2.Empty() + response = None try: - generator_response = stub.MapFn(request_iterator=request_generator(count=10, session=1)) + response = stub.IsReady(request=request) except grpc.RpcError as e: logging.error(e) - handshake = next(generator_response) - # assert that handshake response is received. - self.assertTrue(handshake.handshake.sot) - data_resp = [] - for r in generator_response: - data_resp.append(r) - - idx = 0 - while idx < len(data_resp) - 1: - self.assertEqual( - bytes( - "test_mock_message", - encoding="utf-8", - ), - data_resp[idx].results[0].value, - ) - _id = data_resp[idx].id - self.assertEqual(_id, "test-id-" + str(idx)) - # capture the output from the SinkFn generator and assert. - # self.assertEqual(data_resp[idx].result.status, sink_pb2.Status.SUCCESS) - idx += 1 - # EOT Response - self.assertEqual(data_resp[len(data_resp) - 1].status.eot, True) - # 10 sink responses + 1 EOT response - self.assertEqual(11, len(data_resp)) - - def test_is_ready(self) -> None: - with grpc.insecure_channel(listen_addr) as channel: - stub = map_pb2_grpc.MapStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def test_max_threads(self): - # max cap at 16 - server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = BatchMapAsyncServer(batch_mapper_instance=handler) - self.assertEqual(server.max_threads, 4) - - def __stub(self): - return map_pb2_grpc.MapStub(_channel) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + assert response.ready + + +def test_max_threads(): + # max cap at 16 + server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = BatchMapAsyncServer(batch_mapper_instance=handler) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py index 5f8162f9..8c6795d0 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py @@ -1,19 +1,15 @@ import asyncio import logging import threading -import unittest -from unittest.mock import patch import grpc - -from grpc.aio._server import Server +import pytest from pynumaflow import setup_logging from pynumaflow.batchmapper import BatchResponses from pynumaflow.batchmapper import BatchMapAsyncServer from pynumaflow.proto.mapper import map_pb2_grpc from tests.batchmap.utils import request_generator -from tests.testing_utils import mock_terminate_on_stop LOGGER = setup_logging(__name__) @@ -30,18 +26,12 @@ async def err_handler(datums) -> BatchResponses: listen_addr = "unix:///tmp/async_batch_map_err.sock" -_s: Server = None -_channel = grpc.insecure_channel(listen_addr) -_loop = None - def startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) async def start_server(): server = grpc.aio.server() server_instance = BatchMapAsyncServer(err_handler) @@ -49,82 +39,74 @@ async def start_server(): map_pb2_grpc.add_MapServicer_to_server(udfs, server) server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) - global _s - _s = server await server.start() await server.wait_for_termination() -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncServerErrorScenario(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - asyncio.run_coroutine_threadsafe(start_server(), loop=loop) - while True: - try: - with grpc.insecure_channel(listen_addr) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: +@pytest.fixture(scope="module") +def async_batch_map_err_server(): + """Module-scoped fixture: starts an async gRPC batch map error server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + asyncio.run_coroutine_threadsafe(start_server(), loop=loop) + + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(listen_addr) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_batch_map_error(self) -> None: - global raise_error - raise_error = True - stub = self.__stub() - try: - generator_response = stub.MapFn( - request_iterator=request_generator(count=10, handshake=True, session=1) - ) - counter = 0 - for _ in generator_response: - counter += 1 - except Exception as err: - self.assertTrue("Got a runtime error from batch map handler." in err.__str__()) - return - self.fail("Expected an exception.") - - def test_batch_map_error_no_handshake(self) -> None: - global raise_error - raise_error = True - stub = self.__stub() - try: - generator_response = stub.MapFn( - request_iterator=request_generator(count=10, handshake=False, session=1) - ) - counter = 0 - for _ in generator_response: - counter += 1 - except Exception as err: - self.assertTrue("BatchMapFn: expected handshake as the first message" in err.__str__()) - return - self.fail("Expected an exception.") - - def __stub(self): - return map_pb2_grpc.MapStub(_channel) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - BatchMapAsyncServer() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def batch_map_err_stub(async_batch_map_err_server): + """Returns a MapStub connected to the running async batch map error server.""" + return map_pb2_grpc.MapStub(grpc.insecure_channel(listen_addr)) + + +def test_batch_map_error(batch_map_err_stub) -> None: + global raise_error + raise_error = True + try: + generator_response = batch_map_err_stub.MapFn( + request_iterator=request_generator(count=10, handshake=True, session=1) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + assert "Got a runtime error from batch map handler." in str(err) + return + pytest.fail("Expected an exception.") + + +def test_batch_map_error_no_handshake(batch_map_err_stub) -> None: + global raise_error + raise_error = True + try: + generator_response = batch_map_err_stub.MapFn( + request_iterator=request_generator(count=10, handshake=False, session=1) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + assert "BatchMapFn: expected handshake as the first message" in str(err) + return + pytest.fail("Expected an exception.") + + +def test_invalid_input(): + with pytest.raises(TypeError): + BatchMapAsyncServer() diff --git a/packages/pynumaflow/tests/batchmap/test_datatypes.py b/packages/pynumaflow/tests/batchmap/test_datatypes.py index 06fb3624..39a0ac55 100644 --- a/packages/pynumaflow/tests/batchmap/test_datatypes.py +++ b/packages/pynumaflow/tests/batchmap/test_datatypes.py @@ -1,4 +1,4 @@ -import unittest +import pytest from google.protobuf import timestamp_pb2 as _timestamp_pb2 @@ -16,98 +16,95 @@ TEST_HEADERS = {"key1": "value1", "key2": "value2"} -class TestDatum(unittest.TestCase): - def test_err_event_time(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=ts, - watermark=ts, - headers=TEST_HEADERS, - id=TEST_ID, - ) - self.assertEqual( - "Wrong data type: " - "for Datum.event_time", - str(context.exception), - ) - - def test_err_watermark(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=ts, - headers=TEST_HEADERS, - id=TEST_ID, - ) - self.assertEqual( - "Wrong data type: " - "for Datum.watermark", - str(context.exception), - ) - - def test_value(self): - d = Datum( +def test_datum_err_event_time(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with pytest.raises(Exception) as exc_info: + Datum( keys=TEST_KEYS, value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), + event_time=ts, + watermark=ts, headers=TEST_HEADERS, id=TEST_ID, ) - self.assertEqual(mock_message(), d.value) + assert ( + "Wrong data type: " "for Datum.event_time" + ) == str(exc_info.value) - def test_key(self): - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - id=TEST_ID, - ) - self.assertEqual(TEST_KEYS, d.keys) - def test_event_time(self): - d = Datum( +def test_datum_err_watermark(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with pytest.raises(Exception) as exc_info: + Datum( keys=TEST_KEYS, value=mock_message(), event_time=mock_event_time(), - watermark=mock_watermark(), + watermark=ts, headers=TEST_HEADERS, id=TEST_ID, ) - self.assertEqual(mock_event_time(), d.event_time) - self.assertEqual(TEST_HEADERS, d.headers) + assert ( + "Wrong data type: " "for Datum.watermark" + ) == str(exc_info.value) - def test_watermark(self): - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - id=TEST_ID, - ) - self.assertEqual(mock_watermark(), d.watermark) - self.assertEqual({}, d.headers) - def test_id(self): - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - id=TEST_ID, - ) - self.assertEqual(TEST_ID, d.id) - self.assertEqual({}, d.headers) +def test_datum_value(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=TEST_HEADERS, + id=TEST_ID, + ) + assert mock_message() == d.value + + +def test_datum_key(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id=TEST_ID, + ) + assert TEST_KEYS == d.keys + + +def test_datum_event_time(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=TEST_HEADERS, + id=TEST_ID, + ) + assert mock_event_time() == d.event_time + assert TEST_HEADERS == d.headers + + +def test_datum_watermark(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id=TEST_ID, + ) + assert mock_watermark() == d.watermark + assert {} == d.headers -if __name__ == "__main__": - unittest.main() +def test_datum_id(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id=TEST_ID, + ) + assert TEST_ID == d.id + assert {} == d.headers diff --git a/packages/pynumaflow/tests/batchmap/test_messages.py b/packages/pynumaflow/tests/batchmap/test_messages.py index 351db3b5..2a147d0c 100644 --- a/packages/pynumaflow/tests/batchmap/test_messages.py +++ b/packages/pynumaflow/tests/batchmap/test_messages.py @@ -1,99 +1,94 @@ -import unittest +import pytest from pynumaflow.batchmapper import Message, DROP, BatchResponse, BatchResponses from tests.batchmap.test_datatypes import TEST_ID from tests.testing_utils import mock_message -class TestBatchResponses(unittest.TestCase): - @staticmethod - def mock_message_object(): - value = mock_message() - return Message(value=value) - - def test_init(self): - batch_responses = BatchResponses() - batch_response1 = BatchResponse.from_id(TEST_ID) - batch_response2 = BatchResponse.from_id(TEST_ID + "2") - batch_responses.append(batch_response1) - batch_responses.append(batch_response2) - self.assertEqual(2, len(batch_responses)) - # test indexing - self.assertEqual(batch_responses[0].id, TEST_ID) - self.assertEqual(batch_responses[1].id, TEST_ID + "2") - # test slicing - resp = batch_responses[0:1] - self.assertEqual(resp[0].id, TEST_ID) - - -class TestBatchResponse(unittest.TestCase): - @staticmethod - def mock_message_object(): - value = mock_message() - return Message(value=value) - - def test_init(self): - batch_response = BatchResponse.from_id(TEST_ID) - self.assertEqual(batch_response.id, TEST_ID) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - BatchResponse() - - def test_append(self): - batch_response = BatchResponse.from_id(TEST_ID) - self.assertEqual(0, len(batch_response.items())) - batch_response.append(self.mock_message_object()) - self.assertEqual(1, len(batch_response.items())) - batch_response.append(self.mock_message_object()) - self.assertEqual(2, len(batch_response.items())) - - def test_items(self): - mock_obj = [ - mock_message(), - mock_message(), - ] - msgs = BatchResponse.with_msgs(TEST_ID, mock_obj) - self.assertEqual(len(mock_obj), len(msgs.items())) - self.assertEqual(mock_obj[0], msgs.items()[0]) - - -class TestMessage(unittest.TestCase): - def test_key(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) - print(msg) - self.assertEqual(mock_obj["Keys"], msg.keys) - - def test_value(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) - self.assertEqual(mock_obj["Value"], msg.value) - - def test_message_to_all(self): - mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} - msg = Message(mock_obj["Value"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to_drop(self): - mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} - msg = Message(b"").to_drop() - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to(self): - mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - -if __name__ == "__main__": - unittest.main() +def _mock_message_object(): + value = mock_message() + return Message(value=value) + + +def test_batch_responses_init(): + batch_responses = BatchResponses() + batch_response1 = BatchResponse.from_id(TEST_ID) + batch_response2 = BatchResponse.from_id(TEST_ID + "2") + batch_responses.append(batch_response1) + batch_responses.append(batch_response2) + assert 2 == len(batch_responses) + # test indexing + assert batch_responses[0].id == TEST_ID + assert batch_responses[1].id == TEST_ID + "2" + # test slicing + resp = batch_responses[0:1] + assert resp[0].id == TEST_ID + + +def test_batch_response_init(): + batch_response = BatchResponse.from_id(TEST_ID) + assert batch_response.id == TEST_ID + + +def test_batch_response_invalid_input(): + with pytest.raises(TypeError): + BatchResponse() + + +def test_batch_response_append(): + batch_response = BatchResponse.from_id(TEST_ID) + assert 0 == len(batch_response.items()) + batch_response.append(_mock_message_object()) + assert 1 == len(batch_response.items()) + batch_response.append(_mock_message_object()) + assert 2 == len(batch_response.items()) + + +def test_batch_response_items(): + mock_obj = [ + mock_message(), + mock_message(), + ] + msgs = BatchResponse.with_msgs(TEST_ID, mock_obj) + assert len(mock_obj) == len(msgs.items()) + assert mock_obj[0] == msgs.items()[0] + + +def test_message_key(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + print(msg) + assert mock_obj["Keys"] == msg.keys + + +def test_message_value(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + assert mock_obj["Value"] == msg.value + + +def test_message_to_all(): + mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} + msg = Message(mock_obj["Value"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to_drop(): + mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} + msg = Message(b"").to_drop() + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to(): + mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags diff --git a/packages/pynumaflow/tests/errors/test_dtypes.py b/packages/pynumaflow/tests/errors/test_dtypes.py index 90542f7c..902fcb6c 100644 --- a/packages/pynumaflow/tests/errors/test_dtypes.py +++ b/packages/pynumaflow/tests/errors/test_dtypes.py @@ -1,72 +1,68 @@ -import unittest from pynumaflow.errors._dtypes import _RuntimeErrorEntry -class TestRuntimeErrorEntry(unittest.TestCase): - def test_runtime_error_entry_initialization(self): - """ - Test that _RuntimeErrorEntry initializes correctly with given values. - """ - container = "test-container" - timestamp = 1680700000 - code = "500" - message = "Test error message" - details = "Test error details" +def test_runtime_error_entry_initialization(): + """ + Test that _RuntimeErrorEntry initializes correctly with given values. + """ + container = "test-container" + timestamp = 1680700000 + code = "500" + message = "Test error message" + details = "Test error details" - error_entry = _RuntimeErrorEntry(container, timestamp, code, message, details) + error_entry = _RuntimeErrorEntry(container, timestamp, code, message, details) - self.assertEqual(error_entry.container, container) - self.assertEqual(error_entry.timestamp, timestamp) - self.assertEqual(error_entry.code, code) - self.assertEqual(error_entry.message, message) - self.assertEqual(error_entry.details, details) + assert error_entry.container == container + assert error_entry.timestamp == timestamp + assert error_entry.code == code + assert error_entry.message == message + assert error_entry.details == details - def test_runtime_error_entry_to_dict(self): - """ - Test that _RuntimeErrorEntry converts to a dictionary correctly. - """ - container = "test-container" - timestamp = 1680700000 - code = "500" - message = "Test error message" - details = "Test error details" - error_entry = _RuntimeErrorEntry(container, timestamp, code, message, details) - error_dict = error_entry.to_dict() +def test_runtime_error_entry_to_dict(): + """ + Test that _RuntimeErrorEntry converts to a dictionary correctly. + """ + container = "test-container" + timestamp = 1680700000 + code = "500" + message = "Test error message" + details = "Test error details" - expected_dict = { - "container": container, - "timestamp": timestamp, - "code": code, - "message": message, - "details": details, - } + error_entry = _RuntimeErrorEntry(container, timestamp, code, message, details) + error_dict = error_entry.to_dict() - self.assertEqual(error_dict, expected_dict) + expected_dict = { + "container": container, + "timestamp": timestamp, + "code": code, + "message": message, + "details": details, + } - def test_runtime_error_entry_empty_values(self): - """ - Test that _RuntimeErrorEntry handles empty values correctly. - """ - container = "" - timestamp = 0 - code = "" - message = "" - details = "" + assert error_dict == expected_dict - error_entry = _RuntimeErrorEntry(container, timestamp, code, message, details) - error_dict = error_entry.to_dict() - expected_dict = { - "container": container, - "timestamp": timestamp, - "code": code, - "message": message, - "details": details, - } +def test_runtime_error_entry_empty_values(): + """ + Test that _RuntimeErrorEntry handles empty values correctly. + """ + container = "" + timestamp = 0 + code = "" + message = "" + details = "" - self.assertEqual(error_dict, expected_dict) + error_entry = _RuntimeErrorEntry(container, timestamp, code, message, details) + error_dict = error_entry.to_dict() + expected_dict = { + "container": container, + "timestamp": timestamp, + "code": code, + "message": message, + "details": details, + } -if __name__ == "__main__": - unittest.main() + assert error_dict == expected_dict diff --git a/packages/pynumaflow/tests/errors/test_persist_critical_error.py b/packages/pynumaflow/tests/errors/test_persist_critical_error.py index f0787a85..0eb81c78 100644 --- a/packages/pynumaflow/tests/errors/test_persist_critical_error.py +++ b/packages/pynumaflow/tests/errors/test_persist_critical_error.py @@ -2,133 +2,125 @@ import json import shutil import threading -import unittest + +import pytest + from pynumaflow.errors.errors import persist_critical_error, _persist_error_once from pynumaflow.errors.errors import _persist_critical_error_to_file from pynumaflow._constants import CONTAINER_TYPE, INTERNAL_ERROR_CODE -class TestErrorPersistence(unittest.TestCase): - def setUp(self): - """ - Set up temporary directories for tests. - """ - self.test_dirs = ["/tmp/test_error_dir", "/tmp/test_dir"] - - def tearDown(self): - """ - Clean up temporary directories after tests. - """ - for dir_path in self.test_dirs: - if os.path.exists(dir_path): - shutil.rmtree(dir_path) - - # Writes error details to a JSON file - def test_writes_error_details_to_json_file(self): - """ - Test that _persist_critical_error_to_file writes error details to a JSON file. - """ - - dir_path = self.test_dirs[0] - - error_code = "500" - error_message = "Server Error" - error_details = "An unexpected error occurred." - - _persist_critical_error_to_file(error_code, error_message, error_details, dir_path) - - container_dir = os.path.join(dir_path, CONTAINER_TYPE) - self.assertTrue(os.path.exists(container_dir)) - - # Debug: Check directory after the function call - print(f"After: {os.listdir(container_dir)}") - - files = os.listdir(container_dir) - self.assertEqual(len(files), 1) - - final_file_name = files[0] - final_file_path = os.path.join(container_dir, final_file_name) - - with open(final_file_path) as f: - data = json.load(f) - - self.assertEqual(data["code"], error_code) - self.assertEqual(data["message"], error_message) - self.assertEqual(data["details"], error_details) - self.assertEqual(data["container"], CONTAINER_TYPE) - self.assertTrue(isinstance(data["timestamp"], int)) - - # Uses default error code if none provided - def test_uses_default_error_code_if_none_provided(self): - """ - Test that _persist_critical_error_to_file uses the default error code if none is provided. - """ - dir_path = self.test_dirs[1] - - _persist_critical_error_to_file("", "Error Message", "Error Details", dir_path) - - container_dir = os.path.join(dir_path, "unknown-container") - self.assertTrue(os.path.exists(container_dir)) - - files = os.listdir(container_dir) - self.assertEqual(len(files), 1) - - with open(os.path.join(container_dir, files[0])) as f: - error_data = json.load(f) - self.assertEqual(error_data["code"], INTERNAL_ERROR_CODE) - - def test_persist_critical_error_all_threads_fail(self): - """ - Test that all threads fail when persist_critical_error is executed after the first call. - """ - error_code = "testCode" - error_message = "testMessage" - error_details = "testDetails" - - # Set `done` to True to simulate that the critical error has already been persisted - _persist_error_once.done = True - - try: - # Set up threading - num_threads = 10 - errors = [] - lock = threading.Lock() - - def thread_func(): - nonlocal errors - result = persist_critical_error(error_code, error_message, error_details) - with lock: - errors.append(result) - - # Create and start threads - threads = [] - for _ in range(num_threads): - thread = threading.Thread(target=thread_func) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Count the number of failures - fail_count = sum( - 1 - for error in errors - if error is not None - and "Persist critical error function has already been executed" in str(error) - ) - - # Assert that all threads failed - self.assertEqual( - fail_count, - num_threads, - f"Expected all {num_threads} threads to fail, but only {fail_count} failed", - ) - finally: - # Revert `done` back to False after the test - _persist_error_once.done = False - - -if __name__ == "__main__": - unittest.main() +@pytest.fixture() +def test_dirs(): + """ + Provide temporary directories for tests and clean them up afterwards. + """ + dirs = ["/tmp/test_error_dir", "/tmp/test_dir"] + yield dirs + for dir_path in dirs: + if os.path.exists(dir_path): + shutil.rmtree(dir_path) + + +def test_writes_error_details_to_json_file(test_dirs): + """ + Test that _persist_critical_error_to_file writes error details to a JSON file. + """ + dir_path = test_dirs[0] + + error_code = "500" + error_message = "Server Error" + error_details = "An unexpected error occurred." + + _persist_critical_error_to_file(error_code, error_message, error_details, dir_path) + + container_dir = os.path.join(dir_path, CONTAINER_TYPE) + assert os.path.exists(container_dir) + + # Debug: Check directory after the function call + print(f"After: {os.listdir(container_dir)}") + + files = os.listdir(container_dir) + assert len(files) == 1 + + final_file_name = files[0] + final_file_path = os.path.join(container_dir, final_file_name) + + with open(final_file_path) as f: + data = json.load(f) + + assert data["code"] == error_code + assert data["message"] == error_message + assert data["details"] == error_details + assert data["container"] == CONTAINER_TYPE + assert isinstance(data["timestamp"], int) + + +def test_uses_default_error_code_if_none_provided(test_dirs): + """ + Test that _persist_critical_error_to_file uses the default error code if none is provided. + """ + dir_path = test_dirs[1] + + _persist_critical_error_to_file("", "Error Message", "Error Details", dir_path) + + container_dir = os.path.join(dir_path, "unknown-container") + assert os.path.exists(container_dir) + + files = os.listdir(container_dir) + assert len(files) == 1 + + with open(os.path.join(container_dir, files[0])) as f: + error_data = json.load(f) + assert error_data["code"] == INTERNAL_ERROR_CODE + + +def test_persist_critical_error_all_threads_fail(): + """ + Test that all threads fail when persist_critical_error is executed after the first call. + """ + error_code = "testCode" + error_message = "testMessage" + error_details = "testDetails" + + # Set `done` to True to simulate that the critical error has already been persisted + _persist_error_once.done = True + + try: + # Set up threading + num_threads = 10 + errors = [] + lock = threading.Lock() + + def thread_func(): + nonlocal errors + result = persist_critical_error(error_code, error_message, error_details) + with lock: + errors.append(result) + + # Create and start threads + threads = [] + for _ in range(num_threads): + thread = threading.Thread(target=thread_func) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Count the number of failures + fail_count = sum( + 1 + for error in errors + if error is not None + and "Persist critical error function has already been executed" in str(error) + ) + + # Assert that all threads failed + assert ( + fail_count == num_threads + ), f"Expected all {num_threads} threads to fail, but only {fail_count} failed" + finally: + # Revert `done` back to False after the test + _persist_error_once.done = False diff --git a/packages/pynumaflow/tests/map/test_async_mapper.py b/packages/pynumaflow/tests/map/test_async_mapper.py index 59367079..42e1bce6 100644 --- a/packages/pynumaflow/tests/map/test_async_mapper.py +++ b/packages/pynumaflow/tests/map/test_async_mapper.py @@ -2,12 +2,10 @@ import logging import threading from collections.abc import Iterator -import unittest -from unittest.mock import patch import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio import Server from pynumaflow import setup_logging from pynumaflow._constants import MAX_MESSAGE_SIZE @@ -20,15 +18,14 @@ from pynumaflow.proto.common import metadata_pb2 from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from tests.map.utils import get_test_datums -from tests.testing_utils import ( - mock_terminate_on_stop, -) LOGGER = setup_logging(__name__) # if set to true, map handler will raise a `ValueError` exception. raise_error_from_map = False +SOCK_PATH = "unix:///tmp/async_map.sock" + def request_generator(req): yield from req @@ -43,7 +40,6 @@ async def async_map_handler(keys: list[str], datum: Datum) -> Messages: datum.event_time, datum.watermark, ) - val = bytes(msg, encoding="utf-8") messages = Messages() if datum.system_metadata.value("numaflow_version_info", "version") != b"1.0.0": raise ValueError("System metadata version mismatch") @@ -51,222 +47,174 @@ async def async_map_handler(keys: list[str], datum: Datum) -> Messages: return messages -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/async_map.sock") -_loop = None - - -def startup_callable(loop): +def _startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() -def new_async_mapper(): - server = MapAsyncServer(mapper_instance=async_map_handler) - udfs = server.servicer - return udfs - - -async def start_server(udfs): +async def _start_server(udfs): _server_options = [ ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), ] server = grpc.aio.server(options=_server_options) map_pb2_grpc.add_MapServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/async_map.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncMapper(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = new_async_mapper() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/async_map.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_map_server(): + """Module-scoped fixture: starts an async gRPC map server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) + thread.start() + + server_obj = MapAsyncServer(mapper_instance=async_map_handler) + udfs = server_obj.servicer + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + _server = future.result(timeout=10) + + # Wait for the server to be ready + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_run_server(self) -> None: - with grpc.insecure_channel("unix:///tmp/async_map.sock") as channel: - stub = map_pb2_grpc.MapStub(channel) - request = get_test_datums() - generator_response = None - try: - generator_response = stub.MapFn(request_iterator=request_generator(request)) - except grpc.RpcError as e: - logging.error(e) - - responses = [] - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) - - idx = 1 - while idx < len(responses): - _id = "test-id-" + str(idx) - self.assertEqual(_id, responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " - "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", - encoding="utf-8", - ), - responses[idx].results[0].value, - ) - self.assertEqual(1, len(responses[idx].results)) - idx += 1 - LOGGER.info("Successfully validated the server") - - def test_map(self) -> None: - stub = map_pb2_grpc.MapStub(_channel) + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def map_stub(async_map_server): + """Returns a MapStub connected to the running async server.""" + return map_pb2_grpc.MapStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_run_server(async_map_server): + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = map_pb2_grpc.MapStub(channel) request = get_test_datums() - try: - generator_response: Iterator[map_pb2.MapResponse] = stub.MapFn( - request_iterator=request_generator(request) + generator_response = stub.MapFn(request_iterator=request_generator(request)) + + responses = list(generator_response) + + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot + + idx = 1 + while idx < len(responses): + assert responses[idx].id == "test-id-" + str(idx) + assert responses[idx].results[0].value == bytes( + "payload:test_mock_message " + "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", + encoding="utf-8", ) - except grpc.RpcError as e: - logging.error(e) - raise + assert len(responses[idx].results) == 1 + idx += 1 + LOGGER.info("Successfully validated the server") - responses: list[map_pb2.MapResponse] = [] - # capture the output from the ReadFn generator and assert. + +def test_map(map_stub): + request = get_test_datums() + generator_response: Iterator[map_pb2.MapResponse] = map_stub.MapFn( + request_iterator=request_generator(request) + ) + + responses: list[map_pb2.MapResponse] = list(generator_response) + + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot + + for idx, resp in enumerate(responses[1:], 1): + assert resp.id == "test-id-" + str(idx) + assert resp.results[0].value == bytes( + "payload:test_mock_message " + "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", + encoding="utf-8", + ) + assert len(resp.results) == 1 + assert resp.results[0].metadata.user_metadata["custom_info"] == metadata_pb2.KeyValueGroup( + key_value={"version": f"{idx}.0.0".encode()} + ) + # System metadata will be empty for user responses + assert resp.results[0].metadata.sys_metadata == {} + + +def test_map_grpc_error_no_handshake(map_stub): + request = get_test_datums(handshake=False) + grpc_exception = None + + responses = [] + try: + generator_response = map_stub.MapFn(request_iterator=request_generator(request)) for r in generator_response: responses.append(r) + except grpc.RpcError as e: + logging.error(e) + grpc_exception = e + assert "MapFn: expected handshake as the first message" in str(e) - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) - - for idx, resp in enumerate(responses[1:], 1): - _id = "test-id-" + str(idx) - self.assertEqual(_id, resp.id) - self.assertEqual( - bytes( - "payload:test_mock_message " - "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", - encoding="utf-8", - ), - resp.results[0].value, - ) - self.assertEqual(1, len(resp.results)) - self.assertEqual( - resp.results[0].metadata.user_metadata["custom_info"], - metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}), - ) - # System metadata will be empty for user responses - self.assertEqual(resp.results[0].metadata.sys_metadata, {}) + assert len(responses) == 0 + assert grpc_exception is not None - def test_map_grpc_error_no_handshake(self) -> None: - stub = map_pb2_grpc.MapStub(_channel) - request = get_test_datums(handshake=False) - grpc_exception = None - responses = [] - try: - generator_response = stub.MapFn(request_iterator=request_generator(request)) - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - except grpc.RpcError as e: - logging.error(e) - grpc_exception = e - self.assertTrue("MapFn: expected handshake as the first message" in e.__str__()) - - self.assertEqual(0, len(responses)) - self.assertIsNotNone(grpc_exception) - - def test_map_grpc_error(self) -> None: - stub = map_pb2_grpc.MapStub(_channel) - request = get_test_datums() - grpc_exception = None +def test_map_grpc_error(map_stub): + request = get_test_datums() + grpc_exception = None - responses = [] - try: - global raise_error_from_map - raise_error_from_map = True - generator_response = stub.MapFn(request_iterator=request_generator(request)) - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - except grpc.RpcError as e: - logging.error(e) - grpc_exception = e - self.assertEqual(grpc.StatusCode.INTERNAL, e.code()) - self.assertTrue("Exception thrown from map" in e.__str__()) - finally: - raise_error_from_map = False - # 1 handshake - self.assertEqual(1, len(responses)) - self.assertIsNotNone(grpc_exception) - - def test_is_ready(self) -> None: - with grpc.insecure_channel("unix:///tmp/async_map.sock") as channel: - stub = map_pb2_grpc.MapStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - MapAsyncServer() - - def __stub(self): - return map_pb2_grpc.MapStub(_channel) - - def test_max_threads(self): - # max cap at 16 - server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = MapAsyncServer(mapper_instance=async_map_handler) - self.assertEqual(server.max_threads, 4) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + responses = [] + try: + global raise_error_from_map + raise_error_from_map = True + generator_response = map_stub.MapFn(request_iterator=request_generator(request)) + for r in generator_response: + responses.append(r) + except grpc.RpcError as e: + logging.error(e) + grpc_exception = e + assert e.code() == grpc.StatusCode.INTERNAL + assert "Exception thrown from map" in str(e) + finally: + raise_error_from_map = False + # 1 handshake + assert len(responses) == 1 + assert grpc_exception is not None + + +def test_is_ready(async_map_server): + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = map_pb2_grpc.MapStub(channel) + response = stub.IsReady(request=_empty_pb2.Empty()) + assert response.ready + + +def test_invalid_input(): + with pytest.raises(TypeError): + MapAsyncServer() + + +def test_max_threads(): + # max cap at 16 + server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = MapAsyncServer(mapper_instance=async_map_handler) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/map/test_messages.py b/packages/pynumaflow/tests/map/test_messages.py index b2edbad7..ba66769b 100644 --- a/packages/pynumaflow/tests/map/test_messages.py +++ b/packages/pynumaflow/tests/map/test_messages.py @@ -1,93 +1,97 @@ -import unittest +import pytest from pynumaflow.mapper import Messages, Message, DROP, Mapper, Datum from tests.testing_utils import mock_message -class TestMessage(unittest.TestCase): - def test_key(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) +def test_message_key(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + print(msg) + assert mock_obj["Keys"] == msg.keys + + +def test_message_value(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + assert mock_obj["Value"] == msg.value + + +def test_message_to_all(): + mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} + msg = Message(mock_obj["Value"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to_drop(): + mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} + msg = Message(b"").to_drop() + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to(): + mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def _mock_message_object(): + value = mock_message() + return Message(value=value) + + +def test_messages_items(): + mock_obj = [ + {"Keys": ["test_key"], "Value": mock_message()}, + {"Keys": ["test_key"], "Value": mock_message()}, + ] + msgs = Messages(*mock_obj) + assert len(mock_obj) == len(msgs) + assert len(mock_obj) == len(msgs.items()) + assert mock_obj[0]["Keys"] == msgs[0]["Keys"] + assert mock_obj[0]["Value"] == msgs[0]["Value"] + assert ( + "[{'Keys': ['test_key'], 'Value': b'test_mock_message'}, " + "{'Keys': ['test_key'], 'Value': b'test_mock_message'}]" + ) == repr(msgs) + + +def test_messages_append(): + msgs = Messages() + assert 0 == len(msgs) + msgs.append(_mock_message_object()) + assert 1 == len(msgs) + msgs.append(_mock_message_object()) + assert 2 == len(msgs) + + +def test_messages_forward_to_drop(): + mock_obj = Messages() + mock_obj.append(Message(b"").to_drop()) + true_obj = Messages() + true_obj.append(mock_obj[0]) + assert type(mock_obj) is type(true_obj) + for i in range(len(true_obj)): + assert type(mock_obj[i]) is type(true_obj[i]) + assert mock_obj[i].keys == true_obj[i].keys + assert mock_obj[i].value == true_obj[i].value + for msg in true_obj: print(msg) - self.assertEqual(mock_obj["Keys"], msg.keys) - - def test_value(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) - self.assertEqual(mock_obj["Value"], msg.value) - - def test_message_to_all(self): - mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} - msg = Message(mock_obj["Value"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to_drop(self): - mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} - msg = Message(b"").to_drop() - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to(self): - mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - -class TestMessages(unittest.TestCase): - @staticmethod - def mock_message_object(): - value = mock_message() - return Message(value=value) - - def test_items(self): - mock_obj = [ - {"Keys": ["test_key"], "Value": mock_message()}, - {"Keys": ["test_key"], "Value": mock_message()}, - ] - msgs = Messages(*mock_obj) - self.assertEqual(len(mock_obj), len(msgs)) - self.assertEqual(len(mock_obj), len(msgs.items())) - self.assertEqual(mock_obj[0]["Keys"], msgs[0]["Keys"]) - self.assertEqual(mock_obj[0]["Value"], msgs[0]["Value"]) - self.assertEqual( - "[{'Keys': ['test_key'], 'Value': b'test_mock_message'}, " - "{'Keys': ['test_key'], 'Value': b'test_mock_message'}]", - repr(msgs), - ) - - def test_append(self): - msgs = Messages() - self.assertEqual(0, len(msgs)) - msgs.append(self.mock_message_object()) - self.assertEqual(1, len(msgs)) - msgs.append(self.mock_message_object()) - self.assertEqual(2, len(msgs)) - - def test_message_forward_to_drop(self): - mock_obj = Messages() - mock_obj.append(Message(b"").to_drop()) - true_obj = Messages() - true_obj.append(mock_obj[0]) - self.assertEqual(type(mock_obj), type(true_obj)) - for i in range(len(true_obj)): - self.assertEqual(type(mock_obj[i]), type(true_obj[i])) - self.assertEqual(mock_obj[i].keys, true_obj[i].keys) - self.assertEqual(mock_obj[i].value, true_obj[i].value) - for msg in true_obj: - print(msg) - - def test_err(self): - msgts = Messages(self.mock_message_object(), self.mock_message_object()) - with self.assertRaises(TypeError): - msgts[:1] + + +def test_messages_err(): + msgts = Messages(_mock_message_object(), _mock_message_object()) + with pytest.raises(TypeError): + msgts[:1] class ExampleMapper(Mapper): @@ -97,23 +101,15 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: return messages -class TestMapClass(unittest.TestCase): - def setUp(self) -> None: - # Create a map class instance - self.mapper_instance = ExampleMapper() - - def test_map_class_call(self): - """Test that the __call__ functionality for the class works, - ie the class instance can be called directly to invoke the handler function - """ - # make a call to the class directly - ret = self.mapper_instance([], None) - self.assertEqual(mock_message(), ret[0].value) - # make a call to the handler - ret_handler = self.mapper_instance.handler(keys=[], datum=None) - # - self.assertEqual(ret[0], ret_handler[0]) - - -if __name__ == "__main__": - unittest.main() +def test_map_class_call(): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + mapper_instance = ExampleMapper() + # make a call to the class directly + ret = mapper_instance([], None) + assert mock_message() == ret[0].value + # make a call to the handler + ret_handler = mapper_instance.handler(keys=[], datum=None) + # + assert ret[0] == ret_handler[0] diff --git a/packages/pynumaflow/tests/map/test_multiproc_mapper.py b/packages/pynumaflow/tests/map/test_multiproc_mapper.py index 7250ad0c..148bbcd9 100644 --- a/packages/pynumaflow/tests/map/test_multiproc_mapper.py +++ b/packages/pynumaflow/tests/map/test_multiproc_mapper.py @@ -1,8 +1,6 @@ import os -import unittest -from unittest.mock import patch -import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time @@ -11,123 +9,111 @@ from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, get_test_datums from tests.conftest import collect_responses, drain_responses, send_test_requests -from tests.testing_utils import ( - mock_terminate_on_stop, -) - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestMultiProcMethods(unittest.TestCase): - def setUp(self) -> None: - my_server = MapMultiprocServer(mapper_instance=map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_multiproc_init(self) -> None: - my_server = MapMultiprocServer(mapper_instance=map_handler, server_count=3) - self.assertEqual(my_server._process_count, 3) - - def test_multiproc_process_count(self) -> None: - default_val = os.cpu_count() - my_server = MapMultiprocServer(mapper_instance=map_handler) - self.assertEqual(my_server._process_count, default_val) - - def test_max_process_count(self) -> None: - """Max process count is capped at 2 * os.cpu_count, irrespective of what the user - provides as input""" - default_val = os.cpu_count() - server = MapMultiprocServer(mapper_instance=map_handler, server_count=100) - self.assertEqual(server._process_count, default_val * 2) - - def test_udf_map_err_handshake(self): - my_server = MapMultiprocServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=False) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, - ) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - self.assertTrue("MapFn: expected handshake as the first message" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_udf_map_err(self): - my_server = MapMultiprocServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - test_datums = get_test_datums(handshake=True) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, - ) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - self.assertTrue("Something is fishy!" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( - method_descriptor=( - map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["IsReady"] - ), - invocation_metadata={}, - request=_empty_pb2.Empty(), - timeout=1, - ) - response, metadata, code, details = method.termination() - expected = map_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_map_forward_message(self): - test_datums = get_test_datums(handshake=True) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, + +@pytest.fixture() +def multiproc_test_server(): + my_server = MapMultiprocServer(mapper_instance=map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + return server_from_dictionary(services, strict_real_time()) + + +def _invoke_map_fn(test_server, timeout=1): + """Helper to invoke the MapFn stream method.""" + return test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=timeout, + ) + + +def test_multiproc_init(): + my_server = MapMultiprocServer(mapper_instance=map_handler, server_count=3) + assert my_server._process_count == 3 + + +def test_multiproc_process_count(): + default_val = os.cpu_count() + my_server = MapMultiprocServer(mapper_instance=map_handler) + assert my_server._process_count == default_val + + +def test_max_process_count(): + """Max process count is capped at 2 * os.cpu_count, irrespective of what the user + provides as input""" + default_val = os.cpu_count() + server = MapMultiprocServer(mapper_instance=map_handler, server_count=100) + assert server._process_count == default_val * 2 + + +def test_udf_map_err_handshake(): + my_server = MapMultiprocServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=False) + method = _invoke_map_fn(test_server) + send_test_requests(method, test_datums) + drain_responses(method) + + metadata, code, details = method.termination() + assert "MapFn: expected handshake as the first message" in details + assert code == StatusCode.INTERNAL + + +def test_udf_map_err(): + my_server = MapMultiprocServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + test_server = server_from_dictionary(services, strict_real_time()) + test_datums = get_test_datums(handshake=True) + method = _invoke_map_fn(test_server) + send_test_requests(method, test_datums) + drain_responses(method) + + metadata, code, details = method.termination() + assert "Something is fishy!" in details + assert code == StatusCode.INTERNAL + + +def test_is_ready(multiproc_test_server): + method = multiproc_test_server.invoke_unary_unary( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["IsReady"]), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + response, metadata, code, details = method.termination() + assert response == map_pb2.ReadyResponse(ready=True) + assert code == StatusCode.OK + + +def test_map_forward_message(multiproc_test_server): + test_datums = get_test_datums(handshake=True) + method = _invoke_map_fn(multiproc_test_server) + send_test_requests(method, test_datums) + responses = collect_responses(method) + + metadata, code, details = method.termination() + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot + + result_ids = {f"test-id-{id}" for id in range(1, 4)} + idx = 1 + while idx < len(responses): + result_ids.remove(responses[idx].id) + assert responses[idx].results[0].value == bytes( + "payload:test_mock_message " + "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", + encoding="utf-8", ) - send_test_requests(method, test_datums) - responses = collect_responses(method) - - metadata, code, details = method.termination() - - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) - - result_ids = {f"test-id-{id}" for id in range(1, 4)} - idx = 1 - while idx < len(responses): - result_ids.remove(responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " - "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", - encoding="utf-8", - ), - responses[idx].results[0].value, - ) - self.assertEqual(1, len(responses[idx].results)) - idx += 1 - self.assertEqual(len(result_ids), 0) - self.assertEqual(code, StatusCode.OK) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - MapMultiprocServer() - - -if __name__ == "__main__": - unittest.main() + assert len(responses[idx].results) == 1 + idx += 1 + assert len(result_ids) == 0 + assert code == StatusCode.OK + + +def test_invalid_input(): + with pytest.raises(TypeError): + MapMultiprocServer() diff --git a/packages/pynumaflow/tests/map/test_sync_mapper.py b/packages/pynumaflow/tests/map/test_sync_mapper.py index ce830440..f2d33df6 100644 --- a/packages/pynumaflow/tests/map/test_sync_mapper.py +++ b/packages/pynumaflow/tests/map/test_sync_mapper.py @@ -1,7 +1,4 @@ -import unittest -from unittest.mock import patch - -import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time @@ -10,131 +7,118 @@ from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler, ExampleMap, get_test_datums from tests.conftest import collect_responses, drain_responses, send_test_requests -from tests.testing_utils import ( - mock_terminate_on_stop, -) - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestSyncMapper(unittest.TestCase): - # @patch("psutil.Process.kill", mock_terminate_on_stop) - def setUp(self) -> None: - class_instance = ExampleMap() - my_server = MapServer(mapper_instance=class_instance) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_init_with_args(self) -> None: - my_servicer = MapServer( - mapper_instance=map_handler, - sock_path="/tmp/test.sock", - max_message_size=1024 * 1024 * 5, - ) - self.assertEqual(my_servicer.sock_path, "unix:///tmp/test.sock") - self.assertEqual(my_servicer.max_message_size, 1024 * 1024 * 5) - - def test_udf_map_err_handshake(self): - my_server = MapServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=False) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, - ) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - self.assertTrue("MapFn: expected handshake as the first message" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_udf_map_error_response(self): - my_server = MapServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=True) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, - ) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - self.assertTrue("Something is fishy!" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( - method_descriptor=( - map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["IsReady"] - ), - invocation_metadata={}, - request=_empty_pb2.Empty(), - timeout=1, - ) - response, metadata, code, details = method.termination() - expected = map_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_map_forward_message(self): - test_datums = get_test_datums(handshake=True) - method = self.test_server.invoke_stream_stream( - method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), - invocation_metadata={}, - timeout=1, + +@pytest.fixture() +def map_test_server(): + class_instance = ExampleMap() + my_server = MapServer(mapper_instance=class_instance) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + return server_from_dictionary(services, strict_real_time()) + + +def _invoke_map_fn(test_server, timeout=1): + """Helper to invoke the MapFn stream method.""" + return test_server.invoke_stream_stream( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["MapFn"]), + invocation_metadata={}, + timeout=timeout, + ) + + +def test_init_with_args(): + my_servicer = MapServer( + mapper_instance=map_handler, + sock_path="/tmp/test.sock", + max_message_size=1024 * 1024 * 5, + ) + assert my_servicer.sock_path == "unix:///tmp/test.sock" + assert my_servicer.max_message_size == 1024 * 1024 * 5 + + +def test_udf_map_err_handshake(): + my_server = MapServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=False) + method = _invoke_map_fn(test_server) + send_test_requests(method, test_datums) + drain_responses(method) + + metadata, code, details = method.termination() + assert "MapFn: expected handshake as the first message" in details + assert code == StatusCode.INTERNAL + + +def test_udf_map_error_response(): + my_server = MapServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + test_datums = get_test_datums(handshake=True) + method = _invoke_map_fn(test_server) + send_test_requests(method, test_datums) + drain_responses(method) + + metadata, code, details = method.termination() + assert "Something is fishy!" in details + assert code == StatusCode.INTERNAL + + +def test_is_ready(map_test_server): + method = map_test_server.invoke_unary_unary( + method_descriptor=(map_pb2.DESCRIPTOR.services_by_name["Map"].methods_by_name["IsReady"]), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + response, metadata, code, details = method.termination() + assert response == map_pb2.ReadyResponse(ready=True) + assert code == StatusCode.OK + + +def test_map_forward_message(map_test_server): + test_datums = get_test_datums(handshake=True) + method = _invoke_map_fn(map_test_server) + send_test_requests(method, test_datums) + responses = collect_responses(method) + + metadata, code, details = method.termination() + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot + + result_ids = {f"test-id-{id}" for id in range(1, 4)} + idx = 1 + while idx < len(responses): + result_ids.remove(responses[idx].id) + assert responses[idx].results[0].value == bytes( + "payload:test_mock_message " + "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", + encoding="utf-8", ) - send_test_requests(method, test_datums) - responses = collect_responses(method) - - metadata, code, details = method.termination() - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) - - result_ids = {f"test-id-{id}" for id in range(1, 4)} - idx = 1 - while idx < len(responses): - result_ids.remove(responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " - "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", - encoding="utf-8", - ), - responses[idx].results[0].value, - ) - self.assertEqual(1, len(responses[idx].results)) - idx += 1 - self.assertEqual(len(result_ids), 0) - self.assertEqual(code, StatusCode.OK) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - MapServer() - - def test_max_threads(self): - # max cap at 16 - server = MapServer(mapper_instance=map_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = MapServer(mapper_instance=map_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = MapServer(mapper_instance=map_handler) - self.assertEqual(server.max_threads, 4) - - -if __name__ == "__main__": - unittest.main() + assert len(responses[idx].results) == 1 + idx += 1 + assert len(result_ids) == 0 + assert code == StatusCode.OK + + +def test_invalid_input(): + with pytest.raises(TypeError): + MapServer() + + +def test_max_threads(): + # max cap at 16 + server = MapServer(mapper_instance=map_handler, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = MapServer(mapper_instance=map_handler, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = MapServer(mapper_instance=map_handler) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py index 0beae35e..55df9178 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py @@ -1,12 +1,11 @@ import asyncio import logging import threading -import unittest +from collections import Counter from collections.abc import AsyncIterable import grpc from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio._server import Server from pynumaflow import setup_logging from pynumaflow.mapstreamer import ( @@ -16,12 +15,15 @@ ) from pynumaflow.proto.mapper import map_pb2_grpc from tests.mapstream.utils import request_generator +import pytest LOGGER = setup_logging(__name__) # if set to true, map handler will raise a `ValueError` exception. raise_error_from_map = False +SOCK_PATH = "unix:///tmp/async_map_stream.sock" + async def async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]: val = datum.value @@ -34,159 +36,132 @@ async def async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIterab yield Message(str.encode(msg), keys=keys) -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/async_map_stream.sock") -_loop = None - - -def startup_callable(loop): +def _startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() -def NewAsyncMapStreamer( - map_stream_handler=async_map_stream_handler, -): - server = MapStreamAsyncServer(map_stream_instance=map_stream_handler) - udfs = server.servicer - return udfs - - -async def start_server(udfs): +async def _start_server(udfs): server = grpc.aio.server() map_pb2_grpc.add_MapServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/async_map_stream.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -class TestAsyncMapStreamer(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncMapStreamer() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: - try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: - LOGGER.error(e) + return server - def test_map_stream(self) -> None: - stub = self.__stub() - # Send >1 requests - req_count = 3 - try: - generator_response = stub.MapFn( - request_iterator=request_generator(count=req_count, session=1) - ) - except grpc.RpcError as e: - logging.error(e) - self.fail(f"RPC failed: {e}") +@pytest.fixture(scope="module") +def async_map_stream_server(): + """Module-scoped fixture: starts an async gRPC map stream server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) + thread.start() - # First message must be the handshake - handshake = next(generator_response) - self.assertTrue(handshake.handshake.sot) + server_obj = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) + udfs = server_obj.servicer + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) - # Expected: 10 results per request + 1 EOT per request - expected_result_msgs = req_count * 10 - expected_eots = req_count + while True: + try: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) - # Prepare expected payload - expected_payload = bytes( - "payload:test_mock_message " - "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", - encoding="utf-8", - ) + yield loop - from collections import Counter - - id_counter = Counter() - result_msg_count = 0 - eot_count = 0 - - for msg in generator_response: - # Count EOTs wherever they show up - if hasattr(msg, "status") and msg.status.eot: - eot_count += 1 - continue - - # Otherwise, it's a data/result message; validate payload and tally by id - self.assertTrue(msg.results, "Expected results in MapResponse.") - self.assertEqual(expected_payload, msg.results[0].value) - id_counter[msg.id] += 1 - result_msg_count += 1 - - # Validate totals - self.assertEqual( - expected_result_msgs, - result_msg_count, - f"Expected {expected_result_msgs} result messages, got {result_msg_count}", - ) - self.assertEqual( - expected_eots, eot_count, f"Expected {expected_eots} EOT messages, got {eot_count}" - ) + loop.stop() + LOGGER.info("stopped the event loop") - # Validate 10 messages per request id: test-id-0..test-id-(req_count-1) - for i in range(req_count): - self.assertEqual( - 10, - id_counter[f"test-id-{i}"], - f"Expected 10 results for test-id-{i}, got {id_counter[f'test-id-{i}']}", - ) - def test_is_ready(self) -> None: - with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel: - stub = map_pb2_grpc.MapStub(channel) +@pytest.fixture() +def map_stream_stub(async_map_stream_server): + """Returns a MapStub connected to the running async map stream server.""" + return map_pb2_grpc.MapStub(grpc.insecure_channel(SOCK_PATH)) - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - self.assertTrue(response.ready) +def test_map_stream(map_stream_stub): + # Send >1 requests + req_count = 3 + try: + generator_response = map_stream_stub.MapFn( + request_iterator=request_generator(count=req_count, session=1) + ) + except grpc.RpcError as e: + logging.error(e) + pytest.fail(f"RPC failed: {e}") + + # First message must be the handshake + handshake = next(generator_response) + assert handshake.handshake.sot + + # Expected: 10 results per request + 1 EOT per request + expected_result_msgs = req_count * 10 + expected_eots = req_count + + # Prepare expected payload + expected_payload = bytes( + "payload:test_mock_message " "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", + encoding="utf-8", + ) - def __stub(self): - return map_pb2_grpc.MapStub(_channel) + id_counter = Counter() + result_msg_count = 0 + eot_count = 0 + + for msg in generator_response: + # Count EOTs wherever they show up + if hasattr(msg, "status") and msg.status.eot: + eot_count += 1 + continue + + # Otherwise, it's a data/result message; validate payload and tally by id + assert msg.results, "Expected results in MapResponse." + assert msg.results[0].value == expected_payload + id_counter[msg.id] += 1 + result_msg_count += 1 + + # Validate totals + assert ( + result_msg_count == expected_result_msgs + ), f"Expected {expected_result_msgs} result messages, got {result_msg_count}" + assert eot_count == expected_eots, f"Expected {expected_eots} EOT messages, got {eot_count}" + + # Validate 10 messages per request id: test-id-0..test-id-(req_count-1) + for i in range(req_count): + assert ( + id_counter[f"test-id-{i}"] == 10 + ), f"Expected 10 results for test-id-{i}, got {id_counter[f'test-id-{i}']}" + + +def test_is_ready(async_map_stream_server): + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = map_pb2_grpc.MapStub(channel) + + request = _empty_pb2.Empty() + response = None + try: + response = stub.IsReady(request=request) + except grpc.RpcError as e: + logging.error(e) - def test_max_threads(self): - # max cap at 16 - server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) + assert response.ready - # use argument provided - server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - # defaults to 4 - server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) - self.assertEqual(server.max_threads, 4) +def test_max_threads(): + # max cap at 16 + server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=32) + assert server.max_threads == 16 + # use argument provided + server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=5) + assert server.max_threads == 5 -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + # defaults to 4 + server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py index cb7e0ef6..8f599136 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py @@ -1,22 +1,20 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable -from unittest.mock import patch import grpc - -from grpc.aio._server import Server +import pytest from pynumaflow import setup_logging from pynumaflow.mapstreamer import Message, Datum, MapStreamAsyncServer from pynumaflow.proto.mapper import map_pb2_grpc from tests.mapstream.utils import request_generator -from tests.testing_utils import mock_terminate_on_stop LOGGER = setup_logging(__name__) +SOCK_PATH = "unix:///tmp/async_map_stream_err.sock" + # This handler mimics the scenario where map stream UDF throws a runtime error. async def err_async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIterable[Message]: @@ -33,106 +31,91 @@ async def err_async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIt raise RuntimeError("Got a runtime error from map stream handler.") -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/async_map_stream_err.sock") -_loop = None - - -def startup_callable(loop): +def _startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() -async def start_server(): +async def _start_server(): server = grpc.aio.server() server_instance = MapStreamAsyncServer(err_async_map_stream_handler) udfs = server_instance.servicer map_pb2_grpc.add_MapServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/async_map_stream_err.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncServerErrorScenario(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - asyncio.run_coroutine_threadsafe(start_server(), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/async_map_stream_err.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_map_stream_err_server(): + """Module-scoped fixture: starts an async gRPC map stream error server.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) + thread.start() + + future = asyncio.run_coroutine_threadsafe(_start_server(), loop=loop) + future.result(timeout=10) + + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_map_stream_error(self) -> None: - try: - stub = self.__stub() - generator_response = None - try: - generator_response = stub.MapFn( - request_iterator=request_generator(count=1, session=1) - ) - except grpc.RpcError as e: - logging.error(e) - - handshake = next(generator_response) - # assert that handshake response is received. - self.assertTrue(handshake.handshake.sot) - data_resp = [] - for r in generator_response: - data_resp.append(r) - except Exception as err: - self.assertTrue("Got a runtime error from map stream handler." in err.__str__()) - return - self.fail("Expected an exception.") - - def test_map_stream_error_no_handshake(self) -> None: - global raise_error - raise_error = True - stub = self.__stub() - try: - generator_response = stub.MapFn( - request_iterator=request_generator(count=10, handshake=False, session=1) - ) - counter = 0 - for _ in generator_response: - counter += 1 - except Exception as err: - self.assertTrue("MapStreamFn: expected handshake as the first message" in err.__str__()) - return - self.fail("Expected an exception.") + yield loop - def __stub(self): - return map_pb2_grpc.MapStub(_channel) + loop.stop() + LOGGER.info("stopped the event loop") - def test_invalid_input(self): - with self.assertRaises(TypeError): - MapStreamAsyncServer() +@pytest.fixture() +def map_stream_err_stub(async_map_stream_err_server): + """Returns a MapStub connected to the running async error server.""" + return map_pb2_grpc.MapStub(grpc.insecure_channel(SOCK_PATH)) -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + +def test_map_stream_error(map_stream_err_stub): + try: + generator_response = None + try: + generator_response = map_stream_err_stub.MapFn( + request_iterator=request_generator(count=1, session=1) + ) + except grpc.RpcError as e: + logging.error(e) + + handshake = next(generator_response) + # assert that handshake response is received. + assert handshake.handshake.sot + data_resp = [] + for r in generator_response: + data_resp.append(r) + except Exception as err: + assert "Got a runtime error from map stream handler." in str(err) + return + pytest.fail("Expected an exception.") + + +def test_map_stream_error_no_handshake(map_stream_err_stub): + try: + generator_response = map_stream_err_stub.MapFn( + request_iterator=request_generator(count=10, handshake=False, session=1) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + assert "MapStreamFn: expected handshake as the first message" in str(err) + return + pytest.fail("Expected an exception.") + + +def test_invalid_input(): + with pytest.raises(TypeError): + MapStreamAsyncServer() diff --git a/packages/pynumaflow/tests/mapstream/test_messages.py b/packages/pynumaflow/tests/mapstream/test_messages.py index 6218deb2..bcf44705 100644 --- a/packages/pynumaflow/tests/mapstream/test_messages.py +++ b/packages/pynumaflow/tests/mapstream/test_messages.py @@ -1,94 +1,94 @@ -import unittest +import pytest from pynumaflow.mapstreamer import Messages, Message, DROP from tests.testing_utils import mock_message -class TestMessage(unittest.TestCase): - def test_key(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) +def test_message_key(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + print(msg) + assert mock_obj["Keys"] == msg.keys + + +def test_message_value(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + assert mock_obj["Value"] == msg.value + + +def test_message_to_all(): + mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} + msg = Message(mock_obj["Value"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to_drop(): + mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} + msg = Message(b"").to_drop() + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to(): + mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def _mock_message_object(): + value = mock_message() + return Message(value=value) + + +def test_messages_items(): + mock_obj = [ + {"Keys": ["test_key"], "Value": mock_message()}, + {"Keys": ["test_key"], "Value": mock_message()}, + ] + msgs = Messages(*mock_obj) + assert len(mock_obj) == len(msgs) + assert len(mock_obj) == len(msgs.items()) + assert mock_obj[0]["Keys"] == msgs[0]["Keys"] + assert mock_obj[0]["Value"] == msgs[0]["Value"] + assert ( + "[{'Keys': ['test_key'], 'Value': b'test_mock_message'}, " + "{'Keys': ['test_key'], 'Value': b'test_mock_message'}]" + ) == repr(msgs) + + +def test_messages_append(): + msgs = Messages() + assert 0 == len(msgs) + msgs.append(_mock_message_object()) + assert 1 == len(msgs) + msgs.append(_mock_message_object()) + assert 2 == len(msgs) + + +def test_messages_forward_to_drop(): + mock_obj = Messages() + mock_obj.append(Message(b"").to_drop()) + true_obj = Messages() + true_obj.append(mock_obj[0]) + assert type(mock_obj) is type(true_obj) + for i in range(len(true_obj)): + assert type(mock_obj[i]) is type(true_obj[i]) + assert mock_obj[i].keys == true_obj[i].keys + assert mock_obj[i].value == true_obj[i].value + for msg in true_obj: print(msg) - self.assertEqual(mock_obj["Keys"], msg.keys) - - def test_value(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) - self.assertEqual(mock_obj["Value"], msg.value) - - def test_message_to_all(self): - mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} - msg = Message(mock_obj["Value"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to_drop(self): - mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} - msg = Message(b"").to_drop() - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to(self): - mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - -class TestMessages(unittest.TestCase): - @staticmethod - def mock_message_object(): - value = mock_message() - return Message(value=value) - - def test_items(self): - mock_obj = [ - {"Keys": ["test_key"], "Value": mock_message()}, - {"Keys": ["test_key"], "Value": mock_message()}, - ] - msgs = Messages(*mock_obj) - self.assertEqual(len(mock_obj), len(msgs)) - self.assertEqual(len(mock_obj), len(msgs.items())) - self.assertEqual(mock_obj[0]["Keys"], msgs[0]["Keys"]) - self.assertEqual(mock_obj[0]["Value"], msgs[0]["Value"]) - self.assertEqual( - "[{'Keys': ['test_key'], 'Value': b'test_mock_message'}, " - "{'Keys': ['test_key'], 'Value': b'test_mock_message'}]", - repr(msgs), - ) - - def test_append(self): - msgs = Messages() - self.assertEqual(0, len(msgs)) - msgs.append(self.mock_message_object()) - self.assertEqual(1, len(msgs)) - msgs.append(self.mock_message_object()) - self.assertEqual(2, len(msgs)) - - def test_message_forward_to_drop(self): - mock_obj = Messages() - mock_obj.append(Message(b"").to_drop()) - true_obj = Messages() - true_obj.append(mock_obj[0]) - self.assertEqual(type(mock_obj), type(true_obj)) - for i in range(len(true_obj)): - self.assertEqual(type(mock_obj[i]), type(true_obj[i])) - self.assertEqual(mock_obj[i].keys, true_obj[i].keys) - self.assertEqual(mock_obj[i].value, true_obj[i].value) - for msg in true_obj: - print(msg) - - def test_err(self): - msgts = Messages(self.mock_message_object(), self.mock_message_object()) - with self.assertRaises(TypeError): - msgts[:1] - - -if __name__ == "__main__": - unittest.main() + + +def test_messages_err(): + msgts = Messages(_mock_message_object(), _mock_message_object()) + with pytest.raises(TypeError): + msgts[:1] diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce.py b/packages/pynumaflow/tests/reduce/test_async_reduce.py index 3eda5854..febcbc0e 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce.py @@ -1,12 +1,11 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio._server import Server from pynumaflow._constants import WIN_START_TIME, WIN_END_TIME from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc @@ -28,6 +27,8 @@ logging.basicConfig(level=logging.DEBUG) LOGGER = logging.getLogger(__name__) +SOCK_PATH = "unix:///tmp/reduce.sock" + def request_generator(count, request, resetkey: bool = False): for i in range(count): @@ -63,11 +64,6 @@ def start_request() -> (Datum, tuple): return request, metadata -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/reduce.sock") -_loop = None - - def startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() @@ -104,166 +100,163 @@ def NewAsyncReducer(): return udfs -async def start_server(udfs): +async def _start_server(udfs): server = grpc.aio.server() reduce_pb2_grpc.add_ReduceServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/reduce.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -class TestAsyncReducer(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncReducer() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/reduce.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_reduce_server(): + """Module-scoped fixture: starts an async gRPC reduce server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncReducer() + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) + + # Wait for the server to be ready + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_reduce(self) -> None: - stub = self.__stub() - request, metadata = start_request() - generator_response = None - try: - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=10, request=request) + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def reduce_stub(async_reduce_server): + """Returns a ReduceStub connected to the running async server.""" + return reduce_pb2_grpc.ReduceStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_reduce(reduce_stub) -> None: + request, metadata = start_request() + generator_response = None + try: + generator_response = reduce_stub.ReduceFn( + request_iterator=request_generator(count=10, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the ReduceFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if r.result.value: + count += 1 + assert ( + bytes( + "counter:10", + encoding="utf-8", + ) + == r.result.value ) - except grpc.RpcError as e: - logging.error(e) - - # capture the output from the ReduceFn generator and assert. - count = 0 - eof_count = 0 - for r in generator_response: - if r.result.value: - count += 1 - self.assertEqual( - bytes( - "counter:10", - encoding="utf-8", - ), - r.result.value, + assert r.EOF is False + else: + assert r.EOF is True + eof_count += 1 + assert r.window.start.ToSeconds() == 1662998400 + assert r.window.end.ToSeconds() == 1662998460 + # since there is only one key, the output count is 1 + assert 1 == count + assert 1 == eof_count + + +def test_reduce_with_multiple_keys(reduce_stub) -> None: + request, metadata = start_request() + generator_response = None + try: + generator_response = reduce_stub.ReduceFn( + request_iterator=request_generator(count=100, request=request, resetkey=True), + ) + except grpc.RpcError as e: + print(e) + + count = 0 + eof_count = 0 + + # capture the output from the ReduceFn generator and assert. + for r in generator_response: + # Check for responses with + if r.result.value: + count += 1 + assert ( + bytes( + "counter:1", + encoding="utf-8", ) - self.assertEqual(r.EOF, False) - else: - self.assertEqual(r.EOF, True) - eof_count += 1 - self.assertEqual(r.window.start.ToSeconds(), 1662998400) - self.assertEqual(r.window.end.ToSeconds(), 1662998460) - # since there is only one key, the output count is 1 - self.assertEqual(1, count) - self.assertEqual(1, eof_count) - - def test_reduce_with_multiple_keys(self) -> None: - stub = self.__stub() - request, metadata = start_request() - generator_response = None - try: - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=100, request=request, resetkey=True), + == r.result.value ) + assert r.EOF is False + else: + eof_count += 1 + assert r.EOF is True + assert r.window.start.ToSeconds() == 1662998400 + assert r.window.end.ToSeconds() == 1662998460 + assert 100 == count + assert 1 == eof_count + + +def test_is_ready(async_reduce_server) -> None: + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = reduce_pb2_grpc.ReduceStub(channel) + + request = _empty_pb2.Empty() + response = None + try: + response = stub.IsReady(request=request) except grpc.RpcError as e: - print(e) - - count = 0 - eof_count = 0 - - # capture the output from the ReduceFn generator and assert. - for r in generator_response: - # Check for responses with - if r.result.value: - count += 1 - self.assertEqual( - bytes( - "counter:1", - encoding="utf-8", - ), - r.result.value, - ) - self.assertEqual(r.EOF, False) - else: - eof_count += 1 - self.assertEqual(r.EOF, True) - self.assertEqual(r.window.start.ToSeconds(), 1662998400) - self.assertEqual(r.window.end.ToSeconds(), 1662998460) - self.assertEqual(100, count) - self.assertEqual(1, eof_count) - - def test_is_ready(self) -> None: - with grpc.insecure_channel("unix:///tmp/reduce.sock") as channel: - stub = reduce_pb2_grpc.ReduceStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def __stub(self): - return reduce_pb2_grpc.ReduceStub(_channel) - - def test_error_init(self): - # Check that reducer_instance in required - with self.assertRaises(TypeError): - ReduceAsyncServer() - # Check that the init_args and init_kwargs are passed - # only with a Reducer class - with self.assertRaises(TypeError): - ReduceAsyncServer(reduce_handler_func, init_args=(0, 1)) - # Check that an instance is not passed instead of the class - # signature - with self.assertRaises(TypeError): - ReduceAsyncServer(ExampleClass(0)) - - # Check that an invalid class is passed - class ExampleBadClass: - pass - - with self.assertRaises(TypeError): - ReduceAsyncServer(reducer_instance=ExampleBadClass) - - def test_max_threads(self): - # max cap at 16 - server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = ReduceAsyncServer(reducer_instance=ExampleClass) - self.assertEqual(server.max_threads, 4) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + logging.error(e) + + assert response.ready + + +def test_error_init(): + # Check that reducer_instance in required + with pytest.raises(TypeError): + ReduceAsyncServer() + # Check that the init_args and init_kwargs are passed + # only with a Reducer class + with pytest.raises(TypeError): + ReduceAsyncServer(reduce_handler_func, init_args=(0, 1)) + # Check that an instance is not passed instead of the class + # signature + with pytest.raises(TypeError): + ReduceAsyncServer(ExampleClass(0)) + + # Check that an invalid class is passed + class ExampleBadClass: + pass + + with pytest.raises(TypeError): + ReduceAsyncServer(reducer_instance=ExampleBadClass) + + +def test_max_threads(): + # max cap at 16 + server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = ReduceAsyncServer(reducer_instance=ExampleClass) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce_err.py b/packages/pynumaflow/tests/reduce/test_async_reduce_err.py index 42bf2aab..53b3c91b 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce_err.py @@ -1,12 +1,10 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable -from unittest.mock import patch import grpc -from grpc.aio._server import Server +import pytest from pynumaflow import setup_logging from pynumaflow._constants import WIN_START_TIME, WIN_END_TIME @@ -23,11 +21,12 @@ mock_interval_window_start, mock_interval_window_end, get_time_args, - mock_terminate_on_stop, ) LOGGER = setup_logging(__name__) +SOCK_PATH = "unix:///tmp/reduce_err.sock" + def request_generator(count, request, resetkey: bool = False): for i in range(count): @@ -69,11 +68,6 @@ def start_request(multiple_window: False) -> (Datum, tuple): return request, metadata -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/reduce_err.sock") -_loop = None - - def startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() @@ -99,90 +93,79 @@ def NewAsyncReducer(): return udfs -@patch("psutil.Process.kill", mock_terminate_on_stop) -async def start_server(udfs): +async def _start_server(udfs): server = grpc.aio.server() reduce_pb2_grpc.add_ReduceServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/reduce_err.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -# TODO: Check why terminating even after mocking -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncReducerError(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncReducer() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/reduce_err.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_reduce_err_server(): + """Module-scoped fixture: starts an async gRPC reduce error server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncReducer() + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) + + # Wait for the server to be ready + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_reduce(self) -> None: - with grpc.insecure_channel("unix:///tmp/reduce_err.sock") as channel: - stub = reduce_pb2_grpc.ReduceStub(channel) - request, metadata = start_request(multiple_window=False) - generator_response = None - try: - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=1, request=request) - ) - counter = 0 - for _ in generator_response: - counter += 1 - except BaseException as err: - self.assertTrue("Got a runtime error from reduce handler." in err.__str__()) - return - self.fail("Expected an exception.") - - def test_reduce_window_len(self) -> None: - stub = self.__stub() - request, metadata = start_request(multiple_window=True) + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def reduce_err_stub(async_reduce_err_server): + """Returns a ReduceStub connected to the running async error server.""" + return reduce_pb2_grpc.ReduceStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_reduce(async_reduce_err_server) -> None: + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = reduce_pb2_grpc.ReduceStub(channel) + request, metadata = start_request(multiple_window=False) generator_response = None try: generator_response = stub.ReduceFn( - request_iterator=request_generator(count=10, request=request) + request_iterator=request_generator(count=1, request=request) ) counter = 0 for _ in generator_response: counter += 1 except BaseException as err: - self.assertTrue( - "reduce create operation error: invalid number of windows" in err.__str__() - ) + assert "Got a runtime error from reduce handler." in str(err) return - self.fail("Expected an exception.") - - def __stub(self): - return reduce_pb2_grpc.ReduceStub(_channel) + pytest.fail("Expected an exception.") -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() +def test_reduce_window_len(reduce_err_stub) -> None: + request, metadata = start_request(multiple_window=True) + generator_response = None + try: + generator_response = reduce_err_stub.ReduceFn( + request_iterator=request_generator(count=10, request=request) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except BaseException as err: + assert "reduce create operation error: invalid number of windows" in str(err) + return + pytest.fail("Expected an exception.") diff --git a/packages/pynumaflow/tests/reduce/test_datatypes.py b/packages/pynumaflow/tests/reduce/test_datatypes.py index 20a520dd..8b2b5abf 100644 --- a/packages/pynumaflow/tests/reduce/test_datatypes.py +++ b/packages/pynumaflow/tests/reduce/test_datatypes.py @@ -1,10 +1,10 @@ from copy import deepcopy -import unittest from collections.abc import AsyncIterable -from google.protobuf import timestamp_pb2 as _timestamp_pb2 from pynumaflow.reducer import Reducer, Messages +from google.protobuf import timestamp_pb2 as _timestamp_pb2 + from pynumaflow.reducer._dtypes import ( IntervalWindow, Metadata, @@ -19,151 +19,146 @@ mock_end_time, ) +import pytest + TEST_KEYS = ["test"] TEST_ID = "test_id" -class TestDatum(unittest.TestCase): - def test_err_event_time(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - headers = {"key1": "value1", "key2": "value2"} - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, value=mock_message(), event_time=ts, watermark=ts, headers=headers - ) - self.assertEqual( - "Wrong data type: " - "for Datum.event_time", - str(context.exception), - ) +def test_datum_err_event_time(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + headers = {"key1": "value1", "key2": "value2"} + with pytest.raises(Exception) as exc_info: + Datum(keys=TEST_KEYS, value=mock_message(), event_time=ts, watermark=ts, headers=headers) + assert ( + "Wrong data type: " "for Datum.event_time" + ) == str(exc_info.value) - def test_err_watermark(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - headers = {"key1": "value1", "key2": "value2"} - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=ts, - headers=headers, - ) - self.assertEqual( - "Wrong data type: " - "for Datum.watermark", - str(context.exception), - ) - - def test_value(self): - test_headers = {"key1": "value1", "key2": "value2"} - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - headers=test_headers, - ) - self.assertEqual(mock_message(), d.value) - self.assertEqual(test_headers, d.headers) - - def test_key(self): - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - ) - self.assertEqual(TEST_KEYS, d.keys) - - def test_event_time(self): - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - ) - self.assertEqual(mock_event_time(), d.event_time) - def test_watermark(self): - d = Datum( +def test_datum_err_watermark(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + headers = {"key1": "value1", "key2": "value2"} + with pytest.raises(Exception) as exc_info: + Datum( keys=TEST_KEYS, value=mock_message(), event_time=mock_event_time(), - watermark=mock_watermark(), + watermark=ts, + headers=headers, ) - self.assertEqual(mock_watermark(), d.watermark) - - -class TestIntervalWindow(unittest.TestCase): - def test_start(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(mock_start_time(), i.start) - - def test_end(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(mock_end_time(), i.end) - - -class TestMetadata(unittest.TestCase): - def test_interval_window(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - m = Metadata(interval_window=i) - self.assertEqual(type(i), type(m.interval_window)) - self.assertEqual(i, m.interval_window) - - -class TestReducerWindow(unittest.TestCase): - def test_create_window(self): - rw = ReduceWindow(start=mock_start_time(), end=mock_end_time(), slot="slot-0") - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(rw.window, i) - self.assertEqual(rw.start, mock_start_time()) - self.assertEqual(rw.end, mock_end_time()) - self.assertEqual(rw.slot, "slot-0") - - -class TestReducerClass(unittest.TestCase): - class ExampleClass(Reducer): - async def handler( - self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata - ) -> Messages: - pass - - def __init__(self, test1, test2): - self.test1 = test1 - self.test2 = test2 - self.test3 = self.test1 - - def test_init(self): - r = self.ExampleClass(test1=1, test2=2) - self.assertEqual(1, r.test1) - self.assertEqual(2, r.test2) - self.assertEqual(1, r.test3) - - def test_deep_copy(self): - """Test that the deepcopy works as expected""" - r = self.ExampleClass(test1=1, test2=2) - # Create a copy of r - r_copy = deepcopy(r) - # Check that the attributes are the same - self.assertEqual(1, r_copy.test1) - self.assertEqual(2, r_copy.test2) - self.assertEqual(1, r_copy.test3) - # Check that the objects are not the same - self.assertNotEqual(id(r), id(r_copy)) - # Update the attributes of r - r.test1 = 5 - r.test3 = 6 - # Check that the other object is not updated - self.assertNotEqual(r.test1, r_copy.test1) - self.assertNotEqual(r.test3, r_copy.test3) - self.assertNotEqual(id(r.test3), id(r_copy.test3)) - # Verify that the instance type is correct - self.assertTrue(isinstance(r_copy, self.ExampleClass)) - self.assertTrue(isinstance(r_copy, Reducer)) - - -if __name__ == "__main__": - unittest.main() + assert ( + "Wrong data type: " "for Datum.watermark" + ) == str(exc_info.value) + + +def test_datum_value(): + test_headers = {"key1": "value1", "key2": "value2"} + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=test_headers, + ) + assert mock_message() == d.value + assert test_headers == d.headers + + +def test_datum_key(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + ) + assert TEST_KEYS == d.keys + + +def test_datum_event_time(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + ) + assert mock_event_time() == d.event_time + + +def test_datum_watermark(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + ) + assert mock_watermark() == d.watermark + + +def test_interval_window_start(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert mock_start_time() == i.start + + +def test_interval_window_end(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert mock_end_time() == i.end + + +def test_metadata_interval_window(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + m = Metadata(interval_window=i) + assert type(i) is type(m.interval_window) + assert i == m.interval_window + + +def test_reduce_window_create(): + rw = ReduceWindow(start=mock_start_time(), end=mock_end_time(), slot="slot-0") + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert rw.window == i + assert rw.start == mock_start_time() + assert rw.end == mock_end_time() + assert rw.slot == "slot-0" + + +class ExampleReducer(Reducer): + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + pass + + def __init__(self, test1, test2): + self.test1 = test1 + self.test2 = test2 + self.test3 = self.test1 + + +def test_reducer_class_init(): + r = ExampleReducer(test1=1, test2=2) + assert 1 == r.test1 + assert 2 == r.test2 + assert 1 == r.test3 + + +def test_reducer_class_deep_copy(): + """Test that the deepcopy works as expected""" + r = ExampleReducer(test1=1, test2=2) + # Create a copy of r + r_copy = deepcopy(r) + # Check that the attributes are the same + assert 1 == r_copy.test1 + assert 2 == r_copy.test2 + assert 1 == r_copy.test3 + # Check that the objects are not the same + assert id(r) != id(r_copy) + # Update the attributes of r + r.test1 = 5 + r.test3 = 6 + # Check that the other object is not updated + assert r.test1 != r_copy.test1 + assert r.test3 != r_copy.test3 + assert id(r.test3) != id(r_copy.test3) + # Verify that the instance type is correct + assert isinstance(r_copy, ExampleReducer) + assert isinstance(r_copy, Reducer) diff --git a/packages/pynumaflow/tests/reduce/test_messages.py b/packages/pynumaflow/tests/reduce/test_messages.py index 21ee21b4..5e4684e6 100644 --- a/packages/pynumaflow/tests/reduce/test_messages.py +++ b/packages/pynumaflow/tests/reduce/test_messages.py @@ -1,94 +1,94 @@ -import unittest +import pytest from pynumaflow.reducer import Messages, Message, DROP from tests.testing_utils import mock_message -class TestMessage(unittest.TestCase): - def test_key(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) +def test_message_key(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + print(msg) + assert mock_obj["Keys"] == msg.keys + + +def test_message_value(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + assert mock_obj["Value"] == msg.value + + +def test_message_to_all(): + mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} + msg = Message(mock_obj["Value"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to_drop(): + mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} + msg = Message(b"").to_drop() + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to(): + mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) + assert type(msg) is Message + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def _mock_message_object(): + value = mock_message() + return Message(value=value) + + +def test_messages_items(): + mock_obj = [ + {"Keys": ["test_key"], "Value": mock_message()}, + {"Keys": ["test_key"], "Value": mock_message()}, + ] + msgs = Messages(*mock_obj) + assert len(mock_obj) == len(msgs) + assert len(mock_obj) == len(msgs.items()) + assert mock_obj[0]["Keys"] == msgs[0]["Keys"] + assert mock_obj[0]["Value"] == msgs[0]["Value"] + assert ( + "[{'Keys': ['test_key'], 'Value': b'test_mock_message'}, " + "{'Keys': ['test_key'], 'Value': b'test_mock_message'}]" + ) == repr(msgs) + + +def test_messages_append(): + msgs = Messages() + assert 0 == len(msgs) + msgs.append(_mock_message_object()) + assert 1 == len(msgs) + msgs.append(_mock_message_object()) + assert 2 == len(msgs) + + +def test_messages_forward_to_drop(): + mock_obj = Messages() + mock_obj.append(Message(b"").to_drop()) + true_obj = Messages() + true_obj.append(mock_obj[0]) + assert type(mock_obj) is type(true_obj) + for i in range(len(true_obj)): + assert type(mock_obj[i]) is type(true_obj[i]) + assert mock_obj[i].keys == true_obj[i].keys + assert mock_obj[i].value == true_obj[i].value + for msg in true_obj: print(msg) - self.assertEqual(mock_obj["Keys"], msg.keys) - - def test_value(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) - self.assertEqual(mock_obj["Value"], msg.value) - - def test_message_to_all(self): - mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} - msg = Message(mock_obj["Value"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to_drop(self): - mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} - msg = Message(b"").to_drop() - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to(self): - mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - -class TestMessages(unittest.TestCase): - @staticmethod - def mock_message_object(): - value = mock_message() - return Message(value=value) - - def test_items(self): - mock_obj = [ - {"Keys": ["test_key"], "Value": mock_message()}, - {"Keys": ["test_key"], "Value": mock_message()}, - ] - msgs = Messages(*mock_obj) - self.assertEqual(len(mock_obj), len(msgs)) - self.assertEqual(len(mock_obj), len(msgs.items())) - self.assertEqual(mock_obj[0]["Keys"], msgs[0]["Keys"]) - self.assertEqual(mock_obj[0]["Value"], msgs[0]["Value"]) - self.assertEqual( - "[{'Keys': ['test_key'], 'Value': b'test_mock_message'}, " - "{'Keys': ['test_key'], 'Value': b'test_mock_message'}]", - repr(msgs), - ) - - def test_append(self): - msgs = Messages() - self.assertEqual(0, len(msgs)) - msgs.append(self.mock_message_object()) - self.assertEqual(1, len(msgs)) - msgs.append(self.mock_message_object()) - self.assertEqual(2, len(msgs)) - - def test_message_forward_to_drop(self): - mock_obj = Messages() - mock_obj.append(Message(b"").to_drop()) - true_obj = Messages() - true_obj.append(mock_obj[0]) - self.assertEqual(type(mock_obj), type(true_obj)) - for i in range(len(true_obj)): - self.assertEqual(type(mock_obj[i]), type(true_obj[i])) - self.assertEqual(mock_obj[i].keys, true_obj[i].keys) - self.assertEqual(mock_obj[i].value, true_obj[i].value) - for msg in true_obj: - print(msg) - - def test_err(self): - msgts = Messages(self.mock_message_object(), self.mock_message_object()) - with self.assertRaises(TypeError): - msgts[:1] - - -if __name__ == "__main__": - unittest.main() + + +def test_messages_err(): + msgts = Messages(_mock_message_object(), _mock_message_object()) + with pytest.raises(TypeError): + msgts[:1] diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py index 675ec6d9..c0745e25 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py @@ -1,12 +1,11 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio._server import Server from pynumaflow import setup_logging from pynumaflow._constants import WIN_START_TIME, WIN_END_TIME @@ -28,6 +27,8 @@ LOGGER = setup_logging(__name__) +SOCK_PATH = "unix:///tmp/reduce_stream.sock" + def request_generator(count, request, resetkey: bool = False): for i in range(count): @@ -68,11 +69,6 @@ def start_request() -> (Datum, tuple): return request, metadata -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/reduce_stream.sock") -_loop = None - - def startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() @@ -124,241 +120,241 @@ def NewAsyncReduceStreamer(): return udfs -async def start_server(udfs): +async def _start_server(udfs): server = grpc.aio.server() reduce_pb2_grpc.add_ReduceServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/reduce_stream.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -class TestAsyncReduceStreamer(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncReduceStreamer() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/reduce_stream.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_reduce_stream_server(): + """Module-scoped fixture: starts an async gRPC reduce stream server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncReduceStreamer() + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) + + # Wait for the server to be ready + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_reduce(self) -> None: - stub = self.__stub() - request, metadata = start_request() - generator_response = None + yield loop - try: - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=10, request=request) - ) - except grpc.RpcError as e: - logging.error(e) + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def reduce_stream_stub(async_reduce_stream_server): + """Returns a ReduceStub connected to the running async reduce stream server.""" + return reduce_pb2_grpc.ReduceStub(grpc.insecure_channel(SOCK_PATH)) - # capture the output from the ReduceFn generator and assert. - count = 0 - eof_count = 0 - for r in generator_response: - if r.result.value: - count += 1 - if count <= 3: - self.assertEqual( - bytes( - "counter:3", - encoding="utf-8", - ), - r.result.value, - ) - else: - self.assertEqual( - bytes( - "counter:1", - encoding="utf-8", - ), - r.result.value, - ) - self.assertEqual(r.EOF, False) - else: - self.assertEqual(r.EOF, True) - eof_count += 1 - self.assertEqual(r.window.start.ToSeconds(), 1662998400) - self.assertEqual(r.window.end.ToSeconds(), 1662998460) - # in our example we should be return 3 messages early with counter:3 - # and last message with counter:1 - self.assertEqual(4, count) - self.assertEqual(1, eof_count) - - def test_reduce_with_multiple_keys(self) -> None: - stub = self.__stub() - request, metadata = start_request() - generator_response = None - try: - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=100, request=request, resetkey=True), - ) - except grpc.RpcError as e: - print(e) - count = 0 - eof_count = 0 +def test_reduce(reduce_stream_stub) -> None: + request, metadata = start_request() + generator_response = None - # capture the output from the ReduceFn generator and assert. - for r in generator_response: - # Check for responses with - if r.result.value: - count += 1 - self.assertEqual( + try: + generator_response = reduce_stream_stub.ReduceFn( + request_iterator=request_generator(count=10, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the ReduceFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if r.result.value: + count += 1 + if count <= 3: + assert ( bytes( - "counter:1", + "counter:3", encoding="utf-8", - ), - r.result.value, + ) + == r.result.value ) - self.assertEqual(r.EOF, False) else: - eof_count += 1 - self.assertEqual(r.EOF, True) - self.assertEqual(r.window.start.ToSeconds(), 1662998400) - self.assertEqual(r.window.end.ToSeconds(), 1662998460) - self.assertEqual(100, count) - self.assertEqual(1, eof_count) - - def test_is_ready(self) -> None: - with grpc.insecure_channel("unix:///tmp/reduce_stream.sock") as channel: - stub = reduce_pb2_grpc.ReduceStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def __stub(self): - return reduce_pb2_grpc.ReduceStub(_channel) - - def test_error_init(self): - # Check that reducer_instance in required - with self.assertRaises(TypeError): - ReduceStreamAsyncServer() - # Check that the init_args and init_kwargs are passed - # only with a Reducer class - with self.assertRaises(TypeError): - ReduceStreamAsyncServer(reduce_handler_func, init_args=(0, 1)) - # Check that an instance is not passed instead of the class - # signature - with self.assertRaises(TypeError): - ReduceStreamAsyncServer(ExampleClass(0)) - - # Check that an invalid class is passed - class ExampleBadClass: - pass - - with self.assertRaises(TypeError): - ReduceStreamAsyncServer(reduce_stream_instance=ExampleBadClass) - - def test_max_threads(self): - # max cap at 16 - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) - self.assertEqual(server.max_threads, 4) - - def test_start_shutdown_handler_without_callback(self): - """Test that _shutdown_handler logs and works when no shutdown_callback is set.""" - from unittest.mock import patch, MagicMock - - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) - self.assertIsNone(server.shutdown_callback) - - def close_coro(coro, **kwargs): - coro.close() - - with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: - mock_aiorun.run.side_effect = close_coro - server.start() - - # Extract the shutdown_callback passed to aiorun.run - call_kwargs = mock_aiorun.run.call_args[1] - shutdown_handler = call_kwargs["shutdown_callback"] - - # Invoke the handler — should not raise even without a callback - mock_loop = MagicMock() - shutdown_handler(mock_loop) - - def test_start_shutdown_handler_with_callback(self): - """Test that _shutdown_handler invokes the user-provided shutdown_callback.""" - from unittest.mock import patch, MagicMock - - user_callback = MagicMock() - server = ReduceStreamAsyncServer( - reduce_stream_instance=ExampleClass, shutdown_callback=user_callback + assert ( + bytes( + "counter:1", + encoding="utf-8", + ) + == r.result.value + ) + assert r.EOF is False + else: + assert r.EOF is True + eof_count += 1 + assert r.window.start.ToSeconds() == 1662998400 + assert r.window.end.ToSeconds() == 1662998460 + # in our example we should be return 3 messages early with counter:3 + # and last message with counter:1 + assert 4 == count + assert 1 == eof_count + + +def test_reduce_with_multiple_keys(reduce_stream_stub) -> None: + request, metadata = start_request() + generator_response = None + try: + generator_response = reduce_stream_stub.ReduceFn( + request_iterator=request_generator(count=100, request=request, resetkey=True), ) + except grpc.RpcError as e: + print(e) + + count = 0 + eof_count = 0 + + # capture the output from the ReduceFn generator and assert. + for r in generator_response: + # Check for responses with + if r.result.value: + count += 1 + assert ( + bytes( + "counter:1", + encoding="utf-8", + ) + == r.result.value + ) + assert r.EOF is False + else: + eof_count += 1 + assert r.EOF is True + assert r.window.start.ToSeconds() == 1662998400 + assert r.window.end.ToSeconds() == 1662998460 + assert 100 == count + assert 1 == eof_count + + +def test_is_ready(async_reduce_stream_server) -> None: + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = reduce_pb2_grpc.ReduceStub(channel) + + request = _empty_pb2.Empty() + response = None + try: + response = stub.IsReady(request=request) + except grpc.RpcError as e: + logging.error(e) + + assert response.ready + + +def test_error_init(): + # Check that reducer_instance in required + with pytest.raises(TypeError): + ReduceStreamAsyncServer() + # Check that the init_args and init_kwargs are passed + # only with a Reducer class + with pytest.raises(TypeError): + ReduceStreamAsyncServer(reduce_handler_func, init_args=(0, 1)) + # Check that an instance is not passed instead of the class + # signature + with pytest.raises(TypeError): + ReduceStreamAsyncServer(ExampleClass(0)) + + # Check that an invalid class is passed + class ExampleBadClass: + pass + + with pytest.raises(TypeError): + ReduceStreamAsyncServer(reduce_stream_instance=ExampleBadClass) + - def close_coro(coro, **kwargs): - coro.close() +def test_max_threads(): + # max cap at 16 + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=32) + assert server.max_threads == 16 - with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: - mock_aiorun.run.side_effect = close_coro - server.start() + # use argument provided + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=5) + assert server.max_threads == 5 - shutdown_handler = mock_aiorun.run.call_args[1]["shutdown_callback"] - mock_loop = MagicMock() - shutdown_handler(mock_loop) + # defaults to 4 + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + assert server.max_threads == 4 - user_callback.assert_called_once_with(mock_loop) - def test_start_exits_on_error(self): - """Test that start() calls sys.exit(1) when servicer reports an error.""" - from unittest.mock import patch +def test_start_shutdown_handler_without_callback(): + """Test that _shutdown_handler logs and works when no shutdown_callback is set.""" + from unittest.mock import patch, MagicMock - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + assert server.shutdown_callback is None - def fake_aiorun_run(coro, **kwargs): - # Simulate aiorun completing after a UDF error was recorded - coro.close() # prevent "coroutine never awaited" warning - server._error = RuntimeError("UDF failure") + def close_coro(coro, **kwargs): + coro.close() - with ( - patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun, - patch("pynumaflow.reducestreamer.async_server.sys") as mock_sys, - ): - mock_aiorun.run.side_effect = fake_aiorun_run - server.start() + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: + mock_aiorun.run.side_effect = close_coro + server.start() - mock_sys.exit.assert_called_once_with(1) + # Extract the shutdown_callback passed to aiorun.run + call_kwargs = mock_aiorun.run.call_args[1] + shutdown_handler = call_kwargs["shutdown_callback"] + # Invoke the handler — should not raise even without a callback + mock_loop = MagicMock() + shutdown_handler(mock_loop) + + +def test_start_shutdown_handler_with_callback(): + """Test that _shutdown_handler invokes the user-provided shutdown_callback.""" + from unittest.mock import patch, MagicMock + + user_callback = MagicMock() + server = ReduceStreamAsyncServer( + reduce_stream_instance=ExampleClass, shutdown_callback=user_callback + ) + + def close_coro(coro, **kwargs): + coro.close() + + with patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun: + mock_aiorun.run.side_effect = close_coro + server.start() + + shutdown_handler = mock_aiorun.run.call_args[1]["shutdown_callback"] + mock_loop = MagicMock() + shutdown_handler(mock_loop) + + user_callback.assert_called_once_with(mock_loop) + + +def test_start_exits_on_error(): + """Test that start() calls sys.exit(1) when servicer reports an error.""" + from unittest.mock import patch + + server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) + + def fake_aiorun_run(coro, **kwargs): + # Simulate aiorun completing after a UDF error was recorded + coro.close() # prevent "coroutine never awaited" warning + server._error = RuntimeError("UDF failure") + + with ( + patch("pynumaflow.reducestreamer.async_server.aiorun") as mock_aiorun, + patch("pynumaflow.reducestreamer.async_server.sys") as mock_sys, + ): + mock_aiorun.run.side_effect = fake_aiorun_run + server.start() -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + mock_sys.exit.assert_called_once_with(1) diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py index 79d9f8cf..d2ba2a50 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py @@ -1,11 +1,11 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable from unittest.mock import MagicMock + import grpc -from grpc.aio._server import Server +import pytest from pynumaflow import setup_logging from pynumaflow._constants import WIN_START_TIME, WIN_END_TIME @@ -29,6 +29,8 @@ LOGGER = setup_logging(__name__) +SOCK_PATH = "unix:///tmp/reduce_stream_err.sock" + def request_generator(count, request, resetkey: bool = False): for i in range(count): @@ -70,11 +72,6 @@ def start_request(multiple_window: False) -> (Datum, tuple): return request, metadata -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/reduce_stream_err.sock") -_loop = None - - def startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() @@ -128,96 +125,94 @@ def NewAsyncReduceStreamer(): return udfs -async def start_server(udfs): +async def _start_server(udfs): server = grpc.aio.server() reduce_pb2_grpc.add_ReduceServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/reduce_stream_err.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -class TestAsyncReduceStreamerErr(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncReduceStreamer() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/reduce_stream_err.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_reduce_stream_err_server(): + """Module-scoped fixture: starts an async gRPC reduce stream error server.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncReduceStreamer() + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) + + # Wait for the server to be ready + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except BaseException as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_reduce(self) -> None: - stub = self.__stub() - request, metadata = start_request(multiple_window=False) - generator_response = None - try: - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=10, request=request), - ) - counter = 0 - for _ in generator_response: - counter += 1 - except BaseException as err: - self.assertTrue("Got a runtime error from reduce handler." in err.__str__()) - return - self.fail("Expected an exception.") - - def test_reduce_window_len(self) -> None: - stub = self.__stub() - request, metadata = start_request(multiple_window=True) - generator_response = None - try: - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=10, request=request) - ) - counter = 0 - for _ in generator_response: - counter += 1 - except Exception as err: - self.assertTrue( - "reduce append operation error: invalid number of windows" in err.__str__() - ) - return - try: - request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN - generator_response = stub.ReduceFn( - request_iterator=request_generator(count=10, request=request) - ) - counter = 0 - for _ in generator_response: - counter += 1 - except Exception as err: - self.assertTrue( - "reduce create operation error: invalid number of windows" in err.__str__() - ) - return - self.fail("Expected an exception.") + yield loop - def __stub(self): - return reduce_pb2_grpc.ReduceStub(_channel) + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def reduce_stream_err_stub(async_reduce_stream_err_server): + """Returns a ReduceStub connected to the running async error server.""" + return reduce_pb2_grpc.ReduceStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_reduce(reduce_stream_err_stub) -> None: + request, metadata = start_request(multiple_window=False) + generator_response = None + try: + generator_response = reduce_stream_err_stub.ReduceFn( + request_iterator=request_generator(count=10, request=request), + ) + counter = 0 + for _ in generator_response: + counter += 1 + except BaseException as err: + assert "Got a runtime error from reduce handler." in str(err) + return + pytest.fail("Expected an exception.") + + +def test_reduce_window_len(reduce_stream_err_stub) -> None: + request, metadata = start_request(multiple_window=True) + generator_response = None + try: + generator_response = reduce_stream_err_stub.ReduceFn( + request_iterator=request_generator(count=10, request=request) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + assert "reduce append operation error: invalid number of windows" in str(err) + return + try: + request.operation.event = reduce_pb2.ReduceRequest.WindowOperation.Event.OPEN + generator_response = reduce_stream_err_stub.ReduceFn( + request_iterator=request_generator(count=10, request=request) + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + assert "reduce create operation error: invalid number of windows" in str(err) + return + pytest.fail("Expected an exception.") + + +# --- Standalone test functions (not part of the TestCase) --- async def _emit_one_handler(keys, datums, output, md): @@ -460,8 +455,3 @@ async def send_eof_then_wait_and_raise(): assert task.consumer_future.done() asyncio.run(_run()) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() diff --git a/packages/pynumaflow/tests/reducestreamer/test_datatypes.py b/packages/pynumaflow/tests/reducestreamer/test_datatypes.py index 4d3ffc97..eeab6db2 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_datatypes.py +++ b/packages/pynumaflow/tests/reducestreamer/test_datatypes.py @@ -1,7 +1,7 @@ from copy import deepcopy -import unittest from collections.abc import AsyncIterable +import pytest from google.protobuf import timestamp_pb2 as _timestamp_pb2 from pynumaflow.reducer import Reducer, Messages @@ -24,150 +24,145 @@ TEST_HEADERS = {"key1": "value1", "key2": "value2"} -class TestDatum(unittest.TestCase): - def test_err_event_time(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=ts, - watermark=ts, - headers=TEST_HEADERS, - ) - self.assertEqual( - "Wrong data type: " - "for Datum.event_time", - str(context.exception), - ) - - def test_err_watermark(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - with self.assertRaises(Exception) as context: - Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=ts, - headers=TEST_HEADERS, - ) - self.assertEqual( - "Wrong data type: " - "for Datum.watermark", - str(context.exception), - ) - - def test_value(self): - d = Datum( +def test_err_event_time(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with pytest.raises(Exception) as exc_info: + Datum( keys=TEST_KEYS, value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), + event_time=ts, + watermark=ts, headers=TEST_HEADERS, ) - self.assertEqual(mock_message(), d.value) + assert str(exc_info.value) == ( + "Wrong data type: " "for Datum.event_time" + ) - def test_key(self): - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - ) - self.assertEqual(TEST_KEYS, d.keys) - def test_event_time(self): - d = Datum( +def test_err_watermark(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with pytest.raises(Exception) as exc_info: + Datum( keys=TEST_KEYS, value=mock_message(), event_time=mock_event_time(), - watermark=mock_watermark(), + watermark=ts, headers=TEST_HEADERS, ) - self.assertEqual(mock_event_time(), d.event_time) - self.assertEqual(TEST_HEADERS, d.headers) - - def test_watermark(self): - d = Datum( - keys=TEST_KEYS, - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - ) - self.assertEqual(mock_watermark(), d.watermark) - self.assertEqual({}, d.headers) - - -class TestIntervalWindow(unittest.TestCase): - def test_start(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(mock_start_time(), i.start) - - def test_end(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(mock_end_time(), i.end) - - -class TestMetadata(unittest.TestCase): - def test_interval_window(self): - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - m = Metadata(interval_window=i) - self.assertEqual(type(i), type(m.interval_window)) - self.assertEqual(i, m.interval_window) - - -class TestReducerWindow(unittest.TestCase): - def test_create_window(self): - rw = ReduceWindow(start=mock_start_time(), end=mock_end_time(), slot="slot-0") - i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) - self.assertEqual(rw.window, i) - self.assertEqual(rw.start, mock_start_time()) - self.assertEqual(rw.end, mock_end_time()) - self.assertEqual(rw.slot, "slot-0") - - -class TestReducerClass(unittest.TestCase): - class ExampleClass(Reducer): - async def handler( - self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata - ) -> Messages: - pass - - def __init__(self, test1, test2): - self.test1 = test1 - self.test2 = test2 - self.test3 = self.test1 - - def test_init(self): - r = self.ExampleClass(test1=1, test2=2) - self.assertEqual(1, r.test1) - self.assertEqual(2, r.test2) - self.assertEqual(1, r.test3) - - def test_deep_copy(self): - """Test that the deepcopy works as expected""" - r = self.ExampleClass(test1=1, test2=2) - # Create a copy of r - r_copy = deepcopy(r) - # Check that the attributes are the same - self.assertEqual(1, r_copy.test1) - self.assertEqual(2, r_copy.test2) - self.assertEqual(1, r_copy.test3) - # Check that the objects are not the same - self.assertNotEqual(id(r), id(r_copy)) - # Update the attributes of r - r.test1 = 5 - r.test3 = 6 - # Check that the other object is not updated - self.assertNotEqual(r.test1, r_copy.test1) - self.assertNotEqual(r.test3, r_copy.test3) - self.assertNotEqual(id(r.test3), id(r_copy.test3)) - # Verify that the instance type is correct - self.assertTrue(isinstance(r_copy, self.ExampleClass)) - self.assertTrue(isinstance(r_copy, Reducer)) - - -if __name__ == "__main__": - unittest.main() + assert str(exc_info.value) == ( + "Wrong data type: " "for Datum.watermark" + ) + + +def test_datum_value(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=TEST_HEADERS, + ) + assert mock_message() == d.value + + +def test_datum_key(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + ) + assert TEST_KEYS == d.keys + + +def test_datum_event_time(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=TEST_HEADERS, + ) + assert mock_event_time() == d.event_time + assert TEST_HEADERS == d.headers + + +def test_datum_watermark(): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + ) + assert mock_watermark() == d.watermark + assert {} == d.headers + + +def test_interval_window_start(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert mock_start_time() == i.start + + +def test_interval_window_end(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert mock_end_time() == i.end + + +def test_metadata_interval_window(): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + m = Metadata(interval_window=i) + assert type(i) is type(m.interval_window) + assert i == m.interval_window + + +def test_create_window(): + rw = ReduceWindow(start=mock_start_time(), end=mock_end_time(), slot="slot-0") + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + assert rw.window == i + assert rw.start == mock_start_time() + assert rw.end == mock_end_time() + assert rw.slot == "slot-0" + + +class ExampleClass(Reducer): + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + pass + + def __init__(self, test1, test2): + self.test1 = test1 + self.test2 = test2 + self.test3 = self.test1 + + +def test_reducer_init(): + r = ExampleClass(test1=1, test2=2) + assert 1 == r.test1 + assert 2 == r.test2 + assert 1 == r.test3 + + +def test_reducer_deep_copy(): + """Test that the deepcopy works as expected""" + r = ExampleClass(test1=1, test2=2) + # Create a copy of r + r_copy = deepcopy(r) + # Check that the attributes are the same + assert 1 == r_copy.test1 + assert 2 == r_copy.test2 + assert 1 == r_copy.test3 + # Check that the objects are not the same + assert id(r) != id(r_copy) + # Update the attributes of r + r.test1 = 5 + r.test3 = 6 + # Check that the other object is not updated + assert r.test1 != r_copy.test1 + assert r.test3 != r_copy.test3 + assert id(r.test3) != id(r_copy.test3) + # Verify that the instance type is correct + assert isinstance(r_copy, ExampleClass) + assert isinstance(r_copy, Reducer) diff --git a/packages/pynumaflow/tests/reducestreamer/test_messages.py b/packages/pynumaflow/tests/reducestreamer/test_messages.py index 01f7759c..d8609017 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_messages.py +++ b/packages/pynumaflow/tests/reducestreamer/test_messages.py @@ -1,45 +1,42 @@ -import unittest - from pynumaflow.reducestreamer import Message, DROP from tests.testing_utils import mock_message -class TestMessage(unittest.TestCase): - def test_key(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) - print(msg) - self.assertEqual(mock_obj["Keys"], msg.keys) - - def test_value(self): - mock_obj = {"Keys": ["test-key"], "Value": mock_message()} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) - self.assertEqual(mock_obj["Value"], msg.value) - - def test_message_to_all(self): - mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} - msg = Message(mock_obj["Value"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to_drop(self): - mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} - msg = Message(b"").to_drop() - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - def test_message_to(self): - mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} - msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) - self.assertEqual(Message, type(msg)) - self.assertEqual(mock_obj["Keys"], msg.keys) - self.assertEqual(mock_obj["Value"], msg.value) - self.assertEqual(mock_obj["Tags"], msg.tags) - - -if __name__ == "__main__": - unittest.main() +def test_key(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + print(msg) + assert mock_obj["Keys"] == msg.keys + + +def test_value(): + mock_obj = {"Keys": ["test-key"], "Value": mock_message()} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"]) + assert mock_obj["Value"] == msg.value + + +def test_message_to_all(): + mock_obj = {"Keys": [], "Value": mock_message(), "Tags": []} + msg = Message(mock_obj["Value"]) + assert isinstance(msg, Message) + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to_drop(): + mock_obj = {"Keys": [], "Value": b"", "Tags": [DROP]} + msg = Message(b"").to_drop() + assert isinstance(msg, Message) + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags + + +def test_message_to(): + mock_obj = {"Keys": ["__KEY__"], "Value": mock_message(), "Tags": ["__TAG__"]} + msg = Message(value=mock_obj["Value"], keys=mock_obj["Keys"], tags=mock_obj["Tags"]) + assert isinstance(msg, Message) + assert mock_obj["Keys"] == msg.keys + assert mock_obj["Value"] == msg.value + assert mock_obj["Tags"] == msg.tags diff --git a/packages/pynumaflow/tests/sideinput/test_responses.py b/packages/pynumaflow/tests/sideinput/test_responses.py index 859f4bb1..6c5329a2 100644 --- a/packages/pynumaflow/tests/sideinput/test_responses.py +++ b/packages/pynumaflow/tests/sideinput/test_responses.py @@ -1,53 +1,39 @@ -import unittest - from pynumaflow.sideinput import Response, SideInput -class TestResponse(unittest.TestCase): - """ - Test the Response class for SideInput - """ - - def test_broadcast_message(self): - """ - Test the broadcast_message method, - where we expect the no_broadcast flag to be False. - """ - succ_response = Response.broadcast_message(b"2") - self.assertFalse(succ_response.no_broadcast) - self.assertEqual(b"2", succ_response.value) - - def test_no_broadcast_message(self): - """ - Test the no_broadcast_message method, - where we expect the no_broadcast flag to be True. - """ - succ_response = Response.no_broadcast_message() - self.assertTrue(succ_response.no_broadcast) - - class ExampleSideInput(SideInput): def retrieve_handler(self) -> Response: return Response.broadcast_message(b"testMessage") -class TestSideInputClass(unittest.TestCase): - def setUp(self) -> None: - # Create a side input class instance - self.side_input_instance = ExampleSideInput() +def test_broadcast_message(): + """ + Test the broadcast_message method, + where we expect the no_broadcast flag to be False. + """ + succ_response = Response.broadcast_message(b"2") + assert not succ_response.no_broadcast + assert b"2" == succ_response.value + - def test_side_input_class_call(self): - """Test that the __call__ functionality for the class works, - ie the class instance can be called directly to invoke the handler function - """ - # make a call to the class directly - ret = self.side_input_instance() - self.assertEqual(b"testMessage", ret.value) - # make a call to the handler - ret_handler = self.side_input_instance.retrieve_handler() - # Both responses should be equal - self.assertEqual(ret, ret_handler) +def test_no_broadcast_message(): + """ + Test the no_broadcast_message method, + where we expect the no_broadcast flag to be True. + """ + succ_response = Response.no_broadcast_message() + assert succ_response.no_broadcast -if __name__ == "__main__": - unittest.main() +def test_side_input_class_call(): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + side_input_instance = ExampleSideInput() + # make a call to the class directly + ret = side_input_instance() + assert b"testMessage" == ret.value + # make a call to the handler + ret_handler = side_input_instance.retrieve_handler() + # Both responses should be equal + assert ret == ret_handler diff --git a/packages/pynumaflow/tests/sideinput/test_side_input_server.py b/packages/pynumaflow/tests/sideinput/test_side_input_server.py index aaec4c80..fe110eb3 100644 --- a/packages/pynumaflow/tests/sideinput/test_side_input_server.py +++ b/packages/pynumaflow/tests/sideinput/test_side_input_server.py @@ -1,14 +1,10 @@ -import unittest -from unittest.mock import patch - -import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.proto.sideinput import sideinput_pb2 +from pynumaflow.proto.sideinput import sideinput_pb2 from pynumaflow.sideinput import Response, SideInputServer -from tests.testing_utils import mock_terminate_on_stop def retrieve_side_input_handler() -> Response: @@ -29,135 +25,98 @@ def mock_message(): return msg -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestServer(unittest.TestCase): - """ - Test the SideInput grpc server - """ - - def setUp(self) -> None: - server = SideInputServer(retrieve_side_input_handler) - my_service = server.servicer - services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: my_service} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_init_with_args(self) -> None: - """ - Test the initialization of the SideInput class, - """ - my_servicer = SideInputServer( - side_input_instance=retrieve_side_input_handler, - sock_path="/tmp/test_side_input.sock", - max_message_size=1024 * 1024 * 5, - ) - self.assertEqual(my_servicer.sock_path, "unix:///tmp/test_side_input.sock") - self.assertEqual(my_servicer.max_message_size, 1024 * 1024 * 5) - - def test_side_input_err(self): - """ - Test the error case for the RetrieveSideInput method, - """ - server = SideInputServer(err_retrieve_handler) - my_service = server.servicer - services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: my_service} - self.test_server = server_from_dictionary(services, strict_real_time()) - - method = self.test_server.invoke_unary_unary( - method_descriptor=( - sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ - "RetrieveSideInput" - ] - ), - invocation_metadata={ - ("this_metadata_will_be_skipped", "test_ignore"), - }, - request=_empty_pb2.Empty(), - timeout=1, - ) - response, metadata, code, details = method.termination() - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( - method_descriptor=( - sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name["IsReady"] - ), - invocation_metadata={}, - request=_empty_pb2.Empty(), - timeout=1, - ) - - response, metadata, code, details = method.termination() - expected = sideinput_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_side_input_message(self): - """ - Test the broadcast_message method, - where we expect the no_broadcast flag to be False and - the message value to be the mock_message. - """ - method = self.test_server.invoke_unary_unary( - method_descriptor=( - sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ - "RetrieveSideInput" - ] - ), - invocation_metadata={ - ("this_metadata_will_be_skipped", "test_ignore"), - }, - request=_empty_pb2.Empty(), - timeout=1, - ) - response, metadata, code, details = method.termination() - self.assertEqual(mock_message(), response.value) - self.assertEqual(code, StatusCode.OK) - - def test_side_input_no_broadcast(self): - """ - Test the no_broadcast_message method, - where we expect the no_broadcast flag to be True. - """ - server = SideInputServer(side_input_instance=retrieve_no_broadcast_handler) - my_servicer = server.servicer - services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - method = self.test_server.invoke_unary_unary( - method_descriptor=( - sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ - "RetrieveSideInput" - ] - ), - invocation_metadata={ - ("this_metadata_will_be_skipped", "test_ignore"), - }, - request=_empty_pb2.Empty(), - timeout=1, - ) - response, metadata, code, details = method.termination() - self.assertEqual(code, StatusCode.OK) - self.assertEqual(response.no_broadcast, True) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - SideInputServer() - - def test_max_threads(self): - # max cap at 16 - server = SideInputServer(retrieve_side_input_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = SideInputServer(retrieve_side_input_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = SideInputServer(retrieve_side_input_handler) - self.assertEqual(server.max_threads, 4) - - -if __name__ == "__main__": - unittest.main() +@pytest.fixture() +def sideinput_test_server(): + server = SideInputServer(retrieve_side_input_handler) + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: server.servicer} + return server_from_dictionary(services, strict_real_time()) + + +def _invoke_retrieve(test_server, metadata_set=None): + """Helper to invoke RetrieveSideInput unary method.""" + if metadata_set is None: + metadata_set = {("this_metadata_will_be_skipped", "test_ignore")} + return test_server.invoke_unary_unary( + method_descriptor=( + sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name[ + "RetrieveSideInput" + ] + ), + invocation_metadata=metadata_set, + request=_empty_pb2.Empty(), + timeout=1, + ) + + +def test_init_with_args(): + my_servicer = SideInputServer( + side_input_instance=retrieve_side_input_handler, + sock_path="/tmp/test_side_input.sock", + max_message_size=1024 * 1024 * 5, + ) + assert my_servicer.sock_path == "unix:///tmp/test_side_input.sock" + assert my_servicer.max_message_size == 1024 * 1024 * 5 + + +def test_side_input_err(): + server = SideInputServer(err_retrieve_handler) + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: server.servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + method = _invoke_retrieve(test_server) + response, metadata, code, details = method.termination() + assert code == StatusCode.INTERNAL + + +def test_is_ready(sideinput_test_server): + method = sideinput_test_server.invoke_unary_unary( + method_descriptor=( + sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"].methods_by_name["IsReady"] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + response, metadata, code, details = method.termination() + assert response == sideinput_pb2.ReadyResponse(ready=True) + assert code == StatusCode.OK + + +def test_side_input_message(sideinput_test_server): + """Broadcast message: no_broadcast flag is False and value is mock_message.""" + method = _invoke_retrieve(sideinput_test_server) + response, metadata, code, details = method.termination() + assert response.value == mock_message() + assert code == StatusCode.OK + + +def test_side_input_no_broadcast(): + """No-broadcast message: no_broadcast flag is True.""" + server = SideInputServer(side_input_instance=retrieve_no_broadcast_handler) + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: server.servicer} + test_server = server_from_dictionary(services, strict_real_time()) + + method = _invoke_retrieve(test_server) + response, metadata, code, details = method.termination() + assert code == StatusCode.OK + assert response.no_broadcast is True + + +def test_invalid_input(): + with pytest.raises(TypeError): + SideInputServer() + + +def test_max_threads(): + # max cap at 16 + server = SideInputServer(retrieve_side_input_handler, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = SideInputServer(retrieve_side_input_handler, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = SideInputServer(retrieve_side_input_handler) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/sink/test_async_sink.py b/packages/pynumaflow/tests/sink/test_async_sink.py index c28326af..c0c23b3d 100644 --- a/packages/pynumaflow/tests/sink/test_async_sink.py +++ b/packages/pynumaflow/tests/sink/test_async_sink.py @@ -1,12 +1,11 @@ import asyncio import logging import threading -import unittest from collections.abc import AsyncIterable import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio import Server from pynumaflow import setup_logging from pynumaflow._constants import ( @@ -35,6 +34,8 @@ LOGGER = setup_logging(__name__) +SOCK_PATH = "unix:///tmp/async_sink.sock" + async def udsink_handler(datums: AsyncIterable[Datum]) -> Responses: responses = Responses() @@ -102,275 +103,264 @@ def request_generator(count, req_type="success", session=1, handshake=True): yield sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)) -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/async_sink.sock") -_loop = None - - -def startup_callable(loop): +def _startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() -async def start_server(): +async def _start_server(): server = grpc.aio.server() server_instance = SinkAsyncServer(sinker_instance=udsink_handler) uds = server_instance.servicer sink_pb2_grpc.add_SinkServicer_to_server(uds, server) - listen_addr = "unix:///tmp/async_sink.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -# We are mocking the terminate function from the psutil to not exit the program during testing -class TestAsyncSink(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - asyncio.run_coroutine_threadsafe(start_server(), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/async_sink.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: - try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: - LOGGER.error(e) - - # - def test_run_server(self) -> None: - with grpc.insecure_channel("unix:///tmp/async_sink.sock") as channel: - stub = sink_pb2_grpc.SinkStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def test_sink(self) -> None: - stub = self.__stub() - generator_response = None - grpc_exception = None - try: - generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, req_type="success", session=1) - ) - handshake = next(generator_response) - # assert that handshake response is received. - self.assertTrue(handshake.handshake.sot) - - data_resp = [] - for r in generator_response: - data_resp.append(r) - - # 1 sink data response + 1 EOT response - self.assertEqual(2, len(data_resp)) - - idx = 0 - # capture the output from the SinkFn generator and assert. - for resp in data_resp[0].results: - self.assertEqual(resp.id, str(idx)) - self.assertEqual(resp.status, sink_pb2.Status.SUCCESS) - idx += 1 - # EOT Response - self.assertEqual(data_resp[1].status.eot, True) + return server - except grpc.RpcError as e: - logging.error(e) - grpc_exception = e - self.assertIsNone(grpc_exception) - - def test_sink_err(self) -> None: - stub = self.__stub() - grpc_exception = None - try: - generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, req_type="err") - ) - for _ in generator_response: - pass - except BaseException as e: - self.assertTrue( - f"{ERR_UDF_EXCEPTION_STRING}: ValueError('test_mock_err_message')" in e.__str__() - ) - return - except grpc.RpcError as e: - grpc_exception = e - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) - print(e.details()) +@pytest.fixture(scope="module") +def async_sink_server(): + """Module-scoped fixture: starts an async gRPC sink server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) + thread.start() - self.assertIsNotNone(grpc_exception) + future = asyncio.run_coroutine_threadsafe(_start_server(), loop=loop) + future.result(timeout=10) - def test_sink_err_handshake(self) -> None: - stub = self.__stub() - grpc_exception = None + while True: try: - generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, req_type="success", handshake=False) - ) - for _ in generator_response: - pass - except BaseException as e: - self.assertTrue("ReadFn: expected handshake message" in e.__str__()) - return - except grpc.RpcError as e: - grpc_exception = e - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) - print(e.details()) + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) - self.assertIsNotNone(grpc_exception) + yield loop - def test_sink_fallback(self) -> None: - stub = self.__stub() - try: - generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, req_type="fallback", session=1) - ) - handshake = next(generator_response) - # assert that handshake response is received. - self.assertTrue(handshake.handshake.sot) - - data_resp = [] - for r in generator_response: - data_resp.append(r) - - # 1 sink data response + 1 EOT response - self.assertEqual(2, len(data_resp)) - - idx = 0 - # capture the output from the SinkFn generator and assert. - for resp in data_resp[0].results: - self.assertEqual(resp.id, str(idx)) - self.assertEqual(resp.status, sink_pb2.Status.FALLBACK) - idx += 1 - # EOT Response - self.assertEqual(data_resp[1].status.eot, True) + loop.stop() + LOGGER.info("stopped the event loop") - except grpc.RpcError as e: - logging.error(e) - def test_sink_on_success1(self) -> None: - stub = self.__stub() - grpc_exception = None - try: - generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, req_type="on_success1", session=1) - ) - handshake = next(generator_response) - # assert that handshake response is received. - self.assertTrue(handshake.handshake.sot) - - data_resp = [] - for r in generator_response: - data_resp.append(r) - - # 1 sink data response + 1 EOT response - self.assertEqual(2, len(data_resp)) - - idx = 0 - # capture the output from the SinkFn generator and assert. - for resp in data_resp[0].results: - self.assertEqual(resp.id, str(idx)) - self.assertEqual(resp.status, sink_pb2.Status.ON_SUCCESS) - idx += 1 - # EOT Response - self.assertEqual(data_resp[1].status.eot, True) +@pytest.fixture() +def sink_stub(async_sink_server): + """Returns a SinkStub connected to the running async server.""" + return sink_pb2_grpc.SinkStub(grpc.insecure_channel(SOCK_PATH)) - except grpc.RpcError as e: - logging.error(e) - grpc_exception = e - self.assertIsNone(grpc_exception) +def test_run_server(async_sink_server): + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = sink_pb2_grpc.SinkStub(channel) - def test_sink_on_success2(self) -> None: - stub = self.__stub() - grpc_exception = None + request = _empty_pb2.Empty() + response = None try: - generator_response = stub.SinkFn( - request_iterator=request_generator(count=10, req_type="on_success2", session=1) - ) - handshake = next(generator_response) - # assert that handshake response is received. - self.assertTrue(handshake.handshake.sot) - - data_resp = [] - for r in generator_response: - data_resp.append(r) - - # 1 sink data response + 1 EOT response - self.assertEqual(2, len(data_resp)) - - idx = 0 - # capture the output from the SinkFn generator and assert. - for resp in data_resp[0].results: - self.assertEqual(resp.id, str(idx)) - self.assertEqual(resp.status, sink_pb2.Status.ON_SUCCESS) - idx += 1 - # EOT Response - self.assertEqual(data_resp[1].status.eot, True) - + response = stub.IsReady(request=request) except grpc.RpcError as e: logging.error(e) - grpc_exception = e - - self.assertIsNone(grpc_exception) - - def __stub(self): - return sink_pb2_grpc.SinkStub(_channel) - - def test_invalid_server_type(self) -> None: - with self.assertRaises(TypeError): - SinkAsyncServer() - - @mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_FALLBACK_SINK) - def test_start_fallback_sink(self): - server = SinkAsyncServer(sinker_instance=udsink_handler) - self.assertEqual(server.sock_path, f"unix://{FALLBACK_SINK_SOCK_PATH}") - self.assertEqual(server.server_info_file, FALLBACK_SINK_SERVER_INFO_FILE_PATH) - - @mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_ON_SUCCESS_SINK) - def test_start_on_success_sink(self): - server = SinkAsyncServer(sinker_instance=udsink_handler) - self.assertEqual(server.sock_path, f"unix://{ON_SUCCESS_SINK_SOCK_PATH}") - self.assertEqual(server.server_info_file, ON_SUCCESS_SINK_SERVER_INFO_FILE_PATH) - - def test_max_threads(self): - # max cap at 16 - server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = SinkAsyncServer(sinker_instance=udsink_handler) - self.assertEqual(server.max_threads, 4) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + assert response.ready + + +def test_sink(sink_stub): + generator_response = None + grpc_exception = None + try: + generator_response = sink_stub.SinkFn( + request_iterator=request_generator(count=10, req_type="success", session=1) + ) + handshake = next(generator_response) + # assert that handshake response is received. + assert handshake.handshake.sot + + data_resp = [] + for r in generator_response: + data_resp.append(r) + + # 1 sink data response + 1 EOT response + assert len(data_resp) == 2 + + idx = 0 + # capture the output from the SinkFn generator and assert. + for resp in data_resp[0].results: + assert resp.id == str(idx) + assert resp.status == sink_pb2.Status.SUCCESS + idx += 1 + # EOT Response + assert data_resp[1].status.eot is True + + except grpc.RpcError as e: + logging.error(e) + grpc_exception = e + + assert grpc_exception is None + + +def test_sink_err(sink_stub): + grpc_exception = None + try: + generator_response = sink_stub.SinkFn( + request_iterator=request_generator(count=10, req_type="err") + ) + for _ in generator_response: + pass + except BaseException as e: + assert f"{ERR_UDF_EXCEPTION_STRING}: ValueError('test_mock_err_message')" in str(e) + return + except grpc.RpcError as e: + grpc_exception = e + assert e.code() == grpc.StatusCode.UNKNOWN + print(e.details()) + + assert grpc_exception is not None + + +def test_sink_err_handshake(sink_stub): + grpc_exception = None + try: + generator_response = sink_stub.SinkFn( + request_iterator=request_generator(count=10, req_type="success", handshake=False) + ) + for _ in generator_response: + pass + except BaseException as e: + assert "ReadFn: expected handshake message" in str(e) + return + except grpc.RpcError as e: + grpc_exception = e + assert e.code() == grpc.StatusCode.UNKNOWN + print(e.details()) + + assert grpc_exception is not None + + +def test_sink_fallback(sink_stub): + try: + generator_response = sink_stub.SinkFn( + request_iterator=request_generator(count=10, req_type="fallback", session=1) + ) + handshake = next(generator_response) + # assert that handshake response is received. + assert handshake.handshake.sot + + data_resp = [] + for r in generator_response: + data_resp.append(r) + + # 1 sink data response + 1 EOT response + assert len(data_resp) == 2 + + idx = 0 + # capture the output from the SinkFn generator and assert. + for resp in data_resp[0].results: + assert resp.id == str(idx) + assert resp.status == sink_pb2.Status.FALLBACK + idx += 1 + # EOT Response + assert data_resp[1].status.eot is True + + except grpc.RpcError as e: + logging.error(e) + + +def test_sink_on_success1(sink_stub): + grpc_exception = None + try: + generator_response = sink_stub.SinkFn( + request_iterator=request_generator(count=10, req_type="on_success1", session=1) + ) + handshake = next(generator_response) + # assert that handshake response is received. + assert handshake.handshake.sot + + data_resp = [] + for r in generator_response: + data_resp.append(r) + + # 1 sink data response + 1 EOT response + assert len(data_resp) == 2 + + idx = 0 + # capture the output from the SinkFn generator and assert. + for resp in data_resp[0].results: + assert resp.id == str(idx) + assert resp.status == sink_pb2.Status.ON_SUCCESS + idx += 1 + # EOT Response + assert data_resp[1].status.eot is True + + except grpc.RpcError as e: + logging.error(e) + grpc_exception = e + + assert grpc_exception is None + + +def test_sink_on_success2(sink_stub): + grpc_exception = None + try: + generator_response = sink_stub.SinkFn( + request_iterator=request_generator(count=10, req_type="on_success2", session=1) + ) + handshake = next(generator_response) + # assert that handshake response is received. + assert handshake.handshake.sot + + data_resp = [] + for r in generator_response: + data_resp.append(r) + + # 1 sink data response + 1 EOT response + assert len(data_resp) == 2 + + idx = 0 + # capture the output from the SinkFn generator and assert. + for resp in data_resp[0].results: + assert resp.id == str(idx) + assert resp.status == sink_pb2.Status.ON_SUCCESS + idx += 1 + # EOT Response + assert data_resp[1].status.eot is True + + except grpc.RpcError as e: + logging.error(e) + grpc_exception = e + + assert grpc_exception is None + + +def test_invalid_server_type(): + with pytest.raises(TypeError): + SinkAsyncServer() + + +@mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_FALLBACK_SINK) +def test_start_fallback_sink(): + server = SinkAsyncServer(sinker_instance=udsink_handler) + assert server.sock_path == f"unix://{FALLBACK_SINK_SOCK_PATH}" + assert server.server_info_file == FALLBACK_SINK_SERVER_INFO_FILE_PATH + + +@mockenv(NUMAFLOW_UD_CONTAINER_TYPE=UD_CONTAINER_ON_SUCCESS_SINK) +def test_start_on_success_sink(): + server = SinkAsyncServer(sinker_instance=udsink_handler) + assert server.sock_path == f"unix://{ON_SUCCESS_SINK_SOCK_PATH}" + assert server.server_info_file == ON_SUCCESS_SINK_SERVER_INFO_FILE_PATH + + +def test_max_threads(): + # max cap at 16 + server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = SinkAsyncServer(sinker_instance=udsink_handler) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/sink/test_datatypes.py b/packages/pynumaflow/tests/sink/test_datatypes.py index 07384683..60fe81de 100644 --- a/packages/pynumaflow/tests/sink/test_datatypes.py +++ b/packages/pynumaflow/tests/sink/test_datatypes.py @@ -1,6 +1,6 @@ -import unittest from datetime import datetime, timezone +import pytest from google.protobuf import timestamp_pb2 as _timestamp_pb2 from pynumaflow.sinker._dtypes import ( @@ -28,103 +28,99 @@ def mock_headers(): return headers -class TestDatum(unittest.TestCase): - def test_err_event_time(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - with self.assertRaises(Exception) as context: - Datum( - keys=["test_key"], - sink_msg_id="test_id_0", - value=mock_message(), - event_time=ts, - watermark=ts, - headers=mock_headers(), - ) - self.assertEqual( - "Wrong data type: " - "for Datum.event_time", - str(context.exception), - ) - - def test_err_watermark(self): - ts = _timestamp_pb2.Timestamp() - ts.GetCurrentTime() - with self.assertRaises(Exception) as context: - Datum( - keys=["test_key"], - sink_msg_id="test_id_0", - value=mock_message(), - event_time=mock_event_time(), - watermark=ts, - headers=mock_headers(), - ) - self.assertEqual( - "Wrong data type: " - "for Datum.watermark", - str(context.exception), - ) - - def test_value(self): - d = Datum( - keys=["test_key"], - sink_msg_id="test_id_0", - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - headers=mock_headers(), - ) - self.assertEqual(mock_message(), d.value) - self.assertEqual( - "keys: ['test_key'], " - "id: test_id_0, value: test_mock_message, " - "event_time: 2022-09-12 16:00:00+00:00, watermark: 2022-09-12 16:01:00+00:00, " - "headers: {'key1': 'value1', 'key2': 'value2'}", - str(d), - ) - self.assertEqual( - "keys: ['test_key'], " - "id: test_id_0, value: test_mock_message, " - "event_time: 2022-09-12 16:00:00+00:00, " - "watermark: 2022-09-12 16:01:00+00:00, " - "headers: {'key1': 'value1', 'key2': 'value2'}", - repr(d), - ) - self.assertEqual(mock_headers(), d.headers) - - def test_id(self): - d = Datum( +def test_err_event_time(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with pytest.raises(Exception) as exc_info: + Datum( keys=["test_key"], sink_msg_id="test_id_0", value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), + event_time=ts, + watermark=ts, headers=mock_headers(), ) - self.assertEqual("test_id_0", d.id) + assert ( + "Wrong data type: " + "for Datum.event_time" == str(exc_info.value) + ) - def test_event_time(self): - d = Datum( - keys=["test_key"], - sink_msg_id="test_id_0", - value=mock_message(), - event_time=mock_event_time(), - watermark=mock_watermark(), - headers=mock_headers(), - ) - self.assertEqual(mock_event_time(), d.event_time) - def test_watermark(self): - d = Datum( +def test_err_watermark(): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + with pytest.raises(Exception) as exc_info: + Datum( keys=["test_key"], sink_msg_id="test_id_0", value=mock_message(), event_time=mock_event_time(), - watermark=mock_watermark(), + watermark=ts, headers=mock_headers(), ) - self.assertEqual(mock_watermark(), d.watermark) - - -if __name__ == "__main__": - unittest.main() + assert ( + "Wrong data type: " + "for Datum.watermark" == str(exc_info.value) + ) + + +def test_value(): + d = Datum( + keys=["test_key"], + sink_msg_id="test_id_0", + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=mock_headers(), + ) + assert mock_message() == d.value + assert ( + "keys: ['test_key'], " + "id: test_id_0, value: test_mock_message, " + "event_time: 2022-09-12 16:00:00+00:00, watermark: 2022-09-12 16:01:00+00:00, " + "headers: {'key1': 'value1', 'key2': 'value2'}" == str(d) + ) + assert ( + "keys: ['test_key'], " + "id: test_id_0, value: test_mock_message, " + "event_time: 2022-09-12 16:00:00+00:00, " + "watermark: 2022-09-12 16:01:00+00:00, " + "headers: {'key1': 'value1', 'key2': 'value2'}" == repr(d) + ) + assert mock_headers() == d.headers + + +def test_id(): + d = Datum( + keys=["test_key"], + sink_msg_id="test_id_0", + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=mock_headers(), + ) + assert "test_id_0" == d.id + + +def test_event_time(): + d = Datum( + keys=["test_key"], + sink_msg_id="test_id_0", + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=mock_headers(), + ) + assert mock_event_time() == d.event_time + + +def test_watermark(): + d = Datum( + keys=["test_key"], + sink_msg_id="test_id_0", + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + headers=mock_headers(), + ) + assert mock_watermark() == d.watermark diff --git a/packages/pynumaflow/tests/sink/test_responses.py b/packages/pynumaflow/tests/sink/test_responses.py index 73a11cf7..54bc5682 100644 --- a/packages/pynumaflow/tests/sink/test_responses.py +++ b/packages/pynumaflow/tests/sink/test_responses.py @@ -1,70 +1,67 @@ -import unittest from collections.abc import Iterator from pynumaflow.sinker import Response, Responses, Sinker, Datum, Message, UserMetadata -class TestResponse(unittest.TestCase): - def test_as_success(self): - succ_response = Response.as_success("2") - self.assertTrue(succ_response.success) - - def test_as_failure(self): - _response = Response.as_failure("3", "RuntimeError encountered!") - self.assertFalse(_response.success) - - def test_as_fallback(self): - _response = Response.as_fallback("4") - self.assertFalse(_response.success) - self.assertTrue(_response.fallback) - - def test_as_on_success(self): - _response = Response.as_on_success("5", Message(b"value", ["key"], UserMetadata())) - self.assertFalse(_response.success) - self.assertFalse(_response.fallback) - self.assertTrue(_response.on_success) - - -class TestResponses(unittest.TestCase): - def setUp(self) -> None: - self.resps = Responses( - Response.as_success("2"), - Response.as_failure("3", "RuntimeError encountered!"), - Response.as_fallback("5"), - ) - - def test_responses(self): - self.resps.append(Response.as_success("4")) - self.resps.append(Response.as_on_success("6", Message(b"value", ["key"], UserMetadata()))) - self.resps.append(Response.as_on_success("7", None)) - self.assertEqual(6, len(self.resps)) - - for resp in self.resps: - self.assertIsInstance(resp, Response) - - self.assertEqual(self.resps[0].id, "2") - self.assertEqual(self.resps[1].id, "3") - self.assertEqual(self.resps[2].id, "5") - self.assertEqual(self.resps[3].id, "4") - self.assertEqual(self.resps[4].id, "6") - self.assertEqual(self.resps[5].id, "7") - - self.assertEqual( - "[Response(id='2', success=True, err=None, fallback=False, " - "on_success=False, on_success_msg=None), " - "Response(id='3', success=False, err='RuntimeError encountered!', " - "fallback=False, on_success=False, on_success_msg=None), " - "Response(id='5', success=False, err=None, fallback=True, " - "on_success=False, on_success_msg=None), " - "Response(id='4', success=True, err=None, fallback=False, " - "on_success=False, on_success_msg=None), " - "Response(id='6', success=False, err=None, fallback=False, " - "on_success=True, on_success_msg=Message(_keys=['key'], _value=b'value', " - "_user_metadata=UserMetadata(_data={}))), " - "Response(id='7', success=False, err=None, fallback=False, " - "on_success=True, on_success_msg=None)]", - repr(self.resps), - ) +def test_as_success(): + succ_response = Response.as_success("2") + assert succ_response.success + + +def test_as_failure(): + _response = Response.as_failure("3", "RuntimeError encountered!") + assert not _response.success + + +def test_as_fallback(): + _response = Response.as_fallback("4") + assert not _response.success + assert _response.fallback + + +def test_as_on_success(): + _response = Response.as_on_success("5", Message(b"value", ["key"], UserMetadata())) + assert not _response.success + assert not _response.fallback + assert _response.on_success + + +def test_responses(): + resps = Responses( + Response.as_success("2"), + Response.as_failure("3", "RuntimeError encountered!"), + Response.as_fallback("5"), + ) + resps.append(Response.as_success("4")) + resps.append(Response.as_on_success("6", Message(b"value", ["key"], UserMetadata()))) + resps.append(Response.as_on_success("7", None)) + assert 6 == len(resps) + + for resp in resps: + assert isinstance(resp, Response) + + assert resps[0].id == "2" + assert resps[1].id == "3" + assert resps[2].id == "5" + assert resps[3].id == "4" + assert resps[4].id == "6" + assert resps[5].id == "7" + + assert ( + "[Response(id='2', success=True, err=None, fallback=False, " + "on_success=False, on_success_msg=None), " + "Response(id='3', success=False, err='RuntimeError encountered!', " + "fallback=False, on_success=False, on_success_msg=None), " + "Response(id='5', success=False, err=None, fallback=True, " + "on_success=False, on_success_msg=None), " + "Response(id='4', success=True, err=None, fallback=False, " + "on_success=False, on_success_msg=None), " + "Response(id='6', success=False, err=None, fallback=False, " + "on_success=True, on_success_msg=Message(_keys=['key'], _value=b'value', " + "_user_metadata=UserMetadata(_data={}))), " + "Response(id='7', success=False, err=None, fallback=False, " + "on_success=True, on_success_msg=None)]" == repr(resps) + ) class ExampleSinkClass(Sinker): @@ -74,23 +71,15 @@ def handler(self, datums: Iterator[Datum]) -> Responses: return results -class TestSinkClass(unittest.TestCase): - def setUp(self) -> None: - # Create a map class instance - self.sinker_instance = ExampleSinkClass() - - def test_sink_class_call(self): - """Test that the __call__ functionality for the class works, - ie the class instance can be called directly to invoke the handler function - """ - # make a call to the class directly - ret = self.sinker_instance(None) - self.assertEqual("test_message", ret[0].id) - # make a call to the handler - ret_handler = self.sinker_instance.handler(None) - # Both responses should be equal - self.assertEqual(ret[0], ret_handler[0]) - - -if __name__ == "__main__": - unittest.main() +def test_sink_class_call(): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + sinker_instance = ExampleSinkClass() + # make a call to the class directly + ret = sinker_instance(None) + assert "test_message" == ret[0].id + # make a call to the handler + ret_handler = sinker_instance.handler(None) + # Both responses should be equal + assert ret[0] == ret_handler[0] diff --git a/packages/pynumaflow/tests/source/test_async_source.py b/packages/pynumaflow/tests/source/test_async_source.py index 510a1216..06c371b6 100644 --- a/packages/pynumaflow/tests/source/test_async_source.py +++ b/packages/pynumaflow/tests/source/test_async_source.py @@ -2,11 +2,10 @@ from collections.abc import Iterator import logging import threading -import unittest import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio import Server from pynumaflow import setup_logging from pynumaflow._metadata import _user_and_system_metadata_from_proto @@ -27,10 +26,6 @@ server_port = "unix:///tmp/async_source.sock" -_s: Server = None -_channel = grpc.insecure_channel(server_port) -_loop = None - def startup_callable(loop): asyncio.set_event_loop(loop) @@ -50,8 +45,6 @@ async def start_server(udfs): listen_addr = "unix:///tmp/async_source.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) - global _s - _s = server await server.start() await server.wait_for_termination() @@ -68,184 +61,167 @@ def request_generator(count, request, req_type, send_handshake: bool = True): yield source_pb2.AckRequest(request=request) -class TestAsyncSourcer(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = NewAsyncSourcer() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel(server_port) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: +@pytest.fixture(scope="module") +def async_source_server(): + """Module-scoped fixture: starts an async gRPC source server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + udfs = NewAsyncSourcer() + asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) + + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(server_port) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_read_source(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - - request = read_req_source_fn() - try: - generator_response: Iterator[source_pb2.ReadResponse] = stub.ReadFn( - request_iterator=request_generator(1, request, "read") - ) - except grpc.RpcError as e: - logging.error(e) - raise - - counter = 0 - first = True - # capture the output from the ReadFn generator and assert. - for r in generator_response: - counter += 1 - if first: - self.assertEqual(True, r.handshake.sot) - first = False - continue - - if r.status.eot: - last = True - continue - - self.assertEqual( - bytes("payload:test_mock_message", encoding="utf-8"), - r.result.payload, - ) - self.assertEqual( - ["test_key"], - r.result.keys, - ) - self.assertEqual( - mock_offset().offset, - r.result.offset.offset, - ) - self.assertEqual( - mock_offset().partition_id, - r.result.offset.partition_id, - ) - - print(r.result) - user_metadata, sys_metadata = _user_and_system_metadata_from_proto( - r.result.metadata - ) - print(user_metadata) - - self.assertCountEqual(user_metadata.groups(), ["custom_info", "test_info"]) - self.assertCountEqual( - user_metadata.keys("custom_info"), ["custom_key", "custom_key2"] - ) - self.assertIsNone(user_metadata.value("custom_info", "test_key")) - self.assertEqual(user_metadata.value("custom_info", "custom_key"), b"custom_value") - self.assertEqual(user_metadata.value("test_info", "test_key"), b"test_value") - - self.assertFalse(first) - self.assertTrue(last) - - # Assert that the generator was called 12 - # (10 data messages + handshake + eot) times in the stream - self.assertEqual(12, counter) - - def test_is_ready(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def test_ack(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = ack_req_source_fn() - try: - response = stub.AckFn(request_iterator=request_generator(1, request, "ack")) - except grpc.RpcError as e: - print(e) - - count = 0 - first = True - for r in response: - count += 1 - if first: - self.assertEqual(True, r.handshake.sot) - first = False - continue - self.assertTrue(r.result.success) - - self.assertEqual(count, 2) - self.assertFalse(first) - - def test_nack(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = nack_req_source_fn() - response = stub.NackFn(request=request) - self.assertTrue(response.result.success) - - def test_pending(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = _empty_pb2.Empty() - response = None - try: - response = stub.PendingFn(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertEqual(response.result.count, 10) - - def test_partitions(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = _empty_pb2.Empty() - response = None - try: - response = stub.PartitionsFn(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertEqual(response.result.partitions, mock_partitions()) - - def __stub(self): - return source_pb2_grpc.SourceStub(_channel) - - def test_max_threads(self): - class_instance = AsyncSource() - # max cap at 16 - server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = SourceAsyncServer(sourcer_instance=class_instance) - self.assertEqual(server.max_threads, 4) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +def test_read_source(async_source_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + + request = read_req_source_fn() + try: + generator_response: Iterator[source_pb2.ReadResponse] = stub.ReadFn( + request_iterator=request_generator(1, request, "read") + ) + except grpc.RpcError as e: + logging.error(e) + raise + + counter = 0 + first = True + last = False + # capture the output from the ReadFn generator and assert. + for r in generator_response: + counter += 1 + if first: + assert r.handshake.sot is True + first = False + continue + + if r.status.eot: + last = True + continue + + assert bytes("payload:test_mock_message", encoding="utf-8") == r.result.payload + assert ["test_key"] == r.result.keys + assert mock_offset().offset == r.result.offset.offset + assert mock_offset().partition_id == r.result.offset.partition_id + + print(r.result) + user_metadata, sys_metadata = _user_and_system_metadata_from_proto(r.result.metadata) + print(user_metadata) + + assert sorted(user_metadata.groups()) == sorted(["custom_info", "test_info"]) + assert sorted(user_metadata.keys("custom_info")) == sorted( + ["custom_key", "custom_key2"] + ) + assert user_metadata.value("custom_info", "test_key") is None + assert user_metadata.value("custom_info", "custom_key") == b"custom_value" + assert user_metadata.value("test_info", "test_key") == b"test_value" + + assert not first + assert last + + # Assert that the generator was called 12 + # (10 data messages + handshake + eot) times in the stream + assert 12 == counter + + +def test_is_ready(async_source_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + + request = _empty_pb2.Empty() + response = None + try: + response = stub.IsReady(request=request) + except grpc.RpcError as e: + logging.error(e) + + assert response.ready + + +def test_ack(async_source_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = ack_req_source_fn() + try: + response = stub.AckFn(request_iterator=request_generator(1, request, "ack")) + except grpc.RpcError as e: + print(e) + + count = 0 + first = True + for r in response: + count += 1 + if first: + assert r.handshake.sot is True + first = False + continue + assert r.result.success + + assert count == 2 + assert not first + + +def test_nack(async_source_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = nack_req_source_fn() + response = stub.NackFn(request=request) + assert response.result.success + + +def test_pending(async_source_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = _empty_pb2.Empty() + response = None + try: + response = stub.PendingFn(request=request) + except grpc.RpcError as e: + logging.error(e) + + assert response.result.count == 10 + + +def test_partitions(async_source_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = _empty_pb2.Empty() + response = None + try: + response = stub.PartitionsFn(request=request) + except grpc.RpcError as e: + logging.error(e) + + assert response.result.partitions == mock_partitions() + + +def test_max_threads(): + class_instance = AsyncSource() + # max cap at 16 + server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = SourceAsyncServer(sourcer_instance=class_instance) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/source/test_async_source_err.py b/packages/pynumaflow/tests/source/test_async_source_err.py index 2c8d64d7..88fb4542 100644 --- a/packages/pynumaflow/tests/source/test_async_source_err.py +++ b/packages/pynumaflow/tests/source/test_async_source_err.py @@ -1,12 +1,9 @@ import asyncio import logging import threading -import unittest -from unittest.mock import patch import grpc - -from grpc.aio import Server +import pytest from pynumaflow import setup_logging from pynumaflow.proto.sourcer import source_pb2_grpc @@ -20,14 +17,10 @@ AsyncSourceError, nack_req_source_fn, ) -from tests.testing_utils import mock_terminate_on_stop LOGGER = setup_logging(__name__) -_s: Server = None server_port = "unix:///tmp/async_err_source.sock" -_channel = grpc.insecure_channel(server_port) -_loop = None def startup_callable(loop): @@ -44,154 +37,146 @@ async def start_server(): listen_addr = "unix:///tmp/async_err_source.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) - global _s - _s = server await server.start() await server.wait_for_termination() -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncServerErrorScenario(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - asyncio.run_coroutine_threadsafe(start_server(), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/async_err_source.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: +@pytest.fixture(scope="module") +def async_source_err_server(): + """Module-scoped fixture: starts an async gRPC source error server in a background thread.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + thread.start() + + asyncio.run_coroutine_threadsafe(start_server(), loop=loop) + + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(server_port) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_read_error(self) -> None: - grpc_exception = None - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = read_req_source_fn() - generator_response = None - try: - generator_response = stub.ReadFn( - request_iterator=request_generator(1, request, "read") - ) - for _ in generator_response: - pass - except grpc.RpcError as e: - grpc_exception = e - self.assertEqual(grpc.StatusCode.INTERNAL, e.code()) - print(e.details()) - - self.assertIsNotNone(grpc_exception) - - def test_read_handshake_error(self) -> None: - grpc_exception = None - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = read_req_source_fn() - generator_response = None - try: - generator_response = stub.ReadFn( - request_iterator=request_generator(1, request, "read", False) - ) - for _ in generator_response: - pass - except BaseException as e: - self.assertTrue("ReadFn: expected handshake message" in e.__str__()) - return - except grpc.RpcError as e: - grpc_exception = e - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) - print(e.details()) - - self.assertIsNotNone(grpc_exception) - self.fail("Expected an exception.") - - def test_ack_error(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = ack_req_source_fn() - try: - resp = stub.AckFn(request_iterator=request_generator(1, request, "ack")) - for _ in resp: - pass - except BaseException as e: - self.assertTrue("Got a runtime error from ack handler." in e.__str__()) - return - except grpc.RpcError as e: - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) - print(e.details()) - self.fail("Expected an exception.") - - def test_nack_error(self): - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = nack_req_source_fn() - with self.assertRaisesRegex( - grpc.RpcError, "Got a runtime error from nack handler." - ) as resp: - stub.NackFn(request=request) - - self.assertEqual(grpc.StatusCode.INTERNAL, resp.exception.code()) - - def test_ack_no_handshake_error(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = ack_req_source_fn() - try: - resp = stub.AckFn(request_iterator=request_generator(1, request, "ack", False)) - for _ in resp: - pass - except BaseException as e: - self.assertTrue("AckFn: expected handshake message" in e.__str__()) - return - except grpc.RpcError as e: - self.assertEqual(grpc.StatusCode.UNKNOWN, e.code()) - print(e.details()) - self.fail("Expected an exception.") - - def test_pending_error(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = _empty_pb2.Empty() - try: - stub.PendingFn(request=request) - except Exception as e: - self.assertTrue("Got a runtime error from pending handler." in e.__str__()) - return - self.fail("Expected an exception.") - - def test_partition_error(self) -> None: - with grpc.insecure_channel(server_port) as channel: - stub = source_pb2_grpc.SourceStub(channel) - request = _empty_pb2.Empty() - try: - stub.PartitionsFn(request=request) - except Exception as e: - self.assertTrue("Got a runtime error from partition handler." in e.__str__()) - return - self.fail("Expected an exception.") - - def test_invalid_server_type(self) -> None: - with self.assertRaises(TypeError): - SourceAsyncServer() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +def test_read_error(async_source_err_server) -> None: + grpc_exception = None + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = read_req_source_fn() + generator_response = None + try: + generator_response = stub.ReadFn(request_iterator=request_generator(1, request, "read")) + for _ in generator_response: + pass + except grpc.RpcError as e: + grpc_exception = e + assert grpc.StatusCode.INTERNAL == e.code() + print(e.details()) + + assert grpc_exception is not None + + +def test_read_handshake_error(async_source_err_server) -> None: + grpc_exception = None + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = read_req_source_fn() + generator_response = None + try: + generator_response = stub.ReadFn( + request_iterator=request_generator(1, request, "read", False) + ) + for _ in generator_response: + pass + except BaseException as e: + assert "ReadFn: expected handshake message" in str(e) + return + except grpc.RpcError as e: + grpc_exception = e + assert grpc.StatusCode.UNKNOWN == e.code() + print(e.details()) + + assert grpc_exception is not None + pytest.fail("Expected an exception.") + + +def test_ack_error(async_source_err_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = ack_req_source_fn() + try: + resp = stub.AckFn(request_iterator=request_generator(1, request, "ack")) + for _ in resp: + pass + except BaseException as e: + assert "Got a runtime error from ack handler." in str(e) + return + except grpc.RpcError as e: + assert grpc.StatusCode.UNKNOWN == e.code() + print(e.details()) + pytest.fail("Expected an exception.") + + +def test_nack_error(async_source_err_server): + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = nack_req_source_fn() + with pytest.raises(grpc.RpcError, match="Got a runtime error from nack handler.") as exc: + stub.NackFn(request=request) + + assert grpc.StatusCode.INTERNAL == exc.value.code() + + +def test_ack_no_handshake_error(async_source_err_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = ack_req_source_fn() + try: + resp = stub.AckFn(request_iterator=request_generator(1, request, "ack", False)) + for _ in resp: + pass + except BaseException as e: + assert "AckFn: expected handshake message" in str(e) + return + except grpc.RpcError as e: + assert grpc.StatusCode.UNKNOWN == e.code() + print(e.details()) + pytest.fail("Expected an exception.") + + +def test_pending_error(async_source_err_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = _empty_pb2.Empty() + try: + stub.PendingFn(request=request) + except Exception as e: + assert "Got a runtime error from pending handler." in str(e) + return + pytest.fail("Expected an exception.") + + +def test_partition_error(async_source_err_server) -> None: + with grpc.insecure_channel(server_port) as channel: + stub = source_pb2_grpc.SourceStub(channel) + request = _empty_pb2.Empty() + try: + stub.PartitionsFn(request=request) + except Exception as e: + assert "Got a runtime error from partition handler." in str(e) + return + pytest.fail("Expected an exception.") + + +def test_invalid_server_type() -> None: + with pytest.raises(TypeError): + SourceAsyncServer() diff --git a/packages/pynumaflow/tests/source/test_message.py b/packages/pynumaflow/tests/source/test_message.py index d3ca5a14..7739d37f 100644 --- a/packages/pynumaflow/tests/source/test_message.py +++ b/packages/pynumaflow/tests/source/test_message.py @@ -1,4 +1,4 @@ -import unittest +import pytest from pynumaflow.sourcer import ( Message, @@ -10,71 +10,53 @@ from tests.testing_utils import mock_event_time -class TestMessage(unittest.TestCase): - def test_message_creation(self): - payload = b"payload:test_mock_message" - keys = ["test_key"] - offset = mock_offset() - event_time = mock_event_time() - headers = {"key1": "value1", "key2": "value2"} - msg = Message( - payload=payload, offset=offset, keys=keys, event_time=event_time, headers=headers - ) - self.assertEqual(event_time, msg.event_time) - self.assertEqual(payload, msg.payload) - self.assertEqual(keys, msg.keys) - self.assertEqual(offset, msg.offset) - self.assertEqual(headers, msg.headers) - - -class TestOffset(unittest.TestCase): - def test_offset_creation(self): - msg = Offset(offset=mock_offset().offset, partition_id=mock_offset().partition_id) - self.assertEqual(msg.offset, mock_offset().offset) - self.assertEqual(msg.partition_id, mock_offset().partition_id) - - def test_default_offset_creation(self): - msg = Offset.offset_with_default_partition_id(mock_offset().offset) - self.assertEqual(msg.offset, mock_offset().offset) - self.assertEqual(msg.partition_id, 0) - - -class TestDatum(unittest.TestCase): - def test_datum_creation(self): - msg = ReadRequest(num_records=1, timeout_in_ms=1000) - self.assertEqual(msg.num_records, 1) - self.assertEqual(msg.timeout_in_ms, 1000) - - def test_err_num_record(self): - try: - ReadRequest(num_records="HEKKO", timeout_in_ms=1000) - except TypeError as e: - self.assertTrue("Wrong data type" in e.__str__()) - return - self.fail("Expected TypeError") - - def test_err_timeout(self): - try: - ReadRequest(num_records=1, timeout_in_ms="1000") - except TypeError as e: - self.assertTrue("Wrong data type" in e.__str__()) - return - self.fail("Expected TypeError") - - -class TestPartition(unittest.TestCase): - def test_partition_response(self): - msg = PartitionsResponse(partitions=[1, 2, 3]) - self.assertEqual(msg.partitions, [1, 2, 3]) - - def test_err_partition(self): - try: - PartitionsResponse(partitions="HEKKO") - except TypeError as e: - self.assertTrue("Wrong data type" in e.__str__()) - return - self.fail("Expected TypeError") - - -if __name__ == "__main__": - unittest.main() +def test_message_creation(): + payload = b"payload:test_mock_message" + keys = ["test_key"] + offset = mock_offset() + event_time = mock_event_time() + headers = {"key1": "value1", "key2": "value2"} + msg = Message(payload=payload, offset=offset, keys=keys, event_time=event_time, headers=headers) + assert event_time == msg.event_time + assert payload == msg.payload + assert keys == msg.keys + assert offset == msg.offset + assert headers == msg.headers + + +def test_offset_creation(): + msg = Offset(offset=mock_offset().offset, partition_id=mock_offset().partition_id) + assert msg.offset == mock_offset().offset + assert msg.partition_id == mock_offset().partition_id + + +def test_default_offset_creation(): + msg = Offset.offset_with_default_partition_id(mock_offset().offset) + assert msg.offset == mock_offset().offset + assert msg.partition_id == 0 + + +def test_datum_creation(): + msg = ReadRequest(num_records=1, timeout_in_ms=1000) + assert msg.num_records == 1 + assert msg.timeout_in_ms == 1000 + + +def test_err_num_record(): + with pytest.raises(TypeError, match="Wrong data type"): + ReadRequest(num_records="HEKKO", timeout_in_ms=1000) + + +def test_err_timeout(): + with pytest.raises(TypeError, match="Wrong data type"): + ReadRequest(num_records=1, timeout_in_ms="1000") + + +def test_partition_response(): + msg = PartitionsResponse(partitions=[1, 2, 3]) + assert msg.partitions == [1, 2, 3] + + +def test_err_partition(): + with pytest.raises(TypeError, match="Wrong data type"): + PartitionsResponse(partitions="HEKKO") diff --git a/packages/pynumaflow/tests/sourcetransform/test_async.py b/packages/pynumaflow/tests/sourcetransform/test_async.py index bcdab288..a6fe7a61 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_async.py +++ b/packages/pynumaflow/tests/sourcetransform/test_async.py @@ -1,13 +1,11 @@ import asyncio import logging import threading -import unittest -from unittest.mock import patch -from google.protobuf import timestamp_pb2 as _timestamp_pb2 import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 -from grpc.aio._server import Server +from google.protobuf import timestamp_pb2 as _timestamp_pb2 from pynumaflow import setup_logging from pynumaflow._constants import MAX_MESSAGE_SIZE @@ -17,7 +15,6 @@ from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer from tests.sourcetransform.utils import get_test_datums from tests.testing_utils import ( - mock_terminate_on_stop, mock_new_event_time, ) @@ -26,6 +23,9 @@ # if set to true, transform handler will raise a `ValueError` exception. raise_error_from_st = False +SOCK_PATH = "unix:///tmp/async_st.sock" +METADATA_SOCK_PATH = "unix:///tmp/async_st_metadata.sock" + class SimpleAsyncSourceTrn(SourceTransformer): async def handler(self, keys: list[str], datum: Datum) -> Messages: @@ -46,110 +46,63 @@ def request_generator(req): yield from req -_s: Server = None -_channel = grpc.insecure_channel("unix:///tmp/async_st.sock") -_loop = None - - -def startup_callable(loop): +def _startup_callable(loop): asyncio.set_event_loop(loop) loop.run_forever() -def new_async_st(): - handle = SimpleAsyncSourceTrn() - server = SourceTransformAsyncServer(source_transform_instance=handle) - udfs = server.servicer - return udfs - - -async def start_server(udfs): +async def _start_server(udfs): _server_options = [ ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), ] server = grpc.aio.server(options=_server_options) transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/async_st.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting server on %s", listen_addr) - global _s - _s = server + server.add_insecure_port(SOCK_PATH) + logging.info("Starting server on %s", SOCK_PATH) await server.start() - await server.wait_for_termination() - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncTransformer(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _loop - loop = asyncio.new_event_loop() - _loop = loop - _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = new_async_st() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/async_st.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_st_server(): + """Module-scoped fixture: starts an async gRPC source transform server.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) + thread.start() + + handle = SimpleAsyncSourceTrn() + server_obj = SourceTransformAsyncServer(source_transform_instance=handle) + udfs = server_obj.servicer + future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) + future.result(timeout=10) + + while True: try: - _loop.stop() - LOGGER.info("stopped the event loop") - except Exception as e: + with grpc.insecure_channel(SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_run_server(self) -> None: - with grpc.insecure_channel("unix:///tmp/async_st.sock") as channel: - stub = transform_pb2_grpc.SourceTransformStub(channel) - request = get_test_datums() - generator_response = None - try: - generator_response = stub.SourceTransformFn( - request_iterator=request_generator(request) - ) - except grpc.RpcError as e: - logging.error(e) - - responses = [] - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) - - idx = 1 - while idx < len(responses): - _id = "test-id-" + str(idx) - self.assertEqual(_id, responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", - encoding="utf-8", - ), - responses[idx].results[0].value, - ) - self.assertEqual(1, len(responses[idx].results)) - idx += 1 - - LOGGER.info("Successfully validated the server") - - def test_async_source_transformer(self) -> None: - stub = transform_pb2_grpc.SourceTransformStub(_channel) + yield loop + + loop.stop() + LOGGER.info("stopped the event loop") + + +@pytest.fixture() +def st_stub(async_st_server): + """Returns a SourceTransformStub connected to the running async server.""" + return transform_pb2_grpc.SourceTransformStub(grpc.insecure_channel(SOCK_PATH)) + + +def test_run_server(async_st_server): + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = transform_pb2_grpc.SourceTransformStub(channel) request = get_test_datums() generator_response = None try: @@ -163,109 +116,137 @@ def test_async_source_transformer(self) -> None: responses.append(r) # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) + assert len(responses) == 4 - self.assertTrue(responses[0].handshake.sot) + assert responses[0].handshake.sot idx = 1 while idx < len(responses): _id = "test-id-" + str(idx) - self.assertEqual(_id, responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", - encoding="utf-8", - ), - responses[idx].results[0].value, + assert responses[idx].id == _id + assert responses[idx].results[0].value == bytes( + "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", + encoding="utf-8", ) - self.assertEqual(1, len(responses[idx].results)) + assert len(responses[idx].results) == 1 idx += 1 - # Verify new event time gets assigned. - updated_event_time_timestamp = _timestamp_pb2.Timestamp() - updated_event_time_timestamp.FromDatetime(dt=mock_new_event_time()) - self.assertEqual( - updated_event_time_timestamp, - responses[1].results[0].event_time, + LOGGER.info("Successfully validated the server") + + +def test_async_source_transformer(st_stub): + request = get_test_datums() + generator_response = None + try: + generator_response = st_stub.SourceTransformFn(request_iterator=request_generator(request)) + except grpc.RpcError as e: + logging.error(e) + + responses = [] + # capture the output from the ReadFn generator and assert. + for r in generator_response: + responses.append(r) + + # 1 handshake + 3 data responses + assert len(responses) == 4 + + assert responses[0].handshake.sot + + idx = 1 + while idx < len(responses): + _id = "test-id-" + str(idx) + assert responses[idx].id == _id + assert responses[idx].results[0].value == bytes( + "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", + encoding="utf-8", ) - # self.assertEqual(code, grpc.StatusCode.OK) + assert len(responses[idx].results) == 1 + idx += 1 - def test_async_source_transformer_grpc_error_no_handshake(self) -> None: - stub = transform_pb2_grpc.SourceTransformStub(_channel) - request = get_test_datums(handshake=False) - grpc_exception = None + # Verify new event time gets assigned. + updated_event_time_timestamp = _timestamp_pb2.Timestamp() + updated_event_time_timestamp.FromDatetime(dt=mock_new_event_time()) + assert responses[1].results[0].event_time == updated_event_time_timestamp - responses = [] - try: - generator_response = stub.SourceTransformFn(request_iterator=request_generator(request)) - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) - except grpc.RpcError as e: - logging.error(e) - grpc_exception = e - self.assertTrue("SourceTransformFn: expected handshake message" in e.__str__()) - self.assertEqual(0, len(responses)) - self.assertIsNotNone(grpc_exception) +def test_async_source_transformer_grpc_error_no_handshake(st_stub): + request = get_test_datums(handshake=False) + grpc_exception = None - def test_async_source_transformer_grpc_error(self) -> None: - stub = transform_pb2_grpc.SourceTransformStub(_channel) - request = get_test_datums() - grpc_exception = None + responses = [] + try: + generator_response = st_stub.SourceTransformFn(request_iterator=request_generator(request)) + # capture the output from the ReadFn generator and assert. + for r in generator_response: + responses.append(r) + except grpc.RpcError as e: + logging.error(e) + grpc_exception = e + assert "SourceTransformFn: expected handshake message" in str(e) - responses = [] + assert len(responses) == 0 + assert grpc_exception is not None + + +def test_async_source_transformer_grpc_error(st_stub): + request = get_test_datums() + grpc_exception = None + + responses = [] + try: + global raise_error_from_st + raise_error_from_st = True + generator_response = st_stub.SourceTransformFn(request_iterator=request_generator(request)) + # capture the output from the ReadFn generator and assert. + for r in generator_response: + responses.append(r) + except grpc.RpcError as e: + logging.error(e) + grpc_exception = e + assert e.code() == grpc.StatusCode.INTERNAL + assert "Exception thrown from transform" in str(e) + finally: + raise_error_from_st = False + # 1 handshake + assert len(responses) == 1 + assert grpc_exception is not None + + +def test_is_ready(async_st_server): + with grpc.insecure_channel(SOCK_PATH) as channel: + stub = transform_pb2_grpc.SourceTransformStub(channel) + + request = _empty_pb2.Empty() + response = None try: - global raise_error_from_st - raise_error_from_st = True - generator_response = stub.SourceTransformFn(request_iterator=request_generator(request)) - # capture the output from the ReadFn generator and assert. - for r in generator_response: - responses.append(r) + response = stub.IsReady(request=request) except grpc.RpcError as e: logging.error(e) - grpc_exception = e - self.assertEqual(grpc.StatusCode.INTERNAL, e.code()) - self.assertTrue("Exception thrown from transform" in e.__str__()) - finally: - raise_error_from_st = False - # 1 handshake - self.assertEqual(1, len(responses)) - self.assertIsNotNone(grpc_exception) - - def test_is_ready(self) -> None: - with grpc.insecure_channel("unix:///tmp/async_st.sock") as channel: - stub = transform_pb2_grpc.SourceTransformStub(channel) - - request = _empty_pb2.Empty() - response = None - try: - response = stub.IsReady(request=request) - except grpc.RpcError as e: - logging.error(e) - - self.assertTrue(response.ready) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - SourceTransformAsyncServer() - - def __stub(self): - return transform_pb2_grpc.SourceTransformStub(_channel) - - def test_max_threads(self): - handle = SimpleAsyncSourceTrn() - # max cap at 16 - server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=32) - self.assertEqual(server.max_threads, 16) - - # use argument provided - server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=5) - self.assertEqual(server.max_threads, 5) - - # defaults to 4 - server = SourceTransformAsyncServer(source_transform_instance=handle) - self.assertEqual(server.max_threads, 4) + + assert response.ready + + +def test_invalid_input(): + with pytest.raises(TypeError): + SourceTransformAsyncServer() + + +def test_max_threads(): + handle = SimpleAsyncSourceTrn() + # max cap at 16 + server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=32) + assert server.max_threads == 16 + + # use argument provided + server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=5) + assert server.max_threads == 5 + + # defaults to 4 + server = SourceTransformAsyncServer(source_transform_instance=handle) + assert server.max_threads == 4 + + +# --- Metadata test class --- class MetadataAsyncSourceTransformer(SourceTransformer): @@ -290,100 +271,82 @@ async def handler(self, keys: list[str], datum: Datum) -> Messages: return messages -_metadata_s: Server = None -_metadata_channel = grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") -_metadata_loop = None - - -def metadata_startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - -def new_metadata_async_st(): - handle = MetadataAsyncSourceTransformer() - server = SourceTransformAsyncServer(source_transform_instance=handle) - return server.servicer - - -async def start_metadata_server(udfs): +async def _start_metadata_server(udfs): _server_options = [ ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), ] server = grpc.aio.server(options=_server_options) transform_pb2_grpc.add_SourceTransformServicer_to_server(udfs, server) - listen_addr = "unix:///tmp/async_st_metadata.sock" - server.add_insecure_port(listen_addr) - logging.info("Starting metadata server on %s", listen_addr) - global _metadata_s - _metadata_s = server + server.add_insecure_port(METADATA_SOCK_PATH) + logging.info("Starting metadata server on %s", METADATA_SOCK_PATH) await server.start() - await server.wait_for_termination() - - -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestAsyncTransformerMetadata(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - global _metadata_loop - loop = asyncio.new_event_loop() - _metadata_loop = loop - _thread = threading.Thread(target=metadata_startup_callable, args=(loop,), daemon=True) - _thread.start() - udfs = new_metadata_async_st() - asyncio.run_coroutine_threadsafe(start_metadata_server(udfs), loop=loop) - while True: - try: - with grpc.insecure_channel("unix:///tmp/async_st_metadata.sock") as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - - @classmethod - def tearDownClass(cls) -> None: + return server + + +@pytest.fixture(scope="module") +def async_st_metadata_server(): + """Module-scoped fixture: starts an async gRPC metadata source transform server.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) + thread.start() + + handle = MetadataAsyncSourceTransformer() + server_obj = SourceTransformAsyncServer(source_transform_instance=handle) + udfs = server_obj.servicer + future = asyncio.run_coroutine_threadsafe(_start_metadata_server(udfs), loop=loop) + future.result(timeout=10) + + while True: try: - _metadata_loop.stop() - LOGGER.info("stopped the metadata event loop") - except Exception as e: + with grpc.insecure_channel(METADATA_SOCK_PATH) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") LOGGER.error(e) - def test_source_transformer_with_metadata(self) -> None: - stub = transform_pb2_grpc.SourceTransformStub(_metadata_channel) - request = get_test_datums(with_metadata=True) - generator_response = None - try: - generator_response = stub.SourceTransformFn(request_iterator=request_generator(request)) - except grpc.RpcError as e: - logging.error(e) - raise + yield loop - responses = [] - for r in generator_response: - responses.append(r) + loop.stop() + LOGGER.info("stopped the metadata event loop") - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - self.assertTrue(responses[0].handshake.sot) - # Verify metadata is passed through correctly - for idx, resp in enumerate(responses[1:], 1): - _id = "test-id-" + str(idx) - self.assertEqual(_id, resp.id) - self.assertEqual(1, len(resp.results)) - # Verify user metadata is returned - self.assertEqual( - resp.results[0].metadata.user_metadata["custom_info"], - metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}), - ) - # System metadata should be empty in responses (user cannot set it) - self.assertEqual(resp.results[0].metadata.sys_metadata, {}) +@pytest.fixture() +def metadata_stub(async_st_metadata_server): + """Returns a SourceTransformStub connected to the metadata server.""" + return transform_pb2_grpc.SourceTransformStub(grpc.insecure_channel(METADATA_SOCK_PATH)) -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() +def test_source_transformer_with_metadata(metadata_stub): + request = get_test_datums(with_metadata=True) + generator_response = None + try: + generator_response = metadata_stub.SourceTransformFn( + request_iterator=request_generator(request) + ) + except grpc.RpcError as e: + logging.error(e) + raise + + responses = [] + for r in generator_response: + responses.append(r) + + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot + + # Verify metadata is passed through correctly + for idx, resp in enumerate(responses[1:], 1): + _id = "test-id-" + str(idx) + assert resp.id == _id + assert len(resp.results) == 1 + # Verify user metadata is returned + assert resp.results[0].metadata.user_metadata["custom_info"] == metadata_pb2.KeyValueGroup( + key_value={"version": f"{idx}.0.0".encode()} + ) + # System metadata should be empty in responses (user cannot set it) + assert resp.results[0].metadata.sys_metadata == {} diff --git a/packages/pynumaflow/tests/sourcetransform/test_messages.py b/packages/pynumaflow/tests/sourcetransform/test_messages.py index eb8124c5..90dc03da 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_messages.py +++ b/packages/pynumaflow/tests/sourcetransform/test_messages.py @@ -1,4 +1,4 @@ -import unittest +import pytest from datetime import datetime, timezone from pynumaflow.sourcetransformer import ( @@ -23,141 +23,155 @@ def mock_event_time(): return t -class TestMessage(unittest.TestCase): - def test_Message_creation(self): - mock_obj = { - "Keys": ["test_key"], +def _mock_message_object(): + value = mock_message_t() + event_time = mock_event_time() + return Message(value=value, event_time=event_time) + + +# --- TestMessage --- + + +def test_message_creation(): + mock_obj = { + "Keys": ["test_key"], + "Value": mock_message_t(), + "EventTime": mock_event_time(), + "Tags": ["test_tag"], + } + msgt = Message( + mock_obj["Value"], mock_obj["EventTime"], keys=mock_obj["Keys"], tags=mock_obj["Tags"] + ) + assert mock_obj["EventTime"] == msgt.event_time + assert mock_obj["Value"] == msgt.value + assert mock_obj["Keys"] == msgt.keys + assert mock_obj["Tags"] == msgt.tags + + +def test_message_to_drop(): + mock_obj = { + "Keys": [], + "Value": b"", + "Tags": [DROP], + "EventTime": mock_event_time(), + } + msgt = Message(b"", datetime(1, 1, 1, 0, 0)).to_drop(mock_event_time()) + assert isinstance(msgt, Message) + assert mock_obj["Keys"] == msgt.keys + assert mock_obj["Value"] == msgt.value + assert mock_obj["Tags"] == msgt.tags + assert mock_obj["EventTime"] == msgt.event_time + + +def test_message_with_user_metadata(): + user_meta = UserMetadata() + user_meta.add_key("group1", "key1", b"value1") + user_meta.add_key("group1", "key2", b"value2") + + msgt = Message( + mock_message_t(), + mock_event_time(), + keys=["test_key"], + user_metadata=user_meta, + ) + assert mock_message_t() == msgt.value + assert ["test_key"] == msgt.keys + assert b"value1" == msgt.user_metadata.value("group1", "key1") + assert b"value2" == msgt.user_metadata.value("group1", "key2") + assert ["group1"] == msgt.user_metadata.groups() + + +def test_message_default_user_metadata(): + msgt = Message(mock_message_t(), mock_event_time()) + assert msgt.user_metadata is not None + assert 0 == len(msgt.user_metadata) + + +# --- TestMessages --- + + +def test_messages_items(): + mock_obj = [ + { + "Keys": [b"U+005C__ALL__"], "Value": mock_message_t(), "EventTime": mock_event_time(), - "Tags": ["test_tag"], - } - msgt = Message( - mock_obj["Value"], mock_obj["EventTime"], keys=mock_obj["Keys"], tags=mock_obj["Tags"] - ) - self.assertEqual(mock_obj["EventTime"], msgt.event_time) - self.assertEqual(mock_obj["Value"], msgt.value) - self.assertEqual(mock_obj["Keys"], msgt.keys) - self.assertEqual(mock_obj["Tags"], msgt.tags) - - def test_message_to_drop(self): - mock_obj = { - "Keys": [], - "Value": b"", - "Tags": [DROP], + }, + { + "Keys": [b"U+005C__ALL__"], + "Value": mock_message_t(), "EventTime": mock_event_time(), - } - msgt = Message(b"", datetime(1, 1, 1, 0, 0)).to_drop(mock_event_time()) - self.assertEqual(Message, type(msgt)) - self.assertEqual(mock_obj["Keys"], msgt.keys) - self.assertEqual(mock_obj["Value"], msgt.value) - self.assertEqual(mock_obj["Tags"], msgt.tags) - self.assertEqual(mock_obj["EventTime"], msgt.event_time) - - def test_message_with_user_metadata(self): - user_meta = UserMetadata() - user_meta.add_key("group1", "key1", b"value1") - user_meta.add_key("group1", "key2", b"value2") - - msgt = Message( - mock_message_t(), - mock_event_time(), - keys=["test_key"], - user_metadata=user_meta, - ) - self.assertEqual(mock_message_t(), msgt.value) - self.assertEqual(["test_key"], msgt.keys) - self.assertEqual(b"value1", msgt.user_metadata.value("group1", "key1")) - self.assertEqual(b"value2", msgt.user_metadata.value("group1", "key2")) - self.assertEqual(["group1"], msgt.user_metadata.groups()) - - def test_message_default_user_metadata(self): - msgt = Message(mock_message_t(), mock_event_time()) - self.assertIsNotNone(msgt.user_metadata) - self.assertEqual(0, len(msgt.user_metadata)) - - -class TestMessages(unittest.TestCase): - @staticmethod - def mock_Message_object(): - value = mock_message_t() - event_time = mock_event_time() - return Message(value=value, event_time=event_time) - - def test_items(self): - mock_obj = [ - { - "Keys": [b"U+005C__ALL__"], - "Value": mock_message_t(), - "EventTime": mock_event_time(), - }, - { - "Keys": [b"U+005C__ALL__"], - "Value": mock_message_t(), - "EventTime": mock_event_time(), - }, - ] - msgts = Messages(*mock_obj) - self.assertEqual(len(mock_obj), len(msgts)) - self.assertEqual(len(mock_obj), len(msgts.items())) - self.assertEqual(mock_obj[0]["Keys"], msgts[0]["Keys"]) - self.assertEqual(mock_obj[0]["Value"], msgts[0]["Value"]) - self.assertEqual(mock_obj[0]["EventTime"], msgts[0]["EventTime"]) - self.assertEqual( - "[{'Keys': [b'U+005C__ALL__'], 'Value': b'test_mock_message_t', " - "'EventTime': datetime.datetime(2022, 9, 12, 16, 0, tzinfo=datetime.timezone.utc)}, " - "{'Keys': [b'U+005C__ALL__'], 'Value': b'test_mock_message_t', " - "'EventTime': datetime.datetime(2022, 9, 12, 16, 0, tzinfo=datetime.timezone.utc)}]", - repr(msgts), - ) - - def test_append(self): - msgts = Messages() - self.assertEqual(0, len(msgts)) - msgts.append(self.mock_Message_object()) - self.assertEqual(1, len(msgts)) - msgts.append(self.mock_Message_object()) - self.assertEqual(2, len(msgts)) - - def test_err(self): - msgts = Messages(self.mock_Message_object(), self.mock_Message_object()) - with self.assertRaises(TypeError): - msgts[:1] - - -class TestDatum(unittest.TestCase): - def test_datum_with_metadata(self): - user_meta = UserMetadata() - user_meta.add_key("group1", "key1", b"value1") - - sys_meta = SystemMetadata({"sys_group": {"sys_key": b"sys_value"}}) - - d = Datum( - keys=["test_key"], - value=mock_message_t(), - event_time=mock_event_time(), - watermark=mock_event_time(), - headers={"header1": "value1"}, - user_metadata=user_meta, - system_metadata=sys_meta, - ) - self.assertEqual(["test_key"], d.keys) - self.assertEqual(mock_message_t(), d.value) - self.assertEqual(mock_event_time(), d.event_time) - self.assertEqual({"header1": "value1"}, d.headers) - self.assertEqual(b"value1", d.user_metadata.value("group1", "key1")) - self.assertEqual(b"sys_value", d.system_metadata.value("sys_group", "sys_key")) - - def test_datum_default_metadata(self): - d = Datum( - keys=["test_key"], - value=mock_message_t(), - event_time=mock_event_time(), - watermark=mock_event_time(), - ) - self.assertIsNotNone(d.user_metadata) - self.assertIsNotNone(d.system_metadata) - self.assertEqual(0, len(d.user_metadata)) - self.assertEqual([], d.system_metadata.groups()) + }, + ] + msgts = Messages(*mock_obj) + assert len(mock_obj) == len(msgts) + assert len(mock_obj) == len(msgts.items()) + assert mock_obj[0]["Keys"] == msgts[0]["Keys"] + assert mock_obj[0]["Value"] == msgts[0]["Value"] + assert mock_obj[0]["EventTime"] == msgts[0]["EventTime"] + assert ( + "[{'Keys': [b'U+005C__ALL__'], 'Value': b'test_mock_message_t', " + "'EventTime': datetime.datetime(2022, 9, 12, 16, 0, tzinfo=datetime.timezone.utc)}, " + "{'Keys': [b'U+005C__ALL__'], 'Value': b'test_mock_message_t', " + "'EventTime': datetime.datetime(2022, 9, 12, 16, 0, tzinfo=datetime.timezone.utc)}]" + ) == repr(msgts) + + +def test_messages_append(): + msgts = Messages() + assert 0 == len(msgts) + msgts.append(_mock_message_object()) + assert 1 == len(msgts) + msgts.append(_mock_message_object()) + assert 2 == len(msgts) + + +def test_messages_err(): + msgts = Messages(_mock_message_object(), _mock_message_object()) + with pytest.raises(TypeError): + msgts[:1] + + +# --- TestDatum --- + + +def test_datum_with_metadata(): + user_meta = UserMetadata() + user_meta.add_key("group1", "key1", b"value1") + + sys_meta = SystemMetadata({"sys_group": {"sys_key": b"sys_value"}}) + + d = Datum( + keys=["test_key"], + value=mock_message_t(), + event_time=mock_event_time(), + watermark=mock_event_time(), + headers={"header1": "value1"}, + user_metadata=user_meta, + system_metadata=sys_meta, + ) + assert ["test_key"] == d.keys + assert mock_message_t() == d.value + assert mock_event_time() == d.event_time + assert {"header1": "value1"} == d.headers + assert b"value1" == d.user_metadata.value("group1", "key1") + assert b"sys_value" == d.system_metadata.value("sys_group", "sys_key") + + +def test_datum_default_metadata(): + d = Datum( + keys=["test_key"], + value=mock_message_t(), + event_time=mock_event_time(), + watermark=mock_event_time(), + ) + assert d.user_metadata is not None + assert d.system_metadata is not None + assert 0 == len(d.user_metadata) + assert [] == d.system_metadata.groups() + + +# --- TestSourceTransformClass --- class ExampleSourceTransformClass(SourceTransformer): @@ -167,23 +181,15 @@ def handler(self, keys: list[str], datum: Datum) -> Messages: return messages -class TestSourceTransformClass(unittest.TestCase): - def setUp(self) -> None: - # Create a map class instance - self.transform_instance = ExampleSourceTransformClass() - - def test_source_transform_class_call(self): - """Test that the __call__ functionality for the class works, - ie the class instance can be called directly to invoke the handler function - """ - # make a call to the class directly - ret = self.transform_instance([], None) - self.assertEqual(mock_message_t(), ret[0].value) - # make a call to the handler - ret_handler = self.transform_instance.handler([], None) - # Both responses should be equal - self.assertEqual(ret[0], ret_handler[0]) - - -if __name__ == "__main__": - unittest.main() +def test_source_transform_class_call(): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + transform_instance = ExampleSourceTransformClass() + # make a call to the class directly + ret = transform_instance([], None) + assert mock_message_t() == ret[0].value + # make a call to the handler + ret_handler = transform_instance.handler([], None) + # Both responses should be equal + assert ret[0] == ret_handler[0] diff --git a/packages/pynumaflow/tests/sourcetransform/test_multiproc.py b/packages/pynumaflow/tests/sourcetransform/test_multiproc.py index 4136f9d3..4d56cd31 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_multiproc.py +++ b/packages/pynumaflow/tests/sourcetransform/test_multiproc.py @@ -1,8 +1,6 @@ import os -import unittest -from unittest.mock import patch -import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2 from grpc import StatusCode @@ -12,172 +10,145 @@ from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums from tests.conftest import collect_responses, drain_responses, send_test_requests -from tests.testing_utils import ( - mock_new_event_time, - mock_terminate_on_stop, -) - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestMultiProcMethods(unittest.TestCase): - def setUp(self) -> None: - server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) - my_servicer = server.servicer - services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_multiproc_init(self) -> None: - server = SourceTransformMultiProcServer( - source_transform_instance=transform_handler, server_count=3 - ) - self.assertEqual(server._process_count, 3) +from tests.testing_utils import mock_new_event_time - def test_multiproc_process_count(self) -> None: - default_value = os.cpu_count() - server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) - self.assertEqual(server._process_count, default_value) - def test_max_process_count(self) -> None: - default_value = os.cpu_count() - server = SourceTransformMultiProcServer( - source_transform_instance=transform_handler, server_count=50 - ) - self.assertEqual(server._process_count, 2 * default_value) - - def test_udf_mapt_err_handshake(self): - server = SourceTransformMultiProcServer(source_transform_instance=err_transform_handler) - my_servicer = server.servicer - services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=False) - method = self.test_server.invoke_stream_stream( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "SourceTransformFn" - ] - ), - invocation_metadata={}, - timeout=1, - ) +def _make_multiproc_server(handler): + server = SourceTransformMultiProcServer(source_transform_instance=handler) + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: server.servicer} + return server_from_dictionary(services, strict_real_time()) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - self.assertTrue("SourceTransformFn: expected handshake message" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_udf_mapt_err(self): - server = SourceTransformMultiProcServer(source_transform_instance=err_transform_handler) - my_servicer = server.servicer - services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums() - method = self.test_server.invoke_stream_stream( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "SourceTransformFn" - ] - ), - invocation_metadata={}, - timeout=1, - ) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - self.assertTrue("Something is fishy" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "IsReady" - ] - ), - invocation_metadata={}, - request=_empty_pb2.Empty(), - timeout=1, - ) +def _invoke_transform_fn(test_server, timeout=1): + """Helper to invoke the SourceTransformFn stream method.""" + return test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=timeout, + ) - response, metadata, code, details = method.termination() - expected = transform_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_mapt_assign_new_event_time(self): - test_datums = get_test_datums() - - method = self.test_server.invoke_stream_stream( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "SourceTransformFn" - ] - ), - invocation_metadata={}, - timeout=1, - ) - send_test_requests(method, test_datums) - responses = collect_responses(method) - - metadata, code, details = method.termination() - - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) - - result_ids = {f"test-id-{id}" for id in range(1, 4)} - idx = 1 - while idx < len(responses): - result_ids.remove(responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", - encoding="utf-8", - ), - responses[idx].results[0].value, - ) - self.assertEqual(1, len(responses[idx].results)) - idx += 1 - self.assertEqual(len(result_ids), 0) - - # Verify new event time gets assigned. - updated_event_time_timestamp = _timestamp_pb2.Timestamp() - updated_event_time_timestamp.FromDatetime(dt=mock_new_event_time()) - self.assertEqual( - updated_event_time_timestamp, - responses[1].results[0].event_time, - ) - self.assertEqual(code, StatusCode.OK) +@pytest.fixture() +def multiproc_test_server(): + return _make_multiproc_server(transform_handler) - def test_invalid_input(self): - with self.assertRaises(TypeError): - SourceTransformMultiProcServer() - def test_max_threads(self): - # max cap at 16 - server = SourceTransformMultiProcServer( - source_transform_instance=transform_handler, max_threads=32 - ) - self.assertEqual(server.max_threads, 16) +def test_multiproc_init(): + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, server_count=3 + ) + assert server._process_count == 3 - # use argument provided - server = SourceTransformMultiProcServer( - source_transform_instance=transform_handler, max_threads=5 - ) - self.assertEqual(server.max_threads, 5) - # defaults to 4 - server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) - self.assertEqual(server.max_threads, 4) +def test_multiproc_process_count(): + default_value = os.cpu_count() + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + assert server._process_count == default_value + + +def test_max_process_count(): + default_value = os.cpu_count() + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, server_count=50 + ) + assert server._process_count == 2 * default_value + + +def test_udf_mapt_err_handshake(): + test_server = _make_multiproc_server(err_transform_handler) + test_datums = get_test_datums(handshake=False) + method = _invoke_transform_fn(test_server) + + send_test_requests(method, test_datums) + drain_responses(method) + metadata, code, details = method.termination() + assert "SourceTransformFn: expected handshake message" in details + assert code == StatusCode.INTERNAL -if __name__ == "__main__": - unittest.main() + +def test_udf_mapt_err(): + test_server = _make_multiproc_server(err_transform_handler) + test_datums = get_test_datums() + method = _invoke_transform_fn(test_server) + + send_test_requests(method, test_datums) + drain_responses(method) + + metadata, code, details = method.termination() + assert "Something is fishy" in details + assert code == StatusCode.INTERNAL + + +def test_is_ready(multiproc_test_server): + method = multiproc_test_server.invoke_unary_unary( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name["IsReady"] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + response, metadata, code, details = method.termination() + assert response == transform_pb2.ReadyResponse(ready=True) + assert code == StatusCode.OK + + +def test_mapt_assign_new_event_time(multiproc_test_server): + test_datums = get_test_datums() + method = _invoke_transform_fn(multiproc_test_server) + + send_test_requests(method, test_datums) + responses = collect_responses(method) + + metadata, code, details = method.termination() + + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot + + result_ids = {f"test-id-{id}" for id in range(1, 4)} + idx = 1 + while idx < len(responses): + result_ids.remove(responses[idx].id) + assert responses[idx].results[0].value == bytes( + "payload:test_mock_message event_time:2022-09-12 16:00:00 ", + encoding="utf-8", + ) + assert len(responses[idx].results) == 1 + idx += 1 + assert len(result_ids) == 0 + + # Verify new event time gets assigned. + updated_event_time_timestamp = _timestamp_pb2.Timestamp() + updated_event_time_timestamp.FromDatetime(dt=mock_new_event_time()) + assert responses[1].results[0].event_time == updated_event_time_timestamp + assert code == StatusCode.OK + + +def test_invalid_input(): + with pytest.raises(TypeError): + SourceTransformMultiProcServer() + + +def test_max_threads(): + # max cap at 16 + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, max_threads=32 + ) + assert server.max_threads == 16 + + # use argument provided + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, max_threads=5 + ) + assert server.max_threads == 5 + + # defaults to 4 + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + assert server.max_threads == 4 diff --git a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py index 1b6b4c89..c01b1efe 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py +++ b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py @@ -1,7 +1,4 @@ -import unittest -from unittest.mock import patch - -import grpc +import pytest from google.protobuf import empty_pb2 as _empty_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2 from grpc import StatusCode @@ -12,159 +9,143 @@ from pynumaflow.sourcetransformer import SourceTransformServer, Datum, Messages, Message from tests.sourcetransform.utils import transform_handler, err_transform_handler, get_test_datums from tests.conftest import collect_responses, drain_responses, send_test_requests -from tests.testing_utils import ( - mock_terminate_on_stop, - mock_new_event_time, -) - - -# We are mocking the terminate function from the psutil to not exit the program during testing -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestServer(unittest.TestCase): - def setUp(self) -> None: - server = SourceTransformServer(source_transform_instance=transform_handler) - my_servicer = server.servicer - services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_init_with_args(self) -> None: - server = SourceTransformServer( - source_transform_instance=transform_handler, - sock_path="/tmp/test.sock", - max_message_size=1024 * 1024 * 5, - ) - self.assertEqual(server.sock_path, "unix:///tmp/test.sock") - self.assertEqual(server.max_message_size, 1024 * 1024 * 5) - - def test_udf_mapt_err(self): - server = SourceTransformServer(source_transform_instance=err_transform_handler) - my_servicer = server.servicer - services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums() - - method = self.test_server.invoke_stream_stream( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "SourceTransformFn" - ] - ), - invocation_metadata={}, - timeout=1, - ) +from tests.testing_utils import mock_new_event_time - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - self.assertTrue("Something is fishy" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) - - def test_is_ready(self): - method = self.test_server.invoke_unary_unary( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "IsReady" - ] - ), - invocation_metadata={}, - request=_empty_pb2.Empty(), - timeout=1, - ) - response, metadata, code, details = method.termination() - expected = transform_pb2.ReadyResponse(ready=True) - self.assertEqual(expected, response) - self.assertEqual(code, StatusCode.OK) - - def test_udf_mapt_err_handshake(self): - server = SourceTransformServer(source_transform_instance=err_transform_handler) - my_servicer = server.servicer - services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=False) - method = self.test_server.invoke_stream_stream( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "SourceTransformFn" - ] - ), - invocation_metadata={}, - timeout=1, - ) +def _make_transform_server(handler): + server = SourceTransformServer(source_transform_instance=handler) + services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: server.servicer} + return server_from_dictionary(services, strict_real_time()) - send_test_requests(method, test_datums) - drain_responses(method) - metadata, code, details = method.termination() - self.assertTrue("SourceTransformFn: expected handshake message" in details) - self.assertEqual(grpc.StatusCode.INTERNAL, code) +def _invoke_transform_fn(test_server, timeout=1): + """Helper to invoke the SourceTransformFn stream method.""" + return test_server.invoke_stream_stream( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ + "SourceTransformFn" + ] + ), + invocation_metadata={}, + timeout=timeout, + ) - def test_mapt_assign_new_event_time(self): - test_datums = get_test_datums() - method = self.test_server.invoke_stream_stream( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "SourceTransformFn" - ] - ), - invocation_metadata={}, - timeout=1, - ) +@pytest.fixture() +def transform_test_server(): + return _make_transform_server(transform_handler) + + +# --------------------------------------------------------------------------- +# TestServer tests +# --------------------------------------------------------------------------- + + +def test_init_with_args(): + server = SourceTransformServer( + source_transform_instance=transform_handler, + sock_path="/tmp/test.sock", + max_message_size=1024 * 1024 * 5, + ) + assert server.sock_path == "unix:///tmp/test.sock" + assert server.max_message_size == 1024 * 1024 * 5 + + +def test_udf_mapt_err(): + test_server = _make_transform_server(err_transform_handler) + test_datums = get_test_datums() + method = _invoke_transform_fn(test_server) + + send_test_requests(method, test_datums) + drain_responses(method) + + metadata, code, details = method.termination() + assert "Something is fishy" in details + assert code == StatusCode.INTERNAL + + +def test_is_ready(transform_test_server): + method = transform_test_server.invoke_unary_unary( + method_descriptor=( + transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name["IsReady"] + ), + invocation_metadata={}, + request=_empty_pb2.Empty(), + timeout=1, + ) + + response, metadata, code, details = method.termination() + assert response == transform_pb2.ReadyResponse(ready=True) + assert code == StatusCode.OK + + +def test_udf_mapt_err_handshake(): + test_server = _make_transform_server(err_transform_handler) + test_datums = get_test_datums(handshake=False) + method = _invoke_transform_fn(test_server) + + send_test_requests(method, test_datums) + drain_responses(method) + + metadata, code, details = method.termination() + assert "SourceTransformFn: expected handshake message" in details + assert code == StatusCode.INTERNAL + + +def test_mapt_assign_new_event_time(transform_test_server): + test_datums = get_test_datums() + method = _invoke_transform_fn(transform_test_server) + + send_test_requests(method, test_datums) + responses = collect_responses(method) - send_test_requests(method, test_datums) - responses = collect_responses(method) - - metadata, code, details = method.termination() - - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - - self.assertTrue(responses[0].handshake.sot) - - result_ids = {f"test-id-{id}" for id in range(1, 4)} - idx = 1 - while idx < len(responses): - result_ids.remove(responses[idx].id) - self.assertEqual( - bytes( - "payload:test_mock_message " "event_time:2022-09-12 16:00:00 ", - encoding="utf-8", - ), - responses[idx].results[0].value, - ) - self.assertEqual(1, len(responses[idx].results)) - idx += 1 - self.assertEqual(len(result_ids), 0) - - # Verify new event time gets assigned. - updated_event_time_timestamp = _timestamp_pb2.Timestamp() - updated_event_time_timestamp.FromDatetime(dt=mock_new_event_time()) - self.assertEqual( - updated_event_time_timestamp, - responses[1].results[0].event_time, + metadata, code, details = method.termination() + + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot + + result_ids = {f"test-id-{id}" for id in range(1, 4)} + idx = 1 + while idx < len(responses): + result_ids.remove(responses[idx].id) + assert responses[idx].results[0].value == bytes( + "payload:test_mock_message event_time:2022-09-12 16:00:00 ", + encoding="utf-8", ) - self.assertEqual(code, StatusCode.OK) + assert len(responses[idx].results) == 1 + idx += 1 + assert len(result_ids) == 0 + + # Verify new event time gets assigned. + updated_event_time_timestamp = _timestamp_pb2.Timestamp() + updated_event_time_timestamp.FromDatetime(dt=mock_new_event_time()) + assert responses[1].results[0].event_time == updated_event_time_timestamp + assert code == StatusCode.OK + + +def test_invalid_input(): + with pytest.raises(TypeError): + SourceTransformServer() - def test_invalid_input(self): - with self.assertRaises(TypeError): - SourceTransformServer() - def test_max_threads(self): - # max cap at 16 - server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=32) - self.assertEqual(server.max_threads, 16) +def test_max_threads(): + # max cap at 16 + server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=32) + assert server.max_threads == 16 - # use argument provided - server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=5) - self.assertEqual(server.max_threads, 5) + # use argument provided + server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=5) + assert server.max_threads == 5 - # defaults to 4 - server = SourceTransformServer(source_transform_instance=transform_handler) - self.assertEqual(server.max_threads, 4) + # defaults to 4 + server = SourceTransformServer(source_transform_instance=transform_handler) + assert server.max_threads == 4 + + +# --------------------------------------------------------------------------- +# Metadata tests +# --------------------------------------------------------------------------- def metadata_transform_handler(keys: list[str], datum: Datum) -> Messages: @@ -185,53 +166,36 @@ def metadata_transform_handler(keys: list[str], datum: Datum) -> Messages: return messages -@patch("psutil.Process.kill", mock_terminate_on_stop) -class TestServerMetadata(unittest.TestCase): - def setUp(self) -> None: - server = SourceTransformServer(source_transform_instance=metadata_transform_handler) - my_servicer = server.servicer - services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} - self.test_server = server_from_dictionary(services, strict_real_time()) - - def test_source_transform_with_metadata(self): - test_datums = get_test_datums(with_metadata=True) - - method = self.test_server.invoke_stream_stream( - method_descriptor=( - transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"].methods_by_name[ - "SourceTransformFn" - ] - ), - invocation_metadata={}, - timeout=1, - ) +@pytest.fixture() +def metadata_test_server(): + return _make_transform_server(metadata_transform_handler) - send_test_requests(method, test_datums) - responses = collect_responses(method) - metadata, code, details = method.termination() +def test_source_transform_with_metadata(metadata_test_server): + test_datums = get_test_datums(with_metadata=True) + method = _invoke_transform_fn(metadata_test_server) - # 1 handshake + 3 data responses - self.assertEqual(4, len(responses)) - self.assertTrue(responses[0].handshake.sot) + send_test_requests(method, test_datums) + responses = collect_responses(method) - # Verify metadata is passed through correctly - result_metadata = {} - for resp in responses[1:]: - result_metadata[resp.id] = resp.results[0].metadata + metadata, code, details = method.termination() - for idx in range(1, 4): - _id = f"test-id-{idx}" - self.assertIn(_id, result_metadata) - self.assertEqual( - result_metadata[_id].user_metadata["custom_info"], - metadata_pb2.KeyValueGroup(key_value={"version": f"{idx}.0.0".encode()}), - ) - # System metadata should be empty in responses - self.assertEqual(result_metadata[_id].sys_metadata, {}) + # 1 handshake + 3 data responses + assert len(responses) == 4 + assert responses[0].handshake.sot - self.assertEqual(code, StatusCode.OK) + # Verify metadata is passed through correctly + result_metadata = {} + for resp in responses[1:]: + result_metadata[resp.id] = resp.results[0].metadata + for idx in range(1, 4): + _id = f"test-id-{idx}" + assert _id in result_metadata + assert result_metadata[_id].user_metadata["custom_info"] == metadata_pb2.KeyValueGroup( + key_value={"version": f"{idx}.0.0".encode()} + ) + # System metadata should be empty in responses + assert result_metadata[_id].sys_metadata == {} -if __name__ == "__main__": - unittest.main() + assert code == StatusCode.OK diff --git a/packages/pynumaflow/tests/test_info_server.py b/packages/pynumaflow/tests/test_info_server.py index d05b78ed..9490f991 100644 --- a/packages/pynumaflow/tests/test_info_server.py +++ b/packages/pynumaflow/tests/test_info_server.py @@ -1,5 +1,5 @@ import os -import unittest +import pytest from unittest import mock from tests.testing_utils import read_info_server @@ -15,45 +15,47 @@ ) -def mockenv(**envvars): - return mock.patch.dict(os.environ, envvars) - - -class TestInfoServer(unittest.TestCase): - @mockenv(NUMAFLOW_CPU_LIMIT="3") - def setUp(self) -> None: - self.serv_uds = ServerInfo.get_default_server_info() - self.serv_uds.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sourcer] - self.serv_uds.metadata = get_metadata_env(envs=METADATA_ENVS) - - def test_empty_write_info(self): - test_file = "/tmp/test_info_server" - with self.assertRaises(Exception): - info_server_write(server_info=None, info_file=test_file) - - def test_success_write_info(self): - test_file = "/tmp/test_info_server" - ret = info_server_write(server_info=self.serv_uds, info_file=test_file) - self.assertIsNone(ret) - file_data = read_info_server(info_file=test_file) - self.assertEqual(file_data["metadata"]["CPU_LIMIT"], "3") - self.assertEqual(file_data["protocol"], "uds") - self.assertEqual(file_data["language"], "python") - self.assertEqual(file_data["minimum_numaflow_version"], "1.4.0-z") - - def test_metadata_env(self): - test_file = "/tmp/test_info_server" - ret = info_server_write(server_info=self.serv_uds, info_file=test_file) - self.assertIsNone(ret) - - def test_invalid_input(self): - with self.assertRaises(TypeError): - ServerInfo() - - def test_file_new(self): - test_file = "/tmp/test_info_server" - exists = os.path.isfile(path=test_file) - if exists: - os.remove(test_file) - ret = info_server_write(server_info=self.serv_uds, info_file=test_file) - self.assertIsNone(ret) +@pytest.fixture() +def serv_uds(): + with mock.patch.dict(os.environ, {"NUMAFLOW_CPU_LIMIT": "3"}): + s = ServerInfo.get_default_server_info() + s.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Sourcer] + s.metadata = get_metadata_env(envs=METADATA_ENVS) + return s + + +def test_empty_write_info(): + test_file = "/tmp/test_info_server" + with pytest.raises(Exception): + info_server_write(server_info=None, info_file=test_file) + + +def test_success_write_info(serv_uds): + test_file = "/tmp/test_info_server" + ret = info_server_write(server_info=serv_uds, info_file=test_file) + assert ret is None + file_data = read_info_server(info_file=test_file) + assert file_data["metadata"]["CPU_LIMIT"] == "3" + assert file_data["protocol"] == "uds" + assert file_data["language"] == "python" + assert file_data["minimum_numaflow_version"] == "1.4.0-z" + + +def test_metadata_env(serv_uds): + test_file = "/tmp/test_info_server" + ret = info_server_write(server_info=serv_uds, info_file=test_file) + assert ret is None + + +def test_invalid_input(): + with pytest.raises(TypeError): + ServerInfo() + + +def test_file_new(serv_uds): + test_file = "/tmp/test_info_server" + exists = os.path.isfile(path=test_file) + if exists: + os.remove(test_file) + ret = info_server_write(server_info=serv_uds, info_file=test_file) + assert ret is None diff --git a/packages/pynumaflow/tests/testing_utils.py b/packages/pynumaflow/tests/testing_utils.py index 43c1a437..bc6cb2a3 100644 --- a/packages/pynumaflow/tests/testing_utils.py +++ b/packages/pynumaflow/tests/testing_utils.py @@ -93,7 +93,3 @@ def info_serv_is_ready(info_serv_data: str, eof: str = EOF): data = info_serv_data[:len_diff] return True, data return False, None - - -def mock_terminate_on_stop(process): - _LOGGER.info("Mock terminate %s", str(process)) From d6bd97e6bc99a1e49cdf096c3d6850492e3e701f Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 21 Mar 2026 09:39:50 +0530 Subject: [PATCH 4/7] Use @pytest.mark.parametrize to reduce test duplication Signed-off-by: Sreekanth --- .../accumulator/test_async_accumulator.py | 36 +++++------ .../tests/batchmap/test_async_batch_map.py | 26 ++++---- .../pynumaflow/tests/map/test_async_mapper.py | 26 ++++---- .../tests/map/test_multiproc_mapper.py | 62 ++++++++----------- .../pynumaflow/tests/map/test_sync_mapper.py | 54 +++++++--------- .../tests/mapstream/test_async_map_stream.py | 26 ++++---- .../tests/reduce/test_async_reduce.py | 26 ++++---- .../tests/reducestreamer/test_async_reduce.py | 26 ++++---- .../tests/sideinput/test_side_input_server.py | 26 ++++---- .../pynumaflow/tests/sink/test_async_sink.py | 26 ++++---- packages/pynumaflow/tests/sink/test_server.py | 26 ++++---- .../tests/source/test_async_source.py | 26 ++++---- .../tests/sourcetransform/test_async.py | 26 ++++---- .../tests/sourcetransform/test_multiproc.py | 56 +++++++---------- .../tests/sourcetransform/test_sync_server.py | 52 +++++++--------- 15 files changed, 254 insertions(+), 266 deletions(-) diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py index 46922b4e..8faff67c 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py @@ -409,26 +409,22 @@ class ExampleBadClass: AccumulatorAsyncServer(accumulator_instance=ExampleBadClass) -def test_max_threads(): - # max cap at 16 - server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = AccumulatorAsyncServer(accumulator_instance=ExampleClass) - assert server.max_threads == 4 - - # zero threads - server = AccumulatorAsyncServer(ExampleClass, max_threads=0) - assert server.max_threads == 0 - - # negative threads - server = AccumulatorAsyncServer(ExampleClass, max_threads=-5) - assert server.max_threads == -5 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + (0, 0), # zero threads + (-5, -5), # negative threads + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"accumulator_instance": ExampleClass} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = AccumulatorAsyncServer(**kwargs) + assert server.max_threads == expected def test_server_info_file_path_handling(): diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py index cdfc2003..68e864a1 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py @@ -169,15 +169,17 @@ def test_is_ready(async_batch_map_server) -> None: assert response.ready -def test_max_threads(): - # max cap at 16 - server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = BatchMapAsyncServer(batch_mapper_instance=handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"batch_mapper_instance": handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = BatchMapAsyncServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/map/test_async_mapper.py b/packages/pynumaflow/tests/map/test_async_mapper.py index 42e1bce6..570ae70c 100644 --- a/packages/pynumaflow/tests/map/test_async_mapper.py +++ b/packages/pynumaflow/tests/map/test_async_mapper.py @@ -206,15 +206,17 @@ def test_invalid_input(): MapAsyncServer() -def test_max_threads(): - # max cap at 16 - server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = MapAsyncServer(mapper_instance=async_map_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"mapper_instance": async_map_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = MapAsyncServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/map/test_multiproc_mapper.py b/packages/pynumaflow/tests/map/test_multiproc_mapper.py index 148bbcd9..38aa9f43 100644 --- a/packages/pynumaflow/tests/map/test_multiproc_mapper.py +++ b/packages/pynumaflow/tests/map/test_multiproc_mapper.py @@ -27,51 +27,41 @@ def _invoke_map_fn(test_server, timeout=1): ) -def test_multiproc_init(): - my_server = MapMultiprocServer(mapper_instance=map_handler, server_count=3) - assert my_server._process_count == 3 - - -def test_multiproc_process_count(): - default_val = os.cpu_count() - my_server = MapMultiprocServer(mapper_instance=map_handler) - assert my_server._process_count == default_val - - -def test_max_process_count(): - """Max process count is capped at 2 * os.cpu_count, irrespective of what the user - provides as input""" - default_val = os.cpu_count() - server = MapMultiprocServer(mapper_instance=map_handler, server_count=100) - assert server._process_count == default_val * 2 - - -def test_udf_map_err_handshake(): +@pytest.mark.parametrize( + "server_count,expected", + [ + (3, 3), # explicit count + (None, os.cpu_count()), # default to cpu count + (100, os.cpu_count() * 2), # max cap at 2 * cpu count + ], +) +def test_process_count(server_count, expected): + kwargs = {"mapper_instance": map_handler} + if server_count is not None: + kwargs["server_count"] = server_count + server = MapMultiprocServer(**kwargs) + assert server._process_count == expected + + +@pytest.mark.parametrize( + "handshake,expected_msg", + [ + (False, "MapFn: expected handshake as the first message"), + (True, "Something is fishy!"), + ], +) +def test_udf_map_error(handshake, expected_msg): my_server = MapMultiprocServer(mapper_instance=err_map_handler) services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} test_server = server_from_dictionary(services, strict_real_time()) - test_datums = get_test_datums(handshake=False) - method = _invoke_map_fn(test_server) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - assert "MapFn: expected handshake as the first message" in details - assert code == StatusCode.INTERNAL - - -def test_udf_map_err(): - my_server = MapMultiprocServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - test_server = server_from_dictionary(services, strict_real_time()) - test_datums = get_test_datums(handshake=True) + test_datums = get_test_datums(handshake=handshake) method = _invoke_map_fn(test_server) send_test_requests(method, test_datums) drain_responses(method) metadata, code, details = method.termination() - assert "Something is fishy!" in details + assert expected_msg in details assert code == StatusCode.INTERNAL diff --git a/packages/pynumaflow/tests/map/test_sync_mapper.py b/packages/pynumaflow/tests/map/test_sync_mapper.py index f2d33df6..28942bc4 100644 --- a/packages/pynumaflow/tests/map/test_sync_mapper.py +++ b/packages/pynumaflow/tests/map/test_sync_mapper.py @@ -36,33 +36,25 @@ def test_init_with_args(): assert my_servicer.max_message_size == 1024 * 1024 * 5 -def test_udf_map_err_handshake(): +@pytest.mark.parametrize( + "handshake,expected_msg", + [ + (False, "MapFn: expected handshake as the first message"), + (True, "Something is fishy!"), + ], +) +def test_udf_map_error(handshake, expected_msg): my_server = MapServer(mapper_instance=err_map_handler) services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} test_server = server_from_dictionary(services, strict_real_time()) - test_datums = get_test_datums(handshake=False) + test_datums = get_test_datums(handshake=handshake) method = _invoke_map_fn(test_server) send_test_requests(method, test_datums) drain_responses(method) metadata, code, details = method.termination() - assert "MapFn: expected handshake as the first message" in details - assert code == StatusCode.INTERNAL - - -def test_udf_map_error_response(): - my_server = MapServer(mapper_instance=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} - test_server = server_from_dictionary(services, strict_real_time()) - - test_datums = get_test_datums(handshake=True) - method = _invoke_map_fn(test_server) - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - assert "Something is fishy!" in details + assert expected_msg in details assert code == StatusCode.INTERNAL @@ -110,15 +102,17 @@ def test_invalid_input(): MapServer() -def test_max_threads(): - # max cap at 16 - server = MapServer(mapper_instance=map_handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = MapServer(mapper_instance=map_handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = MapServer(mapper_instance=map_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"mapper_instance": map_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = MapServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py index 55df9178..ae634992 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py @@ -153,15 +153,17 @@ def test_is_ready(async_map_stream_server): assert response.ready -def test_max_threads(): - # max cap at 16 - server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"map_stream_instance": async_map_stream_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = MapStreamAsyncServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce.py b/packages/pynumaflow/tests/reduce/test_async_reduce.py index febcbc0e..cb853b95 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce.py @@ -248,15 +248,17 @@ class ExampleBadClass: ReduceAsyncServer(reducer_instance=ExampleBadClass) -def test_max_threads(): - # max cap at 16 - server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = ReduceAsyncServer(reducer_instance=ExampleClass) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"reducer_instance": ExampleClass} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = ReduceAsyncServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py index c0745e25..e949e06b 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py @@ -279,18 +279,20 @@ class ExampleBadClass: ReduceStreamAsyncServer(reduce_stream_instance=ExampleBadClass) -def test_max_threads(): - # max cap at 16 - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"reduce_stream_instance": ExampleClass} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = ReduceStreamAsyncServer(**kwargs) + assert server.max_threads == expected def test_start_shutdown_handler_without_callback(): diff --git a/packages/pynumaflow/tests/sideinput/test_side_input_server.py b/packages/pynumaflow/tests/sideinput/test_side_input_server.py index fe110eb3..3f0bb329 100644 --- a/packages/pynumaflow/tests/sideinput/test_side_input_server.py +++ b/packages/pynumaflow/tests/sideinput/test_side_input_server.py @@ -108,15 +108,17 @@ def test_invalid_input(): SideInputServer() -def test_max_threads(): - # max cap at 16 - server = SideInputServer(retrieve_side_input_handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = SideInputServer(retrieve_side_input_handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = SideInputServer(retrieve_side_input_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"side_input_instance": retrieve_side_input_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = SideInputServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/sink/test_async_sink.py b/packages/pynumaflow/tests/sink/test_async_sink.py index c0c23b3d..4eb066cc 100644 --- a/packages/pynumaflow/tests/sink/test_async_sink.py +++ b/packages/pynumaflow/tests/sink/test_async_sink.py @@ -352,15 +352,17 @@ def test_start_on_success_sink(): assert server.server_info_file == ON_SUCCESS_SINK_SERVER_INFO_FILE_PATH -def test_max_threads(): - # max cap at 16 - server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = SinkAsyncServer(sinker_instance=udsink_handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = SinkAsyncServer(sinker_instance=udsink_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"sinker_instance": udsink_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = SinkAsyncServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/sink/test_server.py b/packages/pynumaflow/tests/sink/test_server.py index 65da9452..57b8d524 100644 --- a/packages/pynumaflow/tests/sink/test_server.py +++ b/packages/pynumaflow/tests/sink/test_server.py @@ -304,18 +304,20 @@ def test_start_on_success_sink(): assert server.server_info_file == ON_SUCCESS_SINK_SERVER_INFO_FILE_PATH -def test_max_threads(): - # max cap at 16 - server = SinkServer(sinker_instance=udsink_handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = SinkServer(sinker_instance=udsink_handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = SinkServer(sinker_instance=udsink_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"sinker_instance": udsink_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = SinkServer(**kwargs) + assert server.max_threads == expected # --------------------------------------------------------------------------- diff --git a/packages/pynumaflow/tests/source/test_async_source.py b/packages/pynumaflow/tests/source/test_async_source.py index 06c371b6..18ebc106 100644 --- a/packages/pynumaflow/tests/source/test_async_source.py +++ b/packages/pynumaflow/tests/source/test_async_source.py @@ -212,16 +212,18 @@ def test_partitions(async_source_server) -> None: assert response.result.partitions == mock_partitions() -def test_max_threads(): +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): class_instance = AsyncSource() - # max cap at 16 - server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = SourceAsyncServer(sourcer_instance=class_instance, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = SourceAsyncServer(sourcer_instance=class_instance) - assert server.max_threads == 4 + kwargs = {"sourcer_instance": class_instance} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = SourceAsyncServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/sourcetransform/test_async.py b/packages/pynumaflow/tests/sourcetransform/test_async.py index a6fe7a61..32092637 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_async.py +++ b/packages/pynumaflow/tests/sourcetransform/test_async.py @@ -231,19 +231,21 @@ def test_invalid_input(): SourceTransformAsyncServer() -def test_max_threads(): +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): handle = SimpleAsyncSourceTrn() - # max cap at 16 - server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = SourceTransformAsyncServer(source_transform_instance=handle, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = SourceTransformAsyncServer(source_transform_instance=handle) - assert server.max_threads == 4 + kwargs = {"source_transform_instance": handle} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = SourceTransformAsyncServer(**kwargs) + assert server.max_threads == expected # --- Metadata test class --- diff --git a/packages/pynumaflow/tests/sourcetransform/test_multiproc.py b/packages/pynumaflow/tests/sourcetransform/test_multiproc.py index 4d56cd31..5a56b20a 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_multiproc.py +++ b/packages/pynumaflow/tests/sourcetransform/test_multiproc.py @@ -58,29 +58,23 @@ def test_max_process_count(): assert server._process_count == 2 * default_value -def test_udf_mapt_err_handshake(): +@pytest.mark.parametrize( + "handshake,expected_msg", + [ + (True, "Something is fishy"), + (False, "SourceTransformFn: expected handshake message"), + ], +) +def test_udf_mapt_error(handshake, expected_msg): test_server = _make_multiproc_server(err_transform_handler) - test_datums = get_test_datums(handshake=False) + test_datums = get_test_datums(handshake=handshake) method = _invoke_transform_fn(test_server) send_test_requests(method, test_datums) drain_responses(method) metadata, code, details = method.termination() - assert "SourceTransformFn: expected handshake message" in details - assert code == StatusCode.INTERNAL - - -def test_udf_mapt_err(): - test_server = _make_multiproc_server(err_transform_handler) - test_datums = get_test_datums() - method = _invoke_transform_fn(test_server) - - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - assert "Something is fishy" in details + assert expected_msg in details assert code == StatusCode.INTERNAL @@ -136,19 +130,17 @@ def test_invalid_input(): SourceTransformMultiProcServer() -def test_max_threads(): - # max cap at 16 - server = SourceTransformMultiProcServer( - source_transform_instance=transform_handler, max_threads=32 - ) - assert server.max_threads == 16 - - # use argument provided - server = SourceTransformMultiProcServer( - source_transform_instance=transform_handler, max_threads=5 - ) - assert server.max_threads == 5 - - # defaults to 4 - server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"source_transform_instance": transform_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = SourceTransformMultiProcServer(**kwargs) + assert server.max_threads == expected diff --git a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py index c01b1efe..eb982418 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_sync_server.py +++ b/packages/pynumaflow/tests/sourcetransform/test_sync_server.py @@ -51,16 +51,23 @@ def test_init_with_args(): assert server.max_message_size == 1024 * 1024 * 5 -def test_udf_mapt_err(): +@pytest.mark.parametrize( + "handshake,expected_msg", + [ + (True, "Something is fishy"), + (False, "SourceTransformFn: expected handshake message"), + ], +) +def test_udf_mapt_error(handshake, expected_msg): test_server = _make_transform_server(err_transform_handler) - test_datums = get_test_datums() + test_datums = get_test_datums(handshake=handshake) method = _invoke_transform_fn(test_server) send_test_requests(method, test_datums) drain_responses(method) metadata, code, details = method.termination() - assert "Something is fishy" in details + assert expected_msg in details assert code == StatusCode.INTERNAL @@ -79,19 +86,6 @@ def test_is_ready(transform_test_server): assert code == StatusCode.OK -def test_udf_mapt_err_handshake(): - test_server = _make_transform_server(err_transform_handler) - test_datums = get_test_datums(handshake=False) - method = _invoke_transform_fn(test_server) - - send_test_requests(method, test_datums) - drain_responses(method) - - metadata, code, details = method.termination() - assert "SourceTransformFn: expected handshake message" in details - assert code == StatusCode.INTERNAL - - def test_mapt_assign_new_event_time(transform_test_server): test_datums = get_test_datums() method = _invoke_transform_fn(transform_test_server) @@ -129,18 +123,20 @@ def test_invalid_input(): SourceTransformServer() -def test_max_threads(): - # max cap at 16 - server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=32) - assert server.max_threads == 16 - - # use argument provided - server = SourceTransformServer(source_transform_instance=transform_handler, max_threads=5) - assert server.max_threads == 5 - - # defaults to 4 - server = SourceTransformServer(source_transform_instance=transform_handler) - assert server.max_threads == 4 +@pytest.mark.parametrize( + "max_threads_arg,expected", + [ + (32, 16), # max cap at 16 + (5, 5), # use argument provided + (None, 4), # defaults to 4 + ], +) +def test_max_threads(max_threads_arg, expected): + kwargs = {"source_transform_instance": transform_handler} + if max_threads_arg is not None: + kwargs["max_threads"] = max_threads_arg + server = SourceTransformServer(**kwargs) + assert server.max_threads == expected # --------------------------------------------------------------------------- From aa02a8ae931413377dfb62ee1e2ee2ba15470d0a Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 21 Mar 2026 09:43:45 +0530 Subject: [PATCH 5/7] move pytest config to pyproject.toml Signed-off-by: Sreekanth --- packages/pynumaflow/pyproject.toml | 11 +++++++++++ packages/pynumaflow/pytest.ini | 6 ------ 2 files changed, 11 insertions(+), 6 deletions(-) delete mode 100644 packages/pynumaflow/pytest.ini diff --git a/packages/pynumaflow/pyproject.toml b/packages/pynumaflow/pyproject.toml index 3d18ca8d..c39c3ec2 100644 --- a/packages/pynumaflow/pyproject.toml +++ b/packages/pynumaflow/pyproject.toml @@ -84,6 +84,15 @@ exclude = ''' )/ ''' +[tool.pytest] +strict = true +collect_imported_tests = false +console_output_style = "times" +log_cli = true +log_cli_level = "DEBUG" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" + [tool.ruff] line-length = 100 src = ["pynumaflow", "tests", "examples"] @@ -94,4 +103,6 @@ extend-exclude = [ "*_pb2*.py", "*.pyi" ] + +[tool.ruff.lint] select = ["E", "F", "UP"] diff --git a/packages/pynumaflow/pytest.ini b/packages/pynumaflow/pytest.ini deleted file mode 100644 index 9d6adb92..00000000 --- a/packages/pynumaflow/pytest.ini +++ /dev/null @@ -1,6 +0,0 @@ -# pytest.ini -[pytest] -log_cli = 1 -log_cli_level = DEBUG -log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) -log_cli_date_format=%Y-%m-%d %H:%M:%S \ No newline at end of file From 27653c7610b75979ce41256aa9a08da45ec26249 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 21 Mar 2026 11:16:41 +0530 Subject: [PATCH 6/7] pytest marker for tests that run full grpc server Signed-off-by: Sreekanth --- packages/pynumaflow/Makefile | 6 ++++++ packages/pynumaflow/pyproject.toml | 3 +++ .../pynumaflow/tests/accumulator/test_async_accumulator.py | 2 ++ .../tests/accumulator/test_async_accumulator_err.py | 2 ++ packages/pynumaflow/tests/batchmap/test_async_batch_map.py | 2 ++ .../pynumaflow/tests/batchmap/test_async_batch_map_err.py | 2 ++ packages/pynumaflow/tests/map/test_async_mapper.py | 2 ++ .../pynumaflow/tests/mapstream/test_async_map_stream.py | 3 +++ .../pynumaflow/tests/mapstream/test_async_map_stream_err.py | 2 ++ packages/pynumaflow/tests/reduce/test_async_reduce.py | 2 ++ packages/pynumaflow/tests/reduce/test_async_reduce_err.py | 2 ++ .../pynumaflow/tests/reducestreamer/test_async_reduce.py | 2 ++ .../tests/reducestreamer/test_async_reduce_err.py | 2 ++ packages/pynumaflow/tests/sink/test_async_sink.py | 2 ++ packages/pynumaflow/tests/source/test_async_source.py | 2 ++ packages/pynumaflow/tests/source/test_async_source_err.py | 2 ++ packages/pynumaflow/tests/sourcetransform/test_async.py | 2 ++ 17 files changed, 40 insertions(+) diff --git a/packages/pynumaflow/Makefile b/packages/pynumaflow/Makefile index ebb00145..1788ff80 100644 --- a/packages/pynumaflow/Makefile +++ b/packages/pynumaflow/Makefile @@ -17,6 +17,12 @@ lint: format test: uv run pytest tests/ -rA +test-unit: + uv run pytest tests/ -m "not integration" -rA + +test-integration: + uv run pytest tests/ -m integration -rA + setup: uv sync --all-groups diff --git a/packages/pynumaflow/pyproject.toml b/packages/pynumaflow/pyproject.toml index c39c3ec2..896c487f 100644 --- a/packages/pynumaflow/pyproject.toml +++ b/packages/pynumaflow/pyproject.toml @@ -92,6 +92,9 @@ log_cli = true log_cli_level = "DEBUG" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" +markers = [ + "integration: tests that start real gRPC servers on Unix sockets (slower)", +] [tool.ruff] line-length = 100 diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py index 8faff67c..f4fd47ce 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py @@ -23,6 +23,8 @@ get_time_args, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) SOCK_PATH = "unix:///tmp/accumulator.sock" diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py index 09eead23..0ab293b0 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py @@ -20,6 +20,8 @@ get_time_args, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) SOCK_PATH = "unix:///tmp/accumulator_err.sock" diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py index 68e864a1..79c24a03 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py @@ -19,6 +19,8 @@ from pynumaflow.proto.mapper import map_pb2_grpc from tests.batchmap.utils import request_generator +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) listen_addr = "unix:///tmp/batch_map.sock" diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py index 8c6795d0..7b0cb247 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py @@ -11,6 +11,8 @@ from pynumaflow.proto.mapper import map_pb2_grpc from tests.batchmap.utils import request_generator +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) raise_error = False diff --git a/packages/pynumaflow/tests/map/test_async_mapper.py b/packages/pynumaflow/tests/map/test_async_mapper.py index 570ae70c..9c8d3f16 100644 --- a/packages/pynumaflow/tests/map/test_async_mapper.py +++ b/packages/pynumaflow/tests/map/test_async_mapper.py @@ -19,6 +19,8 @@ from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from tests.map.utils import get_test_datums +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) # if set to true, map handler will raise a `ValueError` exception. diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py index ae634992..b47a85ae 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py @@ -15,8 +15,11 @@ ) from pynumaflow.proto.mapper import map_pb2_grpc from tests.mapstream.utils import request_generator + import pytest +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) # if set to true, map handler will raise a `ValueError` exception. diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py index 8f599136..2e8b933f 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py @@ -11,6 +11,8 @@ from pynumaflow.proto.mapper import map_pb2_grpc from tests.mapstream.utils import request_generator +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) SOCK_PATH = "unix:///tmp/async_map_stream_err.sock" diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce.py b/packages/pynumaflow/tests/reduce/test_async_reduce.py index cb853b95..e063de39 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce.py @@ -24,6 +24,8 @@ get_time_args, ) +pytestmark = pytest.mark.integration + logging.basicConfig(level=logging.DEBUG) LOGGER = logging.getLogger(__name__) diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce_err.py b/packages/pynumaflow/tests/reduce/test_async_reduce_err.py index 53b3c91b..d4875bf5 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce_err.py @@ -23,6 +23,8 @@ get_time_args, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) SOCK_PATH = "unix:///tmp/reduce_err.sock" diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py index e949e06b..f8e47eca 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py @@ -25,6 +25,8 @@ get_time_args, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) SOCK_PATH = "unix:///tmp/reduce_stream.sock" diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py index d2ba2a50..45400620 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py @@ -27,6 +27,8 @@ get_time_args, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) SOCK_PATH = "unix:///tmp/reduce_stream_err.sock" diff --git a/packages/pynumaflow/tests/sink/test_async_sink.py b/packages/pynumaflow/tests/sink/test_async_sink.py index 4eb066cc..b92ca1f5 100644 --- a/packages/pynumaflow/tests/sink/test_async_sink.py +++ b/packages/pynumaflow/tests/sink/test_async_sink.py @@ -32,6 +32,8 @@ ) from tests.testing_utils import get_time_args +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) SOCK_PATH = "unix:///tmp/async_sink.sock" diff --git a/packages/pynumaflow/tests/source/test_async_source.py b/packages/pynumaflow/tests/source/test_async_source.py index 18ebc106..c9f30a6a 100644 --- a/packages/pynumaflow/tests/source/test_async_source.py +++ b/packages/pynumaflow/tests/source/test_async_source.py @@ -22,6 +22,8 @@ nack_req_source_fn, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) server_port = "unix:///tmp/async_source.sock" diff --git a/packages/pynumaflow/tests/source/test_async_source_err.py b/packages/pynumaflow/tests/source/test_async_source_err.py index 88fb4542..53360aea 100644 --- a/packages/pynumaflow/tests/source/test_async_source_err.py +++ b/packages/pynumaflow/tests/source/test_async_source_err.py @@ -18,6 +18,8 @@ nack_req_source_fn, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) server_port = "unix:///tmp/async_err_source.sock" diff --git a/packages/pynumaflow/tests/sourcetransform/test_async.py b/packages/pynumaflow/tests/sourcetransform/test_async.py index 32092637..e0bddc31 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_async.py +++ b/packages/pynumaflow/tests/sourcetransform/test_async.py @@ -18,6 +18,8 @@ mock_new_event_time, ) +pytestmark = pytest.mark.integration + LOGGER = setup_logging(__name__) # if set to true, transform handler will raise a `ValueError` exception. From 965d955e53e2c70064ec9c8804337ddc765ffb6c Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Sat, 21 Mar 2026 13:06:18 +0530 Subject: [PATCH 7/7] The logs for each test should be printed along with the test name itself Signed-off-by: Sreekanth --- packages/pynumaflow/Makefile | 3 + packages/pynumaflow/pyproject.toml | 2 - .../accumulator/test_async_accumulator.py | 35 ++------ .../accumulator/test_async_accumulator_err.py | 35 ++------ .../tests/batchmap/test_async_batch_map.py | 30 ++----- .../batchmap/test_async_batch_map_err.py | 33 ++----- packages/pynumaflow/tests/conftest.py | 87 +++++++++++++++++++ .../pynumaflow/tests/map/test_async_mapper.py | 32 ++----- .../tests/mapstream/test_async_map_stream.py | 34 ++------ .../mapstream/test_async_map_stream_err.py | 34 ++------ .../tests/reduce/test_async_reduce.py | 35 ++------ .../tests/reduce/test_async_reduce_err.py | 35 ++------ .../tests/reducestreamer/test_async_reduce.py | 35 ++------ .../reducestreamer/test_async_reduce_err.py | 34 ++------ .../pynumaflow/tests/sink/test_async_sink.py | 31 ++----- .../tests/source/test_async_source.py | 30 ++----- .../tests/source/test_async_source_err.py | 30 ++----- .../tests/sourcetransform/test_async.py | 60 ++----------- 18 files changed, 169 insertions(+), 446 deletions(-) diff --git a/packages/pynumaflow/Makefile b/packages/pynumaflow/Makefile index 1788ff80..b660b8e7 100644 --- a/packages/pynumaflow/Makefile +++ b/packages/pynumaflow/Makefile @@ -23,6 +23,9 @@ test-unit: test-integration: uv run pytest tests/ -m integration -rA +test-debug: + uv run pytest tests/ -rA --log-cli-level=DEBUG + setup: uv sync --all-groups diff --git a/packages/pynumaflow/pyproject.toml b/packages/pynumaflow/pyproject.toml index 896c487f..27c90cba 100644 --- a/packages/pynumaflow/pyproject.toml +++ b/packages/pynumaflow/pyproject.toml @@ -88,8 +88,6 @@ exclude = ''' strict = true collect_imported_tests = false console_output_style = "times" -log_cli = true -log_cli_level = "DEBUG" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" markers = [ diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py index f4fd47ce..eb499cdf 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -16,6 +14,7 @@ ) from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.testing_utils import ( mock_message, mock_interval_window_start, @@ -140,11 +139,6 @@ def start_request_without_open() -> accumulator_pb2.AccumulatorRequest: return request -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - class ExampleClass(Accumulator): def __init__(self, counter): self.counter = counter @@ -176,36 +170,17 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_accumulator_server(): """Module-scoped fixture: starts an async gRPC accumulator server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() udfs = NewAsyncAccumulator() - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - # Wait for the server to be ready - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py b/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py index 0ab293b0..07d10586 100644 --- a/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py +++ b/packages/pynumaflow/tests/accumulator/test_async_accumulator_err.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -15,6 +13,7 @@ ) from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.testing_utils import ( mock_message, get_time_args, @@ -59,11 +58,6 @@ def start_request() -> accumulator_pb2.AccumulatorRequest: return request -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - class ExampleErrorClass(Accumulator): def __init__(self, counter): self.counter = counter @@ -101,36 +95,17 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_accumulator_err_server(): """Module-scoped fixture: starts an async gRPC accumulator error server.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() udfs = NewAsyncAccumulatorError() - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - # Wait for the server to be ready - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py index 79c24a03..7d6b44dd 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -18,6 +16,7 @@ ) from pynumaflow.proto.mapper import map_pb2_grpc from tests.batchmap.utils import request_generator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server pytestmark = pytest.mark.integration @@ -26,11 +25,6 @@ listen_addr = "unix:///tmp/batch_map.sock" -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - class ExampleClass(BatchMapper): async def handler( self, @@ -86,34 +80,20 @@ async def start_server(udfs): server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) await server.start() - await server.wait_for_termination() + return server, listen_addr @pytest.fixture(scope="module") def async_batch_map_server(): """Module-scoped fixture: starts an async gRPC batch map server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() + loop = create_async_loop() udfs = NewAsyncBatchMapper() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - - while True: - try: - with grpc.insecure_channel(listen_addr) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) + server = start_async_server(loop, start_server(udfs)) yield loop - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py b/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py index 7b0cb247..aaee7227 100644 --- a/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py +++ b/packages/pynumaflow/tests/batchmap/test_async_batch_map_err.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading import grpc import pytest @@ -10,6 +8,7 @@ from pynumaflow.batchmapper import BatchMapAsyncServer from pynumaflow.proto.mapper import map_pb2_grpc from tests.batchmap.utils import request_generator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server pytestmark = pytest.mark.integration @@ -29,11 +28,6 @@ async def err_handler(datums) -> BatchResponses: listen_addr = "unix:///tmp/async_batch_map_err.sock" -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def start_server(): server = grpc.aio.server() server_instance = BatchMapAsyncServer(err_handler) @@ -42,33 +36,16 @@ async def start_server(): server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) await server.start() - await server.wait_for_termination() + return server, listen_addr @pytest.fixture(scope="module") def async_batch_map_err_server(): """Module-scoped fixture: starts an async gRPC batch map error server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() - - asyncio.run_coroutine_threadsafe(start_server(), loop=loop) - - while True: - try: - with grpc.insecure_channel(listen_addr) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + loop = create_async_loop() + server = start_async_server(loop, start_server()) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/conftest.py b/packages/pynumaflow/tests/conftest.py index be529123..12507fd3 100644 --- a/packages/pynumaflow/tests/conftest.py +++ b/packages/pynumaflow/tests/conftest.py @@ -5,6 +5,93 @@ sync, multiproc, and async test files. """ +import asyncio +import logging +import threading + +import grpc + +_logger = logging.getLogger(__name__) + + +def start_async_server(loop, start_server_coro): + """Start an async gRPC server on the given event loop and wait until it is ready. + + Args: + loop: The asyncio event loop running in a background thread. + start_server_coro: An awaitable that starts the server and returns + a tuple of (grpc.aio.Server, sock_path). + + Returns: + The grpc.aio.Server instance. + """ + future = asyncio.run_coroutine_threadsafe(start_server_coro, loop=loop) + server, sock_path = future.result(timeout=10) + + # Block until the server is accepting connections + while True: + try: + with grpc.insecure_channel(sock_path) as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + _logger.error("error trying to connect to grpc server") + _logger.error(e) + + return server + + +def create_async_loop(): + """Create a new asyncio event loop running in a daemon thread. + + Returns: + The running event loop. + """ + loop = asyncio.new_event_loop() + + def _run(lp): + asyncio.set_event_loop(lp) + lp.run_forever() + + thread = threading.Thread(target=_run, args=(loop,), daemon=True) + thread.start() + return loop + + +def teardown_async_server(loop, server): + """Gracefully shut down an async gRPC server and its event loop. + + Stops the gRPC server, cancels any remaining tasks, then stops the loop. + This prevents 'Task was destroyed but it is pending!' warnings. + """ + + async def _shutdown(): + await server.stop(grace=1) + # Cancel any lingering tasks on this loop, excluding the current + # _shutdown task itself to avoid recursive cancel chains. + current = asyncio.current_task() + tasks = [t for t in asyncio.all_tasks(loop) if not t.done() and t is not current] + for task in tasks: + task.cancel() + # Await each cancelled task individually so a RecursionError in one + # deeply-nested cancel chain does not prevent the others from being + # reaped, and does not propagate up to the caller. + for task in tasks: + try: + await task + except (asyncio.CancelledError, RecursionError, Exception): + pass + + try: + future = asyncio.run_coroutine_threadsafe(_shutdown(), loop=loop) + future.result(timeout=10) + except Exception as e: + _logger.error("error during async server teardown: %s", e) + finally: + loop.call_soon_threadsafe(loop.stop) + def collect_responses(method): """Collect all responses from a grpc_testing stream method until exhausted. diff --git a/packages/pynumaflow/tests/map/test_async_mapper.py b/packages/pynumaflow/tests/map/test_async_mapper.py index 9c8d3f16..8fd22499 100644 --- a/packages/pynumaflow/tests/map/test_async_mapper.py +++ b/packages/pynumaflow/tests/map/test_async_mapper.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import Iterator import grpc @@ -17,6 +15,7 @@ from pynumaflow.mapper.async_server import MapAsyncServer from pynumaflow.proto.common import metadata_pb2 from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.map.utils import get_test_datums pytestmark = pytest.mark.integration @@ -49,11 +48,6 @@ async def async_map_handler(keys: list[str], datum: Datum) -> Messages: return messages -def _startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def _start_server(udfs): _server_options = [ ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), @@ -64,37 +58,21 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_map_server(): """Module-scoped fixture: starts an async gRPC map server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) - thread.start() + loop = create_async_loop() server_obj = MapAsyncServer(mapper_instance=async_map_handler) udfs = server_obj.servicer - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - _server = future.result(timeout=10) - - # Wait for the server to be ready - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) + server = start_async_server(loop, _start_server(udfs)) yield loop - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py index b47a85ae..1afa6c2c 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections import Counter from collections.abc import AsyncIterable @@ -15,6 +13,7 @@ ) from pynumaflow.proto.mapper import map_pb2_grpc from tests.mapstream.utils import request_generator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server import pytest @@ -39,47 +38,24 @@ async def async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIterab yield Message(str.encode(msg), keys=keys) -def _startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def _start_server(udfs): server = grpc.aio.server() map_pb2_grpc.add_MapServicer_to_server(udfs, server) server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_map_stream_server(): """Module-scoped fixture: starts an async gRPC map stream server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() server_obj = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) udfs = server_obj.servicer - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py b/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py index 2e8b933f..510cf75c 100644 --- a/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py +++ b/packages/pynumaflow/tests/mapstream/test_async_map_stream_err.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -10,6 +8,7 @@ from pynumaflow.mapstreamer import Message, Datum, MapStreamAsyncServer from pynumaflow.proto.mapper import map_pb2_grpc from tests.mapstream.utils import request_generator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server pytestmark = pytest.mark.integration @@ -33,11 +32,6 @@ async def err_async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIt raise RuntimeError("Got a runtime error from map stream handler.") -def _startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def _start_server(): server = grpc.aio.server() server_instance = MapStreamAsyncServer(err_async_map_stream_handler) @@ -46,34 +40,16 @@ async def _start_server(): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_map_stream_err_server(): """Module-scoped fixture: starts an async gRPC map stream error server.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) - thread.start() - - future = asyncio.run_coroutine_threadsafe(_start_server(), loop=loop) - future.result(timeout=10) - - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + loop = create_async_loop() + server = start_async_server(loop, _start_server()) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce.py b/packages/pynumaflow/tests/reduce/test_async_reduce.py index e063de39..105d8ff7 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -17,6 +15,7 @@ ReduceAsyncServer, Reducer, ) +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.testing_utils import ( mock_message, mock_interval_window_start, @@ -66,11 +65,6 @@ def start_request() -> (Datum, tuple): return request, metadata -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - class ExampleClass(Reducer): def __init__(self, counter): self.counter = counter @@ -108,36 +102,17 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_reduce_server(): """Module-scoped fixture: starts an async gRPC reduce server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() udfs = NewAsyncReducer() - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - # Wait for the server to be ready - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/reduce/test_async_reduce_err.py b/packages/pynumaflow/tests/reduce/test_async_reduce_err.py index d4875bf5..793b2a3e 100644 --- a/packages/pynumaflow/tests/reduce/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reduce/test_async_reduce_err.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -16,6 +14,7 @@ ReduceAsyncServer, ) from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.testing_utils import ( mock_message, mock_interval_window_start, @@ -70,11 +69,6 @@ def start_request(multiple_window: False) -> (Datum, tuple): return request, metadata -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def err_handler(keys: list[str], datums: AsyncIterable[Datum], md: Metadata) -> Messages: interval_window = md.interval_window counter = 0 @@ -101,36 +95,17 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_reduce_err_server(): """Module-scoped fixture: starts an async gRPC reduce error server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() udfs = NewAsyncReducer() - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - # Wait for the server to be ready - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py index f8e47eca..a2b65c5f 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -18,6 +16,7 @@ ) from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc from pynumaflow.shared.asynciter import NonBlockingIterator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.testing_utils import ( mock_message, mock_interval_window_start, @@ -71,11 +70,6 @@ def start_request() -> (Datum, tuple): return request, metadata -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - class ExampleClass(ReduceStreamer): def __init__(self, counter): self.counter = counter @@ -128,36 +122,17 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_reduce_stream_server(): """Module-scoped fixture: starts an async gRPC reduce stream server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() udfs = NewAsyncReduceStreamer() - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - # Wait for the server to be ready - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py index 45400620..4c0f9eb5 100644 --- a/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py +++ b/packages/pynumaflow/tests/reducestreamer/test_async_reduce_err.py @@ -1,6 +1,5 @@ import asyncio import logging -import threading from collections.abc import AsyncIterable from unittest.mock import MagicMock @@ -20,6 +19,7 @@ from pynumaflow.reducestreamer.servicer.async_servicer import AsyncReduceStreamServicer from pynumaflow.reducestreamer.servicer.task_manager import TaskManager from pynumaflow.shared.asynciter import NonBlockingIterator +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.testing_utils import ( mock_message, mock_interval_window_start, @@ -74,11 +74,6 @@ def start_request(multiple_window: False) -> (Datum, tuple): return request, metadata -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - class ExampleClass(ReduceStreamer): def __init__(self, counter): self.counter = counter @@ -133,36 +128,17 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_reduce_stream_err_server(): """Module-scoped fixture: starts an async gRPC reduce stream error server.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() udfs = NewAsyncReduceStreamer() - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - # Wait for the server to be ready - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/sink/test_async_sink.py b/packages/pynumaflow/tests/sink/test_async_sink.py index b92ca1f5..7a884ed2 100644 --- a/packages/pynumaflow/tests/sink/test_async_sink.py +++ b/packages/pynumaflow/tests/sink/test_async_sink.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading from collections.abc import AsyncIterable import grpc @@ -24,6 +22,7 @@ from pynumaflow.sinker import Responses, Response, Message, UserMetadata from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 from pynumaflow.sinker.async_server import SinkAsyncServer +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.sink.test_server import ( mock_message, mock_err_message, @@ -105,11 +104,6 @@ def request_generator(count, req_type="success", session=1, handshake=True): yield sink_pb2.SinkRequest(status=sink_pb2.TransmissionStatus(eot=True)) -def _startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def _start_server(): server = grpc.aio.server() server_instance = SinkAsyncServer(sinker_instance=udsink_handler) @@ -118,34 +112,19 @@ async def _start_server(): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_sink_server(): """Module-scoped fixture: starts an async gRPC sink server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) - thread.start() + loop = create_async_loop() - future = asyncio.run_coroutine_threadsafe(_start_server(), loop=loop) - future.result(timeout=10) - - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) + server = start_async_server(loop, _start_server()) yield loop - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() diff --git a/packages/pynumaflow/tests/source/test_async_source.py b/packages/pynumaflow/tests/source/test_async_source.py index c9f30a6a..00c8c0d4 100644 --- a/packages/pynumaflow/tests/source/test_async_source.py +++ b/packages/pynumaflow/tests/source/test_async_source.py @@ -1,7 +1,5 @@ -import asyncio from collections.abc import Iterator import logging -import threading import grpc import pytest @@ -13,6 +11,7 @@ from pynumaflow.sourcer import ( SourceAsyncServer, ) +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.source.utils import ( read_req_source_fn, ack_req_source_fn, @@ -29,11 +28,6 @@ server_port = "unix:///tmp/async_source.sock" -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - def NewAsyncSourcer(): class_instance = AsyncSource() server = SourceAsyncServer(sourcer_instance=class_instance) @@ -48,7 +42,7 @@ async def start_server(udfs): server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) await server.start() - await server.wait_for_termination() + return server, listen_addr def request_generator(count, request, req_type, send_handshake: bool = True): @@ -66,28 +60,14 @@ def request_generator(count, request, req_type, send_handshake: bool = True): @pytest.fixture(scope="module") def async_source_server(): """Module-scoped fixture: starts an async gRPC source server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() + loop = create_async_loop() udfs = NewAsyncSourcer() - asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) - - while True: - try: - with grpc.insecure_channel(server_port) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) + server = start_async_server(loop, start_server(udfs)) yield loop - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) def test_read_source(async_source_server) -> None: diff --git a/packages/pynumaflow/tests/source/test_async_source_err.py b/packages/pynumaflow/tests/source/test_async_source_err.py index 53360aea..8a3985a8 100644 --- a/packages/pynumaflow/tests/source/test_async_source_err.py +++ b/packages/pynumaflow/tests/source/test_async_source_err.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading import grpc import pytest @@ -10,6 +8,7 @@ from google.protobuf import empty_pb2 as _empty_pb2 from pynumaflow.sourcer import SourceAsyncServer +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.source.test_async_source import request_generator from tests.source.utils import ( read_req_source_fn, @@ -25,11 +24,6 @@ server_port = "unix:///tmp/async_err_source.sock" -def startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def start_server(): server = grpc.aio.server() class_instance = AsyncSourceError() @@ -40,33 +34,19 @@ async def start_server(): server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) await server.start() - await server.wait_for_termination() + return server, listen_addr @pytest.fixture(scope="module") def async_source_err_server(): """Module-scoped fixture: starts an async gRPC source error server in a background thread.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) - thread.start() + loop = create_async_loop() - asyncio.run_coroutine_threadsafe(start_server(), loop=loop) - - while True: - try: - with grpc.insecure_channel(server_port) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) + server = start_async_server(loop, start_server()) yield loop - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) def test_read_error(async_source_err_server) -> None: diff --git a/packages/pynumaflow/tests/sourcetransform/test_async.py b/packages/pynumaflow/tests/sourcetransform/test_async.py index e0bddc31..89482b68 100644 --- a/packages/pynumaflow/tests/sourcetransform/test_async.py +++ b/packages/pynumaflow/tests/sourcetransform/test_async.py @@ -1,6 +1,4 @@ -import asyncio import logging -import threading import grpc import pytest @@ -13,6 +11,7 @@ from pynumaflow.proto.sourcetransformer import transform_pb2_grpc from pynumaflow.sourcetransformer import Datum, Messages, Message, SourceTransformer from pynumaflow.sourcetransformer.async_server import SourceTransformAsyncServer +from tests.conftest import create_async_loop, start_async_server, teardown_async_server from tests.sourcetransform.utils import get_test_datums from tests.testing_utils import ( mock_new_event_time, @@ -48,11 +47,6 @@ def request_generator(req): yield from req -def _startup_callable(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - async def _start_server(udfs): _server_options = [ ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), @@ -63,37 +57,19 @@ async def _start_server(udfs): server.add_insecure_port(SOCK_PATH) logging.info("Starting server on %s", SOCK_PATH) await server.start() - return server + return server, SOCK_PATH @pytest.fixture(scope="module") def async_st_server(): """Module-scoped fixture: starts an async gRPC source transform server.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() handle = SimpleAsyncSourceTrn() server_obj = SourceTransformAsyncServer(source_transform_instance=handle) udfs = server_obj.servicer - future = asyncio.run_coroutine_threadsafe(_start_server(udfs), loop=loop) - future.result(timeout=10) - - while True: - try: - with grpc.insecure_channel(SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the event loop") + teardown_async_server(loop, server) @pytest.fixture() @@ -285,37 +261,19 @@ async def _start_metadata_server(udfs): server.add_insecure_port(METADATA_SOCK_PATH) logging.info("Starting metadata server on %s", METADATA_SOCK_PATH) await server.start() - return server + return server, METADATA_SOCK_PATH @pytest.fixture(scope="module") def async_st_metadata_server(): """Module-scoped fixture: starts an async gRPC metadata source transform server.""" - loop = asyncio.new_event_loop() - thread = threading.Thread(target=_startup_callable, args=(loop,), daemon=True) - thread.start() - + loop = create_async_loop() handle = MetadataAsyncSourceTransformer() server_obj = SourceTransformAsyncServer(source_transform_instance=handle) udfs = server_obj.servicer - future = asyncio.run_coroutine_threadsafe(_start_metadata_server(udfs), loop=loop) - future.result(timeout=10) - - while True: - try: - with grpc.insecure_channel(METADATA_SOCK_PATH) as channel: - f = grpc.channel_ready_future(channel) - f.result(timeout=10) - if f.done(): - break - except grpc.FutureTimeoutError as e: - LOGGER.error("error trying to connect to grpc server") - LOGGER.error(e) - + server = start_async_server(loop, _start_metadata_server(udfs)) yield loop - - loop.stop() - LOGGER.info("stopped the metadata event loop") + teardown_async_server(loop, server) @pytest.fixture()