From 682ff1e9e3c8c66aafa5a9f1acd2e21deeeaa9a4 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Thu, 6 Mar 2025 10:02:33 -0800 Subject: [PATCH 01/60] wip --- pyproject.toml | 3 +- src/mcp/server/auth/__init__.py | 3 + src/mcp/server/auth/errors.py | 135 +++++ src/mcp/server/auth/handlers/__init__.py | 3 + src/mcp/server/auth/handlers/authorize.py | 150 +++++ src/mcp/server/auth/handlers/metadata.py | 43 ++ src/mcp/server/auth/handlers/register.py | 106 ++++ src/mcp/server/auth/handlers/revoke.py | 58 ++ src/mcp/server/auth/handlers/token.py | 142 +++++ src/mcp/server/auth/middleware/__init__.py | 3 + src/mcp/server/auth/middleware/bearer_auth.py | 98 +++ src/mcp/server/auth/middleware/client_auth.py | 118 ++++ src/mcp/server/auth/provider.py | 162 +++++ src/mcp/server/auth/router.py | 177 ++++++ src/mcp/server/auth/types.py | 23 + src/mcp/server/fastmcp/server.py | 59 +- src/mcp/shared/auth.py | 123 ++++ tests/server/fastmcp/auth/__init__.py | 3 + .../fastmcp/auth/test_auth_integration.py | 558 ++++++++++++++++++ 19 files changed, 1956 insertions(+), 11 deletions(-) create mode 100644 src/mcp/server/auth/__init__.py create mode 100644 src/mcp/server/auth/errors.py create mode 100644 src/mcp/server/auth/handlers/__init__.py create mode 100644 src/mcp/server/auth/handlers/authorize.py create mode 100644 src/mcp/server/auth/handlers/metadata.py create mode 100644 src/mcp/server/auth/handlers/register.py create mode 100644 src/mcp/server/auth/handlers/revoke.py create mode 100644 src/mcp/server/auth/handlers/token.py create mode 100644 src/mcp/server/auth/middleware/__init__.py create mode 100644 src/mcp/server/auth/middleware/bearer_auth.py create mode 100644 src/mcp/server/auth/middleware/client_auth.py create mode 100644 src/mcp/server/auth/provider.py create mode 100644 src/mcp/server/auth/router.py create mode 100644 src/mcp/server/auth/types.py create mode 100644 src/mcp/shared/auth.py create mode 100644 tests/server/fastmcp/auth/__init__.py create mode 100644 tests/server/fastmcp/auth/test_auth_integration.py diff --git a/pyproject.toml b/pyproject.toml index 157263de68..e87136758c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", + "fastapi", ] [project.optional-dependencies] @@ -47,7 +48,7 @@ dev-dependencies = [ "pytest>=8.3.4", "ruff>=0.8.5", "trio>=0.26.2", - "pytest-flakefinder>=1.1.0", + "pytest-flakefinder==1.1.0", "pytest-xdist>=3.6.1", ] diff --git a/src/mcp/server/auth/__init__.py b/src/mcp/server/auth/__init__.py new file mode 100644 index 0000000000..5ad769fdfe --- /dev/null +++ b/src/mcp/server/auth/__init__.py @@ -0,0 +1,3 @@ +""" +MCP OAuth server authorization components. +""" \ No newline at end of file diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py new file mode 100644 index 0000000000..702df08c91 --- /dev/null +++ b/src/mcp/server/auth/errors.py @@ -0,0 +1,135 @@ +""" +OAuth error classes for MCP authorization. + +Corresponds to TypeScript file: src/server/auth/errors.ts +""" + +from typing import Dict, Optional, Any + + +class OAuthError(Exception): + """ + Base class for all OAuth errors. + + Corresponds to OAuthError in src/server/auth/errors.ts + """ + error_code: str = "server_error" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def to_response_object(self) -> Dict[str, str]: + """Convert error to JSON response object.""" + return { + "error": self.error_code, + "error_description": self.message + } + + +class ServerError(OAuthError): + """ + Server error. + + Corresponds to ServerError in src/server/auth/errors.ts + """ + error_code = "server_error" + + +class InvalidRequestError(OAuthError): + """ + Invalid request error. + + Corresponds to InvalidRequestError in src/server/auth/errors.ts + """ + error_code = "invalid_request" + + +class InvalidClientError(OAuthError): + """ + Invalid client error. + + Corresponds to InvalidClientError in src/server/auth/errors.ts + """ + error_code = "invalid_client" + + +class InvalidGrantError(OAuthError): + """ + Invalid grant error. + + Corresponds to InvalidGrantError in src/server/auth/errors.ts + """ + error_code = "invalid_grant" + + +class UnauthorizedClientError(OAuthError): + """ + Unauthorized client error. + + Corresponds to UnauthorizedClientError in src/server/auth/errors.ts + """ + error_code = "unauthorized_client" + + +class UnsupportedGrantTypeError(OAuthError): + """ + Unsupported grant type error. + + Corresponds to UnsupportedGrantTypeError in src/server/auth/errors.ts + """ + error_code = "unsupported_grant_type" + + +class UnsupportedResponseTypeError(OAuthError): + """ + Unsupported response type error. + + Corresponds to UnsupportedResponseTypeError in src/server/auth/errors.ts + """ + error_code = "unsupported_response_type" + + +class InvalidScopeError(OAuthError): + """ + Invalid scope error. + + Corresponds to InvalidScopeError in src/server/auth/errors.ts + """ + error_code = "invalid_scope" + + +class AccessDeniedError(OAuthError): + """ + Access denied error. + + Corresponds to AccessDeniedError in src/server/auth/errors.ts + """ + error_code = "access_denied" + + +class TemporarilyUnavailableError(OAuthError): + """ + Temporarily unavailable error. + + Corresponds to TemporarilyUnavailableError in src/server/auth/errors.ts + """ + error_code = "temporarily_unavailable" + + +class InvalidTokenError(OAuthError): + """ + Invalid token error. + + Corresponds to InvalidTokenError in src/server/auth/errors.ts + """ + error_code = "invalid_token" + + +class InsufficientScopeError(OAuthError): + """ + Insufficient scope error. + + Corresponds to InsufficientScopeError in src/server/auth/errors.ts + """ + error_code = "insufficient_scope" \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/__init__.py b/src/mcp/server/auth/handlers/__init__.py new file mode 100644 index 0000000000..fb01dab61f --- /dev/null +++ b/src/mcp/server/auth/handlers/__init__.py @@ -0,0 +1,3 @@ +""" +Request handlers for MCP authorization endpoints. +""" \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py new file mode 100644 index 0000000000..2eabd0a6e5 --- /dev/null +++ b/src/mcp/server/auth/handlers/authorize.py @@ -0,0 +1,150 @@ +""" +Handler for OAuth 2.0 Authorization endpoint. + +Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts +""" + +import re +from urllib.parse import urlparse, urlunparse, urlencode +from typing import Any, Callable, Dict, List, Literal, Optional +from urllib.parse import urlencode, parse_qs + +from fastapi import Request, Response +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, ValidationError +from pydantic_core import Url +from starlette.responses import JSONResponse, RedirectResponse + +from mcp.server.auth.errors import ( + InvalidClientError, + InvalidRequestError, + UnsupportedResponseTypeError, + ServerError, + OAuthError, +) +from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider +from mcp.shared.auth import OAuthClientInformationFull + + +class AuthorizationRequest(BaseModel): + """ + Model for the authorization request parameters. + + Corresponds to request schema in authorizationHandler in src/server/auth/handlers/authorize.ts + """ + client_id: str = Field(..., description="The client ID") + redirect_uri: AnyHttpUrl | None = Field(..., description="URL to redirect to after authorization") + + response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") + code_challenge: str = Field(..., description="PKCE code challenge") + code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method") + state: Optional[str] = Field(None, description="Optional state parameter") + scope: Optional[str] = Field(None, description="Optional scope parameter") + + class Config: + extra = "ignore" + +def validate_scope(requested_scope: str | None, client: OAuthClientInformationFull) -> list[str] | None: + if requested_scope is None: + return None + requested_scopes = requested_scope.split(" ") + allowed_scopes = [] if client.scope is None else client.scope.split(" ") + for scope in requested_scopes: + if scope not in allowed_scopes: + raise InvalidRequestError(f"Client was not registered with scope {scope}") + return requested_scopes + +def validate_redirect_uri(auth_request: AuthorizationRequest, client: OAuthClientInformationFull) -> AnyHttpUrl: + if auth_request.redirect_uri is not None: + # Validate redirect_uri against client's registered redirect URIs + if auth_request.redirect_uri not in client.redirect_uris: + raise InvalidRequestError( + f"Redirect URI '{auth_request.redirect_uri}' not registered for client" + ) + return auth_request.redirect_uri + elif len(client.redirect_uris) == 1: + return client.redirect_uris[0] + else: + raise InvalidRequestError("redirect_uri must be specified when client has multiple registered URIs") + +def create_authorization_handler(provider: OAuthServerProvider) -> Callable: + """ + Create a handler for the OAuth 2.0 Authorization endpoint. + + Corresponds to authorizationHandler in src/server/auth/handlers/authorize.ts + + """ + + async def authorization_handler(request: Request) -> Response: + """ + Handler for the OAuth 2.0 Authorization endpoint. + """ + # Validate request parameters + try: + if request.method == "GET": + auth_request = AuthorizationRequest.model_validate(request.query_params) + else: + auth_request = AuthorizationRequest.model_validate_json(await request.body()) + except ValidationError as e: + raise InvalidRequestError(str(e)) + + # Get client information + try: + client = await provider.clients_store.get_client(auth_request.client_id) + except OAuthError as e: + # TODO: proper error rendering + raise InvalidClientError(str(e)) + + if not client: + raise InvalidClientError(f"Client ID '{auth_request.client_id}' not found") + + + # do validation which is dependent on the client configuration + redirect_uri = validate_redirect_uri(auth_request, client) + scopes = validate_scope(auth_request.scope, client) + + auth_params = AuthorizationParams( + state=auth_request.state, + scopes=scopes, + code_challenge=auth_request.code_challenge, + redirect_uri=redirect_uri, + ) + + response = RedirectResponse(url="", status_code=302, headers={"Cache-Control": "no-store"}) + + try: + # Let the provider handle the authorization flow + await provider.authorize(client, auth_params, response) + + return response + except Exception as e: + return RedirectResponse( + url=create_error_redirect(redirect_uri, e, auth_request.state), + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + + return authorization_handler + +def create_error_redirect(redirect_uri: AnyUrl, error: Exception, state: Optional[str]) -> str: + parsed_uri = urlparse(str(redirect_uri)) + if isinstance(error, OAuthError): + query_params = { + "error": error.error_code, + "error_description": str(error) + } + else: + query_params = { + "error": "internal_error", + "error_description": "An unknown error occurred" + } + # TODO: should we add error_uri? + # if error.error_uri: + # query_params["error_uri"] = str(error.error_uri) + if state: + query_params["state"] = state + + new_query = urlencode(query_params) + if parsed_uri.query: + new_query = f"{parsed_uri.query}&{new_query}" + + return urlunparse(parsed_uri._replace(query=new_query)) \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py new file mode 100644 index 0000000000..2acee117a9 --- /dev/null +++ b/src/mcp/server/auth/handlers/metadata.py @@ -0,0 +1,43 @@ +""" +Handler for OAuth 2.0 Authorization Server Metadata. + +Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts +""" + +from typing import Any, Callable, Dict, Optional +from fastapi import Request, Response +from starlette.responses import JSONResponse + + +def create_metadata_handler(metadata: Dict[str, Any]) -> Callable: + """ + Create a handler for OAuth 2.0 Authorization Server Metadata. + + Corresponds to metadataHandler in src/server/auth/handlers/metadata.ts + + Args: + metadata: The metadata to return in the response + + Returns: + A FastAPI route handler function + """ + + async def metadata_handler(request: Request) -> Response: + """ + Handler for the OAuth 2.0 Authorization Server Metadata endpoint. + + Args: + request: The FastAPI request + + Returns: + JSON response with the authorization server metadata + """ + # Remove any None values from metadata + clean_metadata = {k: v for k, v in metadata.items() if v is not None} + + return JSONResponse( + content=clean_metadata, + headers={"Cache-Control": "public, max-age=3600"} # Cache for 1 hour + ) + + return metadata_handler \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py new file mode 100644 index 0000000000..47527ea4e7 --- /dev/null +++ b/src/mcp/server/auth/handlers/register.py @@ -0,0 +1,106 @@ +""" +Handler for OAuth 2.0 Dynamic Client Registration. + +Corresponds to TypeScript file: src/server/auth/handlers/register.ts +""" + +import random +import secrets +import time +from typing import Any, Callable, Dict, List, Optional +from uuid import uuid4 + +from fastapi import Request, Response +from pydantic import ValidationError +from starlette.responses import JSONResponse + +from mcp.server.auth.errors import ( + InvalidRequestError, + ServerError, + OAuthError, +) +from mcp.server.auth.provider import OAuthRegisteredClientsStore +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + + +def create_registration_handler(clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None) -> Callable: + """ + Create a handler for OAuth 2.0 Dynamic Client Registration. + + Corresponds to clientRegistrationHandler in src/server/auth/handlers/register.ts + + Args: + clients_store: The store for registered clients + + Returns: + A FastAPI route handler function + """ + + async def registration_handler(request: Request) -> Response: + """ + Handler for the OAuth 2.0 Dynamic Client Registration endpoint. + + Args: + request: The FastAPI request + + Returns: + JSON response with client information or error + """ + try: + # Validate client metadata + try: + client_metadata = OAuthClientMetadata.model_validate_json(await request.body()) + except ValidationError as e: + raise InvalidRequestError(f"Invalid client metadata: {str(e)}") + + client_id = str(uuid4()) + client_secret = None + if client_metadata.token_endpoint_auth_method != "none": + # cryptographically secure random 32-byte hex string + client_secret = secrets.token_hex(32) + + client_id_issued_at = int(time.time()) + client_secret_expires_at = client_id_issued_at + client_secret_expiry_seconds if client_secret_expiry_seconds is not None else None + + client_info = OAuthClientInformationFull( + client_id=client_id, + client_id_issued_at=client_id_issued_at, + client_secret=client_secret, + client_secret_expires_at=client_secret_expires_at, + # passthrough information from the client request + redirect_uris=client_metadata.redirect_uris, + token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, + grant_types=client_metadata.grant_types, + response_types=client_metadata.response_types, + client_name=client_metadata.client_name, + client_uri=client_metadata.client_uri, + logo_uri=client_metadata.logo_uri, + scope=client_metadata.scope, + contacts=client_metadata.contacts, + tos_uri=client_metadata.tos_uri, + policy_uri=client_metadata.policy_uri, + jwks_uri=client_metadata.jwks_uri, + jwks=client_metadata.jwks, + software_id=client_metadata.software_id, + software_version=client_metadata.software_version, + ) + # Register client + client = await clients_store.register_client(client_info) + if not client: + raise ServerError("Failed to register client") + + # Return client information + return JSONResponse( + content=client.model_dump(exclude_none=True), + status_code=201 + ) + + except OAuthError as e: + # Handle OAuth errors + status_code = 500 if isinstance(e, ServerError) else 400 + return JSONResponse( + status_code=status_code, + content=e.to_response_object() + ) + + return registration_handler \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py new file mode 100644 index 0000000000..59a11918a1 --- /dev/null +++ b/src/mcp/server/auth/handlers/revoke.py @@ -0,0 +1,58 @@ +""" +Handler for OAuth 2.0 Token Revocation. + +Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts +""" + +from typing import Any, Callable, Dict, Optional + +from fastapi import Request, Response +from pydantic import ValidationError +from starlette.responses import JSONResponse, Response as StarletteResponse + +from mcp.server.auth.errors import ( + InvalidRequestError, + ServerError, + OAuthError, +) +from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthClientInformationFull, OAuthTokenRevocationRequest + + +def create_revocation_handler(provider: OAuthServerProvider) -> Callable: + """ + Create a handler for OAuth 2.0 Token Revocation. + + Corresponds to revocationHandler in src/server/auth/handlers/revoke.ts + + Args: + provider: The OAuth server provider + + Returns: + A FastAPI route handler function + """ + + async def revocation_handler(request: Request, client_auth: OAuthClientInformationFull) -> Response: + """ + Handler for the OAuth 2.0 Token Revocation endpoint. + """ + # Validate revocation request + try: + revocation_request = OAuthTokenRevocationRequest.model_validate_json(await request.body()) + except ValidationError as e: + raise InvalidRequestError(str(e)) + + # Revoke token + if provider.revoke_token: + await provider.revoke_token(client_auth, revocation_request) + + # Return successful empty response + return StarletteResponse( + status_code=200, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + } + ) + + return revocation_handler \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py new file mode 100644 index 0000000000..9164991a69 --- /dev/null +++ b/src/mcp/server/auth/handlers/token.py @@ -0,0 +1,142 @@ +""" +Handler for OAuth 2.0 Token endpoint. + +Corresponds to TypeScript file: src/server/auth/handlers/token.ts +""" + +import base64 +import hashlib +import json +from typing import Any, Callable, Dict, List, Optional, Union + +from fastapi import Request, Response +from pydantic import BaseModel, Field, ValidationError +from starlette.responses import JSONResponse + +from mcp.server.auth.errors import ( + InvalidClientError, + InvalidGrantError, + InvalidRequestError, + ServerError, + UnsupportedGrantTypeError, + OAuthError, +) +from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthClientInformationFull, OAuthTokens +from mcp.server.auth.middleware.client_auth import ClientAuthDependency + +class AuthorizationCodeRequest(BaseModel): + """ + Model for the authorization code grant request parameters. + + Corresponds to AuthorizationCodeExchangeSchema in src/server/auth/handlers/token.ts + """ + grant_type: str = Field(..., description="Must be 'authorization_code'") + code: str = Field(..., description="The authorization code") + code_verifier: str = Field(..., description="PKCE code verifier") + + class Config: + extra = "ignore" + + +class RefreshTokenRequest(BaseModel): + """ + Model for the refresh token grant request parameters. + + Corresponds to RefreshTokenExchangeSchema in src/server/auth/handlers/token.ts + """ + grant_type: str = Field(..., description="Must be 'refresh_token'") + refresh_token: str = Field(..., description="The refresh token") + scope: Optional[str] = Field(None, description="Optional scope parameter") + + class Config: + extra = "ignore" + + +def create_token_handler(provider: OAuthServerProvider) -> Callable: + """ + Create a handler for the OAuth 2.0 Token endpoint. + + Corresponds to tokenHandler in src/server/auth/handlers/token.ts + + Args: + provider: The OAuth server provider + + Returns: + A FastAPI route handler function + """ + + async def token_handler(request: Request, client_auth: OAuthClientInformationFull) -> Response: + """ + Handler for the OAuth 2.0 Token endpoint. + + Args: + request: The FastAPI request + + Returns: + JSON response with tokens or error + """ + params = json.loads(await request.body()) + + + # Check grant_type first to determine which validation model to use + if "grant_type" not in params: + raise InvalidRequestError("Missing required parameter: grant_type") + grant_type = params["grant_type"] + + tokens: OAuthTokens + + if grant_type == "authorization_code": + # Validate authorization code parameters + try: + code_request = AuthorizationCodeRequest.model_validate(params) + except ValidationError as e: + raise InvalidRequestError(str(e)) + + # Verify PKCE code verifier + expected_challenge = await provider.challenge_for_authorization_code( + client_auth, code_request.code + ) + if expected_challenge is None: + raise InvalidRequestError("Invalid authorization code") + + # Calculate challenge from verifier + sha256 = hashlib.sha256(code_request.code_verifier.encode()).digest() + actual_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + + if actual_challenge != expected_challenge: + raise InvalidRequestError("code_verifier does not match the challenge") + + # Exchange authorization code for tokens + tokens = await provider.exchange_authorization_code(client_auth, code_request.code) + + elif grant_type == "refresh_token": + # Validate refresh token parameters + try: + refresh_request = RefreshTokenRequest.model_validate(params) + except ValidationError as e: + raise InvalidRequestError(str(e)) + + # Parse scopes if provided + scopes = refresh_request.scope.split(" ") if refresh_request.scope else None + + # Exchange refresh token for new tokens + tokens = await provider.exchange_refresh_token( + client_auth, refresh_request.refresh_token, scopes + ) + + else: + raise InvalidRequestError( + f"Unsupported grant_type: {grant_type}" + ) + + return JSONResponse( + content=tokens, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + } + ) + + + return token_handler \ No newline at end of file diff --git a/src/mcp/server/auth/middleware/__init__.py b/src/mcp/server/auth/middleware/__init__.py new file mode 100644 index 0000000000..60de91e41f --- /dev/null +++ b/src/mcp/server/auth/middleware/__init__.py @@ -0,0 +1,3 @@ +""" +Middleware for MCP authorization. +""" \ No newline at end of file diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py new file mode 100644 index 0000000000..c7b181434d --- /dev/null +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -0,0 +1,98 @@ +""" +Bearer token authentication dependency for FastAPI. + +Corresponds to TypeScript file: src/server/auth/middleware/bearerAuth.ts +""" + +import time +from typing import List, Optional + +from fastapi import Request, HTTPException +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from mcp.server.auth.errors import InsufficientScopeError, InvalidTokenError, OAuthError +from mcp.server.auth.provider import OAuthServerProvider +from mcp.server.auth.types import AuthInfo + + +class BearerAuthDependency: + """ + Dependency that requires a valid Bearer token in the Authorization header. + + This will validate the token with the auth provider and return the resulting + auth info. + + Corresponds to requireBearerAuth in src/server/auth/middleware/bearerAuth.ts + """ + + def __init__( + self, + provider: OAuthServerProvider, + required_scopes: Optional[List[str]] = None + ): + """ + Initialize the dependency. + + Args: + provider: Authentication provider to validate tokens + required_scopes: Optional list of scopes that the token must have + """ + self.provider = provider + self.required_scopes = required_scopes or [] + self.bearer_scheme = HTTPBearer() + + async def __call__(self, request: Request) -> AuthInfo: + """ + Process the request and validate the bearer token. + + Args: + request: FastAPI request + + Returns: + Authenticated auth info + + Raises: + HTTPException: If token validation fails + """ + try: + # Extract and validate the authorization header using FastAPI's built-in scheme + credentials: HTTPAuthorizationCredentials = await self.bearer_scheme(request) + token = credentials.credentials + + # Validate the token with the provider + auth_info: AuthInfo = await self.provider.verify_access_token(token) + + # Check if the token has all required scopes + if self.required_scopes: + has_all_scopes = all(scope in auth_info.scopes for scope in self.required_scopes) + if not has_all_scopes: + raise InsufficientScopeError("Insufficient scope") + + # Check if the token is expired + if auth_info.expires_at and auth_info.expires_at < int(time.time()): + raise InvalidTokenError("Token has expired") + + return auth_info + + except InvalidTokenError as e: + # Return a 401 Unauthorized response with appropriate headers + headers = {"WWW-Authenticate": f'Bearer error="{e.error_code}", error_description="{str(e)}"'} + raise HTTPException( + status_code=401, + detail=e.to_response_object(), + headers=headers + ) + except InsufficientScopeError as e: + # Return a 403 Forbidden response with appropriate headers + headers = {"WWW-Authenticate": f'Bearer error="{e.error_code}", error_description="{str(e)}"'} + raise HTTPException( + status_code=403, + detail=e.to_response_object(), + headers=headers + ) + except OAuthError as e: + # Return a 400 Bad Request response for other OAuth errors + raise HTTPException( + status_code=400, + detail=e.to_response_object() + ) \ No newline at end of file diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py new file mode 100644 index 0000000000..040894381a --- /dev/null +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -0,0 +1,118 @@ +""" +Client authentication dependency for FastAPI. + +Corresponds to TypeScript file: src/server/auth/middleware/clientAuth.ts +""" + +import time +from typing import Optional + +from fastapi import Request, HTTPException, Depends +from pydantic import BaseModel, ValidationError + +from mcp.server.auth.errors import ( + InvalidClientError, + InvalidRequestError, + OAuthError, + ServerError, +) +from mcp.server.auth.provider import OAuthRegisteredClientsStore +from mcp.shared.auth import OAuthClientInformationFull + + +class ClientAuthRequest(BaseModel): + """ + Model for client authentication request body. + + Corresponds to ClientAuthenticatedRequestSchema in src/server/auth/middleware/clientAuth.ts + """ + client_id: str + client_secret: Optional[str] = None + + +class ClientAuthDependency: + """ + Dependency that authenticates a client using client_id and client_secret. + + This will validate the client credentials and return the client information. + + Corresponds to authenticateClient in src/server/auth/middleware/clientAuth.ts + """ + + def __init__(self, clients_store: OAuthRegisteredClientsStore): + """ + Initialize the dependency. + + Args: + clients_store: Store to look up client information + """ + self.clients_store = clients_store + + async def __call__(self, request: Request) -> OAuthClientInformationFull: + """ + Process the request and authenticate the client. + + Args: + request: FastAPI request + + Returns: + Authenticated client information + + Raises: + HTTPException: If client authentication fails + """ + try: + # Parse request body as form data or JSON + content_type = request.headers.get("Content-Type", "") + + if "application/x-www-form-urlencoded" in content_type: + # Parse form data + request_data = await request.form() + elif "application/json" in content_type: + # Parse JSON data + request_data = await request.json() + else: + raise InvalidRequestError("Unsupported content type") + + # Validate client credentials in request + try: + # TODO: can I just pass request_data to model_validate without pydantic complaining about extra params? + client_request = ClientAuthRequest.model_validate({ + "client_id": request_data.get("client_id"), + "client_secret": request_data.get("client_secret"), + }) + except ValidationError as e: + raise InvalidRequestError(str(e)) + + # Look up client information + client_id = client_request.client_id + client_secret = client_request.client_secret + + client = await self.clients_store.get_client(client_id) + if not client: + raise InvalidClientError("Invalid client_id") + + # If client has a secret, validate it + if client.client_secret: + # Check if client_secret is required but not provided + if not client_secret: + raise InvalidClientError("Client secret is required") + + # Check if client_secret matches + if client.client_secret != client_secret: + raise InvalidClientError("Invalid client_secret") + + # Check if client_secret has expired + if (client.client_secret_expires_at and + client.client_secret_expires_at < int(time.time())): + raise InvalidClientError("Client secret has expired") + + return client + + except OAuthError as e: + status_code = 500 if isinstance(e, ServerError) else 400 + # TODO: make sure we're not leaking anything here + raise HTTPException( + status_code=status_code, + detail=e.to_response_object() + ) \ No newline at end of file diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py new file mode 100644 index 0000000000..1412992ac1 --- /dev/null +++ b/src/mcp/server/auth/provider.py @@ -0,0 +1,162 @@ +""" +OAuth server provider interfaces for MCP authorization. + +Corresponds to TypeScript file: src/server/auth/provider.ts +""" + +from typing import Any, Dict, List, Optional, Protocol +from pydantic import AnyHttpUrl, BaseModel +from starlette.responses import Response + +from mcp.shared.auth import OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens +from mcp.server.auth.types import AuthInfo + + +class AuthorizationParams(BaseModel): + """ + Parameters for the authorization flow. + + Corresponds to AuthorizationParams in src/server/auth/provider.ts + """ + state: Optional[str] = None + scopes: Optional[List[str]] = None + code_challenge: str + redirect_uri: AnyHttpUrl + + +class OAuthRegisteredClientsStore(Protocol): + """ + Interface for storing and retrieving registered OAuth clients. + + Corresponds to OAuthRegisteredClientsStore in src/server/auth/clients.ts + """ + + async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: + """ + Retrieves client information by client ID. + + Args: + client_id: The ID of the client to retrieve. + + Returns: + The client information, or None if the client does not exist. + """ + ... + + async def register_client(self, + client_info: OAuthClientInformationFull + ) -> Optional[OAuthClientInformationFull]: + """ + Registers a new client and returns client information. + + Args: + metadata: The client metadata to register. + + Returns: + The client information, or None if registration failed. + """ + ... + + +class OAuthServerProvider(Protocol): + """ + Implements an end-to-end OAuth server. + + Corresponds to OAuthServerProvider in src/server/auth/provider.ts + """ + + @property + def clients_store(self) -> OAuthRegisteredClientsStore: + """ + A store used to read information about registered OAuth clients. + """ + ... + + # TODO: do we really want to be putting the response in this method? + async def authorize(self, + client: OAuthClientInformationFull, + params: AuthorizationParams, + response: Response) -> None: + """ + Begins the authorization flow, which can be implemented by this server or via redirection. + Must eventually issue a redirect with authorization response or error to the given redirect URI. + + Args: + client: The client requesting authorization. + params: Parameters for the authorization request. + response: The response object to write to. + """ + ... + + async def challenge_for_authorization_code(self, + client: OAuthClientInformationFull, + authorization_code: str) -> str | None: + """ + Returns the code_challenge that was used when the indicated authorization began. + + Args: + client: The client that requested the authorization code. + authorization_code: The authorization code to get the challenge for. + + Returns: + The code challenge that was used when the authorization began. + """ + ... + + async def exchange_authorization_code(self, + client: OAuthClientInformationFull, + authorization_code: str) -> OAuthTokens: + """ + Exchanges an authorization code for an access token. + + Args: + client: The client exchanging the authorization code. + authorization_code: The authorization code to exchange. + + Returns: + The access and refresh tokens. + """ + ... + + async def exchange_refresh_token(self, + client: OAuthClientInformationFull, + refresh_token: str, + scopes: Optional[List[str]] = None) -> OAuthTokens: + """ + Exchanges a refresh token for an access token. + + Args: + client: The client exchanging the refresh token. + refresh_token: The refresh token to exchange. + scopes: Optional scopes to request with the new access token. + + Returns: + The new access and refresh tokens. + """ + ... + + async def verify_access_token(self, token: str) -> AuthInfo: + """ + Verifies an access token and returns information about it. + + Args: + token: The access token to verify. + + Returns: + Information about the verified token. + """ + ... + + async def revoke_token(self, + client: OAuthClientInformationFull, + request: OAuthTokenRevocationRequest) -> None: + """ + Revokes an access or refresh token. + + If the given token is invalid or already revoked, this method should do nothing. + + Args: + client: The client revoking the token. + request: The token revocation request. + """ + ... \ No newline at end of file diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py new file mode 100644 index 0000000000..8fdcdf6a09 --- /dev/null +++ b/src/mcp/server/auth/router.py @@ -0,0 +1,177 @@ +""" +Router for OAuth authorization endpoints. + +Corresponds to TypeScript file: src/server/auth/router.ts +""" + +from dataclasses import dataclass +import re +from typing import Dict, List, Optional, Any, Union +from urllib.parse import urlparse + +from fastapi import Depends, FastAPI, APIRouter, Request, Response +from pydantic import AnyUrl, BaseModel + +from mcp.server.auth.middleware.client_auth import ClientAuthDependency +from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthMetadata +from mcp.server.auth.handlers.metadata import create_metadata_handler +from mcp.server.auth.handlers.authorize import create_authorization_handler +from mcp.server.auth.handlers.token import create_token_handler +from mcp.server.auth.handlers.revoke import create_revocation_handler + + +@dataclass +class ClientRegistrationOptions: + enabled: bool = False + client_secret_expiry_seconds: Optional[int] = None + +@dataclass +class RevocationOptions: + enabled: bool = False + + +def validate_issuer_url(url: AnyUrl): + """ + Validate that the issuer URL meets OAuth 2.0 requirements. + + Args: + url: The issuer URL to validate + + Raises: + ValueError: If the issuer URL is invalid + """ + + # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing + if (url.scheme != "https" and + url.host != "localhost" and + not (url.host is not None and url.host.startswith("127.0.0.1"))): + raise ValueError("Issuer URL must be HTTPS") + + # No fragments or query parameters allowed + if url.fragment: + raise ValueError("Issuer URL must not have a fragment") + if url.query: + raise ValueError("Issuer URL must not have a query string") + + +AUTHORIZATION_PATH = "/authorize" +TOKEN_PATH = "/token" +REGISTRATION_PATH = "/register" +REVOCATION_PATH = "/revoke" + + +def create_auth_router( + provider: OAuthServerProvider, + issuer_url: AnyUrl, + service_documentation_url: AnyUrl | None = None, + client_registration_options: ClientRegistrationOptions | None = None, + revocation_options: RevocationOptions | None = None + ) -> APIRouter: + """ + Create a FastAPI application with standard MCP authorization endpoints. + + Corresponds to mcpAuthRouter in src/server/auth/router.ts + + Args: + provider: OAuth server provider + issuer_url: Issuer URL for the authorization server + service_documentation_url: Optional URL for service documentation + + Returns: + FastAPI application with authorization endpoints + """ + + validate_issuer_url(issuer_url) + + client_registration_options = client_registration_options or ClientRegistrationOptions() + revocation_options = revocation_options or RevocationOptions() + + client_auth = ClientAuthDependency(provider.clients_store) + + auth_app = APIRouter() + + + # Create handlers + + # Add routes + metadata = build_metadata(issuer_url, service_documentation_url, client_registration_options, revocation_options) + auth_app.add_api_route( + "/.well-known/oauth-authorization-server", + create_metadata_handler(metadata), + methods=["GET"] + ) + + # NOTE: reviewed + auth_app.add_api_route( + AUTHORIZATION_PATH, + create_authorization_handler(provider), + methods=["GET", "POST"] + ) + + # Add token endpoint with client auth dependency + # NOTE: reviewed + auth_app.add_api_route( + TOKEN_PATH, + create_token_handler(provider), + methods=["POST"], + dependencies=[Depends(client_auth)] + ) + + # Add registration endpoint if supported + if client_registration_options.enabled: + from mcp.server.auth.handlers.register import create_registration_handler + registration_handler = create_registration_handler( + provider.clients_store, + client_secret_expiry_seconds=client_registration_options.client_secret_expiry_seconds, + ) + # NOTE: reviewed + auth_app.add_api_route( + REGISTRATION_PATH, + registration_handler, + methods=["POST"] + ) + + # Add revocation endpoint if supported + if revocation_options.enabled: + # NOTE: reviewed + auth_app.add_api_route( + REVOCATION_PATH, + create_revocation_handler(provider), + methods=["POST"], + dependencies=[Depends(client_auth)] + ) + + return auth_app + +def build_metadata( + issuer_url: AnyUrl, + service_documentation_url: Optional[AnyUrl], + client_registration_options: ClientRegistrationOptions, + revocation_options: RevocationOptions, + ) -> Dict[str, Any]: + issuer_url_str = str(issuer_url).rstrip("/") + # Create metadata + metadata = { + "issuer": issuer_url_str, + "service_documentation": str(service_documentation_url).rstrip("/") if service_documentation_url else None, + + "authorization_endpoint": f"{issuer_url_str}{AUTHORIZATION_PATH}", + "response_types_supported": ["code"], + "code_challenge_methods_supported": ["S256"], + + "token_endpoint": f"{issuer_url_str}{TOKEN_PATH}", + "token_endpoint_auth_methods_supported": ["client_secret_post"], + "grant_types_supported": ["authorization_code", "refresh_token"], + } + + # Add registration endpoint if supported + if client_registration_options.enabled: + metadata["registration_endpoint"] = f"{issuer_url_str}{REGISTRATION_PATH}" + + # Add revocation endpoint if supported + if revocation_options.enabled: + metadata["revocation_endpoint"] = f"{issuer_url_str}{REVOCATION_PATH}" + metadata["revocation_endpoint_auth_methods_supported"] = ["client_secret_post"] + + return metadata \ No newline at end of file diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py new file mode 100644 index 0000000000..98d9ebde4d --- /dev/null +++ b/src/mcp/server/auth/types.py @@ -0,0 +1,23 @@ +""" +Authorization types for MCP server. + +Corresponds to TypeScript file: src/server/auth/types.ts +""" + +from typing import List, Optional +from pydantic import BaseModel + + +class AuthInfo(BaseModel): + """ + Information about a validated access token, provided to request handlers. + + Corresponds to AuthInfo in src/server/auth/types.ts + """ + token: str + client_id: str + scopes: List[str] + expires_at: Optional[int] = None + + class Config: + extra = "ignore" \ No newline at end of file diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 1f5736e43f..793a0b0755 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -11,7 +11,7 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Callable, Generic, Literal, Sequence +from typing import Any, Callable, Generic, Literal, Optional, Sequence import anyio import pydantic_core @@ -20,6 +20,9 @@ from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict +from mcp.server.auth.provider import OAuthServerProvider +from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions +from mcp.server.auth.types import AuthInfo from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -89,6 +92,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]): Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None ) = Field(None, description="Lifespan context manager") + auth_issuer_url: AnyUrl | None = Field(None, description="Auth issuer URL") + auth_service_documentation_url: AnyUrl | None = Field(None, description="Service documentation URL") + auth_client_registration_options: ClientRegistrationOptions | None = None + auth_revocation_options: RevocationOptions | None = None + auth_required_scopes: list[str] | None = None + + def lifespan_wrapper( app: FastMCP, @@ -104,7 +114,11 @@ async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: class FastMCP: def __init__( - self, name: str | None = None, instructions: str | None = None, **settings: Any + self, + name: str | None = None, + instructions: str | None = None, + auth_provider: OAuthServerProvider | None = None, + **settings: Any ): self.settings = Settings(**settings) @@ -124,6 +138,7 @@ def __init__( self._prompt_manager = PromptManager( warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts ) + self._auth_provider = auth_provider self.dependencies = self.settings.dependencies # Set up MCP protocol handlers @@ -463,10 +478,24 @@ async def run_sse_async(self) -> None: """Run the server using SSE transport.""" from starlette.applications import Starlette from starlette.routing import Mount, Route + from starlette.middleware import Middleware + from fastapi import FastAPI, Depends + + # Import auth dependency if needed + auth_dependencies = [] + if self._auth_provider: + from mcp.server.auth.middleware.bearer_auth import BearerAuthDependency + auth_dependencies = [Depends(BearerAuthDependency( + provider=self._auth_provider, + required_scopes=self.settings.auth_required_scopes + ))] sse = SseServerTransport("/messages/") async def handle_sse(request): + # Add client ID from auth context into request context if available + request_meta = {} + async with sse.connect_sse( request.scope, request.receive, request._send ) as streams: @@ -476,16 +505,26 @@ async def handle_sse(request): self._mcp_server.create_initialization_options(), ) - starlette_app = Starlette( - debug=self.settings.debug, - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - ) + # Create Starlette app + app = FastAPI(debug=self.settings.debug) + + # Add routes with auth dependency if required + app.add_api_route("/sse", endpoint=handle_sse, dependencies=auth_dependencies) + # TODO: convert this to a handler so it can take a dependency + app.mount("/messages/", sse.handle_post_message) # , dependencies=auth_dependencies) + + # Add auth endpoints if auth provider is configured + if self._auth_provider and self.settings.auth_issuer_url: + from mcp.server.auth.router import create_auth_router + auth_app = create_auth_router( + self._auth_provider, + self.settings.auth_issuer_url, + self.settings.auth_service_documentation_url + ) + app.mount("/", auth_app) config = uvicorn.Config( - starlette_app, + app, host=self.settings.host, port=self.settings.port, log_level=self.settings.log_level.lower(), diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py new file mode 100644 index 0000000000..f751065a20 --- /dev/null +++ b/src/mcp/shared/auth.py @@ -0,0 +1,123 @@ +""" +Authorization types and models for MCP OAuth implementation. + +Corresponds to TypeScript file: src/shared/auth.ts +""" + +from typing import Any, Dict, List, Optional, Union +from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator + + +class OAuthErrorResponse(BaseModel): + """ + OAuth 2.1 error response. + + Corresponds to OAuthErrorResponseSchema in src/shared/auth.ts + """ + error: str + error_description: Optional[str] = None + error_uri: Optional[AnyHttpUrl] = None + + +class OAuthTokens(BaseModel): + """ + OAuth 2.1 token response. + + Corresponds to OAuthTokensSchema in src/shared/auth.ts + """ + access_token: str + token_type: str + expires_in: Optional[int] = None + scope: Optional[str] = None + refresh_token: Optional[str] = None + + +class OAuthClientMetadata(BaseModel): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. + + Corresponds to OAuthClientMetadataSchema in src/shared/auth.ts + """ + redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) + token_endpoint_auth_method: Optional[str] + grant_types: Optional[List[str]] + response_types: Optional[List[str]] = None + client_name: Optional[str] = None + client_uri: Optional[AnyHttpUrl] = None + logo_uri: Optional[AnyHttpUrl] = None + scope: Optional[str] = None + contacts: Optional[List[str]] = None + tos_uri: Optional[AnyHttpUrl] = None + policy_uri: Optional[AnyHttpUrl] = None + jwks_uri: Optional[AnyHttpUrl] = None + jwks: Optional[Any] = None + software_id: Optional[str] = None + software_version: Optional[str] = None + + +class OAuthClientInformation(BaseModel): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration client information. + + Corresponds to OAuthClientInformationSchema in src/shared/auth.ts + """ + client_id: str + client_secret: Optional[str] = None + client_id_issued_at: Optional[int] = None + client_secret_expires_at: Optional[int] = None + + +class OAuthClientInformationFull(OAuthClientMetadata, OAuthClientInformation): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration full response + (client information plus metadata). + + Corresponds to OAuthClientInformationFullSchema in src/shared/auth.ts + """ + pass + + +class OAuthClientRegistrationError(BaseModel): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration error response. + + Corresponds to OAuthClientRegistrationErrorSchema in src/shared/auth.ts + """ + error: str + error_description: Optional[str] = None + + +class OAuthTokenRevocationRequest(BaseModel): + """ + RFC 7009 OAuth 2.0 Token Revocation request. + + Corresponds to OAuthTokenRevocationRequestSchema in src/shared/auth.ts + """ + token: str + token_type_hint: Optional[str] = None + + +class OAuthMetadata(BaseModel): + """ + RFC 8414 OAuth 2.0 Authorization Server Metadata. + + Corresponds to OAuthMetadataSchema in src/shared/auth.ts + """ + issuer: str + authorization_endpoint: str + token_endpoint: str + registration_endpoint: Optional[str] = None + scopes_supported: Optional[List[str]] = None + response_types_supported: List[str] + response_modes_supported: Optional[List[str]] = None + grant_types_supported: Optional[List[str]] = None + token_endpoint_auth_methods_supported: Optional[List[str]] = None + token_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None + service_documentation: Optional[str] = None + revocation_endpoint: Optional[str] = None + revocation_endpoint_auth_methods_supported: Optional[List[str]] = None + revocation_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None + introspection_endpoint: Optional[str] = None + introspection_endpoint_auth_methods_supported: Optional[List[str]] = None + introspection_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None + code_challenge_methods_supported: Optional[List[str]] = None \ No newline at end of file diff --git a/tests/server/fastmcp/auth/__init__.py b/tests/server/fastmcp/auth/__init__.py new file mode 100644 index 0000000000..304b8cd87a --- /dev/null +++ b/tests/server/fastmcp/auth/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for the MCP server auth components. +""" \ No newline at end of file diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py new file mode 100644 index 0000000000..3d7e51fbdf --- /dev/null +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -0,0 +1,558 @@ +""" +Integration tests for MCP authorization components. +""" + +import base64 +import hashlib +import json +import time +from typing import Any, Dict, List, Optional, cast +from urllib.parse import urlparse, parse_qs + +import anyio +from pydantic import AnyUrl +import pytest +from fastapi import FastAPI, Depends +from fastapi.testclient import TestClient +from starlette.datastructures import MutableHeaders +from starlette.responses import RedirectResponse, JSONResponse +from starlette.requests import Request + +from mcp.server.auth.errors import InvalidTokenError +from mcp.server.auth.middleware.bearer_auth import BearerAuthDependency +from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, OAuthRegisteredClientsStore +from mcp.server.auth.router import create_auth_router +from mcp.server.auth.types import AuthInfo +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthTokenRevocationRequest, + OAuthTokens, +) +from mcp.server.fastmcp import FastMCP + + +# Mock client store for testing +class MockClientStore: + def __init__(self): + self.clients = {} + + async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> OAuthClientInformationFull: + self.clients[client_info.client_id] = client_info + return client_info + + +# Mock OAuth provider for testing +class MockOAuthProvider: + def __init__(self): + self.client_store = MockClientStore() + self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} + self.tokens = {} # token -> {client_id, scopes, expires_at} + self.refresh_tokens = {} # refresh_token -> access_token + + @property + def clients_store(self) -> OAuthRegisteredClientsStore: + return self.client_store + + async def authorize(self, + client: OAuthClientInformationFull, + params: AuthorizationParams, + response: RedirectResponse) -> None: + # Generate an authorization code + code = f"code_{int(time.time())}" + + # Store the code for later verification + self.auth_codes[code] = { + "client_id": client.client_id, + "code_challenge": params.code_challenge, + "redirect_uri": params.redirect_uri, + "expires_at": int(time.time()) + 600, # 10 minutes + } + + # Redirect with code + query = {"code": code} + if params.state: + query["state"] = params.state + + redirect_url = f"{params.redirect_uri}?" + "&".join([f"{k}={v}" for k, v in query.items()]) + response.headers["location"] = redirect_url + + async def challenge_for_authorization_code(self, + client: OAuthClientInformationFull, + authorization_code: str) -> str: + # Get the stored code info + code_info = self.auth_codes.get(authorization_code) + if not code_info: + raise InvalidTokenError("Invalid authorization code") + + # Check if code is expired + if code_info["expires_at"] < int(time.time()): + raise InvalidTokenError("Authorization code has expired") + + # Check if the code was issued to this client + if code_info["client_id"] != client.client_id: + raise InvalidTokenError("Authorization code was not issued to this client") + + return code_info["code_challenge"] + + async def exchange_authorization_code(self, + client: OAuthClientInformationFull, + authorization_code: str) -> OAuthTokens: + # Get the stored code info + code_info = self.auth_codes.get(authorization_code) + if not code_info: + raise InvalidTokenError("Invalid authorization code") + + # Check if code is expired + if code_info["expires_at"] < int(time.time()): + raise InvalidTokenError("Authorization code has expired") + + # Check if the code was issued to this client + if code_info["client_id"] != client.client_id: + raise InvalidTokenError("Authorization code was not issued to this client") + + # Generate an access token and refresh token + access_token = f"access_{int(time.time())}" + refresh_token = f"refresh_{int(time.time())}" + + # Store the tokens + self.tokens[access_token] = { + "client_id": client.client_id, + "scopes": ["read", "write"], + "expires_at": int(time.time()) + 3600, + } + + self.refresh_tokens[refresh_token] = access_token + + # Remove the used code + del self.auth_codes[authorization_code] + + return OAuthTokens( + access_token=access_token, + token_type="bearer", + expires_in=3600, + scope="read write", + refresh_token=refresh_token, + ) + + async def exchange_refresh_token(self, + client: OAuthClientInformationFull, + refresh_token: str, + scopes: Optional[List[str]] = None) -> OAuthTokens: + # Check if refresh token exists + if refresh_token not in self.refresh_tokens: + raise InvalidTokenError("Invalid refresh token") + + # Get the access token for this refresh token + old_access_token = self.refresh_tokens[refresh_token] + + # Check if the access token exists + if old_access_token not in self.tokens: + raise InvalidTokenError("Invalid refresh token") + + # Check if the token was issued to this client + token_info = self.tokens[old_access_token] + if token_info["client_id"] != client.client_id: + raise InvalidTokenError("Refresh token was not issued to this client") + + # Generate a new access token and refresh token + new_access_token = f"access_{int(time.time())}" + new_refresh_token = f"refresh_{int(time.time())}" + + # Store the new tokens + self.tokens[new_access_token] = { + "client_id": client.client_id, + "scopes": scopes or token_info["scopes"], + "expires_at": int(time.time()) + 3600, + } + + self.refresh_tokens[new_refresh_token] = new_access_token + + # Remove the old tokens + del self.refresh_tokens[refresh_token] + del self.tokens[old_access_token] + + return OAuthTokens( + access_token=new_access_token, + token_type="bearer", + expires_in=3600, + scope=" ".join(scopes) if scopes else " ".join(token_info["scopes"]), + refresh_token=new_refresh_token, + ) + + async def verify_access_token(self, token: str) -> AuthInfo: + # Check if token exists + if token not in self.tokens: + raise InvalidTokenError("Invalid access token") + + # Get token info + token_info = self.tokens[token] + + # Check if token is expired + if token_info["expires_at"] < int(time.time()): + raise InvalidTokenError("Access token has expired") + + return AuthInfo( + token=token, + client_id=token_info["client_id"], + scopes=token_info["scopes"], + expires_at=token_info["expires_at"], + ) + + async def revoke_token(self, + client: OAuthClientInformationFull, + request: OAuthTokenRevocationRequest) -> None: + token = request.token + + # Check if it's a refresh token + if token in self.refresh_tokens: + access_token = self.refresh_tokens[token] + + # Check if this refresh token belongs to this client + if self.tokens[access_token]["client_id"] != client.client_id: + # For security reasons, we still return success + return + + # Remove the refresh token and its associated access token + del self.tokens[access_token] + del self.refresh_tokens[token] + + # Check if it's an access token + elif token in self.tokens: + # Check if this access token belongs to this client + if self.tokens[token]["client_id"] != client.client_id: + # For security reasons, we still return success + return + + # Remove the access token + del self.tokens[token] + + # Also remove any refresh tokens that point to this access token + for refresh_token, access_token in list(self.refresh_tokens.items()): + if access_token == token: + del self.refresh_tokens[refresh_token] + + +@pytest.fixture +def mock_oauth_provider(): + return MockOAuthProvider() + + +@pytest.fixture +def auth_app(mock_oauth_provider): + app = create_auth_router( + mock_oauth_provider, + AnyUrl("https://auth.example.com"), + AnyUrl("https://docs.example.com"), + ) + return app + + +@pytest.fixture +def test_client(auth_app): + return TestClient(auth_app) + + +@pytest.mark.anyio +class TestAuthEndpoints: + def test_metadata_endpoint(self, test_client): + """Test the OAuth 2.0 metadata endpoint.""" + response = test_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + + metadata = response.json() + assert metadata["issuer"] == "https://auth.example.com" + assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize" + assert metadata["token_endpoint"] == "https://auth.example.com/token" + assert metadata["registration_endpoint"] == "https://auth.example.com/register" + assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" + assert metadata["response_types_supported"] == ["code"] + assert metadata["code_challenge_methods_supported"] == ["S256"] + assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"] + assert metadata["grant_types_supported"] == ["authorization_code", "refresh_token"] + assert metadata["service_documentation"] == "https://docs.example.com" + + @pytest.mark.anyio + async def test_client_registration(self, test_client, mock_oauth_provider): + """Test client registration.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "client_uri": "https://client.example.com", + } + + response = test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + + client_info = response.json() + assert "client_id" in client_info + assert "client_secret" in client_info + assert client_info["client_name"] == "Test Client" + assert client_info["redirect_uris"] == ["https://client.example.com/callback"] + + # Verify that the client was registered + assert await mock_oauth_provider.clients_store.get_client(client_info["client_id"]) is not None + + @pytest.mark.anyio + async def test_authorization_flow(self, test_client, mock_oauth_provider): + """Test the full authorization flow.""" + # 1. Register a client + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + } + + response = test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # 2. Create a PKCE challenge + code_verifier = "some_random_verifier_string" + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode()).digest() + ).decode().rstrip("=") + + # 3. Request authorization + response = test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": "test_state", + }, + allow_redirects=False, + ) + assert response.status_code == 302 + + # 4. Extract the authorization code from the redirect URL + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params + assert query_params["state"][0] == "test_state" + auth_code = query_params["code"][0] + + # 5. Exchange the authorization code for tokens + response = test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "code": auth_code, + "code_verifier": code_verifier, + }, + ) + assert response.status_code == 200 + + token_response = response.json() + assert "access_token" in token_response + assert "token_type" in token_response + assert "refresh_token" in token_response + assert "expires_in" in token_response + assert token_response["token_type"] == "bearer" + + # 6. Verify the access token + access_token = token_response["access_token"] + refresh_token = token_response["refresh_token"] + + # Create a test client with the token + auth_info = await mock_oauth_provider.verify_access_token(access_token) + assert auth_info.client_id == client_info["client_id"] + assert "read" in auth_info.scopes + assert "write" in auth_info.scopes + + # 7. Refresh the token + response = test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "refresh_token": refresh_token, + }, + ) + assert response.status_code == 200 + + new_token_response = response.json() + assert "access_token" in new_token_response + assert "refresh_token" in new_token_response + assert new_token_response["access_token"] != access_token + assert new_token_response["refresh_token"] != refresh_token + + # 8. Revoke the token + response = test_client.post( + "/revoke", + data={ + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "token": new_token_response["access_token"], + }, + ) + assert response.status_code == 200 + + # Verify that the token was revoked + with pytest.raises(InvalidTokenError): + await mock_oauth_provider.verify_access_token(new_token_response["access_token"]) + + +@pytest.mark.anyio +class TestFastMCPWithAuth: + """Test FastMCP server with authentication.""" + + @pytest.mark.anyio + async def test_fastmcp_with_auth(self, mock_oauth_provider): + """Test creating a FastMCP server with authentication.""" + # Create FastMCP server with auth provider + mcp = FastMCP( + auth_provider=mock_oauth_provider, + auth_issuer_url="https://auth.example.com", + require_auth=True, + ) + + # Add a test tool + @mcp.tool() + def test_tool(x: int) -> str: + return f"Result: {x}" + + # Create a FastAPI app for testing + from fastapi import FastAPI, Depends, Security + + # Override the run method to capture the app + app = None + + async def mock_run_sse(): + nonlocal app + + # Create auth dependency + auth_dependency = BearerAuthDependency( + provider=mock_oauth_provider, + required_scopes=mcp.settings.auth_required_scopes + ) + + # Create FastAPI app + app = FastAPI(debug=mcp.settings.debug) + + # Add a test endpoint that requires authentication + @app.get("/test") + async def test_endpoint(auth: AuthInfo = Depends(auth_dependency)): + return {"status": "ok", "client_id": auth.client_id} + + # Add another endpoint that doesn't require auth for comparison + @app.get("/public") + async def public_endpoint(): + return {"status": "ok"} + + # Add auth endpoints + from mcp.server.auth.router import create_auth_router + auth_app = create_auth_router( + mock_oauth_provider, + cast(AnyUrl, mcp.settings.auth_issuer_url), + mcp.settings.auth_service_documentation_url + ) + app.mount("/", auth_app) + + # Override the run method + mcp.run_sse_async = mock_run_sse + await mcp.run_sse_async() + + assert app is not None + test_client = TestClient(app) + + # Test metadata endpoint + response = test_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + + # Test that auth is required for protected endpoints + response = test_client.get("/test") + assert response.status_code == 401 + + # Test that public endpoints don't require auth + response = test_client.get("/public") + assert response.status_code == 200 + + # Register a client + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + } + + response = test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # Create a PKCE challenge + code_verifier = "some_random_verifier_string" + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode()).digest() + ).decode().rstrip("=") + + # Request authorization + response = test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": "test_state", + }, + allow_redirects=False, + ) + assert response.status_code == 302 + + # Extract the authorization code from the redirect URL + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params + auth_code = query_params["code"][0] + + # Exchange the authorization code for tokens + response = test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "code": auth_code, + "code_verifier": code_verifier, + }, + ) + assert response.status_code == 200 + + token_response = response.json() + assert "access_token" in token_response + + # Test the authenticated endpoint with valid token + response = test_client.get( + "/test", + headers={"Authorization": f"Bearer {token_response['access_token']}"}, + ) + assert response.status_code == 200 + assert response.json()["status"] == "ok" + assert response.json()["client_id"] == client_info["client_id"] + + # Test with invalid token + response = test_client.get( + "/test", + headers={"Authorization": "Bearer invalid_token"}, + ) + assert response.status_code == 401 \ No newline at end of file From 331d51eb04d0f667ff131c6d2315b00fd6751c00 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Thu, 6 Mar 2025 11:19:43 -0800 Subject: [PATCH 02/60] Unwind changes --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e87136758c..157263de68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", - "fastapi", ] [project.optional-dependencies] @@ -48,7 +47,7 @@ dev-dependencies = [ "pytest>=8.3.4", "ruff>=0.8.5", "trio>=0.26.2", - "pytest-flakefinder==1.1.0", + "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", ] From d283f560cc3a461555daba6acd533ab0be157cde Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Thu, 6 Mar 2025 16:21:18 -0800 Subject: [PATCH 03/60] wip --- CLAUDE.md | 10 +- pyproject.toml | 3 + src/mcp/server/auth/handlers/authorize.py | 13 +- src/mcp/server/auth/handlers/metadata.py | 9 +- src/mcp/server/auth/handlers/register.py | 19 +- src/mcp/server/auth/handlers/revoke.py | 26 +-- src/mcp/server/auth/handlers/token.py | 123 ++++++------- src/mcp/server/auth/json_response.py | 6 + src/mcp/server/auth/middleware/bearer_auth.py | 126 +++++++------ src/mcp/server/auth/middleware/client_auth.py | 137 +++++++------- src/mcp/server/auth/provider.py | 2 + src/mcp/server/auth/router.py | 99 +++++------ src/mcp/server/auth/types.py | 1 + src/mcp/server/fastmcp/server.py | 60 ++++--- src/mcp/shared/auth.py | 4 +- .../fastmcp/auth/test_auth_integration.py | 168 ++++++++---------- 16 files changed, 420 insertions(+), 386 deletions(-) create mode 100644 src/mcp/server/auth/json_response.py diff --git a/CLAUDE.md b/CLAUDE.md index e95b75cd58..baed85a238 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -19,7 +19,7 @@ This document contains critical information about working with this codebase. Fo - Line length: 88 chars maximum 3. Testing Requirements - - Framework: `uv run pytest` + - Framework: `uv run --frozen pytest` - Async testing: use anyio, not asyncio - Coverage: test edge cases and errors - New features require tests @@ -54,9 +54,9 @@ This document contains critical information about working with this codebase. Fo ## Code Formatting 1. Ruff - - Format: `uv run ruff format .` - - Check: `uv run ruff check .` - - Fix: `uv run ruff check . --fix` + - Format: `uv run --frozen ruff format .` + - Check: `uv run --frozen ruff check .` + - Fix: `uv run --frozen ruff check . --fix` - Critical issues: - Line length (88 chars) - Import sorting (I001) @@ -67,7 +67,7 @@ This document contains critical information about working with this codebase. Fo - Imports: split into multiple lines 2. Type Checking - - Tool: `uv run pyright` + - Tool: `uv run --frozen pyright` - Requirements: - Explicit None checks for Optional - Type narrowing for strings diff --git a/pyproject.toml b/pyproject.toml index 157263de68..489d1faa71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,9 @@ strict = [ "src/mcp/server/fastmcp/tools/base.py", ] +[tool.pytest.ini_options] +markers = ["anyio"] + [tool.ruff.lint] select = ["E", "F", "I"] ignore = [] diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 2eabd0a6e5..b13555347e 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -9,10 +9,10 @@ from typing import Any, Callable, Dict, List, Literal, Optional from urllib.parse import urlencode, parse_qs -from fastapi import Request, Response +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, ValidationError from pydantic_core import Url -from starlette.responses import JSONResponse, RedirectResponse from mcp.server.auth.errors import ( InvalidClientError, @@ -81,9 +81,14 @@ async def authorization_handler(request: Request) -> Response: # Validate request parameters try: if request.method == "GET": - auth_request = AuthorizationRequest.model_validate(request.query_params) + # Convert query_params to dict for pydantic validation + params = dict(request.query_params) + auth_request = AuthorizationRequest.model_validate(params) else: - auth_request = AuthorizationRequest.model_validate_json(await request.body()) + # Parse form data for POST requests + form_data = await request.form() + params = dict(form_data) + auth_request = AuthorizationRequest.model_validate(params) except ValidationError as e: raise InvalidRequestError(str(e)) diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index 2acee117a9..2c2ca26507 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -5,8 +5,9 @@ """ from typing import Any, Callable, Dict, Optional -from fastapi import Request, Response -from starlette.responses import JSONResponse + +from starlette.requests import Request +from starlette.responses import JSONResponse, Response def create_metadata_handler(metadata: Dict[str, Any]) -> Callable: @@ -19,7 +20,7 @@ def create_metadata_handler(metadata: Dict[str, Any]) -> Callable: metadata: The metadata to return in the response Returns: - A FastAPI route handler function + A Starlette endpoint handler function """ async def metadata_handler(request: Request) -> Response: @@ -27,7 +28,7 @@ async def metadata_handler(request: Request) -> Response: Handler for the OAuth 2.0 Authorization Server Metadata endpoint. Args: - request: The FastAPI request + request: The Starlette request Returns: JSON response with the authorization server metadata diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 47527ea4e7..150e048e69 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -10,15 +10,16 @@ from typing import Any, Callable, Dict, List, Optional from uuid import uuid4 -from fastapi import Request, Response +from starlette.requests import Request +from starlette.responses import JSONResponse, Response from pydantic import ValidationError -from starlette.responses import JSONResponse from mcp.server.auth.errors import ( InvalidRequestError, ServerError, OAuthError, ) +from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata @@ -31,9 +32,10 @@ def create_registration_handler(clients_store: OAuthRegisteredClientsStore, clie Args: clients_store: The store for registered clients + client_secret_expiry_seconds: Optional expiry time for client secrets Returns: - A FastAPI route handler function + A Starlette endpoint handler function """ async def registration_handler(request: Request) -> Response: @@ -41,15 +43,16 @@ async def registration_handler(request: Request) -> Response: Handler for the OAuth 2.0 Dynamic Client Registration endpoint. Args: - request: The FastAPI request + request: The Starlette request Returns: JSON response with client information or error """ try: - # Validate client metadata + # Parse request body as JSON try: - client_metadata = OAuthClientMetadata.model_validate_json(await request.body()) + body = await request.json() + client_metadata = OAuthClientMetadata.model_validate(body) except ValidationError as e: raise InvalidRequestError(f"Invalid client metadata: {str(e)}") @@ -90,8 +93,8 @@ async def registration_handler(request: Request) -> Response: raise ServerError("Failed to register client") # Return client information - return JSONResponse( - content=client.model_dump(exclude_none=True), + return PydanticJSONResponse( + content=client, status_code=201 ) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 59a11918a1..6280e71c97 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -6,20 +6,24 @@ from typing import Any, Callable, Dict, Optional -from fastapi import Request, Response +from starlette.requests import Request +from starlette.responses import Response from pydantic import ValidationError -from starlette.responses import JSONResponse, Response as StarletteResponse from mcp.server.auth.errors import ( InvalidRequestError, ServerError, OAuthError, ) +from mcp.server.auth.middleware import client_auth from mcp.server.auth.provider import OAuthServerProvider from mcp.shared.auth import OAuthClientInformationFull, OAuthTokenRevocationRequest +from mcp.server.auth.middleware.client_auth import ClientAuthRequest, ClientAuthenticator +class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): + pass -def create_revocation_handler(provider: OAuthServerProvider) -> Callable: +def create_revocation_handler(provider: OAuthServerProvider, client_authenticator: ClientAuthenticator) -> Callable: """ Create a handler for OAuth 2.0 Token Revocation. @@ -29,25 +33,27 @@ def create_revocation_handler(provider: OAuthServerProvider) -> Callable: provider: The OAuth server provider Returns: - A FastAPI route handler function + A Starlette endpoint handler function """ - async def revocation_handler(request: Request, client_auth: OAuthClientInformationFull) -> Response: + async def revocation_handler(request: Request) -> Response: """ Handler for the OAuth 2.0 Token Revocation endpoint. """ - # Validate revocation request try: - revocation_request = OAuthTokenRevocationRequest.model_validate_json(await request.body()) + revocation_request = RevocationRequest.model_validate_json(await request.body()) except ValidationError as e: - raise InvalidRequestError(str(e)) + raise InvalidRequestError(f"Invalid request body: {e}") + + # Authenticate client + client_auth_result = await client_authenticator(revocation_request) # Revoke token if provider.revoke_token: - await provider.revoke_token(client_auth, revocation_request) + await provider.revoke_token(client_auth_result, revocation_request) # Return successful empty response - return StarletteResponse( + return Response( status_code=200, headers={ "Cache-Control": "no-store", diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 9164991a69..e9d7ff293b 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -7,11 +7,11 @@ import base64 import hashlib import json -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Annotated, Any, Callable, Dict, List, Literal, Optional, Union -from fastapi import Request, Response -from pydantic import BaseModel, Field, ValidationError +from starlette.requests import Request from starlette.responses import JSONResponse +from pydantic import BaseModel, Field, RootModel, TypeAdapter, ValidationError from mcp.server.auth.errors import ( InvalidClientError, @@ -23,37 +23,36 @@ ) from mcp.server.auth.provider import OAuthServerProvider from mcp.shared.auth import OAuthClientInformationFull, OAuthTokens -from mcp.server.auth.middleware.client_auth import ClientAuthDependency +from mcp.server.auth.middleware.client_auth import ClientAuthRequest, ClientAuthenticator +from mcp.server.auth.json_response import PydanticJSONResponse -class AuthorizationCodeRequest(BaseModel): +class AuthorizationCodeRequest(ClientAuthRequest): """ Model for the authorization code grant request parameters. Corresponds to AuthorizationCodeExchangeSchema in src/server/auth/handlers/token.ts """ - grant_type: str = Field(..., description="Must be 'authorization_code'") + grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") code_verifier: str = Field(..., description="PKCE code verifier") - - class Config: - extra = "ignore" - -class RefreshTokenRequest(BaseModel): +class RefreshTokenRequest(ClientAuthRequest): """ Model for the refresh token grant request parameters. Corresponds to RefreshTokenExchangeSchema in src/server/auth/handlers/token.ts """ - grant_type: str = Field(..., description="Must be 'refresh_token'") + grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: Optional[str] = Field(None, description="Optional scope parameter") - - class Config: - extra = "ignore" -def create_token_handler(provider: OAuthServerProvider) -> Callable: +class TokenRequest(RootModel): + root: Annotated[Union[AuthorizationCodeRequest, RefreshTokenRequest], Field(discriminator="grant_type")] +# TokenRequest = RootModel(Annotated[Union[AuthorizationCodeRequest, RefreshTokenRequest], Field(discriminator="grant_type")]) + + +def create_token_handler(provider: OAuthServerProvider, client_authenticator: ClientAuthenticator) -> Callable: """ Create a handler for the OAuth 2.0 Token endpoint. @@ -63,74 +62,60 @@ def create_token_handler(provider: OAuthServerProvider) -> Callable: provider: The OAuth server provider Returns: - A FastAPI route handler function + A Starlette endpoint handler function """ - async def token_handler(request: Request, client_auth: OAuthClientInformationFull) -> Response: + async def token_handler(request: Request): """ Handler for the OAuth 2.0 Token endpoint. Args: - request: The FastAPI request + request: The Starlette request Returns: JSON response with tokens or error """ - params = json.loads(await request.body()) + # Parse request body as form data or JSON + content_type = request.headers.get("Content-Type", "") - - # Check grant_type first to determine which validation model to use - if "grant_type" not in params: - raise InvalidRequestError("Missing required parameter: grant_type") - grant_type = params["grant_type"] - + try: + token_request = TokenRequest.model_validate_json(await request.body()).root + except ValidationError as e: + raise InvalidRequestError(f"Invalid request body: {e}") + client_info = await client_authenticator(token_request) + tokens: OAuthTokens - if grant_type == "authorization_code": - # Validate authorization code parameters - try: - code_request = AuthorizationCodeRequest.model_validate(params) - except ValidationError as e: - raise InvalidRequestError(str(e)) + match token_request: + case AuthorizationCodeRequest(): + # Verify PKCE code verifier + expected_challenge = await provider.challenge_for_authorization_code( + client_info, token_request.code + ) + if expected_challenge is None: + raise InvalidRequestError("Invalid authorization code") + + # Calculate challenge from verifier + sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() + actual_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + + if actual_challenge != expected_challenge: + raise InvalidRequestError("code_verifier does not match the challenge") + + # Exchange authorization code for tokens + tokens = await provider.exchange_authorization_code(client_info, token_request.code) - # Verify PKCE code verifier - expected_challenge = await provider.challenge_for_authorization_code( - client_auth, code_request.code - ) - if expected_challenge is None: - raise InvalidRequestError("Invalid authorization code") - - # Calculate challenge from verifier - sha256 = hashlib.sha256(code_request.code_verifier.encode()).digest() - actual_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=") - - if actual_challenge != expected_challenge: - raise InvalidRequestError("code_verifier does not match the challenge") - - # Exchange authorization code for tokens - tokens = await provider.exchange_authorization_code(client_auth, code_request.code) - - elif grant_type == "refresh_token": - # Validate refresh token parameters - try: - refresh_request = RefreshTokenRequest.model_validate(params) - except ValidationError as e: - raise InvalidRequestError(str(e)) - - # Parse scopes if provided - scopes = refresh_request.scope.split(" ") if refresh_request.scope else None - - # Exchange refresh token for new tokens - tokens = await provider.exchange_refresh_token( - client_auth, refresh_request.refresh_token, scopes - ) - - else: - raise InvalidRequestError( - f"Unsupported grant_type: {grant_type}" - ) + case RefreshTokenRequest(): + # Parse scopes if provided + scopes = token_request.scope.split(" ") if token_request.scope else None + + # Exchange refresh token for new tokens + tokens = await provider.exchange_refresh_token( + client_info, token_request.refresh_token, scopes + ) + - return JSONResponse( + return PydanticJSONResponse( content=tokens, headers={ "Cache-Control": "no-store", diff --git a/src/mcp/server/auth/json_response.py b/src/mcp/server/auth/json_response.py new file mode 100644 index 0000000000..7dc39bcaac --- /dev/null +++ b/src/mcp/server/auth/json_response.py @@ -0,0 +1,6 @@ +from typing import Any +from starlette.responses import JSONResponse + +class PydanticJSONResponse(JSONResponse): + def render(self, content: Any) -> bytes: + return content.model_dump_json(exclude_none=True).encode("utf-8") \ No newline at end of file diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index c7b181434d..431bf16efc 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,28 +1,34 @@ """ -Bearer token authentication dependency for FastAPI. +Bearer token authentication middleware for ASGI applications. Corresponds to TypeScript file: src/server/auth/middleware/bearerAuth.ts """ import time -from typing import List, Optional +from typing import List, Optional, Callable, Awaitable, cast, Dict, Any -from fastapi import Request, HTTPException -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from starlette.requests import HTTPConnection, Request +from starlette.exceptions import HTTPException +from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser +from starlette.middleware.authentication import AuthenticationMiddleware from mcp.server.auth.errors import InsufficientScopeError, InvalidTokenError, OAuthError from mcp.server.auth.provider import OAuthServerProvider from mcp.server.auth.types import AuthInfo -class BearerAuthDependency: - """ - Dependency that requires a valid Bearer token in the Authorization header. - - This will validate the token with the auth provider and return the resulting - auth info. +class AuthenticatedUser(SimpleUser): + """User with authentication info.""" - Corresponds to requireBearerAuth in src/server/auth/middleware/bearerAuth.ts + def __init__(self, auth_info: AuthInfo): + super().__init__(auth_info.user_id or "anonymous") + self.auth_info = auth_info + self.scopes = auth_info.scopes + + +class BearerAuthBackend(AuthenticationBackend): + """ + Authentication backend that validates Bearer tokens. """ def __init__( @@ -31,7 +37,7 @@ def __init__( required_scopes: Optional[List[str]] = None ): """ - Initialize the dependency. + Initialize the backend. Args: provider: Authentication provider to validate tokens @@ -39,28 +45,22 @@ def __init__( """ self.provider = provider self.required_scopes = required_scopes or [] - self.bearer_scheme = HTTPBearer() - async def __call__(self, request: Request) -> AuthInfo: - """ - Process the request and validate the bearer token. - - Args: - request: FastAPI request + async def authenticate(self, conn: HTTPConnection): + + if "Authorization" not in conn.headers: + raise AuthenticationError() + return None - Returns: - Authenticated auth info + auth_header = conn.headers["Authorization"] + if not auth_header.startswith("Bearer "): + return None - Raises: - HTTPException: If token validation fails - """ + token = auth_header[7:] # Remove "Bearer " prefix + try: - # Extract and validate the authorization header using FastAPI's built-in scheme - credentials: HTTPAuthorizationCredentials = await self.bearer_scheme(request) - token = credentials.credentials - # Validate the token with the provider - auth_info: AuthInfo = await self.provider.verify_access_token(token) + auth_info = await self.provider.verify_access_token(token) # Check if the token has all required scopes if self.required_scopes: @@ -72,27 +72,49 @@ async def __call__(self, request: Request) -> AuthInfo: if auth_info.expires_at and auth_info.expires_at < int(time.time()): raise InvalidTokenError("Token has expired") - return auth_info + return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) - except InvalidTokenError as e: - # Return a 401 Unauthorized response with appropriate headers - headers = {"WWW-Authenticate": f'Bearer error="{e.error_code}", error_description="{str(e)}"'} - raise HTTPException( - status_code=401, - detail=e.to_response_object(), - headers=headers - ) - except InsufficientScopeError as e: - # Return a 403 Forbidden response with appropriate headers - headers = {"WWW-Authenticate": f'Bearer error="{e.error_code}", error_description="{str(e)}"'} - raise HTTPException( - status_code=403, - detail=e.to_response_object(), - headers=headers - ) - except OAuthError as e: - # Return a 400 Bad Request response for other OAuth errors - raise HTTPException( - status_code=400, - detail=e.to_response_object() - ) \ No newline at end of file + except (InvalidTokenError, InsufficientScopeError, OAuthError): + # Return None to indicate authentication failure + return None + + +class BearerAuthMiddleware: + """ + Middleware that requires a valid Bearer token in the Authorization header. + + This will validate the token with the auth provider and store the resulting + auth info in the request state. + + Corresponds to bearerAuthMiddleware in src/server/auth/middleware/bearerAuth.ts + """ + + def __init__( + self, + app: Any, + provider: OAuthServerProvider, + required_scopes: Optional[List[str]] = None + ): + """ + Initialize the middleware. + + Args: + app: ASGI application + provider: Authentication provider to validate tokens + required_scopes: Optional list of scopes that the token must have + """ + self.app = AuthenticationMiddleware( + app, + backend=BearerAuthBackend(provider, required_scopes) + ) + + async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None: + """ + Process the request and validate the bearer token. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + await self.app(scope, receive, send) \ No newline at end of file diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 040894381a..9aab1d3c12 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -1,13 +1,14 @@ """ -Client authentication dependency for FastAPI. +Client authentication middleware for ASGI applications. Corresponds to TypeScript file: src/server/auth/middleware/clientAuth.ts """ import time -from typing import Optional +from typing import Optional, Dict, Any, Callable -from fastapi import Request, HTTPException, Depends +from starlette.requests import Request +from starlette.exceptions import HTTPException from pydantic import BaseModel, ValidationError from mcp.server.auth.errors import ( @@ -30,11 +31,11 @@ class ClientAuthRequest(BaseModel): client_secret: Optional[str] = None -class ClientAuthDependency: +class ClientAuthenticator: """ Dependency that authenticates a client using client_id and client_secret. - This will validate the client credentials and return the client information. + This is a callable that can be used to validate client credentials in a request. Corresponds to authenticateClient in src/server/auth/middleware/clientAuth.ts """ @@ -48,71 +49,75 @@ def __init__(self, clients_store: OAuthRegisteredClientsStore): """ self.clients_store = clients_store - async def __call__(self, request: Request) -> OAuthClientInformationFull: + async def __call__(self, request: ClientAuthRequest) -> OAuthClientInformationFull: + # Look up client information + client = await self.clients_store.get_client(request.client_id) + if not client: + raise InvalidClientError("Invalid client_id") + + # If client from the store expects a secret, validate that the request provides that secret + if client.client_secret: + if not request.client_secret: + raise InvalidClientError("Client secret is required") + + if client.client_secret != request.client_secret: + raise InvalidClientError("Invalid client_secret") + + if (client.client_secret_expires_at and + client.client_secret_expires_at < int(time.time())): + raise InvalidClientError("Client secret has expired") + + return client + + + +class ClientAuthMiddleware: + """ + Middleware that authenticates clients using client_id and client_secret. + + This middleware will validate client credentials and store client information + in the request state. + """ + + def __init__( + self, + app: Any, + clients_store: OAuthRegisteredClientsStore, + ): + """ + Initialize the middleware. + + Args: + app: ASGI application + clients_store: Store for client information + """ + self.app = app + self.client_auth = ClientAuthenticator(clients_store) + + async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None: """ Process the request and authenticate the client. Args: - request: FastAPI request - - Returns: - Authenticated client information - - Raises: - HTTPException: If client authentication fails + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function """ - try: - # Parse request body as form data or JSON - content_type = request.headers.get("Content-Type", "") - - if "application/x-www-form-urlencoded" in content_type: - # Parse form data - request_data = await request.form() - elif "application/json" in content_type: - # Parse JSON data - request_data = await request.json() - else: - raise InvalidRequestError("Unsupported content type") + if scope["type"] != "http": + await self.app(scope, receive, send) + return - # Validate client credentials in request - try: - # TODO: can I just pass request_data to model_validate without pydantic complaining about extra params? - client_request = ClientAuthRequest.model_validate({ - "client_id": request_data.get("client_id"), - "client_secret": request_data.get("client_secret"), - }) - except ValidationError as e: - raise InvalidRequestError(str(e)) - - # Look up client information - client_id = client_request.client_id - client_secret = client_request.client_secret - - client = await self.clients_store.get_client(client_id) - if not client: - raise InvalidClientError("Invalid client_id") - - # If client has a secret, validate it - if client.client_secret: - # Check if client_secret is required but not provided - if not client_secret: - raise InvalidClientError("Client secret is required") - - # Check if client_secret matches - if client.client_secret != client_secret: - raise InvalidClientError("Invalid client_secret") - - # Check if client_secret has expired - if (client.client_secret_expires_at and - client.client_secret_expires_at < int(time.time())): - raise InvalidClientError("Client secret has expired") - - return client + # Create a request object to access the request data + request = Request(scope, receive=receive) + + # Add client authentication to the request + try: + client = await self.client_auth(request) + # Store the client in the request state + request.state.client = client + except HTTPException: + # Continue without authentication + pass - except OAuthError as e: - status_code = 500 if isinstance(e, ServerError) else 400 - # TODO: make sure we're not leaking anything here - raise HTTPException( - status_code=status_code, - detail=e.to_response_object() - ) \ No newline at end of file + # Continue processing the request + await self.app(scope, receive, send) \ No newline at end of file diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 1412992ac1..64995a8359 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -134,6 +134,8 @@ async def exchange_refresh_token(self, The new access and refresh tokens. """ ... + + # TODO: consider methods to generate refresh tokens and access tokens async def verify_access_token(self, token: str) -> AuthInfo: """ diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 8fdcdf6a09..07f703b32f 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -6,13 +6,15 @@ from dataclasses import dataclass import re -from typing import Dict, List, Optional, Any, Union +from typing import Dict, List, Optional, Any, Union, Callable from urllib.parse import urlparse -from fastapi import Depends, FastAPI, APIRouter, Request, Response +from starlette.routing import Route, Router +from starlette.requests import Request +from starlette.middleware import Middleware from pydantic import AnyUrl, BaseModel -from mcp.server.auth.middleware.client_auth import ClientAuthDependency +from mcp.server.auth.middleware.client_auth import ClientAuthMiddleware, ClientAuthenticator from mcp.server.auth.provider import OAuthServerProvider from mcp.shared.auth import OAuthMetadata from mcp.server.auth.handlers.metadata import create_metadata_handler @@ -67,9 +69,9 @@ def create_auth_router( service_documentation_url: AnyUrl | None = None, client_registration_options: ClientRegistrationOptions | None = None, revocation_options: RevocationOptions | None = None - ) -> APIRouter: + ) -> Router: """ - Create a FastAPI application with standard MCP authorization endpoints. + Create a Starlette router with standard MCP authorization endpoints. Corresponds to mcpAuthRouter in src/server/auth/router.ts @@ -77,72 +79,69 @@ def create_auth_router( provider: OAuth server provider issuer_url: Issuer URL for the authorization server service_documentation_url: Optional URL for service documentation + client_registration_options: Options for client registration + revocation_options: Options for token revocation Returns: - FastAPI application with authorization endpoints + Starlette router with authorization endpoints """ validate_issuer_url(issuer_url) client_registration_options = client_registration_options or ClientRegistrationOptions() revocation_options = revocation_options or RevocationOptions() - - client_auth = ClientAuthDependency(provider.clients_store) - - auth_app = APIRouter() - - - # Create handlers - - # Add routes - metadata = build_metadata(issuer_url, service_documentation_url, client_registration_options, revocation_options) - auth_app.add_api_route( - "/.well-known/oauth-authorization-server", - create_metadata_handler(metadata), - methods=["GET"] - ) - - # NOTE: reviewed - auth_app.add_api_route( - AUTHORIZATION_PATH, - create_authorization_handler(provider), - methods=["GET", "POST"] - ) - - # Add token endpoint with client auth dependency - # NOTE: reviewed - auth_app.add_api_route( - TOKEN_PATH, - create_token_handler(provider), - methods=["POST"], - dependencies=[Depends(client_auth)] + metadata = build_metadata( + issuer_url, + service_documentation_url, + client_registration_options, + revocation_options, ) + client_authenticator = ClientAuthenticator(provider.clients_store) + + # Create routes + auth_router = Router(routes=[ + Route( + "/.well-known/oauth-authorization-server", + endpoint=create_metadata_handler(metadata), + methods=["GET"] + ), + Route( + AUTHORIZATION_PATH, + endpoint=create_authorization_handler(provider), + methods=["GET", "POST"] + ), + Route( + TOKEN_PATH, + endpoint=create_token_handler(provider, client_authenticator), + methods=["POST"] + ) + ]) - # Add registration endpoint if supported if client_registration_options.enabled: from mcp.server.auth.handlers.register import create_registration_handler registration_handler = create_registration_handler( provider.clients_store, client_secret_expiry_seconds=client_registration_options.client_secret_expiry_seconds, ) - # NOTE: reviewed - auth_app.add_api_route( - REGISTRATION_PATH, - registration_handler, - methods=["POST"] + auth_router.routes.append( + Route( + REGISTRATION_PATH, + endpoint=registration_handler, + methods=["POST"] + ) ) - # Add revocation endpoint if supported if revocation_options.enabled: - # NOTE: reviewed - auth_app.add_api_route( - REVOCATION_PATH, - create_revocation_handler(provider), - methods=["POST"], - dependencies=[Depends(client_auth)] + revocation_handler = create_revocation_handler(provider, client_authenticator) + auth_router.routes.append( + Route( + REVOCATION_PATH, + endpoint=revocation_handler, + methods=["POST"] + ) ) - return auth_app + return auth_router def build_metadata( issuer_url: AnyUrl, diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py index 98d9ebde4d..494a4c30b0 100644 --- a/src/mcp/server/auth/types.py +++ b/src/mcp/server/auth/types.py @@ -18,6 +18,7 @@ class AuthInfo(BaseModel): client_id: str scopes: List[str] expires_at: Optional[int] = None + user_id: Optional[str] = None class Config: extra = "ignore" \ No newline at end of file diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 793a0b0755..5e5461c7b7 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -15,11 +15,15 @@ import anyio import pydantic_core +from starlette.applications import Starlette +from starlette.authentication import requires +from starlette.middleware.authentication import AuthenticationMiddleware import uvicorn from pydantic import BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend from mcp.server.auth.provider import OAuthServerProvider from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions from mcp.server.auth.types import AuthInfo @@ -474,24 +478,15 @@ async def run_stdio_async(self) -> None: self._mcp_server.create_initialization_options(), ) - async def run_sse_async(self) -> None: + def starlette_app(self) -> Starlette: """Run the server using SSE transport.""" from starlette.applications import Starlette from starlette.routing import Mount, Route from starlette.middleware import Middleware - from fastapi import FastAPI, Depends - # Import auth dependency if needed - auth_dependencies = [] - if self._auth_provider: - from mcp.server.auth.middleware.bearer_auth import BearerAuthDependency - auth_dependencies = [Depends(BearerAuthDependency( - provider=self._auth_provider, - required_scopes=self.settings.auth_required_scopes - ))] + # Set up auth context and dependencies sse = SseServerTransport("/messages/") - async def handle_sse(request): # Add client ID from auth context into request context if available request_meta = {} @@ -505,26 +500,49 @@ async def handle_sse(request): self._mcp_server.create_initialization_options(), ) - # Create Starlette app - app = FastAPI(debug=self.settings.debug) - - # Add routes with auth dependency if required - app.add_api_route("/sse", endpoint=handle_sse, dependencies=auth_dependencies) - # TODO: convert this to a handler so it can take a dependency - app.mount("/messages/", sse.handle_post_message) # , dependencies=auth_dependencies) + # Create routes + routes = [] + middleware = [] + required_scopes = self.settings.auth_required_scopes or [] # Add auth endpoints if auth provider is configured if self._auth_provider and self.settings.auth_issuer_url: from mcp.server.auth.router import create_auth_router - auth_app = create_auth_router( + if "authenticated" not in required_scopes: + required_scopes.append("authenticated") + + # Set up bearer auth middleware if auth is required + middleware = [ + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend( + provider=self._auth_provider, + required_scopes=self.settings.auth_required_scopes + ) + ) + ] + auth_router = create_auth_router( self._auth_provider, self.settings.auth_issuer_url, self.settings.auth_service_documentation_url ) - app.mount("/", auth_app) + + # Add the auth router as a mount + routes.append(Mount("/", app=auth_router)) + routes.append(Route("/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"])) + routes.append(Mount("/messages/", app=requires(required_scopes)(sse.handle_post_message))) + + # Create Starlette app with routes and middleware + return Starlette( + debug=self.settings.debug, + routes=routes, + middleware=middleware + ) + + async def run_sse_async(self) -> None: config = uvicorn.Config( - app, + app=self.starlette_app(), host=self.settings.host, port=self.settings.port, log_level=self.settings.log_level.lower(), diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index f751065a20..3a65ad959a 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -39,8 +39,8 @@ class OAuthClientMetadata(BaseModel): Corresponds to OAuthClientMetadataSchema in src/shared/auth.ts """ redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) - token_endpoint_auth_method: Optional[str] - grant_types: Optional[List[str]] + token_endpoint_auth_method: Optional[str] = None + grant_types: Optional[List[str]] = None response_types: Optional[List[str]] = None client_name: Optional[str] = None client_uri: Optional[AnyHttpUrl] = None diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 3d7e51fbdf..7e8e69eec0 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -5,6 +5,7 @@ import base64 import hashlib import json +import secrets import time from typing import Any, Dict, List, Optional, cast from urllib.parse import urlparse, parse_qs @@ -12,16 +13,19 @@ import anyio from pydantic import AnyUrl import pytest -from fastapi import FastAPI, Depends -from fastapi.testclient import TestClient +import httpx +from starlette.applications import Starlette from starlette.datastructures import MutableHeaders -from starlette.responses import RedirectResponse, JSONResponse +from starlette.testclient import TestClient +from starlette.routing import Route, Router, Mount +from starlette.responses import RedirectResponse, JSONResponse, Response from starlette.requests import Request +from starlette.middleware import Middleware from mcp.server.auth.errors import InvalidTokenError -from mcp.server.auth.middleware.bearer_auth import BearerAuthDependency +from mcp.server.auth.middleware.client_auth import ClientAuthMiddleware from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, OAuthRegisteredClientsStore -from mcp.server.auth.router import create_auth_router +from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions, create_auth_router from mcp.server.auth.types import AuthInfo from mcp.shared.auth import ( OAuthClientInformationFull, @@ -45,7 +49,7 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> OAut # Mock OAuth provider for testing -class MockOAuthProvider: +class MockOAuthProvider(OAuthServerProvider): def __init__(self): self.client_store = MockClientStore() self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} @@ -59,7 +63,7 @@ def clients_store(self) -> OAuthRegisteredClientsStore: async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams, - response: RedirectResponse) -> None: + response: Response): # Generate an authorization code code = f"code_{int(time.time())}" @@ -80,8 +84,8 @@ async def authorize(self, response.headers["location"] = redirect_url async def challenge_for_authorization_code(self, - client: OAuthClientInformationFull, - authorization_code: str) -> str: + client: OAuthClientInformationFull, + authorization_code: str) -> str: # Get the stored code info code_info = self.auth_codes.get(authorization_code) if not code_info: @@ -98,8 +102,8 @@ async def challenge_for_authorization_code(self, return code_info["code_challenge"] async def exchange_authorization_code(self, - client: OAuthClientInformationFull, - authorization_code: str) -> OAuthTokens: + client: OAuthClientInformationFull, + authorization_code: str) -> OAuthTokens: # Get the stored code info code_info = self.auth_codes.get(authorization_code) if not code_info: @@ -114,8 +118,8 @@ async def exchange_authorization_code(self, raise InvalidTokenError("Authorization code was not issued to this client") # Generate an access token and refresh token - access_token = f"access_{int(time.time())}" - refresh_token = f"refresh_{int(time.time())}" + access_token = f"access_{secrets.token_hex(32)}" + refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the tokens self.tokens[access_token] = { @@ -138,9 +142,9 @@ async def exchange_authorization_code(self, ) async def exchange_refresh_token(self, - client: OAuthClientInformationFull, - refresh_token: str, - scopes: Optional[List[str]] = None) -> OAuthTokens: + client: OAuthClientInformationFull, + refresh_token: str, + scopes: Optional[List[str]] = None) -> OAuthTokens: # Check if refresh token exists if refresh_token not in self.refresh_tokens: raise InvalidTokenError("Invalid refresh token") @@ -158,8 +162,8 @@ async def exchange_refresh_token(self, raise InvalidTokenError("Refresh token was not issued to this client") # Generate a new access token and refresh token - new_access_token = f"access_{int(time.time())}" - new_refresh_token = f"refresh_{int(time.time())}" + new_access_token = f"access_{secrets.token_hex(32)}" + new_refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the new tokens self.tokens[new_access_token] = { @@ -202,8 +206,8 @@ async def verify_access_token(self, token: str) -> AuthInfo: ) async def revoke_token(self, - client: OAuthClientInformationFull, - request: OAuthTokenRevocationRequest) -> None: + client: OAuthClientInformationFull, + request: OAuthTokenRevocationRequest) -> None: token = request.token # Check if it's a refresh token @@ -242,24 +246,42 @@ def mock_oauth_provider(): @pytest.fixture def auth_app(mock_oauth_provider): - app = create_auth_router( + # Create auth router + auth_router = create_auth_router( mock_oauth_provider, AnyUrl("https://auth.example.com"), AnyUrl("https://docs.example.com"), + client_registration_options=ClientRegistrationOptions( + enabled=True + ), + revocation_options=RevocationOptions( + enabled=True + ) + ) + + # Create Starlette app + app = Starlette( + routes=[ + Mount("/", app=auth_router) + ] ) + return app @pytest.fixture -def test_client(auth_app): - return TestClient(auth_app) - +def test_client(auth_app) -> httpx.AsyncClient: + return httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") -@pytest.mark.anyio class TestAuthEndpoints: - def test_metadata_endpoint(self, test_client): + @pytest.mark.anyio + async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): """Test the OAuth 2.0 metadata endpoint.""" - response = test_client.get("/.well-known/oauth-authorization-server") + print("Sending request to metadata endpoint") + response = await test_client.get("/.well-known/oauth-authorization-server") + print(f"Got response: {response.status_code}") + if response.status_code != 200: + print(f"Response content: {response.content}") assert response.status_code == 200 metadata = response.json() @@ -275,7 +297,7 @@ def test_metadata_endpoint(self, test_client): assert metadata["service_documentation"] == "https://docs.example.com" @pytest.mark.anyio - async def test_client_registration(self, test_client, mock_oauth_provider): + async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): """Test client registration.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -283,11 +305,11 @@ async def test_client_registration(self, test_client, mock_oauth_provider): "client_uri": "https://client.example.com", } - response = test_client.post( + response = await test_client.post( "/register", json=client_metadata, ) - assert response.status_code == 201 + assert response.status_code == 201, response.content client_info = response.json() assert "client_id" in client_info @@ -296,10 +318,10 @@ async def test_client_registration(self, test_client, mock_oauth_provider): assert client_info["redirect_uris"] == ["https://client.example.com/callback"] # Verify that the client was registered - assert await mock_oauth_provider.clients_store.get_client(client_info["client_id"]) is not None + #assert await mock_oauth_provider.clients_store.get_client(client_info["client_id"]) is not None @pytest.mark.anyio - async def test_authorization_flow(self, test_client, mock_oauth_provider): + async def test_authorization_flow(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): """Test the full authorization flow.""" # 1. Register a client client_metadata = { @@ -307,7 +329,7 @@ async def test_authorization_flow(self, test_client, mock_oauth_provider): "client_name": "Test Client", } - response = test_client.post( + response = await test_client.post( "/register", json=client_metadata, ) @@ -321,7 +343,7 @@ async def test_authorization_flow(self, test_client, mock_oauth_provider): ).decode().rstrip("=") # 3. Request authorization - response = test_client.get( + response = await test_client.get( "/authorize", params={ "response_type": "code", @@ -331,7 +353,6 @@ async def test_authorization_flow(self, test_client, mock_oauth_provider): "code_challenge_method": "S256", "state": "test_state", }, - allow_redirects=False, ) assert response.status_code == 302 @@ -345,9 +366,9 @@ async def test_authorization_flow(self, test_client, mock_oauth_provider): auth_code = query_params["code"][0] # 5. Exchange the authorization code for tokens - response = test_client.post( + response = await test_client.post( "/token", - data={ + json={ "grant_type": "authorization_code", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], @@ -375,9 +396,9 @@ async def test_authorization_flow(self, test_client, mock_oauth_provider): assert "write" in auth_info.scopes # 7. Refresh the token - response = test_client.post( + response = await test_client.post( "/token", - data={ + json={ "grant_type": "refresh_token", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], @@ -393,9 +414,9 @@ async def test_authorization_flow(self, test_client, mock_oauth_provider): assert new_token_response["refresh_token"] != refresh_token # 8. Revoke the token - response = test_client.post( + response = await test_client.post( "/revoke", - data={ + json={ "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "token": new_token_response["access_token"], @@ -408,12 +429,11 @@ async def test_authorization_flow(self, test_client, mock_oauth_provider): await mock_oauth_provider.verify_access_token(new_token_response["access_token"]) -@pytest.mark.anyio class TestFastMCPWithAuth: """Test FastMCP server with authentication.""" @pytest.mark.anyio - async def test_fastmcp_with_auth(self, mock_oauth_provider): + async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider): """Test creating a FastMCP server with authentication.""" # Create FastMCP server with auth provider mcp = FastMCP( @@ -427,60 +447,19 @@ async def test_fastmcp_with_auth(self, mock_oauth_provider): def test_tool(x: int) -> str: return f"Result: {x}" - # Create a FastAPI app for testing - from fastapi import FastAPI, Depends, Security - - # Override the run method to capture the app - app = None - - async def mock_run_sse(): - nonlocal app - - # Create auth dependency - auth_dependency = BearerAuthDependency( - provider=mock_oauth_provider, - required_scopes=mcp.settings.auth_required_scopes - ) - - # Create FastAPI app - app = FastAPI(debug=mcp.settings.debug) - - # Add a test endpoint that requires authentication - @app.get("/test") - async def test_endpoint(auth: AuthInfo = Depends(auth_dependency)): - return {"status": "ok", "client_id": auth.client_id} - - # Add another endpoint that doesn't require auth for comparison - @app.get("/public") - async def public_endpoint(): - return {"status": "ok"} - - # Add auth endpoints - from mcp.server.auth.router import create_auth_router - auth_app = create_auth_router( - mock_oauth_provider, - cast(AnyUrl, mcp.settings.auth_issuer_url), - mcp.settings.auth_service_documentation_url - ) - app.mount("/", auth_app) - - # Override the run method - mcp.run_sse_async = mock_run_sse - await mcp.run_sse_async() - - assert app is not None - test_client = TestClient(app) + transport = httpx.ASGITransport(app=mcp.starlette_app()) # pyright: ignore + test_client = httpx.AsyncClient(transport=transport, base_url="http://mcptest.com") # Test metadata endpoint - response = test_client.get("/.well-known/oauth-authorization-server") + response = await test_client.get("/.well-known/oauth-authorization-server") assert response.status_code == 200 # Test that auth is required for protected endpoints - response = test_client.get("/test") + response = await test_client.get("/test") assert response.status_code == 401 # Test that public endpoints don't require auth - response = test_client.get("/public") + response = await test_client.get("/public") assert response.status_code == 200 # Register a client @@ -489,7 +468,7 @@ async def public_endpoint(): "client_name": "Test Client", } - response = test_client.post( + response = await test_client.post( "/register", json=client_metadata, ) @@ -503,7 +482,7 @@ async def public_endpoint(): ).decode().rstrip("=") # Request authorization - response = test_client.get( + response = await test_client.get( "/authorize", params={ "response_type": "code", @@ -513,7 +492,6 @@ async def public_endpoint(): "code_challenge_method": "S256", "state": "test_state", }, - allow_redirects=False, ) assert response.status_code == 302 @@ -526,7 +504,7 @@ async def public_endpoint(): auth_code = query_params["code"][0] # Exchange the authorization code for tokens - response = test_client.post( + response = await test_client.post( "/token", data={ "grant_type": "authorization_code", @@ -542,7 +520,7 @@ async def public_endpoint(): assert "access_token" in token_response # Test the authenticated endpoint with valid token - response = test_client.get( + response = await test_client.get( "/test", headers={"Authorization": f"Bearer {token_response['access_token']}"}, ) @@ -551,7 +529,7 @@ async def public_endpoint(): assert response.json()["client_id"] == client_info["client_id"] # Test with invalid token - response = test_client.get( + response = await test_client.get( "/test", headers={"Authorization": "Bearer invalid_token"}, ) From e96d280a683f0dbfab4f3939d4f5cbd93c47928a Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Sun, 9 Mar 2025 21:53:34 -0700 Subject: [PATCH 04/60] Get tests passing --- src/mcp/server/auth/middleware/bearer_auth.py | 50 ++--- src/mcp/server/fastmcp/server.py | 23 +- src/mcp/server/sse.py | 9 +- .../fastmcp/auth/streaming_asgi_transport.py | 197 ++++++++++++++++++ .../fastmcp/auth/test_auth_integration.py | 88 +++++--- 5 files changed, 297 insertions(+), 70 deletions(-) create mode 100644 tests/server/fastmcp/auth/streaming_asgi_transport.py diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 431bf16efc..6a023f3215 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -9,8 +9,9 @@ from starlette.requests import HTTPConnection, Request from starlette.exceptions import HTTPException -from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser +from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser, has_required_scope from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.types import Scope from mcp.server.auth.errors import InsufficientScopeError, InvalidTokenError, OAuthError from mcp.server.auth.provider import OAuthServerProvider @@ -34,22 +35,12 @@ class BearerAuthBackend(AuthenticationBackend): def __init__( self, provider: OAuthServerProvider, - required_scopes: Optional[List[str]] = None ): - """ - Initialize the backend. - - Args: - provider: Authentication provider to validate tokens - required_scopes: Optional list of scopes that the token must have - """ self.provider = provider - self.required_scopes = required_scopes or [] async def authenticate(self, conn: HTTPConnection): if "Authorization" not in conn.headers: - raise AuthenticationError() return None auth_header = conn.headers["Authorization"] @@ -61,14 +52,7 @@ async def authenticate(self, conn: HTTPConnection): try: # Validate the token with the provider auth_info = await self.provider.verify_access_token(token) - - # Check if the token has all required scopes - if self.required_scopes: - has_all_scopes = all(scope in auth_info.scopes for scope in self.required_scopes) - if not has_all_scopes: - raise InsufficientScopeError("Insufficient scope") - - # Check if the token is expired + if auth_info.expires_at and auth_info.expires_at < int(time.time()): raise InvalidTokenError("Token has expired") @@ -79,7 +63,7 @@ async def authenticate(self, conn: HTTPConnection): return None -class BearerAuthMiddleware: +class RequireAuthMiddleware: """ Middleware that requires a valid Bearer token in the Authorization header. @@ -92,8 +76,7 @@ class BearerAuthMiddleware: def __init__( self, app: Any, - provider: OAuthServerProvider, - required_scopes: Optional[List[str]] = None + required_scopes: list[str] ): """ Initialize the middleware. @@ -103,18 +86,15 @@ def __init__( provider: Authentication provider to validate tokens required_scopes: Optional list of scopes that the token must have """ - self.app = AuthenticationMiddleware( - app, - backend=BearerAuthBackend(provider, required_scopes) - ) - - async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None: - """ - Process the request and validate the bearer token. + self.app = app + self.required_scopes = required_scopes + + async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: + auth_credentials = scope.get('auth') - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ + for required_scope in self.required_scopes: + # auth_credentials should always be provided; this is just paranoia + if auth_credentials is None or required_scope not in auth_credentials.scopes: + raise HTTPException(status_code=403, detail="Insufficient scope") + await self.app(scope, receive, send) \ No newline at end of file diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 5e5461c7b7..af3b41b797 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -18,12 +18,13 @@ from starlette.applications import Starlette from starlette.authentication import requires from starlette.middleware.authentication import AuthenticationMiddleware +from sse_starlette import EventSourceResponse import uvicorn from pydantic import BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict -from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthServerProvider from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions from mcp.server.auth.types import AuthInfo @@ -487,7 +488,7 @@ def starlette_app(self) -> Starlette: # Set up auth context and dependencies sse = SseServerTransport("/messages/") - async def handle_sse(request): + async def handle_sse(request) -> EventSourceResponse: # Add client ID from auth context into request context if available request_meta = {} @@ -499,17 +500,17 @@ async def handle_sse(request): streams[1], self._mcp_server.create_initialization_options(), ) + return streams[2] # Create routes routes = [] middleware = [] required_scopes = self.settings.auth_required_scopes or [] + auth_router = None # Add auth endpoints if auth provider is configured if self._auth_provider and self.settings.auth_issuer_url: from mcp.server.auth.router import create_auth_router - if "authenticated" not in required_scopes: - required_scopes.append("authenticated") # Set up bearer auth middleware if auth is required middleware = [ @@ -517,21 +518,23 @@ async def handle_sse(request): AuthenticationMiddleware, backend=BearerAuthBackend( provider=self._auth_provider, - required_scopes=self.settings.auth_required_scopes ) ) ] auth_router = create_auth_router( - self._auth_provider, - self.settings.auth_issuer_url, - self.settings.auth_service_documentation_url + provider=self._auth_provider, + issuer_url=self.settings.auth_issuer_url, + service_documentation_url=self.settings.auth_service_documentation_url, + client_registration_options=self.settings.auth_client_registration_options, + revocation_options=self.settings.auth_revocation_options ) # Add the auth router as a mount - routes.append(Mount("/", app=auth_router)) routes.append(Route("/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"])) - routes.append(Mount("/messages/", app=requires(required_scopes)(sse.handle_post_message))) + routes.append(Mount("/messages/", app=RequireAuthMiddleware(sse.handle_post_message, required_scopes))) + if auth_router: + routes.append(Mount("/", app=auth_router)) # Create Starlette app with routes and middleware return Starlette( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 0127753d01..75c1f7302e 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -34,6 +34,7 @@ async def handle_sse(request): import logging from contextlib import asynccontextmanager from typing import Any +from typing_extensions import deprecated from urllib.parse import quote from uuid import UUID, uuid4 @@ -44,6 +45,7 @@ async def handle_sse(request): from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send +from sse_starlette import EventSourceResponse import mcp.types as types @@ -78,6 +80,7 @@ def __init__(self, endpoint: str) -> None: self._read_stream_writers = {} logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") + @deprecated("use connect_sse_v2 instead") @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": @@ -128,7 +131,11 @@ async def sse_writer(): tg.start_soon(response, scope, receive, send) logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + # TODO: hold on; shouldn't we be returning the EventSourceResponse? + # I think this is why the tests hang + # TODO: we probably shouldn't return response here, since it's a breaking change + # this is just to test + yield (read_stream, write_stream, response) async def handle_post_message( self, scope: Scope, receive: Receive, send: Send diff --git a/tests/server/fastmcp/auth/streaming_asgi_transport.py b/tests/server/fastmcp/auth/streaming_asgi_transport.py new file mode 100644 index 0000000000..66774ba67b --- /dev/null +++ b/tests/server/fastmcp/auth/streaming_asgi_transport.py @@ -0,0 +1,197 @@ +""" +A modified version of httpx.ASGITransport that supports streaming responses. + +This transport runs the ASGI app as a separate anyio task, allowing it to +handle streaming responses like SSE where the app doesn't terminate until +the connection is closed. +""" + +import typing +from typing import Any, Dict, List, Optional, Tuple, cast + +import anyio +import anyio.streams.memory +from anyio.abc import TaskStatus +import httpx +from httpx._transports.asgi import ASGIResponseStream +from httpx._transports.base import AsyncBaseTransport +from httpx._models import Request, Response +from httpx._types import AsyncByteStream +import asyncio + + + +class StreamingASGITransport(AsyncBaseTransport): + """ + A custom AsyncTransport that handles sending requests directly to an ASGI app + and supports streaming responses like SSE. + + Unlike the standard ASGITransport, this transport runs the ASGI app in a + separate anyio task, allowing it to handle responses from apps that don't + terminate immediately (like SSE endpoints). + + Arguments: + + * `app` - The ASGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `root_path` - The root path on which the ASGI application should be mounted. + * `client` - A two-tuple indicating the client IP and port of incoming requests. + * `response_timeout` - Timeout in seconds to wait for the initial response. + Default is 10 seconds. + """ + + def __init__( + self, + app: typing.Callable, + raise_app_exceptions: bool = True, + root_path: str = "", + client: Tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.root_path = root_path + self.client = client + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + # ASGI scope. + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?")[0], + "query_string": request.url.query, + "server": (request.url.host, request.url.port), + "client": self.client, + "root_path": self.root_path, + } + + # Request body + request_body_chunks = request.stream.__aiter__() + request_complete = False + + # Response state + status_code = 499 + response_headers = None + response_started = False + response_complete = anyio.Event() + initial_response_ready = anyio.Event() + + # Synchronization for streaming response + asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream(100) + content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) + + # ASGI callables. + async def receive() -> Dict[str, Any]: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} + + try: + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: Dict[str, Any]) -> None: + nonlocal status_code, response_headers, response_started + + await asgi_send_channel.send(message) + + # Start the ASGI application in a separate task + async def run_app() -> None: + try: + await self.app(scope, receive, send) + except Exception: + if self.raise_app_exceptions: + raise + + if not response_started: + await asgi_send_channel.send({ + "type": "http.response.start", + "status": 500, + "headers": [] + }) + + await asgi_send_channel.send({ + "type": "http.response.body", + "body": b"", + "more_body": False + }) + finally: + await asgi_send_channel.aclose() + + # Process messages from the ASGI app + async def process_messages() -> None: + nonlocal status_code, response_headers, response_started + + try: + async with asgi_receive_channel: + async for message in asgi_receive_channel: + if message["type"] == "http.response.start": + assert not response_started + status_code = message["status"] + response_headers = message.get("headers", []) + response_started = True + + # As soon as we have headers, we can return a response + initial_response_ready.set() + + elif message["type"] == "http.response.body": + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and request.method != "HEAD": + await content_send_channel.send(body) + + if not more_body: + response_complete.set() + await content_send_channel.aclose() + break + finally: + # Ensure events are set even if there's an error + initial_response_ready.set() + response_complete.set() + + # Create tasks for running the app and processing messages + app_task = asyncio.create_task(run_app()) + process_task = asyncio.create_task(process_messages()) + + # Wait for the initial response or timeout + await initial_response_ready.wait() + + # Create a streaming response + return Response(status_code, headers=response_headers, stream=StreamingASGIResponseStream(content_receive_channel)) + + +class StreamingASGIResponseStream(AsyncByteStream): + """ + A modified ASGIResponseStream that supports streaming responses. + + This class extends the standard ASGIResponseStream to handle cases where + the response body continues to be generated after the initial response + is returned. + """ + + def __init__( + self, + receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], + ) -> None: + self.receive_channel = receive_channel + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async for chunk in self.receive_channel: + yield chunk \ No newline at end of file diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 7e8e69eec0..423073779e 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -14,6 +14,7 @@ from pydantic import AnyUrl import pytest import httpx +from httpx_sse import aconnect_sse from starlette.applications import Starlette from starlette.datastructures import MutableHeaders from starlette.testclient import TestClient @@ -21,6 +22,7 @@ from starlette.responses import RedirectResponse, JSONResponse, Response from starlette.requests import Request from starlette.middleware import Middleware +from starlette.types import ASGIApp from mcp.server.auth.errors import InvalidTokenError from mcp.server.auth.middleware.client_auth import ClientAuthMiddleware @@ -33,6 +35,8 @@ OAuthTokens, ) from mcp.server.fastmcp import FastMCP +from mcp.types import JSONRPCRequest +from .streaming_asgi_transport import StreamingASGITransport # Mock client store for testing @@ -440,6 +444,13 @@ async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider): auth_provider=mock_oauth_provider, auth_issuer_url="https://auth.example.com", require_auth=True, + auth_client_registration_options=ClientRegistrationOptions( + enabled=True + ), + auth_revocation_options=RevocationOptions( + enabled=True + ), + auth_required_scopes=["read"] ) # Add a test tool @@ -447,22 +458,24 @@ async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider): def test_tool(x: int) -> str: return f"Result: {x}" - transport = httpx.ASGITransport(app=mcp.starlette_app()) # pyright: ignore + transport = StreamingASGITransport(app=mcp.starlette_app()) # pyright: ignore test_client = httpx.AsyncClient(transport=transport, base_url="http://mcptest.com") + # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") # Test metadata endpoint response = await test_client.get("/.well-known/oauth-authorization-server") assert response.status_code == 200 # Test that auth is required for protected endpoints - response = await test_client.get("/test") - assert response.status_code == 401 - - # Test that public endpoints don't require auth - response = await test_client.get("/public") - assert response.status_code == 200 + response = await test_client.get("/sse") + # TODO: we should return 401/403 depending on whether authn or authz fails + assert response.status_code == 403 + + response = await test_client.post("/messages/") + # TODO: we should return 401/403 depending on whether authn or authz fails + assert response.status_code == 403, response.content - # Register a client + # now, become authenticated and try to go through the flow again client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", @@ -506,7 +519,7 @@ def test_tool(x: int) -> str: # Exchange the authorization code for tokens response = await test_client.post( "/token", - data={ + json={ "grant_type": "authorization_code", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], @@ -518,19 +531,46 @@ def test_tool(x: int) -> str: token_response = response.json() assert "access_token" in token_response - + authorization = f"Bearer {token_response['access_token']}" + + # Test the authenticated endpoint with valid token - response = await test_client.get( - "/test", - headers={"Authorization": f"Bearer {token_response['access_token']}"}, - ) - assert response.status_code == 200 - assert response.json()["status"] == "ok" - assert response.json()["client_id"] == client_info["client_id"] - - # Test with invalid token - response = await test_client.get( - "/test", - headers={"Authorization": "Bearer invalid_token"}, - ) - assert response.status_code == 401 \ No newline at end of file + async with aconnect_sse(test_client, "GET", "/sse", headers={"Authorization": authorization}) as event_source: + assert event_source.response.status_code == 200 + events = event_source.aiter_sse() + sse = await events.__anext__() + assert sse.event == "endpoint" + assert sse.data.startswith("/messages/?session_id=") + messages_uri = sse.data + + # verify that we can now post to the /messages endpoint, and get a response on the /sse endpoint + response = await test_client.post( + messages_uri, + headers={"Authorization": authorization}, + content=JSONRPCRequest( + jsonrpc="2.0", + id="123", + method="initialize", + params={ + "protocolVersion": "2024-11-05", + "capabilities": { + "roots": { + "listChanged": True + }, + "sampling": {}, + }, + "clientInfo": { + "name": "ExampleClient", + "version": "1.0.0" + } + }, + ).model_dump_json(), + ) + assert response.status_code == 202 + assert response.content == b"Accepted" + + sse = await events.__anext__() + assert sse.event == "message" + sse_data = json.loads(sse.data) + assert sse_data["id"] == '123' + assert set(sse_data["result"]["capabilities"].keys()) == set(("experimental", "prompts", "resources", "tools")) \ No newline at end of file From 1e9dd4c213a2e31920df0c7dab4e6db699c3c5bc Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 13:44:38 -0700 Subject: [PATCH 05/60] Clean up provider interface --- src/mcp/server/auth/handlers/authorize.py | 19 ++++++++++++++----- src/mcp/server/auth/handlers/token.py | 4 ++++ src/mcp/server/auth/provider.py | 18 +++++++----------- .../fastmcp/auth/test_auth_integration.py | 16 +++++----------- 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index b13555347e..cb271b1618 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -36,9 +36,9 @@ class AuthorizationRequest(BaseModel): response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") code_challenge: str = Field(..., description="PKCE code challenge") - code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method") + code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256") state: Optional[str] = Field(None, description="Optional state parameter") - scope: Optional[str] = Field(None, description="Optional scope parameter") + scope: Optional[str] = Field(None, description="Optional scope; if specified, should be a space-separated list of scope strings") class Config: extra = "ignore" @@ -113,12 +113,21 @@ async def authorization_handler(request: Request) -> Response: code_challenge=auth_request.code_challenge, redirect_uri=redirect_uri, ) - - response = RedirectResponse(url="", status_code=302, headers={"Cache-Control": "no-store"}) try: # Let the provider handle the authorization flow - await provider.authorize(client, auth_params, response) + authorization_code = await provider.create_authorization_code(client, auth_params) + response = RedirectResponse(url="", status_code=302, headers={"Cache-Control": "no-store"}) + + # Redirect with code + parsed_uri = urlparse(str(auth_params.redirect_uri)) + query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] + query_params.append(("code", authorization_code)) + if auth_params.state: + query_params.append(("state", auth_params.state)) + + redirect_url = urlunparse(parsed_uri._replace(query=urlencode(query_params))) + response.headers["location"] = redirect_url return response except Exception as e: diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index e9d7ff293b..9b092ccc7f 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -88,6 +88,10 @@ async def token_handler(request: Request): match token_request: case AuthorizationCodeRequest(): + # TODO: verify that the redirect URIs match; does the client actually provide this? + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 + # TODO: enforce TTL on the authorization code + # Verify PKCE code verifier expected_challenge = await provider.challenge_for_authorization_code( client_info, token_request.code diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 64995a8359..5b30734d6b 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -72,19 +72,15 @@ def clients_store(self) -> OAuthRegisteredClientsStore: """ ... - # TODO: do we really want to be putting the response in this method? - async def authorize(self, + async def create_authorization_code(self, client: OAuthClientInformationFull, - params: AuthorizationParams, - response: Response) -> None: + params: AuthorizationParams) -> str: """ - Begins the authorization flow, which can be implemented by this server or via redirection. - Must eventually issue a redirect with authorization response or error to the given redirect URI. - - Args: - client: The client requesting authorization. - params: Parameters for the authorization request. - response: The response object to write to. + Generates and stores an authorization code as part of completing the /authorize OAuth step. + + Implementations SHOULD generate an authorization code with at least 160 bits of entropy, + and MUST generate an authorization code with at least 128 bits of entropy. + See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. """ ... diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 423073779e..a22c675dea 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -64,10 +64,9 @@ def __init__(self): def clients_store(self) -> OAuthRegisteredClientsStore: return self.client_store - async def authorize(self, + async def create_authorization_code(self, client: OAuthClientInformationFull, - params: AuthorizationParams, - response: Response): + params: AuthorizationParams) -> str: # Generate an authorization code code = f"code_{int(time.time())}" @@ -78,14 +77,9 @@ async def authorize(self, "redirect_uri": params.redirect_uri, "expires_at": int(time.time()) + 600, # 10 minutes } - - # Redirect with code - query = {"code": code} - if params.state: - query["state"] = params.state - - redirect_url = f"{params.redirect_uri}?" + "&".join([f"{k}={v}" for k, v in query.items()]) - response.headers["location"] = redirect_url + + return code + async def challenge_for_authorization_code(self, client: OAuthClientInformationFull, From d535089a661a94f43f5f0e24846c07d0fd23919e Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 13:45:58 -0700 Subject: [PATCH 06/60] Lint --- src/mcp/server/auth/__init__.py | 2 +- src/mcp/server/auth/errors.py | 52 +-- src/mcp/server/auth/handlers/__init__.py | 2 +- src/mcp/server/auth/handlers/authorize.py | 107 +++--- src/mcp/server/auth/handlers/metadata.py | 22 +- src/mcp/server/auth/handlers/register.py | 49 ++- src/mcp/server/auth/handlers/revoke.py | 44 +-- src/mcp/server/auth/handlers/token.py | 81 +++-- src/mcp/server/auth/json_response.py | 4 +- src/mcp/server/auth/middleware/__init__.py | 2 +- src/mcp/server/auth/middleware/bearer_auth.py | 55 +-- src/mcp/server/auth/middleware/client_auth.py | 55 ++- src/mcp/server/auth/provider.py | 105 +++--- src/mcp/server/auth/router.py | 142 ++++---- src/mcp/server/auth/types.py | 8 +- src/mcp/server/fastmcp/server.py | 65 ++-- src/mcp/server/sse.py | 3 +- src/mcp/shared/auth.py | 15 +- tests/server/fastmcp/auth/__init__.py | 2 +- .../fastmcp/auth/streaming_asgi_transport.py | 56 ++-- .../fastmcp/auth/test_auth_integration.py | 312 +++++++++--------- 21 files changed, 632 insertions(+), 551 deletions(-) diff --git a/src/mcp/server/auth/__init__.py b/src/mcp/server/auth/__init__.py index 5ad769fdfe..6888ffe8d9 100644 --- a/src/mcp/server/auth/__init__.py +++ b/src/mcp/server/auth/__init__.py @@ -1,3 +1,3 @@ """ MCP OAuth server authorization components. -""" \ No newline at end of file +""" diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index 702df08c91..badee09844 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -4,132 +4,142 @@ Corresponds to TypeScript file: src/server/auth/errors.ts """ -from typing import Dict, Optional, Any +from typing import Dict class OAuthError(Exception): """ Base class for all OAuth errors. - + Corresponds to OAuthError in src/server/auth/errors.ts """ + error_code: str = "server_error" - + def __init__(self, message: str): super().__init__(message) self.message = message - + def to_response_object(self) -> Dict[str, str]: """Convert error to JSON response object.""" - return { - "error": self.error_code, - "error_description": self.message - } + return {"error": self.error_code, "error_description": self.message} class ServerError(OAuthError): """ Server error. - + Corresponds to ServerError in src/server/auth/errors.ts """ + error_code = "server_error" class InvalidRequestError(OAuthError): """ Invalid request error. - + Corresponds to InvalidRequestError in src/server/auth/errors.ts """ + error_code = "invalid_request" class InvalidClientError(OAuthError): """ Invalid client error. - + Corresponds to InvalidClientError in src/server/auth/errors.ts """ + error_code = "invalid_client" class InvalidGrantError(OAuthError): """ Invalid grant error. - + Corresponds to InvalidGrantError in src/server/auth/errors.ts """ + error_code = "invalid_grant" class UnauthorizedClientError(OAuthError): """ Unauthorized client error. - + Corresponds to UnauthorizedClientError in src/server/auth/errors.ts """ + error_code = "unauthorized_client" class UnsupportedGrantTypeError(OAuthError): """ Unsupported grant type error. - + Corresponds to UnsupportedGrantTypeError in src/server/auth/errors.ts """ + error_code = "unsupported_grant_type" class UnsupportedResponseTypeError(OAuthError): """ Unsupported response type error. - + Corresponds to UnsupportedResponseTypeError in src/server/auth/errors.ts """ + error_code = "unsupported_response_type" class InvalidScopeError(OAuthError): """ Invalid scope error. - + Corresponds to InvalidScopeError in src/server/auth/errors.ts """ + error_code = "invalid_scope" class AccessDeniedError(OAuthError): """ Access denied error. - + Corresponds to AccessDeniedError in src/server/auth/errors.ts """ + error_code = "access_denied" class TemporarilyUnavailableError(OAuthError): """ Temporarily unavailable error. - + Corresponds to TemporarilyUnavailableError in src/server/auth/errors.ts """ + error_code = "temporarily_unavailable" class InvalidTokenError(OAuthError): """ Invalid token error. - + Corresponds to InvalidTokenError in src/server/auth/errors.ts """ + error_code = "invalid_token" class InsufficientScopeError(OAuthError): """ Insufficient scope error. - + Corresponds to InsufficientScopeError in src/server/auth/errors.ts """ - error_code = "insufficient_scope" \ No newline at end of file + + error_code = "insufficient_scope" diff --git a/src/mcp/server/auth/handlers/__init__.py b/src/mcp/server/auth/handlers/__init__.py index fb01dab61f..e99a62de1a 100644 --- a/src/mcp/server/auth/handlers/__init__.py +++ b/src/mcp/server/auth/handlers/__init__.py @@ -1,3 +1,3 @@ """ Request handlers for MCP authorization endpoints. -""" \ No newline at end of file +""" diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index cb271b1618..76b2802465 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -4,21 +4,16 @@ Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts """ -import re -from urllib.parse import urlparse, urlunparse, urlencode -from typing import Any, Callable, Dict, List, Literal, Optional -from urllib.parse import urlencode, parse_qs +from typing import Callable, Literal, Optional +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse, Response from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, ValidationError -from pydantic_core import Url +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response from mcp.server.auth.errors import ( - InvalidClientError, + InvalidClientError, InvalidRequestError, - UnsupportedResponseTypeError, - ServerError, OAuthError, ) from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider @@ -28,22 +23,35 @@ class AuthorizationRequest(BaseModel): """ Model for the authorization request parameters. - + Corresponds to request schema in authorizationHandler in src/server/auth/handlers/authorize.ts """ + client_id: str = Field(..., description="The client ID") - redirect_uri: AnyHttpUrl | None = Field(..., description="URL to redirect to after authorization") + redirect_uri: AnyHttpUrl | None = Field( + ..., description="URL to redirect to after authorization" + ) - response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") + response_type: Literal["code"] = Field( + ..., description="Must be 'code' for authorization code flow" + ) code_challenge: str = Field(..., description="PKCE code challenge") - code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256") + code_challenge_method: Literal["S256"] = Field( + "S256", description="PKCE code challenge method, must be S256" + ) state: Optional[str] = Field(None, description="Optional state parameter") - scope: Optional[str] = Field(None, description="Optional scope; if specified, should be a space-separated list of scope strings") - + scope: Optional[str] = Field( + None, + description="Optional scope; if specified, should be a space-separated list of scope strings", + ) + class Config: extra = "ignore" -def validate_scope(requested_scope: str | None, client: OAuthClientInformationFull) -> list[str] | None: + +def validate_scope( + requested_scope: str | None, client: OAuthClientInformationFull +) -> list[str] | None: if requested_scope is None: return None requested_scopes = requested_scope.split(" ") @@ -53,7 +61,10 @@ def validate_scope(requested_scope: str | None, client: OAuthClientInformationFu raise InvalidRequestError(f"Client was not registered with scope {scope}") return requested_scopes -def validate_redirect_uri(auth_request: AuthorizationRequest, client: OAuthClientInformationFull) -> AnyHttpUrl: + +def validate_redirect_uri( + auth_request: AuthorizationRequest, client: OAuthClientInformationFull +) -> AnyHttpUrl: if auth_request.redirect_uri is not None: # Validate redirect_uri against client's registered redirect URIs if auth_request.redirect_uri not in client.redirect_uris: @@ -64,16 +75,19 @@ def validate_redirect_uri(auth_request: AuthorizationRequest, client: OAuthClien elif len(client.redirect_uris) == 1: return client.redirect_uris[0] else: - raise InvalidRequestError("redirect_uri must be specified when client has multiple registered URIs") + raise InvalidRequestError( + "redirect_uri must be specified when client has multiple registered URIs" + ) + def create_authorization_handler(provider: OAuthServerProvider) -> Callable: """ Create a handler for the OAuth 2.0 Authorization endpoint. - + Corresponds to authorizationHandler in src/server/auth/handlers/authorize.ts """ - + async def authorization_handler(request: Request) -> Response: """ Handler for the OAuth 2.0 Authorization endpoint. @@ -91,74 +105,79 @@ async def authorization_handler(request: Request) -> Response: auth_request = AuthorizationRequest.model_validate(params) except ValidationError as e: raise InvalidRequestError(str(e)) - + # Get client information try: client = await provider.clients_store.get_client(auth_request.client_id) except OAuthError as e: # TODO: proper error rendering raise InvalidClientError(str(e)) - + if not client: raise InvalidClientError(f"Client ID '{auth_request.client_id}' not found") - - + # do validation which is dependent on the client configuration redirect_uri = validate_redirect_uri(auth_request, client) scopes = validate_scope(auth_request.scope, client) - + auth_params = AuthorizationParams( state=auth_request.state, scopes=scopes, code_challenge=auth_request.code_challenge, redirect_uri=redirect_uri, ) - + try: # Let the provider handle the authorization flow - authorization_code = await provider.create_authorization_code(client, auth_params) - response = RedirectResponse(url="", status_code=302, headers={"Cache-Control": "no-store"}) - + authorization_code = await provider.create_authorization_code( + client, auth_params + ) + response = RedirectResponse( + url="", status_code=302, headers={"Cache-Control": "no-store"} + ) + # Redirect with code parsed_uri = urlparse(str(auth_params.redirect_uri)) query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] query_params.append(("code", authorization_code)) if auth_params.state: query_params.append(("state", auth_params.state)) - - redirect_url = urlunparse(parsed_uri._replace(query=urlencode(query_params))) + + redirect_url = urlunparse( + parsed_uri._replace(query=urlencode(query_params)) + ) response.headers["location"] = redirect_url - + return response except Exception as e: return RedirectResponse( url=create_error_redirect(redirect_uri, e, auth_request.state), status_code=302, headers={"Cache-Control": "no-store"}, - ) - + ) + return authorization_handler -def create_error_redirect(redirect_uri: AnyUrl, error: Exception, state: Optional[str]) -> str: + +def create_error_redirect( + redirect_uri: AnyUrl, error: Exception, state: Optional[str] +) -> str: parsed_uri = urlparse(str(redirect_uri)) if isinstance(error, OAuthError): - query_params = { - "error": error.error_code, - "error_description": str(error) - } + query_params = {"error": error.error_code, "error_description": str(error)} else: query_params = { "error": "internal_error", - "error_description": "An unknown error occurred" + "error_description": "An unknown error occurred", } # TODO: should we add error_uri? # if error.error_uri: # query_params["error_uri"] = str(error.error_uri) if state: query_params["state"] = state - + new_query = urlencode(query_params) if parsed_uri.query: new_query = f"{parsed_uri.query}&{new_query}" - - return urlunparse(parsed_uri._replace(query=new_query)) \ No newline at end of file + + return urlunparse(parsed_uri._replace(query=new_query)) diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index 2c2ca26507..11a9c904de 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -4,7 +4,7 @@ Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts """ -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -13,32 +13,32 @@ def create_metadata_handler(metadata: Dict[str, Any]) -> Callable: """ Create a handler for OAuth 2.0 Authorization Server Metadata. - + Corresponds to metadataHandler in src/server/auth/handlers/metadata.ts - + Args: metadata: The metadata to return in the response - + Returns: A Starlette endpoint handler function """ - + async def metadata_handler(request: Request) -> Response: """ Handler for the OAuth 2.0 Authorization Server Metadata endpoint. - + Args: request: The Starlette request - + Returns: JSON response with the authorization server metadata """ # Remove any None values from metadata clean_metadata = {k: v for k, v in metadata.items() if v is not None} - + return JSONResponse( content=clean_metadata, - headers={"Cache-Control": "public, max-age=3600"} # Cache for 1 hour + headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour ) - - return metadata_handler \ No newline at end of file + + return metadata_handler diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 150e048e69..0437a7abaa 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -4,47 +4,48 @@ Corresponds to TypeScript file: src/server/auth/handlers/register.ts """ -import random import secrets import time -from typing import Any, Callable, Dict, List, Optional +from typing import Callable from uuid import uuid4 +from pydantic import ValidationError from starlette.requests import Request from starlette.responses import JSONResponse, Response -from pydantic import ValidationError from mcp.server.auth.errors import ( InvalidRequestError, - ServerError, OAuthError, + ServerError, ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata -def create_registration_handler(clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None) -> Callable: +def create_registration_handler( + clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None +) -> Callable: """ Create a handler for OAuth 2.0 Dynamic Client Registration. - + Corresponds to clientRegistrationHandler in src/server/auth/handlers/register.ts - + Args: clients_store: The store for registered clients client_secret_expiry_seconds: Optional expiry time for client secrets - + Returns: A Starlette endpoint handler function """ - + async def registration_handler(request: Request) -> Response: """ Handler for the OAuth 2.0 Dynamic Client Registration endpoint. - + Args: request: The Starlette request - + Returns: JSON response with client information or error """ @@ -55,7 +56,7 @@ async def registration_handler(request: Request) -> Response: client_metadata = OAuthClientMetadata.model_validate(body) except ValidationError as e: raise InvalidRequestError(f"Invalid client metadata: {str(e)}") - + client_id = str(uuid4()) client_secret = None if client_metadata.token_endpoint_auth_method != "none": @@ -63,7 +64,11 @@ async def registration_handler(request: Request) -> Response: client_secret = secrets.token_hex(32) client_id_issued_at = int(time.time()) - client_secret_expires_at = client_id_issued_at + client_secret_expiry_seconds if client_secret_expiry_seconds is not None else None + client_secret_expires_at = ( + client_id_issued_at + client_secret_expiry_seconds + if client_secret_expiry_seconds is not None + else None + ) client_info = OAuthClientInformationFull( client_id=client_id, @@ -91,19 +96,13 @@ async def registration_handler(request: Request) -> Response: client = await clients_store.register_client(client_info) if not client: raise ServerError("Failed to register client") - + # Return client information - return PydanticJSONResponse( - content=client, - status_code=201 - ) - + return PydanticJSONResponse(content=client, status_code=201) + except OAuthError as e: # Handle OAuth errors status_code = 500 if isinstance(e, ServerError) else 400 - return JSONResponse( - status_code=status_code, - content=e.to_response_object() - ) - - return registration_handler \ No newline at end of file + return JSONResponse(status_code=status_code, content=e.to_response_object()) + + return registration_handler diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 6280e71c97..7aa09fa03c 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -4,61 +4,67 @@ Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts """ -from typing import Any, Callable, Dict, Optional +from typing import Callable +from pydantic import ValidationError from starlette.requests import Request from starlette.responses import Response -from pydantic import ValidationError from mcp.server.auth.errors import ( InvalidRequestError, - ServerError, - OAuthError, ) -from mcp.server.auth.middleware import client_auth +from mcp.server.auth.middleware.client_auth import ( + ClientAuthenticator, + ClientAuthRequest, +) from mcp.server.auth.provider import OAuthServerProvider -from mcp.shared.auth import OAuthClientInformationFull, OAuthTokenRevocationRequest -from mcp.server.auth.middleware.client_auth import ClientAuthRequest, ClientAuthenticator +from mcp.shared.auth import OAuthTokenRevocationRequest + class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): pass -def create_revocation_handler(provider: OAuthServerProvider, client_authenticator: ClientAuthenticator) -> Callable: + +def create_revocation_handler( + provider: OAuthServerProvider, client_authenticator: ClientAuthenticator +) -> Callable: """ Create a handler for OAuth 2.0 Token Revocation. - + Corresponds to revocationHandler in src/server/auth/handlers/revoke.ts - + Args: provider: The OAuth server provider - + Returns: A Starlette endpoint handler function """ - + async def revocation_handler(request: Request) -> Response: """ Handler for the OAuth 2.0 Token Revocation endpoint. """ try: - revocation_request = RevocationRequest.model_validate_json(await request.body()) + revocation_request = RevocationRequest.model_validate_json( + await request.body() + ) except ValidationError as e: raise InvalidRequestError(f"Invalid request body: {e}") - + # Authenticate client client_auth_result = await client_authenticator(revocation_request) - + # Revoke token if provider.revoke_token: await provider.revoke_token(client_auth_result, revocation_request) - + # Return successful empty response return Response( status_code=200, headers={ "Cache-Control": "no-store", "Pragma": "no-cache", - } + }, ) - - return revocation_handler \ No newline at end of file + + return revocation_handler diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 9b092ccc7f..c5745f977e 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -6,72 +6,79 @@ import base64 import hashlib -import json -from typing import Annotated, Any, Callable, Dict, List, Literal, Optional, Union +from typing import Annotated, Callable, Literal, Optional, Union +from pydantic import Field, RootModel, ValidationError from starlette.requests import Request -from starlette.responses import JSONResponse -from pydantic import BaseModel, Field, RootModel, TypeAdapter, ValidationError from mcp.server.auth.errors import ( - InvalidClientError, - InvalidGrantError, InvalidRequestError, - ServerError, - UnsupportedGrantTypeError, - OAuthError, ) -from mcp.server.auth.provider import OAuthServerProvider -from mcp.shared.auth import OAuthClientInformationFull, OAuthTokens -from mcp.server.auth.middleware.client_auth import ClientAuthRequest, ClientAuthenticator from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.middleware.client_auth import ( + ClientAuthenticator, + ClientAuthRequest, +) +from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthTokens + class AuthorizationCodeRequest(ClientAuthRequest): """ Model for the authorization code grant request parameters. - + Corresponds to AuthorizationCodeExchangeSchema in src/server/auth/handlers/token.ts """ + grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") code_verifier: str = Field(..., description="PKCE code verifier") + class RefreshTokenRequest(ClientAuthRequest): """ Model for the refresh token grant request parameters. - + Corresponds to RefreshTokenExchangeSchema in src/server/auth/handlers/token.ts """ + grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: Optional[str] = Field(None, description="Optional scope parameter") class TokenRequest(RootModel): - root: Annotated[Union[AuthorizationCodeRequest, RefreshTokenRequest], Field(discriminator="grant_type")] + root: Annotated[ + Union[AuthorizationCodeRequest, RefreshTokenRequest], + Field(discriminator="grant_type"), + ] + + # TokenRequest = RootModel(Annotated[Union[AuthorizationCodeRequest, RefreshTokenRequest], Field(discriminator="grant_type")]) -def create_token_handler(provider: OAuthServerProvider, client_authenticator: ClientAuthenticator) -> Callable: +def create_token_handler( + provider: OAuthServerProvider, client_authenticator: ClientAuthenticator +) -> Callable: """ Create a handler for the OAuth 2.0 Token endpoint. - + Corresponds to tokenHandler in src/server/auth/handlers/token.ts - + Args: provider: The OAuth server provider - + Returns: A Starlette endpoint handler function """ - + async def token_handler(request: Request): """ Handler for the OAuth 2.0 Token endpoint. - + Args: request: The Starlette request - + Returns: JSON response with tokens or error """ @@ -83,9 +90,9 @@ async def token_handler(request: Request): except ValidationError as e: raise InvalidRequestError(f"Invalid request body: {e}") client_info = await client_authenticator(token_request) - + tokens: OAuthTokens - + match token_request: case AuthorizationCodeRequest(): # TODO: verify that the redirect URIs match; does the client actually provide this? @@ -98,34 +105,36 @@ async def token_handler(request: Request): ) if expected_challenge is None: raise InvalidRequestError("Invalid authorization code") - + # Calculate challenge from verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() actual_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=") - + if actual_challenge != expected_challenge: - raise InvalidRequestError("code_verifier does not match the challenge") - + raise InvalidRequestError( + "code_verifier does not match the challenge" + ) + # Exchange authorization code for tokens - tokens = await provider.exchange_authorization_code(client_info, token_request.code) - + tokens = await provider.exchange_authorization_code( + client_info, token_request.code + ) + case RefreshTokenRequest(): # Parse scopes if provided scopes = token_request.scope.split(" ") if token_request.scope else None - + # Exchange refresh token for new tokens tokens = await provider.exchange_refresh_token( client_info, token_request.refresh_token, scopes ) - return PydanticJSONResponse( content=tokens, headers={ "Cache-Control": "no-store", "Pragma": "no-cache", - } + }, ) - - - return token_handler \ No newline at end of file + + return token_handler diff --git a/src/mcp/server/auth/json_response.py b/src/mcp/server/auth/json_response.py index 7dc39bcaac..25971cc916 100644 --- a/src/mcp/server/auth/json_response.py +++ b/src/mcp/server/auth/json_response.py @@ -1,6 +1,8 @@ from typing import Any + from starlette.responses import JSONResponse + class PydanticJSONResponse(JSONResponse): def render(self, content: Any) -> bytes: - return content.model_dump_json(exclude_none=True).encode("utf-8") \ No newline at end of file + return content.model_dump_json(exclude_none=True).encode("utf-8") diff --git a/src/mcp/server/auth/middleware/__init__.py b/src/mcp/server/auth/middleware/__init__.py index 60de91e41f..ba3ff63c34 100644 --- a/src/mcp/server/auth/middleware/__init__.py +++ b/src/mcp/server/auth/middleware/__init__.py @@ -1,3 +1,3 @@ """ Middleware for MCP authorization. -""" \ No newline at end of file +""" diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6a023f3215..bfa15996fa 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -5,12 +5,15 @@ """ import time -from typing import List, Optional, Callable, Awaitable, cast, Dict, Any +from typing import Any, Callable -from starlette.requests import HTTPConnection, Request +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + SimpleUser, +) from starlette.exceptions import HTTPException -from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError, BaseUser, SimpleUser, UnauthenticatedUser, has_required_scope -from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import HTTPConnection from starlette.types import Scope from mcp.server.auth.errors import InsufficientScopeError, InvalidTokenError, OAuthError @@ -20,7 +23,7 @@ class AuthenticatedUser(SimpleUser): """User with authentication info.""" - + def __init__(self, auth_info: AuthInfo): super().__init__(auth_info.user_id or "anonymous") self.auth_info = auth_info @@ -31,33 +34,32 @@ class BearerAuthBackend(AuthenticationBackend): """ Authentication backend that validates Bearer tokens. """ - + def __init__( self, provider: OAuthServerProvider, ): self.provider = provider - - async def authenticate(self, conn: HTTPConnection): + async def authenticate(self, conn: HTTPConnection): if "Authorization" not in conn.headers: return None - + auth_header = conn.headers["Authorization"] if not auth_header.startswith("Bearer "): return None - + token = auth_header[7:] # Remove "Bearer " prefix - + try: # Validate the token with the provider auth_info = await self.provider.verify_access_token(token) if auth_info.expires_at and auth_info.expires_at < int(time.time()): raise InvalidTokenError("Token has expired") - + return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) - + except (InvalidTokenError, InsufficientScopeError, OAuthError): # Return None to indicate authentication failure return None @@ -66,21 +68,17 @@ async def authenticate(self, conn: HTTPConnection): class RequireAuthMiddleware: """ Middleware that requires a valid Bearer token in the Authorization header. - - This will validate the token with the auth provider and store the resulting + + This will validate the token with the auth provider and store the resulting auth info in the request state. - + Corresponds to bearerAuthMiddleware in src/server/auth/middleware/bearerAuth.ts """ - - def __init__( - self, - app: Any, - required_scopes: list[str] - ): + + def __init__(self, app: Any, required_scopes: list[str]): """ Initialize the middleware. - + Args: app: ASGI application provider: Authentication provider to validate tokens @@ -90,11 +88,14 @@ def __init__( self.required_scopes = required_scopes async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: - auth_credentials = scope.get('auth') - + auth_credentials = scope.get("auth") + for required_scope in self.required_scopes: # auth_credentials should always be provided; this is just paranoia - if auth_credentials is None or required_scope not in auth_credentials.scopes: + if ( + auth_credentials is None + or required_scope not in auth_credentials.scopes + ): raise HTTPException(status_code=403, detail="Insufficient scope") - await self.app(scope, receive, send) \ No newline at end of file + await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 9aab1d3c12..33130bf677 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -5,17 +5,14 @@ """ import time -from typing import Optional, Dict, Any, Callable +from typing import Any, Callable, Dict, Optional -from starlette.requests import Request +from pydantic import BaseModel from starlette.exceptions import HTTPException -from pydantic import BaseModel, ValidationError +from starlette.requests import Request from mcp.server.auth.errors import ( InvalidClientError, - InvalidRequestError, - OAuthError, - ServerError, ) from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull @@ -24,9 +21,10 @@ class ClientAuthRequest(BaseModel): """ Model for client authentication request body. - + Corresponds to ClientAuthenticatedRequestSchema in src/server/auth/middleware/clientAuth.ts """ + client_id: str client_secret: Optional[str] = None @@ -34,51 +32,52 @@ class ClientAuthRequest(BaseModel): class ClientAuthenticator: """ Dependency that authenticates a client using client_id and client_secret. - + This is a callable that can be used to validate client credentials in a request. - + Corresponds to authenticateClient in src/server/auth/middleware/clientAuth.ts """ - + def __init__(self, clients_store: OAuthRegisteredClientsStore): """ Initialize the dependency. - + Args: clients_store: Store to look up client information """ self.clients_store = clients_store - + async def __call__(self, request: ClientAuthRequest) -> OAuthClientInformationFull: # Look up client information client = await self.clients_store.get_client(request.client_id) if not client: raise InvalidClientError("Invalid client_id") - + # If client from the store expects a secret, validate that the request provides that secret if client.client_secret: if not request.client_secret: raise InvalidClientError("Client secret is required") - + if client.client_secret != request.client_secret: raise InvalidClientError("Invalid client_secret") - - if (client.client_secret_expires_at and - client.client_secret_expires_at < int(time.time())): + + if ( + client.client_secret_expires_at + and client.client_secret_expires_at < int(time.time()) + ): raise InvalidClientError("Client secret has expired") - + return client - class ClientAuthMiddleware: """ Middleware that authenticates clients using client_id and client_secret. - + This middleware will validate client credentials and store client information in the request state. """ - + def __init__( self, app: Any, @@ -86,18 +85,18 @@ def __init__( ): """ Initialize the middleware. - + Args: app: ASGI application clients_store: Store for client information """ self.app = app self.client_auth = ClientAuthenticator(clients_store) - + async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None: """ Process the request and authenticate the client. - + Args: scope: ASGI scope receive: ASGI receive function @@ -106,10 +105,10 @@ async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None if scope["type"] != "http": await self.app(scope, receive, send) return - + # Create a request object to access the request data request = Request(scope, receive=receive) - + # Add client authentication to the request try: client = await self.client_auth(request) @@ -118,6 +117,6 @@ async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None except HTTPException: # Continue without authentication pass - + # Continue processing the request - await self.app(scope, receive, send) \ No newline at end of file + await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 5b30734d6b..c9c2ae63bd 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -4,20 +4,25 @@ Corresponds to TypeScript file: src/server/auth/provider.ts """ -from typing import Any, Dict, List, Optional, Protocol +from typing import List, Optional, Protocol + from pydantic import AnyHttpUrl, BaseModel -from starlette.responses import Response -from mcp.shared.auth import OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens from mcp.server.auth.types import AuthInfo +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthTokenRevocationRequest, + OAuthTokens, +) class AuthorizationParams(BaseModel): """ Parameters for the authorization flow. - + Corresponds to AuthorizationParams in src/server/auth/provider.ts """ + state: Optional[str] = None scopes: Optional[List[str]] = None code_challenge: str @@ -27,31 +32,31 @@ class AuthorizationParams(BaseModel): class OAuthRegisteredClientsStore(Protocol): """ Interface for storing and retrieving registered OAuth clients. - + Corresponds to OAuthRegisteredClientsStore in src/server/auth/clients.ts """ - + async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: """ Retrieves client information by client ID. - + Args: client_id: The ID of the client to retrieve. - + Returns: The client information, or None if the client does not exist. """ ... - - async def register_client(self, - client_info: OAuthClientInformationFull - ) -> Optional[OAuthClientInformationFull]: + + async def register_client( + self, client_info: OAuthClientInformationFull + ) -> Optional[OAuthClientInformationFull]: """ Registers a new client and returns client information. - + Args: metadata: The client metadata to register. - + Returns: The client information, or None if registration failed. """ @@ -61,20 +66,20 @@ async def register_client(self, class OAuthServerProvider(Protocol): """ Implements an end-to-end OAuth server. - + Corresponds to OAuthServerProvider in src/server/auth/provider.ts """ - + @property def clients_store(self) -> OAuthRegisteredClientsStore: """ A store used to read information about registered OAuth clients. """ ... - - async def create_authorization_code(self, - client: OAuthClientInformationFull, - params: AuthorizationParams) -> str: + + async def create_authorization_code( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: """ Generates and stores an authorization code as part of completing the /authorize OAuth step. @@ -83,78 +88,80 @@ async def create_authorization_code(self, See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. """ ... - - async def challenge_for_authorization_code(self, - client: OAuthClientInformationFull, - authorization_code: str) -> str | None: + + async def challenge_for_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> str | None: """ Returns the code_challenge that was used when the indicated authorization began. - + Args: client: The client that requested the authorization code. authorization_code: The authorization code to get the challenge for. - + Returns: The code challenge that was used when the authorization began. """ ... - - async def exchange_authorization_code(self, - client: OAuthClientInformationFull, - authorization_code: str) -> OAuthTokens: + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> OAuthTokens: """ Exchanges an authorization code for an access token. - + Args: client: The client exchanging the authorization code. authorization_code: The authorization code to exchange. - + Returns: The access and refresh tokens. """ ... - - async def exchange_refresh_token(self, - client: OAuthClientInformationFull, - refresh_token: str, - scopes: Optional[List[str]] = None) -> OAuthTokens: + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: str, + scopes: Optional[List[str]] = None, + ) -> OAuthTokens: """ Exchanges a refresh token for an access token. - + Args: client: The client exchanging the refresh token. refresh_token: The refresh token to exchange. scopes: Optional scopes to request with the new access token. - + Returns: The new access and refresh tokens. """ ... # TODO: consider methods to generate refresh tokens and access tokens - + async def verify_access_token(self, token: str) -> AuthInfo: """ Verifies an access token and returns information about it. - + Args: token: The access token to verify. - + Returns: Information about the verified token. """ ... - - async def revoke_token(self, - client: OAuthClientInformationFull, - request: OAuthTokenRevocationRequest) -> None: + + async def revoke_token( + self, client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest + ) -> None: """ Revokes an access or refresh token. - + If the given token is invalid or already revoked, this method should do nothing. - + Args: client: The client revoking the token. request: The token revocation request. """ - ... \ No newline at end of file + ... diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 07f703b32f..4dfa8e6aee 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -5,29 +5,27 @@ """ from dataclasses import dataclass -import re -from typing import Dict, List, Optional, Any, Union, Callable -from urllib.parse import urlparse +from typing import Any, Dict, Optional +from pydantic import AnyUrl from starlette.routing import Route, Router -from starlette.requests import Request -from starlette.middleware import Middleware -from pydantic import AnyUrl, BaseModel -from mcp.server.auth.middleware.client_auth import ClientAuthMiddleware, ClientAuthenticator -from mcp.server.auth.provider import OAuthServerProvider -from mcp.shared.auth import OAuthMetadata -from mcp.server.auth.handlers.metadata import create_metadata_handler from mcp.server.auth.handlers.authorize import create_authorization_handler -from mcp.server.auth.handlers.token import create_token_handler +from mcp.server.auth.handlers.metadata import create_metadata_handler from mcp.server.auth.handlers.revoke import create_revocation_handler +from mcp.server.auth.handlers.token import create_token_handler +from mcp.server.auth.middleware.client_auth import ( + ClientAuthenticator, +) +from mcp.server.auth.provider import OAuthServerProvider @dataclass class ClientRegistrationOptions: enabled: bool = False client_secret_expiry_seconds: Optional[int] = None - + + @dataclass class RevocationOptions: enabled: bool = False @@ -36,20 +34,22 @@ class RevocationOptions: def validate_issuer_url(url: AnyUrl): """ Validate that the issuer URL meets OAuth 2.0 requirements. - + Args: url: The issuer URL to validate - + Raises: ValueError: If the issuer URL is invalid """ - + # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing - if (url.scheme != "https" and - url.host != "localhost" and - not (url.host is not None and url.host.startswith("127.0.0.1"))): + if ( + url.scheme != "https" + and url.host != "localhost" + and not (url.host is not None and url.host.startswith("127.0.0.1")) + ): raise ValueError("Issuer URL must be HTTPS") - + # No fragments or query parameters allowed if url.fragment: raise ValueError("Issuer URL must not have a fragment") @@ -64,31 +64,33 @@ def validate_issuer_url(url: AnyUrl): def create_auth_router( - provider: OAuthServerProvider, - issuer_url: AnyUrl, - service_documentation_url: AnyUrl | None = None, - client_registration_options: ClientRegistrationOptions | None = None, - revocation_options: RevocationOptions | None = None - ) -> Router: + provider: OAuthServerProvider, + issuer_url: AnyUrl, + service_documentation_url: AnyUrl | None = None, + client_registration_options: ClientRegistrationOptions | None = None, + revocation_options: RevocationOptions | None = None, +) -> Router: """ Create a Starlette router with standard MCP authorization endpoints. - + Corresponds to mcpAuthRouter in src/server/auth/router.ts - + Args: provider: OAuth server provider issuer_url: Issuer URL for the authorization server service_documentation_url: Optional URL for service documentation client_registration_options: Options for client registration revocation_options: Options for token revocation - + Returns: Starlette router with authorization endpoints """ validate_issuer_url(issuer_url) - - client_registration_options = client_registration_options or ClientRegistrationOptions() + + client_registration_options = ( + client_registration_options or ClientRegistrationOptions() + ) revocation_options = revocation_options or RevocationOptions() metadata = build_metadata( issuer_url, @@ -97,80 +99,76 @@ def create_auth_router( revocation_options, ) client_authenticator = ClientAuthenticator(provider.clients_store) - + # Create routes - auth_router = Router(routes=[ - Route( - "/.well-known/oauth-authorization-server", - endpoint=create_metadata_handler(metadata), - methods=["GET"] - ), - Route( - AUTHORIZATION_PATH, - endpoint=create_authorization_handler(provider), - methods=["GET", "POST"] - ), - Route( - TOKEN_PATH, - endpoint=create_token_handler(provider, client_authenticator), - methods=["POST"] - ) - ]) - + auth_router = Router( + routes=[ + Route( + "/.well-known/oauth-authorization-server", + endpoint=create_metadata_handler(metadata), + methods=["GET"], + ), + Route( + AUTHORIZATION_PATH, + endpoint=create_authorization_handler(provider), + methods=["GET", "POST"], + ), + Route( + TOKEN_PATH, + endpoint=create_token_handler(provider, client_authenticator), + methods=["POST"], + ), + ] + ) + if client_registration_options.enabled: from mcp.server.auth.handlers.register import create_registration_handler + registration_handler = create_registration_handler( provider.clients_store, client_secret_expiry_seconds=client_registration_options.client_secret_expiry_seconds, ) auth_router.routes.append( - Route( - REGISTRATION_PATH, - endpoint=registration_handler, - methods=["POST"] - ) + Route(REGISTRATION_PATH, endpoint=registration_handler, methods=["POST"]) ) - + if revocation_options.enabled: revocation_handler = create_revocation_handler(provider, client_authenticator) auth_router.routes.append( - Route( - REVOCATION_PATH, - endpoint=revocation_handler, - methods=["POST"] - ) + Route(REVOCATION_PATH, endpoint=revocation_handler, methods=["POST"]) ) - + return auth_router + def build_metadata( - issuer_url: AnyUrl, - service_documentation_url: Optional[AnyUrl], - client_registration_options: ClientRegistrationOptions, - revocation_options: RevocationOptions, - ) -> Dict[str, Any]: + issuer_url: AnyUrl, + service_documentation_url: Optional[AnyUrl], + client_registration_options: ClientRegistrationOptions, + revocation_options: RevocationOptions, +) -> Dict[str, Any]: issuer_url_str = str(issuer_url).rstrip("/") # Create metadata metadata = { "issuer": issuer_url_str, - "service_documentation": str(service_documentation_url).rstrip("/") if service_documentation_url else None, - + "service_documentation": str(service_documentation_url).rstrip("/") + if service_documentation_url + else None, "authorization_endpoint": f"{issuer_url_str}{AUTHORIZATION_PATH}", "response_types_supported": ["code"], "code_challenge_methods_supported": ["S256"], - "token_endpoint": f"{issuer_url_str}{TOKEN_PATH}", "token_endpoint_auth_methods_supported": ["client_secret_post"], "grant_types_supported": ["authorization_code", "refresh_token"], } - + # Add registration endpoint if supported if client_registration_options.enabled: metadata["registration_endpoint"] = f"{issuer_url_str}{REGISTRATION_PATH}" - + # Add revocation endpoint if supported if revocation_options.enabled: metadata["revocation_endpoint"] = f"{issuer_url_str}{REVOCATION_PATH}" metadata["revocation_endpoint_auth_methods_supported"] = ["client_secret_post"] - return metadata \ No newline at end of file + return metadata diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py index 494a4c30b0..3edc4cb93c 100644 --- a/src/mcp/server/auth/types.py +++ b/src/mcp/server/auth/types.py @@ -5,20 +5,22 @@ """ from typing import List, Optional + from pydantic import BaseModel class AuthInfo(BaseModel): """ Information about a validated access token, provided to request handlers. - + Corresponds to AuthInfo in src/server/auth/types.ts """ + token: str client_id: str scopes: List[str] expires_at: Optional[int] = None user_id: Optional[str] = None - + class Config: - extra = "ignore" \ No newline at end of file + extra = "ignore" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index af3b41b797..8b0ae3b9dc 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -11,23 +11,25 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Callable, Generic, Literal, Optional, Sequence +from typing import Any, Callable, Generic, Literal, Sequence import anyio import pydantic_core -from starlette.applications import Starlette -from starlette.authentication import requires -from starlette.middleware.authentication import AuthenticationMiddleware -from sse_starlette import EventSourceResponse import uvicorn from pydantic import BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict +from sse_starlette import EventSourceResponse +from starlette.applications import Starlette +from starlette.authentication import requires +from starlette.middleware.authentication import AuthenticationMiddleware -from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware +from mcp.server.auth.middleware.bearer_auth import ( + BearerAuthBackend, + RequireAuthMiddleware, +) from mcp.server.auth.provider import OAuthServerProvider from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions -from mcp.server.auth.types import AuthInfo from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -98,13 +100,14 @@ class Settings(BaseSettings, Generic[LifespanResultT]): ) = Field(None, description="Lifespan context manager") auth_issuer_url: AnyUrl | None = Field(None, description="Auth issuer URL") - auth_service_documentation_url: AnyUrl | None = Field(None, description="Service documentation URL") + auth_service_documentation_url: AnyUrl | None = Field( + None, description="Service documentation URL" + ) auth_client_registration_options: ClientRegistrationOptions | None = None - auth_revocation_options: RevocationOptions | None = None + auth_revocation_options: RevocationOptions | None = None auth_required_scopes: list[str] | None = None - def lifespan_wrapper( app: FastMCP, lifespan: Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]], @@ -119,11 +122,11 @@ async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: class FastMCP: def __init__( - self, - name: str | None = None, - instructions: str | None = None, + self, + name: str | None = None, + instructions: str | None = None, auth_provider: OAuthServerProvider | None = None, - **settings: Any + **settings: Any, ): self.settings = Settings(**settings) @@ -482,16 +485,17 @@ async def run_stdio_async(self) -> None: def starlette_app(self) -> Starlette: """Run the server using SSE transport.""" from starlette.applications import Starlette - from starlette.routing import Mount, Route from starlette.middleware import Middleware - + from starlette.routing import Mount, Route + # Set up auth context and dependencies sse = SseServerTransport("/messages/") + async def handle_sse(request) -> EventSourceResponse: # Add client ID from auth context into request context if available request_meta = {} - + async with sse.connect_sse( request.scope, request.receive, request._send ) as streams: @@ -507,7 +511,7 @@ async def handle_sse(request) -> EventSourceResponse: middleware = [] required_scopes = self.settings.auth_required_scopes or [] auth_router = None - + # Add auth endpoints if auth provider is configured if self._auth_provider and self.settings.auth_issuer_url: from mcp.server.auth.router import create_auth_router @@ -518,7 +522,7 @@ async def handle_sse(request) -> EventSourceResponse: AuthenticationMiddleware, backend=BearerAuthBackend( provider=self._auth_provider, - ) + ), ) ] auth_router = create_auth_router( @@ -526,21 +530,28 @@ async def handle_sse(request) -> EventSourceResponse: issuer_url=self.settings.auth_issuer_url, service_documentation_url=self.settings.auth_service_documentation_url, client_registration_options=self.settings.auth_client_registration_options, - revocation_options=self.settings.auth_revocation_options + revocation_options=self.settings.auth_revocation_options, ) - + # Add the auth router as a mount - routes.append(Route("/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"])) - routes.append(Mount("/messages/", app=RequireAuthMiddleware(sse.handle_post_message, required_scopes))) + routes.append( + Route( + "/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"] + ) + ) + routes.append( + Mount( + "/messages/", + app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), + ) + ) if auth_router: routes.append(Mount("/", app=auth_router)) - + # Create Starlette app with routes and middleware return Starlette( - debug=self.settings.debug, - routes=routes, - middleware=middleware + debug=self.settings.debug, routes=routes, middleware=middleware ) async def run_sse_async(self) -> None: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 75c1f7302e..cd1b5502fa 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -34,7 +34,6 @@ async def handle_sse(request): import logging from contextlib import asynccontextmanager from typing import Any -from typing_extensions import deprecated from urllib.parse import quote from uuid import UUID, uuid4 @@ -45,7 +44,7 @@ async def handle_sse(request): from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send -from sse_starlette import EventSourceResponse +from typing_extensions import deprecated import mcp.types as types diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 3a65ad959a..97ac8f2149 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -4,8 +4,9 @@ Corresponds to TypeScript file: src/shared/auth.ts """ -from typing import Any, Dict, List, Optional, Union -from pydantic import AnyHttpUrl, BaseModel, Field, field_validator, model_validator +from typing import Any, List, Optional + +from pydantic import AnyHttpUrl, BaseModel, Field class OAuthErrorResponse(BaseModel): @@ -14,6 +15,7 @@ class OAuthErrorResponse(BaseModel): Corresponds to OAuthErrorResponseSchema in src/shared/auth.ts """ + error: str error_description: Optional[str] = None error_uri: Optional[AnyHttpUrl] = None @@ -25,6 +27,7 @@ class OAuthTokens(BaseModel): Corresponds to OAuthTokensSchema in src/shared/auth.ts """ + access_token: str token_type: str expires_in: Optional[int] = None @@ -38,6 +41,7 @@ class OAuthClientMetadata(BaseModel): Corresponds to OAuthClientMetadataSchema in src/shared/auth.ts """ + redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) token_endpoint_auth_method: Optional[str] = None grant_types: Optional[List[str]] = None @@ -61,6 +65,7 @@ class OAuthClientInformation(BaseModel): Corresponds to OAuthClientInformationSchema in src/shared/auth.ts """ + client_id: str client_secret: Optional[str] = None client_id_issued_at: Optional[int] = None @@ -74,6 +79,7 @@ class OAuthClientInformationFull(OAuthClientMetadata, OAuthClientInformation): Corresponds to OAuthClientInformationFullSchema in src/shared/auth.ts """ + pass @@ -83,6 +89,7 @@ class OAuthClientRegistrationError(BaseModel): Corresponds to OAuthClientRegistrationErrorSchema in src/shared/auth.ts """ + error: str error_description: Optional[str] = None @@ -93,6 +100,7 @@ class OAuthTokenRevocationRequest(BaseModel): Corresponds to OAuthTokenRevocationRequestSchema in src/shared/auth.ts """ + token: str token_type_hint: Optional[str] = None @@ -103,6 +111,7 @@ class OAuthMetadata(BaseModel): Corresponds to OAuthMetadataSchema in src/shared/auth.ts """ + issuer: str authorization_endpoint: str token_endpoint: str @@ -120,4 +129,4 @@ class OAuthMetadata(BaseModel): introspection_endpoint: Optional[str] = None introspection_endpoint_auth_methods_supported: Optional[List[str]] = None introspection_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None - code_challenge_methods_supported: Optional[List[str]] = None \ No newline at end of file + code_challenge_methods_supported: Optional[List[str]] = None diff --git a/tests/server/fastmcp/auth/__init__.py b/tests/server/fastmcp/auth/__init__.py index 304b8cd87a..64d318ec46 100644 --- a/tests/server/fastmcp/auth/__init__.py +++ b/tests/server/fastmcp/auth/__init__.py @@ -1,3 +1,3 @@ """ Tests for the MCP server auth components. -""" \ No newline at end of file +""" diff --git a/tests/server/fastmcp/auth/streaming_asgi_transport.py b/tests/server/fastmcp/auth/streaming_asgi_transport.py index 66774ba67b..bb54e46385 100644 --- a/tests/server/fastmcp/auth/streaming_asgi_transport.py +++ b/tests/server/fastmcp/auth/streaming_asgi_transport.py @@ -6,19 +6,15 @@ the connection is closed. """ +import asyncio import typing -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, Tuple import anyio import anyio.streams.memory -from anyio.abc import TaskStatus -import httpx -from httpx._transports.asgi import ASGIResponseStream -from httpx._transports.base import AsyncBaseTransport from httpx._models import Request, Response +from httpx._transports.base import AsyncBaseTransport from httpx._types import AsyncByteStream -import asyncio - class StreamingASGITransport(AsyncBaseTransport): @@ -89,7 +85,9 @@ async def handle_async_request( # Synchronization for streaming response asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream(100) - content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) + content_send_channel, content_receive_channel = ( + anyio.create_memory_object_stream[bytes](100) + ) # ASGI callables. async def receive() -> Dict[str, Any]: @@ -118,26 +116,22 @@ async def run_app() -> None: except Exception: if self.raise_app_exceptions: raise - + if not response_started: - await asgi_send_channel.send({ - "type": "http.response.start", - "status": 500, - "headers": [] - }) - - await asgi_send_channel.send({ - "type": "http.response.body", - "body": b"", - "more_body": False - }) + await asgi_send_channel.send( + {"type": "http.response.start", "status": 500, "headers": []} + ) + + await asgi_send_channel.send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) finally: await asgi_send_channel.aclose() # Process messages from the ASGI app async def process_messages() -> None: nonlocal status_code, response_headers, response_started - + try: async with asgi_receive_channel: async for message in asgi_receive_channel: @@ -146,7 +140,7 @@ async def process_messages() -> None: status_code = message["status"] response_headers = message.get("headers", []) response_started = True - + # As soon as we have headers, we can return a response initial_response_ready.set() @@ -169,29 +163,33 @@ async def process_messages() -> None: # Create tasks for running the app and processing messages app_task = asyncio.create_task(run_app()) process_task = asyncio.create_task(process_messages()) - + # Wait for the initial response or timeout await initial_response_ready.wait() # Create a streaming response - return Response(status_code, headers=response_headers, stream=StreamingASGIResponseStream(content_receive_channel)) + return Response( + status_code, + headers=response_headers, + stream=StreamingASGIResponseStream(content_receive_channel), + ) class StreamingASGIResponseStream(AsyncByteStream): """ A modified ASGIResponseStream that supports streaming responses. - + This class extends the standard ASGIResponseStream to handle cases where the response body continues to be generated after the initial response is returned. """ - + def __init__( - self, + self, receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], ) -> None: self.receive_channel = receive_channel - + async def __aiter__(self) -> typing.AsyncIterator[bytes]: async for chunk in self.receive_channel: - yield chunk \ No newline at end of file + yield chunk diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index a22c675dea..1728e915ab 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -7,35 +7,36 @@ import json import secrets import time -from typing import Any, Dict, List, Optional, cast -from urllib.parse import urlparse, parse_qs +from typing import List, Optional +from urllib.parse import parse_qs, urlparse -import anyio -from pydantic import AnyUrl -import pytest import httpx +import pytest from httpx_sse import aconnect_sse +from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.datastructures import MutableHeaders -from starlette.testclient import TestClient -from starlette.routing import Route, Router, Mount -from starlette.responses import RedirectResponse, JSONResponse, Response -from starlette.requests import Request -from starlette.middleware import Middleware -from starlette.types import ASGIApp +from starlette.routing import Mount from mcp.server.auth.errors import InvalidTokenError -from mcp.server.auth.middleware.client_auth import ClientAuthMiddleware -from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, OAuthRegisteredClientsStore -from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions, create_auth_router +from mcp.server.auth.provider import ( + AuthorizationParams, + OAuthRegisteredClientsStore, + OAuthServerProvider, +) +from mcp.server.auth.router import ( + ClientRegistrationOptions, + RevocationOptions, + create_auth_router, +) from mcp.server.auth.types import AuthInfo +from mcp.server.fastmcp import FastMCP from mcp.shared.auth import ( OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens, ) -from mcp.server.fastmcp import FastMCP from mcp.types import JSONRPCRequest + from .streaming_asgi_transport import StreamingASGITransport @@ -43,11 +44,13 @@ class MockClientStore: def __init__(self): self.clients = {} - + async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: return self.clients.get(client_id) - - async def register_client(self, client_info: OAuthClientInformationFull) -> OAuthClientInformationFull: + + async def register_client( + self, client_info: OAuthClientInformationFull + ) -> OAuthClientInformationFull: self.clients[client_info.client_id] = client_info return client_info @@ -59,17 +62,17 @@ def __init__(self): self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} self.tokens = {} # token -> {client_id, scopes, expires_at} self.refresh_tokens = {} # refresh_token -> access_token - + @property def clients_store(self) -> OAuthRegisteredClientsStore: return self.client_store - - async def create_authorization_code(self, - client: OAuthClientInformationFull, - params: AuthorizationParams) -> str: + + async def create_authorization_code( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: # Generate an authorization code code = f"code_{int(time.time())}" - + # Store the code for later verification self.auth_codes[code] = { "client_id": client.client_id, @@ -80,57 +83,56 @@ async def create_authorization_code(self, return code - - async def challenge_for_authorization_code(self, - client: OAuthClientInformationFull, - authorization_code: str) -> str: + async def challenge_for_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> str: # Get the stored code info code_info = self.auth_codes.get(authorization_code) if not code_info: raise InvalidTokenError("Invalid authorization code") - + # Check if code is expired if code_info["expires_at"] < int(time.time()): raise InvalidTokenError("Authorization code has expired") - + # Check if the code was issued to this client if code_info["client_id"] != client.client_id: raise InvalidTokenError("Authorization code was not issued to this client") - + return code_info["code_challenge"] - - async def exchange_authorization_code(self, - client: OAuthClientInformationFull, - authorization_code: str) -> OAuthTokens: + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> OAuthTokens: # Get the stored code info code_info = self.auth_codes.get(authorization_code) if not code_info: raise InvalidTokenError("Invalid authorization code") - + # Check if code is expired if code_info["expires_at"] < int(time.time()): raise InvalidTokenError("Authorization code has expired") - + # Check if the code was issued to this client if code_info["client_id"] != client.client_id: raise InvalidTokenError("Authorization code was not issued to this client") - + # Generate an access token and refresh token access_token = f"access_{secrets.token_hex(32)}" refresh_token = f"refresh_{secrets.token_hex(32)}" - + # Store the tokens self.tokens[access_token] = { "client_id": client.client_id, "scopes": ["read", "write"], "expires_at": int(time.time()) + 3600, } - + self.refresh_tokens[refresh_token] = access_token - + # Remove the used code del self.auth_codes[authorization_code] - + return OAuthTokens( access_token=access_token, token_type="bearer", @@ -138,44 +140,46 @@ async def exchange_authorization_code(self, scope="read write", refresh_token=refresh_token, ) - - async def exchange_refresh_token(self, - client: OAuthClientInformationFull, - refresh_token: str, - scopes: Optional[List[str]] = None) -> OAuthTokens: + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: str, + scopes: Optional[List[str]] = None, + ) -> OAuthTokens: # Check if refresh token exists if refresh_token not in self.refresh_tokens: raise InvalidTokenError("Invalid refresh token") - + # Get the access token for this refresh token old_access_token = self.refresh_tokens[refresh_token] - + # Check if the access token exists if old_access_token not in self.tokens: raise InvalidTokenError("Invalid refresh token") - + # Check if the token was issued to this client token_info = self.tokens[old_access_token] if token_info["client_id"] != client.client_id: raise InvalidTokenError("Refresh token was not issued to this client") - + # Generate a new access token and refresh token new_access_token = f"access_{secrets.token_hex(32)}" new_refresh_token = f"refresh_{secrets.token_hex(32)}" - + # Store the new tokens self.tokens[new_access_token] = { "client_id": client.client_id, "scopes": scopes or token_info["scopes"], "expires_at": int(time.time()) + 3600, } - + self.refresh_tokens[new_refresh_token] = new_access_token - + # Remove the old tokens del self.refresh_tokens[refresh_token] del self.tokens[old_access_token] - + return OAuthTokens( access_token=new_access_token, token_type="bearer", @@ -183,54 +187,54 @@ async def exchange_refresh_token(self, scope=" ".join(scopes) if scopes else " ".join(token_info["scopes"]), refresh_token=new_refresh_token, ) - + async def verify_access_token(self, token: str) -> AuthInfo: # Check if token exists if token not in self.tokens: raise InvalidTokenError("Invalid access token") - + # Get token info token_info = self.tokens[token] - + # Check if token is expired if token_info["expires_at"] < int(time.time()): raise InvalidTokenError("Access token has expired") - + return AuthInfo( token=token, client_id=token_info["client_id"], scopes=token_info["scopes"], expires_at=token_info["expires_at"], ) - - async def revoke_token(self, - client: OAuthClientInformationFull, - request: OAuthTokenRevocationRequest) -> None: + + async def revoke_token( + self, client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest + ) -> None: token = request.token - + # Check if it's a refresh token if token in self.refresh_tokens: access_token = self.refresh_tokens[token] - + # Check if this refresh token belongs to this client if self.tokens[access_token]["client_id"] != client.client_id: # For security reasons, we still return success return - + # Remove the refresh token and its associated access token del self.tokens[access_token] del self.refresh_tokens[token] - + # Check if it's an access token elif token in self.tokens: # Check if this access token belongs to this client if self.tokens[token]["client_id"] != client.client_id: # For security reasons, we still return success return - + # Remove the access token del self.tokens[token] - + # Also remove any refresh tokens that point to this access token for refresh_token, access_token in list(self.refresh_tokens.items()): if access_token == token: @@ -249,27 +253,22 @@ def auth_app(mock_oauth_provider): mock_oauth_provider, AnyUrl("https://auth.example.com"), AnyUrl("https://docs.example.com"), - client_registration_options=ClientRegistrationOptions( - enabled=True - ), - revocation_options=RevocationOptions( - enabled=True - ) + client_registration_options=ClientRegistrationOptions(enabled=True), + revocation_options=RevocationOptions(enabled=True), ) - + # Create Starlette app - app = Starlette( - routes=[ - Mount("/", app=auth_router) - ] - ) - + app = Starlette(routes=[Mount("/", app=auth_router)]) + return app @pytest.fixture def test_client(auth_app) -> httpx.AsyncClient: - return httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") + return httpx.AsyncClient( + transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com" + ) + class TestAuthEndpoints: @pytest.mark.anyio @@ -281,65 +280,78 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): if response.status_code != 200: print(f"Response content: {response.content}") assert response.status_code == 200 - + metadata = response.json() assert metadata["issuer"] == "https://auth.example.com" - assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize" + assert ( + metadata["authorization_endpoint"] == "https://auth.example.com/authorize" + ) assert metadata["token_endpoint"] == "https://auth.example.com/token" assert metadata["registration_endpoint"] == "https://auth.example.com/register" assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" assert metadata["response_types_supported"] == ["code"] assert metadata["code_challenge_methods_supported"] == ["S256"] - assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"] - assert metadata["grant_types_supported"] == ["authorization_code", "refresh_token"] + assert metadata["token_endpoint_auth_methods_supported"] == [ + "client_secret_post" + ] + assert metadata["grant_types_supported"] == [ + "authorization_code", + "refresh_token", + ] assert metadata["service_documentation"] == "https://docs.example.com" - + @pytest.mark.anyio - async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): + async def test_client_registration( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + ): """Test client registration.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", "client_uri": "https://client.example.com", } - + response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 201, response.content - + client_info = response.json() assert "client_id" in client_info assert "client_secret" in client_info assert client_info["client_name"] == "Test Client" assert client_info["redirect_uris"] == ["https://client.example.com/callback"] - + # Verify that the client was registered - #assert await mock_oauth_provider.clients_store.get_client(client_info["client_id"]) is not None - + # assert await mock_oauth_provider.clients_store.get_client(client_info["client_id"]) is not None + @pytest.mark.anyio - async def test_authorization_flow(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): + async def test_authorization_flow( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + ): """Test the full authorization flow.""" # 1. Register a client client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", } - + response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 201 client_info = response.json() - + # 2. Create a PKCE challenge code_verifier = "some_random_verifier_string" - code_challenge = base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode()).digest() - ).decode().rstrip("=") - + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + # 3. Request authorization response = await test_client.get( "/authorize", @@ -353,16 +365,16 @@ async def test_authorization_flow(self, test_client: httpx.AsyncClient, mock_oau }, ) assert response.status_code == 302 - + # 4. Extract the authorization code from the redirect URL redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "code" in query_params assert query_params["state"][0] == "test_state" auth_code = query_params["code"][0] - + # 5. Exchange the authorization code for tokens response = await test_client.post( "/token", @@ -375,24 +387,24 @@ async def test_authorization_flow(self, test_client: httpx.AsyncClient, mock_oau }, ) assert response.status_code == 200 - + token_response = response.json() assert "access_token" in token_response assert "token_type" in token_response assert "refresh_token" in token_response assert "expires_in" in token_response assert token_response["token_type"] == "bearer" - + # 6. Verify the access token access_token = token_response["access_token"] refresh_token = token_response["refresh_token"] - + # Create a test client with the token auth_info = await mock_oauth_provider.verify_access_token(access_token) assert auth_info.client_id == client_info["client_id"] assert "read" in auth_info.scopes assert "write" in auth_info.scopes - + # 7. Refresh the token response = await test_client.post( "/token", @@ -404,13 +416,13 @@ async def test_authorization_flow(self, test_client: httpx.AsyncClient, mock_oau }, ) assert response.status_code == 200 - + new_token_response = response.json() assert "access_token" in new_token_response assert "refresh_token" in new_token_response assert new_token_response["access_token"] != access_token assert new_token_response["refresh_token"] != refresh_token - + # 8. Revoke the token response = await test_client.post( "/revoke", @@ -421,15 +433,17 @@ async def test_authorization_flow(self, test_client: httpx.AsyncClient, mock_oau }, ) assert response.status_code == 200 - + # Verify that the token was revoked with pytest.raises(InvalidTokenError): - await mock_oauth_provider.verify_access_token(new_token_response["access_token"]) + await mock_oauth_provider.verify_access_token( + new_token_response["access_token"] + ) class TestFastMCPWithAuth: """Test FastMCP server with authentication.""" - + @pytest.mark.anyio async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider): """Test creating a FastMCP server with authentication.""" @@ -438,28 +452,26 @@ async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider): auth_provider=mock_oauth_provider, auth_issuer_url="https://auth.example.com", require_auth=True, - auth_client_registration_options=ClientRegistrationOptions( - enabled=True - ), - auth_revocation_options=RevocationOptions( - enabled=True - ), - auth_required_scopes=["read"] + auth_client_registration_options=ClientRegistrationOptions(enabled=True), + auth_revocation_options=RevocationOptions(enabled=True), + auth_required_scopes=["read"], ) - + # Add a test tool @mcp.tool() def test_tool(x: int) -> str: return f"Result: {x}" - - transport = StreamingASGITransport(app=mcp.starlette_app()) # pyright: ignore - test_client = httpx.AsyncClient(transport=transport, base_url="http://mcptest.com") + + transport = StreamingASGITransport(app=mcp.starlette_app()) # pyright: ignore + test_client = httpx.AsyncClient( + transport=transport, base_url="http://mcptest.com" + ) # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") - + # Test metadata endpoint response = await test_client.get("/.well-known/oauth-authorization-server") assert response.status_code == 200 - + # Test that auth is required for protected endpoints response = await test_client.get("/sse") # TODO: we should return 401/403 depending on whether authn or authz fails @@ -468,26 +480,28 @@ def test_tool(x: int) -> str: response = await test_client.post("/messages/") # TODO: we should return 401/403 depending on whether authn or authz fails assert response.status_code == 403, response.content - + # now, become authenticated and try to go through the flow again client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", } - + response = await test_client.post( "/register", json=client_metadata, ) assert response.status_code == 201 client_info = response.json() - + # Create a PKCE challenge code_verifier = "some_random_verifier_string" - code_challenge = base64.urlsafe_b64encode( - hashlib.sha256(code_verifier.encode()).digest() - ).decode().rstrip("=") - + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + # Request authorization response = await test_client.get( "/authorize", @@ -501,15 +515,15 @@ def test_tool(x: int) -> str: }, ) assert response.status_code == 302 - + # Extract the authorization code from the redirect URL redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "code" in query_params auth_code = query_params["code"][0] - + # Exchange the authorization code for tokens response = await test_client.post( "/token", @@ -522,21 +536,22 @@ def test_tool(x: int) -> str: }, ) assert response.status_code == 200 - + token_response = response.json() assert "access_token" in token_response authorization = f"Bearer {token_response['access_token']}" - # Test the authenticated endpoint with valid token - async with aconnect_sse(test_client, "GET", "/sse", headers={"Authorization": authorization}) as event_source: + async with aconnect_sse( + test_client, "GET", "/sse", headers={"Authorization": authorization} + ) as event_source: assert event_source.response.status_code == 200 events = event_source.aiter_sse() sse = await events.__anext__() assert sse.event == "endpoint" assert sse.data.startswith("/messages/?session_id=") messages_uri = sse.data - + # verify that we can now post to the /messages endpoint, and get a response on the /sse endpoint response = await test_client.post( messages_uri, @@ -548,15 +563,10 @@ def test_tool(x: int) -> str: params={ "protocolVersion": "2024-11-05", "capabilities": { - "roots": { - "listChanged": True - }, + "roots": {"listChanged": True}, "sampling": {}, }, - "clientInfo": { - "name": "ExampleClient", - "version": "1.0.0" - } + "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, }, ).model_dump_json(), ) @@ -566,5 +576,7 @@ def test_tool(x: int) -> str: sse = await events.__anext__() assert sse.event == "message" sse_data = json.loads(sse.data) - assert sse_data["id"] == '123' - assert set(sse_data["result"]["capabilities"].keys()) == set(("experimental", "prompts", "resources", "tools")) \ No newline at end of file + assert sse_data["id"] == "123" + assert set(sse_data["result"]["capabilities"].keys()) == set( + ("experimental", "prompts", "resources", "tools") + ) From 031cadff64ecde859f59b1a5ab19997ffe84979e Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 14:09:12 -0700 Subject: [PATCH 07/60] Clean up registration endpoint --- src/mcp/server/auth/handlers/authorize.py | 9 +--- src/mcp/server/auth/handlers/token.py | 44 +++---------------- src/mcp/server/auth/middleware/client_auth.py | 12 ++--- src/mcp/shared/auth.py | 20 ++++++--- .../fastmcp/auth/test_auth_integration.py | 1 + 5 files changed, 27 insertions(+), 59 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 76b2802465..a359456551 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -21,12 +21,6 @@ class AuthorizationRequest(BaseModel): - """ - Model for the authorization request parameters. - - Corresponds to request schema in authorizationHandler in src/server/auth/handlers/authorize.ts - """ - client_id: str = Field(..., description="The client ID") redirect_uri: AnyHttpUrl | None = Field( ..., description="URL to redirect to after authorization" @@ -42,7 +36,8 @@ class AuthorizationRequest(BaseModel): state: Optional[str] = Field(None, description="Optional state parameter") scope: Optional[str] = Field( None, - description="Optional scope; if specified, should be a space-separated list of scope strings", + description="Optional scope; if specified, should be " \ + "a space-separated list of scope strings", ) class Config: diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index c5745f977e..e5c37f7737 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -24,24 +24,13 @@ class AuthorizationCodeRequest(ClientAuthRequest): - """ - Model for the authorization code grant request parameters. - - Corresponds to AuthorizationCodeExchangeSchema in src/server/auth/handlers/token.ts - """ - grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") code_verifier: str = Field(..., description="PKCE code verifier") + # TODO: this should take redirect_uri class RefreshTokenRequest(ClientAuthRequest): - """ - Model for the refresh token grant request parameters. - - Corresponds to RefreshTokenExchangeSchema in src/server/auth/handlers/token.ts - """ - grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: Optional[str] = Field(None, description="Optional scope parameter") @@ -54,48 +43,25 @@ class TokenRequest(RootModel): ] -# TokenRequest = RootModel(Annotated[Union[AuthorizationCodeRequest, RefreshTokenRequest], Field(discriminator="grant_type")]) - def create_token_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator ) -> Callable: - """ - Create a handler for the OAuth 2.0 Token endpoint. - - Corresponds to tokenHandler in src/server/auth/handlers/token.ts - - Args: - provider: The OAuth server provider - - Returns: - A Starlette endpoint handler function - """ - async def token_handler(request: Request): - """ - Handler for the OAuth 2.0 Token endpoint. - - Args: - request: The Starlette request - - Returns: - JSON response with tokens or error - """ - # Parse request body as form data or JSON - content_type = request.headers.get("Content-Type", "") - try: token_request = TokenRequest.model_validate_json(await request.body()).root except ValidationError as e: raise InvalidRequestError(f"Invalid request body: {e}") client_info = await client_authenticator(token_request) + if token_request.grant_type not in client_info.grant_types: + raise InvalidRequestError(f"Unsupported grant type (supported grant types are {client_info.grant_types})") + tokens: OAuthTokens match token_request: case AuthorizationCodeRequest(): - # TODO: verify that the redirect URIs match; does the client actually provide this? + # TODO: verify that the redirect URIs match # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 # TODO: enforce TTL on the authorization code diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 33130bf677..524bcdf369 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -31,13 +31,13 @@ class ClientAuthRequest(BaseModel): class ClientAuthenticator: """ - Dependency that authenticates a client using client_id and client_secret. - - This is a callable that can be used to validate client credentials in a request. - - Corresponds to authenticateClient in src/server/auth/middleware/clientAuth.ts + ClientAuthenticator is a callable which validates requests from a client application, + used to verify /token and /revoke calls. + If, during registration, the client requested to be issued a secret, the authenticator + asserts that /token and /register calls must be authenticated with that same token. + NOTE: clients can opt for no authentication during registration, in which case this logic + is skipped. """ - def __init__(self, clients_store: OAuthRegisteredClientsStore): """ Initialize the dependency. diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 97ac8f2149..961a73acda 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -4,7 +4,7 @@ Corresponds to TypeScript file: src/shared/auth.ts """ -from typing import Any, List, Optional +from typing import Any, List, Literal, Optional from pydantic import AnyHttpUrl, BaseModel, Field @@ -38,18 +38,24 @@ class OAuthTokens(BaseModel): class OAuthClientMetadata(BaseModel): """ RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. - - Corresponds to OAuthClientMetadataSchema in src/shared/auth.ts + See https://datatracker.ietf.org/doc/html/rfc7591#section-2 + for the full specification. """ redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) - token_endpoint_auth_method: Optional[str] = None - grant_types: Optional[List[str]] = None - response_types: Optional[List[str]] = None + # token_endpoint_auth_method: this implementation only supports none & client_secret_basic; + # ie: we do not support client_secret_post + token_endpoint_auth_method: Literal["none", "client_secret_basic"] = "client_secret_basic" + # grant_types: this implementation only supports authorization_code & refresh_token + grant_types: List[Literal["authorization_code", "refresh_token"]] = ["authorization_code"] + # this implementation only supports code; ie: it does not support implicit grants + response_types: List[Literal["code"]] = ["code"] + scope: Optional[str] = None + + # these fields are currently unused, but we support & store them for potential future use client_name: Optional[str] = None client_uri: Optional[AnyHttpUrl] = None logo_uri: Optional[AnyHttpUrl] = None - scope: Optional[str] = None contacts: Optional[List[str]] = None tos_uri: Optional[AnyHttpUrl] = None policy_uri: Optional[AnyHttpUrl] = None diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 1728e915ab..0e2461784f 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -335,6 +335,7 @@ async def test_authorization_flow( client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", + "grant_types": ["authorization_code", "refresh_token"] } response = await test_client.post( From 765efb6a096ef187c103f5a95f5d3423885514b5 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 14:13:06 -0700 Subject: [PATCH 08/60] Lint --- src/mcp/server/auth/handlers/token.py | 5 ++++- src/mcp/server/auth/middleware/client_auth.py | 18 +++++++++++------- src/mcp/server/auth/provider.py | 6 ++++-- src/mcp/server/fastmcp/server.py | 1 - src/mcp/server/sse.py | 3 ++- src/mcp/shared/auth.py | 12 ++++++++---- .../fastmcp/auth/streaming_asgi_transport.py | 4 ++-- .../fastmcp/auth/test_auth_integration.py | 7 +++++-- 8 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index e5c37f7737..f564f39478 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -55,7 +55,10 @@ async def token_handler(request: Request): client_info = await client_authenticator(token_request) if token_request.grant_type not in client_info.grant_types: - raise InvalidRequestError(f"Unsupported grant type (supported grant types are {client_info.grant_types})") + raise InvalidRequestError( + f"Unsupported grant type (supported grant types are " + f"{client_info.grant_types})" + ) tokens: OAuthTokens diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 524bcdf369..f56e7f058f 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -22,7 +22,8 @@ class ClientAuthRequest(BaseModel): """ Model for client authentication request body. - Corresponds to ClientAuthenticatedRequestSchema in src/server/auth/middleware/clientAuth.ts + Corresponds to ClientAuthenticatedRequestSchema in + src/server/auth/middleware/clientAuth.ts """ client_id: str @@ -31,12 +32,14 @@ class ClientAuthRequest(BaseModel): class ClientAuthenticator: """ - ClientAuthenticator is a callable which validates requests from a client application, + ClientAuthenticator is a callable which validates requests from a client + application, used to verify /token and /revoke calls. - If, during registration, the client requested to be issued a secret, the authenticator - asserts that /token and /register calls must be authenticated with that same token. - NOTE: clients can opt for no authentication during registration, in which case this logic - is skipped. + If, during registration, the client requested to be issued a secret, the + authenticator asserts that /token and /register calls must be authenticated with + that same token. + NOTE: clients can opt for no authentication during registration, in which case this + logic is skipped. """ def __init__(self, clients_store: OAuthRegisteredClientsStore): """ @@ -53,7 +56,8 @@ async def __call__(self, request: ClientAuthRequest) -> OAuthClientInformationFu if not client: raise InvalidClientError("Invalid client_id") - # If client from the store expects a secret, validate that the request provides that secret + # If client from the store expects a secret, validate that the request provides + # that secret if client.client_secret: if not request.client_secret: raise InvalidClientError("Client secret is required") diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index c9c2ae63bd..437c6514da 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -81,9 +81,11 @@ async def create_authorization_code( self, client: OAuthClientInformationFull, params: AuthorizationParams ) -> str: """ - Generates and stores an authorization code as part of completing the /authorize OAuth step. + Generates and stores an authorization code as part of completing the /authorize + OAuth step. - Implementations SHOULD generate an authorization code with at least 160 bits of entropy, + Implementations SHOULD generate an authorization code with at least 160 bits of + entropy, and MUST generate an authorization code with at least 128 bits of entropy. See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. """ diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 8b0ae3b9dc..c30b67c4a2 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -494,7 +494,6 @@ def starlette_app(self) -> Starlette: async def handle_sse(request) -> EventSourceResponse: # Add client ID from auth context into request context if available - request_meta = {} async with sse.connect_sse( request.scope, request.receive, request._send diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index cd1b5502fa..ef63b9ce47 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -132,7 +132,8 @@ async def sse_writer(): logger.debug("Yielding read and write streams") # TODO: hold on; shouldn't we be returning the EventSourceResponse? # I think this is why the tests hang - # TODO: we probably shouldn't return response here, since it's a breaking change + # TODO: we probably shouldn't return response here, since it's a breaking + # change # this is just to test yield (read_stream, write_stream, response) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 961a73acda..2fb0372ae4 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -43,16 +43,20 @@ class OAuthClientMetadata(BaseModel): """ redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) - # token_endpoint_auth_method: this implementation only supports none & client_secret_basic; + # token_endpoint_auth_method: this implementation only supports none & + # client_secret_basic; # ie: we do not support client_secret_post - token_endpoint_auth_method: Literal["none", "client_secret_basic"] = "client_secret_basic" + token_endpoint_auth_method: Literal["none", "client_secret_basic"] = \ + "client_secret_basic" # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: List[Literal["authorization_code", "refresh_token"]] = ["authorization_code"] + grant_types: List[Literal["authorization_code", "refresh_token"]] = \ + ["authorization_code"] # this implementation only supports code; ie: it does not support implicit grants response_types: List[Literal["code"]] = ["code"] scope: Optional[str] = None - # these fields are currently unused, but we support & store them for potential future use + # these fields are currently unused, but we support & store them for potential + # future use client_name: Optional[str] = None client_uri: Optional[AnyHttpUrl] = None logo_uri: Optional[AnyHttpUrl] = None diff --git a/tests/server/fastmcp/auth/streaming_asgi_transport.py b/tests/server/fastmcp/auth/streaming_asgi_transport.py index bb54e46385..eb1ba4342e 100644 --- a/tests/server/fastmcp/auth/streaming_asgi_transport.py +++ b/tests/server/fastmcp/auth/streaming_asgi_transport.py @@ -161,8 +161,8 @@ async def process_messages() -> None: response_complete.set() # Create tasks for running the app and processing messages - app_task = asyncio.create_task(run_app()) - process_task = asyncio.create_task(process_messages()) + asyncio.create_task(run_app()) + asyncio.create_task(process_messages()) # Wait for the initial response or timeout await initial_response_ready.wait() diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 0e2461784f..4bed508677 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -324,7 +324,9 @@ async def test_client_registration( assert client_info["redirect_uris"] == ["https://client.example.com/callback"] # Verify that the client was registered - # assert await mock_oauth_provider.clients_store.get_client(client_info["client_id"]) is not None + # assert await mock_oauth_provider.clients_store.get_client( + # client_info["client_id"] + # ) is not None @pytest.mark.anyio async def test_authorization_flow( @@ -553,7 +555,8 @@ def test_tool(x: int) -> str: assert sse.data.startswith("/messages/?session_id=") messages_uri = sse.data - # verify that we can now post to the /messages endpoint, and get a response on the /sse endpoint + # verify that we can now post to the /messages endpoint, and get a response + # on the /sse endpoint response = await test_client.post( messages_uri, headers={"Authorization": authorization}, From 0637bc3c09013388438ec3fd67878e6b37b62d74 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 14:47:55 -0700 Subject: [PATCH 09/60] update token + revoke to use form data --- CLAUDE.md | 4 ++ pyproject.toml | 1 + src/mcp/server/auth/handlers/authorize.py | 4 +- src/mcp/server/auth/handlers/revoke.py | 5 +- src/mcp/server/auth/handlers/token.py | 3 +- src/mcp/server/auth/middleware/client_auth.py | 11 +-- src/mcp/server/auth/provider.py | 4 +- src/mcp/server/sse.py | 2 +- src/mcp/shared/auth.py | 12 ++-- .../fastmcp/auth/test_auth_integration.py | 67 ++++++++++++++++--- uv.lock | 11 +++ 11 files changed, 96 insertions(+), 28 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index baed85a238..619f3bb44b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -104,6 +104,10 @@ This document contains critical information about working with this codebase. Fo - Add None checks - Narrow string types - Match existing patterns + - Pytest: + - If the tests aren't finding the anyio pytest mark, try adding PYTEST_DISABLE_PLUGIN_AUTOLOAD="" + to the start of the pytest run command eg: + `PYTEST_DISABLE_PLUGIN_AUTOLOAD="" uv run --frozen pytest` 3. Best Practices - Check git status before commits diff --git a/pyproject.toml b/pyproject.toml index 489d1faa71..429b7d6633 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", + "python-multipart", ] [project.optional-dependencies] diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index a359456551..6194803b1f 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -36,8 +36,8 @@ class AuthorizationRequest(BaseModel): state: Optional[str] = Field(None, description="Optional state parameter") scope: Optional[str] = Field( None, - description="Optional scope; if specified, should be " \ - "a space-separated list of scope strings", + description="Optional scope; if specified, should be " + "a space-separated list of scope strings", ) class Config: diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 7aa09fa03c..d8ce89ea1a 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -45,9 +45,8 @@ async def revocation_handler(request: Request) -> Response: Handler for the OAuth 2.0 Token Revocation endpoint. """ try: - revocation_request = RevocationRequest.model_validate_json( - await request.body() - ) + form_data = await request.form() + revocation_request = RevocationRequest.model_validate(dict(form_data)) except ValidationError as e: raise InvalidRequestError(f"Invalid request body: {e}") diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index f564f39478..a054d69201 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -49,7 +49,8 @@ def create_token_handler( ) -> Callable: async def token_handler(request: Request): try: - token_request = TokenRequest.model_validate_json(await request.body()).root + form_data = await request.form() + token_request = TokenRequest.model_validate(dict(form_data)).root except ValidationError as e: raise InvalidRequestError(f"Invalid request body: {e}") client_info = await client_authenticator(token_request) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index f56e7f058f..f24aefca28 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -22,7 +22,7 @@ class ClientAuthRequest(BaseModel): """ Model for client authentication request body. - Corresponds to ClientAuthenticatedRequestSchema in + Corresponds to ClientAuthenticatedRequestSchema in src/server/auth/middleware/clientAuth.ts """ @@ -32,15 +32,16 @@ class ClientAuthRequest(BaseModel): class ClientAuthenticator: """ - ClientAuthenticator is a callable which validates requests from a client + ClientAuthenticator is a callable which validates requests from a client application, used to verify /token and /revoke calls. - If, during registration, the client requested to be issued a secret, the - authenticator asserts that /token and /register calls must be authenticated with + If, during registration, the client requested to be issued a secret, the + authenticator asserts that /token and /register calls must be authenticated with that same token. - NOTE: clients can opt for no authentication during registration, in which case this + NOTE: clients can opt for no authentication during registration, in which case this logic is skipped. """ + def __init__(self, clients_store: OAuthRegisteredClientsStore): """ Initialize the dependency. diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 437c6514da..4936d195a3 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -81,10 +81,10 @@ async def create_authorization_code( self, client: OAuthClientInformationFull, params: AuthorizationParams ) -> str: """ - Generates and stores an authorization code as part of completing the /authorize + Generates and stores an authorization code as part of completing the /authorize OAuth step. - Implementations SHOULD generate an authorization code with at least 160 bits of + Implementations SHOULD generate an authorization code with at least 160 bits of entropy, and MUST generate an authorization code with at least 128 bits of entropy. See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index ef63b9ce47..db36bffad5 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -132,7 +132,7 @@ async def sse_writer(): logger.debug("Yielding read and write streams") # TODO: hold on; shouldn't we be returning the EventSourceResponse? # I think this is why the tests hang - # TODO: we probably shouldn't return response here, since it's a breaking + # TODO: we probably shouldn't return response here, since it's a breaking # change # this is just to test yield (read_stream, write_stream, response) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 2fb0372ae4..bc113b440f 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -43,19 +43,21 @@ class OAuthClientMetadata(BaseModel): """ redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) - # token_endpoint_auth_method: this implementation only supports none & + # token_endpoint_auth_method: this implementation only supports none & # client_secret_basic; # ie: we do not support client_secret_post - token_endpoint_auth_method: Literal["none", "client_secret_basic"] = \ + token_endpoint_auth_method: Literal["none", "client_secret_basic"] = ( "client_secret_basic" + ) # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: List[Literal["authorization_code", "refresh_token"]] = \ - ["authorization_code"] + grant_types: List[Literal["authorization_code", "refresh_token"]] = [ + "authorization_code" + ] # this implementation only supports code; ie: it does not support implicit grants response_types: List[Literal["code"]] = ["code"] scope: Optional[str] = None - # these fields are currently unused, but we support & store them for potential + # these fields are currently unused, but we support & store them for potential # future use client_name: Optional[str] = None client_uri: Optional[AnyHttpUrl] = None diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 4bed508677..81a76d0be2 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -327,6 +327,55 @@ async def test_client_registration( # assert await mock_oauth_provider.clients_store.get_client( # client_info["client_id"] # ) is not None + + @pytest.mark.anyio + async def test_authorize_form_post( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + ): + """Test the authorization endpoint using POST with form-encoded data.""" + # Register a client + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # Create a PKCE challenge + code_verifier = "some_random_verifier_string" + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + + # Use POST with form-encoded data for authorization + response = await test_client.post( + "/authorize", + data={ + "response_type": "code", + "client_id": client_info["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": "test_form_state", + }, + ) + assert response.status_code == 302 + + # Extract the authorization code from the redirect URL + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params + assert query_params["state"][0] == "test_form_state" @pytest.mark.anyio async def test_authorization_flow( @@ -337,7 +386,7 @@ async def test_authorization_flow( client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", - "grant_types": ["authorization_code", "refresh_token"] + "grant_types": ["authorization_code", "refresh_token"], } response = await test_client.post( @@ -355,7 +404,7 @@ async def test_authorization_flow( .rstrip("=") ) - # 3. Request authorization + # 3. Request authorization using GET with query params response = await test_client.get( "/authorize", params={ @@ -381,7 +430,7 @@ async def test_authorization_flow( # 5. Exchange the authorization code for tokens response = await test_client.post( "/token", - json={ + data={ "grant_type": "authorization_code", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], @@ -411,7 +460,7 @@ async def test_authorization_flow( # 7. Refresh the token response = await test_client.post( "/token", - json={ + data={ "grant_type": "refresh_token", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], @@ -429,7 +478,7 @@ async def test_authorization_flow( # 8. Revoke the token response = await test_client.post( "/revoke", - json={ + data={ "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "token": new_token_response["access_token"], @@ -505,10 +554,10 @@ def test_tool(x: int) -> str: .rstrip("=") ) - # Request authorization - response = await test_client.get( + # Request authorization using POST with form-encoded data + response = await test_client.post( "/authorize", - params={ + data={ "response_type": "code", "client_id": client_info["client_id"], "redirect_uri": "https://client.example.com/callback", @@ -530,7 +579,7 @@ def test_tool(x: int) -> str: # Exchange the authorization code for tokens response = await test_client.post( "/token", - json={ + data={ "grant_type": "authorization_code", "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], diff --git a/uv.lock b/uv.lock index e17a8dc188..b1887c3506 100644 --- a/uv.lock +++ b/uv.lock @@ -202,6 +202,7 @@ dependencies = [ { name = "sse-starlette" }, { name = "starlette" }, { name = "uvicorn" }, + { name = "python-multipart" }, ] [package.optional-dependencies] @@ -826,3 +827,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884 }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, ] + +[[package]] + +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 85321 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size=11111 }, +] From b99633af2a3de7e68a6852f46d3f7c50478898dc Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 15:28:11 -0700 Subject: [PATCH 10/60] Adjust more things to fit spec --- src/mcp/server/auth/handlers/authorize.py | 4 +- src/mcp/server/auth/handlers/revoke.py | 20 ++------- src/mcp/server/auth/handlers/token.py | 39 ++++++++++------ src/mcp/server/auth/provider.py | 20 +++++++-- src/mcp/shared/auth.py | 12 ++--- .../fastmcp/auth/test_auth_integration.py | 45 ++++++------------- 6 files changed, 65 insertions(+), 75 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 6194803b1f..eef8ccfb7f 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -39,9 +39,7 @@ class AuthorizationRequest(BaseModel): description="Optional scope; if specified, should be " "a space-separated list of scope strings", ) - - class Config: - extra = "ignore" + def validate_scope( diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index d8ce89ea1a..7efc23d7a8 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -4,9 +4,9 @@ Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts """ -from typing import Callable +from typing import Callable, Optional -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from starlette.requests import Request from starlette.responses import Response @@ -17,8 +17,8 @@ ClientAuthenticator, ClientAuthRequest, ) -from mcp.server.auth.provider import OAuthServerProvider -from mcp.shared.auth import OAuthTokenRevocationRequest +from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest + class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): @@ -28,18 +28,6 @@ class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): def create_revocation_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator ) -> Callable: - """ - Create a handler for OAuth 2.0 Token Revocation. - - Corresponds to revocationHandler in src/server/auth/handlers/revoke.ts - - Args: - provider: The OAuth server provider - - Returns: - A Starlette endpoint handler function - """ - async def revocation_handler(request: Request) -> Response: """ Handler for the OAuth 2.0 Token Revocation endpoint. diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index a054d69201..866efcff0a 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -6,9 +6,10 @@ import base64 import hashlib +import time from typing import Annotated, Callable, Literal, Optional, Union -from pydantic import Field, RootModel, ValidationError +from pydantic import AnyHttpUrl, Field, RootModel, ValidationError from starlette.requests import Request from mcp.server.auth.errors import ( @@ -24,13 +25,19 @@ class AuthorizationCodeRequest(ClientAuthRequest): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") + redirect_uri: AnyHttpUrl | None = Field( + ..., description="Must be the same as redirect URI provided in /authorize" + ) + client_id: str + # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 code_verifier: str = Field(..., description="PKCE code verifier") - # TODO: this should take redirect_uri class RefreshTokenRequest(ClientAuthRequest): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: Optional[str] = Field(None, description="Optional scope parameter") @@ -42,7 +49,7 @@ class TokenRequest(RootModel): Field(discriminator="grant_type"), ] - +AUTH_CODE_TTL = 300 # seconds def create_token_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator @@ -65,22 +72,28 @@ async def token_handler(request: Request): match token_request: case AuthorizationCodeRequest(): - # TODO: verify that the redirect URIs match - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 - # TODO: enforce TTL on the authorization code - - # Verify PKCE code verifier - expected_challenge = await provider.challenge_for_authorization_code( + auth_code_metadata = await provider.load_authorization_code_metadata( client_info, token_request.code ) - if expected_challenge is None: + if auth_code_metadata is None or auth_code_metadata.client_id != token_request.client_id: raise InvalidRequestError("Invalid authorization code") - # Calculate challenge from verifier + # make auth codes expire after a deadline + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 + expires_at = auth_code_metadata.issued_at + AUTH_CODE_TTL + if expires_at < time.time(): + raise InvalidRequestError("authorization code has expired") + + # verify redirect_uri doesn't change between /authorize and /tokens + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 + if token_request.redirect_uri != auth_code_metadata.redirect_uri: + raise InvalidRequestError("redirect_uri did not match redirect_uri used when authorization code was created") + + # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() - actual_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") - if actual_challenge != expected_challenge: + if hashed_code_verifier != auth_code_metadata.code_challenge: raise InvalidRequestError( "code_verifier does not match the challenge" ) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 4936d195a3..d996dcb454 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -4,7 +4,7 @@ Corresponds to TypeScript file: src/server/auth/provider.ts """ -from typing import List, Optional, Protocol +from typing import List, Literal, Optional, Protocol from pydantic import AnyHttpUrl, BaseModel @@ -28,6 +28,18 @@ class AuthorizationParams(BaseModel): code_challenge: str redirect_uri: AnyHttpUrl +class AuthorizationCodeMeta(BaseModel): + issued_at: float + client_id: str + code_challenge: str + redirect_uri: AnyHttpUrl +class OAuthTokenRevocationRequest(BaseModel): + """ + # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 + """ + + token: str + token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None class OAuthRegisteredClientsStore(Protocol): """ @@ -91,11 +103,11 @@ async def create_authorization_code( """ ... - async def challenge_for_authorization_code( + async def load_authorization_code_metadata( self, client: OAuthClientInformationFull, authorization_code: str - ) -> str | None: + ) -> AuthorizationCodeMeta | None: """ - Returns the code_challenge that was used when the indicated authorization began. + Loads metadata for the authorization code challenge. Args: client: The client that requested the authorization code. diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index bc113b440f..298053181d 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -11,25 +11,21 @@ class OAuthErrorResponse(BaseModel): """ - OAuth 2.1 error response. - - Corresponds to OAuthErrorResponseSchema in src/shared/auth.ts + See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 """ - error: str + error: Literal["invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"] error_description: Optional[str] = None error_uri: Optional[AnyHttpUrl] = None class OAuthTokens(BaseModel): """ - OAuth 2.1 token response. - - Corresponds to OAuthTokensSchema in src/shared/auth.ts + See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 """ access_token: str - token_type: str + token_type: Literal["bearer"] = "bearer" expires_in: Optional[int] = None scope: Optional[str] = None refresh_token: Optional[str] = None diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 81a76d0be2..055be4fe1b 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -19,9 +19,11 @@ from mcp.server.auth.errors import InvalidTokenError from mcp.server.auth.provider import ( + AuthorizationCodeMeta, AuthorizationParams, OAuthRegisteredClientsStore, OAuthServerProvider, + OAuthTokenRevocationRequest, ) from mcp.server.auth.router import ( ClientRegistrationOptions, @@ -32,7 +34,6 @@ from mcp.server.fastmcp import FastMCP from mcp.shared.auth import ( OAuthClientInformationFull, - OAuthTokenRevocationRequest, OAuthTokens, ) from mcp.types import JSONRPCRequest @@ -74,32 +75,19 @@ async def create_authorization_code( code = f"code_{int(time.time())}" # Store the code for later verification - self.auth_codes[code] = { - "client_id": client.client_id, - "code_challenge": params.code_challenge, - "redirect_uri": params.redirect_uri, - "expires_at": int(time.time()) + 600, # 10 minutes - } + self.auth_codes[code] = AuthorizationCodeMeta( + client_id= client.client_id, + code_challenge= params.code_challenge, + redirect_uri= params.redirect_uri, + issued_at= time.time(), + ) return code - async def challenge_for_authorization_code( + async def load_authorization_code_metadata( self, client: OAuthClientInformationFull, authorization_code: str - ) -> str: - # Get the stored code info - code_info = self.auth_codes.get(authorization_code) - if not code_info: - raise InvalidTokenError("Invalid authorization code") - - # Check if code is expired - if code_info["expires_at"] < int(time.time()): - raise InvalidTokenError("Authorization code has expired") - - # Check if the code was issued to this client - if code_info["client_id"] != client.client_id: - raise InvalidTokenError("Authorization code was not issued to this client") - - return code_info["code_challenge"] + ) -> AuthorizationCodeMeta | None: + return self.auth_codes.get(authorization_code) async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str @@ -109,14 +97,6 @@ async def exchange_authorization_code( if not code_info: raise InvalidTokenError("Invalid authorization code") - # Check if code is expired - if code_info["expires_at"] < int(time.time()): - raise InvalidTokenError("Authorization code has expired") - - # Check if the code was issued to this client - if code_info["client_id"] != client.client_id: - raise InvalidTokenError("Authorization code was not issued to this client") - # Generate an access token and refresh token access_token = f"access_{secrets.token_hex(32)}" refresh_token = f"refresh_{secrets.token_hex(32)}" @@ -436,6 +416,7 @@ async def test_authorization_flow( "client_secret": client_info["client_secret"], "code": auth_code, "code_verifier": code_verifier, + "redirect_uri": "https://client.example.com/callback", }, ) assert response.status_code == 200 @@ -465,6 +446,7 @@ async def test_authorization_flow( "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "refresh_token": refresh_token, + "redirect_uri": "https://client.example.com/callback", }, ) assert response.status_code == 200 @@ -585,6 +567,7 @@ def test_tool(x: int) -> str: "client_secret": client_info["client_secret"], "code": auth_code, "code_verifier": code_verifier, + "redirect_uri": "https://client.example.com/callback", }, ) assert response.status_code == 200 From 9ae1c2174b6fe5105bfd730f3485e03eb10471a9 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 15:29:06 -0700 Subject: [PATCH 11/60] Lint --- src/mcp/server/auth/handlers/revoke.py | 5 ++--- src/mcp/server/auth/provider.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 7efc23d7a8..33d5e1af7c 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -4,9 +4,9 @@ Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts """ -from typing import Callable, Optional +from typing import Callable -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from starlette.requests import Request from starlette.responses import Response @@ -20,7 +20,6 @@ from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest - class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): pass diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index d996dcb454..01529a1a9f 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -11,7 +11,6 @@ from mcp.server.auth.types import AuthInfo from mcp.shared.auth import ( OAuthClientInformationFull, - OAuthTokenRevocationRequest, OAuthTokens, ) From 50683b9cb752663eb7a378e3db637196fbdabec7 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 15:29:36 -0700 Subject: [PATCH 12/60] Remove dup --- src/mcp/shared/auth.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 298053181d..a8f4acfa8b 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -102,15 +102,7 @@ class OAuthClientRegistrationError(BaseModel): error_description: Optional[str] = None -class OAuthTokenRevocationRequest(BaseModel): - """ - RFC 7009 OAuth 2.0 Token Revocation request. - - Corresponds to OAuthTokenRevocationRequestSchema in src/shared/auth.ts - """ - token: str - token_type_hint: Optional[str] = None class OAuthMetadata(BaseModel): From 2c5f26a86ddb52c0f7b531d7a10470dddd43264d Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 15:32:14 -0700 Subject: [PATCH 13/60] Comment --- src/mcp/server/auth/handlers/authorize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index eef8ccfb7f..a0ef5dc220 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -26,6 +26,7 @@ class AuthorizationRequest(BaseModel): ..., description="URL to redirect to after authorization" ) + # see OAuthClientMetadata; we only support `code` response_type: Literal["code"] = Field( ..., description="Must be 'code' for authorization code flow" ) From e60599461b397450b10e423534ccd71e1b062ada Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 16:29:06 -0700 Subject: [PATCH 14/60] Refactor back to authorize() --- src/mcp/server/auth/handlers/authorize.py | 24 +++---- src/mcp/server/auth/handlers/token.py | 72 ++++++++++++------- src/mcp/server/auth/provider.py | 58 ++++++++++++--- src/mcp/shared/auth.py | 4 +- .../fastmcp/auth/test_auth_integration.py | 42 +++++------ 5 files changed, 125 insertions(+), 75 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index a0ef5dc220..4d5c7d4572 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -19,6 +19,10 @@ from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider from mcp.shared.auth import OAuthClientInformationFull +import logging + +logger = logging.getLogger(__name__) + class AuthorizationRequest(BaseModel): client_id: str = Field(..., description="The client ID") @@ -122,28 +126,18 @@ async def authorization_handler(request: Request) -> Response: ) try: - # Let the provider handle the authorization flow - authorization_code = await provider.create_authorization_code( - client, auth_params - ) + # Let the provider pick the next URI to redirect to response = RedirectResponse( url="", status_code=302, headers={"Cache-Control": "no-store"} ) - - # Redirect with code - parsed_uri = urlparse(str(auth_params.redirect_uri)) - query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] - query_params.append(("code", authorization_code)) - if auth_params.state: - query_params.append(("state", auth_params.state)) - - redirect_url = urlunparse( - parsed_uri._replace(query=urlencode(query_params)) + response.headers["location"] = await provider.authorize( + client, auth_params ) - response.headers["location"] = redirect_url return response except Exception as e: + logger.exception("error from authorize()", exc_info=e) + return RedirectResponse( url=create_error_redirect(redirect_uri, e, auth_request.state), status_code=302, diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 866efcff0a..712cf8e2f1 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -21,7 +21,7 @@ ClientAuthRequest, ) from mcp.server.auth.provider import OAuthServerProvider -from mcp.shared.auth import OAuthTokens +from mcp.shared.auth import TokenErrorResponse, TokenSuccessResponse class AuthorizationCodeRequest(ClientAuthRequest): @@ -54,53 +54,79 @@ class TokenRequest(RootModel): def create_token_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator ) -> Callable: + def response(obj: TokenSuccessResponse | TokenErrorResponse): + return PydanticJSONResponse( + content=obj, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + }, + ) + async def token_handler(request: Request): try: form_data = await request.form() token_request = TokenRequest.model_validate(dict(form_data)).root - except ValidationError as e: - raise InvalidRequestError(f"Invalid request body: {e}") + except ValidationError as validation_error: + return response(TokenErrorResponse( + error="invalid_request", + error_description="\n".join(e['msg'] for e in validation_error.errors()) + + )) client_info = await client_authenticator(token_request) if token_request.grant_type not in client_info.grant_types: - raise InvalidRequestError( - f"Unsupported grant type (supported grant types are " + return response(TokenErrorResponse( + error="unsupported_grant_type", + error_description=f"Unsupported grant type (supported grant types are " f"{client_info.grant_types})" - ) + )) - tokens: OAuthTokens + tokens: TokenSuccessResponse match token_request: case AuthorizationCodeRequest(): - auth_code_metadata = await provider.load_authorization_code_metadata( + auth_code = await provider.load_authorization_code( client_info, token_request.code ) - if auth_code_metadata is None or auth_code_metadata.client_id != token_request.client_id: - raise InvalidRequestError("Invalid authorization code") + if auth_code is None or auth_code.client_id != token_request.client_id: + # if the authoriation code belongs to a different client, pretend it doesn't exist + return response(TokenErrorResponse( + error="invalid_grant", + error_description=f"authorization code does not exist" + )) # make auth codes expire after a deadline # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 - expires_at = auth_code_metadata.issued_at + AUTH_CODE_TTL + expires_at = auth_code.issued_at + AUTH_CODE_TTL if expires_at < time.time(): - raise InvalidRequestError("authorization code has expired") + return response(TokenErrorResponse( + error="invalid_grant", + error_description=f"authorization code has expired" + )) # verify redirect_uri doesn't change between /authorize and /tokens # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 - if token_request.redirect_uri != auth_code_metadata.redirect_uri: - raise InvalidRequestError("redirect_uri did not match redirect_uri used when authorization code was created") + if token_request.redirect_uri != auth_code.redirect_uri: + return response(TokenErrorResponse( + error="invalid_request", + error_description=f"redirect_uri did not match redirect_uri used when authorization code was created" + )) # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") - if hashed_code_verifier != auth_code_metadata.code_challenge: - raise InvalidRequestError( - "code_verifier does not match the challenge" - ) + if hashed_code_verifier != auth_code.code_challenge: + # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 + return response(TokenErrorResponse( + error="invalid_grant", + error_description=f"incorrect code_verifier" + )) # Exchange authorization code for tokens tokens = await provider.exchange_authorization_code( - client_info, token_request.code + client_info, auth_code ) case RefreshTokenRequest(): @@ -112,12 +138,6 @@ async def token_handler(request: Request): client_info, token_request.refresh_token, scopes ) - return PydanticJSONResponse( - content=tokens, - headers={ - "Cache-Control": "no-store", - "Pragma": "no-cache", - }, - ) + return response(tokens) return token_handler diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 01529a1a9f..e4a159b200 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -5,13 +5,14 @@ """ from typing import List, Literal, Optional, Protocol +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from pydantic import AnyHttpUrl, BaseModel +from pydantic import AnyHttpUrl, AnyUrl, BaseModel from mcp.server.auth.types import AuthInfo from mcp.shared.auth import ( OAuthClientInformationFull, - OAuthTokens, + TokenSuccessResponse, ) @@ -27,7 +28,9 @@ class AuthorizationParams(BaseModel): code_challenge: str redirect_uri: AnyHttpUrl -class AuthorizationCodeMeta(BaseModel): +class AuthorizationCode(BaseModel): + code: str + scopes: list[str] issued_at: float client_id: str code_challenge: str @@ -88,12 +91,33 @@ def clients_store(self) -> OAuthRegisteredClientsStore: """ ... - async def create_authorization_code( + async def authorize( self, client: OAuthClientInformationFull, params: AuthorizationParams ) -> str: """ - Generates and stores an authorization code as part of completing the /authorize - OAuth step. + Called as part of the /authorize endpoint, and returns a URL that the client + will be redirected to. + Many MCP implementations will redirect to a third-party provider to perform + a second OAuth exchange with that provider. In this sort of setup, the client + has an OAuth connection with the MCP server, and the MCP server has an OAuth + connection with the 3rd-party provider. At the end of this flow, the client + should be redirected to the redirect_uri from params.redirect_uri. + + +--------+ +------------+ +-------------------+ + | | | | | | + | Client | --> | MCP Server | --> | 3rd Party OAuth | + | | | | | Server | + +--------+ +------------+ +-------------------+ + | ^ | + +------------+ | | | + | | | | Redirect | + |redirect_uri|<-----+ +------------------+ + | | + +------------+ + + Implementations will need to define another handler on the MCP server return + flow to perform the second redirect, and generates and stores an authorization + code as part of completing the OAuth authorization step. Implementations SHOULD generate an authorization code with at least 160 bits of entropy, @@ -102,9 +126,9 @@ async def create_authorization_code( """ ... - async def load_authorization_code_metadata( + async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str - ) -> AuthorizationCodeMeta | None: + ) -> AuthorizationCode | None: """ Loads metadata for the authorization code challenge. @@ -118,8 +142,8 @@ async def load_authorization_code_metadata( ... async def exchange_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: str - ) -> OAuthTokens: + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> TokenSuccessResponse: """ Exchanges an authorization code for an access token. @@ -137,7 +161,7 @@ async def exchange_refresh_token( client: OAuthClientInformationFull, refresh_token: str, scopes: Optional[List[str]] = None, - ) -> OAuthTokens: + ) -> TokenSuccessResponse: """ Exchanges a refresh token for an access token. @@ -178,3 +202,15 @@ async def revoke_token( request: The token revocation request. """ ... + +def construct_redirect_uri(redirect_uri_base: str, authorization_code: AuthorizationCode, state: Optional[str]) -> str: + parsed_uri = urlparse(redirect_uri_base) + query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] + query_params.append(("code", authorization_code.code)) + if state: + query_params.append(("state", state)) + + redirect_uri = urlunparse( + parsed_uri._replace(query=urlencode(query_params)) + ) + return redirect_uri \ No newline at end of file diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index a8f4acfa8b..9bcdaef150 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -9,7 +9,7 @@ from pydantic import AnyHttpUrl, BaseModel, Field -class OAuthErrorResponse(BaseModel): +class TokenErrorResponse(BaseModel): """ See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 """ @@ -19,7 +19,7 @@ class OAuthErrorResponse(BaseModel): error_uri: Optional[AnyHttpUrl] = None -class OAuthTokens(BaseModel): +class TokenSuccessResponse(BaseModel): """ See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 """ diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 055be4fe1b..fbdded8752 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -19,11 +19,12 @@ from mcp.server.auth.errors import InvalidTokenError from mcp.server.auth.provider import ( - AuthorizationCodeMeta, + AuthorizationCode, AuthorizationParams, OAuthRegisteredClientsStore, OAuthServerProvider, OAuthTokenRevocationRequest, + construct_redirect_uri, ) from mcp.server.auth.router import ( ClientRegistrationOptions, @@ -34,7 +35,7 @@ from mcp.server.fastmcp import FastMCP from mcp.shared.auth import ( OAuthClientInformationFull, - OAuthTokens, + TokenSuccessResponse, ) from mcp.types import JSONRPCRequest @@ -68,33 +69,32 @@ def __init__(self): def clients_store(self) -> OAuthRegisteredClientsStore: return self.client_store - async def create_authorization_code( + async def authorize( self, client: OAuthClientInformationFull, params: AuthorizationParams ) -> str: - # Generate an authorization code - code = f"code_{int(time.time())}" - - # Store the code for later verification - self.auth_codes[code] = AuthorizationCodeMeta( + # toy authorize implementation which just immediately generates an authorization + # code and completes the redirect + code = AuthorizationCode( + code=f"code_{int(time.time())}", client_id= client.client_id, code_challenge= params.code_challenge, redirect_uri= params.redirect_uri, issued_at= time.time(), + scopes=params.scopes or ["read", "write"] ) + self.auth_codes[code.code] = code - return code + return construct_redirect_uri(str(params.redirect_uri), code, params.state) - async def load_authorization_code_metadata( + async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str - ) -> AuthorizationCodeMeta | None: + ) -> AuthorizationCode | None: return self.auth_codes.get(authorization_code) async def exchange_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: str - ) -> OAuthTokens: - # Get the stored code info - code_info = self.auth_codes.get(authorization_code) - if not code_info: + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> TokenSuccessResponse: + if authorization_code.code not in self.auth_codes: raise InvalidTokenError("Invalid authorization code") # Generate an access token and refresh token @@ -104,16 +104,16 @@ async def exchange_authorization_code( # Store the tokens self.tokens[access_token] = { "client_id": client.client_id, - "scopes": ["read", "write"], + "scopes": authorization_code.scopes, "expires_at": int(time.time()) + 3600, } self.refresh_tokens[refresh_token] = access_token # Remove the used code - del self.auth_codes[authorization_code] + del self.auth_codes[authorization_code.code] - return OAuthTokens( + return TokenSuccessResponse( access_token=access_token, token_type="bearer", expires_in=3600, @@ -126,7 +126,7 @@ async def exchange_refresh_token( client: OAuthClientInformationFull, refresh_token: str, scopes: Optional[List[str]] = None, - ) -> OAuthTokens: + ) -> TokenSuccessResponse: # Check if refresh token exists if refresh_token not in self.refresh_tokens: raise InvalidTokenError("Invalid refresh token") @@ -160,7 +160,7 @@ async def exchange_refresh_token( del self.refresh_tokens[refresh_token] del self.tokens[old_access_token] - return OAuthTokens( + return TokenSuccessResponse( access_token=new_access_token, token_type="bearer", expires_in=3600, From e7c5f87fd30910e61fd0321a7f725d49eb782eba Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 17:29:34 -0700 Subject: [PATCH 15/60] Improve validation for /token --- src/mcp/server/auth/handlers/token.py | 35 +- src/mcp/server/auth/middleware/bearer_auth.py | 2 +- src/mcp/server/auth/provider.py | 17 +- src/mcp/server/auth/types.py | 4 - .../fastmcp/auth/test_auth_integration.py | 413 +++++++++++++++++- 5 files changed, 438 insertions(+), 33 deletions(-) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 712cf8e2f1..e258992da6 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -49,14 +49,18 @@ class TokenRequest(RootModel): Field(discriminator="grant_type"), ] -AUTH_CODE_TTL = 300 # seconds def create_token_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator ) -> Callable: def response(obj: TokenSuccessResponse | TokenErrorResponse): + status_code = 200 + if isinstance(obj, TokenErrorResponse): + status_code = 400 + return PydanticJSONResponse( content=obj, + status_code=status_code, headers={ "Cache-Control": "no-store", "Pragma": "no-cache", @@ -98,8 +102,7 @@ async def token_handler(request: Request): # make auth codes expire after a deadline # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 - expires_at = auth_code.issued_at + AUTH_CODE_TTL - if expires_at < time.time(): + if auth_code.expires_at < time.time(): return response(TokenErrorResponse( error="invalid_grant", error_description=f"authorization code has expired" @@ -130,12 +133,34 @@ async def token_handler(request: Request): ) case RefreshTokenRequest(): + refresh_token = await provider.load_refresh_token(client_info, token_request.refresh_token) + if refresh_token is None or refresh_token.client_id != token_request.client_id: + # if the authoriation code belongs to a different client, pretend it doesn't exist + return response(TokenErrorResponse( + error="invalid_grant", + error_description=f"refresh token does not exist" + )) + + if refresh_token.expires_at and refresh_token.expires_at < time.time(): + # if the authoriation code belongs to a different client, pretend it doesn't exist + return response(TokenErrorResponse( + error="invalid_grant", + error_description=f"refresh token has expired" + )) + # Parse scopes if provided - scopes = token_request.scope.split(" ") if token_request.scope else None + scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes + + for scope in scopes: + if scope not in refresh_token.scopes: + return response(TokenErrorResponse( + error="invalid_scope", + error_description=f"cannot request scope `{scope}` not provided by refresh token" + )) # Exchange refresh token for new tokens tokens = await provider.exchange_refresh_token( - client_info, token_request.refresh_token, scopes + client_info, refresh_token, scopes ) return response(tokens) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index bfa15996fa..796dba7046 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -25,7 +25,7 @@ class AuthenticatedUser(SimpleUser): """User with authentication info.""" def __init__(self, auth_info: AuthInfo): - super().__init__(auth_info.user_id or "anonymous") + super().__init__(auth_info.client_id) self.auth_info = auth_info self.scopes = auth_info.scopes diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index e4a159b200..fb354ef163 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -31,10 +31,18 @@ class AuthorizationParams(BaseModel): class AuthorizationCode(BaseModel): code: str scopes: list[str] - issued_at: float + expires_at: float client_id: str code_challenge: str redirect_uri: AnyHttpUrl + +class RefreshToken(BaseModel): + token: str + client_id: str + scopes: List[str] + expires_at: Optional[int] = None + + class OAuthTokenRevocationRequest(BaseModel): """ # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 @@ -156,11 +164,14 @@ async def exchange_authorization_code( """ ... + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + ... + async def exchange_refresh_token( self, client: OAuthClientInformationFull, - refresh_token: str, - scopes: Optional[List[str]] = None, + refresh_token: RefreshToken, + scopes: List[str], ) -> TokenSuccessResponse: """ Exchanges a refresh token for an access token. diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py index 3edc4cb93c..f0593d8644 100644 --- a/src/mcp/server/auth/types.py +++ b/src/mcp/server/auth/types.py @@ -20,7 +20,3 @@ class AuthInfo(BaseModel): client_id: str scopes: List[str] expires_at: Optional[int] = None - user_id: Optional[str] = None - - class Config: - extra = "ignore" diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index fbdded8752..792394ffe6 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -7,6 +7,7 @@ import json import secrets import time +import unittest.mock from typing import List, Optional from urllib.parse import parse_qs, urlparse @@ -24,6 +25,7 @@ OAuthRegisteredClientsStore, OAuthServerProvider, OAuthTokenRevocationRequest, + RefreshToken, construct_redirect_uri, ) from mcp.server.auth.router import ( @@ -36,6 +38,7 @@ from mcp.shared.auth import ( OAuthClientInformationFull, TokenSuccessResponse, + TokenErrorResponse, ) from mcp.types import JSONRPCRequest @@ -79,7 +82,7 @@ async def authorize( client_id= client.client_id, code_challenge= params.code_challenge, redirect_uri= params.redirect_uri, - issued_at= time.time(), + expires_at=time.time() + 300, scopes=params.scopes or ["read", "write"] ) self.auth_codes[code.code] = code @@ -102,11 +105,12 @@ async def exchange_authorization_code( refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the tokens - self.tokens[access_token] = { - "client_id": client.client_id, - "scopes": authorization_code.scopes, - "expires_at": int(time.time()) + 3600, - } + self.tokens[access_token] = AuthInfo( + token=access_token, + client_id= client.client_id, + scopes= authorization_code.scopes, + expires_at=int(time.time()) + 3600, + ) self.refresh_tokens[refresh_token] = access_token @@ -121,18 +125,35 @@ async def exchange_authorization_code( refresh_token=refresh_token, ) + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + old_access_token = self.refresh_tokens.get(refresh_token) + if old_access_token is None: + return None + token_info = self.tokens.get(old_access_token) + if token_info is None: + return None + + # Create a RefreshToken object that matches what is expected in later code + refresh_obj = RefreshToken( + token=refresh_token, + client_id=token_info.client_id, + scopes=token_info.scopes, + expires_at=token_info.expires_at, + ) + + return refresh_obj + async def exchange_refresh_token( self, client: OAuthClientInformationFull, - refresh_token: str, - scopes: Optional[List[str]] = None, + refresh_token: RefreshToken, + scopes: List[str], ) -> TokenSuccessResponse: # Check if refresh token exists - if refresh_token not in self.refresh_tokens: + if refresh_token.token not in self.refresh_tokens: raise InvalidTokenError("Invalid refresh token") - # Get the access token for this refresh token - old_access_token = self.refresh_tokens[refresh_token] + old_access_token = self.refresh_tokens[refresh_token.token] # Check if the access token exists if old_access_token not in self.tokens: @@ -140,7 +161,7 @@ async def exchange_refresh_token( # Check if the token was issued to this client token_info = self.tokens[old_access_token] - if token_info["client_id"] != client.client_id: + if token_info.client_id != client.client_id: raise InvalidTokenError("Refresh token was not issued to this client") # Generate a new access token and refresh token @@ -150,21 +171,21 @@ async def exchange_refresh_token( # Store the new tokens self.tokens[new_access_token] = { "client_id": client.client_id, - "scopes": scopes or token_info["scopes"], + "scopes": scopes or token_info.scopes, "expires_at": int(time.time()) + 3600, } self.refresh_tokens[new_refresh_token] = new_access_token # Remove the old tokens - del self.refresh_tokens[refresh_token] + del self.refresh_tokens[refresh_token.token] del self.tokens[old_access_token] return TokenSuccessResponse( access_token=new_access_token, token_type="bearer", expires_in=3600, - scope=" ".join(scopes) if scopes else " ".join(token_info["scopes"]), + scope=" ".join(scopes) if scopes else " ".join(token_info.scopes), refresh_token=new_refresh_token, ) @@ -177,14 +198,14 @@ async def verify_access_token(self, token: str) -> AuthInfo: token_info = self.tokens[token] # Check if token is expired - if token_info["expires_at"] < int(time.time()): + if token_info.expires_at < int(time.time()): raise InvalidTokenError("Access token has expired") return AuthInfo( token=token, - client_id=token_info["client_id"], - scopes=token_info["scopes"], - expires_at=token_info["expires_at"], + client_id=token_info.client_id, + scopes=token_info.scopes, + expires_at=token_info.expires_at, ) async def revoke_token( @@ -250,6 +271,119 @@ def test_client(auth_app) -> httpx.AsyncClient: ) +@pytest.fixture +async def registered_client(test_client: httpx.AsyncClient, request): + """Create and register a test client. + + Parameters can be customized via indirect parameterization: + @pytest.mark.parametrize("registered_client", + [{"grant_types": ["authorization_code"]}], + indirect=True) + """ + # Default client metadata + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "grant_types": ["authorization_code", "refresh_token"], + } + + # Override with any parameters from the test + if hasattr(request, "param") and request.param: + client_metadata.update(request.param) + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201, f"Failed to register client: {response.content}" + + client_info = response.json() + return client_info + + +@pytest.fixture +def pkce_challenge(): + """Create a PKCE challenge with code_verifier and code_challenge.""" + code_verifier = "some_random_verifier_string" + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + + return {"code_verifier": code_verifier, "code_challenge": code_challenge} + + +@pytest.fixture +async def auth_code(test_client, registered_client, pkce_challenge, request): + """Get an authorization code. + + Parameters can be customized via indirect parameterization: + @pytest.mark.parametrize("auth_code", + [{"redirect_uri": "https://client.example.com/other-callback"}], + indirect=True) + """ + # Default authorize params + auth_params = { + "response_type": "code", + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + } + + # Override with any parameters from the test + if hasattr(request, "param") and request.param: + auth_params.update(request.param) + + response = await test_client.get("/authorize", params=auth_params) + assert response.status_code == 302, f"Failed to get auth code: {response.content}" + + # Extract the authorization code + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params, f"No code in response: {query_params}" + auth_code = query_params["code"][0] + + return { + "code": auth_code, + "redirect_uri": auth_params["redirect_uri"], + "state": query_params.get("state", [None])[0], + } + + +@pytest.fixture +async def tokens(test_client, registered_client, auth_code, pkce_challenge, request): + """Exchange authorization code for tokens. + + Parameters can be customized via indirect parameterization: + @pytest.mark.parametrize("tokens", + [{"code_verifier": "wrong_verifier"}], + indirect=True) + """ + # Default token request params + token_params = { + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + } + + # Override with any parameters from the test + if hasattr(request, "param") and request.param: + token_params.update(request.param) + + response = await test_client.post("/token", data=token_params) + + # Don't assert success here since some tests will intentionally cause errors + return { + "response": response, + "params": token_params, + } + + class TestAuthEndpoints: @pytest.mark.anyio async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): @@ -279,6 +413,245 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): "refresh_token", ] assert metadata["service_documentation"] == "https://docs.example.com" + + @pytest.mark.anyio + async def test_token_validation_error(self, test_client: httpx.AsyncClient): + """Test token endpoint error - validation error.""" + # Missing required fields + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + # Missing code, code_verifier, client_id, etc. + }, + ) + error_response = response.json() + assert error_response["error"] == "invalid_request" + assert "error_description" in error_response # Contains validation error messages + + @pytest.mark.anyio + @pytest.mark.parametrize("registered_client", [{"grant_types": ["authorization_code"]}], indirect=True) + async def test_token_unsupported_grant_type(self, test_client, registered_client): + """Test token endpoint error - unsupported grant type.""" + # Try to use refresh_token grant type with a client that only supports authorization_code + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "refresh_token": "some_refresh_token", + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "unsupported_grant_type" + assert "supported grant types" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge): + """Test token endpoint error - authorization code does not exist.""" + # Try to use a non-existent authorization code + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": "non_existent_auth_code", + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + print(f"Status code: {response.status_code}") + print(f"Response body: {response.content}") + print(f"Response JSON: {response.json()}") + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "authorization code does not exist" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_expired_auth_code( + self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider + ): + """Test token endpoint error - authorization code has expired.""" + # Get the current time for our time mocking + current_time = time.time() + + # Find the auth code object + code_value = auth_code["code"] + found_code = None + for code_obj in mock_oauth_provider.auth_codes.values(): + if code_obj.code == code_value: + found_code = code_obj + break + + assert found_code is not None + + # Authorization codes are typically short-lived (5 minutes = 300 seconds) + # So we'll mock time to be 10 minutes (600 seconds) in the future + with unittest.mock.patch('time.time', return_value=current_time + 600): + # Try to use the expired authorization code + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": code_value, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "authorization code has expired" in error_response["error_description"] + + @pytest.mark.anyio + @pytest.mark.parametrize("registered_client", + [{"redirect_uris": ["https://client.example.com/callback", + "https://client.example.com/other-callback"]}], + indirect=True) + async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge): + """Test token endpoint error - redirect URI mismatch.""" + # Try to use the code with a different redirect URI + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/other-callback", # Different from the one used in /authorize + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_request" + assert "redirect_uri did not match" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code): + """Test token endpoint error - PKCE code verifier mismatch.""" + # Try to use the code with an incorrect code verifier + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": "incorrect_code_verifier", # Different from the one used to create challenge + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "incorrect code_verifier" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_invalid_refresh_token(self, test_client, registered_client): + """Test token endpoint error - refresh token does not exist.""" + # Try to use a non-existent refresh token + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "refresh_token": "non_existent_refresh_token", + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "refresh token does not exist" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_expired_refresh_token( + self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider + ): + """Test token endpoint error - refresh token has expired.""" + # Step 1: First, let's create a token and refresh token at the current time + current_time = time.time() + + # Exchange authorization code for tokens normally + token_response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert token_response.status_code == 200 + tokens = token_response.json() + refresh_token = tokens["refresh_token"] + + # Step 2: Now let's time travel forward 4 hours (tokens expire in 1 hour by default) + # Mock the time.time() function to return a value 4 hours in the future + with unittest.mock.patch('time.time', return_value=current_time + 14400): # 4 hours = 14400 seconds + # Try to use the refresh token which should now be considered expired + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "refresh_token": refresh_token, + }, + ) + + # In the "future", the token should be considered expired + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "refresh token has expired" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_invalid_scope( + self, test_client, registered_client, auth_code, pkce_challenge + ): + """Test token endpoint error - invalid scope in refresh token request.""" + # Exchange authorization code for tokens + token_response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert token_response.status_code == 200 + + tokens = token_response.json() + refresh_token = tokens["refresh_token"] + + # Try to use refresh token with an invalid scope + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "refresh_token": refresh_token, + "scope": "read write invalid_scope", # Adding an invalid scope + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_scope" + assert "cannot request scope" in error_response["error_description"] @pytest.mark.anyio async def test_client_registration( @@ -358,7 +731,7 @@ async def test_authorize_form_post( assert query_params["state"][0] == "test_form_state" @pytest.mark.anyio - async def test_authorization_flow( + async def test_authorization_get( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider ): """Test the full authorization flow.""" From 83c0c9f7b5a16e85fb814b3bda9675bd5b10399d Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 17:50:13 -0700 Subject: [PATCH 16/60] Improve validation for registration --- src/mcp/server/auth/errors.py | 5 + src/mcp/server/auth/handlers/register.py | 134 ++++++++---------- src/mcp/server/auth/handlers/token.py | 3 +- src/mcp/server/auth/provider.py | 9 +- .../fastmcp/auth/test_auth_integration.py | 62 ++++++++ 5 files changed, 129 insertions(+), 84 deletions(-) diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index badee09844..863a17b55e 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -6,6 +6,8 @@ from typing import Dict +from pydantic import ValidationError + class OAuthError(Exception): """ @@ -143,3 +145,6 @@ class InsufficientScopeError(OAuthError): """ error_code = "insufficient_scope" + +def stringify_pydantic_error(validation_error: ValidationError) -> str: + return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors()) \ No newline at end of file diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 0437a7abaa..4378dc9493 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -6,10 +6,10 @@ import secrets import time -from typing import Callable +from typing import Callable, Literal from uuid import uuid4 -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -17,92 +17,72 @@ InvalidRequestError, OAuthError, ServerError, + stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata +class ErrorResponse(BaseModel): + error: Literal["invalid_redirect_uri", "invalid_client_metadata", "invalid_software_statement", "unapproved_software_statement"] + error_description: str + def create_registration_handler( clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None ) -> Callable: - """ - Create a handler for OAuth 2.0 Dynamic Client Registration. - - Corresponds to clientRegistrationHandler in src/server/auth/handlers/register.ts - - Args: - clients_store: The store for registered clients - client_secret_expiry_seconds: Optional expiry time for client secrets - - Returns: - A Starlette endpoint handler function - """ - async def registration_handler(request: Request) -> Response: - """ - Handler for the OAuth 2.0 Dynamic Client Registration endpoint. - - Args: - request: The Starlette request - - Returns: - JSON response with client information or error - """ + # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: # Parse request body as JSON - try: - body = await request.json() - client_metadata = OAuthClientMetadata.model_validate(body) - except ValidationError as e: - raise InvalidRequestError(f"Invalid client metadata: {str(e)}") - - client_id = str(uuid4()) - client_secret = None - if client_metadata.token_endpoint_auth_method != "none": - # cryptographically secure random 32-byte hex string - client_secret = secrets.token_hex(32) - - client_id_issued_at = int(time.time()) - client_secret_expires_at = ( - client_id_issued_at + client_secret_expiry_seconds - if client_secret_expiry_seconds is not None - else None - ) - - client_info = OAuthClientInformationFull( - client_id=client_id, - client_id_issued_at=client_id_issued_at, - client_secret=client_secret, - client_secret_expires_at=client_secret_expires_at, - # passthrough information from the client request - redirect_uris=client_metadata.redirect_uris, - token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, - grant_types=client_metadata.grant_types, - response_types=client_metadata.response_types, - client_name=client_metadata.client_name, - client_uri=client_metadata.client_uri, - logo_uri=client_metadata.logo_uri, - scope=client_metadata.scope, - contacts=client_metadata.contacts, - tos_uri=client_metadata.tos_uri, - policy_uri=client_metadata.policy_uri, - jwks_uri=client_metadata.jwks_uri, - jwks=client_metadata.jwks, - software_id=client_metadata.software_id, - software_version=client_metadata.software_version, - ) - # Register client - client = await clients_store.register_client(client_info) - if not client: - raise ServerError("Failed to register client") - - # Return client information - return PydanticJSONResponse(content=client, status_code=201) - - except OAuthError as e: - # Handle OAuth errors - status_code = 500 if isinstance(e, ServerError) else 400 - return JSONResponse(status_code=status_code, content=e.to_response_object()) + body = await request.json() + client_metadata = OAuthClientMetadata.model_validate(body) + except ValidationError as validation_error: + return PydanticJSONResponse(content=ErrorResponse( + error="invalid_client_metadata", + error_description=stringify_pydantic_error(validation_error) + ), status_code=400) + raise InvalidRequestError(f"Invalid client metadata: {str(e)}") + + client_id = str(uuid4()) + client_secret = None + if client_metadata.token_endpoint_auth_method != "none": + # cryptographically secure random 32-byte hex string + client_secret = secrets.token_hex(32) + + client_id_issued_at = int(time.time()) + client_secret_expires_at = ( + client_id_issued_at + client_secret_expiry_seconds + if client_secret_expiry_seconds is not None + else None + ) + + client_info = OAuthClientInformationFull( + client_id=client_id, + client_id_issued_at=client_id_issued_at, + client_secret=client_secret, + client_secret_expires_at=client_secret_expires_at, + # passthrough information from the client request + redirect_uris=client_metadata.redirect_uris, + token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, + grant_types=client_metadata.grant_types, + response_types=client_metadata.response_types, + client_name=client_metadata.client_name, + client_uri=client_metadata.client_uri, + logo_uri=client_metadata.logo_uri, + scope=client_metadata.scope, + contacts=client_metadata.contacts, + tos_uri=client_metadata.tos_uri, + policy_uri=client_metadata.policy_uri, + jwks_uri=client_metadata.jwks_uri, + jwks=client_metadata.jwks, + software_id=client_metadata.software_id, + software_version=client_metadata.software_version, + ) + # Register client + client = await clients_store.register_client(client_info) + + # Return client information + return PydanticJSONResponse(content=client, status_code=201) return registration_handler diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index e258992da6..0c8efe9292 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -14,6 +14,7 @@ from mcp.server.auth.errors import ( InvalidRequestError, + stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( @@ -74,7 +75,7 @@ async def token_handler(request: Request): except ValidationError as validation_error: return response(TokenErrorResponse( error="invalid_request", - error_description="\n".join(e['msg'] for e in validation_error.errors()) + error_description=stringify_pydantic_error(validation_error) )) client_info = await client_authenticator(token_request) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index fb354ef163..c15c1540ca 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -72,15 +72,12 @@ async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFul async def register_client( self, client_info: OAuthClientInformationFull - ) -> Optional[OAuthClientInformationFull]: + ) -> None: """ - Registers a new client and returns client information. + Registers a new client Args: - metadata: The client metadata to register. - - Returns: - The client information, or None if registration failed. + client_info: The client metadata to register. """ ... diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 792394ffe6..8243ad7543 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -681,6 +681,68 @@ async def test_client_registration( # client_info["client_id"] # ) is not None + @pytest.mark.anyio + async def test_client_registration_missing_required_fields( + self, test_client: httpx.AsyncClient + ): + """Test client registration with missing required fields.""" + # Missing redirect_uris which is a required field + client_metadata = { + "client_name": "Test Client", + "client_uri": "https://client.example.com", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert error_data["error_description"] == "redirect_uris: Field required" + + @pytest.mark.anyio + async def test_client_registration_invalid_uri( + self, test_client: httpx.AsyncClient + ): + """Test client registration with invalid URIs.""" + # Invalid redirect_uri format + client_metadata = { + "redirect_uris": ["not-a-valid-uri"], + "client_name": "Test Client", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert error_data["error_description"] == "redirect_uris.0: Input should be a valid URL, relative URL without a base" + + @pytest.mark.anyio + async def test_client_registration_empty_redirect_uris( + self, test_client: httpx.AsyncClient + ): + """Test client registration with empty redirect_uris array.""" + client_metadata = { + "redirect_uris": [], # Empty array + "client_name": "Test Client", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0" + @pytest.mark.anyio async def test_authorize_form_post( self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider From 0c1aae97c7de1ea09ba33b356d3b1d41fe9a010b Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 21:38:57 -0700 Subject: [PATCH 17/60] Improve /authorize validation & add tests --- src/mcp/server/auth/handlers/authorize.py | 218 ++++++++---- src/mcp/server/auth/provider.py | 8 +- .../fastmcp/auth/test_auth_integration.py | 321 +++++++++++++++--- 3 files changed, 443 insertions(+), 104 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 4d5c7d4572..9d0b3c1d3a 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -4,10 +4,11 @@ Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts """ -from typing import Callable, Literal, Optional +from typing import Callable, Literal, Optional, Union from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, ValidationError +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError +from starlette.datastructures import FormData, QueryParams from starlette.requests import Request from starlette.responses import RedirectResponse, Response @@ -15,9 +16,11 @@ InvalidClientError, InvalidRequestError, OAuthError, + stringify_pydantic_error, ) -from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider +from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, construct_redirect_uri from mcp.shared.auth import OAuthClientInformationFull +from mcp.server.auth.json_response import PydanticJSONResponse import logging @@ -25,9 +28,10 @@ class AuthorizationRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 client_id: str = Field(..., description="The client ID") redirect_uri: AnyHttpUrl | None = Field( - ..., description="URL to redirect to after authorization" + None, description="URL to redirect to after authorization" ) # see OAuthClientMetadata; we only support `code` @@ -61,71 +65,160 @@ def validate_scope( def validate_redirect_uri( - auth_request: AuthorizationRequest, client: OAuthClientInformationFull + redirect_uri: AnyHttpUrl | None, client: OAuthClientInformationFull ) -> AnyHttpUrl: - if auth_request.redirect_uri is not None: + if redirect_uri is not None: # Validate redirect_uri against client's registered redirect URIs - if auth_request.redirect_uri not in client.redirect_uris: + if redirect_uri not in client.redirect_uris: raise InvalidRequestError( - f"Redirect URI '{auth_request.redirect_uri}' not registered for client" + f"Redirect URI '{redirect_uri}' not registered for client" ) - return auth_request.redirect_uri + return redirect_uri elif len(client.redirect_uris) == 1: return client.redirect_uris[0] else: raise InvalidRequestError( "redirect_uri must be specified when client has multiple registered URIs" ) +ErrorCode = Literal[ + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable" + ] +class ErrorResponse(BaseModel): + error: ErrorCode + error_description: str + error_uri: Optional[AnyUrl] = None + # must be set if provided in the request + state: Optional[str] + +def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> Optional[str]: + if params is None: + return None + value = params.get(key) + if isinstance(value, str): + return value + return None + +class AnyHttpUrlModel(RootModel): + root: AnyHttpUrl def create_authorization_handler(provider: OAuthServerProvider) -> Callable: - """ - Create a handler for the OAuth 2.0 Authorization endpoint. + async def authorization_handler(request: Request) -> Response: + # implements authorization requests for grant_type=code; + # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 - Corresponds to authorizationHandler in src/server/auth/handlers/authorize.ts + state = None + redirect_uri = None + client = None + params = None - """ + async def error_response(error: ErrorCode, error_description: str, attempt_load_client: bool = True): + nonlocal client, redirect_uri, state + if client is None and attempt_load_client: + # make last-ditch attempt to load the client + client_id = best_effort_extract_string("client_id", params) + client = client_id and await provider.clients_store.get_client(client_id) + if redirect_uri is None and client: + # make last-ditch effort to load the redirect uri + if params is not None and "redirect_uri" not in params: + raw_redirect_uri = None + else: + raw_redirect_uri = AnyHttpUrlModel.model_validate(best_effort_extract_string("redirect_uri", params)).root + try: + redirect_uri = validate_redirect_uri(raw_redirect_uri, client) + except (ValidationError, InvalidRequestError): + pass + if state is None: + # make last-ditch effort to load state + state = best_effort_extract_string("state", params) - async def authorization_handler(request: Request) -> Response: - """ - Handler for the OAuth 2.0 Authorization endpoint. - """ - # Validate request parameters + error_resp = ErrorResponse( + error=error, + error_description=error_description, + state=state, + ) + + if redirect_uri and client: + return RedirectResponse( + url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + else: + return PydanticJSONResponse( + status_code=400, + content=error_resp, + headers={"Cache-Control": "no-store"}, + ) + try: + # Parse request parameters if request.method == "GET": # Convert query_params to dict for pydantic validation - params = dict(request.query_params) - auth_request = AuthorizationRequest.model_validate(params) + params = request.query_params else: # Parse form data for POST requests - form_data = await request.form() - params = dict(form_data) + params = await request.form() + + # Save state if it exists, even before validation + state = best_effort_extract_string("state", params) + + try: auth_request = AuthorizationRequest.model_validate(params) - except ValidationError as e: - raise InvalidRequestError(str(e)) + state = auth_request.state # Update with validated state + except ValidationError as validation_error: + error: ErrorCode = "invalid_request" + for e in validation_error.errors(): + if e['loc'] == ('response_type',) and e['type'] == 'literal_error': + error = "unsupported_response_type" + break + return await error_response(error, stringify_pydantic_error(validation_error)) - # Get client information - try: + # Get client information client = await provider.clients_store.get_client(auth_request.client_id) - except OAuthError as e: - # TODO: proper error rendering - raise InvalidClientError(str(e)) - - if not client: - raise InvalidClientError(f"Client ID '{auth_request.client_id}' not found") - - # do validation which is dependent on the client configuration - redirect_uri = validate_redirect_uri(auth_request, client) - scopes = validate_scope(auth_request.scope, client) - - auth_params = AuthorizationParams( - state=auth_request.state, - scopes=scopes, - code_challenge=auth_request.code_challenge, - redirect_uri=redirect_uri, - ) + if not client: + # For client_id validation errors, return direct error (no redirect) + return await error_response( + error="invalid_request", + error_description=f"Client ID '{auth_request.client_id}' not found", + attempt_load_client=False, + ) - try: + + # Validate redirect_uri against client's registered URIs + try: + redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client) + except InvalidRequestError as validation_error: + # For redirect_uri validation errors, return direct error (no redirect) + return await error_response( + error="invalid_request", + error_description=validation_error.message, + ) + + # Validate scope - for scope errors, we can redirect + try: + scopes = validate_scope(auth_request.scope, client) + except InvalidRequestError as validation_error: + # For scope errors, redirect with error parameters + return await error_response( + error="invalid_scope", + error_description=validation_error.message, + ) + + # Setup authorization parameters + auth_params = AuthorizationParams( + state=state, + scopes=scopes, + code_challenge=auth_request.code_challenge, + redirect_uri=redirect_uri, + ) + # Let the provider pick the next URI to redirect to response = RedirectResponse( url="", status_code=302, headers={"Cache-Control": "no-store"} @@ -133,36 +226,39 @@ async def authorization_handler(request: Request) -> Response: response.headers["location"] = await provider.authorize( client, auth_params ) - return response - except Exception as e: - logger.exception("error from authorize()", exc_info=e) - - return RedirectResponse( - url=create_error_redirect(redirect_uri, e, auth_request.state), - status_code=302, - headers={"Cache-Control": "no-store"}, - ) + + except Exception as validation_error: + # Catch-all for unexpected errors + logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) + return await error_response(error="server_error", error_description="An unexpected error occurred") return authorization_handler def create_error_redirect( - redirect_uri: AnyUrl, error: Exception, state: Optional[str] + redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse] ) -> str: parsed_uri = urlparse(str(redirect_uri)) - if isinstance(error, OAuthError): + + if isinstance(error, ErrorResponse): + # Convert ErrorResponse to dict + error_dict = error.model_dump(exclude_none=True) + query_params = {} + for key, value in error_dict.items(): + if value is not None: + if key == "error_uri" and hasattr(value, "__str__"): + query_params[key] = str(value) + else: + query_params[key] = value + + elif isinstance(error, OAuthError): query_params = {"error": error.error_code, "error_description": str(error)} else: query_params = { - "error": "internal_error", + "error": "server_error", "error_description": "An unknown error occurred", } - # TODO: should we add error_uri? - # if error.error_uri: - # query_params["error_uri"] = str(error.error_uri) - if state: - query_params["state"] = state new_query = urlencode(query_params) if parsed_uri.query: diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index c15c1540ca..24109bda3a 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -211,12 +211,12 @@ async def revoke_token( """ ... -def construct_redirect_uri(redirect_uri_base: str, authorization_code: AuthorizationCode, state: Optional[str]) -> str: +def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: parsed_uri = urlparse(redirect_uri_base) query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] - query_params.append(("code", authorization_code.code)) - if state: - query_params.append(("state", state)) + for k, v in params.items(): + if v is not None: + query_params.append((k, v)) redirect_uri = urlunparse( parsed_uri._replace(query=urlencode(query_params)) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 8243ad7543..49a586d838 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -87,7 +87,7 @@ async def authorize( ) self.auth_codes[code.code] = code - return construct_redirect_uri(str(params.redirect_uri), code, params.state) + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str @@ -745,7 +745,7 @@ async def test_client_registration_empty_redirect_uris( @pytest.mark.anyio async def test_authorize_form_post( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge ): """Test the authorization endpoint using POST with form-encoded data.""" # Register a client @@ -762,14 +762,6 @@ async def test_authorize_form_post( assert response.status_code == 201 client_info = response.json() - # Create a PKCE challenge - code_verifier = "some_random_verifier_string" - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .rstrip("=") - ) - # Use POST with form-encoded data for authorization response = await test_client.post( "/authorize", @@ -777,7 +769,7 @@ async def test_authorize_form_post( "response_type": "code", "client_id": client_info["client_id"], "redirect_uri": "https://client.example.com/callback", - "code_challenge": code_challenge, + "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_form_state", }, @@ -794,7 +786,7 @@ async def test_authorize_form_post( @pytest.mark.anyio async def test_authorization_get( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge ): """Test the full authorization flow.""" # 1. Register a client @@ -811,29 +803,21 @@ async def test_authorization_get( assert response.status_code == 201 client_info = response.json() - # 2. Create a PKCE challenge - code_verifier = "some_random_verifier_string" - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .rstrip("=") - ) - - # 3. Request authorization using GET with query params + # 2. Request authorization using GET with query params response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": client_info["client_id"], "redirect_uri": "https://client.example.com/callback", - "code_challenge": code_challenge, + "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) assert response.status_code == 302 - # 4. Extract the authorization code from the redirect URL + # 3. Extract the authorization code from the redirect URL redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) @@ -842,7 +826,7 @@ async def test_authorization_get( assert query_params["state"][0] == "test_state" auth_code = query_params["code"][0] - # 5. Exchange the authorization code for tokens + # 4. Exchange the authorization code for tokens response = await test_client.post( "/token", data={ @@ -850,7 +834,7 @@ async def test_authorization_get( "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "code": auth_code, - "code_verifier": code_verifier, + "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": "https://client.example.com/callback", }, ) @@ -863,7 +847,7 @@ async def test_authorization_get( assert "expires_in" in token_response assert token_response["token_type"] == "bearer" - # 6. Verify the access token + # 5. Verify the access token access_token = token_response["access_token"] refresh_token = token_response["refresh_token"] @@ -873,7 +857,7 @@ async def test_authorization_get( assert "read" in auth_info.scopes assert "write" in auth_info.scopes - # 7. Refresh the token + # 6. Refresh the token response = await test_client.post( "/token", data={ @@ -892,7 +876,7 @@ async def test_authorization_get( assert new_token_response["access_token"] != access_token assert new_token_response["refresh_token"] != refresh_token - # 8. Revoke the token + # 7. Revoke the token response = await test_client.post( "/revoke", data={ @@ -914,7 +898,7 @@ class TestFastMCPWithAuth: """Test FastMCP server with authentication.""" @pytest.mark.anyio - async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider): + async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider, pkce_challenge): """Test creating a FastMCP server with authentication.""" # Create FastMCP server with auth provider mcp = FastMCP( @@ -963,14 +947,6 @@ def test_tool(x: int) -> str: assert response.status_code == 201 client_info = response.json() - # Create a PKCE challenge - code_verifier = "some_random_verifier_string" - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .rstrip("=") - ) - # Request authorization using POST with form-encoded data response = await test_client.post( "/authorize", @@ -978,7 +954,7 @@ def test_tool(x: int) -> str: "response_type": "code", "client_id": client_info["client_id"], "redirect_uri": "https://client.example.com/callback", - "code_challenge": code_challenge, + "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, @@ -1001,7 +977,7 @@ def test_tool(x: int) -> str: "client_id": client_info["client_id"], "client_secret": client_info["client_secret"], "code": auth_code, - "code_verifier": code_verifier, + "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": "https://client.example.com/callback", }, ) @@ -1051,3 +1027,270 @@ def test_tool(x: int) -> str: assert set(sse_data["result"]["capabilities"].keys()) == set( ("experimental", "prompts", "resources", "tools") ) + + +class TestAuthorizeEndpointErrors: + """Test error handling in the OAuth authorization endpoint.""" + + @pytest.mark.anyio + async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + """Test authorization endpoint with missing client_id. + + According to the OAuth2.0 spec, if client_id is missing, the server should + inform the resource owner and NOT redirect. + """ + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + # Missing client_id + "redirect_uri": "https://client.example.com/callback", + "state": "test_state", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256" + }, + ) + + # Should NOT redirect, should show an error page + assert response.status_code == 400 + # The response should include an error message about missing client_id + assert "client_id" in response.text.lower() + + @pytest.mark.anyio + async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + """Test authorization endpoint with invalid client_id. + + According to the OAuth2.0 spec, if client_id is invalid, the server should + inform the resource owner and NOT redirect. + """ + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": "invalid_client_id_that_does_not_exist", + "redirect_uri": "https://client.example.com/callback", + "state": "test_state", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256" + }, + ) + + # Should NOT redirect, should show an error page + assert response.status_code == 400 + # The response should include an error message about invalid client_id + assert "client" in response.text.lower() + + @pytest.mark.anyio + async def test_authorize_missing_redirect_uri( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with missing redirect_uri. + + If client has only one registered redirect_uri, it can be omitted. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + # Missing redirect_uri + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should redirect to the registered redirect_uri + assert response.status_code == 302, response.content + redirect_url = response.headers["location"] + assert redirect_url.startswith("https://client.example.com/callback") + + @pytest.mark.anyio + async def test_authorize_invalid_redirect_uri( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with invalid redirect_uri. + + According to the OAuth2.0 spec, if redirect_uri is invalid or doesn't match, + the server should inform the resource owner and NOT redirect. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + "redirect_uri": "https://attacker.example.com/callback", # Non-matching URI + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should NOT redirect, should show an error page + assert response.status_code == 400, response.content + # The response should include an error message about redirect_uri mismatch + assert "redirect" in response.text.lower() + + @pytest.mark.anyio + @pytest.mark.parametrize("registered_client", + [{"redirect_uris": ["https://client.example.com/callback", + "https://client.example.com/other-callback"]}], + indirect=True) + async def test_authorize_missing_redirect_uri_multiple_registered( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with missing redirect_uri when client has multiple registered URIs. + + If client has multiple registered redirect_uris, redirect_uri must be provided. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + # Missing redirect_uri + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should NOT redirect, should return a 400 error + assert response.status_code == 400 + # The response should include an error message about missing redirect_uri + assert "redirect_uri" in response.text.lower() + + @pytest.mark.anyio + async def test_authorize_unsupported_response_type( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with unsupported response_type. + + According to the OAuth2.0 spec, for other errors like unsupported_response_type, + the server should redirect with error parameters. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "token", # Unsupported (we only support "code") + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "unsupported_response_type" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + async def test_authorize_missing_response_type( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with missing response_type. + + Missing required parameter should result in invalid_request error. + """ + + response = await test_client.get( + "/authorize", + params={ + # Missing response_type + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "invalid_request" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + async def test_authorize_missing_pkce_challenge( + self, test_client: httpx.AsyncClient, registered_client + ): + """Test authorization endpoint with missing PKCE code_challenge. + + Missing PKCE parameters should result in invalid_request error. + """ + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + # Missing code_challenge + "state": "test_state", + # using default URL + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "invalid_request" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + async def test_authorize_invalid_scope( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with invalid scope. + + Invalid scope should redirect with invalid_scope error. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "scope": "invalid_scope_that_does_not_exist", + "state": "test_state", + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "invalid_scope" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" From 038fb045f076b6aa8b151febdd7f93f827834316 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 21:46:43 -0700 Subject: [PATCH 18/60] Hoist oauth token expiration check into bearer auth middleware --- src/mcp/server/auth/middleware/bearer_auth.py | 5 +++- src/mcp/server/auth/provider.py | 6 ++--- .../fastmcp/auth/test_auth_integration.py | 25 ++++++++----------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 796dba7046..b89d7eca3d 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -53,7 +53,10 @@ async def authenticate(self, conn: HTTPConnection): try: # Validate the token with the provider - auth_info = await self.provider.verify_access_token(token) + auth_info = await self.provider.load_access_token(token) + + if not auth_info: + raise InvalidTokenError("Invalid access token") if auth_info.expires_at and auth_info.expires_at < int(time.time()): raise InvalidTokenError("Token has expired") diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 24109bda3a..3013ae4397 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -183,9 +183,7 @@ async def exchange_refresh_token( """ ... - # TODO: consider methods to generate refresh tokens and access tokens - - async def verify_access_token(self, token: str) -> AuthInfo: + async def load_access_token(self, token: str) -> AuthInfo | None: """ Verifies an access token and returns information about it. @@ -193,7 +191,7 @@ async def verify_access_token(self, token: str) -> AuthInfo: token: The access token to verify. Returns: - Information about the verified token. + Information about the verified token, or None if the token is invalid. """ ... diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 49a586d838..a4e82b4d89 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -189,19 +189,14 @@ async def exchange_refresh_token( refresh_token=new_refresh_token, ) - async def verify_access_token(self, token: str) -> AuthInfo: - # Check if token exists - if token not in self.tokens: - raise InvalidTokenError("Invalid access token") - - # Get token info - token_info = self.tokens[token] + async def load_access_token(self, token: str) -> AuthInfo | None: + token_info = self.tokens.get(token) # Check if token is expired - if token_info.expires_at < int(time.time()): - raise InvalidTokenError("Access token has expired") + # if token_info.expires_at < int(time.time()): + # raise InvalidTokenError("Access token has expired") - return AuthInfo( + return token_info and AuthInfo( token=token, client_id=token_info.client_id, scopes=token_info.scopes, @@ -852,7 +847,8 @@ async def test_authorization_get( refresh_token = token_response["refresh_token"] # Create a test client with the token - auth_info = await mock_oauth_provider.verify_access_token(access_token) + auth_info = await mock_oauth_provider.load_access_token(access_token) + assert auth_info assert auth_info.client_id == client_info["client_id"] assert "read" in auth_info.scopes assert "write" in auth_info.scopes @@ -888,10 +884,9 @@ async def test_authorization_get( assert response.status_code == 200 # Verify that the token was revoked - with pytest.raises(InvalidTokenError): - await mock_oauth_provider.verify_access_token( - new_token_response["access_token"] - ) + assert await mock_oauth_provider.load_access_token( + new_token_response["access_token"] + ) is None class TestFastMCPWithAuth: From a4e17f3f13446e26fe1f6dc3dd7688e72a0c5bb5 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 10 Mar 2025 21:56:48 -0700 Subject: [PATCH 19/60] Add tests for /revoke validation --- src/mcp/server/auth/handlers/revoke.py | 9 +++++- .../fastmcp/auth/test_auth_integration.py | 31 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 33d5e1af7c..3ede08c1f1 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -4,6 +4,7 @@ Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts """ +from tokenize import Token from typing import Callable from pydantic import ValidationError @@ -12,12 +13,15 @@ from mcp.server.auth.errors import ( InvalidRequestError, + stringify_pydantic_error, ) from mcp.server.auth.middleware.client_auth import ( ClientAuthenticator, ClientAuthRequest, ) from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.shared.auth import TokenErrorResponse class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): @@ -35,7 +39,10 @@ async def revocation_handler(request: Request) -> Response: form_data = await request.form() revocation_request = RevocationRequest.model_validate(dict(form_data)) except ValidationError as e: - raise InvalidRequestError(f"Invalid request body: {e}") + return PydanticJSONResponse(status_code=400,content=TokenErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(e) + )) # Authenticate client client_auth_result = await client_authenticator(revocation_request) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index a4e82b4d89..785c5a7ad4 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -887,6 +887,37 @@ async def test_authorization_get( assert await mock_oauth_provider.load_access_token( new_token_response["access_token"] ) is None + @pytest.mark.anyio + async def test_revoke_invalid_token(self, test_client, registered_client): + """Test revoking an invalid token.""" + response = await test_client.post( + "/revoke", + data={ + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "token": "invalid_token", + }, + ) + # per RFC, this should return 200 even if the token is invalid + assert response.status_code == 200 + @pytest.mark.anyio + async def test_revoke_with_malformed_token(self, test_client, registered_client): + response = await test_client.post( + "/revoke", + data={ + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "token": 123, + "token_type_hint": "asdf" + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_request" + assert "token_type_hint" in error_response["error_description"] + + + class TestFastMCPWithAuth: From 5f11c601f4743e2fed2dd1dd67671ad0edbbca22 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 07:04:02 -0700 Subject: [PATCH 20/60] Lint + typecheck --- src/mcp/server/auth/errors.py | 6 +- src/mcp/server/auth/handlers/authorize.py | 94 +++-- src/mcp/server/auth/handlers/register.py | 29 +- src/mcp/server/auth/handlers/revoke.py | 15 +- src/mcp/server/auth/handlers/token.py | 127 ++++--- src/mcp/server/auth/middleware/client_auth.py | 54 +-- src/mcp/server/auth/provider.py | 23 +- src/mcp/shared/auth.py | 12 +- .../fastmcp/auth/test_auth_integration.py | 340 +++++++++++------- 9 files changed, 392 insertions(+), 308 deletions(-) diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index 863a17b55e..cc92b33894 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -146,5 +146,9 @@ class InsufficientScopeError(OAuthError): error_code = "insufficient_scope" + def stringify_pydantic_error(validation_error: ValidationError) -> str: - return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors()) \ No newline at end of file + return "\n".join( + f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" + for e in validation_error.errors() + ) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 9d0b3c1d3a..31a9eee212 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -4,8 +4,9 @@ Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts """ +import logging from typing import Callable, Literal, Optional, Union -from urllib.parse import parse_qs, urlencode, urlparse, urlunparse +from urllib.parse import urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.datastructures import FormData, QueryParams @@ -13,16 +14,17 @@ from starlette.responses import RedirectResponse, Response from mcp.server.auth.errors import ( - InvalidClientError, InvalidRequestError, OAuthError, stringify_pydantic_error, ) -from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider, construct_redirect_uri -from mcp.shared.auth import OAuthClientInformationFull from mcp.server.auth.json_response import PydanticJSONResponse - -import logging +from mcp.server.auth.provider import ( + AuthorizationParams, + OAuthServerProvider, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull logger = logging.getLogger(__name__) @@ -48,7 +50,6 @@ class AuthorizationRequest(BaseModel): description="Optional scope; if specified, should be " "a space-separated list of scope strings", ) - def validate_scope( @@ -80,15 +81,19 @@ def validate_redirect_uri( raise InvalidRequestError( "redirect_uri must be specified when client has multiple registered URIs" ) + + ErrorCode = Literal[ - "invalid_request", - "unauthorized_client", - "access_denied", - "unsupported_response_type", - "invalid_scope", - "server_error", - "temporarily_unavailable" - ] + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", +] + + class ErrorResponse(BaseModel): error: ErrorCode error_description: str @@ -96,7 +101,10 @@ class ErrorResponse(BaseModel): # must be set if provided in the request state: Optional[str] -def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> Optional[str]: + +def best_effort_extract_string( + key: str, params: None | FormData | QueryParams +) -> Optional[str]: if params is None: return None value = params.get(key) @@ -104,6 +112,7 @@ def best_effort_extract_string(key: str, params: None | FormData | QueryParams) return value return None + class AnyHttpUrlModel(RootModel): root: AnyHttpUrl @@ -118,18 +127,24 @@ async def authorization_handler(request: Request) -> Response: client = None params = None - async def error_response(error: ErrorCode, error_description: str, attempt_load_client: bool = True): + async def error_response( + error: ErrorCode, error_description: str, attempt_load_client: bool = True + ): nonlocal client, redirect_uri, state if client is None and attempt_load_client: # make last-ditch attempt to load the client client_id = best_effort_extract_string("client_id", params) - client = client_id and await provider.clients_store.get_client(client_id) + client = client_id and await provider.clients_store.get_client( + client_id + ) if redirect_uri is None and client: # make last-ditch effort to load the redirect uri if params is not None and "redirect_uri" not in params: raw_redirect_uri = None else: - raw_redirect_uri = AnyHttpUrlModel.model_validate(best_effort_extract_string("redirect_uri", params)).root + raw_redirect_uri = AnyHttpUrlModel.model_validate( + best_effort_extract_string("redirect_uri", params) + ).root try: redirect_uri = validate_redirect_uri(raw_redirect_uri, client) except (ValidationError, InvalidRequestError): @@ -146,7 +161,9 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ if redirect_uri and client: return RedirectResponse( - url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), + url=construct_redirect_uri( + str(redirect_uri), **error_resp.model_dump(exclude_none=True) + ), status_code=302, headers={"Cache-Control": "no-store"}, ) @@ -156,7 +173,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ content=error_resp, headers={"Cache-Control": "no-store"}, ) - + try: # Parse request parameters if request.method == "GET": @@ -165,20 +182,22 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ else: # Parse form data for POST requests params = await request.form() - + # Save state if it exists, even before validation state = best_effort_extract_string("state", params) - + try: auth_request = AuthorizationRequest.model_validate(params) state = auth_request.state # Update with validated state except ValidationError as validation_error: error: ErrorCode = "invalid_request" for e in validation_error.errors(): - if e['loc'] == ('response_type',) and e['type'] == 'literal_error': + if e["loc"] == ("response_type",) and e["type"] == "literal_error": error = "unsupported_response_type" break - return await error_response(error, stringify_pydantic_error(validation_error)) + return await error_response( + error, stringify_pydantic_error(validation_error) + ) # Get client information client = await provider.clients_store.get_client(auth_request.client_id) @@ -190,7 +209,6 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ attempt_load_client=False, ) - # Validate redirect_uri against client's registered URIs try: redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client) @@ -200,7 +218,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ error="invalid_request", error_description=validation_error.message, ) - + # Validate scope - for scope errors, we can redirect try: scopes = validate_scope(auth_request.scope, client) @@ -210,7 +228,7 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ error="invalid_scope", error_description=validation_error.message, ) - + # Setup authorization parameters auth_params = AuthorizationParams( state=state, @@ -218,20 +236,22 @@ async def error_response(error: ErrorCode, error_description: str, attempt_load_ code_challenge=auth_request.code_challenge, redirect_uri=redirect_uri, ) - + # Let the provider pick the next URI to redirect to response = RedirectResponse( url="", status_code=302, headers={"Cache-Control": "no-store"} ) - response.headers["location"] = await provider.authorize( - client, auth_params - ) + response.headers["location"] = await provider.authorize(client, auth_params) return response - + except Exception as validation_error: # Catch-all for unexpected errors - logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) - return await error_response(error="server_error", error_description="An unexpected error occurred") + logger.exception( + "Unexpected error in authorization_handler", exc_info=validation_error + ) + return await error_response( + error="server_error", error_description="An unexpected error occurred" + ) return authorization_handler @@ -240,7 +260,7 @@ def create_error_redirect( redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse] ) -> str: parsed_uri = urlparse(str(redirect_uri)) - + if isinstance(error, ErrorResponse): # Convert ErrorResponse to dict error_dict = error.model_dump(exclude_none=True) @@ -251,7 +271,7 @@ def create_error_redirect( query_params[key] = str(value) else: query_params[key] = value - + elif isinstance(error, OAuthError): query_params = {"error": error.error_code, "error_description": str(error)} else: diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 4378dc9493..f9e814f6de 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -11,20 +11,21 @@ from pydantic import BaseModel, ValidationError from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import Response -from mcp.server.auth.errors import ( - InvalidRequestError, - OAuthError, - ServerError, - stringify_pydantic_error, -) +from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + class ErrorResponse(BaseModel): - error: Literal["invalid_redirect_uri", "invalid_client_metadata", "invalid_software_statement", "unapproved_software_statement"] + error: Literal[ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_software_statement", + "unapproved_software_statement", + ] error_description: str @@ -38,11 +39,13 @@ async def registration_handler(request: Request) -> Response: body = await request.json() client_metadata = OAuthClientMetadata.model_validate(body) except ValidationError as validation_error: - return PydanticJSONResponse(content=ErrorResponse( - error="invalid_client_metadata", - error_description=stringify_pydantic_error(validation_error) - ), status_code=400) - raise InvalidRequestError(f"Invalid client metadata: {str(e)}") + return PydanticJSONResponse( + content=ErrorResponse( + error="invalid_client_metadata", + error_description=stringify_pydantic_error(validation_error), + ), + status_code=400, + ) client_id = str(uuid4()) client_secret = None diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 3ede08c1f1..1863685fc4 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -4,7 +4,6 @@ Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts """ -from tokenize import Token from typing import Callable from pydantic import ValidationError @@ -12,15 +11,14 @@ from starlette.responses import Response from mcp.server.auth.errors import ( - InvalidRequestError, stringify_pydantic_error, ) +from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( ClientAuthenticator, ClientAuthRequest, ) from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest -from mcp.server.auth.json_response import PydanticJSONResponse from mcp.shared.auth import TokenErrorResponse @@ -39,10 +37,13 @@ async def revocation_handler(request: Request) -> Response: form_data = await request.form() revocation_request = RevocationRequest.model_validate(dict(form_data)) except ValidationError as e: - return PydanticJSONResponse(status_code=400,content=TokenErrorResponse( - error="invalid_request", - error_description=stringify_pydantic_error(e) - )) + return PydanticJSONResponse( + status_code=400, + content=TokenErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(e), + ), + ) # Authenticate client client_auth_result = await client_authenticator(revocation_request) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0c8efe9292..c6dbcd0bb6 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -13,7 +13,6 @@ from starlette.requests import Request from mcp.server.auth.errors import ( - InvalidRequestError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse @@ -58,7 +57,7 @@ def response(obj: TokenSuccessResponse | TokenErrorResponse): status_code = 200 if isinstance(obj, TokenErrorResponse): status_code = 400 - + return PydanticJSONResponse( content=obj, status_code=status_code, @@ -73,19 +72,24 @@ async def token_handler(request: Request): form_data = await request.form() token_request = TokenRequest.model_validate(dict(form_data)).root except ValidationError as validation_error: - return response(TokenErrorResponse( - error="invalid_request", - error_description=stringify_pydantic_error(validation_error) - - )) + return response( + TokenErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(validation_error), + ) + ) client_info = await client_authenticator(token_request) if token_request.grant_type not in client_info.grant_types: - return response(TokenErrorResponse( - error="unsupported_grant_type", - error_description=f"Unsupported grant type (supported grant types are " - f"{client_info.grant_types})" - )) + return response( + TokenErrorResponse( + error="unsupported_grant_type", + error_description=( + f"Unsupported grant type (supported grant types are " + f"{client_info.grant_types})" + ), + ) + ) tokens: TokenSuccessResponse @@ -95,38 +99,50 @@ async def token_handler(request: Request): client_info, token_request.code ) if auth_code is None or auth_code.client_id != token_request.client_id: - # if the authoriation code belongs to a different client, pretend it doesn't exist - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"authorization code does not exist" - )) + # if code belongs to different client, pretend it doesn't exist + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code does not exist", + ) + ) # make auth codes expire after a deadline # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 if auth_code.expires_at < time.time(): - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"authorization code has expired" - )) + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code has expired", + ) + ) # verify redirect_uri doesn't change between /authorize and /tokens # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 if token_request.redirect_uri != auth_code.redirect_uri: - return response(TokenErrorResponse( - error="invalid_request", - error_description=f"redirect_uri did not match redirect_uri used when authorization code was created" - )) + return response( + TokenErrorResponse( + error="invalid_request", + error_description=( + "redirect_uri didn't match the one used when creating auth code" + ), + ) + ) # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() - hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + hashed_code_verifier = ( + base64.urlsafe_b64encode(sha256).decode().rstrip("=") + ) if hashed_code_verifier != auth_code.code_challenge: # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"incorrect code_verifier" - )) + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="incorrect code_verifier", + ) + ) # Exchange authorization code for tokens tokens = await provider.exchange_authorization_code( @@ -134,30 +150,47 @@ async def token_handler(request: Request): ) case RefreshTokenRequest(): - refresh_token = await provider.load_refresh_token(client_info, token_request.refresh_token) - if refresh_token is None or refresh_token.client_id != token_request.client_id: - # if the authoriation code belongs to a different client, pretend it doesn't exist - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"refresh token does not exist" - )) + refresh_token = await provider.load_refresh_token( + client_info, token_request.refresh_token + ) + if ( + refresh_token is None + or refresh_token.client_id != token_request.client_id + ): + # if token belongs to different client, pretend it doesn't exist + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token does not exist", + ) + ) if refresh_token.expires_at and refresh_token.expires_at < time.time(): - # if the authoriation code belongs to a different client, pretend it doesn't exist - return response(TokenErrorResponse( - error="invalid_grant", - error_description=f"refresh token has expired" - )) + # if the refresh token has expired, pretend it doesn't exist + return response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token has expired", + ) + ) # Parse scopes if provided - scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else refresh_token.scopes + ) for scope in scopes: if scope not in refresh_token.scopes: - return response(TokenErrorResponse( - error="invalid_scope", - error_description=f"cannot request scope `{scope}` not provided by refresh token" - )) + return response( + TokenErrorResponse( + error="invalid_scope", + error_description=( + f"cannot request scope `{scope}` not provided by refresh token" + ), + ) + ) # Exchange refresh token for new tokens tokens = await provider.exchange_refresh_token( diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index f24aefca28..df4732de3a 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -72,56 +72,4 @@ async def __call__(self, request: ClientAuthRequest) -> OAuthClientInformationFu ): raise InvalidClientError("Client secret has expired") - return client - - -class ClientAuthMiddleware: - """ - Middleware that authenticates clients using client_id and client_secret. - - This middleware will validate client credentials and store client information - in the request state. - """ - - def __init__( - self, - app: Any, - clients_store: OAuthRegisteredClientsStore, - ): - """ - Initialize the middleware. - - Args: - app: ASGI application - clients_store: Store for client information - """ - self.app = app - self.client_auth = ClientAuthenticator(clients_store) - - async def __call__(self, scope: Dict, receive: Callable, send: Callable) -> None: - """ - Process the request and authenticate the client. - - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - # Create a request object to access the request data - request = Request(scope, receive=receive) - - # Add client authentication to the request - try: - client = await self.client_auth(request) - # Store the client in the request state - request.state.client = client - except HTTPException: - # Continue without authentication - pass - - # Continue processing the request - await self.app(scope, receive, send) + return client \ No newline at end of file diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 3013ae4397..954d8a57ed 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -7,7 +7,7 @@ from typing import List, Literal, Optional, Protocol from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from pydantic import AnyHttpUrl, BaseModel from mcp.server.auth.types import AuthInfo from mcp.shared.auth import ( @@ -28,6 +28,7 @@ class AuthorizationParams(BaseModel): code_challenge: str redirect_uri: AnyHttpUrl + class AuthorizationCode(BaseModel): code: str scopes: list[str] @@ -36,6 +37,7 @@ class AuthorizationCode(BaseModel): code_challenge: str redirect_uri: AnyHttpUrl + class RefreshToken(BaseModel): token: str client_id: str @@ -51,6 +53,7 @@ class OAuthTokenRevocationRequest(BaseModel): token: str token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None + class OAuthRegisteredClientsStore(Protocol): """ Interface for storing and retrieving registered OAuth clients. @@ -70,9 +73,7 @@ async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFul """ ... - async def register_client( - self, client_info: OAuthClientInformationFull - ) -> None: + async def register_client(self, client_info: OAuthClientInformationFull) -> None: """ Registers a new client @@ -118,7 +119,7 @@ async def authorize( | | | | Redirect | |redirect_uri|<-----+ +------------------+ | | - +------------+ + +------------+ Implementations will need to define another handler on the MCP server return flow to perform the second redirect, and generates and stores an authorization @@ -161,8 +162,9 @@ async def exchange_authorization_code( """ ... - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: - ... + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshToken | None: ... async def exchange_refresh_token( self, @@ -209,6 +211,7 @@ async def revoke_token( """ ... + def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: parsed_uri = urlparse(redirect_uri_base) query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] @@ -216,7 +219,5 @@ def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: if v is not None: query_params.append((k, v)) - redirect_uri = urlunparse( - parsed_uri._replace(query=urlencode(query_params)) - ) - return redirect_uri \ No newline at end of file + redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params))) + return redirect_uri diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 9bcdaef150..963fcc7236 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -14,7 +14,14 @@ class TokenErrorResponse(BaseModel): See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 """ - error: Literal["invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"] + error: Literal[ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", + ] error_description: Optional[str] = None error_uri: Optional[AnyHttpUrl] = None @@ -102,9 +109,6 @@ class OAuthClientRegistrationError(BaseModel): error_description: Optional[str] = None - - - class OAuthMetadata(BaseModel): """ RFC 8414 OAuth 2.0 Authorization Server Metadata. diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 785c5a7ad4..ee04d78552 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -38,7 +38,6 @@ from mcp.shared.auth import ( OAuthClientInformationFull, TokenSuccessResponse, - TokenErrorResponse, ) from mcp.types import JSONRPCRequest @@ -55,9 +54,8 @@ async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFul async def register_client( self, client_info: OAuthClientInformationFull - ) -> OAuthClientInformationFull: + ): self.clients[client_info.client_id] = client_info - return client_info # Mock OAuth provider for testing @@ -79,15 +77,17 @@ async def authorize( # code and completes the redirect code = AuthorizationCode( code=f"code_{int(time.time())}", - client_id= client.client_id, - code_challenge= params.code_challenge, - redirect_uri= params.redirect_uri, + client_id=client.client_id, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, expires_at=time.time() + 300, - scopes=params.scopes or ["read", "write"] + scopes=params.scopes or ["read", "write"], ) self.auth_codes[code.code] = code - return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state) + return construct_redirect_uri( + str(params.redirect_uri), code=code.code, state=params.state + ) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str @@ -107,8 +107,8 @@ async def exchange_authorization_code( # Store the tokens self.tokens[access_token] = AuthInfo( token=access_token, - client_id= client.client_id, - scopes= authorization_code.scopes, + client_id=client.client_id, + scopes=authorization_code.scopes, expires_at=int(time.time()) + 3600, ) @@ -125,14 +125,16 @@ async def exchange_authorization_code( refresh_token=refresh_token, ) - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshToken | None: old_access_token = self.refresh_tokens.get(refresh_token) if old_access_token is None: return None token_info = self.tokens.get(old_access_token) if token_info is None: return None - + # Create a RefreshToken object that matches what is expected in later code refresh_obj = RefreshToken( token=refresh_token, @@ -140,7 +142,7 @@ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_t scopes=token_info.scopes, expires_at=token_info.expires_at, ) - + return refresh_obj async def exchange_refresh_token( @@ -269,10 +271,10 @@ def test_client(auth_app) -> httpx.AsyncClient: @pytest.fixture async def registered_client(test_client: httpx.AsyncClient, request): """Create and register a test client. - + Parameters can be customized via indirect parameterization: - @pytest.mark.parametrize("registered_client", - [{"grant_types": ["authorization_code"]}], + @pytest.mark.parametrize("registered_client", + [{"grant_types": ["authorization_code"]}], indirect=True) """ # Default client metadata @@ -281,14 +283,14 @@ async def registered_client(test_client: httpx.AsyncClient, request): "client_name": "Test Client", "grant_types": ["authorization_code", "refresh_token"], } - + # Override with any parameters from the test if hasattr(request, "param") and request.param: client_metadata.update(request.param) - + response = await test_client.post("/register", json=client_metadata) assert response.status_code == 201, f"Failed to register client: {response.content}" - + client_info = response.json() return client_info @@ -302,17 +304,17 @@ def pkce_challenge(): .decode() .rstrip("=") ) - + return {"code_verifier": code_verifier, "code_challenge": code_challenge} @pytest.fixture async def auth_code(test_client, registered_client, pkce_challenge, request): """Get an authorization code. - + Parameters can be customized via indirect parameterization: - @pytest.mark.parametrize("auth_code", - [{"redirect_uri": "https://client.example.com/other-callback"}], + @pytest.mark.parametrize("auth_code", + [{"redirect_uri": "https://client.example.com/other-callback"}], indirect=True) """ # Default authorize params @@ -324,22 +326,22 @@ async def auth_code(test_client, registered_client, pkce_challenge, request): "code_challenge_method": "S256", "state": "test_state", } - + # Override with any parameters from the test if hasattr(request, "param") and request.param: auth_params.update(request.param) - + response = await test_client.get("/authorize", params=auth_params) assert response.status_code == 302, f"Failed to get auth code: {response.content}" - + # Extract the authorization code redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "code" in query_params, f"No code in response: {query_params}" auth_code = query_params["code"][0] - + return { "code": auth_code, "redirect_uri": auth_params["redirect_uri"], @@ -350,10 +352,10 @@ async def auth_code(test_client, registered_client, pkce_challenge, request): @pytest.fixture async def tokens(test_client, registered_client, auth_code, pkce_challenge, request): """Exchange authorization code for tokens. - + Parameters can be customized via indirect parameterization: - @pytest.mark.parametrize("tokens", - [{"code_verifier": "wrong_verifier"}], + @pytest.mark.parametrize("tokens", + [{"code_verifier": "wrong_verifier"}], indirect=True) """ # Default token request params @@ -365,13 +367,13 @@ async def tokens(test_client, registered_client, auth_code, pkce_challenge, requ "code_verifier": pkce_challenge["code_verifier"], "redirect_uri": auth_code["redirect_uri"], } - + # Override with any parameters from the test if hasattr(request, "param") and request.param: token_params.update(request.param) - + response = await test_client.post("/token", data=token_params) - + # Don't assert success here since some tests will intentionally cause errors return { "response": response, @@ -408,7 +410,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): "refresh_token", ] assert metadata["service_documentation"] == "https://docs.example.com" - + @pytest.mark.anyio async def test_token_validation_error(self, test_client: httpx.AsyncClient): """Test token endpoint error - validation error.""" @@ -422,13 +424,17 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient): ) error_response = response.json() assert error_response["error"] == "invalid_request" - assert "error_description" in error_response # Contains validation error messages - + assert ( + "error_description" in error_response + ) # Contains validation error messages + @pytest.mark.anyio - @pytest.mark.parametrize("registered_client", [{"grant_types": ["authorization_code"]}], indirect=True) + @pytest.mark.parametrize( + "registered_client", [{"grant_types": ["authorization_code"]}], indirect=True + ) async def test_token_unsupported_grant_type(self, test_client, registered_client): """Test token endpoint error - unsupported grant type.""" - # Try to use refresh_token grant type with a client that only supports authorization_code + # Try refresh_token grant with client that only supports authorization_code response = await test_client.post( "/token", data={ @@ -442,9 +448,11 @@ async def test_token_unsupported_grant_type(self, test_client, registered_client error_response = response.json() assert error_response["error"] == "unsupported_grant_type" assert "supported grant types" in error_response["error_description"] - + @pytest.mark.anyio - async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge): + async def test_token_invalid_auth_code( + self, test_client, registered_client, pkce_challenge + ): """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( @@ -464,29 +472,36 @@ async def test_token_invalid_auth_code(self, test_client, registered_client, pkc assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert "authorization code does not exist" in error_response["error_description"] - + assert ( + "authorization code does not exist" in error_response["error_description"] + ) + @pytest.mark.anyio async def test_token_expired_auth_code( - self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider + self, + test_client, + registered_client, + auth_code, + pkce_challenge, + mock_oauth_provider, ): """Test token endpoint error - authorization code has expired.""" # Get the current time for our time mocking current_time = time.time() - - # Find the auth code object + + # Find the auth code object code_value = auth_code["code"] found_code = None for code_obj in mock_oauth_provider.auth_codes.values(): if code_obj.code == code_value: found_code = code_obj break - + assert found_code is not None - + # Authorization codes are typically short-lived (5 minutes = 300 seconds) # So we'll mock time to be 10 minutes (600 seconds) in the future - with unittest.mock.patch('time.time', return_value=current_time + 600): + with unittest.mock.patch("time.time", return_value=current_time + 600): # Try to use the expired authorization code response = await test_client.post( "/token", @@ -502,14 +517,26 @@ async def test_token_expired_auth_code( assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert "authorization code has expired" in error_response["error_description"] - + assert ( + "authorization code has expired" in error_response["error_description"] + ) + @pytest.mark.anyio - @pytest.mark.parametrize("registered_client", - [{"redirect_uris": ["https://client.example.com/callback", - "https://client.example.com/other-callback"]}], - indirect=True) - async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge): + @pytest.mark.parametrize( + "registered_client", + [ + { + "redirect_uris": [ + "https://client.example.com/callback", + "https://client.example.com/other-callback", + ] + } + ], + indirect=True, + ) + async def test_token_redirect_uri_mismatch( + self, test_client, registered_client, auth_code, pkce_challenge + ): """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( @@ -520,16 +547,19 @@ async def test_token_redirect_uri_mismatch(self, test_client, registered_client, "client_secret": registered_client["client_secret"], "code": auth_code["code"], "code_verifier": pkce_challenge["code_verifier"], - "redirect_uri": "https://client.example.com/other-callback", # Different from the one used in /authorize + # Different from the one used in /authorize + "redirect_uri": "https://client.example.com/other-callback", }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_request" assert "redirect_uri did not match" in error_response["error_description"] - + @pytest.mark.anyio - async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code): + async def test_token_code_verifier_mismatch( + self, test_client, registered_client, auth_code + ): """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( @@ -539,7 +569,8 @@ async def test_token_code_verifier_mismatch(self, test_client, registered_client "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "code": auth_code["code"], - "code_verifier": "incorrect_code_verifier", # Different from the one used to create challenge + # Different from the one used to create challenge + "code_verifier": "incorrect_code_verifier", "redirect_uri": auth_code["redirect_uri"], }, ) @@ -547,7 +578,7 @@ async def test_token_code_verifier_mismatch(self, test_client, registered_client error_response = response.json() assert error_response["error"] == "invalid_grant" assert "incorrect code_verifier" in error_response["error_description"] - + @pytest.mark.anyio async def test_token_invalid_refresh_token(self, test_client, registered_client): """Test token endpoint error - refresh token does not exist.""" @@ -565,15 +596,20 @@ async def test_token_invalid_refresh_token(self, test_client, registered_client) error_response = response.json() assert error_response["error"] == "invalid_grant" assert "refresh token does not exist" in error_response["error_description"] - + @pytest.mark.anyio async def test_token_expired_refresh_token( - self, test_client, registered_client, auth_code, pkce_challenge, mock_oauth_provider + self, + test_client, + registered_client, + auth_code, + pkce_challenge, + mock_oauth_provider, ): """Test token endpoint error - refresh token has expired.""" # Step 1: First, let's create a token and refresh token at the current time current_time = time.time() - + # Exchange authorization code for tokens normally token_response = await test_client.post( "/token", @@ -589,10 +625,12 @@ async def test_token_expired_refresh_token( assert token_response.status_code == 200 tokens = token_response.json() refresh_token = tokens["refresh_token"] - - # Step 2: Now let's time travel forward 4 hours (tokens expire in 1 hour by default) + + # Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default) # Mock the time.time() function to return a value 4 hours in the future - with unittest.mock.patch('time.time', return_value=current_time + 14400): # 4 hours = 14400 seconds + with unittest.mock.patch( + "time.time", return_value=current_time + 14400 + ): # 4 hours = 14400 seconds # Try to use the refresh token which should now be considered expired response = await test_client.post( "/token", @@ -603,13 +641,13 @@ async def test_token_expired_refresh_token( "refresh_token": refresh_token, }, ) - + # In the "future", the token should be considered expired assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" assert "refresh token has expired" in error_response["error_description"] - + @pytest.mark.anyio async def test_token_invalid_scope( self, test_client, registered_client, auth_code, pkce_challenge @@ -628,10 +666,10 @@ async def test_token_invalid_scope( }, ) assert token_response.status_code == 200 - + tokens = token_response.json() refresh_token = tokens["refresh_token"] - + # Try to use refresh token with an invalid scope response = await test_client.post( "/token", @@ -675,7 +713,7 @@ async def test_client_registration( # assert await mock_oauth_provider.clients_store.get_client( # client_info["client_id"] # ) is not None - + @pytest.mark.anyio async def test_client_registration_missing_required_fields( self, test_client: httpx.AsyncClient @@ -696,7 +734,7 @@ async def test_client_registration_missing_required_fields( assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert error_data["error_description"] == "redirect_uris: Field required" - + @pytest.mark.anyio async def test_client_registration_invalid_uri( self, test_client: httpx.AsyncClient @@ -716,8 +754,14 @@ async def test_client_registration_invalid_uri( error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert error_data["error_description"] == "redirect_uris.0: Input should be a valid URL, relative URL without a base" - + assert ( + error_data["error_description"] + == ( + "redirect_uris.0: Input should be a valid URL, " + "relative URL without a base" + ) + ) + @pytest.mark.anyio async def test_client_registration_empty_redirect_uris( self, test_client: httpx.AsyncClient @@ -736,11 +780,17 @@ async def test_client_registration_empty_redirect_uris( error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0" - + assert ( + error_data["error_description"] + == "redirect_uris: List should have at least 1 item after validation, not 0" + ) + @pytest.mark.anyio async def test_authorize_form_post( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge + self, + test_client: httpx.AsyncClient, + mock_oauth_provider: MockOAuthProvider, + pkce_challenge, ): """Test the authorization endpoint using POST with form-encoded data.""" # Register a client @@ -781,7 +831,10 @@ async def test_authorize_form_post( @pytest.mark.anyio async def test_authorization_get( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge + self, + test_client: httpx.AsyncClient, + mock_oauth_provider: MockOAuthProvider, + pkce_challenge, ): """Test the full authorization flow.""" # 1. Register a client @@ -884,9 +937,13 @@ async def test_authorization_get( assert response.status_code == 200 # Verify that the token was revoked - assert await mock_oauth_provider.load_access_token( - new_token_response["access_token"] - ) is None + assert ( + await mock_oauth_provider.load_access_token( + new_token_response["access_token"] + ) + is None + ) + @pytest.mark.anyio async def test_revoke_invalid_token(self, test_client, registered_client): """Test revoking an invalid token.""" @@ -900,6 +957,7 @@ async def test_revoke_invalid_token(self, test_client, registered_client): ) # per RFC, this should return 200 even if the token is invalid assert response.status_code == 200 + @pytest.mark.anyio async def test_revoke_with_malformed_token(self, test_client, registered_client): response = await test_client.post( @@ -908,23 +966,22 @@ async def test_revoke_with_malformed_token(self, test_client, registered_client) "client_id": registered_client["client_id"], "client_secret": registered_client["client_secret"], "token": 123, - "token_type_hint": "asdf" + "token_type_hint": "asdf", }, ) assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_request" assert "token_type_hint" in error_response["error_description"] - - - class TestFastMCPWithAuth: """Test FastMCP server with authentication.""" @pytest.mark.anyio - async def test_fastmcp_with_auth(self, mock_oauth_provider: MockOAuthProvider, pkce_challenge): + async def test_fastmcp_with_auth( + self, mock_oauth_provider: MockOAuthProvider, pkce_challenge + ): """Test creating a FastMCP server with authentication.""" # Create FastMCP server with auth provider mcp = FastMCP( @@ -1053,15 +1110,17 @@ def test_tool(x: int) -> str: assert set(sse_data["result"]["capabilities"].keys()) == set( ("experimental", "prompts", "resources", "tools") ) - - + + class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" - + @pytest.mark.anyio - async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + async def test_authorize_missing_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge + ): """Test authorization endpoint with missing client_id. - + According to the OAuth2.0 spec, if client_id is missing, the server should inform the resource owner and NOT redirect. """ @@ -1073,19 +1132,21 @@ async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, "redirect_uri": "https://client.example.com/callback", "state": "test_state", "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256" + "code_challenge_method": "S256", }, ) - + # Should NOT redirect, should show an error page assert response.status_code == 400 # The response should include an error message about missing client_id assert "client_id" in response.text.lower() - + @pytest.mark.anyio - async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): + async def test_authorize_invalid_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge + ): """Test authorization endpoint with invalid client_id. - + According to the OAuth2.0 spec, if client_id is invalid, the server should inform the resource owner and NOT redirect. """ @@ -1097,24 +1158,24 @@ async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, "redirect_uri": "https://client.example.com/callback", "state": "test_state", "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256" + "code_challenge_method": "S256", }, ) - + # Should NOT redirect, should show an error page assert response.status_code == 400 # The response should include an error message about invalid client_id assert "client" in response.text.lower() - + @pytest.mark.anyio async def test_authorize_missing_redirect_uri( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with missing redirect_uri. - + If client has only one registered redirect_uri, it can be omitted. """ - + response = await test_client.get( "/authorize", params={ @@ -1126,52 +1187,61 @@ async def test_authorize_missing_redirect_uri( "state": "test_state", }, ) - + # Should redirect to the registered redirect_uri assert response.status_code == 302, response.content redirect_url = response.headers["location"] assert redirect_url.startswith("https://client.example.com/callback") - + @pytest.mark.anyio async def test_authorize_invalid_redirect_uri( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with invalid redirect_uri. - + According to the OAuth2.0 spec, if redirect_uri is invalid or doesn't match, the server should inform the resource owner and NOT redirect. """ - + response = await test_client.get( "/authorize", params={ "response_type": "code", "client_id": registered_client["client_id"], - "redirect_uri": "https://attacker.example.com/callback", # Non-matching URI + # Non-matching URI + "redirect_uri": "https://attacker.example.com/callback", "code_challenge": pkce_challenge["code_challenge"], "code_challenge_method": "S256", "state": "test_state", }, ) - + # Should NOT redirect, should show an error page assert response.status_code == 400, response.content # The response should include an error message about redirect_uri mismatch assert "redirect" in response.text.lower() @pytest.mark.anyio - @pytest.mark.parametrize("registered_client", - [{"redirect_uris": ["https://client.example.com/callback", - "https://client.example.com/other-callback"]}], - indirect=True) + @pytest.mark.parametrize( + "registered_client", + [ + { + "redirect_uris": [ + "https://client.example.com/callback", + "https://client.example.com/other-callback", + ] + } + ], + indirect=True, + ) async def test_authorize_missing_redirect_uri_multiple_registered( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): - """Test authorization endpoint with missing redirect_uri when client has multiple registered URIs. - + """Test endpoint with missing redirect_uri with multiple registered URIs. + If client has multiple registered redirect_uris, redirect_uri must be provided. """ - + response = await test_client.get( "/authorize", params={ @@ -1183,22 +1253,22 @@ async def test_authorize_missing_redirect_uri_multiple_registered( "state": "test_state", }, ) - + # Should NOT redirect, should return a 400 error assert response.status_code == 400 # The response should include an error message about missing redirect_uri assert "redirect_uri" in response.text.lower() - + @pytest.mark.anyio async def test_authorize_unsupported_response_type( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with unsupported response_type. - + According to the OAuth2.0 spec, for other errors like unsupported_response_type, the server should redirect with error parameters. """ - + response = await test_client.get( "/authorize", params={ @@ -1210,28 +1280,28 @@ async def test_authorize_unsupported_response_type( "state": "test_state", }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "unsupported_response_type" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" - + @pytest.mark.anyio async def test_authorize_missing_response_type( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with missing response_type. - + Missing required parameter should result in invalid_request error. """ - + response = await test_client.get( "/authorize", params={ @@ -1243,25 +1313,25 @@ async def test_authorize_missing_response_type( "state": "test_state", }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "invalid_request" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" - + @pytest.mark.anyio async def test_authorize_missing_pkce_challenge( self, test_client: httpx.AsyncClient, registered_client ): """Test authorization endpoint with missing PKCE code_challenge. - + Missing PKCE parameters should result in invalid_request error. """ response = await test_client.get( @@ -1274,28 +1344,28 @@ async def test_authorize_missing_pkce_challenge( # using default URL }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "invalid_request" # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" - + @pytest.mark.anyio async def test_authorize_invalid_scope( self, test_client: httpx.AsyncClient, registered_client, pkce_challenge ): """Test authorization endpoint with invalid scope. - + Invalid scope should redirect with invalid_scope error. """ - + response = await test_client.get( "/authorize", params={ @@ -1308,13 +1378,13 @@ async def test_authorize_invalid_scope( "state": "test_state", }, ) - + # Should redirect with error parameters assert response.status_code == 302 redirect_url = response.headers["location"] parsed_url = urlparse(redirect_url) query_params = parse_qs(parsed_url.query) - + assert "error" in query_params assert query_params["error"][0] == "invalid_scope" # State should be preserved From 571913a89397a9f1c29e7150d9cf3adecc367073 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 07:36:22 -0700 Subject: [PATCH 21/60] Clean up unused error classes --- src/mcp/server/auth/errors.py | 126 +++--------------- src/mcp/server/auth/handlers/authorize.py | 4 +- src/mcp/server/auth/handlers/register.py | 4 +- src/mcp/server/auth/handlers/revoke.py | 6 +- src/mcp/server/auth/handlers/token.py | 15 ++- src/mcp/server/auth/middleware/bearer_auth.py | 19 +-- src/mcp/server/auth/middleware/client_auth.py | 4 +- .../fastmcp/auth/test_auth_integration.py | 25 ++-- 8 files changed, 61 insertions(+), 142 deletions(-) diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index cc92b33894..08686d2eb3 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -4,9 +4,15 @@ Corresponds to TypeScript file: src/server/auth/errors.ts """ -from typing import Dict +from typing import Literal -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError + +ErrorCode = Literal["invalid_request", "invalid_client"] + +class ErrorResponse(BaseModel): + error: ErrorCode + error_description: str class OAuthError(Exception): @@ -16,25 +22,17 @@ class OAuthError(Exception): Corresponds to OAuthError in src/server/auth/errors.ts """ - error_code: str = "server_error" - - def __init__(self, message: str): - super().__init__(message) - self.message = message - - def to_response_object(self) -> Dict[str, str]: - """Convert error to JSON response object.""" - return {"error": self.error_code, "error_description": self.message} + error_code: ErrorCode + def __init__(self, error_description: str): + super().__init__(error_description) + self.error_description = error_description -class ServerError(OAuthError): - """ - Server error. - - Corresponds to ServerError in src/server/auth/errors.ts - """ - - error_code = "server_error" + def error_response(self) -> ErrorResponse: + return ErrorResponse( + error=self.error_code, + error_description=self.error_description, + ) class InvalidRequestError(OAuthError): @@ -57,96 +55,6 @@ class InvalidClientError(OAuthError): error_code = "invalid_client" -class InvalidGrantError(OAuthError): - """ - Invalid grant error. - - Corresponds to InvalidGrantError in src/server/auth/errors.ts - """ - - error_code = "invalid_grant" - - -class UnauthorizedClientError(OAuthError): - """ - Unauthorized client error. - - Corresponds to UnauthorizedClientError in src/server/auth/errors.ts - """ - - error_code = "unauthorized_client" - - -class UnsupportedGrantTypeError(OAuthError): - """ - Unsupported grant type error. - - Corresponds to UnsupportedGrantTypeError in src/server/auth/errors.ts - """ - - error_code = "unsupported_grant_type" - - -class UnsupportedResponseTypeError(OAuthError): - """ - Unsupported response type error. - - Corresponds to UnsupportedResponseTypeError in src/server/auth/errors.ts - """ - - error_code = "unsupported_response_type" - - -class InvalidScopeError(OAuthError): - """ - Invalid scope error. - - Corresponds to InvalidScopeError in src/server/auth/errors.ts - """ - - error_code = "invalid_scope" - - -class AccessDeniedError(OAuthError): - """ - Access denied error. - - Corresponds to AccessDeniedError in src/server/auth/errors.ts - """ - - error_code = "access_denied" - - -class TemporarilyUnavailableError(OAuthError): - """ - Temporarily unavailable error. - - Corresponds to TemporarilyUnavailableError in src/server/auth/errors.ts - """ - - error_code = "temporarily_unavailable" - - -class InvalidTokenError(OAuthError): - """ - Invalid token error. - - Corresponds to InvalidTokenError in src/server/auth/errors.ts - """ - - error_code = "invalid_token" - - -class InsufficientScopeError(OAuthError): - """ - Insufficient scope error. - - Corresponds to InsufficientScopeError in src/server/auth/errors.ts - """ - - error_code = "insufficient_scope" - - def stringify_pydantic_error(validation_error: ValidationError) -> str: return "\n".join( f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 31a9eee212..d86408dc55 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -216,7 +216,7 @@ async def error_response( # For redirect_uri validation errors, return direct error (no redirect) return await error_response( error="invalid_request", - error_description=validation_error.message, + error_description=validation_error.error_description, ) # Validate scope - for scope errors, we can redirect @@ -226,7 +226,7 @@ async def error_response( # For scope errors, redirect with error parameters return await error_response( error="invalid_scope", - error_description=validation_error.message, + error_description=validation_error.error_description, ) # Setup authorization parameters diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index f9e814f6de..66afd1b66c 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -83,9 +83,9 @@ async def registration_handler(request: Request) -> Response: software_version=client_metadata.software_version, ) # Register client - client = await clients_store.register_client(client_info) + await clients_store.register_client(client_info) # Return client information - return PydanticJSONResponse(content=client, status_code=201) + return PydanticJSONResponse(content=client_info, status_code=201) return registration_handler diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 1863685fc4..01f126cba2 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -11,6 +11,7 @@ from starlette.responses import Response from mcp.server.auth.errors import ( + InvalidClientError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse @@ -46,7 +47,10 @@ async def revocation_handler(request: Request) -> Response: ) # Authenticate client - client_auth_result = await client_authenticator(revocation_request) + try: + client_auth_result = await client_authenticator(revocation_request) + except InvalidClientError as e: + return PydanticJSONResponse(status_code=401, content=e.error_response()) # Revoke token if provider.revoke_token: diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index c6dbcd0bb6..b67bf5bd99 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -13,6 +13,8 @@ from starlette.requests import Request from mcp.server.auth.errors import ( + ErrorResponse, + InvalidClientError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse @@ -53,7 +55,7 @@ class TokenRequest(RootModel): def create_token_handler( provider: OAuthServerProvider, client_authenticator: ClientAuthenticator ) -> Callable: - def response(obj: TokenSuccessResponse | TokenErrorResponse): + def response(obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse): status_code = 200 if isinstance(obj, TokenErrorResponse): status_code = 400 @@ -78,7 +80,11 @@ async def token_handler(request: Request): error_description=stringify_pydantic_error(validation_error), ) ) - client_info = await client_authenticator(token_request) + + try: + client_info = await client_authenticator(token_request) + except InvalidClientError as e: + return response(e.error_response()) if token_request.grant_type not in client_info.grant_types: return response( @@ -124,8 +130,9 @@ async def token_handler(request: Request): TokenErrorResponse( error="invalid_request", error_description=( - "redirect_uri didn't match the one used when creating auth code" - ), + "redirect_uri did not match the one " + "used when creating auth code" + ), ) ) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index b89d7eca3d..ab597ac903 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -16,7 +16,6 @@ from starlette.requests import HTTPConnection from starlette.types import Scope -from mcp.server.auth.errors import InsufficientScopeError, InvalidTokenError, OAuthError from mcp.server.auth.provider import OAuthServerProvider from mcp.server.auth.types import AuthInfo @@ -51,21 +50,17 @@ async def authenticate(self, conn: HTTPConnection): token = auth_header[7:] # Remove "Bearer " prefix - try: - # Validate the token with the provider - auth_info = await self.provider.load_access_token(token) + # Validate the token with the provider + auth_info = await self.provider.load_access_token(token) - if not auth_info: - raise InvalidTokenError("Invalid access token") + if not auth_info: + return None - if auth_info.expires_at and auth_info.expires_at < int(time.time()): - raise InvalidTokenError("Token has expired") + if auth_info.expires_at and auth_info.expires_at < int(time.time()): + return None - return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) + return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) - except (InvalidTokenError, InsufficientScopeError, OAuthError): - # Return None to indicate authentication failure - return None class RequireAuthMiddleware: diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index df4732de3a..3a16d960da 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -5,11 +5,9 @@ """ import time -from typing import Any, Callable, Dict, Optional +from typing import Optional from pydantic import BaseModel -from starlette.exceptions import HTTPException -from starlette.requests import Request from mcp.server.auth.errors import ( InvalidClientError, diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index ee04d78552..9f756c0507 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -18,7 +18,6 @@ from starlette.applications import Starlette from starlette.routing import Mount -from mcp.server.auth.errors import InvalidTokenError from mcp.server.auth.provider import ( AuthorizationCode, AuthorizationParams, @@ -97,8 +96,7 @@ async def load_authorization_code( async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode ) -> TokenSuccessResponse: - if authorization_code.code not in self.auth_codes: - raise InvalidTokenError("Invalid authorization code") + assert authorization_code.code in self.auth_codes # Generate an access token and refresh token access_token = f"access_{secrets.token_hex(32)}" @@ -152,19 +150,16 @@ async def exchange_refresh_token( scopes: List[str], ) -> TokenSuccessResponse: # Check if refresh token exists - if refresh_token.token not in self.refresh_tokens: - raise InvalidTokenError("Invalid refresh token") + assert refresh_token.token in self.refresh_tokens old_access_token = self.refresh_tokens[refresh_token.token] # Check if the access token exists - if old_access_token not in self.tokens: - raise InvalidTokenError("Invalid refresh token") + assert old_access_token in self.tokens # Check if the token was issued to this client token_info = self.tokens[old_access_token] - if token_info.client_id != client.client_id: - raise InvalidTokenError("Refresh token was not issued to this client") + assert token_info.client_id == client.client_id # Generate a new access token and refresh token new_access_token = f"access_{secrets.token_hex(32)}" @@ -1017,6 +1012,18 @@ def test_tool(x: int) -> str: # TODO: we should return 401/403 depending on whether authn or authz fails assert response.status_code == 403, response.content + response = await test_client.post( + "/messages/", + headers={"Authorization": "invalid"}, + ) + assert response.status_code == 403 + + response = await test_client.post( + "/messages/", + headers={"Authorization": "Bearer invalid"}, + ) + assert response.status_code == 403 + # now, become authenticated and try to go through the flow again client_metadata = { "redirect_uris": ["https://client.example.com/callback"], From d43647f8f385207a04eb7f0eb737875cfbe70cfd Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 07:43:19 -0700 Subject: [PATCH 22/60] Update to use Python 3.10 types --- src/mcp/server/auth/handlers/authorize.py | 14 +++++++------- src/mcp/server/auth/handlers/metadata.py | 4 ++-- src/mcp/server/auth/handlers/token.py | 6 +++--- src/mcp/server/auth/middleware/client_auth.py | 7 ++----- src/mcp/server/auth/provider.py | 16 ++++++++-------- src/mcp/server/auth/router.py | 8 ++++---- src/mcp/server/auth/types.py | 6 ++---- 7 files changed, 28 insertions(+), 33 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index d86408dc55..b9bace0ca6 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -5,7 +5,7 @@ """ import logging -from typing import Callable, Literal, Optional, Union +from typing import Callable, Literal from urllib.parse import urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError @@ -44,8 +44,8 @@ class AuthorizationRequest(BaseModel): code_challenge_method: Literal["S256"] = Field( "S256", description="PKCE code challenge method, must be S256" ) - state: Optional[str] = Field(None, description="Optional state parameter") - scope: Optional[str] = Field( + state: str | None = Field(None, description="Optional state parameter") + scope: str | None = Field( None, description="Optional scope; if specified, should be " "a space-separated list of scope strings", @@ -97,14 +97,14 @@ def validate_redirect_uri( class ErrorResponse(BaseModel): error: ErrorCode error_description: str - error_uri: Optional[AnyUrl] = None + error_uri: AnyUrl | None = None # must be set if provided in the request - state: Optional[str] + state: str | None = None def best_effort_extract_string( key: str, params: None | FormData | QueryParams -) -> Optional[str]: +) -> str | None: if params is None: return None value = params.get(key) @@ -257,7 +257,7 @@ async def error_response( def create_error_redirect( - redirect_uri: AnyUrl, error: Union[Exception, ErrorResponse] + redirect_uri: AnyUrl, error: Exception | ErrorResponse ) -> str: parsed_uri = urlparse(str(redirect_uri)) diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index 11a9c904de..e77157af37 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -4,13 +4,13 @@ Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts """ -from typing import Any, Callable, Dict +from typing import Any, Callable from starlette.requests import Request from starlette.responses import JSONResponse, Response -def create_metadata_handler(metadata: Dict[str, Any]) -> Callable: +def create_metadata_handler(metadata: dict[str, Any]) -> Callable: """ Create a handler for OAuth 2.0 Authorization Server Metadata. diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index b67bf5bd99..01cf0554f3 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -7,7 +7,7 @@ import base64 import hashlib import time -from typing import Annotated, Callable, Literal, Optional, Union +from typing import Annotated, Callable, Literal from pydantic import AnyHttpUrl, Field, RootModel, ValidationError from starlette.requests import Request @@ -42,12 +42,12 @@ class RefreshTokenRequest(ClientAuthRequest): # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") - scope: Optional[str] = Field(None, description="Optional scope parameter") + scope: str | None = Field(None, description="Optional scope parameter") class TokenRequest(RootModel): root: Annotated[ - Union[AuthorizationCodeRequest, RefreshTokenRequest], + AuthorizationCodeRequest | RefreshTokenRequest, Field(discriminator="grant_type"), ] diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 3a16d960da..4546d92215 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -5,13 +5,10 @@ """ import time -from typing import Optional from pydantic import BaseModel -from mcp.server.auth.errors import ( - InvalidClientError, -) +from mcp.server.auth.errors import InvalidClientError from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull @@ -25,7 +22,7 @@ class ClientAuthRequest(BaseModel): """ client_id: str - client_secret: Optional[str] = None + client_secret: str | None = None class ClientAuthenticator: diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 954d8a57ed..6eb039746b 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -4,7 +4,7 @@ Corresponds to TypeScript file: src/server/auth/provider.ts """ -from typing import List, Literal, Optional, Protocol +from typing import Literal, Protocol from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, BaseModel @@ -23,8 +23,8 @@ class AuthorizationParams(BaseModel): Corresponds to AuthorizationParams in src/server/auth/provider.ts """ - state: Optional[str] = None - scopes: Optional[List[str]] = None + state: str | None = None + scopes: list[str] | None = None code_challenge: str redirect_uri: AnyHttpUrl @@ -41,8 +41,8 @@ class AuthorizationCode(BaseModel): class RefreshToken(BaseModel): token: str client_id: str - scopes: List[str] - expires_at: Optional[int] = None + scopes: list[str] + expires_at: int | None = None class OAuthTokenRevocationRequest(BaseModel): @@ -51,7 +51,7 @@ class OAuthTokenRevocationRequest(BaseModel): """ token: str - token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None + token_type_hint: Literal["access_token", "refresh_token"] | None = None class OAuthRegisteredClientsStore(Protocol): @@ -61,7 +61,7 @@ class OAuthRegisteredClientsStore(Protocol): Corresponds to OAuthRegisteredClientsStore in src/server/auth/clients.ts """ - async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """ Retrieves client information by client ID. @@ -170,7 +170,7 @@ async def exchange_refresh_token( self, client: OAuthClientInformationFull, refresh_token: RefreshToken, - scopes: List[str], + scopes: list[str], ) -> TokenSuccessResponse: """ Exchanges a refresh token for an access token. diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 4dfa8e6aee..5fa82f82b2 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any from pydantic import AnyUrl from starlette.routing import Route, Router @@ -23,7 +23,7 @@ @dataclass class ClientRegistrationOptions: enabled: bool = False - client_secret_expiry_seconds: Optional[int] = None + client_secret_expiry_seconds: int | None = None @dataclass @@ -143,10 +143,10 @@ def create_auth_router( def build_metadata( issuer_url: AnyUrl, - service_documentation_url: Optional[AnyUrl], + service_documentation_url: AnyUrl | None, client_registration_options: ClientRegistrationOptions, revocation_options: RevocationOptions, -) -> Dict[str, Any]: +) -> dict[str, Any]: issuer_url_str = str(issuer_url).rstrip("/") # Create metadata metadata = { diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py index f0593d8644..eb47b65770 100644 --- a/src/mcp/server/auth/types.py +++ b/src/mcp/server/auth/types.py @@ -4,8 +4,6 @@ Corresponds to TypeScript file: src/server/auth/types.ts """ -from typing import List, Optional - from pydantic import BaseModel @@ -18,5 +16,5 @@ class AuthInfo(BaseModel): token: str client_id: str - scopes: List[str] - expires_at: Optional[int] = None + scopes: list[str] + expires_at: int | None = None From 9d72c1e598f41e8e1741e20b541ce1776c57882d Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 07:56:54 -0700 Subject: [PATCH 23/60] Use classes for handlers --- src/mcp/server/auth/handlers/authorize.py | 22 ++++++--- src/mcp/server/auth/handlers/metadata.py | 33 +++----------- src/mcp/server/auth/handlers/register.py | 20 ++++----- src/mcp/server/auth/handlers/revoke.py | 20 ++++----- src/mcp/server/auth/handlers/token.py | 54 ++++++++++++----------- src/mcp/server/auth/router.py | 33 +++++++------- 6 files changed, 87 insertions(+), 95 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index b9bace0ca6..59ea1f62e9 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -5,7 +5,8 @@ """ import logging -from typing import Callable, Literal +from dataclasses import dataclass +from typing import Literal from urllib.parse import urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError @@ -117,8 +118,11 @@ class AnyHttpUrlModel(RootModel): root: AnyHttpUrl -def create_authorization_handler(provider: OAuthServerProvider) -> Callable: - async def authorization_handler(request: Request) -> Response: +@dataclass +class AuthorizationHandler: + provider: OAuthServerProvider + + async def handle(self, request: Request) -> Response: # implements authorization requests for grant_type=code; # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 @@ -134,7 +138,7 @@ async def error_response( if client is None and attempt_load_client: # make last-ditch attempt to load the client client_id = best_effort_extract_string("client_id", params) - client = client_id and await provider.clients_store.get_client( + client = client_id and await self.provider.clients_store.get_client( client_id ) if redirect_uri is None and client: @@ -200,7 +204,9 @@ async def error_response( ) # Get client information - client = await provider.clients_store.get_client(auth_request.client_id) + client = await self.provider.clients_store.get_client( + auth_request.client_id, + ) if not client: # For client_id validation errors, return direct error (no redirect) return await error_response( @@ -241,7 +247,10 @@ async def error_response( response = RedirectResponse( url="", status_code=302, headers={"Cache-Control": "no-store"} ) - response.headers["location"] = await provider.authorize(client, auth_params) + response.headers["location"] = await self.provider.authorize( + client, + auth_params, + ) return response except Exception as validation_error: @@ -253,7 +262,6 @@ async def error_response( error="server_error", error_description="An unexpected error occurred" ) - return authorization_handler def create_error_redirect( diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index e77157af37..39cc889402 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -4,41 +4,22 @@ Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts """ -from typing import Any, Callable +from dataclasses import dataclass +from typing import Any from starlette.requests import Request from starlette.responses import JSONResponse, Response -def create_metadata_handler(metadata: dict[str, Any]) -> Callable: - """ - Create a handler for OAuth 2.0 Authorization Server Metadata. +@dataclass +class MetadataHandler: + metadata: dict[str, Any] - Corresponds to metadataHandler in src/server/auth/handlers/metadata.ts - - Args: - metadata: The metadata to return in the response - - Returns: - A Starlette endpoint handler function - """ - - async def metadata_handler(request: Request) -> Response: - """ - Handler for the OAuth 2.0 Authorization Server Metadata endpoint. - - Args: - request: The Starlette request - - Returns: - JSON response with the authorization server metadata - """ + async def handle(self, request: Request) -> Response: # Remove any None values from metadata - clean_metadata = {k: v for k, v in metadata.items() if v is not None} + clean_metadata = {k: v for k, v in self.metadata.items() if v is not None} return JSONResponse( content=clean_metadata, headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour ) - - return metadata_handler diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 66afd1b66c..6c41b85855 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -6,7 +6,8 @@ import secrets import time -from typing import Callable, Literal +from dataclasses import dataclass +from typing import Literal from uuid import uuid4 from pydantic import BaseModel, ValidationError @@ -29,10 +30,11 @@ class ErrorResponse(BaseModel): error_description: str -def create_registration_handler( - clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None -) -> Callable: - async def registration_handler(request: Request) -> Response: +@dataclass +class RegistrationHandler: + clients_store: OAuthRegisteredClientsStore + client_secret_expiry_seconds: int | None + async def handle(self, request: Request) -> Response: # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: # Parse request body as JSON @@ -55,8 +57,8 @@ async def registration_handler(request: Request) -> Response: client_id_issued_at = int(time.time()) client_secret_expires_at = ( - client_id_issued_at + client_secret_expiry_seconds - if client_secret_expiry_seconds is not None + client_id_issued_at + self.client_secret_expiry_seconds + if self.client_secret_expiry_seconds is not None else None ) @@ -83,9 +85,7 @@ async def registration_handler(request: Request) -> Response: software_version=client_metadata.software_version, ) # Register client - await clients_store.register_client(client_info) + await self.clients_store.register_client(client_info) # Return client information return PydanticJSONResponse(content=client_info, status_code=201) - - return registration_handler diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 01f126cba2..d31fe62285 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -4,7 +4,7 @@ Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts """ -from typing import Callable +from dataclasses import dataclass from pydantic import ValidationError from starlette.requests import Request @@ -27,10 +27,12 @@ class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): pass -def create_revocation_handler( - provider: OAuthServerProvider, client_authenticator: ClientAuthenticator -) -> Callable: - async def revocation_handler(request: Request) -> Response: +@dataclass +class RevocationHandler: + provider: OAuthServerProvider + client_authenticator: ClientAuthenticator + + async def handle(self, request: Request) -> Response: """ Handler for the OAuth 2.0 Token Revocation endpoint. """ @@ -48,13 +50,13 @@ async def revocation_handler(request: Request) -> Response: # Authenticate client try: - client_auth_result = await client_authenticator(revocation_request) + client_auth_result = await self.client_authenticator(revocation_request) except InvalidClientError as e: return PydanticJSONResponse(status_code=401, content=e.error_response()) # Revoke token - if provider.revoke_token: - await provider.revoke_token(client_auth_result, revocation_request) + if self.provider.revoke_token: + await self.provider.revoke_token(client_auth_result, revocation_request) # Return successful empty response return Response( @@ -64,5 +66,3 @@ async def revocation_handler(request: Request) -> Response: "Pragma": "no-cache", }, ) - - return revocation_handler diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 01cf0554f3..0698262a5c 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -7,7 +7,8 @@ import base64 import hashlib import time -from typing import Annotated, Callable, Literal +from dataclasses import dataclass +from typing import Annotated, Literal from pydantic import AnyHttpUrl, Field, RootModel, ValidationError from starlette.requests import Request @@ -52,10 +53,12 @@ class TokenRequest(RootModel): ] -def create_token_handler( - provider: OAuthServerProvider, client_authenticator: ClientAuthenticator -) -> Callable: - def response(obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse): +@dataclass +class TokenHandler: + provider: OAuthServerProvider + client_authenticator: ClientAuthenticator + + def response(self, obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse): status_code = 200 if isinstance(obj, TokenErrorResponse): status_code = 400 @@ -69,12 +72,12 @@ def response(obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse): }, ) - async def token_handler(request: Request): + async def handle(self, request: Request): try: form_data = await request.form() token_request = TokenRequest.model_validate(dict(form_data)).root except ValidationError as validation_error: - return response( + return self.response( TokenErrorResponse( error="invalid_request", error_description=stringify_pydantic_error(validation_error), @@ -82,12 +85,12 @@ async def token_handler(request: Request): ) try: - client_info = await client_authenticator(token_request) + client_info = await self.client_authenticator(token_request) except InvalidClientError as e: - return response(e.error_response()) + return self.response(e.error_response()) if token_request.grant_type not in client_info.grant_types: - return response( + return self.response( TokenErrorResponse( error="unsupported_grant_type", error_description=( @@ -101,12 +104,12 @@ async def token_handler(request: Request): match token_request: case AuthorizationCodeRequest(): - auth_code = await provider.load_authorization_code( + auth_code = await self.provider.load_authorization_code( client_info, token_request.code ) if auth_code is None or auth_code.client_id != token_request.client_id: # if code belongs to different client, pretend it doesn't exist - return response( + return self.response( TokenErrorResponse( error="invalid_grant", error_description="authorization code does not exist", @@ -116,7 +119,7 @@ async def token_handler(request: Request): # make auth codes expire after a deadline # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 if auth_code.expires_at < time.time(): - return response( + return self.response( TokenErrorResponse( error="invalid_grant", error_description="authorization code has expired", @@ -126,7 +129,7 @@ async def token_handler(request: Request): # verify redirect_uri doesn't change between /authorize and /tokens # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 if token_request.redirect_uri != auth_code.redirect_uri: - return response( + return self.response( TokenErrorResponse( error="invalid_request", error_description=( @@ -144,7 +147,7 @@ async def token_handler(request: Request): if hashed_code_verifier != auth_code.code_challenge: # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 - return response( + return self.response( TokenErrorResponse( error="invalid_grant", error_description="incorrect code_verifier", @@ -152,12 +155,12 @@ async def token_handler(request: Request): ) # Exchange authorization code for tokens - tokens = await provider.exchange_authorization_code( + tokens = await self.provider.exchange_authorization_code( client_info, auth_code ) case RefreshTokenRequest(): - refresh_token = await provider.load_refresh_token( + refresh_token = await self.provider.load_refresh_token( client_info, token_request.refresh_token ) if ( @@ -165,7 +168,7 @@ async def token_handler(request: Request): or refresh_token.client_id != token_request.client_id ): # if token belongs to different client, pretend it doesn't exist - return response( + return self.response( TokenErrorResponse( error="invalid_grant", error_description="refresh token does not exist", @@ -174,7 +177,7 @@ async def token_handler(request: Request): if refresh_token.expires_at and refresh_token.expires_at < time.time(): # if the refresh token has expired, pretend it doesn't exist - return response( + return self.response( TokenErrorResponse( error="invalid_grant", error_description="refresh token has expired", @@ -190,20 +193,19 @@ async def token_handler(request: Request): for scope in scopes: if scope not in refresh_token.scopes: - return response( + return self.response( TokenErrorResponse( error="invalid_scope", error_description=( - f"cannot request scope `{scope}` not provided by refresh token" - ), + f"cannot request scope `{scope}` " + "not provided by refresh token" + ), ) ) # Exchange refresh token for new tokens - tokens = await provider.exchange_refresh_token( + tokens = await self.provider.exchange_refresh_token( client_info, refresh_token, scopes ) - return response(tokens) - - return token_handler + return self.response(tokens) diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 5fa82f82b2..0cc2b921a3 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -10,13 +10,12 @@ from pydantic import AnyUrl from starlette.routing import Route, Router -from mcp.server.auth.handlers.authorize import create_authorization_handler -from mcp.server.auth.handlers.metadata import create_metadata_handler -from mcp.server.auth.handlers.revoke import create_revocation_handler -from mcp.server.auth.handlers.token import create_token_handler -from mcp.server.auth.middleware.client_auth import ( - ClientAuthenticator, -) +from mcp.server.auth.handlers.authorize import AuthorizationHandler +from mcp.server.auth.handlers.metadata import MetadataHandler +from mcp.server.auth.handlers.register import RegistrationHandler +from mcp.server.auth.handlers.revoke import RevocationHandler +from mcp.server.auth.handlers.token import TokenHandler +from mcp.server.auth.middleware.client_auth import ClientAuthenticator from mcp.server.auth.provider import OAuthServerProvider @@ -105,37 +104,39 @@ def create_auth_router( routes=[ Route( "/.well-known/oauth-authorization-server", - endpoint=create_metadata_handler(metadata), + endpoint=MetadataHandler(metadata).handle, methods=["GET"], ), Route( AUTHORIZATION_PATH, - endpoint=create_authorization_handler(provider), + endpoint=AuthorizationHandler(provider).handle, methods=["GET", "POST"], ), Route( TOKEN_PATH, - endpoint=create_token_handler(provider, client_authenticator), + endpoint=TokenHandler(provider, client_authenticator).handle, methods=["POST"], ), ] ) if client_registration_options.enabled: - from mcp.server.auth.handlers.register import create_registration_handler - - registration_handler = create_registration_handler( + registration_handler = RegistrationHandler( provider.clients_store, client_secret_expiry_seconds=client_registration_options.client_secret_expiry_seconds, ) auth_router.routes.append( - Route(REGISTRATION_PATH, endpoint=registration_handler, methods=["POST"]) + Route( + REGISTRATION_PATH, + endpoint=registration_handler.handle, + methods=["POST"], + ) ) if revocation_options.enabled: - revocation_handler = create_revocation_handler(provider, client_authenticator) + revocation_handler = RevocationHandler(provider, client_authenticator) auth_router.routes.append( - Route(REVOCATION_PATH, endpoint=revocation_handler, methods=["POST"]) + Route(REVOCATION_PATH, endpoint=revocation_handler.handle, methods=["POST"]) ) return auth_router From a5079af9844b169a7bc34668275554fced20565a Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 07:58:13 -0700 Subject: [PATCH 24/60] Refactor --- src/mcp/server/auth/handlers/authorize.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 59ea1f62e9..160643f9cb 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -244,14 +244,14 @@ async def error_response( ) # Let the provider pick the next URI to redirect to - response = RedirectResponse( - url="", status_code=302, headers={"Cache-Control": "no-store"} + return RedirectResponse( + url=await self.provider.authorize( + client, + auth_params, + ), + status_code=302, + headers={"Cache-Control": "no-store"} ) - response.headers["location"] = await self.provider.authorize( - client, - auth_params, - ) - return response except Exception as validation_error: # Catch-all for unexpected errors From c4c26087c224d443853723913666ef0218652586 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 08:02:15 -0700 Subject: [PATCH 25/60] Simplify bearer auth logic --- src/mcp/server/auth/middleware/bearer_auth.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index ab597ac903..5d9b72f2e7 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -41,11 +41,8 @@ def __init__( self.provider = provider async def authenticate(self, conn: HTTPConnection): - if "Authorization" not in conn.headers: - return None - - auth_header = conn.headers["Authorization"] - if not auth_header.startswith("Bearer "): + auth_header = conn.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): return None token = auth_header[7:] # Remove "Bearer " prefix From bc62d73214b62415b8e6f1fe346d8a4320054041 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 08:09:45 -0700 Subject: [PATCH 26/60] Avoid asyncio dependency in tests --- .../fastmcp/auth/streaming_asgi_transport.py | 9 +- .../fastmcp/auth/test_auth_integration.py | 225 +++++++++--------- 2 files changed, 120 insertions(+), 114 deletions(-) diff --git a/tests/server/fastmcp/auth/streaming_asgi_transport.py b/tests/server/fastmcp/auth/streaming_asgi_transport.py index eb1ba4342e..6ada601a27 100644 --- a/tests/server/fastmcp/auth/streaming_asgi_transport.py +++ b/tests/server/fastmcp/auth/streaming_asgi_transport.py @@ -6,12 +6,13 @@ the connection is closed. """ -import asyncio import typing from typing import Any, Dict, Tuple import anyio +import anyio.abc import anyio.streams.memory + from httpx._models import Request, Response from httpx._transports.base import AsyncBaseTransport from httpx._types import AsyncByteStream @@ -41,6 +42,7 @@ class StreamingASGITransport(AsyncBaseTransport): def __init__( self, app: typing.Callable, + task_group: anyio.abc.TaskGroup, raise_app_exceptions: bool = True, root_path: str = "", client: Tuple[str, int] = ("127.0.0.1", 123), @@ -49,6 +51,7 @@ def __init__( self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path self.client = client + self.task_group = task_group async def handle_async_request( self, @@ -161,8 +164,8 @@ async def process_messages() -> None: response_complete.set() # Create tasks for running the app and processing messages - asyncio.create_task(run_app()) - asyncio.create_task(process_messages()) + self.task_group.start_soon(run_app) + self.task_group.start_soon(process_messages) # Wait for the initial response or timeout await initial_response_ready.wait() diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 9f756c0507..73991e299d 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -11,6 +11,7 @@ from typing import List, Optional from urllib.parse import parse_qs, urlparse +import anyio import httpx import pytest from httpx_sse import aconnect_sse @@ -993,130 +994,132 @@ async def test_fastmcp_with_auth( def test_tool(x: int) -> str: return f"Result: {x}" - transport = StreamingASGITransport(app=mcp.starlette_app()) # pyright: ignore - test_client = httpx.AsyncClient( - transport=transport, base_url="http://mcptest.com" - ) - # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") + async with anyio.create_task_group() as task_group: + transport = StreamingASGITransport(app=mcp.starlette_app(), task_group=task_group) # pyright: ignore + test_client = httpx.AsyncClient( + transport=transport, base_url="http://mcptest.com" + ) + # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") - # Test metadata endpoint - response = await test_client.get("/.well-known/oauth-authorization-server") - assert response.status_code == 200 + # Test metadata endpoint + response = await test_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 - # Test that auth is required for protected endpoints - response = await test_client.get("/sse") - # TODO: we should return 401/403 depending on whether authn or authz fails - assert response.status_code == 403 + # Test that auth is required for protected endpoints + response = await test_client.get("/sse") + # TODO: we should return 401/403 depending on whether authn or authz fails + assert response.status_code == 403 - response = await test_client.post("/messages/") - # TODO: we should return 401/403 depending on whether authn or authz fails - assert response.status_code == 403, response.content + response = await test_client.post("/messages/") + # TODO: we should return 401/403 depending on whether authn or authz fails + assert response.status_code == 403, response.content - response = await test_client.post( - "/messages/", - headers={"Authorization": "invalid"}, - ) - assert response.status_code == 403 - - response = await test_client.post( - "/messages/", - headers={"Authorization": "Bearer invalid"}, - ) - assert response.status_code == 403 + response = await test_client.post( + "/messages/", + headers={"Authorization": "invalid"}, + ) + assert response.status_code == 403 - # now, become authenticated and try to go through the flow again - client_metadata = { - "redirect_uris": ["https://client.example.com/callback"], - "client_name": "Test Client", - } + response = await test_client.post( + "/messages/", + headers={"Authorization": "Bearer invalid"}, + ) + assert response.status_code == 403 - response = await test_client.post( - "/register", - json=client_metadata, - ) - assert response.status_code == 201 - client_info = response.json() + # now, become authenticated and try to go through the flow again + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + } - # Request authorization using POST with form-encoded data - response = await test_client.post( - "/authorize", - data={ - "response_type": "code", - "client_id": client_info["client_id"], - "redirect_uri": "https://client.example.com/callback", - "code_challenge": pkce_challenge["code_challenge"], - "code_challenge_method": "S256", - "state": "test_state", - }, - ) - assert response.status_code == 302 + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() - # Extract the authorization code from the redirect URL - redirect_url = response.headers["location"] - parsed_url = urlparse(redirect_url) - query_params = parse_qs(parsed_url.query) + # Request authorization using POST with form-encoded data + response = await test_client.post( + "/authorize", + data={ + "response_type": "code", + "client_id": client_info["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + assert response.status_code == 302 - assert "code" in query_params - auth_code = query_params["code"][0] + # Extract the authorization code from the redirect URL + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) - # Exchange the authorization code for tokens - response = await test_client.post( - "/token", - data={ - "grant_type": "authorization_code", - "client_id": client_info["client_id"], - "client_secret": client_info["client_secret"], - "code": auth_code, - "code_verifier": pkce_challenge["code_verifier"], - "redirect_uri": "https://client.example.com/callback", - }, - ) - assert response.status_code == 200 + assert "code" in query_params + auth_code = query_params["code"][0] - token_response = response.json() - assert "access_token" in token_response - authorization = f"Bearer {token_response['access_token']}" - - # Test the authenticated endpoint with valid token - async with aconnect_sse( - test_client, "GET", "/sse", headers={"Authorization": authorization} - ) as event_source: - assert event_source.response.status_code == 200 - events = event_source.aiter_sse() - sse = await events.__anext__() - assert sse.event == "endpoint" - assert sse.data.startswith("/messages/?session_id=") - messages_uri = sse.data - - # verify that we can now post to the /messages endpoint, and get a response - # on the /sse endpoint + # Exchange the authorization code for tokens response = await test_client.post( - messages_uri, - headers={"Authorization": authorization}, - content=JSONRPCRequest( - jsonrpc="2.0", - id="123", - method="initialize", - params={ - "protocolVersion": "2024-11-05", - "capabilities": { - "roots": {"listChanged": True}, - "sampling": {}, - }, - "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, - }, - ).model_dump_json(), - ) - assert response.status_code == 202 - assert response.content == b"Accepted" - - sse = await events.__anext__() - assert sse.event == "message" - sse_data = json.loads(sse.data) - assert sse_data["id"] == "123" - assert set(sse_data["result"]["capabilities"].keys()) == set( - ("experimental", "prompts", "resources", "tools") + "/token", + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, ) + assert response.status_code == 200 + + token_response = response.json() + assert "access_token" in token_response + authorization = f"Bearer {token_response['access_token']}" + + # Test the authenticated endpoint with valid token + async with aconnect_sse( + test_client, "GET", "/sse", headers={"Authorization": authorization} + ) as event_source: + assert event_source.response.status_code == 200 + events = event_source.aiter_sse() + sse = await events.__anext__() + assert sse.event == "endpoint" + assert sse.data.startswith("/messages/?session_id=") + messages_uri = sse.data + + # verify that we can now post to the /messages endpoint, and get a response + # on the /sse endpoint + response = await test_client.post( + messages_uri, + headers={"Authorization": authorization}, + content=JSONRPCRequest( + jsonrpc="2.0", + id="123", + method="initialize", + params={ + "protocolVersion": "2024-11-05", + "capabilities": { + "roots": {"listChanged": True}, + "sampling": {}, + }, + "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, + }, + ).model_dump_json(), + ) + assert response.status_code == 202 + assert response.content == b"Accepted" + + sse = await events.__anext__() + assert sse.event == "message" + sse_data = json.loads(sse.data) + assert sse_data["id"] == "123" + assert set(sse_data["result"]["capabilities"].keys()) == set( + ("experimental", "prompts", "resources", "tools") + ) + task_group.cancel_scope.cancel() class TestAuthorizeEndpointErrors: From 3852179c7dc1801c67c193632ccb10667346dd28 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 08:10:45 -0700 Subject: [PATCH 27/60] Add comment --- tests/server/fastmcp/auth/test_auth_integration.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 73991e299d..82ec6067f2 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1119,6 +1119,9 @@ def test_tool(x: int) -> str: assert set(sse_data["result"]["capabilities"].keys()) == set( ("experimental", "prompts", "resources", "tools") ) + # the /sse endpoint will never finish; normally, the client could just + # disconnect, but in tests the easiest way to do this is to cancel the + # task group task_group.cancel_scope.cancel() From 874838a58f54c11287a32b2f6d89e21ecc9a7c5a Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 08:11:48 -0700 Subject: [PATCH 28/60] Lint --- tests/server/fastmcp/auth/streaming_asgi_transport.py | 1 - tests/server/fastmcp/auth/test_auth_integration.py | 10 ++++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/server/fastmcp/auth/streaming_asgi_transport.py b/tests/server/fastmcp/auth/streaming_asgi_transport.py index 6ada601a27..7bb07b50a4 100644 --- a/tests/server/fastmcp/auth/streaming_asgi_transport.py +++ b/tests/server/fastmcp/auth/streaming_asgi_transport.py @@ -12,7 +12,6 @@ import anyio import anyio.abc import anyio.streams.memory - from httpx._models import Request, Response from httpx._transports.base import AsyncBaseTransport from httpx._types import AsyncByteStream diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 82ec6067f2..fb6d58deb9 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -995,11 +995,13 @@ def test_tool(x: int) -> str: return f"Result: {x}" async with anyio.create_task_group() as task_group: - transport = StreamingASGITransport(app=mcp.starlette_app(), task_group=task_group) # pyright: ignore + transport = StreamingASGITransport( + app=mcp.starlette_app(), + task_group=task_group, + ) test_client = httpx.AsyncClient( transport=transport, base_url="http://mcptest.com" ) - # test_client = httpx.AsyncClient(app=mcp.starlette_app(), base_url="http://mcptest.com") # Test metadata endpoint response = await test_client.get("/.well-known/oauth-authorization-server") @@ -1090,8 +1092,8 @@ def test_tool(x: int) -> str: assert sse.data.startswith("/messages/?session_id=") messages_uri = sse.data - # verify that we can now post to the /messages endpoint, and get a response - # on the /sse endpoint + # verify that we can now post to the /messages endpoint, + # and get a response on the /sse endpoint response = await test_client.post( messages_uri, headers={"Authorization": authorization}, From f788d7900beed7fa0ae582c9e855d6b35b664af5 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 08:15:16 -0700 Subject: [PATCH 29/60] Add json_response.py comment --- src/mcp/server/auth/json_response.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp/server/auth/json_response.py b/src/mcp/server/auth/json_response.py index 25971cc916..bd95bd693b 100644 --- a/src/mcp/server/auth/json_response.py +++ b/src/mcp/server/auth/json_response.py @@ -4,5 +4,7 @@ class PydanticJSONResponse(JSONResponse): + # use pydantic json serialization instead of the stock `json.dumps`, + # so that we can handle serializing pydantic models like AnyHttpUrl def render(self, content: Any) -> bytes: return content.model_dump_json(exclude_none=True).encode("utf-8") From 152feb94df7324f47ba9d1a1521bcc9062a1fd8d Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 11:14:19 -0700 Subject: [PATCH 30/60] Format --- src/mcp/server/auth/errors.py | 1 + src/mcp/server/auth/handlers/authorize.py | 3 +-- src/mcp/server/auth/handlers/register.py | 1 + src/mcp/server/auth/handlers/token.py | 2 +- src/mcp/server/auth/middleware/bearer_auth.py | 1 - src/mcp/server/auth/middleware/client_auth.py | 2 +- tests/server/fastmcp/auth/test_auth_integration.py | 13 ++++--------- 7 files changed, 9 insertions(+), 14 deletions(-) diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index 08686d2eb3..e82afcfe4e 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -10,6 +10,7 @@ ErrorCode = Literal["invalid_request", "invalid_client"] + class ErrorResponse(BaseModel): error: ErrorCode error_description: str diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 160643f9cb..7f50b4bd15 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -250,7 +250,7 @@ async def error_response( auth_params, ), status_code=302, - headers={"Cache-Control": "no-store"} + headers={"Cache-Control": "no-store"}, ) except Exception as validation_error: @@ -263,7 +263,6 @@ async def error_response( ) - def create_error_redirect( redirect_uri: AnyUrl, error: Exception | ErrorResponse ) -> str: diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 6c41b85855..51947ee960 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -34,6 +34,7 @@ class ErrorResponse(BaseModel): class RegistrationHandler: clients_store: OAuthRegisteredClientsStore client_secret_expiry_seconds: int | None + async def handle(self, request: Request) -> Response: # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 try: diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0698262a5c..3b48008cdf 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -83,7 +83,7 @@ async def handle(self, request: Request): error_description=stringify_pydantic_error(validation_error), ) ) - + try: client_info = await self.client_authenticator(token_request) except InvalidClientError as e: diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 5d9b72f2e7..139035b9ab 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -59,7 +59,6 @@ async def authenticate(self, conn: HTTPConnection): return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) - class RequireAuthMiddleware: """ Middleware that requires a valid Bearer token in the Authorization header. diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 4546d92215..2219a74e2e 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -67,4 +67,4 @@ async def __call__(self, request: ClientAuthRequest) -> OAuthClientInformationFu ): raise InvalidClientError("Client secret has expired") - return client \ No newline at end of file + return client diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index fb6d58deb9..c8144e6c2e 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -52,9 +52,7 @@ def __init__(self): async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: return self.clients.get(client_id) - async def register_client( - self, client_info: OAuthClientInformationFull - ): + async def register_client(self, client_info: OAuthClientInformationFull): self.clients[client_info.client_id] = client_info @@ -750,12 +748,9 @@ async def test_client_registration_invalid_uri( error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert ( - error_data["error_description"] - == ( - "redirect_uris.0: Input should be a valid URL, " - "relative URL without a base" - ) + assert error_data["error_description"] == ( + "redirect_uris.0: Input should be a valid URL, " + "relative URL without a base" ) @pytest.mark.anyio From f37ebc46e5b19d3ee0e6e57f937dfd12e40e106c Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 14:22:43 -0700 Subject: [PATCH 31/60] Move around the response models to be closer to the handlers --- src/mcp/server/auth/handlers/authorize.py | 50 ++++++++++--------- src/mcp/server/auth/handlers/register.py | 12 +++-- src/mcp/server/auth/handlers/revoke.py | 35 +++++++------ src/mcp/server/auth/handlers/token.py | 32 ++++++++++-- src/mcp/server/auth/provider.py | 25 ++++------ src/mcp/shared/auth.py | 19 +------ .../fastmcp/auth/test_auth_integration.py | 34 ++++--------- 7 files changed, 103 insertions(+), 104 deletions(-) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 7f50b4bd15..ef4af9d0c7 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -53,6 +53,25 @@ class AuthorizationRequest(BaseModel): ) +AuthorizationErrorCode = Literal[ + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", +] + + +class AuthorizationErrorResponse(BaseModel): + error: AuthorizationErrorCode + error_description: str + error_uri: AnyUrl | None = None + # must be set if provided in the request + state: str | None = None + + def validate_scope( requested_scope: str | None, client: OAuthClientInformationFull ) -> list[str] | None: @@ -84,25 +103,6 @@ def validate_redirect_uri( ) -ErrorCode = Literal[ - "invalid_request", - "unauthorized_client", - "access_denied", - "unsupported_response_type", - "invalid_scope", - "server_error", - "temporarily_unavailable", -] - - -class ErrorResponse(BaseModel): - error: ErrorCode - error_description: str - error_uri: AnyUrl | None = None - # must be set if provided in the request - state: str | None = None - - def best_effort_extract_string( key: str, params: None | FormData | QueryParams ) -> str | None: @@ -132,7 +132,9 @@ async def handle(self, request: Request) -> Response: params = None async def error_response( - error: ErrorCode, error_description: str, attempt_load_client: bool = True + error: AuthorizationErrorCode, + error_description: str, + attempt_load_client: bool = True, ): nonlocal client, redirect_uri, state if client is None and attempt_load_client: @@ -157,7 +159,7 @@ async def error_response( # make last-ditch effort to load state state = best_effort_extract_string("state", params) - error_resp = ErrorResponse( + error_resp = AuthorizationErrorResponse( error=error, error_description=error_description, state=state, @@ -194,7 +196,7 @@ async def error_response( auth_request = AuthorizationRequest.model_validate(params) state = auth_request.state # Update with validated state except ValidationError as validation_error: - error: ErrorCode = "invalid_request" + error: AuthorizationErrorCode = "invalid_request" for e in validation_error.errors(): if e["loc"] == ("response_type",) and e["type"] == "literal_error": error = "unsupported_response_type" @@ -264,11 +266,11 @@ async def error_response( def create_error_redirect( - redirect_uri: AnyUrl, error: Exception | ErrorResponse + redirect_uri: AnyUrl, error: Exception | AuthorizationErrorResponse ) -> str: parsed_uri = urlparse(str(redirect_uri)) - if isinstance(error, ErrorResponse): + if isinstance(error, AuthorizationErrorResponse): # Convert ErrorResponse to dict error_dict = error.model_dump(exclude_none=True) query_params = {} diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 51947ee960..8213aaa322 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -10,7 +10,7 @@ from typing import Literal from uuid import uuid4 -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, RootModel, ValidationError from starlette.requests import Request from starlette.responses import Response @@ -20,7 +20,13 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata -class ErrorResponse(BaseModel): +class RegistrationRequest(RootModel): + # this wrapper is a no-op; it's just to separate out the types exposed to the + # provider from what we use in the HTTP handler + root: OAuthClientMetadata + + +class RegistrationErrorResponse(BaseModel): error: Literal[ "invalid_redirect_uri", "invalid_client_metadata", @@ -43,7 +49,7 @@ async def handle(self, request: Request) -> Response: client_metadata = OAuthClientMetadata.model_validate(body) except ValidationError as validation_error: return PydanticJSONResponse( - content=ErrorResponse( + content=RegistrationErrorResponse( error="invalid_client_metadata", error_description=stringify_pydantic_error(validation_error), ), diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index d31fe62285..6711506f9c 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -5,26 +5,34 @@ """ from dataclasses import dataclass +from typing import Literal -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from starlette.requests import Request from starlette.responses import Response from mcp.server.auth.errors import ( - InvalidClientError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( ClientAuthenticator, - ClientAuthRequest, ) -from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest -from mcp.shared.auth import TokenErrorResponse +from mcp.server.auth.provider import OAuthServerProvider -class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest): - pass +class RevocationRequest(BaseModel): + """ + # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 + """ + + token: str + token_type_hint: Literal["access_token", "refresh_token"] | None = None + + +class RevocationErrorResponse(BaseModel): + error: Literal["invalid_request",] + error_description: str | None = None @dataclass @@ -42,21 +50,16 @@ async def handle(self, request: Request) -> Response: except ValidationError as e: return PydanticJSONResponse( status_code=400, - content=TokenErrorResponse( + content=RevocationErrorResponse( error="invalid_request", error_description=stringify_pydantic_error(e), ), ) - # Authenticate client - try: - client_auth_result = await self.client_authenticator(revocation_request) - except InvalidClientError as e: - return PydanticJSONResponse(status_code=401, content=e.error_response()) - # Revoke token - if self.provider.revoke_token: - await self.provider.revoke_token(client_auth_result, revocation_request) + await self.provider.revoke_token( + revocation_request.token, revocation_request.token_type_hint + ) # Return successful empty response return Response( diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 3b48008cdf..f005dff232 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from typing import Annotated, Literal -from pydantic import AnyHttpUrl, Field, RootModel, ValidationError +from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError from starlette.requests import Request from mcp.server.auth.errors import ( @@ -24,7 +24,7 @@ ClientAuthRequest, ) from mcp.server.auth.provider import OAuthServerProvider -from mcp.shared.auth import TokenErrorResponse, TokenSuccessResponse +from mcp.shared.auth import OAuthToken class AuthorizationCodeRequest(ClientAuthRequest): @@ -53,6 +53,30 @@ class TokenRequest(RootModel): ] +class TokenErrorResponse(BaseModel): + """ + See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + """ + + error: Literal[ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", + ] + error_description: str | None = None + error_uri: AnyHttpUrl | None = None + + +class TokenSuccessResponse(RootModel): + # this is just a wrapper over OAuthToken; the only reason we do this + # is to have some separation between the HTTP response type, and the + # type returned by the provider + root: OAuthToken + + @dataclass class TokenHandler: provider: OAuthServerProvider @@ -100,7 +124,7 @@ async def handle(self, request: Request): ) ) - tokens: TokenSuccessResponse + tokens: OAuthToken match token_request: case AuthorizationCodeRequest(): @@ -208,4 +232,4 @@ async def handle(self, request: Request): client_info, refresh_token, scopes ) - return self.response(tokens) + return self.response(TokenSuccessResponse(root=tokens)) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 6eb039746b..ac1f6343cc 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -12,7 +12,7 @@ from mcp.server.auth.types import AuthInfo from mcp.shared.auth import ( OAuthClientInformationFull, - TokenSuccessResponse, + OAuthToken, ) @@ -45,15 +45,6 @@ class RefreshToken(BaseModel): expires_at: int | None = None -class OAuthTokenRevocationRequest(BaseModel): - """ - # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 - """ - - token: str - token_type_hint: Literal["access_token", "refresh_token"] | None = None - - class OAuthRegisteredClientsStore(Protocol): """ Interface for storing and retrieving registered OAuth clients. @@ -149,7 +140,7 @@ async def load_authorization_code( async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode - ) -> TokenSuccessResponse: + ) -> OAuthToken: """ Exchanges an authorization code for an access token. @@ -171,7 +162,7 @@ async def exchange_refresh_token( client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str], - ) -> TokenSuccessResponse: + ) -> OAuthToken: """ Exchanges a refresh token for an access token. @@ -198,7 +189,9 @@ async def load_access_token(self, token: str) -> AuthInfo | None: ... async def revoke_token( - self, client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest + self, + token: str, + token_type_hint: Literal["access_token", "refresh_token"] | None = None, ) -> None: """ Revokes an access or refresh token. @@ -206,8 +199,10 @@ async def revoke_token( If the given token is invalid or already revoked, this method should do nothing. Args: - client: The client revoking the token. - request: The token revocation request. + token: the token to revoke + token_type_hint: hint about the type of token to revoke; optional. if the + token cannot be located using this hint, the provider MUST extend its search + to include all tokens. """ ... diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 963fcc7236..16c07a70a8 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -9,24 +9,7 @@ from pydantic import AnyHttpUrl, BaseModel, Field -class TokenErrorResponse(BaseModel): - """ - See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 - """ - - error: Literal[ - "invalid_request", - "invalid_client", - "invalid_grant", - "unauthorized_client", - "unsupported_grant_type", - "invalid_scope", - ] - error_description: Optional[str] = None - error_uri: Optional[AnyHttpUrl] = None - - -class TokenSuccessResponse(BaseModel): +class OAuthToken(BaseModel): """ See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 """ diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index c8144e6c2e..11a9ccd44f 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -8,7 +8,7 @@ import secrets import time import unittest.mock -from typing import List, Optional +from typing import List, Literal, Optional from urllib.parse import parse_qs, urlparse import anyio @@ -24,7 +24,6 @@ AuthorizationParams, OAuthRegisteredClientsStore, OAuthServerProvider, - OAuthTokenRevocationRequest, RefreshToken, construct_redirect_uri, ) @@ -37,7 +36,7 @@ from mcp.server.fastmcp import FastMCP from mcp.shared.auth import ( OAuthClientInformationFull, - TokenSuccessResponse, + OAuthToken, ) from mcp.types import JSONRPCRequest @@ -94,7 +93,7 @@ async def load_authorization_code( async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode - ) -> TokenSuccessResponse: + ) -> OAuthToken: assert authorization_code.code in self.auth_codes # Generate an access token and refresh token @@ -114,7 +113,7 @@ async def exchange_authorization_code( # Remove the used code del self.auth_codes[authorization_code.code] - return TokenSuccessResponse( + return OAuthToken( access_token=access_token, token_type="bearer", expires_in=3600, @@ -147,7 +146,7 @@ async def exchange_refresh_token( client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: List[str], - ) -> TokenSuccessResponse: + ) -> OAuthToken: # Check if refresh token exists assert refresh_token.token in self.refresh_tokens @@ -177,7 +176,7 @@ async def exchange_refresh_token( del self.refresh_tokens[refresh_token.token] del self.tokens[old_access_token] - return TokenSuccessResponse( + return OAuthToken( access_token=new_access_token, token_type="bearer", expires_in=3600, @@ -200,30 +199,17 @@ async def load_access_token(self, token: str) -> AuthInfo | None: ) async def revoke_token( - self, client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest + self, + token: str, + token_type_hint: Literal["access_token", "refresh_token"] | None = None, ) -> None: - token = request.token - # Check if it's a refresh token if token in self.refresh_tokens: - access_token = self.refresh_tokens[token] - - # Check if this refresh token belongs to this client - if self.tokens[access_token]["client_id"] != client.client_id: - # For security reasons, we still return success - return - - # Remove the refresh token and its associated access token - del self.tokens[access_token] + # Remove the refresh token del self.refresh_tokens[token] # Check if it's an access token elif token in self.tokens: - # Check if this access token belongs to this client - if self.tokens[token]["client_id"] != client.client_id: - # For security reasons, we still return success - return - # Remove the access token del self.tokens[token] From c2873fdb16ca939fea14e06f6e018aa67301370c Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 14:27:32 -0700 Subject: [PATCH 32/60] Get rid of silly TS comments --- src/mcp/server/auth/errors.py | 12 ------------ src/mcp/server/auth/handlers/authorize.py | 6 ------ src/mcp/server/auth/handlers/metadata.py | 6 ------ src/mcp/server/auth/handlers/register.py | 6 ------ src/mcp/server/auth/handlers/revoke.py | 6 ------ src/mcp/server/auth/handlers/token.py | 6 ------ src/mcp/server/auth/middleware/bearer_auth.py | 8 -------- src/mcp/server/auth/middleware/client_auth.py | 13 +------------ src/mcp/server/auth/router.py | 8 -------- src/mcp/server/auth/types.py | 12 ------------ src/mcp/shared/auth.py | 14 -------------- 11 files changed, 1 insertion(+), 96 deletions(-) diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index e82afcfe4e..e629e28acb 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -1,9 +1,3 @@ -""" -OAuth error classes for MCP authorization. - -Corresponds to TypeScript file: src/server/auth/errors.ts -""" - from typing import Literal from pydantic import BaseModel, ValidationError @@ -19,8 +13,6 @@ class ErrorResponse(BaseModel): class OAuthError(Exception): """ Base class for all OAuth errors. - - Corresponds to OAuthError in src/server/auth/errors.ts """ error_code: ErrorCode @@ -39,8 +31,6 @@ def error_response(self) -> ErrorResponse: class InvalidRequestError(OAuthError): """ Invalid request error. - - Corresponds to InvalidRequestError in src/server/auth/errors.ts """ error_code = "invalid_request" @@ -49,8 +39,6 @@ class InvalidRequestError(OAuthError): class InvalidClientError(OAuthError): """ Invalid client error. - - Corresponds to InvalidClientError in src/server/auth/errors.ts """ error_code = "invalid_client" diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index ef4af9d0c7..6c99bcfb7a 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -1,9 +1,3 @@ -""" -Handler for OAuth 2.0 Authorization endpoint. - -Corresponds to TypeScript file: src/server/auth/handlers/authorize.ts -""" - import logging from dataclasses import dataclass from typing import Literal diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index 39cc889402..43a37affae 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -1,9 +1,3 @@ -""" -Handler for OAuth 2.0 Authorization Server Metadata. - -Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts -""" - from dataclasses import dataclass from typing import Any diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 8213aaa322..893e7a7f8e 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -1,9 +1,3 @@ -""" -Handler for OAuth 2.0 Dynamic Client Registration. - -Corresponds to TypeScript file: src/server/auth/handlers/register.ts -""" - import secrets import time from dataclasses import dataclass diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 6711506f9c..e45c935912 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -1,9 +1,3 @@ -""" -Handler for OAuth 2.0 Token Revocation. - -Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts -""" - from dataclasses import dataclass from typing import Literal diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index f005dff232..8cdf216474 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -1,9 +1,3 @@ -""" -Handler for OAuth 2.0 Token endpoint. - -Corresponds to TypeScript file: src/server/auth/handlers/token.ts -""" - import base64 import hashlib import time diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 139035b9ab..fbd4f4d152 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,9 +1,3 @@ -""" -Bearer token authentication middleware for ASGI applications. - -Corresponds to TypeScript file: src/server/auth/middleware/bearerAuth.ts -""" - import time from typing import Any, Callable @@ -65,8 +59,6 @@ class RequireAuthMiddleware: This will validate the token with the auth provider and store the resulting auth info in the request state. - - Corresponds to bearerAuthMiddleware in src/server/auth/middleware/bearerAuth.ts """ def __init__(self, app: Any, required_scopes: list[str]): diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 2219a74e2e..d70d56749d 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -1,9 +1,3 @@ -""" -Client authentication middleware for ASGI applications. - -Corresponds to TypeScript file: src/server/auth/middleware/clientAuth.ts -""" - import time from pydantic import BaseModel @@ -14,12 +8,7 @@ class ClientAuthRequest(BaseModel): - """ - Model for client authentication request body. - - Corresponds to ClientAuthenticatedRequestSchema in - src/server/auth/middleware/clientAuth.ts - """ + # TODO: mix this directly into TokenRequest client_id: str client_secret: str | None = None diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 0cc2b921a3..1e49aef5f6 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -1,9 +1,3 @@ -""" -Router for OAuth authorization endpoints. - -Corresponds to TypeScript file: src/server/auth/router.ts -""" - from dataclasses import dataclass from typing import Any @@ -72,8 +66,6 @@ def create_auth_router( """ Create a Starlette router with standard MCP authorization endpoints. - Corresponds to mcpAuthRouter in src/server/auth/router.ts - Args: provider: OAuth server provider issuer_url: Issuer URL for the authorization server diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py index eb47b65770..6e03b1ffad 100644 --- a/src/mcp/server/auth/types.py +++ b/src/mcp/server/auth/types.py @@ -1,19 +1,7 @@ -""" -Authorization types for MCP server. - -Corresponds to TypeScript file: src/server/auth/types.ts -""" - from pydantic import BaseModel class AuthInfo(BaseModel): - """ - Information about a validated access token, provided to request handlers. - - Corresponds to AuthInfo in src/server/auth/types.ts - """ - token: str client_id: str scopes: list[str] diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 16c07a70a8..e62f8d762c 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -1,9 +1,3 @@ -""" -Authorization types and models for MCP OAuth implementation. - -Corresponds to TypeScript file: src/shared/auth.ts -""" - from typing import Any, List, Literal, Optional from pydantic import AnyHttpUrl, BaseModel, Field @@ -60,8 +54,6 @@ class OAuthClientMetadata(BaseModel): class OAuthClientInformation(BaseModel): """ RFC 7591 OAuth 2.0 Dynamic Client Registration client information. - - Corresponds to OAuthClientInformationSchema in src/shared/auth.ts """ client_id: str @@ -74,8 +66,6 @@ class OAuthClientInformationFull(OAuthClientMetadata, OAuthClientInformation): """ RFC 7591 OAuth 2.0 Dynamic Client Registration full response (client information plus metadata). - - Corresponds to OAuthClientInformationFullSchema in src/shared/auth.ts """ pass @@ -84,8 +74,6 @@ class OAuthClientInformationFull(OAuthClientMetadata, OAuthClientInformation): class OAuthClientRegistrationError(BaseModel): """ RFC 7591 OAuth 2.0 Dynamic Client Registration error response. - - Corresponds to OAuthClientRegistrationErrorSchema in src/shared/auth.ts """ error: str @@ -95,8 +83,6 @@ class OAuthClientRegistrationError(BaseModel): class OAuthMetadata(BaseModel): """ RFC 8414 OAuth 2.0 Authorization Server Metadata. - - Corresponds to OAuthMetadataSchema in src/shared/auth.ts """ issuer: str From fe2c029096e5c8e2a5d8973d7db9d3b1eefa7d59 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 15:27:03 -0700 Subject: [PATCH 33/60] Remove ClientAuthRequest --- src/mcp/server/auth/handlers/token.py | 15 ++++++++--- src/mcp/server/auth/middleware/client_auth.py | 24 ++++++----------- src/mcp/server/auth/provider.py | 26 +------------------ 3 files changed, 20 insertions(+), 45 deletions(-) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 8cdf216474..14c92e4a1d 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -15,13 +15,12 @@ from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( ClientAuthenticator, - ClientAuthRequest, ) from mcp.server.auth.provider import OAuthServerProvider from mcp.shared.auth import OAuthToken -class AuthorizationCodeRequest(ClientAuthRequest): +class AuthorizationCodeRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") @@ -29,15 +28,20 @@ class AuthorizationCodeRequest(ClientAuthRequest): ..., description="Must be the same as redirect URI provided in /authorize" ) client_id: str + # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + client_secret: str | None = None # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 code_verifier: str = Field(..., description="PKCE code verifier") -class RefreshTokenRequest(ClientAuthRequest): +class RefreshTokenRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 grant_type: Literal["refresh_token"] refresh_token: str = Field(..., description="The refresh token") scope: str | None = Field(None, description="Optional scope parameter") + client_id: str + # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + client_secret: str | None = None class TokenRequest(RootModel): @@ -103,7 +107,10 @@ async def handle(self, request: Request): ) try: - client_info = await self.client_authenticator(token_request) + client_info = await self.client_authenticator.authenticate( + client_id=token_request.client_id, + client_secret=token_request.client_secret, + ) except InvalidClientError as e: return self.response(e.error_response()) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index d70d56749d..cda5d79a52 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -1,26 +1,16 @@ import time -from pydantic import BaseModel - from mcp.server.auth.errors import InvalidClientError from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull -class ClientAuthRequest(BaseModel): - # TODO: mix this directly into TokenRequest - - client_id: str - client_secret: str | None = None - - class ClientAuthenticator: """ ClientAuthenticator is a callable which validates requests from a client - application, - used to verify /token and /revoke calls. + application, used to verify /token calls. If, during registration, the client requested to be issued a secret, the - authenticator asserts that /token and /register calls must be authenticated with + authenticator asserts that /token calls must be authenticated with that same token. NOTE: clients can opt for no authentication during registration, in which case this logic is skipped. @@ -35,19 +25,21 @@ def __init__(self, clients_store: OAuthRegisteredClientsStore): """ self.clients_store = clients_store - async def __call__(self, request: ClientAuthRequest) -> OAuthClientInformationFull: + async def authenticate( + self, client_id: str, client_secret: str | None + ) -> OAuthClientInformationFull: # Look up client information - client = await self.clients_store.get_client(request.client_id) + client = await self.clients_store.get_client(client_id) if not client: raise InvalidClientError("Invalid client_id") # If client from the store expects a secret, validate that the request provides # that secret if client.client_secret: - if not request.client_secret: + if not client_secret: raise InvalidClientError("Client secret is required") - if client.client_secret != request.client_secret: + if client.client_secret != client_secret: raise InvalidClientError("Invalid client_secret") if ( diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index ac1f6343cc..466acccbff 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,9 +1,3 @@ -""" -OAuth server provider interfaces for MCP authorization. - -Corresponds to TypeScript file: src/server/auth/provider.ts -""" - from typing import Literal, Protocol from urllib.parse import parse_qs, urlencode, urlparse, urlunparse @@ -17,12 +11,6 @@ class AuthorizationParams(BaseModel): - """ - Parameters for the authorization flow. - - Corresponds to AuthorizationParams in src/server/auth/provider.ts - """ - state: str | None = None scopes: list[str] | None = None code_challenge: str @@ -46,12 +34,6 @@ class RefreshToken(BaseModel): class OAuthRegisteredClientsStore(Protocol): - """ - Interface for storing and retrieving registered OAuth clients. - - Corresponds to OAuthRegisteredClientsStore in src/server/auth/clients.ts - """ - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """ Retrieves client information by client ID. @@ -66,7 +48,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def register_client(self, client_info: OAuthClientInformationFull) -> None: """ - Registers a new client + Saves client information as part of registering it. Args: client_info: The client metadata to register. @@ -75,12 +57,6 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None class OAuthServerProvider(Protocol): - """ - Implements an end-to-end OAuth server. - - Corresponds to OAuthServerProvider in src/server/auth/provider.ts - """ - @property def clients_store(self) -> OAuthRegisteredClientsStore: """ From 3a13f5d8e3458258cdaedb242ee049e93af9ee18 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 15:34:16 -0700 Subject: [PATCH 34/60] Reorganize AuthInfo --- src/mcp/server/auth/middleware/bearer_auth.py | 3 +-- src/mcp/server/auth/provider.py | 8 +++++++- src/mcp/server/auth/types.py | 8 -------- tests/server/fastmcp/auth/test_auth_integration.py | 2 +- 4 files changed, 9 insertions(+), 12 deletions(-) delete mode 100644 src/mcp/server/auth/types.py diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index fbd4f4d152..6a64648b81 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -10,8 +10,7 @@ from starlette.requests import HTTPConnection from starlette.types import Scope -from mcp.server.auth.provider import OAuthServerProvider -from mcp.server.auth.types import AuthInfo +from mcp.server.auth.provider import AuthInfo, OAuthServerProvider class AuthenticatedUser(SimpleUser): diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 466acccbff..e0ee171ab6 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -3,7 +3,6 @@ from pydantic import AnyHttpUrl, BaseModel -from mcp.server.auth.types import AuthInfo from mcp.shared.auth import ( OAuthClientInformationFull, OAuthToken, @@ -33,6 +32,13 @@ class RefreshToken(BaseModel): expires_at: int | None = None +class AuthInfo(BaseModel): + token: str + client_id: str + scopes: list[str] + expires_at: int | None = None + + class OAuthRegisteredClientsStore(Protocol): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """ diff --git a/src/mcp/server/auth/types.py b/src/mcp/server/auth/types.py deleted file mode 100644 index 6e03b1ffad..0000000000 --- a/src/mcp/server/auth/types.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel - - -class AuthInfo(BaseModel): - token: str - client_id: str - scopes: list[str] - expires_at: int | None = None diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 11a9ccd44f..458d46c166 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -20,6 +20,7 @@ from starlette.routing import Mount from mcp.server.auth.provider import ( + AuthInfo, AuthorizationCode, AuthorizationParams, OAuthRegisteredClientsStore, @@ -32,7 +33,6 @@ RevocationOptions, create_auth_router, ) -from mcp.server.auth.types import AuthInfo from mcp.server.fastmcp import FastMCP from mcp.shared.auth import ( OAuthClientInformationFull, From 37c5fc4e22ee488541cec001fbc331d42318adba Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 15:57:27 -0700 Subject: [PATCH 35/60] Refactor client metadata endpoint --- src/mcp/server/auth/handlers/metadata.py | 15 ++++++------ src/mcp/server/auth/router.py | 30 ++++++++++++------------ src/mcp/shared/auth.py | 23 +++--------------- 3 files changed, 25 insertions(+), 43 deletions(-) diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index 43a37affae..e37e5d311f 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -1,19 +1,18 @@ from dataclasses import dataclass -from typing import Any from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import Response + +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.shared.auth import OAuthMetadata @dataclass class MetadataHandler: - metadata: dict[str, Any] + metadata: OAuthMetadata async def handle(self, request: Request) -> Response: - # Remove any None values from metadata - clean_metadata = {k: v for k, v in self.metadata.items() if v is not None} - - return JSONResponse( - content=clean_metadata, + return PydanticJSONResponse( + content=self.metadata, headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour ) diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 1e49aef5f6..85e2a21c34 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Any from pydantic import AnyUrl from starlette.routing import Route, Router @@ -11,6 +10,7 @@ from mcp.server.auth.handlers.token import TokenHandler from mcp.server.auth.middleware.client_auth import ClientAuthenticator from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthMetadata @dataclass @@ -139,29 +139,29 @@ def build_metadata( service_documentation_url: AnyUrl | None, client_registration_options: ClientRegistrationOptions, revocation_options: RevocationOptions, -) -> dict[str, Any]: +) -> OAuthMetadata: issuer_url_str = str(issuer_url).rstrip("/") # Create metadata - metadata = { - "issuer": issuer_url_str, - "service_documentation": str(service_documentation_url).rstrip("/") + metadata = OAuthMetadata( + issuer=issuer_url_str, + service_documentation=str(service_documentation_url).rstrip("/") if service_documentation_url else None, - "authorization_endpoint": f"{issuer_url_str}{AUTHORIZATION_PATH}", - "response_types_supported": ["code"], - "code_challenge_methods_supported": ["S256"], - "token_endpoint": f"{issuer_url_str}{TOKEN_PATH}", - "token_endpoint_auth_methods_supported": ["client_secret_post"], - "grant_types_supported": ["authorization_code", "refresh_token"], - } + authorization_endpoint=f"{issuer_url_str}{AUTHORIZATION_PATH}", + response_types_supported=["code"], + code_challenge_methods_supported=["S256"], + token_endpoint=f"{issuer_url_str}{TOKEN_PATH}", + token_endpoint_auth_methods_supported=["client_secret_post"], + grant_types_supported=["authorization_code", "refresh_token"], + ) # Add registration endpoint if supported if client_registration_options.enabled: - metadata["registration_endpoint"] = f"{issuer_url_str}{REGISTRATION_PATH}" + metadata.registration_endpoint = f"{issuer_url_str}{REGISTRATION_PATH}" # Add revocation endpoint if supported if revocation_options.enabled: - metadata["revocation_endpoint"] = f"{issuer_url_str}{REVOCATION_PATH}" - metadata["revocation_endpoint_auth_methods_supported"] = ["client_secret_post"] + metadata.revocation_endpoint = f"{issuer_url_str}{REVOCATION_PATH}" + metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] return metadata diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index e62f8d762c..debcda47f0 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -51,9 +51,10 @@ class OAuthClientMetadata(BaseModel): software_version: Optional[str] = None -class OAuthClientInformation(BaseModel): +class OAuthClientInformationFull(OAuthClientMetadata): """ - RFC 7591 OAuth 2.0 Dynamic Client Registration client information. + RFC 7591 OAuth 2.0 Dynamic Client Registration full response + (client information plus metadata). """ client_id: str @@ -62,24 +63,6 @@ class OAuthClientInformation(BaseModel): client_secret_expires_at: Optional[int] = None -class OAuthClientInformationFull(OAuthClientMetadata, OAuthClientInformation): - """ - RFC 7591 OAuth 2.0 Dynamic Client Registration full response - (client information plus metadata). - """ - - pass - - -class OAuthClientRegistrationError(BaseModel): - """ - RFC 7591 OAuth 2.0 Dynamic Client Registration error response. - """ - - error: str - error_description: Optional[str] = None - - class OAuthMetadata(BaseModel): """ RFC 8414 OAuth 2.0 Authorization Server Metadata. From 792d3020e2495ee8ef995e18df63e4b3bfa23a3b Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 17:05:47 -0700 Subject: [PATCH 36/60] Make metadata more spec compliant --- src/mcp/server/auth/router.py | 62 ++++++++++++++----- src/mcp/server/fastmcp/server.py | 12 ++-- src/mcp/shared/auth.py | 56 ++++++++++------- .../fastmcp/auth/test_auth_integration.py | 10 +-- 4 files changed, 92 insertions(+), 48 deletions(-) diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 85e2a21c34..ba33d7ea7b 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -1,6 +1,7 @@ from dataclasses import dataclass +from typing import Callable -from pydantic import AnyUrl +from pydantic import AnyHttpUrl from starlette.routing import Route, Router from mcp.server.auth.handlers.authorize import AuthorizationHandler @@ -24,7 +25,7 @@ class RevocationOptions: enabled: bool = False -def validate_issuer_url(url: AnyUrl): +def validate_issuer_url(url: AnyHttpUrl): """ Validate that the issuer URL meets OAuth 2.0 requirements. @@ -58,8 +59,8 @@ def validate_issuer_url(url: AnyUrl): def create_auth_router( provider: OAuthServerProvider, - issuer_url: AnyUrl, - service_documentation_url: AnyUrl | None = None, + issuer_url: AnyHttpUrl, + service_documentation_url: AnyHttpUrl | None = None, client_registration_options: ClientRegistrationOptions | None = None, revocation_options: RevocationOptions | None = None, ) -> Router: @@ -134,34 +135,61 @@ def create_auth_router( return auth_router +def modify_url_path(url: AnyHttpUrl, path_mapper: Callable[[str], str]) -> AnyHttpUrl: + return AnyHttpUrl.build( + scheme=url.scheme, + username=url.username, + password=url.password, + host=url.host, + port=url.port, + path=path_mapper(url.path or ""), + query=url.query, + fragment=url.fragment, + ) + + def build_metadata( - issuer_url: AnyUrl, - service_documentation_url: AnyUrl | None, + issuer_url: AnyHttpUrl, + service_documentation_url: AnyHttpUrl | None, client_registration_options: ClientRegistrationOptions, revocation_options: RevocationOptions, ) -> OAuthMetadata: - issuer_url_str = str(issuer_url).rstrip("/") + authorization_url = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + AUTHORIZATION_PATH.lstrip("/") + ) + token_url = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + TOKEN_PATH.lstrip("/") + ) # Create metadata metadata = OAuthMetadata( - issuer=issuer_url_str, - service_documentation=str(service_documentation_url).rstrip("/") - if service_documentation_url - else None, - authorization_endpoint=f"{issuer_url_str}{AUTHORIZATION_PATH}", + issuer=issuer_url, + authorization_endpoint=authorization_url, + token_endpoint=token_url, + scopes_supported=None, response_types_supported=["code"], - code_challenge_methods_supported=["S256"], - token_endpoint=f"{issuer_url_str}{TOKEN_PATH}", - token_endpoint_auth_methods_supported=["client_secret_post"], + response_modes_supported=None, grant_types_supported=["authorization_code", "refresh_token"], + token_endpoint_auth_methods_supported=["client_secret_post"], + token_endpoint_auth_signing_alg_values_supported=None, + service_documentation=service_documentation_url, + ui_locales_supported=None, + op_policy_uri=None, + op_tos_uri=None, + introspection_endpoint=None, + code_challenge_methods_supported=["S256"], ) # Add registration endpoint if supported if client_registration_options.enabled: - metadata.registration_endpoint = f"{issuer_url_str}{REGISTRATION_PATH}" + metadata.registration_endpoint = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + REGISTRATION_PATH.lstrip("/") + ) # Add revocation endpoint if supported if revocation_options.enabled: - metadata.revocation_endpoint = f"{issuer_url_str}{REVOCATION_PATH}" + metadata.revocation_endpoint = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + REVOCATION_PATH.lstrip("/") + ) metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] return metadata diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c30b67c4a2..65d075d3ae 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -16,7 +16,7 @@ import anyio import pydantic_core import uvicorn -from pydantic import BaseModel, Field +from pydantic import AnyHttpUrl, BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from sse_starlette import EventSourceResponse @@ -99,9 +99,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]): Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None ) = Field(None, description="Lifespan context manager") - auth_issuer_url: AnyUrl | None = Field(None, description="Auth issuer URL") - auth_service_documentation_url: AnyUrl | None = Field( - None, description="Service documentation URL" + auth_issuer_url: AnyHttpUrl | None = Field( + None, + description="URL advertised as OAuth issuer; this should be the URL the server " + "is reachable at", + ) + auth_service_documentation_url: AnyHttpUrl | None = Field( + None, description="Service documentation URL advertised by OAuth" ) auth_client_registration_options: ClientRegistrationOptions | None = None auth_revocation_options: RevocationOptions | None = None diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index debcda47f0..dde4b25df0 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -24,10 +24,10 @@ class OAuthClientMetadata(BaseModel): redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) # token_endpoint_auth_method: this implementation only supports none & - # client_secret_basic; - # ie: we do not support client_secret_post - token_endpoint_auth_method: Literal["none", "client_secret_basic"] = ( - "client_secret_basic" + # client_secret_post; + # ie: we do not support client_secret_basic + token_endpoint_auth_method: Literal["none", "client_secret_post"] = ( + "client_secret_post" ) # grant_types: this implementation only supports authorization_code & refresh_token grant_types: List[Literal["authorization_code", "refresh_token"]] = [ @@ -66,23 +66,35 @@ class OAuthClientInformationFull(OAuthClientMetadata): class OAuthMetadata(BaseModel): """ RFC 8414 OAuth 2.0 Authorization Server Metadata. + See https://datatracker.ietf.org/doc/html/rfc8414#section-2 """ - issuer: str - authorization_endpoint: str - token_endpoint: str - registration_endpoint: Optional[str] = None - scopes_supported: Optional[List[str]] = None - response_types_supported: List[str] - response_modes_supported: Optional[List[str]] = None - grant_types_supported: Optional[List[str]] = None - token_endpoint_auth_methods_supported: Optional[List[str]] = None - token_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None - service_documentation: Optional[str] = None - revocation_endpoint: Optional[str] = None - revocation_endpoint_auth_methods_supported: Optional[List[str]] = None - revocation_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None - introspection_endpoint: Optional[str] = None - introspection_endpoint_auth_methods_supported: Optional[List[str]] = None - introspection_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None - code_challenge_methods_supported: Optional[List[str]] = None + issuer: AnyHttpUrl + authorization_endpoint: AnyHttpUrl + token_endpoint: AnyHttpUrl + registration_endpoint: AnyHttpUrl | None = None + scopes_supported: list[str] | None = None + response_types_supported: list[Literal["code"]] = ["code"] + response_modes_supported: list[Literal["query", "fragment"]] | None = None + grant_types_supported: ( + list[Literal["authorization_code", "refresh_token"]] | None + ) = None + token_endpoint_auth_methods_supported: ( + list[Literal["none", "client_secret_post"]] | None + ) = None + token_endpoint_auth_signing_alg_values_supported: None = None + service_documentation: AnyHttpUrl | None = None + ui_locales_supported: list[str] | None = None + op_policy_uri: AnyHttpUrl | None = None + op_tos_uri: AnyHttpUrl | None = None + revocation_endpoint: AnyHttpUrl | None = None + revocation_endpoint_auth_methods_supported: ( + list[Literal["client_secret_post"]] | None + ) = None + revocation_endpoint_auth_signing_alg_values_supported: None = None + introspection_endpoint: AnyHttpUrl | None = None + introspection_endpoint_auth_methods_supported: ( + list[Literal["client_secret_post"]] | None + ) = None + introspection_endpoint_auth_signing_alg_values_supported: None = None + code_challenge_methods_supported: list[Literal["S256"]] | None = None diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 458d46c166..c18b7bf112 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -15,7 +15,7 @@ import httpx import pytest from httpx_sse import aconnect_sse -from pydantic import AnyUrl +from pydantic import AnyHttpUrl from starlette.applications import Starlette from starlette.routing import Mount @@ -229,8 +229,8 @@ def auth_app(mock_oauth_provider): # Create auth router auth_router = create_auth_router( mock_oauth_provider, - AnyUrl("https://auth.example.com"), - AnyUrl("https://docs.example.com"), + AnyHttpUrl("https://auth.example.com"), + AnyHttpUrl("https://docs.example.com"), client_registration_options=ClientRegistrationOptions(enabled=True), revocation_options=RevocationOptions(enabled=True), ) @@ -373,7 +373,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert response.status_code == 200 metadata = response.json() - assert metadata["issuer"] == "https://auth.example.com" + assert metadata["issuer"] == "https://auth.example.com/" assert ( metadata["authorization_endpoint"] == "https://auth.example.com/authorize" ) @@ -389,7 +389,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): "authorization_code", "refresh_token", ] - assert metadata["service_documentation"] == "https://docs.example.com" + assert metadata["service_documentation"] == "https://docs.example.com/" @pytest.mark.anyio async def test_token_validation_error(self, test_client: httpx.AsyncClient): From 6c48b1107b80b91fae17fea0d7438cf392b9c5e1 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 17:08:52 -0700 Subject: [PATCH 37/60] Use python 3.10 types everywhere --- src/mcp/shared/auth.py | 42 +++++++++---------- .../fastmcp/auth/test_auth_integration.py | 6 +-- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index dde4b25df0..29b360039b 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -1,4 +1,4 @@ -from typing import Any, List, Literal, Optional +from typing import Any, Literal from pydantic import AnyHttpUrl, BaseModel, Field @@ -10,9 +10,9 @@ class OAuthToken(BaseModel): access_token: str token_type: Literal["bearer"] = "bearer" - expires_in: Optional[int] = None - scope: Optional[str] = None - refresh_token: Optional[str] = None + expires_in: int | None = None + scope: str | None = None + refresh_token: str | None = None class OAuthClientMetadata(BaseModel): @@ -22,7 +22,7 @@ class OAuthClientMetadata(BaseModel): for the full specification. """ - redirect_uris: List[AnyHttpUrl] = Field(..., min_length=1) + redirect_uris: list[AnyHttpUrl] = Field(..., min_length=1) # token_endpoint_auth_method: this implementation only supports none & # client_secret_post; # ie: we do not support client_secret_basic @@ -30,25 +30,25 @@ class OAuthClientMetadata(BaseModel): "client_secret_post" ) # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: List[Literal["authorization_code", "refresh_token"]] = [ + grant_types: list[Literal["authorization_code", "refresh_token"]] = [ "authorization_code" ] # this implementation only supports code; ie: it does not support implicit grants - response_types: List[Literal["code"]] = ["code"] - scope: Optional[str] = None + response_types: list[Literal["code"]] = ["code"] + scope: str | None = None # these fields are currently unused, but we support & store them for potential # future use - client_name: Optional[str] = None - client_uri: Optional[AnyHttpUrl] = None - logo_uri: Optional[AnyHttpUrl] = None - contacts: Optional[List[str]] = None - tos_uri: Optional[AnyHttpUrl] = None - policy_uri: Optional[AnyHttpUrl] = None - jwks_uri: Optional[AnyHttpUrl] = None - jwks: Optional[Any] = None - software_id: Optional[str] = None - software_version: Optional[str] = None + client_name: str | None = None + client_uri: AnyHttpUrl | None = None + logo_uri: AnyHttpUrl | None = None + contacts: list[str] | None = None + tos_uri: AnyHttpUrl | None = None + policy_uri: AnyHttpUrl | None = None + jwks_uri: AnyHttpUrl | None = None + jwks: Any | None = None + software_id: str | None = None + software_version: str | None = None class OAuthClientInformationFull(OAuthClientMetadata): @@ -58,9 +58,9 @@ class OAuthClientInformationFull(OAuthClientMetadata): """ client_id: str - client_secret: Optional[str] = None - client_id_issued_at: Optional[int] = None - client_secret_expires_at: Optional[int] = None + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None class OAuthMetadata(BaseModel): diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index c18b7bf112..02b6a005e9 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -8,7 +8,7 @@ import secrets import time import unittest.mock -from typing import List, Literal, Optional +from typing import Literal from urllib.parse import parse_qs, urlparse import anyio @@ -48,7 +48,7 @@ class MockClientStore: def __init__(self): self.clients = {} - async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: return self.clients.get(client_id) async def register_client(self, client_info: OAuthClientInformationFull): @@ -145,7 +145,7 @@ async def exchange_refresh_token( self, client: OAuthClientInformationFull, refresh_token: RefreshToken, - scopes: List[str], + scopes: list[str], ) -> OAuthToken: # Check if refresh token exists assert refresh_token.token in self.refresh_tokens From a437566229b5b97eff4233cba5a2a86466665b43 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 17:30:06 -0700 Subject: [PATCH 38/60] Add back authorization to the /revoke endpoint, simplify revoke --- src/mcp/server/auth/handlers/revoke.py | 34 +++++++++++--- src/mcp/server/auth/provider.py | 5 +-- .../fastmcp/auth/test_auth_integration.py | 45 +++++++++---------- 3 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index e45c935912..5a2359cf8a 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from functools import partial from typing import Literal from pydantic import BaseModel, ValidationError @@ -6,13 +7,14 @@ from starlette.responses import Response from mcp.server.auth.errors import ( + InvalidClientError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( ClientAuthenticator, ) -from mcp.server.auth.provider import OAuthServerProvider +from mcp.server.auth.provider import AuthInfo, OAuthServerProvider, RefreshToken class RevocationRequest(BaseModel): @@ -22,6 +24,8 @@ class RevocationRequest(BaseModel): token: str token_type_hint: Literal["access_token", "refresh_token"] | None = None + client_id: str + client_secret: str | None class RevocationErrorResponse(BaseModel): @@ -50,10 +54,30 @@ async def handle(self, request: Request) -> Response: ), ) - # Revoke token - await self.provider.revoke_token( - revocation_request.token, revocation_request.token_type_hint - ) + # Authenticate client + try: + client = await self.client_authenticator.authenticate( + revocation_request.client_id, revocation_request.client_secret + ) + except InvalidClientError as e: + return PydanticJSONResponse(status_code=401, content=e.error_response()) + + loaders = [ + self.provider.load_access_token, + partial(self.provider.load_refresh_token, client), + ] + if revocation_request.token_type_hint == "refresh_token": + loaders = reversed(loaders) + + token: None | AuthInfo | RefreshToken = None + for loader in loaders: + token = await loader(revocation_request.token) + if token is not None: + break + + if token and token.client_id == client.client_id: + # Revoke token + await self.provider.revoke_token(token) # Return successful empty response return Response( diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index e0ee171ab6..a7254be3c6 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,4 +1,4 @@ -from typing import Literal, Protocol +from typing import Protocol from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, BaseModel @@ -172,8 +172,7 @@ async def load_access_token(self, token: str) -> AuthInfo | None: async def revoke_token( self, - token: str, - token_type_hint: Literal["access_token", "refresh_token"] | None = None, + token: AuthInfo | RefreshToken, ) -> None: """ Revokes an access or refresh token. diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 02b6a005e9..3c058add0c 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -8,7 +8,6 @@ import secrets import time import unittest.mock -from typing import Literal from urllib.parse import parse_qs, urlparse import anyio @@ -164,11 +163,12 @@ async def exchange_refresh_token( new_refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the new tokens - self.tokens[new_access_token] = { - "client_id": client.client_id, - "scopes": scopes or token_info.scopes, - "expires_at": int(time.time()) + 3600, - } + self.tokens[new_access_token] = AuthInfo( + token=new_access_token, + client_id=client.client_id, + scopes=scopes or token_info.scopes, + expires_at=int(time.time()) + 3600, + ) self.refresh_tokens[new_refresh_token] = new_access_token @@ -198,25 +198,20 @@ async def load_access_token(self, token: str) -> AuthInfo | None: expires_at=token_info.expires_at, ) - async def revoke_token( - self, - token: str, - token_type_hint: Literal["access_token", "refresh_token"] | None = None, - ) -> None: - # Check if it's a refresh token - if token in self.refresh_tokens: - # Remove the refresh token - del self.refresh_tokens[token] - - # Check if it's an access token - elif token in self.tokens: - # Remove the access token - del self.tokens[token] - - # Also remove any refresh tokens that point to this access token - for refresh_token, access_token in list(self.refresh_tokens.items()): - if access_token == token: - del self.refresh_tokens[refresh_token] + async def revoke_token(self, token: OAuthToken | RefreshToken) -> None: + match token: + case RefreshToken(): + # Remove the refresh token + del self.refresh_tokens[token.token] + + case AuthInfo(): + # Remove the access token + del self.tokens[token.token] + + # Also remove any refresh tokens that point to this access token + for refresh_token, access_token in list(self.refresh_tokens.items()): + if access_token == token.token: + del self.refresh_tokens[refresh_token] @pytest.fixture From 9fee92976c165d02692cb4da315b87cca1cbdde5 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Tue, 11 Mar 2025 22:40:34 -0700 Subject: [PATCH 39/60] Move around validation logic --- src/mcp/server/auth/errors.py | 16 ------ src/mcp/server/auth/handlers/authorize.py | 53 +++++-------------- src/mcp/server/auth/handlers/revoke.py | 14 +++-- src/mcp/server/auth/handlers/token.py | 11 ++-- src/mcp/server/auth/middleware/client_auth.py | 14 +++-- src/mcp/shared/auth.py | 36 +++++++++++++ .../fastmcp/auth/test_auth_integration.py | 2 +- 7 files changed, 76 insertions(+), 70 deletions(-) diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index e629e28acb..9353285986 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -28,22 +28,6 @@ def error_response(self) -> ErrorResponse: ) -class InvalidRequestError(OAuthError): - """ - Invalid request error. - """ - - error_code = "invalid_request" - - -class InvalidClientError(OAuthError): - """ - Invalid client error. - """ - - error_code = "invalid_client" - - def stringify_pydantic_error(validation_error: ValidationError) -> str: return "\n".join( f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 6c99bcfb7a..3f78b7e87f 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -9,7 +9,6 @@ from starlette.responses import RedirectResponse, Response from mcp.server.auth.errors import ( - InvalidRequestError, OAuthError, stringify_pydantic_error, ) @@ -19,7 +18,10 @@ OAuthServerProvider, construct_redirect_uri, ) -from mcp.shared.auth import OAuthClientInformationFull +from mcp.shared.auth import ( + InvalidRedirectUriError, + InvalidScopeError, +) logger = logging.getLogger(__name__) @@ -66,37 +68,6 @@ class AuthorizationErrorResponse(BaseModel): state: str | None = None -def validate_scope( - requested_scope: str | None, client: OAuthClientInformationFull -) -> list[str] | None: - if requested_scope is None: - return None - requested_scopes = requested_scope.split(" ") - allowed_scopes = [] if client.scope is None else client.scope.split(" ") - for scope in requested_scopes: - if scope not in allowed_scopes: - raise InvalidRequestError(f"Client was not registered with scope {scope}") - return requested_scopes - - -def validate_redirect_uri( - redirect_uri: AnyHttpUrl | None, client: OAuthClientInformationFull -) -> AnyHttpUrl: - if redirect_uri is not None: - # Validate redirect_uri against client's registered redirect URIs - if redirect_uri not in client.redirect_uris: - raise InvalidRequestError( - f"Redirect URI '{redirect_uri}' not registered for client" - ) - return redirect_uri - elif len(client.redirect_uris) == 1: - return client.redirect_uris[0] - else: - raise InvalidRequestError( - "redirect_uri must be specified when client has multiple registered URIs" - ) - - def best_effort_extract_string( key: str, params: None | FormData | QueryParams ) -> str | None: @@ -146,8 +117,8 @@ async def error_response( best_effort_extract_string("redirect_uri", params) ).root try: - redirect_uri = validate_redirect_uri(raw_redirect_uri, client) - except (ValidationError, InvalidRequestError): + redirect_uri = client.validate_redirect_uri(raw_redirect_uri) + except (ValidationError, InvalidRedirectUriError): pass if state is None: # make last-ditch effort to load state @@ -213,22 +184,22 @@ async def error_response( # Validate redirect_uri against client's registered URIs try: - redirect_uri = validate_redirect_uri(auth_request.redirect_uri, client) - except InvalidRequestError as validation_error: + redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri) + except InvalidRedirectUriError as validation_error: # For redirect_uri validation errors, return direct error (no redirect) return await error_response( error="invalid_request", - error_description=validation_error.error_description, + error_description=validation_error.message, ) # Validate scope - for scope errors, we can redirect try: - scopes = validate_scope(auth_request.scope, client) - except InvalidRequestError as validation_error: + scopes = client.validate_scope(auth_request.scope) + except InvalidScopeError as validation_error: # For scope errors, redirect with error parameters return await error_response( error="invalid_scope", - error_description=validation_error.error_description, + error_description=validation_error.message, ) # Setup authorization parameters diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 5a2359cf8a..141fc81e88 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -7,11 +7,11 @@ from starlette.responses import Response from mcp.server.auth.errors import ( - InvalidClientError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( + AuthenticationError, ClientAuthenticator, ) from mcp.server.auth.provider import AuthInfo, OAuthServerProvider, RefreshToken @@ -29,7 +29,7 @@ class RevocationRequest(BaseModel): class RevocationErrorResponse(BaseModel): - error: Literal["invalid_request",] + error: Literal["invalid_request", "unauthorized_client"] error_description: str | None = None @@ -59,8 +59,14 @@ async def handle(self, request: Request) -> Response: client = await self.client_authenticator.authenticate( revocation_request.client_id, revocation_request.client_secret ) - except InvalidClientError as e: - return PydanticJSONResponse(status_code=401, content=e.error_response()) + except AuthenticationError as e: + return PydanticJSONResponse( + status_code=401, + content=RevocationErrorResponse( + error="unauthorized_client", + error_description=e.message, + ), + ) loaders = [ self.provider.load_access_token, diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 14c92e4a1d..a60c091c07 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -9,11 +9,11 @@ from mcp.server.auth.errors import ( ErrorResponse, - InvalidClientError, stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.middleware.client_auth import ( + AuthenticationError, ClientAuthenticator, ) from mcp.server.auth.provider import OAuthServerProvider @@ -111,8 +111,13 @@ async def handle(self, request: Request): client_id=token_request.client_id, client_secret=token_request.client_secret, ) - except InvalidClientError as e: - return self.response(e.error_response()) + except AuthenticationError as e: + return self.response( + TokenErrorResponse( + error="unauthorized_client", + error_description=e.message, + ) + ) if token_request.grant_type not in client_info.grant_types: return self.response( diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index cda5d79a52..56cd93ae9e 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -1,10 +1,14 @@ import time -from mcp.server.auth.errors import InvalidClientError from mcp.server.auth.provider import OAuthRegisteredClientsStore from mcp.shared.auth import OAuthClientInformationFull +class AuthenticationError(Exception): + def __init__(self, message: str): + self.message = message + + class ClientAuthenticator: """ ClientAuthenticator is a callable which validates requests from a client @@ -31,21 +35,21 @@ async def authenticate( # Look up client information client = await self.clients_store.get_client(client_id) if not client: - raise InvalidClientError("Invalid client_id") + raise AuthenticationError("Invalid client_id") # If client from the store expects a secret, validate that the request provides # that secret if client.client_secret: if not client_secret: - raise InvalidClientError("Client secret is required") + raise AuthenticationError("Client secret is required") if client.client_secret != client_secret: - raise InvalidClientError("Invalid client_secret") + raise AuthenticationError("Invalid client_secret") if ( client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()) ): - raise InvalidClientError("Client secret has expired") + raise AuthenticationError("Client secret has expired") return client diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 29b360039b..bcf287e5ed 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -15,6 +15,16 @@ class OAuthToken(BaseModel): refresh_token: str | None = None +class InvalidScopeError(Exception): + def __init__(self, message: str): + self.message = message + + +class InvalidRedirectUriError(Exception): + def __init__(self, message: str): + self.message = message + + class OAuthClientMetadata(BaseModel): """ RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. @@ -50,6 +60,32 @@ class OAuthClientMetadata(BaseModel): software_id: str | None = None software_version: str | None = None + def validate_scope(self, requested_scope: str | None) -> list[str] | None: + if requested_scope is None: + return None + requested_scopes = requested_scope.split(" ") + allowed_scopes = [] if self.scope is None else self.scope.split(" ") + for scope in requested_scopes: + if scope not in allowed_scopes: + raise InvalidScopeError(f"Client was not registered with scope {scope}") + return requested_scopes + + def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl: + if redirect_uri is not None: + # Validate redirect_uri against client's registered redirect URIs + if redirect_uri not in self.redirect_uris: + raise InvalidRedirectUriError( + f"Redirect URI '{redirect_uri}' not registered for client" + ) + return redirect_uri + elif len(self.redirect_uris) == 1: + return self.redirect_uris[0] + else: + raise InvalidRedirectUriError( + "redirect_uri must be specified when client " + "has multiple registered URIs" + ) + class OAuthClientInformationFull(OAuthClientMetadata): """ diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 3c058add0c..38f58d4a58 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -198,7 +198,7 @@ async def load_access_token(self, token: str) -> AuthInfo | None: expires_at=token_info.expires_at, ) - async def revoke_token(self, token: OAuthToken | RefreshToken) -> None: + async def revoke_token(self, token: AuthInfo | RefreshToken) -> None: match token: case RefreshToken(): # Remove the refresh token From d79be8f227d7dd23f2a9782d4edb559273a70a31 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 09:26:38 -0700 Subject: [PATCH 40/60] Fixups while integrating new auth capabilities --- src/mcp/server/auth/provider.py | 25 +++++--- src/mcp/server/auth/router.py | 62 +++++++------------ src/mcp/server/fastmcp/server.py | 61 +++++++++++++----- src/mcp/shared/auth.py | 3 +- .../fastmcp/auth/test_auth_integration.py | 6 +- 5 files changed, 91 insertions(+), 66 deletions(-) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index a7254be3c6..10e666028c 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,4 +1,4 @@ -from typing import Protocol +from typing import Generic, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, BaseModel @@ -62,7 +62,16 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None ... -class OAuthServerProvider(Protocol): +# NOTE: FastMCP doesn't render any of these types in the user response, so it's +# OK to add fields to subclasses which should not be exposed externally. +AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) +RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) +AuthInfoT = TypeVar("AuthInfoT", bound=AuthInfo) + + +class OAuthServerProvider( + Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AuthInfoT] +): @property def clients_store(self) -> OAuthRegisteredClientsStore: """ @@ -107,7 +116,7 @@ async def authorize( async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str - ) -> AuthorizationCode | None: + ) -> AuthorizationCodeT | None: """ Loads metadata for the authorization code challenge. @@ -121,7 +130,7 @@ async def load_authorization_code( ... async def exchange_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT ) -> OAuthToken: """ Exchanges an authorization code for an access token. @@ -137,12 +146,12 @@ async def exchange_authorization_code( async def load_refresh_token( self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshToken | None: ... + ) -> RefreshTokenT | None: ... async def exchange_refresh_token( self, client: OAuthClientInformationFull, - refresh_token: RefreshToken, + refresh_token: RefreshTokenT, scopes: list[str], ) -> OAuthToken: """ @@ -158,7 +167,7 @@ async def exchange_refresh_token( """ ... - async def load_access_token(self, token: str) -> AuthInfo | None: + async def load_access_token(self, token: str) -> AuthInfoT | None: """ Verifies an access token and returns information about it. @@ -172,7 +181,7 @@ async def load_access_token(self, token: str) -> AuthInfo | None: async def revoke_token( self, - token: AuthInfo | RefreshToken, + token: AuthInfoT | RefreshTokenT, ) -> None: """ Revokes an access or refresh token. diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index ba33d7ea7b..4e3fc2b048 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -2,7 +2,7 @@ from typing import Callable from pydantic import AnyHttpUrl -from starlette.routing import Route, Router +from starlette.routing import Route from mcp.server.auth.handlers.authorize import AuthorizationHandler from mcp.server.auth.handlers.metadata import MetadataHandler @@ -57,27 +57,13 @@ def validate_issuer_url(url: AnyHttpUrl): REVOCATION_PATH = "/revoke" -def create_auth_router( +def create_auth_routes( provider: OAuthServerProvider, issuer_url: AnyHttpUrl, service_documentation_url: AnyHttpUrl | None = None, client_registration_options: ClientRegistrationOptions | None = None, revocation_options: RevocationOptions | None = None, -) -> Router: - """ - Create a Starlette router with standard MCP authorization endpoints. - - Args: - provider: OAuth server provider - issuer_url: Issuer URL for the authorization server - service_documentation_url: Optional URL for service documentation - client_registration_options: Options for client registration - revocation_options: Options for token revocation - - Returns: - Starlette router with authorization endpoints - """ - +) -> list[Route]: validate_issuer_url(issuer_url) client_registration_options = ( @@ -93,32 +79,30 @@ def create_auth_router( client_authenticator = ClientAuthenticator(provider.clients_store) # Create routes - auth_router = Router( - routes=[ - Route( - "/.well-known/oauth-authorization-server", - endpoint=MetadataHandler(metadata).handle, - methods=["GET"], - ), - Route( - AUTHORIZATION_PATH, - endpoint=AuthorizationHandler(provider).handle, - methods=["GET", "POST"], - ), - Route( - TOKEN_PATH, - endpoint=TokenHandler(provider, client_authenticator).handle, - methods=["POST"], - ), - ] - ) + routes = [ + Route( + "/.well-known/oauth-authorization-server", + endpoint=MetadataHandler(metadata).handle, + methods=["GET"], + ), + Route( + AUTHORIZATION_PATH, + endpoint=AuthorizationHandler(provider).handle, + methods=["GET", "POST"], + ), + Route( + TOKEN_PATH, + endpoint=TokenHandler(provider, client_authenticator).handle, + methods=["POST"], + ), + ] if client_registration_options.enabled: registration_handler = RegistrationHandler( provider.clients_store, client_secret_expiry_seconds=client_registration_options.client_secret_expiry_seconds, ) - auth_router.routes.append( + routes.append( Route( REGISTRATION_PATH, endpoint=registration_handler.handle, @@ -128,11 +112,11 @@ def create_auth_router( if revocation_options.enabled: revocation_handler = RevocationHandler(provider, client_authenticator) - auth_router.routes.append( + routes.append( Route(REVOCATION_PATH, endpoint=revocation_handler.handle, methods=["POST"]) ) - return auth_router + return routes def modify_url_path(url: AnyHttpUrl, path_mapper: Callable[[str], str]) -> AnyHttpUrl: diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 65d075d3ae..fc40305b84 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -11,7 +11,7 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Callable, Generic, Literal, Sequence +from typing import Any, Awaitable, Callable, Generic, Literal, Sequence import anyio import pydantic_core @@ -24,6 +24,7 @@ from starlette.authentication import requires from starlette.middleware.authentication import AuthenticationMiddleware +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import ( BearerAuthBackend, RequireAuthMiddleware, @@ -151,6 +152,7 @@ def __init__( warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts ) self._auth_provider = auth_provider + self._custom_starlette_routes = [] self.dependencies = self.settings.dependencies # Set up MCP protocol handlers @@ -477,6 +479,33 @@ def decorator(func: AnyFunction) -> AnyFunction: return decorator + def custom_route( + self, + path: str, + methods: list[str], + name: str | None = None, + include_in_schema: bool = True, + ): + from starlette.requests import Request + from starlette.responses import Response + from starlette.routing import Route + + def decorator( + func: Callable[[Request], Awaitable[Response]], + ) -> Callable[[Request], Awaitable[Response]]: + self._custom_starlette_routes.append( + Route( + path, + endpoint=func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + ) + return func + + return decorator + async def run_stdio_async(self) -> None: """Run the server using stdio transport.""" async with stdio_server() as (read_stream, write_stream): @@ -513,31 +542,33 @@ async def handle_sse(request) -> EventSourceResponse: routes = [] middleware = [] required_scopes = self.settings.auth_required_scopes or [] - auth_router = None # Add auth endpoints if auth provider is configured if self._auth_provider and self.settings.auth_issuer_url: - from mcp.server.auth.router import create_auth_router + from mcp.server.auth.router import create_auth_routes - # Set up bearer auth middleware if auth is required middleware = [ + # extract auth info from request (but do not require it) Middleware( AuthenticationMiddleware, backend=BearerAuthBackend( provider=self._auth_provider, ), - ) + ), + # Add the auth context middleware to store + # authenticated user in a contextvar + Middleware(AuthContextMiddleware), ] - auth_router = create_auth_router( - provider=self._auth_provider, - issuer_url=self.settings.auth_issuer_url, - service_documentation_url=self.settings.auth_service_documentation_url, - client_registration_options=self.settings.auth_client_registration_options, - revocation_options=self.settings.auth_revocation_options, + routes.extend( + create_auth_routes( + provider=self._auth_provider, + issuer_url=self.settings.auth_issuer_url, + service_documentation_url=self.settings.auth_service_documentation_url, + client_registration_options=self.settings.auth_client_registration_options, + revocation_options=self.settings.auth_revocation_options, + ) ) - # Add the auth router as a mount - routes.append( Route( "/sse", endpoint=requires(required_scopes)(handle_sse), methods=["GET"] @@ -549,8 +580,8 @@ async def handle_sse(request) -> EventSourceResponse: app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), ) ) - if auth_router: - routes.append(Mount("/", app=auth_router)) + # mount these routes last, so they have the lowest route matching precedence + routes.extend(self._custom_starlette_routes) # Create Starlette app with routes and middleware return Starlette( diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index bcf287e5ed..22f8a971d6 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -41,7 +41,8 @@ class OAuthClientMetadata(BaseModel): ) # grant_types: this implementation only supports authorization_code & refresh_token grant_types: list[Literal["authorization_code", "refresh_token"]] = [ - "authorization_code" + "authorization_code", + "refresh_token", ] # this implementation only supports code; ie: it does not support implicit grants response_types: list[Literal["code"]] = ["code"] diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 38f58d4a58..07babdb830 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -30,7 +30,7 @@ from mcp.server.auth.router import ( ClientRegistrationOptions, RevocationOptions, - create_auth_router, + create_auth_routes, ) from mcp.server.fastmcp import FastMCP from mcp.shared.auth import ( @@ -222,7 +222,7 @@ def mock_oauth_provider(): @pytest.fixture def auth_app(mock_oauth_provider): # Create auth router - auth_router = create_auth_router( + auth_routes = create_auth_routes( mock_oauth_provider, AnyHttpUrl("https://auth.example.com"), AnyHttpUrl("https://docs.example.com"), @@ -231,7 +231,7 @@ def auth_app(mock_oauth_provider): ) # Create Starlette app - app = Starlette(routes=[Mount("/", app=auth_router)]) + app = Starlette(routes=auth_routes) return app From 8d637b432eae7ffcfb6a21c22be4372cdcea743f Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 10:23:15 -0700 Subject: [PATCH 41/60] Pull all auth settings out into a separate config --- src/mcp/server/auth/router.py | 23 ++++++++--- src/mcp/server/fastmcp/server.py | 40 ++++++++++--------- .../fastmcp/auth/test_auth_integration.py | 12 +++--- 3 files changed, 45 insertions(+), 30 deletions(-) diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/router.py index 4e3fc2b048..4b1893f4bd 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/router.py @@ -1,7 +1,6 @@ -from dataclasses import dataclass from typing import Callable -from pydantic import AnyHttpUrl +from pydantic import AnyHttpUrl, BaseModel, Field from starlette.routing import Route from mcp.server.auth.handlers.authorize import AuthorizationHandler @@ -14,17 +13,29 @@ from mcp.shared.auth import OAuthMetadata -@dataclass -class ClientRegistrationOptions: +class ClientRegistrationOptions(BaseModel): enabled: bool = False client_secret_expiry_seconds: int | None = None -@dataclass -class RevocationOptions: +class RevocationOptions(BaseModel): enabled: bool = False +class AuthSettings(BaseModel): + issuer_url: AnyHttpUrl = Field( + ..., + description="URL advertised as OAuth issuer; this should be the URL the server " + "is reachable at", + ) + service_documentation_url: AnyHttpUrl | None = Field( + None, description="Service documentation URL advertised by OAuth" + ) + client_registration_options: ClientRegistrationOptions | None = None + revocation_options: RevocationOptions | None = None + required_scopes: list[str] | None = None + + def validate_issuer_url(url: AnyHttpUrl): """ Validate that the issuer URL meets OAuth 2.0 requirements. diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index fc40305b84..82ec57cb70 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -16,7 +16,7 @@ import anyio import pydantic_core import uvicorn -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from sse_starlette import EventSourceResponse @@ -30,7 +30,9 @@ RequireAuthMiddleware, ) from mcp.server.auth.provider import OAuthServerProvider -from mcp.server.auth.router import ClientRegistrationOptions, RevocationOptions +from mcp.server.auth.router import ( + AuthSettings, +) from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -71,6 +73,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): model_config = SettingsConfigDict( env_prefix="FASTMCP_", env_file=".env", + env_nested_delimiter="__", + nested_model_default_partial_update=True, extra="ignore", ) @@ -100,17 +104,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): Callable[["FastMCP"], AbstractAsyncContextManager[LifespanResultT]] | None ) = Field(None, description="Lifespan context manager") - auth_issuer_url: AnyHttpUrl | None = Field( - None, - description="URL advertised as OAuth issuer; this should be the URL the server " - "is reachable at", - ) - auth_service_documentation_url: AnyHttpUrl | None = Field( - None, description="Service documentation URL advertised by OAuth" - ) - auth_client_registration_options: ClientRegistrationOptions | None = None - auth_revocation_options: RevocationOptions | None = None - auth_required_scopes: list[str] | None = None + auth: AuthSettings | None = None def lifespan_wrapper( @@ -151,6 +145,11 @@ def __init__( self._prompt_manager = PromptManager( warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts ) + if (self.settings.auth is not None) != (auth_provider is not None): + raise ValueError( + "settings.auth must be specified if and only if auth_provider " + "is specified" + ) self._auth_provider = auth_provider self._custom_starlette_routes = [] self.dependencies = self.settings.dependencies @@ -541,12 +540,15 @@ async def handle_sse(request) -> EventSourceResponse: # Create routes routes = [] middleware = [] - required_scopes = self.settings.auth_required_scopes or [] + required_scopes = [] # Add auth endpoints if auth provider is configured - if self._auth_provider and self.settings.auth_issuer_url: + if self._auth_provider: + assert self.settings.auth from mcp.server.auth.router import create_auth_routes + required_scopes = self.settings.auth.required_scopes or [] + middleware = [ # extract auth info from request (but do not require it) Middleware( @@ -562,10 +564,10 @@ async def handle_sse(request) -> EventSourceResponse: routes.extend( create_auth_routes( provider=self._auth_provider, - issuer_url=self.settings.auth_issuer_url, - service_documentation_url=self.settings.auth_service_documentation_url, - client_registration_options=self.settings.auth_client_registration_options, - revocation_options=self.settings.auth_revocation_options, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, ) ) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 07babdb830..28b26f21be 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -16,7 +16,6 @@ from httpx_sse import aconnect_sse from pydantic import AnyHttpUrl from starlette.applications import Starlette -from starlette.routing import Mount from mcp.server.auth.provider import ( AuthInfo, @@ -28,6 +27,7 @@ construct_redirect_uri, ) from mcp.server.auth.router import ( + AuthSettings, ClientRegistrationOptions, RevocationOptions, create_auth_routes, @@ -958,11 +958,13 @@ async def test_fastmcp_with_auth( # Create FastMCP server with auth provider mcp = FastMCP( auth_provider=mock_oauth_provider, - auth_issuer_url="https://auth.example.com", require_auth=True, - auth_client_registration_options=ClientRegistrationOptions(enabled=True), - auth_revocation_options=RevocationOptions(enabled=True), - auth_required_scopes=["read"], + auth=AuthSettings( + issuer_url=AnyHttpUrl("https://auth.example.com"), + client_registration_options=ClientRegistrationOptions(enabled=True), + revocation_options=RevocationOptions(enabled=True), + required_scopes=["read"], + ), ) # Add a test tool From 8c86bce36275c6eb5018dafe613841078d5f75da Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 10:25:54 -0700 Subject: [PATCH 42/60] Move router file to be routes --- src/mcp/server/auth/{router.py => routes.py} | 4 +--- src/mcp/server/fastmcp/server.py | 4 ++-- tests/server/fastmcp/auth/test_auth_integration.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) rename src/mcp/server/auth/{router.py => routes.py} (97%) diff --git a/src/mcp/server/auth/router.py b/src/mcp/server/auth/routes.py similarity index 97% rename from src/mcp/server/auth/router.py rename to src/mcp/server/auth/routes.py index 4b1893f4bd..898df924b2 100644 --- a/src/mcp/server/auth/router.py +++ b/src/mcp/server/auth/routes.py @@ -28,9 +28,7 @@ class AuthSettings(BaseModel): description="URL advertised as OAuth issuer; this should be the URL the server " "is reachable at", ) - service_documentation_url: AnyHttpUrl | None = Field( - None, description="Service documentation URL advertised by OAuth" - ) + service_documentation_url: AnyHttpUrl | None = None client_registration_options: ClientRegistrationOptions | None = None revocation_options: RevocationOptions | None = None required_scopes: list[str] | None = None diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 82ec57cb70..778b0dcc1b 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -30,7 +30,7 @@ RequireAuthMiddleware, ) from mcp.server.auth.provider import OAuthServerProvider -from mcp.server.auth.router import ( +from mcp.server.auth.routes import ( AuthSettings, ) from mcp.server.fastmcp.exceptions import ResourceError @@ -545,7 +545,7 @@ async def handle_sse(request) -> EventSourceResponse: # Add auth endpoints if auth provider is configured if self._auth_provider: assert self.settings.auth - from mcp.server.auth.router import create_auth_routes + from mcp.server.auth.routes import create_auth_routes required_scopes = self.settings.auth.required_scopes or [] diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 28b26f21be..a06123fed5 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -26,7 +26,7 @@ RefreshToken, construct_redirect_uri, ) -from mcp.server.auth.router import ( +from mcp.server.auth.routes import ( AuthSettings, ClientRegistrationOptions, RevocationOptions, From 31618c148e9600dec4b0a3baeed62e8b2695c6f1 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 10:26:03 -0700 Subject: [PATCH 43/60] Add auth context middleware --- .../server/auth/middleware/auth_context.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 src/mcp/server/auth/middleware/auth_context.py diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py new file mode 100644 index 0000000000..7de643c891 --- /dev/null +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -0,0 +1,57 @@ +import contextvars + +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AuthInfo + +# Create a contextvar to store the authenticated user +# The default is None, indicating no authenticated user is present +auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]( + "auth_context", default=None +) + + +def get_current_auth_info() -> AuthInfo | None: + """ + Get the auth info from the current context. + + Returns: + The auth info if an authenticated user is available, None otherwise. + """ + auth_user = auth_context_var.get() + return auth_user.auth_info if auth_user else None + + +class AuthContextMiddleware(BaseHTTPMiddleware): + """ + Middleware that extracts the authenticated user from the request + and sets it in a contextvar for easy access throughout the request lifecycle. + + This middleware should be added after the AuthenticationMiddleware in the + middleware stack to ensure that the user is properly authenticated before + being stored in the context. + """ + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + # Get the authenticated user from the request if it exists + user = getattr(request, "user", None) + + # Only set the context var if the user is an AuthenticatedUser + if isinstance(user, AuthenticatedUser): + # Set the authenticated user in the contextvar + token = auth_context_var.set(user) + try: + # Process the request + response = await call_next(request) + return response + finally: + # Reset the contextvar after the request is processed + auth_context_var.reset(token) + else: + # No authenticated user, just process the request + return await call_next(request) From 5ebbc19b713bd896764cad369cdf8432823e50c7 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 10:42:14 -0700 Subject: [PATCH 44/60] Validate scopes + provide default --- src/mcp/server/auth/handlers/register.py | 26 +++++++- src/mcp/server/auth/routes.py | 26 +------- src/mcp/server/auth/settings.py | 24 ++++++++ src/mcp/server/fastmcp/server.py | 2 +- .../fastmcp/auth/test_auth_integration.py | 59 ++++++++++++++++++- 5 files changed, 108 insertions(+), 29 deletions(-) create mode 100644 src/mcp/server/auth/settings.py diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 893e7a7f8e..e79355eea6 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -11,6 +11,7 @@ from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import OAuthRegisteredClientsStore +from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata @@ -33,7 +34,7 @@ class RegistrationErrorResponse(BaseModel): @dataclass class RegistrationHandler: clients_store: OAuthRegisteredClientsStore - client_secret_expiry_seconds: int | None + options: ClientRegistrationOptions async def handle(self, request: Request) -> Response: # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 @@ -41,6 +42,8 @@ async def handle(self, request: Request) -> Response: # Parse request body as JSON body = await request.json() client_metadata = OAuthClientMetadata.model_validate(body) + + # Scope validation is handled below except ValidationError as validation_error: return PydanticJSONResponse( content=RegistrationErrorResponse( @@ -56,10 +59,27 @@ async def handle(self, request: Request) -> Response: # cryptographically secure random 32-byte hex string client_secret = secrets.token_hex(32) + if client_metadata.scope is None and self.options.default_scopes is not None: + client_metadata.scope = " ".join(self.options.default_scopes) + elif ( + client_metadata.scope is not None and self.options.valid_scopes is not None + ): + requested_scopes = set(client_metadata.scope.split()) + valid_scopes = set(self.options.valid_scopes) + if not requested_scopes.issubset(valid_scopes): + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description="Requested scopes are not valid: " + f"{', '.join(requested_scopes - valid_scopes)}", + ), + status_code=400, + ) + client_id_issued_at = int(time.time()) client_secret_expires_at = ( - client_id_issued_at + self.client_secret_expiry_seconds - if self.client_secret_expiry_seconds is not None + client_id_issued_at + self.options.client_secret_expiry_seconds + if self.options.client_secret_expiry_seconds is not None else None ) diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 898df924b2..49387247ab 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -1,6 +1,6 @@ from typing import Callable -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import AnyHttpUrl from starlette.routing import Route from mcp.server.auth.handlers.authorize import AuthorizationHandler @@ -10,30 +10,10 @@ from mcp.server.auth.handlers.token import TokenHandler from mcp.server.auth.middleware.client_auth import ClientAuthenticator from mcp.server.auth.provider import OAuthServerProvider +from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions from mcp.shared.auth import OAuthMetadata -class ClientRegistrationOptions(BaseModel): - enabled: bool = False - client_secret_expiry_seconds: int | None = None - - -class RevocationOptions(BaseModel): - enabled: bool = False - - -class AuthSettings(BaseModel): - issuer_url: AnyHttpUrl = Field( - ..., - description="URL advertised as OAuth issuer; this should be the URL the server " - "is reachable at", - ) - service_documentation_url: AnyHttpUrl | None = None - client_registration_options: ClientRegistrationOptions | None = None - revocation_options: RevocationOptions | None = None - required_scopes: list[str] | None = None - - def validate_issuer_url(url: AnyHttpUrl): """ Validate that the issuer URL meets OAuth 2.0 requirements. @@ -109,7 +89,7 @@ def create_auth_routes( if client_registration_options.enabled: registration_handler = RegistrationHandler( provider.clients_store, - client_secret_expiry_seconds=client_registration_options.client_secret_expiry_seconds, + options=client_registration_options, ) routes.append( Route( diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py new file mode 100644 index 0000000000..1086bb77e2 --- /dev/null +++ b/src/mcp/server/auth/settings.py @@ -0,0 +1,24 @@ +from pydantic import AnyHttpUrl, BaseModel, Field + + +class ClientRegistrationOptions(BaseModel): + enabled: bool = False + client_secret_expiry_seconds: int | None = None + valid_scopes: list[str] | None = None + default_scopes: list[str] | None = None + + +class RevocationOptions(BaseModel): + enabled: bool = False + + +class AuthSettings(BaseModel): + issuer_url: AnyHttpUrl = Field( + ..., + description="URL advertised as OAuth issuer; this should be the URL the server " + "is reachable at", + ) + service_documentation_url: AnyHttpUrl | None = None + client_registration_options: ClientRegistrationOptions | None = None + revocation_options: RevocationOptions | None = None + required_scopes: list[str] | None = None diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 778b0dcc1b..66244b7465 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -30,7 +30,7 @@ RequireAuthMiddleware, ) from mcp.server.auth.provider import OAuthServerProvider -from mcp.server.auth.routes import ( +from mcp.server.auth.settings import ( AuthSettings, ) from mcp.server.fastmcp.exceptions import ResourceError diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index a06123fed5..efee1fe6ac 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -27,11 +27,11 @@ construct_redirect_uri, ) from mcp.server.auth.routes import ( - AuthSettings, ClientRegistrationOptions, RevocationOptions, create_auth_routes, ) +from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp import FastMCP from mcp.shared.auth import ( OAuthClientInformationFull, @@ -226,7 +226,11 @@ def auth_app(mock_oauth_provider): mock_oauth_provider, AnyHttpUrl("https://auth.example.com"), AnyHttpUrl("https://docs.example.com"), - client_registration_options=ClientRegistrationOptions(enabled=True), + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=["read", "write", "profile"], + default_scopes=["read", "write"], + ), revocation_options=RevocationOptions(enabled=True), ) @@ -946,6 +950,57 @@ async def test_revoke_with_malformed_token(self, test_client, registered_client) assert error_response["error"] == "invalid_request" assert "token_type_hint" in error_response["error_description"] + @pytest.mark.anyio + async def test_client_registration_disallowed_scopes( + self, test_client: httpx.AsyncClient + ): + """Test client registration with scopes that are not allowed.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "scope": "read write profile admin", # 'admin' is not in valid_scopes + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert "scope" in error_data["error_description"] + assert "admin" in error_data["error_description"] + + @pytest.mark.anyio + async def test_client_registration_default_scopes( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + ): + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + # No scope specified + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # Verify client was registered successfully + assert client_info["scope"] == "read write" + + # Retrieve the client from the store to verify default scopes + registered_client = await mock_oauth_provider.clients_store.get_client( + client_info["client_id"] + ) + assert registered_client is not None + + # Check that default scopes were applied + assert registered_client.scope == "read write" + class TestFastMCPWithAuth: """Test FastMCP server with authentication.""" From 50673c6360749521db9181effec0e883ae51781a Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 10:47:05 -0700 Subject: [PATCH 45/60] Validate grant_types on registration --- src/mcp/server/auth/handlers/register.py | 11 +++++ .../fastmcp/auth/test_auth_integration.py | 44 ++++++++++--------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index e79355eea6..efcb32e2b9 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -75,6 +75,17 @@ async def handle(self, request: Request) -> Response: ), status_code=400, ) + if set(client_metadata.grant_types) != set( + ["authorization_code", "refresh_token"] + ): + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description="grant_types must be authorization_code " + "and refresh_token", + ), + status_code=400, + ) client_id_issued_at = int(time.time()) client_secret_expires_at = ( diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index efee1fe6ac..ec19b5148d 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -407,27 +407,6 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient): "error_description" in error_response ) # Contains validation error messages - @pytest.mark.anyio - @pytest.mark.parametrize( - "registered_client", [{"grant_types": ["authorization_code"]}], indirect=True - ) - async def test_token_unsupported_grant_type(self, test_client, registered_client): - """Test token endpoint error - unsupported grant type.""" - # Try refresh_token grant with client that only supports authorization_code - response = await test_client.post( - "/token", - data={ - "grant_type": "refresh_token", - "client_id": registered_client["client_id"], - "client_secret": registered_client["client_secret"], - "refresh_token": "some_refresh_token", - }, - ) - assert response.status_code == 400 - error_response = response.json() - assert error_response["error"] == "unsupported_grant_type" - assert "supported grant types" in error_response["error_description"] - @pytest.mark.anyio async def test_token_invalid_auth_code( self, test_client, registered_client, pkce_challenge @@ -1001,6 +980,29 @@ async def test_client_registration_default_scopes( # Check that default scopes were applied assert registered_client.scope == "read write" + @pytest.mark.anyio + async def test_client_registration_invalid_grant_type( + self, test_client: httpx.AsyncClient + ): + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "grant_types": ["authorization_code"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert ( + error_data["error_description"] + == "grant_types must be authorization_code and refresh_token" + ) + class TestFastMCPWithAuth: """Test FastMCP server with authentication.""" From 02d76f32c54a367758072b8541f755188ce20fba Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Wed, 12 Mar 2025 14:35:09 +0000 Subject: [PATCH 46/60] auth: client implementation --- src/mcp/client/auth/__init__.py | 0 src/mcp/client/auth/oauth.py | 495 ++++++++++++++++++++++++++++++++ tests/client/test_oauth.py | 236 +++++++++++++++ 3 files changed, 731 insertions(+) create mode 100644 src/mcp/client/auth/__init__.py create mode 100644 src/mcp/client/auth/oauth.py create mode 100644 tests/client/test_oauth.py diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/mcp/client/auth/oauth.py b/src/mcp/client/auth/oauth.py new file mode 100644 index 0000000000..0f5aa0df02 --- /dev/null +++ b/src/mcp/client/auth/oauth.py @@ -0,0 +1,495 @@ +""" +Authentication functionality for MCP client. + +This module provides authentication mechanisms for the MCP client to authenticate +with an MCP server. It implements the authentication flow as specified in the MCP +authorization specification. +""" + +import json +import logging +from datetime import datetime, timedelta +from typing import Any, Protocol +from urllib.parse import urlparse + +import httpx +from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field + +logger = logging.getLogger(__name__) + + +class AccessToken(BaseModel): + """ + Represents an OAuth 2.0 access token with its associated metadata. + """ + + access_token: str + token_type: str = Field(default="Bearer") + expires_in: timedelta | None = None + refresh_token: str | None = None + scope: str | None = None + + created_at: datetime = Field(default=datetime.now(), exclude=True) + + model_config = ConfigDict(extra="allow") + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return ( + self.expires_in is not None + and datetime.now() >= self.created_at + self.expires_in + ) + + @property + def scopes(self) -> list[str]: + """Convert scope string to list of scopes.""" + if isinstance(self.scope, list): + return self.scope + return self.scope.split() if self.scope else [] + + def to_auth_header(self) -> dict[str, str]: + """Convert token to Authorization header.""" + + return {"Authorization": f"{self.token_type} {self.access_token}"} + + +class AuthConfig(BaseModel): + """ + Configuration for the MCP client authentication. + """ + + client_id: str + client_secret: str | None = None + token_endpoint: str | None = None + redirect_uri: str | None = None + scope: str | None = None + auth_endpoint: str | None = None + model_config = ConfigDict(extra="allow") + + +class ClientMetadata(BaseModel): + """ + OAuth 2.0 Dynamic Client Registration Metadata. + + This model represents the client metadata used when registering a client + with an OAuth 2.0 server using the Dynamic Client Registration protocol + as defined in RFC 7591 Section 2. + """ + + redirect_uris: list[AnyHttpUrl] = Field(default_factory=list) + token_endpoint_auth_method: str | None = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + client_name: str | None = None + client_uri: AnyHttpUrl | None = None + logo_uri: AnyHttpUrl | None = None + scope: str | None = None + contacts: list[str] | None = None + tos_uri: AnyHttpUrl | None = None + policy_uri: AnyHttpUrl | None = None + jwks_uri: AnyHttpUrl | None = None + jwks: dict[str, Any] | None = None + software_id: str | None = None + software_version: str | None = None + + model_config = ConfigDict(extra="allow") + + +class DynamicClientRegistration(ClientMetadata): + """ + Response from OAuth 2.0 Dynamic Client Registration. + + This model represents the response received after registering a client + with an OAuth 2.0 server using the Dynamic Client Registration protocol + as defined in RFC 7591. + + Note that we inherit from ClientMetadata, which contains the client metadata, + since all values sent during the request are also returned in the response, + as per https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.1 + """ + + client_id: str + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + + model_config = ConfigDict(extra="allow") + + +class ServerMetadataDiscovery(BaseModel): + """ + OAuth 2.0 Authorization Server Metadata Discovery Response. + + This model represents the response received from an OAuth 2.0 server's + metadata discovery endpoint as defined in RFC 8414. + """ + + issuer: AnyHttpUrl + authorization_endpoint: AnyHttpUrl + token_endpoint: AnyHttpUrl + registration_endpoint: AnyHttpUrl | None = None + scopes_supported: list[str] | None = None + response_types_supported: list[str] + response_modes_supported: list[str] | None = None + grant_types_supported: list[str] | None = None + token_endpoint_auth_methods_supported: list[str] | None = None + token_endpoint_auth_signing_alg_values_supported: list[str] | None = None + service_documentation: AnyHttpUrl | None = None + revocation_endpoint: AnyHttpUrl | None = None + revocation_endpoint_auth_methods_supported: list[str] | None = None + revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None + introspection_endpoint: AnyHttpUrl | None = None + introspection_endpoint_auth_methods_supported: list[str] | None = None + introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None + + model_config = ConfigDict(extra="allow") + + +class TokenManager: + """ + Manages OAuth tokens for MCP client, handling token refresh and expiration. + """ + + def __init__(self, config: AuthConfig): + self.config = config + self.token: AccessToken | None = None + + @property + def is_authenticated(self) -> bool: + """Check if the client is authenticated with a valid token.""" + return self.token is not None and not self.token.is_expired + + async def refresh_token_if_needed(self) -> bool: + """ + Refresh the token if it's expired or close to expiration. + + Returns: + bool: True if token was refreshed, False otherwise + """ + if not self.token or not self.token.refresh_token: + return False + + if self.token.is_expired(): + await self.refresh() + return True + + return False + + async def refresh(self) -> AccessToken | None: + """ + Refresh the access token using the refresh token. + + Returns: + AccessToken | None: The new token if successful, None otherwise + """ + if ( + not self.token + or not self.token.refresh_token + or not self.config.token_endpoint + ): + return None + + data = { + "grant_type": "refresh_token", + "refresh_token": self.token.refresh_token, + "client_id": self.config.client_id, + } + + # Add client secret if available + if self.config.client_secret: + data["client_secret"] = self.config.client_secret + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + self.config.token_endpoint, + data=data, + headers=headers, + ) + response.raise_for_status() + token_data = response.json() + + # Create and store the token + token = AccessToken(**token_data) + + # If the response didn't include a refresh token, keep the old one + if not token.refresh_token and self._token.refresh_token: + token.refresh_token = self._token.refresh_token + + self._token = token + return token + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error during token refresh: {e.response.status_code}") + if e.response.content: + try: + error_data = json.loads(e.response.content) + logger.error(f"Error details: {error_data}") + except json.JSONDecodeError: + logger.error(f"Error content: {e.response.content}") + return None + + except httpx.RequestError as e: + logger.error(f"Request error during token refresh: {e}") + return None + + except Exception as e: + logger.error(f"Unexpected error during token refresh: {e}") + return None + + async def authenticate_with_client_credentials(self) -> AccessToken | None: + """ + Authenticate using client credentials flow. + + Returns: + AccessToken | None: The access token if successful, None otherwise + """ + if not self.config.token_endpoint or not self.config.client_id: + logger.error("Token endpoint or client ID not configured") + return None + + data = { + "grant_type": "client_credentials", + "client_id": self.config.client_id, + } + + # Add client secret if available + if self.config.client_secret: + data["client_secret"] = self.config.client_secret + + # Add scope if available + if self.config.scope: + data["scope"] = self.config.scope + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + self.config.token_endpoint, + data=data, + headers=headers, + ) + response.raise_for_status() + token_data = response.json() + + # Create and store the token + token = AccessToken(**token_data) + self._token = token + return token + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error during authentication: {e.response.status_code}") + if e.response.content: + try: + error_data = json.loads(e.response.content) + logger.error(f"Error details: {error_data}") + except json.JSONDecodeError: + logger.error(f"Error content: {e.response.content}") + return None + + except httpx.RequestError as e: + logger.error(f"Request error during authentication: {e}") + return None + + except Exception as e: + logger.error(f"Unexpected error during authentication: {e}") + return None + + +class AuthSession: + """ + Client for handling authentication with an MCP server. + + This client provides methods for authenticating with an MCP server using + various OAuth 2.0 flows and managing the resulting tokens. + """ + + def __init__(self, config: AuthConfig): + """ + Initialize the authentication client with the given configuration. + + Args: + config: Authentication configuration + """ + self.config = config + self.token_manager: TokenManager = TokenManager(config) + + async def initialize(self) -> None: + """ + Initialize the client and prepare it for authentication. + """ + if self.token_manager is None: + self.token_manager = TokenManager(self.config) + + async def authenticate_with_client_credentials(self) -> AccessToken | None: + """ + Authenticate using the client credentials flow. + + This flow is typically used for machine-to-machine authentication + where the client is acting on its own behalf, not on behalf of a user. + + Returns: + AccessToken | None: The access token if successful, None otherwise + """ + await self.initialize() + return await self.token_manager.authenticate_with_client_credentials() + + async def get_auth_headers(self) -> dict[str, str]: + """ + Get the authentication headers for API requests. + + This method will refresh the token if needed before returning headers. + + Returns: + dict[str, str]: Authentication headers + """ + await self.initialize() + await self.token_manager.refresh_token_if_needed() + + if not self.token_manager.token: + return {} + + return self.token_manager.token.to_auth_header() + + @property + def is_authenticated(self) -> bool: + """Check if the client is authenticated with a valid token.""" + if self.token_manager is None: + return False + return self.token_manager.is_authenticated + + +class OAuthClientProvider(Protocol): + @property + def client_metadata(self) -> ClientMetadata: ... + + def save_client_information(self, metadata: DynamicClientRegistration) -> None: ... + + +class NotFoundError(Exception): + """Exception raised when a resource or endpoint is not found.""" + + pass + + +class RegistrationFailedError(Exception): + """Exception raised when client registration fails.""" + + pass + + +class OAuthClient: + WELL_KNOWN = "/.well-known/oauth-authorization-server" + + def __init__(self, server_url: AnyHttpUrl, provider: OAuthClientProvider): + self.server_url = server_url + self.http_client = httpx.AsyncClient() + self.provider = provider + self._registration: DynamicClientRegistration | None = None + + async def auth(self): + metadata = await self.discover_auth_metadata() or self._default_metadata() + if metadata.registration_endpoint is None: + raise NotFoundError("Registration endpoint not found") + self._registration = await self.dynamic_client_registration( + self.provider.client_metadata, metadata.registration_endpoint + ) + if self._registration is None: + raise RegistrationFailedError( + f"Registration at {metadata.registration_endpoint} failed" + ) + self.provider.save_client_information(self._registration) + + def _default_metadata(self) -> ServerMetadataDiscovery: + base_url = AnyHttpUrl(str(self.server_url).rstrip("/")) + return ServerMetadataDiscovery( + issuer=base_url, + authorization_endpoint=AnyHttpUrl(f"{base_url}/authorize"), + token_endpoint=AnyHttpUrl(f"{base_url}/token"), + registration_endpoint=AnyHttpUrl(f"{base_url}/register"), + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token"], + token_endpoint_auth_methods_supported=["client_secret_post"], + ) + + async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None: + discovery_url = self._build_discovery_url() + + try: + response = await self.http_client.get(str(discovery_url)) + if response.status_code == 404: + return None + response.raise_for_status() + json_data = await response.aread() + return ServerMetadataDiscovery.model_validate_json(json_data) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP status: {e}") + raise + except Exception as e: + logger.error(f"Error during auth metadata discovery: {e}") + raise + + def _build_discovery_url(self) -> AnyHttpUrl: + base_url = str(self.server_url).rstrip("/") + parsed_url = urlparse(base_url) + # HTTPS is required by RFC 8414 + discovery_url = f"https://{parsed_url.netloc}{self.WELL_KNOWN}" + return AnyHttpUrl(discovery_url) + + async def dynamic_client_registration( + self, client_metadata: ClientMetadata, registration_endpoint: AnyHttpUrl + ) -> DynamicClientRegistration | None: + """ + Register a client dynamically with an OAuth 2.0 authorization server + following RFC 7591. + + Args: + client_metadata: Typed client registration metadata + registration_endpoint: Where to register clients. + If None, will use discovery + + Returns: + DynamicClientRegistrationResponse if successful, None otherwise + + Raises: + httpx.HTTPStatusError: If the server returns an error status code + Exception: For other errors during registration + """ + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + try: + response = await self.http_client.post( + str(registration_endpoint), + json=client_metadata.model_dump(exclude_none=True), + headers=headers, + ) + if response.status_code == 404: + logger.error( + f"Registration endpoint not found at {registration_endpoint}" + ) + return None + response.raise_for_status() + client_data = await response.aread() + return DynamicClientRegistration.model_validate_json(client_data) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error in client registration: {e.response.status_code}") + if e.response.content: + try: + error_data = json.loads(e.response.content) + logger.error(f"Error details: {error_data}") + except json.JSONDecodeError: + logger.error(f"Error content: {e.response.content}") + except Exception as e: + logger.error(f"Unexpected error during registration: {e}") + + return None diff --git a/tests/client/test_oauth.py b/tests/client/test_oauth.py new file mode 100644 index 0000000000..dee89e97dd --- /dev/null +++ b/tests/client/test_oauth.py @@ -0,0 +1,236 @@ +import json +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +from pydantic import AnyHttpUrl + +from mcp.client.auth.oauth import ( + ClientMetadata, + DynamicClientRegistration, + OAuthClient, + OAuthClientProvider, +) + + +class MockOauthClientProvider(OAuthClientProvider): + @property + def client_metadata(self) -> ClientMetadata: + return ClientMetadata( + client_name="Test Client", + redirect_uris=[AnyHttpUrl("https://client.example.com/callback")], + token_endpoint_auth_method="client_secret_post", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + ) + + def save_client_information(self, metadata: DynamicClientRegistration) -> None: + pass + + +@pytest.fixture +def server_url(): + return AnyHttpUrl("https://example.com/v1") + + +@pytest.fixture +def http_server_urls(): + return [ + # HTTP URL should be converted to HTTPS + "http://example.com/auth", + # URL with trailing slash + "http://auth.example.org/", + # Complex path + "http://api.example.net/v1/auth/service", + # URL with query parameters (these should be ignored) + "http://example.io/oauth?version=2.0&debug=true", + # URL with port + "http://auth.example.com:8080/v1", + ] + + +@pytest.fixture +def auth_client(server_url): + return OAuthClient(server_url, MockOauthClientProvider()) + + +@pytest.fixture +def mock_http_response(): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.aread = AsyncMock( + return_value=json.dumps( + { + "issuer": "https://example.com/v1", + "authorization_endpoint": "https://example.com/v1/authorize", + "token_endpoint": "https://example.com/v1/token", + "registration_endpoint": "https://example.com/v1/register", + "response_types_supported": ["code"], + } + ) + ) + return mock_response + + +@pytest.fixture +def client_metadata(): + return ClientMetadata( + client_name="Test Client", + redirect_uris=[AnyHttpUrl("https://client.example.com/callback")], + token_endpoint_auth_method="client_secret_post", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + ) + + +@pytest.mark.anyio +async def test_discover_auth_metadata(auth_client, mock_http_response): + # Mock the HTTP client's stream method + auth_client.http_client.get = AsyncMock(return_value=mock_http_response) + + # Call the method under test + result = await auth_client.discover_auth_metadata() + + # Assertions + assert result is not None + assert result.issuer == AnyHttpUrl("https://example.com/v1") + assert result.authorization_endpoint == AnyHttpUrl( + "https://example.com/v1/authorize" + ) + assert result.token_endpoint == AnyHttpUrl("https://example.com/v1/token") + assert result.registration_endpoint == AnyHttpUrl("https://example.com/v1/register") + + # Verify the correct URL was used + expected_url = "https://example.com/.well-known/oauth-authorization-server" + auth_client.http_client.get.assert_called_once_with(expected_url) + + +@pytest.mark.anyio +async def test_discover_auth_metadata_not_found(auth_client): + # Mock 404 response + mock_response = MagicMock() + mock_response.status_code = 404 + auth_client.http_client.get = AsyncMock(return_value=mock_response) + + # Call the method under test + result = await auth_client.discover_auth_metadata() + + # Assertions + assert result is None + + +@pytest.mark.anyio +async def test_dynamic_client_registration( + auth_client, client_metadata, mock_http_response +): + # Setup mock response for registration + registration_response = { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "client_name": "Test Client", + "redirect_uris": ["https://client.example.com/callback"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + } + mock_http_response.aread = AsyncMock(return_value=json.dumps(registration_response)) + auth_client.http_client.post = AsyncMock(return_value=mock_http_response) + + # Call the method under test + registration_endpoint = "https://example.com/v1/register" + result = await auth_client.dynamic_client_registration( + client_metadata, registration_endpoint + ) + + # Assertions + assert result is not None + assert result.client_id == "test-client-id" + assert result.client_secret == "test-client-secret" + assert result.client_name == "Test Client" + + # Verify the request was made correctly + auth_client.http_client.post.assert_called_once_with( + registration_endpoint, + json=client_metadata.model_dump(exclude_none=True), + headers={"Content-Type": "application/json", "Accept": "application/json"}, + ) + + +@pytest.mark.anyio +async def test_dynamic_client_registration_error(auth_client, client_metadata): + # Mock error response + mock_error_response = AsyncMock() + mock_error_response.__aenter__ = AsyncMock(return_value=mock_error_response) + mock_error_response.__aexit__ = AsyncMock(return_value=None) + mock_error_response.status_code = 400 + mock_error_response.raise_for_status = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Client error '400 Bad Request'", + request=MagicMock(), + response=MagicMock( + status_code=400, + content=json.dumps({"error": "invalid_client_metadata"}), + ), + ) + ) + error_json = json.dumps({"error": "invalid_client_metadata"}) + mock_error_response.content = error_json.encode() + + auth_client.http_client.post = AsyncMock(return_value=mock_error_response) + + # Call the method under test + registration_endpoint = "https://example.com/v1/register" + result = await auth_client.dynamic_client_registration( + client_metadata, registration_endpoint + ) + + # Assertions + assert result is None + + +@pytest.mark.parametrize( + "input_url,expected_discovery_url", + [ + # Basic HTTP URL: protocol should be changed to HTTPS + ( + "http://example.com", + "https://example.com/.well-known/oauth-authorization-server", + ), + # URL with trailing slash: should be normalized + ( + "https://example.com/", + "https://example.com/.well-known/oauth-authorization-server", + ), + # URL with complex path: .well-known should be at the root + ( + "https://example.com/api/v1/auth", + "https://example.com/.well-known/oauth-authorization-server", + ), + # URL with query parameters: parameters should be ignored + ( + "https://auth.example.org?version=2.0&debug=true", + "https://auth.example.org/.well-known/oauth-authorization-server", + ), + # URL with port: port should be preserved + ( + "http://auth.example.net:8080", + "https://auth.example.net:8080/.well-known/oauth-authorization-server", + ), + # URL with subdomain, path, and trailing slash: .well-known should be at the + # root + ( + "http://api.auth.example.com/oauth/v2/", + "https://api.auth.example.com/.well-known/oauth-authorization-server", + ), + ], +) +def test_build_discovery_url_with_various_formats(input_url, expected_discovery_url): + # Create auth client with the given URL + auth_client = OAuthClient(AnyHttpUrl(input_url), MockOauthClientProvider()) + + # Call the method under test + discovery_url = auth_client._build_discovery_url() + + # Assertions + assert discovery_url == AnyHttpUrl(expected_discovery_url) From 88edddcd0a776966cc87d1673b2a8d64b27c4af5 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Wed, 12 Mar 2025 20:27:12 +0000 Subject: [PATCH 47/60] update lock --- pyproject.toml | 2 - src/mcp/client/auth/oauth.py | 277 ++++++++++++++++++++++++++++++----- uv.lock | 14 -- 3 files changed, 240 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 429b7d6633..de1186e75b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", - "python-multipart", ] [project.optional-dependencies] @@ -48,7 +47,6 @@ dev-dependencies = [ "pytest>=8.3.4", "ruff>=0.8.5", "trio>=0.26.2", - "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", ] diff --git a/src/mcp/client/auth/oauth.py b/src/mcp/client/auth/oauth.py index 0f5aa0df02..7f5949652a 100644 --- a/src/mcp/client/auth/oauth.py +++ b/src/mcp/client/auth/oauth.py @@ -6,11 +6,13 @@ authorization specification. """ +import base64 +import hashlib import json import logging from datetime import datetime, timedelta from typing import Any, Protocol -from urllib.parse import urlparse +from urllib.parse import urlencode, urlparse import httpx from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field @@ -373,7 +375,49 @@ class OAuthClientProvider(Protocol): @property def client_metadata(self) -> ClientMetadata: ... - def save_client_information(self, metadata: DynamicClientRegistration) -> None: ... + @property + def redirect_url(self) -> AnyHttpUrl: ... + + async def open_user_agent(self, url: AnyHttpUrl) -> None: + """ + Opens the user agent to the given URL. + """ + ... + + async def client_registration( + self, endpoint: AnyHttpUrl + ) -> DynamicClientRegistration | None: + """ + Loads the client registration for the given endpoint. + """ + ... + + async def store_client_registration( + self, endpoint: AnyHttpUrl, metadata: DynamicClientRegistration + ) -> None: + """ + Stores the client registration to be retreived for the next session + """ + ... + + def code_verifier(self) -> str: + """ + Loads the PKCE code verifier for the current session. + See https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1 + """ + ... + + async def token(self) -> AccessToken | None: + """ + Loads the token for the current session. + """ + ... + + async def store_token(self, token: AccessToken) -> None: + """ + Stores the token to be retreived for the next session + """ + ... class NotFoundError(Exception): @@ -388,29 +432,64 @@ class RegistrationFailedError(Exception): pass +class GrantNotSupported(Exception): + """Exception raised when a grant type is not supported.""" + + pass + + class OAuthClient: WELL_KNOWN = "/.well-known/oauth-authorization-server" - - def __init__(self, server_url: AnyHttpUrl, provider: OAuthClientProvider): + GRANT_TYPE: str = "authorization_code" + + def __init__( + self, + server_url: AnyHttpUrl, + provider: OAuthClientProvider, + scope: str | None = None, + ): self.server_url = server_url self.http_client = httpx.AsyncClient() self.provider = provider - self._registration: DynamicClientRegistration | None = None + self.scope = scope - async def auth(self): - metadata = await self.discover_auth_metadata() or self._default_metadata() + @property + def discovery_url(self) -> AnyHttpUrl: + base_url = str(self.server_url).rstrip("/") + parsed_url = urlparse(base_url) + # HTTPS is required by RFC 8414 + discovery_url = f"https://{parsed_url.netloc}{self.WELL_KNOWN}" + return AnyHttpUrl(discovery_url) + + async def _obtain_client( + self, metadata: ServerMetadataDiscovery + ) -> DynamicClientRegistration: + """ + Obtain a client by either reading it from the OAuthProvider or registering it. + """ if metadata.registration_endpoint is None: raise NotFoundError("Registration endpoint not found") - self._registration = await self.dynamic_client_registration( - self.provider.client_metadata, metadata.registration_endpoint - ) - if self._registration is None: - raise RegistrationFailedError( - f"Registration at {metadata.registration_endpoint} failed" + + if registration := await self.provider.client_registration(metadata.issuer): + return registration + else: + registration = await self.dynamic_client_registration( + self.provider.client_metadata, metadata.registration_endpoint ) - self.provider.save_client_information(self._registration) + if registration is None: + raise RegistrationFailedError( + f"Registration at {metadata.registration_endpoint} failed" + ) - def _default_metadata(self) -> ServerMetadataDiscovery: + await self.provider.store_client_registration(metadata.issuer, registration) + return registration + + def default_metadata(self) -> ServerMetadataDiscovery: + """ + Returns default endpoints as specified in + https://spec.modelcontextprotocol.io/specification/draft/basic/authorization/ + for the server. + """ base_url = AnyHttpUrl(str(self.server_url).rstrip("/")) return ServerMetadataDiscovery( issuer=base_url, @@ -423,10 +502,11 @@ def _default_metadata(self) -> ServerMetadataDiscovery: ) async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None: - discovery_url = self._build_discovery_url() - + """ + Use RFC 8414 to discover the authorization server metadata. + """ try: - response = await self.http_client.get(str(discovery_url)) + response = await self.http_client.get(str(self.discovery_url)) if response.status_code == 404: return None response.raise_for_status() @@ -439,31 +519,12 @@ async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None: logger.error(f"Error during auth metadata discovery: {e}") raise - def _build_discovery_url(self) -> AnyHttpUrl: - base_url = str(self.server_url).rstrip("/") - parsed_url = urlparse(base_url) - # HTTPS is required by RFC 8414 - discovery_url = f"https://{parsed_url.netloc}{self.WELL_KNOWN}" - return AnyHttpUrl(discovery_url) - async def dynamic_client_registration( self, client_metadata: ClientMetadata, registration_endpoint: AnyHttpUrl ) -> DynamicClientRegistration | None: """ Register a client dynamically with an OAuth 2.0 authorization server following RFC 7591. - - Args: - client_metadata: Typed client registration metadata - registration_endpoint: Where to register clients. - If None, will use discovery - - Returns: - DynamicClientRegistrationResponse if successful, None otherwise - - Raises: - httpx.HTTPStatusError: If the server returns an error status code - Exception: For other errors during registration """ headers = {"Content-Type": "application/json", "Accept": "application/json"} @@ -493,3 +554,145 @@ async def dynamic_client_registration( logger.error(f"Unexpected error during registration: {e}") return None + + async def exchange_authorization( + self, + metadata: ServerMetadataDiscovery, + registration: DynamicClientRegistration, + code_verifier: str, + authorization_code: str, + ) -> AccessToken: + """Exchange an authorization code for an access token using OAuth 2.1 with PKCE. + + Args: + registration: The client registration information + code_verifier: The PKCE code verifier used to generate the code challenge + authorization_code: The authorization code received from the authorization + server + + Returns: + AccessToken: The resulting access token + + Raises: + GrantNotSupported: If the grant type is not supported + httpx.HTTPStatusError: If the token endpoint request fails + """ + if self.GRANT_TYPE not in (registration.grant_types or []): + raise GrantNotSupported(f"Grant type {self.GRANT_TYPE} not supported") + + code_verifier = self.provider.code_verifier() + # Get token endpoint from server metadata or use default + token_endpoint = str(metadata.token_endpoint) + + # Prepare token request parameters + data = { + "grant_type": self.GRANT_TYPE, + "code": authorization_code, + "redirect_uri": str(self.provider.redirect_url), + "client_id": registration.client_id, + "code_verifier": code_verifier, + } + + # Add client secret if available (optional in OAuth 2.1) + if registration.client_secret: + data["client_secret"] = registration.client_secret + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + try: + response = await self.http_client.post( + token_endpoint, data=data, headers=headers + ) + response.raise_for_status() + token_data = response.json() + + # Create and return the token + return AccessToken(**token_data) + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error during token exchange: {e.response.status_code}") + if e.response.content: + try: + error_data = json.loads(e.response.content) + logger.error(f"Error details: {error_data}") + except json.JSONDecodeError: + logger.error(f"Error content: {e.response.content}") + raise + except Exception as e: + logger.error(f"Unexpected error during token exchange: {e}") + raise + + async def auth(self, authorization_code: str, code_verifier: str) -> AccessToken: + """ + Complete the OAuth 2.1 authorization flow by exchanging authorization code + for tokens. + + Args: + authorization_code: The authorization code received from the authorization + server + code_verifier: The PKCE code verifier used to generate the code challenge + + Returns: + AccessToken: The resulting access token + """ + metadata = await self.discover_auth_metadata() or self.default_metadata() + registration = await self._obtain_client(metadata) + + code_verifier = self.provider.code_verifier() + + authorization_url = self.get_authorization_url( + metadata.authorization_endpoint, + self.provider.redirect_url, + registration.client_id, + code_verifier, + self.scope, + ) + + await self.provider.open_user_agent(AnyHttpUrl(authorization_url)) + + return await self.exchange_authorization( + metadata, registration, code_verifier, authorization_code + ) + + def get_authorization_url( + self, + authorization_endpoint: AnyHttpUrl, + redirect_uri: AnyHttpUrl, + client_id: str, + code_verifier: str, + scope: str | None = None, + ) -> AnyHttpUrl: + """Generate an OAuth 2.1 authorization URL for the user agent. + + This method generates a URL that the user agent (browser) should visit to + authenticate the user and authorize the application. It includes PKCE + (Proof Key for Code Exchange) for enhanced security as required by OAuth 2.1. + """ + # Create a custom verifier for this authorization request + code_verifier = self.provider.code_verifier() + + # Generate code challenge from verifier using SHA-256 + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + + # Build authorization URL with necessary parameters + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": str(redirect_uri), + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + + # Add scope if provided or use the one from registration + if scope: + params["scope"] = scope + + # Construct the full authorization URL + return AnyHttpUrl(f"{authorization_endpoint}?{urlencode(params)}") diff --git a/uv.lock b/uv.lock index b1887c3506..9bbfa795fb 100644 --- a/uv.lock +++ b/uv.lock @@ -221,7 +221,6 @@ ws = [ dev = [ { name = "pyright" }, { name = "pytest" }, - { name = "pytest-flakefinder" }, { name = "pytest-xdist" }, { name = "ruff" }, { name = "trio" }, @@ -247,7 +246,6 @@ requires-dist = [ dev = [ { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=8.3.4" }, - { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, { name = "trio", specifier = ">=0.26.2" }, @@ -550,18 +548,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, ] -[[package]] -name = "pytest-flakefinder" -version = "1.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pytest" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ec/53/69c56a93ea057895b5761c5318455804873a6cd9d796d7c55d41c2358125/pytest-flakefinder-1.1.0.tar.gz", hash = "sha256:e2412a1920bdb8e7908783b20b3d57e9dad590cc39a93e8596ffdd493b403e0e", size = 6795 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/8b/06787150d0fd0cbd3a8054262b56f91631c7778c1bc91bf4637e47f909ad/pytest_flakefinder-1.1.0-py2.py3-none-any.whl", hash = "sha256:741e0e8eea427052f5b8c89c2b3c3019a50c39a59ce4df6a305a2c2d9ba2bd13", size = 4644 }, -] - [[package]] name = "pytest-xdist" version = "3.6.1" From d774be7daefaf6968f60b9ab9bdd3c8966153e70 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Wed, 12 Mar 2025 20:39:39 +0000 Subject: [PATCH 48/60] fix --- .../simple-chatbot/mcp_simple_chatbot/main.py | 3 +- pyproject.toml | 11 ++++--- src/mcp/client/auth/oauth.py | 4 +-- src/mcp/server/lowlevel/server.py | 6 ++-- src/mcp/server/sse.py | 6 ---- tests/client/test_oauth.py | 31 ++++++++++++++++--- uv.lock | 8 ++--- 7 files changed, 41 insertions(+), 28 deletions(-) diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 30bca72293..7d73e98760 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -322,8 +322,7 @@ async def process_llm_response(self, llm_response: str) -> str: total = result["total"] percentage = (progress / total) * 100 logging.info( - f"Progress: {progress}/{total} " - f"({percentage:.1f}%)" + f"Progress: {progress}/{total} ({percentage:.1f}%)" ) return f"Tool execution result: {result}" diff --git a/pyproject.toml b/pyproject.toml index de1186e75b..4d0d79ba32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ mcp = "mcp.cli:app [cli]" [tool.uv] resolution = "lowest-direct" dev-dependencies = [ - "pyright>=1.1.391", + "pyright>=1.1.396", "pytest>=8.3.4", "ruff>=0.8.5", "trio>=0.26.2", @@ -70,9 +70,6 @@ strict = [ "src/mcp/server/fastmcp/tools/base.py", ] -[tool.pytest.ini_options] -markers = ["anyio"] - [tool.ruff.lint] select = ["E", "F", "I"] ignore = [] @@ -95,8 +92,12 @@ mcp = { workspace = true } xfail_strict = true filterwarnings = [ "error", + # this is a long-standing issue with fastmcp, which is just now being exercised by tests + "ignore:Unclosed:ResourceWarning", # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", + # this is a problem in starlette + "ignore:Please use `import python_multipart` instead.:PendingDeprecationWarning", ] diff --git a/src/mcp/client/auth/oauth.py b/src/mcp/client/auth/oauth.py index 7f5949652a..7763897fac 100644 --- a/src/mcp/client/auth/oauth.py +++ b/src/mcp/client/auth/oauth.py @@ -385,7 +385,7 @@ async def open_user_agent(self, url: AnyHttpUrl) -> None: ... async def client_registration( - self, endpoint: AnyHttpUrl + self, issuer: AnyHttpUrl ) -> DynamicClientRegistration | None: """ Loads the client registration for the given endpoint. @@ -393,7 +393,7 @@ async def client_registration( ... async def store_client_registration( - self, endpoint: AnyHttpUrl, metadata: DynamicClientRegistration + self, issuer: AnyHttpUrl, metadata: DynamicClientRegistration ) -> None: """ Stores the client registration to be retreived for the next session diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 817d1918a9..a09065ec41 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -578,14 +578,12 @@ async def _handle_notification(self, notify: Any): assert type(notify) in self.notification_handlers handler = self.notification_handlers[type(notify)] - logger.debug( - f"Dispatching notification of type " f"{type(notify).__name__}" - ) + logger.debug(f"Dispatching notification of type {type(notify).__name__}") try: await handler(notify) except Exception as err: - logger.error(f"Uncaught exception in notification handler: " f"{err}") + logger.error(f"Uncaught exception in notification handler: {err}") async def _ping_handler(request: types.PingRequest) -> types.ServerResult: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index db36bffad5..63d1b8bf45 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -79,7 +79,6 @@ def __init__(self, endpoint: str) -> None: self._read_stream_writers = {} logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") - @deprecated("use connect_sse_v2 instead") @asynccontextmanager async def connect_sse(self, scope: Scope, receive: Receive, send: Send): if scope["type"] != "http": @@ -130,11 +129,6 @@ async def sse_writer(): tg.start_soon(response, scope, receive, send) logger.debug("Yielding read and write streams") - # TODO: hold on; shouldn't we be returning the EventSourceResponse? - # I think this is why the tests hang - # TODO: we probably shouldn't return response here, since it's a breaking - # change - # this is just to test yield (read_stream, write_stream, response) async def handle_post_message( diff --git a/tests/client/test_oauth.py b/tests/client/test_oauth.py index dee89e97dd..90ca5683e5 100644 --- a/tests/client/test_oauth.py +++ b/tests/client/test_oauth.py @@ -6,6 +6,7 @@ from pydantic import AnyHttpUrl from mcp.client.auth.oauth import ( + AccessToken, ClientMetadata, DynamicClientRegistration, OAuthClient, @@ -24,7 +25,30 @@ def client_metadata(self) -> ClientMetadata: response_types=["code"], ) - def save_client_information(self, metadata: DynamicClientRegistration) -> None: + @property + def redirect_url(self) -> AnyHttpUrl: + return AnyHttpUrl("https://client.example.com/callback") + + async def open_user_agent(self, url: AnyHttpUrl) -> None: + pass + + async def client_registration( + self, issuer: AnyHttpUrl + ) -> DynamicClientRegistration | None: + return None + + async def store_client_registration( + self, issuer: AnyHttpUrl, metadata: DynamicClientRegistration + ) -> None: + pass + + def code_verifier(self) -> str: + return "test-code-verifier" + + async def token(self) -> AccessToken | None: + return None + + async def store_token(self, token: AccessToken) -> None: pass @@ -229,8 +253,5 @@ def test_build_discovery_url_with_various_formats(input_url, expected_discovery_ # Create auth client with the given URL auth_client = OAuthClient(AnyHttpUrl(input_url), MockOauthClientProvider()) - # Call the method under test - discovery_url = auth_client._build_discovery_url() - # Assertions - assert discovery_url == AnyHttpUrl(expected_discovery_url) + assert auth_client.discovery_url == AnyHttpUrl(expected_discovery_url) diff --git a/uv.lock b/uv.lock index 9bbfa795fb..8671811eea 100644 --- a/uv.lock +++ b/uv.lock @@ -244,7 +244,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "pyright", specifier = ">=1.1.391" }, + { name = "pyright", specifier = ">=1.1.396" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.8.5" }, @@ -520,15 +520,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.391" +version = "1.1.396" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/11/05/4ea52a8a45cc28897edb485b4102d37cbfd5fce8445d679cdeb62bfad221/pyright-1.1.391.tar.gz", hash = "sha256:66b2d42cdf5c3cbab05f2f4b76e8bec8aa78e679bfa0b6ad7b923d9e027cadb2", size = 21965 } +sdist = { url = "https://files.pythonhosted.org/packages/bd/73/f20cb1dea1bdc1774e7f860fb69dc0718c7d8dea854a345faec845eb086a/pyright-1.1.396.tar.gz", hash = "sha256:142901f5908f5a0895be3d3befcc18bedcdb8cc1798deecaec86ef7233a29b03", size = 3814400 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/89/66f49552fbeb21944c8077d11834b2201514a56fd1b7747ffff9630f1bd9/pyright-1.1.391-py3-none-any.whl", hash = "sha256:54fa186f8b3e8a55a44ebfa842636635688670c6896dcf6cf4a7fc75062f4d15", size = 18579 }, + { url = "https://files.pythonhosted.org/packages/80/be/ecb7cfb42d242b7ee764b52e6ff4782beeec00e3b943a3ec832b281f9da6/pyright-1.1.396-py3-none-any.whl", hash = "sha256:c635e473095b9138c471abccca22b9fedbe63858e0b40d4fc4b67da041891844", size = 5689355 }, ] [[package]] From a09e9580c71ed4ff7892468a7dbe7bfabea1fd0c Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Fri, 14 Mar 2025 20:52:07 +0000 Subject: [PATCH 49/60] foo --- .gitignore | 3 +- src/mcp/client/auth/oauth.py | 562 ++++++++++++++--------------------- src/mcp/client/sse.py | 34 ++- 3 files changed, 258 insertions(+), 341 deletions(-) diff --git a/.gitignore b/.gitignore index 54006f93f2..2754db9d95 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ cython_debug/ #.idea/ # vscode -.vscode/ \ No newline at end of file +.vscode/ +.windsurfrules diff --git a/src/mcp/client/auth/oauth.py b/src/mcp/client/auth/oauth.py index 7763897fac..a43a461dbf 100644 --- a/src/mcp/client/auth/oauth.py +++ b/src/mcp/client/auth/oauth.py @@ -6,16 +6,19 @@ authorization specification. """ +from __future__ import annotations as _annotations + import base64 import hashlib import json import logging +from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol from urllib.parse import urlencode, urlparse import httpx -from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, ConfigDict, Field logger = logging.getLogger(__name__) @@ -55,20 +58,6 @@ def to_auth_header(self) -> dict[str, str]: return {"Authorization": f"{self.token_type} {self.access_token}"} -class AuthConfig(BaseModel): - """ - Configuration for the MCP client authentication. - """ - - client_id: str - client_secret: str | None = None - token_endpoint: str | None = None - redirect_uri: str | None = None - scope: str | None = None - auth_endpoint: str | None = None - model_config = ConfigDict(extra="allow") - - class ClientMetadata(BaseModel): """ OAuth 2.0 Dynamic Client Registration Metadata. @@ -148,229 +137,6 @@ class ServerMetadataDiscovery(BaseModel): model_config = ConfigDict(extra="allow") -class TokenManager: - """ - Manages OAuth tokens for MCP client, handling token refresh and expiration. - """ - - def __init__(self, config: AuthConfig): - self.config = config - self.token: AccessToken | None = None - - @property - def is_authenticated(self) -> bool: - """Check if the client is authenticated with a valid token.""" - return self.token is not None and not self.token.is_expired - - async def refresh_token_if_needed(self) -> bool: - """ - Refresh the token if it's expired or close to expiration. - - Returns: - bool: True if token was refreshed, False otherwise - """ - if not self.token or not self.token.refresh_token: - return False - - if self.token.is_expired(): - await self.refresh() - return True - - return False - - async def refresh(self) -> AccessToken | None: - """ - Refresh the access token using the refresh token. - - Returns: - AccessToken | None: The new token if successful, None otherwise - """ - if ( - not self.token - or not self.token.refresh_token - or not self.config.token_endpoint - ): - return None - - data = { - "grant_type": "refresh_token", - "refresh_token": self.token.refresh_token, - "client_id": self.config.client_id, - } - - # Add client secret if available - if self.config.client_secret: - data["client_secret"] = self.config.client_secret - - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - } - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - self.config.token_endpoint, - data=data, - headers=headers, - ) - response.raise_for_status() - token_data = response.json() - - # Create and store the token - token = AccessToken(**token_data) - - # If the response didn't include a refresh token, keep the old one - if not token.refresh_token and self._token.refresh_token: - token.refresh_token = self._token.refresh_token - - self._token = token - return token - - except httpx.HTTPStatusError as e: - logger.error(f"HTTP error during token refresh: {e.response.status_code}") - if e.response.content: - try: - error_data = json.loads(e.response.content) - logger.error(f"Error details: {error_data}") - except json.JSONDecodeError: - logger.error(f"Error content: {e.response.content}") - return None - - except httpx.RequestError as e: - logger.error(f"Request error during token refresh: {e}") - return None - - except Exception as e: - logger.error(f"Unexpected error during token refresh: {e}") - return None - - async def authenticate_with_client_credentials(self) -> AccessToken | None: - """ - Authenticate using client credentials flow. - - Returns: - AccessToken | None: The access token if successful, None otherwise - """ - if not self.config.token_endpoint or not self.config.client_id: - logger.error("Token endpoint or client ID not configured") - return None - - data = { - "grant_type": "client_credentials", - "client_id": self.config.client_id, - } - - # Add client secret if available - if self.config.client_secret: - data["client_secret"] = self.config.client_secret - - # Add scope if available - if self.config.scope: - data["scope"] = self.config.scope - - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - } - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - self.config.token_endpoint, - data=data, - headers=headers, - ) - response.raise_for_status() - token_data = response.json() - - # Create and store the token - token = AccessToken(**token_data) - self._token = token - return token - - except httpx.HTTPStatusError as e: - logger.error(f"HTTP error during authentication: {e.response.status_code}") - if e.response.content: - try: - error_data = json.loads(e.response.content) - logger.error(f"Error details: {error_data}") - except json.JSONDecodeError: - logger.error(f"Error content: {e.response.content}") - return None - - except httpx.RequestError as e: - logger.error(f"Request error during authentication: {e}") - return None - - except Exception as e: - logger.error(f"Unexpected error during authentication: {e}") - return None - - -class AuthSession: - """ - Client for handling authentication with an MCP server. - - This client provides methods for authenticating with an MCP server using - various OAuth 2.0 flows and managing the resulting tokens. - """ - - def __init__(self, config: AuthConfig): - """ - Initialize the authentication client with the given configuration. - - Args: - config: Authentication configuration - """ - self.config = config - self.token_manager: TokenManager = TokenManager(config) - - async def initialize(self) -> None: - """ - Initialize the client and prepare it for authentication. - """ - if self.token_manager is None: - self.token_manager = TokenManager(self.config) - - async def authenticate_with_client_credentials(self) -> AccessToken | None: - """ - Authenticate using the client credentials flow. - - This flow is typically used for machine-to-machine authentication - where the client is acting on its own behalf, not on behalf of a user. - - Returns: - AccessToken | None: The access token if successful, None otherwise - """ - await self.initialize() - return await self.token_manager.authenticate_with_client_credentials() - - async def get_auth_headers(self) -> dict[str, str]: - """ - Get the authentication headers for API requests. - - This method will refresh the token if needed before returning headers. - - Returns: - dict[str, str]: Authentication headers - """ - await self.initialize() - await self.token_manager.refresh_token_if_needed() - - if not self.token_manager.token: - return {} - - return self.token_manager.token.to_auth_header() - - @property - def is_authenticated(self) -> bool: - """Check if the client is authenticated with a valid token.""" - if self.token_manager is None: - return False - return self.token_manager.is_authenticated - - class OAuthClientProvider(Protocol): @property def client_metadata(self) -> ClientMetadata: ... @@ -400,6 +166,20 @@ async def store_client_registration( """ ... + async def store_metadata( + self, issuer: AnyHttpUrl, metadata: ServerMetadataDiscovery + ) -> None: + """ + Stores the metadata for the given issuer + """ + ... + + async def metadata(self, issuer: AnyHttpUrl) -> ServerMetadataDiscovery | None: + """ + Loads the metadata for the given issuer + """ + ... + def code_verifier(self) -> str: """ Loads the PKCE code verifier for the current session. @@ -442,24 +222,51 @@ class OAuthClient: WELL_KNOWN = "/.well-known/oauth-authorization-server" GRANT_TYPE: str = "authorization_code" + @dataclass + class State: + metadata: ServerMetadataDiscovery | None = None + registeration: DynamicClientRegistration | None = None + def __init__( self, server_url: AnyHttpUrl, provider: OAuthClientProvider, scope: str | None = None, ): - self.server_url = server_url self.http_client = httpx.AsyncClient() + self.server_url = server_url self.provider = provider self.scope = scope + self.state = self.State() + + @property + def is_authenticated(self) -> bool: + """Check if client has a valid, non-expired token.""" + return self.token is not None and not self.token.is_expired() @property def discovery_url(self) -> AnyHttpUrl: base_url = str(self.server_url).rstrip("/") parsed_url = urlparse(base_url) + # HTTPS is required by RFC 8414 discovery_url = f"https://{parsed_url.netloc}{self.WELL_KNOWN}" - return AnyHttpUrl(discovery_url) + return AnyUrl(discovery_url) + + async def _obtain_metadata(self) -> ServerMetadataDiscovery: + if metadata := await self.provider.metadata(self.discovery_url): + return metadata + if metadata := await self.discover_auth_metadata(self.discovery_url): + await self.provider.store_metadata(self.discovery_url, metadata) + return metadata + return self.default_metadata() + + async def metadata(self) -> ServerMetadataDiscovery: + if self.state.metadata is not None: + return self.state.metadata + + self.state.metadata = await self._obtain_metadata() + return self.state.metadata async def _obtain_client( self, metadata: ServerMetadataDiscovery @@ -484,29 +291,39 @@ async def _obtain_client( await self.provider.store_client_registration(metadata.issuer, registration) return registration + async def client_metadata( + self, metadata: ServerMetadataDiscovery + ) -> DynamicClientRegistration: + if self.state.registeration is not None: + return self.state.registeration + else: + return await self._obtain_client(metadata) + def default_metadata(self) -> ServerMetadataDiscovery: """ Returns default endpoints as specified in https://spec.modelcontextprotocol.io/specification/draft/basic/authorization/ for the server. """ - base_url = AnyHttpUrl(str(self.server_url).rstrip("/")) + base_url = AnyUrl(str(self.server_url).rstrip("/")) return ServerMetadataDiscovery( issuer=base_url, - authorization_endpoint=AnyHttpUrl(f"{base_url}/authorize"), - token_endpoint=AnyHttpUrl(f"{base_url}/token"), - registration_endpoint=AnyHttpUrl(f"{base_url}/register"), + authorization_endpoint=AnyUrl(f"{base_url}/authorize"), + token_endpoint=AnyUrl(f"{base_url}/token"), + registration_endpoint=AnyUrl(f"{base_url}/register"), response_types_supported=["code"], grant_types_supported=["authorization_code", "refresh_token"], token_endpoint_auth_methods_supported=["client_secret_post"], ) - async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None: + async def discover_auth_metadata( + self, discovery_url: AnyHttpUrl + ) -> ServerMetadataDiscovery | None: """ Use RFC 8414 to discover the authorization server metadata. """ try: - response = await self.http_client.get(str(self.discovery_url)) + response = await self.http_client.get(str(discovery_url)) if response.status_code == 404: return None response.raise_for_status() @@ -555,40 +372,148 @@ async def dynamic_client_registration( return None - async def exchange_authorization( - self, - metadata: ServerMetadataDiscovery, - registration: DynamicClientRegistration, - code_verifier: str, - authorization_code: str, - ) -> AccessToken: - """Exchange an authorization code for an access token using OAuth 2.1 with PKCE. + async def start_auth(self) -> AnyHttpUrl: + """ + Start the OAuth 2.1 authorization flow by redirecting the user to the + authorization server. + + Returns: + AnyHttpUrl: The authorization URL to redirect the user to + """ + metadata = await self.metadata() + registration = await self.client_metadata(metadata) + + # Generate PKCE code verifier + code_verifier = self.provider.code_verifier() + + # Build authorization URL + authorization_url = get_authorization_url( + metadata.authorization_endpoint, + self.provider.redirect_url, + registration.client_id, + code_verifier, + self.scope, + ) + + # Open the URL in the user's browser + await self.provider.open_user_agent(authorization_url) + + return authorization_url + + async def finalize_auth(self, authorization_code: str) -> AccessToken: + """ + Complete the OAuth 2.1 authorization flow by exchanging authorization code + for tokens. Args: - registration: The client registration information - code_verifier: The PKCE code verifier used to generate the code challenge authorization_code: The authorization code received from the authorization server Returns: AccessToken: The resulting access token + """ + # Get metadata and registration info + metadata = await self.metadata() + registration = await self.client_metadata(metadata) + code_verifier = self.provider.code_verifier() - Raises: - GrantNotSupported: If the grant type is not supported - httpx.HTTPStatusError: If the token endpoint request fails + # Exchange the code for a token + token = await self.exchange_authorization( + metadata, + registration, + self.provider.redirect_url, + code_verifier, + authorization_code, + ) + + # Cache the token and store it for future use + self.token = token + await self.provider.store_token(token) + + return token + + async def refresh_if_needed(self) -> AccessToken | None: + """ + Get the current token from the underlying provider """ - if self.GRANT_TYPE not in (registration.grant_types or []): - raise GrantNotSupported(f"Grant type {self.GRANT_TYPE} not supported") + # Return cached token if it's valid + metadata = await self.metadata() + registration = await self.client_metadata(metadata) + + if token := await self.provider.token(): + if not token.is_expired(): + return token + + token = await self.refresh_token( + token, + metadata.token_endpoint, + registration.client_id, + registration.client_secret, + ) + + if token is not None: + return token + + return None + + async def refresh_token( + self, + token: AccessToken, + token_endpoint: AnyHttpUrl, + client_id: str, + client_secret: str | None = None, + ) -> AccessToken: + """ + Refresh the access token using a refresh token. + """ + data = { + "grant_type": "refresh_token", + "refresh_token": token.refresh_token, + "client_id": client_id, + } + + if client_secret: + data["client_secret"] = client_secret + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + try: + response = await self.http_client.post( + str(token_endpoint), data=data, headers=headers + ) + response.raise_for_status() + token_data = response.json() + return AccessToken(**token_data) + except Exception as e: + logger.error(f"Error refreshing token: {e}") + raise + + async def exchange_authorization( + self, + metadata: ServerMetadataDiscovery, + registration: DynamicClientRegistration, + redirect_uri: AnyHttpUrl, + code_verifier: str, + authorization_code: str, + grant_type: str = "authorization_code", + ) -> AccessToken: + """ + Exchange an authorization code for an access token using OAuth 2.1 with PKCE. + """ + if grant_type not in (registration.grant_types or []): + raise GrantNotSupported(f"Grant type {grant_type} not supported") - code_verifier = self.provider.code_verifier() # Get token endpoint from server metadata or use default token_endpoint = str(metadata.token_endpoint) # Prepare token request parameters data = { - "grant_type": self.GRANT_TYPE, + "grant_type": grant_type, "code": authorization_code, - "redirect_uri": str(self.provider.redirect_url), + "redirect_uri": str(redirect_uri), "client_id": registration.client_id, "code_verifier": code_verifier, } @@ -615,84 +540,45 @@ async def exchange_authorization( except httpx.HTTPStatusError as e: logger.error(f"HTTP error during token exchange: {e.response.status_code}") if e.response.content: - try: - error_data = json.loads(e.response.content) - logger.error(f"Error details: {error_data}") - except json.JSONDecodeError: - logger.error(f"Error content: {e.response.content}") + logger.error(f"Error content: {e.response.content}") raise except Exception as e: logger.error(f"Unexpected error during token exchange: {e}") raise - async def auth(self, authorization_code: str, code_verifier: str) -> AccessToken: - """ - Complete the OAuth 2.1 authorization flow by exchanging authorization code - for tokens. - - Args: - authorization_code: The authorization code received from the authorization - server - code_verifier: The PKCE code verifier used to generate the code challenge - - Returns: - AccessToken: The resulting access token - """ - metadata = await self.discover_auth_metadata() or self.default_metadata() - registration = await self._obtain_client(metadata) - code_verifier = self.provider.code_verifier() +def get_authorization_url( + authorization_endpoint: AnyHttpUrl, + redirect_uri: AnyHttpUrl, + client_id: str, + code_verifier: str, + scope: str | None = None, +) -> AnyHttpUrl: + """Generate an OAuth 2.1 authorization URL for the user agent. - authorization_url = self.get_authorization_url( - metadata.authorization_endpoint, - self.provider.redirect_url, - registration.client_id, - code_verifier, - self.scope, - ) - - await self.provider.open_user_agent(AnyHttpUrl(authorization_url)) - - return await self.exchange_authorization( - metadata, registration, code_verifier, authorization_code - ) - - def get_authorization_url( - self, - authorization_endpoint: AnyHttpUrl, - redirect_uri: AnyHttpUrl, - client_id: str, - code_verifier: str, - scope: str | None = None, - ) -> AnyHttpUrl: - """Generate an OAuth 2.1 authorization URL for the user agent. - - This method generates a URL that the user agent (browser) should visit to - authenticate the user and authorize the application. It includes PKCE - (Proof Key for Code Exchange) for enhanced security as required by OAuth 2.1. - """ - # Create a custom verifier for this authorization request - code_verifier = self.provider.code_verifier() - - # Generate code challenge from verifier using SHA-256 - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .rstrip("=") - ) - - # Build authorization URL with necessary parameters - params = { - "response_type": "code", - "client_id": client_id, - "redirect_uri": str(redirect_uri), - "code_challenge": code_challenge, - "code_challenge_method": "S256", - } - - # Add scope if provided or use the one from registration - if scope: - params["scope"] = scope - - # Construct the full authorization URL - return AnyHttpUrl(f"{authorization_endpoint}?{urlencode(params)}") + This method generates a URL that the user agent (browser) should visit to + authenticate the user and authorize the application. It includes PKCE + (Proof Key for Code Exchange) for enhanced security as required by OAuth 2.1. + """ + # Generate code challenge from verifier using SHA-256 + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + + # Build authorization URL with necessary parameters + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": str(redirect_uri), + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + + # Add scope if provided or use the one from registration + if scope: + params["scope"] = scope + + # Construct the full authorization URL + return AnyUrl(f"{authorization_endpoint}?{urlencode(params)}") diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index abafacb962..acaecc4b25 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,6 +1,6 @@ import logging from contextlib import asynccontextmanager -from typing import Any +from typing import Any, Union from urllib.parse import urljoin, urlparse import anyio @@ -10,6 +10,8 @@ from httpx_sse import aconnect_sse import mcp.types as types +from mcp.client.auth import http as auth_http +from mcp.client.auth.oauth import AuthSession, OAuthClient logger = logging.getLogger(__name__) @@ -24,6 +26,7 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + auth: Union[AuthSession, OAuthClient, None] = None, ): """ Client transport for SSE. @@ -43,7 +46,33 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: + + # Set up headers and auth if needed + if headers is None: + headers = {} + + if auth is not None: + await auth_http.add_auth_headers(headers, auth) + + # Set up event hooks for auth if auth is provided + event_hooks = {} + if auth is not None: + # Create a response hook for authentication + async def auth_hook(response): + if isinstance(auth, AuthSession): + return await auth_http.auth_response_hook( + response, auth_session=auth + ) + else: + return await auth_http.auth_response_hook( + response, oauth_client=auth + ) + + event_hooks["response"] = [auth_hook] + + async with httpx.AsyncClient( + headers=headers, event_hooks=event_hooks + ) as client: async with aconnect_sse( client, "GET", @@ -117,6 +146,7 @@ async def post_writer(endpoint_url: str): exclude_none=True, ), ) + # Handle 401 responses through the auth hook response.raise_for_status() logger.debug( "Client message sent successfully: " From 4e73552027316dce3b3b9fa5a8130341b50d037c Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 11:43:53 -0700 Subject: [PATCH 50/60] Format --- src/mcp/client/websocket.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index b807370a54..3e73b02048 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -15,7 +15,9 @@ @asynccontextmanager -async def websocket_client(url: str) -> AsyncGenerator[ +async def websocket_client( + url: str, +) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], MemoryObjectSendStream[types.JSONRPCMessage], @@ -59,7 +61,7 @@ async def ws_reader(): async def ws_writer(): """ - Reads JSON-RPC messages from write_stream_reader and + Reads JSON-RPC messages from write_stream_reader and sends them to the server. """ async with write_stream_reader: From 56f694e16aa072e5e53fead581260ebae044cddb Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Wed, 19 Mar 2025 15:11:31 -0700 Subject: [PATCH 51/60] Move StreamingASGITransport into the library code, so MCP integrations can use this in their tests --- .../fastmcp/auth => src/mcp/server}/streaming_asgi_transport.py | 2 ++ tests/server/fastmcp/auth/test_auth_integration.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) rename {tests/server/fastmcp/auth => src/mcp/server}/streaming_asgi_transport.py (99%) diff --git a/tests/server/fastmcp/auth/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py similarity index 99% rename from tests/server/fastmcp/auth/streaming_asgi_transport.py rename to src/mcp/server/streaming_asgi_transport.py index 7bb07b50a4..98a706b381 100644 --- a/tests/server/fastmcp/auth/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -4,6 +4,8 @@ This transport runs the ASGI app as a separate anyio task, allowing it to handle streaming responses like SSE where the app doesn't terminate until the connection is closed. + +This is only intended for writing tests for the SSE transport. """ import typing diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index ec19b5148d..245edf1f1a 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -39,7 +39,7 @@ ) from mcp.types import JSONRPCRequest -from .streaming_asgi_transport import StreamingASGITransport +from mcp.server.streaming_asgi_transport import StreamingASGITransport # Mock client store for testing From 60da6822a3f836becba91b73e533bf31ca0d53fa Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Fri, 21 Mar 2025 13:58:37 -0700 Subject: [PATCH 52/60] Improved error handling, generic types for provider --- src/mcp/server/auth/handlers/authorize.py | 48 ++- src/mcp/server/auth/handlers/register.py | 35 ++- src/mcp/server/auth/handlers/revoke.py | 4 +- src/mcp/server/auth/handlers/token.py | 43 ++- src/mcp/server/auth/middleware/client_auth.py | 10 +- src/mcp/server/auth/provider.py | 141 +++++++-- src/mcp/server/auth/routes.py | 4 +- src/mcp/server/sse.py | 1 - tests/server/auth/test_error_handling.py | 294 ++++++++++++++++++ .../fastmcp/auth/test_auth_integration.py | 26 +- 10 files changed, 490 insertions(+), 116 deletions(-) create mode 100644 tests/server/auth/test_error_handling.py diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 3f78b7e87f..4223e8cecf 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -14,7 +14,9 @@ ) from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import ( + AuthorizationErrorCode, AuthorizationParams, + AuthorizeError, OAuthServerProvider, construct_redirect_uri, ) @@ -49,20 +51,9 @@ class AuthorizationRequest(BaseModel): ) -AuthorizationErrorCode = Literal[ - "invalid_request", - "unauthorized_client", - "access_denied", - "unsupported_response_type", - "invalid_scope", - "server_error", - "temporarily_unavailable", -] - - class AuthorizationErrorResponse(BaseModel): error: AuthorizationErrorCode - error_description: str + error_description: str | None error_uri: AnyUrl | None = None # must be set if provided in the request state: str | None = None @@ -98,16 +89,14 @@ async def handle(self, request: Request) -> Response: async def error_response( error: AuthorizationErrorCode, - error_description: str, + error_description: str | None, attempt_load_client: bool = True, ): nonlocal client, redirect_uri, state if client is None and attempt_load_client: # make last-ditch attempt to load the client client_id = best_effort_extract_string("client_id", params) - client = client_id and await self.provider.clients_store.get_client( - client_id - ) + client = client_id and await self.provider.get_client(client_id) if redirect_uri is None and client: # make last-ditch effort to load the redirect uri if params is not None and "redirect_uri" not in params: @@ -171,7 +160,7 @@ async def error_response( ) # Get client information - client = await self.provider.clients_store.get_client( + client = await self.provider.get_client( auth_request.client_id, ) if not client: @@ -210,15 +199,22 @@ async def error_response( redirect_uri=redirect_uri, ) - # Let the provider pick the next URI to redirect to - return RedirectResponse( - url=await self.provider.authorize( - client, - auth_params, - ), - status_code=302, - headers={"Cache-Control": "no-store"}, - ) + try: + # Let the provider pick the next URI to redirect to + return RedirectResponse( + url=await self.provider.authorize( + client, + auth_params, + ), + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + except AuthorizeError as e: + # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 + return await error_response( + error=e.error, + error_description=e.error_description, + ) except Exception as validation_error: # Catch-all for unexpected errors diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index efcb32e2b9..d1f0213c95 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -1,7 +1,6 @@ import secrets import time from dataclasses import dataclass -from typing import Literal from uuid import uuid4 from pydantic import BaseModel, RootModel, ValidationError @@ -10,7 +9,11 @@ from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.provider import OAuthRegisteredClientsStore +from mcp.server.auth.provider import ( + OAuthServerProvider, + RegistrationError, + RegistrationErrorCode, +) from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata @@ -22,18 +25,13 @@ class RegistrationRequest(RootModel): class RegistrationErrorResponse(BaseModel): - error: Literal[ - "invalid_redirect_uri", - "invalid_client_metadata", - "invalid_software_statement", - "unapproved_software_statement", - ] - error_description: str + error: RegistrationErrorCode + error_description: str | None @dataclass class RegistrationHandler: - clients_store: OAuthRegisteredClientsStore + provider: OAuthServerProvider options: ClientRegistrationOptions async def handle(self, request: Request) -> Response: @@ -116,8 +114,17 @@ async def handle(self, request: Request) -> Response: software_id=client_metadata.software_id, software_version=client_metadata.software_version, ) - # Register client - await self.clients_store.register_client(client_info) + try: + # Register client + await self.provider.register_client(client_info) - # Return client information - return PydanticJSONResponse(content=client_info, status_code=201) + # Return client information + return PydanticJSONResponse(content=client_info, status_code=201) + except RegistrationError as e: + # Handle registration errors as defined in RFC 7591 Section 3.2.2 + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error=e.error, error_description=e.error_description + ), + status_code=400, + ) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 141fc81e88..2d8a745b4b 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -81,8 +81,10 @@ async def handle(self, request: Request) -> Response: if token is not None: break + # if token is not found, just return HTTP 200 per the RFC if token and token.client_id == client.client_id: - # Revoke token + # Revoke token; provider is not meant to be able to do validation + # at this point that would result in an error await self.provider.revoke_token(token) # Return successful empty response diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index a60c091c07..54320a2ff4 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -16,7 +16,7 @@ AuthenticationError, ClientAuthenticator, ) -from mcp.server.auth.provider import OAuthServerProvider +from mcp.server.auth.provider import OAuthServerProvider, TokenError, TokenErrorCode from mcp.shared.auth import OAuthToken @@ -56,14 +56,7 @@ class TokenErrorResponse(BaseModel): See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 """ - error: Literal[ - "invalid_request", - "invalid_client", - "invalid_grant", - "unauthorized_client", - "unsupported_grant_type", - "invalid_scope", - ] + error: TokenErrorCode error_description: str | None = None error_uri: AnyHttpUrl | None = None @@ -184,10 +177,18 @@ async def handle(self, request: Request): ) ) - # Exchange authorization code for tokens - tokens = await self.provider.exchange_authorization_code( - client_info, auth_code - ) + try: + # Exchange authorization code for tokens + tokens = await self.provider.exchange_authorization_code( + client_info, auth_code + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) case RefreshTokenRequest(): refresh_token = await self.provider.load_refresh_token( @@ -233,9 +234,17 @@ async def handle(self, request: Request): ) ) - # Exchange refresh token for new tokens - tokens = await self.provider.exchange_refresh_token( - client_info, refresh_token, scopes - ) + try: + # Exchange refresh token for new tokens + tokens = await self.provider.exchange_refresh_token( + client_info, refresh_token, scopes + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) return self.response(TokenSuccessResponse(root=tokens)) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 56cd93ae9e..62a95e313a 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -1,6 +1,6 @@ import time -from mcp.server.auth.provider import OAuthRegisteredClientsStore +from mcp.server.auth.provider import OAuthServerProvider from mcp.shared.auth import OAuthClientInformationFull @@ -20,20 +20,20 @@ class ClientAuthenticator: logic is skipped. """ - def __init__(self, clients_store: OAuthRegisteredClientsStore): + def __init__(self, provider: OAuthServerProvider): """ Initialize the dependency. Args: - clients_store: Store to look up client information + provider: Provider to look up client information """ - self.clients_store = clients_store + self.provider = provider async def authenticate( self, client_id: str, client_secret: str | None ) -> OAuthClientInformationFull: # Look up client information - client = await self.clients_store.get_client(client_id) + client = await self.provider.get_client(client_id) if not client: raise AuthenticationError("Invalid client_id") diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index 10e666028c..b98009cf2d 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -1,4 +1,5 @@ -from typing import Generic, Protocol, TypeVar +from dataclasses import dataclass +from typing import Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, BaseModel @@ -39,11 +40,70 @@ class AuthInfo(BaseModel): expires_at: int | None = None -class OAuthRegisteredClientsStore(Protocol): +RegistrationErrorCode = Literal[ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_software_statement", + "unapproved_software_statement", +] + + +@dataclass(frozen=True) +class RegistrationError(Exception): + error: RegistrationErrorCode + error_description: str | None = None + + +AuthorizationErrorCode = Literal[ + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", +] + + +@dataclass(frozen=True) +class AuthorizeError(Exception): + error: AuthorizationErrorCode + error_description: str | None = None + + +TokenErrorCode = Literal[ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", +] + + +@dataclass(frozen=True) +class TokenError(Exception): + error: TokenErrorCode + error_description: str | None = None + + +# NOTE: FastMCP doesn't render any of these types in the user response, so it's +# OK to add fields to subclasses which should not be exposed externally. +AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) +RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) +AuthInfoT = TypeVar("AuthInfoT", bound=AuthInfo) + + +class OAuthServerProvider( + Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AuthInfoT] +): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """ Retrieves client information by client ID. + Implementors MAY raise NotImplementedError if dynamic client registration is + disabled in ClientRegistrationOptions. + Args: client_id: The ID of the client to retrieve. @@ -56,26 +116,14 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None """ Saves client information as part of registering it. + Implementors MAY raise NotImplementedError if dynamic client registration is + disabled in ClientRegistrationOptions. + Args: client_info: The client metadata to register. - """ - ... - -# NOTE: FastMCP doesn't render any of these types in the user response, so it's -# OK to add fields to subclasses which should not be exposed externally. -AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) -RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) -AuthInfoT = TypeVar("AuthInfoT", bound=AuthInfo) - - -class OAuthServerProvider( - Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AuthInfoT] -): - @property - def clients_store(self) -> OAuthRegisteredClientsStore: - """ - A store used to read information about registered OAuth clients. + Raises: + RegistrationError: If the client metadata is invalid. """ ... @@ -111,6 +159,16 @@ async def authorize( entropy, and MUST generate an authorization code with at least 128 bits of entropy. See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. + + Args: + client: The client requesting authorization. + params: The parameters of the authorization request. + + Returns: + A URL to redirect the client to for authorization. + + Raises: + AuthorizeError: If the authorization request is invalid. """ ... @@ -118,14 +176,14 @@ async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str ) -> AuthorizationCodeT | None: """ - Loads metadata for the authorization code challenge. + Loads an AuthorizationCode by its code. Args: client: The client that requested the authorization code. authorization_code: The authorization code to get the challenge for. Returns: - The code challenge that was used when the authorization began. + The AuthorizationCode, or None if not found """ ... @@ -133,20 +191,35 @@ async def exchange_authorization_code( self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT ) -> OAuthToken: """ - Exchanges an authorization code for an access token. + Exchanges an authorization code for an access token and refresh token. Args: client: The client exchanging the authorization code. authorization_code: The authorization code to exchange. Returns: - The access and refresh tokens. + The OAuth token, containing access and refresh tokens. + + Raises: + TokenError: If the request is invalid """ ... async def load_refresh_token( self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshTokenT | None: ... + ) -> RefreshTokenT | None: + """ + Loads a RefreshToken by its token string. + + Args: + client: The client that is requesting to load the refresh token. + refresh_token: The refresh token string to load. + + Returns: + The RefreshToken object if found, or None if not found. + """ + + ... async def exchange_refresh_token( self, @@ -155,7 +228,9 @@ async def exchange_refresh_token( scopes: list[str], ) -> OAuthToken: """ - Exchanges a refresh token for an access token. + Exchanges a refresh token for an access token and refresh token. + + Implementations SHOULD rotate both the access token and refresh token. Args: client: The client exchanging the refresh token. @@ -163,19 +238,22 @@ async def exchange_refresh_token( scopes: Optional scopes to request with the new access token. Returns: - The new access and refresh tokens. + The OAuth token, containing access and refresh tokens. + + Raises: + TokenError: If the request is invalid """ ... async def load_access_token(self, token: str) -> AuthInfoT | None: """ - Verifies an access token and returns information about it. + Loads an access token by its token. Args: token: The access token to verify. Returns: - Information about the verified token, or None if the token is invalid. + The AuthInfo, or None if the token is invalid. """ ... @@ -188,11 +266,12 @@ async def revoke_token( If the given token is invalid or already revoked, this method should do nothing. + Implementations SHOULD revoke both the access token and its corresponding + refresh token, regardless of which of the access token or refresh token is + provided. + Args: token: the token to revoke - token_type_hint: hint about the type of token to revoke; optional. if the - token cannot be located using this hint, the provider MUST extend its search - to include all tokens. """ ... diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 49387247ab..581d08d01f 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -65,7 +65,7 @@ def create_auth_routes( client_registration_options, revocation_options, ) - client_authenticator = ClientAuthenticator(provider.clients_store) + client_authenticator = ClientAuthenticator(provider) # Create routes routes = [ @@ -88,7 +88,7 @@ def create_auth_routes( if client_registration_options.enabled: registration_handler = RegistrationHandler( - provider.clients_store, + provider, options=client_registration_options, ) routes.append( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 63d1b8bf45..aab2aa7aee 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -44,7 +44,6 @@ async def handle_sse(request): from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send -from typing_extensions import deprecated import mcp.types as types diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py new file mode 100644 index 0000000000..18e9933e73 --- /dev/null +++ b/tests/server/auth/test_error_handling.py @@ -0,0 +1,294 @@ +""" +Tests for OAuth error handling in the auth handlers. +""" + +import unittest.mock +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest +from httpx import ASGITransport +from pydantic import AnyHttpUrl +from starlette.applications import Starlette + +from mcp.server.auth.provider import ( + AuthorizeError, + RegistrationError, + TokenError, +) +from mcp.server.auth.routes import create_auth_routes +from tests.server.fastmcp.auth.test_auth_integration import ( + MockOAuthProvider, +) + + +@pytest.fixture +def oauth_provider(): + """Return a MockOAuthProvider instance that can be configured to raise errors.""" + return MockOAuthProvider() + + +@pytest.fixture +def app(oauth_provider): + from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions + + # Enable client registration + client_registration_options = ClientRegistrationOptions(enabled=True) + revocation_options = RevocationOptions(enabled=True) + + # Create auth routes + auth_routes = create_auth_routes( + oauth_provider, + issuer_url=AnyHttpUrl("http://localhost"), + client_registration_options=client_registration_options, + revocation_options=revocation_options, + ) + + # Create Starlette app with routes directly + return Starlette(routes=auth_routes) + + +@pytest.fixture +def client(app): + transport = ASGITransport(app=app) + # Use base_url without a path since routes are directly on the app + return httpx.AsyncClient(transport=transport, base_url="http://localhost") + + +@pytest.fixture +def pkce_challenge(): + """Create a PKCE challenge with code_verifier and code_challenge.""" + import base64 + import hashlib + import secrets + + # Generate a code verifier + code_verifier = secrets.token_urlsafe(64)[:128] + + # Create code challenge using S256 method + code_verifier_bytes = code_verifier.encode("ascii") + sha256 = hashlib.sha256(code_verifier_bytes).digest() + code_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + + return {"code_verifier": code_verifier, "code_challenge": code_challenge} + + +@pytest.fixture +async def registered_client(client): + """Create and register a test client.""" + # Default client metadata + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "client_name": "Test Client", + } + + response = await client.post("/register", json=client_metadata) + assert response.status_code == 201, f"Failed to register client: {response.content}" + + client_info = response.json() + return client_info + + +class TestRegistrationErrorHandling: + @pytest.mark.anyio + async def test_registration_error_handling(self, client, oauth_provider): + # Mock the register_client method to raise a registration error + with unittest.mock.patch.object( + oauth_provider, + "register_client", + side_effect=RegistrationError( + error="invalid_redirect_uri", + error_description="The redirect URI is invalid", + ), + ): + # Prepare a client registration request + client_data = { + "redirect_uris": ["https://client.example.com/callback"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "client_name": "Test Client", + } + + # Send the registration request + response = await client.post( + "/register", + json=client_data, + ) + + # Verify the response + assert response.status_code == 400, response.content + data = response.json() + assert data["error"] == "invalid_redirect_uri" + assert data["error_description"] == "The redirect URI is invalid" + + +class TestAuthorizeErrorHandling: + @pytest.mark.anyio + async def test_authorize_error_handling( + self, client, oauth_provider, registered_client, pkce_challenge + ): + # Mock the authorize method to raise an authorize error + with unittest.mock.patch.object( + oauth_provider, + "authorize", + side_effect=AuthorizeError( + error="access_denied", error_description="The user denied the request" + ), + ): + # Register the client + client_id = registered_client["client_id"] + redirect_uri = registered_client["redirect_uris"][0] + + # Prepare an authorization request + params = { + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + } + + # Send the authorization request + response = await client.get("/authorize", params=params) + + # Verify the response is a redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert query_params["error"][0] == "access_denied" + assert "error_description" in query_params + assert query_params["state"][0] == "test_state" + + +class TestTokenErrorHandling: + @pytest.mark.anyio + async def test_token_error_handling_auth_code( + self, client, oauth_provider, registered_client, pkce_challenge + ): + # Register the client and get an auth code + client_id = registered_client["client_id"] + client_secret = registered_client["client_secret"] + redirect_uri = registered_client["redirect_uris"][0] + + # First get an authorization code + auth_response = await client.get( + "/authorize", + params={ + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + redirect_url = auth_response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + code = query_params["code"][0] + + # Mock the exchange_authorization_code method to raise a token error + with unittest.mock.patch.object( + oauth_provider, + "exchange_authorization_code", + side_effect=TokenError( + error="invalid_grant", + error_description="The authorization code is invalid", + ), + ): + # Try to exchange the code for tokens + token_response = await client.post( + "/token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "client_secret": client_secret, + "code_verifier": pkce_challenge["code_verifier"], + }, + ) + + # Verify the response + assert token_response.status_code == 400 + data = token_response.json() + assert data["error"] == "invalid_grant" + assert data["error_description"] == "The authorization code is invalid" + + @pytest.mark.anyio + async def test_token_error_handling_refresh_token( + self, client, oauth_provider, registered_client, pkce_challenge + ): + # Register the client and get tokens + client_id = registered_client["client_id"] + client_secret = registered_client["client_secret"] + redirect_uri = registered_client["redirect_uris"][0] + + # First get an authorization code + auth_response = await client.get( + "/authorize", + params={ + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + assert auth_response.status_code == 302, auth_response.content + + redirect_url = auth_response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + code = query_params["code"][0] + + # Exchange the code for tokens + token_response = await client.post( + "/token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "client_secret": client_secret, + "code_verifier": pkce_challenge["code_verifier"], + }, + ) + + tokens = token_response.json() + refresh_token = tokens["refresh_token"] + + # Mock the exchange_refresh_token method to raise a token error + with unittest.mock.patch.object( + oauth_provider, + "exchange_refresh_token", + side_effect=TokenError( + error="invalid_scope", + error_description="The requested scope is invalid", + ), + ): + # Try to use the refresh token + refresh_response = await client.post( + "/token", + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client_id, + "client_secret": client_secret, + }, + ) + + # Verify the response + assert refresh_response.status_code == 400 + data = refresh_response.json() + assert data["error"] == "invalid_scope" + assert data["error_description"] == "The requested scope is invalid" diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 245edf1f1a..8693e65d48 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -21,7 +21,6 @@ AuthInfo, AuthorizationCode, AuthorizationParams, - OAuthRegisteredClientsStore, OAuthServerProvider, RefreshToken, construct_redirect_uri, @@ -33,19 +32,21 @@ ) from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp import FastMCP +from mcp.server.streaming_asgi_transport import StreamingASGITransport from mcp.shared.auth import ( OAuthClientInformationFull, OAuthToken, ) from mcp.types import JSONRPCRequest -from mcp.server.streaming_asgi_transport import StreamingASGITransport - -# Mock client store for testing -class MockClientStore: +# Mock OAuth provider for testing +class MockOAuthProvider(OAuthServerProvider): def __init__(self): self.clients = {} + self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} + self.tokens = {} # token -> {client_id, scopes, expires_at} + self.refresh_tokens = {} # refresh_token -> access_token async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: return self.clients.get(client_id) @@ -53,19 +54,6 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def register_client(self, client_info: OAuthClientInformationFull): self.clients[client_info.client_id] = client_info - -# Mock OAuth provider for testing -class MockOAuthProvider(OAuthServerProvider): - def __init__(self): - self.client_store = MockClientStore() - self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} - self.tokens = {} # token -> {client_id, scopes, expires_at} - self.refresh_tokens = {} # refresh_token -> access_token - - @property - def clients_store(self) -> OAuthRegisteredClientsStore: - return self.client_store - async def authorize( self, client: OAuthClientInformationFull, params: AuthorizationParams ) -> str: @@ -972,7 +960,7 @@ async def test_client_registration_default_scopes( assert client_info["scope"] == "read write" # Retrieve the client from the store to verify default scopes - registered_client = await mock_oauth_provider.clients_store.get_client( + registered_client = await mock_oauth_provider.get_client( client_info["client_id"] ) assert registered_client is not None From 374a0b4903ffcaf4dc0b99cb46ad1465a0f7d8a2 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Fri, 21 Mar 2025 14:01:07 -0700 Subject: [PATCH 53/60] Rename AuthInfo to AccessToken --- src/mcp/server/auth/handlers/revoke.py | 4 ++-- src/mcp/server/auth/middleware/auth_context.py | 10 +++++----- src/mcp/server/auth/middleware/bearer_auth.py | 6 +++--- src/mcp/server/auth/provider.py | 4 ++-- tests/server/fastmcp/auth/test_auth_integration.py | 14 +++++++------- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 2d8a745b4b..b4ea2f2ff6 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -14,7 +14,7 @@ AuthenticationError, ClientAuthenticator, ) -from mcp.server.auth.provider import AuthInfo, OAuthServerProvider, RefreshToken +from mcp.server.auth.provider import AccessToken, OAuthServerProvider, RefreshToken class RevocationRequest(BaseModel): @@ -75,7 +75,7 @@ async def handle(self, request: Request) -> Response: if revocation_request.token_type_hint == "refresh_token": loaders = reversed(loaders) - token: None | AuthInfo | RefreshToken = None + token: None | AccessToken | RefreshToken = None for loader in loaders: token = await loader(revocation_request.token) if token is not None: diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 7de643c891..de7f4e20c3 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -5,7 +5,7 @@ from starlette.responses import Response from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser -from mcp.server.auth.provider import AuthInfo +from mcp.server.auth.provider import AccessToken # Create a contextvar to store the authenticated user # The default is None, indicating no authenticated user is present @@ -14,15 +14,15 @@ ) -def get_current_auth_info() -> AuthInfo | None: +def get_access_token() -> AccessToken | None: """ - Get the auth info from the current context. + Get the access token from the current context. Returns: - The auth info if an authenticated user is available, None otherwise. + The access token if an authenticated user is available, None otherwise. """ auth_user = auth_context_var.get() - return auth_user.auth_info if auth_user else None + return auth_user.access_token if auth_user else None class AuthContextMiddleware(BaseHTTPMiddleware): diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6a64648b81..4f8fd46796 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -10,15 +10,15 @@ from starlette.requests import HTTPConnection from starlette.types import Scope -from mcp.server.auth.provider import AuthInfo, OAuthServerProvider +from mcp.server.auth.provider import AccessToken, OAuthServerProvider class AuthenticatedUser(SimpleUser): """User with authentication info.""" - def __init__(self, auth_info: AuthInfo): + def __init__(self, auth_info: AccessToken): super().__init__(auth_info.client_id) - self.auth_info = auth_info + self.access_token = auth_info self.scopes = auth_info.scopes diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index b98009cf2d..f5f4f18e62 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -33,7 +33,7 @@ class RefreshToken(BaseModel): expires_at: int | None = None -class AuthInfo(BaseModel): +class AccessToken(BaseModel): token: str client_id: str scopes: list[str] @@ -91,7 +91,7 @@ class TokenError(Exception): # OK to add fields to subclasses which should not be exposed externally. AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) -AuthInfoT = TypeVar("AuthInfoT", bound=AuthInfo) +AuthInfoT = TypeVar("AuthInfoT", bound=AccessToken) class OAuthServerProvider( diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 8693e65d48..6ae5e93833 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -18,7 +18,7 @@ from starlette.applications import Starlette from mcp.server.auth.provider import ( - AuthInfo, + AccessToken, AuthorizationCode, AuthorizationParams, OAuthServerProvider, @@ -88,7 +88,7 @@ async def exchange_authorization_code( refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the tokens - self.tokens[access_token] = AuthInfo( + self.tokens[access_token] = AccessToken( token=access_token, client_id=client.client_id, scopes=authorization_code.scopes, @@ -151,7 +151,7 @@ async def exchange_refresh_token( new_refresh_token = f"refresh_{secrets.token_hex(32)}" # Store the new tokens - self.tokens[new_access_token] = AuthInfo( + self.tokens[new_access_token] = AccessToken( token=new_access_token, client_id=client.client_id, scopes=scopes or token_info.scopes, @@ -172,27 +172,27 @@ async def exchange_refresh_token( refresh_token=new_refresh_token, ) - async def load_access_token(self, token: str) -> AuthInfo | None: + async def load_access_token(self, token: str) -> AccessToken | None: token_info = self.tokens.get(token) # Check if token is expired # if token_info.expires_at < int(time.time()): # raise InvalidTokenError("Access token has expired") - return token_info and AuthInfo( + return token_info and AccessToken( token=token, client_id=token_info.client_id, scopes=token_info.scopes, expires_at=token_info.expires_at, ) - async def revoke_token(self, token: AuthInfo | RefreshToken) -> None: + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: match token: case RefreshToken(): # Remove the refresh token del self.refresh_tokens[token.token] - case AuthInfo(): + case AccessToken(): # Remove the access token del self.tokens[token.token] From fb5a56831e5f0573d0cd6eaa3647d7470d0b904a Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Sat, 22 Mar 2025 08:14:20 -0700 Subject: [PATCH 54/60] Rename --- src/mcp/server/auth/provider.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index f5f4f18e62..a6d5c0cf04 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -91,11 +91,11 @@ class TokenError(Exception): # OK to add fields to subclasses which should not be exposed externally. AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) -AuthInfoT = TypeVar("AuthInfoT", bound=AccessToken) +AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) class OAuthServerProvider( - Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AuthInfoT] + Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT] ): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """ @@ -245,7 +245,7 @@ async def exchange_refresh_token( """ ... - async def load_access_token(self, token: str) -> AuthInfoT | None: + async def load_access_token(self, token: str) -> AccessTokenT | None: """ Loads an access token by its token. @@ -259,7 +259,7 @@ async def load_access_token(self, token: str) -> AuthInfoT | None: async def revoke_token( self, - token: AuthInfoT | RefreshTokenT, + token: AccessTokenT | RefreshTokenT, ) -> None: """ Revokes an access or refresh token. From 76ddc65a2dad159960b73326aab326565969c722 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Sat, 22 Mar 2025 08:24:46 -0700 Subject: [PATCH 55/60] Add docs --- README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/README.md b/README.md index bdbc9bca56..dbba1a6420 100644 --- a/README.md +++ b/README.md @@ -250,6 +250,33 @@ async def long_task(files: list[str], ctx: Context) -> str: return "Processing complete" ``` +### Authentication + +Authentication can be used by servers that want to expose tools accessing protected resources. + +`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by +providing an implementation of the `OAuthServerProvider` protocol. + +``` +mcp = FastMCP("My App", + auth_provider=MyOAuthServerProvider(), + auth=AuthSettings( + issuer_url="https://myapp.com", + revocation_options=RevocationOptions( + enabled=True, + ), + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=["myscope", "myotherscope"], + default_scopes=["myscope"], + ), + required_scopes=["myscope"], + ), +) +``` + +See [OAuthServerProvider](mcp/server/auth/provider.py) for more details. + ## Running Your Server ### Development Mode From 10e00e7e128d858091d541d56382e72d7df67ea0 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Sat, 22 Mar 2025 08:58:28 -0700 Subject: [PATCH 56/60] Typecheck --- src/mcp/client/sse.py | 4 ++-- src/mcp/server/auth/handlers/authorize.py | 6 +++--- src/mcp/server/auth/handlers/register.py | 5 +++-- src/mcp/server/auth/handlers/revoke.py | 4 ++-- src/mcp/server/auth/handlers/token.py | 15 ++++++++++---- src/mcp/server/auth/middleware/bearer_auth.py | 8 ++++---- src/mcp/server/auth/middleware/client_auth.py | 3 ++- src/mcp/server/auth/routes.py | 3 ++- src/mcp/server/fastmcp/server.py | 19 +++++++++--------- src/mcp/server/streaming_asgi_transport.py | 20 ++++++++++++------- .../fastmcp/auth/test_auth_integration.py | 2 +- tests/shared/test_sse.py | 4 +--- tests/shared/test_ws.py | 4 +--- 13 files changed, 55 insertions(+), 42 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index c84340a154..0812876fc8 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,6 +1,6 @@ import logging from contextlib import asynccontextmanager -from typing import Any, Union +from typing import Any from urllib.parse import urljoin, urlparse import anyio @@ -26,7 +26,7 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, - auth: Union[AuthSession, OAuthClient, None] = None, + auth: AuthSession | OAuthClient | None = None, ): """ Client transport for SSE. diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 4223e8cecf..b6079da974 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Literal +from typing import Any, Literal from urllib.parse import urlencode, urlparse, urlunparse from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError @@ -70,13 +70,13 @@ def best_effort_extract_string( return None -class AnyHttpUrlModel(RootModel): +class AnyHttpUrlModel(RootModel[AnyHttpUrl]): root: AnyHttpUrl @dataclass class AuthorizationHandler: - provider: OAuthServerProvider + provider: OAuthServerProvider[Any, Any, Any] async def handle(self, request: Request) -> Response: # implements authorization requests for grant_type=code; diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index d1f0213c95..29f97319a8 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -1,6 +1,7 @@ import secrets import time from dataclasses import dataclass +from typing import Any from uuid import uuid4 from pydantic import BaseModel, RootModel, ValidationError @@ -18,7 +19,7 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata -class RegistrationRequest(RootModel): +class RegistrationRequest(RootModel[OAuthClientMetadata]): # this wrapper is a no-op; it's just to separate out the types exposed to the # provider from what we use in the HTTP handler root: OAuthClientMetadata @@ -31,7 +32,7 @@ class RegistrationErrorResponse(BaseModel): @dataclass class RegistrationHandler: - provider: OAuthServerProvider + provider: OAuthServerProvider[Any, Any, Any] options: ClientRegistrationOptions async def handle(self, request: Request) -> Response: diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index b4ea2f2ff6..37883cd700 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from functools import partial -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, ValidationError from starlette.requests import Request @@ -35,7 +35,7 @@ class RevocationErrorResponse(BaseModel): @dataclass class RevocationHandler: - provider: OAuthServerProvider + provider: OAuthServerProvider[Any, Any, Any] client_authenticator: ClientAuthenticator async def handle(self, request: Request) -> Response: diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 54320a2ff4..a79cc7f1ba 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -2,7 +2,7 @@ import hashlib import time from dataclasses import dataclass -from typing import Annotated, Literal +from typing import Annotated, Any, Literal from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError from starlette.requests import Request @@ -44,7 +44,14 @@ class RefreshTokenRequest(BaseModel): client_secret: str | None = None -class TokenRequest(RootModel): +class TokenRequest( + RootModel[ + Annotated[ + AuthorizationCodeRequest | RefreshTokenRequest, + Field(discriminator="grant_type"), + ] + ] +): root: Annotated[ AuthorizationCodeRequest | RefreshTokenRequest, Field(discriminator="grant_type"), @@ -61,7 +68,7 @@ class TokenErrorResponse(BaseModel): error_uri: AnyHttpUrl | None = None -class TokenSuccessResponse(RootModel): +class TokenSuccessResponse(RootModel[OAuthToken]): # this is just a wrapper over OAuthToken; the only reason we do this # is to have some separation between the HTTP response type, and the # type returned by the provider @@ -70,7 +77,7 @@ class TokenSuccessResponse(RootModel): @dataclass class TokenHandler: - provider: OAuthServerProvider + provider: OAuthServerProvider[Any, Any, Any] client_authenticator: ClientAuthenticator def response(self, obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse): diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 4f8fd46796..2785ecd5f0 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,5 +1,5 @@ import time -from typing import Any, Callable +from typing import Any from starlette.authentication import ( AuthCredentials, @@ -8,7 +8,7 @@ ) from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection -from starlette.types import Scope +from starlette.types import Receive, Scope, Send from mcp.server.auth.provider import AccessToken, OAuthServerProvider @@ -29,7 +29,7 @@ class BearerAuthBackend(AuthenticationBackend): def __init__( self, - provider: OAuthServerProvider, + provider: OAuthServerProvider[Any, Any, Any], ): self.provider = provider @@ -72,7 +72,7 @@ def __init__(self, app: Any, required_scopes: list[str]): self.app = app self.required_scopes = required_scopes - async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: auth_credentials = scope.get("auth") for required_scope in self.required_scopes: diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 62a95e313a..da0ab0369f 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -1,4 +1,5 @@ import time +from typing import Any from mcp.server.auth.provider import OAuthServerProvider from mcp.shared.auth import OAuthClientInformationFull @@ -20,7 +21,7 @@ class ClientAuthenticator: logic is skipped. """ - def __init__(self, provider: OAuthServerProvider): + def __init__(self, provider: OAuthServerProvider[Any, Any, Any]): """ Initialize the dependency. diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index db8813ef37..3e7e77bcd3 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from pydantic import AnyHttpUrl from starlette.routing import Route diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index b736315e69..0098511bb5 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -5,13 +5,13 @@ import inspect import json import re -from collections.abc import AsyncIterator, Callable, Iterable, Sequence +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import ( AbstractAsyncContextManager, asynccontextmanager, ) from itertools import chain -from typing import Any, Awaitable, Generic, Literal +from typing import Any, Generic, Literal import anyio import pydantic_core @@ -22,10 +22,10 @@ from sse_starlette import EventSourceResponse from starlette.applications import Starlette from starlette.authentication import requires +from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import Response -from starlette.middleware import Middleware from starlette.routing import Mount, Route from mcp.server.auth.middleware.auth_context import AuthContextMiddleware @@ -491,7 +491,6 @@ def custom_route( name: str | None = None, include_in_schema: bool = True, ): - def decorator( func: Callable[[Request], Awaitable[Response]], ) -> Callable[[Request], Awaitable[Response]]: @@ -541,7 +540,7 @@ async def handle_sse(request: Request) -> EventSourceResponse: async with sse.connect_sse( request.scope, request.receive, - request._send # type: ignore[reportPrivateUsage] + request._send, # type: ignore[reportPrivateUsage] ) as streams: await self._mcp_server.run( streams[0], @@ -586,7 +585,9 @@ async def handle_sse(request: Request) -> EventSourceResponse: routes.append( Route( - self.settings.sse_path, endpoint=requires(required_scopes)(handle_sse), methods=["GET"] + self.settings.sse_path, + endpoint=requires(required_scopes)(handle_sse), + methods=["GET"], ) ) routes.append( @@ -754,9 +755,9 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent Returns: The resource content as either text or bytes """ - assert self._fastmcp is not None, ( - "Context is not available outside of a request" - ) + assert ( + self._fastmcp is not None + ), "Context is not available outside of a request" return await self._fastmcp.read_resource(uri) async def log( diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index 98a706b381..4cbd77370d 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -9,7 +9,7 @@ """ import typing -from typing import Any, Dict, Tuple +from typing import Any, cast import anyio import anyio.abc @@ -17,6 +17,7 @@ from httpx._models import Request, Response from httpx._transports.base import AsyncBaseTransport from httpx._types import AsyncByteStream +from starlette.types import ASGIApp, Receive, Scope, Send class StreamingASGITransport(AsyncBaseTransport): @@ -42,11 +43,11 @@ class StreamingASGITransport(AsyncBaseTransport): def __init__( self, - app: typing.Callable, + app: ASGIApp, task_group: anyio.abc.TaskGroup, raise_app_exceptions: bool = True, root_path: str = "", - client: Tuple[str, int] = ("127.0.0.1", 123), + client: tuple[str, int] = ("127.0.0.1", 123), ) -> None: self.app = app self.raise_app_exceptions = raise_app_exceptions @@ -88,13 +89,15 @@ async def handle_async_request( initial_response_ready = anyio.Event() # Synchronization for streaming response - asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream(100) + asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[ + dict[str, Any] + ](100) content_send_channel, content_receive_channel = ( anyio.create_memory_object_stream[bytes](100) ) # ASGI callables. - async def receive() -> Dict[str, Any]: + async def receive() -> dict[str, Any]: nonlocal request_complete if request_complete: @@ -108,7 +111,7 @@ async def receive() -> Dict[str, Any]: return {"type": "http.request", "body": b"", "more_body": False} return {"type": "http.request", "body": body, "more_body": True} - async def send(message: Dict[str, Any]) -> None: + async def send(message: dict[str, Any]) -> None: nonlocal status_code, response_headers, response_started await asgi_send_channel.send(message) @@ -116,7 +119,10 @@ async def send(message: Dict[str, Any]) -> None: # Start the ASGI application in a separate task async def run_app() -> None: try: - await self.app(scope, receive, send) + # Cast the receive and send functions to the ASGI types + await self.app( + cast(Scope, scope), cast(Receive, receive), cast(Send, send) + ) except Exception: if self.raise_app_exceptions: raise diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 6ae5e93833..45df6eaf46 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1019,7 +1019,7 @@ def test_tool(x: int) -> str: async with anyio.create_task_group() as task_group: transport = StreamingASGITransport( - app=mcp.starlette_app(), + app=mcp.sse_app(), task_group=task_group, ) test_client = httpx.AsyncClient( diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 43107b5978..f5158c3c37 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 2aca97e154..1381c8153c 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield From 87571d8ff4fee015709438b72c8999d3a026a780 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 24 Mar 2025 17:00:25 -0700 Subject: [PATCH 57/60] Return 401 on missing auth, not 403 --- src/mcp/server/auth/middleware/bearer_auth.py | 3 + src/mcp/server/fastmcp/server.py | 4 +- .../auth/middleware/test_bearer_auth.py | 371 ++++++++++++++++++ .../fastmcp/auth/test_auth_integration.py | 12 +- 4 files changed, 381 insertions(+), 9 deletions(-) create mode 100644 tests/server/auth/middleware/test_bearer_auth.py diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 2785ecd5f0..15e6f2fc5d 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -73,6 +73,9 @@ def __init__(self, app: Any, required_scopes: list[str]): self.required_scopes = required_scopes async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + auth_user = scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + raise HTTPException(status_code=401, detail="Unauthorized") auth_credentials = scope.get("auth") for required_scope in self.required_scopes: diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 0098511bb5..460cffac7a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -26,7 +26,7 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import Response -from starlette.routing import Mount, Route +from starlette.routing import Mount, Route, request_response from mcp.server.auth.middleware.auth_context import AuthContextMiddleware from mcp.server.auth.middleware.bearer_auth import ( @@ -586,7 +586,7 @@ async def handle_sse(request: Request) -> EventSourceResponse: routes.append( Route( self.settings.sse_path, - endpoint=requires(required_scopes)(handle_sse), + endpoint=RequireAuthMiddleware(request_response(handle_sse), required_scopes), methods=["GET"], ) ) diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py new file mode 100644 index 0000000000..d6ddb7c38d --- /dev/null +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -0,0 +1,371 @@ +""" +Tests for the BearerAuth middleware components. +""" + +import time +from typing import Any, Dict, List, Optional, cast + +import pytest +from starlette.authentication import AuthCredentials +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +from mcp.server.auth.middleware.bearer_auth import ( + AuthenticatedUser, + BearerAuthBackend, + RequireAuthMiddleware, +) +from mcp.server.auth.provider import ( + AccessToken, + OAuthServerProvider, +) + + +class MockOAuthProvider: + """Mock OAuth provider for testing. + + This is a simplified version that only implements the methods needed for testing + the BearerAuthMiddleware components. + """ + + def __init__(self): + self.tokens = {} # token -> AccessToken + + def add_token(self, token: str, access_token: AccessToken) -> None: + """Add a token to the provider.""" + self.tokens[token] = access_token + + async def load_access_token(self, token: str) -> Optional[AccessToken]: + """Load an access token.""" + return self.tokens.get(token) + + +def add_token_to_provider(provider: OAuthServerProvider[Any, Any, Any], token: str, access_token: AccessToken) -> None: + """Helper function to add a token to a provider. + + This is used to work around type checking issues with our mock provider. + """ + # We know this is actually a MockOAuthProvider + mock_provider = cast(MockOAuthProvider, provider) + mock_provider.add_token(token, access_token) + + +class MockApp: + """Mock ASGI app for testing.""" + + def __init__(self): + self.called = False + self.scope: Optional[Scope] = None + self.receive: Optional[Receive] = None + self.send: Optional[Send] = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.called = True + self.scope = scope + self.receive = receive + self.send = send + + +@pytest.fixture +def mock_oauth_provider() -> OAuthServerProvider[Any, Any, Any]: + """Create a mock OAuth provider.""" + # Use type casting to satisfy the type checker + return cast(OAuthServerProvider[Any, Any, Any], MockOAuthProvider()) + + +@pytest.fixture +def valid_access_token() -> AccessToken: + """Create a valid access token.""" + return AccessToken( + token="valid_token", + client_id="test_client", + scopes=["read", "write"], + expires_at=int(time.time()) + 3600, # 1 hour from now + ) + + +@pytest.fixture +def expired_access_token() -> AccessToken: + """Create an expired access token.""" + return AccessToken( + token="expired_token", + client_id="test_client", + scopes=["read"], + expires_at=int(time.time()) - 3600, # 1 hour ago + ) + + +@pytest.fixture +def no_expiry_access_token() -> AccessToken: + """Create an access token with no expiry.""" + return AccessToken( + token="no_expiry_token", + client_id="test_client", + scopes=["read", "write"], + expires_at=None, + ) + + +@pytest.mark.anyio +class TestBearerAuthBackend: + """Tests for the BearerAuthBackend class.""" + + async def test_no_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]): + """Test authentication with no Authorization header.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + request = Request({"type": "http", "headers": []}) + result = await backend.authenticate(request) + assert result is None + + async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]): + """Test authentication with non-Bearer Authorization header.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Basic dXNlcjpwYXNz")], + } + ) + result = await backend.authenticate(request) + assert result is None + + async def test_invalid_token(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]): + """Test authentication with invalid token.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer invalid_token")], + } + ) + result = await backend.authenticate(request) + assert result is None + + async def test_expired_token( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], expired_access_token: AccessToken + ): + """Test authentication with expired token.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer expired_token")], + } + ) + result = await backend.authenticate(request) + assert result is None + + async def test_valid_token( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], valid_access_token: AccessToken + ): + """Test authentication with valid token.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer valid_token")], + } + ) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + assert user.scopes == ["read", "write"] + + async def test_token_without_expiry( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], no_expiry_access_token: AccessToken + ): + """Test authentication with token that has no expiry.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer no_expiry_token")], + } + ) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == no_expiry_access_token + assert user.scopes == ["read", "write"] + + +@pytest.mark.anyio +class TestRequireAuthMiddleware: + """Tests for the RequireAuthMiddleware class.""" + + async def test_no_user(self): + """Test middleware with no user in scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + scope: Scope = {"type": "http"} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "Unauthorized" + assert not app.called + + async def test_non_authenticated_user(self): + """Test middleware with non-authenticated user in scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + scope: Scope = {"type": "http", "user": object()} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "Unauthorized" + assert not app.called + + async def test_missing_required_scope(self, valid_access_token: AccessToken): + """Test middleware with user missing required scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["admin"]) + + # Create a user with read/write scopes but not admin + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 403 + assert excinfo.value.detail == "Insufficient scope" + assert not app.called + + async def test_no_auth_credentials(self, valid_access_token: AccessToken): + """Test middleware with no auth credentials in scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + + scope: Scope = {"type": "http", "user": user} # No auth credentials + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 403 + assert excinfo.value.detail == "Insufficient scope" + assert not app.called + + async def test_has_required_scopes(self, valid_access_token: AccessToken): + """Test middleware with user having all required scopes.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + await middleware(scope, receive, send) + + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send + + async def test_multiple_required_scopes(self, valid_access_token: AccessToken): + """Test middleware with multiple required scopes.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + await middleware(scope, receive, send) + + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send + + async def test_no_required_scopes(self, valid_access_token: AccessToken): + """Test middleware with no required scopes.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=[]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + await middleware(scope, receive, send) + + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index 45df6eaf46..e4c310f7b4 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -1008,7 +1008,7 @@ async def test_fastmcp_with_auth( issuer_url=AnyHttpUrl("https://auth.example.com"), client_registration_options=ClientRegistrationOptions(enabled=True), revocation_options=RevocationOptions(enabled=True), - required_scopes=["read"], + required_scopes=["read", "write"], ), ) @@ -1032,24 +1032,22 @@ def test_tool(x: int) -> str: # Test that auth is required for protected endpoints response = await test_client.get("/sse") - # TODO: we should return 401/403 depending on whether authn or authz fails - assert response.status_code == 403 + assert response.status_code == 401 response = await test_client.post("/messages/") - # TODO: we should return 401/403 depending on whether authn or authz fails - assert response.status_code == 403, response.content + assert response.status_code == 401, response.content response = await test_client.post( "/messages/", headers={"Authorization": "invalid"}, ) - assert response.status_code == 403 + assert response.status_code == 401 response = await test_client.post( "/messages/", headers={"Authorization": "Bearer invalid"}, ) - assert response.status_code == 403 + assert response.status_code == 401 # now, become authenticated and try to go through the flow again client_metadata = { From c6f991bdd9b92349bd48a2340db3b81fea4705a3 Mon Sep 17 00:00:00 2001 From: Peter Raboud Date: Mon, 24 Mar 2025 17:11:50 -0700 Subject: [PATCH 58/60] Convert AuthContextMiddleware to plain ASGI middleware & add tests --- .../server/auth/middleware/auth_context.py | 23 ++-- src/mcp/server/fastmcp/server.py | 5 +- .../auth/middleware/test_bearer_auth.py | 128 ++++++++++-------- 3 files changed, 84 insertions(+), 72 deletions(-) diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index de7f4e20c3..1073c07ada 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -1,8 +1,6 @@ import contextvars -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request -from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser from mcp.server.auth.provider import AccessToken @@ -25,7 +23,7 @@ def get_access_token() -> AccessToken | None: return auth_user.access_token if auth_user else None -class AuthContextMiddleware(BaseHTTPMiddleware): +class AuthContextMiddleware: """ Middleware that extracts the authenticated user from the request and sets it in a contextvar for easy access throughout the request lifecycle. @@ -35,23 +33,18 @@ class AuthContextMiddleware(BaseHTTPMiddleware): being stored in the context. """ - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - # Get the authenticated user from the request if it exists - user = getattr(request, "user", None) + def __init__(self, app: ASGIApp): + self.app = app - # Only set the context var if the user is an AuthenticatedUser + async def __call__(self, scope: Scope, receive: Receive, send: Send): + user = scope.get("user") if isinstance(user, AuthenticatedUser): # Set the authenticated user in the contextvar token = auth_context_var.set(user) try: - # Process the request - response = await call_next(request) - return response + await self.app(scope, receive, send) finally: - # Reset the contextvar after the request is processed auth_context_var.reset(token) else: # No authenticated user, just process the request - return await call_next(request) + await self.app(scope, receive, send) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 460cffac7a..c2c9ac7249 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -21,7 +21,6 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from sse_starlette import EventSourceResponse from starlette.applications import Starlette -from starlette.authentication import requires from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request @@ -586,7 +585,9 @@ async def handle_sse(request: Request) -> EventSourceResponse: routes.append( Route( self.settings.sse_path, - endpoint=RequireAuthMiddleware(request_response(handle_sse), required_scopes), + endpoint=RequireAuthMiddleware( + request_response(handle_sse), required_scopes + ), methods=["GET"], ) ) diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index d6ddb7c38d..a6da24e398 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -3,13 +3,13 @@ """ import time -from typing import Any, Dict, List, Optional, cast +from typing import Any, cast import pytest from starlette.authentication import AuthCredentials from starlette.exceptions import HTTPException from starlette.requests import Request -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import Message, Receive, Scope, Send from mcp.server.auth.middleware.bearer_auth import ( AuthenticatedUser, @@ -24,7 +24,7 @@ class MockOAuthProvider: """Mock OAuth provider for testing. - + This is a simplified version that only implements the methods needed for testing the BearerAuthMiddleware components. """ @@ -36,14 +36,16 @@ def add_token(self, token: str, access_token: AccessToken) -> None: """Add a token to the provider.""" self.tokens[token] = access_token - async def load_access_token(self, token: str) -> Optional[AccessToken]: + async def load_access_token(self, token: str) -> AccessToken | None: """Load an access token.""" return self.tokens.get(token) -def add_token_to_provider(provider: OAuthServerProvider[Any, Any, Any], token: str, access_token: AccessToken) -> None: +def add_token_to_provider( + provider: OAuthServerProvider[Any, Any, Any], token: str, access_token: AccessToken +) -> None: """Helper function to add a token to a provider. - + This is used to work around type checking issues with our mock provider. """ # We know this is actually a MockOAuthProvider @@ -56,9 +58,9 @@ class MockApp: def __init__(self): self.called = False - self.scope: Optional[Scope] = None - self.receive: Optional[Receive] = None - self.send: Optional[Send] = None + self.scope: Scope | None = None + self.receive: Receive | None = None + self.send: Send | None = None async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.called = True @@ -111,14 +113,18 @@ def no_expiry_access_token() -> AccessToken: class TestBearerAuthBackend: """Tests for the BearerAuthBackend class.""" - async def test_no_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]): + async def test_no_auth_header( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any] + ): """Test authentication with no Authorization header.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request({"type": "http", "headers": []}) result = await backend.authenticate(request) assert result is None - async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]): + async def test_non_bearer_auth_header( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any] + ): """Test authentication with non-Bearer Authorization header.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request( @@ -130,7 +136,9 @@ async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthServerProv result = await backend.authenticate(request) assert result is None - async def test_invalid_token(self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any]): + async def test_invalid_token( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any] + ): """Test authentication with invalid token.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request( @@ -143,11 +151,15 @@ async def test_invalid_token(self, mock_oauth_provider: OAuthServerProvider[Any, assert result is None async def test_expired_token( - self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], expired_access_token: AccessToken + self, + mock_oauth_provider: OAuthServerProvider[Any, Any, Any], + expired_access_token: AccessToken, ): """Test authentication with expired token.""" backend = BearerAuthBackend(provider=mock_oauth_provider) - add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) + add_token_to_provider( + mock_oauth_provider, "expired_token", expired_access_token + ) request = Request( { "type": "http", @@ -158,7 +170,9 @@ async def test_expired_token( assert result is None async def test_valid_token( - self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], valid_access_token: AccessToken + self, + mock_oauth_provider: OAuthServerProvider[Any, Any, Any], + valid_access_token: AccessToken, ): """Test authentication with valid token.""" backend = BearerAuthBackend(provider=mock_oauth_provider) @@ -180,11 +194,15 @@ async def test_valid_token( assert user.scopes == ["read", "write"] async def test_token_without_expiry( - self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any], no_expiry_access_token: AccessToken + self, + mock_oauth_provider: OAuthServerProvider[Any, Any, Any], + no_expiry_access_token: AccessToken, ): """Test authentication with token that has no expiry.""" backend = BearerAuthBackend(provider=mock_oauth_provider) - add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token) + add_token_to_provider( + mock_oauth_provider, "no_expiry_token", no_expiry_access_token + ) request = Request( { "type": "http", @@ -211,17 +229,17 @@ async def test_no_user(self): app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) scope: Scope = {"type": "http"} - + # Create dummy async functions for receive and send async def receive() -> Message: return {"type": "http.request"} - + async def send(message: Message) -> None: pass - + with pytest.raises(HTTPException) as excinfo: await middleware(scope, receive, send) - + assert excinfo.value.status_code == 401 assert excinfo.value.detail == "Unauthorized" assert not app.called @@ -231,17 +249,17 @@ async def test_non_authenticated_user(self): app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) scope: Scope = {"type": "http", "user": object()} - + # Create dummy async functions for receive and send async def receive() -> Message: return {"type": "http.request"} - + async def send(message: Message) -> None: pass - + with pytest.raises(HTTPException) as excinfo: await middleware(scope, receive, send) - + assert excinfo.value.status_code == 401 assert excinfo.value.detail == "Unauthorized" assert not app.called @@ -250,23 +268,23 @@ async def test_missing_required_scope(self, valid_access_token: AccessToken): """Test middleware with user missing required scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["admin"]) - + # Create a user with read/write scopes but not admin user = AuthenticatedUser(valid_access_token) auth = AuthCredentials(["read", "write"]) - + scope: Scope = {"type": "http", "user": user, "auth": auth} - + # Create dummy async functions for receive and send async def receive() -> Message: return {"type": "http.request"} - + async def send(message: Message) -> None: pass - + with pytest.raises(HTTPException) as excinfo: await middleware(scope, receive, send) - + assert excinfo.value.status_code == 403 assert excinfo.value.detail == "Insufficient scope" assert not app.called @@ -275,22 +293,22 @@ async def test_no_auth_credentials(self, valid_access_token: AccessToken): """Test middleware with no auth credentials in scope.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) - + # Create a user with read/write scopes user = AuthenticatedUser(valid_access_token) - + scope: Scope = {"type": "http", "user": user} # No auth credentials - + # Create dummy async functions for receive and send async def receive() -> Message: return {"type": "http.request"} - + async def send(message: Message) -> None: pass - + with pytest.raises(HTTPException) as excinfo: await middleware(scope, receive, send) - + assert excinfo.value.status_code == 403 assert excinfo.value.detail == "Insufficient scope" assert not app.called @@ -299,22 +317,22 @@ async def test_has_required_scopes(self, valid_access_token: AccessToken): """Test middleware with user having all required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read"]) - + # Create a user with read/write scopes user = AuthenticatedUser(valid_access_token) auth = AuthCredentials(["read", "write"]) - + scope: Scope = {"type": "http", "user": user, "auth": auth} - + # Create dummy async functions for receive and send async def receive() -> Message: return {"type": "http.request"} - + async def send(message: Message) -> None: pass - + await middleware(scope, receive, send) - + assert app.called assert app.scope == scope assert app.receive == receive @@ -324,22 +342,22 @@ async def test_multiple_required_scopes(self, valid_access_token: AccessToken): """Test middleware with multiple required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"]) - + # Create a user with read/write scopes user = AuthenticatedUser(valid_access_token) auth = AuthCredentials(["read", "write"]) - + scope: Scope = {"type": "http", "user": user, "auth": auth} - + # Create dummy async functions for receive and send async def receive() -> Message: return {"type": "http.request"} - + async def send(message: Message) -> None: pass - + await middleware(scope, receive, send) - + assert app.called assert app.scope == scope assert app.receive == receive @@ -349,22 +367,22 @@ async def test_no_required_scopes(self, valid_access_token: AccessToken): """Test middleware with no required scopes.""" app = MockApp() middleware = RequireAuthMiddleware(app, required_scopes=[]) - + # Create a user with read/write scopes user = AuthenticatedUser(valid_access_token) auth = AuthCredentials(["read", "write"]) - + scope: Scope = {"type": "http", "user": user, "auth": auth} - + # Create dummy async functions for receive and send async def receive() -> Message: return {"type": "http.request"} - + async def send(message: Message) -> None: pass - + await middleware(scope, receive, send) - + assert app.called assert app.scope == scope assert app.receive == receive From 800b66a6242959673273b214ff405d810e6cde98 Mon Sep 17 00:00:00 2001 From: Jerome Date: Tue, 25 Mar 2025 13:52:41 +0000 Subject: [PATCH 59/60] Added CORS middleware to allow cross-origin requests - Allow any origin to make requests - Allow GET, POST, and OPTIONS HTTP methods - Allow any headers - Allow sending credentials with requests Also added OPTIONS method to auth routes to handle CORS preflight requests. --- src/mcp/server/auth/handlers/token.py | 3 ++- src/mcp/server/auth/routes.py | 9 ++++++--- src/mcp/server/fastmcp/server.py | 11 ++++++++++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index a79cc7f1ba..0d6e2deade 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -25,7 +25,8 @@ class AuthorizationCodeRequest(BaseModel): grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") redirect_uri: AnyHttpUrl | None = Field( - ..., description="Must be the same as redirect URI provided in /authorize" + default=None, + description="Must be the same as redirect URI provided in /authorize", ) client_id: str # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 3e7e77bcd3..865a61d6d1 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -2,7 +2,10 @@ from typing import Any from pydantic import AnyHttpUrl +from starlette.middleware.cors import CORSMiddleware from starlette.routing import Route +from starlette.requests import Request +from starlette.responses import JSONResponse, Response from mcp.server.auth.handlers.authorize import AuthorizationHandler from mcp.server.auth.handlers.metadata import MetadataHandler @@ -73,17 +76,17 @@ def create_auth_routes( Route( "/.well-known/oauth-authorization-server", endpoint=MetadataHandler(metadata).handle, - methods=["GET"], + methods=["GET", "OPTIONS"], ), Route( AUTHORIZATION_PATH, endpoint=AuthorizationHandler(provider).handle, - methods=["GET", "POST"], + methods=["GET", "POST", "OPTIONS"], ), Route( TOKEN_PATH, endpoint=TokenHandler(provider, client_authenticator).handle, - methods=["POST"], + methods=["POST", "OPTIONS"], ), ] diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c2c9ac7249..70604b7e5a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -23,6 +23,7 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route, request_response @@ -559,8 +560,16 @@ async def handle_sse(request: Request) -> EventSourceResponse: from mcp.server.auth.routes import create_auth_routes required_scopes = self.settings.auth.required_scopes or [] - + middleware = [ + # Add CORS middleware to allow cross-origin requests + Middleware( + CORSMiddleware, + allow_origins=["*"], # Allow any origin + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["*"], + allow_credentials=True, + ), # extract auth info from request (but do not require it) Middleware( AuthenticationMiddleware, From f614ea25e11c8042e5cb9ee09fa7305ff3ae499d Mon Sep 17 00:00:00 2001 From: Jerome Date: Tue, 25 Mar 2025 15:12:58 +0000 Subject: [PATCH 60/60] Commented out redirect_uri checks in authorization code request handling --- src/mcp/server/auth/handlers/token.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 0d6e2deade..aa1ce934e3 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -157,18 +157,18 @@ async def handle(self, request: Request): ) ) - # verify redirect_uri doesn't change between /authorize and /tokens - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 - if token_request.redirect_uri != auth_code.redirect_uri: - return self.response( - TokenErrorResponse( - error="invalid_request", - error_description=( - "redirect_uri did not match the one " - "used when creating auth code" - ), - ) - ) + # # verify redirect_uri doesn't change between /authorize and /tokens + # # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 + # if token_request.redirect_uri != auth_code.redirect_uri: + # return self.response( + # TokenErrorResponse( + # error="invalid_request", + # error_description=( + # "redirect_uri did not match the one " + # "used when creating auth code" + # ), + # ) + # ) # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()