Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions src/mcp/server/lowlevel/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
from collections.abc import Awaitable, Callable
from typing import Any, Generic
from typing import TYPE_CHECKING, Any, Generic

from typing_extensions import TypeVar

Expand Down Expand Up @@ -38,6 +38,9 @@
TasksToolsCapability,
)

if TYPE_CHECKING:
from mcp.server.lowlevel.server import Server

logger = logging.getLogger(__name__)

LifespanResultT = TypeVar("LifespanResultT", default=Any)
Expand All @@ -51,13 +54,9 @@ class ExperimentalHandlers(Generic[LifespanResultT]):

def __init__(
self,
add_request_handler: Callable[
[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]], None
],
has_handler: Callable[[str], bool],
server: Server[LifespanResultT, Any],
) -> None:
self._add_request_handler = add_request_handler
self._has_handler = has_handler
self._server = server
self._task_support: TaskSupport | None = None

@property
Expand All @@ -67,13 +66,15 @@ def task_support(self) -> TaskSupport | None:

def update_capabilities(self, capabilities: ServerCapabilities) -> None:
# Only add tasks capability if handlers are registered
if not any(self._has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]):
if not any(
self._server.has_handler(method) for method in ["tasks/get", "tasks/list", "tasks/cancel", "tasks/result"]
):
return

capabilities.tasks = ServerTasksCapability()
if self._has_handler("tasks/list"):
if self._server.has_handler("tasks/list"):
capabilities.tasks.list = TasksListCapability()
if self._has_handler("tasks/cancel"):
if self._server.has_handler("tasks/cancel"):
capabilities.tasks.cancel = TasksCancelCapability()

capabilities.tasks.requests = ServerTasksRequestsCapability(
Expand Down Expand Up @@ -145,16 +146,16 @@ def enable_tasks(

# Register user-provided handlers
if on_get_task is not None:
self._add_request_handler("tasks/get", on_get_task)
self._server.add_request_handler("tasks/get", on_get_task)
if on_task_result is not None:
self._add_request_handler("tasks/result", on_task_result)
self._server.add_request_handler("tasks/result", on_task_result)
if on_list_tasks is not None:
self._add_request_handler("tasks/list", on_list_tasks)
self._server.add_request_handler("tasks/list", on_list_tasks)
if on_cancel_task is not None:
self._add_request_handler("tasks/cancel", on_cancel_task)
self._server.add_request_handler("tasks/cancel", on_cancel_task)

# Fill in defaults for any not provided
if not self._has_handler("tasks/get"):
if not self._server.has_handler("tasks/get"):

async def _default_get_task(
ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams
Expand All @@ -172,9 +173,9 @@ async def _default_get_task(
poll_interval=task.poll_interval,
)

self._add_request_handler("tasks/get", _default_get_task)
self._server.add_request_handler("tasks/get", _default_get_task)

if not self._has_handler("tasks/result"):
if not self._server.has_handler("tasks/result"):

async def _default_get_task_result(
ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams
Expand All @@ -184,9 +185,9 @@ async def _default_get_task_result(
result = await task_support.handler.handle(req, ctx.session, ctx.request_id)
return result

self._add_request_handler("tasks/result", _default_get_task_result)
self._server.add_request_handler("tasks/result", _default_get_task_result)

if not self._has_handler("tasks/list"):
if not self._server.has_handler("tasks/list"):

async def _default_list_tasks(
ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None
Expand All @@ -195,16 +196,16 @@ async def _default_list_tasks(
tasks, next_cursor = await task_support.store.list_tasks(cursor)
return ListTasksResult(tasks=tasks, next_cursor=next_cursor)

self._add_request_handler("tasks/list", _default_list_tasks)
self._server.add_request_handler("tasks/list", _default_list_tasks)

if not self._has_handler("tasks/cancel"):
if not self._server.has_handler("tasks/cancel"):

async def _default_cancel_task(
ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams
) -> CancelTaskResult:
result = await cancel_task(task_support.store, params.task_id)
return result

self._add_request_handler("tasks/cancel", _default_cancel_task)
self._server.add_request_handler("tasks/cancel", _default_cancel_task)

return task_support
71 changes: 67 additions & 4 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,72 @@ def _has_handler(self, method: str) -> bool:
"""Check if a handler is registered for the given method."""
return method in self._request_handlers or method in self._notification_handlers

def add_request_handler(
self,
method: str,
handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]],
) -> None:
"""Register a request handler for the given method.

If a handler is already registered for this method, it will be replaced.

Args:
method: The JSON-RPC method name (e.g., "tools/list", "myextension/query").
handler: An async callable that takes (ServerRequestContext, params) and
returns the result.
"""
self._request_handlers[method] = handler

def remove_request_handler(self, method: str) -> None:
"""Remove the request handler for the given method.

Args:
method: The JSON-RPC method name to deregister.

Raises:
KeyError: If no handler is registered for this method.
"""
del self._request_handlers[method]

def add_notification_handler(
self,
method: str,
handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]],
) -> None:
"""Register a notification handler for the given method.

If a handler is already registered for this method, it will be replaced.

Args:
method: The JSON-RPC notification method name
(e.g., "notifications/progress").
handler: An async callable that takes (ServerRequestContext, params) and
returns None.
"""
self._notification_handlers[method] = handler

def remove_notification_handler(self, method: str) -> None:
"""Remove the notification handler for the given method.

Args:
method: The JSON-RPC notification method name to deregister.

Raises:
KeyError: If no handler is registered for this method.
"""
del self._notification_handlers[method]

def has_handler(self, method: str) -> bool:
"""Check if a handler is registered for the given request or notification method.

Args:
method: The JSON-RPC method name to check.

Returns:
True if a handler is registered, False otherwise.
"""
return method in self._request_handlers or method in self._notification_handlers

# TODO: Rethink capabilities API. Currently capabilities are derived from registered
# handlers but require NotificationOptions to be passed externally for list_changed
# flags, and experimental_capabilities as a separate dict. Consider deriving capabilities
Expand Down Expand Up @@ -336,10 +402,7 @@ def experimental(self) -> ExperimentalHandlers[LifespanResultT]:

# We create this inline so we only add these capabilities _if_ they're actually used
if self._experimental_handlers is None:
self._experimental_handlers = ExperimentalHandlers(
add_request_handler=self._add_request_handler,
has_handler=self._has_handler,
)
self._experimental_handlers = ExperimentalHandlers(server=self)
return self._experimental_handlers

@property
Expand Down
94 changes: 94 additions & 0 deletions tests/server/lowlevel/test_handler_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Tests for public handler registration/deregistration API on low-level Server."""

import pytest

from mcp.server.lowlevel.server import Server


@pytest.fixture
def server():
return Server(name="test-server")


async def _dummy_request_handler(ctx, params):
return {"result": "ok"}


async def _dummy_notification_handler(ctx, params):
pass


class TestAddRequestHandler:
def test_add_request_handler(self, server):
server.add_request_handler("custom/method", _dummy_request_handler)
assert server.has_handler("custom/method")

def test_add_request_handler_replaces_existing(self, server):
async def handler_a(ctx, params):
return "a"

async def handler_b(ctx, params):
return "b"

server.add_request_handler("custom/method", handler_a)
server.add_request_handler("custom/method", handler_b)
# The second handler should replace the first
assert server._request_handlers["custom/method"] is handler_b


class TestRemoveRequestHandler:
def test_remove_request_handler(self, server):
server.add_request_handler("custom/method", _dummy_request_handler)
assert server.has_handler("custom/method")
server.remove_request_handler("custom/method")
assert not server.has_handler("custom/method")

def test_remove_request_handler_not_found(self, server):
with pytest.raises(KeyError):
server.remove_request_handler("nonexistent/method")


class TestAddNotificationHandler:
def test_add_notification_handler(self, server):
server.add_notification_handler("custom/notify", _dummy_notification_handler)
assert server.has_handler("custom/notify")

def test_add_notification_handler_replaces_existing(self, server):
async def handler_a(ctx, params):
pass

async def handler_b(ctx, params):
pass

server.add_notification_handler("custom/notify", handler_a)
server.add_notification_handler("custom/notify", handler_b)
assert server._notification_handlers["custom/notify"] is handler_b


class TestRemoveNotificationHandler:
def test_remove_notification_handler(self, server):
server.add_notification_handler("custom/notify", _dummy_notification_handler)
assert server.has_handler("custom/notify")
server.remove_notification_handler("custom/notify")
assert not server.has_handler("custom/notify")

def test_remove_notification_handler_not_found(self, server):
with pytest.raises(KeyError):
server.remove_notification_handler("nonexistent/notify")


class TestHasHandler:
def test_has_handler_request(self, server):
server.add_request_handler("custom/method", _dummy_request_handler)
assert server.has_handler("custom/method")

def test_has_handler_notification(self, server):
server.add_notification_handler("custom/notify", _dummy_notification_handler)
assert server.has_handler("custom/notify")

def test_has_handler_unregistered(self, server):
assert not server.has_handler("nonexistent/method")

def test_has_handler_default_ping(self, server):
"""The ping handler is registered by default."""
assert server.has_handler("ping")
Loading