diff --git a/broadcaster/_base.py b/broadcaster/_base.py index a63b22b..8c4b42a 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -5,12 +5,14 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, cast from urllib.parse import urlparse +from pydantic import BaseModel + if TYPE_CHECKING: # pragma: no cover from broadcaster.backends.base import BroadcastBackend class Event: - def __init__(self, channel: str, message: str) -> None: + def __init__(self, channel: str, message: str | BaseModel) -> None: self.channel = channel self.message = message @@ -43,6 +45,11 @@ def _create_backend(self, url: str) -> BroadcastBackend: return RedisStreamBackend(url) + elif parsed_url.scheme == "redis-pydantic-stream": + from broadcaster.backends.redis import RedisPydanticStreamBackend + + return RedisPydanticStreamBackend(url) + elif parsed_url.scheme in ("postgres", "postgresql"): from broadcaster.backends.postgres import PostgresBackend diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index 1be4195..a26d3ad 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -1,8 +1,11 @@ from __future__ import annotations import asyncio +import inspect +import sys import typing +from pydantic import BaseModel from redis import asyncio as redis from .._base import Event @@ -108,3 +111,61 @@ async def next_published(self) -> Event: channel=stream.decode("utf-8"), message=message.get(b"message", b"").decode("utf-8"), ) + + +class RedisPydanticStreamBackend(RedisStreamBackend): + """Redis Stream backend for broadcasting messages using Pydantic models.""" + + def __init__(self, url: str) -> None: + """Create a new Redis Stream backend.""" + url = url.replace("redis-pydantic-stream", "redis", 1) + self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {} + self._ready = asyncio.Event() + self._producer = redis.Redis.from_url(url) + self._consumer = redis.Redis.from_url(url) + self._module_cache: dict[str, type[BaseModel]] = {} + + def _build_module_cache(self) -> None: + """Build a cache of Pydantic models.""" + modules = list(sys.modules.keys()) + for module_name in modules: + for _, obj in inspect.getmembers(sys.modules[module_name]): + if inspect.isclass(obj) and issubclass(obj, BaseModel): + self._module_cache[obj.__name__] = obj + + async def publish(self, channel: str, message: BaseModel) -> None: + """Publish a message to a channel.""" + msg_type: str = message.__class__.__name__ + + if msg_type not in self._module_cache: + self._module_cache[msg_type] = message.__class__ + + message_json: str = message.model_dump_json() + await self._producer.xadd(channel, {"msg_type": msg_type, "message": message_json}) + + async def wait_for_messages(self) -> list[StreamMessageType]: + """Wait for messages to be published.""" + await self._ready.wait() + self._build_module_cache() + messages = None + while not messages: + messages = await self._consumer.xread(self.streams, count=1, block=100) + return messages + + async def next_published(self) -> Event: + """Get the next published message.""" + messages = await self.wait_for_messages() + stream, events = messages[0] + _msg_id, message = events[0] + self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8") + msg_type = message.get(b"msg_type", b"").decode("utf-8") + message_data = message.get(b"message", b"").decode("utf-8") + message_obj: BaseModel | None = None + if msg_type in self._module_cache: + message_obj = self._module_cache[msg_type].model_validate_json(message_data) + if not message_obj: + return Event(stream.decode("utf-8"), message_data) + return Event( + channel=stream.decode("utf-8"), + message=message_obj, + ) diff --git a/pyproject.toml b/pyproject.toml index c4e8036..042e703 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ redis = ["redis"] postgres = ["asyncpg"] kafka = ["aiokafka"] +pydantic = ["pydantic", "redis"] test = ["pytest", "pytest-asyncio"] [project.urls] diff --git a/requirements.txt b/requirements.txt index ed2926b..b50e044 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ --e .[redis,postgres,kafka] +-e .[redis,postgres,kafka,pydantic] # Documentation mkdocs==1.5.3 diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index a8bd3eb..4cc1498 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -4,11 +4,17 @@ import typing import pytest +from pydantic import BaseModel from broadcaster import Broadcast, BroadcastBackend, Event from broadcaster.backends.kafka import KafkaBackend +class PydanticEvent(BaseModel): + event: str + data: str + + class CustomBackend(BroadcastBackend): def __init__(self, url: str): self._subscribed: set[str] = set() @@ -71,6 +77,26 @@ async def test_redis_stream(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_redis_pydantic_stream(): + async with Broadcast("redis-pydantic-stream://localhost:6379") as broadcast: + async with broadcast.subscribe("chatroom") as subscriber: + message = PydanticEvent(event="on_message", data="hello") + await broadcast.publish("chatroom", message) + event = await subscriber.get() + assert event.channel == "chatroom" + assert isinstance(event.message, PydanticEvent) + assert event.message.event == message.event + assert event.message.data == message.data + async with broadcast.subscribe("chatroom1") as subscriber: + await broadcast.publish("chatroom1", message) + event = await subscriber.get() + assert event.channel == "chatroom1" + assert isinstance(event.message, PydanticEvent) + assert event.message.event == message.event + assert event.message.data == message.data + + @pytest.mark.asyncio async def test_postgres(): async with Broadcast("postgres://postgres:postgres@localhost:5432/broadcaster") as broadcast: