diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py index 729b213..ccd62ee 100644 --- a/src/climatevision/api/main.py +++ b/src/climatevision/api/main.py @@ -364,6 +364,12 @@ def create_app() -> FastAPI: ) app.add_middleware(AuditLogMiddleware) + # Register the RequestIDMiddleware LAST so it sits OUTERMOST in the + # middleware stack: it must run before every other middleware so the + # request_id ContextVar is set in time for AuditLogMiddleware's logger + # call (and for any inference-pipeline code further down). + from climatevision.api.middleware import RequestIDMiddleware + app.add_middleware(RequestIDMiddleware) app.add_middleware( CORSMiddleware, allow_origins=[ diff --git a/src/climatevision/api/middleware.py b/src/climatevision/api/middleware.py index 7a6a3d0..dfed900 100644 --- a/src/climatevision/api/middleware.py +++ b/src/climatevision/api/middleware.py @@ -10,6 +10,7 @@ import logging import time import uuid +from contextvars import ContextVar from typing import Callable from fastapi import Request, Response @@ -18,6 +19,55 @@ logger = logging.getLogger(__name__) +# Context variable that carries the X-Request-ID through the entire request +# lifecycle, including non-FastAPI code (inference pipeline, helper modules) +# that does not have access to the FastAPI `Request` object. +request_id_var: ContextVar[str | None] = ContextVar("request_id", default=None) + + +class RequestIDMiddleware(BaseHTTPMiddleware): + """ + Propagate ``X-Request-ID`` through the request lifecycle. + + Reads the inbound ``X-Request-ID`` header (or generates a fresh UUID4 if + absent), stores it on a :class:`~contextvars.ContextVar` so any code that + runs during the request -- inference pipeline, helper modules, background + tasks scheduled with ``asyncio.create_task`` -- can read the same value, + and echoes it back on the response. + + Pair this with :class:`RequestIDLogFilter` so log records emitted during + the request automatically carry ``%(request_id)s``. + """ + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + request_id = request.headers.get("X-Request-ID") or str(uuid.uuid4()) + token = request_id_var.set(request_id) + # Mirror to ``request.state`` so handlers that already read from + # there (e.g. the existing ``RequestLoggingMiddleware`` / + # ``AuditLogMiddleware``) keep working. + request.state.request_id = request_id + try: + response = await call_next(request) + finally: + request_id_var.reset(token) + response.headers["X-Request-ID"] = request_id + return response + + +class RequestIDLogFilter(logging.Filter): + """Inject the current request ID into every log record. + + Reads :data:`request_id_var` and exposes it as ``record.request_id`` so + logging formatters can reference ``%(request_id)s``. Records emitted + outside of any request (startup, background workers without context) + receive ``"-"`` so the format string never KeyErrors. + """ + + def filter(self, record: logging.LogRecord) -> bool: # noqa: A003 + record.request_id = request_id_var.get() or "-" + return True + + class RequestLoggingMiddleware(BaseHTTPMiddleware): """ Middleware for structured request logging and audit trails. @@ -135,9 +185,25 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: def setup_logging(log_level: str = "INFO") -> None: - """Configure structured JSON logging for the API.""" + """Configure structured JSON logging for the API. + + Installs :class:`RequestIDLogFilter` on the root logger so every log + record emitted during a request carries the current ``request_id``. + """ logging.basicConfig( level=getattr(logging, log_level.upper()), - format='{"timestamp":"%(asctime)s","level":"%(levelname)s","message":"%(message)s"}', - datefmt="%Y-%m-%dT%H:%M:%S" + format=( + '{"timestamp":"%(asctime)s","level":"%(levelname)s",' + '"request_id":"%(request_id)s","message":"%(message)s"}' + ), + datefmt="%Y-%m-%dT%H:%M:%S", ) + request_id_filter = RequestIDLogFilter() + root = logging.getLogger() + # Attach to the root logger and to any handlers already installed by + # ``logging.basicConfig`` so existing handlers also see ``request_id``. + if not any(isinstance(f, RequestIDLogFilter) for f in root.filters): + root.addFilter(request_id_filter) + for handler in root.handlers: + if not any(isinstance(f, RequestIDLogFilter) for f in handler.filters): + handler.addFilter(request_id_filter) diff --git a/tests/test_request_id_middleware.py b/tests/test_request_id_middleware.py new file mode 100644 index 0000000..dcfd13d --- /dev/null +++ b/tests/test_request_id_middleware.py @@ -0,0 +1,109 @@ +"""Tests for the X-Request-ID middleware. + +Covers issue #44: request log lines from the inference pipeline and helper +modules need to carry the same identifier as the inbound HTTP request. +""" + +from __future__ import annotations + +import logging +import re +import uuid + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from climatevision.api.middleware import ( + RequestIDLogFilter, + RequestIDMiddleware, + request_id_var, +) + + +UUID_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" +) + + +def _build_app() -> FastAPI: + app = FastAPI() + app.add_middleware(RequestIDMiddleware) + + logger = logging.getLogger("climatevision.tests.request_id") + + @app.get("/echo") + def echo() -> dict[str, str]: + # Read the contextvar the way inference/pipeline.py would + return {"request_id": request_id_var.get() or ""} + + @app.get("/log") + def log_route() -> dict[str, str]: + logger.info("inside-handler", extra={"sample": True}) + return {"ok": "1"} + + return app + + +def test_response_has_uuid_request_id_when_header_absent() -> None: + """Without an X-Request-ID header the middleware mints a fresh UUID4.""" + client = TestClient(_build_app()) + response = client.get("/echo") + assert response.status_code == 200 + + request_id = response.headers["X-Request-ID"] + assert UUID_RE.match(request_id), f"not a UUID4 shape: {request_id!r}" + # And the same value flowed into the contextvar that the route saw. + assert response.json()["request_id"] == request_id + + +def test_response_echoes_explicit_request_id() -> None: + """When the client sends X-Request-ID the middleware must echo it back.""" + client = TestClient(_build_app()) + sent = "test-id-deadbeef" + response = client.get("/echo", headers={"X-Request-ID": sent}) + assert response.status_code == 200 + assert response.headers["X-Request-ID"] == sent + assert response.json()["request_id"] == sent + + +def test_log_records_include_request_id_via_filter( + caplog: object, +) -> None: + """Logs emitted inside the request show request_id once the filter is on.""" + handler = logging.StreamHandler() + handler.addFilter(RequestIDLogFilter()) + handler.setFormatter( + logging.Formatter("%(request_id)s | %(message)s") + ) + + target_logger = logging.getLogger("climatevision.tests.request_id") + target_logger.setLevel(logging.INFO) + target_logger.addHandler(handler) + try: + # Use pytest's caplog with the same filter so we can introspect records. + with caplog.at_level(logging.INFO, logger="climatevision.tests.request_id"): # type: ignore[attr-defined] + for record_filter in [RequestIDLogFilter()]: + caplog.handler.addFilter(record_filter) # type: ignore[attr-defined] + + sent = str(uuid.uuid4()) + client = TestClient(_build_app()) + response = client.get("/log", headers={"X-Request-ID": sent}) + assert response.status_code == 200 + + records = [ + r for r in caplog.records # type: ignore[attr-defined] + if r.name == "climatevision.tests.request_id" + ] + assert records, "expected the route handler to emit a log record" + record = records[0] + assert getattr(record, "request_id", None) == sent + finally: + target_logger.removeHandler(handler) + + +def test_request_id_var_resets_after_response() -> None: + """The ContextVar must not leak across requests.""" + client = TestClient(_build_app()) + client.get("/echo", headers={"X-Request-ID": "first"}) + # Outside any request the var returns its default ("" via ``or ''``). + assert request_id_var.get() is None