From f74cce954b54a6a3d5b9e12c640b899b5905a615 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 16:38:17 +0300 Subject: [PATCH 1/5] Refactoring the NNG support solution v1 --- taskiq/brokers/nng/__init__.py | 24 + taskiq/brokers/nng/broker.py | 328 ++++++++++++++ taskiq/brokers/nng/hub.py | 482 +++++++++++++++++++++ taskiq/brokers/nng/protocol.py | 159 +++++++ taskiq/brokers/nng/storage.py | 722 +++++++++++++++++++++++++++++++ taskiq/brokers/nng_broker.py | 48 -- tests/brokers/test_nng_broker.py | 576 ++++++++++++++++++++++++ 7 files changed, 2291 insertions(+), 48 deletions(-) create mode 100644 taskiq/brokers/nng/__init__.py create mode 100644 taskiq/brokers/nng/broker.py create mode 100644 taskiq/brokers/nng/hub.py create mode 100644 taskiq/brokers/nng/protocol.py create mode 100644 taskiq/brokers/nng/storage.py delete mode 100644 taskiq/brokers/nng_broker.py create mode 100644 tests/brokers/test_nng_broker.py diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py new file mode 100644 index 00000000..0d0a2946 --- /dev/null +++ b/taskiq/brokers/nng/__init__.py @@ -0,0 +1,24 @@ +from hub import HubConfig, NNGHub +from protocol import ( + ControlMessage, + ControlResponse, + MessageKind, + TaskEnvelope, + WorkerState, + WorkerStatus, +) +from storage import QueueFullError, SQLiteJournal, StoreConfig + +__all__ = [ + 'HubConfig', + 'NNGHub', + 'ControlMessage', + 'ControlResponse', + 'MessageKind', + 'TaskEnvelope', + 'WorkerState', + 'WorkerStatus', + 'QueueFullError', + 'SQLiteJournal', + 'StoreConfig', +] diff --git a/taskiq/brokers/nng/broker.py b/taskiq/brokers/nng/broker.py new file mode 100644 index 00000000..6961cbeb --- /dev/null +++ b/taskiq/brokers/nng/broker.py @@ -0,0 +1,328 @@ +"""NNG broker for taskiq — backed by a standalone :class:`NNGHub`.""" +from __future__ import annotations + +import asyncio +import base64 +import logging +import os +import tempfile +import time +import uuid +from collections.abc import AsyncGenerator, Callable +from contextlib import suppress +from typing import Any, TypeVar + +from taskiq.abc.broker import AsyncBroker +from taskiq.abc.result_backend import AsyncResultBackend +from taskiq.acks import AckableMessage +from taskiq.message import BrokerMessage + +from protocol import ( + ControlMessage, + ControlResponse, + TaskEnvelope, + WorkerState, + WorkerStatus, +) + +try: + import pynng # type: ignore +except ImportError: + pynng = None # type: ignore[assignment] + +_T = TypeVar("_T") + +logger = logging.getLogger(__name__) + + +def _ipc_addr(prefix: str = "taskiq-nng") -> str: + name = f"{prefix}-{os.getpid()}-{uuid.uuid4().hex[:8]}.ipc" + return f"ipc://{os.path.join(tempfile.gettempdir(), name)}" + + +class NNGBroker(AsyncBroker): + """ + Taskiq broker backed by a standalone :class:`~taskiq.brokers.nng_hub.NNGHub`. + + The hub must be running before workers or clients start. Launch it with:: + + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc + + **Client mode** (``is_worker_process = False``) + Only the control socket is opened. :meth:`kick` submits tasks to the + hub via a Req0 → Rep0 round-trip. + + **Worker mode** (``is_worker_process = True``) + In addition to the control socket the broker opens a unique Pull0 + socket, registers with the hub, and runs a heartbeat loop. + :meth:`listen` yields :class:`~taskiq.acks.AckableMessage` instances + whose ``ack`` callback sends the correct ``lease_id`` back to the hub. + + Thread / coroutine safety + ───────────────────────── + ``Req0`` is strictly serial (one request in-flight per socket). + ``_ctrl_lock`` serialises all :meth:`_send_control` calls so that + concurrent coroutines (heartbeat + ack + kick) never interleave frames. + + Ack correctness + ─────────────── + The hub embeds the dispatch-generated ``lease_id`` inside every + :class:`~taskiq.brokers.nng_protocol.TaskEnvelope`. The ack closure + captures it directly, so validation on the hub side always succeeds for + genuine acks and correctly rejects late/duplicate ones. + """ + + def __init__( + self, + control_addr: str, + *, + result_backend: "AsyncResultBackend[_T] | None" = None, + task_id_generator: Callable[[], str] | None = None, + worker_task_addr: str | None = None, + worker_id: str | None = None, + heartbeat_interval: float = 5.0, + lease_timeout: float = 20.0, + capacity: int = 1, + max_retries: int = 0, + retry_backoff: float = 1.0, + retry_jitter: float = 0.0, + recv_timeout_ms: int = 5_000, + send_timeout_ms: int = 5_000, + ) -> None: + """ + Initialise the NNG broker. + + :param control_addr: NNG address of the hub's Rep0 control socket. + :param result_backend: optional result backend. + :param task_id_generator: optional task ID generator. + :param worker_task_addr: NNG address this worker's Pull0 listens on. + Defaults to a unique per-process IPC path. + :param worker_id: stable identifier for this worker process. + Defaults to ``-``. + :param heartbeat_interval: seconds between heartbeat messages to hub. + :param lease_timeout: seconds a dispatched task lease remains valid. + :param capacity: max concurrent tasks this worker will accept. + :param max_retries: default max retries for submitted tasks. + :param retry_backoff: base seconds for exponential backoff. + :param retry_jitter: jitter multiplier added to backoff (0 = no jitter). + :param recv_timeout_ms: Req0 recv timeout in milliseconds. + :param send_timeout_ms: Req0 send timeout in milliseconds. + """ + if pynng is None: + raise RuntimeError( + "pynng is required to use NNGBroker. " + "Install it with: pip install taskiq[nng]", + ) + super().__init__( + result_backend=result_backend, + task_id_generator=task_id_generator, + ) + self.control_addr = control_addr + self.worker_task_addr = worker_task_addr or _ipc_addr() + self.worker_id = worker_id or f"{os.getpid()}-{uuid.uuid4().hex[:12]}" + self.heartbeat_interval = heartbeat_interval + self.lease_timeout = lease_timeout + self.capacity = capacity + self.max_retries = max_retries + self.retry_backoff = retry_backoff + self.retry_jitter = retry_jitter + self.recv_timeout_ms = recv_timeout_ms + self.send_timeout_ms = send_timeout_ms + + self._ctrl_sock: Any = None # pynng.Req0 + self._task_sock: Any = None # pynng.Pull0 (worker mode only) + self._heartbeat_task: asyncio.Task[None] | None = None + # Req0 allows exactly one request in-flight; this lock enforces that. + self._ctrl_lock = asyncio.Lock() + + # ── lifecycle ───────────────────────────────────────────────────────────── + + async def startup(self) -> None: + """Open sockets, register with hub (worker mode), and start heartbeat.""" + self._ctrl_sock = pynng.Req0( + dial=self.control_addr, + recv_timeout=self.recv_timeout_ms, + send_timeout=self.send_timeout_ms, + ) + if self.is_worker_process: + # recv_buffer_size lets the hub pre-queue up to `capacity` task + # messages in NNG's recv buffer before listen() calls arecv(). + self._task_sock = pynng.Pull0( + listen=self.worker_task_addr, + recv_buffer_size=self.capacity, + ) + resp = await self._send_control( + "register", + { + "worker_id": self.worker_id, + "task_addr": self.worker_task_addr, + "capacity": self.capacity, + "inflight": 0, + "last_seen": time.time(), + "heartbeat_interval": self.heartbeat_interval, + "lease_timeout": self.lease_timeout, + "draining": False, + "status": str(WorkerStatus.STARTING), + "version": "taskiq-nng", + }, + ) + if not resp.ok: + raise RuntimeError(f"Worker registration failed: {resp.error}") + logger.info( + "Worker %s registered with hub at %s", + self.worker_id, + self.control_addr, + ) + self._heartbeat_task = asyncio.create_task( + self._heartbeat_loop(), + name=f"nng-hb-{self.worker_id[:8]}", + ) + await super().startup() + + async def shutdown(self) -> None: + """Drain, unregister, cancel heartbeat, and close all sockets.""" + if self.is_worker_process: + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + with suppress(asyncio.CancelledError): + await self._heartbeat_task + if self._ctrl_sock is not None: + with suppress(Exception): + await self._send_control( + "drain", {"worker_id": self.worker_id}, + ) + await self._send_control( + "unregister", {"worker_id": self.worker_id}, + ) + if self._task_sock is not None: + with suppress(Exception): + self._task_sock.close() + if self._ctrl_sock is not None: + with suppress(Exception): + self._ctrl_sock.close() + await super().shutdown() + + # ── internal helpers ────────────────────────────────────────────────────── + + async def _send_control( + self, kind: str, payload: dict[str, Any], + ) -> ControlResponse: + if self._ctrl_sock is None: + raise RuntimeError("Control socket is not open (call startup() first)") + async with self._ctrl_lock: + await self._ctrl_sock.asend( + ControlMessage(kind=kind, payload=payload).to_bytes(), + ) + raw = await self._ctrl_sock.arecv() + return ControlResponse.from_bytes(raw) + + async def _heartbeat_loop(self) -> None: + while True: + try: + await asyncio.sleep(self.heartbeat_interval) + resp = await self._send_control( + "heartbeat", {"worker_id": self.worker_id}, + ) + if not resp.ok: + logger.warning("Heartbeat rejected by hub: %s", resp.error) + except asyncio.CancelledError: + raise + except Exception as exc: + # Hub may be temporarily unreachable; log and keep trying. + logger.warning("Heartbeat failed: %s", exc) + + # ── AsyncBroker API ─────────────────────────────────────────────────────── + + async def kick(self, message: BrokerMessage) -> None: + """ + Submit a task to the hub for dispatch. + + :param message: broker message to submit. + :raises RuntimeError: if the broker has not been started or the hub + rejects the submission (e.g. queue full). + """ + if self._ctrl_sock is None: + raise RuntimeError("Broker is not started") + payload: dict[str, Any] = { + "task_id": message.task_id, + "task_name": message.task_name, + "payload_b64": base64.b64encode(message.message).decode("ascii"), + "labels": message.labels, + "lease_id": "", # hub assigns the real lease_id at dispatch time + "attempts": int(message.labels.get("attempts", 0)), + "max_retries": int( + message.labels.get("max_retries", self.max_retries), + ), + "retry_backoff": float( + message.labels.get("retry_backoff", self.retry_backoff), + ), + "retry_jitter": float( + message.labels.get("retry_jitter", self.retry_jitter), + ), + "priority": int(message.labels.get("priority", 0)), + "created_at": time.time(), + } + resp = await self._send_control("submit", payload) + if not resp.ok: + raise RuntimeError(resp.error or "task submission failed") + + async def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: + """ + Yield incoming tasks as :class:`~taskiq.acks.AckableMessage` instances. + + Each message's ``ack`` callback sends the hub-issued ``lease_id`` back + so the hub can validate the ack and reject any late/duplicate ones. + + :raises RuntimeError: if called outside worker mode or before startup. + :yields: ackable task messages. + """ + if not self.is_worker_process: + raise RuntimeError("listen() is only valid in worker mode") + if self._task_sock is None: + raise RuntimeError("Task socket is not open (call startup() first)") + + while True: + try: + raw = await self._task_sock.arecv() + except pynng.Closed: + logger.info("Task socket closed; stopping listen()") + return + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Task arecv error: %s", exc) + continue + + try: + envelope = TaskEnvelope.from_bytes(raw) + except Exception as exc: + logger.error("Malformed task envelope discarded: %s", exc) + continue + + task_id = envelope.task_id + worker_id = self.worker_id + lease_id = envelope.lease_id # hub-assigned; correct by construction + + async def _ack( + _task_id: str = task_id, + _worker_id: str = worker_id, + _lease_id: str = lease_id, + ) -> None: + try: + resp = await self._send_control( + "ack", + { + "task_id": _task_id, + "worker_id": _worker_id, + "lease_id": _lease_id, + }, + ) + if not resp.ok: + logger.debug( + "Ack rejected for %s (late/duplicate): %s", + _task_id, resp.error, + ) + except Exception as exc: + logger.warning("Ack send failed for %s: %s", _task_id, exc) + + yield AckableMessage(data=envelope.payload, ack=_ack) diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py new file mode 100644 index 00000000..844055c5 --- /dev/null +++ b/taskiq/brokers/nng/hub.py @@ -0,0 +1,482 @@ +""" +NNG hub: central control plane, task dispatcher, and lease manager. + +Run as a standalone process:: + + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc \\ + --task-db /var/lib/taskiq/tasks.db + +Or embed it in an application for testing:: + + hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc", task_db=":memory:")) + await hub.start() + ... + await hub.stop() +""" +from __future__ import annotations + +import argparse +import asyncio +import base64 +import json +import logging +import os +import signal +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress +from dataclasses import dataclass, field +from typing import Any + +try: + import pynng # type: ignore +except ImportError: + pynng = None # type: ignore[assignment] + +from protocol import ( + ControlMessage, + ControlResponse, + TaskEnvelope, + WorkerState, +) +from storage import QueueFullError, SQLiteJournal, StoreConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class HubConfig: + """Configuration for :class:`NNGHub`.""" + + control_addr: str + task_db: str + max_pending: int = 10_000 + heartbeat_timeout: float = 15.0 + lease_timeout: float = 20.0 + dispatch_interval: float = 0.05 + reaper_interval: float = 0.5 + routing_policy: str = "least_loaded" + backoff_cap: float = 60.0 + # Number of concurrent Rep0 contexts. Each context handles one req/rep + # pair independently; N contexts ≈ N simultaneous control-plane clients. + control_concurrency: int = 16 + dispatch_batch: int = 50 + # Per-context recv timeout in ms. Allows the stop event to be checked + # even when there are no incoming messages. + recv_timeout_ms: int = 1_000 + + +class NNGHub: + """ + Stateful central hub: control plane, task dispatcher, and lease manager. + + Architecture + ──────────── + **Control plane** — ``Rep0`` socket with ``control_concurrency`` + independent ``nng_ctx`` contexts running concurrently. Each context + handles one request-reply at a time, so N workers can + register/heartbeat/ack simultaneously without queuing behind each other. + This is the key fix over the single-context (serial) Rep0 in v2. + + **Data plane** — One ``Push0`` socket per registered worker, dialed to + the worker's own ``Pull0`` listen address. The hub explicitly targets + the least-loaded worker instead of relying on NNG round-robin, giving + us load-aware routing. + + **Persistence** — :class:`~taskiq.brokers.nng_storage.SQLiteJournal` in + WAL mode. All storage calls are executed on a single-threaded + ``ThreadPoolExecutor`` so the asyncio event loop is never blocked and + SQLite write serialisation is guaranteed. + + **Recovery** — On startup, tasks leased to workers that died during the + previous hub session are automatically requeued. + """ + + def __init__(self, config: HubConfig) -> None: + """ + Initialise the hub with the given configuration. + + :param config: hub configuration. + """ + if pynng is None: + raise RuntimeError( + "pynng is required to use NNGHub. " + "Install it with: pip install taskiq[nng]" + ) + self.config = config + self.store = SQLiteJournal( + StoreConfig( + path=config.task_db, + max_pending=config.max_pending, + lease_timeout=config.lease_timeout, + backoff_cap=config.backoff_cap, + ), + ) + self._stop = asyncio.Event() + self._ctrl_sock: Any = None # pynng.Rep0 + self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 + self._tasks: list[asyncio.Task[None]] = [] + # Single-threaded executor: serialises all SQLite calls on one OS thread. + self._db_exec = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="nng-db" + ) + + # ── lifecycle ───────────────────────────────────────────────────────────── + + async def start(self) -> None: + """Start the hub: recover orphaned tasks, open sockets, spawn loops.""" + await self._db(self.store.recover_dead_workers, self.config.heartbeat_timeout) + + self._ctrl_sock = pynng.Rep0(listen=self.config.control_addr) + self._ctrl_sock.recv_timeout = self.config.recv_timeout_ms + + self._tasks = [ + asyncio.create_task(self._dispatch_loop(), name="hub-dispatch"), + asyncio.create_task(self._reaper_loop(), name="hub-reaper"), + ] + for i in range(self.config.control_concurrency): + ctx = self._ctrl_sock.new_context() + self._tasks.append( + asyncio.create_task( + self._control_handler(ctx), name=f"hub-ctrl-{i}" + ), + ) + logger.info( + "NNG hub started on %s (db=%s)", + self.config.control_addr, + self.config.task_db, + ) + + async def stop(self) -> None: + """Gracefully stop all hub loops and close sockets.""" + logger.info("NNG hub stopping…") + self._stop.set() + for t in self._tasks: + t.cancel() + with suppress(asyncio.CancelledError): + await t + for sock in self._worker_push.values(): + with suppress(Exception): + sock.close() + self._worker_push.clear() + if self._ctrl_sock is not None: + with suppress(Exception): + self._ctrl_sock.close() + self._db_exec.shutdown(wait=True) + logger.info("NNG hub stopped") + + # ── DB helper ───────────────────────────────────────────────────────────── + + async def _db(self, fn: Any, *args: Any, **kwargs: Any) -> Any: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self._db_exec, lambda: fn(*args, **kwargs) + ) + + # ── control plane ───────────────────────────────────────────────────────── + + async def _control_handler(self, ctx: Any) -> None: + """Run one Rep0 context: receive → dispatch → reply, in a loop.""" + while not self._stop.is_set(): + try: + raw = await ctx.arecv() + except pynng.Timeout: + continue + except (pynng.Closed, asyncio.CancelledError): + break + except Exception as exc: + logger.warning("Control recv error: %s", exc) + continue + + try: + response = await self._handle(raw) + except Exception as exc: + logger.exception("Unhandled error in control handler") + response = ControlResponse(ok=False, error=str(exc)) + + try: + await ctx.asend(response.to_bytes()) + except (pynng.Closed, asyncio.CancelledError): + break + except Exception as exc: + logger.warning("Control send error: %s", exc) + + async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 + """Dispatch a raw control message to the appropriate handler.""" + msg = ControlMessage.from_bytes(raw) + + if msg.kind == "ping": + return ControlResponse(ok=True, payload={"pong": True}) + + if msg.kind == "submit": + return await self._handle_submit(msg.payload) + + if msg.kind == "register": + return await self._handle_register(msg.payload) + + if msg.kind == "heartbeat": + await self._db(self.store.heartbeat, msg.payload["worker_id"]) + return ControlResponse(ok=True, payload={"ok": True}) + + if msg.kind == "unregister": + return await self._handle_unregister(msg.payload["worker_id"]) + + if msg.kind == "drain": + await self._db(self.store.mark_draining, msg.payload["worker_id"]) + return ControlResponse(ok=True, payload={"draining": True}) + + if msg.kind == "ack": + ok = await self._db( + self.store.ack, + msg.payload["task_id"], + msg.payload["worker_id"], + msg.payload["lease_id"], + ) + return ControlResponse(ok=ok, payload={"acked": ok}) + + if msg.kind == "nack": + ok = await self._db( + self.store.nack, + msg.payload["task_id"], + msg.payload["worker_id"], + msg.payload["lease_id"], + msg.payload.get("error", "unknown error"), + ) + return ControlResponse(ok=ok, payload={"nacked": ok}) + + if msg.kind == "status": + task = await self._db(self.store.get_task, msg.payload["task_id"]) + return ControlResponse(ok=bool(task), payload=dict(task) if task else {}) + + if msg.kind == "stats": + s = await self._db(self.store.stats) + return ControlResponse(ok=True, payload=s) + + return ControlResponse(ok=False, error=f"unknown kind: {msg.kind!r}") + + async def _handle_submit(self, payload: dict[str, Any]) -> ControlResponse: + envelope = TaskEnvelope(**payload) + try: + await self._db(self.store.submit, envelope) + return ControlResponse(ok=True, payload={"task_id": envelope.task_id}) + except QueueFullError: + return ControlResponse(ok=False, error="queue full") + + async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse: + worker = WorkerState(**payload) + await self._db(self.store.register_worker, worker) + if worker.worker_id not in self._worker_push: + try: + sock = pynng.Push0(dial=worker.task_addr) + self._worker_push[worker.worker_id] = sock + except Exception as exc: + logger.error( + "Failed to dial worker %s at %s: %s", + worker.worker_id, worker.task_addr, exc, + ) + return ControlResponse(ok=False, error=f"dial failed: {exc}") + return ControlResponse(ok=True, payload={"registered": True}) + + async def _handle_unregister(self, worker_id: str) -> ControlResponse: + await self._db(self.store.unregister_worker, worker_id) + sock = self._worker_push.pop(worker_id, None) + if sock is not None: + with suppress(Exception): + sock.close() + return ControlResponse(ok=True, payload={"unregistered": True}) + + # ── dispatch loop ───────────────────────────────────────────────────────── + + async def _dispatch_loop(self) -> None: + while not self._stop.is_set(): + try: + sent = await self._dispatch_once() + if not sent: + await asyncio.sleep(self.config.dispatch_interval) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Dispatch loop error") + await asyncio.sleep(self.config.dispatch_interval) + + async def _dispatch_once(self) -> bool: + """Dispatch up to ``dispatch_batch`` due tasks to available workers.""" + due = await self._db(self.store.due_tasks, self.config.dispatch_batch) + if not due: + return False + sent_any = False + for row in due: + worker = await self._db( + self.store.choose_worker, + self.config.routing_policy, + heartbeat_timeout=self.config.heartbeat_timeout, + ) + if worker is None: + return sent_any # no capacity; leave remaining tasks in queue + + worker_id = worker["worker_id"] + lease_id = uuid.uuid4().hex + lease_until = time.time() + self.config.lease_timeout + + if not await self._db( + self.store.mark_leased, + row["task_id"], worker_id, lease_id, lease_until, + ): + continue # concurrent dispatch race; task already taken + + sock = self._worker_push.get(worker_id) + if sock is None: + logger.warning( + "No push socket for worker %s, requeueing %s", + worker_id, row["task_id"], + ) + await self._db( + self.store.nack, + row["task_id"], worker_id, lease_id, "no socket", + ) + continue + + # Include the hub-generated lease_id so the worker can ack with + # the exact token. Omitting it was the core correctness bug in v2. + envelope = TaskEnvelope( + task_id=row["task_id"], + task_name=row["task_name"], + payload_b64=base64.b64encode(row["payload"]).decode("ascii"), + labels=json.loads(row["labels_json"]), + lease_id=lease_id, + attempts=int(row["attempts"]) + 1, + max_retries=int(row["max_retries"]), + retry_backoff=float(row["retry_backoff"]), + retry_jitter=float(row["retry_jitter"]), + priority=int(row["priority"]), + created_at=float(row["created_at"]), + ) + try: + await sock.asend(envelope.to_bytes()) + sent_any = True + except Exception as exc: + logger.warning( + "Failed to deliver %s to worker %s: %s", + row["task_id"], worker_id, exc, + ) + await self._db( + self.store.nack, + row["task_id"], worker_id, lease_id, + f"dispatch send failed: {exc}", + ) + return sent_any + + # ── reaper loop ─────────────────────────────────────────────────────────── + + async def _reaper_loop(self) -> None: + while not self._stop.is_set(): + try: + await asyncio.sleep(self.config.reaper_interval) + reaped = await self._db(self.store.reap_expired_leases) + if reaped: + logger.debug("Reaped %d expired leases", reaped) + recovered = await self._db( + self.store.recover_dead_workers, + self.config.heartbeat_timeout, + ) + if recovered: + logger.info("Requeued %d tasks from dead workers", recovered) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Reaper loop error") + + +# ── standalone CLI entry point ──────────────────────────────────────────────── + +def _build_config() -> HubConfig: + p = argparse.ArgumentParser( + description="taskiq-nng-hub — NNG task router, dispatcher, and lease manager", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument( + "--control-addr", + default=os.getenv("NNG_CONTROL_ADDR", "ipc:///tmp/taskiq-nng.ipc"), + help="NNG address the hub listens on. Env: NNG_CONTROL_ADDR", + ) + p.add_argument( + "--task-db", + default=os.getenv("NNG_TASK_DB", "/tmp/taskiq-nng-tasks.db"), # noqa: S108 + help="Path to the SQLite WAL task journal. Env: NNG_TASK_DB", + ) + p.add_argument( + "--max-pending", + type=int, + default=int(os.getenv("NNG_MAX_PENDING", "10000")), + ) + p.add_argument( + "--heartbeat-timeout", + type=float, + default=float(os.getenv("NNG_HEARTBEAT_TIMEOUT", "15.0")), + help="Seconds of silence before a worker is declared dead.", + ) + p.add_argument( + "--lease-timeout", + type=float, + default=float(os.getenv("NNG_LEASE_TIMEOUT", "20.0")), + help="Seconds before an unacked task lease is reaped.", + ) + p.add_argument( + "--routing-policy", + choices=["least_loaded", "p2c"], + default=os.getenv("NNG_ROUTING_POLICY", "least_loaded"), + ) + p.add_argument( + "--control-concurrency", + type=int, + default=int(os.getenv("NNG_CONTROL_CONCURRENCY", "16")), + help="Number of concurrent Rep0 contexts.", + ) + p.add_argument( + "--log-level", + default=os.getenv("NNG_LOG_LEVEL", "INFO"), + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + ) + args = p.parse_args() + logging.basicConfig( + level=getattr(logging, args.log_level), + format="%(asctime)s %(name)-24s %(levelname)-8s %(message)s", + ) + return HubConfig( + control_addr=args.control_addr, + task_db=args.task_db, + max_pending=args.max_pending, + heartbeat_timeout=args.heartbeat_timeout, + lease_timeout=args.lease_timeout, + routing_policy=args.routing_policy, + control_concurrency=args.control_concurrency, + ) + + +async def _run(config: HubConfig) -> None: + hub = NNGHub(config) + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def _on_signal() -> None: + logger.info("Shutdown signal received") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, _on_signal) + + await hub.start() + try: + await stop_event.wait() + finally: + await hub.stop() + + +def main() -> None: + """Entry point for the ``taskiq-nng-hub`` CLI command.""" + config = _build_config() + try: + asyncio.run(_run(config)) + except KeyboardInterrupt: + pass diff --git a/taskiq/brokers/nng/protocol.py b/taskiq/brokers/nng/protocol.py new file mode 100644 index 00000000..9b0b4d8e --- /dev/null +++ b/taskiq/brokers/nng/protocol.py @@ -0,0 +1,159 @@ +"""Wire protocol types for the NNG broker.""" +from __future__ import annotations + +import base64 +import enum +import json +from dataclasses import asdict, dataclass, field +from typing import Any + + +class _StrValue(str, enum.Enum): + """Base for string enums whose str() returns the plain value (Python 3.10+).""" + + def __str__(self) -> str: + return self.value + + +class MessageKind(_StrValue): + """Kinds of control-plane messages sent between broker/client and hub.""" + + SUBMIT = "submit" + REGISTER = "register" + HEARTBEAT = "heartbeat" + UNREGISTER = "unregister" + DRAIN = "drain" + ACK = "ack" + NACK = "nack" + STATUS = "status" + STATS = "stats" + PING = "ping" + + +class TaskState(_StrValue): + """Lifecycle state of a task in the hub store.""" + + READY = "ready" + LEASED = "leased" + DONE = "done" + FAILED = "failed" + + +class WorkerStatus(_StrValue): + """Lifecycle status of a registered worker.""" + + STARTING = "starting" + LISTENING = "listening" + DRAINING = "draining" + OFFLINE = "offline" + DEAD = "dead" + + +@dataclass +class TaskEnvelope: + """ + Task payload sent from hub to worker over the data plane. + + ``lease_id`` is the UUID assigned by the hub at dispatch time. + Workers must echo it back in the ACK so the hub can validate + that the ack is not stale (e.g. after lease expiry and requeue). + """ + + task_id: str + task_name: str + payload_b64: str + labels: dict[str, Any] = field(default_factory=dict) + lease_id: str = "" + attempts: int = 0 + max_retries: int = 0 + retry_backoff: float = 1.0 + retry_jitter: float = 0.0 + priority: int = 0 + created_at: float = 0.0 + + @property + def payload(self) -> bytes: + """Decode the base-64 task payload.""" + return base64.b64decode(self.payload_b64.encode("ascii")) + + @classmethod + def from_bytes(cls, raw: bytes) -> TaskEnvelope: + """Deserialise from JSON bytes.""" + return cls(**json.loads(raw.decode("utf-8"))) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + asdict(self), separators=(",", ":"), ensure_ascii=False + ).encode("utf-8") + + +@dataclass +class ControlMessage: + """Request sent over the control plane (Req0 → Rep0).""" + + kind: str + payload: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_bytes(cls, raw: bytes) -> ControlMessage: + """Deserialise from JSON bytes.""" + data = json.loads(raw.decode("utf-8")) + return cls(kind=data["kind"], payload=data.get("payload", {})) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + {"kind": self.kind, "payload": self.payload}, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + + +@dataclass +class ControlResponse: + """Response sent back over the control plane (Rep0 → Req0).""" + + ok: bool + payload: dict[str, Any] = field(default_factory=dict) + error: str | None = None + + @classmethod + def from_bytes(cls, raw: bytes) -> ControlResponse: + """Deserialise from JSON bytes.""" + data = json.loads(raw.decode("utf-8")) + return cls( + ok=data["ok"], + payload=data.get("payload", {}), + error=data.get("error"), + ) + + def to_bytes(self) -> bytes: + """Serialise to JSON bytes.""" + return json.dumps( + {"ok": self.ok, "payload": self.payload, "error": self.error}, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + + +@dataclass +class WorkerState: + """Snapshot of a worker's identity and capacity at registration time.""" + + worker_id: str + task_addr: str + capacity: int + inflight: int = 0 + last_seen: float = 0.0 + heartbeat_interval: float = 5.0 + lease_timeout: float = 15.0 + draining: bool = False + status: WorkerStatus = WorkerStatus.STARTING + version: str = "unknown" + + def to_dict(self) -> dict[str, Any]: + """Convert to a plain dict, serialising the status enum to its string value.""" + d = asdict(self) + d["status"] = str(self.status) + return d diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py new file mode 100644 index 00000000..410a7064 --- /dev/null +++ b/taskiq/brokers/nng/storage.py @@ -0,0 +1,722 @@ +"""Durable WAL-mode SQLite task journal for the NNG hub.""" +from __future__ import annotations + +import json +import random +import sqlite3 +import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Generator + +from protocol import TaskEnvelope, WorkerState, WorkerStatus + + +@dataclass +class StoreConfig: + """Configuration for the SQLite task journal.""" + + path: str + max_pending: int = 10_000 + lease_timeout: float = 30.0 + backoff_base: float = 1.0 + backoff_cap: float = 60.0 + + +class QueueFullError(RuntimeError): + """Raised when a submission is attempted on a full queue.""" + + +class SQLiteJournal: + """ + Thread-safe, WAL-mode SQLite task store. + + Design notes + ──────────── + * Every method opens and closes its own connection. WAL allows concurrent + readers without blocking; SQLite serialises concurrent writers internally, + and the Python-level ``_submit_lock`` prevents the TOCTOU race in + :meth:`submit`. + * The hub runs every call through a single-threaded + ``ThreadPoolExecutor`` so, in practice, writes never contend at the + OS level either. + * ``PRAGMA`` settings (WAL, synchronous, busy_timeout) are applied per + connection because each ``sqlite3.connect()`` call starts with defaults. + """ + + def __init__(self, config: StoreConfig) -> None: + """Initialise the journal and create schema if not present.""" + self.config = config + # Guards only the pending_count check + INSERT pair in submit() to + # prevent concurrent callers from racing past max_pending. + self._submit_lock = threading.Lock() + self._init() + + # ── connection ──────────────────────────────────────────────────────────── + + @contextmanager + def _conn(self) -> Generator[sqlite3.Connection, None, None]: + conn = sqlite3.connect( + self.config.path, + timeout=10.0, + check_same_thread=False, + isolation_level=None, # we manage transactions explicitly + ) + conn.row_factory = sqlite3.Row + # Must be set per-connection, not just once at schema creation. + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") # safe with WAL; faster than FULL + conn.execute("PRAGMA busy_timeout=5000") # wait up to 5s before SQLITE_BUSY + conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache + try: + yield conn + finally: + conn.close() + + def _init(self) -> None: + Path(self.config.path).parent.mkdir(parents=True, exist_ok=True) + with self._conn() as conn: + conn.executescript(""" + CREATE TABLE IF NOT EXISTS tasks ( + task_id TEXT PRIMARY KEY, + task_name TEXT NOT NULL, + payload BLOB NOT NULL, + labels_json TEXT NOT NULL DEFAULT '{}', + state TEXT NOT NULL, + attempts INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 0, + retry_backoff REAL NOT NULL DEFAULT 1.0, + retry_jitter REAL NOT NULL DEFAULT 0.0, + priority INTEGER NOT NULL DEFAULT 0, + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + next_run_at REAL NOT NULL, + lease_id TEXT, + leased_worker_id TEXT, + lease_until REAL, + last_error TEXT + ); + + CREATE TABLE IF NOT EXISTS workers ( + worker_id TEXT PRIMARY KEY, + task_addr TEXT NOT NULL, + capacity INTEGER NOT NULL, + inflight INTEGER NOT NULL DEFAULT 0, + last_seen REAL NOT NULL DEFAULT 0, + heartbeat_interval REAL NOT NULL DEFAULT 5.0, + lease_timeout REAL NOT NULL DEFAULT 15.0, + draining INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL, + version TEXT NOT NULL DEFAULT 'unknown' + ); + + CREATE TABLE IF NOT EXISTS journal ( + seq INTEGER PRIMARY KEY AUTOINCREMENT, + ts REAL NOT NULL, + kind TEXT NOT NULL, + payload_json TEXT NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_tasks_dispatch + ON tasks (state, next_run_at, priority DESC); + CREATE INDEX IF NOT EXISTS idx_tasks_lease + ON tasks (state, lease_until); + CREATE INDEX IF NOT EXISTS idx_workers_active + ON workers (status, draining, last_seen); + """) + + # ── helpers ─────────────────────────────────────────────────────────────── + + def _journal( + self, + conn: sqlite3.Connection, + kind: str, + payload: dict[str, Any], + ) -> None: + conn.execute( + "INSERT INTO journal (ts, kind, payload_json) VALUES (?, ?, ?)", + ( + time.time(), + kind, + json.dumps(payload, separators=(",", ":"), ensure_ascii=False), + ), + ) + + def _backoff(self, attempts: int, backoff_base: float) -> float: + return min(self.config.backoff_cap, backoff_base * (2 ** max(0, attempts - 1))) + + # ── task lifecycle ──────────────────────────────────────────────────────── + + def pending_count(self) -> int: + """Return the number of ready + leased tasks.""" + with self._conn() as conn: + return int( + conn.execute( + "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", + ).fetchone()[0], + ) + + def submit(self, envelope: TaskEnvelope) -> None: + """ + Persist a new task in 'ready' state. + + :param envelope: task envelope to store. + :raises QueueFullError: when ``max_pending`` is reached. + """ + now = time.time() + with self._submit_lock, self._conn() as conn: + count = conn.execute( + "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", + ).fetchone()[0] + if count >= self.config.max_pending: + raise QueueFullError("Task queue is full.") + conn.execute("BEGIN") + conn.execute( + """ + INSERT INTO tasks ( + task_id, task_name, payload, labels_json, state, + attempts, max_retries, retry_backoff, retry_jitter, + priority, created_at, updated_at, next_run_at + ) VALUES (?, ?, ?, ?, 'ready', 0, ?, ?, ?, ?, ?, ?, ?) + """, + ( + envelope.task_id, + envelope.task_name, + envelope.payload, + json.dumps( + envelope.labels, separators=(",", ":"), ensure_ascii=False + ), + envelope.max_retries, + envelope.retry_backoff, + envelope.retry_jitter, + envelope.priority, + envelope.created_at or now, + now, + now, + ), + ) + self._journal( + conn, + "task_submitted", + {"task_id": envelope.task_id, "task_name": envelope.task_name}, + ) + conn.execute("COMMIT") + + def due_tasks(self, limit: int = 50) -> list[sqlite3.Row]: + """ + Return ready tasks whose ``next_run_at`` is in the past. + + Results are ordered by priority (descending) then creation time. + + :param limit: maximum number of rows to return. + :return: list of task rows. + """ + now = time.time() + with self._conn() as conn: + return list( + conn.execute( + """ + SELECT * FROM tasks + WHERE state = 'ready' AND next_run_at <= ? + ORDER BY priority DESC, created_at ASC + LIMIT ? + """, + (now, limit), + ), + ) + + def mark_leased( + self, + task_id: str, + worker_id: str, + lease_id: str, + lease_until: float, + ) -> bool: + """ + Atomically transition a task from 'ready' to 'leased'. + + :param task_id: task to lease. + :param worker_id: worker receiving the task. + :param lease_id: unique token for this dispatch attempt. + :param lease_until: absolute epoch deadline for the lease. + :return: True if the transition succeeded; False if the task was + already taken (concurrent dispatch race). + """ + now = time.time() + with self._conn() as conn: + row = conn.execute( + "SELECT state FROM tasks WHERE task_id = ?", (task_id,) + ).fetchone() + if not row or row["state"] != "ready": + return False + conn.execute("BEGIN") + conn.execute( + """ + UPDATE tasks + SET state = 'leased', + leased_worker_id = ?, lease_id = ?, lease_until = ?, + attempts = attempts + 1, updated_at = ? + WHERE task_id = ? + """, + (worker_id, lease_id, lease_until, now, task_id), + ) + conn.execute( + "UPDATE workers SET inflight = inflight + 1 WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "task_leased", + { + "task_id": task_id, + "worker_id": worker_id, + "lease_id": lease_id, + }, + ) + conn.execute("COMMIT") + return True + + def ack(self, task_id: str, worker_id: str, lease_id: str) -> bool: + """ + Mark a task as successfully completed. + + Late or duplicate acks (mismatched ``lease_id`` or state ≠ 'leased') + are silently rejected and return False. + + :param task_id: task being acknowledged. + :param worker_id: worker sending the ack. + :param lease_id: lease token that was issued at dispatch. + :return: True if the ack was accepted. + """ + now = time.time() + with self._conn() as conn: + row = conn.execute( + "SELECT state, lease_id, leased_worker_id FROM tasks WHERE task_id = ?", + (task_id,), + ).fetchone() + if not row or row["state"] != "leased": + return False + if row["lease_id"] != lease_id or row["leased_worker_id"] != worker_id: + return False + conn.execute("BEGIN") + conn.execute( + """ + UPDATE tasks + SET state = 'done', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL + WHERE task_id = ? + """, + (now, task_id), + ) + conn.execute( + "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "task_acked", + { + "task_id": task_id, + "worker_id": worker_id, + "lease_id": lease_id, + }, + ) + conn.execute("COMMIT") + return True + + def nack( + self, task_id: str, worker_id: str, lease_id: str, error: str + ) -> bool: + """ + Explicitly fail a task, triggering retry or permanent failure. + + :param task_id: task being nacked. + :param worker_id: worker sending the nack. + :param lease_id: lease token issued at dispatch. + :param error: human-readable reason for the failure. + :return: True if the nack was accepted. + """ + return self._requeue_or_fail(task_id, worker_id, lease_id, error) + + def _requeue_or_fail( + self, task_id: str, worker_id: str, lease_id: str, error: str + ) -> bool: + now = time.time() + with self._conn() as conn: + row = conn.execute( + "SELECT * FROM tasks WHERE task_id = ?", (task_id,) + ).fetchone() + if ( + not row + or row["state"] != "leased" + or row["lease_id"] != lease_id + or row["leased_worker_id"] != worker_id + ): + return False + attempts = int(row["attempts"]) + max_retries = int(row["max_retries"]) + conn.execute("BEGIN") + if attempts > max_retries: + conn.execute( + """ + UPDATE tasks + SET state = 'failed', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = ? + WHERE task_id = ? + """, + (now, error, task_id), + ) + else: + backoff = self._backoff(attempts, float(row["retry_backoff"])) + conn.execute( + """ + UPDATE tasks + SET state = 'ready', updated_at = ?, next_run_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = ? + WHERE task_id = ? + """, + (now, now + backoff, error, task_id), + ) + conn.execute( + "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "task_nacked", + { + "task_id": task_id, + "worker_id": worker_id, + "lease_id": lease_id, + "error": error, + "requeued": attempts <= max_retries, + }, + ) + conn.execute("COMMIT") + return True + + # ── reaper / recovery ───────────────────────────────────────────────────── + + def reap_expired_leases(self) -> int: + """ + Find leases past their deadline and requeue or permanently fail them. + + :return: number of tasks reaped. + """ + now = time.time() + with self._conn() as conn: + expired = list( + conn.execute( + """ + SELECT * FROM tasks + WHERE state = 'leased' + AND lease_until IS NOT NULL + AND lease_until < ? + """, + (now,), + ), + ) + if not expired: + return 0 + conn.execute("BEGIN") + count = 0 + for row in expired: + attempts = int(row["attempts"]) + max_retries = int(row["max_retries"]) + worker_id = row["leased_worker_id"] + if attempts > max_retries: + conn.execute( + """ + UPDATE tasks + SET state = 'failed', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'lease expired' + WHERE task_id = ? + """, + (now, row["task_id"]), + ) + else: + backoff = self._backoff(attempts, float(row["retry_backoff"])) + conn.execute( + """ + UPDATE tasks + SET state = 'ready', updated_at = ?, next_run_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'lease expired' + WHERE task_id = ? + """, + (now, now + backoff, row["task_id"]), + ) + if worker_id: + conn.execute( + "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", + (worker_id,), + ) + self._journal( + conn, + "lease_reaped", + { + "task_id": row["task_id"], + "worker_id": worker_id, + "lease_id": row["lease_id"], + }, + ) + count += 1 + conn.execute("COMMIT") + return count + + def recover_dead_workers(self, heartbeat_timeout: float) -> int: + """ + Mark workers that missed their heartbeat deadline as DEAD. + + All tasks leased to dead workers are requeued (or permanently failed + if retries are exhausted). + + :param heartbeat_timeout: seconds of silence before a worker is dead. + :return: number of tasks requeued. + """ + now = time.time() + cutoff = now - heartbeat_timeout + with self._conn() as conn: + dead = list( + conn.execute( + "SELECT * FROM workers WHERE last_seen < ? AND status != 'dead'", + (cutoff,), + ), + ) + if not dead: + return 0 + conn.execute("BEGIN") + requeued = 0 + for worker in dead: + worker_id = worker["worker_id"] + conn.execute( + "UPDATE workers SET status = 'dead', draining = 1 WHERE worker_id = ?", + (worker_id,), + ) + leased = list( + conn.execute( + "SELECT * FROM tasks WHERE state = 'leased' AND leased_worker_id = ?", + (worker_id,), + ), + ) + for row in leased: + attempts = int(row["attempts"]) + max_retries = int(row["max_retries"]) + if attempts > max_retries: + conn.execute( + """ + UPDATE tasks + SET state = 'failed', updated_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'worker died' + WHERE task_id = ? + """, + (now, row["task_id"]), + ) + else: + backoff = self._backoff( + attempts, float(row["retry_backoff"]) + ) + conn.execute( + """ + UPDATE tasks + SET state = 'ready', updated_at = ?, next_run_at = ?, + lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, + last_error = 'worker died' + WHERE task_id = ? + """, + (now, now + backoff, row["task_id"]), + ) + self._journal( + conn, + "worker_dead_requeue", + {"worker_id": worker_id, "task_id": row["task_id"]}, + ) + requeued += 1 + conn.execute("COMMIT") + return requeued + + # ── worker lifecycle ────────────────────────────────────────────────────── + + def register_worker(self, worker: WorkerState) -> None: + """ + Upsert a worker record. + + Re-registering an existing worker (e.g. after hub restart) resets + its draining flag and updates its metadata. + + :param worker: worker state snapshot from the registration message. + """ + now = time.time() + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute( + """ + INSERT INTO workers ( + worker_id, task_addr, capacity, inflight, last_seen, + heartbeat_interval, lease_timeout, draining, status, version + ) VALUES (?, ?, ?, 0, ?, ?, ?, 0, ?, ?) + ON CONFLICT(worker_id) DO UPDATE SET + task_addr = excluded.task_addr, + capacity = excluded.capacity, + last_seen = excluded.last_seen, + heartbeat_interval = excluded.heartbeat_interval, + lease_timeout = excluded.lease_timeout, + draining = 0, + status = excluded.status, + version = excluded.version + """, + ( + worker.worker_id, + worker.task_addr, + worker.capacity, + now, + worker.heartbeat_interval, + worker.lease_timeout, + str(WorkerStatus.LISTENING), + worker.version, + ), + ) + self._journal( + conn, + "worker_register", + {"worker_id": worker.worker_id, "task_addr": worker.task_addr}, + ) + conn.execute("COMMIT") + + def heartbeat(self, worker_id: str) -> None: + """ + Record a heartbeat from a worker, resetting its last_seen timestamp. + + :param worker_id: ID of the worker sending the heartbeat. + """ + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute( + "UPDATE workers SET last_seen = ?, status = ? WHERE worker_id = ?", + (time.time(), str(WorkerStatus.LISTENING), worker_id), + ) + self._journal(conn, "heartbeat", {"worker_id": worker_id}) + conn.execute("COMMIT") + + def unregister_worker(self, worker_id: str) -> None: + """ + Remove a worker from the registry (graceful shutdown path). + + :param worker_id: ID of the worker unregistering. + """ + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) + self._journal(conn, "worker_unregister", {"worker_id": worker_id}) + conn.execute("COMMIT") + + def mark_draining(self, worker_id: str) -> None: + """ + Mark a worker as draining so the hub stops dispatching new tasks to it. + + :param worker_id: ID of the worker entering drain mode. + """ + with self._conn() as conn: + conn.execute("BEGIN") + conn.execute( + "UPDATE workers SET draining = 1, status = ? WHERE worker_id = ?", + (str(WorkerStatus.DRAINING), worker_id), + ) + self._journal(conn, "worker_drain", {"worker_id": worker_id}) + conn.execute("COMMIT") + + # ── routing ─────────────────────────────────────────────────────────────── + + def choose_worker( + self, + routing_policy: str = "least_loaded", + *, + heartbeat_timeout: float = 15.0, + ) -> sqlite3.Row | None: + """ + Select the best available worker according to ``routing_policy``. + + ``'least_loaded'`` picks the worker with the lowest ``inflight / + capacity`` ratio. ``'p2c'`` (Power-of-Two-Choices) samples two + workers at random and picks the less loaded one, reducing hot-spot + probability under high concurrency. + + :param routing_policy: ``'least_loaded'`` or ``'p2c'``. + :param heartbeat_timeout: seconds before a worker is considered stale. + :return: chosen worker row, or None if no worker has capacity. + """ + cutoff = time.time() - heartbeat_timeout + with self._conn() as conn: + rows = list( + conn.execute( + """ + SELECT * FROM workers + WHERE status IN ('starting', 'listening') + AND draining = 0 + AND last_seen >= ? + """, + (cutoff,), + ), + ) + available = [ + w for w in rows if int(w["inflight"]) < int(w["capacity"]) + ] + if not available: + return None + if routing_policy == "p2c" and len(available) >= 2: + a, b = random.sample(available, k=2) # noqa: S311 + load_a = int(a["inflight"]) / max(int(a["capacity"]), 1) + load_b = int(b["inflight"]) / max(int(b["capacity"]), 1) + return a if load_a <= load_b else b + return min( + available, + key=lambda w: int(w["inflight"]) / max(int(w["capacity"]), 1), + ) + + # ── management / observability ──────────────────────────────────────────── + + def get_task(self, task_id: str) -> sqlite3.Row | None: + """ + Fetch a single task row by ID. + + :param task_id: ID of the task to look up. + :return: row or None if not found. + """ + with self._conn() as conn: + return conn.execute( + "SELECT * FROM tasks WHERE task_id = ?", (task_id,) + ).fetchone() + + def list_workers(self) -> list[sqlite3.Row]: + """Return all registered workers ordered by most-recently-seen.""" + with self._conn() as conn: + return list( + conn.execute("SELECT * FROM workers ORDER BY last_seen DESC"), + ) + + def stats(self) -> dict[str, int]: + """Return a summary dict with task state counts and active worker count.""" + with self._conn() as conn: + rows = conn.execute( + "SELECT state, COUNT(*) AS n FROM tasks GROUP BY state", + ).fetchall() + counts = {r["state"]: r["n"] for r in rows} + worker_count = conn.execute( + """ + SELECT COUNT(*) FROM workers + WHERE status IN ('starting', 'listening') AND draining = 0 + """, + ).fetchone()[0] + return { + "ready": counts.get("ready", 0), + "leased": counts.get("leased", 0), + "done": counts.get("done", 0), + "failed": counts.get("failed", 0), + "active_workers": worker_count, + } diff --git a/taskiq/brokers/nng_broker.py b/taskiq/brokers/nng_broker.py deleted file mode 100644 index 15ab3aaa..00000000 --- a/taskiq/brokers/nng_broker.py +++ /dev/null @@ -1,48 +0,0 @@ -from collections.abc import AsyncGenerator - -import pynng - -from taskiq.abc.broker import AsyncBroker -from taskiq.message import BrokerMessage - - -class NNGBroker(AsyncBroker): - """ - NanoMSG next generation broker. - - This broker is very much alike to the ZMQ broker, - It has a similar Idea, but slightly different - implementation. - """ - - def __init__(self, addr: str) -> None: - """ - Initialize the broker. - - :param addr: address which is used by both worker and client. - """ - super().__init__() - self.socket = pynng.Pair1(polyamorous=True) - self.addr = addr - - async def startup(self) -> None: - """Start the socket.""" - await super().startup() - if self.is_worker_process: - self.socket.listen(self.addr) - else: - self.socket.dial(self.addr, block=True) - - async def shutdown(self) -> None: - """Close the socket.""" - await super().shutdown() - self.socket.close() - - async def kick(self, message: BrokerMessage) -> None: - """Send a message.""" - await self.socket.ascend(message.message) - - async def listen(self) -> AsyncGenerator[bytes, None]: - """Infinite loop that receives messages.""" - while True: - yield await self.socket.arecv() diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py new file mode 100644 index 00000000..f128e757 --- /dev/null +++ b/tests/brokers/test_nng_broker.py @@ -0,0 +1,576 @@ +""" +Tests for the NNG broker, hub, storage, and protocol. + +The test suite is split into three layers: + +1. **Protocol** — pure serialisation roundtrips; no NNG sockets needed. +2. **Storage** — SQLiteJournal unit tests; no NNG sockets needed. +3. **Integration** — real NNG sockets, real SQLite, single asyncio event loop. + Uses ``FakeWorker`` / ``FakeClient`` helpers that speak the wire protocol + directly so we can inject faults precisely (crash before ack, late ack, etc.). + +All NNG tests are skipped when ``pynng`` is not installed. +""" +from __future__ import annotations + +import asyncio +import os +import sqlite3 +import tempfile +import time +import uuid + +import pytest + +pynng = pytest.importorskip("pynng") + +from taskiq.brokers.nng import ( + HubConfig, + NNGHub, + ControlMessage, + ControlResponse, + MessageKind, + TaskEnvelope, + WorkerState, + WorkerStatus, + QueueFullError, + SQLiteJournal, + StoreConfig, +) + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def _ipc(tag: str = "") -> str: + name = f"nng-test-{tag}-{uuid.uuid4().hex[:8]}.ipc" + return f"ipc://{os.path.join(tempfile.gettempdir(), name)}" + + +def _envelope(**kwargs: object) -> TaskEnvelope: + defaults: dict[str, object] = { + "task_id": uuid.uuid4().hex, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": 0, + "retry_backoff": 1.0, + "retry_jitter": 0.0, + "priority": 0, + "created_at": time.time(), + } + defaults.update(kwargs) + return TaskEnvelope(**defaults) # type: ignore[arg-type] + + +def _worker_state( + worker_id: str | None = None, + task_addr: str | None = None, + capacity: int = 2, +) -> WorkerState: + wid = worker_id or uuid.uuid4().hex + return WorkerState( + worker_id=wid, + task_addr=task_addr or f"ipc:///tmp/{wid}.ipc", + capacity=capacity, + heartbeat_interval=5.0, + lease_timeout=10.0, + ) + + +def _hub(control_addr: str, db_path: str, **kwargs: object) -> NNGHub: + cfg = HubConfig( + control_addr=control_addr, + task_db=db_path, + max_pending=100, + heartbeat_timeout=2.0, + lease_timeout=2.0, + dispatch_interval=0.02, + reaper_interval=0.1, + control_concurrency=4, + **kwargs, # type: ignore[arg-type] + ) + return NNGHub(cfg) + + +@pytest.fixture +def db_path(tmp_path: object) -> str: + import pathlib + return str(pathlib.Path(str(tmp_path)) / "hub.db") # type: ignore[arg-type] + + +@pytest.fixture +def ctrl_addr() -> str: + return _ipc("ctrl") + + +class FakeWorker: + """Minimal NNG worker that speaks the control + task protocol.""" + + def __init__( + self, + control_addr: str, + task_addr: str | None = None, + capacity: int = 1, + ) -> None: + self.worker_id = uuid.uuid4().hex[:8] + self.task_addr = task_addr or _ipc("worker") + self._ctrl = pynng.Req0( + dial=control_addr, recv_timeout=3000, send_timeout=3000 + ) + self._pull = pynng.Pull0(listen=self.task_addr, recv_timeout=3000) + self._lock = asyncio.Lock() + self.capacity = capacity + + async def ctrl(self, kind: str, payload: dict[str, object]) -> ControlResponse: + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind=kind, payload=payload).to_bytes() + ) + raw = await self._ctrl.arecv() + return ControlResponse.from_bytes(raw) + + async def register(self) -> None: + resp = await self.ctrl( + "register", + { + "worker_id": self.worker_id, + "task_addr": self.task_addr, + "capacity": self.capacity, + "inflight": 0, + "last_seen": time.time(), + "heartbeat_interval": 1.0, + "lease_timeout": 2.0, + "draining": False, + "status": str(WorkerStatus.STARTING), + "version": "test", + }, + ) + assert resp.ok, f"register failed: {resp.error}" + + async def recv_task(self, timeout: float = 3.0) -> TaskEnvelope: + raw = await asyncio.wait_for(self._pull.arecv(), timeout=timeout) + return TaskEnvelope.from_bytes(raw) + + async def ack(self, task_id: str, lease_id: str) -> bool: + resp = await self.ctrl( + "ack", + { + "task_id": task_id, + "worker_id": self.worker_id, + "lease_id": lease_id, + }, + ) + return resp.ok + + async def heartbeat(self) -> None: + await self.ctrl("heartbeat", {"worker_id": self.worker_id}) + + async def drain_and_unregister(self) -> None: + await self.ctrl("drain", {"worker_id": self.worker_id}) + await self.ctrl("unregister", {"worker_id": self.worker_id}) + + def close(self) -> None: + self._ctrl.close() + self._pull.close() + + +class FakeClient: + """Minimal NNG client that can submit tasks and query hub status.""" + + def __init__(self, control_addr: str) -> None: + self._ctrl = pynng.Req0( + dial=control_addr, recv_timeout=3000, send_timeout=3000 + ) + self._lock = asyncio.Lock() + + async def submit(self, **labels: object) -> str: + tid = uuid.uuid4().hex + payload: dict[str, object] = { + "task_id": tid, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": labels.pop("max_retries", 0), + "retry_backoff": labels.pop("retry_backoff", 1.0), + "retry_jitter": 0.0, + "priority": labels.pop("priority", 0), + "created_at": time.time(), + } + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind="submit", payload=payload).to_bytes() + ) + raw = await self._ctrl.arecv() + resp = ControlResponse.from_bytes(raw) + assert resp.ok, f"submit failed: {resp.error}" + return tid + + async def ping(self) -> bool: + async with self._lock: + await self._ctrl.asend( + ControlMessage(kind="ping", payload={}).to_bytes() + ) + raw = await self._ctrl.arecv() + return ControlResponse.from_bytes(raw).ok + + def close(self) -> None: + self._ctrl.close() + + +# ── 1. Protocol tests ───────────────────────────────────────────────────────── + + +def test_control_message_roundtrip() -> None: + msg = ControlMessage(kind=MessageKind.HEARTBEAT, payload={"worker_id": "w1"}) + assert ControlMessage.from_bytes(msg.to_bytes()) == msg + + +def test_control_response_roundtrip() -> None: + resp = ControlResponse(ok=True, payload={"task_id": "abc"}, error=None) + assert ControlResponse.from_bytes(resp.to_bytes()) == resp + + +def test_task_envelope_lease_id_preserved() -> None: + """Regression: v2 omitted lease_id from the envelope, breaking ack validation.""" + env = TaskEnvelope( + task_id="x", task_name="m:f", payload_b64="YQ==", lease_id="abc123" + ) + rt = TaskEnvelope.from_bytes(env.to_bytes()) + assert rt.lease_id == "abc123" + + +def test_task_envelope_payload_decode() -> None: + env = _envelope(payload_b64="dGVzdA==") + assert env.payload == b"test" + + +# ── 2. Storage tests ────────────────────────────────────────────────────────── + + +@pytest.fixture +def store(db_path: str) -> SQLiteJournal: + return SQLiteJournal(StoreConfig(path=db_path, max_pending=50, lease_timeout=5.0)) + + +def test_submit_and_pending(store: SQLiteJournal) -> None: + store.submit(_envelope()) + assert store.pending_count() == 1 + + +def test_submit_queue_full(db_path: str) -> None: + s = SQLiteJournal(StoreConfig(path=db_path, max_pending=2)) + s.submit(_envelope()) + s.submit(_envelope()) + with pytest.raises(QueueFullError): + s.submit(_envelope()) + + +def test_due_tasks_ordered_by_priority(store: SQLiteJournal) -> None: + store.submit(_envelope(task_id="lo", priority=0)) + store.submit(_envelope(task_id="hi", priority=10)) + due = store.due_tasks(limit=10) + assert due[0]["task_id"] == "hi" + assert due[1]["task_id"] == "lo" + + +def test_ack_happy_path(store: SQLiteJournal) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + assert store.mark_leased(env.task_id, w.worker_id, "L1", time.time() + 60) + assert store.ack(env.task_id, w.worker_id, "L1") + assert store.get_task(env.task_id)["state"] == "done" + + +def test_ack_wrong_lease_rejected(store: SQLiteJournal) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "real", time.time() + 60) + assert not store.ack(env.task_id, w.worker_id, "wrong") + + +def test_late_ack_after_requeue_ignored(store: SQLiteJournal) -> None: + env = _envelope() + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L2", time.time() - 1) + assert store.reap_expired_leases() == 1 + assert not store.ack(env.task_id, w.worker_id, "L2") + + +def test_nack_requeues_with_backoff(store: SQLiteJournal) -> None: + env = _envelope(max_retries=2, retry_backoff=1.0) + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L3", time.time() + 60) + assert store.nack(env.task_id, w.worker_id, "L3", "boom") + task = store.get_task(env.task_id) + assert task["state"] == "ready" + assert float(task["next_run_at"]) > time.time() + + +def test_nack_exceeds_retries_fails(store: SQLiteJournal) -> None: + env = _envelope(max_retries=0) + store.submit(env) + w = _worker_state() + store.register_worker(w) + store.mark_leased(env.task_id, w.worker_id, "L4", time.time() + 60) + store.nack(env.task_id, w.worker_id, "L4", "error") + assert store.get_task(env.task_id)["state"] == "failed" + + +def test_dead_worker_tasks_requeued(store: SQLiteJournal, db_path: str) -> None: + w = _worker_state() + store.register_worker(w) + env = _envelope(max_retries=3) + store.submit(env) + store.mark_leased(env.task_id, w.worker_id, "L5", time.time() + 60) + conn = sqlite3.connect(db_path) + conn.execute("UPDATE workers SET last_seen=0 WHERE worker_id=?", (w.worker_id,)) + conn.commit() + conn.close() + assert store.recover_dead_workers(heartbeat_timeout=1.0) == 1 + assert store.get_task(env.task_id)["state"] == "ready" + + +def test_choose_worker_least_loaded(store: SQLiteJournal, db_path: str) -> None: + w1 = _worker_state(worker_id="w1", capacity=4) + w2 = _worker_state(worker_id="w2", capacity=4) + store.register_worker(w1) + store.register_worker(w2) + conn = sqlite3.connect(db_path) + conn.execute("UPDATE workers SET inflight=3 WHERE worker_id='w1'") + conn.commit() + conn.close() + chosen = store.choose_worker("least_loaded", heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] == "w2" + + +def test_stats(store: SQLiteJournal) -> None: + w = _worker_state() + store.register_worker(w) + store.submit(_envelope()) + s = store.stats() + assert s["ready"] == 1 + assert s["active_workers"] == 1 + + +# ── 3. Integration tests ────────────────────────────────────────────────────── + + +async def test_ping(ctrl_addr: str, db_path: str) -> None: + hub = _hub(ctrl_addr, db_path) + await hub.start() + client = FakeClient(ctrl_addr) + try: + assert await client.ping() + finally: + client.close() + await hub.stop() + + +async def test_submit_dispatch_ack(ctrl_addr: str, db_path: str) -> None: + """Golden path: one task, one worker, full round-trip.""" + hub = _hub(ctrl_addr, db_path) + await hub.start() + worker = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await worker.register() + tid = await client.submit() + env = await worker.recv_task(timeout=3.0) + assert env.task_id == tid + assert env.lease_id != "", "Hub must populate lease_id in envelope" + assert await worker.ack(env.task_id, env.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + worker.close() + client.close() + await hub.stop() + + +async def test_multiple_workers_load_balanced(ctrl_addr: str, db_path: str) -> None: + """Both workers must receive at least one task — no single hot-spot.""" + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=4) + w2 = FakeWorker(ctrl_addr, capacity=4) + client = FakeClient(ctrl_addr) + try: + await w1.register() + await w2.register() + task_ids = [await client.submit() for _ in range(6)] + received: dict[str, list[str]] = {w1.worker_id: [], w2.worker_id: []} + pending = set(task_ids) + + async def drain(w: FakeWorker) -> None: + while pending: + try: + env = await w.recv_task(timeout=0.5) + received[w.worker_id].append(env.task_id) + pending.discard(env.task_id) + await w.ack(env.task_id, env.lease_id) + except asyncio.TimeoutError: + break + + await asyncio.gather(drain(w1), drain(w2)) + assert not pending, f"Tasks not delivered: {pending}" + assert len(received[w1.worker_id]) > 0 + assert len(received[w2.worker_id]) > 0 + finally: + w1.close() + w2.close() + client.close() + await hub.stop() + + +async def test_worker_crash_before_ack_task_requeued( + ctrl_addr: str, db_path: str +) -> None: + """ + Worker receives a task but dies before acking. + After lease expiry the hub must requeue it for a second worker. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await w1.register() + tid = await client.submit(max_retries=3) + env1 = await w1.recv_task(timeout=3.0) + assert env1.task_id == tid + w1.close() # simulate crash without acking + + await asyncio.sleep(3.5) # lease_timeout=2s + reaper_interval=0.1s + + assert hub.store.get_task(tid)["state"] == "ready" + + w2 = FakeWorker(ctrl_addr, capacity=1) + try: + await w2.register() + env2 = await w2.recv_task(timeout=3.0) + assert env2.task_id == tid + assert env2.lease_id != env1.lease_id + assert await w2.ack(env2.task_id, env2.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + w2.close() + finally: + client.close() + await hub.stop() + + +async def test_late_ack_after_requeue_rejected( + ctrl_addr: str, db_path: str +) -> None: + """ + Sequence: dispatch to w1 → lease expires → requeue → dispatch to w2. + w1's late ack must be rejected; w2's ack must succeed. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + try: + await w1.register() + tid = await client.submit(max_retries=3) + env1 = await w1.recv_task(timeout=3.0) + await asyncio.sleep(3.5) # let lease expire + + w2 = FakeWorker(ctrl_addr, capacity=1) + try: + await w2.register() + env2 = await w2.recv_task(timeout=3.0) + + # w1's stale ack must be rejected + assert not await w1.ack(env1.task_id, env1.lease_id) + # w2's valid ack succeeds + assert await w2.ack(env2.task_id, env2.lease_id) + assert hub.store.get_task(tid)["state"] == "done" + finally: + w2.close() + finally: + w1.close() + client.close() + await hub.stop() + + +async def test_hub_restart_recovers_orphaned_tasks( + ctrl_addr: str, db_path: str +) -> None: + """ + Tasks leased at hub shutdown must be requeued when a new hub starts + with the same database. + """ + hub1 = _hub(ctrl_addr, db_path) + await hub1.start() + w1 = FakeWorker(ctrl_addr, capacity=1) + client = FakeClient(ctrl_addr) + await w1.register() + tid = await client.submit(max_retries=3) + env = await w1.recv_task(timeout=3.0) + assert env.task_id == tid + # "kill" hub1 without giving worker a chance to ack + await hub1.stop() + w1.close() + client.close() + + # Task is still leased in the DB + assert hub1.store.get_task(tid)["state"] == "leased" + + hub2 = _hub(ctrl_addr, db_path) + await hub2.start() + await asyncio.sleep(0.3) # allow startup recovery + try: + assert hub2.store.get_task(tid)["state"] == "ready" + finally: + await hub2.stop() + + +async def test_concurrent_heartbeats(ctrl_addr: str, db_path: str) -> None: + """ + N workers heartbeat simultaneously. With concurrent Rep0 contexts all + must succeed without serialisation stalls. + """ + hub = _hub(ctrl_addr, db_path) + await hub.start() + workers = [FakeWorker(ctrl_addr, capacity=2) for _ in range(8)] + try: + await asyncio.gather(*[w.register() for w in workers]) + results = await asyncio.gather( + *[w.heartbeat() for w in workers], + return_exceptions=True, + ) + errors = [r for r in results if isinstance(r, Exception)] + assert not errors, f"Concurrent heartbeats failed: {errors}" + finally: + for w in workers: + w.close() + await hub.stop() + + +async def test_graceful_drain_and_unregister(ctrl_addr: str, db_path: str) -> None: + hub = _hub(ctrl_addr, db_path) + await hub.start() + worker = FakeWorker(ctrl_addr, capacity=2) + try: + await worker.register() + assert len(hub.store.list_workers()) == 1 + await worker.drain_and_unregister() + await asyncio.sleep(0.1) + assert len(hub.store.list_workers()) == 0 + finally: + worker.close() + await hub.stop() From 7fdb8eb19cd6085419556d850466214eb4799855 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 17:10:02 +0300 Subject: [PATCH 2/5] Refactoring the NNG support solution v1.5: Simplify store --- taskiq/brokers/nng/__init__.py | 29 +- taskiq/brokers/nng/broker.py | 2 +- taskiq/brokers/nng/hub.py | 107 ++-- taskiq/brokers/nng/storage.py | 822 +++++++++++-------------------- tests/brokers/test_nng_broker.py | 74 +-- 5 files changed, 349 insertions(+), 685 deletions(-) diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py index 0d0a2946..8e2a7f4a 100644 --- a/taskiq/brokers/nng/__init__.py +++ b/taskiq/brokers/nng/__init__.py @@ -1,5 +1,6 @@ -from hub import HubConfig, NNGHub -from protocol import ( +"""NNG broker package for taskiq.""" +from .hub import HubConfig, NNGHub +from .protocol import ( ControlMessage, ControlResponse, MessageKind, @@ -7,18 +8,18 @@ WorkerState, WorkerStatus, ) -from storage import QueueFullError, SQLiteJournal, StoreConfig +from .storage import InMemoryStore, QueueFullError, StoreConfig __all__ = [ - 'HubConfig', - 'NNGHub', - 'ControlMessage', - 'ControlResponse', - 'MessageKind', - 'TaskEnvelope', - 'WorkerState', - 'WorkerStatus', - 'QueueFullError', - 'SQLiteJournal', - 'StoreConfig', + "HubConfig", + "NNGHub", + "ControlMessage", + "ControlResponse", + "MessageKind", + "TaskEnvelope", + "WorkerState", + "WorkerStatus", + "QueueFullError", + "InMemoryStore", + "StoreConfig", ] diff --git a/taskiq/brokers/nng/broker.py b/taskiq/brokers/nng/broker.py index 6961cbeb..a6273e41 100644 --- a/taskiq/brokers/nng/broker.py +++ b/taskiq/brokers/nng/broker.py @@ -17,7 +17,7 @@ from taskiq.acks import AckableMessage from taskiq.message import BrokerMessage -from protocol import ( +from .protocol import ( ControlMessage, ControlResponse, TaskEnvelope, diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py index 844055c5..c58857bf 100644 --- a/taskiq/brokers/nng/hub.py +++ b/taskiq/brokers/nng/hub.py @@ -3,12 +3,11 @@ Run as a standalone process:: - taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc \\ - --task-db /var/lib/taskiq/tasks.db + taskiq-nng-hub --control-addr ipc:///tmp/taskiq-nng.ipc Or embed it in an application for testing:: - hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc", task_db=":memory:")) + hub = NNGHub(HubConfig(control_addr="ipc:///tmp/h.ipc")) await hub.start() ... await hub.stop() @@ -18,13 +17,11 @@ import argparse import asyncio import base64 -import json import logging import os import signal import time import uuid -from concurrent.futures import ThreadPoolExecutor from contextlib import suppress from dataclasses import dataclass, field from typing import Any @@ -34,13 +31,13 @@ except ImportError: pynng = None # type: ignore[assignment] -from protocol import ( +from .protocol import ( ControlMessage, ControlResponse, TaskEnvelope, WorkerState, ) -from storage import QueueFullError, SQLiteJournal, StoreConfig +from .storage import InMemoryStore, QueueFullError, StoreConfig logger = logging.getLogger(__name__) @@ -50,7 +47,7 @@ class HubConfig: """Configuration for :class:`NNGHub`.""" control_addr: str - task_db: str + task_db: str = "" # kept for API compat; ignored by in-memory store max_pending: int = 10_000 heartbeat_timeout: float = 15.0 lease_timeout: float = 20.0 @@ -77,20 +74,18 @@ class NNGHub: independent ``nng_ctx`` contexts running concurrently. Each context handles one request-reply at a time, so N workers can register/heartbeat/ack simultaneously without queuing behind each other. - This is the key fix over the single-context (serial) Rep0 in v2. **Data plane** — One ``Push0`` socket per registered worker, dialed to the worker's own ``Pull0`` listen address. The hub explicitly targets - the least-loaded worker instead of relying on NNG round-robin, giving - us load-aware routing. + the least-loaded worker instead of relying on NNG round-robin. - **Persistence** — :class:`~taskiq.brokers.nng_storage.SQLiteJournal` in - WAL mode. All storage calls are executed on a single-threaded - ``ThreadPoolExecutor`` so the asyncio event loop is never blocked and - SQLite write serialisation is guaranteed. + **State** — :class:`~taskiq.brokers.nng.storage.InMemoryStore`. All + store operations are synchronous and execute directly on the asyncio event + loop without blocking (no I/O, no syscalls). - **Recovery** — On startup, tasks leased to workers that died during the - previous hub session are automatically requeued. + **Recovery** — On startup, any tasks that were leased before the hub last + stopped (within the same process lifetime) are automatically requeued by + :meth:`~InMemoryStore.recover_dead_workers`. """ def __init__(self, config: HubConfig) -> None: @@ -105,9 +100,8 @@ def __init__(self, config: HubConfig) -> None: "Install it with: pip install taskiq[nng]" ) self.config = config - self.store = SQLiteJournal( + self.store = InMemoryStore( StoreConfig( - path=config.task_db, max_pending=config.max_pending, lease_timeout=config.lease_timeout, backoff_cap=config.backoff_cap, @@ -117,16 +111,12 @@ def __init__(self, config: HubConfig) -> None: self._ctrl_sock: Any = None # pynng.Rep0 self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 self._tasks: list[asyncio.Task[None]] = [] - # Single-threaded executor: serialises all SQLite calls on one OS thread. - self._db_exec = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="nng-db" - ) # ── lifecycle ───────────────────────────────────────────────────────────── async def start(self) -> None: """Start the hub: recover orphaned tasks, open sockets, spawn loops.""" - await self._db(self.store.recover_dead_workers, self.config.heartbeat_timeout) + self.store.recover_dead_workers(self.config.heartbeat_timeout) self._ctrl_sock = pynng.Rep0(listen=self.config.control_addr) self._ctrl_sock.recv_timeout = self.config.recv_timeout_ms @@ -142,11 +132,7 @@ async def start(self) -> None: self._control_handler(ctx), name=f"hub-ctrl-{i}" ), ) - logger.info( - "NNG hub started on %s (db=%s)", - self.config.control_addr, - self.config.task_db, - ) + logger.info("NNG hub started on %s", self.config.control_addr) async def stop(self) -> None: """Gracefully stop all hub loops and close sockets.""" @@ -163,17 +149,8 @@ async def stop(self) -> None: if self._ctrl_sock is not None: with suppress(Exception): self._ctrl_sock.close() - self._db_exec.shutdown(wait=True) logger.info("NNG hub stopped") - # ── DB helper ───────────────────────────────────────────────────────────── - - async def _db(self, fn: Any, *args: Any, **kwargs: Any) -> Any: - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - self._db_exec, lambda: fn(*args, **kwargs) - ) - # ── control plane ───────────────────────────────────────────────────────── async def _control_handler(self, ctx: Any) -> None: @@ -216,19 +193,18 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 return await self._handle_register(msg.payload) if msg.kind == "heartbeat": - await self._db(self.store.heartbeat, msg.payload["worker_id"]) + self.store.heartbeat(msg.payload["worker_id"]) return ControlResponse(ok=True, payload={"ok": True}) if msg.kind == "unregister": return await self._handle_unregister(msg.payload["worker_id"]) if msg.kind == "drain": - await self._db(self.store.mark_draining, msg.payload["worker_id"]) + self.store.mark_draining(msg.payload["worker_id"]) return ControlResponse(ok=True, payload={"draining": True}) if msg.kind == "ack": - ok = await self._db( - self.store.ack, + ok = self.store.ack( msg.payload["task_id"], msg.payload["worker_id"], msg.payload["lease_id"], @@ -236,8 +212,7 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 return ControlResponse(ok=ok, payload={"acked": ok}) if msg.kind == "nack": - ok = await self._db( - self.store.nack, + ok = self.store.nack( msg.payload["task_id"], msg.payload["worker_id"], msg.payload["lease_id"], @@ -246,26 +221,25 @@ async def _handle(self, raw: bytes) -> ControlResponse: # noqa: PLR0911, C901 return ControlResponse(ok=ok, payload={"nacked": ok}) if msg.kind == "status": - task = await self._db(self.store.get_task, msg.payload["task_id"]) - return ControlResponse(ok=bool(task), payload=dict(task) if task else {}) + task = self.store.get_task(msg.payload["task_id"]) + return ControlResponse(ok=bool(task), payload=task or {}) if msg.kind == "stats": - s = await self._db(self.store.stats) - return ControlResponse(ok=True, payload=s) + return ControlResponse(ok=True, payload=self.store.stats()) return ControlResponse(ok=False, error=f"unknown kind: {msg.kind!r}") async def _handle_submit(self, payload: dict[str, Any]) -> ControlResponse: envelope = TaskEnvelope(**payload) try: - await self._db(self.store.submit, envelope) + self.store.submit(envelope) return ControlResponse(ok=True, payload={"task_id": envelope.task_id}) except QueueFullError: return ControlResponse(ok=False, error="queue full") async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse: worker = WorkerState(**payload) - await self._db(self.store.register_worker, worker) + self.store.register_worker(worker) if worker.worker_id not in self._worker_push: try: sock = pynng.Push0(dial=worker.task_addr) @@ -279,7 +253,7 @@ async def _handle_register(self, payload: dict[str, Any]) -> ControlResponse: return ControlResponse(ok=True, payload={"registered": True}) async def _handle_unregister(self, worker_id: str) -> ControlResponse: - await self._db(self.store.unregister_worker, worker_id) + self.store.unregister_worker(worker_id) sock = self._worker_push.pop(worker_id, None) if sock is not None: with suppress(Exception): @@ -302,13 +276,12 @@ async def _dispatch_loop(self) -> None: async def _dispatch_once(self) -> bool: """Dispatch up to ``dispatch_batch`` due tasks to available workers.""" - due = await self._db(self.store.due_tasks, self.config.dispatch_batch) + due = self.store.due_tasks(self.config.dispatch_batch) if not due: return False sent_any = False for row in due: - worker = await self._db( - self.store.choose_worker, + worker = self.store.choose_worker( self.config.routing_policy, heartbeat_timeout=self.config.heartbeat_timeout, ) @@ -319,8 +292,7 @@ async def _dispatch_once(self) -> bool: lease_id = uuid.uuid4().hex lease_until = time.time() + self.config.lease_timeout - if not await self._db( - self.store.mark_leased, + if not self.store.mark_leased( row["task_id"], worker_id, lease_id, lease_until, ): continue # concurrent dispatch race; task already taken @@ -331,19 +303,14 @@ async def _dispatch_once(self) -> bool: "No push socket for worker %s, requeueing %s", worker_id, row["task_id"], ) - await self._db( - self.store.nack, - row["task_id"], worker_id, lease_id, "no socket", - ) + self.store.nack(row["task_id"], worker_id, lease_id, "no socket") continue - # Include the hub-generated lease_id so the worker can ack with - # the exact token. Omitting it was the core correctness bug in v2. envelope = TaskEnvelope( task_id=row["task_id"], task_name=row["task_name"], payload_b64=base64.b64encode(row["payload"]).decode("ascii"), - labels=json.loads(row["labels_json"]), + labels=row["labels"], lease_id=lease_id, attempts=int(row["attempts"]) + 1, max_retries=int(row["max_retries"]), @@ -360,8 +327,7 @@ async def _dispatch_once(self) -> bool: "Failed to deliver %s to worker %s: %s", row["task_id"], worker_id, exc, ) - await self._db( - self.store.nack, + self.store.nack( row["task_id"], worker_id, lease_id, f"dispatch send failed: {exc}", ) @@ -373,11 +339,10 @@ async def _reaper_loop(self) -> None: while not self._stop.is_set(): try: await asyncio.sleep(self.config.reaper_interval) - reaped = await self._db(self.store.reap_expired_leases) + reaped = self.store.reap_expired_leases() if reaped: logger.debug("Reaped %d expired leases", reaped) - recovered = await self._db( - self.store.recover_dead_workers, + recovered = self.store.recover_dead_workers( self.config.heartbeat_timeout, ) if recovered: @@ -400,11 +365,6 @@ def _build_config() -> HubConfig: default=os.getenv("NNG_CONTROL_ADDR", "ipc:///tmp/taskiq-nng.ipc"), help="NNG address the hub listens on. Env: NNG_CONTROL_ADDR", ) - p.add_argument( - "--task-db", - default=os.getenv("NNG_TASK_DB", "/tmp/taskiq-nng-tasks.db"), # noqa: S108 - help="Path to the SQLite WAL task journal. Env: NNG_TASK_DB", - ) p.add_argument( "--max-pending", type=int, @@ -445,7 +405,6 @@ def _build_config() -> HubConfig: ) return HubConfig( control_addr=args.control_addr, - task_db=args.task_db, max_pending=args.max_pending, heartbeat_timeout=args.heartbeat_timeout, lease_timeout=args.lease_timeout, diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py index 410a7064..87970adf 100644 --- a/taskiq/brokers/nng/storage.py +++ b/taskiq/brokers/nng/storage.py @@ -1,24 +1,20 @@ -"""Durable WAL-mode SQLite task journal for the NNG hub.""" +"""Pure in-memory task store for the NNG hub — no external dependencies.""" from __future__ import annotations -import json import random -import sqlite3 -import threading import time -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Generator +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any -from protocol import TaskEnvelope, WorkerState, WorkerStatus +if TYPE_CHECKING: + from .protocol import TaskEnvelope, WorkerState @dataclass class StoreConfig: - """Configuration for the SQLite task journal.""" + """Configuration for :class:`InMemoryStore`.""" - path: str + path: str = "" # kept for API compat; not used max_pending: int = 10_000 lease_timeout: float = 30.0 backoff_base: float = 1.0 @@ -29,203 +25,172 @@ class QueueFullError(RuntimeError): """Raised when a submission is attempted on a full queue.""" -class SQLiteJournal: +@dataclass +class _Task: + task_id: str + task_name: str + payload: bytes + labels: dict[str, Any] + state: str # ready / leased / done / failed + attempts: int = 0 + max_retries: int = 0 + retry_backoff: float = 1.0 + retry_jitter: float = 0.0 + priority: int = 0 + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + next_run_at: float = field(default_factory=time.time) + lease_id: str | None = None + leased_worker_id: str | None = None + lease_until: float | None = None + last_error: str | None = None + + def as_dict(self) -> dict[str, Any]: + """Return a dict view of this task record.""" + return { + "task_id": self.task_id, + "task_name": self.task_name, + "payload": self.payload, + "labels": self.labels, + "state": self.state, + "attempts": self.attempts, + "max_retries": self.max_retries, + "retry_backoff": self.retry_backoff, + "retry_jitter": self.retry_jitter, + "priority": self.priority, + "created_at": self.created_at, + "updated_at": self.updated_at, + "next_run_at": self.next_run_at, + "lease_id": self.lease_id, + "leased_worker_id": self.leased_worker_id, + "lease_until": self.lease_until, + "last_error": self.last_error, + } + + def as_status_dict(self) -> dict[str, Any]: + """Return a JSON-safe dict (no raw bytes) for control-plane status responses.""" + d = self.as_dict() + d.pop("payload", None) + return d + + +@dataclass +class _Worker: + worker_id: str + task_addr: str + capacity: int + inflight: int = 0 + last_seen: float = 0.0 + heartbeat_interval: float = 5.0 + lease_timeout: float = 15.0 + draining: bool = False + status: str = "starting" + version: str = "unknown" + + def as_dict(self) -> dict[str, Any]: + """Return a dict view of this worker record.""" + return { + "worker_id": self.worker_id, + "task_addr": self.task_addr, + "capacity": self.capacity, + "inflight": self.inflight, + "last_seen": self.last_seen, + "heartbeat_interval": self.heartbeat_interval, + "lease_timeout": self.lease_timeout, + "draining": self.draining, + "status": self.status, + "version": self.version, + } + + +class InMemoryStore: """ - Thread-safe, WAL-mode SQLite task store. - - Design notes - ──────────── - * Every method opens and closes its own connection. WAL allows concurrent - readers without blocking; SQLite serialises concurrent writers internally, - and the Python-level ``_submit_lock`` prevents the TOCTOU race in - :meth:`submit`. - * The hub runs every call through a single-threaded - ``ThreadPoolExecutor`` so, in practice, writes never contend at the - OS level either. - * ``PRAGMA`` settings (WAL, synchronous, busy_timeout) are applied per - connection because each ``sqlite3.connect()`` call starts with defaults. + Pure in-memory task store for the NNG hub. + + All methods are synchronous and safe to call from a single asyncio event + loop — asyncio's cooperative scheduling makes them effectively atomic (no + ``await`` between reads and writes). + + State is lost when the process exits. For persistent task queues use a + dedicated result backend; the NNG broker is designed for low-latency + in-process delivery, not durable storage. """ def __init__(self, config: StoreConfig) -> None: - """Initialise the journal and create schema if not present.""" + """Initialise an empty store with the given configuration.""" self.config = config - # Guards only the pending_count check + INSERT pair in submit() to - # prevent concurrent callers from racing past max_pending. - self._submit_lock = threading.Lock() - self._init() - - # ── connection ──────────────────────────────────────────────────────────── - - @contextmanager - def _conn(self) -> Generator[sqlite3.Connection, None, None]: - conn = sqlite3.connect( - self.config.path, - timeout=10.0, - check_same_thread=False, - isolation_level=None, # we manage transactions explicitly - ) - conn.row_factory = sqlite3.Row - # Must be set per-connection, not just once at schema creation. - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") # safe with WAL; faster than FULL - conn.execute("PRAGMA busy_timeout=5000") # wait up to 5s before SQLITE_BUSY - conn.execute("PRAGMA cache_size=-32000") # 32 MB page cache - try: - yield conn - finally: - conn.close() - - def _init(self) -> None: - Path(self.config.path).parent.mkdir(parents=True, exist_ok=True) - with self._conn() as conn: - conn.executescript(""" - CREATE TABLE IF NOT EXISTS tasks ( - task_id TEXT PRIMARY KEY, - task_name TEXT NOT NULL, - payload BLOB NOT NULL, - labels_json TEXT NOT NULL DEFAULT '{}', - state TEXT NOT NULL, - attempts INTEGER NOT NULL DEFAULT 0, - max_retries INTEGER NOT NULL DEFAULT 0, - retry_backoff REAL NOT NULL DEFAULT 1.0, - retry_jitter REAL NOT NULL DEFAULT 0.0, - priority INTEGER NOT NULL DEFAULT 0, - created_at REAL NOT NULL, - updated_at REAL NOT NULL, - next_run_at REAL NOT NULL, - lease_id TEXT, - leased_worker_id TEXT, - lease_until REAL, - last_error TEXT - ); - - CREATE TABLE IF NOT EXISTS workers ( - worker_id TEXT PRIMARY KEY, - task_addr TEXT NOT NULL, - capacity INTEGER NOT NULL, - inflight INTEGER NOT NULL DEFAULT 0, - last_seen REAL NOT NULL DEFAULT 0, - heartbeat_interval REAL NOT NULL DEFAULT 5.0, - lease_timeout REAL NOT NULL DEFAULT 15.0, - draining INTEGER NOT NULL DEFAULT 0, - status TEXT NOT NULL, - version TEXT NOT NULL DEFAULT 'unknown' - ); - - CREATE TABLE IF NOT EXISTS journal ( - seq INTEGER PRIMARY KEY AUTOINCREMENT, - ts REAL NOT NULL, - kind TEXT NOT NULL, - payload_json TEXT NOT NULL - ); - - CREATE INDEX IF NOT EXISTS idx_tasks_dispatch - ON tasks (state, next_run_at, priority DESC); - CREATE INDEX IF NOT EXISTS idx_tasks_lease - ON tasks (state, lease_until); - CREATE INDEX IF NOT EXISTS idx_workers_active - ON workers (status, draining, last_seen); - """) + self._tasks: dict[str, _Task] = {} + self._workers: dict[str, _Worker] = {} # ── helpers ─────────────────────────────────────────────────────────────── - def _journal( - self, - conn: sqlite3.Connection, - kind: str, - payload: dict[str, Any], - ) -> None: - conn.execute( - "INSERT INTO journal (ts, kind, payload_json) VALUES (?, ?, ?)", - ( - time.time(), - kind, - json.dumps(payload, separators=(",", ":"), ensure_ascii=False), - ), - ) - def _backoff(self, attempts: int, backoff_base: float) -> float: return min(self.config.backoff_cap, backoff_base * (2 ** max(0, attempts - 1))) + def _requeue_or_fail(self, task: _Task, worker_id: str, error: str) -> bool: + now = time.time() + if task.attempts > task.max_retries: + task.state = "failed" + else: + task.state = "ready" + task.next_run_at = now + self._backoff(task.attempts, task.retry_backoff) + task.last_error = error + task.lease_id = None + task.leased_worker_id = None + task.lease_until = None + task.updated_at = now + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight = max(0, worker.inflight - 1) + return True + # ── task lifecycle ──────────────────────────────────────────────────────── def pending_count(self) -> int: - """Return the number of ready + leased tasks.""" - with self._conn() as conn: - return int( - conn.execute( - "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", - ).fetchone()[0], - ) + """Return the count of ready and leased tasks.""" + return sum(1 for t in self._tasks.values() if t.state in ("ready", "leased")) def submit(self, envelope: TaskEnvelope) -> None: """ - Persist a new task in 'ready' state. + Accept a new task into the store. :param envelope: task envelope to store. :raises QueueFullError: when ``max_pending`` is reached. """ + if self.pending_count() >= self.config.max_pending: + raise QueueFullError("Task queue is full.") now = time.time() - with self._submit_lock, self._conn() as conn: - count = conn.execute( - "SELECT COUNT(*) FROM tasks WHERE state IN ('ready', 'leased')", - ).fetchone()[0] - if count >= self.config.max_pending: - raise QueueFullError("Task queue is full.") - conn.execute("BEGIN") - conn.execute( - """ - INSERT INTO tasks ( - task_id, task_name, payload, labels_json, state, - attempts, max_retries, retry_backoff, retry_jitter, - priority, created_at, updated_at, next_run_at - ) VALUES (?, ?, ?, ?, 'ready', 0, ?, ?, ?, ?, ?, ?, ?) - """, - ( - envelope.task_id, - envelope.task_name, - envelope.payload, - json.dumps( - envelope.labels, separators=(",", ":"), ensure_ascii=False - ), - envelope.max_retries, - envelope.retry_backoff, - envelope.retry_jitter, - envelope.priority, - envelope.created_at or now, - now, - now, - ), - ) - self._journal( - conn, - "task_submitted", - {"task_id": envelope.task_id, "task_name": envelope.task_name}, - ) - conn.execute("COMMIT") + self._tasks[envelope.task_id] = _Task( + task_id=envelope.task_id, + task_name=envelope.task_name, + payload=envelope.payload, + labels=envelope.labels, + state="ready", + max_retries=envelope.max_retries, + retry_backoff=envelope.retry_backoff, + retry_jitter=envelope.retry_jitter, + priority=envelope.priority, + created_at=envelope.created_at or now, + updated_at=now, + next_run_at=now, + ) - def due_tasks(self, limit: int = 50) -> list[sqlite3.Row]: + def due_tasks(self, limit: int = 50) -> list[dict[str, Any]]: """ Return ready tasks whose ``next_run_at`` is in the past. Results are ordered by priority (descending) then creation time. :param limit: maximum number of rows to return. - :return: list of task rows. + :return: list of task dicts. """ now = time.time() - with self._conn() as conn: - return list( - conn.execute( - """ - SELECT * FROM tasks - WHERE state = 'ready' AND next_run_at <= ? - ORDER BY priority DESC, created_at ASC - LIMIT ? - """, - (now, limit), - ), - ) + ready = [ + t for t in self._tasks.values() + if t.state == "ready" and t.next_run_at <= now + ] + ready.sort(key=lambda t: (-t.priority, t.created_at)) + return [t.as_dict() for t in ready[:limit]] def mark_leased( self, @@ -241,90 +206,50 @@ def mark_leased( :param worker_id: worker receiving the task. :param lease_id: unique token for this dispatch attempt. :param lease_until: absolute epoch deadline for the lease. - :return: True if the transition succeeded; False if the task was - already taken (concurrent dispatch race). + :return: True on success; False if the task is not in 'ready' state. """ + task = self._tasks.get(task_id) + if task is None or task.state != "ready": + return False now = time.time() - with self._conn() as conn: - row = conn.execute( - "SELECT state FROM tasks WHERE task_id = ?", (task_id,) - ).fetchone() - if not row or row["state"] != "ready": - return False - conn.execute("BEGIN") - conn.execute( - """ - UPDATE tasks - SET state = 'leased', - leased_worker_id = ?, lease_id = ?, lease_until = ?, - attempts = attempts + 1, updated_at = ? - WHERE task_id = ? - """, - (worker_id, lease_id, lease_until, now, task_id), - ) - conn.execute( - "UPDATE workers SET inflight = inflight + 1 WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "task_leased", - { - "task_id": task_id, - "worker_id": worker_id, - "lease_id": lease_id, - }, - ) - conn.execute("COMMIT") - return True + task.state = "leased" + task.leased_worker_id = worker_id + task.lease_id = lease_id + task.lease_until = lease_until + task.attempts += 1 + task.updated_at = now + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight += 1 + return True def ack(self, task_id: str, worker_id: str, lease_id: str) -> bool: """ Mark a task as successfully completed. Late or duplicate acks (mismatched ``lease_id`` or state ≠ 'leased') - are silently rejected and return False. + are silently rejected. :param task_id: task being acknowledged. :param worker_id: worker sending the ack. - :param lease_id: lease token that was issued at dispatch. + :param lease_id: lease token issued at dispatch. :return: True if the ack was accepted. """ + task = self._tasks.get(task_id) + if task is None or task.state != "leased": + return False + if task.lease_id != lease_id or task.leased_worker_id != worker_id: + return False now = time.time() - with self._conn() as conn: - row = conn.execute( - "SELECT state, lease_id, leased_worker_id FROM tasks WHERE task_id = ?", - (task_id,), - ).fetchone() - if not row or row["state"] != "leased": - return False - if row["lease_id"] != lease_id or row["leased_worker_id"] != worker_id: - return False - conn.execute("BEGIN") - conn.execute( - """ - UPDATE tasks - SET state = 'done', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL - WHERE task_id = ? - """, - (now, task_id), - ) - conn.execute( - "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "task_acked", - { - "task_id": task_id, - "worker_id": worker_id, - "lease_id": lease_id, - }, - ) - conn.execute("COMMIT") - return True + task.state = "done" + task.updated_at = now + task.lease_id = None + task.leased_worker_id = None + task.lease_until = None + worker = self._workers.get(worker_id) + if worker is not None: + worker.inflight = max(0, worker.inflight - 1) + return True def nack( self, task_id: str, worker_id: str, lease_id: str, error: str @@ -335,274 +260,106 @@ def nack( :param task_id: task being nacked. :param worker_id: worker sending the nack. :param lease_id: lease token issued at dispatch. - :param error: human-readable reason for the failure. + :param error: human-readable failure reason. :return: True if the nack was accepted. """ - return self._requeue_or_fail(task_id, worker_id, lease_id, error) - - def _requeue_or_fail( - self, task_id: str, worker_id: str, lease_id: str, error: str - ) -> bool: - now = time.time() - with self._conn() as conn: - row = conn.execute( - "SELECT * FROM tasks WHERE task_id = ?", (task_id,) - ).fetchone() - if ( - not row - or row["state"] != "leased" - or row["lease_id"] != lease_id - or row["leased_worker_id"] != worker_id - ): - return False - attempts = int(row["attempts"]) - max_retries = int(row["max_retries"]) - conn.execute("BEGIN") - if attempts > max_retries: - conn.execute( - """ - UPDATE tasks - SET state = 'failed', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = ? - WHERE task_id = ? - """, - (now, error, task_id), - ) - else: - backoff = self._backoff(attempts, float(row["retry_backoff"])) - conn.execute( - """ - UPDATE tasks - SET state = 'ready', updated_at = ?, next_run_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = ? - WHERE task_id = ? - """, - (now, now + backoff, error, task_id), - ) - conn.execute( - "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "task_nacked", - { - "task_id": task_id, - "worker_id": worker_id, - "lease_id": lease_id, - "error": error, - "requeued": attempts <= max_retries, - }, - ) - conn.execute("COMMIT") - return True + task = self._tasks.get(task_id) + if ( + task is None + or task.state != "leased" + or task.lease_id != lease_id + or task.leased_worker_id != worker_id + ): + return False + return self._requeue_or_fail(task, worker_id, error) # ── reaper / recovery ───────────────────────────────────────────────────── def reap_expired_leases(self) -> int: """ - Find leases past their deadline and requeue or permanently fail them. + Requeue or permanently fail tasks whose lease deadline has passed. :return: number of tasks reaped. """ now = time.time() - with self._conn() as conn: - expired = list( - conn.execute( - """ - SELECT * FROM tasks - WHERE state = 'leased' - AND lease_until IS NOT NULL - AND lease_until < ? - """, - (now,), - ), - ) - if not expired: - return 0 - conn.execute("BEGIN") - count = 0 - for row in expired: - attempts = int(row["attempts"]) - max_retries = int(row["max_retries"]) - worker_id = row["leased_worker_id"] - if attempts > max_retries: - conn.execute( - """ - UPDATE tasks - SET state = 'failed', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'lease expired' - WHERE task_id = ? - """, - (now, row["task_id"]), - ) - else: - backoff = self._backoff(attempts, float(row["retry_backoff"])) - conn.execute( - """ - UPDATE tasks - SET state = 'ready', updated_at = ?, next_run_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'lease expired' - WHERE task_id = ? - """, - (now, now + backoff, row["task_id"]), - ) - if worker_id: - conn.execute( - "UPDATE workers SET inflight = MAX(inflight - 1, 0) WHERE worker_id = ?", - (worker_id,), - ) - self._journal( - conn, - "lease_reaped", - { - "task_id": row["task_id"], - "worker_id": worker_id, - "lease_id": row["lease_id"], - }, - ) - count += 1 - conn.execute("COMMIT") - return count + expired = [ + t for t in self._tasks.values() + if t.state == "leased" + and t.lease_until is not None + and t.lease_until < now + ] + for task in expired: + self._requeue_or_fail(task, task.leased_worker_id or "", "lease expired") + return len(expired) def recover_dead_workers(self, heartbeat_timeout: float) -> int: """ - Mark workers that missed their heartbeat deadline as DEAD. - - All tasks leased to dead workers are requeued (or permanently failed - if retries are exhausted). + Mark workers that missed their heartbeat deadline as dead and requeue their tasks. :param heartbeat_timeout: seconds of silence before a worker is dead. :return: number of tasks requeued. """ - now = time.time() - cutoff = now - heartbeat_timeout - with self._conn() as conn: - dead = list( - conn.execute( - "SELECT * FROM workers WHERE last_seen < ? AND status != 'dead'", - (cutoff,), - ), - ) - if not dead: - return 0 - conn.execute("BEGIN") - requeued = 0 - for worker in dead: - worker_id = worker["worker_id"] - conn.execute( - "UPDATE workers SET status = 'dead', draining = 1 WHERE worker_id = ?", - (worker_id,), - ) - leased = list( - conn.execute( - "SELECT * FROM tasks WHERE state = 'leased' AND leased_worker_id = ?", - (worker_id,), - ), - ) - for row in leased: - attempts = int(row["attempts"]) - max_retries = int(row["max_retries"]) - if attempts > max_retries: - conn.execute( - """ - UPDATE tasks - SET state = 'failed', updated_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'worker died' - WHERE task_id = ? - """, - (now, row["task_id"]), - ) - else: - backoff = self._backoff( - attempts, float(row["retry_backoff"]) - ) - conn.execute( - """ - UPDATE tasks - SET state = 'ready', updated_at = ?, next_run_at = ?, - lease_id = NULL, leased_worker_id = NULL, lease_until = NULL, - last_error = 'worker died' - WHERE task_id = ? - """, - (now, now + backoff, row["task_id"]), - ) - self._journal( - conn, - "worker_dead_requeue", - {"worker_id": worker_id, "task_id": row["task_id"]}, - ) - requeued += 1 - conn.execute("COMMIT") - return requeued + cutoff = time.time() - heartbeat_timeout + dead = [ + w for w in self._workers.values() + if w.last_seen < cutoff and w.status != "dead" + ] + requeued = 0 + for worker in dead: + worker.status = "dead" + worker.draining = True + leased = [ + t for t in self._tasks.values() + if t.state == "leased" and t.leased_worker_id == worker.worker_id + ] + for task in leased: + self._requeue_or_fail(task, worker.worker_id, "worker died") + requeued += 1 + return requeued # ── worker lifecycle ────────────────────────────────────────────────────── def register_worker(self, worker: WorkerState) -> None: """ - Upsert a worker record. - - Re-registering an existing worker (e.g. after hub restart) resets - its draining flag and updates its metadata. + Upsert a worker record, resetting drain state on re-registration. :param worker: worker state snapshot from the registration message. """ now = time.time() - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute( - """ - INSERT INTO workers ( - worker_id, task_addr, capacity, inflight, last_seen, - heartbeat_interval, lease_timeout, draining, status, version - ) VALUES (?, ?, ?, 0, ?, ?, ?, 0, ?, ?) - ON CONFLICT(worker_id) DO UPDATE SET - task_addr = excluded.task_addr, - capacity = excluded.capacity, - last_seen = excluded.last_seen, - heartbeat_interval = excluded.heartbeat_interval, - lease_timeout = excluded.lease_timeout, - draining = 0, - status = excluded.status, - version = excluded.version - """, - ( - worker.worker_id, - worker.task_addr, - worker.capacity, - now, - worker.heartbeat_interval, - worker.lease_timeout, - str(WorkerStatus.LISTENING), - worker.version, - ), - ) - self._journal( - conn, - "worker_register", - {"worker_id": worker.worker_id, "task_addr": worker.task_addr}, + existing = self._workers.get(worker.worker_id) + if existing is not None: + existing.task_addr = worker.task_addr + existing.capacity = worker.capacity + existing.last_seen = now + existing.heartbeat_interval = worker.heartbeat_interval + existing.lease_timeout = worker.lease_timeout + existing.draining = False + existing.status = "listening" + existing.version = worker.version + else: + self._workers[worker.worker_id] = _Worker( + worker_id=worker.worker_id, + task_addr=worker.task_addr, + capacity=worker.capacity, + inflight=0, + last_seen=now, + heartbeat_interval=worker.heartbeat_interval, + lease_timeout=worker.lease_timeout, + draining=False, + status="listening", + version=worker.version, ) - conn.execute("COMMIT") def heartbeat(self, worker_id: str) -> None: """ - Record a heartbeat from a worker, resetting its last_seen timestamp. + Record a heartbeat, resetting the worker's last_seen timestamp. :param worker_id: ID of the worker sending the heartbeat. """ - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute( - "UPDATE workers SET last_seen = ?, status = ? WHERE worker_id = ?", - (time.time(), str(WorkerStatus.LISTENING), worker_id), - ) - self._journal(conn, "heartbeat", {"worker_id": worker_id}) - conn.execute("COMMIT") + worker = self._workers.get(worker_id) + if worker is not None: + worker.last_seen = time.time() + worker.status = "listening" def unregister_worker(self, worker_id: str) -> None: """ @@ -610,11 +367,7 @@ def unregister_worker(self, worker_id: str) -> None: :param worker_id: ID of the worker unregistering. """ - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute("DELETE FROM workers WHERE worker_id = ?", (worker_id,)) - self._journal(conn, "worker_unregister", {"worker_id": worker_id}) - conn.execute("COMMIT") + self._workers.pop(worker_id, None) def mark_draining(self, worker_id: str) -> None: """ @@ -622,14 +375,10 @@ def mark_draining(self, worker_id: str) -> None: :param worker_id: ID of the worker entering drain mode. """ - with self._conn() as conn: - conn.execute("BEGIN") - conn.execute( - "UPDATE workers SET draining = 1, status = ? WHERE worker_id = ?", - (str(WorkerStatus.DRAINING), worker_id), - ) - self._journal(conn, "worker_drain", {"worker_id": worker_id}) - conn.execute("COMMIT") + worker = self._workers.get(worker_id) + if worker is not None: + worker.draining = True + worker.status = "draining" # ── routing ─────────────────────────────────────────────────────────────── @@ -638,85 +387,70 @@ def choose_worker( routing_policy: str = "least_loaded", *, heartbeat_timeout: float = 15.0, - ) -> sqlite3.Row | None: + ) -> dict[str, Any] | None: """ Select the best available worker according to ``routing_policy``. - ``'least_loaded'`` picks the worker with the lowest ``inflight / - capacity`` ratio. ``'p2c'`` (Power-of-Two-Choices) samples two - workers at random and picks the less loaded one, reducing hot-spot - probability under high concurrency. + ``'least_loaded'`` picks the worker with the lowest inflight/capacity + ratio. ``'p2c'`` samples two workers and picks the less loaded one. :param routing_policy: ``'least_loaded'`` or ``'p2c'``. :param heartbeat_timeout: seconds before a worker is considered stale. - :return: chosen worker row, or None if no worker has capacity. + :return: chosen worker dict, or None if no worker has capacity. """ cutoff = time.time() - heartbeat_timeout - with self._conn() as conn: - rows = list( - conn.execute( - """ - SELECT * FROM workers - WHERE status IN ('starting', 'listening') - AND draining = 0 - AND last_seen >= ? - """, - (cutoff,), - ), - ) available = [ - w for w in rows if int(w["inflight"]) < int(w["capacity"]) + w for w in self._workers.values() + if w.status in ("starting", "listening") + and not w.draining + and w.last_seen >= cutoff + and w.inflight < w.capacity ] if not available: return None if routing_policy == "p2c" and len(available) >= 2: a, b = random.sample(available, k=2) # noqa: S311 - load_a = int(a["inflight"]) / max(int(a["capacity"]), 1) - load_b = int(b["inflight"]) / max(int(b["capacity"]), 1) - return a if load_a <= load_b else b - return min( - available, - key=lambda w: int(w["inflight"]) / max(int(w["capacity"]), 1), - ) + load_a = a.inflight / max(a.capacity, 1) + load_b = b.inflight / max(b.capacity, 1) + chosen = a if load_a <= load_b else b + else: + chosen = min(available, key=lambda w: w.inflight / max(w.capacity, 1)) + return chosen.as_dict() - # ── management / observability ──────────────────────────────────────────── + # ── observability ───────────────────────────────────────────────────────── - def get_task(self, task_id: str) -> sqlite3.Row | None: + def get_task(self, task_id: str) -> dict[str, Any] | None: """ - Fetch a single task row by ID. + Fetch task status by ID (no raw bytes in result). :param task_id: ID of the task to look up. - :return: row or None if not found. + :return: status dict or None if not found. """ - with self._conn() as conn: - return conn.execute( - "SELECT * FROM tasks WHERE task_id = ?", (task_id,) - ).fetchone() + task = self._tasks.get(task_id) + return task.as_status_dict() if task is not None else None - def list_workers(self) -> list[sqlite3.Row]: + def list_workers(self) -> list[dict[str, Any]]: """Return all registered workers ordered by most-recently-seen.""" - with self._conn() as conn: - return list( - conn.execute("SELECT * FROM workers ORDER BY last_seen DESC"), + return [ + w.as_dict() + for w in sorted( + self._workers.values(), key=lambda w: w.last_seen, reverse=True ) + ] def stats(self) -> dict[str, int]: """Return a summary dict with task state counts and active worker count.""" - with self._conn() as conn: - rows = conn.execute( - "SELECT state, COUNT(*) AS n FROM tasks GROUP BY state", - ).fetchall() - counts = {r["state"]: r["n"] for r in rows} - worker_count = conn.execute( - """ - SELECT COUNT(*) FROM workers - WHERE status IN ('starting', 'listening') AND draining = 0 - """, - ).fetchone()[0] + counts: dict[str, int] = {} + for t in self._tasks.values(): + counts[t.state] = counts.get(t.state, 0) + 1 + active = sum( + 1 for w in self._workers.values() + if w.status in ("starting", "listening") and not w.draining + ) return { "ready": counts.get("ready", 0), "leased": counts.get("leased", 0), "done": counts.get("done", 0), "failed": counts.get("failed", 0), - "active_workers": worker_count, + "active_workers": active, } diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py index f128e757..4bb9c4b6 100644 --- a/tests/brokers/test_nng_broker.py +++ b/tests/brokers/test_nng_broker.py @@ -4,8 +4,8 @@ The test suite is split into three layers: 1. **Protocol** — pure serialisation roundtrips; no NNG sockets needed. -2. **Storage** — SQLiteJournal unit tests; no NNG sockets needed. -3. **Integration** — real NNG sockets, real SQLite, single asyncio event loop. +2. **Storage** — InMemoryStore unit tests; no NNG sockets needed. +3. **Integration** — real NNG sockets, single asyncio event loop. Uses ``FakeWorker`` / ``FakeClient`` helpers that speak the wire protocol directly so we can inject faults precisely (crash before ack, late ack, etc.). @@ -15,7 +15,6 @@ import asyncio import os -import sqlite3 import tempfile import time import uuid @@ -34,7 +33,7 @@ WorkerState, WorkerStatus, QueueFullError, - SQLiteJournal, + InMemoryStore, StoreConfig, ) @@ -253,24 +252,24 @@ def test_task_envelope_payload_decode() -> None: @pytest.fixture -def store(db_path: str) -> SQLiteJournal: - return SQLiteJournal(StoreConfig(path=db_path, max_pending=50, lease_timeout=5.0)) +def store(db_path: str) -> InMemoryStore: + return InMemoryStore(StoreConfig(path=db_path, max_pending=50, lease_timeout=5.0)) -def test_submit_and_pending(store: SQLiteJournal) -> None: +def test_submit_and_pending(store: InMemoryStore) -> None: store.submit(_envelope()) assert store.pending_count() == 1 def test_submit_queue_full(db_path: str) -> None: - s = SQLiteJournal(StoreConfig(path=db_path, max_pending=2)) + s = InMemoryStore(StoreConfig(path=db_path, max_pending=2)) s.submit(_envelope()) s.submit(_envelope()) with pytest.raises(QueueFullError): s.submit(_envelope()) -def test_due_tasks_ordered_by_priority(store: SQLiteJournal) -> None: +def test_due_tasks_ordered_by_priority(store: InMemoryStore) -> None: store.submit(_envelope(task_id="lo", priority=0)) store.submit(_envelope(task_id="hi", priority=10)) due = store.due_tasks(limit=10) @@ -278,7 +277,7 @@ def test_due_tasks_ordered_by_priority(store: SQLiteJournal) -> None: assert due[1]["task_id"] == "lo" -def test_ack_happy_path(store: SQLiteJournal) -> None: +def test_ack_happy_path(store: InMemoryStore) -> None: env = _envelope() store.submit(env) w = _worker_state() @@ -288,7 +287,7 @@ def test_ack_happy_path(store: SQLiteJournal) -> None: assert store.get_task(env.task_id)["state"] == "done" -def test_ack_wrong_lease_rejected(store: SQLiteJournal) -> None: +def test_ack_wrong_lease_rejected(store: InMemoryStore) -> None: env = _envelope() store.submit(env) w = _worker_state() @@ -297,7 +296,7 @@ def test_ack_wrong_lease_rejected(store: SQLiteJournal) -> None: assert not store.ack(env.task_id, w.worker_id, "wrong") -def test_late_ack_after_requeue_ignored(store: SQLiteJournal) -> None: +def test_late_ack_after_requeue_ignored(store: InMemoryStore) -> None: env = _envelope() store.submit(env) w = _worker_state() @@ -307,7 +306,7 @@ def test_late_ack_after_requeue_ignored(store: SQLiteJournal) -> None: assert not store.ack(env.task_id, w.worker_id, "L2") -def test_nack_requeues_with_backoff(store: SQLiteJournal) -> None: +def test_nack_requeues_with_backoff(store: InMemoryStore) -> None: env = _envelope(max_retries=2, retry_backoff=1.0) store.submit(env) w = _worker_state() @@ -319,7 +318,7 @@ def test_nack_requeues_with_backoff(store: SQLiteJournal) -> None: assert float(task["next_run_at"]) > time.time() -def test_nack_exceeds_retries_fails(store: SQLiteJournal) -> None: +def test_nack_exceeds_retries_fails(store: InMemoryStore) -> None: env = _envelope(max_retries=0) store.submit(env) w = _worker_state() @@ -329,35 +328,29 @@ def test_nack_exceeds_retries_fails(store: SQLiteJournal) -> None: assert store.get_task(env.task_id)["state"] == "failed" -def test_dead_worker_tasks_requeued(store: SQLiteJournal, db_path: str) -> None: +def test_dead_worker_tasks_requeued(store: InMemoryStore) -> None: w = _worker_state() store.register_worker(w) env = _envelope(max_retries=3) store.submit(env) store.mark_leased(env.task_id, w.worker_id, "L5", time.time() + 60) - conn = sqlite3.connect(db_path) - conn.execute("UPDATE workers SET last_seen=0 WHERE worker_id=?", (w.worker_id,)) - conn.commit() - conn.close() + store._workers[w.worker_id].last_seen = 0 # simulate missed heartbeats assert store.recover_dead_workers(heartbeat_timeout=1.0) == 1 assert store.get_task(env.task_id)["state"] == "ready" -def test_choose_worker_least_loaded(store: SQLiteJournal, db_path: str) -> None: +def test_choose_worker_least_loaded(store: InMemoryStore) -> None: w1 = _worker_state(worker_id="w1", capacity=4) w2 = _worker_state(worker_id="w2", capacity=4) store.register_worker(w1) store.register_worker(w2) - conn = sqlite3.connect(db_path) - conn.execute("UPDATE workers SET inflight=3 WHERE worker_id='w1'") - conn.commit() - conn.close() + store._workers["w1"].inflight = 3 # w1 heavily loaded chosen = store.choose_worker("least_loaded", heartbeat_timeout=30.0) assert chosen is not None assert chosen["worker_id"] == "w2" -def test_stats(store: SQLiteJournal) -> None: +def test_stats(store: InMemoryStore) -> None: w = _worker_state() store.register_worker(w) store.submit(_envelope()) @@ -507,36 +500,13 @@ async def test_late_ack_after_requeue_rejected( await hub.stop() +@pytest.mark.skip( + reason="In-memory store has no persistence; restart recovery requires a durable backend." +) async def test_hub_restart_recovers_orphaned_tasks( ctrl_addr: str, db_path: str ) -> None: - """ - Tasks leased at hub shutdown must be requeued when a new hub starts - with the same database. - """ - hub1 = _hub(ctrl_addr, db_path) - await hub1.start() - w1 = FakeWorker(ctrl_addr, capacity=1) - client = FakeClient(ctrl_addr) - await w1.register() - tid = await client.submit(max_retries=3) - env = await w1.recv_task(timeout=3.0) - assert env.task_id == tid - # "kill" hub1 without giving worker a chance to ack - await hub1.stop() - w1.close() - client.close() - - # Task is still leased in the DB - assert hub1.store.get_task(tid)["state"] == "leased" - - hub2 = _hub(ctrl_addr, db_path) - await hub2.start() - await asyncio.sleep(0.3) # allow startup recovery - try: - assert hub2.store.get_task(tid)["state"] == "ready" - finally: - await hub2.stop() + """Persistence across restarts is not supported by the in-memory store.""" async def test_concurrent_heartbeats(ctrl_addr: str, db_path: str) -> None: From b8c72127c53918ef63120d3755b1c9148695c971 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 18:37:06 +0300 Subject: [PATCH 3/5] Refactoring the NNG support solution v2: Update routing policy --- taskiq/brokers/nng/__init__.py | 21 ++++- taskiq/brokers/nng/hub.py | 17 +++- taskiq/brokers/nng/storage.py | 138 +++++++++++++++++++++++++++--- tests/brokers/test_nng_broker.py | 141 ++++++++++++++++++++++++++++++- 4 files changed, 295 insertions(+), 22 deletions(-) diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py index 8e2a7f4a..75ba0d95 100644 --- a/taskiq/brokers/nng/__init__.py +++ b/taskiq/brokers/nng/__init__.py @@ -8,18 +8,37 @@ WorkerState, WorkerStatus, ) -from .storage import InMemoryStore, QueueFullError, StoreConfig +from .storage import ( + InMemoryStore, + LeastLoaded, + PowerOfTwoChoices, + QueueFullError, + RoutingPolicy, + RoundRobin, + StoreConfig, + WorkerView, + make_routing_policy, +) __all__ = [ "HubConfig", "NNGHub", + # protocol "ControlMessage", "ControlResponse", "MessageKind", "TaskEnvelope", "WorkerState", "WorkerStatus", + # store "QueueFullError", "InMemoryStore", "StoreConfig", + # routing + "WorkerView", + "RoutingPolicy", + "LeastLoaded", + "PowerOfTwoChoices", + "RoundRobin", + "make_routing_policy", ] diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py index c58857bf..37a935b9 100644 --- a/taskiq/brokers/nng/hub.py +++ b/taskiq/brokers/nng/hub.py @@ -37,7 +37,13 @@ TaskEnvelope, WorkerState, ) -from .storage import InMemoryStore, QueueFullError, StoreConfig +from .storage import ( + InMemoryStore, + QueueFullError, + RoutingPolicy, + StoreConfig, + make_routing_policy, +) logger = logging.getLogger(__name__) @@ -53,7 +59,7 @@ class HubConfig: lease_timeout: float = 20.0 dispatch_interval: float = 0.05 reaper_interval: float = 0.5 - routing_policy: str = "least_loaded" + routing_policy: RoutingPolicy | str = "least_loaded" backoff_cap: float = 60.0 # Number of concurrent Rep0 contexts. Each context handles one req/rep # pair independently; N contexts ≈ N simultaneous control-plane clients. @@ -107,6 +113,9 @@ def __init__(self, config: HubConfig) -> None: backoff_cap=config.backoff_cap, ), ) + # Resolve once at construction so RoundRobin and similar stateful + # policies maintain their counter across dispatch calls. + self._routing: RoutingPolicy = make_routing_policy(config.routing_policy) self._stop = asyncio.Event() self._ctrl_sock: Any = None # pynng.Rep0 self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 @@ -282,7 +291,7 @@ async def _dispatch_once(self) -> bool: sent_any = False for row in due: worker = self.store.choose_worker( - self.config.routing_policy, + self._routing, heartbeat_timeout=self.config.heartbeat_timeout, ) if worker is None: @@ -384,7 +393,7 @@ def _build_config() -> HubConfig: ) p.add_argument( "--routing-policy", - choices=["least_loaded", "p2c"], + choices=["least_loaded", "p2c", "round_robin"], default=os.getenv("NNG_ROUTING_POLICY", "least_loaded"), ) p.add_argument( diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py index 87970adf..81ab1899 100644 --- a/taskiq/brokers/nng/storage.py +++ b/taskiq/brokers/nng/storage.py @@ -4,7 +4,7 @@ import random import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from .protocol import TaskEnvelope, WorkerState @@ -103,6 +103,113 @@ def as_dict(self) -> dict[str, Any]: } +# ── routing policy abstraction ──────────────────────────────────────────────── + + +@dataclass(frozen=True) +class WorkerView: + """Immutable worker snapshot passed to :class:`RoutingPolicy` implementations.""" + + worker_id: str + inflight: int + capacity: int + + @property + def load(self) -> float: + """Fractional load: 0.0 idle → 1.0 at capacity.""" + return self.inflight / max(self.capacity, 1) + + +@runtime_checkable +class RoutingPolicy(Protocol): + """Strategy interface for selecting a dispatch target from available workers.""" + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the chosen worker, or None to hold off dispatch.""" + ... + + +class LeastLoaded: + """Pick the worker with the lowest inflight / capacity ratio.""" + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the least-loaded worker.""" + if not workers: + return None + return min(workers, key=lambda w: w.load) + + +class PowerOfTwoChoices: + """ + Power-of-two-choices routing. + + Samples two workers uniformly at random and returns the less loaded one. + Reduces hot-spot probability under high concurrency compared to pure random. + """ + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the less loaded of two randomly sampled workers.""" + if not workers: + return None + if len(workers) == 1: + return workers[0] + a, b = random.sample(workers, k=2) # noqa: S311 + return a if a.load <= b.load else b + + +class RoundRobin: + """ + Round-robin routing — cycle through workers in alphabetical ID order. + + Ignores load; useful when tasks are homogeneous and worker capacity is equal. + The counter is per-instance, so each :class:`NNGHub` maintains its own cycle. + """ + + def __init__(self) -> None: + """Initialise the cycle counter.""" + self._idx: int = 0 + + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + """Return the next worker in the cycle.""" + if not workers: + return None + w = workers[self._idx % len(workers)] + self._idx += 1 + return w + + +# Singletons for stateless built-ins; RoundRobin singleton is fine for single-hub +# processes. Users needing isolated counters should pass their own instance. +_BUILTIN_POLICIES: dict[str, RoutingPolicy] = { + "least_loaded": LeastLoaded(), + "p2c": PowerOfTwoChoices(), + "round_robin": RoundRobin(), +} + + +def make_routing_policy(policy: "RoutingPolicy | str") -> RoutingPolicy: + """ + Resolve a routing policy name or pass through an instance. + + :param policy: ``'least_loaded'``, ``'p2c'``, ``'round_robin'``, or a + :class:`RoutingPolicy` instance. + :return: concrete routing policy. + :raises ValueError: for unknown string names. + """ + if isinstance(policy, str): + resolved = _BUILTIN_POLICIES.get(policy) + if resolved is None: + raise ValueError( + f"Unknown routing policy {policy!r}; " + f"available: {sorted(_BUILTIN_POLICIES)}" + ) + return resolved + return policy + + +# ── store ───────────────────────────────────────────────────────────────────── + + class InMemoryStore: """ Pure in-memory task store for the NNG hub. @@ -384,17 +491,17 @@ def mark_draining(self, worker_id: str) -> None: def choose_worker( self, - routing_policy: str = "least_loaded", + policy: "RoutingPolicy | str" = "least_loaded", *, heartbeat_timeout: float = 15.0, ) -> dict[str, Any] | None: """ - Select the best available worker according to ``routing_policy``. + Select the best available worker using a routing policy. - ``'least_loaded'`` picks the worker with the lowest inflight/capacity - ratio. ``'p2c'`` samples two workers and picks the less loaded one. + Accepts a :class:`RoutingPolicy` instance or a string name + (``'least_loaded'``, ``'p2c'``, ``'round_robin'``). - :param routing_policy: ``'least_loaded'`` or ``'p2c'``. + :param policy: routing policy or name. :param heartbeat_timeout: seconds before a worker is considered stale. :return: chosen worker dict, or None if no worker has capacity. """ @@ -408,14 +515,17 @@ def choose_worker( ] if not available: return None - if routing_policy == "p2c" and len(available) >= 2: - a, b = random.sample(available, k=2) # noqa: S311 - load_a = a.inflight / max(a.capacity, 1) - load_b = b.inflight / max(b.capacity, 1) - chosen = a if load_a <= load_b else b - else: - chosen = min(available, key=lambda w: w.inflight / max(w.capacity, 1)) - return chosen.as_dict() + # Stable sort so RoundRobin cycles in a predictable, deterministic order. + views = sorted( + [WorkerView(w.worker_id, w.inflight, w.capacity) for w in available], + key=lambda v: v.worker_id, + ) + routing = make_routing_policy(policy) + chosen = routing.choose(views) + if chosen is None: + return None + worker = self._workers.get(chosen.worker_id) + return worker.as_dict() if worker is not None else None # ── observability ───────────────────────────────────────────────────────── diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py index 4bb9c4b6..0b8bcc43 100644 --- a/tests/brokers/test_nng_broker.py +++ b/tests/brokers/test_nng_broker.py @@ -28,13 +28,19 @@ NNGHub, ControlMessage, ControlResponse, + InMemoryStore, + LeastLoaded, MessageKind, + PowerOfTwoChoices, + QueueFullError, + RoutingPolicy, + RoundRobin, + StoreConfig, TaskEnvelope, WorkerState, WorkerStatus, - QueueFullError, - InMemoryStore, - StoreConfig, + WorkerView, + make_routing_policy, ) @@ -544,3 +550,132 @@ async def test_graceful_drain_and_unregister(ctrl_addr: str, db_path: str) -> No finally: worker.close() await hub.stop() + + +# ── 2b. Routing policy unit tests ───────────────────────────────────────────── + + +def test_least_loaded_picks_idle_worker() -> None: + policy = LeastLoaded() + workers = [WorkerView("w1", inflight=3, capacity=4), WorkerView("w2", inflight=0, capacity=4)] + assert policy.choose(workers).worker_id == "w2" # type: ignore[union-attr] + + +def test_least_loaded_empty_returns_none() -> None: + assert LeastLoaded().choose([]) is None + + +def test_p2c_returns_a_worker() -> None: + policy = PowerOfTwoChoices() + workers = [WorkerView("w1", 1, 4), WorkerView("w2", 2, 4), WorkerView("w3", 0, 4)] + chosen = policy.choose(workers) + assert chosen is not None + assert chosen.worker_id in {"w1", "w2", "w3"} + + +def test_p2c_single_worker() -> None: + policy = PowerOfTwoChoices() + workers = [WorkerView("only", 0, 4)] + assert policy.choose(workers).worker_id == "only" # type: ignore[union-attr] + + +def test_round_robin_cycles() -> None: + policy = RoundRobin() + workers = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4), WorkerView("w3", 0, 4)] + ids = [policy.choose(workers).worker_id for _ in range(6)] # type: ignore[union-attr] + assert ids == ["w1", "w2", "w3", "w1", "w2", "w3"] + + +def test_make_routing_policy_string() -> None: + assert isinstance(make_routing_policy("least_loaded"), LeastLoaded) + assert isinstance(make_routing_policy("p2c"), PowerOfTwoChoices) + assert isinstance(make_routing_policy("round_robin"), RoundRobin) + + +def test_make_routing_policy_instance_passthrough() -> None: + policy = LeastLoaded() + assert make_routing_policy(policy) is policy + + +def test_make_routing_policy_unknown_raises() -> None: + with pytest.raises(ValueError, match="Unknown routing policy"): + make_routing_policy("no_such_policy") + + +def test_custom_routing_policy_accepted(store: InMemoryStore) -> None: + """Users can pass a RoutingPolicy instance directly to choose_worker.""" + + class AlwaysFirstPolicy: + """Trivial policy: always pick the worker with the lexicographically smallest ID.""" + def choose(self, workers: list[WorkerView]) -> WorkerView | None: + return min(workers, key=lambda w: w.worker_id) if workers else None + + policy = AlwaysFirstPolicy() + # Verify it satisfies the Protocol at runtime. + assert isinstance(policy, RoutingPolicy) + + w1 = _worker_state(worker_id="aaa", capacity=4) + w2 = _worker_state(worker_id="zzz", capacity=4) + store.register_worker(w1) + store.register_worker(w2) + chosen = store.choose_worker(policy, heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] == "aaa" + + +def test_choose_worker_p2c(store: InMemoryStore) -> None: + """P2C routing returns one of the registered workers.""" + for i in range(4): + store.register_worker(_worker_state(worker_id=f"w{i}", capacity=4)) + chosen = store.choose_worker("p2c", heartbeat_timeout=30.0) + assert chosen is not None + assert chosen["worker_id"] in {f"w{i}" for i in range(4)} + + +def test_hub_accepts_policy_instance(ctrl_addr: str, db_path: str) -> None: + """HubConfig.routing_policy accepts a RoutingPolicy instance.""" + hub = NNGHub(HubConfig( + control_addr=ctrl_addr, + routing_policy=RoundRobin(), + max_pending=100, + )) + assert isinstance(hub._routing, RoundRobin) + + +# ── 3b. Backpressure integration test ──────────────────────────────────────── + + +async def test_backpressure_hub_rejects_when_full( + ctrl_addr: str, db_path: str +) -> None: + """Hub returns error=queue full when max_pending is reached.""" + hub = _hub(ctrl_addr, db_path, max_pending=1) + await hub.start() + client = FakeClient(ctrl_addr) + try: + await client.submit() # fills the one slot (no worker → stays queued) + # Second submission must be rejected + payload: dict[str, object] = { + "task_id": uuid.uuid4().hex, + "task_name": "tests:task", + "payload_b64": "dGVzdA==", + "labels": {}, + "lease_id": "", + "attempts": 0, + "max_retries": 0, + "retry_backoff": 1.0, + "retry_jitter": 0.0, + "priority": 0, + "created_at": time.time(), + } + async with client._lock: + await client._ctrl.asend( + ControlMessage(kind="submit", payload=payload).to_bytes() + ) + raw = await client._ctrl.arecv() + resp = ControlResponse.from_bytes(raw) + assert not resp.ok + assert resp.error == "queue full" + finally: + client.close() + await hub.stop() From fe3b0a5883d59a075d792e0ea1c7c91541bf7c11 Mon Sep 17 00:00:00 2001 From: Alexandr Tedeev Date: Sun, 26 Apr 2026 23:36:47 +0300 Subject: [PATCH 4/5] Refactoring the NNG support solution v2.5: Update routing policy, add affinity policy, and scheduler abstraction. --- taskiq/brokers/nng/__init__.py | 9 ++ taskiq/brokers/nng/hub.py | 15 ++- taskiq/brokers/nng/storage.py | 108 +++++++++++++++- tests/brokers/test_nng_broker.py | 216 ++++++++++++++++++++++++++++++- 4 files changed, 336 insertions(+), 12 deletions(-) diff --git a/taskiq/brokers/nng/__init__.py b/taskiq/brokers/nng/__init__.py index 75ba0d95..1e0bdcea 100644 --- a/taskiq/brokers/nng/__init__.py +++ b/taskiq/brokers/nng/__init__.py @@ -9,13 +9,17 @@ WorkerStatus, ) from .storage import ( + AffinityPolicy, InMemoryStore, LeastLoaded, PowerOfTwoChoices, + PriorityScheduler, QueueFullError, RoutingPolicy, RoundRobin, + Scheduler, StoreConfig, + TaskContext, WorkerView, make_routing_policy, ) @@ -35,10 +39,15 @@ "InMemoryStore", "StoreConfig", # routing + "TaskContext", "WorkerView", "RoutingPolicy", + "AffinityPolicy", "LeastLoaded", "PowerOfTwoChoices", "RoundRobin", "make_routing_policy", + # scheduler + "Scheduler", + "PriorityScheduler", ] diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py index 37a935b9..f3920c6f 100644 --- a/taskiq/brokers/nng/hub.py +++ b/taskiq/brokers/nng/hub.py @@ -39,9 +39,12 @@ ) from .storage import ( InMemoryStore, + PriorityScheduler, QueueFullError, RoutingPolicy, + Scheduler, StoreConfig, + TaskContext, make_routing_policy, ) @@ -60,6 +63,7 @@ class HubConfig: dispatch_interval: float = 0.05 reaper_interval: float = 0.5 routing_policy: RoutingPolicy | str = "least_loaded" + scheduler: Scheduler | None = None backoff_cap: float = 60.0 # Number of concurrent Rep0 contexts. Each context handles one req/rep # pair independently; N contexts ≈ N simultaneous control-plane clients. @@ -116,6 +120,7 @@ def __init__(self, config: HubConfig) -> None: # Resolve once at construction so RoundRobin and similar stateful # policies maintain their counter across dispatch calls. self._routing: RoutingPolicy = make_routing_policy(config.routing_policy) + self._scheduler: Scheduler = config.scheduler or PriorityScheduler() self._stop = asyncio.Event() self._ctrl_sock: Any = None # pynng.Rep0 self._worker_push: dict[str, Any] = {} # worker_id -> pynng.Push0 @@ -285,14 +290,22 @@ async def _dispatch_loop(self) -> None: async def _dispatch_once(self) -> bool: """Dispatch up to ``dispatch_batch`` due tasks to available workers.""" - due = self.store.due_tasks(self.config.dispatch_batch) + due = self._scheduler.select(self.store, self.config.dispatch_batch) if not due: return False sent_any = False for row in due: + task_ctx = TaskContext( + task_id=row["task_id"], + task_name=row["task_name"], + labels=row["labels"], + priority=int(row["priority"]), + attempts=int(row["attempts"]), + ) worker = self.store.choose_worker( self._routing, heartbeat_timeout=self.config.heartbeat_timeout, + task=task_ctx, ) if worker is None: return sent_any # no capacity; leave remaining tasks in queue diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py index 81ab1899..b804400a 100644 --- a/taskiq/brokers/nng/storage.py +++ b/taskiq/brokers/nng/storage.py @@ -1,6 +1,8 @@ """Pure in-memory task store for the NNG hub — no external dependencies.""" from __future__ import annotations +import functools +import inspect import random import time from dataclasses import dataclass, field @@ -103,6 +105,20 @@ def as_dict(self) -> dict[str, Any]: } +# ── task context ───────────────────────────────────────────────────────────── + + +@dataclass +class TaskContext: + """Task metadata passed to context-aware routing policies (e.g. affinity).""" + + task_id: str + task_name: str + labels: dict[str, Any] + priority: int = 0 + attempts: int = 0 + + # ── routing policy abstraction ──────────────────────────────────────────────── @@ -178,12 +194,70 @@ def choose(self, workers: list[WorkerView]) -> WorkerView | None: return w -# Singletons for stateless built-ins; RoundRobin singleton is fine for single-hub -# processes. Users needing isolated counters should pass their own instance. +class AffinityPolicy: + """ + Sticky routing: tasks with the same ``affinity_key`` label always go to the + same worker. Falls back to least-loaded when the preferred worker is gone. + + The affinity table is per-instance and lives only in memory. + """ + + def __init__(self) -> None: + """Initialise an empty affinity table.""" + self._table: dict[str, str] = {} # affinity_key → worker_id + + def choose( + self, + workers: list[WorkerView], + task: "TaskContext | None" = None, + ) -> WorkerView | None: + """Return the sticky worker for the task's affinity key, or least-loaded.""" + if not workers: + return None + if task is not None: + key = str(task.labels.get("affinity_key", "")) + if key and key in self._table: + match = next( + (w for w in workers if w.worker_id == self._table[key]), None + ) + if match is not None: + return match + chosen = min(workers, key=lambda w: w.load) + if task is not None: + key = str(task.labels.get("affinity_key", "")) + if key: + self._table[key] = chosen.worker_id + return chosen + + +@functools.lru_cache(maxsize=None) +def _policy_accepts_task(policy_cls: type) -> bool: + """Return True if policy.choose accepts a ``task`` keyword argument.""" + try: + return "task" in inspect.signature(policy_cls.choose).parameters + except (ValueError, TypeError): + return False + + +def _choose_with_context( + policy: RoutingPolicy, + views: list[WorkerView], + task: "TaskContext | None", +) -> "WorkerView | None": + """Call policy.choose, passing ``task`` only when the policy supports it.""" + if task is not None and _policy_accepts_task(type(policy)): + return policy.choose(views, task=task) # type: ignore[call-arg] + return policy.choose(views) + + +# Singletons for stateless built-ins; RoundRobin/AffinityPolicy singletons are +# fine for single-hub processes. Users needing isolated state should pass their +# own instance. _BUILTIN_POLICIES: dict[str, RoutingPolicy] = { "least_loaded": LeastLoaded(), "p2c": PowerOfTwoChoices(), "round_robin": RoundRobin(), + "affinity": AffinityPolicy(), # type: ignore[dict-item] } @@ -207,6 +281,26 @@ def make_routing_policy(policy: "RoutingPolicy | str") -> RoutingPolicy: return policy +# ── scheduler abstraction ───────────────────────────────────────────────────── + + +@runtime_checkable +class Scheduler(Protocol): + """Strategy interface for selecting which tasks to dispatch next.""" + + def select(self, store: "InMemoryStore", limit: int) -> list[dict[str, Any]]: + """Return up to ``limit`` tasks ready for dispatch.""" + ... + + +class PriorityScheduler: + """Default scheduler: highest-priority due tasks first.""" + + def select(self, store: "InMemoryStore", limit: int) -> list[dict[str, Any]]: + """Delegate to :meth:`InMemoryStore.due_tasks`.""" + return store.due_tasks(limit) + + # ── store ───────────────────────────────────────────────────────────────────── @@ -494,15 +588,21 @@ def choose_worker( policy: "RoutingPolicy | str" = "least_loaded", *, heartbeat_timeout: float = 15.0, + task: "TaskContext | None" = None, ) -> dict[str, Any] | None: """ Select the best available worker using a routing policy. Accepts a :class:`RoutingPolicy` instance or a string name - (``'least_loaded'``, ``'p2c'``, ``'round_robin'``). + (``'least_loaded'``, ``'p2c'``, ``'round_robin'``, ``'affinity'``). + + Context-aware policies (e.g. :class:`AffinityPolicy`) receive the + optional ``task`` argument when they declare it in their ``choose`` + signature. :param policy: routing policy or name. :param heartbeat_timeout: seconds before a worker is considered stale. + :param task: optional task context for context-aware policies. :return: chosen worker dict, or None if no worker has capacity. """ cutoff = time.time() - heartbeat_timeout @@ -521,7 +621,7 @@ def choose_worker( key=lambda v: v.worker_id, ) routing = make_routing_policy(policy) - chosen = routing.choose(views) + chosen = _choose_with_context(routing, views, task) if chosen is None: return None worker = self._workers.get(chosen.worker_id) diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py index 0b8bcc43..7bb7b2db 100644 --- a/tests/brokers/test_nng_broker.py +++ b/tests/brokers/test_nng_broker.py @@ -15,7 +15,9 @@ import asyncio import os +import sys import tempfile +import textwrap import time import uuid @@ -24,6 +26,7 @@ pynng = pytest.importorskip("pynng") from taskiq.brokers.nng import ( + AffinityPolicy, HubConfig, NNGHub, ControlMessage, @@ -32,10 +35,13 @@ LeastLoaded, MessageKind, PowerOfTwoChoices, + PriorityScheduler, QueueFullError, RoutingPolicy, RoundRobin, + Scheduler, StoreConfig, + TaskContext, TaskEnvelope, WorkerState, WorkerStatus, @@ -86,16 +92,19 @@ def _worker_state( def _hub(control_addr: str, db_path: str, **kwargs: object) -> NNGHub: + defaults: dict[str, object] = { + "max_pending": 100, + "heartbeat_timeout": 2.0, + "lease_timeout": 2.0, + "dispatch_interval": 0.02, + "reaper_interval": 0.1, + "control_concurrency": 4, + } + defaults.update(kwargs) cfg = HubConfig( control_addr=control_addr, task_db=db_path, - max_pending=100, - heartbeat_timeout=2.0, - lease_timeout=2.0, - dispatch_interval=0.02, - reaper_interval=0.1, - control_concurrency=4, - **kwargs, # type: ignore[arg-type] + **defaults, # type: ignore[arg-type] ) return NNGHub(cfg) @@ -679,3 +688,196 @@ async def test_backpressure_hub_rejects_when_full( finally: client.close() await hub.stop() + + +# ── 2c. AffinityPolicy unit tests ──────────────────────────────────────────── + + +def test_affinity_policy_sticks_to_worker() -> None: + """Same affinity_key must route to the same worker across calls.""" + policy = AffinityPolicy() + workers = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {"affinity_key": "user-42"}) + first = policy.choose(workers, task=task) + assert first is not None + for _ in range(10): + chosen = policy.choose(workers, task=task) + assert chosen is not None + assert chosen.worker_id == first.worker_id + + +def test_affinity_policy_falls_back_when_worker_gone() -> None: + """When the sticky worker is no longer available, fall back to least-loaded.""" + policy = AffinityPolicy() + workers_full = [WorkerView("w1", 0, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {"affinity_key": "key-x"}) + first = policy.choose(workers_full, task=task) + assert first is not None + # Remove the sticky worker — only the other one remains. + remaining = [w for w in workers_full if w.worker_id != first.worker_id] + fallback = policy.choose(remaining, task=task) + assert fallback is not None + assert fallback.worker_id != first.worker_id + + +def test_affinity_policy_no_key_uses_least_loaded() -> None: + """Tasks without affinity_key get least-loaded routing.""" + policy = AffinityPolicy() + workers = [WorkerView("w1", 3, 4), WorkerView("w2", 0, 4)] + task = TaskContext("t1", "fn", {}) + chosen = policy.choose(workers, task=task) + assert chosen is not None + assert chosen.worker_id == "w2" + + +def test_affinity_policy_is_routing_policy() -> None: + assert isinstance(AffinityPolicy(), RoutingPolicy) + + +def test_choose_worker_affinity_string(store: InMemoryStore) -> None: + """String 'affinity' resolves to the singleton AffinityPolicy via choose_worker.""" + for wid in ("a1", "a2"): + store.register_worker(_worker_state(worker_id=wid, capacity=4)) + task = TaskContext("t1", "fn", {"affinity_key": "session-1"}) + first = store.choose_worker("affinity", heartbeat_timeout=30.0, task=task) + assert first is not None + for _ in range(5): + chosen = store.choose_worker("affinity", heartbeat_timeout=30.0, task=task) + assert chosen is not None + assert chosen["worker_id"] == first["worker_id"] + + +# ── 2d. Scheduler unit tests ───────────────────────────────────────────────── + + +def test_priority_scheduler_delegates_to_due_tasks(store: InMemoryStore) -> None: + store.submit(_envelope(task_id="lo", priority=0)) + store.submit(_envelope(task_id="hi", priority=5)) + sched = PriorityScheduler() + rows = sched.select(store, limit=10) + assert rows[0]["task_id"] == "hi" + + +def test_priority_scheduler_is_scheduler() -> None: + assert isinstance(PriorityScheduler(), Scheduler) + + +def test_custom_scheduler_used_by_hub(ctrl_addr: str, db_path: str) -> None: + """HubConfig.scheduler accepts a custom Scheduler instance.""" + + class NoopScheduler: + """Never returns tasks — useful for verifying it is actually called.""" + called = False + + def select( + self, store: InMemoryStore, limit: int + ) -> list[dict[str, object]]: + NoopScheduler.called = True + return [] + + scheduler = NoopScheduler() + assert isinstance(scheduler, Scheduler) + hub = NNGHub(HubConfig( + control_addr=ctrl_addr, + scheduler=scheduler, + max_pending=10, + )) + assert hub._scheduler is scheduler + + +# ── 4. Multiprocess integration test ───────────────────────────────────────── + +_WORKER_SCRIPT = textwrap.dedent("""\ + import asyncio, sys, os + sys.path.insert(0, {root!r}) + try: + import pynng # noqa: F401 + from taskiq.brokers.nng.broker import NNGBroker + except Exception as exc: + sys.stdout.write(f"SKIP:{{exc}}\\n") + sys.stdout.flush() + sys.exit(0) + + async def main() -> None: + broker = NNGBroker( + {ctrl_addr!r}, + worker_task_addr={task_addr!r}, + worker_id={worker_id!r}, + capacity=1, + heartbeat_interval=1.0, + recv_timeout_ms=3000, + send_timeout_ms=3000, + ) + broker.is_worker_process = True + await broker.startup() + sys.stdout.write("READY\\n") + sys.stdout.flush() + async for msg in broker.listen(): + sys.stdout.write(f"TASK:{{msg.data.decode()}}\\n") + sys.stdout.flush() + await msg.ack() + break + await broker.shutdown() + + asyncio.run(main()) +""") + + +async def test_multiprocess_worker_receives_task( + ctrl_addr: str, db_path: str +) -> None: + """A real subprocess worker (separate OS process) receives and acks a task.""" + repo_root = str( + __import__("pathlib").Path(__file__).parent.parent.parent.resolve() + ) + task_addr = _ipc("mp-worker") + worker_id = f"mp-{uuid.uuid4().hex[:8]}" + + script = _WORKER_SCRIPT.format( + root=repo_root, + ctrl_addr=ctrl_addr, + task_addr=task_addr, + worker_id=worker_id, + ) + + hub = _hub(ctrl_addr, db_path) + await hub.start() + client = FakeClient(ctrl_addr) + + proc = await asyncio.create_subprocess_exec( + sys.executable, "-c", script, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + async def _read_line(timeout: float = 10.0) -> str: + assert proc.stdout is not None + line = await asyncio.wait_for(proc.stdout.readline(), timeout=timeout) + return line.decode().strip() + + try: + first_line = await _read_line(timeout=10.0) + if first_line.startswith("SKIP:"): + pytest.skip(f"Worker subprocess skipped: {first_line[5:]}") + + assert first_line == "READY", f"Expected READY, got: {first_line!r}" + + # Submit a task now that the worker is registered and listening. + tid = await client.submit() + + task_line = await _read_line(timeout=10.0) + assert task_line.startswith("TASK:"), f"Expected TASK:..., got: {task_line!r}" + + await proc.wait() + + # Give hub's reaper a tick to process the ack. + await asyncio.sleep(0.2) + state = hub.store.get_task(tid) + assert state is not None + assert state["state"] == "done", f"Expected done, got {state['state']!r}" + finally: + if proc.returncode is None: + proc.terminate() + await proc.wait() + client.close() + await hub.stop() From 6e8cdc485e06b1dcdc552d660fa5591eb865b7ff Mon Sep 17 00:00:00 2001 From: Alex Ted Date: Sat, 23 May 2026 22:47:35 +0300 Subject: [PATCH 5/5] refactor: Replace as_dict with dataclasses.asdict; Remove user-level retry orchestration from broker/hub/protocol --- taskiq/brokers/nng/broker.py | 19 ---------- taskiq/brokers/nng/hub.py | 6 +--- taskiq/brokers/nng/protocol.py | 10 +++--- taskiq/brokers/nng/storage.py | 62 ++++++++++---------------------- tests/brokers/test_nng_broker.py | 60 ++++++++++++++++--------------- 5 files changed, 57 insertions(+), 100 deletions(-) diff --git a/taskiq/brokers/nng/broker.py b/taskiq/brokers/nng/broker.py index a6273e41..3998c073 100644 --- a/taskiq/brokers/nng/broker.py +++ b/taskiq/brokers/nng/broker.py @@ -83,9 +83,6 @@ def __init__( heartbeat_interval: float = 5.0, lease_timeout: float = 20.0, capacity: int = 1, - max_retries: int = 0, - retry_backoff: float = 1.0, - retry_jitter: float = 0.0, recv_timeout_ms: int = 5_000, send_timeout_ms: int = 5_000, ) -> None: @@ -102,9 +99,6 @@ def __init__( :param heartbeat_interval: seconds between heartbeat messages to hub. :param lease_timeout: seconds a dispatched task lease remains valid. :param capacity: max concurrent tasks this worker will accept. - :param max_retries: default max retries for submitted tasks. - :param retry_backoff: base seconds for exponential backoff. - :param retry_jitter: jitter multiplier added to backoff (0 = no jitter). :param recv_timeout_ms: Req0 recv timeout in milliseconds. :param send_timeout_ms: Req0 send timeout in milliseconds. """ @@ -123,9 +117,6 @@ def __init__( self.heartbeat_interval = heartbeat_interval self.lease_timeout = lease_timeout self.capacity = capacity - self.max_retries = max_retries - self.retry_backoff = retry_backoff - self.retry_jitter = retry_jitter self.recv_timeout_ms = recv_timeout_ms self.send_timeout_ms = send_timeout_ms @@ -249,16 +240,6 @@ async def kick(self, message: BrokerMessage) -> None: "payload_b64": base64.b64encode(message.message).decode("ascii"), "labels": message.labels, "lease_id": "", # hub assigns the real lease_id at dispatch time - "attempts": int(message.labels.get("attempts", 0)), - "max_retries": int( - message.labels.get("max_retries", self.max_retries), - ), - "retry_backoff": float( - message.labels.get("retry_backoff", self.retry_backoff), - ), - "retry_jitter": float( - message.labels.get("retry_jitter", self.retry_jitter), - ), "priority": int(message.labels.get("priority", 0)), "created_at": time.time(), } diff --git a/taskiq/brokers/nng/hub.py b/taskiq/brokers/nng/hub.py index f3920c6f..d31374d7 100644 --- a/taskiq/brokers/nng/hub.py +++ b/taskiq/brokers/nng/hub.py @@ -300,7 +300,7 @@ async def _dispatch_once(self) -> bool: task_name=row["task_name"], labels=row["labels"], priority=int(row["priority"]), - attempts=int(row["attempts"]), + attempts=int(row.get("attempts", 0)), ) worker = self.store.choose_worker( self._routing, @@ -334,10 +334,6 @@ async def _dispatch_once(self) -> bool: payload_b64=base64.b64encode(row["payload"]).decode("ascii"), labels=row["labels"], lease_id=lease_id, - attempts=int(row["attempts"]) + 1, - max_retries=int(row["max_retries"]), - retry_backoff=float(row["retry_backoff"]), - retry_jitter=float(row["retry_jitter"]), priority=int(row["priority"]), created_at=float(row["created_at"]), ) diff --git a/taskiq/brokers/nng/protocol.py b/taskiq/brokers/nng/protocol.py index 9b0b4d8e..c715e7c7 100644 --- a/taskiq/brokers/nng/protocol.py +++ b/taskiq/brokers/nng/protocol.py @@ -57,6 +57,11 @@ class TaskEnvelope: ``lease_id`` is the UUID assigned by the hub at dispatch time. Workers must echo it back in the ACK so the hub can validate that the ack is not stale (e.g. after lease expiry and requeue). + + User-level retry policy is the responsibility of the + :class:`~taskiq.middlewares.SmartRetryMiddleware` (or any compatible + middleware) and travels in :attr:`labels`; the envelope itself carries + no retry knobs. """ task_id: str @@ -64,10 +69,6 @@ class TaskEnvelope: payload_b64: str labels: dict[str, Any] = field(default_factory=dict) lease_id: str = "" - attempts: int = 0 - max_retries: int = 0 - retry_backoff: float = 1.0 - retry_jitter: float = 0.0 priority: int = 0 created_at: float = 0.0 @@ -157,3 +158,4 @@ def to_dict(self) -> dict[str, Any]: d = asdict(self) d["status"] = str(self.status) return d + diff --git a/taskiq/brokers/nng/storage.py b/taskiq/brokers/nng/storage.py index b804400a..4c0dc1d6 100644 --- a/taskiq/brokers/nng/storage.py +++ b/taskiq/brokers/nng/storage.py @@ -5,7 +5,7 @@ import inspect import random import time -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: @@ -19,7 +19,11 @@ class StoreConfig: path: str = "" # kept for API compat; not used max_pending: int = 10_000 lease_timeout: float = 30.0 - backoff_base: float = 1.0 + # Hub-internal cap on delivery retries (lease expiry / worker death). + # Has nothing to do with user-level retry policy, which is handled by + # taskiq retry middlewares. Set to 0 to disable redelivery entirely. + max_delivery_attempts: int = 5 + delivery_backoff: float = 1.0 backoff_cap: float = 60.0 @@ -34,10 +38,9 @@ class _Task: payload: bytes labels: dict[str, Any] state: str # ready / leased / done / failed + # Internal delivery-attempt counter (incremented on each dispatch). + # NOT related to user-level retry policy — that lives in middlewares. attempts: int = 0 - max_retries: int = 0 - retry_backoff: float = 1.0 - retry_jitter: float = 0.0 priority: int = 0 created_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) @@ -49,25 +52,7 @@ class _Task: def as_dict(self) -> dict[str, Any]: """Return a dict view of this task record.""" - return { - "task_id": self.task_id, - "task_name": self.task_name, - "payload": self.payload, - "labels": self.labels, - "state": self.state, - "attempts": self.attempts, - "max_retries": self.max_retries, - "retry_backoff": self.retry_backoff, - "retry_jitter": self.retry_jitter, - "priority": self.priority, - "created_at": self.created_at, - "updated_at": self.updated_at, - "next_run_at": self.next_run_at, - "lease_id": self.lease_id, - "leased_worker_id": self.leased_worker_id, - "lease_until": self.lease_until, - "last_error": self.last_error, - } + return asdict(self) def as_status_dict(self) -> dict[str, Any]: """Return a JSON-safe dict (no raw bytes) for control-plane status responses.""" @@ -91,18 +76,7 @@ class _Worker: def as_dict(self) -> dict[str, Any]: """Return a dict view of this worker record.""" - return { - "worker_id": self.worker_id, - "task_addr": self.task_addr, - "capacity": self.capacity, - "inflight": self.inflight, - "last_seen": self.last_seen, - "heartbeat_interval": self.heartbeat_interval, - "lease_timeout": self.lease_timeout, - "draining": self.draining, - "status": self.status, - "version": self.version, - } + return asdict(self) # ── task context ───────────────────────────────────────────────────────────── @@ -325,16 +299,21 @@ def __init__(self, config: StoreConfig) -> None: # ── helpers ─────────────────────────────────────────────────────────────── - def _backoff(self, attempts: int, backoff_base: float) -> float: - return min(self.config.backoff_cap, backoff_base * (2 ** max(0, attempts - 1))) + def _backoff(self, attempts: int) -> float: + return min( + self.config.backoff_cap, + self.config.delivery_backoff * (2 ** max(0, attempts - 1)), + ) def _requeue_or_fail(self, task: _Task, worker_id: str, error: str) -> bool: now = time.time() - if task.attempts > task.max_retries: + # Hub-internal delivery cap. User-level retries are handled by + # retry middlewares, which re-kick the task with updated labels. + if task.attempts > self.config.max_delivery_attempts: task.state = "failed" else: task.state = "ready" - task.next_run_at = now + self._backoff(task.attempts, task.retry_backoff) + task.next_run_at = now + self._backoff(task.attempts) task.last_error = error task.lease_id = None task.leased_worker_id = None @@ -367,9 +346,6 @@ def submit(self, envelope: TaskEnvelope) -> None: payload=envelope.payload, labels=envelope.labels, state="ready", - max_retries=envelope.max_retries, - retry_backoff=envelope.retry_backoff, - retry_jitter=envelope.retry_jitter, priority=envelope.priority, created_at=envelope.created_at or now, updated_at=now, diff --git a/tests/brokers/test_nng_broker.py b/tests/brokers/test_nng_broker.py index 7bb7b2db..e76e3370 100644 --- a/tests/brokers/test_nng_broker.py +++ b/tests/brokers/test_nng_broker.py @@ -65,10 +65,6 @@ def _envelope(**kwargs: object) -> TaskEnvelope: "payload_b64": "dGVzdA==", "labels": {}, "lease_id": "", - "attempts": 0, - "max_retries": 0, - "retry_backoff": 1.0, - "retry_jitter": 0.0, "priority": 0, "created_at": time.time(), } @@ -208,10 +204,6 @@ async def submit(self, **labels: object) -> str: "payload_b64": "dGVzdA==", "labels": {}, "lease_id": "", - "attempts": 0, - "max_retries": labels.pop("max_retries", 0), - "retry_backoff": labels.pop("retry_backoff", 1.0), - "retry_jitter": 0.0, "priority": labels.pop("priority", 0), "created_at": time.time(), } @@ -321,32 +313,46 @@ def test_late_ack_after_requeue_ignored(store: InMemoryStore) -> None: assert not store.ack(env.task_id, w.worker_id, "L2") -def test_nack_requeues_with_backoff(store: InMemoryStore) -> None: - env = _envelope(max_retries=2, retry_backoff=1.0) - store.submit(env) +def test_nack_requeues_with_backoff(db_path: str) -> None: + """Hub-level delivery retry: nack within delivery cap requeues with backoff.""" + s = InMemoryStore( + StoreConfig( + path=db_path, max_pending=50, lease_timeout=5.0, + max_delivery_attempts=2, delivery_backoff=1.0, + ), + ) + env = _envelope() + s.submit(env) w = _worker_state() - store.register_worker(w) - store.mark_leased(env.task_id, w.worker_id, "L3", time.time() + 60) - assert store.nack(env.task_id, w.worker_id, "L3", "boom") - task = store.get_task(env.task_id) + s.register_worker(w) + s.mark_leased(env.task_id, w.worker_id, "L3", time.time() + 60) + assert s.nack(env.task_id, w.worker_id, "L3", "boom") + task = s.get_task(env.task_id) assert task["state"] == "ready" assert float(task["next_run_at"]) > time.time() -def test_nack_exceeds_retries_fails(store: InMemoryStore) -> None: - env = _envelope(max_retries=0) - store.submit(env) +def test_nack_exceeds_delivery_cap_fails(db_path: str) -> None: + """Delivery cap of 0 → first nack fails the task immediately.""" + s = InMemoryStore( + StoreConfig( + path=db_path, max_pending=50, lease_timeout=5.0, + max_delivery_attempts=0, + ), + ) + env = _envelope() + s.submit(env) w = _worker_state() - store.register_worker(w) - store.mark_leased(env.task_id, w.worker_id, "L4", time.time() + 60) - store.nack(env.task_id, w.worker_id, "L4", "error") - assert store.get_task(env.task_id)["state"] == "failed" + s.register_worker(w) + s.mark_leased(env.task_id, w.worker_id, "L4", time.time() + 60) + s.nack(env.task_id, w.worker_id, "L4", "error") + assert s.get_task(env.task_id)["state"] == "failed" def test_dead_worker_tasks_requeued(store: InMemoryStore) -> None: w = _worker_state() store.register_worker(w) - env = _envelope(max_retries=3) + env = _envelope() store.submit(env) store.mark_leased(env.task_id, w.worker_id, "L5", time.time() + 60) store._workers[w.worker_id].last_seen = 0 # simulate missed heartbeats @@ -456,7 +462,7 @@ async def test_worker_crash_before_ack_task_requeued( client = FakeClient(ctrl_addr) try: await w1.register() - tid = await client.submit(max_retries=3) + tid = await client.submit() env1 = await w1.recv_task(timeout=3.0) assert env1.task_id == tid w1.close() # simulate crash without acking @@ -493,7 +499,7 @@ async def test_late_ack_after_requeue_rejected( client = FakeClient(ctrl_addr) try: await w1.register() - tid = await client.submit(max_retries=3) + tid = await client.submit() env1 = await w1.recv_task(timeout=3.0) await asyncio.sleep(3.5) # let lease expire @@ -670,10 +676,6 @@ async def test_backpressure_hub_rejects_when_full( "payload_b64": "dGVzdA==", "labels": {}, "lease_id": "", - "attempts": 0, - "max_retries": 0, - "retry_backoff": 1.0, - "retry_jitter": 0.0, "priority": 0, "created_at": time.time(), }