4545LAST_EVENT_ID = "last-event-id"
4646
4747# Reconnection defaults
48- DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
48+ DEFAULT_RECONNECTION_DELAY_MS = (
49+ 1000 # 1 second fallback when server doesn't provide retry
50+ )
4951MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
5052
5153
@@ -104,7 +106,10 @@ def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
104106
105107 def _is_initialized_notification (self , message : JSONRPCMessage ) -> bool :
106108 """Check if the message is an initialized notification."""
107- return isinstance (message , JSONRPCNotification ) and message .method == "notifications/initialized"
109+ return (
110+ isinstance (message , JSONRPCNotification )
111+ and message .method == "notifications/initialized"
112+ )
108113
109114 def _maybe_extract_session_id_from_response (self , response : httpx .Response ) -> None :
110115 """Extract and store session ID from response headers."""
@@ -113,16 +118,23 @@ def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> N
113118 self .session_id = new_session_id
114119 logger .info (f"Received session ID: { self .session_id } " )
115120
116- def _maybe_extract_protocol_version_from_message (self , message : JSONRPCMessage ) -> None :
121+ def _maybe_extract_protocol_version_from_message (
122+ self , message : JSONRPCMessage
123+ ) -> None :
117124 """Extract protocol version from initialization response message."""
118125 if isinstance (message , JSONRPCResponse ) and message .result : # pragma: no branch
119126 try :
120127 # Parse the result as InitializeResult for type safety
121- init_result = InitializeResult .model_validate (message .result , by_name = False )
128+ init_result = InitializeResult .model_validate (
129+ message .result , by_name = False
130+ )
122131 self .protocol_version = str (init_result .protocol_version )
123132 logger .info (f"Negotiated protocol version: { self .protocol_version } " )
124133 except Exception : # pragma: no cover
125- logger .warning ("Failed to parse initialization response as InitializeResult" , exc_info = True )
134+ logger .warning (
135+ "Failed to parse initialization response as InitializeResult" ,
136+ exc_info = True ,
137+ )
126138 logger .warning (f"Raw result: { message .result } " )
127139
128140 async def _handle_sse_event (
@@ -150,7 +162,9 @@ async def _handle_sse_event(
150162 self ._maybe_extract_protocol_version_from_message (message )
151163
152164 # If this is a response and we have original_request_id, replace it
153- if original_request_id is not None and isinstance (message , JSONRPCResponse | JSONRPCError ):
165+ if original_request_id is not None and isinstance (
166+ message , JSONRPCResponse | JSONRPCError
167+ ):
154168 message .id = original_request_id
155169
156170 session_message = SessionMessage (message )
@@ -167,8 +181,14 @@ async def _handle_sse_event(
167181 except Exception as exc : # pragma: no cover
168182 logger .exception ("Error parsing SSE message" )
169183 if original_request_id is not None :
170- error_data = ErrorData (code = PARSE_ERROR , message = f"Failed to parse SSE message: { exc } " )
171- error_msg = SessionMessage (JSONRPCError (jsonrpc = "2.0" , id = original_request_id , error = error_data ))
184+ error_data = ErrorData (
185+ code = PARSE_ERROR , message = f"Failed to parse SSE message: { exc } "
186+ )
187+ error_msg = SessionMessage (
188+ JSONRPCError (
189+ jsonrpc = "2.0" , id = original_request_id , error = error_data
190+ )
191+ )
172192 await read_stream_writer .send (error_msg )
173193 return True
174194 await read_stream_writer .send (exc )
@@ -177,7 +197,9 @@ async def _handle_sse_event(
177197 logger .warning (f"Unknown SSE event: { sse .event } " )
178198 return False
179199
180- async def handle_get_stream (self , client : httpx .AsyncClient , read_stream_writer : StreamWriter ) -> None :
200+ async def handle_get_stream (
201+ self , client : httpx .AsyncClient , read_stream_writer : StreamWriter
202+ ) -> None :
181203 """Handle GET stream for server-initiated messages with auto-reconnect."""
182204 last_event_id : str | None = None
183205 retry_interval_ms : int | None = None
@@ -192,7 +214,9 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
192214 if last_event_id :
193215 headers [LAST_EVENT_ID ] = last_event_id
194216
195- async with aconnect_sse (client , "GET" , self .url , headers = headers ) as event_source :
217+ async with aconnect_sse (
218+ client , "GET" , self .url , headers = headers
219+ ) as event_source :
196220 event_source .response .raise_for_status ()
197221 logger .debug ("GET SSE connection established" )
198222
@@ -214,11 +238,17 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
214238 attempt += 1
215239
216240 if attempt >= MAX_RECONNECTION_ATTEMPTS : # pragma: no cover
217- logger .debug (f"GET stream max reconnection attempts ({ MAX_RECONNECTION_ATTEMPTS } ) exceeded" )
241+ logger .debug (
242+ f"GET stream max reconnection attempts ({ MAX_RECONNECTION_ATTEMPTS } ) exceeded"
243+ )
218244 return
219245
220246 # Wait before reconnecting
221- delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
247+ delay_ms = (
248+ retry_interval_ms
249+ if retry_interval_ms is not None
250+ else DEFAULT_RECONNECTION_DELAY_MS
251+ )
222252 logger .info (f"GET stream disconnected, reconnecting in { delay_ms } ms..." )
223253 await anyio .sleep (delay_ms / 1000.0 )
224254
@@ -228,14 +258,18 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
228258 if ctx .metadata and ctx .metadata .resumption_token :
229259 headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
230260 else :
231- raise ResumptionError ("Resumption request requires a resumption token" ) # pragma: no cover
261+ raise ResumptionError (
262+ "Resumption request requires a resumption token"
263+ ) # pragma: no cover
232264
233265 # Extract original request ID to map responses
234266 original_request_id = None
235267 if isinstance (ctx .session_message .message , JSONRPCRequest ): # pragma: no branch
236268 original_request_id = ctx .session_message .message .id
237269
238- async with aconnect_sse (ctx .client , "GET" , self .url , headers = headers ) as event_source :
270+ async with aconnect_sse (
271+ ctx .client , "GET" , self .url , headers = headers
272+ ) as event_source :
239273 event_source .response .raise_for_status ()
240274 logger .debug ("Resumption GET SSE connection established" )
241275
@@ -268,8 +302,12 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
268302
269303 if response .status_code == 404 : # pragma: no branch
270304 if isinstance (message , JSONRPCRequest ): # pragma: no branch
271- error_data = ErrorData (code = INVALID_REQUEST , message = "Session terminated" )
272- session_message = SessionMessage (JSONRPCError (jsonrpc = "2.0" , id = message .id , error = error_data ))
305+ error_data = ErrorData (
306+ code = INVALID_REQUEST , message = "Session terminated"
307+ )
308+ session_message = SessionMessage (
309+ JSONRPCError (jsonrpc = "2.0" , id = message .id , error = error_data )
310+ )
273311 await ctx .read_stream_writer .send (session_message )
274312 return
275313
@@ -283,14 +321,22 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
283321 content_type = response .headers .get ("content-type" , "" ).lower ()
284322 if content_type .startswith ("application/json" ):
285323 await self ._handle_json_response (
286- response , ctx .read_stream_writer , is_initialization , request_id = message .id
324+ response ,
325+ ctx .read_stream_writer ,
326+ is_initialization ,
327+ request_id = message .id ,
287328 )
288329 elif content_type .startswith ("text/event-stream" ):
289330 await self ._handle_sse_response (response , ctx , is_initialization )
290331 else :
291332 logger .error (f"Unexpected content type: { content_type } " )
292- error_data = ErrorData (code = INVALID_REQUEST , message = f"Unexpected content type: { content_type } " )
293- error_msg = SessionMessage (JSONRPCError (jsonrpc = "2.0" , id = message .id , error = error_data ))
333+ error_data = ErrorData (
334+ code = INVALID_REQUEST ,
335+ message = f"Unexpected content type: { content_type } " ,
336+ )
337+ error_msg = SessionMessage (
338+ JSONRPCError (jsonrpc = "2.0" , id = message .id , error = error_data )
339+ )
294340 await ctx .read_stream_writer .send (error_msg )
295341
296342 async def _handle_json_response (
@@ -314,8 +360,12 @@ async def _handle_json_response(
314360 await read_stream_writer .send (session_message )
315361 except (httpx .StreamError , ValidationError ) as exc :
316362 logger .exception ("Error parsing JSON response" )
317- error_data = ErrorData (code = PARSE_ERROR , message = f"Failed to parse JSON response: { exc } " )
318- error_msg = SessionMessage (JSONRPCError (jsonrpc = "2.0" , id = request_id , error = error_data ))
363+ error_data = ErrorData (
364+ code = PARSE_ERROR , message = f"Failed to parse JSON response: { exc } "
365+ )
366+ error_msg = SessionMessage (
367+ JSONRPCError (jsonrpc = "2.0" , id = request_id , error = error_data )
368+ )
319369 await read_stream_writer .send (error_msg )
320370
321371 async def _handle_sse_response (
@@ -348,7 +398,11 @@ async def _handle_sse_response(
348398 sse ,
349399 ctx .read_stream_writer ,
350400 original_request_id = original_request_id ,
351- resumption_callback = (ctx .metadata .on_resumption_token_update if ctx .metadata else None ),
401+ resumption_callback = (
402+ ctx .metadata .on_resumption_token_update
403+ if ctx .metadata
404+ else None
405+ ),
352406 is_initialization = is_initialization ,
353407 )
354408 # If the SSE event indicates completion, like returning respose/error
@@ -374,11 +428,17 @@ async def _handle_reconnection(
374428 """Reconnect with Last-Event-ID to resume stream after server disconnect."""
375429 # Bail if max retries exceeded
376430 if attempt >= MAX_RECONNECTION_ATTEMPTS : # pragma: no cover
377- logger .debug (f"Max reconnection attempts ({ MAX_RECONNECTION_ATTEMPTS } ) exceeded" )
431+ logger .debug (
432+ f"Max reconnection attempts ({ MAX_RECONNECTION_ATTEMPTS } ) exceeded"
433+ )
378434 return
379435
380436 # Always wait - use server value or default
381- delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
437+ delay_ms = (
438+ retry_interval_ms
439+ if retry_interval_ms is not None
440+ else DEFAULT_RECONNECTION_DELAY_MS
441+ )
382442 await anyio .sleep (delay_ms / 1000.0 )
383443
384444 headers = self ._prepare_headers ()
@@ -390,7 +450,9 @@ async def _handle_reconnection(
390450 original_request_id = ctx .session_message .message .id
391451
392452 try :
393- async with aconnect_sse (ctx .client , "GET" , self .url , headers = headers ) as event_source :
453+ async with aconnect_sse (
454+ ctx .client , "GET" , self .url , headers = headers
455+ ) as event_source :
394456 event_source .response .raise_for_status ()
395457 logger .info ("Reconnected to SSE stream" )
396458
@@ -408,19 +470,25 @@ async def _handle_reconnection(
408470 sse ,
409471 ctx .read_stream_writer ,
410472 original_request_id ,
411- ctx .metadata .on_resumption_token_update if ctx .metadata else None ,
473+ ctx .metadata .on_resumption_token_update
474+ if ctx .metadata
475+ else None ,
412476 )
413477 if is_complete :
414478 await event_source .response .aclose ()
415479 return
416480
417481 # Stream ended again without response - reconnect again (reset attempt counter)
418482 logger .info ("SSE stream disconnected, reconnecting..." )
419- await self ._handle_reconnection (ctx , reconnect_last_event_id , reconnect_retry_ms , 0 )
483+ await self ._handle_reconnection (
484+ ctx , reconnect_last_event_id , reconnect_retry_ms , 0
485+ )
420486 except Exception as e : # pragma: no cover
421487 logger .debug (f"Reconnection failed: { e } " )
422488 # Try to reconnect again if we still have an event ID
423- await self ._handle_reconnection (ctx , last_event_id , retry_interval_ms , attempt + 1 )
489+ await self ._handle_reconnection (
490+ ctx , last_event_id , retry_interval_ms , attempt + 1
491+ )
424492
425493 async def post_writer (
426494 self ,
@@ -434,7 +502,8 @@ async def post_writer(
434502 """Handle writing requests to the server."""
435503 try :
436504 async with write_stream_reader :
437- async for session_message in write_stream_reader :
505+
506+ async def handle_message (session_message : SessionMessage ) -> None :
438507 message = session_message .message
439508 metadata = (
440509 session_message .metadata
@@ -471,8 +540,14 @@ async def handle_request_async():
471540 else :
472541 await handle_request_async ()
473542
474- except Exception : # pragma: lax no cover
475- logger .exception ("Error in post_writer" )
543+ async for session_message in write_stream_reader :
544+ async with anyio .create_task_group () as tg_local :
545+ session_message .context .run (
546+ tg_local .start_soon , handle_message , session_message
547+ )
548+
549+ except Exception :
550+ logger .exception ("Error in post_writer" ) # pragma: no cover
476551 finally :
477552 await read_stream_writer .aclose ()
478553 await write_stream .aclose ()
@@ -526,8 +601,12 @@ async def streamable_http_client(
526601 Example:
527602 See examples/snippets/clients/ for usage patterns.
528603 """
529- read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
530- write_stream , write_stream_reader = anyio .create_memory_object_stream [SessionMessage ](0 )
604+ read_stream_writer , read_stream = anyio .create_memory_object_stream [
605+ SessionMessage | Exception
606+ ](0 )
607+ write_stream , write_stream_reader = anyio .create_memory_object_stream [
608+ SessionMessage
609+ ](0 )
531610
532611 # Determine if we need to create and manage the client
533612 client_provided = http_client is not None
@@ -549,7 +628,9 @@ async def streamable_http_client(
549628 await stack .enter_async_context (client )
550629
551630 def start_get_stream () -> None :
552- tg .start_soon (transport .handle_get_stream , client , read_stream_writer )
631+ tg .start_soon (
632+ transport .handle_get_stream , client , read_stream_writer
633+ )
553634
554635 tg .start_soon (
555636 transport .post_writer ,
0 commit comments