Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

from mcp.client._memory import InMemoryTransport
from mcp.client._transport import Transport
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
PromptListChangedFnT,
ResourceListChangedFnT,
SamplingFnT,
ToolListChangedFnT,
)
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server
from mcp.server.mcpserver import MCPServer
Expand Down Expand Up @@ -95,6 +105,15 @@ async def main():
elicitation_callback: ElicitationFnT | None = None
"""Callback for handling elicitation requests."""

tool_list_changed_callback: ToolListChangedFnT | None = None
"""Callback invoked when the server sends a tools/list_changed notification."""

resource_list_changed_callback: ResourceListChangedFnT | None = None
"""Callback invoked when the server sends a resources/list_changed notification."""

prompt_list_changed_callback: PromptListChangedFnT | None = None
"""Callback invoked when the server sends a prompts/list_changed notification."""

_session: ClientSession | None = field(init=False, default=None)
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
_transport: Transport = field(init=False)
Expand Down Expand Up @@ -126,6 +145,9 @@ async def __aenter__(self) -> Client:
message_handler=self.message_handler,
client_info=self.client_info,
elicitation_callback=self.elicitation_callback,
tool_list_changed_callback=self.tool_list_changed_callback,
resource_list_changed_callback=self.resource_list_changed_callback,
prompt_list_changed_callback=self.prompt_list_changed_callback,
)
)

Expand Down
36 changes: 36 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ class LoggingFnT(Protocol):
async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch


class ToolListChangedFnT(Protocol):
async def __call__(self) -> None: ... # pragma: no branch


class ResourceListChangedFnT(Protocol):
async def __call__(self) -> None: ... # pragma: no branch


class PromptListChangedFnT(Protocol):
async def __call__(self) -> None: ... # pragma: no branch


class MessageHandlerFnT(Protocol):
async def __call__(
self,
Expand Down Expand Up @@ -95,6 +107,18 @@ async def _default_logging_callback(
pass


async def _default_tool_list_changed_callback() -> None:
pass


async def _default_resource_list_changed_callback() -> None:
pass


async def _default_prompt_list_changed_callback() -> None:
pass


ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)


Expand All @@ -120,6 +144,9 @@ def __init__(
client_info: types.Implementation | None = None,
*,
sampling_capabilities: types.SamplingCapability | None = None,
tool_list_changed_callback: ToolListChangedFnT | None = None,
resource_list_changed_callback: ResourceListChangedFnT | None = None,
prompt_list_changed_callback: PromptListChangedFnT | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
) -> None:
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
Expand All @@ -129,6 +156,9 @@ def __init__(
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._tool_list_changed_callback = tool_list_changed_callback or _default_tool_list_changed_callback
self._resource_list_changed_callback = resource_list_changed_callback or _default_resource_list_changed_callback
self._prompt_list_changed_callback = prompt_list_changed_callback or _default_prompt_list_changed_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
Expand Down Expand Up @@ -475,5 +505,11 @@ async def _received_notification(self, notification: types.ServerNotification) -
# Clients MAY use this to retry requests or update UI
# The notification contains the elicitationId of the completed elicitation
pass
case types.ToolListChangedNotification():
await self._tool_list_changed_callback()
case types.ResourceListChangedNotification():
await self._resource_list_changed_callback()
case types.PromptListChangedNotification():
await self._prompt_list_changed_callback()
case _:
pass
165 changes: 165 additions & 0 deletions tests/client/test_list_changed_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""Tests for tools/resources/prompts list_changed notification callbacks."""

import pytest

from mcp import Client, types
from mcp.server.mcpserver import MCPServer
from mcp.shared.session import RequestResponder


class ListChangedCollector:
"""Collects list_changed notification invocations."""

def __init__(self):
self.tool_changed_count = 0
self.resource_changed_count = 0
self.prompt_changed_count = 0

async def on_tool_list_changed(self) -> None:
self.tool_changed_count += 1

async def on_resource_list_changed(self) -> None:
self.resource_changed_count += 1

async def on_prompt_list_changed(self) -> None:
self.prompt_changed_count += 1


@pytest.mark.anyio
async def test_tool_list_changed_callback():
"""Client receives tools/list_changed notification and invokes callback."""
server = MCPServer("test")
collector = ListChangedCollector()

@server.tool("trigger_tool_change")
async def trigger_tool_change() -> str:
ctx = server.get_context()
await ctx.session.send_notification(types.ToolListChangedNotification())
return "ok"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message # pragma: no cover

async with Client(
server,
tool_list_changed_callback=collector.on_tool_list_changed,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_tool_change", {})
assert result.is_error is False
assert collector.tool_changed_count == 1


@pytest.mark.anyio
async def test_resource_list_changed_callback():
"""Client receives resources/list_changed notification and invokes callback."""
server = MCPServer("test")
collector = ListChangedCollector()

@server.tool("trigger_resource_change")
async def trigger_resource_change() -> str:
ctx = server.get_context()
await ctx.session.send_notification(types.ResourceListChangedNotification())
return "ok"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message # pragma: no cover

async with Client(
server,
resource_list_changed_callback=collector.on_resource_list_changed,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_resource_change", {})
assert result.is_error is False
assert collector.resource_changed_count == 1


@pytest.mark.anyio
async def test_prompt_list_changed_callback():
"""Client receives prompts/list_changed notification and invokes callback."""
server = MCPServer("test")
collector = ListChangedCollector()

@server.tool("trigger_prompt_change")
async def trigger_prompt_change() -> str:
ctx = server.get_context()
await ctx.session.send_notification(types.PromptListChangedNotification())
return "ok"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message # pragma: no cover

async with Client(
server,
prompt_list_changed_callback=collector.on_prompt_list_changed,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_prompt_change", {})
assert result.is_error is False
assert collector.prompt_changed_count == 1


@pytest.mark.anyio
async def test_list_changed_without_callback_does_not_crash():
"""list_changed notifications are silently ignored when no callback is set."""
server = MCPServer("test")

@server.tool("trigger_all_changes")
async def trigger_all_changes() -> str:
ctx = server.get_context()
await ctx.session.send_notification(types.ToolListChangedNotification())
await ctx.session.send_notification(types.ResourceListChangedNotification())
await ctx.session.send_notification(types.PromptListChangedNotification())
return "ok"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message # pragma: no cover

async with Client(
server,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_all_changes", {})
assert result.is_error is False


@pytest.mark.anyio
async def test_multiple_list_changed_notifications():
"""Multiple list_changed notifications each invoke the callback."""
server = MCPServer("test")
collector = ListChangedCollector()

@server.tool("trigger_double")
async def trigger_double() -> str:
ctx = server.get_context()
await ctx.session.send_notification(types.ToolListChangedNotification())
await ctx.session.send_notification(types.ToolListChangedNotification())
return "ok"

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message # pragma: no cover

async with Client(
server,
tool_list_changed_callback=collector.on_tool_list_changed,
message_handler=message_handler,
) as client:
result = await client.call_tool("trigger_double", {})
assert result.is_error is False
assert collector.tool_changed_count == 2
Loading