diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index fb4e992dfd..ec05372ac5 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -18,6 +18,7 @@ import base64 import logging import sys +import time from typing import Any from typing import Awaitable from typing import Callable @@ -114,6 +115,8 @@ def __init__( Union[ProgressFnT, ProgressCallbackFactory] ] = None, use_mcp_resources: Optional[bool] = False, + cache: bool = False, + cache_ttl_seconds: Optional[int] = None, ): """Initializes the McpToolset. @@ -150,6 +153,10 @@ def __init__( use_mcp_resources: Whether the agent should have access to MCP resources. This will add a `load_mcp_resource` tool to the toolset and include available resources in the agent context. Defaults to False. + cache: If True, the toolset will cache the response from the + first `list_tools` call and reuse it for subsequent calls. + cache_ttl_seconds: If set, the in-memory cache will expire + after this many seconds. """ super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) @@ -170,6 +177,11 @@ def __init__( self._auth_scheme = auth_scheme self._auth_credential = auth_credential self._require_confirmation = require_confirmation + self._cache = cache + self._cache_ttl_seconds = cache_ttl_seconds + self._cached_tool_response: Optional[ListToolsResult] = None + self._cache_creation_time: Optional[float] = None + self._cache_lock = asyncio.Lock() # Store auth config as instance variable so ADK can populate # exchanged_auth_credential in-place before calling get_tools() self._auth_config: Optional[AuthConfig] = ( @@ -300,13 +312,40 @@ async def get_tools( Returns: List[BaseTool]: A list of tools available under the specified context. """ - # Fetch available tools from the MCP server - tools_response: ListToolsResult = await self._execute_with_session( - lambda session: session.list_tools(), - "Failed to get tools from MCP server", - readonly_context, - ) + def _is_cache_valid() -> bool: + if not self._cache or not self._cached_tool_response: + return False + + if self._cache_ttl_seconds is None: + return True # No TTL set, consider cache always valid + + if self._cache_creation_time is None: + # This should not happen in a well-initialized system + return False + + elapsed = time.monotonic() - self._cache_creation_time + return elapsed <= self._cache_ttl_seconds + + # First check without a lock for performance. + if _is_cache_valid(): + tools_response = self._cached_tool_response + else: + # If cache is invalid, acquire lock to prevent stampede. + async with self._cache_lock: + # Double-check if cache was populated while waiting for the lock. + if _is_cache_valid(): + tools_response = self._cached_tool_response + else: + fetched_tools: ListToolsResult = await self._execute_with_session( + lambda session: session.list_tools(), + "Failed to get tools from MCP server", + readonly_context, + ) + if self._cache: + self._cached_tool_response = fetched_tools + self._cache_creation_time = time.monotonic() + tools_response = fetched_tools # Apply filtering based on context and tool_filter tools = [] for tool in tools_response.tools: @@ -431,6 +470,7 @@ def from_config( auth_credential=mcp_toolset_config.auth_credential, use_mcp_resources=mcp_toolset_config.use_mcp_resources, ) + class MCPToolset(McpToolset): diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py new file mode 100644 index 0000000000..14cedc216b --- /dev/null +++ b/tests/unittests/tools/test_mcp_toolset.py @@ -0,0 +1,212 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for McpToolset.""" + +import asyncio +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset +import pytest + + +@pytest.mark.asyncio +async def test_mcp_toolset_with_prefix(): + """Test that McpToolset correctly applies the tool_name_prefix.""" + # Mock the connection parameters + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + + # Mock the MCPSessionManager and its create_session method + mock_session_manager = MagicMock() + mock_session = MagicMock() + + # Mock the list_tools response from the MCP server + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "tool 1 desc" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "tool 2 desc" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool1, mock_tool2] + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Create an instance of McpToolset with a prefix + toolset = McpToolset( + connection_params=mock_connection_params, + tool_name_prefix="my_prefix", + ) + + # Replace the internal session manager with our mock + toolset._mcp_session_manager = mock_session_manager + + # Get the tools from the toolset + tools = await toolset.get_tools() + + # The get_tools method in McpToolset returns MCPTool objects, which are + # instances of BaseTool. The prefixing is handled by the BaseToolset, + # so we need to call get_tools_with_prefix to get the prefixed tools. + prefixed_tools = await toolset.get_tools_with_prefix() + + # Assert that the tools are prefixed correctly + assert len(prefixed_tools) == 2 + assert prefixed_tools[0].name == "my_prefix_tool1" + assert prefixed_tools[1].name == "my_prefix_tool2" + + # Assert that the original tools are not modified + assert tools[0].name == "tool1" + assert tools[1].name == "tool2" + + +def _create_mock_session_manager(): + """Helper to create a mock MCPSessionManager.""" + mock_session_manager = MagicMock() + mock_session = MagicMock() + + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "tool 1 desc" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "tool 2 desc" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool1, mock_tool2] + + mock_session.list_tools = AsyncMock(return_value=list_tools_result) + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + return mock_session_manager, mock_session + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_disabled(): + """Test that list_tools is called every time when cache is disabled.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset(connection_params=mock_connection_params, cache=False) + toolset._mcp_session_manager = mock_session_manager + + await toolset.get_tools() + await toolset.get_tools() + + assert mock_session.list_tools.call_count == 2 + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_enabled(): + """Test that list_tools is called only once when cache is enabled.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset(connection_params=mock_connection_params, cache=True) + toolset._mcp_session_manager = mock_session_manager + + tools1 = await toolset.get_tools() + tools2 = await toolset.get_tools() + + mock_session.list_tools.assert_called_once() + assert len(tools1) == 2 + assert len(tools2) == 2 + assert tools1[0].name == tools2[0].name + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_with_ttl_not_expired(): + """Test that cache is used when TTL has not expired.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset( + connection_params=mock_connection_params, cache=True, cache_ttl_seconds=10 + ) + toolset._mcp_session_manager = mock_session_manager + + await toolset.get_tools() + await toolset.get_tools() + + mock_session.list_tools.assert_called_once() + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_with_ttl_expired(): + """Test that list_tools is called again after TTL expires.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + mock_session_manager, mock_session = _create_mock_session_manager() + + toolset = McpToolset( + connection_params=mock_connection_params, cache=True, cache_ttl_seconds=1 + ) + toolset._mcp_session_manager = mock_session_manager + + await toolset.get_tools() + mock_session.list_tools.assert_called_once() + + await asyncio.sleep(1.1) + + await toolset.get_tools() + assert mock_session.list_tools.call_count == 2 + + +@pytest.mark.asyncio +async def test_mcp_toolset_cache_concurrency(): + """Test that list_tools is called only once during concurrent requests.""" + mock_connection_params = MagicMock() + mock_connection_params.timeout = None + + # Create a mock session manager. Add a small delay to the mock call + # to simulate network latency and increase the chance of a race condition. + mock_session_manager = MagicMock() + mock_session = MagicMock() + + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "tool 1 desc" + list_tools_result = MagicMock() + list_tools_result.tools = [mock_tool1] + + async def delayed_list_tools(): + await asyncio.sleep(0.1) + return list_tools_result + + mock_session.list_tools = AsyncMock(side_effect=delayed_list_tools) + mock_session_manager.create_session = AsyncMock(return_value=mock_session) + + # Initialize the toolset with caching enabled + toolset = McpToolset(connection_params=mock_connection_params, cache=True) + toolset._mcp_session_manager = mock_session_manager + + # Create multiple concurrent tasks to call get_tools + tasks = [asyncio.create_task(toolset.get_tools()) for _ in range(5)] + + # Run all tasks concurrently + results = await asyncio.gather(*tasks) + + # Assert that list_tools was only called once, thanks to the lock + mock_session.list_tools.assert_called_once() + + # Assert that all results are the same and correct + assert len(results) == 5 + for result in results: + assert len(result) == 1 + assert result[0].name == "tool1" + + # Check that the first result is the same as the others + assert all(results[0][0].name == r[0].name for r in results[1:]) \ No newline at end of file