diff --git a/internal/server/e2e_test.go b/internal/server/e2e_test.go index 035064c8..5d7e82a9 100644 --- a/internal/server/e2e_test.go +++ b/internal/server/e2e_test.go @@ -3010,3 +3010,173 @@ func TestE2E_ServerDeleteReaddDifferentTools(t *testing.T) { t.Log("Phase 3 & 4 Complete: ONLY Tool Set B (new_tool_gamma) searchable and callable") t.Log("SUCCESS: Stale index entries cleaned up correctly on server re-add") } + +// Test: retrieve_tools returns correct annotations and call_with based on tool hints (Issue #306) +func TestE2E_RetrieveToolsAnnotationsAndCallWith(t *testing.T) { + env := NewTestEnvironment(t) + defer env.Cleanup() + + trueVal := true + falseVal := false + + // Create mock upstream server with tools that have different annotation hints + mockTools := []mcp.Tool{ + { + Name: "delete_records", + Description: "Delete records from the database permanently", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{}, + }, + Annotations: mcp.ToolAnnotation{ + DestructiveHint: &trueVal, + ReadOnlyHint: &falseVal, + }, + }, + { + Name: "update_config", + Description: "Update server configuration settings", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{}, + }, + Annotations: mcp.ToolAnnotation{ + ReadOnlyHint: &falseVal, + }, + }, + { + Name: "list_items", + Description: "List all items in the inventory", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{}, + }, + Annotations: mcp.ToolAnnotation{ + ReadOnlyHint: &trueVal, + }, + }, + } + + mockServer := env.CreateMockUpstreamServer("annotated", mockTools) + + mcpClient := env.CreateProxyClient() + defer mcpClient.Close() + env.ConnectClient(mcpClient) + + ctx := context.Background() + + // Add and unquarantine the server + addRequest := mcp.CallToolRequest{} + addRequest.Params.Name = "upstream_servers" + addRequest.Params.Arguments = map[string]interface{}{ + "operation": "add", + "name": "annotated", + "url": mockServer.addr, + "protocol": "streamable-http", + "enabled": true, + } + + result, err := mcpClient.CallTool(ctx, addRequest) + require.NoError(t, err) + assert.False(t, result.IsError) + + serverConfig, err := env.proxyServer.runtime.StorageManager().GetUpstreamServer("annotated") + require.NoError(t, err) + serverConfig.Quarantined = false + err = env.proxyServer.runtime.StorageManager().SaveUpstreamServer(serverConfig) + require.NoError(t, err) + + servers, err := env.proxyServer.runtime.StorageManager().ListUpstreamServers() + require.NoError(t, err) + + cfg := env.proxyServer.runtime.Config() + cfg.Servers = servers + err = env.proxyServer.runtime.LoadConfiguredServers(cfg) + require.NoError(t, err) + + // Wait for connection and trigger tool discovery + time.Sleep(4 * time.Second) + err = env.proxyServer.runtime.DiscoverAndIndexTools(ctx) + require.NoError(t, err) + time.Sleep(2 * time.Second) + + // Do a broad search to find all annotated tools + searchRequest := mcp.CallToolRequest{} + searchRequest.Params.Name = "retrieve_tools" + searchRequest.Params.Arguments = map[string]interface{}{ + "query": "delete update list records config items", + "limit": 20, + } + searchResult, err := mcpClient.CallTool(ctx, searchRequest) + require.NoError(t, err) + assert.False(t, searchResult.IsError) + + require.Greater(t, len(searchResult.Content), 0) + contentBytes, err := json.Marshal(searchResult.Content[0]) + require.NoError(t, err) + var contentMap map[string]interface{} + err = json.Unmarshal(contentBytes, &contentMap) + require.NoError(t, err) + contentText, ok := contentMap["text"].(string) + require.True(t, ok) + + var searchResponse map[string]interface{} + err = json.Unmarshal([]byte(contentText), &searchResponse) + require.NoError(t, err) + + tools, ok := searchResponse["tools"].([]interface{}) + require.True(t, ok, "tools should be an array") + t.Logf("Found %d tools in search", len(tools)) + for _, toolRaw := range tools { + if tool, ok := toolRaw.(map[string]interface{}); ok { + t.Logf(" Tool: name=%v, server=%v, call_with=%v, annotations=%v", + tool["name"], tool["server"], tool["call_with"], tool["annotations"]) + } + } + require.GreaterOrEqual(t, len(tools), 3, "Should find all 3 annotated tools") + + // Verify call_with for each tool + // Note: tool names may be bare ("delete_records") or prefixed ("annotated:delete_records") + // depending on how they were indexed. Match by bare name + server field. + expectedCallWith := map[string]string{ + "delete_records": "call_tool_destructive", + "update_config": "call_tool_write", + "list_items": "call_tool_read", + } + + for _, toolRaw := range tools { + tool, ok := toolRaw.(map[string]interface{}) + if !ok { + continue + } + toolName, _ := tool["name"].(string) + serverField, _ := tool["server"].(string) + + // Strip server prefix if present for matching + bareName := toolName + if parts := strings.SplitN(toolName, ":", 2); len(parts) == 2 { + bareName = parts[1] + } + + expected, isOurs := expectedCallWith[bareName] + if !isOurs || serverField != "annotated" { + continue + } + + callWith, ok := tool["call_with"].(string) + assert.True(t, ok, "call_with should be a string for %s", toolName) + assert.Equal(t, expected, callWith, + "call_with mismatch for %s: expected %s, got %s", + toolName, expected, callWith) + + // Verify annotations are present for destructive/write tools + if expected != "call_tool_read" { + assert.NotNil(t, tool["annotations"], + "annotations should be present for %s", toolName) + } + + delete(expectedCallWith, bareName) + } + + assert.Empty(t, expectedCallWith, "Not all expected tools were found: %v", expectedCallWith) +} diff --git a/internal/server/mcp.go b/internal/server/mcp.go index c8c4393b..e96debc7 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -847,17 +847,26 @@ func (p *MCPProxyServer) handleRetrieveTools(ctx context.Context, request mcp.Ca } // Look up tool annotations and derive recommended call_with variant (Spec 018) - // Parse tool name to get just the tool part (format: server:tool) - parts := strings.SplitN(result.Tool.Name, ":", 2) - if len(parts) == 2 { - annotations := p.lookupToolAnnotations(parts[0], parts[1]) + // Use ServerName directly - result.Tool.Name may or may not have "server:" prefix + // depending on how tools were indexed (Issue #306) + serverName := result.Tool.ServerName + toolName := result.Tool.Name + if serverName == "" { + // Fallback: try to extract from "server:tool" format + if parts := strings.SplitN(result.Tool.Name, ":", 2); len(parts) == 2 { + serverName = parts[0] + toolName = parts[1] + } + } + + if serverName != "" { + annotations := p.lookupToolAnnotations(serverName, toolName) if annotations != nil { mcpTool["annotations"] = annotations } // Add call_with recommendation based on annotations mcpTool["call_with"] = contracts.DeriveCallWith(annotations) } else { - // Fallback for tools without server prefix (shouldn't happen normally) mcpTool["call_with"] = contracts.ToolVariantRead // Default to read - safest option } @@ -3984,7 +3993,9 @@ func (p *MCPProxyServer) lookupToolAnnotations(serverName, toolName string) *con } for _, tool := range serverStatus.Tools { - if tool.Name == toolName { + // tool.Name may be in "server:tool" format (from ToolMetadata.Name), + // while toolName is just the tool part. Match both formats. + if tool.Name == toolName || tool.Name == serverName+":"+toolName { return tool.Annotations } } diff --git a/internal/server/mcp_test.go b/internal/server/mcp_test.go index b985cc6f..c65e0816 100644 --- a/internal/server/mcp_test.go +++ b/internal/server/mcp_test.go @@ -12,6 +12,7 @@ import ( "go.uber.org/zap" "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/contracts" "github.com/smart-mcp-proxy/mcpproxy-go/internal/secret" "github.com/smart-mcp-proxy/mcpproxy-go/internal/upstream" ) @@ -307,6 +308,163 @@ func TestToolFormatConversion(t *testing.T) { assert.Contains(t, properties, "market_data") } +// TestAnnotationLookupNameMatching tests that lookupToolAnnotations correctly +// matches tool names regardless of whether StateView stores them as "tool" +// or "server:tool" format. This is the bug reported in Issue #306. +func TestAnnotationLookupNameMatching(t *testing.T) { + trueVal := true + falseVal := false + + tests := []struct { + name string + serverName string + toolName string + stateViewName string // How the tool name is stored in StateView + annotations *config.ToolAnnotations + expectedCallWith string + }{ + { + name: "StateView stores full name (server:tool), lookup uses bare tool name", + serverName: "github", + toolName: "delete_repo", + stateViewName: "github:delete_repo", + annotations: &config.ToolAnnotations{ + DestructiveHint: &trueVal, + }, + expectedCallWith: "call_tool_destructive", + }, + { + name: "StateView stores bare tool name, lookup uses bare tool name", + serverName: "github", + toolName: "delete_repo", + stateViewName: "delete_repo", + annotations: &config.ToolAnnotations{ + DestructiveHint: &trueVal, + }, + expectedCallWith: "call_tool_destructive", + }, + { + name: "write tool with full name in StateView", + serverName: "myserver", + toolName: "update_config", + stateViewName: "myserver:update_config", + annotations: &config.ToolAnnotations{ + ReadOnlyHint: &falseVal, + }, + expectedCallWith: "call_tool_write", + }, + { + name: "read-only tool with full name in StateView", + serverName: "myserver", + toolName: "list_items", + stateViewName: "myserver:list_items", + annotations: &config.ToolAnnotations{ + ReadOnlyHint: &trueVal, + }, + expectedCallWith: "call_tool_read", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Simulate the name matching logic from lookupToolAnnotations + // This tests the fix for Issue #306 where tool.Name in StateView + // is "server:tool" but toolName passed in is just "tool" + matched := false + toolNameInStateView := tt.stateViewName + if toolNameInStateView == tt.toolName || toolNameInStateView == tt.serverName+":"+tt.toolName { + matched = true + } + + assert.True(t, matched, "Tool name matching failed: stateview=%q, lookup=%q", + tt.stateViewName, tt.toolName) + + // Verify DeriveCallWith returns correct variant when annotations are found + if matched { + callWith := contracts.DeriveCallWith(tt.annotations) + assert.Equal(t, tt.expectedCallWith, callWith) + } + }) + } +} + +// TestRetrieveToolsCallWithAnnotations verifies that the handleRetrieveTools +// code path correctly splits tool names and derives call_with from annotations. +// This is a regression test for Issue #306. +func TestRetrieveToolsCallWithAnnotations(t *testing.T) { + trueVal := true + falseVal := false + + // Simulate search results as returned by the index + mockResults := []*config.SearchResult{ + { + Tool: &config.ToolMetadata{ + Name: "myserver:delete_data", + ServerName: "myserver", + Annotations: &config.ToolAnnotations{ + DestructiveHint: &trueVal, + }, + }, + Score: 0.9, + }, + { + Tool: &config.ToolMetadata{ + Name: "myserver:update_config", + ServerName: "myserver", + Annotations: &config.ToolAnnotations{ + ReadOnlyHint: &falseVal, + }, + }, + Score: 0.8, + }, + { + Tool: &config.ToolMetadata{ + Name: "myserver:list_items", + ServerName: "myserver", + Annotations: &config.ToolAnnotations{ + ReadOnlyHint: &trueVal, + }, + }, + Score: 0.7, + }, + { + Tool: &config.ToolMetadata{ + Name: "myserver:unknown_tool", + ServerName: "myserver", + // No annotations + }, + Score: 0.6, + }, + } + + // Simulate the annotation lookup + call_with derivation from handleRetrieveTools + // In production, lookupToolAnnotations queries the StateView, but we can test + // the name-splitting logic and DeriveCallWith here. + for _, result := range mockResults { + parts := strings.SplitN(result.Tool.Name, ":", 2) + require.Len(t, parts, 2, "Tool name should be in server:tool format: %s", result.Tool.Name) + + // The fix for #306: even if lookupToolAnnotations can't find annotations + // via StateView (because of name mismatch), we can verify the split is correct + assert.Equal(t, result.Tool.ServerName, parts[0], + "Server name from split should match ServerName field") + + // Verify DeriveCallWith with the tool's own annotations + callWith := contracts.DeriveCallWith(result.Tool.Annotations) + + switch result.Tool.Name { + case "myserver:delete_data": + assert.Equal(t, "call_tool_destructive", callWith) + case "myserver:update_config": + assert.Equal(t, "call_tool_write", callWith) + case "myserver:list_items": + assert.Equal(t, "call_tool_read", callWith) + case "myserver:unknown_tool": + assert.Equal(t, "call_tool_read", callWith) // nil annotations → safe default + } + } +} + func TestUpstreamServerOperations(t *testing.T) { // Test basic server operations parsing t.Run("BasicServerOperations", func(t *testing.T) {