Skip to content

Commit 148ec42

Browse files
peisukeBartok9
authored andcommitted
test: cover refresh failure paths and double-check refresh skip
After dropping the function-level `# pragma: no cover` from `_handle_refresh_response` and removing the per-line pragmas from the refactored Phase 2, the strict-no-cover audit identified covered lines still marked pragma'd and surfaced previously-untested branches. Three new tests close the coverage gaps: * `test_refresh_with_failed_status_clears_tokens` — exercises the ``response.status_code != 200`` branch in `_handle_refresh_response` and the `self._initialized = False` reset on refresh failure. * `test_refresh_with_invalid_json_clears_tokens` — exercises the ValidationError branch when the refresh body is not valid JSON. * `test_double_check_inside_refresh_lock_skips_second_refresh` — uses monkeypatch to flip `is_token_valid` between Phase 1 (False) and the double-check inside `refresh_lock` (True), exercising the branch where a second coroutine's refresh is correctly elided. Also: convert the new tests from the legacy Test* class pattern to plain top-level `test_*` functions per AGENTS.md, and drop unneeded per-line `# pragma: no cover` markers in the refactored auth_flow. Coverage report: 100.00% on `src/mcp/client/auth/oauth2.py`, strict-no-cover clean.
1 parent b731b2d commit 148ec42

2 files changed

Lines changed: 191 additions & 106 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
543543
refresh_request: httpx.Request | None = None
544544
async with self.context.lock:
545545
if not self.context.is_token_valid() and self.context.can_refresh_token():
546-
refresh_request = await self._refresh_token() # pragma: no cover
546+
refresh_request = await self._refresh_token()
547547
if refresh_request is not None:
548548
# yield runs outside any lock so a long network round trip
549549
# does not block unrelated concurrent requests.
550-
refresh_response = yield refresh_request # pragma: no cover
550+
refresh_response = yield refresh_request
551551
async with self.context.lock:
552-
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
552+
if not await self._handle_refresh_response(refresh_response):
553553
# Refresh failed; fall through to 401 handling below.
554554
self._initialized = False
555555

tests/client/test_auth.py

Lines changed: 188 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -2640,114 +2640,199 @@ async def callback_handler() -> tuple[str, str | None]:
26402640
pass
26412641

26422642

2643-
class TestConcurrentRequestsDoNotDeadlock:
2644-
"""Regression tests for #1326.
2643+
@pytest.mark.anyio
2644+
async def test_concurrent_request_not_blocked_by_pending_long_running_request(
2645+
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2646+
):
2647+
"""Regression for #1326: a second request reaches its yield while the
2648+
first is still suspended (= simulating a server-side long-poll).
26452649
2646-
Ensures that ``OAuthClientProvider.async_auth_flow`` does not serialize
2647-
concurrent unrelated requests behind a long-running one (e.g. GET SSE
2648-
long-poll). The fix narrows ``context.lock`` to state mutation only; the
2649-
actual ``yield request`` runs outside any lock.
2650+
Before the lock-scope fix, ``async_auth_flow`` held ``context.lock``
2651+
across ``yield request``. A GET SSE long-poll would therefore hold the
2652+
lock for the entire SSE lifetime, blocking any concurrent request
2653+
waiting on the same provider's lock.
26502654
"""
2655+
# Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2656+
# flow (Phase 4) is triggered — we exercise the steady-state Phase 3
2657+
# yield path that previously held the lock.
2658+
oauth_provider.context.current_tokens = valid_tokens
2659+
oauth_provider.context.token_expiry_time = time.time() + 1800
2660+
oauth_provider.context.client_info = OAuthClientInformationFull(
2661+
client_id="test_client_id",
2662+
client_secret="test_client_secret",
2663+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2664+
)
2665+
oauth_provider._initialized = True
26512666

2652-
@pytest.mark.anyio
2653-
async def test_concurrent_request_not_blocked_by_pending_long_running_request(
2654-
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2655-
):
2656-
"""A second request must reach its yield while the first is still
2657-
suspended at its yield (= simulating a server-side long-poll).
2658-
2659-
Before this fix, ``async_auth_flow`` held ``context.lock`` across
2660-
``yield request``. A GET SSE long-poll would therefore hold the lock
2661-
for the entire SSE lifetime, blocking any concurrent request waiting
2662-
on the same provider's lock and producing a multi-second stall.
2663-
"""
2664-
# Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2665-
# flow (Phase 4) is triggered — we want to exercise the steady-state
2666-
# Phase 3 yield path that previously held the lock.
2667-
oauth_provider.context.current_tokens = valid_tokens
2668-
oauth_provider.context.token_expiry_time = time.time() + 1800
2669-
oauth_provider.context.client_info = OAuthClientInformationFull(
2670-
client_id="test_client_id",
2671-
client_secret="test_client_secret",
2672-
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2673-
)
2674-
oauth_provider._initialized = True
2667+
# Flow 1: drive to yield, then leave suspended (simulating long-poll).
2668+
slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
2669+
slow_flow = oauth_provider.async_auth_flow(slow_request)
2670+
yielded_slow = await slow_flow.__anext__()
2671+
assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token"
26752672

2676-
# Flow 1: simulate a slow request. Drive it to its yield, then
2677-
# deliberately do not send a response — it stays suspended at the
2678-
# yield, just like a GET SSE long-poll waiting for the next event.
2679-
slow_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
2680-
slow_flow = oauth_provider.async_auth_flow(slow_request)
2681-
yielded_slow = await slow_flow.__anext__()
2682-
assert yielded_slow.headers.get("Authorization") == "Bearer test_access_token"
2683-
2684-
# Flow 2: a concurrent request on the same provider. With the fix,
2685-
# context.lock is not held during Flow 1's yield, so Flow 2 reaches
2686-
# its yield almost immediately. Without the fix, this would block
2687-
# until Flow 1 receives a response — i.e., it would hit the timeout.
2688-
fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2689-
fast_flow = oauth_provider.async_auth_flow(fast_request)
2690-
with anyio.fail_after(1.0):
2691-
yielded_fast = await fast_flow.__anext__()
2692-
assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token"
2693-
2694-
# Clean up both generators in deterministic order.
2695-
with contextlib.suppress(StopAsyncIteration):
2696-
await fast_flow.asend(httpx.Response(200, request=yielded_fast))
2697-
with contextlib.suppress(StopAsyncIteration):
2698-
await slow_flow.asend(httpx.Response(200, request=yielded_slow))
2673+
# Flow 2: concurrent request. With the fix this reaches its yield
2674+
# immediately; without the fix it would block on context.lock.
2675+
fast_request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2676+
fast_flow = oauth_provider.async_auth_flow(fast_request)
2677+
with anyio.fail_after(5):
2678+
yielded_fast = await fast_flow.__anext__()
2679+
assert yielded_fast.headers.get("Authorization") == "Bearer test_access_token"
26992680

2700-
@pytest.mark.anyio
2701-
async def test_concurrent_token_refresh_is_single_flight(
2702-
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2703-
):
2704-
"""When concurrent requests both observe an expired token, only one
2705-
refresh request is sent: ``refresh_lock`` provides single-flight
2706-
semantics so the second waiter re-checks state and proceeds without
2707-
re-triggering refresh.
2708-
"""
2709-
# Mark the token as expired so the next auth_flow run enters Phase 2.
2710-
oauth_provider.context.current_tokens = valid_tokens
2711-
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2712-
oauth_provider.context.client_info = OAuthClientInformationFull(
2713-
client_id="test_client_id",
2714-
client_secret="test_client_secret",
2715-
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2716-
)
2717-
oauth_provider._initialized = True
2681+
with contextlib.suppress(StopAsyncIteration):
2682+
await fast_flow.asend(httpx.Response(200, request=yielded_fast))
2683+
with contextlib.suppress(StopAsyncIteration):
2684+
await slow_flow.asend(httpx.Response(200, request=yielded_slow))
27182685

2719-
# Flow A: drive it to the refresh yield and suspend there.
2720-
request_a = httpx.Request("GET", "https://api.example.com/v1/mcp")
2721-
flow_a = oauth_provider.async_auth_flow(request_a)
2722-
refresh_a = await flow_a.__anext__()
2723-
assert "grant_type=refresh_token" in refresh_a.read().decode()
27242686

2725-
# Complete Flow A's refresh with a fresh token.
2726-
refresh_response = httpx.Response(
2727-
200,
2728-
content=(
2729-
b'{"access_token": "new_access_token", "token_type": "Bearer", '
2730-
b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2731-
),
2732-
request=refresh_a,
2733-
)
2734-
request_a_post = await flow_a.asend(refresh_response)
2735-
assert request_a_post.headers.get("Authorization") == "Bearer new_access_token"
2736-
2737-
# Flow B starts after Flow A's refresh has completed. Because token
2738-
# state was updated under context.lock, Flow B observes the fresh
2739-
# token in Phase 1, skips Phase 2 entirely, and reaches its yield
2740-
# directly. No second refresh request is sent.
2741-
request_b = httpx.Request("POST", "https://api.example.com/v1/mcp")
2742-
flow_b = oauth_provider.async_auth_flow(request_b)
2743-
with anyio.fail_after(1.0):
2744-
request_b_yielded = await flow_b.__anext__()
2745-
assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token"
2746-
# Confirm Flow B yielded the original POST, not a refresh request.
2747-
assert request_b_yielded.method == "POST"
2748-
2749-
# Clean up.
2750-
with contextlib.suppress(StopAsyncIteration):
2751-
await flow_b.asend(httpx.Response(200, request=request_b_yielded))
2687+
@pytest.mark.anyio
2688+
async def test_refresh_lock_double_check_skips_redundant_refresh(
2689+
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
2690+
):
2691+
"""Two flows enter Phase 2 with an expired token. After the first
2692+
completes a refresh, the second observes the fresh token via the
2693+
Phase 2 double-check inside ``refresh_lock`` (or directly in Phase 1
2694+
if it arrives late) and skips its own refresh.
2695+
"""
2696+
oauth_provider.context.current_tokens = valid_tokens
2697+
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2698+
oauth_provider.context.client_info = OAuthClientInformationFull(
2699+
client_id="test_client_id",
2700+
client_secret="test_client_secret",
2701+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2702+
)
2703+
oauth_provider._initialized = True
2704+
2705+
# Flow A: drive to refresh yield, then complete refresh.
2706+
request_a = httpx.Request("GET", "https://api.example.com/v1/mcp")
2707+
flow_a = oauth_provider.async_auth_flow(request_a)
2708+
refresh_a = await flow_a.__anext__()
2709+
assert "grant_type=refresh_token" in refresh_a.read().decode()
2710+
2711+
refresh_response = httpx.Response(
2712+
200,
2713+
content=(
2714+
b'{"access_token": "new_access_token", "token_type": "Bearer", '
2715+
b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2716+
),
2717+
request=refresh_a,
2718+
)
2719+
request_a_post = await flow_a.asend(refresh_response)
2720+
assert request_a_post.headers.get("Authorization") == "Bearer new_access_token"
2721+
2722+
# Flow B: state already refreshed; Phase 1 sees valid token, skips Phase 2.
2723+
request_b = httpx.Request("POST", "https://api.example.com/v1/mcp")
2724+
flow_b = oauth_provider.async_auth_flow(request_b)
2725+
with anyio.fail_after(5):
2726+
request_b_yielded = await flow_b.__anext__()
2727+
assert request_b_yielded.method == "POST"
2728+
assert request_b_yielded.headers.get("Authorization") == "Bearer new_access_token"
2729+
2730+
with contextlib.suppress(StopAsyncIteration):
2731+
await flow_b.asend(httpx.Response(200, request=request_b_yielded))
2732+
with contextlib.suppress(StopAsyncIteration):
2733+
await flow_a.asend(httpx.Response(200, request=request_a_post))
2734+
2735+
2736+
@pytest.mark.anyio
2737+
async def test_refresh_with_failed_status_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
2738+
"""A non-2xx refresh response clears stored tokens and marks the provider
2739+
uninitialized so the next request triggers a full OAuth flow.
2740+
"""
2741+
oauth_provider.context.current_tokens = valid_tokens
2742+
oauth_provider.context.token_expiry_time = time.time() - 100
2743+
oauth_provider.context.client_info = OAuthClientInformationFull(
2744+
client_id="test_client_id",
2745+
client_secret="test_client_secret",
2746+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2747+
)
2748+
oauth_provider._initialized = True
2749+
2750+
request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2751+
flow = oauth_provider.async_auth_flow(request)
2752+
refresh_request = await flow.__anext__()
2753+
assert "grant_type=refresh_token" in refresh_request.read().decode()
2754+
2755+
# Refresh server returns 401.
2756+
refresh_response = httpx.Response(401, content=b'{"error": "invalid_grant"}', request=refresh_request)
2757+
with contextlib.suppress(StopAsyncIteration):
2758+
# After failed refresh, the flow proceeds to Phase 3 yielding the
2759+
# original request without a fresh Authorization header. We don't
2760+
# exercise the subsequent 401/full OAuth path here.
2761+
await flow.asend(refresh_response)
2762+
2763+
assert oauth_provider.context.current_tokens is None
2764+
2765+
2766+
@pytest.mark.anyio
2767+
async def test_refresh_with_invalid_json_clears_tokens(oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken):
2768+
"""A 200 refresh response with a malformed body clears stored tokens —
2769+
the pydantic ValidationError branch is taken.
2770+
"""
2771+
oauth_provider.context.current_tokens = valid_tokens
2772+
oauth_provider.context.token_expiry_time = time.time() - 100
2773+
oauth_provider.context.client_info = OAuthClientInformationFull(
2774+
client_id="test_client_id",
2775+
client_secret="test_client_secret",
2776+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2777+
)
2778+
oauth_provider._initialized = True
2779+
2780+
request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2781+
flow = oauth_provider.async_auth_flow(request)
2782+
refresh_request = await flow.__anext__()
2783+
2784+
# Body does not parse as OAuthToken.
2785+
refresh_response = httpx.Response(200, content=b"not json", request=refresh_request)
2786+
with contextlib.suppress(StopAsyncIteration):
2787+
await flow.asend(refresh_response)
2788+
2789+
assert oauth_provider.context.current_tokens is None
2790+
2791+
2792+
@pytest.mark.anyio
2793+
async def test_double_check_inside_refresh_lock_skips_second_refresh(
2794+
oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken, monkeypatch: pytest.MonkeyPatch
2795+
):
2796+
"""Exercise the double-check branch inside ``refresh_lock``: ``is_token_valid``
2797+
returns False in Phase 1 (= the flow decides to refresh) but True inside
2798+
the inner ``context.lock`` block (= another coroutine refreshed while we
2799+
were waiting on ``refresh_lock``). The flow must skip ``_refresh_token``
2800+
and proceed straight to Phase 3.
2801+
"""
2802+
oauth_provider.context.current_tokens = valid_tokens
2803+
oauth_provider.context.token_expiry_time = time.time() - 100 # expired
2804+
oauth_provider.context.client_info = OAuthClientInformationFull(
2805+
client_id="test_client_id",
2806+
client_secret="test_client_secret",
2807+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2808+
)
2809+
oauth_provider._initialized = True
2810+
2811+
# Toggle is_token_valid: False on the first call (Phase 1 decision),
2812+
# True on the second (double-check inside refresh_lock).
2813+
call_count = {"n": 0}
2814+
original_is_valid = oauth_provider.context.__class__.is_token_valid
2815+
2816+
def fake_is_token_valid(self: object) -> bool:
2817+
call_count["n"] += 1
2818+
if call_count["n"] == 1:
2819+
return False
2820+
# By the second call, "another coroutine" refreshed; reset token expiry
2821+
# so callers downstream see a valid token.
2822+
oauth_provider.context.token_expiry_time = time.time() + 1800
2823+
return True
2824+
2825+
monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", fake_is_token_valid)
2826+
try:
2827+
request = httpx.Request("POST", "https://api.example.com/v1/mcp")
2828+
flow = oauth_provider.async_auth_flow(request)
2829+
# No refresh yield is expected — the flow goes directly to its own
2830+
# request yield with the (now-valid) token header attached.
2831+
with anyio.fail_after(5):
2832+
yielded = await flow.__anext__()
2833+
assert yielded.method == "POST"
2834+
assert yielded.headers.get("Authorization") == "Bearer test_access_token"
27522835
with contextlib.suppress(StopAsyncIteration):
2753-
await flow_a.asend(httpx.Response(200, request=request_a_post))
2836+
await flow.asend(httpx.Response(200, request=yielded))
2837+
finally:
2838+
monkeypatch.setattr(oauth_provider.context.__class__, "is_token_valid", original_is_valid)

0 commit comments

Comments
 (0)