diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f025385..8aa9f96a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ All notable changes to `mcp/sdk` will be documented in this file. ----- * Allow overriding the default name pattern for Discovery +* Add `CorsMiddleware`, `DnsRebindingProtectionMiddleware`, and `ProtocolVersionMiddleware` for `StreamableHttpTransport`, composed automatically as the default stack via `StreamableHttpTransport::defaultMiddleware()` +* **[BC BREAK]** `StreamableHttpTransport` constructor: `$corsHeaders` parameter removed; CORS is now configured via `CorsMiddleware`. The `$middleware` parameter is nullable — `null` (or omitted) installs the default stack; `[]` disables all defaults. Default `Access-Control-Allow-Origin` is no longer set (was `*`). 0.5.0 ----- diff --git a/docs/transports.md b/docs/transports.md index a68875d9..3708e350 100644 --- a/docs/transports.md +++ b/docs/transports.md @@ -110,8 +110,8 @@ $transport = new StreamableHttpTransport( - **`request`** (required): `ServerRequestInterface` - The incoming PSR-7 HTTP request - **`responseFactory`** (optional): `ResponseFactoryInterface` - PSR-17 factory for creating HTTP responses. Auto-discovered if not provided. - **`streamFactory`** (optional): `StreamFactoryInterface` - PSR-17 factory for creating response body streams. Auto-discovered if not provided. -- **`corsHeaders`** (optional): `array` - Custom CORS headers to override defaults. Merges with secure defaults. Defaults to `[]`. - **`logger`** (optional): `LoggerInterface` - PSR-3 logger for debugging. Defaults to `NullLogger`. +- **`middleware`** (optional): `iterable|null` - PSR-15 middleware chain. `null` (omitted) installs the [default stack](#default-middleware). `[]` disables all defaults — useful when the surrounding application already handles CORS, host validation, etc. ### PSR-17 Auto-Discovery @@ -137,56 +137,109 @@ $psr17Factory = new Psr17Factory(); $transport = new StreamableHttpTransport($request, $psr17Factory, $psr17Factory); ``` -### CORS Configuration +### Default Middleware + +When the `middleware` argument is omitted (or set to `null`), the transport installs a secure default stack: -The transport sets secure CORS defaults that can be customized or disabled: +| Order | Middleware | Purpose | +|-------|------------|---------| +| 1 | `CorsMiddleware` | Applies CORS headers to every response. By default does **not** set `Access-Control-Allow-Origin` (cross-origin requests are blocked). | +| 2 | `DnsRebindingProtectionMiddleware` | Validates `Origin`/`Host` against an allowlist. Defaults to localhost variants only. | +| 3 | `ProtocolVersionMiddleware` | Rejects requests carrying an unsupported `MCP-Protocol-Version` header with `400 Bad Request`. | ```php -// Default CORS headers (backward compatible) -$transport = new StreamableHttpTransport($request, $responseFactory, $streamFactory); +// Zero-config, secure-by-default — local servers get full protection automatically. +$transport = new StreamableHttpTransport($request); +``` -// Restrict to specific origin -$transport = new StreamableHttpTransport( - $request, - $responseFactory, - $streamFactory, - ['Access-Control-Allow-Origin' => 'https://myapp.com'] -); +The default stack can be inspected and recomposed via the public factory: + +```php +$middleware = StreamableHttpTransport::defaultMiddleware(); +``` + +### CORS Configuration + +CORS is handled by `CorsMiddleware`. To enable cross-origin browser requests, configure it explicitly and pass it +in place of (or alongside) the defaults: -// Disable CORS for proxy scenarios +```php +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; +use Mcp\Server\Transport\StreamableHttpTransport; + +// Reflect a specific origin $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - ['Access-Control-Allow-Origin' => ''] + middleware: [ + new CorsMiddleware(allowedOrigins: ['https://myapp.com']), + new DnsRebindingProtectionMiddleware(), + new ProtocolVersionMiddleware(), + ], ); -// Custom headers with logger +// Allow all origins (development only) $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - [ - 'Access-Control-Allow-Origin' => 'https://api.example.com', - 'Access-Control-Max-Age' => '86400' + middleware: [ + new CorsMiddleware(allowedOrigins: ['*']), + new DnsRebindingProtectionMiddleware(), + new ProtocolVersionMiddleware(), ], - $logger ); ``` -Default CORS headers: -- `Access-Control-Allow-Origin: *` -- `Access-Control-Allow-Methods: GET, POST, DELETE, OPTIONS` -- `Access-Control-Allow-Headers: Content-Type, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, Authorization, Accept` +When the allowlist is a concrete set of origins (not `['*']`), `CorsMiddleware` automatically adds `Vary: Origin` +so shared caches/CDNs do not serve a response generated for one origin to a request from another. + +Headers already present on a response (e.g. set by inner middleware) are preserved — `CorsMiddleware` only adds +defaults when they are absent. + +> [!IMPORTANT] +> `Access-Control-Allow-Origin: *` is incompatible with credentialed browser requests (those carrying +> `Authorization`, cookies, or client certificates). If your MCP server runs OAuth/Bearer auth and serves +> a browser client, configure `allowedOrigins` with the explicit origin(s) you trust rather than `['*']`. +> The middleware reflects the matching origin verbatim, which is the form browsers accept with credentials. -### PSR-15 Middleware +### DNS Rebinding Protection -`StreamableHttpTransport` can run a PSR-15 middleware chain before it processes the request. Middleware can log, -enforce auth, or short-circuit with a response for any HTTP method. +`DnsRebindingProtectionMiddleware` validates the `Origin` header against an allowlist (falling back to `Host` +when `Origin` is absent). The default allowlist is localhost-only: + +```php +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; + +new DnsRebindingProtectionMiddleware(allowedHosts: ['myapp.local', 'mcp.internal']); +``` + +If the server is fronted by a reverse proxy that already validates `Host`, drop this middleware from the chain +or supply a permissive allowlist. + +### Protocol Version Validation + +`ProtocolVersionMiddleware` rejects requests whose `MCP-Protocol-Version` header is not in the SDK's supported +set with `400 Bad Request`. Requests without the header pass through, since the `initialize` round-trip and some +legacy clients do not send it. + +```php +use Mcp\Schema\Enum\ProtocolVersion; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; + +// Only accept the latest spec version +new ProtocolVersionMiddleware(supportedVersions: [ProtocolVersion::V2025_11_25]); +``` + +### Custom PSR-15 Middleware + +`StreamableHttpTransport` accepts any PSR-15 middleware chain. To extend the defaults, spread them and append +your own middleware — the defaults stay outermost so CORS headers are applied to every response, including +short-circuited ones: ```php use Mcp\Server\Transport\StreamableHttpTransport; use Psr\Http\Message\ResponseFactoryInterface; +use Psr\Http\Message\ResponseInterface; use Psr\Http\Message\ServerRequestInterface; use Psr\Http\Server\MiddlewareInterface; use Psr\Http\Server\RequestHandlerInterface; @@ -197,7 +250,7 @@ final class AuthMiddleware implements MiddlewareInterface { } - public function process(ServerRequestInterface $request, RequestHandlerInterface $handler) + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface { if (!$request->hasHeader('Authorization')) { return $this->responses->createResponse(401); @@ -209,15 +262,40 @@ final class AuthMiddleware implements MiddlewareInterface $transport = new StreamableHttpTransport( $request, - $responseFactory, - $streamFactory, - [], - $logger, - [new AuthMiddleware($responseFactory)], + logger: $logger, + middleware: [ + ...StreamableHttpTransport::defaultMiddleware(), + new AuthMiddleware($responseFactory), + ], ); ``` -If middleware returns a response, the transport will still ensure CORS headers are present unless you set them yourself. +To selectively drop one default (for example DNS rebinding when running behind a proxy), filter the default list: + +```php +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\StreamableHttpTransport; + +$transport = new StreamableHttpTransport( + $request, + middleware: [ + ...array_filter( + StreamableHttpTransport::defaultMiddleware(), + fn ($m) => !$m instanceof DnsRebindingProtectionMiddleware, + ), + new AuthMiddleware($responseFactory), + ], +); +``` + +Pass `middleware: []` to disable every default and run only your own chain: + +```php +$transport = new StreamableHttpTransport( + $request, + middleware: [new AuthMiddleware($responseFactory)], +); +``` ### Architecture diff --git a/examples/server/oauth-keycloak/server.php b/examples/server/oauth-keycloak/server.php index bdd22b90..fdaae7a0 100644 --- a/examples/server/oauth-keycloak/server.php +++ b/examples/server/oauth-keycloak/server.php @@ -58,7 +58,12 @@ $transport = new StreamableHttpTransport( (new Psr17Factory())->createServerRequestFromGlobals(), logger: logger(), - middleware: [$metadataMiddleware, $authMiddleware, new OAuthRequestMetaMiddleware()], + middleware: [ + ...StreamableHttpTransport::defaultMiddleware(), + $metadataMiddleware, + $authMiddleware, + new OAuthRequestMetaMiddleware(), + ], ); $response = $server->run($transport); diff --git a/examples/server/oauth-microsoft/server.php b/examples/server/oauth-microsoft/server.php index 419817cc..c4fae598 100644 --- a/examples/server/oauth-microsoft/server.php +++ b/examples/server/oauth-microsoft/server.php @@ -81,7 +81,13 @@ $transport = new StreamableHttpTransport( (new Psr17Factory())->createServerRequestFromGlobals(), logger: logger(), - middleware: [$oauthProxyMiddleware, $metadataMiddleware, $authMiddleware, new OAuthRequestMetaMiddleware()], + middleware: [ + ...StreamableHttpTransport::defaultMiddleware(), + $oauthProxyMiddleware, + $metadataMiddleware, + $authMiddleware, + new OAuthRequestMetaMiddleware(), + ], ); $response = $server->run($transport); diff --git a/src/Server/Transport/Http/JsonRpcErrorResponse.php b/src/Server/Transport/Http/JsonRpcErrorResponse.php new file mode 100644 index 00000000..ad4ebc83 --- /dev/null +++ b/src/Server/Transport/Http/JsonRpcErrorResponse.php @@ -0,0 +1,42 @@ +createResponse($statusCode) + ->withHeader('Content-Type', 'application/json') + ->withBody($streamFactory->createStream($body)); + } +} diff --git a/src/Server/Transport/Http/Middleware/CorsMiddleware.php b/src/Server/Transport/Http/Middleware/CorsMiddleware.php new file mode 100644 index 00000000..b4680031 --- /dev/null +++ b/src/Server/Transport/Http/Middleware/CorsMiddleware.php @@ -0,0 +1,125 @@ + + */ +final class CorsMiddleware implements MiddlewareInterface +{ + private readonly bool $isWildcard; + private readonly bool $varyOnOrigin; + private readonly string $allowedMethodsHeader; + private readonly string $allowedHeadersHeader; + private readonly ?string $exposedHeadersHeader; + + /** + * @param list $allowedOrigins Origins permitted for cross-origin requests. Empty disables `Access-Control-Allow-Origin`. Use `['*']` to allow any origin. + * @param list $allowedMethods Methods advertised via `Access-Control-Allow-Methods` + * @param list $allowedHeaders Headers advertised via `Access-Control-Allow-Headers` + * @param list $exposedHeaders Headers advertised via `Access-Control-Expose-Headers` + */ + public function __construct( + private readonly array $allowedOrigins = [], + array $allowedMethods = ['GET', 'POST', 'DELETE', 'OPTIONS'], + array $allowedHeaders = [ + 'Accept', + 'Authorization', + 'Content-Type', + 'Last-Event-ID', + StreamableHttpTransport::PROTOCOL_VERSION_HEADER, + StreamableHttpTransport::SESSION_HEADER, + ], + array $exposedHeaders = [StreamableHttpTransport::SESSION_HEADER], + ) { + $this->isWildcard = \in_array('*', $allowedOrigins, true); + $this->varyOnOrigin = [] !== $allowedOrigins && !$this->isWildcard; + $this->allowedMethodsHeader = implode(', ', $allowedMethods); + $this->allowedHeadersHeader = implode(', ', $allowedHeaders); + $this->exposedHeadersHeader = [] === $exposedHeaders ? null : implode(', ', $exposedHeaders); + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $response = $handler->handle($request); + + $allowedOrigin = $this->resolveAllowedOrigin($request->getHeaderLine('Origin')); + if (null !== $allowedOrigin && !$response->hasHeader('Access-Control-Allow-Origin')) { + $response = $response->withHeader('Access-Control-Allow-Origin', $allowedOrigin); + } + + if ($this->varyOnOrigin) { + $response = $this->ensureVaryOrigin($response); + } + + if (!$response->hasHeader('Access-Control-Allow-Methods')) { + $response = $response->withHeader('Access-Control-Allow-Methods', $this->allowedMethodsHeader); + } + + if (!$response->hasHeader('Access-Control-Allow-Headers')) { + $response = $response->withHeader('Access-Control-Allow-Headers', $this->allowedHeadersHeader); + } + + if (null !== $this->exposedHeadersHeader && !$response->hasHeader('Access-Control-Expose-Headers')) { + $response = $response->withHeader('Access-Control-Expose-Headers', $this->exposedHeadersHeader); + } + + return $response; + } + + private function resolveAllowedOrigin(string $origin): ?string + { + if ([] === $this->allowedOrigins) { + return null; + } + + if ($this->isWildcard) { + return '*'; + } + + if ('' !== $origin && \in_array($origin, $this->allowedOrigins, true)) { + return $origin; + } + + return null; + } + + private function ensureVaryOrigin(ResponseInterface $response): ResponseInterface + { + $current = $response->getHeaderLine('Vary'); + + if ('' === $current) { + return $response->withHeader('Vary', 'Origin'); + } + + if ('*' === trim($current) || false !== stripos($current, 'origin')) { + return $response; + } + + return $response->withHeader('Vary', $current.', Origin'); + } +} diff --git a/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php b/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php new file mode 100644 index 00000000..9e279849 --- /dev/null +++ b/src/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddleware.php @@ -0,0 +1,108 @@ + + */ +final class DnsRebindingProtectionMiddleware implements MiddlewareInterface +{ + private ResponseFactoryInterface $responseFactory; + private StreamFactoryInterface $streamFactory; + + /** @var list */ + private readonly array $allowedHosts; + + /** + * @param list $allowedHosts Hostnames (without port) that are permitted. Defaults to localhost variants. + * @param ResponseFactoryInterface|null $responseFactory PSR-17 response factory (auto-discovered if null) + * @param StreamFactoryInterface|null $streamFactory PSR-17 stream factory (auto-discovered if null) + */ + public function __construct( + array $allowedHosts = ['localhost', '127.0.0.1', '[::1]', '::1'], + ?ResponseFactoryInterface $responseFactory = null, + ?StreamFactoryInterface $streamFactory = null, + ) { + $this->allowedHosts = array_values(array_map('strtolower', $allowedHosts)); + $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); + $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $origin = $request->getHeaderLine('Origin'); + if ('' !== $origin) { + if (!$this->isAllowedOrigin($origin)) { + return $this->createForbiddenResponse('Forbidden: Invalid Origin header.'); + } + + return $handler->handle($request); + } + + $host = $request->getHeaderLine('Host'); + if ('' !== $host && !$this->isAllowedHost($host)) { + return $this->createForbiddenResponse('Forbidden: Invalid Host header.'); + } + + return $handler->handle($request); + } + + private function isAllowedOrigin(string $origin): bool + { + $parsed = parse_url($origin); + if (false === $parsed || !isset($parsed['host'])) { + return false; + } + + return \in_array(strtolower($parsed['host']), $this->allowedHosts, true); + } + + private function isAllowedHost(string $host): bool + { + if (str_starts_with($host, '[')) { + $closingBracket = strpos($host, ']'); + if (false === $closingBracket) { + return false; + } + $hostname = substr($host, 0, $closingBracket + 1); + } else { + $hostname = explode(':', $host, 2)[0]; + } + + return \in_array(strtolower($hostname), $this->allowedHosts, true); + } + + private function createForbiddenResponse(string $message): ResponseInterface + { + return JsonRpcErrorResponse::create($this->responseFactory, $this->streamFactory, 403, $message); + } +} diff --git a/src/Server/Transport/Http/Middleware/ProtocolVersionMiddleware.php b/src/Server/Transport/Http/Middleware/ProtocolVersionMiddleware.php new file mode 100644 index 00000000..1cb4ecf2 --- /dev/null +++ b/src/Server/Transport/Http/Middleware/ProtocolVersionMiddleware.php @@ -0,0 +1,97 @@ + + */ +final class ProtocolVersionMiddleware implements MiddlewareInterface +{ + private ResponseFactoryInterface $responseFactory; + private StreamFactoryInterface $streamFactory; + + /** @var list */ + private readonly array $supportedVersions; + + /** + * @param list|null $supportedVersions Versions the server accepts. Defaults to all values of {@see ProtocolVersion}. + * @param ResponseFactoryInterface|null $responseFactory PSR-17 response factory (auto-discovered if null) + * @param StreamFactoryInterface|null $streamFactory PSR-17 stream factory (auto-discovered if null) + */ + public function __construct( + ?array $supportedVersions = null, + ?ResponseFactoryInterface $responseFactory = null, + ?StreamFactoryInterface $streamFactory = null, + ) { + $versions = $supportedVersions ?? ProtocolVersion::cases(); + $this->supportedVersions = array_values(array_map(static fn (ProtocolVersion $v): string => $v->value, $versions)); + $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); + $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $headerValue = $request->getHeaderLine(StreamableHttpTransport::PROTOCOL_VERSION_HEADER); + + // Spec backwards-compat: when the header is absent, the server SHOULD assume + // protocol version 2025-03-26 — the release in which Streamable HTTP and the + // header itself were introduced. This is deliberately lower than the SDK's + // own default (V2025_06_18) so clients predating the header convention still + // get a deterministic protocol version applied. Servers that whitelist only + // newer versions in $supportedVersions will reject such requests with 400. + $version = '' === $headerValue ? ProtocolVersion::V2025_03_26->value : $headerValue; + + if (\in_array($version, $this->supportedVersions, true)) { + return $handler->handle($request); + } + + $message = '' === $headerValue + ? \sprintf( + 'Missing %s header; backwards-compat default %s is not accepted. Supported versions: %s.', + StreamableHttpTransport::PROTOCOL_VERSION_HEADER, + $version, + implode(', ', $this->supportedVersions), + ) + : \sprintf( + 'Unsupported %s header value: %s. Supported versions: %s.', + StreamableHttpTransport::PROTOCOL_VERSION_HEADER, + $headerValue, + implode(', ', $this->supportedVersions), + ); + + return JsonRpcErrorResponse::create($this->responseFactory, $this->streamFactory, 400, $message); + } +} diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index 62e82ae4..dd40af40 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -14,6 +14,9 @@ use Http\Discovery\Psr17FactoryDiscovery; use Mcp\Exception\InvalidArgumentException; use Mcp\Schema\JsonRpc\Error; +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; use Mcp\Server\Transport\Http\MiddlewareRequestHandler; use Psr\Http\Message\ResponseFactoryInterface; use Psr\Http\Message\ResponseInterface; @@ -30,16 +33,8 @@ */ class StreamableHttpTransport extends BaseTransport { - private const SESSION_HEADER = 'Mcp-Session-Id'; - - private const ALLOWED_HEADER = [ - 'Accept', - 'Authorization', - 'Content-Type', - 'Last-Event-ID', - 'Mcp-Protocol-Version', - self::SESSION_HEADER, - ]; + public const SESSION_HEADER = 'Mcp-Session-Id'; + public const PROTOCOL_VERSION_HEADER = 'Mcp-Protocol-Version'; private ResponseFactoryInterface $responseFactory; private StreamFactoryInterface $streamFactory; @@ -47,42 +42,39 @@ class StreamableHttpTransport extends BaseTransport private ?string $immediateResponse = null; private ?int $immediateStatusCode = null; - /** @var array */ - private array $corsHeaders; - /** @var list */ - private array $middleware = []; + private array $middleware; /** - * @param array $corsHeaders - * @param iterable $middleware + * @param iterable|null $middleware `null` installs {@see self::defaultMiddleware()}; `[]` disables all middleware */ public function __construct( private ServerRequestInterface $request, ?ResponseFactoryInterface $responseFactory = null, ?StreamFactoryInterface $streamFactory = null, - array $corsHeaders = [], ?LoggerInterface $logger = null, - iterable $middleware = [], + ?iterable $middleware = null, ) { parent::__construct($logger); $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); - $this->corsHeaders = array_merge([ - 'Access-Control-Allow-Origin' => '*', - 'Access-Control-Allow-Methods' => 'GET, POST, DELETE, OPTIONS', - 'Access-Control-Allow-Headers' => implode(',', self::ALLOWED_HEADER), - 'Access-Control-Expose-Headers' => self::SESSION_HEADER, - ], $corsHeaders); + $this->middleware = self::normalizeMiddleware($middleware ?? self::defaultMiddleware()); + } - foreach ($middleware as $m) { - if (!$m instanceof MiddlewareInterface) { - throw new InvalidArgumentException('Streamable HTTP middleware must implement Psr\\Http\\Server\\MiddlewareInterface.'); - } - $this->middleware[] = $m; - } + /** + * Secure default middleware stack applied when no `$middleware` is provided to the constructor. + * + * @return list + */ + public static function defaultMiddleware(): array + { + return [ + new CorsMiddleware(), + new DnsRebindingProtectionMiddleware(), + new ProtocolVersionMiddleware(), + ]; } public function send(string $data, array $context): void @@ -98,7 +90,7 @@ public function listen(): ResponseInterface \Closure::fromCallable([$this, 'handleRequest']), ); - return $this->withCorsHeaders($handler->handle($this->request)); + return $handler->handle($this->request); } protected function handleOptionsRequest(): ResponseInterface @@ -274,15 +266,22 @@ protected function createErrorResponse(Error $jsonRpcError, int $statusCode): Re return $response; } - protected function withCorsHeaders(ResponseInterface $response): ResponseInterface + /** + * @param iterable $middleware + * + * @return list + */ + private static function normalizeMiddleware(iterable $middleware): array { - foreach ($this->corsHeaders as $name => $value) { - if (!$response->hasHeader($name)) { - $response = $response->withHeader($name, $value); + $normalized = []; + foreach ($middleware as $m) { + if (!$m instanceof MiddlewareInterface) { + throw new InvalidArgumentException('Streamable HTTP middleware must implement Psr\\Http\\Server\\MiddlewareInterface.'); } + $normalized[] = $m; } - return $response; + return $normalized; } private function handleRequest(ServerRequestInterface $request): ResponseInterface diff --git a/tests/Conformance/conformance-baseline.yml b/tests/Conformance/conformance-baseline.yml index 61f9783f..efda80ab 100644 --- a/tests/Conformance/conformance-baseline.yml +++ b/tests/Conformance/conformance-baseline.yml @@ -1,5 +1,4 @@ -server: - - dns-rebinding-protection +server: [] client: - elicitation-sep1034-client-defaults diff --git a/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php new file mode 100644 index 00000000..ba82c73b --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/CorsMiddlewareTest.php @@ -0,0 +1,176 @@ +factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertTrue($response->hasHeader('Access-Control-Allow-Methods')); + $this->assertTrue($response->hasHeader('Access-Control-Allow-Headers')); + $this->assertTrue($response->hasHeader('Access-Control-Expose-Headers')); + } + + #[TestDox('wildcard allowedOrigins sets Access-Control-Allow-Origin to *')] + public function testWildcardOrigin(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + + #[TestDox('matching Origin is reflected back')] + public function testMatchingOriginIsReflected(): void + { + $middleware = new CorsMiddleware( + allowedOrigins: ['https://app.example.com', 'https://staging.example.com'], + ); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('https://app.example.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + } + + #[TestDox('non-matching Origin is not echoed')] + public function testNonMatchingOriginIsBlocked(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + } + + #[TestDox('does not overwrite headers set by inner middleware')] + public function testPreExistingHeadersAreNotOverwritten(): void + { + $inner = $this->handlerReturning(200, [ + 'Access-Control-Allow-Origin' => 'https://override.example.com', + 'Access-Control-Allow-Methods' => 'POST', + ]); + + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $inner); + + $this->assertSame('https://override.example.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertSame('POST', $response->getHeaderLine('Access-Control-Allow-Methods')); + } + + #[TestDox('exposed headers can be omitted')] + public function testEmptyExposedHeadersAreNotSet(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*'], exposedHeaders: []); + $request = $this->factory->createServerRequest('POST', 'https://example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Expose-Headers')); + } + + #[TestDox('adds Vary: Origin when reflecting a specific origin to protect caches')] + public function testVaryOriginIsAddedForReflectedOrigin(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('adds Vary: Origin even when origin is rejected so caches do not poison')] + public function testVaryOriginIsAddedEvenWhenOriginDoesNotMatch(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertSame('Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('does not add Vary when Access-Control-Allow-Origin is wildcard')] + public function testVaryOriginIsNotAddedForWildcard(): void + { + $middleware = new CorsMiddleware(allowedOrigins: ['*']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertFalse($response->hasHeader('Vary')); + } + + #[TestDox('does not add Vary when no allowed origins are configured')] + public function testVaryOriginIsNotAddedWhenAllowedOriginsEmpty(): void + { + $middleware = new CorsMiddleware(); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertFalse($response->hasHeader('Vary')); + } + + #[TestDox('preserves existing Vary value when appending Origin')] + public function testVaryOriginAppendsToExistingVary(): void + { + $inner = $this->handlerReturning(200, ['Vary' => 'Accept-Encoding']); + + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $inner); + + $this->assertSame('Accept-Encoding, Origin', $response->getHeaderLine('Vary')); + } + + #[TestDox('does not duplicate Origin in existing Vary header')] + public function testVaryOriginIsNotDuplicated(): void + { + $inner = $this->handlerReturning(200, ['Vary' => 'Accept-Encoding, Origin']); + + $middleware = new CorsMiddleware(allowedOrigins: ['https://app.example.com']); + $request = $this->factory->createServerRequest('POST', 'https://example.com') + ->withHeader('Origin', 'https://app.example.com'); + + $response = $middleware->process($request, $inner); + + $this->assertSame('Accept-Encoding, Origin', $response->getHeaderLine('Vary')); + } +} diff --git a/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php new file mode 100644 index 00000000..a7dfa0c2 --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/DnsRebindingProtectionMiddlewareTest.php @@ -0,0 +1,144 @@ + ['http://localhost:8000']; + yield 'IPv4 loopback' => ['http://127.0.0.1:3000']; + yield 'IPv6 loopback (bracketed)' => ['http://[::1]:8000']; + } + + #[DataProvider('allowedOriginProvider')] + #[TestDox('allows request with localhost Origin variant: $origin')] + public function testAllowsLocalhostOrigin(string $origin): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', $origin); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('rejects non-allowed Origin with 403')] + public function testRejectsForeignOrigin(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(403, $response->getStatusCode()); + $this->assertSame('application/json', $response->getHeaderLine('Content-Type')); + } + + #[TestDox('Origin header takes precedence over Host')] + public function testOriginPrecedenceOverHost(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Origin', 'http://localhost:8000') + ->withHeader('Host', 'evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('validates Host header when Origin is absent')] + public function testFallbackToHostValidation(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://evil/') + ->withHeader('Host', 'evil.example.com'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(403, $response->getStatusCode()); + } + + #[TestDox('strips port from Host header when validating')] + public function testHostPortIsStripped(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost:8000'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('IPv6 Host with port is parsed correctly')] + public function testIpv6HostWithPort(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', '[::1]:8080'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('custom allowed hosts permit non-localhost names')] + public function testCustomAllowedHosts(): void + { + $middleware = new DnsRebindingProtectionMiddleware( + allowedHosts: ['myapp.local'], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + $request = $this->factory->createServerRequest('POST', 'http://myapp.local/') + ->withHeader('Origin', 'http://myapp.local:3000'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('host comparison is case-insensitive')] + public function testCaseInsensitive(): void + { + $middleware = new DnsRebindingProtectionMiddleware( + allowedHosts: ['MyApp.Local'], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + $request = $this->factory->createServerRequest('POST', 'http://myapp.local/') + ->withHeader('Origin', 'http://MYAPP.LOCAL:80'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('request without Origin or Host is allowed')] + public function testNoOriginNoHostPasses(): void + { + $middleware = new DnsRebindingProtectionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/')->withoutHeader('Host'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } +} diff --git a/tests/Unit/Server/Transport/Http/Middleware/MiddlewareTestCase.php b/tests/Unit/Server/Transport/Http/Middleware/MiddlewareTestCase.php new file mode 100644 index 00000000..3f5ba2fe --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/MiddlewareTestCase.php @@ -0,0 +1,57 @@ +factory = new Psr17Factory(); + $this->passthroughHandler = $this->handlerReturning(200); + } + + /** + * @param array $headers extra headers to set on the response (already-set CORS headers etc.) + */ + protected function handlerReturning(int $status, array $headers = []): RequestHandlerInterface + { + return new class($this->factory, $status, $headers) implements RequestHandlerInterface { + /** @param array $headers */ + public function __construct( + private ResponseFactoryInterface $factory, + private int $status, + private array $headers, + ) { + } + + public function handle(ServerRequestInterface $request): ResponseInterface + { + $response = $this->factory->createResponse($this->status); + foreach ($this->headers as $name => $value) { + $response = $response->withHeader($name, $value); + } + + return $response; + } + }; + } +} diff --git a/tests/Unit/Server/Transport/Http/Middleware/ProtocolVersionMiddlewareTest.php b/tests/Unit/Server/Transport/Http/Middleware/ProtocolVersionMiddlewareTest.php new file mode 100644 index 00000000..ad216d56 --- /dev/null +++ b/tests/Unit/Server/Transport/Http/Middleware/ProtocolVersionMiddlewareTest.php @@ -0,0 +1,104 @@ +factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode()); + } + + #[TestDox('rejects missing header when 2025-03-26 backwards-compat default is not in supportedVersions')] + public function testMissingHeaderRejectedByStrictServer(): void + { + $middleware = new ProtocolVersionMiddleware( + supportedVersions: [ProtocolVersion::V2025_11_25], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + $request = $this->factory->createServerRequest('POST', 'http://localhost/'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(400, $response->getStatusCode()); + } + + #[TestDox('accepts every version declared in the ProtocolVersion enum')] + public function testAcceptsSupportedVersions(): void + { + $middleware = new ProtocolVersionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + + foreach (ProtocolVersion::cases() as $version) { + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, $version->value); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(200, $response->getStatusCode(), 'Expected '.$version->value.' to be accepted.'); + } + } + + #[TestDox('rejects unsupported well-formed version with 400')] + public function testRejectsUnsupportedVersion(): void + { + $middleware = new ProtocolVersionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, '1900-01-01'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(400, $response->getStatusCode()); + $this->assertSame('application/json', $response->getHeaderLine('Content-Type')); + } + + #[TestDox('rejects malformed version with 400')] + public function testRejectsMalformedVersion(): void + { + $middleware = new ProtocolVersionMiddleware(responseFactory: $this->factory, streamFactory: $this->factory); + $request = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, 'not-a-version'); + + $response = $middleware->process($request, $this->passthroughHandler); + + $this->assertSame(400, $response->getStatusCode()); + } + + #[TestDox('accepts only the supportedVersions whitelist when provided')] + public function testRestrictedSupportedVersions(): void + { + $middleware = new ProtocolVersionMiddleware( + supportedVersions: [ProtocolVersion::V2025_11_25], + responseFactory: $this->factory, + streamFactory: $this->factory, + ); + + $accepted = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, ProtocolVersion::V2025_11_25->value); + $rejected = $this->factory->createServerRequest('POST', 'http://localhost/') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, ProtocolVersion::V2024_11_05->value); + + $this->assertSame(200, $middleware->process($accepted, $this->passthroughHandler)->getStatusCode()); + $this->assertSame(400, $middleware->process($rejected, $this->passthroughHandler)->getStatusCode()); + } +} diff --git a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php index 7d9cd484..7d4fdd0d 100644 --- a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php +++ b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php @@ -11,9 +11,12 @@ namespace Mcp\Tests\Unit\Server\Transport; +use Mcp\Exception\InvalidArgumentException; +use Mcp\Server\Transport\Http\Middleware\CorsMiddleware; +use Mcp\Server\Transport\Http\Middleware\DnsRebindingProtectionMiddleware; +use Mcp\Server\Transport\Http\Middleware\ProtocolVersionMiddleware; use Mcp\Server\Transport\StreamableHttpTransport; use Nyholm\Psr7\Factory\Psr17Factory; -use PHPUnit\Framework\Attributes\DataProvider; use PHPUnit\Framework\Attributes\TestDox; use PHPUnit\Framework\TestCase; use Psr\Http\Message\ResponseFactoryInterface; @@ -24,117 +27,204 @@ final class StreamableHttpTransportTest extends TestCase { - public static function corsHeaderProvider(): iterable + private Psr17Factory $factory; + + protected function setUp(): void + { + $this->factory = new Psr17Factory(); + } + + #[TestDox('default middleware is applied when none is passed')] + public function testDefaultMiddlewareIsAppliedWhenOmitted(): void + { + $request = $this->factory + ->createServerRequest('OPTIONS', 'http://localhost/') + ->withHeader('Host', 'localhost'); + + $transport = new StreamableHttpTransport($request, $this->factory, $this->factory); + + $response = $transport->listen(); + + // Default CORS middleware exposes Methods/Headers/Expose but no Allow-Origin (secure-by-default). + $this->assertSame(204, $response->getStatusCode()); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); + $this->assertNotSame('', $response->getHeaderLine('Access-Control-Allow-Headers')); + $this->assertNotSame('', $response->getHeaderLine('Access-Control-Expose-Headers')); + } + + #[TestDox('default middleware blocks non-localhost Origin')] + public function testDefaultMiddlewareBlocksRebindingAttempt(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost') + ->withHeader('Origin', 'http://evil.example.com'); + + $transport = new StreamableHttpTransport($request, $this->factory, $this->factory); + + $response = $transport->listen(); + + $this->assertSame(403, $response->getStatusCode()); + } + + #[TestDox('default middleware rejects unsupported MCP-Protocol-Version')] + public function testDefaultMiddlewareRejectsUnsupportedProtocolVersion(): void { - yield 'GET (middleware returns 401)' => ['GET', false, 401]; - yield 'POST (middleware returns 401)' => ['POST', false, 401]; - yield 'DELETE (middleware returns 401)' => ['DELETE', false, 401]; - yield 'OPTIONS (middleware delegates -> transport handles preflight)' => ['OPTIONS', true, 204]; - yield 'GET (middleware delegates -> transport handles preflight)' => ['GET', true, 405]; - yield 'POST (middleware delegates -> transport handles preflight)' => ['POST', true, 202]; - yield 'DELETE (middleware delegates -> transport handles preflight)' => ['DELETE', true, 400]; + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost') + ->withHeader(StreamableHttpTransport::PROTOCOL_VERSION_HEADER, '1900-01-01'); + + $transport = new StreamableHttpTransport($request, $this->factory, $this->factory); + + $response = $transport->listen(); + + $this->assertSame(400, $response->getStatusCode()); } - #[DataProvider('corsHeaderProvider')] - #[TestDox('CORS headers are always applied')] - public function testCorsHeader(string $method, bool $middlewareDelegatesToTransport, int $expectedStatusCode): void + #[TestDox('explicit empty middleware list disables all defaults')] + public function testEmptyMiddlewareListDisablesDefaults(): void { - $factory = new Psr17Factory(); - $request = $factory->createServerRequest($method, 'https://example.com'); - - $middleware = new class($factory, $expectedStatusCode, $middlewareDelegatesToTransport) implements MiddlewareInterface { - public function __construct( - private ResponseFactoryInterface $responseFactory, - private int $expectedStatusCode, - private bool $middlewareDelegatesToTransport, - ) { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'evil.example.com') + ->withHeader('Origin', 'http://evil.example.com'); + + $transport = new StreamableHttpTransport( + $request, + $this->factory, + $this->factory, + null, + [], + ); + + $response = $transport->listen(); + + // No CORS, no DNS rebinding check — transport just answers. + $this->assertNotSame(403, $response->getStatusCode()); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertFalse($response->hasHeader('Access-Control-Allow-Methods')); + } + + #[TestDox('custom middleware composes with default stack via spread')] + public function testDefaultsCanBeSpreadAndExtended(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost'); + + $authStub = new class($this->factory) implements MiddlewareInterface { + public function __construct(private ResponseFactoryInterface $factory) + { } public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface { - if ($this->middlewareDelegatesToTransport) { - return $handler->handle($request); - } - - return $this->responseFactory->createResponse($this->expectedStatusCode); + return $this->factory->createResponse(401); } }; $transport = new StreamableHttpTransport( $request, - $factory, - $factory, - [], + $this->factory, + $this->factory, null, - [$middleware], + [ + ...StreamableHttpTransport::defaultMiddleware(), + $authStub, + ], ); $response = $transport->listen(); - $this->assertSame($expectedStatusCode, $response->getStatusCode(), $response->getBody()->getContents()); - $this->assertTrue($response->hasHeader('Access-Control-Allow-Origin')); - $this->assertTrue($response->hasHeader('Access-Control-Allow-Methods')); - $this->assertTrue($response->hasHeader('Access-Control-Allow-Headers')); - $this->assertTrue($response->hasHeader('Access-Control-Expose-Headers')); - - $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertSame(401, $response->getStatusCode()); + // CORS middleware is outermost — its headers must still be applied to the 401. $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); - $this->assertSame( - 'Accept,Authorization,Content-Type,Last-Event-ID,Mcp-Protocol-Version,Mcp-Session-Id', - $response->getHeaderLine('Access-Control-Allow-Headers') - ); - $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); } - #[TestDox('transport replaces existing CORS headers on the response')] - public function testCorsHeadersAreReplacedWhenAlreadyPresent(): void + #[TestDox('defaults can be filtered to drop DNS rebinding for proxy deployments')] + public function testDefaultsCanBeFilteredToDropDnsRebinding(): void { - $factory = new Psr17Factory(); - $request = $factory->createServerRequest('GET', 'https://example.com'); + // Behind a reverse proxy: real Host is api.myapp.com, browser Origin is myapp.com. + // DnsRebindingProtectionMiddleware default (localhost-only) would 403 this — drop it. + $request = $this->factory + ->createServerRequest('POST', 'http://api.myapp.com/') + ->withHeader('Host', 'api.myapp.com') + ->withHeader('Origin', 'https://myapp.com'); - $middleware = new class($factory) implements MiddlewareInterface { - public function __construct(private ResponseFactoryInterface $responses) + $authStub = new class($this->factory) implements MiddlewareInterface { + public function __construct(private ResponseFactoryInterface $factory) { } public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface { - return $this->responses->createResponse(200) - ->withHeader('Access-Control-Allow-Origin', 'https://another.com'); + if ('' === $request->getHeaderLine('Authorization')) { + return $this->factory->createResponse(401); + } + + return $handler->handle($request); } }; $transport = new StreamableHttpTransport( $request, - $factory, - $factory, - [], + $this->factory, + $this->factory, null, - [$middleware], + [ + ...array_filter( + StreamableHttpTransport::defaultMiddleware(), + static fn (MiddlewareInterface $m): bool => !$m instanceof DnsRebindingProtectionMiddleware, + ), + $authStub, + ], ); $response = $transport->listen(); - $this->assertSame(200, $response->getStatusCode()); - - $this->assertSame('https://another.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + // Auth short-circuits with 401 — proves DNS rebinding didn't reject the request first. + $this->assertSame(401, $response->getStatusCode()); + // CORS middleware is still in the chain — Methods header attached to the 401. $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); - $this->assertSame( - 'Accept,Authorization,Content-Type,Last-Event-ID,Mcp-Protocol-Version,Mcp-Session-Id', - $response->getHeaderLine('Access-Control-Allow-Headers') + // ProtocolVersionMiddleware also still in the chain — would have rejected a bad header. + } + + #[TestDox('configured CorsMiddleware reflects matching Origin')] + public function testConfiguredCorsReflectsMatchingOrigin(): void + { + $request = $this->factory + ->createServerRequest('POST', 'http://localhost/') + ->withHeader('Host', 'localhost') + ->withHeader('Origin', 'https://myapp.example.com'); + + $transport = new StreamableHttpTransport( + $request, + $this->factory, + $this->factory, + null, + [ + new CorsMiddleware(allowedOrigins: ['https://myapp.example.com']), + new DnsRebindingProtectionMiddleware(allowedHosts: ['localhost']), + new ProtocolVersionMiddleware(), + ], ); - $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); + + $response = $transport->listen(); + + $this->assertSame('https://myapp.example.com', $response->getHeaderLine('Access-Control-Allow-Origin')); } #[TestDox('middleware runs before transport handles the request')] public function testMiddlewareRunsBeforeTransportHandlesRequest(): void { - $factory = new Psr17Factory(); - $request = $factory->createServerRequest('OPTIONS', 'https://example.com'); + $request = $this->factory->createServerRequest('OPTIONS', 'http://localhost/') + ->withHeader('Host', 'localhost'); $state = new \stdClass(); $state->called = false; - $middleware = new class($state) implements MiddlewareInterface { + $spy = new class($state) implements MiddlewareInterface { public function __construct(private \stdClass $state) { } @@ -149,11 +239,10 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $transport = new StreamableHttpTransport( $request, - $factory, - $factory, - [], + $this->factory, + $this->factory, null, - [$middleware], + [$spy], ); $response = $transport->listen(); @@ -161,4 +250,20 @@ public function process(ServerRequestInterface $request, RequestHandlerInterface $this->assertTrue($state->called); $this->assertSame(204, $response->getStatusCode()); } + + #[TestDox('non-middleware entries are rejected')] + public function testInvalidMiddlewareEntryThrows(): void + { + $request = $this->factory->createServerRequest('POST', 'http://localhost/'); + + $this->expectException(InvalidArgumentException::class); + + new StreamableHttpTransport( + $request, + $this->factory, + $this->factory, + null, + [new \stdClass()], // @phpstan-ignore-line argument.type + ); + } }