From c3caaddaf3ba657346442f7d60dad81bfff767b9 Mon Sep 17 00:00:00 2001 From: "alex.oleshkevich" Date: Wed, 23 Oct 2024 10:50:17 +0200 Subject: [PATCH] allow preconfigured redis clients --- broadcaster/backends/redis.py | 15 +++++++++++---- tests/test_broadcast.py | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index 1be4195..effb166 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -10,8 +10,15 @@ class RedisBackend(BroadcastBackend): - def __init__(self, url: str): - self._conn = redis.Redis.from_url(url) + _conn: redis.Redis + + def __init__(self, url: str | None = None, *, conn: redis.Redis | None = None): + if url is None: + assert conn is not None, "conn must be provided if url is not" + self._conn = conn + else: + self._conn = redis.Redis.from_url(url) + self._pubsub = self._conn.pubsub() self._ready = asyncio.Event() self._queue: asyncio.Queue[Event] = asyncio.Queue() @@ -19,10 +26,10 @@ def __init__(self, url: str): async def connect(self) -> None: self._listener = asyncio.create_task(self._pubsub_listener()) - await self._pubsub.connect() + await self._pubsub.connect() # type: ignore[no-untyped-call] async def disconnect(self) -> None: - await self._pubsub.aclose() + await self._pubsub.aclose() # type: ignore[no-untyped-call] await self._conn.aclose() if self._listener is not None: self._listener.cancel() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index a8bd3eb..b88b317 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -4,9 +4,11 @@ import typing import pytest +from redis import asyncio as redis from broadcaster import Broadcast, BroadcastBackend, Event from broadcaster.backends.kafka import KafkaBackend +from broadcaster.backends.redis import RedisBackend class CustomBackend(BroadcastBackend): @@ -56,6 +58,23 @@ async def test_redis(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_redis_configured_client(): + backend = RedisBackend(conn=redis.Redis.from_url("redis://localhost:6379")) + async with Broadcast(backend=backend) as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + await broadcast.publish("chatroom", "hello") + event = await subscriber.get() + assert event.channel == "chatroom" + assert event.message == "hello" + + +@pytest.mark.asyncio +async def test_redis_requires_url_or_connection(): + with pytest.raises(AssertionError, match="conn must be provided if url is not"): + RedisBackend() + + @pytest.mark.asyncio async def test_redis_stream(): async with Broadcast("redis-stream://localhost:6379") as broadcast: