Skip to content

Commit fe8f75d

Browse files
author
g97iulio1609
committed
fix: handle tools/resources/prompts list_changed notifications in client
Previously, the client silently dropped ToolListChangedNotification, ResourceListChangedNotification, and PromptListChangedNotification from the server. They fell through to the catch-all case in _received_notification(). Add optional callbacks (tool_list_changed_callback, resource_list_changed_callback, prompt_list_changed_callback) following the same pattern as logging_callback. When set, the callbacks are invoked when the corresponding notification arrives. When not set, a no-op default is used (preserving backward compat). Updated both ClientSession and the high-level Client dataclass. Fixes #2107
1 parent 62575ed commit fe8f75d

File tree

3 files changed

+224
-1
lines changed

3 files changed

+224
-1
lines changed

src/mcp/client/client.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,17 @@
88

99
from mcp.client._memory import InMemoryTransport
1010
from mcp.client._transport import Transport
11-
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
11+
from mcp.client.session import (
12+
ClientSession,
13+
ElicitationFnT,
14+
ListRootsFnT,
15+
LoggingFnT,
16+
MessageHandlerFnT,
17+
PromptListChangedFnT,
18+
ResourceListChangedFnT,
19+
SamplingFnT,
20+
ToolListChangedFnT,
21+
)
1222
from mcp.client.streamable_http import streamable_http_client
1323
from mcp.server import Server
1424
from mcp.server.mcpserver import MCPServer
@@ -95,6 +105,15 @@ async def main():
95105
elicitation_callback: ElicitationFnT | None = None
96106
"""Callback for handling elicitation requests."""
97107

108+
tool_list_changed_callback: ToolListChangedFnT | None = None
109+
"""Callback invoked when the server sends a tools/list_changed notification."""
110+
111+
resource_list_changed_callback: ResourceListChangedFnT | None = None
112+
"""Callback invoked when the server sends a resources/list_changed notification."""
113+
114+
prompt_list_changed_callback: PromptListChangedFnT | None = None
115+
"""Callback invoked when the server sends a prompts/list_changed notification."""
116+
98117
_session: ClientSession | None = field(init=False, default=None)
99118
_exit_stack: AsyncExitStack | None = field(init=False, default=None)
100119
_transport: Transport = field(init=False)
@@ -126,6 +145,9 @@ async def __aenter__(self) -> Client:
126145
message_handler=self.message_handler,
127146
client_info=self.client_info,
128147
elicitation_callback=self.elicitation_callback,
148+
tool_list_changed_callback=self.tool_list_changed_callback,
149+
resource_list_changed_callback=self.resource_list_changed_callback,
150+
prompt_list_changed_callback=self.prompt_list_changed_callback,
129151
)
130152
)
131153

src/mcp/client/session.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ class LoggingFnT(Protocol):
4747
async def __call__(self, params: types.LoggingMessageNotificationParams) -> None: ... # pragma: no branch
4848

4949

50+
class ToolListChangedFnT(Protocol):
51+
async def __call__(self) -> None: ... # pragma: no branch
52+
53+
54+
class ResourceListChangedFnT(Protocol):
55+
async def __call__(self) -> None: ... # pragma: no branch
56+
57+
58+
class PromptListChangedFnT(Protocol):
59+
async def __call__(self) -> None: ... # pragma: no branch
60+
61+
5062
class MessageHandlerFnT(Protocol):
5163
async def __call__(
5264
self,
@@ -95,6 +107,18 @@ async def _default_logging_callback(
95107
pass
96108

97109

110+
async def _default_tool_list_changed_callback() -> None:
111+
pass
112+
113+
114+
async def _default_resource_list_changed_callback() -> None:
115+
pass
116+
117+
118+
async def _default_prompt_list_changed_callback() -> None:
119+
pass
120+
121+
98122
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
99123

100124

@@ -120,6 +144,9 @@ def __init__(
120144
client_info: types.Implementation | None = None,
121145
*,
122146
sampling_capabilities: types.SamplingCapability | None = None,
147+
tool_list_changed_callback: ToolListChangedFnT | None = None,
148+
resource_list_changed_callback: ResourceListChangedFnT | None = None,
149+
prompt_list_changed_callback: PromptListChangedFnT | None = None,
123150
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
124151
) -> None:
125152
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
@@ -129,6 +156,9 @@ def __init__(
129156
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
130157
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
131158
self._logging_callback = logging_callback or _default_logging_callback
159+
self._tool_list_changed_callback = tool_list_changed_callback or _default_tool_list_changed_callback
160+
self._resource_list_changed_callback = resource_list_changed_callback or _default_resource_list_changed_callback
161+
self._prompt_list_changed_callback = prompt_list_changed_callback or _default_prompt_list_changed_callback
132162
self._message_handler = message_handler or _default_message_handler
133163
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
134164
self._server_capabilities: types.ServerCapabilities | None = None
@@ -475,5 +505,11 @@ async def _received_notification(self, notification: types.ServerNotification) -
475505
# Clients MAY use this to retry requests or update UI
476506
# The notification contains the elicitationId of the completed elicitation
477507
pass
508+
case types.ToolListChangedNotification():
509+
await self._tool_list_changed_callback()
510+
case types.ResourceListChangedNotification():
511+
await self._resource_list_changed_callback()
512+
case types.PromptListChangedNotification():
513+
await self._prompt_list_changed_callback()
478514
case _:
479515
pass
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""Tests for tools/resources/prompts list_changed notification callbacks."""
2+
3+
import pytest
4+
5+
from mcp import Client, types
6+
from mcp.server.mcpserver import MCPServer
7+
from mcp.shared.session import RequestResponder
8+
9+
10+
class ListChangedCollector:
11+
"""Collects list_changed notification invocations."""
12+
13+
def __init__(self):
14+
self.tool_changed_count = 0
15+
self.resource_changed_count = 0
16+
self.prompt_changed_count = 0
17+
18+
async def on_tool_list_changed(self) -> None:
19+
self.tool_changed_count += 1
20+
21+
async def on_resource_list_changed(self) -> None:
22+
self.resource_changed_count += 1
23+
24+
async def on_prompt_list_changed(self) -> None:
25+
self.prompt_changed_count += 1
26+
27+
28+
@pytest.mark.anyio
29+
async def test_tool_list_changed_callback():
30+
"""Client receives tools/list_changed notification and invokes callback."""
31+
server = MCPServer("test")
32+
collector = ListChangedCollector()
33+
34+
@server.tool("trigger_tool_change")
35+
async def trigger_tool_change() -> str:
36+
ctx = server.get_context()
37+
await ctx.session.send_notification(types.ToolListChangedNotification())
38+
return "ok"
39+
40+
async def message_handler(
41+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
42+
) -> None:
43+
if isinstance(message, Exception):
44+
raise message
45+
46+
async with Client(
47+
server,
48+
tool_list_changed_callback=collector.on_tool_list_changed,
49+
message_handler=message_handler,
50+
) as client:
51+
result = await client.call_tool("trigger_tool_change", {})
52+
assert result.is_error is False
53+
assert collector.tool_changed_count == 1
54+
55+
56+
@pytest.mark.anyio
57+
async def test_resource_list_changed_callback():
58+
"""Client receives resources/list_changed notification and invokes callback."""
59+
server = MCPServer("test")
60+
collector = ListChangedCollector()
61+
62+
@server.tool("trigger_resource_change")
63+
async def trigger_resource_change() -> str:
64+
ctx = server.get_context()
65+
await ctx.session.send_notification(types.ResourceListChangedNotification())
66+
return "ok"
67+
68+
async def message_handler(
69+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
70+
) -> None:
71+
if isinstance(message, Exception):
72+
raise message
73+
74+
async with Client(
75+
server,
76+
resource_list_changed_callback=collector.on_resource_list_changed,
77+
message_handler=message_handler,
78+
) as client:
79+
result = await client.call_tool("trigger_resource_change", {})
80+
assert result.is_error is False
81+
assert collector.resource_changed_count == 1
82+
83+
84+
@pytest.mark.anyio
85+
async def test_prompt_list_changed_callback():
86+
"""Client receives prompts/list_changed notification and invokes callback."""
87+
server = MCPServer("test")
88+
collector = ListChangedCollector()
89+
90+
@server.tool("trigger_prompt_change")
91+
async def trigger_prompt_change() -> str:
92+
ctx = server.get_context()
93+
await ctx.session.send_notification(types.PromptListChangedNotification())
94+
return "ok"
95+
96+
async def message_handler(
97+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
98+
) -> None:
99+
if isinstance(message, Exception):
100+
raise message
101+
102+
async with Client(
103+
server,
104+
prompt_list_changed_callback=collector.on_prompt_list_changed,
105+
message_handler=message_handler,
106+
) as client:
107+
result = await client.call_tool("trigger_prompt_change", {})
108+
assert result.is_error is False
109+
assert collector.prompt_changed_count == 1
110+
111+
112+
@pytest.mark.anyio
113+
async def test_list_changed_without_callback_does_not_crash():
114+
"""list_changed notifications are silently ignored when no callback is set."""
115+
server = MCPServer("test")
116+
117+
@server.tool("trigger_all_changes")
118+
async def trigger_all_changes() -> str:
119+
ctx = server.get_context()
120+
await ctx.session.send_notification(types.ToolListChangedNotification())
121+
await ctx.session.send_notification(types.ResourceListChangedNotification())
122+
await ctx.session.send_notification(types.PromptListChangedNotification())
123+
return "ok"
124+
125+
async def message_handler(
126+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
127+
) -> None:
128+
if isinstance(message, Exception):
129+
raise message
130+
131+
async with Client(
132+
server,
133+
message_handler=message_handler,
134+
) as client:
135+
result = await client.call_tool("trigger_all_changes", {})
136+
assert result.is_error is False
137+
138+
139+
@pytest.mark.anyio
140+
async def test_multiple_list_changed_notifications():
141+
"""Multiple list_changed notifications each invoke the callback."""
142+
server = MCPServer("test")
143+
collector = ListChangedCollector()
144+
145+
@server.tool("trigger_double")
146+
async def trigger_double() -> str:
147+
ctx = server.get_context()
148+
await ctx.session.send_notification(types.ToolListChangedNotification())
149+
await ctx.session.send_notification(types.ToolListChangedNotification())
150+
return "ok"
151+
152+
async def message_handler(
153+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
154+
) -> None:
155+
if isinstance(message, Exception):
156+
raise message
157+
158+
async with Client(
159+
server,
160+
tool_list_changed_callback=collector.on_tool_list_changed,
161+
message_handler=message_handler,
162+
) as client:
163+
result = await client.call_tool("trigger_double", {})
164+
assert result.is_error is False
165+
assert collector.tool_changed_count == 2

0 commit comments

Comments
 (0)