From 6fd3b75652dfb60b9032236bb9ab37ad70a5fe7b Mon Sep 17 00:00:00 2001 From: Me Date: Wed, 1 Jan 2025 09:37:38 -0700 Subject: [PATCH 1/3] =?UTF-8?q?=E2=9C=A8=20Add=20`RedisPydanticStream`=20b?= =?UTF-8?q?ackend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- broadcaster/_base.py | 5 +++ broadcaster/backends/redis.py | 57 +++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + requirements.txt | 2 +- tests/test_broadcast.py | 27 +++++++++++++++++ 5 files changed, 91 insertions(+), 1 deletion(-) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index a63b22b..1d1de35 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -43,6 +43,11 @@ def _create_backend(self, url: str) -> BroadcastBackend: return RedisStreamBackend(url) + elif parsed_url.scheme == "redis-pydantic-stream": + from broadcaster.backends.redis import RedisPydanticStreamBackend + + return RedisPydanticStreamBackend(url) + elif parsed_url.scheme in ("postgres", "postgresql"): from broadcaster.backends.postgres import PostgresBackend diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index 1be4195..fa3535d 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -1,9 +1,12 @@ from __future__ import annotations import asyncio +import inspect +import sys import typing from redis import asyncio as redis +from pydantic import BaseModel from .._base import Event from .base import BroadcastBackend @@ -108,3 +111,57 @@ async def next_published(self) -> Event: channel=stream.decode("utf-8"), message=message.get(b"message", b"").decode("utf-8"), ) + + +class RedisPydanticStreamBackend(RedisStreamBackend): + """Redis Stream backend for broadcasting messages using Pydantic models.""" + + def __init__(self: typing.Self, url: str) -> None: + """Create a new Redis Stream backend.""" + url = url.replace("redis-pydantic-stream", "redis", 1) + self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {} + self._ready = asyncio.Event() + self._producer = redis.Redis.from_url(url) + self._consumer = redis.Redis.from_url(url) + self._module_cache: dict[str, type(BaseModel)] = {} + + def _build_module_cache(self: typing.Self) -> None: + """Build a cache of Pydantic models.""" + modules = list(sys.modules.keys()) + for module_name in modules: + for _, obj in inspect.getmembers(sys.modules[module_name]): + if inspect.isclass(obj) and issubclass(obj, BaseModel): + self._module_cache[obj.__name__] = obj + + async def publish(self: typing.Self, channel: str, message: BaseModel) -> None: + """Publish a message to a channel.""" + msg_type: str = message.__class__.__name__ + message_json: str = message.model_dump_json() + await self._producer.xadd(channel, {"msg_type": msg_type, "message": message_json}) + + async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]: + """Wait for messages to be published.""" + await self._ready.wait() + self._build_module_cache() + messages = None + while not messages: + messages = await self._consumer.xread(self.streams, count=1, block=100) + return messages + + async def next_published(self: typing.Self) -> Event | None: + """Get the next published message.""" + messages = await self.wait_for_messages() + stream, events = messages[0] + _msg_id, message = events[0] + self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8") + msg_type = message.get(b"msg_type", b"").decode("utf-8") + message_data = message.get(b"message", b"").decode("utf-8") + message_obj: BaseModel | None = None + if msg_type in self._module_cache: + message_obj = self._module_cache[msg_type].model_validate_json(message_data) + if not message_obj: + return None + return Event( + channel=stream.decode("utf-8"), + message=message_obj, + ) diff --git a/pyproject.toml b/pyproject.toml index c4e8036..042e703 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ redis = ["redis"] postgres = ["asyncpg"] kafka = ["aiokafka"] +pydantic = ["pydantic", "redis"] test = ["pytest", "pytest-asyncio"] [project.urls] diff --git a/requirements.txt b/requirements.txt index ed2926b..b50e044 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ --e .[redis,postgres,kafka] +-e .[redis,postgres,kafka,pydantic] # Documentation mkdocs==1.5.3 diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index a8bd3eb..cb07fb2 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -5,10 +5,17 @@ import pytest +from pydantic import BaseModel + from broadcaster import Broadcast, BroadcastBackend, Event from broadcaster.backends.kafka import KafkaBackend +class PydanticEvent(BaseModel): + event: str + data: str + + class CustomBackend(BroadcastBackend): def __init__(self, url: str): self._subscribed: set[str] = set() @@ -71,6 +78,26 @@ async def test_redis_stream(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_redis_pydantic_stream(): + async with Broadcast("redis-pydantic-stream://localhost:6379") as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + message = PydanticEvent(event="on_message", data="hello") + await broadcast.publish("chatroom", message) + event = await subscriber.get() + assert event.channel == "chatroom" + assert isinstance(event.message, PydanticEvent) + assert event.message.event == message.event + assert event.message.data == message.data + async with broadcast.subscribe("chatroom1") as subscriber: + await broadcast.publish("chatroom1", message) + event = await subscriber.get() + assert event.channel == "chatroom1" + assert isinstance(event.message, PydanticEvent) + assert event.message.event == message.event + assert event.message.data == message.data + + @pytest.mark.asyncio async def test_postgres(): async with Broadcast("postgres://postgres:postgres@localhost:5432/broadcaster") as broadcast: From 1086033e46f09549ae854c42c09630b4ad56f722 Mon Sep 17 00:00:00 2001 From: Me Date: Wed, 1 Jan 2025 11:09:38 -0700 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=9A=A8=20Fixing=20mypy's=20complaints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- broadcaster/_base.py | 4 +++- broadcaster/backends/redis.py | 18 +++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/broadcaster/_base.py b/broadcaster/_base.py index 1d1de35..8c4b42a 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -5,12 +5,14 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, cast from urllib.parse import urlparse +from pydantic import BaseModel + if TYPE_CHECKING: # pragma: no cover from broadcaster.backends.base import BroadcastBackend class Event: - def __init__(self, channel: str, message: str) -> None: + def __init__(self, channel: str, message: str | BaseModel) -> None: self.channel = channel self.message = message diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index fa3535d..48e09d6 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -116,16 +116,16 @@ async def next_published(self) -> Event: class RedisPydanticStreamBackend(RedisStreamBackend): """Redis Stream backend for broadcasting messages using Pydantic models.""" - def __init__(self: typing.Self, url: str) -> None: + def __init__(self, url: str) -> None: """Create a new Redis Stream backend.""" url = url.replace("redis-pydantic-stream", "redis", 1) self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {} self._ready = asyncio.Event() self._producer = redis.Redis.from_url(url) self._consumer = redis.Redis.from_url(url) - self._module_cache: dict[str, type(BaseModel)] = {} + self._module_cache: dict[str, type[BaseModel]] = {} - def _build_module_cache(self: typing.Self) -> None: + def _build_module_cache(self) -> None: """Build a cache of Pydantic models.""" modules = list(sys.modules.keys()) for module_name in modules: @@ -133,13 +133,17 @@ def _build_module_cache(self: typing.Self) -> None: if inspect.isclass(obj) and issubclass(obj, BaseModel): self._module_cache[obj.__name__] = obj - async def publish(self: typing.Self, channel: str, message: BaseModel) -> None: + async def publish(self, channel: str, message: BaseModel) -> None: """Publish a message to a channel.""" msg_type: str = message.__class__.__name__ + + if msg_type not in self._module_cache: + self._module_cache[msg_type] = message.__class__ + message_json: str = message.model_dump_json() await self._producer.xadd(channel, {"msg_type": msg_type, "message": message_json}) - async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]: + async def wait_for_messages(self) -> list[StreamMessageType]: """Wait for messages to be published.""" await self._ready.wait() self._build_module_cache() @@ -148,7 +152,7 @@ async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]: messages = await self._consumer.xread(self.streams, count=1, block=100) return messages - async def next_published(self: typing.Self) -> Event | None: + async def next_published(self) -> Event: """Get the next published message.""" messages = await self.wait_for_messages() stream, events = messages[0] @@ -160,7 +164,7 @@ async def next_published(self: typing.Self) -> Event | None: if msg_type in self._module_cache: message_obj = self._module_cache[msg_type].model_validate_json(message_data) if not message_obj: - return None + return Event(stream.decode("utf-8"), message_data) return Event( channel=stream.decode("utf-8"), message=message_obj, From bf99ea783acf0ee988c34221216d579b9525b066 Mon Sep 17 00:00:00 2001 From: Me Date: Wed, 1 Jan 2025 11:12:52 -0700 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=9A=A8=20Fixing=20ruff's=20complaints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- broadcaster/backends/redis.py | 2 +- tests/test_broadcast.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index 48e09d6..a26d3ad 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -5,8 +5,8 @@ import sys import typing -from redis import asyncio as redis from pydantic import BaseModel +from redis import asyncio as redis from .._base import Event from .base import BroadcastBackend diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index cb07fb2..4cc1498 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -4,7 +4,6 @@ import typing import pytest - from pydantic import BaseModel from broadcaster import Broadcast, BroadcastBackend, Event