diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 5c1459dff..7fc732f43 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -23,10 +23,37 @@ async def run_server(): import anyio import anyio.lowlevel +import pydantic_core from mcp import types from mcp.shared._context_streams import create_context_streams -from mcp.shared.message import SessionMessage +from mcp.shared.message import SessionMessage, extract_raw_request_id + + +def _error_response_for_invalid_line(line: str) -> SessionMessage: + """Build the JSON-RPC error response for a stdin line that failed message validation. + + Correlates the error with the originating request where possible: for lines that + are valid JSON but an invalid JSON-RPC envelope, the request id is extracted + best-effort from the raw payload (Invalid Request, -32600); for lines that are + not valid JSON, a null id is used (Parse error, -32700), per the JSON-RPC 2.0 + specification. + + Args: + line: The raw stdin line that failed to validate as a JSON-RPC message. + + Returns: + A `SessionMessage` wrapping the `JSONRPCError` to write back to the client. + """ + try: + raw_message = pydantic_core.from_json(line) + except ValueError: + request_id = None + error = types.ErrorData(code=types.PARSE_ERROR, message="Parse error") + else: + request_id = extract_raw_request_id(raw_message) + error = types.ErrorData(code=types.INVALID_REQUEST, message="Invalid Request") + return SessionMessage(types.JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) @asynccontextmanager @@ -53,6 +80,13 @@ async def stdin_reader(): try: message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) except Exception as exc: + try: + await write_stream.send(_error_response_for_invalid_line(line)) + except anyio.ClosedResourceError: + # The server side already closed the write stream; the + # error response cannot be delivered, but the exception + # below still surfaces the bad line in-stream. + await anyio.lowlevel.checkpoint() await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 220d46f9a..4b385ff33 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -27,12 +27,11 @@ from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage, extract_raw_request_id from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS, is_version_at_least from mcp.types import ( DEFAULT_NEGOTIATED_VERSION, INTERNAL_ERROR, - INVALID_PARAMS, INVALID_REQUEST, PARSE_ERROR, ErrorData, @@ -288,8 +287,14 @@ def _create_error_response( status_code: HTTPStatus, error_code: int = INVALID_REQUEST, headers: dict[str, str] | None = None, + request_id: RequestId | None = None, ) -> Response: - """Create an error response with a simple string message.""" + """Create an error response with a simple string message. + + ``request_id`` correlates the error with the originating request when it + could be extracted from the (possibly invalid) request body; it defaults + to ``None`` (a null id) per the JSON-RPC 2.0 specification. + """ response_headers = {"Content-Type": CONTENT_TYPE_JSON} if headers: response_headers.update(headers) @@ -300,7 +305,7 @@ def _create_error_response( # Return a properly formatted JSON error response error_response = JSONRPCError( jsonrpc="2.0", - id=None, + id=request_id, error=ErrorData(code=error_code, message=error_message), ) @@ -468,10 +473,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) except ValidationError as e: + # Correlate the error with the originating request: even though the + # envelope is invalid, the id is often still extractable from the raw + # payload (falls back to a null id per the JSON-RPC 2.0 spec). response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, - INVALID_PARAMS, + INVALID_REQUEST, + request_id=extract_raw_request_id(raw_message), ) await response(scope, receive, send) return diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 1858eeac3..8ab0ba5ec 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any +from typing import Any, cast from mcp.types import JSONRPCMessage, RequestId @@ -14,6 +14,29 @@ ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] + +def extract_raw_request_id(raw_message: Any) -> RequestId | None: + """Best-effort extraction of a JSON-RPC request id from an unvalidated payload. + + Used to correlate error responses with the originating request when an incoming + message fails JSON-RPC envelope validation: even though the envelope is invalid, + the ``id`` member is often still present in the raw parsed JSON. + + Args: + raw_message: The parsed JSON payload, before any envelope validation. + + Returns: + The request id when it is a valid JSON-RPC id type (a string, or an integer + that is not a bool — ``bool`` subclasses ``int`` but is not a valid id), + otherwise ``None``. + """ + if isinstance(raw_message, dict): + raw_id = cast("dict[Any, Any]", raw_message).get("id") + if isinstance(raw_id, str) or (isinstance(raw_id, int) and not isinstance(raw_id, bool)): + return raw_id + return None + + # Callback type for closing SSE streams without terminating CloseSSEStreamCallback = Callable[[], Awaitable[None]] diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index 85e64ded4..c2cc9dae8 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -15,7 +15,7 @@ from mcp.server import Server, ServerRequestContext from mcp.server.transport_security import TransportSecuritySettings from mcp.types import ( - INVALID_PARAMS, + INVALID_REQUEST, PARSE_ERROR, CallToolRequestParams, CallToolResult, @@ -129,7 +129,7 @@ async def test_non_json_content_type_is_rejected() -> None: @requirement("hosting:http:parse-error-400") @requirement("hosting:http:batch") async def test_malformed_and_batched_bodies_return_400() -> None: - """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid params.""" + """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid Request.""" async with mounted_app(_server()) as (http, _): session_id = await initialize_via_http(http) not_json = await http.post( @@ -149,7 +149,7 @@ async def test_malformed_and_batched_bodies_return_400() -> None: assert not_json.status_code == 400 assert JSONRPCError.model_validate_json(not_json.text).error.code == PARSE_ERROR assert batched.status_code == 400 - assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_PARAMS + assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_REQUEST @requirement("hosting:http:protocol-version-400") diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 054a157b3..b541f7037 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -11,7 +11,15 @@ from mcp.server.mcpserver import MCPServer from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter +from mcp.types import ( + INVALID_REQUEST, + PARSE_ERROR, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) @pytest.mark.anyio @@ -96,6 +104,46 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> Non assert second.message == valid +@pytest.mark.anyio +async def test_stdio_server_replies_to_invalid_messages_with_correlated_errors() -> None: + """Invalid stdin lines are answered with a JSON-RPC error carrying the original id. + + Lines that are valid JSON but invalid JSON-RPC envelopes get an Invalid Request + error with the id extracted best-effort from the raw payload; lines that are not + valid JSON get a Parse error with a null id, per the JSON-RPC 2.0 specification. + The exception is still surfaced in-stream for each bad line. + """ + invalid_lines = [ + '{"jsonrpc": "1.0", "id": 3, "method": "ping", "params": {}}', + '{"id": 4, "method": "ping", "params": {}}', + '{"jsonrpc": "2.0", "id": 8, "method": 12345, "params": {}}', + "this is not valid json", + ] + stdin = io.StringIO("".join(line + "\n" for line in invalid_lines)) + stdout = io.StringIO() + + with anyio.fail_after(5): + async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as ( + read_stream, + write_stream, + ): + async with read_stream: + for _ in invalid_lines: + received = await read_stream.receive() + assert isinstance(received, Exception) + await write_stream.aclose() + + stdout.seek(0) + error_responses = [JSONRPCError.model_validate_json(line.strip()) for line in stdout.readlines()] + assert [error_response.id for error_response in error_responses] == [3, 4, 8, None] + assert [error_response.error.code for error_response in error_responses] == [ + INVALID_REQUEST, + INVALID_REQUEST, + INVALID_REQUEST, + PARSE_ERROR, + ] + + class _KeepOpenBytesIO(io.BytesIO): """A BytesIO that survives its TextIOWrapper being closed. diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7db7e68fb..6d0368f31 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -499,6 +499,40 @@ async def test_json_parsing(basic_app: Starlette) -> None: assert "Validation error" in response.text +@pytest.mark.anyio +@pytest.mark.parametrize( + ("body", "expected_id"), + [ + pytest.param({"jsonrpc": "1.0", "id": 3, "method": "ping", "params": {}}, 3, id="wrong-jsonrpc-version"), + pytest.param({"id": 4, "method": "ping", "params": {}}, 4, id="missing-jsonrpc-field"), + pytest.param({"jsonrpc": "2.0", "id": 8, "method": 12345, "params": {}}, 8, id="method-not-a-string"), + pytest.param({"jsonrpc": "2.0", "id": 2.5, "method": 12345, "params": {}}, None, id="id-not-a-valid-type"), + ], +) +async def test_validation_error_preserves_request_id( + basic_app: Starlette, body: dict[str, Any], expected_id: int | None +) -> None: + """An envelope-invalid message is answered with an error carrying the original request id. + + The id is extracted best-effort from the raw payload so the client can correlate the + error response with its request; when no valid id can be extracted, the error falls + back to a null id per the JSON-RPC 2.0 specification. + """ + async with make_client(basic_app) as client: + response = await client.post( + "/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=body, + ) + assert response.status_code == 400 + error = types.JSONRPCError.model_validate_json(response.text) + assert error.id == expected_id + assert error.error.code == types.INVALID_REQUEST + + @pytest.mark.anyio async def test_method_not_allowed(basic_app: Starlette) -> None: """Unsupported HTTP methods are rejected with 405."""