diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java index 07d86f40e..57beb9bf3 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -177,6 +177,8 @@ private static class DefaultInitialization implements Initialization { */ private final AtomicReference mcpClientSession; + private final AtomicReference initializeRequestId = new AtomicReference<>(); + private DefaultInitialization() { this.initSink = Sinks.one(); this.result = new AtomicReference<>(); @@ -235,6 +237,16 @@ public boolean isInitialized() { return this.currentInitializationResult() != null; } + /** + * Returns the request ID of the initialize request, if one has been issued. Used to + * prevent cancellation of the initialize request per spec. + * @return the initialize request ID, or null + */ + public Object getInitializeRequestId() { + DefaultInitialization current = this.initializationRef.get(); + return current != null ? current.initializeRequestId.get() : null; + } + public McpSchema.InitializeResult currentInitializationResult() { DefaultInitialization current = this.initializationRef.get(); McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; @@ -305,8 +317,11 @@ private Mono doInitialize(DefaultInitialization init McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion, this.clientCapabilities, this.clientInfo); - Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, - initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF); + McpClientSession.RequestMono requestMono = mcpClientSession.sendRequestWithId( + McpSchema.METHOD_INITIALIZE, initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF); + initialization.initializeRequestId.set(requestMono.requestId()); + + Mono result = requestMono.response(); return result.flatMap(initializeResult -> { logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 93fcc332a..2fcd93d52 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -13,6 +13,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import io.modelcontextprotocol.client.LifecycleInitializer.Initialization; @@ -22,6 +23,7 @@ import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpRequestHandle; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; @@ -449,6 +451,61 @@ public Mono ping() { init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } + // -------------------------- + // Cancellation + // -------------------------- + + /** + * Cancels a previously issued request by its ID. Sends a + * {@code notifications/cancelled} notification to the server and errors the pending + * response locally. + * @param requestId The ID of the request to cancel + * @param reason An optional human-readable reason for the cancellation + * @return A Mono that completes when the cancellation notification is sent + */ + public Mono cancelRequest(Object requestId, String reason) { + if (!this.isInitialized()) { + return Mono.error(new IllegalStateException("Cannot cancel request before initialization")); + } + Object initId = this.initializer.getInitializeRequestId(); + if (initId != null && initId.equals(requestId)) { + return Mono.error(new IllegalArgumentException("The initialize request MUST NOT be cancelled")); + } + return this.initializer.withInitialization("cancelling request", + init -> init.mcpSession().sendCancellation(requestId, reason)); + } + + /** + * Calls a tool and returns a handle that can be used to cancel the request. + * @param callToolRequest The request containing the tool name and input parameters + * @return A McpRequestHandle containing the request ID, response Mono, and a cancel + * function + */ + public McpRequestHandle callToolWithHandle(McpSchema.CallToolRequest callToolRequest) { + AtomicReference requestIdRef = new AtomicReference<>(); + + Mono responseMono = this.initializer.withInitialization("calling tool with handle", + init -> { + if (init.initializeResult().capabilities().tools() == null) { + return Mono.error(new IllegalStateException("Server does not provide tools capability")); + } + McpClientSession.RequestMono rm = init.mcpSession() + .sendRequestWithId(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF); + requestIdRef.set(rm.requestId()); + return rm.response() + .flatMap(result -> Mono.just(validateToolResult(callToolRequest.name(), result))); + }); + + return McpRequestHandle.lazy(requestIdRef, responseMono, reason -> { + String id = requestIdRef.get(); + if (id == null) { + return Mono.error(new IllegalStateException( + "Cannot cancel: request has not been issued yet. Subscribe to the response Mono first.")); + } + return this.cancelRequest(id, reason); + }); + } + // -------------------------- // Roots // -------------------------- diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 7fdaa8941..c658797dd 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -11,6 +11,7 @@ import org.slf4j.LoggerFactory; import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.spec.McpRequestHandle; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; @@ -218,6 +219,20 @@ public Object ping() { return withProvidedContext(this.delegate.ping()).block(); } + // -------------------------- + // Cancellation + // -------------------------- + + /** + * Cancels a previously issued request. IMPORTANT: This method MUST be called from a + * different thread than the one blocked waiting on the request result. + * @param requestId The ID of the request to cancel + * @param reason An optional human-readable reason for the cancellation + */ + public void cancelRequest(Object requestId, String reason) { + withProvidedContext(this.delegate.cancelRequest(requestId, reason)).block(); + } + // -------------------------- // Tools // -------------------------- @@ -234,7 +249,16 @@ public Object ping() { */ public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) { return withProvidedContext(this.delegate.callTool(callToolRequest)).block(); + } + /** + * Calls a tool with a specific timeout. + * @param callToolRequest The request containing the tool name and input parameters + * @param timeout The maximum duration to wait for the result + * @return The tool execution result + */ + public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest, Duration timeout) { + return withProvidedContext(this.delegate.callTool(callToolRequest)).timeout(timeout).block(); } /** diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java index d1b55f594..846769e94 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java @@ -11,6 +11,7 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; +import java.util.HashMap; import java.util.Map; class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler { @@ -24,7 +25,11 @@ class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler { public DefaultMcpStatelessServerHandler(Map> requestHandlers, Map notificationHandlers) { this.requestHandlers = requestHandlers; - this.notificationHandlers = notificationHandlers; + this.notificationHandlers = new HashMap<>(notificationHandlers); + this.notificationHandlers.putIfAbsent(McpSchema.METHOD_NOTIFICATION_CANCELLED, (ctx, params) -> { + logger.debug("Ignoring cancellation in stateless mode"); + return Mono.empty(); + }); } @Override diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 23285d514..7253399db 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -183,6 +183,16 @@ private Map prepareNotificationHandlers(McpServe notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_CANCELLED, (exchange, params) -> { + if (features.cancellationConsumer() != null) { + McpSchema.CancelledNotification cancelled = jsonMapper.convertValue(params, + new TypeRef() { + }); + return features.cancellationConsumer().apply(exchange, cancelled); + } + return Mono.empty(); + }); + List, Mono>> rootsChangeConsumers = features .rootsChangeConsumers(); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index a15c58cd5..aa6e911a9 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -247,6 +247,17 @@ public Mono ping() { return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF); } + /** + * Cancels a previously issued request to the client (server-to-client direction). + * Sends a {@code notifications/cancelled} notification. + * @param requestId The ID of the request to cancel + * @param reason An optional human-readable reason for the cancellation + * @return A Mono that completes when the cancellation notification is sent + */ + public Mono cancelRequest(Object requestId, String reason) { + return this.session.sendCancellation(requestId, reason); + } + /** * Set the minimum logging level for the client. Messages below this level will be * filtered out. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java index 7fe9ef2a2..acf1a2b4b 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -237,7 +237,7 @@ private SingleSessionAsyncSpecification(McpServerTransportProvider transportProv public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); + this.instructions, this.cancellationConsumer); var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(); @@ -265,7 +265,7 @@ public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider t public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); + this.instructions, this.cancellationConsumer); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(); return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, @@ -333,6 +333,8 @@ abstract class AsyncSpecification> { final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + BiFunction> cancellationConsumer; + Duration requestTimeout = Duration.ofHours(10); // Default timeout public abstract McpAsyncServer build(); @@ -805,6 +807,20 @@ public AsyncSpecification rootsChangeHandlers( return this.rootsChangeHandlers(Arrays.asList(handlers)); } + /** + * Registers an optional consumer that is invoked when a + * {@code notifications/cancelled} notification is received from the client. The + * session layer already handles the core cancellation logic; this consumer is for + * application-level side-effects (e.g. logging, metrics, UI updates). + * @param consumer The cancellation consumer + * @return This builder instance for method chaining + */ + public AsyncSpecification cancellationConsumer( + BiFunction> consumer) { + this.cancellationConsumer = consumer; + return this; + } + /** * Sets the JsonMapper to use for serializing and deserializing JSON messages. * @param jsonMapper the mapper to use. Must not be null. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index fe0608b1c..cd66bf015 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -45,19 +45,11 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s Map prompts, Map completions, List, Mono>> rootsChangeConsumers, - String instructions) { + String instructions, + BiFunction> cancellationConsumer) { /** - * Create an instance and validate the arguments. - * @param serverInfo The server implementation details - * @param serverCapabilities The server capabilities - * @param tools The list of tool specifications - * @param resources The map of resource specifications - * @param resourceTemplates The map of resource templates - * @param prompts The map of prompt specifications - * @param rootsChangeConsumers The list of consumers that will be notified when - * the roots list changes - * @param instructions The server instructions text + * Backwards-compatible constructor without cancellationConsumer. */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, @@ -66,6 +58,21 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s Map completions, List, Mono>> rootsChangeConsumers, String instructions) { + this(serverInfo, serverCapabilities, tools, resources, resourceTemplates, prompts, completions, + rootsChangeConsumers, instructions, null); + } + + /** + * Create an instance and validate the arguments. + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + Map resourceTemplates, + Map prompts, + Map completions, + List, Mono>> rootsChangeConsumers, + String instructions, + BiFunction> cancellationConsumer) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -89,6 +96,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); this.instructions = instructions; + this.cancellationConsumer = cancellationConsumer; } /** diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 0b9115b79..6cf22ccfe 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -143,4 +143,13 @@ public Object ping() { return this.exchange.ping().block(); } + /** + * Cancels a previously issued request to the client (server-to-client direction). + * @param requestId The ID of the request to cancel + * @param reason An optional human-readable reason for the cancellation + */ + public void cancelRequest(Object requestId, String reason) { + this.exchange.cancelRequest(requestId, reason).block(); + } + } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 0ba7ab3b8..d7931298e 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -9,6 +9,8 @@ import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; @@ -56,6 +58,15 @@ public class McpClientSession implements McpSession { /** Map of notification handlers keyed by method name */ private final ConcurrentHashMap notificationHandlers = new ConcurrentHashMap<>(); + /** + * Tracks in-progress inbound requests (requests the server sent to us) for + * cancellation + */ + private final ConcurrentHashMap inProgressInbound = new ConcurrentHashMap<>(); + + private static final TypeRef CANCELLED_NOTIFICATION_TYPE_REF = new TypeRef<>() { + }; + /** Session-specific prefix for request IDs */ private final String sessionPrefix = UUID.randomUUID().toString().substring(0, 8); @@ -134,6 +145,22 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); + this.notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_CANCELLED, params -> { + McpSchema.CancelledNotification cancelled = transport.unmarshalFrom(params, + CANCELLED_NOTIFICATION_TYPE_REF); + if (cancelled != null) { + String normalizedId = String.valueOf(cancelled.requestId()); + logger.debug("Received cancellation for request {}: {}", normalizedId, cancelled.reason()); + + var inbound = this.inProgressInbound.remove(normalizedId); + if (inbound != null && !inbound.isDisposed()) { + inbound.dispose(); + } + } + + return Mono.empty(); + }); + this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe(); } @@ -176,7 +203,15 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, jsonRpcError); return Mono.just(errorResponse); - }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { + }).flatMap(this.transport::sendMessage).doOnSubscribe(sub -> { + if (request.id() != null) { + inProgressInbound.put(String.valueOf(request.id()), Disposables.composite(sub::cancel)); + } + }).doFinally(signal -> { + if (request.id() != null) { + inProgressInbound.remove(String.valueOf(request.id())); + } + }).onErrorComplete(t -> { logger.warn("Issue sending response to the client, ", t); return true; }).subscribe(); @@ -289,6 +324,79 @@ public Mono sendRequest(String method, Object requestParams, TypeRef t }); } + /** + * Composite holding a request ID and the response Mono, enabling cancellation by ID. + * + * @param The response type + * @param requestId The generated request ID + * @param response The Mono that will emit the response + */ + public record RequestMono(String requestId, Mono response) { + } + + /** + * Sends a request and exposes the generated request ID for cancellation support. + * @param The expected response type + * @param method The method name to call + * @param requestParams The request parameters + * @param typeRef Type reference for response deserialization + * @return A RequestMono containing both the request ID and the response Mono + */ + public RequestMono sendRequestWithId(String method, Object requestParams, TypeRef typeRef) { + String requestId = this.generateRequestId(); + Mono mono = Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> { + logger.debug("Sending message for method {}", method); + this.pendingResponses.put(requestId, pendingResponseSink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest).contextWrite(ctx).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + pendingResponseSink.error(error); + }); + })).timeout(this.requestTimeout).handle((jsonRpcResponse, deliveredResponseSink) -> { + if (jsonRpcResponse.error() != null) { + logger.error("Error handling request: {}", jsonRpcResponse.error()); + deliveredResponseSink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + deliveredResponseSink.complete(); + } + else { + deliveredResponseSink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + return new RequestMono<>(requestId, mono); + } + + /** + * Cancels a previously issued outbound request (client-to-server direction). The + * pending response is errored locally and a cancellation notification is sent to the + * server. + * @param requestId The ID of the request to cancel + * @param reason An optional human-readable reason + * @return A Mono that completes when the cancellation notification is sent + */ + public Mono sendCancellation(Object requestId, String reason) { + return Mono.defer(() -> { + var pending = this.pendingResponses.remove(requestId); + if (pending != null) { + pending.error( + new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.REQUEST_CANCELLED, + "Request cancelled locally" + (reason != null ? ": " + reason : ""), null))); + } + return this + .sendNotification(McpSchema.METHOD_NOTIFICATION_CANCELLED, + new McpSchema.CancelledNotification(requestId, reason)) + .onErrorResume(e -> { + logger.warn("Failed to send cancellation notification for request {}", requestId, e); + return Mono.empty(); + }); + }); + } + /** * Sends a JSON-RPC notification. * @param method The method name for the notification diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpRequestHandle.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpRequestHandle.java new file mode 100644 index 000000000..d6db0d3c7 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpRequestHandle.java @@ -0,0 +1,86 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.Supplier; + +import reactor.core.publisher.Mono; + +/** + * A handle to a pending MCP request that allows cancellation without leaking session + * internals. The cancel function is a closure over the session's sendCancellation method. + * + * @param the response type of the request + */ +public final class McpRequestHandle { + + private final Supplier requestIdSupplier; + + private final Mono responseMono; + + private final Function> cancelFunction; + + /** + * Creates a handle with a known request ID. + * @param requestId the request ID (may be null) + * @param responseMono the Mono that will emit the response + * @param cancelFunction the function to invoke for cancellation + */ + public McpRequestHandle(Object requestId, Mono responseMono, Function> cancelFunction) { + this.requestIdSupplier = () -> requestId; + this.responseMono = responseMono; + this.cancelFunction = cancelFunction; + } + + private McpRequestHandle(Supplier requestIdSupplier, Mono responseMono, + Function> cancelFunction) { + this.requestIdSupplier = requestIdSupplier; + this.responseMono = responseMono; + this.cancelFunction = cancelFunction; + } + + /** + * Creates a handle with a lazily-resolved request ID. The ID may not be available + * until the response Mono is subscribed to. + * @param the response type + * @param requestIdRef an AtomicReference that will be populated with the request ID + * @param responseMono the Mono that will emit the response + * @param cancelFunction the function to invoke for cancellation + * @return the handle + */ + public static McpRequestHandle lazy(AtomicReference requestIdRef, Mono responseMono, + Function> cancelFunction) { + return new McpRequestHandle<>(requestIdRef::get, responseMono, cancelFunction); + } + + /** + * Returns the ID of the underlying request. May return {@code null} if the request + * has not yet been issued (i.e. the response Mono has not been subscribed to). + * @return the request ID, or null if not yet available + */ + public Object requestId() { + return requestIdSupplier.get(); + } + + /** + * Returns the Mono that will emit the response. + * @return the response Mono + */ + public Mono response() { + return responseMono; + } + + /** + * Cancel this request. Sends a cancellation notification to the other party. + * @param reason an optional human-readable reason for the cancellation + * @return a Mono that completes when the cancellation notification is sent + */ + public Mono cancel(String reason) { + return this.cancelFunction.apply(reason); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 97bde0b10..ad9d5ef5d 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -102,6 +102,9 @@ private McpSchema() { public static final String METHOD_NOTIFICATION_ROOTS_LIST_CHANGED = "notifications/roots/list_changed"; + // Cancellation + public static final String METHOD_NOTIFICATION_CANCELLED = "notifications/cancelled"; + // Sampling Methods public static final String METHOD_SAMPLING_CREATE_MESSAGE = "sampling/createMessage"; @@ -146,6 +149,11 @@ public static final class ErrorCodes { */ public static final int RESOURCE_NOT_FOUND = -32002; + /** + * The request was cancelled. + */ + public static final int REQUEST_CANCELLED = -32800; + } /** @@ -182,8 +190,8 @@ public sealed interface Result extends Meta permits InitializeResult, ListResour } - public sealed interface Notification extends Meta - permits ProgressNotification, LoggingMessageNotification, ResourcesUpdatedNotification { + public sealed interface Notification extends Meta permits ProgressNotification, LoggingMessageNotification, + ResourcesUpdatedNotification, CancelledNotification { } @@ -2302,6 +2310,26 @@ public ResourcesUpdatedNotification(String uri) { } } + /** + * Notification sent by either side to cancel a previously-issued request. Per the MCP + * spec, cancellation MUST only reference requests issued in the same direction. + * + * @param requestId The ID of the request to cancel. + * @param reason An optional human-readable reason for the cancellation. + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record CancelledNotification(// @formatter:off + @JsonProperty("requestId") Object requestId, + @JsonProperty("reason") String reason, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + + public CancelledNotification(Object requestId, String reason) { + this(requestId, reason, null); + } + } + /** * The Model Context Protocol (MCP) provides a standardized way for servers to send * structured log messages to clients. Clients can control logging verbosity by diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 241f7d8b5..aa54346ea 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -20,6 +20,8 @@ import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; @@ -63,6 +65,11 @@ public class McpServerSession implements McpLoggableSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + private final ConcurrentHashMap inProgressInbound = new ConcurrentHashMap<>(); + + private static final TypeRef CANCELLED_NOTIFICATION_TYPE_REF = new TypeRef<>() { + }; + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; /** @@ -232,7 +239,15 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { jsonRpcError); // TODO: Should the error go to SSE or back as POST return? return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage); + }).flatMap(this.transport::sendMessage).doOnSubscribe(sub -> { + if (request.id() != null) { + inProgressInbound.put(String.valueOf(request.id()), Disposables.composite(sub::cancel)); + } + }).doFinally(signal -> { + if (request.id() != null) { + inProgressInbound.remove(String.valueOf(request.id())); + } + }); } else if (message instanceof McpSchema.JSONRPCNotification notification) { // TODO handle errors for communication to without initialization @@ -314,6 +329,20 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti clientInfo.get(), transportContext)); } + if (McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(notification.method())) { + McpSchema.CancelledNotification cancelled = transport.unmarshalFrom(notification.params(), + CANCELLED_NOTIFICATION_TYPE_REF); + if (cancelled != null) { + String normalizedId = String.valueOf(cancelled.requestId()); + logger.debug("Client cancelled request {}: {}", normalizedId, cancelled.reason()); + var inbound = this.inProgressInbound.remove(normalizedId); + if (inbound != null && !inbound.isDisposed()) { + logger.debug("Disposing in-progress request pipeline for {}", normalizedId); + inbound.dispose(); + } + } + } + var handler = notificationHandlers.get(notification.method()); if (handler == null) { logger.warn("No handler registered for notification method: {}", notification); @@ -343,6 +372,32 @@ private MethodNotFoundError getMethodNotFoundError(String method) { return new MethodNotFoundError(method, "Method not found: " + method, null); } + /** + * Cancels a previously issued outbound request (server-to-client direction). The + * pending response is errored locally and a cancellation notification is sent to the + * client. + * @param requestId The ID of the request to cancel + * @param reason An optional human-readable reason + * @return A Mono that completes when the cancellation notification is sent + */ + public Mono sendCancellation(Object requestId, String reason) { + return Mono.defer(() -> { + var pending = this.pendingResponses.remove(requestId); + if (pending != null) { + pending.error( + new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.REQUEST_CANCELLED, + "Request cancelled locally" + (reason != null ? ": " + reason : ""), null))); + } + return this + .sendNotification(McpSchema.METHOD_NOTIFICATION_CANCELLED, + new McpSchema.CancelledNotification(requestId, reason)) + .onErrorResume(e -> { + logger.warn("Failed to send cancellation notification for request {}", requestId, e); + return Mono.empty(); + }); + }); + } + @Override public Mono closeGracefully() { // TODO: clear pendingResponses and emit errors? diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 767ed673e..d49d36ed7 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -66,6 +66,24 @@ default Mono sendNotification(String method) { */ Mono sendNotification(String method, Object params); + /** + * Cancels a previously issued outbound request. The pending response is errored + * locally and a {@code notifications/cancelled} notification is sent to the other + * party. + * + *

+ * Implementations that track pending responses should override this to also error the + * pending response before sending the notification. + *

+ * @param requestId the ID of the request to cancel + * @param reason an optional human-readable reason for the cancellation + * @return a Mono that completes when the cancellation notification has been sent + */ + default Mono sendCancellation(Object requestId, String reason) { + return sendNotification(McpSchema.METHOD_NOTIFICATION_CANCELLED, + new McpSchema.CancelledNotification(requestId, reason)); + } + /** * Closes the session and releases any associated resources asynchronously. * @return a {@link Mono} that completes when the session has been closed. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 95f8959f5..3783e977c 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -23,6 +23,8 @@ import io.modelcontextprotocol.server.McpRequestHandler; import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; import io.modelcontextprotocol.util.Assert; +import reactor.core.Disposable; +import reactor.core.Disposables; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; @@ -60,6 +62,11 @@ public class McpStreamableServerSession implements McpLoggableSession { private final MissingMcpTransportSession missingMcpTransportSession; + private final ConcurrentHashMap inProgressInbound = new ConcurrentHashMap<>(); + + private static final TypeRef CANCELLED_NOTIFICATION_TYPE_REF = new TypeRef<>() { + }; + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; /** @@ -173,7 +180,7 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStr new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, error.message(), error.data()))); } - return requestHandler + Mono pipeline = requestHandler .handle(new McpAsyncServerExchange(this.id, stream, clientCapabilities.get(), clientInfo.get(), transportContext), jsonrpcRequest.params()) .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, @@ -190,6 +197,16 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStr }) .flatMap(transport::sendMessage) .then(transport.closeGracefully()); + + return pipeline.doOnSubscribe(sub -> { + if (jsonrpcRequest.id() != null) { + inProgressInbound.put(String.valueOf(jsonrpcRequest.id()), Disposables.composite(sub::cancel)); + } + }).doFinally(signal -> { + if (jsonrpcRequest.id() != null) { + inProgressInbound.remove(String.valueOf(jsonrpcRequest.id())); + } + }); }); } @@ -201,6 +218,19 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStr public Mono accept(McpSchema.JSONRPCNotification notification) { return Mono.deferContextual(ctx -> { McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + + if (McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(notification.method())) { + McpSchema.CancelledNotification cancelled = unmarshalCancelled(notification.params()); + if (cancelled != null) { + String normalizedId = String.valueOf(cancelled.requestId()); + logger.debug("Client cancelled request {}: {}", normalizedId, cancelled.reason()); + var inProgress = this.inProgressInbound.remove(normalizedId); + if (inProgress != null && !inProgress.isDisposed()) { + inProgress.dispose(); + } + } + } + McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); if (notificationHandler == null) { logger.warn("No handler registered for notification method: {}", notification); @@ -213,6 +243,16 @@ public Mono accept(McpSchema.JSONRPCNotification notification) { } + private McpSchema.CancelledNotification unmarshalCancelled(Object params) { + if (params instanceof Map map) { + Object requestId = map.get("requestId"); + Object reasonValue = map.get("reason"); + String reason = reasonValue instanceof String s ? s : null; + return new McpSchema.CancelledNotification(requestId, reason); + } + return null; + } + /** * Handle the MCP response. * @param response MCP response to the server-initiated request diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java index 6f7390f19..04d269e9d 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerPostInitializationHookTests.java @@ -33,7 +33,7 @@ /** * Tests for {@link LifecycleInitializer} postInitializationHook functionality. - * + * * @author Christian Tzolov */ class LifecycleInitializerPostInitializationHookTests { @@ -70,6 +70,8 @@ void setUp() { when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(inv -> new McpClientSession.RequestMono<>("init-1", Mono.just(MOCK_INIT_RESULT))); when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) .thenReturn(Mono.empty()); when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); @@ -163,8 +165,9 @@ void shouldFailInitializationWhenPostInitializationHookFails() { @Test void shouldNotInvokePostInitializationHookWhenInitializationFails() { - when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) - .thenReturn(Mono.error(new RuntimeException("Initialization failed"))); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(inv -> new McpClientSession.RequestMono<>("init-1", + Mono.error(new RuntimeException("Initialization failed")))); StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java index 787ee9480..fb70ecd7f 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -71,6 +71,8 @@ void setUp() { when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(inv -> new McpClientSession.RequestMono<>("init-1", Mono.just(MOCK_INIT_RESULT))); when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) .thenReturn(Mono.empty()); when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); @@ -122,8 +124,8 @@ void shouldInitializeSuccessfully() { }) .verifyComplete(); - verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(McpSchema.InitializeRequest.class), - any()); + verify(mockClientSession).sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), + any(McpSchema.InitializeRequest.class), any()); verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); } @@ -131,10 +133,11 @@ void shouldInitializeSuccessfully() { void shouldUseLatestProtocolVersionInInitializeRequest() { AtomicReference capturedRequest = new AtomicReference<>(); - when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { - capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); - return Mono.just(MOCK_INIT_RESULT); - }); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return new McpClientSession.RequestMono<>("init-1", Mono.just(MOCK_INIT_RESULT)); + }); StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { @@ -153,8 +156,8 @@ void shouldFailForUnsupportedProtocolVersion() { McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions"); - when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) - .thenReturn(Mono.just(unsupportedResult)); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(inv -> new McpClientSession.RequestMono<>("init-1", Mono.just(unsupportedResult))); StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) @@ -173,8 +176,9 @@ void shouldTimeoutOnSlowInitialization() { LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier, mockPostInitializationHook); - when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) - .thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler)); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(inv -> new McpClientSession.RequestMono<>("init-1", + Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler))); StepVerifier .withVirtualTime(() -> shortTimeoutInitializer.withInitialization("test", @@ -199,7 +203,7 @@ void shouldReuseExistingInitialization() { // Verify session was created only once verify(mockSessionSupplier, times(1)).apply(any(ContextView.class)); - verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + verify(mockClientSession, times(1)).sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any()); } @Test @@ -228,16 +232,15 @@ void shouldHandleConcurrentInitializationRequests() { // Should only create one session despite concurrent requests assertThat(sessionCreationCount.get()).isEqualTo(1); - verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + verify(mockClientSession, times(1)).sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any()); } @Test void shouldHandleInitializationFailure() { - when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) - // fail once - .thenReturn(Mono.error(new RuntimeException("Connection failed"))) - // succeeds on the second call - .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(inv -> new McpClientSession.RequestMono<>("init-1", + Mono.error(new RuntimeException("Connection failed")))) + .thenAnswer(inv -> new McpClientSession.RequestMono<>("init-2", Mono.just(MOCK_INIT_RESULT))); StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .expectError(RuntimeException.class) @@ -340,11 +343,14 @@ void shouldSetProtocolVersionsForTesting() { AtomicReference capturedRequest = new AtomicReference<>(); - when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { - capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); - return Mono.just(new McpSchema.InitializeResult("4.0.0", McpSchema.ServerCapabilities.builder().build(), - new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions")); - }); + when(mockClientSession.sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return new McpClientSession.RequestMono<>("init-1", + Mono.just( + new McpSchema.InitializeResult("4.0.0", McpSchema.ServerCapabilities.builder().build(), + new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions"))); + }); StepVerifier.create(initializer.withInitialization("test", init -> Mono.just(init.initializeResult()))) .assertNext(result -> { @@ -394,7 +400,7 @@ void shouldHandleNotificationFailure() { .expectError(RuntimeException.class) .verify(); - verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + verify(mockClientSession).sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any()); verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); } @@ -421,7 +427,7 @@ void shouldReinitializeAfterTransportSessionException() { // Verify two separate initializations occurred verify(mockSessionSupplier, times(2)).apply(any(ContextView.class)); - verify(mockClientSession, times(2)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + verify(mockClientSession, times(2)).sendRequestWithId(eq(McpSchema.METHOD_INITIALIZE), any(), any()); } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index 640d34c9c..5d8cc18b7 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -695,4 +695,46 @@ void testPingMultipleCalls() { verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } + // --------------------------------------- + // Cancel Request Tests + // --------------------------------------- + + @Test + void testCancelRequestDelegatesToSendCancellation() { + when(mockSession.sendCancellation(any(), any())).thenReturn(Mono.empty()); + + StepVerifier.create(exchange.cancelRequest("req-123", "user aborted")).verifyComplete(); + + verify(mockSession, times(1)).sendCancellation("req-123", "user aborted"); + } + + @Test + void testCancelRequestWithNullReason() { + when(mockSession.sendCancellation(any(), any())).thenReturn(Mono.empty()); + + StepVerifier.create(exchange.cancelRequest("req-456", null)).verifyComplete(); + + verify(mockSession, times(1)).sendCancellation("req-456", null); + } + + @Test + void testCancelRequestWithSessionError() { + when(mockSession.sendCancellation(any(), any())) + .thenReturn(Mono.error(new RuntimeException("Transport error"))); + + StepVerifier.create(exchange.cancelRequest("req-789", "timeout")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("Transport error"); + }); + } + + @Test + void testCancelRequestMultipleTimes() { + when(mockSession.sendCancellation(any(), any())).thenReturn(Mono.empty()); + + StepVerifier.create(exchange.cancelRequest("req-1", "first cancel")).verifyComplete(); + StepVerifier.create(exchange.cancelRequest("req-1", "second cancel")).verifyComplete(); + + verify(mockSession, times(2)).sendCancellation(eq("req-1"), any()); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java index 069d0f896..e88065972 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -689,4 +689,35 @@ void testPingMultipleCalls() { verify(mockSession, times(2)).sendRequest(eq(McpSchema.METHOD_PING), eq(null), any(TypeRef.class)); } + // --------------------------------------- + // Cancel Request Tests + // --------------------------------------- + + @Test + void testCancelRequestDelegatesToSendCancellation() { + when(mockSession.sendCancellation(any(), any())).thenReturn(Mono.empty()); + + exchange.cancelRequest("req-sync-1", "test reason"); + + verify(mockSession, times(1)).sendCancellation("req-sync-1", "test reason"); + } + + @Test + void testCancelRequestWithNullReason() { + when(mockSession.sendCancellation(any(), any())).thenReturn(Mono.empty()); + + exchange.cancelRequest("req-sync-2", null); + + verify(mockSession, times(1)).sendCancellation("req-sync-2", null); + } + + @Test + void testCancelRequestWithSessionError() { + when(mockSession.sendCancellation(any(), any())) + .thenReturn(Mono.error(new RuntimeException("Transport error"))); + + assertThatThrownBy(() -> exchange.cancelRequest("req-sync-3", "timeout")).isInstanceOf(RuntimeException.class) + .hasMessage("Transport error"); + } + } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionCancellationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionCancellationTests.java new file mode 100644 index 000000000..72a80706e --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpClientSessionCancellationTests.java @@ -0,0 +1,271 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.json.TypeRef; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for cancellation support in {@link McpClientSession}. + */ +class McpClientSessionCancellationTests { + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + private static final String ECHO_METHOD = "echo"; + + TypeRef responseType = new TypeRef<>() { + }; + + // ------------------------------------------------------------------ + // sendRequestWithId + // ------------------------------------------------------------------ + + @Test + void sendRequestWithIdReturnsRequestIdAndResponseMono() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + McpClientSession.RequestMono rm = session.sendRequestWithId("test.method", "param", responseType); + + assertThat(rm.requestId()).isNotNull(); + assertThat(rm.requestId()).isNotEmpty(); + assertThat(rm.response()).isNotNull(); + + StepVerifier.create(rm.response()).then(() -> { + McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); + assertThat(request.id()).isEqualTo(rm.requestId()); + assertThat(request.method()).isEqualTo("test.method"); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), "response-data", null)); + }).expectNext("response-data").verifyComplete(); + + session.close(); + } + + @Test + void sendRequestWithIdHandlesErrorResponse() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + McpClientSession.RequestMono rm = session.sendRequestWithId("test.method", "param", responseType); + + StepVerifier.create(rm.response()).then(() -> { + McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INTERNAL_ERROR, "Server error", null); + transport.simulateIncomingMessage( + new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, rm.requestId(), null, error)); + }).expectError(McpError.class).verify(); + + session.close(); + } + + @Test + void sendRequestWithIdGeneratesUniqueIds() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + McpClientSession.RequestMono rm1 = session.sendRequestWithId("m1", "p1", responseType); + McpClientSession.RequestMono rm2 = session.sendRequestWithId("m2", "p2", responseType); + + assertThat(rm1.requestId()).isNotEqualTo(rm2.requestId()); + + session.close(); + } + + // ------------------------------------------------------------------ + // sendCancellation – outbound cancellation (client cancels own request) + // ------------------------------------------------------------------ + + @Test + void sendCancellationErrorsPendingResponseAndSendsNotification() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + McpClientSession.RequestMono rm = session.sendRequestWithId("test.method", "param", responseType); + + StepVerifier.create(rm.response()).then(() -> { + session.sendCancellation(rm.requestId(), "user aborted").block(Duration.ofSeconds(2)); + }).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class); + McpError mcpError = (McpError) error; + assertThat(mcpError.getJsonRpcError()).isNotNull(); + assertThat(mcpError.getJsonRpcError().code()).isEqualTo(McpSchema.ErrorCodes.REQUEST_CANCELLED); + }).verify(TIMEOUT); + + McpSchema.JSONRPCMessage lastSent = transport.getLastSentMessage(); + assertThat(lastSent).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) lastSent; + assertThat(notification.method()).isEqualTo(McpSchema.METHOD_NOTIFICATION_CANCELLED); + + session.close(); + } + + @Test + void sendCancellationForUnknownRequestStillSendsNotification() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + StepVerifier.create(session.sendCancellation("non-existent-id", "cleanup")).verifyComplete(); + + McpSchema.JSONRPCMessage lastSent = transport.getLastSentMessage(); + assertThat(lastSent).isInstanceOf(McpSchema.JSONRPCNotification.class); + McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) lastSent; + assertThat(notification.method()).isEqualTo(McpSchema.METHOD_NOTIFICATION_CANCELLED); + + session.close(); + } + + @Test + void sendCancellationWithNullReasonFormatsMessageCorrectly() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + McpClientSession.RequestMono rm = session.sendRequestWithId("test.method", "param", responseType); + + StepVerifier.create(rm.response()).then(() -> { + session.sendCancellation(rm.requestId(), null).block(Duration.ofSeconds(2)); + }).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class); + McpError mcpError = (McpError) error; + assertThat(mcpError.getJsonRpcError().code()).isEqualTo(McpSchema.ErrorCodes.REQUEST_CANCELLED); + assertThat(mcpError.getMessage()).doesNotContain("null"); + }).verify(TIMEOUT); + + session.close(); + } + + // ------------------------------------------------------------------ + // Inbound cancellation – server cancels a request it sent to us + // ------------------------------------------------------------------ + + @Test + void inboundCancellationDisposesInProgressRequest() throws InterruptedException { + CountDownLatch handlerStarted = new CountDownLatch(1); + AtomicBoolean handlerCompleted = new AtomicBoolean(false); + + Map> requestHandlers = Map.of(ECHO_METHOD, params -> { + handlerStarted.countDown(); + return Mono.delay(Duration.ofSeconds(10)).map(l -> { + handlerCompleted.set(true); + return params; + }); + }); + + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of(), Function.identity()); + + McpSchema.JSONRPCRequest incomingRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "server-req-1", "hello"); + transport.simulateIncomingMessage(incomingRequest); + + assertThat(handlerStarted.await(2, TimeUnit.SECONDS)).isTrue(); + + McpSchema.CancelledNotification cancelledNotification = new McpSchema.CancelledNotification("server-req-1", + "server timeout"); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, cancelledNotification); + transport.simulateIncomingMessage(notification); + + Thread.sleep(200); + assertThat(handlerCompleted.get()).isFalse(); + + session.close(); + } + + @Test + void inboundCancellationForNonExistentRequestDoesNotThrow() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + McpSchema.CancelledNotification cancelledNotification = new McpSchema.CancelledNotification("unknown-id", + "cleanup"); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, cancelledNotification); + + transport.simulateIncomingMessage(notification); + + session.close(); + } + + @Test + void inboundCancellationWithNullDeserialization() { + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(), Function.identity()); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, null); + + transport.simulateIncomingMessage(notification); + + session.close(); + } + + // ------------------------------------------------------------------ + // RequestMono record + // ------------------------------------------------------------------ + + @Test + void requestMonoRecordFieldAccess() { + Mono mono = Mono.just("test"); + McpClientSession.RequestMono rm = new McpClientSession.RequestMono<>("req-123", mono); + + assertThat(rm.requestId()).isEqualTo("req-123"); + assertThat(rm.response()).isSameAs(mono); + } + + @Test + void requestMonoRecordEquality() { + Mono mono = Mono.just("test"); + McpClientSession.RequestMono rm1 = new McpClientSession.RequestMono<>("req-1", mono); + McpClientSession.RequestMono rm2 = new McpClientSession.RequestMono<>("req-1", mono); + McpClientSession.RequestMono rm3 = new McpClientSession.RequestMono<>("req-2", mono); + + assertThat(rm1).isEqualTo(rm2); + assertThat(rm1).isNotEqualTo(rm3); + } + + // ------------------------------------------------------------------ + // doOnSubscribe / doFinally lifecycle tracking + // ------------------------------------------------------------------ + + @Test + void completedRequestIsClearedFromTracking() throws InterruptedException { + Sinks.One responseSent = Sinks.one(); + + Map> requestHandlers = Map.of(ECHO_METHOD, + params -> Mono.just(params)); + var transport = new MockMcpClientTransport(); + var session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of(), Function.identity()); + + McpSchema.JSONRPCRequest incomingRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, + "req-track-1", "data"); + transport.simulateIncomingMessage(incomingRequest); + + Thread.sleep(200); + + McpSchema.CancelledNotification cancelledNotification = new McpSchema.CancelledNotification("req-track-1", + "late cancel"); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, cancelledNotification); + transport.simulateIncomingMessage(notification); + + session.close(); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpRequestHandleTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpRequestHandleTests.java new file mode 100644 index 000000000..9f347aff9 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpRequestHandleTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link McpRequestHandle}. + */ +class McpRequestHandleTests { + + @Test + void requestIdIsExposed() { + McpRequestHandle handle = new McpRequestHandle<>("req-1", Mono.just("result"), reason -> Mono.empty()); + + assertThat(handle.requestId()).isEqualTo("req-1"); + } + + @Test + void responseMonoEmitsValue() { + McpRequestHandle handle = new McpRequestHandle<>("req-1", Mono.just("result"), reason -> Mono.empty()); + + StepVerifier.create(handle.response()).expectNext("result").verifyComplete(); + } + + @Test + void cancelInvokesCancelFunction() { + AtomicReference capturedReason = new AtomicReference<>(); + McpRequestHandle handle = new McpRequestHandle<>("req-1", Mono.just("result"), reason -> { + capturedReason.set(reason); + return Mono.empty(); + }); + + StepVerifier.create(handle.cancel("user requested")).verifyComplete(); + assertThat(capturedReason.get()).isEqualTo("user requested"); + } + + @Test + void cancelWithNullReason() { + AtomicReference capturedReason = new AtomicReference<>("not-null"); + McpRequestHandle handle = new McpRequestHandle<>("req-1", Mono.just("result"), reason -> { + capturedReason.set(reason); + return Mono.empty(); + }); + + StepVerifier.create(handle.cancel(null)).verifyComplete(); + assertThat(capturedReason.get()).isNull(); + } + + @Test + void cancelPropagatesError() { + McpRequestHandle handle = new McpRequestHandle<>("req-1", Mono.just("result"), + reason -> Mono.error(new RuntimeException("cancel failed"))); + + StepVerifier.create(handle.cancel("reason")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("cancel failed"); + }); + } + + @Test + void handleWithNullRequestId() { + McpRequestHandle handle = new McpRequestHandle<>(null, Mono.just("result"), reason -> Mono.empty()); + + assertThat(handle.requestId()).isNull(); + StepVerifier.create(handle.response()).expectNext("result").verifyComplete(); + } + + @Test + void handleWithErrorResponse() { + McpRequestHandle handle = new McpRequestHandle<>("req-err", Mono.error(new McpError("server error")), + reason -> Mono.empty()); + + StepVerifier.create(handle.response()).expectError(McpError.class).verify(); + } + + @Test + void requestIdResolvesLazilyFromAtomicReference() { + AtomicReference idRef = new AtomicReference<>(); + McpRequestHandle handle = McpRequestHandle.lazy(idRef, Mono.just("result"), reason -> Mono.empty()); + + assertThat(handle.requestId()).isNull(); + + idRef.set("lazy-req-1"); + assertThat(handle.requestId()).isEqualTo("lazy-req-1"); + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpServerSessionCancellationTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpServerSessionCancellationTests.java new file mode 100644 index 000000000..dcf3147c5 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/spec/McpServerSessionCancellationTests.java @@ -0,0 +1,244 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.server.McpInitRequestHandler; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for cancellation support in {@link McpServerSession}. + */ +class McpServerSessionCancellationTests { + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + @Mock + private McpServerTransport mockTransport; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + when(mockTransport.sendMessage(any())).thenReturn(Mono.empty()); + when(mockTransport.closeGracefully()).thenReturn(Mono.empty()); + when(mockTransport.unmarshalFrom(any(), any())).thenAnswer(inv -> inv.getArgument(0)); + } + + private McpServerSession createSession(Map> requestHandlers, + Map notificationHandlers) { + McpInitRequestHandler initHandler = initializeRequest -> Mono.just(new McpSchema.InitializeResult( + McpSchema.LATEST_PROTOCOL_VERSION, McpSchema.ServerCapabilities.builder().build(), + new McpSchema.Implementation("test-server", "1.0.0"), null)); + + return new McpServerSession("test-session", TIMEOUT, mockTransport, initHandler, requestHandlers, + notificationHandlers); + } + + private void performInitialization(McpServerSession session) { + McpSchema.InitializeRequest initReq = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().build(), new McpSchema.Implementation("client", "1.0")); + McpSchema.JSONRPCRequest initRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_INITIALIZE, "init-1", initReq); + session.handle(initRequest).block(Duration.ofSeconds(2)); + + McpSchema.JSONRPCNotification initializedNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, McpSchema.METHOD_NOTIFICATION_INITIALIZED, null); + session.handle(initializedNotification).block(Duration.ofSeconds(2)); + } + + // ------------------------------------------------------------------ + // Inbound cancellation – client cancels a request it sent to the server + // ------------------------------------------------------------------ + + @Test + void cancellationNotificationIsAcceptedBySession() { + McpServerSession session = createSession(new HashMap<>(), new HashMap<>()); + + McpSchema.CancelledNotification cancelledNotification = new McpSchema.CancelledNotification("req-1", + "user cancelled"); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, cancelledNotification); + + StepVerifier.create(session.handle(notification)).verifyComplete(); + } + + @Test + void preEmptiveCancellationIsIgnoredAndResponseStillSent() throws InterruptedException { + Map> requestHandlers = new HashMap<>(); + requestHandlers.put("fast.method", (exchange, params) -> Mono.just("result")); + + McpServerSession session = createSession(requestHandlers, new HashMap<>()); + performInitialization(session); + + McpSchema.CancelledNotification cancelledNotification = new McpSchema.CancelledNotification("req-1", + "pre-emptive cancel"); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, cancelledNotification); + session.handle(notification).block(Duration.ofSeconds(2)); + + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "fast.method", + "req-1", "param"); + session.handle(request).block(Duration.ofSeconds(2)); + + Thread.sleep(100); + + ArgumentCaptor captor = ArgumentCaptor.forClass(McpSchema.JSONRPCMessage.class); + verify(mockTransport, atLeastOnce()).sendMessage(captor.capture()); + + boolean responseSent = captor.getAllValues().stream().anyMatch(msg -> { + if (msg instanceof McpSchema.JSONRPCResponse r) { + return "req-1".equals(r.id()) && r.error() == null; + } + return false; + }); + assertThat(responseSent) + .as("Pre-emptive cancellation for non-existent request is ignored per MCP spec; response should still be sent") + .isTrue(); + } + + @Test + void cancelledRequestWithNullDeserialization() { + McpServerSession session = createSession(new HashMap<>(), new HashMap<>()); + + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, null); + + StepVerifier.create(session.handle(notification)).verifyComplete(); + } + + @Test + void cancellationForNonExistentRequestDoesNotError() { + McpServerSession session = createSession(new HashMap<>(), new HashMap<>()); + + McpSchema.CancelledNotification cancelledNotification = new McpSchema.CancelledNotification("non-existent", + "cleanup"); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, cancelledNotification); + + StepVerifier.create(session.handle(notification)).verifyComplete(); + } + + // ------------------------------------------------------------------ + // sendCancellation – server cancels its own outbound request + // ------------------------------------------------------------------ + + @Test + void sendCancellationSendsNotificationToClient() { + McpServerSession session = createSession(new HashMap<>(), new HashMap<>()); + + StepVerifier.create(session.sendCancellation("outbound-req-1", "timeout")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(McpSchema.JSONRPCMessage.class); + verify(mockTransport, atLeastOnce()).sendMessage(captor.capture()); + + boolean foundCancellation = captor.getAllValues().stream().anyMatch(msg -> { + if (msg instanceof McpSchema.JSONRPCNotification n) { + return McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(n.method()); + } + return false; + }); + assertThat(foundCancellation).isTrue(); + } + + @Test + void sendCancellationErrorsPendingOutboundResponse() { + McpServerSession session = createSession(new HashMap<>(), new HashMap<>()); + + Mono outboundResponse = session.sendRequest("sampling/createMessage", Map.of("prompt", "test"), + new TypeRef<>() { + }); + + StepVerifier.create(outboundResponse).then(() -> { + session.sendCancellation("test-session-0", "no longer needed").block(Duration.ofSeconds(2)); + }).expectErrorSatisfies(error -> { + assertThat(error).isInstanceOf(McpError.class); + McpError mcpError = (McpError) error; + assertThat(mcpError.getJsonRpcError()).isNotNull(); + assertThat(mcpError.getJsonRpcError().code()).isEqualTo(McpSchema.ErrorCodes.REQUEST_CANCELLED); + }).verify(TIMEOUT); + } + + @Test + void sendCancellationWithNullReason() { + McpServerSession session = createSession(new HashMap<>(), new HashMap<>()); + + StepVerifier.create(session.sendCancellation("some-id", null)).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(McpSchema.JSONRPCMessage.class); + verify(mockTransport, atLeastOnce()).sendMessage(captor.capture()); + + boolean foundCancellation = captor.getAllValues().stream().anyMatch(msg -> { + if (msg instanceof McpSchema.JSONRPCNotification n) { + return McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(n.method()); + } + return false; + }); + assertThat(foundCancellation).isTrue(); + } + + @Test + void sendCancellationForUnknownIdStillSendsNotification() { + McpServerSession session = createSession(new HashMap<>(), new HashMap<>()); + + StepVerifier.create(session.sendCancellation("non-existent-outbound", "cleanup")).verifyComplete(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(McpSchema.JSONRPCMessage.class); + verify(mockTransport, atLeastOnce()).sendMessage(captor.capture()); + + boolean foundCancellation = captor.getAllValues().stream().anyMatch(msg -> { + if (msg instanceof McpSchema.JSONRPCNotification n) { + return McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(n.method()); + } + return false; + }); + assertThat(foundCancellation).isTrue(); + } + + // ------------------------------------------------------------------ + // Notification handler delegation + // ------------------------------------------------------------------ + + @Test + void cancellationNotificationStillInvokesRegisteredHandler() { + ConcurrentHashMap handlerCalled = new ConcurrentHashMap<>(); + Map notificationHandlers = new HashMap<>(); + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_CANCELLED, (exchange, params) -> { + handlerCalled.put("called", true); + return Mono.empty(); + }); + + McpServerSession session = createSession(new HashMap<>(), notificationHandlers); + performInitialization(session); + + McpSchema.CancelledNotification cancelledNotification = new McpSchema.CancelledNotification("req-42", + "test reason"); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_CANCELLED, cancelledNotification); + + StepVerifier.create(session.handle(notification)).verifyComplete(); + + assertThat(handlerCalled).containsKey("called"); + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index c732b1cc1..c726cc46c 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -1768,4 +1768,97 @@ void testProgressNotificationWithoutMessage() throws Exception { {"progressToken":"progress-token-789","progress":0.25}""")); } + // Cancelled Notification Tests + + @Test + void testCancelledNotificationSerialization() throws Exception { + McpSchema.CancelledNotification notification = new McpSchema.CancelledNotification("req-123", "user cancelled"); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER).isObject().isEqualTo(json(""" + {"requestId":"req-123","reason":"user cancelled"}""")); + } + + @Test + void testCancelledNotificationDeserialization() throws Exception { + McpSchema.CancelledNotification notification = JSON_MAPPER.readValue(""" + {"requestId":"req-456","reason":"timeout exceeded"} + """, McpSchema.CancelledNotification.class); + + assertThat(notification.requestId()).isEqualTo("req-456"); + assertThat(notification.reason()).isEqualTo("timeout exceeded"); + assertThat(notification.meta()).isNull(); + } + + @Test + void testCancelledNotificationWithNullReason() throws Exception { + McpSchema.CancelledNotification notification = new McpSchema.CancelledNotification("req-789", null); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER).isObject().isEqualTo(json(""" + {"requestId":"req-789"}""")); + + assertThat(notification.reason()).isNull(); + assertThat(notification.meta()).isNull(); + } + + @Test + void testCancelledNotificationWithMeta() throws Exception { + Map meta = Map.of("key", "value"); + McpSchema.CancelledNotification notification = new McpSchema.CancelledNotification("req-meta", "reason", meta); + + String value = JSON_MAPPER.writeValueAsString(notification); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER).isObject().isEqualTo(json(""" + {"requestId":"req-meta","reason":"reason","_meta":{"key":"value"}}""")); + } + + @Test + void testCancelledNotificationDeserializationIgnoresUnknownFields() throws Exception { + McpSchema.CancelledNotification notification = JSON_MAPPER.readValue(""" + {"requestId":"req-extra","reason":"test","unknownField":"ignored"} + """, McpSchema.CancelledNotification.class); + + assertThat(notification.requestId()).isEqualTo("req-extra"); + assertThat(notification.reason()).isEqualTo("test"); + } + + @Test + void testCancelledNotificationWithNumericRequestId() throws Exception { + McpSchema.CancelledNotification notification = JSON_MAPPER.readValue(""" + {"requestId":42,"reason":"numeric id test"} + """, McpSchema.CancelledNotification.class); + + assertThat(notification.requestId()).isEqualTo(42); + assertThat(notification.reason()).isEqualTo("numeric id test"); + } + + @Test + void testCancelledNotificationRoundTrip() throws Exception { + McpSchema.CancelledNotification original = new McpSchema.CancelledNotification("round-trip-id", + "round trip test"); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.CancelledNotification deserialized = JSON_MAPPER.readValue(json, + McpSchema.CancelledNotification.class); + + assertThat(deserialized.requestId()).isEqualTo(original.requestId()); + assertThat(deserialized.reason()).isEqualTo(original.reason()); + } + + @Test + void testCancelledNotificationImplementsNotificationInterface() { + McpSchema.CancelledNotification notification = new McpSchema.CancelledNotification("test", "test"); + assertThat(notification).isInstanceOf(McpSchema.Notification.class); + } + + @Test + void testErrorCodesRequestCancelled() { + assertThat(McpSchema.ErrorCodes.REQUEST_CANCELLED).isEqualTo(-32800); + } + + @Test + void testMethodNotificationCancelledConstant() { + assertThat(McpSchema.METHOD_NOTIFICATION_CANCELLED).isEqualTo("notifications/cancelled"); + } + }