From d55ddd32e3304f26955e9b23e062228fdad98dbe Mon Sep 17 00:00:00 2001 From: Arne Baumann Date: Wed, 27 May 2026 17:02:20 +0200 Subject: [PATCH] feat(auth): add JWT Bearer token authentication support Extends api.auth to accept Auth0 JWT Bearer tokens alongside existing cookie-based sessions. Bearer auth is opt-in via AUTH_JWT_ENABLED and AUTH_JWT_AUDIENCE; each require_* dependency tries Bearer first, then falls back to cookie. Adds JWKS fetching with per-domain TTL cache. Co-Authored-By: Claude Sonnet 4.6 --- pyproject.toml | 6 + src/aignostics_foundry_core/AGENTS.md | 21 +- src/aignostics_foundry_core/api/auth.py | 317 ++++++++++++---- .../aignostics_foundry_core/api/auth_test.py | 346 ++++++++++++++++-- uv.lock | 62 +++- 5 files changed, 632 insertions(+), 120 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7253600..e7f432e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,9 @@ dependencies = [ "auth0-fastapi>=1.0.0b5,<2", "certifi>=2024", "fastapi>=0.110,<1", + "httpx2>=2.2.0,<3", "loguru>=0.7,<1", + "PyJWT[cryptography]>=2.10,<3", "platformdirs>=4,<5", "psutil>=6", "pydantic>=2,<3", @@ -301,7 +303,11 @@ style = [ # Changelog template configuration template = ".cz-templates/CHANGELOG.md.j2" +[tool.deptry] +package_module_name_map = { "httpx2" = ["httpx"] } + [tool.deptry.per_rule_ignores] "DEP002" = [ "asyncpg", + "auth0-fastapi", ] diff --git a/src/aignostics_foundry_core/AGENTS.md b/src/aignostics_foundry_core/AGENTS.md index 33714c1..2a8220e 100644 --- a/src/aignostics_foundry_core/AGENTS.md +++ b/src/aignostics_foundry_core/AGENTS.md @@ -11,7 +11,7 @@ This file provides an overview of all modules in `aignostics_foundry_core`, thei | **models** | Shared output format enum | `OutputFormat` StrEnum with `YAML` and `JSON` values for use in CLI and API responses | | **process** | Current process introspection | `ProcessInfo`, `ParentProcessInfo` Pydantic models and `get_process_info()` for runtime process metadata; `SUBPROCESS_CREATION_FLAGS` for subprocess creation | | **api.exceptions** | API exception hierarchy and FastAPI handlers | `ApiException` (500), `NotFoundException` (404), `AccessDeniedException` (401); `api_exception_handler`, `unhandled_exception_handler`, `validation_exception_handler` for FastAPI registration | -| **api.auth** | Auth0 authentication FastAPI dependencies | `AuthSettings` (env-prefix and env files derived from `ctx.env_prefix`/`ctx.env_file`), `UnauthenticatedError`, `ForbiddenError` (403); `get_auth_client`, `get_user`, `require_authenticated`, `require_admin`, `require_internal`, `require_internal_admin` FastAPI dependencies; Auth0 cookie security schemes | +| **api.auth** | Auth0 authentication FastAPI dependencies (cookie + Bearer JWT) | `AuthSettings` (env-prefix from `ctx.env_prefix`; fields: `cookie_enabled`, `enabled` (deprecated alias), `jwt_enabled`, `jwt_audience`, domain, credentials, org, role); `UnauthenticatedError`, `ForbiddenError` (403); `get_auth_client`, `get_user` (tries Bearer first, falls back to cookie), `require_authenticated`, `require_admin`, `require_internal`, `require_internal_admin` FastAPI dependencies; Auth0 cookie + Bearer security schemes for OpenAPI | | **api.core** | Versioned API router and FastAPI factory | `VersionedAPIRouter` (tracks all created instances), `API_TAG_*` constants, `create_public/authenticated/admin/internal/internal_admin_router` factories, `build_api_metadata`, `build_versioned_api_tags`, `build_root_api_tags`, `get_versioned_api_instances(versions, build_metadata=None, *, context=None)`, `init_api()` | | **api** | Consolidated API sub-package | Re-exports all public symbols from `api.exceptions`, `api.auth`, and `api.core`; import any API symbol directly from `aignostics_foundry_core.api` | | **log** | Configurable loguru logging initialisation | `logging_initialize(filter_func=None, *, context=None)`, `LogSettings` (env-prefix configurable), `InterceptHandler` for stdlib-to-loguru bridging | @@ -108,22 +108,21 @@ This file provides an overview of all modules in `aignostics_foundry_core`, thei ### api.auth -**Auth0 authentication and authorization FastAPI dependencies** +**Auth0 authentication and authorization FastAPI dependencies (cookie + Bearer JWT)** -- **Purpose**: Provides Auth0 cookie-based session authentication dependencies for FastAPI routes. All project-specific settings (org ID, role claim) are loaded from `AuthSettings` whose env prefix is configurable at instantiation. +- **Purpose**: Provides Auth0 cookie-based session and JWT Bearer token authentication dependencies for FastAPI routes. Each `require_*` dependency accepts either an Auth0 session cookie **or** a Bearer JWT — Bearer is tried first, cookie is the fallback. - **Key Features**: - - `AuthSettings(OpaqueSettings)` — uses the active FoundryContext to derive both the env prefix (`{ctx.env_prefix}AUTH_`) and the env file list (`ctx.env_file`). Fields: `internal_org_id` (required `str`; identifies the internal organization), `auth0_role_claim` (required `str`; JWT claim name for role). Both fields are mandatory — no defaults are provided. + - `AuthSettings(OpaqueSettings)` — uses the active FoundryContext (`{ctx.env_prefix}AUTH_`). Key fields: `cookie_enabled` (`AUTH_COOKIE_ENABLED`; new primary name), `enabled` (`AUTH_ENABLED`; deprecated alias for `cookie_enabled`, kept for backwards compat), `jwt_enabled` (`AUTH_JWT_ENABLED`; opt-in Bearer JWT auth), `jwt_audience` (`AUTH_JWT_AUDIENCE`; required when `jwt_enabled=True`), `domain`, `client_id`, `client_secret`, `internal_org_id`, `role_claim`, `session_secret`, `session_expiration`. - `UnauthenticatedError(Exception)` — raised when a user session is missing or invalid - `ForbiddenError(ApiException)` — `status_code = 403`; raised when user lacks required role or org membership - `get_auth_client(request)` — retrieves `AuthClient` from `request.app.state.auth_client`; raises `RuntimeError` if not configured - - `get_user(request, _cookie)` — async FastAPI dependency; returns user dict from Auth0 session or `None`; validates expiry; sets Sentry user context - - `require_authenticated` — dependency: requires a valid session - - `require_admin` — dependency: requires admin role - - `require_internal` — dependency: requires internal organization membership - - `require_internal_admin` — dependency: requires internal org membership AND admin role - - Auth0 cookie security scheme constants: `AUTH0_SESSION_COOKIE_NAME`, `AUTH0_TRANSACTION_COOKIE_NAME`, `AUTH0_ROLE_ADMIN` + - `get_user(request, _cookie, _bearer)` — async FastAPI dependency; tries Bearer JWT first (when `jwt_enabled=True`), falls back to cookie; returns user dict or `None`; sets Sentry user context + - `require_authenticated`, `require_admin`, `require_internal`, `require_internal_admin` — FastAPI dependencies; each accepts both cookie and Bearer schemes in OpenAPI + - Cookie security schemes: `auth0_session_scheme`, `auth0_admin_scheme`, `auth0_internal_scheme`, `auth0_internal_admin_scheme` (`APIKeyCookie`) + - Bearer security schemes: `auth0_bearer_scheme`, `auth0_admin_bearer_scheme`, `auth0_internal_bearer_scheme`, `auth0_internal_admin_bearer_scheme` (`HTTPBearer`) + - Constants: `AUTH0_SESSION_COOKIE_NAME`, `AUTH0_TRANSACTION_COOKIE_NAME`, `AUTH0_ROLE_ADMIN`, `AUTH0_JWKS_ALGORITHMS`, `AUTH0_JWKS_CACHE_TTL` - **Location**: `aignostics_foundry_core/api/auth.py` -- **Dependencies**: `auth0-fastapi>=1.0.0b5,<2`, `fastapi>=0.110,<1`, `loguru>=0.7,<1` (all mandatory) +- **Dependencies**: `auth0-fastapi>=1.0.0b5,<2`, `fastapi>=0.110,<1`, `loguru>=0.7,<1`, `PyJWT[cryptography]>=2.10,<3`, `httpx>=0.28,<1` (all mandatory) - **Import**: `from aignostics_foundry_core.api.auth import AuthSettings, ForbiddenError, UnauthenticatedError, get_auth_client, get_user, require_authenticated, require_admin, require_internal, require_internal_admin` ### api.core diff --git a/src/aignostics_foundry_core/api/auth.py b/src/aignostics_foundry_core/api/auth.py index 194ae49..0731d5e 100644 --- a/src/aignostics_foundry_core/api/auth.py +++ b/src/aignostics_foundry_core/api/auth.py @@ -1,28 +1,39 @@ """Authentication utilities for FastAPI. This module provides: -- Auth0 cookie schemes for OpenAPI documentation +- Auth0 cookie and Bearer JWT schemes for OpenAPI documentation - Authentication dependencies (require_authenticated, require_admin, etc.) -- get_user: Get authenticated user from session +- get_user: Get authenticated user from session cookie or Bearer JWT - get_auth_client: Get Auth0 client from app state -- AuthSettings: Full auth configuration (enabled, session, domain, credentials, org, role claim) +- AuthSettings: Full auth configuration (cookie_enabled, jwt_enabled, session, domain, + credentials, org, role claim, JWT audience) """ -import time -from typing import Annotated, Any +from __future__ import annotations -from auth0_fastapi.auth.auth_client import AuthClient -from fastapi import Request, Security -from fastapi.security import APIKeyCookie +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Annotated, Any + +import httpx +import jwt +from fastapi import Request, Response, Security +from fastapi.security import APIKeyCookie, HTTPAuthorizationCredentials, HTTPBearer +from jwt.algorithms import RSAAlgorithm from loguru import logger from pydantic import Field, PlainSerializer, SecretStr, StringConstraints, model_validator from pydantic_settings import SettingsConfigDict from aignostics_foundry_core.foundry import get_context +from aignostics_foundry_core.sentry import set_sentry_user from aignostics_foundry_core.settings import OpaqueSettings, load_settings from .exceptions import ApiException +if TYPE_CHECKING: + from auth0_fastapi.auth.auth_client import AuthClient + from jwt.algorithms import AllowedRSAKeys + AUTH0_SESSION_COOKIE_NAME = "_a0_session" AUTH0_TRANSACTION_COOKIE_NAME = "_a0_tx" AUTH0_COOKIE_SCHEME_NAME = "Auth0Cookie" @@ -30,6 +41,17 @@ AUTH0_ROLE_ADMIN = "admin" USER_NOT_AUTHENTICATED = "User is not authenticated" AUTH_SESSION_EXPIRATION_DEFAULT = 60 * 60 * 24 # 1 day in seconds +AUTH0_JWKS_ALGORITHMS = ["RS256"] +AUTH0_JWKS_CACHE_TTL = 3600 # seconds + + +@dataclass(frozen=True) +class _JwksCacheEntry: + jwks: dict[str, Any] + fetched_at: float + + +_jwks_cache: dict[str, _JwksCacheEntry] = {} class AuthSettings(OpaqueSettings): @@ -40,7 +62,11 @@ class AuthSettings(OpaqueSettings): :func:`aignostics_foundry_core.foundry.get_context`. Fields: - enabled: Enable Auth0 authentication (AUTH_ENABLED). + cookie_enabled: Enable Auth0 cookie-based authentication (AUTH_COOKIE_ENABLED). + enabled: Deprecated alias for cookie_enabled. Use AUTH_COOKIE_ENABLED instead + (AUTH_ENABLED still accepted for backwards compatibility). + jwt_enabled: Enable JWT Bearer token authentication (AUTH_JWT_ENABLED). + jwt_audience: Auth0 API audience identifier for JWT validation (AUTH_JWT_AUDIENCE). session_secret: Secret used to sign session cookies (AUTH_SESSION_SECRET). session_expiration: Session cookie expiration in seconds (AUTH_SESSION_EXPIRATION). domain: Auth0 domain (AUTH_DOMAIN). @@ -50,34 +76,66 @@ class AuthSettings(OpaqueSettings): role_claim: JWT claim name containing the user's role (AUTH_ROLE_CLAIM). Cross-field rules (validated after field assignment): - - enabled=True requires session_secret not None, client_secret not None, + - cookie_enabled=True (or enabled=True) requires session_secret, client_secret, non-empty domain, client_id, internal_org_id, and role_claim + - jwt_enabled=True requires non-empty domain and jwt_audience """ model_config = SettingsConfigDict(extra="ignore") - enabled: bool = Field(default=False) + cookie_enabled: bool = Field(default=False) + enabled: bool = Field(default=False) # deprecated; kept for AUTH_ENABLED backwards compat + jwt_enabled: bool = Field(default=False) + jwt_audience: str = Field(default="") session_secret: Annotated[ SecretStr | None, PlainSerializer(func=OpaqueSettings.serialize_sensitive_info, return_type=str, when_used="always"), ] = Field(default=None) session_expiration: int = Field(default=AUTH_SESSION_EXPIRATION_DEFAULT, gt=60, le=31536000) - domain: Annotated[str, StringConstraints(max_length=255)] = Field(default="") - client_id: Annotated[str, StringConstraints(max_length=32)] = Field(default="") + domain: Annotated[str, StringConstraints(max_length=255, strip_whitespace=True)] = Field(default="") + client_id: Annotated[str, StringConstraints(max_length=32, strip_whitespace=True)] = Field(default="") client_secret: Annotated[ SecretStr | None, PlainSerializer(func=OpaqueSettings.serialize_sensitive_info, return_type=str, when_used="always"), ] = Field(default=None, min_length=64, max_length=64) - internal_org_id: str = "" - role_claim: str = "" + internal_org_id: Annotated[str, StringConstraints(max_length=255, strip_whitespace=True)] = Field(default="") + role_claim: Annotated[str, StringConstraints(max_length=255, strip_whitespace=True)] = Field(default="") def __init__(self, **kwargs: Any) -> None: # noqa: ANN401 """Initialise settings, deriving env_prefix and env files from the active FoundryContext.""" ctx = get_context() super().__init__(_env_prefix=f"{ctx.env_prefix}AUTH_", _env_file=ctx.env_file, **kwargs) # pyright: ignore[reportCallIssue] + def _validate_cookie_auth(self) -> None: + """Validate cookie auth required fields when cookie auth is active. + + Raises: + ValueError: If any required cookie auth field is missing or invalid. + """ + cookie_active = self.cookie_enabled or self.enabled + if not cookie_active: + return + if self.session_secret is None: + msg = "AUTH_SESSION_SECRET must not be None when cookie auth is enabled" + raise ValueError(msg) + if self.client_secret is None: + msg = "AUTH_CLIENT_SECRET must not be None when cookie auth is enabled" + raise ValueError(msg) + if not self.domain: + msg = "AUTH_DOMAIN must not be empty when cookie auth is enabled" + raise ValueError(msg) + if not self.client_id: + msg = "AUTH_CLIENT_ID must not be empty when cookie auth is enabled" + raise ValueError(msg) + if not self.internal_org_id: + msg = "AUTH_INTERNAL_ORG_ID must not be empty when cookie auth is enabled" + raise ValueError(msg) + if not self.role_claim: + msg = "AUTH_ROLE_CLAIM must not be empty when cookie auth is enabled" + raise ValueError(msg) + @model_validator(mode="after") - def validate_auth_dependencies(self) -> "AuthSettings": + def validate_auth_dependencies(self) -> AuthSettings: """Validate cross-field auth dependencies. Returns: @@ -86,23 +144,12 @@ def validate_auth_dependencies(self) -> "AuthSettings": Raises: ValueError: If any cross-field dependency is violated. """ - if self.enabled and self.session_secret is None: - msg = "AUTH_SESSION_SECRET must not be None when AUTH_ENABLED is True" - raise ValueError(msg) - if self.enabled and self.client_secret is None: - msg = "AUTH_CLIENT_SECRET must not be None when AUTH_ENABLED is True" - raise ValueError(msg) - if self.enabled and not self.domain: - msg = "AUTH_DOMAIN must not be empty when AUTH_ENABLED is True" - raise ValueError(msg) - if self.enabled and not self.client_id: - msg = "AUTH_CLIENT_ID must not be empty when AUTH_ENABLED is True" - raise ValueError(msg) - if self.enabled and not self.internal_org_id: - msg = "AUTH_INTERNAL_ORG_ID must not be empty when AUTH_ENABLED is True" + self._validate_cookie_auth() + if self.jwt_enabled and not self.domain: + msg = "AUTH_DOMAIN must not be empty when AUTH_JWT_ENABLED is True" raise ValueError(msg) - if self.enabled and not self.role_claim: - msg = "AUTH_ROLE_CLAIM must not be empty when AUTH_ENABLED is True" + if self.jwt_enabled and not self.jwt_audience: + msg = "AUTH_JWT_AUDIENCE must not be empty when AUTH_JWT_ENABLED is True" raise ValueError(msg) return self @@ -181,10 +228,120 @@ def get_auth_client(request: Request) -> AuthClient: auto_error=False, ) # Security scheme for internal admin endpoints +auth0_bearer_scheme = HTTPBearer( + scheme_name="Auth0Bearer", + description="Auth0 JWT Bearer token authentication.", + auto_error=False, +) + +auth0_admin_bearer_scheme = HTTPBearer( + scheme_name="Auth0AdminBearer", + description="Auth0 JWT Bearer token authentication with admin role requirement. " + f"User must have '{AUTH0_ROLE_ADMIN}' role in their configured role_claim.", + auto_error=False, +) + +auth0_internal_bearer_scheme = HTTPBearer( + scheme_name="Auth0InternalBearer", + description="Auth0 JWT Bearer token authentication with internal organization membership requirement. " + "User must be a member of the configured internal organization.", + auto_error=False, +) + +auth0_internal_admin_bearer_scheme = HTTPBearer( + scheme_name="Auth0InternalAdminBearer", + description=( + "Auth0 JWT Bearer token authentication with internal organization membership AND admin role requirements. " + f"User must be a member of the internal organization AND have '{AUTH0_ROLE_ADMIN}' role." + ), + auto_error=False, +) + + +async def _fetch_jwks(domain: str, *, force_refresh: bool = False) -> dict[str, Any]: + """Fetch JWKS from Auth0, caching the result per domain for AUTH0_JWKS_CACHE_TTL seconds. + + On fetch failure falls back to the last known good cache entry when one exists. + + Args: + domain: Auth0 domain to fetch JWKS from. + force_refresh: Bypass the TTL check and always fetch from the network. + + Returns: + Parsed JWKS JSON as a dict. + """ + entry = _jwks_cache.get(domain) + if not force_refresh and entry is not None and (time.time() - entry.fetched_at) < AUTH0_JWKS_CACHE_TTL: + return entry.jwks + + try: + async with httpx.AsyncClient() as client: + resp = await client.get(f"https://{domain}/.well-known/jwks.json", timeout=10) + resp.raise_for_status() + result: dict[str, Any] = resp.json() + _jwks_cache[domain] = _JwksCacheEntry(jwks=result, fetched_at=time.time()) + return result + except Exception: + if entry is not None: + logger.warning("JWKS refresh failed for domain {}; using stale cache", domain) + return entry.jwks + raise + + +async def _extract_public_key(token: str, domain: str) -> AllowedRSAKeys | None: + """Resolve the RSA public key for a JWT's kid from Auth0 JWKS. + + Fetches JWKS from cache; on a kid-miss, force-refreshes once before giving up. + + Args: + token: The raw JWT string (used only to read the unverified header). + domain: Auth0 domain to fetch JWKS from. + + Returns: + RSA public key object on success, or None if the kid cannot be resolved. + """ + jwks = await _fetch_jwks(domain) + header = jwt.get_unverified_header(token) + kid = header.get("kid") + + key_data = next((k for k in jwks.get("keys", []) if k.get("kid") == kid), None) + if key_data is None: + logger.debug("JWT kid not found in cache; force-refreshing JWKS", domain=domain, kid=kid) + jwks = await _fetch_jwks(domain, force_refresh=True) + + key_data = next((k for k in jwks.get("keys", []) if k.get("kid") == kid), None) + if key_data is None: + logger.warning("JWT kid not found in JWKS after refresh", domain=domain, kid=kid) + return None + + return RSAAlgorithm.from_jwk(key_data) + + +async def _validate_jwt(token: str, auth_settings: AuthSettings) -> dict[str, Any] | None: + """Validate a Bearer JWT against the Auth0 JWKS. + + Returns: + Decoded JWT claims dict on success, or None if validation fails. + """ + try: + public_key = await _extract_public_key(token, auth_settings.domain) + payload: dict[str, Any] = jwt.decode( + token, + public_key, # pyright: ignore[reportArgumentType] # from_jwk returns public key from JWKS + algorithms=AUTH0_JWKS_ALGORITHMS, + audience=auth_settings.jwt_audience, + issuer=f"https://{auth_settings.domain}/", + ) + return payload + except Exception: # noqa: BLE001 + logger.debug("JWT validation failed") + return None + async def _require_authenticated_impl( request: Request, _cookie: str | None, + _bearer: HTTPAuthorizationCredentials | None = None, role: str | None = None, ) -> None: """Internal implementation for authenticated session check with optional role. @@ -192,6 +349,7 @@ async def _require_authenticated_impl( Args: request: The incoming request. _cookie: The session cookie. + _bearer: Optional Bearer JWT credentials. role: Optional role required (e.g., "admin"). If specified, user must have this role in their configured role_claim. @@ -201,147 +359,161 @@ async def _require_authenticated_impl( """ auth_settings = load_settings(AuthSettings) - user = await get_user(request, _cookie) + user = await get_user(request, _cookie, _bearer) if not user: - msg = USER_NOT_AUTHENTICATED - logger.critical(msg) - raise ForbiddenError(msg) + logger.critical(USER_NOT_AUTHENTICATED) + raise ForbiddenError(USER_NOT_AUTHENTICATED) + + log = logger.bind(user_id=user.get("sub")) # Check role if specified if role is not None: user_role = user.get(auth_settings.role_claim) if user_role != role: + log.warning("Role check failed", required_role=role, actual_role=user_role) msg = f"User role '{user_role}' does not match required role '{role}'" - logger.warning(msg) raise ForbiddenError(msg) - logger.debug(f"User has required role: {role}") + log.debug("Role check passed", role=role) async def require_authenticated( request: Request, _cookie: Annotated[str | None, Security(auth0_session_scheme)], + _bearer: Annotated[HTTPAuthorizationCredentials | None, Security(auth0_bearer_scheme)], ) -> None: """Require an authenticated session (FastAPI dependency). + Accepts either an Auth0 session cookie or a valid JWT Bearer token. + Args: request: The incoming request. _cookie: The session cookie (auto-injected by FastAPI). + _bearer: JWT Bearer credentials (auto-injected by FastAPI). Raises: - UnauthenticatedError: If the session is not valid or missing. + ForbiddenError: If the session is not valid or missing. """ - await _require_authenticated_impl(request, _cookie) + await _require_authenticated_impl(request, _cookie, _bearer) async def require_admin( request: Request, _cookie: Annotated[str | None, Security(auth0_admin_scheme)], + _bearer: Annotated[HTTPAuthorizationCredentials | None, Security(auth0_admin_bearer_scheme)], ) -> None: """Require admin role (FastAPI dependency). + Accepts either an Auth0 session cookie or a valid JWT Bearer token. + Args: request: The incoming request. _cookie: The session cookie (auto-injected by FastAPI). + _bearer: JWT Bearer credentials (auto-injected by FastAPI). Raises: - UnauthenticatedError: If the session is not valid or missing. - ForbiddenError: If user doesn't have admin role. + ForbiddenError: If the session is not valid or user doesn't have admin role. """ - await _require_authenticated_impl(request, _cookie, role=AUTH0_ROLE_ADMIN) + await _require_authenticated_impl(request, _cookie, _bearer, role=AUTH0_ROLE_ADMIN) async def require_internal( request: Request, _cookie: Annotated[str | None, Security(auth0_internal_scheme)], + _bearer: Annotated[HTTPAuthorizationCredentials | None, Security(auth0_internal_bearer_scheme)], ) -> None: """Require internal organization membership (FastAPI dependency). Checks if the authenticated user is a member of the configured internal organization. The internal organization is identified by the FOUNDRY_AUTH_INTERNAL_ORG_ID setting. + Accepts either an Auth0 session cookie or a valid JWT Bearer token. Args: request: The incoming request. _cookie: The session cookie (auto-injected by FastAPI). + _bearer: JWT Bearer credentials (auto-injected by FastAPI). Raises: - UnauthenticatedError: If the session is not valid or missing. - ForbiddenError: If user is not a member of the internal organization. + ForbiddenError: If the session is not valid or user is not in the internal org. """ auth_settings = load_settings(AuthSettings) - user = await get_user(request, _cookie) + user = await get_user(request, _cookie, _bearer) if not user: - msg = USER_NOT_AUTHENTICATED - logger.critical(msg) - raise ForbiddenError(msg) - # Check organization membership + logger.critical(USER_NOT_AUTHENTICATED) + raise ForbiddenError(USER_NOT_AUTHENTICATED) + user_org_id = user.get("org_id") + log = logger.bind(user_id=user.get("sub"), user_org=user_org_id) + if user_org_id != auth_settings.internal_org_id: + log.warning("Org membership check failed") msg = f"User is not a member of the internal organization (org_id: {user_org_id})" - logger.warning(msg) raise ForbiddenError(msg) - logger.debug(f"User is member of internal organization: {auth_settings.internal_org_id}") + log.debug("Org membership check passed") async def require_internal_admin( request: Request, _cookie: Annotated[str | None, Security(auth0_internal_admin_scheme)], + _bearer: Annotated[HTTPAuthorizationCredentials | None, Security(auth0_internal_admin_bearer_scheme)], ) -> None: """Require internal organization membership AND admin role (FastAPI dependency). Checks if the authenticated user is both: 1. A member of the configured internal organization (FOUNDRY_AUTH_INTERNAL_ORG_ID) 2. Has the admin role in their configured role_claim + Accepts either an Auth0 session cookie or a valid JWT Bearer token. Args: request: The incoming request. _cookie: The session cookie (auto-injected by FastAPI). + _bearer: JWT Bearer credentials (auto-injected by FastAPI). Raises: - UnauthenticatedError: If the session is not valid or missing. ForbiddenError: If user is not internal or doesn't have admin role. """ auth_settings = load_settings(AuthSettings) - user = await get_user(request, _cookie) + user = await get_user(request, _cookie, _bearer) if not user: - msg = USER_NOT_AUTHENTICATED - logger.critical(msg) - raise ForbiddenError(msg) + logger.critical(USER_NOT_AUTHENTICATED) + raise ForbiddenError(USER_NOT_AUTHENTICATED) - # Check organization membership user_org_id = user.get("org_id") + user_role = user.get(auth_settings.role_claim) + log = logger.bind(user_id=user.get("sub"), user_org=user_org_id, user_role=user_role) + if user_org_id != auth_settings.internal_org_id: + log.warning("Org membership check failed") msg = f"User is not a member of the internal organization (org_id: {user_org_id})" - logger.warning(msg) raise ForbiddenError(msg) - # Check admin role - user_role = user.get(auth_settings.role_claim) if user_role != AUTH0_ROLE_ADMIN: + log.warning("Role check failed", required_role=AUTH0_ROLE_ADMIN) msg = f"User role '{user_role}' does not match required role '{AUTH0_ROLE_ADMIN}'" - logger.warning(msg) raise ForbiddenError(msg) - logger.debug(f"User is internal admin: org={auth_settings.internal_org_id}, role={AUTH0_ROLE_ADMIN}") + log.debug("Internal admin check passed") async def get_user( request: Request, _cookie: Annotated[str | None, Security(auth0_session_scheme)], + _bearer: Annotated[HTTPAuthorizationCredentials | None, Security(auth0_bearer_scheme)], ) -> dict[str, Any] | None: """Get authenticated user information (FastAPI dependency). - This dependency ensures the user is authenticated and returns their user data - from the Auth0 session. Internally reads from the encrypted session cookie. + Tries Bearer JWT first (when jwt_enabled=True and a token is present), then falls + back to the Auth0 encrypted session cookie. Returns None if neither authenticates. Args: request: The incoming request. _cookie: The session cookie (auto-injected by FastAPI). + _bearer: JWT Bearer credentials (auto-injected by FastAPI). Returns: - User dictionary from Auth0 session containing claims like 'sub', 'email', 'name', etc. + User dictionary containing claims like 'sub', 'email', 'name', etc., or None if not authenticated. Example: @@ -349,12 +521,17 @@ async def get_user( async def me(user: Annotated[dict[str, Any], Depends(get_user)]): return {"email": user.get("email")} """ - from fastapi import Response # noqa: PLC0415 - - from aignostics_foundry_core.sentry import set_sentry_user # noqa: PLC0415 - auth_settings = load_settings(AuthSettings) + # Try Bearer JWT first + if _bearer and auth_settings.jwt_enabled: + jwt_user = await _validate_jwt(_bearer.credentials, auth_settings) + if jwt_user: + set_sentry_user(jwt_user, role_claim=auth_settings.role_claim) + return jwt_user + logger.debug("Bearer token present but JWT validation failed; falling back to cookie") + + # Cookie path try: auth_client = get_auth_client(request) session: dict = await auth_client.require_session(request, Response()) # type: ignore[reportUnknownVariableType] diff --git a/tests/aignostics_foundry_core/api/auth_test.py b/tests/aignostics_foundry_core/api/auth_test.py index e602970..52dbcda 100644 --- a/tests/aignostics_foundry_core/api/auth_test.py +++ b/tests/aignostics_foundry_core/api/auth_test.py @@ -2,7 +2,7 @@ import time from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pydantic import pytest @@ -13,6 +13,10 @@ AuthSettings, ForbiddenError, UnauthenticatedError, + _fetch_jwks, + _jwks_cache, + _JwksCacheEntry, + _validate_jwt, get_auth_client, get_user, require_admin, @@ -22,7 +26,11 @@ ) from aignostics_foundry_core.foundry import set_context from tests.aignostics_foundry_core.api import INTERNAL_ORG_ID_VAR_NAME, ROLE_CLAIM_VAR_NAME -from tests.conftest import make_context +from tests.conftest import TEST_PROJECT_PREFIX, make_context + +_JWT_ENABLED_VAR_NAME = f"{TEST_PROJECT_PREFIX}AUTH_JWT_ENABLED" +_JWT_AUDIENCE_VAR_NAME = f"{TEST_PROJECT_PREFIX}AUTH_JWT_AUDIENCE" +_DOMAIN_VAR_NAME = f"{TEST_PROJECT_PREFIX}AUTH_DOMAIN" _INTERNAL_ORG_ID = "org_internal_123" _OTHER_ORG_ID = "org_other_456" @@ -34,6 +42,10 @@ _TEST_CLIENT_SECRET = "x" * 64 _TEST_CLIENT_ID = "x" * 32 _TEST_DOMAIN = "example.auth0.com" +_TEST_JWT_AUDIENCE = "https://api.example.com" +_TEST_BEARER_TOKEN = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3QifQ.test.test" # noqa: S105 +_TEST_KID = "test-kid" +_FETCH_JWKS_PATH = "aignostics_foundry_core.api.auth._fetch_jwks" @pytest.mark.unit @@ -91,7 +103,10 @@ class TestAuthSettings: def test_auth_settings_defaults(self) -> None: """AuthSettings has correct defaults when no env vars are set.""" settings = AuthSettings() + assert settings.cookie_enabled is False assert settings.enabled is False + assert settings.jwt_enabled is False + assert not settings.jwt_audience assert not settings.internal_org_id assert not settings.role_claim assert not settings.domain @@ -110,6 +125,16 @@ def test_auth_settings_uses_context_env_prefix(self, monkeypatch: pytest.MonkeyP settings = AuthSettings() assert settings.role_claim == "https://custom/role" + def test_cookie_enabled_requires_session_secret(self) -> None: + """cookie_enabled=True with session_secret=None raises ValidationError.""" + with pytest.raises(pydantic.ValidationError): + AuthSettings(cookie_enabled=True, session_secret=None) + + def test_deprecated_enabled_still_triggers_validation(self) -> None: + """enabled=True (deprecated flag) still enforces all cookie auth validations.""" + with pytest.raises(pydantic.ValidationError): + AuthSettings(enabled=True, session_secret=None) + def test_enabled_requires_session_secret(self) -> None: """enabled=True with session_secret=None raises ValidationError.""" with pytest.raises(pydantic.ValidationError): @@ -170,6 +195,40 @@ def test_enabled_requires_non_empty_role_claim(self) -> None: role_claim="", ) + def test_cookie_enabled_with_all_required_fields_passes(self) -> None: + """cookie_enabled=True with all required fields set validates successfully.""" + settings = AuthSettings( + cookie_enabled=True, + session_secret=_TEST_SESSION_SECRET, + client_secret=_TEST_CLIENT_SECRET, + domain=_TEST_DOMAIN, + client_id=_TEST_CLIENT_ID, + internal_org_id=_INTERNAL_ORG_ID, + role_claim=_TEST_ROLE_CLAIM, + ) + assert settings.cookie_enabled is True + + def test_jwt_enabled_requires_domain(self) -> None: + """jwt_enabled=True with empty domain raises ValidationError.""" + with pytest.raises(pydantic.ValidationError, match="AUTH_DOMAIN"): + AuthSettings(jwt_enabled=True, domain="", jwt_audience=_TEST_JWT_AUDIENCE) + + def test_jwt_enabled_requires_audience(self) -> None: + """jwt_enabled=True with empty jwt_audience raises ValidationError.""" + with pytest.raises(pydantic.ValidationError, match="AUTH_JWT_AUDIENCE"): + AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience="") + + def test_jwt_enabled_with_all_fields_passes(self) -> None: + """jwt_enabled=True with domain and jwt_audience set does not raise.""" + settings = AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience=_TEST_JWT_AUDIENCE) + assert settings.jwt_enabled is True + + def test_jwt_enabled_independent_of_cookie_enabled(self) -> None: + """jwt_enabled=True can be set without cookie_enabled=True.""" + settings = AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience=_TEST_JWT_AUDIENCE) + assert settings.cookie_enabled is False + assert settings.enabled is False + @pytest.mark.integration class TestAuthSettingsEnvFile: @@ -195,6 +254,13 @@ def test_auth_settings_reads_fields_from_env_file_via_context( assert settings.role_claim == "claim_from_env_file" +def _make_bearer(token: str = _TEST_BEARER_TOKEN) -> MagicMock: + """Create a mock HTTPAuthorizationCredentials with the given token.""" + bearer = MagicMock() + bearer.credentials = token + return bearer + + @pytest.mark.integration class TestGetUser: """Tests for get_user FastAPI dependency.""" @@ -205,7 +271,7 @@ async def test_get_user_returns_none_without_session(self) -> None: request.app.state = MagicMock(spec=[]) # no auth_client → get_auth_client raises naturally cookie = None - result = await get_user(request, cookie) + result = await get_user(request, cookie, None) assert result is None @@ -219,7 +285,7 @@ async def test_get_user_returns_none_for_expired_session(self) -> None: fake_client.require_session = AsyncMock(return_value={"user": expired_user}) request.app.state.auth_client = fake_client - result = await get_user(request, cookie) + result = await get_user(request, cookie, None) assert result is None @@ -231,7 +297,7 @@ async def test_get_user_returns_none_when_session_has_no_user_key(self) -> None: fake_client.require_session = AsyncMock(return_value={}) request.app.state.auth_client = fake_client - result = await get_user(request, cookie) + result = await get_user(request, cookie, None) assert result is None @@ -243,7 +309,7 @@ async def test_get_user_returns_none_when_exp_claim_missing(self) -> None: fake_client.require_session = AsyncMock(return_value={"user": {"sub": "x"}}) request.app.state.auth_client = fake_client - result = await get_user(request, cookie) + result = await get_user(request, cookie, None) assert result is None @@ -255,7 +321,7 @@ async def test_get_user_returns_none_when_session_is_not_a_dict(self) -> None: fake_client.require_session = AsyncMock(return_value="not-a-dict") request.app.state.auth_client = fake_client - result = await get_user(request, cookie) + result = await get_user(request, cookie, None) assert result is None @@ -268,10 +334,162 @@ async def test_get_user_returns_user_for_valid_session(self) -> None: fake_client.require_session = AsyncMock(return_value={"user": user}) request.app.state.auth_client = fake_client - result = await get_user(request, cookie) + result = await get_user(request, cookie, None) + + assert result == user + + async def test_get_user_bearer_takes_priority_over_cookie(self, monkeypatch: pytest.MonkeyPatch) -> None: + """get_user returns JWT user when both bearer and cookie are valid.""" + monkeypatch.setenv(_JWT_ENABLED_VAR_NAME, "true") + monkeypatch.setenv(_DOMAIN_VAR_NAME, _TEST_DOMAIN) + monkeypatch.setenv(_JWT_AUDIENCE_VAR_NAME, _TEST_JWT_AUDIENCE) + + jwt_user = {"sub": "jwt|user", "email": "jwt@example.com", "exp": int(time.time()) + 3600} + cookie_user = {"sub": _USER_SUB, "email": _USER_EMAIL, "exp": int(time.time()) + 3600} + + request = MagicMock() + fake_client = MagicMock() + fake_client.require_session = AsyncMock(return_value={"user": cookie_user}) + request.app.state.auth_client = fake_client + + with patch("aignostics_foundry_core.api.auth._validate_jwt", AsyncMock(return_value=jwt_user)): + result = await get_user(request, "cookie-value", _make_bearer()) + + assert result == jwt_user + + async def test_get_user_falls_back_to_cookie_when_bearer_absent(self) -> None: + """get_user uses cookie when _bearer is None.""" + user = {"sub": _USER_SUB, "email": _USER_EMAIL, "exp": int(time.time()) + 3600} + request = MagicMock() + fake_client = MagicMock() + fake_client.require_session = AsyncMock(return_value={"user": user}) + request.app.state.auth_client = fake_client + + result = await get_user(request, "cookie-value", None) + + assert result == user + + async def test_get_user_falls_back_to_cookie_when_jwt_disabled(self) -> None: + """get_user uses cookie when jwt_enabled=False even if bearer token is present.""" + user = {"sub": _USER_SUB, "email": _USER_EMAIL, "exp": int(time.time()) + 3600} + request = MagicMock() + fake_client = MagicMock() + fake_client.require_session = AsyncMock(return_value={"user": user}) + request.app.state.auth_client = fake_client + + # jwt_enabled defaults to False; bearer is present but should be ignored + result = await get_user(request, "cookie-value", _make_bearer()) assert result == user + async def test_get_user_falls_back_to_cookie_when_bearer_invalid(self, monkeypatch: pytest.MonkeyPatch) -> None: + """get_user falls back to cookie when JWT validation fails.""" + monkeypatch.setenv(_JWT_ENABLED_VAR_NAME, "true") + monkeypatch.setenv(_DOMAIN_VAR_NAME, _TEST_DOMAIN) + monkeypatch.setenv(_JWT_AUDIENCE_VAR_NAME, _TEST_JWT_AUDIENCE) + + cookie_user = {"sub": _USER_SUB, "email": _USER_EMAIL, "exp": int(time.time()) + 3600} + request = MagicMock() + fake_client = MagicMock() + fake_client.require_session = AsyncMock(return_value={"user": cookie_user}) + request.app.state.auth_client = fake_client + + with patch("aignostics_foundry_core.api.auth._validate_jwt", AsyncMock(return_value=None)): + result = await get_user(request, "cookie-value", _make_bearer()) + + assert result == cookie_user + + +@pytest.mark.unit +class TestValidateJwt: + """Unit tests for _validate_jwt (JWT validation helper).""" + + @pytest.fixture + def mock_jwks(self) -> dict: + """A minimal JWKS response with one RSA key entry.""" + return {"keys": [{"kid": _TEST_KID, "kty": "RSA", "use": "sig"}]} + + async def test_validate_jwt_returns_none_when_kid_absent_after_refresh(self) -> None: + """_validate_jwt returns None when kid is absent from JWKS even after a force-refresh.""" + jwks_without_kid: dict = {"keys": []} + settings = AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience=_TEST_JWT_AUDIENCE) + + with patch(_FETCH_JWKS_PATH, AsyncMock(return_value=jwks_without_kid)): + import jwt + + with patch.object(jwt, "get_unverified_header", return_value={"kid": _TEST_KID, "alg": "RS256"}): + result = await _validate_jwt(_TEST_BEARER_TOKEN, settings) + + assert result is None + + async def test_validate_jwt_force_refreshes_on_kid_miss_and_succeeds(self) -> None: + """_validate_jwt retries with a force-refreshed JWKS when the kid is missing from cache.""" + import jwt + from jwt.algorithms import RSAAlgorithm + + stale_jwks: dict = {"keys": []} + fresh_jwks: dict = {"keys": [{"kid": _TEST_KID, "kty": "RSA", "use": "sig"}]} + expected_payload = {"sub": _USER_SUB, "exp": int(time.time()) + 3600} + settings = AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience=_TEST_JWT_AUDIENCE) + + fetch_mock = AsyncMock(side_effect=[stale_jwks, fresh_jwks]) + with ( + patch(_FETCH_JWKS_PATH, fetch_mock), + patch.object(jwt, "get_unverified_header", return_value={"kid": _TEST_KID, "alg": "RS256"}), + patch.object(RSAAlgorithm, "from_jwk", return_value=MagicMock()), + patch.object(jwt, "decode", return_value=expected_payload), + ): + result = await _validate_jwt(_TEST_BEARER_TOKEN, settings) + + assert result == expected_payload + assert fetch_mock.call_count == 2 + _, kwargs = fetch_mock.call_args + assert kwargs.get("force_refresh") is True + + async def test_validate_jwt_returns_none_on_fetch_failure(self) -> None: + """_validate_jwt returns None when JWKS fetch raises an exception.""" + settings = AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience=_TEST_JWT_AUDIENCE) + + with patch(_FETCH_JWKS_PATH, AsyncMock(side_effect=RuntimeError("network error"))): + result = await _validate_jwt(_TEST_BEARER_TOKEN, settings) + + assert result is None + + async def test_validate_jwt_returns_none_for_invalid_token(self, mock_jwks: dict) -> None: + """_validate_jwt returns None when jwt.decode raises (e.g., expired or bad signature).""" + import jwt + from jwt.algorithms import RSAAlgorithm + + settings = AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience=_TEST_JWT_AUDIENCE) + + with ( + patch(_FETCH_JWKS_PATH, AsyncMock(return_value=mock_jwks)), + patch.object(jwt, "get_unverified_header", return_value={"kid": _TEST_KID, "alg": "RS256"}), + patch.object(RSAAlgorithm, "from_jwk", return_value=MagicMock()), + patch.object(jwt, "decode", side_effect=jwt.ExpiredSignatureError("expired")), + ): + result = await _validate_jwt(_TEST_BEARER_TOKEN, settings) + + assert result is None + + async def test_validate_jwt_returns_payload_for_valid_token(self, mock_jwks: dict) -> None: + """_validate_jwt returns decoded payload when token is valid.""" + import jwt + from jwt.algorithms import RSAAlgorithm + + expected_payload = {"sub": _USER_SUB, "email": _USER_EMAIL, "exp": int(time.time()) + 3600} + settings = AuthSettings(jwt_enabled=True, domain=_TEST_DOMAIN, jwt_audience=_TEST_JWT_AUDIENCE) + + with ( + patch(_FETCH_JWKS_PATH, AsyncMock(return_value=mock_jwks)), + patch.object(jwt, "get_unverified_header", return_value={"kid": _TEST_KID, "alg": "RS256"}), + patch.object(RSAAlgorithm, "from_jwk", return_value=MagicMock()), + patch.object(jwt, "decode", return_value=expected_payload), + ): + result = await _validate_jwt(_TEST_BEARER_TOKEN, settings) + + assert result == expected_payload + @pytest.mark.integration class TestRequireAuthenticated: @@ -283,7 +501,7 @@ async def test_unauthenticated_user_raises_forbidden_error(self) -> None: request.app.state = MagicMock(spec=[]) # no auth_client → get_user returns None with pytest.raises(ForbiddenError, match=_USER_NOT_AUTHENTICATED): - await require_authenticated(request, None) + await require_authenticated(request, None, None) async def test_authenticated_user_passes(self) -> None: """require_authenticated returns None without raising when user is authenticated.""" @@ -293,7 +511,7 @@ async def test_authenticated_user_passes(self) -> None: fake_client.require_session = AsyncMock(return_value={"user": user}) request.app.state.auth_client = fake_client - result = await require_authenticated(request, None) + result = await require_authenticated(request, None, None) assert result is None @@ -308,7 +526,7 @@ async def test_no_user_raises_forbidden_error(self, monkeypatch: pytest.MonkeyPa request.app.state = MagicMock(spec=[]) # no auth_client → get_user returns None with pytest.raises(ForbiddenError): - await require_admin(request, None) + await require_admin(request, None, None) async def test_wrong_role_raises_forbidden_error(self, monkeypatch: pytest.MonkeyPatch) -> None: """require_admin raises ForbiddenError when user has a non-admin role.""" @@ -320,7 +538,7 @@ async def test_wrong_role_raises_forbidden_error(self, monkeypatch: pytest.Monke request.app.state.auth_client = fake_client with pytest.raises(ForbiddenError, match="does not match required role"): - await require_admin(request, None) + await require_admin(request, None, None) async def test_admin_role_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: """require_admin returns None without raising when user has the admin role.""" @@ -331,7 +549,7 @@ async def test_admin_role_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: fake_client.require_session = AsyncMock(return_value={"user": user}) request.app.state.auth_client = fake_client - result = await require_admin(request, None) + result = await require_admin(request, None, None) assert result is None @@ -339,15 +557,15 @@ async def test_admin_role_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: class TestRequireInternal: """Tests for require_internal FastAPI dependency.""" - async def test_unauthenticated_user_raises_forbidden_error(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_unauthenticated_user_raises_forbidden_error(self) -> None: """require_internal raises ForbiddenError when no session is available.""" request = MagicMock() request.app.state = MagicMock(spec=[]) # no auth_client → get_user returns None with pytest.raises(ForbiddenError, match=_USER_NOT_AUTHENTICATED): - await require_internal(request, None) + await require_internal(request, None, None) - async def test_wrong_org_raises_forbidden_error(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_wrong_org_raises_forbidden_error(self) -> None: """require_internal raises ForbiddenError when user belongs to a different org.""" request = MagicMock() user = {"sub": _USER_SUB, "org_id": _OTHER_ORG_ID, "exp": int(time.time()) + 3600} @@ -356,7 +574,7 @@ async def test_wrong_org_raises_forbidden_error(self, monkeypatch: pytest.Monkey request.app.state.auth_client = fake_client with pytest.raises(ForbiddenError, match="not a member of the internal organization"): - await require_internal(request, None) + await require_internal(request, None, None) async def test_internal_org_member_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: """require_internal returns None without raising when user is in the internal org.""" @@ -367,7 +585,7 @@ async def test_internal_org_member_passes(self, monkeypatch: pytest.MonkeyPatch) fake_client.require_session = AsyncMock(return_value={"user": user}) request.app.state.auth_client = fake_client - result = await require_internal(request, None) + result = await require_internal(request, None, None) assert result is None @@ -375,15 +593,15 @@ async def test_internal_org_member_passes(self, monkeypatch: pytest.MonkeyPatch) class TestRequireInternalAdmin: """Tests for require_internal_admin FastAPI dependency.""" - async def test_unauthenticated_user_raises_forbidden_error(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_unauthenticated_user_raises_forbidden_error(self) -> None: """require_internal_admin raises ForbiddenError when no session is available.""" request = MagicMock() request.app.state = MagicMock(spec=[]) # no auth_client → get_user returns None with pytest.raises(ForbiddenError, match=_USER_NOT_AUTHENTICATED): - await require_internal_admin(request, None) + await require_internal_admin(request, None, None) - async def test_wrong_org_raises_forbidden_error(self, monkeypatch: pytest.MonkeyPatch) -> None: + async def test_wrong_org_raises_forbidden_error(self) -> None: """require_internal_admin raises ForbiddenError when user belongs to a different org.""" request = MagicMock() user = {"sub": _USER_SUB, "org_id": _OTHER_ORG_ID, "exp": int(time.time()) + 3600} @@ -392,7 +610,7 @@ async def test_wrong_org_raises_forbidden_error(self, monkeypatch: pytest.Monkey request.app.state.auth_client = fake_client with pytest.raises(ForbiddenError, match="not a member of the internal organization"): - await require_internal_admin(request, None) + await require_internal_admin(request, None, None) async def test_correct_org_wrong_role_raises_forbidden_error(self, monkeypatch: pytest.MonkeyPatch) -> None: """require_internal_admin raises ForbiddenError when user is in internal org but lacks admin role.""" @@ -410,7 +628,7 @@ async def test_correct_org_wrong_role_raises_forbidden_error(self, monkeypatch: request.app.state.auth_client = fake_client with pytest.raises(ForbiddenError, match="does not match required role"): - await require_internal_admin(request, None) + await require_internal_admin(request, None, None) async def test_internal_admin_passes(self, monkeypatch: pytest.MonkeyPatch) -> None: """require_internal_admin returns None without raising when user is internal org admin.""" @@ -427,5 +645,85 @@ async def test_internal_admin_passes(self, monkeypatch: pytest.MonkeyPatch) -> N fake_client.require_session = AsyncMock(return_value={"user": user}) request.app.state.auth_client = fake_client - result = await require_internal_admin(request, None) + result = await require_internal_admin(request, None, None) assert result is None + + +@pytest.mark.unit +class TestFetchJwks: + """Unit tests for _fetch_jwks (JWKS fetching and caching helper).""" + + def setup_method(self) -> None: + """Clear the JWKS cache before each test for isolation.""" + _jwks_cache.clear() + + async def test_returns_cached_jwks_when_fresh(self) -> None: + """_fetch_jwks returns the in-memory cached JWKS without an HTTP call when fresh.""" + cached_jwks: dict = {"keys": [{"kid": _TEST_KID}]} + _jwks_cache[_TEST_DOMAIN] = _JwksCacheEntry(jwks=cached_jwks, fetched_at=time.time()) + + result = await _fetch_jwks(_TEST_DOMAIN) + + assert result is cached_jwks + + async def test_fetches_and_caches_on_cache_miss(self) -> None: + """_fetch_jwks makes an HTTP request, stores the result in cache, and returns it.""" + fetched_jwks: dict = {"keys": [{"kid": _TEST_KID, "kty": "RSA"}]} + mock_response = MagicMock() + mock_response.json.return_value = fetched_jwks + + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("httpx.AsyncClient", return_value=mock_client): + result = await _fetch_jwks(_TEST_DOMAIN) + + assert result == fetched_jwks + assert _TEST_DOMAIN in _jwks_cache + assert _jwks_cache[_TEST_DOMAIN].jwks == fetched_jwks + + async def test_force_refresh_bypasses_fresh_cache(self) -> None: + """_fetch_jwks hits the network even when a fresh cache entry exists if force_refresh=True.""" + stale_jwks: dict = {"keys": [{"kid": "old-kid"}]} + fresh_jwks: dict = {"keys": [{"kid": _TEST_KID, "kty": "RSA"}]} + _jwks_cache[_TEST_DOMAIN] = _JwksCacheEntry(jwks=stale_jwks, fetched_at=time.time()) + + mock_response = MagicMock() + mock_response.json.return_value = fresh_jwks + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("httpx.AsyncClient", return_value=mock_client): + result = await _fetch_jwks(_TEST_DOMAIN, force_refresh=True) + + assert result == fresh_jwks + mock_client.get.assert_called_once() + + async def test_stale_cache_returned_on_fetch_failure(self) -> None: + """_fetch_jwks returns the stale cached JWKS when the network request fails.""" + stale_jwks: dict = {"keys": [{"kid": _TEST_KID}]} + _jwks_cache[_TEST_DOMAIN] = _JwksCacheEntry(jwks=stale_jwks, fetched_at=0.0) # expired + + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=RuntimeError("network error")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("httpx.AsyncClient", return_value=mock_client): + result = await _fetch_jwks(_TEST_DOMAIN) + + assert result is stale_jwks + + async def test_raises_when_fetch_fails_and_no_cache(self) -> None: + """_fetch_jwks re-raises the exception when the network request fails and no cache exists.""" + mock_client = AsyncMock() + mock_client.get = AsyncMock(side_effect=RuntimeError("network error")) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("httpx.AsyncClient", return_value=mock_client), pytest.raises(RuntimeError, match="network error"): + await _fetch_jwks(_TEST_DOMAIN) diff --git a/uv.lock b/uv.lock index 629ff20..30d4b61 100644 --- a/uv.lock +++ b/uv.lock @@ -18,12 +18,14 @@ dependencies = [ { name = "certifi" }, { name = "chancy", extra = ["cron"] }, { name = "fastapi" }, + { name = "httpx2" }, { name = "loguru" }, { name = "nicegui" }, { name = "platformdirs" }, { name = "psutil" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "python-dotenv" }, { name = "rich" }, { name = "sentry-sdk" }, @@ -73,12 +75,14 @@ requires-dist = [ { name = "certifi", specifier = ">=2024" }, { name = "chancy", extras = ["cron"], specifier = ">=0.25.1,<1" }, { name = "fastapi", specifier = ">=0.110,<1" }, + { name = "httpx2", specifier = ">=2.2.0,<3" }, { name = "loguru", specifier = ">=0.7,<1" }, { name = "nicegui", specifier = ">=3,<4" }, { name = "platformdirs", specifier = ">=4,<5" }, { name = "psutil", specifier = ">=6" }, { name = "pydantic", specifier = ">=2,<3" }, { name = "pydantic-settings", specifier = ">=2,<3" }, + { name = "pyjwt", extras = ["cryptography"], specifier = ">=2.10,<3" }, { name = "python-dotenv", specifier = ">=1,<2" }, { name = "rich", specifier = ">=15,<16" }, { name = "sentry-sdk", specifier = ">=2,<3" }, @@ -1047,11 +1051,11 @@ wheels = [ [[package]] name = "docutils" -version = "0.22.4" +version = "0.23" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ae/b6/03bb70946330e88ffec97aefd3ea75ba575cb2e762061e0e62a213befee8/docutils-0.22.4.tar.gz", hash = "sha256:4db53b1fde9abecbb74d91230d32ab626d94f6badfc575d6db9194a49df29968", size = 2291750, upload-time = "2025-12-18T19:00:26.443Z" } +sdist = { url = "https://files.pythonhosted.org/packages/39/a4/5180d9afc57e8fca05601dd652bdff19604c218814037fe90ffc7625a50a/docutils-0.23.tar.gz", hash = "sha256:746f5060322511280a1e50eb76846ed6bf2342984b2ac04dc42caa1a8d78799e", size = 2303823, upload-time = "2026-05-27T17:41:06.934Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" }, + { url = "https://files.pythonhosted.org/packages/32/91/30151a39f7570f448ed84529390628a651d7f27c87d73c9b887f8189695e/docutils-0.23-py3-none-any.whl", hash = "sha256:25d013af9bf23bc1c7b2b093dff4208166c53a94786c9e447808335ef1185fea", size = 634701, upload-time = "2026-05-27T17:40:58.442Z" }, ] [[package]] @@ -1294,6 +1298,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, ] +[[package]] +name = "httpcore2" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1a/7e/8ab39aab1d392845b6512009a9be57d24a5bd4ec7a22d02e513d0645e7a8/httpcore2-2.2.0.tar.gz", hash = "sha256:10e0e142f1ecc1c1cb2a9ebbce82e57f16169f61d163ea336abf36799e89294b", size = 63533, upload-time = "2026-05-17T05:29:55.836Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/22/64de17e7956e8c002f7558ed667d924c2a288344aeff4bd8ff5dc5fdb70b/httpcore2-2.2.0-py3-none-any.whl", hash = "sha256:ce859f268bf8d34fa2d7753e09e4dd5194f557e1b3038439b68a89b2999572fa", size = 79288, upload-time = "2026-05-17T05:29:52.56Z" }, +] + [[package]] name = "httptools" version = "0.8.0" @@ -1352,6 +1369,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx2" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore2" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/aa/c3119de1aa7ad870a01aaddbf3bc3445ed9a681c31d45e3838fd8b7bc155/httpx2-2.2.0.tar.gz", hash = "sha256:f3428d59b1752b8f5629826277262fb4d65e3a683f48af8a5b16c4d012e0b801", size = 80477, upload-time = "2026-05-17T05:29:57.376Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/e0/e0a52596c14194e428c20de4903f4abec38c0dfb5364d20f1d4a2b6266ef/httpx2-2.2.0-py3-none-any.whl", hash = "sha256:12347ebd2daeaefd50b529359778fff767082a09c5826752c963e71269722ff0", size = 74083, upload-time = "2026-05-17T05:29:54.543Z" }, +] + [[package]] name = "humanize" version = "4.15.0" @@ -2163,11 +2195,11 @@ wheels = [ [[package]] name = "platformdirs" -version = "4.9.6" +version = "4.10.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9f/4a/0883b8e3802965322523f0b200ecf33d31f10991d0401162f4b23c698b42/platformdirs-4.9.6.tar.gz", hash = "sha256:3bfa75b0ad0db84096ae777218481852c0ebc6c727b3168c1b9e0118e458cf0a", size = 29400, upload-time = "2026-04-09T00:04:10.812Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/47/e4501f49c178ae1d9f4a75073fda4204f52647993f075a9db4d14930e0c5/platformdirs-4.10.0.tar.gz", hash = "sha256:31e761a6a0ca04faf7353ea759bdba55652be214725111e5aac52dfa29d4bef7", size = 31224, upload-time = "2026-05-28T03:32:53.587Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" }, + { url = "https://files.pythonhosted.org/packages/81/e6/cd9575ac904136b3cbf7aa7ee819ef86eedb7274e46f230e94ea4342e729/platformdirs-4.10.0-py3-none-any.whl", hash = "sha256:fb516cdb12eb0d857d0cd85a7c57cea4d060bee4578d6cf5a14dfdf8cbf8784a", size = 22743, upload-time = "2026-05-28T03:32:52.175Z" }, ] [[package]] @@ -2854,15 +2886,15 @@ wheels = [ [[package]] name = "python-discovery" -version = "1.3.1" +version = "1.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "platformdirs" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/48/60/e88788207d81e46362cfbef0d4aaf4c0f49efc3c12d4c3fa3f542c34ebec/python_discovery-1.3.1.tar.gz", hash = "sha256:62f6db28064c9613e7ca76cb3f00c38c839a07c31c00dfe7ed0986493d2150a6", size = 68011, upload-time = "2026-05-12T20:53:36.336Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/12/38c1a0b1e64806780c9563e3fc9f6e472251839662587cfbe9bfaf2ae10a/python_discovery-1.4.0.tar.gz", hash = "sha256:eb8bc7daad3c226c147e45bb4e970a1feb1bf4048ee178e6db59e197b8010ce3", size = 68455, upload-time = "2026-05-28T01:15:37.639Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/6f/a05a317a66fee0aad270011461f1a63a453ed12471249f172f7d2e2bc7b4/python_discovery-1.3.1-py3-none-any.whl", hash = "sha256:ed188687ebb3b82c01a17cd5ac62fc94d9f6487a7f1a0f9dfe89753fec91039c", size = 33185, upload-time = "2026-05-12T20:53:34.969Z" }, + { url = "https://files.pythonhosted.org/packages/c8/8d/3d316429f65029532bb1e28ff77b797d86b5ac3915bb44ca4e19aa283d43/python_discovery-1.4.0-py3-none-any.whl", hash = "sha256:26ed78d703e234879a66244c7d4114563fb13ec5cd30a2d1357e5fb4850782da", size = 33217, upload-time = "2026-05-28T01:15:36.573Z" }, ] [[package]] @@ -3211,15 +3243,15 @@ wheels = [ [[package]] name = "sentry-sdk" -version = "2.60.0" +version = "2.61.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/a2/2e6c090db384cc515069f4f85542bd5baf6786852073020ea73d4a76d3ea/sentry_sdk-2.60.0.tar.gz", hash = "sha256:0bd25e54e78ca02d0be512529fa644bbbf9e8470d7b26371294012d4ca93c978", size = 452946, upload-time = "2026-05-13T13:34:52.516Z" } +sdist = { url = "https://files.pythonhosted.org/packages/52/4d/3c66e6045bd2071256b6b6fdcb0cc02b86ce54b2acc2ceac79af8e0efbb5/sentry_sdk-2.61.0.tar.gz", hash = "sha256:1ca9b4bb777eb5be67004edab7eb894f21c6301f1d05ed64966719ad5d1764ce", size = 458510, upload-time = "2026-05-28T09:40:28.917Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/41/f2b800b7f12a05dd48c2a6280d4dd812d1425fc66ed3fe3fd99420c41d1a/sentry_sdk-2.60.0-py3-none-any.whl", hash = "sha256:28a536c03291c8bcb363cf35c611b32738ec118ff64d8d6383b096448ac4c803", size = 475616, upload-time = "2026-05-13T13:34:50.259Z" }, + { url = "https://files.pythonhosted.org/packages/21/5a/9794736d5802689c1a48862e6afe6b7f3e86cc37c15d4a84bc0143877dc1/sentry_sdk-2.61.0-py3-none-any.whl", hash = "sha256:ec4d30273909cb1d198e03208b16ee70e2bc5d90a16fd9f1fb2fc6a72e1f03dc", size = 483111, upload-time = "2026-05-28T09:40:27.027Z" }, ] [[package]] @@ -3630,7 +3662,7 @@ wheels = [ [[package]] name = "virtualenv" -version = "21.3.3" +version = "21.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, @@ -3638,9 +3670,9 @@ dependencies = [ { name = "platformdirs" }, { name = "python-discovery" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/15/ba/1f6e8c957e4932be060dcdc482d339c12e0216351478add3645cdaa53c05/virtualenv-21.3.3.tar.gz", hash = "sha256:f5bda277e553b1c2b3c1a8debfc30496e1288cc93ce6b7b71b3280047e317328", size = 7613784, upload-time = "2026-05-13T18:01:30.19Z" } +sdist = { url = "https://files.pythonhosted.org/packages/95/f0/b47ecf438211a25a97f8f0e4b23c22bc2496ebfea18dd6ec16210f09cc36/virtualenv-21.4.1.tar.gz", hash = "sha256:2ca543c713b72840ceffd94e9bdedfbd09a661defa1f7f69e5429ad4059442e2", size = 7613344, upload-time = "2026-05-28T04:12:49.905Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f4/34/a9dbe051de88a63eb7408ea66630bac38e72f7f6077d4be58737106860d9/virtualenv-21.3.3-py3-none-any.whl", hash = "sha256:7d5987d8369e098e41406efb780a3d4ca79280097293899e351a6407ee153ab3", size = 7594554, upload-time = "2026-05-13T18:01:27.815Z" }, + { url = "https://files.pythonhosted.org/packages/ff/dc/ac4f3a987a87e1a18556896f257c4e15c95ed157b7975347ec6b313b75ce/virtualenv-21.4.1-py3-none-any.whl", hash = "sha256:caf4ff72d1b4039057f41d8e8466e859513d67c0400d9c6b62c02c9d1ebc3e12", size = 7594078, upload-time = "2026-05-28T04:12:47.686Z" }, ] [[package]]