Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import base64
import logging
import sys
import time
from typing import Any
from typing import Awaitable
from typing import Callable
Expand Down Expand Up @@ -114,6 +115,8 @@ def __init__(
Union[ProgressFnT, ProgressCallbackFactory]
] = None,
use_mcp_resources: Optional[bool] = False,
cache: bool = False,
cache_ttl_seconds: Optional[int] = None,
):
"""Initializes the McpToolset.

Expand Down Expand Up @@ -150,6 +153,10 @@ def __init__(
use_mcp_resources: Whether the agent should have access to MCP resources.
This will add a `load_mcp_resource` tool to the toolset and include
available resources in the agent context. Defaults to False.
cache: If True, the toolset will cache the response from the
first `list_tools` call and reuse it for subsequent calls.
cache_ttl_seconds: If set, the in-memory cache will expire
after this many seconds.
"""

super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
Expand All @@ -170,6 +177,11 @@ def __init__(
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self._require_confirmation = require_confirmation
self._cache = cache
self._cache_ttl_seconds = cache_ttl_seconds
self._cached_tool_response: Optional[ListToolsResult] = None
self._cache_creation_time: Optional[float] = None
self._cache_lock = asyncio.Lock()
# Store auth config as instance variable so ADK can populate
# exchanged_auth_credential in-place before calling get_tools()
self._auth_config: Optional[AuthConfig] = (
Expand Down Expand Up @@ -300,13 +312,40 @@ async def get_tools(
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
# Fetch available tools from the MCP server
tools_response: ListToolsResult = await self._execute_with_session(
lambda session: session.list_tools(),
"Failed to get tools from MCP server",
readonly_context,
)

def _is_cache_valid() -> bool:
if not self._cache or not self._cached_tool_response:
return False

if self._cache_ttl_seconds is None:
return True # No TTL set, consider cache always valid

if self._cache_creation_time is None:
# This should not happen in a well-initialized system
return False

elapsed = time.monotonic() - self._cache_creation_time
return elapsed <= self._cache_ttl_seconds

# First check without a lock for performance.
if _is_cache_valid():
tools_response = self._cached_tool_response
else:
# If cache is invalid, acquire lock to prevent stampede.
async with self._cache_lock:
# Double-check if cache was populated while waiting for the lock.
if _is_cache_valid():
tools_response = self._cached_tool_response
else:
fetched_tools: ListToolsResult = await self._execute_with_session(
lambda session: session.list_tools(),
"Failed to get tools from MCP server",
readonly_context,
)
if self._cache:
self._cached_tool_response = fetched_tools
self._cache_creation_time = time.monotonic()
tools_response = fetched_tools
# Apply filtering based on context and tool_filter
tools = []
for tool in tools_response.tools:
Expand Down Expand Up @@ -431,6 +470,7 @@ def from_config(
auth_credential=mcp_toolset_config.auth_credential,
use_mcp_resources=mcp_toolset_config.use_mcp_resources,
)



class MCPToolset(McpToolset):
Expand Down
212 changes: 212 additions & 0 deletions tests/unittests/tools/test_mcp_toolset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for McpToolset."""

import asyncio
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
import pytest


@pytest.mark.asyncio
async def test_mcp_toolset_with_prefix():
"""Test that McpToolset correctly applies the tool_name_prefix."""
# Mock the connection parameters
mock_connection_params = MagicMock()
mock_connection_params.timeout = None

# Mock the MCPSessionManager and its create_session method
mock_session_manager = MagicMock()
mock_session = MagicMock()

# Mock the list_tools response from the MCP server
mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tool1.description = "tool 1 desc"
mock_tool2 = MagicMock()
mock_tool2.name = "tool2"
mock_tool2.description = "tool 2 desc"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool1, mock_tool2]
mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_session_manager.create_session = AsyncMock(return_value=mock_session)

# Create an instance of McpToolset with a prefix
toolset = McpToolset(
connection_params=mock_connection_params,
tool_name_prefix="my_prefix",
)

# Replace the internal session manager with our mock
toolset._mcp_session_manager = mock_session_manager

# Get the tools from the toolset
tools = await toolset.get_tools()

# The get_tools method in McpToolset returns MCPTool objects, which are
# instances of BaseTool. The prefixing is handled by the BaseToolset,
# so we need to call get_tools_with_prefix to get the prefixed tools.
prefixed_tools = await toolset.get_tools_with_prefix()

# Assert that the tools are prefixed correctly
assert len(prefixed_tools) == 2
assert prefixed_tools[0].name == "my_prefix_tool1"
assert prefixed_tools[1].name == "my_prefix_tool2"

# Assert that the original tools are not modified
assert tools[0].name == "tool1"
assert tools[1].name == "tool2"


def _create_mock_session_manager():
"""Helper to create a mock MCPSessionManager."""
mock_session_manager = MagicMock()
mock_session = MagicMock()

mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tool1.description = "tool 1 desc"
mock_tool2 = MagicMock()
mock_tool2.name = "tool2"
mock_tool2.description = "tool 2 desc"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool1, mock_tool2]

mock_session.list_tools = AsyncMock(return_value=list_tools_result)
mock_session_manager.create_session = AsyncMock(return_value=mock_session)
return mock_session_manager, mock_session


@pytest.mark.asyncio
async def test_mcp_toolset_cache_disabled():
"""Test that list_tools is called every time when cache is disabled."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(connection_params=mock_connection_params, cache=False)
toolset._mcp_session_manager = mock_session_manager

await toolset.get_tools()
await toolset.get_tools()

assert mock_session.list_tools.call_count == 2


@pytest.mark.asyncio
async def test_mcp_toolset_cache_enabled():
"""Test that list_tools is called only once when cache is enabled."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(connection_params=mock_connection_params, cache=True)
toolset._mcp_session_manager = mock_session_manager

tools1 = await toolset.get_tools()
tools2 = await toolset.get_tools()

mock_session.list_tools.assert_called_once()
assert len(tools1) == 2
assert len(tools2) == 2
assert tools1[0].name == tools2[0].name


@pytest.mark.asyncio
async def test_mcp_toolset_cache_with_ttl_not_expired():
"""Test that cache is used when TTL has not expired."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(
connection_params=mock_connection_params, cache=True, cache_ttl_seconds=10
)
toolset._mcp_session_manager = mock_session_manager

await toolset.get_tools()
await toolset.get_tools()

mock_session.list_tools.assert_called_once()


@pytest.mark.asyncio
async def test_mcp_toolset_cache_with_ttl_expired():
"""Test that list_tools is called again after TTL expires."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None
mock_session_manager, mock_session = _create_mock_session_manager()

toolset = McpToolset(
connection_params=mock_connection_params, cache=True, cache_ttl_seconds=1
)
toolset._mcp_session_manager = mock_session_manager

await toolset.get_tools()
mock_session.list_tools.assert_called_once()

await asyncio.sleep(1.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using asyncio.sleep() to test time-dependent logic can make tests slow and potentially flaky. A more robust approach is to mock the time source, time.monotonic. This gives you precise control over time in your test, making it faster and more reliable.

You can use mocker.patch from pytest-mock to do this. Here's an example of how you could rewrite this test:

@pytest.mark.asyncio
async def test_mcp_toolset_cache_with_ttl_expired(mocker):
    """Test that list_tools is called again after TTL expires."""
    mock_connection_params = MagicMock()
    mock_connection_params.timeout = None
    mock_session_manager, mock_session = _create_mock_session_manager()

    toolset = McpToolset(
        connection_params=mock_connection_params, cache=True, cache_ttl_seconds=1
    )
    toolset._mcp_session_manager = mock_session_manager

    # Patch time.monotonic
    mock_time = mocker.patch('time.monotonic')

    # First call, populates cache
    mock_time.return_value = 1000.0
    await toolset.get_tools()
    mock_session.list_tools.assert_called_once()

    # Second call, after TTL expired
    mock_time.return_value = 1001.1  # More than 1 second later
    await toolset.get_tools()
    assert mock_session.list_tools.call_count == 2


await toolset.get_tools()
assert mock_session.list_tools.call_count == 2


@pytest.mark.asyncio
async def test_mcp_toolset_cache_concurrency():
"""Test that list_tools is called only once during concurrent requests."""
mock_connection_params = MagicMock()
mock_connection_params.timeout = None

# Create a mock session manager. Add a small delay to the mock call
# to simulate network latency and increase the chance of a race condition.
mock_session_manager = MagicMock()
mock_session = MagicMock()

mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tool1.description = "tool 1 desc"
list_tools_result = MagicMock()
list_tools_result.tools = [mock_tool1]

async def delayed_list_tools():
await asyncio.sleep(0.1)
return list_tools_result

mock_session.list_tools = AsyncMock(side_effect=delayed_list_tools)
mock_session_manager.create_session = AsyncMock(return_value=mock_session)

# Initialize the toolset with caching enabled
toolset = McpToolset(connection_params=mock_connection_params, cache=True)
toolset._mcp_session_manager = mock_session_manager

# Create multiple concurrent tasks to call get_tools
tasks = [asyncio.create_task(toolset.get_tools()) for _ in range(5)]

# Run all tasks concurrently
results = await asyncio.gather(*tasks)

# Assert that list_tools was only called once, thanks to the lock
mock_session.list_tools.assert_called_once()

# Assert that all results are the same and correct
assert len(results) == 5
for result in results:
assert len(result) == 1
assert result[0].name == "tool1"

# Check that the first result is the same as the others
assert all(results[0][0].name == r[0].name for r in results[1:])