diff --git a/broadcaster/_base.py b/broadcaster/_base.py index a63b22b..b00f2a6 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -29,7 +29,7 @@ class Broadcast: def __init__(self, url: str | None = None, *, backend: BroadcastBackend | None = None) -> None: assert url or backend, "Either `url` or `backend` must be provided." self._backend = backend or self._create_backend(cast(str, url)) - self._subscribers: dict[str, set[asyncio.Queue[Event | None]]] = {} + self._subscribers: dict[str, set[asyncio.Queue[Event | BaseException | None]]] = {} def _create_backend(self, url: str) -> BroadcastBackend: parsed_url = urlparse(url) @@ -69,10 +69,23 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None: async def connect(self) -> None: await self._backend.connect() self._listener_task = asyncio.create_task(self._listener()) + self._listener_task.add_done_callback(self.drop) + + def drop(self, task: asyncio.Task[None]) -> None: + try: + exc = task.exception() + except asyncio.CancelledError: + pass + else: + for queues in self._subscribers.values(): + for queue in queues: + queue.put_nowait(exc) async def disconnect(self) -> None: if self._listener_task.done(): - self._listener_task.result() + exc = self._listener_task.exception() + if exc is None: + self._listener_task.result() else: self._listener_task.cancel() await self._backend.disconnect() @@ -88,7 +101,7 @@ async def publish(self, channel: str, message: Any) -> None: @asynccontextmanager async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]: - queue: asyncio.Queue[Event | None] = asyncio.Queue() + queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue() try: if not self._subscribers.get(channel): @@ -107,7 +120,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]: class Subscriber: - def __init__(self, queue: asyncio.Queue[Event | None]) -> None: + def __init__(self, queue: asyncio.Queue[Event | BaseException | None]) -> None: self._queue = queue async def __aiter__(self) -> AsyncGenerator[Event | None, None]: @@ -119,6 +132,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]: async def get(self) -> Event: item = await self._queue.get() + if isinstance(item, BaseException): + raise item if item is None: raise Unsubscribed() return item diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index 1be4195..235ab7b 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -14,11 +14,12 @@ def __init__(self, url: str): self._conn = redis.Redis.from_url(url) self._pubsub = self._conn.pubsub() self._ready = asyncio.Event() - self._queue: asyncio.Queue[Event] = asyncio.Queue() + self._queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue() self._listener: asyncio.Task[None] | None = None async def connect(self) -> None: self._listener = asyncio.create_task(self._pubsub_listener()) + self._listener.add_done_callback(self.drop) await self._pubsub.connect() async def disconnect(self) -> None: @@ -27,6 +28,14 @@ async def disconnect(self) -> None: if self._listener is not None: self._listener.cancel() + def drop(self, task: asyncio.Task[None]) -> None: + try: + exc = task.exception() + except asyncio.CancelledError: + pass + else: + self._queue.put_nowait(exc) + async def subscribe(self, channel: str) -> None: self._ready.set() await self._pubsub.subscribe(channel) @@ -38,7 +47,12 @@ async def publish(self, channel: str, message: typing.Any) -> None: await self._conn.publish(channel, message) async def next_published(self) -> Event: - return await self._queue.get() + result = await self._queue.get() + if result is None: + raise RuntimeError + if isinstance(result, BaseException): + raise result + return result async def _pubsub_listener(self) -> None: # redis-py does not listen to the pubsub connection if there are no channels subscribed diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index a8bd3eb..113b134 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -4,6 +4,7 @@ import typing import pytest +import redis from broadcaster import Broadcast, BroadcastBackend, Event from broadcaster.backends.kafka import KafkaBackend @@ -56,6 +57,45 @@ async def test_redis(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_redis_server_disconnect(): + with pytest.raises(redis.ConnectionError) as exc: + async with Broadcast("redis://localhost:6379") as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + await broadcast._backend._conn.connection_pool.aclose() # type: ignore[attr-defined] + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + await subscriber.get() + assert False + + assert exc.value.args == ("Connection closed by server.",) + + +@pytest.mark.asyncio +async def test_redis_does_not_log_loop_error_messages_if_subscribing(caplog): + async with Broadcast("redis://localhost:6379") as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + + assert caplog.messages == [] + + +@pytest.mark.asyncio +async def test_redis_does_not_log_loop_error_messages_if_not_subscribing(caplog): + async with Broadcast("redis://localhost:6379") as broadcast: + await broadcast.publish("chatroom", "hello") + + # Give the loop an opportunity to catch any errors before checking + # the logs. + await asyncio.sleep(0.1) + assert caplog.messages == [] + + @pytest.mark.asyncio async def test_redis_stream(): async with Broadcast("redis-stream://localhost:6379") as broadcast: