diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 1aff22a1e..0cfaae790 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -73,6 +73,39 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] +class TasksConfig(TypedDict, total=False): + """Configuration for MCP Tasks (task-augmented tool execution). + + If this config is provided (not None), task-augmented execution is enabled. + When enabled, long-running tool calls use the MCP task workflow: + create task -> poll for completion -> get result. + + Attributes: + ttl_ms: Task time-to-live in milliseconds. Defaults to 60000 (1 minute). + poll_timeout_seconds: Timeout for polling task completion in seconds. + Defaults to 300.0 (5 minutes). + """ + + ttl_ms: int + poll_timeout_seconds: float + + +class ExperimentalConfig(TypedDict, total=False): + """Configuration for experimental MCPClient features. + + Warning: + Features under this configuration are experimental and subject to change + in future revisions without notice. + + Attributes: + tasks: Configuration for MCP Tasks (task-augmented tool execution). + If provided (not None), enables task-augmented execution for tools + that support it. + """ + + tasks: TasksConfig | None + + MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", @@ -120,6 +153,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, + experimental: ExperimentalConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -130,6 +164,10 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + experimental: Configuration for experimental features. Currently supports: + - tasks: Enable MCP task-augmented execution for long-running tools. + If provided (not None), enables task-augmented execution for tools + that support it. See ExperimentalConfig and TasksConfig for details. """ self._startup_timeout = startup_timeout self._tool_filters = tool_filters @@ -154,6 +192,11 @@ def __init__( self._tool_provider_started = False self._consumers: set[Any] = set() + # Task support configuration and caching + self._experimental = experimental or {} + self._server_task_capable: bool | None = None + self._tool_task_support_cache: dict[str, str | None] = {} + def __enter__(self) -> "MCPClient": """Context manager entry point which initializes the MCP server connection. @@ -358,6 +401,8 @@ async def _set_close_event() -> None: self._loaded_tools = None self._tool_provider_started = False self._consumers = set() + self._server_task_capable = None + self._tool_task_support_cache = {} if self._close_exception: exception = self._close_exception @@ -400,6 +445,12 @@ async def _list_tools_async() -> ListToolsResult: mcp_tools = [] for tool in list_tools_response.tools: + # Cache taskSupport for task-augmented execution decisions + task_support = None + if tool.execution is not None and tool.execution.taskSupport is not None: + task_support = tool.execution.taskSupport + self._tool_task_support_cache[tool.name] = task_support + # Apply prefix if specified if effective_prefix: prefixed_name = f"{effective_prefix}_{tool.name}" @@ -539,6 +590,45 @@ async def _list_resource_templates_async() -> ListResourceTemplatesResult: return list_resource_templates_result + def _create_call_tool_coroutine( + self, + name: str, + arguments: dict[str, Any] | None, + read_timeout_seconds: timedelta | None, + ) -> Coroutine[Any, Any, MCPCallToolResult]: + """Create the appropriate coroutine for calling a tool. + + This method encapsulates the decision logic for whether to use task-augmented + execution or direct call_tool, returning the appropriate coroutine. + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + read_timeout_seconds: Optional timeout for the tool call. + + Returns: + A coroutine that will execute the tool call. + """ + use_task = self._should_use_task(name) + + if use_task: + self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) + poll_timeout = self._convert_timeout_for_polling(read_timeout_seconds) + + async def _call_as_task() -> MCPCallToolResult: + return await self._call_tool_as_task_and_poll_async(name, arguments, poll_timeout_seconds=poll_timeout) + + return _call_as_task() + else: + self._log_debug_with_thread("tool=<%s> | using direct call_tool", name) + + async def _call_tool_direct() -> MCPCallToolResult: + return await cast(ClientSession, self._background_thread_session).call_tool( + name, arguments, read_timeout_seconds + ) + + return _call_tool_direct() + def call_tool_sync( self, tool_use_id: str, @@ -548,10 +638,8 @@ def call_tool_sync( ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the ToolResult format. If the MCP tool returns - structured content, it will be included as the last item in the content array - of the returned ToolResult. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -566,13 +654,9 @@ def call_tool_sync( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result() + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: logger.exception("tool execution failed") @@ -587,8 +671,8 @@ async def call_tool_async( ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. - This method calls the asynchronous call_tool method on the MCP session - and converts the result to the MCPToolResult format. + This method automatically uses task-augmented execution when appropriate, + based on server capabilities and tool-level taskSupport settings. Args: tool_use_id: Unique identifier for this tool use @@ -603,13 +687,9 @@ async def call_tool_async( if not self._is_session_active(): raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) - async def _call_tool_async() -> MCPCallToolResult: - return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds - ) - try: - future = self._invoke_on_background_thread(_call_tool_async()) + coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds) + future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -687,6 +767,21 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("session initialized successfully") # Store the session for use while we await the close event self._background_thread_session = session + + # Cache server task capability immediately after initialization + # Capabilities are exchanged during session.initialize(), so this is available now + caps = session.get_server_capabilities() + self._server_task_capable = ( + caps is not None + and caps.tasks is not None + and caps.tasks.requests is not None + and caps.tasks.requests.tools is not None + and caps.tasks.requests.tools.call is not None + ) + self._log_debug_with_thread( + "server_task_capable=<%s> | cached server task capability", self._server_task_capable + ) + # Signal that the session has been created and is ready for use self._init_future.set_result(None) @@ -898,3 +993,233 @@ def _is_session_active(self) -> bool: return False return True + + def _is_tasks_enabled(self) -> bool: + """Check if experimental tasks feature is enabled. + + Tasks are enabled if experimental.tasks is defined and not None. + + Returns: + True if task-augmented execution is enabled, False otherwise. + """ + return self._experimental.get("tasks") is not None + + def _get_task_ttl_ms(self) -> int: + """Get task TTL in milliseconds. + + Returns: + Task TTL from config, or default of 60000 (1 minute). + """ + tasks_config = self._experimental.get("tasks") + if tasks_config is None: + return 60000 + return tasks_config.get("ttl_ms", 60000) + + def _get_task_poll_timeout_seconds(self) -> float: + """Get task polling timeout in seconds. + + Returns: + Polling timeout from config, or default of 300.0 (5 minutes). + """ + tasks_config = self._experimental.get("tasks") + if tasks_config is None: + return 300.0 + return tasks_config.get("poll_timeout_seconds", 300.0) + + def _has_server_task_support(self) -> bool: + """Check if the MCP server supports task-augmented tool calls. + + Returns the capability value that was cached immediately after session initialization. + Server capabilities are exchanged during the MCP handshake, so this is available + as soon as start() completes. + + Returns: + True if server supports task-augmented tool calls, False otherwise. + """ + return self._server_task_capable or False + + def _get_tool_task_support(self, tool_name: str) -> str | None: + """Get the taskSupport setting for a tool. + + Returns the cached taskSupport value for the given tool name. + The cache is populated during list_tools_sync(). + + Args: + tool_name: Name of the tool to look up. + + Returns: + The taskSupport value ('required', 'optional', 'forbidden') or None if not cached. + """ + return self._tool_task_support_cache.get(tool_name) + + def _should_use_task(self, tool_name: str) -> bool: + """Determine if task-augmented execution should be used for a tool. + + Task-augmented execution requires: + 1. experimental.tasks is enabled (opt-in check) + 2. Server supports tasks (capability check) + 3. Tool taskSupport is 'required' or 'optional' + + Args: + tool_name: Name of the tool to check. + + Returns: + True if task-augmented execution should be used, False otherwise. + """ + # Opt-in check: tasks must be explicitly enabled via experimental.tasks + if not self._is_tasks_enabled(): + return False + + # Server capability check (per MCP spec) + if not self._has_server_task_support(): + return False + + task_support = self._get_tool_task_support(tool_name) + + # Use tasks for 'required' or 'optional' when server supports + if task_support == "required" or task_support == "optional": + return True + + # Default: 'forbidden', None, or unknown -> don't use tasks + return False + + def _convert_timeout_for_polling(self, read_timeout_seconds: timedelta | None) -> float | None: + """Convert a timedelta timeout to seconds for task polling. + + When task-augmented execution is used, the read_timeout_seconds parameter + (which is a timedelta) needs to be converted to a float for the polling timeout. + + Args: + read_timeout_seconds: Optional timedelta timeout from the call_tool API. + + Returns: + Float seconds if timeout was specified, None to use default. + """ + return read_timeout_seconds.total_seconds() if read_timeout_seconds else None + + def _create_task_error_result(self, message: str) -> MCPCallToolResult: + """Create an error MCPCallToolResult with consistent formatting. + + This helper reduces duplication in task error handling paths. + + Args: + message: The error message to include in the result. + + Returns: + MCPCallToolResult with isError=True and the message as text content. + """ + return MCPCallToolResult( + isError=True, + content=[MCPTextContent(type="text", text=message)], + ) + + # ================================================================================== + # Task-Augmented Tool Execution + # ================================================================================== + # + # The MCP spec defines task-augmented execution for long-running tools. The flow is: + # + # 1. Check server capability (tasks.requests.tools.call) and tool setting (taskSupport) + # 2. If using tasks: call_tool_as_task() -> poll_task() -> get_task_result() + # 3. If not using tasks: call_tool() directly + # + # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks + # ================================================================================== + + async def _call_tool_as_task_and_poll_async( + self, + name: str, + arguments: dict[str, Any] | None = None, + ttl_ms: int | None = None, + poll_timeout_seconds: float | None = None, + ) -> MCPCallToolResult: + """Call a tool using task-augmented execution and poll until completion. + + This method implements the MCP task workflow: + 1. Creates a task via call_tool_as_task + 2. Polls using poll_task until terminal status (with timeout protection) + 3. Gets the final result using get_task_result + + Args: + name: Name of the tool to call. + arguments: Optional arguments to pass to the tool. + ttl_ms: Task time-to-live in milliseconds. Uses configured value if not specified. + poll_timeout_seconds: Timeout for polling in seconds. Uses configured value if not specified. + + Returns: + MCPCallToolResult: The final tool result after task completion. + """ + session = cast(ClientSession, self._background_thread_session) + ttl = ttl_ms or self._get_task_ttl_ms() + timeout = poll_timeout_seconds or self._get_task_poll_timeout_seconds() + + # Step 1: Create the task + self._log_debug_with_thread("tool=<%s> | calling tool as task with ttl=%d ms", name, ttl) + create_result = await session.experimental.call_tool_as_task( + name=name, + arguments=arguments, + ttl=ttl, + ) + task_id = create_result.task.taskId + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task created", name, task_id) + + # Step 2: Poll until terminal status (with timeout protection) + # Note: Using asyncio.wait_for() instead of asyncio.timeout() for Python 3.10 compatibility + async def _poll_until_terminal() -> Any: + """Inner function to poll task status until terminal state.""" + final = None + async for status in session.experimental.poll_task(task_id): + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | task status update", + name, + task_id, + status.status, + ) + final = status + return final + + try: + final_status = await asyncio.wait_for(_poll_until_terminal(), timeout=timeout) + except asyncio.TimeoutError: + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, timeout=<%s> | task polling timed out", name, task_id, timeout + ) + return self._create_task_error_result(f"Task {task_id} polling timed out after {timeout} seconds") + + # Step 3: Handle terminal status + if final_status is None: + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | polling completed without status", name, task_id) + return self._create_task_error_result(f"Task {task_id} polling completed without status") + + if final_status.status == "failed": + error_msg = final_status.statusMessage or "Task failed" + self._log_debug_with_thread("tool=<%s>, task_id=<%s>, error=<%s> | task failed", name, task_id, error_msg) + return self._create_task_error_result(error_msg) + + if final_status.status == "cancelled": + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task was cancelled", name, task_id) + return self._create_task_error_result("Task was cancelled") + + # Step 4: Get the actual result for completed tasks (with error handling for race conditions) + if final_status.status == "completed": + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task completed, fetching result", name, task_id) + try: + result = await session.experimental.get_task_result(task_id, MCPCallToolResult) + self._log_debug_with_thread("tool=<%s>, task_id=<%s> | task result retrieved", name, task_id) + return result + except Exception as e: + # Handle race condition: task completed but result retrieval failed + # (e.g., result expired, network error, server restarted) + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, error=<%s> | failed to retrieve task result", name, task_id, str(e) + ) + return self._create_task_error_result(f"Task completed but result retrieval failed: {str(e)}") + + # Unexpected status - return as error + self._log_debug_with_thread( + "tool=<%s>, task_id=<%s>, status=<%s> | unexpected task status", + name, + task_id, + final_status.status, + ) + return self._create_task_error_result(f"Unexpected task status: {final_status.status}") diff --git a/tests/strands/tools/mcp/conftest.py b/tests/strands/tools/mcp/conftest.py new file mode 100644 index 000000000..0cfce470a --- /dev/null +++ b/tests/strands/tools/mcp/conftest.py @@ -0,0 +1,59 @@ +"""Shared fixtures and helpers for MCP client tests.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_transport(): + """Create a mock MCP transport.""" + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + mock_transport_cm = AsyncMock() + mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) + mock_transport_callable = MagicMock(return_value=mock_transport_cm) + + return { + "read_stream": mock_read_stream, + "write_stream": mock_write_stream, + "transport_cm": mock_transport_cm, + "transport_callable": mock_transport_callable, + } + + +@pytest.fixture +def mock_session(): + """Create a mock MCP session.""" + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + # Default: no task support (get_server_capabilities is sync, not async!) + mock_session.get_server_capabilities = MagicMock(return_value=None) + + # Create a mock context manager for ClientSession + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__.return_value = mock_session + + # Patch ClientSession to return our mock session + with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): + yield mock_session + + +def create_server_capabilities(has_task_support: bool) -> MagicMock: + """Create mock server capabilities. + + Args: + has_task_support: Whether the server should advertise task support. + + Returns: + MagicMock representing server capabilities. + """ + caps = MagicMock() + if has_task_support: + caps.tasks = MagicMock() + caps.tasks.requests = MagicMock() + caps.tasks.requests.tools = MagicMock() + caps.tasks.requests.tools.call = MagicMock() + else: + caps.tasks = None + return caps diff --git a/tests/strands/tools/mcp/test_mcp_client.py b/tests/strands/tools/mcp/test_mcp_client.py index f784da414..4c9ca6752 100644 --- a/tests/strands/tools/mcp/test_mcp_client.py +++ b/tests/strands/tools/mcp/test_mcp_client.py @@ -1,6 +1,6 @@ import base64 import time -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from mcp import ListToolsResult @@ -25,35 +25,7 @@ from strands.tools.mcp.mcp_types import MCPToolResult from strands.types.exceptions import MCPClientInitializationError - -@pytest.fixture -def mock_transport(): - mock_read_stream = AsyncMock() - mock_write_stream = AsyncMock() - mock_transport_cm = AsyncMock() - mock_transport_cm.__aenter__.return_value = (mock_read_stream, mock_write_stream) - mock_transport_callable = MagicMock(return_value=mock_transport_cm) - - return { - "read_stream": mock_read_stream, - "write_stream": mock_write_stream, - "transport_cm": mock_transport_cm, - "transport_callable": mock_transport_callable, - } - - -@pytest.fixture -def mock_session(): - mock_session = AsyncMock() - mock_session.initialize = AsyncMock() - - # Create a mock context manager for ClientSession - mock_session_cm = AsyncMock() - mock_session_cm.__aenter__.return_value = mock_session - - # Patch ClientSession to return our mock session - with patch("strands.tools.mcp.mcp_client.ClientSession", return_value=mock_session_cm): - yield mock_session +# Fixtures mock_transport and mock_session are imported from conftest.py @pytest.fixture diff --git a/tests/strands/tools/mcp/test_mcp_client_contextvar.py b/tests/strands/tools/mcp/test_mcp_client_contextvar.py index d95929b02..739796366 100644 --- a/tests/strands/tools/mcp/test_mcp_client_contextvar.py +++ b/tests/strands/tools/mcp/test_mcp_client_contextvar.py @@ -37,6 +37,8 @@ def mock_session(): """Create mock MCP session.""" mock_session = AsyncMock() mock_session.initialize = AsyncMock() + # get_server_capabilities is sync, not async + mock_session.get_server_capabilities = MagicMock(return_value=None) mock_session_cm = AsyncMock() mock_session_cm.__aenter__.return_value = mock_session diff --git a/tests/strands/tools/mcp/test_mcp_client_tasks.py b/tests/strands/tools/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..629163695 --- /dev/null +++ b/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -0,0 +1,214 @@ +"""Tests for MCP task-augmented execution support in MCPClient.""" + +import asyncio +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from mcp import ListToolsResult +from mcp.types import CallToolResult as MCPCallToolResult +from mcp.types import TextContent as MCPTextContent +from mcp.types import Tool as MCPTool +from mcp.types import ToolExecution + +from strands.tools.mcp import MCPClient + +from .conftest import create_server_capabilities + + +class TestTasksOptIn: + """Tests for task opt-in behavior via experimental.tasks.""" + + @pytest.mark.parametrize( + "experimental,expected_enabled", + [ + (None, False), + ({}, False), + ({"tasks": None}, False), + ({"tasks": {}}, True), + ({"tasks": {"ttl_ms": 1000}}, True), + ], + ) + def test_tasks_enabled_state(self, mock_transport, mock_session, experimental, expected_enabled): + """Test _is_tasks_enabled based on experimental config.""" + with MCPClient(mock_transport["transport_callable"], experimental=experimental) as client: + assert client._is_tasks_enabled() is expected_enabled + + def test_should_use_task_requires_opt_in(self, mock_transport, mock_session): + """Test that _should_use_task returns False without opt-in even with server/tool support.""" + with MCPClient(mock_transport["transport_callable"]) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + assert client._should_use_task("test_tool") is False + + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + assert client._should_use_task("test_tool") is True + + +class TestTaskConfiguration: + """Tests for task-related configuration options.""" + + @pytest.mark.parametrize( + "config,expected_ttl,expected_timeout", + [ + ({}, 60000, 300.0), + ({"ttl_ms": 120000}, 120000, 300.0), + ({"poll_timeout_seconds": 60.0}, 60000, 60.0), + ({"ttl_ms": 120000, "poll_timeout_seconds": 60.0}, 120000, 60.0), + ], + ) + def test_task_config_values(self, mock_transport, mock_session, config, expected_ttl, expected_timeout): + """Test task configuration values with various configs.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": config}) as client: + assert client._get_task_ttl_ms() == expected_ttl + assert client._get_task_poll_timeout_seconds() == expected_timeout + + def test_stop_resets_task_caches(self, mock_transport, mock_session): + """Test that stop() resets the task support caches.""" + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client._server_task_capable = True + client._tool_task_support_cache["tool1"] = "required" + assert client._server_task_capable is None + assert client._tool_task_support_cache == {} + + +class TestTaskExecution: + """Tests for task execution and error handling.""" + + def _setup_task_tool(self, mock_session, tool_name: str) -> None: + """Helper to set up a mock task-enabled tool.""" + mock_session.get_server_capabilities = MagicMock(return_value=create_server_capabilities(True)) + mock_tool = MCPTool( + name=tool_name, + description="A test tool", + inputSchema={"type": "object"}, + execution=ToolExecution(taskSupport="optional"), + ) + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=[mock_tool], nextCursor=None)) + mock_create_result = MagicMock() + mock_create_result.task.taskId = "test-task-id" + mock_session.experimental = MagicMock() + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + @pytest.mark.parametrize( + "status,status_message,expected_text", + [ + ("failed", "Something went wrong", "Something went wrong"), + ("cancelled", None, "cancelled"), + ("unknown_status", None, "unexpected task status"), + ], + ) + def test_terminal_status_handling(self, mock_transport, mock_session, status, status_message, expected_text): + """Test handling of terminal task statuses.""" + mock_create_result = MagicMock() + mock_create_result.task.taskId = f"task-{status}" + mock_session.experimental.call_tool_as_task = AsyncMock(return_value=mock_create_result) + + async def mock_poll_task(task_id): + yield MagicMock(status=status, statusMessage=status_message) + + mock_session.experimental.poll_task = mock_poll_task + + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client._server_task_capable = True + client._tool_task_support_cache["test_tool"] = "required" + result = client.call_tool_sync(tool_use_id="test-id", name="test_tool", arguments={}) + assert result["status"] == "error" + assert expected_text.lower() in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_polling_timeout(self, mock_transport, mock_session): + """Test that task polling times out properly.""" + self._setup_task_tool(mock_session, "slow_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient( + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 0.1}} + ) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="slow_tool", arguments={}) + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_explicit_timeout_overrides_default(self, mock_transport, mock_session): + """Test that read_timeout_seconds overrides the default poll timeout.""" + self._setup_task_tool(mock_session, "timeout_tool") + + async def infinite_poll(task_id): + while True: + await asyncio.sleep(1) + yield MagicMock(status="running") + + mock_session.experimental.poll_task = infinite_poll + + with MCPClient( + mock_transport["transport_callable"], experimental={"tasks": {"poll_timeout_seconds": 300.0}} + ) as client: + client.list_tools_sync() + result = await client.call_tool_async( + tool_use_id="t", name="timeout_tool", arguments={}, read_timeout_seconds=timedelta(seconds=0.1) + ) + assert result["status"] == "error" + assert "timed out" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_result_retrieval_failure(self, mock_transport, mock_session): + """Test that get_task_result failures are handled gracefully.""" + self._setup_task_tool(mock_session, "failing_tool") + + async def successful_poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = successful_poll + mock_session.experimental.get_task_result = AsyncMock(side_effect=Exception("Network error")) + + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="failing_tool", arguments={}) + assert result["status"] == "error" + assert "result retrieval failed" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_empty_poll_result(self, mock_transport, mock_session): + """Test handling when poll_task yields nothing.""" + self._setup_task_tool(mock_session, "empty_poll_tool") + + async def empty_poll(task_id): + return + yield # noqa: B901 + + mock_session.experimental.poll_task = empty_poll + + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="empty_poll_tool", arguments={}) + assert result["status"] == "error" + assert "without status" in result["content"][0].get("text", "").lower() + + @pytest.mark.asyncio + async def test_successful_completion(self, mock_transport, mock_session): + """Test successful task completion.""" + self._setup_task_tool(mock_session, "success_tool") + + async def poll(task_id): + yield MagicMock(status="completed", statusMessage=None) + + mock_session.experimental.poll_task = poll + mock_session.experimental.get_task_result = AsyncMock( + return_value=MCPCallToolResult(content=[MCPTextContent(type="text", text="Done")], isError=False) + ) + + with MCPClient(mock_transport["transport_callable"], experimental={"tasks": {}}) as client: + client.list_tools_sync() + result = await client.call_tool_async(tool_use_id="t", name="success_tool", arguments={}) + assert result["status"] == "success" + assert "Done" in result["content"][0].get("text", "") diff --git a/tests_integ/mcp/task_echo_server.py b/tests_integ/mcp/task_echo_server.py new file mode 100644 index 000000000..4a8edc97d --- /dev/null +++ b/tests_integ/mcp/task_echo_server.py @@ -0,0 +1,139 @@ +"""MCP server with task-augmented tool execution support for integration testing.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import click +import mcp.types as types +from mcp.server.experimental.task_context import ServerTaskContext +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount + + +def create_task_server() -> Server: + """Create and configure the task-supporting MCP server.""" + server = Server("task-echo-server") + server.experimental.enable_tasks() + + # Workaround: MCP Python SDK's enable_tasks() doesn't properly set tasks.requests.tools.call capability + original_update_capabilities = server.experimental.update_capabilities + + def patched_update_capabilities(capabilities: types.ServerCapabilities) -> None: + original_update_capabilities(capabilities) + if capabilities.tasks and capabilities.tasks.requests and capabilities.tasks.requests.tools: + capabilities.tasks.requests.tools.call = types.TasksCallCapability() + + server.experimental.update_capabilities = patched_update_capabilities # type: ignore[method-assign] + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="task_required_echo", + description="Echo that requires task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_REQUIRED), + ), + types.Tool( + name="task_optional_echo", + description="Echo that optionally supports task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_OPTIONAL), + ), + types.Tool( + name="task_forbidden_echo", + description="Echo that does not support task-augmented execution", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + execution=types.ToolExecution(taskSupport=types.TASK_FORBIDDEN), + ), + types.Tool( + name="echo", + description="Simple echo without task support setting", + inputSchema={"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}, + ), + ] + + async def handle_task_required_echo(arguments: dict[str, Any]) -> types.CreateTaskResult: + ctx = server.request_context + ctx.experimental.validate_task_mode(types.TASK_REQUIRED) + message = arguments.get("message", "") + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing echo...") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Task echo: {message}")]) + + return await ctx.experimental.run_task(work) + + async def handle_task_optional_echo(arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + ctx = server.request_context + message = arguments.get("message", "") + + if ctx.experimental.is_task: + + async def work(task: ServerTaskContext) -> types.CallToolResult: + await task.update_status("Processing optional task echo...") + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Task optional echo: {message}")] + ) + + return await ctx.experimental.run_task(work) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Direct optional echo: {message}")] + ) + + async def handle_task_forbidden_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Forbidden echo: {message}")]) + + async def handle_simple_echo(arguments: dict[str, Any]) -> types.CallToolResult: + message = arguments.get("message", "") + return types.CallToolResult(content=[types.TextContent(type="text", text=f"Simple echo: {message}")]) + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict[str, Any]) -> types.CallToolResult | types.CreateTaskResult: + handlers = { + "task_required_echo": handle_task_required_echo, + "task_optional_echo": handle_task_optional_echo, + "task_forbidden_echo": handle_task_forbidden_echo, + "echo": handle_simple_echo, + } + if name in handlers: + return await handlers[name](arguments) + return types.CallToolResult( + content=[types.TextContent(type="text", text=f"Unknown tool: {name}")], isError=True + ) + + return server + + +def create_starlette_app(port: int) -> tuple[Starlette, StreamableHTTPSessionManager]: + """Create the Starlette app with MCP session manager.""" + server = create_task_server() + session_manager = StreamableHTTPSessionManager(app=server) + + @asynccontextmanager + async def app_lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + + return Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)], lifespan=app_lifespan), session_manager + + +@click.command() +@click.option("--port", default=8010, help="Port to listen on") +def main(port: int) -> int: + """Start the task echo server.""" + import uvicorn + + starlette_app, _ = create_starlette_app(port) + print(f"Starting task echo server on http://localhost:{port}/mcp") + uvicorn.run(starlette_app, host="127.0.0.1", port=port) + return 0 + + +if __name__ == "__main__": + main() diff --git a/tests_integ/mcp/test_mcp_client_tasks.py b/tests_integ/mcp/test_mcp_client_tasks.py new file mode 100644 index 000000000..5e398f6de --- /dev/null +++ b/tests_integ/mcp/test_mcp_client_tasks.py @@ -0,0 +1,154 @@ +"""Integration tests for MCP task-augmented tool execution.""" + +import os +import socket +import threading +import time +from typing import Any + +import pytest +from mcp.client.streamable_http import streamablehttp_client + +from strands.tools.mcp.mcp_client import MCPClient +from strands.tools.mcp.mcp_types import MCPTransport + + +def _find_available_port() -> int: + """Find an available port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(1) + return s.getsockname()[1] + + +def start_task_server(port: int) -> None: + """Start the task echo server in a thread.""" + import uvicorn + + from tests_integ.mcp.task_echo_server import create_starlette_app + + starlette_app, _ = create_starlette_app(port) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="warning") + + +@pytest.fixture(scope="module") +def task_server_port() -> int: + return _find_available_port() + + +@pytest.fixture(scope="module") +def task_server(task_server_port: int) -> Any: + """Start the task server for the test module.""" + server_thread = threading.Thread(target=start_task_server, kwargs={"port": task_server_port}, daemon=True) + server_thread.start() + time.sleep(2) + yield + + +@pytest.fixture +def task_mcp_client(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client with tasks enabled.""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback, experimental={"tasks": {}}) + + +@pytest.fixture +def task_mcp_client_disabled(task_server: Any, task_server_port: int) -> MCPClient: + """Create an MCP client with tasks disabled (default).""" + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url=f"http://127.0.0.1:{task_server_port}/mcp") + + return MCPClient(transport_callback) + + +@pytest.mark.skipif(os.environ.get("GITHUB_ACTIONS") == "true", reason="streamable transport failing in CI") +class TestMCPTaskSupport: + """Integration tests for MCP task-augmented execution.""" + + def test_direct_call_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use direct call_tool (forbidden or no taskSupport).""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Tool with taskSupport='forbidden' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_forbidden_echo", arguments={"message": "Hello!"} + ) + assert r1["status"] == "success" + assert "Forbidden echo: Hello!" in r1["content"][0].get("text", "") + + # Tool without taskSupport + r2 = task_mcp_client.call_tool_sync(tool_use_id="t2", name="echo", arguments={"message": "Simple!"}) + assert r2["status"] == "success" + assert "Simple echo: Simple!" in r2["content"][0].get("text", "") + + def test_task_augmented_tools(self, task_mcp_client: MCPClient) -> None: + """Test tools that use task-augmented execution (required or optional).""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Tool with taskSupport='required' + r1 = task_mcp_client.call_tool_sync( + tool_use_id="t1", name="task_required_echo", arguments={"message": "Required!"} + ) + assert r1["status"] == "success" + assert "Task echo: Required!" in r1["content"][0].get("text", "") + + # Tool with taskSupport='optional' + r2 = task_mcp_client.call_tool_sync( + tool_use_id="t2", name="task_optional_echo", arguments={"message": "Optional!"} + ) + assert r2["status"] == "success" + assert "Task optional echo: Optional!" in r2["content"][0].get("text", "") + + def test_task_support_caching_and_decision(self, task_mcp_client: MCPClient) -> None: + """Test taskSupport caching and _should_use_task decision logic.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + + # Verify cached values + assert task_mcp_client._get_tool_task_support("task_required_echo") == "required" + assert task_mcp_client._get_tool_task_support("task_optional_echo") == "optional" + assert task_mcp_client._get_tool_task_support("task_forbidden_echo") == "forbidden" + assert task_mcp_client._get_tool_task_support("echo") is None + + # Verify decision logic + assert task_mcp_client._should_use_task("task_required_echo") is True + assert task_mcp_client._should_use_task("task_optional_echo") is True + assert task_mcp_client._should_use_task("task_forbidden_echo") is False + assert task_mcp_client._should_use_task("echo") is False + + def test_server_capabilities(self, task_mcp_client: MCPClient) -> None: + """Test server task capability detection.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + assert task_mcp_client._has_server_task_support() is True + + def test_tasks_disabled_by_default(self, task_mcp_client_disabled: MCPClient) -> None: + """Test that tasks are disabled when experimental.tasks is not configured.""" + with task_mcp_client_disabled: + task_mcp_client_disabled.list_tools_sync() + + assert task_mcp_client_disabled._is_tasks_enabled() is False + assert task_mcp_client_disabled._should_use_task("task_required_echo") is False + + # Tool calls still work via direct call_tool + result = task_mcp_client_disabled.call_tool_sync( + tool_use_id="t", name="task_required_echo", arguments={"message": "Direct!"} + ) + assert result["status"] == "success" + + @pytest.mark.asyncio + async def test_async_tool_call(self, task_mcp_client: MCPClient) -> None: + """Test async tool calls.""" + with task_mcp_client: + task_mcp_client.list_tools_sync() + result = await task_mcp_client.call_tool_async( + tool_use_id="t", name="task_forbidden_echo", arguments={"message": "Async!"} + ) + assert result["status"] == "success" + assert "Forbidden echo: Async!" in result["content"][0].get("text", "")