diff --git a/src/agents/agent.py b/src/agents/agent.py index 5d700ebaa3..0b81d410fd 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -150,6 +150,11 @@ class MCPConfig(TypedDict): default_tool_error_function. """ + include_server_in_tool_names: NotRequired[bool] + """If True, MCP tools are exposed with an unambiguous server-specific prefix to avoid + collisions across servers that publish the same tool names. Defaults to False. + """ + @dataclass class AgentBase(Generic[TContext]): @@ -186,12 +191,14 @@ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[ failure_error_function = self.mcp_config.get( "failure_error_function", default_tool_error_function ) + include_server_in_tool_names = self.mcp_config.get("include_server_in_tool_names", False) return await MCPUtil.get_all_function_tools( self.mcp_servers, convert_schemas_to_strict, run_context, self, failure_error_function=failure_error_function, + include_server_in_tool_names=include_server_in_tool_names, ) async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 33bea065c5..0075ecaf2d 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -3,6 +3,7 @@ import asyncio import copy import functools +import hashlib import inspect import json from collections.abc import Awaitable @@ -207,10 +208,14 @@ async def get_all_function_tools( run_context: RunContextWrapper[Any], agent: AgentBase, failure_error_function: ToolErrorFunction | None = default_tool_error_function, + include_server_in_tool_names: bool = False, ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() + server_tool_name_prefixes = ( + cls._server_tool_name_prefixes(servers) if include_server_in_tool_names else {} + ) for server in servers: server_tools = await cls.get_function_tools( server, @@ -218,6 +223,8 @@ async def get_all_function_tools( run_context, agent, failure_error_function=failure_error_function, + include_server_in_tool_names=include_server_in_tool_names, + tool_name_prefix=server_tool_name_prefixes.get(id(server)), ) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: @@ -238,6 +245,8 @@ async def get_function_tools( run_context: RunContextWrapper[Any], agent: AgentBase, failure_error_function: ToolErrorFunction | None = default_tool_error_function, + include_server_in_tool_names: bool = False, + tool_name_prefix: str | None = None, ) -> list[Tool]: """Get all function tools from a single MCP server.""" @@ -245,6 +254,10 @@ async def get_function_tools( tools = await server.list_tools(run_context, agent) span.span_data.result = [tool.name for tool in tools] + if tool_name_prefix is None: + tool_name_prefix = ( + cls._server_tool_name_prefix(server.name) if include_server_in_tool_names else "" + ) return [ cls.to_function_tool( tool, @@ -252,10 +265,53 @@ async def get_function_tools( convert_schemas_to_strict, agent, failure_error_function=failure_error_function, + tool_name_override=( + cls._prefixed_tool_name(tool_name_prefix, tool.name) + if tool_name_prefix + else None + ), ) for tool in tools ] + @staticmethod + def _server_tool_name_prefix(server_name: str) -> str: + normalized = "".join( + char if char.isalnum() or char in ("_", "-") else "_" for char in server_name + ) + normalized = normalized.strip("_-") + if not normalized: + normalized = "server" + return f"{normalized}_" + + @staticmethod + def _prefixed_tool_name(tool_name_prefix: str, tool_name: str) -> str: + return f"mcp_{len(tool_name_prefix)}_{tool_name_prefix}{tool_name}" + + @classmethod + def _server_tool_name_prefixes(cls, servers: list[MCPServer]) -> dict[int, str]: + normalized_to_servers: dict[str, list[MCPServer]] = {} + for server in servers: + normalized_prefix = cls._server_tool_name_prefix(server.name)[:-1] + normalized_to_servers.setdefault(normalized_prefix, []).append(server) + + prefixes: dict[int, str] = {} + for normalized_prefix, grouped_servers in normalized_to_servers.items(): + if len(grouped_servers) == 1: + prefixes[id(grouped_servers[0])] = f"{normalized_prefix}_" + continue + + seen_prefixes: set[str] = set() + for index, server in enumerate(grouped_servers, start=1): + hash_suffix = hashlib.sha1(server.name.encode("utf-8")).hexdigest()[:8] + prefix = f"{normalized_prefix}_{hash_suffix}" + if prefix in seen_prefixes: + prefix = f"{prefix}_{index}" + seen_prefixes.add(prefix) + prefixes[id(server)] = f"{prefix}_" + + return prefixes + @classmethod def to_function_tool( cls, @@ -264,6 +320,7 @@ def to_function_tool( convert_schemas_to_strict: bool, agent: AgentBase | None = None, failure_error_function: ToolErrorFunction | None = default_tool_error_function, + tool_name_override: str | None = None, ) -> FunctionTool: """Convert an MCP tool to an Agents SDK function tool. @@ -273,11 +330,13 @@ def to_function_tool( policies. If the server uses a callable approval policy, approvals default to required to avoid bypassing dynamic checks. """ + tool_name = tool_name_override or tool.name static_meta = cls._extract_static_meta(tool) invoke_func_impl = functools.partial( cls.invoke_mcp_tool, server, tool, + tool_display_name=tool_name, meta=static_meta, ) effective_failure_error_function = server._get_failure_error_function( @@ -301,7 +360,7 @@ def to_function_tool( ) = server._get_needs_approval_for_tool(tool, agent) function_tool = _build_wrapped_function_tool( - name=tool.name, + name=tool_name, description=resolve_mcp_tool_description_for_model(tool), params_json_schema=schema, invoke_tool_impl=invoke_func_impl, @@ -367,25 +426,28 @@ async def invoke_mcp_tool( input_json: str, *, meta: dict[str, Any] | None = None, + tool_display_name: str | None = None, ) -> ToolOutput: """Invoke an MCP tool and return the result as ToolOutput.""" + tool_name = tool_display_name or tool.name try: json_data: dict[str, Any] = json.loads(input_json) if input_json else {} except Exception as e: if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"Invalid JSON input for tool {tool.name}") + logger.debug(f"Invalid JSON input for tool {tool_name}") else: - logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}") + logger.debug(f"Invalid JSON input for tool {tool_name}: {input_json}") raise ModelBehaviorError( - f"Invalid JSON input for tool {tool.name}: {input_json}" + f"Invalid JSON input for tool {tool_name}: {input_json}" ) from e if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"Invoking MCP tool {tool.name}") + logger.debug(f"Invoking MCP tool {tool_name}") else: - logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}") + logger.debug(f"Invoking MCP tool {tool_name} with input {input_json}") try: + # Keep meta resolution and server routing keyed by the original MCP tool name. resolved_meta = await cls._resolve_meta(server, context, tool.name, json_data) merged_meta = cls._merge_mcp_meta(resolved_meta, meta) call_task = asyncio.create_task( @@ -423,20 +485,20 @@ async def invoke_mcp_tool( # failure_error_function=None will have the error raised as documented. error_text = e.error.message if hasattr(e, "error") and e.error else str(e) logger.warning( - f"MCP tool {tool.name} on server '{server.name}' returned an error: " + f"MCP tool {tool_name} on server '{server.name}' returned an error: " f"{error_text}" ) raise - logger.error(f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}") + logger.error(f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}") raise AgentsException( - f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}" + f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}" ) from e if _debug.DONT_LOG_TOOL_DATA: - logger.debug(f"MCP tool {tool.name} completed.") + logger.debug(f"MCP tool {tool_name} completed.") else: - logger.debug(f"MCP tool {tool.name} returned {result}") + logger.debug(f"MCP tool {tool_name} returned {result}") # If structured content is requested and available, use it exclusively tool_output: ToolOutput diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index 5a9cbd140c..28da2545b6 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -10,7 +10,12 @@ from pydantic import BaseModel, TypeAdapter from agents import Agent, FunctionTool, RunContextWrapper, default_tool_error_function -from agents.exceptions import AgentsException, MCPToolCancellationError, ModelBehaviorError +from agents.exceptions import ( + AgentsException, + MCPToolCancellationError, + ModelBehaviorError, + UserError, +) from agents.mcp import MCPServer, MCPUtil from agents.tool_context import ToolContext @@ -82,6 +87,136 @@ async def test_get_all_function_tools(): assert all(tool.name in names for tool in tools) +@pytest.mark.asyncio +async def test_get_all_function_tools_duplicate_names_raise_by_default(): + server1 = FakeMCPServer(server_name="github") + server1.add_tool("create_issue", {}) + + server2 = FakeMCPServer(server_name="linear") + server2.add_tool("create_issue", {}) + + run_context = RunContextWrapper(context=None) + agent = Agent(name="test_agent", instructions="Test agent") + + with pytest.raises(UserError, match="Duplicate tool names found across MCP servers"): + await MCPUtil.get_all_function_tools([server1, server2], False, run_context, agent) + + +@pytest.mark.asyncio +async def test_get_all_function_tools_can_prefix_with_server_name(): + server1 = FakeMCPServer(server_name="GitHub MCP Server") + server1.add_tool("create_issue", {}) + + server2 = FakeMCPServer(server_name="linear") + server2.add_tool("create_issue", {}) + + run_context = RunContextWrapper(context=None) + agent = Agent( + name="test_agent", + instructions="Test agent", + mcp_servers=[server1, server2], + mcp_config={"include_server_in_tool_names": True}, + ) + + tools = await agent.get_mcp_tools(run_context) + tool_names = {tool.name for tool in tools} + assert tool_names == {"mcp_18_GitHub_MCP_Server_create_issue", "mcp_7_linear_create_issue"} + + github_tool = next( + tool for tool in tools if tool.name == "mcp_18_GitHub_MCP_Server_create_issue" + ) + linear_tool = next(tool for tool in tools if tool.name == "mcp_7_linear_create_issue") + assert isinstance(github_tool, FunctionTool) + assert isinstance(linear_tool, FunctionTool) + + github_ctx = ToolContext( + context=None, + tool_name=github_tool.name, + tool_call_id="prefixed_call_1", + tool_arguments='{"title":"a"}', + ) + linear_ctx = ToolContext( + context=None, + tool_name=linear_tool.name, + tool_call_id="prefixed_call_2", + tool_arguments='{"title":"b"}', + ) + + github_result = await github_tool.on_invoke_tool(github_ctx, '{"title":"a"}') + linear_result = await linear_tool.on_invoke_tool(linear_ctx, '{"title":"b"}') + assert isinstance(github_result, dict) + assert isinstance(linear_result, dict) + assert server1.tool_calls == ["create_issue"] + assert server2.tool_calls == ["create_issue"] + + +@pytest.mark.asyncio +async def test_get_all_function_tools_prefix_falls_back_for_empty_server_name_slug(): + server = FakeMCPServer(server_name="!!!") + server.add_tool("search", {}) + + run_context = RunContextWrapper(context=None) + agent = Agent( + name="test_agent", + instructions="Test agent", + mcp_servers=[server], + mcp_config={"include_server_in_tool_names": True}, + ) + + tools = await agent.get_mcp_tools(run_context) + assert len(tools) == 1 + prefixed_tool = tools[0] + assert isinstance(prefixed_tool, FunctionTool) + assert prefixed_tool.name == "mcp_7_server_search" + + tool_context = ToolContext( + context=None, + tool_name=prefixed_tool.name, + tool_call_id="prefixed_call_3", + tool_arguments='{"query":"docs"}', + ) + result = await prefixed_tool.on_invoke_tool(tool_context, '{"query":"docs"}') + assert isinstance(result, dict) + assert server.tool_calls == ["search"] + + +@pytest.mark.asyncio +async def test_get_all_function_tools_disambiguates_colliding_server_name_prefixes(): + server1 = FakeMCPServer(server_name="GitHub MCP Server") + server1.add_tool("create_issue", {}) + + server2 = FakeMCPServer(server_name="GitHub_MCP_Server") + server2.add_tool("create_issue", {}) + + run_context = RunContextWrapper(context=None) + agent = Agent( + name="test_agent", + instructions="Test agent", + mcp_servers=[server1, server2], + mcp_config={"include_server_in_tool_names": True}, + ) + + tools = await agent.get_mcp_tools(run_context) + tool_names = {tool.name for tool in tools} + assert len(tool_names) == 2 + assert all(tool_name.startswith("mcp_27_GitHub_MCP_Server_") for tool_name in tool_names) + assert all(tool_name.endswith("_create_issue") for tool_name in tool_names) + + for idx, tool in enumerate(tools, start=1): + assert isinstance(tool, FunctionTool) + tool_context = ToolContext( + context=None, + tool_name=tool.name, + tool_call_id=f"prefixed_collision_{idx}", + tool_arguments='{"title":"collision"}', + ) + result = await tool.on_invoke_tool(tool_context, '{"title":"collision"}') + assert isinstance(result, dict) + + assert server1.tool_calls == ["create_issue"] + assert server2.tool_calls == ["create_issue"] + + @pytest.mark.asyncio async def test_invoke_mcp_tool(): """Test that the invoke_mcp_tool function invokes an MCP tool and returns the result.""" @@ -127,6 +262,48 @@ def resolve_meta(context): assert captured["arguments"] == {} +@pytest.mark.asyncio +async def test_mcp_meta_resolver_uses_original_tool_name_with_prefixed_display_name(): + captured: dict[str, Any] = {} + + def resolve_meta(context): + captured["tool_name"] = context.tool_name + return {"scope": "meta"} + + server = FakeMCPServer( + server_name="GitHub MCP Server", + tool_meta_resolver=resolve_meta, + ) + server.add_tool("create_issue", {}) + + run_context = RunContextWrapper(context=None) + agent = Agent( + name="test_agent", + instructions="Test agent", + mcp_servers=[server], + mcp_config={"include_server_in_tool_names": True}, + ) + + tools = await agent.get_mcp_tools(run_context) + assert len(tools) == 1 + + prefixed_tool = tools[0] + assert isinstance(prefixed_tool, FunctionTool) + assert prefixed_tool.name == "mcp_18_GitHub_MCP_Server_create_issue" + + tool_context = ToolContext( + context=None, + tool_name=prefixed_tool.name, + tool_call_id="prefixed_call_meta_1", + tool_arguments='{"title":"a"}', + ) + await prefixed_tool.on_invoke_tool(tool_context, '{"title":"a"}') + + assert captured["tool_name"] == "create_issue" + assert server.tool_calls == ["create_issue"] + assert server.tool_metas[-1] == {"scope": "meta"} + + @pytest.mark.asyncio async def test_mcp_meta_resolver_does_not_mutate_arguments(): def resolve_meta(context): @@ -547,6 +724,40 @@ async def test_mcp_tool_graceful_error_handling(caplog: pytest.LogCaptureFixture ) +@pytest.mark.asyncio +async def test_mcp_tool_failure_logs_prefixed_name_when_tool_data_logging_enabled( + caplog: pytest.LogCaptureFixture, monkeypatch: pytest.MonkeyPatch +): + import agents._debug as debug_settings + + caplog.set_level(logging.ERROR) + monkeypatch.setattr(debug_settings, "DONT_LOG_TOOL_DATA", False) + + server = CrashingFakeMCPServer() + server.add_tool("crashing_tool", {}) + + mcp_tool = MCPTool(name="crashing_tool", inputSchema={}) + agent = Agent(name="test-agent") + function_tool = MCPUtil.to_function_tool( + mcp_tool, + server, + convert_schemas_to_strict=False, + agent=agent, + tool_name_override="prefixed_crashing_tool", + ) + + tool_context = ToolContext( + context=None, + tool_name="prefixed_crashing_tool", + tool_call_id="test_call_prefixed_log", + tool_arguments="{}", + ) + result = await function_tool.on_invoke_tool(tool_context, "{}") + + assert isinstance(result, str) + assert "MCP tool prefixed_crashing_tool failed" in caplog.text + + @pytest.mark.asyncio async def test_mcp_tool_timeout_handling(): """Test that MCP tool timeouts are handled gracefully.