@@ -2640,114 +2640,199 @@ async def callback_handler() -> tuple[str, str | None]:
26402640 pass
26412641
26422642
2643- class TestConcurrentRequestsDoNotDeadlock :
2644- """Regression tests for #1326.
2643+ @pytest .mark .anyio
2644+ async def test_concurrent_request_not_blocked_by_pending_long_running_request (
2645+ oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2646+ ):
2647+ """Regression for #1326: a second request reaches its yield while the
2648+ first is still suspended (= simulating a server-side long-poll).
26452649
2646- Ensures that ``OAuthClientProvider. async_auth_flow`` does not serialize
2647- concurrent unrelated requests behind a long-running one (e.g. GET SSE
2648- long-poll). The fix narrows ``context.lock`` to state mutation only; the
2649- actual ``yield request`` runs outside any lock.
2650+ Before the lock-scope fix, `` async_auth_flow`` held ``context.lock``
2651+ across ``yield request``. A GET SSE long-poll would therefore hold the
2652+ lock for the entire SSE lifetime, blocking any concurrent request
2653+ waiting on the same provider's lock.
26502654 """
2655+ # Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2656+ # flow (Phase 4) is triggered — we exercise the steady-state Phase 3
2657+ # yield path that previously held the lock.
2658+ oauth_provider .context .current_tokens = valid_tokens
2659+ oauth_provider .context .token_expiry_time = time .time () + 1800
2660+ oauth_provider .context .client_info = OAuthClientInformationFull (
2661+ client_id = "test_client_id" ,
2662+ client_secret = "test_client_secret" ,
2663+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2664+ )
2665+ oauth_provider ._initialized = True
26512666
2652- @pytest .mark .anyio
2653- async def test_concurrent_request_not_blocked_by_pending_long_running_request (
2654- self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2655- ):
2656- """A second request must reach its yield while the first is still
2657- suspended at its yield (= simulating a server-side long-poll).
2658-
2659- Before this fix, ``async_auth_flow`` held ``context.lock`` across
2660- ``yield request``. A GET SSE long-poll would therefore hold the lock
2661- for the entire SSE lifetime, blocking any concurrent request waiting
2662- on the same provider's lock and producing a multi-second stall.
2663- """
2664- # Set up valid tokens so neither refresh (Phase 2) nor full OAuth
2665- # flow (Phase 4) is triggered — we want to exercise the steady-state
2666- # Phase 3 yield path that previously held the lock.
2667- oauth_provider .context .current_tokens = valid_tokens
2668- oauth_provider .context .token_expiry_time = time .time () + 1800
2669- oauth_provider .context .client_info = OAuthClientInformationFull (
2670- client_id = "test_client_id" ,
2671- client_secret = "test_client_secret" ,
2672- redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2673- )
2674- oauth_provider ._initialized = True
2667+ # Flow 1: drive to yield, then leave suspended (simulating long-poll).
2668+ slow_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2669+ slow_flow = oauth_provider .async_auth_flow (slow_request )
2670+ yielded_slow = await slow_flow .__anext__ ()
2671+ assert yielded_slow .headers .get ("Authorization" ) == "Bearer test_access_token"
26752672
2676- # Flow 1: simulate a slow request. Drive it to its yield, then
2677- # deliberately do not send a response — it stays suspended at the
2678- # yield, just like a GET SSE long-poll waiting for the next event.
2679- slow_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2680- slow_flow = oauth_provider .async_auth_flow (slow_request )
2681- yielded_slow = await slow_flow .__anext__ ()
2682- assert yielded_slow .headers .get ("Authorization" ) == "Bearer test_access_token"
2683-
2684- # Flow 2: a concurrent request on the same provider. With the fix,
2685- # context.lock is not held during Flow 1's yield, so Flow 2 reaches
2686- # its yield almost immediately. Without the fix, this would block
2687- # until Flow 1 receives a response — i.e., it would hit the timeout.
2688- fast_request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2689- fast_flow = oauth_provider .async_auth_flow (fast_request )
2690- with anyio .fail_after (1.0 ):
2691- yielded_fast = await fast_flow .__anext__ ()
2692- assert yielded_fast .headers .get ("Authorization" ) == "Bearer test_access_token"
2693-
2694- # Clean up both generators in deterministic order.
2695- with contextlib .suppress (StopAsyncIteration ):
2696- await fast_flow .asend (httpx .Response (200 , request = yielded_fast ))
2697- with contextlib .suppress (StopAsyncIteration ):
2698- await slow_flow .asend (httpx .Response (200 , request = yielded_slow ))
2673+ # Flow 2: concurrent request. With the fix this reaches its yield
2674+ # immediately; without the fix it would block on context.lock.
2675+ fast_request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2676+ fast_flow = oauth_provider .async_auth_flow (fast_request )
2677+ with anyio .fail_after (5 ):
2678+ yielded_fast = await fast_flow .__anext__ ()
2679+ assert yielded_fast .headers .get ("Authorization" ) == "Bearer test_access_token"
26992680
2700- @pytest .mark .anyio
2701- async def test_concurrent_token_refresh_is_single_flight (
2702- self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2703- ):
2704- """When concurrent requests both observe an expired token, only one
2705- refresh request is sent: ``refresh_lock`` provides single-flight
2706- semantics so the second waiter re-checks state and proceeds without
2707- re-triggering refresh.
2708- """
2709- # Mark the token as expired so the next auth_flow run enters Phase 2.
2710- oauth_provider .context .current_tokens = valid_tokens
2711- oauth_provider .context .token_expiry_time = time .time () - 100 # expired
2712- oauth_provider .context .client_info = OAuthClientInformationFull (
2713- client_id = "test_client_id" ,
2714- client_secret = "test_client_secret" ,
2715- redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2716- )
2717- oauth_provider ._initialized = True
2681+ with contextlib .suppress (StopAsyncIteration ):
2682+ await fast_flow .asend (httpx .Response (200 , request = yielded_fast ))
2683+ with contextlib .suppress (StopAsyncIteration ):
2684+ await slow_flow .asend (httpx .Response (200 , request = yielded_slow ))
27182685
2719- # Flow A: drive it to the refresh yield and suspend there.
2720- request_a = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2721- flow_a = oauth_provider .async_auth_flow (request_a )
2722- refresh_a = await flow_a .__anext__ ()
2723- assert "grant_type=refresh_token" in refresh_a .read ().decode ()
27242686
2725- # Complete Flow A's refresh with a fresh token.
2726- refresh_response = httpx .Response (
2727- 200 ,
2728- content = (
2729- b'{"access_token": "new_access_token", "token_type": "Bearer", '
2730- b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2731- ),
2732- request = refresh_a ,
2733- )
2734- request_a_post = await flow_a .asend (refresh_response )
2735- assert request_a_post .headers .get ("Authorization" ) == "Bearer new_access_token"
2736-
2737- # Flow B starts after Flow A's refresh has completed. Because token
2738- # state was updated under context.lock, Flow B observes the fresh
2739- # token in Phase 1, skips Phase 2 entirely, and reaches its yield
2740- # directly. No second refresh request is sent.
2741- request_b = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2742- flow_b = oauth_provider .async_auth_flow (request_b )
2743- with anyio .fail_after (1.0 ):
2744- request_b_yielded = await flow_b .__anext__ ()
2745- assert request_b_yielded .headers .get ("Authorization" ) == "Bearer new_access_token"
2746- # Confirm Flow B yielded the original POST, not a refresh request.
2747- assert request_b_yielded .method == "POST"
2748-
2749- # Clean up.
2750- with contextlib .suppress (StopAsyncIteration ):
2751- await flow_b .asend (httpx .Response (200 , request = request_b_yielded ))
2687+ @pytest .mark .anyio
2688+ async def test_refresh_lock_double_check_skips_redundant_refresh (
2689+ oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
2690+ ):
2691+ """Two flows enter Phase 2 with an expired token. After the first
2692+ completes a refresh, the second observes the fresh token via the
2693+ Phase 2 double-check inside ``refresh_lock`` (or directly in Phase 1
2694+ if it arrives late) and skips its own refresh.
2695+ """
2696+ oauth_provider .context .current_tokens = valid_tokens
2697+ oauth_provider .context .token_expiry_time = time .time () - 100 # expired
2698+ oauth_provider .context .client_info = OAuthClientInformationFull (
2699+ client_id = "test_client_id" ,
2700+ client_secret = "test_client_secret" ,
2701+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2702+ )
2703+ oauth_provider ._initialized = True
2704+
2705+ # Flow A: drive to refresh yield, then complete refresh.
2706+ request_a = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
2707+ flow_a = oauth_provider .async_auth_flow (request_a )
2708+ refresh_a = await flow_a .__anext__ ()
2709+ assert "grant_type=refresh_token" in refresh_a .read ().decode ()
2710+
2711+ refresh_response = httpx .Response (
2712+ 200 ,
2713+ content = (
2714+ b'{"access_token": "new_access_token", "token_type": "Bearer", '
2715+ b'"expires_in": 3600, "refresh_token": "new_refresh_token"}'
2716+ ),
2717+ request = refresh_a ,
2718+ )
2719+ request_a_post = await flow_a .asend (refresh_response )
2720+ assert request_a_post .headers .get ("Authorization" ) == "Bearer new_access_token"
2721+
2722+ # Flow B: state already refreshed; Phase 1 sees valid token, skips Phase 2.
2723+ request_b = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2724+ flow_b = oauth_provider .async_auth_flow (request_b )
2725+ with anyio .fail_after (5 ):
2726+ request_b_yielded = await flow_b .__anext__ ()
2727+ assert request_b_yielded .method == "POST"
2728+ assert request_b_yielded .headers .get ("Authorization" ) == "Bearer new_access_token"
2729+
2730+ with contextlib .suppress (StopAsyncIteration ):
2731+ await flow_b .asend (httpx .Response (200 , request = request_b_yielded ))
2732+ with contextlib .suppress (StopAsyncIteration ):
2733+ await flow_a .asend (httpx .Response (200 , request = request_a_post ))
2734+
2735+
2736+ @pytest .mark .anyio
2737+ async def test_refresh_with_failed_status_clears_tokens (oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
2738+ """A non-2xx refresh response clears stored tokens and marks the provider
2739+ uninitialized so the next request triggers a full OAuth flow.
2740+ """
2741+ oauth_provider .context .current_tokens = valid_tokens
2742+ oauth_provider .context .token_expiry_time = time .time () - 100
2743+ oauth_provider .context .client_info = OAuthClientInformationFull (
2744+ client_id = "test_client_id" ,
2745+ client_secret = "test_client_secret" ,
2746+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2747+ )
2748+ oauth_provider ._initialized = True
2749+
2750+ request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2751+ flow = oauth_provider .async_auth_flow (request )
2752+ refresh_request = await flow .__anext__ ()
2753+ assert "grant_type=refresh_token" in refresh_request .read ().decode ()
2754+
2755+ # Refresh server returns 401.
2756+ refresh_response = httpx .Response (401 , content = b'{"error": "invalid_grant"}' , request = refresh_request )
2757+ with contextlib .suppress (StopAsyncIteration ):
2758+ # After failed refresh, the flow proceeds to Phase 3 yielding the
2759+ # original request without a fresh Authorization header. We don't
2760+ # exercise the subsequent 401/full OAuth path here.
2761+ await flow .asend (refresh_response )
2762+
2763+ assert oauth_provider .context .current_tokens is None
2764+
2765+
2766+ @pytest .mark .anyio
2767+ async def test_refresh_with_invalid_json_clears_tokens (oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken ):
2768+ """A 200 refresh response with a malformed body clears stored tokens —
2769+ the pydantic ValidationError branch is taken.
2770+ """
2771+ oauth_provider .context .current_tokens = valid_tokens
2772+ oauth_provider .context .token_expiry_time = time .time () - 100
2773+ oauth_provider .context .client_info = OAuthClientInformationFull (
2774+ client_id = "test_client_id" ,
2775+ client_secret = "test_client_secret" ,
2776+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2777+ )
2778+ oauth_provider ._initialized = True
2779+
2780+ request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2781+ flow = oauth_provider .async_auth_flow (request )
2782+ refresh_request = await flow .__anext__ ()
2783+
2784+ # Body does not parse as OAuthToken.
2785+ refresh_response = httpx .Response (200 , content = b"not json" , request = refresh_request )
2786+ with contextlib .suppress (StopAsyncIteration ):
2787+ await flow .asend (refresh_response )
2788+
2789+ assert oauth_provider .context .current_tokens is None
2790+
2791+
2792+ @pytest .mark .anyio
2793+ async def test_double_check_inside_refresh_lock_skips_second_refresh (
2794+ oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken , monkeypatch : pytest .MonkeyPatch
2795+ ):
2796+ """Exercise the double-check branch inside ``refresh_lock``: ``is_token_valid``
2797+ returns False in Phase 1 (= the flow decides to refresh) but True inside
2798+ the inner ``context.lock`` block (= another coroutine refreshed while we
2799+ were waiting on ``refresh_lock``). The flow must skip ``_refresh_token``
2800+ and proceed straight to Phase 3.
2801+ """
2802+ oauth_provider .context .current_tokens = valid_tokens
2803+ oauth_provider .context .token_expiry_time = time .time () - 100 # expired
2804+ oauth_provider .context .client_info = OAuthClientInformationFull (
2805+ client_id = "test_client_id" ,
2806+ client_secret = "test_client_secret" ,
2807+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2808+ )
2809+ oauth_provider ._initialized = True
2810+
2811+ # Toggle is_token_valid: False on the first call (Phase 1 decision),
2812+ # True on the second (double-check inside refresh_lock).
2813+ call_count = {"n" : 0 }
2814+ original_is_valid = oauth_provider .context .__class__ .is_token_valid
2815+
2816+ def fake_is_token_valid (self : object ) -> bool :
2817+ call_count ["n" ] += 1
2818+ if call_count ["n" ] == 1 :
2819+ return False
2820+ # By the second call, "another coroutine" refreshed; reset token expiry
2821+ # so callers downstream see a valid token.
2822+ oauth_provider .context .token_expiry_time = time .time () + 1800
2823+ return True
2824+
2825+ monkeypatch .setattr (oauth_provider .context .__class__ , "is_token_valid" , fake_is_token_valid )
2826+ try :
2827+ request = httpx .Request ("POST" , "https://api.example.com/v1/mcp" )
2828+ flow = oauth_provider .async_auth_flow (request )
2829+ # No refresh yield is expected — the flow goes directly to its own
2830+ # request yield with the (now-valid) token header attached.
2831+ with anyio .fail_after (5 ):
2832+ yielded = await flow .__anext__ ()
2833+ assert yielded .method == "POST"
2834+ assert yielded .headers .get ("Authorization" ) == "Bearer test_access_token"
27522835 with contextlib .suppress (StopAsyncIteration ):
2753- await flow_a .asend (httpx .Response (200 , request = request_a_post ))
2836+ await flow .asend (httpx .Response (200 , request = yielded ))
2837+ finally :
2838+ monkeypatch .setattr (oauth_provider .context .__class__ , "is_token_valid" , original_is_valid )
0 commit comments