From ac9577b9f42f1fe777ae94b00d7e3e7ffd514f70 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Tue, 13 Jan 2026 12:53:52 -0800 Subject: [PATCH 1/4] feat(mcp): Implement basic support for Tasks Implements client support for MCP Tasks via an adapter model that handles both tool execution modes identically. This can later be hooked into event handlers in a more intelligent way, but this unblocks support for simply invoking task-augmented tools. Keep error handling and edge case tests (timeout, failure status, config). Also remove unused create_tool_with_task_support helper and trim task_echo_server. Reduces PR diff from 1433 to 969 lines (under 1000 limit). --- src/strands/tools/mcp/mcp_client.py | 310 ++++++++++++++++-- tests/strands/tools/mcp/conftest.py | 59 ++++ tests/strands/tools/mcp/test_mcp_client.py | 32 +- .../tools/mcp/test_mcp_client_tasks.py | 223 +++++++++++++ tests_integ/mcp/task_echo_server.py | 139 ++++++++ tests_integ/mcp/test_mcp_client_tasks.py | 188 +++++++++++ 6 files changed, 900 insertions(+), 51 deletions(-) create mode 100644 tests/strands/tools/mcp/conftest.py create mode 100644 tests/strands/tools/mcp/test_mcp_client_tasks.py create mode 100644 tests_integ/mcp/task_echo_server.py create mode 100644 tests_integ/mcp/test_mcp_client_tasks.py diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 1aff22a1e..2ca284407 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -120,6 +120,8 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, + default_task_ttl_ms: int = 60000, + default_task_poll_timeout_seconds: float = 300.0, ) -> None: """Initialize a new MCP Server connection. @@ -130,6 +132,10 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + default_task_ttl_ms: Default time-to-live in milliseconds for task-augmented tool calls. + Defaults to 60000 (1 minute). + default_task_poll_timeout_seconds: Default timeout in seconds for polling task completion. + Defaults to 300.0 (5 minutes). """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -154,6 +160,12 @@ def __init__( self._tool_provider_started = False self._consumers: set[Any] = set() + # Task support caching + self._default_task_ttl_ms = default_task_ttl_ms + self._default_task_poll_timeout_seconds = default_task_poll_timeout_seconds + self._server_task_capable: bool | None = None + self._tool_task_support_cache: dict[str, str | None] = {} + def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -358,6 +370,8 @@ async def _set_close_event() -> None: self._loaded_tools = None self._tool_provider_started = False self._consumers = set() + self._server_task_capable = None + self._tool_task_support_cache = {} if self._close_exception: exception = self._close_exception @@ -392,14 +406,37 @@ def list_tools_sync( effective_prefix = self._prefix if prefix is None else prefix effective_filters = self._tool_filters if tool_filters is None else tool_filters - async def _list_tools_async() -> ListToolsResult: - return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) + async def _list_tools_and_cache_capabilities_async() -> ListToolsResult: + session = cast(ClientSession, self._background_thread_session) + list_tools_result = await session.list_tools(cursor=pagination_token) + + # Cache server task capability while we have an active session + # This avoids needing a separate async call later during call_tool_* + if self._server_task_capable is None: + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + + return list_tools_result - list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() + list_tools_response: ListToolsResult = self._invoke_on_background_thread( + _list_tools_and_cache_capabilities_async() + ).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) mcp_tools = [] for tool in list_tools_response.tools: + # Cache taskSupport for task-augmented execution decisions + task_support = None + if tool.execution is not None and tool.execution.taskSupport is not None: + task_support = tool.execution.taskSupport + self._tool_task_support_cache[tool.name] = task_support + # Apply prefix if specified if effective_prefix: prefixed_name = f"{effective_prefix}_{tool.name}" @@ -539,6 +576,45 @@ async def _list_resource_templates_async() -> ListResourceTemplatesResult: return list_resource_templates_result + def _create_call_tool_coroutine( + self, + name: str, + arguments: dict[str, Any] | None, + read_timeout_seconds: timedelta | None, + ) -> Coroutine[Any, Any, MCPCallToolResult]: + """Create the appropriate coroutine for calling a tool. + + This method encapsulates the decision logic for whether to use task-augmented + execution or direct call_tool, returning the appropriate coroutine. + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + read_timeout_seconds: Optional timeout for the tool call. + + Returns: + A coroutine that will execute the tool call. + """ + use_task = self._should_use_task(name) + + if use_task: + self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) + poll_timeout = self._convert_timeout_for_polling(read_timeout_seconds) + + async def _call_as_task() -> MCPCallToolResult: + return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout_seconds=poll_timeout) + + return _call_as_task() + else: + self._log_debug_with_thread("tool=<%s> | using direct call_tool", name) + + async def _call_tool_direct() -> MCPCallToolResult: + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds + ) + + return _call_tool_direct() + def call_tool_sync( self, tool_use_id: str, @@ -548,10 +624,8 @@ def call_tool_sync( ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. If the MCP tool returns - structured content, it will be included as the last item in the content array - of the returned ToolResult. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -566,13 +640,9 @@ def call_tool_sync( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result() + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: logger.exception("tool execution failed") @@ -587,8 +657,8 @@ async def call_tool_async( ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the MCPToolResult format. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -603,13 +673,9 @@ async def call_tool_async( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - future = self._invoke_on_background_thread(_call_tool_async()) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -898,3 +964,205 @@ def _is_session_active(self) -> bool: return False return True + + def _has_server_task_support(self) -> bool: + """Check if the MCP server supports task-augmented tool calls. + + Returns the cached capability value that was populated during list_tools_sync(). + If list_tools_sync() hasn't been called yet, returns False (conservative default). + + The capability is cached during list_tools_sync() to avoid needing a separate + async call during call_tool_*() operations. + + Returns: + True if server supports task-augmented tool calls, False otherwise. + """ + # Return cached value, defaulting to False if not yet populated + # The cache is populated during list_tools_sync() + return self._server_task_capable or False + + def _get_tool_task_support(self, tool_name: str) -> str | None: + """Get the taskSupport setting for a tool. + + Returns the cached taskSupport value for the given tool name. + The cache is populated during list_tools_sync(). + + Args: + tool_name: Name of the tool to look up. + + Returns: + The taskSupport value ('required', 'optional', 'forbidden') or None if not cached. + """ + return self._tool_task_support_cache.get(tool_name) + + def _should_use_task(self, tool_name: str) -> bool: + """Determine if task-augmented execution should be used for a tool. + + Implements the MCP spec decision matrix: + - If server doesn't support tasks: MUST NOT use tasks (returns False) + - If tool taskSupport is None or 'forbidden': MUST NOT use tasks (returns False) + - If tool taskSupport is 'required' and server supports: use tasks (returns True) + - If tool taskSupport is 'optional' and server supports: prefer tasks (returns True) + + Per MCP spec, server capability check takes precedence over tool-level settings. + + Args: + tool_name: Name of the tool to check. + + Returns: + True if task-augmented execution should be used, False otherwise. + """ + # Server capability check comes first (per MCP spec) + if not self._has_server_task_support(): + return False + + task_support = self._get_tool_task_support(tool_name) + + # Use tasks for 'required' or 'optional' when server supports + if task_support == "required" or task_support == "optional": + return True + + # Default: 'forbidden', None, or unknown -> don't use tasks + return False + + def _convert_timeout_for_polling(self, read_timeout_seconds: timedelta | None) -> float | None: + """Convert a timedelta timeout to seconds for task polling. + + When task-augmented execution is used, the read_timeout_seconds parameter + (which is a timedelta) needs to be converted to a float for the polling timeout. + + Args: + read_timeout_seconds: Optional timedelta timeout from the call_tool API. + + Returns: + Float seconds if timeout was specified, None to use default. + """ + return read_timeout_seconds.total_seconds() if read_timeout_seconds else None + + def _create_task_error_result(self, message: str) -> MCPCallToolResult: + """Create an error MCPCallToolResult with consistent formatting. + + This helper reduces duplication in task error handling paths. + + Args: + message: The error message to include in the result. + + Returns: + MCPCallToolResult with isError=True and the message as text content. + """ + return MCPCallToolResult( + isError=True, + content=[MCPTextContent(type="text", text=message)], + ) + + # ================================================================================== + # Task-Augmented Tool Execution + # ================================================================================== + # + # The MCP spec defines task-augmented execution for long-running tools. The flow is: + # + # 1. Check server capability (tasks.requests.tools.call) and tool setting (taskSupport) + # 2. If using tasks: call_tool_as_task() -> poll_task() -> get_task_result() + # 3. If not using tasks: call_tool() directly + # + # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks + # ================================================================================== + + async def _call_tool_as_task_and_poll_async( + self, + name: str, + arguments: dict[str, Any] | None = None, + ttl_ms: int | None = None, + poll_timeout_seconds: float | None = None, + ) -> MCPCallToolResult: + """Call a tool using task-augmented execution and poll until completion. + + This method implements the MCP task workflow: + 1. Creates a task via call_tool_as_task + 2. Polls using poll_task until terminal status (with timeout protection) + 3. Gets the final result using get_task_result + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + ttl_ms: Task time-to-live in milliseconds. Uses default_task_ttl_ms if not specified. + poll_timeout_seconds: Timeout for polling in seconds. Uses default_task_poll_timeout_seconds if not + specified. + + Returns: + MCPCallToolResult: The final tool result after task completion. + """ + session = cast(ClientSession, self._background_thread_session) + ttl = ttl_ms or self._default_task_ttl_ms + timeout = poll_timeout_seconds or self._default_task_poll_timeout_seconds + + # Step 1: Create the task + self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl) + create_result = await session.experimental.call_tool_as_task( + name=name, + arguments=arguments, + ttl=ttl, + ) + task_id = create_result.task.taskId + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) + + # Step 2: Poll until terminal status (with timeout protection) + # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility + async def _poll_until_terminal() -> Any: + """Inner function to poll task status until terminal state.""" + final = None + async for status in session.experimental.poll_task(task_id): + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | task status update", + name, + task_id, + status.status, + ) + final = status + return final + + try: + final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout) + except asyncio.TimeoutError: + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, timeout=<%s> | task polling timed out", name, task_id, timeout + ) + return self._create_task_error_result(f"Task {task_id} polling timed out after {timeout} seconds") + + # Step 3: Handle terminal status + if final_status is None: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) + return self._create_task_error_result(f"Task {task_id} polling completed without status") + + if final_status.status == "failed": + error_msg = final_status.statusMessage or "Task failed" + self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) + return self._create_task_error_result(error_msg) + + if final_status.status == "cancelled": + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) + return self._create_task_error_result("Task was cancelled") + + # Step 4: Get the actual result for completed tasks (with error handling for race conditions) + if final_status.status == "completed": + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) + try: + result = await session.experimental.get_task_result(task_id, MCPCallToolResult) + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task result retrieved", name, task_id) + return result + except Exception as e: + # Handle race condition: task completed but result retrieval failed + # (e.g., result expired, network error, server restarted) + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, error=<%s> | failed to retrieve task result", name, task_id, str(e) + ) + return self._create_task_error_result(f"Task completed but result retrieval failed: {str(e)}") + + # Unexpected status - return as error + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | unexpected task status", + name, + task_id, + final_status.status, + ) + return self._create_task_error_result(f"Unexpected task status: {final_status.status}") diff --git a/tests/strands/tools/mcp/conftest.py b/tests/strands/tools/mcp/conftest.py new file mode 100644 index 000000000..0cfce470a --- /dev/null +++ b/tests/strands/tools/mcp/conftest.py @@ -0,0 +1,59 @@ +"""Shared fixtures and helpers for MCP client tests.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_transport(): + """Create a mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create a mock MCP session.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + # Default: no task support (get_server_capabilities is sync, not async!) + mock_session.get_server_capabilities = MagicMock(return_value=None) + + # Create a mock context manager for ClientSession + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + # Patch ClientSession to return our mock session + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +def create_server_capabilities(has_task_support: bool) -> MagicMock: + """Create mock server capabilities. + + Args: + has_task_support: Whether the server should advertise task support. + + Returns: + MagicMock representing server capabilities. + """ + caps = MagicMock() + if has_task_support: + caps.tasks = MagicMock() + caps.tasks.requests = MagicMock() + caps.tasks.requests.tools = MagicMock() + caps.tasks.requests.tools.call = MagicMock() + else: + caps.tasks = None + return caps diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f784da414..4c9ca6752 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,6 +1,6 @@ import base64 import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from mcp import ListToolsResult @@ -25,35 +25,7 @@ from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError - -@pytest.fixture -def mock_transport(): - mock_read_stream = AsyncMock() - mock_write_stream = AsyncMock() - mock_transport_cm = AsyncMock() - mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) - mock_transport_callable = MagicMock(return_value=mock_transport_cm) - - return { - "read_stream": mock_read_stream, - "write_stream": mock_write_stream, - "transport_cm": mock_transport_cm, - "transport_callable": mock_transport_callable, - } - - -@pytest.fixture -def mock_session(): - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - - # Create a mock context manager for ClientSession - mock_session_cm = AsyncMock() - mock_session_cm.__aenter__.return_value = mock_session - - # Patch ClientSession to return our mock session - with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): - yield mock_session +# Fixtures mock_transport and mock_session are imported from conftest.py @pytest.fixture diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..6ce53f292 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -0,0 +1,223 @@ +"""Tests for MCP task-augmented execution support in MCPClient. + +These unit tests focus on error handling and edge cases that are not easily +testable through integration tests. Happy-path flows are covered by +integration tests in tests_integ/mcp/test_mcp_client_tasks.py. +""" + +import asyncio +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from mcp import ListToolsResult +from mcp.types import Tool as MCPTool +from mcp.types import ToolExecution + +from strands.tools.mcp import MCPClient + +from .conftest import create_server_capabilities + + +class TestTaskExecutionFailures: + """Tests for task execution failure handling.""" + + @pytest.mark.parametrize( + "status,status_message,expected_text", + [ + ("failed", "Something went wrong", "Something went wrong"), + ("cancelled", None, "cancelled"), + ], + ) + def test_task_execution_terminal_status(self, mock_transport, mock_session, status, status_message, expected_text): + """Test handling of terminal task statuses (failed, cancelled).""" + mock_create_result = MagicMock() + mock_create_result.task.taskId = f"task-{status}" + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + mock_status = MagicMock() + mock_status.status = status + mock_status.statusMessage = status_message + + async def mock_poll_task(task_id): + yield mock_status + + mock_session.experimental.poll_task = mock_poll_task + + with MCPClient(mock_transport["transport_callable"]) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) + + assert result["status"] == "error" + assert expected_text.lower() in result["content"][0].get("text", "").lower() + + +class TestStopResetCache: + """Tests for cache reset in stop().""" + + def test_stop_resets_task_caches(self, mock_transport, mock_session): + """Test that stop() resets the task support caches.""" + with MCPClient(mock_transport["transport_callable"]) as client: + client._server_task_capable = True + client._tool_task_support_cache["tool1"] = "required" + + assert client._server_task_capable is None + assert client._tool_task_support_cache == {} + + +class TestTaskConfiguration: + """Tests for task-related configuration options.""" + + def test_default_task_config_values(self, mock_transport, mock_session): + """Test default configuration values.""" + with MCPClient(mock_transport["transport_callable"]) as client: + assert client._default_task_ttl_ms == 60000 + assert client._default_task_poll_timeout_seconds == 300.0 + + def test_custom_task_config_values(self, mock_transport, mock_session): + """Test custom configuration values.""" + with MCPClient( + mock_transport["transport_callable"], + default_task_ttl_ms=120000, + default_task_poll_timeout_seconds=60.0, + ) as client: + assert client._default_task_ttl_ms == 120000 + assert client._default_task_poll_timeout_seconds == 60.0 + + +class TestTaskExecutionTimeout: + """Tests for task execution timeout and error handling.""" + + def _setup_task_tool(self, mock_session, tool_name: str) -> None: + """Helper to set up a mock task-enabled tool.""" + mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True)) + mock_tool = MCPTool( + name=tool_name, + description="A test tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport="optional"), + ) + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) + + mock_create_result = MagicMock() + mock_create_result.task.taskId = "test-task-id" + mock_session.experimental = MagicMock() + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + @pytest.mark.asyncio + async def test_task_polling_timeout(self, mock_transport, mock_session): + """Test that task polling times out properly.""" + self._setup_task_tool(mock_session, "slow_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=0.1) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="test-123", name="slow_tool", arguments={}) + + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_result_retrieval_failure(self, mock_transport, mock_session): + """Test that get_task_result failures are handled gracefully.""" + self._setup_task_tool(mock_session, "failing_tool") + + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="test-456", name="failing_tool", arguments={}) + + assert result["status"] == "error" + assert "result retrieval failed" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session): + """Test that read_timeout_seconds overrides the default poll timeout.""" + self._setup_task_tool(mock_session, "timeout_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + # Long default timeout, but short explicit timeout + with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=300.0) as client: + client.list_tools_sync() + result = await client.call_tool_async( + tool_use_id="test-timeout", + name="timeout_tool", + arguments={}, + read_timeout_seconds=timedelta(seconds=0.1), + ) + + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_polling_yields_no_status(self, mock_transport, mock_session): + """Test handling when poll_task yields nothing (final_status is None).""" + self._setup_task_tool(mock_session, "empty_poll_tool") + + async def empty_poll(task_id): + return + yield # noqa: B901 - makes this an async generator + + mock_session.experimental.poll_task = empty_poll + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) + assert result["status"] == "error" + assert "without status" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_unexpected_terminal_status(self, mock_transport, mock_session): + """Test handling of unexpected task status (not completed/failed/cancelled).""" + self._setup_task_tool(mock_session, "weird_tool") + + async def poll(task_id): + yield MagicMock(status="unknown_status", statusMessage=None) + + mock_session.experimental.poll_task = poll + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="weird_tool", arguments={}) + assert result["status"] == "error" + assert "unexpected task status" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_task_successful_completion(self, mock_transport, mock_session): + """Test successful task completion with result retrieval (happy path).""" + from mcp.types import CallToolResult as MCPCallToolResult + from mcp.types import TextContent as MCPTextContent + + self._setup_task_tool(mock_session, "success_tool") + + async def poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = poll + mock_session.experimental.get_task_result = AsyncMock( + return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) + ) + + with MCPClient(mock_transport["transport_callable"]) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) + assert result["status"] == "success" + assert "Done" in result["content"][0].get("text", "") diff --git a/tests_integ/mcp/task_echo_server.py b/tests_integ/mcp/task_echo_server.py new file mode 100644 index 000000000..4a8edc97d --- /dev/null +++ b/tests_integ/mcp/task_echo_server.py @@ -0,0 +1,139 @@ +"""MCP server with task-augmented tool execution support for integration testing.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import click +import mcp.types as types +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount + + +def create_task_server() -> Server: + """Create and configure the task-supporting MCP server.""" + server = Server("task-echo-server") + server.experimental.enable_tasks() + + # Workaround: MCP Python SDK's enable_tasks() doesn't properly set tasks.requests.tools.call capability + original_update_capabilities = server.experimental.update_capabilities + + def patched_update_capabilities(capabilities: types.ServerCapabilities) -> None: + original_update_capabilities(capabilities) + if capabilities.tasks and capabilities.tasks.requests and capabilities.tasks.requests.tools: + capabilities.tasks.requests.tools.call = types.TasksCallCapability() + + server.experimental.update_capabilities = patched_update_capabilities # type: ignore[method-assign] + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="task_required_echo", + description="Echo that requires task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + types.Tool( + name="task_optional_echo", + description="Echo that optionally supports task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_OPTIONAL), + ), + types.Tool( + name="task_forbidden_echo", + description="Echo that does not support task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_FORBIDDEN), + ), + types.Tool( + name="echo", + description="Simple echo without task support setting", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + ), + ] + + async def handle_task_required_echo(arguments: dict[str, Any]) -> types.CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + message = arguments.get("message", "") + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing echo...") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Task echo: {message}")]) + + return await ctx.experimental.run_task(work) + + async def handle_task_optional_echo(arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + ctx = server.request_context + message = arguments.get("message", "") + + if ctx.experimental.is_task: + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing optional task echo...") + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Task optional echo: {message}")] + ) + + return await ctx.experimental.run_task(work) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Direct optional echo: {message}")] + ) + + async def handle_task_forbidden_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Forbidden echo: {message}")]) + + async def handle_simple_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Simple echo: {message}")]) + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + handlers = { + "task_required_echo": handle_task_required_echo, + "task_optional_echo": handle_task_optional_echo, + "task_forbidden_echo": handle_task_forbidden_echo, + "echo": handle_simple_echo, + } + if name in handlers: + return await handlers[name](arguments) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], isError=True + ) + + return server + + +def create_starlette_app(port: int) -> tuple[Starlette, StreamableHTTPSessionManager]: + """Create the Starlette app with MCP session manager.""" + server = create_task_server() + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=app_lifespan), session_manager + + +@click.command() +@click.option("--port", default=8010, help="Port to listen on") +def main(port: int) -> int: + """Start the task echo server.""" + import uvicorn + + starlette_app, _ = create_starlette_app(port) + print(f"Starting task echo server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..a294246f4 --- /dev/null +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -0,0 +1,188 @@ +"""Integration tests for MCP task-augmented tool execution. + +These tests verify that our MCPClient correctly handles tools with taskSupport settings +and integrates with MCP servers that support task-augmented execution. + +The test server (task_echo_server.py) includes a workaround for an MCP Python SDK bug +where `enable_tasks()` doesn't properly set `tasks.requests.tools.call` capability. +""" + +import os +import socket +import threading +import time +from typing import Any + +import pytest +from mcp.client.streamable_http import streamablehttp_client + +from strands.tools.mcp.mcp_client import MCPClient +from strands.tools.mcp.mcp_types import MCPTransport + + +def _find_available_port() -> int: + """Find an available port by binding to port 0 and letting the OS assign one.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + +def start_task_server(port: int) -> None: + """Start the task echo server in a thread.""" + import uvicorn + + from tests_integ.mcp.task_echo_server import create_starlette_app + + starlette_app, _ = create_starlette_app(port) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="warning") + + +@pytest.fixture(scope="module") +def task_server_port() -> int: + """Get a dynamically allocated port for the task server.""" + return _find_available_port() + + +@pytest.fixture(scope="module") +def task_server(task_server_port: int) -> Any: + """Start the task server for the test module.""" + server_thread = threading.Thread(target=start_task_server, kwargs={"port": task_server_port}, daemon=True) + server_thread.start() + time.sleep(2) # Wait for server to start + yield + # Server thread is daemon, will be cleaned up automatically + + +@pytest.fixture +def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client connected to the task server.""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback) + + +@pytest.mark.skipif( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="streamable transport is failing in GitHub actions", +) +class TestMCPTaskSupport: + """Integration tests for MCP task-augmented execution. + + These tests verify our client correctly: + 1. Detects server task capability and uses task-augmented execution when appropriate + 2. Caches taskSupport settings from tools + 3. Falls back to direct call_tool for tools that don't support tasks + 4. Handles the full task workflow (call_tool_as_task -> poll_task -> get_task_result) + """ + + def test_task_forbidden_tool_uses_direct_call(self, task_mcp_client: MCPClient) -> None: + """Test that a tool with taskSupport='forbidden' uses direct call_tool.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "task_forbidden_echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-1", name="task_forbidden_echo", arguments={"message": "Hello forbidden!"} + ) + assert result["status"] == "success" + assert "Forbidden echo: Hello forbidden!" in result["content"][0].get("text", "") + + def test_tool_without_task_support_uses_direct_call(self, task_mcp_client: MCPClient) -> None: + """Test that a tool without taskSupport setting uses direct call_tool.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-2", name="echo", arguments={"message": "Hello simple!"} + ) + assert result["status"] == "success" + assert "Simple echo: Hello simple!" in result["content"][0].get("text", "") + + def test_tool_task_support_caching(self, task_mcp_client: MCPClient) -> None: + """Test that tool taskSupport values are cached during list_tools.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + assert task_mcp_client._get_tool_task_support("task_required_echo") == "required" + assert task_mcp_client._get_tool_task_support("task_optional_echo") == "optional" + assert task_mcp_client._get_tool_task_support("task_forbidden_echo") == "forbidden" + assert task_mcp_client._get_tool_task_support("echo") is None + + def test_server_capabilities_advertised(self, task_mcp_client: MCPClient) -> None: + """Test that server properly advertises task capabilities.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + session = task_mcp_client._background_thread_session + if session: + caps = session.get_server_capabilities() + assert caps is not None and caps.tasks is not None + assert caps.tasks.requests is not None and caps.tasks.requests.tools is not None + assert caps.tasks.requests.tools.call is not None + assert task_mcp_client._has_server_task_support() is True + + def test_task_required_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: + """Test that task-required tools use task-augmented execution.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "task_required_echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-3", name="task_required_echo", arguments={"message": "Hello from task!"} + ) + assert result["status"] == "success" + assert "Task echo: Hello from task!" in result["content"][0].get("text", "") + + def test_task_optional_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: + """Test that task-optional tools use task-augmented execution when server supports it.""" + with task_mcp_client: + tools = task_mcp_client.list_tools_sync() + assert "task_optional_echo" in [t.tool_name for t in tools] + + result = task_mcp_client.call_tool_sync( + tool_use_id="test-4", name="task_optional_echo", arguments={"message": "Hello optional task!"} + ) + assert result["status"] == "success" + assert "Task optional echo: Hello optional task!" in result["content"][0].get("text", "") + + def test_should_use_task_logic_with_server_support(self, task_mcp_client: MCPClient) -> None: + """Test that _should_use_task returns correct values based on tool taskSupport.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + assert task_mcp_client._should_use_task("task_required_echo") is True + assert task_mcp_client._should_use_task("task_optional_echo") is True + assert task_mcp_client._should_use_task("task_forbidden_echo") is False + assert task_mcp_client._should_use_task("echo") is False + + def test_multiple_tool_calls_in_sequence(self, task_mcp_client: MCPClient) -> None: + """Test calling multiple tools in sequence with different task modes.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + r1 = task_mcp_client.call_tool_sync( + tool_use_id="s1", name="task_forbidden_echo", arguments={"message": "1"} + ) + assert r1["status"] == "success" and "Forbidden echo: 1" in r1["content"][0].get("text", "") + + r2 = task_mcp_client.call_tool_sync(tool_use_id="s2", name="echo", arguments={"message": "2"}) + assert r2["status"] == "success" and "Simple echo: 2" in r2["content"][0].get("text", "") + + r3 = task_mcp_client.call_tool_sync(tool_use_id="s3", name="task_optional_echo", arguments={"message": "3"}) + assert r3["status"] == "success" and "Task optional echo: 3" in r3["content"][0].get("text", "") + + r4 = task_mcp_client.call_tool_sync(tool_use_id="s4", name="task_required_echo", arguments={"message": "4"}) + assert r4["status"] == "success" and "Task echo: 4" in r4["content"][0].get("text", "") + + @pytest.mark.asyncio + async def test_async_tool_calls(self, task_mcp_client: MCPClient) -> None: + """Test async tool calls work correctly.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + result = await task_mcp_client.call_tool_async( + tool_use_id="test-async", name="task_forbidden_echo", arguments={"message": "Async hello!"} + ) + assert result["status"] == "success" + assert "Forbidden echo: Async hello!" in result["content"][0].get("text", "") From a4a5ac7aedf7ea3e5607ddef9c097b917c5248e5 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Thu, 22 Jan 2026 14:47:15 -0800 Subject: [PATCH 2/4] chore: cache server task capability immediately --- src/strands/tools/mcp/mcp_client.py | 48 ++++++++----------- .../tools/mcp/test_mcp_client_contextvar.py | 2 + 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 2ca284407..ac16d4800 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -406,27 +406,10 @@ def list_tools_sync( effective_prefix = self._prefix if prefix is None else prefix effective_filters = self._tool_filters if tool_filters is None else tool_filters - async def _list_tools_and_cache_capabilities_async() -> ListToolsResult: - session = cast(ClientSession, self._background_thread_session) - list_tools_result = await session.list_tools(cursor=pagination_token) - - # Cache server task capability while we have an active session - # This avoids needing a separate async call later during call_tool_* - if self._server_task_capable is None: - caps = session.get_server_capabilities() - self._server_task_capable = ( - caps is not None - and caps.tasks is not None - and caps.tasks.requests is not None - and caps.tasks.requests.tools is not None - and caps.tasks.requests.tools.call is not None - ) - - return list_tools_result + async def _list_tools_async() -> ListToolsResult: + return await cast(ClientSession, self._background_thread_session).list_tools(cursor=pagination_token) - list_tools_response: ListToolsResult = self._invoke_on_background_thread( - _list_tools_and_cache_capabilities_async() - ).result() + list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result() self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools)) mcp_tools = [] @@ -753,6 +736,21 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("session initialized successfully") # Store the session for use while we await the close event self._background_thread_session = session + + # Cache server task capability immediately after initialization + # Capabilities are exchanged during session.initialize(), so this is available now + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + self._log_debug_with_thread( + "server_task_capable=<%s> | cached server task capability", self._server_task_capable + ) + # Signal that the session has been created and is ready for use self._init_future.set_result(None) @@ -968,17 +966,13 @@ def _is_session_active(self) -> bool: def _has_server_task_support(self) -> bool: """Check if the MCP server supports task-augmented tool calls. - Returns the cached capability value that was populated during list_tools_sync(). - If list_tools_sync() hasn't been called yet, returns False (conservative default). - - The capability is cached during list_tools_sync() to avoid needing a separate - async call during call_tool_*() operations. + Returns the capability value that was cached immediately after session initialization. + Server capabilities are exchanged during the MCP handshake, so this is available + as soon as start() completes. Returns: True if server supports task-augmented tool calls, False otherwise. """ - # Return cached value, defaulting to False if not yet populated - # The cache is populated during list_tools_sync() return self._server_task_capable or False def _get_tool_task_support(self, tool_name: str) -> str | None: diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py index d95929b02..739796366 100644 --- a/tests/strands/tools/mcp/test_mcp_client_contextvar.py +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -37,6 +37,8 @@ def mock_session(): """Create mock MCP session.""" mock_session = AsyncMock() mock_session.initialize = AsyncMock() + # get_server_capabilities is sync, not async + mock_session.get_server_capabilities = MagicMock(return_value=None) mock_session_cm = AsyncMock() mock_session_cm.__aenter__.return_value = mock_session From 6801cdf5da9026c4b742b2e494e28f5af2758ec3 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 23 Jan 2026 14:24:40 -0800 Subject: [PATCH 3/4] chore: add experimental.tasks feature gate --- src/strands/tools/mcp/mcp_client.py | 107 ++++++++++++++---- .../tools/mcp/test_mcp_client_tasks.py | 74 +++++++++--- tests_integ/mcp/test_mcp_client_tasks.py | 32 +++++- 3 files changed, 174 insertions(+), 39 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index ac16d4800..0cfaae790 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -73,6 +73,39 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] +class TasksConfig(TypedDict, total=False): + """Configuration for MCP Tasks (task-augmented tool execution). + + If this config is provided (not None), task-augmented execution is enabled. + When enabled, long-running tool calls use the MCP task workflow: + create task -> poll for completion -> get result. + + Attributes: + ttl_ms: Task time-to-live in milliseconds. Defaults to 60000 (1 minute). + poll_timeout_seconds: Timeout for polling task completion in seconds. + Defaults to 300.0 (5 minutes). + """ + + ttl_ms: int + poll_timeout_seconds: float + + +class ExperimentalConfig(TypedDict, total=False): + """Configuration for experimental MCPClient features. + + Warning: + Features under this configuration are experimental and subject to change + in future revisions without notice. + + Attributes: + tasks: Configuration for MCP Tasks (task-augmented tool execution). + If provided (not None), enables task-augmented execution for tools + that support it. + """ + + tasks: TasksConfig | None + + MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -120,8 +153,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, - default_task_ttl_ms: int = 60000, - default_task_poll_timeout_seconds: float = 300.0, + experimental: ExperimentalConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -132,10 +164,10 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. - default_task_ttl_ms: Default time-to-live in milliseconds for task-augmented tool calls. - Defaults to 60000 (1 minute). - default_task_poll_timeout_seconds: Default timeout in seconds for polling task completion. - Defaults to 300.0 (5 minutes). + experimental: Configuration for experimental features. Currently supports: + - tasks: Enable MCP task-augmented execution for long-running tools. + If provided (not None), enables task-augmented execution for tools + that support it. See ExperimentalConfig and TasksConfig for details. """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -160,9 +192,8 @@ def __init__( self._tool_provider_started = False self._consumers: set[Any] = set() - # Task support caching - self._default_task_ttl_ms = default_task_ttl_ms - self._default_task_poll_timeout_seconds = default_task_poll_timeout_seconds + # Task support configuration and caching + self._experimental = experimental or {} self._server_task_capable: bool | None = None self._tool_task_support_cache: dict[str, str | None] = {} @@ -963,6 +994,38 @@ def _is_session_active(self) -> bool: return True + def _is_tasks_enabled(self) -> bool: + """Check if experimental tasks feature is enabled. + + Tasks are enabled if experimental.tasks is defined and not None. + + Returns: + True if task-augmented execution is enabled, False otherwise. + """ + return self._experimental.get("tasks") is not None + + def _get_task_ttl_ms(self) -> int: + """Get task TTL in milliseconds. + + Returns: + Task TTL from config, or default of 60000 (1 minute). + """ + tasks_config = self._experimental.get("tasks") + if tasks_config is None: + return 60000 + return tasks_config.get("ttl_ms", 60000) + + def _get_task_poll_timeout_seconds(self) -> float: + """Get task polling timeout in seconds. + + Returns: + Polling timeout from config, or default of 300.0 (5 minutes). + """ + tasks_config = self._experimental.get("tasks") + if tasks_config is None: + return 300.0 + return tasks_config.get("poll_timeout_seconds", 300.0) + def _has_server_task_support(self) -> bool: """Check if the MCP server supports task-augmented tool calls. @@ -992,13 +1055,10 @@ def _get_tool_task_support(self, tool_name: str) -> str | None: def _should_use_task(self, tool_name: str) -> bool: """Determine if task-augmented execution should be used for a tool. - Implements the MCP spec decision matrix: - - If server doesn't support tasks: MUST NOT use tasks (returns False) - - If tool taskSupport is None or 'forbidden': MUST NOT use tasks (returns False) - - If tool taskSupport is 'required' and server supports: use tasks (returns True) - - If tool taskSupport is 'optional' and server supports: prefer tasks (returns True) - - Per MCP spec, server capability check takes precedence over tool-level settings. + Task-augmented execution requires: + 1. experimental.tasks is enabled (opt-in check) + 2. Server supports tasks (capability check) + 3. Tool taskSupport is 'required' or 'optional' Args: tool_name: Name of the tool to check. @@ -1006,7 +1066,11 @@ def _should_use_task(self, tool_name: str) -> bool: Returns: True if task-augmented execution should be used, False otherwise. """ - # Server capability check comes first (per MCP spec) + # Opt-in check: tasks must be explicitly enabled via experimental.tasks + if not self._is_tasks_enabled(): + return False + + # Server capability check (per MCP spec) if not self._has_server_task_support(): return False @@ -1079,16 +1143,15 @@ async def _call_tool_as_task_and_poll_async( Args: name: Name of the tool to call. arguments: Optional arguments to pass to the tool. - ttl_ms: Task time-to-live in milliseconds. Uses default_task_ttl_ms if not specified. - poll_timeout_seconds: Timeout for polling in seconds. Uses default_task_poll_timeout_seconds if not - specified. + ttl_ms: Task time-to-live in milliseconds. Uses configured value if not specified. + poll_timeout_seconds: Timeout for polling in seconds. Uses configured value if not specified. Returns: MCPCallToolResult: The final tool result after task completion. """ session = cast(ClientSession, self._background_thread_session) - ttl = ttl_ms or self._default_task_ttl_ms - timeout = poll_timeout_seconds or self._default_task_poll_timeout_seconds + ttl = ttl_ms or self._get_task_ttl_ms() + timeout = poll_timeout_seconds or self._get_task_poll_timeout_seconds() # Step 1: Create the task self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 6ce53f292..2dd4908a5 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -19,6 +19,38 @@ from .conftest import create_server_capabilities +class TestTasksDisabledByDefault: + """Tests that tasks are disabled by default.""" + + def test_tasks_disabled_when_no_experimental_config(self, mock_transport, mock_session): + """Test that _should_use_task returns False when experimental.tasks is not configured.""" + with MCPClient(mock_transport["transport_callable"]) as client: + # Even with server capability and tool support, tasks should be disabled + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + + assert client._is_tasks_enabled() is False + assert client._should_use_task("test_tool") is False + + def test_tasks_disabled_when_experimental_tasks_is_none(self, mock_transport, mock_session): + """Test that _should_use_task returns False when experimental.tasks is explicitly None.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": None}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + + assert client._is_tasks_enabled() is False + assert client._should_use_task("test_tool") is False + + def test_tasks_enabled_when_experimental_tasks_is_empty_dict(self, mock_transport, mock_session): + """Test that tasks are enabled when experimental.tasks is an empty dict.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + + assert client._is_tasks_enabled() is True + assert client._should_use_task("test_tool") is True + + class TestTaskExecutionFailures: """Tests for task execution failure handling.""" @@ -44,7 +76,7 @@ async def mock_poll_task(task_id): mock_session.experimental.poll_task = mock_poll_task - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) @@ -58,7 +90,7 @@ class TestStopResetCache: def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["tool1"] = "required" @@ -71,19 +103,27 @@ class TestTaskConfiguration: def test_default_task_config_values(self, mock_transport, mock_session): """Test default configuration values.""" - with MCPClient(mock_transport["transport_callable"]) as client: - assert client._default_task_ttl_ms == 60000 - assert client._default_task_poll_timeout_seconds == 300.0 + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + assert client._get_task_ttl_ms() == 60000 + assert client._get_task_poll_timeout_seconds() == 300.0 def test_custom_task_config_values(self, mock_transport, mock_session): """Test custom configuration values.""" with MCPClient( mock_transport["transport_callable"], - default_task_ttl_ms=120000, - default_task_poll_timeout_seconds=60.0, + experimental={"tasks": {"ttl_ms": 120000, "poll_timeout_seconds": 60.0}}, + ) as client: + assert client._get_task_ttl_ms() == 120000 + assert client._get_task_poll_timeout_seconds() == 60.0 + + def test_partial_task_config_uses_defaults(self, mock_transport, mock_session): + """Test that partial config uses defaults for unspecified values.""" + with MCPClient( + mock_transport["transport_callable"], + experimental={"tasks": {"ttl_ms": 120000}}, ) as client: - assert client._default_task_ttl_ms == 120000 - assert client._default_task_poll_timeout_seconds == 60.0 + assert client._get_task_ttl_ms() == 120000 + assert client._get_task_poll_timeout_seconds() == 300.0 # default class TestTaskExecutionTimeout: @@ -117,7 +157,9 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll - with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=0.1) as client: + with MCPClient( + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 0.1}} + ) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="test-123", name="slow_tool", arguments={}) @@ -135,7 +177,7 @@ async def successful_poll(task_id): mock_session.experimental.poll_task = successful_poll mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="test-456", name="failing_tool", arguments={}) @@ -155,7 +197,9 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll # Long default timeout, but short explicit timeout - with MCPClient(mock_transport["transport_callable"], default_task_poll_timeout_seconds=300.0) as client: + with MCPClient( + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 300.0}} + ) as client: client.list_tools_sync() result = await client.call_tool_async( tool_use_id="test-timeout", @@ -178,7 +222,7 @@ async def empty_poll(task_id): mock_session.experimental.poll_task = empty_poll - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) assert result["status"] == "error" @@ -194,7 +238,7 @@ async def poll(task_id): mock_session.experimental.poll_task = poll - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="weird_tool", arguments={}) assert result["status"] == "error" @@ -216,7 +260,7 @@ async def poll(task_id): return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) ) - with MCPClient(mock_transport["transport_callable"]) as client: + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) assert result["status"] == "success" diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index a294246f4..892dfb90e 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -57,12 +57,22 @@ def task_server(task_server_port: int) -> Any: @pytest.fixture def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: - """Create an MCP client connected to the task server.""" + """Create an MCP client connected to the task server with tasks enabled.""" def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") - return MCPClient(transport_callback) + return MCPClient(transport_callback, experimental={"tasks": {}}) + + +@pytest.fixture +def task_mcp_client_disabled(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client connected to the task server with tasks disabled (default).""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback) # No experimental config - tasks disabled @pytest.mark.skipif( @@ -186,3 +196,21 @@ async def test_async_tool_calls(self, task_mcp_client: MCPClient) -> None: ) assert result["status"] == "success" assert "Forbidden echo: Async hello!" in result["content"][0].get("text", "") + + def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> None: + """Test that tasks are disabled when experimental.tasks is not configured.""" + with task_mcp_client_disabled: + task_mcp_client_disabled.list_tools_sync() + + # Even though server supports tasks and tool has taskSupport='required', + # tasks should NOT be used because experimental.tasks is not configured + assert task_mcp_client_disabled._is_tasks_enabled() is False + assert task_mcp_client_disabled._should_use_task("task_required_echo") is False + assert task_mcp_client_disabled._should_use_task("task_optional_echo") is False + + # Tool calls should still work via direct call_tool + result = task_mcp_client_disabled.call_tool_sync( + tool_use_id="test-disabled", name="task_required_echo", arguments={"message": "Direct call!"} + ) + assert result["status"] == "success" + assert "Task echo: Direct call!" in result["content"][0].get("text", "") From ceaea6ac5fd54fe3376c82ecd3ec8719e280eaa3 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Fri, 23 Jan 2026 14:32:11 -0800 Subject: [PATCH 4/4] chore: parameterize tests --- .../tools/mcp/test_mcp_client_tasks.py | 223 +++++++----------- tests_integ/mcp/test_mcp_client_tasks.py | 178 +++++--------- 2 files changed, 143 insertions(+), 258 deletions(-) diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py index 2dd4908a5..629163695 100644 --- a/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -1,9 +1,4 @@ -"""Tests for MCP task-augmented execution support in MCPClient. - -These unit tests focus on error handling and edge cases that are not easily -testable through integration tests. Happy-path flows are covered by -integration tests in tests_integ/mcp/test_mcp_client_tasks.py. -""" +"""Tests for MCP task-augmented execution support in MCPClient.""" import asyncio from datetime import timedelta @@ -11,6 +6,8 @@ import pytest from mcp import ListToolsResult +from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import TextContent as MCPTextContent from mcp.types import Tool as MCPTool from mcp.types import ToolExecution @@ -19,115 +16,66 @@ from .conftest import create_server_capabilities -class TestTasksDisabledByDefault: - """Tests that tasks are disabled by default.""" +class TestTasksOptIn: + """Tests for task opt-in behavior via experimental.tasks.""" - def test_tasks_disabled_when_no_experimental_config(self, mock_transport, mock_session): - """Test that _should_use_task returns False when experimental.tasks is not configured.""" - with MCPClient(mock_transport["transport_callable"]) as client: - # Even with server capability and tool support, tasks should be disabled - client._server_task_capable = True - client._tool_task_support_cache["test_tool"] = "required" - - assert client._is_tasks_enabled() is False - assert client._should_use_task("test_tool") is False + @pytest.mark.parametrize( + "experimental,expected_enabled", + [ + (None, False), + ({}, False), + ({"tasks": None}, False), + ({"tasks": {}}, True), + ({"tasks": {"ttl_ms": 1000}}, True), + ], + ) + def test_tasks_enabled_state(self, mock_transport, mock_session, experimental, expected_enabled): + """Test _is_tasks_enabled based on experimental config.""" + with MCPClient(mock_transport["transport_callable"], experimental=experimental) as client: + assert client._is_tasks_enabled() is expected_enabled - def test_tasks_disabled_when_experimental_tasks_is_none(self, mock_transport, mock_session): - """Test that _should_use_task returns False when experimental.tasks is explicitly None.""" - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": None}) as client: + def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): + """Test that _should_use_task returns False without opt-in even with server/tool support.""" + with MCPClient(mock_transport["transport_callable"]) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" - - assert client._is_tasks_enabled() is False assert client._should_use_task("test_tool") is False - def test_tasks_enabled_when_experimental_tasks_is_empty_dict(self, mock_transport, mock_session): - """Test that tasks are enabled when experimental.tasks is an empty dict.""" with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["test_tool"] = "required" - - assert client._is_tasks_enabled() is True assert client._should_use_task("test_tool") is True -class TestTaskExecutionFailures: - """Tests for task execution failure handling.""" +class TestTaskConfiguration: + """Tests for task-related configuration options.""" @pytest.mark.parametrize( - "status,status_message,expected_text", + "config,expected_ttl,expected_timeout", [ - ("failed", "Something went wrong", "Something went wrong"), - ("cancelled", None, "cancelled"), + ({}, 60000, 300.0), + ({"ttl_ms": 120000}, 120000, 300.0), + ({"poll_timeout_seconds": 60.0}, 60000, 60.0), + ({"ttl_ms": 120000, "poll_timeout_seconds": 60.0}, 120000, 60.0), ], ) - def test_task_execution_terminal_status(self, mock_transport, mock_session, status, status_message, expected_text): - """Test handling of terminal task statuses (failed, cancelled).""" - mock_create_result = MagicMock() - mock_create_result.task.taskId = f"task-{status}" - mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) - - mock_status = MagicMock() - mock_status.status = status - mock_status.statusMessage = status_message - - async def mock_poll_task(task_id): - yield mock_status - - mock_session.experimental.poll_task = mock_poll_task - - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: - client._server_task_capable = True - client._tool_task_support_cache["test_tool"] = "required" - result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) - - assert result["status"] == "error" - assert expected_text.lower() in result["content"][0].get("text", "").lower() - - -class TestStopResetCache: - """Tests for cache reset in stop().""" + def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): + """Test task configuration values with various configs.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": config}) as client: + assert client._get_task_ttl_ms() == expected_ttl + assert client._get_task_poll_timeout_seconds() == expected_timeout def test_stop_resets_task_caches(self, mock_transport, mock_session): """Test that stop() resets the task support caches.""" with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client._server_task_capable = True client._tool_task_support_cache["tool1"] = "required" - assert client._server_task_capable is None assert client._tool_task_support_cache == {} -class TestTaskConfiguration: - """Tests for task-related configuration options.""" - - def test_default_task_config_values(self, mock_transport, mock_session): - """Test default configuration values.""" - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: - assert client._get_task_ttl_ms() == 60000 - assert client._get_task_poll_timeout_seconds() == 300.0 - - def test_custom_task_config_values(self, mock_transport, mock_session): - """Test custom configuration values.""" - with MCPClient( - mock_transport["transport_callable"], - experimental={"tasks": {"ttl_ms": 120000, "poll_timeout_seconds": 60.0}}, - ) as client: - assert client._get_task_ttl_ms() == 120000 - assert client._get_task_poll_timeout_seconds() == 60.0 - - def test_partial_task_config_uses_defaults(self, mock_transport, mock_session): - """Test that partial config uses defaults for unspecified values.""" - with MCPClient( - mock_transport["transport_callable"], - experimental={"tasks": {"ttl_ms": 120000}}, - ) as client: - assert client._get_task_ttl_ms() == 120000 - assert client._get_task_poll_timeout_seconds() == 300.0 # default - - -class TestTaskExecutionTimeout: - """Tests for task execution timeout and error handling.""" +class TestTaskExecution: + """Tests for task execution and error handling.""" def _setup_task_tool(self, mock_session, tool_name: str) -> None: """Helper to set up a mock task-enabled tool.""" @@ -139,14 +87,39 @@ def _setup_task_tool(self, mock_session, tool_name: str) -> None: execution=ToolExecution(taskSupport="optional"), ) mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) - mock_create_result = MagicMock() mock_create_result.task.taskId = "test-task-id" mock_session.experimental = MagicMock() mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + @pytest.mark.parametrize( + "status,status_message,expected_text", + [ + ("failed", "Something went wrong", "Something went wrong"), + ("cancelled", None, "cancelled"), + ("unknown_status", None, "unexpected task status"), + ], + ) + def test_terminal_status_handling(self, mock_transport, mock_session, status, status_message, expected_text): + """Test handling of terminal task statuses.""" + mock_create_result = MagicMock() + mock_create_result.task.taskId = f"task-{status}" + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + async def mock_poll_task(task_id): + yield MagicMock(status=status, statusMessage=status_message) + + mock_session.experimental.poll_task = mock_poll_task + + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) + assert result["status"] == "error" + assert expected_text.lower() in result["content"][0].get("text", "").lower() + @pytest.mark.asyncio - async def test_task_polling_timeout(self, mock_transport, mock_session): + async def test_polling_timeout(self, mock_transport, mock_session): """Test that task polling times out properly.""" self._setup_task_tool(mock_session, "slow_tool") @@ -161,29 +134,10 @@ async def infinite_poll(task_id): mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 0.1}} ) as client: client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="test-123", name="slow_tool", arguments={}) - + result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) assert result["status"] == "error" assert "timed out" in result["content"][0].get("text", "").lower() - @pytest.mark.asyncio - async def test_task_result_retrieval_failure(self, mock_transport, mock_session): - """Test that get_task_result failures are handled gracefully.""" - self._setup_task_tool(mock_session, "failing_tool") - - async def successful_poll(task_id): - yield MagicMock(status="completed", statusMessage=None) - - mock_session.experimental.poll_task = successful_poll - mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) - - with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: - client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="test-456", name="failing_tool", arguments={}) - - assert result["status"] == "error" - assert "result retrieval failed" in result["content"][0].get("text", "").lower() - @pytest.mark.asyncio async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session): """Test that read_timeout_seconds overrides the default poll timeout.""" @@ -196,60 +150,53 @@ async def infinite_poll(task_id): mock_session.experimental.poll_task = infinite_poll - # Long default timeout, but short explicit timeout with MCPClient( mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 300.0}} ) as client: client.list_tools_sync() result = await client.call_tool_async( - tool_use_id="test-timeout", - name="timeout_tool", - arguments={}, - read_timeout_seconds=timedelta(seconds=0.1), + tool_use_id="t", name="timeout_tool", arguments={}, read_timeout_seconds=timedelta(seconds=0.1) ) - assert result["status"] == "error" assert "timed out" in result["content"][0].get("text", "").lower() @pytest.mark.asyncio - async def test_task_polling_yields_no_status(self, mock_transport, mock_session): - """Test handling when poll_task yields nothing (final_status is None).""" - self._setup_task_tool(mock_session, "empty_poll_tool") + async def test_result_retrieval_failure(self, mock_transport, mock_session): + """Test that get_task_result failures are handled gracefully.""" + self._setup_task_tool(mock_session, "failing_tool") - async def empty_poll(task_id): - return - yield # noqa: B901 - makes this an async generator + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) - mock_session.experimental.poll_task = empty_poll + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) + result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={}) assert result["status"] == "error" - assert "without status" in result["content"][0].get("text", "").lower() + assert "result retrieval failed" in result["content"][0].get("text", "").lower() @pytest.mark.asyncio - async def test_task_unexpected_terminal_status(self, mock_transport, mock_session): - """Test handling of unexpected task status (not completed/failed/cancelled).""" - self._setup_task_tool(mock_session, "weird_tool") + async def test_empty_poll_result(self, mock_transport, mock_session): + """Test handling when poll_task yields nothing.""" + self._setup_task_tool(mock_session, "empty_poll_tool") - async def poll(task_id): - yield MagicMock(status="unknown_status", statusMessage=None) + async def empty_poll(task_id): + return + yield # noqa: B901 - mock_session.experimental.poll_task = poll + mock_session.experimental.poll_task = empty_poll with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: client.list_tools_sync() - result = await client.call_tool_async(tool_use_id="t", name="weird_tool", arguments={}) + result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) assert result["status"] == "error" - assert "unexpected task status" in result["content"][0].get("text", "").lower() + assert "without status" in result["content"][0].get("text", "").lower() @pytest.mark.asyncio - async def test_task_successful_completion(self, mock_transport, mock_session): - """Test successful task completion with result retrieval (happy path).""" - from mcp.types import CallToolResult as MCPCallToolResult - from mcp.types import TextContent as MCPTextContent - + async def test_successful_completion(self, mock_transport, mock_session): + """Test successful task completion.""" self._setup_task_tool(mock_session, "success_tool") async def poll(task_id): diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py index 892dfb90e..5e398f6de 100644 --- a/tests_integ/mcp/test_mcp_client_tasks.py +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -1,11 +1,4 @@ -"""Integration tests for MCP task-augmented tool execution. - -These tests verify that our MCPClient correctly handles tools with taskSupport settings -and integrates with MCP servers that support task-augmented execution. - -The test server (task_echo_server.py) includes a workaround for an MCP Python SDK bug -where `enable_tasks()` doesn't properly set `tasks.requests.tools.call` capability. -""" +"""Integration tests for MCP task-augmented tool execution.""" import os import socket @@ -21,12 +14,11 @@ def _find_available_port() -> int: - """Find an available port by binding to port 0 and letting the OS assign one.""" + """Find an available port.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) s.listen(1) - port = s.getsockname()[1] - return port + return s.getsockname()[1] def start_task_server(port: int) -> None: @@ -41,7 +33,6 @@ def start_task_server(port: int) -> None: @pytest.fixture(scope="module") def task_server_port() -> int: - """Get a dynamically allocated port for the task server.""" return _find_available_port() @@ -50,14 +41,13 @@ def task_server(task_server_port: int) -> Any: """Start the task server for the test module.""" server_thread = threading.Thread(target=start_task_server, kwargs={"port": task_server_port}, daemon=True) server_thread.start() - time.sleep(2) # Wait for server to start + time.sleep(2) yield - # Server thread is daemon, will be cleaned up automatically @pytest.fixture def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: - """Create an MCP client connected to the task server with tasks enabled.""" + """Create an MCP client with tasks enabled.""" def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") @@ -67,150 +57,98 @@ def transport_callback() -> MCPTransport: @pytest.fixture def task_mcp_client_disabled(task_server: Any, task_server_port: int) -> MCPClient: - """Create an MCP client connected to the task server with tasks disabled (default).""" + """Create an MCP client with tasks disabled (default).""" def transport_callback() -> MCPTransport: return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") - return MCPClient(transport_callback) # No experimental config - tasks disabled + return MCPClient(transport_callback) -@pytest.mark.skipif( - condition=os.environ.get("GITHUB_ACTIONS") == "true", - reason="streamable transport is failing in GitHub actions", -) +@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport failing in CI") class TestMCPTaskSupport: - """Integration tests for MCP task-augmented execution. + """Integration tests for MCP task-augmented execution.""" - These tests verify our client correctly: - 1. Detects server task capability and uses task-augmented execution when appropriate - 2. Caches taskSupport settings from tools - 3. Falls back to direct call_tool for tools that don't support tasks - 4. Handles the full task workflow (call_tool_as_task -> poll_task -> get_task_result) - """ - - def test_task_forbidden_tool_uses_direct_call(self, task_mcp_client: MCPClient) -> None: - """Test that a tool with taskSupport='forbidden' uses direct call_tool.""" + def test_direct_call_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use direct call_tool (forbidden or no taskSupport).""" with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "task_forbidden_echo" in [t.tool_name for t in tools] + task_mcp_client.list_tools_sync() - result = task_mcp_client.call_tool_sync( - tool_use_id="test-1", name="task_forbidden_echo", arguments={"message": "Hello forbidden!"} + # Tool with taskSupport='forbidden' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_forbidden_echo", arguments={"message": "Hello!"} ) - assert result["status"] == "success" - assert "Forbidden echo: Hello forbidden!" in result["content"][0].get("text", "") + assert r1["status"] == "success" + assert "Forbidden echo: Hello!" in r1["content"][0].get("text", "") + + # Tool without taskSupport + r2 = task_mcp_client.call_tool_sync(tool_use_id="t2", name="echo", arguments={"message": "Simple!"}) + assert r2["status"] == "success" + assert "Simple echo: Simple!" in r2["content"][0].get("text", "") - def test_tool_without_task_support_uses_direct_call(self, task_mcp_client: MCPClient) -> None: - """Test that a tool without taskSupport setting uses direct call_tool.""" + def test_task_augmented_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use task-augmented execution (required or optional).""" with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "echo" in [t.tool_name for t in tools] + task_mcp_client.list_tools_sync() - result = task_mcp_client.call_tool_sync( - tool_use_id="test-2", name="echo", arguments={"message": "Hello simple!"} + # Tool with taskSupport='required' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_required_echo", arguments={"message": "Required!"} ) - assert result["status"] == "success" - assert "Simple echo: Hello simple!" in result["content"][0].get("text", "") + assert r1["status"] == "success" + assert "Task echo: Required!" in r1["content"][0].get("text", "") - def test_tool_task_support_caching(self, task_mcp_client: MCPClient) -> None: - """Test that tool taskSupport values are cached during list_tools.""" + # Tool with taskSupport='optional' + r2 = task_mcp_client.call_tool_sync( + tool_use_id="t2", name="task_optional_echo", arguments={"message": "Optional!"} + ) + assert r2["status"] == "success" + assert "Task optional echo: Optional!" in r2["content"][0].get("text", "") + + def test_task_support_caching_and_decision(self, task_mcp_client: MCPClient) -> None: + """Test taskSupport caching and _should_use_task decision logic.""" with task_mcp_client: task_mcp_client.list_tools_sync() + + # Verify cached values assert task_mcp_client._get_tool_task_support("task_required_echo") == "required" assert task_mcp_client._get_tool_task_support("task_optional_echo") == "optional" assert task_mcp_client._get_tool_task_support("task_forbidden_echo") == "forbidden" assert task_mcp_client._get_tool_task_support("echo") is None - def test_server_capabilities_advertised(self, task_mcp_client: MCPClient) -> None: - """Test that server properly advertises task capabilities.""" - with task_mcp_client: - task_mcp_client.list_tools_sync() - session = task_mcp_client._background_thread_session - if session: - caps = session.get_server_capabilities() - assert caps is not None and caps.tasks is not None - assert caps.tasks.requests is not None and caps.tasks.requests.tools is not None - assert caps.tasks.requests.tools.call is not None - assert task_mcp_client._has_server_task_support() is True - - def test_task_required_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: - """Test that task-required tools use task-augmented execution.""" - with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "task_required_echo" in [t.tool_name for t in tools] - - result = task_mcp_client.call_tool_sync( - tool_use_id="test-3", name="task_required_echo", arguments={"message": "Hello from task!"} - ) - assert result["status"] == "success" - assert "Task echo: Hello from task!" in result["content"][0].get("text", "") - - def test_task_optional_tool_uses_task_execution(self, task_mcp_client: MCPClient) -> None: - """Test that task-optional tools use task-augmented execution when server supports it.""" - with task_mcp_client: - tools = task_mcp_client.list_tools_sync() - assert "task_optional_echo" in [t.tool_name for t in tools] - - result = task_mcp_client.call_tool_sync( - tool_use_id="test-4", name="task_optional_echo", arguments={"message": "Hello optional task!"} - ) - assert result["status"] == "success" - assert "Task optional echo: Hello optional task!" in result["content"][0].get("text", "") - - def test_should_use_task_logic_with_server_support(self, task_mcp_client: MCPClient) -> None: - """Test that _should_use_task returns correct values based on tool taskSupport.""" - with task_mcp_client: - task_mcp_client.list_tools_sync() + # Verify decision logic assert task_mcp_client._should_use_task("task_required_echo") is True assert task_mcp_client._should_use_task("task_optional_echo") is True assert task_mcp_client._should_use_task("task_forbidden_echo") is False assert task_mcp_client._should_use_task("echo") is False - def test_multiple_tool_calls_in_sequence(self, task_mcp_client: MCPClient) -> None: - """Test calling multiple tools in sequence with different task modes.""" + def test_server_capabilities(self, task_mcp_client: MCPClient) -> None: + """Test server task capability detection.""" with task_mcp_client: task_mcp_client.list_tools_sync() - - r1 = task_mcp_client.call_tool_sync( - tool_use_id="s1", name="task_forbidden_echo", arguments={"message": "1"} - ) - assert r1["status"] == "success" and "Forbidden echo: 1" in r1["content"][0].get("text", "") - - r2 = task_mcp_client.call_tool_sync(tool_use_id="s2", name="echo", arguments={"message": "2"}) - assert r2["status"] == "success" and "Simple echo: 2" in r2["content"][0].get("text", "") - - r3 = task_mcp_client.call_tool_sync(tool_use_id="s3", name="task_optional_echo", arguments={"message": "3"}) - assert r3["status"] == "success" and "Task optional echo: 3" in r3["content"][0].get("text", "") - - r4 = task_mcp_client.call_tool_sync(tool_use_id="s4", name="task_required_echo", arguments={"message": "4"}) - assert r4["status"] == "success" and "Task echo: 4" in r4["content"][0].get("text", "") - - @pytest.mark.asyncio - async def test_async_tool_calls(self, task_mcp_client: MCPClient) -> None: - """Test async tool calls work correctly.""" - with task_mcp_client: - task_mcp_client.list_tools_sync() - result = await task_mcp_client.call_tool_async( - tool_use_id="test-async", name="task_forbidden_echo", arguments={"message": "Async hello!"} - ) - assert result["status"] == "success" - assert "Forbidden echo: Async hello!" in result["content"][0].get("text", "") + assert task_mcp_client._has_server_task_support() is True def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> None: """Test that tasks are disabled when experimental.tasks is not configured.""" with task_mcp_client_disabled: task_mcp_client_disabled.list_tools_sync() - # Even though server supports tasks and tool has taskSupport='required', - # tasks should NOT be used because experimental.tasks is not configured assert task_mcp_client_disabled._is_tasks_enabled() is False assert task_mcp_client_disabled._should_use_task("task_required_echo") is False - assert task_mcp_client_disabled._should_use_task("task_optional_echo") is False - # Tool calls should still work via direct call_tool + # Tool calls still work via direct call_tool result = task_mcp_client_disabled.call_tool_sync( - tool_use_id="test-disabled", name="task_required_echo", arguments={"message": "Direct call!"} + tool_use_id="t", name="task_required_echo", arguments={"message": "Direct!"} + ) + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_async_tool_call(self, task_mcp_client: MCPClient) -> None: + """Test async tool calls.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + result = await task_mcp_client.call_tool_async( + tool_use_id="t", name="task_forbidden_echo", arguments={"message": "Async!"} ) assert result["status"] == "success" - assert "Task echo: Direct call!" in result["content"][0].get("text", "") + assert "Forbidden echo: Async!" in result["content"][0].get("text", "")