diff --git a/docs/migration.md b/docs/migration.md index bd70bca99..d13b164aa 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -674,7 +674,7 @@ ctx: ClientRequestContext server_ctx: ServerRequestContext[LifespanContextT, RequestT] ``` -`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`), so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`. +`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus a new `protocol_version: str` field, so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`. The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient: diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 8f7296202..e0f406a20 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -87,7 +87,7 @@ class Connection: """The protocol version negotiated during `initialize`; `None` before initialization. Stateless connections don't require the handshake, so this normally stays `None` there (a client that sends `initialize` anyway still - commits it). Handlers read this as `ServerSession.protocol_version`.""" + commits it). For the per-request value, read `ctx.protocol_version`.""" initialized: anyio.Event """Set when `notifications/initialized` arrives (matches TS `oninitialized`); diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 61003ac9f..b7effb70f 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -31,6 +31,7 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]): session: ServerSession lifespan_context: LifespanContextT + protocol_version: str request_id: RequestId | None = None meta: RequestParamsMeta | None = None request: RequestT | None = None diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index e65b2e68c..8dd9a2fac 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -34,7 +34,7 @@ from mcp.shared.dispatcher import DispatchContext, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import ServerMessageMetadata +from mcp.shared.message import MessageMetadata, ServerMessageMetadata from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -42,6 +42,7 @@ INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, + PROTOCOL_VERSION_META_KEY, ErrorData, Implementation, InitializeRequestParams, @@ -79,6 +80,30 @@ def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None: return None +def _resolve_protocol_version( + negotiated: str | None, + meta: RequestParamsMeta | None, + md: MessageMetadata, +) -> str: + """Resolve the protocol version for this inbound message. + + Handshake-committed value wins; else per-request `_meta`, else the + transport hint. Unsupported values fall through so surface validation + never sees them. + """ + if negotiated is not None: + return negotiated + if meta is not None: + v = meta.get(PROTOCOL_VERSION_META_KEY) + if isinstance(v, str) and v in SUPPORTED_PROTOCOL_VERSIONS: + return v + if isinstance(md, ServerMessageMetadata): + hint = md.protocol_version + if hint is not None and hint in SUPPORTED_PROTOCOL_VERSIONS: + return hint + return "2025-11-25" + + def otel_middleware(next_on_request: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. @@ -218,11 +243,9 @@ async def _on_request( method: str, params: Mapping[str, Any] | None, ) -> dict[str, Any]: - ctx = self._make_context(dctx, _extract_meta(params)) - # Literal, not LATEST_PROTOCOL_VERSION: the fallback covers the initialize - # handshake (which only exists at <=2025) and stateless until the header - # is plumbed; its meaning is fixed regardless of LATEST bumps. - version = self.connection.protocol_version or "2025-11-25" + meta = _extract_meta(params) + version = _resolve_protocol_version(self.connection.protocol_version, meta, dctx.message_metadata) + ctx = self._make_context(dctx, meta, version) is_spec_method = method in _methods.SPEC_CLIENT_METHODS async def _inner() -> HandlerResult: @@ -289,9 +312,9 @@ async def _on_notify( method: str, params: Mapping[str, Any] | None, ) -> None: - ctx = self._make_context(dctx, _extract_meta(params)) - # Same fallback as `_on_request`: covers pre-handshake and stateless. - version = self.connection.protocol_version or "2025-11-25" + meta = _extract_meta(params) + version = _resolve_protocol_version(self.connection.protocol_version, meta, dctx.message_metadata) + ctx = self._make_context(dctx, meta, version) async def _inner() -> None: if method in _methods.SPEC_CLIENT_NOTIFICATION_METHODS: @@ -349,7 +372,7 @@ def _compose_server_middleware( return call def _make_context( - self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None + self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None, protocol_version: str ) -> ServerRequestContext[LifespanT, Any]: # TODO(maxisbey): remove for Context rework. Reads the SHTTP per-request # data off the raw `dctx.message_metadata` carrier; replace with the @@ -366,6 +389,7 @@ def _make_context( lifespan_context=self.lifespan_state, request_id=dctx.request_id, meta=meta, + protocol_version=protocol_version, request=request, close_sse_stream=close_sse_stream, close_standalone_sse_stream=close_standalone_sse_stream, diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 2dba81abe..6254a01ee 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -55,10 +55,8 @@ def client_params(self) -> types.InitializeRequestParams | None: def protocol_version(self) -> str | None: """The protocol version negotiated during `initialize`. - `None` before initialization completes. Stateless connections don't - require the handshake, so this is normally `None` there (on streamable - HTTP the per-request version is the `MCP-Protocol-Version` header, - available via `ctx.request.headers`). + `None` before initialization, and normally `None` on stateless + connections. For the per-request value, read `ctx.protocol_version`. """ return self._connection.protocol_version diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 93904d6cc..9103996a5 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -248,11 +248,12 @@ async def close_standalone_stream_callback() -> None: metadata = ServerMessageMetadata( request_context=request, + protocol_version=protocol_version, close_sse_stream=close_stream_callback, close_standalone_sse_stream=close_standalone_stream_callback, ) else: - metadata = ServerMessageMetadata(request_context=request) + metadata = ServerMessageMetadata(request_context=request, protocol_version=protocol_version) return SessionMessage(message, metadata=metadata) @@ -506,7 +507,10 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) # Process the message after sending the response - metadata = ServerMessageMetadata(request_context=request) + metadata = ServerMessageMetadata( + request_context=request, + protocol_version=request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION), + ) session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) @@ -529,7 +533,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re if self.is_json_response_enabled: # Process the message - metadata = ServerMessageMetadata(request_context=request) + metadata = ServerMessageMetadata(request_context=request, protocol_version=protocol_version) session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) try: diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 1858eeac3..dba263ad5 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -35,6 +35,9 @@ class ServerMessageMetadata: # transports, None for stdio). Typed as Any because the server layer is # transport-agnostic. request_context: Any = None + # Per-message protocol version observed by the transport (e.g. the + # validated MCP-Protocol-Version header). + protocol_version: str | None = None # Callback to close SSE stream for the current request without terminating close_sse_stream: CloseSSEStreamCallback | None = None # Callback to close the standalone GET SSE stream (for unsolicited notifications) diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index bef44928a..ddd9c67c1 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -21,6 +21,7 @@ async def test_progress_token_zero_first_call(): session=mock_session, meta={"progress_token": 0}, lifespan_context=None, + protocol_version="2025-11-25", ) # Create context with our mocks diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 60d30342c..6ec060d20 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1501,6 +1501,7 @@ async def test_report_progress_passes_related_request_id(): session=mock_session, meta={"progress_token": "tok-1"}, lifespan_context=None, + protocol_version="2025-11-25", ) ctx = Context(request_context=request_context, mcp_server=MagicMock()) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index bc298185c..5d61a676a 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -18,11 +18,12 @@ from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.models import InitializationOptions -from mcp.server.runner import ServerRunner, _extract_meta, otel_middleware +from mcp.server.runner import ServerRunner, _extract_meta, _resolve_protocol_version, otel_middleware from mcp.server.session import ServerSession from mcp.shared.dispatcher import DispatchContext, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata from mcp.shared.peer import dump_params from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -31,6 +32,7 @@ INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, + PROTOCOL_VERSION_META_KEY, CallToolRequestParams, ClientCapabilities, ErrorData, @@ -41,6 +43,7 @@ PaginatedRequestParams, ProgressNotificationParams, RequestParams, + RequestParamsMeta, SetLevelRequestParams, Tool, ) @@ -218,6 +221,7 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): assert isinstance(ctx.session, ServerSession) assert ctx.session is runner.session assert ctx.request_id is not None + assert ctx.protocol_version == LATEST_PROTOCOL_VERSION @pytest.mark.anyio @@ -650,6 +654,65 @@ async def on_roots(ctx: Ctx, params: NotificationParams | None) -> None: ] +def test_resolve_protocol_version_handshake_committed_value_wins(): + md = ServerMessageMetadata(protocol_version="2025-03-26") + meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: "2025-03-26"} + assert _resolve_protocol_version("2025-06-18", meta, md) == "2025-06-18" + + +def test_resolve_protocol_version_reads_per_request_meta_when_no_handshake(): + md = ServerMessageMetadata(protocol_version="2025-03-26") + meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: "2025-06-18"} + assert _resolve_protocol_version(None, meta, md) == "2025-06-18" + + +def test_resolve_protocol_version_skips_unsupported_meta_value(): + md = ServerMessageMetadata(protocol_version="2025-03-26") + meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: "1900-01-01"} + assert _resolve_protocol_version(None, meta, md) == "2025-03-26" + + +def test_resolve_protocol_version_skips_non_string_meta_value(): + md = ServerMessageMetadata(protocol_version="2025-03-26") + meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: 42} + assert _resolve_protocol_version(None, meta, md) == "2025-03-26" + + +def test_resolve_protocol_version_reads_transport_hint_when_no_handshake_or_meta(): + md = ServerMessageMetadata(protocol_version="2025-06-18") + assert _resolve_protocol_version(None, None, md) == "2025-06-18" + assert _resolve_protocol_version(None, {}, md) == "2025-06-18" + + +def test_resolve_protocol_version_skips_unsupported_transport_hint(): + """The `initialize` params version reaches the metadata unvalidated; surface validation must never see it.""" + md = ServerMessageMetadata(protocol_version="1900-01-01") + assert _resolve_protocol_version(None, None, md) == "2025-11-25" + + +def test_resolve_protocol_version_terminal_default_with_no_signals(): + assert _resolve_protocol_version(None, None, None) == "2025-11-25" + assert _resolve_protocol_version(None, None, ServerMessageMetadata()) == "2025-11-25" + assert _resolve_protocol_version(None, None, ClientMessageMetadata()) == "2025-11-25" + + +@pytest.mark.anyio +async def test_runner_ctx_protocol_version_is_terminal_default_on_stateless_in_memory(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True) as (client, runner): + await client.send_raw_request("tools/list", None) + ctx = _seen_ctx[0] + assert ctx.protocol_version == "2025-11-25" + assert ctx.session.protocol_version is None + assert runner.connection.protocol_version is None + + +@pytest.mark.anyio +async def test_runner_ctx_protocol_version_tracks_per_request_meta_on_stateless(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True) as (client, _): + await client.send_raw_request("tools/list", {"_meta": {PROTOCOL_VERSION_META_KEY: "2025-06-18"}}) + assert _seen_ctx[0].protocol_version == "2025-06-18" + + def test_extract_meta_returns_none_for_absent_or_malformed(): """Context construction is independent of `_meta` validity; the params validation inside `call_next()` is what surfaces the error.""" diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 02976656e..6aadf6ff8 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -49,6 +49,7 @@ from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( + DEFAULT_NEGOTIATED_VERSION, CallToolRequestParams, CallToolResult, InitializeResult, @@ -81,14 +82,18 @@ # Helper functions -def extract_protocol_version_from_sse(response: httpx.Response) -> str: - """Extract the negotiated protocol version from an SSE initialization response.""" +def first_sse_data(response: httpx.Response) -> dict[str, Any]: + """Return the first SSE `data:` payload of a response, parsed as JSON.""" assert response.headers.get("Content-Type") == "text/event-stream" for line in response.text.splitlines(): if line.startswith("data: "): - init_data = json.loads(line[6:]) - return init_data["result"]["protocolVersion"] - raise ValueError("Could not extract protocol version from SSE response") # pragma: no cover + return json.loads(line.removeprefix("data: ")) + raise ValueError("No data event in SSE response") # pragma: no cover + + +def extract_protocol_version_from_sse(response: httpx.Response) -> str: + """Extract the negotiated protocol version from an SSE initialization response.""" + return first_sse_data(response)["result"]["protocolVersion"] # Simple in-memory event store for testing @@ -1318,13 +1323,14 @@ async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolR "headers": dict(ctx.request.headers), "method": ctx.request.method, "path": ctx.request.url.path, + "protocol_version": ctx.protocol_version, + "session_protocol_version": ctx.session.protocol_version, } return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) -@pytest.fixture -async def context_app() -> AsyncIterator[Starlette]: - """An app whose server echoes request context, served in process.""" +@asynccontextmanager +async def _run_context_app(*, stateless: bool) -> AsyncIterator[Starlette]: server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, @@ -1332,6 +1338,7 @@ async def context_app() -> AsyncIterator[Starlette]: ) session_manager = StreamableHTTPSessionManager( app=server, + stateless=stateless, security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) @@ -1339,6 +1346,51 @@ async def context_app() -> AsyncIterator[Starlette]: yield app +@pytest.fixture +async def context_app() -> AsyncIterator[Starlette]: + """An app whose server echoes request context, served in process.""" + async with _run_context_app(stateless=False) as app: + yield app + + +@pytest.fixture +async def stateless_context_app() -> AsyncIterator[Starlette]: + async with _run_context_app(stateless=True) as app: + yield app + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("header_value", "expected"), + [ + ("2025-06-18", "2025-06-18"), + ("2025-11-25", "2025-11-25"), + (None, DEFAULT_NEGOTIATED_VERSION), + ], +) +async def test_streamablehttp_stateless_ctx_protocol_version_tracks_the_header( + stateless_context_app: Starlette, header_value: str | None, expected: str +) -> None: + """No handshake on stateless: the header (or the spec's 2025-03-26 default) reaches `ctx.protocol_version`.""" + body = JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="tools/call", + params={"name": "echo_context", "arguments": {"request_id": "r"}}, + ) + headers = {"Accept": "application/json, text/event-stream", "Content-Type": "application/json"} + if header_value is not None: + headers[MCP_PROTOCOL_VERSION_HEADER] = header_value + async with make_client(stateless_context_app) as client: + response = await client.post( + f"{BASE_URL}/mcp", json=body.model_dump(by_alias=True, exclude_none=True), headers=headers + ) + assert response.status_code == 200 + echoed = json.loads(first_sse_data(response)["result"]["content"][0]["text"]) + assert echoed["protocol_version"] == expected + assert echoed["session_protocol_version"] is None + + @pytest.mark.anyio async def test_streamablehttp_request_context_propagation(context_app: Starlette) -> None: """Custom HTTP headers on the connection are visible to server handlers via ctx.request."""