diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 01bcc8234..0a9172f03 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -46,6 +46,7 @@ ) from mcp.shared.auth_utils import ( calculate_token_expiry, + calculate_token_refresh_time, check_resource_allowed, resource_url_from_server_url, ) @@ -113,6 +114,9 @@ class OAuthContext: # Token management current_tokens: OAuthToken | None = None token_expiry_time: float | None = None + # Jittered point (before hard expiry) at which to proactively refresh, so a fleet + # of connectors does not all refresh in the same window. See should_refresh_token. + token_refresh_time: float | None = None # State lock: anyio.Lock = field(default_factory=anyio.Lock) @@ -123,11 +127,12 @@ def get_authorization_base_url(self, server_url: str) -> str: return f"{parsed.scheme}://{parsed.netloc}" def update_token_expiry(self, token: OAuthToken) -> None: - """Update token expiry time using shared util function.""" + """Update token expiry and proactive-refresh times using shared util functions.""" self.token_expiry_time = calculate_token_expiry(token.expires_in) + self.token_refresh_time = calculate_token_refresh_time(token.expires_in) def is_token_valid(self) -> bool: - """Check if current token is valid.""" + """Check if current token is valid (i.e. usable, not past hard expiry).""" return bool( self.current_tokens and self.current_tokens.access_token @@ -138,10 +143,28 @@ def can_refresh_token(self) -> bool: """Check if token can be refreshed.""" return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) + def should_refresh_token(self) -> bool: + """Check if the token should be *proactively* refreshed. + + Returns True when we hold refreshable tokens and have passed the jittered + proactive-refresh point (``token_refresh_time``), even if the token is still + technically valid. Refreshing slightly early -- and at a per-connector jittered + moment -- spreads a fleet's refreshes out instead of bunching them into the + same expiry window. Returns False when no refresh time is known (no expiry + info) so behavior degrades to the existing reactive path. + """ + return bool( + self.current_tokens + and self.can_refresh_token() + and self.token_refresh_time is not None + and time.time() >= self.token_refresh_time + ) + def clear_tokens(self) -> None: """Clear current tokens.""" self.current_tokens = None self.token_expiry_time = None + self.token_refresh_time = None def get_resource_url(self) -> str: """Get resource URL for RFC 8707. @@ -511,7 +534,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - if not self.context.is_token_valid() and self.context.can_refresh_token(): + if ( + not self.context.is_token_valid() or self.context.should_refresh_token() + ) and self.context.can_refresh_token(): + # Refresh either reactively (token already invalid) or proactively + # (past the jittered refresh point, before hard expiry). # Try to refresh token refresh_request = await self._refresh_token() refresh_response = yield refresh_request diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 3ba880f40..a3771294c 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -1,5 +1,6 @@ """Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636).""" +import random import time from urllib.parse import urlparse, urlsplit, urlunsplit @@ -78,3 +79,80 @@ def calculate_token_expiry(expires_in: int | str | None) -> float | None: return None # pragma: no cover # Defensive: handle servers that return expires_in as string return time.time() + int(expires_in) + + +def calculate_token_refresh_time( + expires_in: int | str | None, + *, + refresh_fraction: float = 0.8, + max_jitter_seconds: float = 30.0, + jitter: float | None = None, +) -> float | None: + """Calculate when a token should be *proactively* refreshed. + + Reactive refresh (waiting until a token has already expired) means that, for a + fleet of OAuth-backed MCP connectors provisioned around the same time, every + token tends to expire inside the same narrow window. When they do, all of those + clients try to refresh simultaneously, producing a "thundering herd" of refresh + requests against the authorization server -- contention, rate limiting, and + spurious auth failures. + + To avoid that, this returns a timestamp *before* hard expiry at which the token + should be refreshed: + + refresh_at = now + expires_in * refresh_fraction - jitter + + The jitter is always *subtracted* so it pulls the refresh point earlier and can + never push it past the hard-expiry boundary. Spreading each client's refresh + point by a small random amount means a fleet naturally desynchronizes instead of + refreshing in lockstep. + + Args: + expires_in: Seconds until token expiration (may be a string from some servers). + refresh_fraction: Fraction of the token lifetime after which to refresh. + Defaults to 0.8 (refresh once 80% of the lifetime has elapsed). + max_jitter_seconds: Upper bound (in seconds) of the random jitter subtracted + from the refresh point. Defaults to 30s. + jitter: Optional explicit jitter value (seconds). When provided it is used + directly instead of drawing a random value, which keeps the function + deterministic and testable. When None, a value in + ``[0, max_jitter_seconds]`` is drawn at random. + + Returns: + Unix timestamp at which the token should be proactively refreshed, or None + if ``expires_in`` is None (no expiry information -> nothing to schedule). + The result is always in ``(now, hard_expiry]`` and never in the past. + """ + if expires_in is None: + return None + + expires_in_seconds = int(expires_in) + now = time.time() + hard_expiry = now + expires_in_seconds + + # Base proactive point: refresh once `refresh_fraction` of the lifetime elapsed. + refresh_at = now + expires_in_seconds * refresh_fraction + + # Cap the jitter so it can never reach back before `now`, which matters for very + # short TTLs (e.g. expires_in smaller than max_jitter_seconds). The window we are + # allowed to pull earlier into is (refresh_at - now); never jitter more than that. + available_window = refresh_at - now + effective_max_jitter = min(max_jitter_seconds, max(available_window, 0.0)) + + if jitter is None: + applied_jitter = random.uniform(0, effective_max_jitter) + else: + # Clamp an injected jitter into the valid range to preserve invariants. + applied_jitter = min(max(jitter, 0.0), effective_max_jitter) + + refresh_at -= applied_jitter + + # Final guard: keep the result strictly within (now, hard_expiry]. For tiny or + # zero TTLs this collapses gracefully toward `now` rather than going negative or + # past the hard-expiry boundary. + if refresh_at < now: + refresh_at = now + if refresh_at > hard_expiry: + refresh_at = hard_expiry + + return refresh_at diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index ca7a495e6..a53324b97 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -251,6 +251,7 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O context = oauth_provider.context context.current_tokens = valid_tokens context.token_expiry_time = time.time() + 1800 + context.token_refresh_time = time.time() + 1440 # Clear tokens context.clear_tokens() @@ -258,6 +259,42 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O # Verify cleared assert context.current_tokens is None assert context.token_expiry_time is None + assert context.token_refresh_time is None + + @pytest.mark.anyio + async def test_should_refresh_token(self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken): + """Test should_refresh_token() proactive-refresh logic.""" + context = oauth_provider.context + + # No tokens at all -> never proactively refresh. + assert not context.should_refresh_token() + + context.current_tokens = valid_tokens + context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + # Token still hard-valid AND before the jittered refresh point -> no refresh. + context.token_expiry_time = time.time() + 1800 + context.token_refresh_time = time.time() + 600 + assert context.is_token_valid() + assert not context.should_refresh_token() + + # Token still hard-valid but we have passed the proactive refresh point -> refresh. + context.token_refresh_time = time.time() - 1 + assert context.is_token_valid() + assert context.should_refresh_token() + + # No refresh time known (e.g. server gave no expiry) -> fall back to reactive only. + context.token_refresh_time = None + assert not context.should_refresh_token() + + # Past the refresh point but no refresh token -> cannot proactively refresh. + context.token_refresh_time = time.time() - 1 + context.current_tokens.refresh_token = None + assert not context.should_refresh_token() class TestOAuthFlow: @@ -506,6 +543,102 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl except StopAsyncIteration: pass # Expected - generator should complete + @pytest.mark.anyio + async def test_async_auth_flow_proactively_refreshes_when_past_jitter_window( + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken + ): + """async_auth_flow refreshes proactively past the jittered window. + + The token is still hard-valid (is_token_valid() is True), but we are past the + proactive refresh point, so the flow should yield a refresh request *before* + sending the original request -- spreading fleet refreshes out instead of + waiting for hard expiry. + """ + context = oauth_provider.context + context.current_tokens = valid_tokens + context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + # Token is still valid for a while, but we are past the proactive refresh point. + context.token_expiry_time = time.time() + 1800 + context.token_refresh_time = time.time() - 1 + assert context.is_token_valid() + assert context.should_refresh_token() + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First yielded request must be a proactive refresh, not the original request. + refresh_request = await auth_flow.__anext__() + assert refresh_request.method == "POST" + assert str(refresh_request.url) == "https://api.example.com/token" + refresh_content = refresh_request.content.decode() + assert "grant_type=refresh_token" in refresh_content + assert "refresh_token=test_refresh_token" in refresh_content + + # Provide a successful refresh response with fresh tokens. + refresh_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, ' + b'"refresh_token": "new_refresh_token"}' + ), + request=refresh_request, + ) + + # After a successful refresh, the original request is sent with the new token. + actual_request = await auth_flow.asend(refresh_response) + assert actual_request.headers["Authorization"] == "Bearer new_access_token" + assert str(actual_request.url) == "https://api.example.com/v1/mcp" + + # New proactive-refresh point should have been scheduled in the future. + assert context.token_refresh_time is not None + assert context.token_refresh_time > time.time() + + # Close out the generator with a final success response. + final_response = httpx.Response(200, request=actual_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass # Expected - generator completes + + @pytest.mark.anyio + async def test_async_auth_flow_skips_refresh_before_jitter_window( + self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken + ): + """A fresh token (before the proactive window) is used directly, no refresh.""" + context = oauth_provider.context + context.current_tokens = valid_tokens + context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + + # Token valid and well before the proactive refresh point. + context.token_expiry_time = time.time() + 1800 + context.token_refresh_time = time.time() + 600 + assert not context.should_refresh_token() + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First (and only auth-related) yielded request is the original request itself. + actual_request = await auth_flow.__anext__() + assert actual_request.headers["Authorization"] == "Bearer test_access_token" + assert str(actual_request.url) == "https://api.example.com/v1/mcp" + + final_response = httpx.Response(200, request=actual_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass # Expected - generator completes + @pytest.mark.anyio async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider): """Test successful metadata response handling.""" diff --git a/tests/shared/test_auth_utils.py b/tests/shared/test_auth_utils.py index 5ae0e22b0..91ac2ea0a 100644 --- a/tests/shared/test_auth_utils.py +++ b/tests/shared/test_auth_utils.py @@ -1,8 +1,14 @@ """Tests for OAuth 2.0 Resource Indicators utilities.""" +import time + from pydantic import HttpUrl -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url +from mcp.shared.auth_utils import ( + calculate_token_refresh_time, + check_resource_allowed, + resource_url_from_server_url, +) # Tests for resource_url_from_server_url function @@ -121,3 +127,98 @@ def test_check_resource_allowed_empty_paths(): assert check_resource_allowed("https://example.com", "https://example.com") is True assert check_resource_allowed("https://example.com/", "https://example.com") is True assert check_resource_allowed("https://example.com/api", "https://example.com") is True + + +# Tests for calculate_token_refresh_time function + + +def test_calculate_token_refresh_time_none_expires_in(): + """None expires_in means no expiry info -> no refresh schedule.""" + assert calculate_token_refresh_time(None) is None + + +def test_calculate_token_refresh_time_normal_ttl_within_window(): + """For a normal TTL the refresh point falls inside the expected jitter window + and strictly before hard expiry.""" + expires_in = 3600 + before = time.time() + refresh_at = calculate_token_refresh_time(expires_in) + after = time.time() + + assert refresh_at is not None + hard_expiry_lower = before + expires_in + # With default fraction 0.8 and up to 30s of jitter subtracted, the refresh + # point lies in [now + 0.8*ttl - 30, now + 0.8*ttl]. + assert before + expires_in * 0.8 - 30.0 <= refresh_at <= after + expires_in * 0.8 + # Must be strictly before hard expiry and in the future. + assert refresh_at < hard_expiry_lower + assert refresh_at > before + + +def test_calculate_token_refresh_time_accepts_string_expires_in(): + """expires_in may arrive as a string from some servers.""" + refresh_at = calculate_token_refresh_time("3600", jitter=0.0) + now = time.time() + assert refresh_at is not None + # Roughly now + 0.8 * 3600 = now + 2880 (allow small scheduling slack). + assert now + 2880 - 5 <= refresh_at <= now + 2880 + 5 + + +def test_calculate_token_refresh_time_injected_jitter_is_deterministic(): + """Injecting jitter makes the function deterministic/testable.""" + expires_in = 1000 + now = time.time() + refresh_at = calculate_token_refresh_time(expires_in, jitter=10.0) + # now + 0.8*1000 - 10 = now + 790 (allow small scheduling slack). + assert now + 790 - 2 <= refresh_at <= now + 790 + 2 # type: ignore[operator] + + +def test_calculate_token_refresh_time_jitter_pulls_earlier(): + """Larger jitter must produce an earlier (smaller) refresh timestamp.""" + expires_in = 1000 + no_jitter = calculate_token_refresh_time(expires_in, jitter=0.0) + small_jitter = calculate_token_refresh_time(expires_in, jitter=5.0) + big_jitter = calculate_token_refresh_time(expires_in, jitter=25.0) + + assert no_jitter is not None and small_jitter is not None and big_jitter is not None + # Jitter is subtracted, so more jitter -> earlier refresh. + assert big_jitter < small_jitter < no_jitter + + +def test_calculate_token_refresh_time_never_past_hard_expiry(): + """The refresh point is always strictly before hard expiry for positive TTLs.""" + for expires_in in (1, 5, 30, 60, 300, 3600, 86400): + before = time.time() + refresh_at = calculate_token_refresh_time(expires_in, jitter=0.0) + assert refresh_at is not None + assert refresh_at <= before + expires_in + assert refresh_at >= before # never in the past + + +def test_calculate_token_refresh_time_tiny_ttl_no_negative(): + """Very short TTLs (smaller than max jitter) must not go negative or before now.""" + now = time.time() + # 10s TTL with a requested 30s jitter: jitter must be clamped to the + # available window (0.8 * 10 = 8s) so the result stays >= now. + refresh_at = calculate_token_refresh_time(10, max_jitter_seconds=30.0, jitter=30.0) + assert refresh_at is not None + assert refresh_at >= now + assert refresh_at <= now + 10 + + +def test_calculate_token_refresh_time_zero_ttl(): + """A zero TTL collapses to roughly now without going negative.""" + now = time.time() + refresh_at = calculate_token_refresh_time(0) + assert refresh_at is not None + assert now - 1 <= refresh_at <= now + 1 + + +def test_calculate_token_refresh_time_custom_fraction(): + """refresh_fraction controls how far into the lifetime we refresh.""" + expires_in = 1000 + now = time.time() + refresh_at = calculate_token_refresh_time(expires_in, refresh_fraction=0.5, jitter=0.0) + assert refresh_at is not None + # now + 0.5 * 1000 = now + 500 (allow small scheduling slack). + assert now + 500 - 2 <= refresh_at <= now + 500 + 2