Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/mcp/client/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
from mcp.client.auth.oauth2 import (
OAuthAuthorizationRedirect,
OAuthClientProvider,
PKCEParameters,
TokenStorage,
build_authorization_redirect,
)

__all__ = [
"OAuthAuthorizationRedirect",
"OAuthClientProvider",
"OAuthFlowError",
"OAuthRegistrationError",
"OAuthTokenError",
"PKCEParameters",
"TokenStorage",
"build_authorization_redirect",
]
121 changes: 92 additions & 29 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse

import anyio
import httpx
Expand Down Expand Up @@ -69,6 +69,85 @@ def generate(cls) -> "PKCEParameters":
return cls(code_verifier=code_verifier, code_challenge=code_challenge)


@dataclass(frozen=True)
class OAuthAuthorizationRedirect:
"""Resumable OAuth authorization redirect state.

Proxy and server-side callers can persist this value, send the authorization
URL to a user, and later resume token exchange with the returned code plus
the stored state and code verifier.
"""

authorization_url: str
state: str
code_verifier: str = field(repr=False)


def build_authorization_redirect(
*,
authorization_endpoint: str,
client_info: OAuthClientInformationFull,
client_metadata: OAuthClientMetadata,
pkce_params: PKCEParameters | None = None,
state: str | None = None,
resource_url: str | None = None,
) -> OAuthAuthorizationRedirect:
"""Build an OAuth authorization URL and resumable state.

Args:
authorization_endpoint: Authorization endpoint URL.
client_info: Registered OAuth client information.
client_metadata: Client metadata containing redirect URIs and scopes.
pkce_params: Optional PKCE parameters. Generated when omitted.
state: Optional OAuth state value. Generated when omitted.
resource_url: Optional RFC 8707 resource value.

Returns:
Authorization URL plus the state and code verifier needed to resume.

Raises:
OAuthFlowError: If no client ID or redirect URI is available.
"""
if client_info.client_id is None:
raise OAuthFlowError("No client ID provided for authorization code grant")

if client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant")

pkce_params = pkce_params or PKCEParameters.generate()
state = state or secrets.token_urlsafe(32)

auth_params = {
"response_type": "code",
"client_id": client_info.client_id,
"redirect_uri": str(client_metadata.redirect_uris[0]),
"state": state,
"code_challenge": pkce_params.code_challenge,
"code_challenge_method": "S256",
}

if resource_url:
auth_params["resource"] = resource_url

if client_metadata.scope:
auth_params["scope"] = client_metadata.scope

# OIDC requires prompt=consent when offline_access is requested
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
if "offline_access" in client_metadata.scope.split():
auth_params["prompt"] = "consent"

parsed_endpoint = urlparse(authorization_endpoint)
query_params = parse_qsl(parsed_endpoint.query, keep_blank_values=True)
query_params.extend(auth_params.items())
authorization_url = urlunparse(parsed_endpoint._replace(query=urlencode(query_params)))
return OAuthAuthorizationRedirect(
authorization_url=authorization_url,
state=state,
code_verifier=pkce_params.code_verifier,
)


class TokenStorage(Protocol):
"""Protocol for token storage implementations."""

Expand Down Expand Up @@ -327,45 +406,29 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
if not self.context.client_info:
raise OAuthFlowError("No client info available for authorization") # pragma: no cover

# Generate PKCE parameters
pkce_params = PKCEParameters.generate()
state = secrets.token_urlsafe(32)

auth_params = {
"response_type": "code",
"client_id": self.context.client_info.client_id,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"state": state,
"code_challenge": pkce_params.code_challenge,
"code_challenge_method": "S256",
}

# Only include resource param if conditions are met
resource_url = None
if self.context.should_include_resource_param(self.context.protocol_version):
auth_params["resource"] = self.context.get_resource_url() # RFC 8707
resource_url = self.context.get_resource_url() # RFC 8707

if self.context.client_metadata.scope: # pragma: no branch
auth_params["scope"] = self.context.client_metadata.scope

# OIDC requires prompt=consent when offline_access is requested
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
if "offline_access" in self.context.client_metadata.scope.split():
auth_params["prompt"] = "consent"

authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
await self.context.redirect_handler(authorization_url)
redirect = build_authorization_redirect(
authorization_endpoint=auth_endpoint,
client_info=self.context.client_info,
client_metadata=self.context.client_metadata,
resource_url=resource_url,
)
await self.context.redirect_handler(redirect.authorization_url)

# Wait for callback
auth_code, returned_state = await self.context.callback_handler()

if returned_state is None or not secrets.compare_digest(returned_state, state):
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}")
if returned_state is None or not secrets.compare_digest(returned_state, redirect.state):
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {redirect.state}")

if not auth_code:
raise OAuthFlowError("No authorization code received")

# Return auth code and code verifier for token exchange
return auth_code, pkce_params.code_verifier
return auth_code, redirect.code_verifier

def _get_token_endpoint(self) -> str:
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
Expand Down
113 changes: 112 additions & 1 deletion tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from inline_snapshot import Is, snapshot
from pydantic import AnyHttpUrl, AnyUrl

from mcp.client.auth import OAuthClientProvider, PKCEParameters
from mcp.client.auth import OAuthClientProvider, PKCEParameters, build_authorization_redirect
from mcp.client.auth.exceptions import OAuthFlowError
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
Expand Down Expand Up @@ -72,6 +72,117 @@ def client_metadata():
)


def test_build_authorization_redirect_returns_resumable_oauth_state(client_metadata: OAuthClientMetadata):
"""The authorization step can be built without browser/callback handlers."""
redirect = build_authorization_redirect(
authorization_endpoint="https://auth.example.com/authorize",
client_info=OAuthClientInformationFull(
client_id="test_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
),
client_metadata=client_metadata,
pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64),
state="stored-state",
resource_url="https://api.example.com/v1/mcp",
)

assert redirect.state == "stored-state"
assert redirect.code_verifier == "v" * 64

parsed = urlparse(redirect.authorization_url)
assert parsed.scheme == "https"
assert parsed.netloc == "auth.example.com"
assert parsed.path == "/authorize"

params = parse_qs(parsed.query)
assert params == {
"response_type": ["code"],
"client_id": ["test_client"],
"redirect_uri": ["http://localhost:3030/callback"],
"state": ["stored-state"],
"code_challenge": ["c" * 64],
"code_challenge_method": ["S256"],
"resource": ["https://api.example.com/v1/mcp"],
"scope": ["read write"],
}


def test_build_authorization_redirect_requires_redirect_uri():
with pytest.raises(OAuthFlowError, match="No redirect URIs provided"):
build_authorization_redirect(
authorization_endpoint="https://auth.example.com/authorize",
client_info=OAuthClientInformationFull(
client_id="test_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
),
client_metadata=OAuthClientMetadata(redirect_uris=None),
)


def test_build_authorization_redirect_requires_client_id():
with pytest.raises(OAuthFlowError, match="No client ID provided"):
build_authorization_redirect(
authorization_endpoint="https://auth.example.com/authorize",
client_info=OAuthClientInformationFull(
client_id=None,
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
),
client_metadata=OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]),
)


def test_build_authorization_redirect_prompts_for_offline_access():
redirect = build_authorization_redirect(
authorization_endpoint="https://auth.example.com/authorize",
client_info=OAuthClientInformationFull(
client_id="test_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
),
client_metadata=OAuthClientMetadata(
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
scope="read offline_access",
),
pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64),
state="stored-state",
)

assert parse_qs(urlparse(redirect.authorization_url).query)["prompt"] == ["consent"]


def test_build_authorization_redirect_preserves_existing_authorization_endpoint_query():
redirect = build_authorization_redirect(
authorization_endpoint="https://auth.example.com/authorize?audience=api",
client_info=OAuthClientInformationFull(
client_id="test_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
),
client_metadata=OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]),
pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64),
state="stored-state",
)

parsed = urlparse(redirect.authorization_url)
assert parsed.path == "/authorize"
assert parse_qs(parsed.query)["audience"] == ["api"]
assert parse_qs(parsed.query)["response_type"] == ["code"]


def test_authorization_redirect_repr_hides_code_verifier():
redirect = build_authorization_redirect(
authorization_endpoint="https://auth.example.com/authorize",
client_info=OAuthClientInformationFull(
client_id="test_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
),
client_metadata=OAuthClientMetadata(redirect_uris=[AnyUrl("http://localhost:3030/callback")]),
pkce_params=PKCEParameters(code_verifier="v" * 64, code_challenge="c" * 64),
state="stored-state",
)

assert "code_verifier" not in repr(redirect)
assert "v" * 64 not in repr(redirect)


@pytest.fixture
def valid_tokens():
return OAuthToken(
Expand Down
Loading