Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion packages/core/src/shared/transport.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types/index.js';
import type { ClientCapabilities, Implementation, JSONRPCMessage, MessageExtraInfo, RequestId } from '../types/index.js';

export type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;

Expand Down Expand Up @@ -116,6 +116,20 @@ export interface Transport {
*/
onmessage?: (<T extends JSONRPCMessage>(message: T, extra?: MessageExtraInfo) => void) | undefined;

/**
* Callback invoked when session initialization state is restored by the transport.
*
* For transports that support stateless session replay (e.g.,
* {@linkcode @modelcontextprotocol/server!server/streamableHttp.WebStandardStreamableHTTPServerTransport | WebStandardStreamableHTTPServerTransport}),
* this is called when a non-initialize request triggers session restoration
* via the transport's `replayInitialization` callback.
*
* The {@linkcode @modelcontextprotocol/server!server/server.Server | Server} hooks this during
* {@linkcode @modelcontextprotocol/server!server/server.Server.connect | connect()} to
* seed client capabilities and version info.
*/
oninitializationreplay?: ((data: { clientCapabilities: ClientCapabilities; clientVersion: Implementation }) => void) | undefined;

/**
* The session ID generated for this connection.
*/
Expand Down
3 changes: 2 additions & 1 deletion packages/core/src/util/inMemory.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js';
import type { Transport } from '../shared/transport.js';
import type { AuthInfo, JSONRPCMessage, RequestId } from '../types/index.js';
import type { AuthInfo, ClientCapabilities, Implementation, JSONRPCMessage, RequestId } from '../types/index.js';

interface QueuedMessage {
message: JSONRPCMessage;
Expand All @@ -18,6 +18,7 @@ export class InMemoryTransport implements Transport {
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void;
oninitializationreplay?: (data: { clientCapabilities: ClientCapabilities; clientVersion: Implementation }) => void;
sessionId?: string;

/**
Expand Down
23 changes: 22 additions & 1 deletion packages/middleware/node/src/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
import type { IncomingMessage, ServerResponse } from 'node:http';

import { getRequestListener } from '@hono/node-server';
import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core';
import type {
AuthInfo,
ClientCapabilities,
Implementation,
JSONRPCMessage,
MessageExtraInfo,
RequestId,
Transport
} from '@modelcontextprotocol/core';
import type { WebStandardStreamableHTTPServerTransportOptions } from '@modelcontextprotocol/server';
import { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server';

Expand Down Expand Up @@ -130,6 +138,19 @@ export class NodeStreamableHTTPServerTransport implements Transport {
return this._webStandardTransport.onmessage;
}

/**
* Sets callback for session initialization replay.
*/
set oninitializationreplay(
handler: ((data: { clientCapabilities: ClientCapabilities; clientVersion: Implementation }) => void) | undefined
) {
this._webStandardTransport.oninitializationreplay = handler;
}

get oninitializationreplay(): ((data: { clientCapabilities: ClientCapabilities; clientVersion: Implementation }) => void) | undefined {
return this._webStandardTransport.oninitializationreplay;
}

/**
* Starts the transport. This is required by the {@linkcode Transport} interface but is a no-op
* for the Streamable HTTP transport as connections are managed per-request.
Expand Down
18 changes: 17 additions & 1 deletion packages/server/src/server/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import type {
ServerResult,
TaskManagerOptions,
ToolResultContent,
ToolUseContent
ToolUseContent,
Transport
} from '@modelcontextprotocol/core';
import {
assertClientRequestTaskCapability,
Expand Down Expand Up @@ -140,6 +141,21 @@ export class Server extends Protocol<ServerContext> {
}
}

/**
* Attaches to the given transport, hooking the `oninitializationreplay` callback
* to seed client capabilities and version info from replayed sessions.
*/
override async connect(transport: Transport): Promise<void> {
const _oninitializationreplay = transport.oninitializationreplay;
transport.oninitializationreplay = data => {
_oninitializationreplay?.(data);
this._clientCapabilities ??= data.clientCapabilities;
this._clientVersion ??= data.clientVersion;
};

await super.connect(transport);
}

private _registerLoggingHandler(): void {
this.setRequestHandler('logging/setLevel', async (request, ctx) => {
const transportSessionId: string | undefined =
Expand Down
86 changes: 82 additions & 4 deletions packages/server/src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@
* For Node.js Express/HTTP compatibility, use {@linkcode @modelcontextprotocol/node!NodeStreamableHTTPServerTransport | NodeStreamableHTTPServerTransport} which wraps this transport.
*/

import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core';
import type {
AuthInfo,
ClientCapabilities,
Implementation,
JSONRPCMessage,
MessageExtraInfo,
RequestId,
Transport
} from '@modelcontextprotocol/core';
import {
DEFAULT_NEGOTIATED_PROTOCOL_VERSION,
isInitializeRequest,
Expand Down Expand Up @@ -152,6 +160,30 @@ export interface WebStandardStreamableHTTPServerTransportOptions {
* @default {@linkcode SUPPORTED_PROTOCOL_VERSIONS}
*/
supportedProtocolVersions?: string[];

/**
* Callback to restore session state for stateless deployments.
*
* Called when a non-initialize request arrives and the transport is not yet
* initialized. The transport reads the session ID from the `mcp-session-id`
* request header and passes it to this callback.
*
* Return cached client capabilities and version info to adopt the session,
* or `undefined` to reject the request.
*
* This callback is never called for `initialize` requests — those follow
* the normal handshake path. Note that {@linkcode WebStandardStreamableHTTPServerTransportOptions.onsessioninitialized | onsessioninitialized}
* is NOT called during replay — it only fires for new session creation.
*
* @param sessionId - The session ID from the `mcp-session-id` request header.
* @returns Cached client state to restore, or `undefined` if the session is unknown.
*/
replayInitialization?: (
sessionId: string
) =>
| { clientCapabilities: ClientCapabilities; clientVersion: Implementation }
| undefined
| Promise<{ clientCapabilities: ClientCapabilities; clientVersion: Implementation } | undefined>;
}

/**
Expand Down Expand Up @@ -240,11 +272,14 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
private _enableDnsRebindingProtection: boolean;
private _retryInterval?: number;
private _supportedProtocolVersions: string[];
private _replayInitialization?: WebStandardStreamableHTTPServerTransportOptions['replayInitialization'];
private _replayInProgress = false;

sessionId?: string;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
oninitializationreplay?: (data: { clientCapabilities: ClientCapabilities; clientVersion: Implementation }) => void;

constructor(options: WebStandardStreamableHTTPServerTransportOptions = {}) {
this.sessionIdGenerator = options.sessionIdGenerator;
Expand All @@ -257,6 +292,7 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false;
this._retryInterval = options.retryInterval;
this._supportedProtocolVersions = options.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS;
this._replayInitialization = options.replayInitialization;
}

/**
Expand Down Expand Up @@ -351,6 +387,17 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
return validationError;
}

// Attempt stateless session replay before dispatching to method handlers.
let replayError: Response | undefined;
try {
replayError = await this._tryReplayInitialization(req);
} catch {
return this.createJsonErrorResponse(500, -32_603, 'Internal error: session replay failed');
}
if (replayError) {
return replayError;
}

switch (req.method) {
case 'POST': {
return this.handlePostRequest(req, options);
Expand Down Expand Up @@ -840,14 +887,45 @@ export class WebStandardStreamableHTTPServerTransport implements Transport {
return new Response(null, { status: 200 });
}

/**
* Attempts to restore session state via the `replayInitialization` callback.
* Called once in `handleRequest()` before method dispatch.
*
* No-op when already initialized, no callback provided, or no session ID header.
* On success, sets `sessionId` and `_initialized`, then invokes `oninitializationreplay`
* so the server can seed client capabilities and version info.
*/
private async _tryReplayInitialization(req: Request): Promise<Response | undefined> {
if (this._initialized || this._replayInProgress || !this._replayInitialization) return undefined;

const sessionId = req.headers.get('mcp-session-id');
if (!sessionId) return undefined;

this._replayInProgress = true;
try {
const result = await this._replayInitialization(sessionId);
if (!result) {
// Session unknown/expired — 404 tells the client to re-initialize per spec
this.onerror?.(new Error('Session not found'));
return this.createJsonErrorResponse(404, -32_001, 'Session not found');
}

this.sessionId = sessionId;
this._initialized = true;
this.oninitializationreplay?.(result);
return undefined;
} finally {
this._replayInProgress = false;
}
}

/**
* Validates session ID for non-initialization requests.
* Returns `Response` error if invalid, `undefined` otherwise
*/
private validateSession(req: Request): Response | undefined {
if (this.sessionIdGenerator === undefined) {
// If the sessionIdGenerator ID is not set, the session management is disabled
// and we don't need to validate the session ID
if (this.sessionIdGenerator === undefined && this.sessionId === undefined && !this._replayInitialization) {
// Session management is fully disabled (no generator, no adopted/replayed session, no replay callback)
return undefined;
}
if (!this._initialized) {
Expand Down
99 changes: 98 additions & 1 deletion packages/server/test/server/server.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { JSONRPCMessage } from '@modelcontextprotocol/core';
import type { ClientCapabilities, Implementation, JSONRPCMessage } from '@modelcontextprotocol/core';
import { InMemoryTransport, LATEST_PROTOCOL_VERSION } from '@modelcontextprotocol/core';
import { Server } from '../../src/server/server.js';

Expand Down Expand Up @@ -39,4 +39,101 @@ describe('Server', () => {
await server.close();
});
});

describe('connect — oninitializationreplay hook', () => {
const testCapabilities: ClientCapabilities = { sampling: {} };
const testVersion: Implementation = { name: 'test-client', version: '2.0.0' };

it('should seed getClientCapabilities() when transport.oninitializationreplay is called', async () => {
const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: {} });

const [, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);

// Simulate transport calling oninitializationreplay (as _tryReplayInitialization would)
serverTransport.oninitializationreplay?.({
clientCapabilities: testCapabilities,
clientVersion: testVersion
});

expect(server.getClientCapabilities()).toEqual(testCapabilities);
expect(server.getClientVersion()).toEqual(testVersion);

await server.close();
});

it('should return undefined for getClientCapabilities() when oninitializationreplay is not called', async () => {
const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: {} });

const [, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);

expect(server.getClientCapabilities()).toBeUndefined();
expect(server.getClientVersion()).toBeUndefined();

await server.close();
});

it('should be overwritten by a real initialize handshake', async () => {
const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: {} });

const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);

// First: seed via oninitializationreplay
serverTransport.oninitializationreplay?.({
clientCapabilities: testCapabilities,
clientVersion: testVersion
});

expect(server.getClientCapabilities()).toEqual(testCapabilities);

// Then: real initialize overwrites
const responsePromise = new Promise<JSONRPCMessage>(resolve => {
clientTransport.onmessage = msg => resolve(msg);
});
await clientTransport.start();

const realCapabilities: ClientCapabilities = { elicitation: { form: {} } };
const realVersion: Implementation = { name: 'real-client', version: '3.0.0' };

await clientTransport.send({
jsonrpc: '2.0',
id: 1,
method: 'initialize',
params: {
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: realCapabilities,
clientInfo: realVersion
}
} as JSONRPCMessage);

await responsePromise;

expect(server.getClientCapabilities()).toEqual(realCapabilities);
expect(server.getClientVersion()).toEqual(realVersion);

await server.close();
});

it('should chain with an existing transport.oninitializationreplay callback', async () => {
const server = new Server({ name: 'test', version: '1.0.0' }, { capabilities: {} });

const [, serverTransport] = InMemoryTransport.createLinkedPair();

const existingCallback = vi.fn();
serverTransport.oninitializationreplay = existingCallback;

await server.connect(serverTransport);

const data = { clientCapabilities: testCapabilities, clientVersion: testVersion };
serverTransport.oninitializationreplay?.(data);

// Both the existing callback and the server's hook should have fired
expect(existingCallback).toHaveBeenCalledWith(data);
expect(server.getClientCapabilities()).toEqual(testCapabilities);

await server.close();
});
});
});
Loading
Loading