From d72180ef093f4c557cc2ca8b3807a70d1140e726 Mon Sep 17 00:00:00 2001 From: Jose Eduardo Date: Fri, 18 Oct 2024 15:36:33 +0100 Subject: [PATCH 1/3] Handle Redis pub/sub subscribe errors --- broadcaster/_base.py | 19 +++++++++++++++---- broadcaster/backends/redis.py | 14 ++++++++++++-- tests/test_broadcast.py | 17 +++++++++++++++++ 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index a63b22b..d418e03 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -29,7 +29,7 @@ class Broadcast: def __init__(self, url: str | None = None, *, backend: BroadcastBackend | None = None) -> None: assert url or backend, "Either `url` or `backend` must be provided." self._backend = backend or self._create_backend(cast(str, url)) - self._subscribers: dict[str, set[asyncio.Queue[Event | None]]] = {} + self._subscribers: dict[str, set[asyncio.Queue[Event | BaseException | None]]] = {} def _create_backend(self, url: str) -> BroadcastBackend: parsed_url = urlparse(url) @@ -69,10 +69,19 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None: async def connect(self) -> None: await self._backend.connect() self._listener_task = asyncio.create_task(self._listener()) + self._listener_task.add_done_callback(self.drop) + + def drop(self, task: asyncio.Task[None]) -> None: + exc = task.exception() + for queues in self._subscribers.values(): + for queue in queues: + queue.put_nowait(exc) async def disconnect(self) -> None: if self._listener_task.done(): - self._listener_task.result() + exc = self._listener_task.exception() + if exc is None: + self._listener_task.result() else: self._listener_task.cancel() await self._backend.disconnect() @@ -88,7 +97,7 @@ async def publish(self, channel: str, message: Any) -> None: @asynccontextmanager async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]: - queue: asyncio.Queue[Event | None] = asyncio.Queue() + queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue() try: if not self._subscribers.get(channel): @@ -107,7 +116,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]: class Subscriber: - def __init__(self, queue: asyncio.Queue[Event | None]) -> None: + def __init__(self, queue: asyncio.Queue[Event | BaseException | None]) -> None: self._queue = queue async def __aiter__(self) -> AsyncGenerator[Event | None, None]: @@ -119,6 +128,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]: async def get(self) -> Event: item = await self._queue.get() + if isinstance(item, BaseException): + raise item if item is None: raise Unsubscribed() return item diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index 1be4195..bd140cd 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -14,11 +14,12 @@ def __init__(self, url: str): self._conn = redis.Redis.from_url(url) self._pubsub = self._conn.pubsub() self._ready = asyncio.Event() - self._queue: asyncio.Queue[Event] = asyncio.Queue() + self._queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue() self._listener: asyncio.Task[None] | None = None async def connect(self) -> None: self._listener = asyncio.create_task(self._pubsub_listener()) + self._listener.add_done_callback(self.drop) await self._pubsub.connect() async def disconnect(self) -> None: @@ -27,6 +28,10 @@ async def disconnect(self) -> None: if self._listener is not None: self._listener.cancel() + def drop(self, task: asyncio.Task[None]) -> None: + exc = task.exception() + self._queue.put_nowait(exc) + async def subscribe(self, channel: str) -> None: self._ready.set() await self._pubsub.subscribe(channel) @@ -38,7 +43,12 @@ async def publish(self, channel: str, message: typing.Any) -> None: await self._conn.publish(channel, message) async def next_published(self) -> Event: - return await self._queue.get() + result = await self._queue.get() + if result is None: + raise RuntimeError + if isinstance(result, BaseException): + raise result + return result async def _pubsub_listener(self) -> None: # redis-py does not listen to the pubsub connection if there are no channels subscribed diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index a8bd3eb..ac11c2c 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -4,6 +4,7 @@ import typing import pytest +import redis from broadcaster import Broadcast, BroadcastBackend, Event from broadcaster.backends.kafka import KafkaBackend @@ -56,6 +57,22 @@ async def test_redis(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_redis_disconnect(): + with pytest.raises(redis.ConnectionError) as exc: + async with Broadcast("redis://localhost:6379") as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + await broadcast._backend._conn.connection_pool.aclose() # type: ignore[attr-defined] + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + await subscriber.get() + assert False + + assert exc.value.args == ("Connection closed by server.",) + + @pytest.mark.asyncio async def test_redis_stream(): async with Broadcast("redis-stream://localhost:6379") as broadcast: From 8b934d063cb441e16f50ace5235e16c442a39960 Mon Sep 17 00:00:00 2001 From: Jose Eduardo Date: Thu, 7 Nov 2024 11:15:34 +0000 Subject: [PATCH 2/3] Avoid loop error logs CancelledError tracebacks may appear on disconnect, like: ERROR asyncio:base_events.py:1821 Exception in callback Broadcast.drop(>) handle: >)> Or using `uvloop`: File "uvloop/cbhandles.pyx", line 63, in uvloop.loop.Handle._run --- broadcaster/_base.py | 12 ++++++++---- tests/test_broadcast.py | 14 +++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index d418e03..b00f2a6 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -72,10 +72,14 @@ async def connect(self) -> None: self._listener_task.add_done_callback(self.drop) def drop(self, task: asyncio.Task[None]) -> None: - exc = task.exception() - for queues in self._subscribers.values(): - for queue in queues: - queue.put_nowait(exc) + try: + exc = task.exception() + except asyncio.CancelledError: + pass + else: + for queues in self._subscribers.values(): + for queue in queues: + queue.put_nowait(exc) async def disconnect(self) -> None: if self._listener_task.done(): diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index ac11c2c..f8029ee 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -58,7 +58,7 @@ async def test_redis(): @pytest.mark.asyncio -async def test_redis_disconnect(): +async def test_redis_server_disconnect(): with pytest.raises(redis.ConnectionError) as exc: async with Broadcast("redis://localhost:6379") as broadcast: async with broadcast.subscribe("chatroom") as subscriber: @@ -73,6 +73,18 @@ async def test_redis_disconnect(): assert exc.value.args == ("Connection closed by server.",) +@pytest.mark.asyncio +async def test_redis_does_not_log_loop_error_messages(caplog): + async with Broadcast("redis://localhost:6379") as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + + assert caplog.messages == [] + + @pytest.mark.asyncio async def test_redis_stream(): async with Broadcast("redis-stream://localhost:6379") as broadcast: From 5f39d75b165ccbf4ef224859ea5510653d56c76f Mon Sep 17 00:00:00 2001 From: Jose Eduardo Date: Fri, 8 Nov 2024 12:58:22 +0000 Subject: [PATCH 3/3] Avoid loop error logs raised from the Redis backend --- broadcaster/backends/redis.py | 8 ++++++-- tests/test_broadcast.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index bd140cd..235ab7b 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -29,8 +29,12 @@ async def disconnect(self) -> None: self._listener.cancel() def drop(self, task: asyncio.Task[None]) -> None: - exc = task.exception() - self._queue.put_nowait(exc) + try: + exc = task.exception() + except asyncio.CancelledError: + pass + else: + self._queue.put_nowait(exc) async def subscribe(self, channel: str) -> None: self._ready.set() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index f8029ee..113b134 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -74,7 +74,7 @@ async def test_redis_server_disconnect(): @pytest.mark.asyncio -async def test_redis_does_not_log_loop_error_messages(caplog): +async def test_redis_does_not_log_loop_error_messages_if_subscribing(caplog): async with Broadcast("redis://localhost:6379") as broadcast: async with broadcast.subscribe("chatroom") as subscriber: await broadcast.publish("chatroom", "hello") @@ -85,6 +85,17 @@ async def test_redis_does_not_log_loop_error_messages(caplog): assert caplog.messages == [] +@pytest.mark.asyncio +async def test_redis_does_not_log_loop_error_messages_if_not_subscribing(caplog): + async with Broadcast("redis://localhost:6379") as broadcast: + await broadcast.publish("chatroom", "hello") + + # Give the loop an opportunity to catch any errors before checking + # the logs. + await asyncio.sleep(0.1) + assert caplog.messages == [] + + @pytest.mark.asyncio async def test_redis_stream(): async with Broadcast("redis-stream://localhost:6379") as broadcast: