Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ class MCPConfig(TypedDict):
default_tool_error_function.
"""

include_server_in_tool_names: NotRequired[bool]
"""If True, MCP tools are exposed with an unambiguous server-specific prefix to avoid
collisions across servers that publish the same tool names. Defaults to False.
"""


@dataclass
class AgentBase(Generic[TContext]):
Expand Down Expand Up @@ -186,12 +191,14 @@ async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[
failure_error_function = self.mcp_config.get(
"failure_error_function", default_tool_error_function
)
include_server_in_tool_names = self.mcp_config.get("include_server_in_tool_names", False)
return await MCPUtil.get_all_function_tools(
self.mcp_servers,
convert_schemas_to_strict,
run_context,
self,
failure_error_function=failure_error_function,
include_server_in_tool_names=include_server_in_tool_names,
)

async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
Expand Down
84 changes: 73 additions & 11 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import copy
import functools
import hashlib
import inspect
import json
from collections.abc import Awaitable
Expand Down Expand Up @@ -207,17 +208,23 @@ async def get_all_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
include_server_in_tool_names: bool = False,
) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
tool_names: set[str] = set()
server_tool_name_prefixes = (
cls._server_tool_name_prefixes(servers) if include_server_in_tool_names else {}
)
for server in servers:
server_tools = await cls.get_function_tools(
server,
convert_schemas_to_strict,
run_context,
agent,
failure_error_function=failure_error_function,
include_server_in_tool_names=include_server_in_tool_names,
tool_name_prefix=server_tool_name_prefixes.get(id(server)),
)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
Expand All @@ -238,24 +245,73 @@ async def get_function_tools(
run_context: RunContextWrapper[Any],
agent: AgentBase,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
include_server_in_tool_names: bool = False,
tool_name_prefix: str | None = None,
) -> list[Tool]:
"""Get all function tools from a single MCP server."""

with mcp_tools_span(server=server.name) as span:
tools = await server.list_tools(run_context, agent)
span.span_data.result = [tool.name for tool in tools]

if tool_name_prefix is None:
tool_name_prefix = (
cls._server_tool_name_prefix(server.name) if include_server_in_tool_names else ""
)
return [
cls.to_function_tool(
tool,
server,
convert_schemas_to_strict,
agent,
failure_error_function=failure_error_function,
tool_name_override=(
cls._prefixed_tool_name(tool_name_prefix, tool.name)
if tool_name_prefix
else None
),
)
for tool in tools
]

@staticmethod
def _server_tool_name_prefix(server_name: str) -> str:
normalized = "".join(
char if char.isalnum() or char in ("_", "-") else "_" for char in server_name
)
normalized = normalized.strip("_-")
if not normalized:
normalized = "server"
return f"{normalized}_"

@staticmethod
def _prefixed_tool_name(tool_name_prefix: str, tool_name: str) -> str:
return f"mcp_{len(tool_name_prefix)}_{tool_name_prefix}{tool_name}"

@classmethod
def _server_tool_name_prefixes(cls, servers: list[MCPServer]) -> dict[int, str]:
normalized_to_servers: dict[str, list[MCPServer]] = {}
for server in servers:
normalized_prefix = cls._server_tool_name_prefix(server.name)[:-1]
normalized_to_servers.setdefault(normalized_prefix, []).append(server)

prefixes: dict[int, str] = {}
for normalized_prefix, grouped_servers in normalized_to_servers.items():
if len(grouped_servers) == 1:
prefixes[id(grouped_servers[0])] = f"{normalized_prefix}_"
continue

seen_prefixes: set[str] = set()
for index, server in enumerate(grouped_servers, start=1):
hash_suffix = hashlib.sha1(server.name.encode("utf-8")).hexdigest()[:8]
prefix = f"{normalized_prefix}_{hash_suffix}"
if prefix in seen_prefixes:
prefix = f"{prefix}_{index}"
seen_prefixes.add(prefix)
prefixes[id(server)] = f"{prefix}_"

return prefixes

@classmethod
def to_function_tool(
cls,
Expand All @@ -264,6 +320,7 @@ def to_function_tool(
convert_schemas_to_strict: bool,
agent: AgentBase | None = None,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
tool_name_override: str | None = None,
) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool.

Expand All @@ -273,11 +330,13 @@ def to_function_tool(
policies. If the server uses a callable approval policy, approvals default
to required to avoid bypassing dynamic checks.
"""
tool_name = tool_name_override or tool.name
static_meta = cls._extract_static_meta(tool)
invoke_func_impl = functools.partial(
cls.invoke_mcp_tool,
server,
tool,
tool_display_name=tool_name,
meta=static_meta,
)
effective_failure_error_function = server._get_failure_error_function(
Expand All @@ -301,7 +360,7 @@ def to_function_tool(
) = server._get_needs_approval_for_tool(tool, agent)

function_tool = _build_wrapped_function_tool(
name=tool.name,
name=tool_name,
description=resolve_mcp_tool_description_for_model(tool),
params_json_schema=schema,
invoke_tool_impl=invoke_func_impl,
Expand Down Expand Up @@ -367,25 +426,28 @@ async def invoke_mcp_tool(
input_json: str,
*,
meta: dict[str, Any] | None = None,
tool_display_name: str | None = None,
) -> ToolOutput:
"""Invoke an MCP tool and return the result as ToolOutput."""
tool_name = tool_display_name or tool.name
try:
json_data: dict[str, Any] = json.loads(input_json) if input_json else {}
except Exception as e:
if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invalid JSON input for tool {tool.name}")
logger.debug(f"Invalid JSON input for tool {tool_name}")
else:
logger.debug(f"Invalid JSON input for tool {tool.name}: {input_json}")
logger.debug(f"Invalid JSON input for tool {tool_name}: {input_json}")
raise ModelBehaviorError(
f"Invalid JSON input for tool {tool.name}: {input_json}"
f"Invalid JSON input for tool {tool_name}: {input_json}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"Invoking MCP tool {tool.name}")
logger.debug(f"Invoking MCP tool {tool_name}")
else:
logger.debug(f"Invoking MCP tool {tool.name} with input {input_json}")
logger.debug(f"Invoking MCP tool {tool_name} with input {input_json}")

try:
# Keep meta resolution and server routing keyed by the original MCP tool name.
resolved_meta = await cls._resolve_meta(server, context, tool.name, json_data)
merged_meta = cls._merge_mcp_meta(resolved_meta, meta)
call_task = asyncio.create_task(
Expand Down Expand Up @@ -423,20 +485,20 @@ async def invoke_mcp_tool(
# failure_error_function=None will have the error raised as documented.
error_text = e.error.message if hasattr(e, "error") and e.error else str(e)
logger.warning(
f"MCP tool {tool.name} on server '{server.name}' returned an error: "
f"MCP tool {tool_name} on server '{server.name}' returned an error: "
f"{error_text}"
)
raise

logger.error(f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}")
logger.error(f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}")
raise AgentsException(
f"Error invoking MCP tool {tool.name} on server '{server.name}': {e}"
f"Error invoking MCP tool {tool_name} on server '{server.name}': {e}"
) from e

if _debug.DONT_LOG_TOOL_DATA:
logger.debug(f"MCP tool {tool.name} completed.")
logger.debug(f"MCP tool {tool_name} completed.")
else:
logger.debug(f"MCP tool {tool.name} returned {result}")
logger.debug(f"MCP tool {tool_name} returned {result}")

# If structured content is requested and available, use it exclusively
tool_output: ToolOutput
Expand Down
Loading