diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 9dfb29932f..fea224301b 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -1076,12 +1076,17 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: raise ToolExecutionException( "Tools are not loaded for this server, please set load_tools=True in the constructor." ) + # Extract user-supplied _meta before filtering so it is forwarded as + # MCP request metadata rather than as a tool argument. + user_meta = kwargs.pop("_meta", None) + # Filter out framework kwargs that cannot be serialized by the MCP SDK. # These are internal objects passed through the function invocation pipeline # that should not be forwarded to external MCP servers. # conversation_id is an internal tracking ID used by services like Azure AI. # options contains metadata/store used by AG-UI for Azure AI client requirements. # response_format is a Pydantic model class used for structured output (not serializable). + # _meta is handled separately and merged into the MCP request meta field. filtered_kwargs = { k: v for k, v in kwargs.items() @@ -1095,11 +1100,13 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: "conversation_id", "options", "response_format", + "_meta", } } - # Inject OpenTelemetry trace context into MCP _meta for distributed tracing. - otel_meta = _inject_otel_into_mcp_meta() + # Merge user-supplied _meta with OpenTelemetry trace context. + # User keys take precedence; OTel keys fill in non-conflicting slots. + otel_meta = _inject_otel_into_mcp_meta(dict(user_meta) if user_meta else None) parser = self.parse_tool_results or self._parse_tool_result_from_mcp # Try the operation, reconnecting once if the connection is closed diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 01cf1717bd..6a0995c0f3 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -3804,6 +3804,388 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: assert meta is None +async def test_mcp_tool_call_tool_user_meta_propagated(): + """User-supplied _meta in kwargs is forwarded to session.call_tool as meta.""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + from opentelemetry import trace + + # Use an invalid span so OTel does not inject anything — isolate user meta. + with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): + await server.call_tool("test_tool", param="hello", _meta={"session_id": "abc", "locale": "fr-FR"}) + + call_kwargs = server.session.call_tool.call_args + meta = call_kwargs.kwargs.get("meta") + arguments = call_kwargs.kwargs.get("arguments") + + # _meta forwarded as meta + assert meta == {"session_id": "abc", "locale": "fr-FR"} + # _meta not leaked into arguments + assert "_meta" not in arguments + # regular arg preserved + assert arguments["param"] == "hello" + + +async def test_mcp_tool_call_tool_user_meta_merged_with_otel(span_exporter): + """User-supplied _meta is merged with OTel context; user keys win on conflicts.""" + from opentelemetry import trace + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("merge_span"): + await server.call_tool("test_tool", param="v", _meta={"custom_key": "custom_val"}) + + meta = server.session.call_tool.call_args.kwargs.get("meta") + assert meta is not None + # User key present + assert meta["custom_key"] == "custom_val" + # OTel keys also present (at least one propagation header) + assert len(meta) > 1 + + +async def test_mcp_tool_call_tool_no_meta_backward_compat(): + """Without _meta, behavior is unchanged (OTel-only or None).""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + from opentelemetry import trace + + with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): + await server.call_tool("test_tool", param="hello") + + meta = server.session.call_tool.call_args.kwargs.get("meta") + # No _meta supplied and no active span → meta should be None + assert meta is None + + +async def test_mcp_tool_call_tool_meta_not_mutated(): + """The original _meta dict passed by the caller must not be mutated.""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + from opentelemetry import trace + + original_meta = {"session_id": "abc"} + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("mutate_span"): + await server.call_tool("test_tool", param="v", _meta=original_meta) + + # Original dict should not have been modified by OTel injection + assert original_meta == {"session_id": "abc"} + + +async def test_mcp_streamable_http_tool_call_tool_meta_forwarded(): + """_meta flows through MCPStreamableHTTPTool.call_tool() override to super().""" + + server = MCPStreamableHTTPTool(name="test", url="http://example.com/mcp") + server.session = Mock(spec=ClientSession) + server.is_connected = True + server.load_tools_flag = True + server._tools_loaded = True + server.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + server.session.send_ping = AsyncMock() + + from opentelemetry import trace + + with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): + await server.call_tool("test_tool", param="v", _meta={"req_id": "r1"}) + + meta = server.session.call_tool.call_args.kwargs.get("meta") + arguments = server.session.call_tool.call_args.kwargs.get("arguments") + assert meta == {"req_id": "r1"} + assert "_meta" not in arguments + + +async def test_mcp_tool_call_tool_user_meta_wins_on_otel_conflict(span_exporter): + """When user _meta has a key that OTel would also inject, user value takes precedence.""" + from opentelemetry import trace + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + tracer = trace.get_tracer("test") + with tracer.start_as_current_span("conflict_span"): + # "traceparent" is the W3C key that OTel injects + await server.call_tool("test_tool", param="v", _meta={"traceparent": "user-override"}) + + meta = server.session.call_tool.call_args.kwargs.get("meta") + assert meta is not None + # User value must win because _inject_otel_into_mcp_meta skips keys already present + assert meta["traceparent"] == "user-override" + + +async def test_mcp_tool_call_tool_empty_meta_dict(): + """Passing _meta={} should behave like no user meta (OTel-only if active, else None).""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + + from opentelemetry import trace + + # Empty dict with no active span → empty dict is falsy, so treated like None + with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): + await server.call_tool("test_tool", param="v", _meta={}) + + meta = server.session.call_tool.call_args.kwargs.get("meta") + # Empty dict is falsy, so user_meta is treated as None → no meta injected + assert meta is None + + arguments = server.session.call_tool.call_args.kwargs.get("arguments") + assert "_meta" not in arguments + + +async def test_mcp_tool_call_tool_meta_via_function_tool_wrapper(): + """_meta flows end-to-end through the FunctionTool wrapper created by load_tools.""" + + class TestServer(MCPTool): + async def connect(self): + self.session = Mock(spec=ClientSession) + self.session.list_tools = AsyncMock( + return_value=types.ListToolsResult( + tools=[ + types.Tool( + name="test_tool", + description="Test tool", + inputSchema={ + "type": "object", + "properties": {"param": {"type": "string"}}, + }, + ) + ] + ) + ) + self.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + + def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: + return None + + server = TestServer(name="test_server") + async with server: + await server.load_tools() + func = server.functions[0] + + from opentelemetry import trace + + # Invoke through FunctionTool with _meta in the FunctionInvocationContext.kwargs, + # which is how it arrives from function_invocation_kwargs in agent.run(). + ctx = FunctionInvocationContext( + function=func, + arguments={"param": "hello"}, + kwargs={"_meta": {"correlation_id": "c1"}}, + ) + with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): + await func.invoke(arguments={"param": "hello"}, context=ctx) + + meta = server.session.call_tool.call_args.kwargs.get("meta") + arguments = server.session.call_tool.call_args.kwargs.get("arguments") + assert meta is not None + # User key forwarded + assert meta["correlation_id"] == "c1" + assert "_meta" not in arguments + assert arguments["param"] == "hello" + + +async def test_mcp_streamable_http_tool_meta_with_header_provider(): + """_meta and header_provider work together without interference.""" + captured_headers: dict[str, str] = {} + + def header_provider(kwargs: dict) -> dict[str, str]: + captured_headers.update({"Authorization": f"Bearer {kwargs.get('token', '')}"}) + return captured_headers + + server = MCPStreamableHTTPTool( + name="test", + url="http://example.com/mcp", + header_provider=header_provider, + ) + server.session = Mock(spec=ClientSession) + server.is_connected = True + server.load_tools_flag = True + server._tools_loaded = True + server.session.call_tool = AsyncMock( + return_value=types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) + ) + server.session.send_ping = AsyncMock() + + from opentelemetry import trace + + with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)): + await server.call_tool("test_tool", param="v", token="sk-xxx", _meta={"req_id": "r2"}) + + meta = server.session.call_tool.call_args.kwargs.get("meta") + arguments = server.session.call_tool.call_args.kwargs.get("arguments") + + # _meta forwarded as protocol-level meta + assert meta == {"req_id": "r2"} + # _meta and token not in arguments + assert "_meta" not in arguments + # header_provider was called (headers captured) + assert captured_headers["Authorization"] == "Bearer sk-xxx" + + +# endregion + + async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client(): """Test that calling get_mcp_client multiple times does not accumulate duplicate hooks.""" tool = MCPStreamableHTTPTool(