-
Notifications
You must be signed in to change notification settings - Fork 2.9k
feat: Add in-memory cache with TTL support #3831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
peaaceChoi
wants to merge
6
commits into
google:main
Choose a base branch
from
peaaceChoi:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+258
−6
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
52572e7
feat: Add in-memory cache with TTL support
peaaceChoi 1b3d370
Add test case
peaaceChoi e219590
fix: Prevent cache stampede in McpToolset
peaaceChoi 08c6145
Merge branch 'main' into main
peaaceChoi 39422d0
chore: rebase main branch with upstream
peaaceChoi 0682f6f
Merge branch 'main' of https://github.com/peaaceChoi/adk-python
peaaceChoi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| # Copyright 2025 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Unit tests for McpToolset.""" | ||
|
|
||
| import asyncio | ||
| from unittest.mock import AsyncMock | ||
| from unittest.mock import MagicMock | ||
|
|
||
| from google.adk.tools.mcp_tool.mcp_toolset import McpToolset | ||
| import pytest | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_mcp_toolset_with_prefix(): | ||
| """Test that McpToolset correctly applies the tool_name_prefix.""" | ||
| # Mock the connection parameters | ||
| mock_connection_params = MagicMock() | ||
| mock_connection_params.timeout = None | ||
|
|
||
| # Mock the MCPSessionManager and its create_session method | ||
| mock_session_manager = MagicMock() | ||
| mock_session = MagicMock() | ||
|
|
||
| # Mock the list_tools response from the MCP server | ||
| mock_tool1 = MagicMock() | ||
| mock_tool1.name = "tool1" | ||
| mock_tool1.description = "tool 1 desc" | ||
| mock_tool2 = MagicMock() | ||
| mock_tool2.name = "tool2" | ||
| mock_tool2.description = "tool 2 desc" | ||
| list_tools_result = MagicMock() | ||
| list_tools_result.tools = [mock_tool1, mock_tool2] | ||
| mock_session.list_tools = AsyncMock(return_value=list_tools_result) | ||
| mock_session_manager.create_session = AsyncMock(return_value=mock_session) | ||
|
|
||
| # Create an instance of McpToolset with a prefix | ||
| toolset = McpToolset( | ||
| connection_params=mock_connection_params, | ||
| tool_name_prefix="my_prefix", | ||
| ) | ||
|
|
||
| # Replace the internal session manager with our mock | ||
| toolset._mcp_session_manager = mock_session_manager | ||
|
|
||
| # Get the tools from the toolset | ||
| tools = await toolset.get_tools() | ||
|
|
||
| # The get_tools method in McpToolset returns MCPTool objects, which are | ||
| # instances of BaseTool. The prefixing is handled by the BaseToolset, | ||
| # so we need to call get_tools_with_prefix to get the prefixed tools. | ||
| prefixed_tools = await toolset.get_tools_with_prefix() | ||
|
|
||
| # Assert that the tools are prefixed correctly | ||
| assert len(prefixed_tools) == 2 | ||
| assert prefixed_tools[0].name == "my_prefix_tool1" | ||
| assert prefixed_tools[1].name == "my_prefix_tool2" | ||
|
|
||
| # Assert that the original tools are not modified | ||
| assert tools[0].name == "tool1" | ||
| assert tools[1].name == "tool2" | ||
|
|
||
|
|
||
| def _create_mock_session_manager(): | ||
| """Helper to create a mock MCPSessionManager.""" | ||
| mock_session_manager = MagicMock() | ||
| mock_session = MagicMock() | ||
|
|
||
| mock_tool1 = MagicMock() | ||
| mock_tool1.name = "tool1" | ||
| mock_tool1.description = "tool 1 desc" | ||
| mock_tool2 = MagicMock() | ||
| mock_tool2.name = "tool2" | ||
| mock_tool2.description = "tool 2 desc" | ||
| list_tools_result = MagicMock() | ||
| list_tools_result.tools = [mock_tool1, mock_tool2] | ||
|
|
||
| mock_session.list_tools = AsyncMock(return_value=list_tools_result) | ||
| mock_session_manager.create_session = AsyncMock(return_value=mock_session) | ||
| return mock_session_manager, mock_session | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_mcp_toolset_cache_disabled(): | ||
| """Test that list_tools is called every time when cache is disabled.""" | ||
| mock_connection_params = MagicMock() | ||
| mock_connection_params.timeout = None | ||
| mock_session_manager, mock_session = _create_mock_session_manager() | ||
|
|
||
| toolset = McpToolset(connection_params=mock_connection_params, cache=False) | ||
| toolset._mcp_session_manager = mock_session_manager | ||
|
|
||
| await toolset.get_tools() | ||
| await toolset.get_tools() | ||
|
|
||
| assert mock_session.list_tools.call_count == 2 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_mcp_toolset_cache_enabled(): | ||
| """Test that list_tools is called only once when cache is enabled.""" | ||
| mock_connection_params = MagicMock() | ||
| mock_connection_params.timeout = None | ||
| mock_session_manager, mock_session = _create_mock_session_manager() | ||
|
|
||
| toolset = McpToolset(connection_params=mock_connection_params, cache=True) | ||
| toolset._mcp_session_manager = mock_session_manager | ||
|
|
||
| tools1 = await toolset.get_tools() | ||
| tools2 = await toolset.get_tools() | ||
|
|
||
| mock_session.list_tools.assert_called_once() | ||
| assert len(tools1) == 2 | ||
| assert len(tools2) == 2 | ||
| assert tools1[0].name == tools2[0].name | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_mcp_toolset_cache_with_ttl_not_expired(): | ||
| """Test that cache is used when TTL has not expired.""" | ||
| mock_connection_params = MagicMock() | ||
| mock_connection_params.timeout = None | ||
| mock_session_manager, mock_session = _create_mock_session_manager() | ||
|
|
||
| toolset = McpToolset( | ||
| connection_params=mock_connection_params, cache=True, cache_ttl_seconds=10 | ||
| ) | ||
| toolset._mcp_session_manager = mock_session_manager | ||
|
|
||
| await toolset.get_tools() | ||
| await toolset.get_tools() | ||
|
|
||
| mock_session.list_tools.assert_called_once() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_mcp_toolset_cache_with_ttl_expired(): | ||
| """Test that list_tools is called again after TTL expires.""" | ||
| mock_connection_params = MagicMock() | ||
| mock_connection_params.timeout = None | ||
| mock_session_manager, mock_session = _create_mock_session_manager() | ||
|
|
||
| toolset = McpToolset( | ||
| connection_params=mock_connection_params, cache=True, cache_ttl_seconds=1 | ||
| ) | ||
| toolset._mcp_session_manager = mock_session_manager | ||
|
|
||
| await toolset.get_tools() | ||
| mock_session.list_tools.assert_called_once() | ||
|
|
||
| await asyncio.sleep(1.1) | ||
|
|
||
| await toolset.get_tools() | ||
| assert mock_session.list_tools.call_count == 2 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_mcp_toolset_cache_concurrency(): | ||
| """Test that list_tools is called only once during concurrent requests.""" | ||
| mock_connection_params = MagicMock() | ||
| mock_connection_params.timeout = None | ||
|
|
||
| # Create a mock session manager. Add a small delay to the mock call | ||
| # to simulate network latency and increase the chance of a race condition. | ||
| mock_session_manager = MagicMock() | ||
| mock_session = MagicMock() | ||
|
|
||
| mock_tool1 = MagicMock() | ||
| mock_tool1.name = "tool1" | ||
| mock_tool1.description = "tool 1 desc" | ||
| list_tools_result = MagicMock() | ||
| list_tools_result.tools = [mock_tool1] | ||
|
|
||
| async def delayed_list_tools(): | ||
| await asyncio.sleep(0.1) | ||
| return list_tools_result | ||
|
|
||
| mock_session.list_tools = AsyncMock(side_effect=delayed_list_tools) | ||
| mock_session_manager.create_session = AsyncMock(return_value=mock_session) | ||
|
|
||
| # Initialize the toolset with caching enabled | ||
| toolset = McpToolset(connection_params=mock_connection_params, cache=True) | ||
| toolset._mcp_session_manager = mock_session_manager | ||
|
|
||
| # Create multiple concurrent tasks to call get_tools | ||
| tasks = [asyncio.create_task(toolset.get_tools()) for _ in range(5)] | ||
|
|
||
| # Run all tasks concurrently | ||
| results = await asyncio.gather(*tasks) | ||
|
|
||
| # Assert that list_tools was only called once, thanks to the lock | ||
| mock_session.list_tools.assert_called_once() | ||
|
|
||
| # Assert that all results are the same and correct | ||
| assert len(results) == 5 | ||
| for result in results: | ||
| assert len(result) == 1 | ||
| assert result[0].name == "tool1" | ||
|
|
||
| # Check that the first result is the same as the others | ||
| assert all(results[0][0].name == r[0].name for r in results[1:]) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
asyncio.sleep()to test time-dependent logic can make tests slow and potentially flaky. A more robust approach is to mock the time source,time.monotonic. This gives you precise control over time in your test, making it faster and more reliable.You can use
mocker.patchfrompytest-mockto do this. Here's an example of how you could rewrite this test: