Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ private static class DefaultInitialization implements Initialization {
*/
private final AtomicReference<McpClientSession> mcpClientSession;

private final AtomicReference<Object> initializeRequestId = new AtomicReference<>();

private DefaultInitialization() {
this.initSink = Sinks.one();
this.result = new AtomicReference<>();
Expand Down Expand Up @@ -235,6 +237,16 @@ public boolean isInitialized() {
return this.currentInitializationResult() != null;
}

/**
* Returns the request ID of the initialize request, if one has been issued. Used to
* prevent cancellation of the initialize request per spec.
* @return the initialize request ID, or null
*/
public Object getInitializeRequestId() {
DefaultInitialization current = this.initializationRef.get();
return current != null ? current.initializeRequestId.get() : null;
}

public McpSchema.InitializeResult currentInitializationResult() {
DefaultInitialization current = this.initializationRef.get();
McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null;
Expand Down Expand Up @@ -305,8 +317,11 @@ private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization init
McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion,
this.clientCapabilities, this.clientInfo);

Mono<McpSchema.InitializeResult> result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE,
initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF);
McpClientSession.RequestMono<McpSchema.InitializeResult> requestMono = mcpClientSession.sendRequestWithId(
McpSchema.METHOD_INITIALIZE, initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF);
initialization.initializeRequestId.set(requestMono.requestId());

Mono<McpSchema.InitializeResult> result = requestMono.response();

return result.flatMap(initializeResult -> {
logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import io.modelcontextprotocol.client.LifecycleInitializer.Initialization;
Expand All @@ -22,6 +23,7 @@
import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler;
import io.modelcontextprotocol.spec.McpClientSession.RequestHandler;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpRequestHandle;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
Expand Down Expand Up @@ -449,6 +451,61 @@ public Mono<Object> ping() {
init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF));
}

// --------------------------
// Cancellation
// --------------------------

/**
* Cancels a previously issued request by its ID. Sends a
* {@code notifications/cancelled} notification to the server and errors the pending
* response locally.
* @param requestId The ID of the request to cancel
* @param reason An optional human-readable reason for the cancellation
* @return A Mono that completes when the cancellation notification is sent
*/
public Mono<Void> cancelRequest(Object requestId, String reason) {
if (!this.isInitialized()) {
return Mono.error(new IllegalStateException("Cannot cancel request before initialization"));
}
Object initId = this.initializer.getInitializeRequestId();
if (initId != null && initId.equals(requestId)) {
return Mono.error(new IllegalArgumentException("The initialize request MUST NOT be cancelled"));
}
return this.initializer.withInitialization("cancelling request",
init -> init.mcpSession().sendCancellation(requestId, reason));
}

/**
* Calls a tool and returns a handle that can be used to cancel the request.
* @param callToolRequest The request containing the tool name and input parameters
* @return A McpRequestHandle containing the request ID, response Mono, and a cancel
* function
*/
public McpRequestHandle<McpSchema.CallToolResult> callToolWithHandle(McpSchema.CallToolRequest callToolRequest) {
AtomicReference<String> requestIdRef = new AtomicReference<>();

Mono<McpSchema.CallToolResult> responseMono = this.initializer.withInitialization("calling tool with handle",
init -> {
if (init.initializeResult().capabilities().tools() == null) {
return Mono.error(new IllegalStateException("Server does not provide tools capability"));
}
McpClientSession.RequestMono<McpSchema.CallToolResult> rm = init.mcpSession()
.sendRequestWithId(McpSchema.METHOD_TOOLS_CALL, callToolRequest, CALL_TOOL_RESULT_TYPE_REF);
requestIdRef.set(rm.requestId());
return rm.response()
.flatMap(result -> Mono.just(validateToolResult(callToolRequest.name(), result)));
});

return McpRequestHandle.lazy(requestIdRef, responseMono, reason -> {
String id = requestIdRef.get();
if (id == null) {
return Mono.error(new IllegalStateException(
"Cannot cancel: request has not been issued yet. Subscribe to the response Mono first."));
}
return this.cancelRequest(id, reason);
});
}

// --------------------------
// Roots
// --------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.slf4j.LoggerFactory;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.spec.McpRequestHandle;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest;
Expand Down Expand Up @@ -218,6 +219,20 @@ public Object ping() {
return withProvidedContext(this.delegate.ping()).block();
}

// --------------------------
// Cancellation
// --------------------------

/**
* Cancels a previously issued request. IMPORTANT: This method MUST be called from a
* different thread than the one blocked waiting on the request result.
* @param requestId The ID of the request to cancel
* @param reason An optional human-readable reason for the cancellation
*/
public void cancelRequest(Object requestId, String reason) {
withProvidedContext(this.delegate.cancelRequest(requestId, reason)).block();
}

// --------------------------
// Tools
// --------------------------
Expand All @@ -234,7 +249,16 @@ public Object ping() {
*/
public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest) {
return withProvidedContext(this.delegate.callTool(callToolRequest)).block();
}

/**
* Calls a tool with a specific timeout.
* @param callToolRequest The request containing the tool name and input parameters
* @param timeout The maximum duration to wait for the result
* @return The tool execution result
*/
public McpSchema.CallToolResult callTool(McpSchema.CallToolRequest callToolRequest, Duration timeout) {
return withProvidedContext(this.delegate.callTool(callToolRequest)).timeout(timeout).block();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;

import java.util.HashMap;
import java.util.Map;

class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler {
Expand All @@ -24,7 +25,11 @@ class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler {
public DefaultMcpStatelessServerHandler(Map<String, McpStatelessRequestHandler<?>> requestHandlers,
Map<String, McpStatelessNotificationHandler> notificationHandlers) {
this.requestHandlers = requestHandlers;
this.notificationHandlers = notificationHandlers;
this.notificationHandlers = new HashMap<>(notificationHandlers);
this.notificationHandlers.putIfAbsent(McpSchema.METHOD_NOTIFICATION_CANCELLED, (ctx, params) -> {
logger.debug("Ignoring cancellation in stateless mode");
return Mono.empty();
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ private Map<String, McpNotificationHandler> prepareNotificationHandlers(McpServe

notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty());

notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_CANCELLED, (exchange, params) -> {
if (features.cancellationConsumer() != null) {
McpSchema.CancelledNotification cancelled = jsonMapper.convertValue(params,
new TypeRef<McpSchema.CancelledNotification>() {
});
return features.cancellationConsumer().apply(exchange, cancelled);
}
return Mono.empty();
});

List<BiFunction<McpAsyncServerExchange, List<McpSchema.Root>, Mono<Void>>> rootsChangeConsumers = features
.rootsChangeConsumers();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,17 @@ public Mono<Object> ping() {
return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF);
}

/**
* Cancels a previously issued request to the client (server-to-client direction).
* Sends a {@code notifications/cancelled} notification.
* @param requestId The ID of the request to cancel
* @param reason An optional human-readable reason for the cancellation
* @return A Mono that completes when the cancellation notification is sent
*/
public Mono<Void> cancelRequest(Object requestId, String reason) {
return this.session.sendCancellation(requestId, reason);
}

/**
* Set the minimum logging level for the client. Messages below this level will be
* filtered out.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ private SingleSessionAsyncSpecification(McpServerTransportProvider transportProv
public McpAsyncServer build() {
var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools,
this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers,
this.instructions);
this.instructions, this.cancellationConsumer);

var jsonSchemaValidator = (this.jsonSchemaValidator != null) ? this.jsonSchemaValidator
: McpJsonDefaults.getSchemaValidator();
Expand Down Expand Up @@ -265,7 +265,7 @@ public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider t
public McpAsyncServer build() {
var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools,
this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers,
this.instructions);
this.instructions, this.cancellationConsumer);
var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator
: McpJsonDefaults.getSchemaValidator();
return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper,
Expand Down Expand Up @@ -333,6 +333,8 @@ abstract class AsyncSpecification<S extends AsyncSpecification<S>> {

final List<BiFunction<McpAsyncServerExchange, List<McpSchema.Root>, Mono<Void>>> rootsChangeHandlers = new ArrayList<>();

BiFunction<McpAsyncServerExchange, McpSchema.CancelledNotification, Mono<Void>> cancellationConsumer;

Duration requestTimeout = Duration.ofHours(10); // Default timeout

public abstract McpAsyncServer build();
Expand Down Expand Up @@ -805,6 +807,20 @@ public AsyncSpecification<S> rootsChangeHandlers(
return this.rootsChangeHandlers(Arrays.asList(handlers));
}

/**
* Registers an optional consumer that is invoked when a
* {@code notifications/cancelled} notification is received from the client. The
* session layer already handles the core cancellation logic; this consumer is for
* application-level side-effects (e.g. logging, metrics, UI updates).
* @param consumer The cancellation consumer
* @return This builder instance for method chaining
*/
public AsyncSpecification<S> cancellationConsumer(
BiFunction<McpAsyncServerExchange, McpSchema.CancelledNotification, Mono<Void>> consumer) {
this.cancellationConsumer = consumer;
return this;
}

/**
* Sets the JsonMapper to use for serializing and deserializing JSON messages.
* @param jsonMapper the mapper to use. Must not be null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,11 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s
Map<String, McpServerFeatures.AsyncPromptSpecification> prompts,
Map<McpSchema.CompleteReference, McpServerFeatures.AsyncCompletionSpecification> completions,
List<BiFunction<McpAsyncServerExchange, List<McpSchema.Root>, Mono<Void>>> rootsChangeConsumers,
String instructions) {
String instructions,
BiFunction<McpAsyncServerExchange, McpSchema.CancelledNotification, Mono<Void>> cancellationConsumer) {

/**
* Create an instance and validate the arguments.
* @param serverInfo The server implementation details
* @param serverCapabilities The server capabilities
* @param tools The list of tool specifications
* @param resources The map of resource specifications
* @param resourceTemplates The map of resource templates
* @param prompts The map of prompt specifications
* @param rootsChangeConsumers The list of consumers that will be notified when
* the roots list changes
* @param instructions The server instructions text
* Backwards-compatible constructor without cancellationConsumer.
*/
Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities,
List<McpServerFeatures.AsyncToolSpecification> tools, Map<String, AsyncResourceSpecification> resources,
Expand All @@ -66,6 +58,21 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s
Map<McpSchema.CompleteReference, McpServerFeatures.AsyncCompletionSpecification> completions,
List<BiFunction<McpAsyncServerExchange, List<McpSchema.Root>, Mono<Void>>> rootsChangeConsumers,
String instructions) {
this(serverInfo, serverCapabilities, tools, resources, resourceTemplates, prompts, completions,
rootsChangeConsumers, instructions, null);
}

/**
* Create an instance and validate the arguments.
*/
Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities,
List<McpServerFeatures.AsyncToolSpecification> tools, Map<String, AsyncResourceSpecification> resources,
Map<String, McpServerFeatures.AsyncResourceTemplateSpecification> resourceTemplates,
Map<String, McpServerFeatures.AsyncPromptSpecification> prompts,
Map<McpSchema.CompleteReference, McpServerFeatures.AsyncCompletionSpecification> completions,
List<BiFunction<McpAsyncServerExchange, List<McpSchema.Root>, Mono<Void>>> rootsChangeConsumers,
String instructions,
BiFunction<McpAsyncServerExchange, McpSchema.CancelledNotification, Mono<Void>> cancellationConsumer) {

Assert.notNull(serverInfo, "Server info must not be null");

Expand All @@ -89,6 +96,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s
this.completions = (completions != null) ? completions : Map.of();
this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of();
this.instructions = instructions;
this.cancellationConsumer = cancellationConsumer;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,13 @@ public Object ping() {
return this.exchange.ping().block();
}

/**
* Cancels a previously issued request to the client (server-to-client direction).
* @param requestId The ID of the request to cancel
* @param reason An optional human-readable reason for the cancellation
*/
public void cancelRequest(Object requestId, String reason) {
this.exchange.cancelRequest(requestId, reason).block();
}

}
Loading