Skip to content

Commit ba80bbb

Browse files
committed
Propagate contextvars through anyio streams
TODO: - Update a recipe to show it working - Consider adding an integration test of some kind
1 parent 02c664c commit ba80bbb

File tree

4 files changed

+140
-40
lines changed

4 files changed

+140
-40
lines changed

src/mcp/client/streamable_http.py

Lines changed: 115 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
LAST_EVENT_ID = "last-event-id"
4646

4747
# Reconnection defaults
48-
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
48+
DEFAULT_RECONNECTION_DELAY_MS = (
49+
1000 # 1 second fallback when server doesn't provide retry
50+
)
4951
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
5052

5153

@@ -104,7 +106,10 @@ def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
104106

105107
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
106108
"""Check if the message is an initialized notification."""
107-
return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized"
109+
return (
110+
isinstance(message, JSONRPCNotification)
111+
and message.method == "notifications/initialized"
112+
)
108113

109114
def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None:
110115
"""Extract and store session ID from response headers."""
@@ -113,16 +118,23 @@ def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> N
113118
self.session_id = new_session_id
114119
logger.info(f"Received session ID: {self.session_id}")
115120

116-
def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None:
121+
def _maybe_extract_protocol_version_from_message(
122+
self, message: JSONRPCMessage
123+
) -> None:
117124
"""Extract protocol version from initialization response message."""
118125
if isinstance(message, JSONRPCResponse) and message.result: # pragma: no branch
119126
try:
120127
# Parse the result as InitializeResult for type safety
121-
init_result = InitializeResult.model_validate(message.result, by_name=False)
128+
init_result = InitializeResult.model_validate(
129+
message.result, by_name=False
130+
)
122131
self.protocol_version = str(init_result.protocol_version)
123132
logger.info(f"Negotiated protocol version: {self.protocol_version}")
124133
except Exception: # pragma: no cover
125-
logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True)
134+
logger.warning(
135+
"Failed to parse initialization response as InitializeResult",
136+
exc_info=True,
137+
)
126138
logger.warning(f"Raw result: {message.result}")
127139

128140
async def _handle_sse_event(
@@ -150,7 +162,9 @@ async def _handle_sse_event(
150162
self._maybe_extract_protocol_version_from_message(message)
151163

152164
# If this is a response and we have original_request_id, replace it
153-
if original_request_id is not None and isinstance(message, JSONRPCResponse | JSONRPCError):
165+
if original_request_id is not None and isinstance(
166+
message, JSONRPCResponse | JSONRPCError
167+
):
154168
message.id = original_request_id
155169

156170
session_message = SessionMessage(message)
@@ -167,8 +181,14 @@ async def _handle_sse_event(
167181
except Exception as exc: # pragma: no cover
168182
logger.exception("Error parsing SSE message")
169183
if original_request_id is not None:
170-
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}")
171-
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
184+
error_data = ErrorData(
185+
code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}"
186+
)
187+
error_msg = SessionMessage(
188+
JSONRPCError(
189+
jsonrpc="2.0", id=original_request_id, error=error_data
190+
)
191+
)
172192
await read_stream_writer.send(error_msg)
173193
return True
174194
await read_stream_writer.send(exc)
@@ -177,7 +197,9 @@ async def _handle_sse_event(
177197
logger.warning(f"Unknown SSE event: {sse.event}")
178198
return False
179199

180-
async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: StreamWriter) -> None:
200+
async def handle_get_stream(
201+
self, client: httpx.AsyncClient, read_stream_writer: StreamWriter
202+
) -> None:
181203
"""Handle GET stream for server-initiated messages with auto-reconnect."""
182204
last_event_id: str | None = None
183205
retry_interval_ms: int | None = None
@@ -192,7 +214,9 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
192214
if last_event_id:
193215
headers[LAST_EVENT_ID] = last_event_id
194216

195-
async with aconnect_sse(client, "GET", self.url, headers=headers) as event_source:
217+
async with aconnect_sse(
218+
client, "GET", self.url, headers=headers
219+
) as event_source:
196220
event_source.response.raise_for_status()
197221
logger.debug("GET SSE connection established")
198222

@@ -214,11 +238,17 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
214238
attempt += 1
215239

216240
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
217-
logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
241+
logger.debug(
242+
f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded"
243+
)
218244
return
219245

220246
# Wait before reconnecting
221-
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
247+
delay_ms = (
248+
retry_interval_ms
249+
if retry_interval_ms is not None
250+
else DEFAULT_RECONNECTION_DELAY_MS
251+
)
222252
logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...")
223253
await anyio.sleep(delay_ms / 1000.0)
224254

@@ -228,14 +258,18 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
228258
if ctx.metadata and ctx.metadata.resumption_token:
229259
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
230260
else:
231-
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
261+
raise ResumptionError(
262+
"Resumption request requires a resumption token"
263+
) # pragma: no cover
232264

233265
# Extract original request ID to map responses
234266
original_request_id = None
235267
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
236268
original_request_id = ctx.session_message.message.id
237269

238-
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
270+
async with aconnect_sse(
271+
ctx.client, "GET", self.url, headers=headers
272+
) as event_source:
239273
event_source.response.raise_for_status()
240274
logger.debug("Resumption GET SSE connection established")
241275

@@ -268,8 +302,12 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
268302

269303
if response.status_code == 404: # pragma: no branch
270304
if isinstance(message, JSONRPCRequest): # pragma: no branch
271-
error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated")
272-
session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
305+
error_data = ErrorData(
306+
code=INVALID_REQUEST, message="Session terminated"
307+
)
308+
session_message = SessionMessage(
309+
JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)
310+
)
273311
await ctx.read_stream_writer.send(session_message)
274312
return
275313

@@ -283,14 +321,22 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
283321
content_type = response.headers.get("content-type", "").lower()
284322
if content_type.startswith("application/json"):
285323
await self._handle_json_response(
286-
response, ctx.read_stream_writer, is_initialization, request_id=message.id
324+
response,
325+
ctx.read_stream_writer,
326+
is_initialization,
327+
request_id=message.id,
287328
)
288329
elif content_type.startswith("text/event-stream"):
289330
await self._handle_sse_response(response, ctx, is_initialization)
290331
else:
291332
logger.error(f"Unexpected content type: {content_type}")
292-
error_data = ErrorData(code=INVALID_REQUEST, message=f"Unexpected content type: {content_type}")
293-
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
333+
error_data = ErrorData(
334+
code=INVALID_REQUEST,
335+
message=f"Unexpected content type: {content_type}",
336+
)
337+
error_msg = SessionMessage(
338+
JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)
339+
)
294340
await ctx.read_stream_writer.send(error_msg)
295341

296342
async def _handle_json_response(
@@ -314,8 +360,12 @@ async def _handle_json_response(
314360
await read_stream_writer.send(session_message)
315361
except (httpx.StreamError, ValidationError) as exc:
316362
logger.exception("Error parsing JSON response")
317-
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse JSON response: {exc}")
318-
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data))
363+
error_data = ErrorData(
364+
code=PARSE_ERROR, message=f"Failed to parse JSON response: {exc}"
365+
)
366+
error_msg = SessionMessage(
367+
JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data)
368+
)
319369
await read_stream_writer.send(error_msg)
320370

321371
async def _handle_sse_response(
@@ -348,7 +398,11 @@ async def _handle_sse_response(
348398
sse,
349399
ctx.read_stream_writer,
350400
original_request_id=original_request_id,
351-
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
401+
resumption_callback=(
402+
ctx.metadata.on_resumption_token_update
403+
if ctx.metadata
404+
else None
405+
),
352406
is_initialization=is_initialization,
353407
)
354408
# If the SSE event indicates completion, like returning respose/error
@@ -374,11 +428,17 @@ async def _handle_reconnection(
374428
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
375429
# Bail if max retries exceeded
376430
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
377-
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
431+
logger.debug(
432+
f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded"
433+
)
378434
return
379435

380436
# Always wait - use server value or default
381-
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
437+
delay_ms = (
438+
retry_interval_ms
439+
if retry_interval_ms is not None
440+
else DEFAULT_RECONNECTION_DELAY_MS
441+
)
382442
await anyio.sleep(delay_ms / 1000.0)
383443

384444
headers = self._prepare_headers()
@@ -390,7 +450,9 @@ async def _handle_reconnection(
390450
original_request_id = ctx.session_message.message.id
391451

392452
try:
393-
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
453+
async with aconnect_sse(
454+
ctx.client, "GET", self.url, headers=headers
455+
) as event_source:
394456
event_source.response.raise_for_status()
395457
logger.info("Reconnected to SSE stream")
396458

@@ -408,19 +470,25 @@ async def _handle_reconnection(
408470
sse,
409471
ctx.read_stream_writer,
410472
original_request_id,
411-
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
473+
ctx.metadata.on_resumption_token_update
474+
if ctx.metadata
475+
else None,
412476
)
413477
if is_complete:
414478
await event_source.response.aclose()
415479
return
416480

417481
# Stream ended again without response - reconnect again (reset attempt counter)
418482
logger.info("SSE stream disconnected, reconnecting...")
419-
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
483+
await self._handle_reconnection(
484+
ctx, reconnect_last_event_id, reconnect_retry_ms, 0
485+
)
420486
except Exception as e: # pragma: no cover
421487
logger.debug(f"Reconnection failed: {e}")
422488
# Try to reconnect again if we still have an event ID
423-
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
489+
await self._handle_reconnection(
490+
ctx, last_event_id, retry_interval_ms, attempt + 1
491+
)
424492

425493
async def post_writer(
426494
self,
@@ -434,7 +502,8 @@ async def post_writer(
434502
"""Handle writing requests to the server."""
435503
try:
436504
async with write_stream_reader:
437-
async for session_message in write_stream_reader:
505+
506+
async def handle_message(session_message: SessionMessage) -> None:
438507
message = session_message.message
439508
metadata = (
440509
session_message.metadata
@@ -471,8 +540,14 @@ async def handle_request_async():
471540
else:
472541
await handle_request_async()
473542

474-
except Exception: # pragma: lax no cover
475-
logger.exception("Error in post_writer")
543+
async for session_message in write_stream_reader:
544+
async with anyio.create_task_group() as tg_local:
545+
session_message.context.run(
546+
tg_local.start_soon, handle_message, session_message
547+
)
548+
549+
except Exception:
550+
logger.exception("Error in post_writer") # pragma: no cover
476551
finally:
477552
await read_stream_writer.aclose()
478553
await write_stream.aclose()
@@ -526,8 +601,12 @@ async def streamable_http_client(
526601
Example:
527602
See examples/snippets/clients/ for usage patterns.
528603
"""
529-
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
530-
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
604+
read_stream_writer, read_stream = anyio.create_memory_object_stream[
605+
SessionMessage | Exception
606+
](0)
607+
write_stream, write_stream_reader = anyio.create_memory_object_stream[
608+
SessionMessage
609+
](0)
531610

532611
# Determine if we need to create and manage the client
533612
client_provided = http_client is not None
@@ -549,7 +628,9 @@ async def streamable_http_client(
549628
await stack.enter_async_context(client)
550629

551630
def start_get_stream() -> None:
552-
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
631+
tg.start_soon(
632+
transport.handle_get_stream, client, read_stream_writer
633+
)
553634

554635
tg.start_soon(
555636
transport.post_writer,

src/mcp/server/lowlevel/server.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,14 @@ async def run(
683683
async for message in session.incoming_messages:
684684
logger.debug("Received message: %s", message)
685685

686-
tg.start_soon(
686+
if isinstance(message, RequestResponder) and message.context is not None:
687+
logger.debug("Got a context to propagate, %s", message.context)
688+
context = message.context
689+
else:
690+
context = contextvars.copy_context()
691+
692+
context.run(
693+
tg.start_soon,
687694
self._handle_message,
688695
message,
689696
session,

src/mcp/shared/message.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
to support transport-specific features like resumability.
55
"""
66

7+
import contextvars
78
from collections.abc import Awaitable, Callable
8-
from dataclasses import dataclass
9+
from dataclasses import dataclass, field
910

1011
from mcp.types import JSONRPCMessage, RequestId
1112

@@ -46,4 +47,5 @@ class SessionMessage:
4647
"""A message with specific metadata for transport-specific features."""
4748

4849
message: JSONRPCMessage
50+
context: contextvars.Context = field(default_factory=contextvars.copy_context)
4951
metadata: MessageMetadata = None

0 commit comments

Comments
 (0)