Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/climatevision/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,12 @@ def create_app() -> FastAPI:
)

app.add_middleware(AuditLogMiddleware)
# Register the RequestIDMiddleware LAST so it sits OUTERMOST in the
# middleware stack: it must run before every other middleware so the
# request_id ContextVar is set in time for AuditLogMiddleware's logger
# call (and for any inference-pipeline code further down).
from climatevision.api.middleware import RequestIDMiddleware
app.add_middleware(RequestIDMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=[
Expand Down
72 changes: 69 additions & 3 deletions src/climatevision/api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import time
import uuid
from contextvars import ContextVar
from typing import Callable

from fastapi import Request, Response
Expand All @@ -18,6 +19,55 @@
logger = logging.getLogger(__name__)


# Context variable that carries the X-Request-ID through the entire request
# lifecycle, including non-FastAPI code (inference pipeline, helper modules)
# that does not have access to the FastAPI `Request` object.
request_id_var: ContextVar[str | None] = ContextVar("request_id", default=None)


class RequestIDMiddleware(BaseHTTPMiddleware):
"""
Propagate ``X-Request-ID`` through the request lifecycle.

Reads the inbound ``X-Request-ID`` header (or generates a fresh UUID4 if
absent), stores it on a :class:`~contextvars.ContextVar` so any code that
runs during the request -- inference pipeline, helper modules, background
tasks scheduled with ``asyncio.create_task`` -- can read the same value,
and echoes it back on the response.

Pair this with :class:`RequestIDLogFilter` so log records emitted during
the request automatically carry ``%(request_id)s``.
"""

async def dispatch(self, request: Request, call_next: Callable) -> Response:
request_id = request.headers.get("X-Request-ID") or str(uuid.uuid4())
token = request_id_var.set(request_id)
# Mirror to ``request.state`` so handlers that already read from
# there (e.g. the existing ``RequestLoggingMiddleware`` /
# ``AuditLogMiddleware``) keep working.
request.state.request_id = request_id
try:
response = await call_next(request)
finally:
request_id_var.reset(token)
response.headers["X-Request-ID"] = request_id
return response


class RequestIDLogFilter(logging.Filter):
"""Inject the current request ID into every log record.

Reads :data:`request_id_var` and exposes it as ``record.request_id`` so
logging formatters can reference ``%(request_id)s``. Records emitted
outside of any request (startup, background workers without context)
receive ``"-"`` so the format string never KeyErrors.
"""

def filter(self, record: logging.LogRecord) -> bool: # noqa: A003
record.request_id = request_id_var.get() or "-"
return True


class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""
Middleware for structured request logging and audit trails.
Expand Down Expand Up @@ -135,9 +185,25 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:


def setup_logging(log_level: str = "INFO") -> None:
"""Configure structured JSON logging for the API."""
"""Configure structured JSON logging for the API.

Installs :class:`RequestIDLogFilter` on the root logger so every log
record emitted during a request carries the current ``request_id``.
"""
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='{"timestamp":"%(asctime)s","level":"%(levelname)s","message":"%(message)s"}',
datefmt="%Y-%m-%dT%H:%M:%S"
format=(
'{"timestamp":"%(asctime)s","level":"%(levelname)s",'
'"request_id":"%(request_id)s","message":"%(message)s"}'
),
datefmt="%Y-%m-%dT%H:%M:%S",
)
request_id_filter = RequestIDLogFilter()
root = logging.getLogger()
# Attach to the root logger and to any handlers already installed by
# ``logging.basicConfig`` so existing handlers also see ``request_id``.
if not any(isinstance(f, RequestIDLogFilter) for f in root.filters):
root.addFilter(request_id_filter)
for handler in root.handlers:
if not any(isinstance(f, RequestIDLogFilter) for f in handler.filters):
handler.addFilter(request_id_filter)
109 changes: 109 additions & 0 deletions tests/test_request_id_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tests for the X-Request-ID middleware.

Covers issue #44: request log lines from the inference pipeline and helper
modules need to carry the same identifier as the inbound HTTP request.
"""

from __future__ import annotations

import logging
import re
import uuid

from fastapi import FastAPI
from fastapi.testclient import TestClient

from climatevision.api.middleware import (
RequestIDLogFilter,
RequestIDMiddleware,
request_id_var,
)


UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
)


def _build_app() -> FastAPI:
app = FastAPI()
app.add_middleware(RequestIDMiddleware)

logger = logging.getLogger("climatevision.tests.request_id")

@app.get("/echo")
def echo() -> dict[str, str]:
# Read the contextvar the way inference/pipeline.py would
return {"request_id": request_id_var.get() or ""}

@app.get("/log")
def log_route() -> dict[str, str]:
logger.info("inside-handler", extra={"sample": True})
return {"ok": "1"}

return app


def test_response_has_uuid_request_id_when_header_absent() -> None:
"""Without an X-Request-ID header the middleware mints a fresh UUID4."""
client = TestClient(_build_app())
response = client.get("/echo")
assert response.status_code == 200

request_id = response.headers["X-Request-ID"]
assert UUID_RE.match(request_id), f"not a UUID4 shape: {request_id!r}"
# And the same value flowed into the contextvar that the route saw.
assert response.json()["request_id"] == request_id


def test_response_echoes_explicit_request_id() -> None:
"""When the client sends X-Request-ID the middleware must echo it back."""
client = TestClient(_build_app())
sent = "test-id-deadbeef"
response = client.get("/echo", headers={"X-Request-ID": sent})
assert response.status_code == 200
assert response.headers["X-Request-ID"] == sent
assert response.json()["request_id"] == sent


def test_log_records_include_request_id_via_filter(
caplog: object,
) -> None:
"""Logs emitted inside the request show request_id once the filter is on."""
handler = logging.StreamHandler()
handler.addFilter(RequestIDLogFilter())
handler.setFormatter(
logging.Formatter("%(request_id)s | %(message)s")
)

target_logger = logging.getLogger("climatevision.tests.request_id")
target_logger.setLevel(logging.INFO)
target_logger.addHandler(handler)
try:
# Use pytest's caplog with the same filter so we can introspect records.
with caplog.at_level(logging.INFO, logger="climatevision.tests.request_id"): # type: ignore[attr-defined]
for record_filter in [RequestIDLogFilter()]:
caplog.handler.addFilter(record_filter) # type: ignore[attr-defined]

sent = str(uuid.uuid4())
client = TestClient(_build_app())
response = client.get("/log", headers={"X-Request-ID": sent})
assert response.status_code == 200

records = [
r for r in caplog.records # type: ignore[attr-defined]
if r.name == "climatevision.tests.request_id"
]
assert records, "expected the route handler to emit a log record"
record = records[0]
assert getattr(record, "request_id", None) == sent
finally:
target_logger.removeHandler(handler)


def test_request_id_var_resets_after_response() -> None:
"""The ContextVar must not leak across requests."""
client = TestClient(_build_app())
client.get("/echo", headers={"X-Request-ID": "first"})
# Outside any request the var returns its default ("" via ``or ''``).
assert request_id_var.get() is None
Loading