Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@
logger = logging.getLogger(__name__)


def _normalize_resource_url(resource: str) -> str:
"""Undo the trailing slash URL parsers add to bare-domain URLs (e.g. pydantic's AnyHttpUrl).

RFC 9728 requires exact-string identity on the resource identifier, so only a root path
with no query or fragment is stripped; trailing slashes on deeper paths are preserved.
"""
parsed = urlparse(resource)
if parsed.path == "/" and not parsed.params and not parsed.query and not parsed.fragment:
return f"{parsed.scheme}://{parsed.netloc}"
return resource


class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""

Expand Down Expand Up @@ -152,7 +164,7 @@ def get_resource_url(self) -> str:

# If PRM provides a resource that's a valid parent, use it
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
prm_resource = str(self.protected_resource_metadata.resource)
prm_resource = _normalize_resource_url(str(self.protected_resource_metadata.resource))
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
resource = prm_resource

Expand Down Expand Up @@ -441,9 +453,8 @@ async def _refresh_token(self) -> httpx.Request:
"client_id": self.context.client_info.client_id,
}

# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
# The RFC 8707 resource param is deliberately omitted: some providers
# (e.g. Entra ID v2.0) reject it on refresh_token grants.

# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
Expand Down
60 changes: 56 additions & 4 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ class TestProtectedResourceMetadata:

@pytest.mark.anyio
async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider):
"""Test resource parameter is included for protocol version >= 2025-06-18."""
"""Test resource parameter is included in initial token requests for protocol version >= 2025-06-18."""
# Set protocol version to 2025-06-18
oauth_provider.context.protocol_version = "2025-06-18"
oauth_provider.context.client_info = OAuthClientInformationFull(
Expand All @@ -762,15 +762,16 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_
expected_resource = quote(oauth_provider.context.get_resource_url(), safe="")
assert f"resource={expected_resource}" in content

# Test in refresh token
# Refresh requests never include the resource parameter: some providers
# (e.g. Entra ID v2.0) reject RFC 8707 resource values on refresh_token grants.
oauth_provider.context.current_tokens = OAuthToken(
access_token="test_access",
token_type="Bearer",
refresh_token="test_refresh",
)
refresh_request = await oauth_provider._refresh_token()
refresh_content = refresh_request.content.decode()
assert "resource=" in refresh_content
assert "resource=" not in refresh_content

@pytest.mark.anyio
async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider):
Expand Down Expand Up @@ -800,7 +801,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro

@pytest.mark.anyio
async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider):
"""Test resource parameter is always included when protected resource metadata exists."""
"""Test resource parameter is included in initial token requests when protected resource metadata exists."""
# Set old protocol version but with protected resource metadata
oauth_provider.context.protocol_version = "2025-03-26"
oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata(
Expand All @@ -818,6 +819,16 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
content = request.content.decode()
assert "resource=" in content

# Even with PRM present, refresh requests omit the resource parameter
oauth_provider.context.current_tokens = OAuthToken(
access_token="test_access",
token_type="Bearer",
refresh_token="test_refresh",
)
refresh_request = await oauth_provider._refresh_token()
refresh_content = refresh_request.content.decode()
assert "resource=" not in refresh_content


@pytest.mark.parametrize(
("protocol_version", "expected"),
Expand Down Expand Up @@ -967,6 +978,47 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches(
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")


@pytest.mark.anyio
async def test_get_resource_url_removes_root_prm_trailing_slash(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""Bare-domain PRM resources should not pick up the trailing slash AnyHttpUrl adds."""
provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

# AnyHttpUrl normalizes "https://api.example.com" to "https://api.example.com/"
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)

assert provider.context.get_resource_url() == snapshot("https://api.example.com")


@pytest.mark.anyio
async def test_get_resource_url_preserves_non_root_trailing_slash(
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""RFC 9728 requires exact-string identity, so intentional trailing slashes on deeper paths stay."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp/",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

provider.context.protected_resource_metadata = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com/v1/mcp/"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)

assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp/")


class TestRegistrationResponse:
"""Test client registration response handling."""

Expand Down
9 changes: 5 additions & 4 deletions tests/interaction/auth/test_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ async def test_an_expired_access_token_is_transparently_refreshed_before_the_nex
The provider tells the client `expires_in=-3600` for the first token while keeping the
server-side `expires_at` in the future, so the connect's retry succeeds and the next
request finds the token expired and refreshes. The recorded requests prove exactly one
`grant_type=refresh_token` exchange carrying the resource indicator, and the bearer used
after the refresh is the second access token, which is the one persisted to storage.
`grant_type=refresh_token` exchange without the resource indicator (some providers,
e.g. Entra ID v2.0, reject RFC 8707 resource values on refresh_token grants), and the
bearer used after the refresh is the second access token, which is the one persisted
to storage.
"""
recorded, on_request = record_requests()
provider = InMemoryAuthorizationServerProvider(issue_expired_first=True)
Expand All @@ -123,9 +125,8 @@ async def test_an_expired_access_token_is_transparently_refreshed_before_the_nex
assert [b["grant_type"] for b in bodies] == snapshot(["authorization_code", "refresh_token"])

refresh_body = bodies[1]
assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token", "resource"])
assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token"])
assert refresh_body["refresh_token"].startswith("refresh_")
assert refresh_body["resource"].startswith(BASE_URL)

bearers = {r.headers["authorization"] for r in recorded if r.path == "/mcp" and "authorization" in r.headers}
assert len(bearers) == 2
Expand Down
Loading