diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 01bcc8234..c73093bb4 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -54,6 +54,18 @@ logger = logging.getLogger(__name__) +def _normalize_resource_url(resource: str) -> str: + """Undo the trailing slash URL parsers add to bare-domain URLs (e.g. pydantic's AnyHttpUrl). + + RFC 9728 requires exact-string identity on the resource identifier, so only a root path + with no query or fragment is stripped; trailing slashes on deeper paths are preserved. + """ + parsed = urlparse(resource) + if parsed.path == "/" and not parsed.params and not parsed.query and not parsed.fragment: + return f"{parsed.scheme}://{parsed.netloc}" + return resource + + class PKCEParameters(BaseModel): """PKCE (Proof Key for Code Exchange) parameters.""" @@ -152,7 +164,7 @@ def get_resource_url(self) -> str: # If PRM provides a resource that's a valid parent, use it if self.protected_resource_metadata and self.protected_resource_metadata.resource: - prm_resource = str(self.protected_resource_metadata.resource) + prm_resource = _normalize_resource_url(str(self.protected_resource_metadata.resource)) if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): resource = prm_resource @@ -441,9 +453,8 @@ async def _refresh_token(self) -> httpx.Request: "client_id": self.context.client_info.client_id, } - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 + # The RFC 8707 resource param is deliberately omitted: some providers + # (e.g. Entra ID v2.0) reject it on refresh_token grants. # Prepare authentication based on preferred method headers = {"Content-Type": "application/x-www-form-urlencoded"} diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index ca7a495e6..f8a13abbd 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -745,7 +745,7 @@ class TestProtectedResourceMetadata: @pytest.mark.anyio async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider): - """Test resource parameter is included for protocol version >= 2025-06-18.""" + """Test resource parameter is included in initial token requests for protocol version >= 2025-06-18.""" # Set protocol version to 2025-06-18 oauth_provider.context.protocol_version = "2025-06-18" oauth_provider.context.client_info = OAuthClientInformationFull( @@ -762,7 +762,8 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ expected_resource = quote(oauth_provider.context.get_resource_url(), safe="") assert f"resource={expected_resource}" in content - # Test in refresh token + # Refresh requests never include the resource parameter: some providers + # (e.g. Entra ID v2.0) reject RFC 8707 resource values on refresh_token grants. oauth_provider.context.current_tokens = OAuthToken( access_token="test_access", token_type="Bearer", @@ -770,7 +771,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ ) refresh_request = await oauth_provider._refresh_token() refresh_content = refresh_request.content.decode() - assert "resource=" in refresh_content + assert "resource=" not in refresh_content @pytest.mark.anyio async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider): @@ -800,7 +801,7 @@ async def test_resource_param_excluded_with_old_protocol_version(self, oauth_pro @pytest.mark.anyio async def test_resource_param_included_with_protected_resource_metadata(self, oauth_provider: OAuthClientProvider): - """Test resource parameter is always included when protected resource metadata exists.""" + """Test resource parameter is included in initial token requests when protected resource metadata exists.""" # Set old protocol version but with protected resource metadata oauth_provider.context.protocol_version = "2025-03-26" oauth_provider.context.protected_resource_metadata = ProtectedResourceMetadata( @@ -818,6 +819,16 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa content = request.content.decode() assert "resource=" in content + # Even with PRM present, refresh requests omit the resource parameter + oauth_provider.context.current_tokens = OAuthToken( + access_token="test_access", + token_type="Bearer", + refresh_token="test_refresh", + ) + refresh_request = await oauth_provider._refresh_token() + refresh_content = refresh_request.content.decode() + assert "resource=" not in refresh_content + @pytest.mark.parametrize( ("protocol_version", "expected"), @@ -967,6 +978,47 @@ async def test_get_resource_url_uses_canonical_when_prm_mismatches( assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp") +@pytest.mark.anyio +async def test_get_resource_url_removes_root_prm_trailing_slash( + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage +) -> None: + """Bare-domain PRM resources should not pick up the trailing slash AnyHttpUrl adds.""" + provider = OAuthClientProvider( + server_url="https://api.example.com", + client_metadata=client_metadata, + storage=mock_storage, + ) + provider._initialized = True + + # AnyHttpUrl normalizes "https://api.example.com" to "https://api.example.com/" + provider.context.protected_resource_metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + + assert provider.context.get_resource_url() == snapshot("https://api.example.com") + + +@pytest.mark.anyio +async def test_get_resource_url_preserves_non_root_trailing_slash( + client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage +) -> None: + """RFC 9728 requires exact-string identity, so intentional trailing slashes on deeper paths stay.""" + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp/", + client_metadata=client_metadata, + storage=mock_storage, + ) + provider._initialized = True + + provider.context.protected_resource_metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com/v1/mcp/"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + + assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp/") + + class TestRegistrationResponse: """Test client registration response handling.""" diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py index aa552ae8a..7e63feb52 100644 --- a/tests/interaction/auth/test_lifecycle.py +++ b/tests/interaction/auth/test_lifecycle.py @@ -104,8 +104,10 @@ async def test_an_expired_access_token_is_transparently_refreshed_before_the_nex The provider tells the client `expires_in=-3600` for the first token while keeping the server-side `expires_at` in the future, so the connect's retry succeeds and the next request finds the token expired and refreshes. The recorded requests prove exactly one - `grant_type=refresh_token` exchange carrying the resource indicator, and the bearer used - after the refresh is the second access token, which is the one persisted to storage. + `grant_type=refresh_token` exchange without the resource indicator (some providers, + e.g. Entra ID v2.0, reject RFC 8707 resource values on refresh_token grants), and the + bearer used after the refresh is the second access token, which is the one persisted + to storage. """ recorded, on_request = record_requests() provider = InMemoryAuthorizationServerProvider(issue_expired_first=True) @@ -123,9 +125,8 @@ async def test_an_expired_access_token_is_transparently_refreshed_before_the_nex assert [b["grant_type"] for b in bodies] == snapshot(["authorization_code", "refresh_token"]) refresh_body = bodies[1] - assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token", "resource"]) + assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token"]) assert refresh_body["refresh_token"].startswith("refresh_") - assert refresh_body["resource"].startswith(BASE_URL) bearers = {r.headers["authorization"] for r in recorded if r.path == "/mcp" and "authorization" in r.headers} assert len(bearers) == 2