diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c..c02b9edf4 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -57,8 +57,8 @@ async def sse_client( write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = anyio.create_memory_object_stream(1) + write_stream, write_stream_reader = anyio.create_memory_object_stream(1) async with anyio.create_task_group() as tg: try: @@ -113,11 +113,23 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): logger.debug(f"Received server message: {message}") except Exception as exc: # pragma: no cover logger.exception("Error parsing server message") # pragma: no cover - await read_stream_writer.send(exc) # pragma: no cover + try: # pragma: no cover + await read_stream_writer.send(exc) # pragma: no cover + except ( # pragma: no cover + anyio.ClosedResourceError, + anyio.BrokenResourceError, + ): + return # pragma: no cover continue # pragma: no cover session_message = SessionMessage(message) - await read_stream_writer.send(session_message) + try: + await read_stream_writer.send(session_message) + except ( + anyio.ClosedResourceError, + anyio.BrokenResourceError, + ): # pragma: no cover + return # pragma: no cover case _: # pragma: no cover logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover except SSEError as sse_exc: # pragma: lax no cover @@ -125,7 +137,10 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): raise sse_exc except Exception as exc: # pragma: lax no cover logger.exception("Error in sse_reader") - await read_stream_writer.send(exc) + try: + await read_stream_writer.send(exc) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + pass finally: await read_stream_writer.aclose() @@ -156,6 +171,8 @@ async def post_writer(endpoint_url: str): try: yield read_stream, write_stream finally: + await read_stream_writer.aclose() + await write_stream.aclose() tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0b..26882cc3f 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -155,7 +155,11 @@ async def _handle_sse_event( message.id = original_request_id session_message = SessionMessage(message) - await read_stream_writer.send(session_message) + try: + await read_stream_writer.send(session_message) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover + logger.debug("Read stream closed, stopping SSE event handling") + return True # Call resumption token callback if we have an ID if sse.id and resumption_callback: @@ -170,9 +174,15 @@ async def _handle_sse_event( if original_request_id is not None: error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}") error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data)) - await read_stream_writer.send(error_msg) + try: + await read_stream_writer.send(error_msg) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + pass return True - await read_stream_writer.send(exc) + try: + await read_stream_writer.send(exc) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + pass return False else: # pragma: no cover logger.warning(f"Unknown SSE event: {sse.event}") @@ -271,14 +281,20 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: if isinstance(message, JSONRPCRequest): # pragma: no branch error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) - await ctx.read_stream_writer.send(session_message) + try: + await ctx.read_stream_writer.send(session_message) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover + pass return if response.status_code >= 400: if isinstance(message, JSONRPCRequest): error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) - await ctx.read_stream_writer.send(session_message) + try: + await ctx.read_stream_writer.send(session_message) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover + pass return if is_initialization: @@ -298,7 +314,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.error(f"Unexpected content type: {content_type}") error_data = ErrorData(code=INVALID_REQUEST, message=f"Unexpected content type: {content_type}") error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) - await ctx.read_stream_writer.send(error_msg) + try: + await ctx.read_stream_writer.send(error_msg) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover + pass async def _handle_json_response( self, @@ -318,12 +337,18 @@ async def _handle_json_response( self._maybe_extract_protocol_version_from_message(message) session_message = SessionMessage(message) - await read_stream_writer.send(session_message) + try: + await read_stream_writer.send(session_message) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover + return except (httpx.StreamError, ValidationError) as exc: logger.exception("Error parsing JSON response") error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse JSON response: {exc}") error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data)) - await read_stream_writer.send(error_msg) + try: + await read_stream_writer.send(error_msg) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover + return async def _handle_sse_response( self, @@ -533,8 +558,8 @@ async def streamable_http_client( Example: See examples/snippets/clients/ for usage patterns. """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](1) # Determine if we need to create and manage the client client_provided = http_client is not None @@ -573,6 +598,10 @@ def start_get_stream() -> None: finally: if transport.session_id and terminate_on_close: await transport.terminate_session(client) + # Close streams before cancelling to unblock tasks + # waiting on stream send/receive during shutdown. + await read_stream_writer.aclose() + await write_stream.aclose() tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702f..22094fce6 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -224,9 +224,10 @@ async def __aexit__( exc_tb: TracebackType | None, ) -> bool | None: await self._exit_stack.aclose() - # Using BaseSession as a context manager should not block on exit (this - # would be very surprising behavior), so make sure to cancel the tasks - # in the task group. + # Close streams first so _receive_loop exits cooperatively, + # then cancel the task group as a fallback. + await self._read_stream.aclose() + await self._write_stream.aclose() self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5..a5c58a1d1 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -416,3 +416,45 @@ async def make_request(client_session: ClientSession): # Pending request completed successfully assert len(result_holder) == 1 assert isinstance(result_holder[0], EmptyResult) + + +@pytest.mark.anyio +async def test_session_exit_closes_streams_before_cancel(): + """Verify BaseSession.__aexit__ closes streams before cancelling task group. + + The receive loop should exit via ClosedResourceError on the read stream, + not via forced task group cancellation. This prevents AnyIO cancellation + busy-loops when tasks are blocked on stream operations. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, _server_write = server_streams + + async def slow_server(): + """Read a request but never respond, keeping the session busy.""" + try: + await server_read.receive() + # Hold the connection open + await anyio.sleep(60) + except (anyio.ClosedResourceError, anyio.get_cancelled_exc_class()): + pass + + async with anyio.create_task_group() as outer_tg: + outer_tg.start_soon(slow_server) + + with anyio.fail_after(5): # pragma: no branch + async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session: + # Fire a request in a background task (will never get a response) + async with anyio.create_task_group() as inner_tg: # pragma: no branch + + async def send_and_ignore(): + try: + await client_session.send_ping() + except (MCPError, anyio.get_cancelled_exc_class()): + pass + + inner_tg.start_soon(send_and_ignore) + await anyio.sleep(0.1) + inner_tg.cancel_scope.cancel() + + outer_tg.cancel_scope.cancel() # pragma: lax no cover diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 42b1a3698..95dc37075 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -2247,3 +2247,52 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_streamable_http_client_exit_with_pending_requests(basic_server: None, basic_server_url: str): + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1805. + + Sends tool calls to a server-side handler that blocks indefinitely (lock + never released), then exits the client context while responses are still + pending. Verifies that shutdown completes within the timeout and does not + hang or busy-loop in AnyIO cancellation delivery. + """ + with anyio.fail_after(10): # pragma: no branch + async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call_blocked_tool(): + try: + await session.call_tool("wait_for_lock_with_notification", {}) + except (MCPError, anyio.get_cancelled_exc_class()): + pass + + # Fire off multiple requests that will block server-side + for _ in range(3): + tg.start_soon(call_blocked_tool) + + # Give the server a moment to receive them, then bail out + await anyio.sleep(0.2) + tg.cancel_scope.cancel() + + # If we reach here, shutdown completed without hanging. + await anyio.sleep(0.1) + + +@pytest.mark.anyio +async def test_streamable_http_client_rapid_connect_disconnect(basic_server: None, basic_server_url: str): + """Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1805. + + Rapidly connects, initializes, and disconnects multiple times. Verifies no + resource leak or cancellation busy-loop across iterations. + """ + for _ in range(5): # pragma: no branch + with anyio.fail_after(10): # pragma: no branch + async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + await anyio.sleep(0.1)