diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py index ab3179ecb..d23a8208f 100644 --- a/src/mcp/client/auth/__init__.py +++ b/src/mcp/client/auth/__init__.py @@ -5,16 +5,20 @@ from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError from mcp.client.auth.oauth2 import ( + OAuthAuthorizationRedirect, OAuthClientProvider, PKCEParameters, TokenStorage, + build_authorization_redirect, ) __all__ = [ + "OAuthAuthorizationRedirect", "OAuthClientProvider", "OAuthFlowError", "OAuthRegistrationError", "OAuthTokenError", "PKCEParameters", "TokenStorage", + "build_authorization_redirect", ] diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 01bcc8234..aef0944df 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -12,7 +12,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field from typing import Any, Protocol -from urllib.parse import quote, urlencode, urljoin, urlparse +from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse import anyio import httpx @@ -69,6 +69,85 @@ def generate(cls) -> "PKCEParameters": return cls(code_verifier=code_verifier, code_challenge=code_challenge) +@dataclass(frozen=True) +class OAuthAuthorizationRedirect: + """Resumable OAuth authorization redirect state. + + Proxy and server-side callers can persist this value, send the authorization + URL to a user, and later resume token exchange with the returned code plus + the stored state and code verifier. + """ + + authorization_url: str + state: str + code_verifier: str = field(repr=False) + + +def build_authorization_redirect( + *, + authorization_endpoint: str, + client_info: OAuthClientInformationFull, + client_metadata: OAuthClientMetadata, + pkce_params: PKCEParameters | None = None, + state: str | None = None, + resource_url: str | None = None, +) -> OAuthAuthorizationRedirect: + """Build an OAuth authorization URL and resumable state. + + Args: + authorization_endpoint: Authorization endpoint URL. + client_info: Registered OAuth client information. + client_metadata: Client metadata containing redirect URIs and scopes. + pkce_params: Optional PKCE parameters. Generated when omitted. + state: Optional OAuth state value. Generated when omitted. + resource_url: Optional RFC 8707 resource value. + + Returns: + Authorization URL plus the state and code verifier needed to resume. + + Raises: + OAuthFlowError: If no client ID or redirect URI is available. + """ + if client_info.client_id is None: + raise OAuthFlowError("No client ID provided for authorization code grant") + + if client_metadata.redirect_uris is None: + raise OAuthFlowError("No redirect URIs provided for authorization code grant") + + pkce_params = pkce_params or PKCEParameters.generate() + state = state or secrets.token_urlsafe(32) + + auth_params = { + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": str(client_metadata.redirect_uris[0]), + "state": state, + "code_challenge": pkce_params.code_challenge, + "code_challenge_method": "S256", + } + + if resource_url: + auth_params["resource"] = resource_url + + if client_metadata.scope: + auth_params["scope"] = client_metadata.scope + + # OIDC requires prompt=consent when offline_access is requested + # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess + if "offline_access" in client_metadata.scope.split(): + auth_params["prompt"] = "consent" + + parsed_endpoint = urlparse(authorization_endpoint) + query_params = parse_qsl(parsed_endpoint.query, keep_blank_values=True) + query_params.extend(auth_params.items()) + authorization_url = urlunparse(parsed_endpoint._replace(query=urlencode(query_params))) + return OAuthAuthorizationRedirect( + authorization_url=authorization_url, + state=state, + code_verifier=pkce_params.code_verifier, + ) + + class TokenStorage(Protocol): """Protocol for token storage implementations.""" @@ -327,45 +406,29 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: if not self.context.client_info: raise OAuthFlowError("No client info available for authorization") # pragma: no cover - # Generate PKCE parameters - pkce_params = PKCEParameters.generate() - state = secrets.token_urlsafe(32) - - auth_params = { - "response_type": "code", - "client_id": self.context.client_info.client_id, - "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), - "state": state, - "code_challenge": pkce_params.code_challenge, - "code_challenge_method": "S256", - } - - # Only include resource param if conditions are met + resource_url = None if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 + resource_url = self.context.get_resource_url() # RFC 8707 - if self.context.client_metadata.scope: # pragma: no branch - auth_params["scope"] = self.context.client_metadata.scope - - # OIDC requires prompt=consent when offline_access is requested - # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess - if "offline_access" in self.context.client_metadata.scope.split(): - auth_params["prompt"] = "consent" - - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" - await self.context.redirect_handler(authorization_url) + redirect = build_authorization_redirect( + authorization_endpoint=auth_endpoint, + client_info=self.context.client_info, + client_metadata=self.context.client_metadata, + resource_url=resource_url, + ) + await self.context.redirect_handler(redirect.authorization_url) # Wait for callback auth_code, returned_state = await self.context.callback_handler() - if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") + if returned_state is None or not secrets.compare_digest(returned_state, redirect.state): + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {redirect.state}") if not auth_code: raise OAuthFlowError("No authorization code received") # Return auth code and code verifier for token exchange - return auth_code, pkce_params.code_verifier + return auth_code, redirect.code_verifier def _get_token_endpoint(self) -> str: if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index ca7a495e6..bd7e87ece 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -10,7 +10,7 @@ from inline_snapshot import Is, snapshot from pydantic import AnyHttpUrl, AnyUrl -from mcp.client.auth import OAuthClientProvider, PKCEParameters +from mcp.client.auth import OAuthClientProvider, PKCEParameters, build_authorization_redirect from mcp.client.auth.exceptions import OAuthFlowError from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, @@ -72,6 +72,117 @@ def client_metadata(): ) +def test_build_authorization_redirect_returns_resumable_oauth_state(client_metadata: OAuthClientMetadata): + """The authorization step can be built without browser/callback handlers.""" + redirect = build_authorization_redirect( + authorization_endpoint="https://auth.example.com/authorize", + client_info=OAuthClientInformationFull( + client_id="test_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + client_metadata=client_metadata, + pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64), + state="stored-state", + resource_url="https://api.example.com/v1/mcp", + ) + + assert redirect.state == "stored-state" + assert redirect.code_verifier == "v" * 64 + + parsed = urlparse(redirect.authorization_url) + assert parsed.scheme == "https" + assert parsed.netloc == "auth.example.com" + assert parsed.path == "/authorize" + + params = parse_qs(parsed.query) + assert params == { + "response_type": ["code"], + "client_id": ["test_client"], + "redirect_uri": ["http://localhost:3030/callback"], + "state": ["stored-state"], + "code_challenge": ["c" * 64], + "code_challenge_method": ["S256"], + "resource": ["https://api.example.com/v1/mcp"], + "scope": ["read write"], + } + + +def test_build_authorization_redirect_requires_redirect_uri(): + with pytest.raises(OAuthFlowError, match="No redirect URIs provided"): + build_authorization_redirect( + authorization_endpoint="https://auth.example.com/authorize", + client_info=OAuthClientInformationFull( + client_id="test_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + client_metadata=OAuthClientMetadata(redirect_uris=None), + ) + + +def test_build_authorization_redirect_requires_client_id(): + with pytest.raises(OAuthFlowError, match="No client ID provided"): + build_authorization_redirect( + authorization_endpoint="https://auth.example.com/authorize", + client_info=OAuthClientInformationFull( + client_id=None, + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + client_metadata=OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]), + ) + + +def test_build_authorization_redirect_prompts_for_offline_access(): + redirect = build_authorization_redirect( + authorization_endpoint="https://auth.example.com/authorize", + client_info=OAuthClientInformationFull( + client_id="test_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + client_metadata=OAuthClientMetadata( + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + scope="read offline_access", + ), + pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64), + state="stored-state", + ) + + assert parse_qs(urlparse(redirect.authorization_url).query)["prompt"] == ["consent"] + + +def test_build_authorization_redirect_preserves_existing_authorization_endpoint_query(): + redirect = build_authorization_redirect( + authorization_endpoint="https://auth.example.com/authorize?audience=api", + client_info=OAuthClientInformationFull( + client_id="test_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + client_metadata=OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]), + pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64), + state="stored-state", + ) + + parsed = urlparse(redirect.authorization_url) + assert parsed.path == "/authorize" + assert parse_qs(parsed.query)["audience"] == ["api"] + assert parse_qs(parsed.query)["response_type"] == ["code"] + + +def test_authorization_redirect_repr_hides_code_verifier(): + redirect = build_authorization_redirect( + authorization_endpoint="https://auth.example.com/authorize", + client_info=OAuthClientInformationFull( + client_id="test_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ), + client_metadata=OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]), + pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64), + state="stored-state", + ) + + assert "code_verifier" not in repr(redirect) + assert "v" * 64 not in repr(redirect) + + @pytest.fixture def valid_tokens(): return OAuthToken(