diff --git a/src/strands/tools/mcp/__init__.py b/src/strands/tools/mcp/__init__.py index 8d2c1daa2..8bd621a22 100644 --- a/src/strands/tools/mcp/__init__.py +++ b/src/strands/tools/mcp/__init__.py @@ -6,9 +6,11 @@ - Docs: https://www.anthropic.com/news/model-context-protocol """ +from mcp.shared.session import ProgressFnT + from .mcp_agent_tool import MCPAgentTool from .mcp_client import MCPClient, ToolFilters from .mcp_tasks import TasksConfig from .mcp_types import MCPTransport -__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "TasksConfig", "ToolFilters"] +__all__ = ["MCPAgentTool", "MCPClient", "MCPTransport", "ProgressFnT", "TasksConfig", "ToolFilters"] diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 270012fde..53932fbb6 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -27,6 +27,7 @@ from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT from mcp.types import ( BlobResourceContents, ElicitationRequiredErrorData, @@ -121,6 +122,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, + progress_callback: ProgressFnT | None = None, tasks_config: TasksConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -132,6 +134,9 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + progress_callback: Optional callback to receive progress notifications during tool execution. + Called with ``(progress, total, message)`` as the server reports progress. The ``total`` + and ``message`` parameters may be ``None`` if the server does not provide them. tasks_config: Configuration for MCP task-augmented execution for long-running tools. If provided (not None), enables task-augmented execution for tools that support it. See TasksConfig for details. This feature is experimental and subject to change. @@ -140,6 +145,7 @@ def __init__( self._tool_filters = tool_filters self._prefix = prefix self._elicitation_callback = elicitation_callback + self._progress_callback = progress_callback mcp_instrumentation() self._session_id = uuid.uuid4() @@ -589,6 +595,7 @@ def _create_call_tool_coroutine( arguments: dict[str, Any] | None, read_timeout_seconds: timedelta | None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> Coroutine[Any, Any, MCPCallToolResult]: """Create the appropriate coroutine for calling a tool. @@ -600,11 +607,14 @@ def _create_call_tool_coroutine( arguments: Optional arguments to pass to the tool. read_timeout_seconds: Optional timeout for the tool call. meta: Optional metadata to pass to the tool call per MCP spec (_meta). + progress_callback: Optional callback to receive progress notifications. + If None, falls back to the instance-level callback set at construction time. Returns: A coroutine that will execute the tool call. """ use_task = self._should_use_task(name) + effective_callback = progress_callback if progress_callback is not None else self._progress_callback if use_task: self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) @@ -622,7 +632,7 @@ async def _call_as_task() -> MCPCallToolResult: async def _call_tool_direct() -> MCPCallToolResult: return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds, meta=meta + name, arguments, read_timeout_seconds, progress_callback=effective_callback, meta=meta ) return _call_tool_direct() @@ -634,6 +644,7 @@ def call_tool_sync( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. @@ -646,6 +657,8 @@ def call_tool_sync( arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call meta: Optional metadata to pass to the tool call per MCP spec (_meta) + progress_callback: Optional callback to receive progress notifications for this + call. Overrides the instance-level callback set at construction time. Returns: MCPToolResult: The result of the tool call @@ -655,7 +668,9 @@ def call_tool_sync( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + coro = self._create_call_tool_coroutine( + name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback + ) call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -669,6 +684,7 @@ async def call_tool_async( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. @@ -681,6 +697,8 @@ async def call_tool_async( arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call meta: Optional metadata to pass to the tool call per MCP spec (_meta) + progress_callback: Optional callback to receive progress notifications for this + call. Overrides the instance-level callback set at construction time. Returns: MCPToolResult: The result of the tool call @@ -690,7 +708,9 @@ async def call_tool_async( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + coro = self._create_call_tool_coroutine( + name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback + ) future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f270fa6fc..e6d6032e9 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,6 +1,6 @@ import base64 import time -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp import ListToolsResult @@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None) assert result["status"] == expected_status assert result["toolUseId"] == "test-123" @@ -155,7 +155,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123" @@ -193,10 +193,56 @@ def test_call_tool_sync_forwards_meta(mock_transport, mock_session): tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta ) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=meta) assert result["status"] == "success" +def test_call_tool_sync_forwards_instance_progress_callback(mock_transport, mock_session): + """Test that call_tool_sync uses the instance-level progress callback when no per-call callback is given.""" + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + cb = AsyncMock() + + with MCPClient(mock_transport["transport_callable"], progress_callback=cb) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + + mock_session.call_tool.assert_called_once_with( + "test_tool", {}, None, progress_callback=cb, meta=None + ) + assert result["status"] == "success" + + +def test_call_tool_sync_per_call_progress_callback_overrides_instance(mock_transport, mock_session): + """Test that a per-call progress callback overrides the instance-level one.""" + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + instance_cb = AsyncMock() + per_call_cb = AsyncMock() + + with MCPClient(mock_transport["transport_callable"], progress_callback=instance_cb) as client: + result = client.call_tool_sync( + tool_use_id="test-123", name="test_tool", arguments={}, progress_callback=per_call_cb + ) + + mock_session.call_tool.assert_called_once_with( + "test_tool", {}, None, progress_callback=per_call_cb, meta=None + ) + assert result["status"] == "success" + + +def test_call_tool_sync_no_progress_callback_by_default(mock_transport, mock_session): + """Test that progress_callback defaults to None when not set on instance or per-call.""" + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + + mock_session.call_tool.assert_called_once_with( + "test_tool", {}, None, progress_callback=None, meta=None + ) + + @pytest.mark.asyncio async def test_call_tool_async_forwards_meta(mock_transport, mock_session): """Test that call_tool_async forwards meta to ClientSession.call_tool.""" @@ -672,7 +718,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "inner text" @@ -697,7 +743,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == '{"k":"v"}' @@ -723,7 +769,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "image" in result["content"][0] @@ -748,7 +794,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Content should be dropped @@ -771,7 +817,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "key: value" in result["content"][0]["text"] @@ -798,7 +844,7 @@ def __init__(self): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Unknown resource type should be dropped @@ -850,7 +896,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert result["toolUseId"] == "test-123"