Skip to content

Conversation

@aabmass
Copy link

@aabmass aabmass commented Feb 4, 2026

Fixes #1969
Part of #421

This PR propagates contextvars.Context through internal anyio memory streams and restores it in long running async tasks, so that they run with the correct context. Works for in memory, streamable http, and sse transports.

  • For MCP clients, this captures the caller's active context and restores it before using the transport.
  • For MCP servers, it captures the active context from the transport and restores it before using the Server's handlers

Motivation and Context

The goal is to make contextvars "work" with the MCP SDK as users expect. The main motivation is OpenTelemetry #421, but there are other use cases e.g. per-request auth, stdlib decimal, starlette-context, etc.

How Has This Been Tested?

Added several integration tests which are failing at main. I also manually tested e2e with the some of the samples by using OpenTelemetry httpx instrumentation (I can share screenshots).

Tests failures at main

═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════ inline-snapshot ═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════════
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── Fix snapshots ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
╭───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── tests/test_context_propagation.py ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ @@ -83,7 +83,7 @@                                                                                                                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                             │
│              result = await client.call_tool(name="my_tool")                                                                                                                                                                                                                                                                                                │
│                                                                                                                                                                                                                                                                                                                                                             │
│              assert isinstance(result, types.CallToolResult)                                                                                                                                                                                                                                                                                                │
│ -            assert result.content == snapshot([types.TextContent(text="client_value")])                                                                                                                                                                                                                                                                    │
│ +            assert result.content == snapshot([types.TextContent(text="initial")])                                                                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                                                                                                             │
│                                                                                                                                                                                                                                                                                                                                                             │
│  class ContextVarMiddleware(BaseHTTPMiddleware):  # pragma: lax no cover                                                                                                                                                                                                                                                                                    │
│ @@ -127,7 +127,7 @@                                                                                                                                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                                                                                                             │
│          ):                                                                                                                                                                                                                                                                                                                                                 │
│              transport.client_header_value = "expected_value"                                                                                                                                                                                                                                                                                               │
│              result = await client.call_tool("my_tool")                                                                                                                                                                                                                                                                                                     │
│ -            assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=expected_value")])                                                                                                                                                                                                                                    │
│ +            assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=initial")])                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                             │
│                                                                                                                                                                                                                                                                                                                                                             │
│  @pytest.mark.anyio                                                                                                                                                                                                                                                                                                                                         │
│ @@ -142,7 +142,7 @@                                                                                                                                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                                                                                                             │
│          ):                                                                                                                                                                                                                                                                                                                                                 │
│              with set_test_contextvar("client_value_list"):                                                                                                                                                                                                                                                                                                 │
│                  await client.list_tools()                                                                                                                                                                                                                                                                                                                  │
│ -                assert transport.captured_context_var == snapshot("client_value_list")                                                                                                                                                                                                                                                                     │
│ +                assert transport.captured_context_var == snapshot("initial")                                                                                                                                                                                                                                                                               │
│                                                                                                                                                                                                                                                                                                                                                             │
│              with set_test_contextvar("client_value_call_tool"):                                                                                                                                                                                                                                                                                            │
│                  await client.call_tool("my_tool")                                                                                                                                                                                                                                                                                                          │
│ @@ -167,7 +167,7 @@                                                                                                                                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                                                                                                             │
│          ) as client:                                                                                                                                                                                                                                                                                                                                       │
│              transport.client_header_value = "expected_value"                                                                                                                                                                                                                                                                                               │
│              result = await client.call_tool("my_tool")                                                                                                                                                                                                                                                                                                     │
│ -            assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=expected_value")])                                                                                                                                                                                                                                    │
│ +            assert result.content == snapshot([types.TextContent(text="from middleware CLIENT_HEADER=initial")])                                                                                                                                                                                                                                           │
│                                                                                                                                                                                                                                                                                                                                                             │
│                                                                                                                                                                                                                                                                                                                                                             │
│  @pytest.mark.anyio                                                                                                                                                                                                                                                                                                                                         │
│ @@ -187,7 +187,7 @@                                                                                                                                                                                                                                                                                                                                         │
│                                                                                                                                                                                                                                                                                                                                                             │
│          ) as client:                                                                                                                                                                                                                                                                                                                                       │
│              with set_test_contextvar("client_value_list"):                                                                                                                                                                                                                                                                                                 │
│                  await client.list_tools()                                                                                                                                                                                                                                                                                                                  │
│ -                assert transport.captured_context_var == snapshot("client_value_list")                                                                                                                                                                                                                                                                     │
│ +                assert transport.captured_context_var == snapshot("initial")                                                                                                                                                                                                                                                                               │
│                                                                                                                                                                                                                                                                                                                                                             │
│              with set_test_contextvar("client_value_call_tool"):                                                                                                                                                                                                                                                                                            │
│                  await client.call_tool("my_tool")                                                                                                                                                                                                                                                                                                          │
╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
These changes are not applied.
Use --inline-snapshot=fix to apply them, or use the interactive mode with --inline-snapshot=review

Breaking Changes

No

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation update

Checklist

  • I have read the MCP Documentation
  • My code follows the repository's style guidelines
  • New and existing tests pass locally
  • I have added appropriate error handling
  • I have added or updated documentation as needed

Additional context

N/A

@aabmass aabmass force-pushed the propagate-contextvars branch 6 times, most recently from d4b03e2 to 8cdb649 Compare February 6, 2026 03:44
@aabmass aabmass force-pushed the propagate-contextvars branch 2 times, most recently from e3adab2 to ba80bbb Compare February 9, 2026 14:53
See if tests still fail
This fix covers in memory, streamable http, and sse transports.
@aabmass aabmass force-pushed the propagate-contextvars branch from ba80bbb to b936a62 Compare February 10, 2026 03:27
@aabmass aabmass marked this pull request as ready for review February 10, 2026 03:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Propagate ContextVars to Transport Layer in MCP Clients

1 participant