diff --git a/internal/services/toolkit/tools/xtramcp/loader.go b/internal/services/toolkit/tools/xtramcp/loader.go index a47b95c..3f33f16 100644 --- a/internal/services/toolkit/tools/xtramcp/loader.go +++ b/internal/services/toolkit/tools/xtramcp/loader.go @@ -53,17 +53,45 @@ func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolReg // Register each tool dynamically, passing the session ID for _, toolSchema := range toolSchemas { - dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID) + // some tools require secrutiy context injection e.g. user_id to authenticate + requiresInjection := loader.requiresSecurityInjection(toolSchema) + + dynamicTool := NewDynamicTool( + loader.db, + loader.projectService, + toolSchema, + loader.baseURL, + loader.sessionID, + requiresInjection, + ) // Register the tool with the registry toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call) - fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name) + if requiresInjection { + fmt.Printf("Registered dynamic tool with security injection: %s\n", toolSchema.Name) + } else { + fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name) + } } return nil } +// checks if a tool schema contains parameters that should be inejected instead of LLM-generated +func (loader *XtraMCPLoader) requiresSecurityInjection(schema ToolSchema) bool { + properties, ok := schema.InputSchema["properties"].(map[string]interface{}) + if !ok { + return false + } + + // injected parameters + _, hasUserId := properties["user_id"] + _, hasProjectId := properties["project_id"] + + return hasUserId || hasProjectId +} + // InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it func (loader *XtraMCPLoader) InitializeMCP() (string, error) { // Step 1: Initialize diff --git a/internal/services/toolkit/tools/xtramcp/schema_filter.go b/internal/services/toolkit/tools/xtramcp/schema_filter.go new file mode 100644 index 0000000..08ecd9b --- /dev/null +++ b/internal/services/toolkit/tools/xtramcp/schema_filter.go @@ -0,0 +1,65 @@ +package xtramcp + +import "encoding/json" + +// parameters that should be injected server-side +var securityParameters = []string{"user_id", "project_id"} + +// removes security parameters from schema shown to LLM so LLM does not need to generate / fill +func filterSecurityParameters(schema map[string]interface{}) map[string]interface{} { + filtered := deepCopySchema(schema) + + // Remove from properties + if properties, ok := filtered["properties"].(map[string]interface{}); ok { + for _, param := range securityParameters { + delete(properties, param) + } + } + + // Remove from required array + if required, ok := filtered["required"].([]interface{}); ok { + filtered["required"] = filterRequiredArray(required, securityParameters) + } + + return filtered +} + +// creates a deep copy of the schema using JSON marshal/unmarshal +func deepCopySchema(schema map[string]interface{}) map[string]interface{} { + // Use JSON marshal/unmarshal for deep copy + jsonBytes, err := json.Marshal(schema) + if err != nil { + // If marshaling fails, return original schema + return schema + } + + var copy map[string]interface{} + err = json.Unmarshal(jsonBytes, ©) + if err != nil { + // If unmarshaling fails, return original schema + return schema + } + + return copy +} + +// removes security parameters from the required array +func filterRequiredArray(required []interface{}, toRemove []string) []interface{} { + filtered := []interface{}{} + removeMap := make(map[string]bool) + + for _, r := range toRemove { + removeMap[r] = true + } + + // filter our security params + for _, item := range required { + if str, ok := item.(string); ok { + if !removeMap[str] { + filtered = append(filtered, item) + } + } + } + + return filtered +} diff --git a/internal/services/toolkit/tools/xtramcp/tool.go b/internal/services/toolkit/tools/xtramcp/tool.go index f9a4e47..2b04c1e 100644 --- a/internal/services/toolkit/tools/xtramcp/tool.go +++ b/internal/services/toolkit/tools/xtramcp/tool.go @@ -9,12 +9,14 @@ import ( "net/http" "paperdebugger/internal/libs/db" "paperdebugger/internal/services" + "paperdebugger/internal/services/toolkit" toolCallRecordDB "paperdebugger/internal/services/toolkit/db" "time" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/packages/param" "github.com/openai/openai-go/v2/responses" + "go.mongodb.org/mongo-driver/v2/mongo" ) // ToolSchema represents the schema from your backend @@ -41,39 +43,47 @@ type MCPParams struct { // DynamicTool represents a generic tool that can handle any schema type DynamicTool struct { - Name string - Description responses.ToolUnionParam - toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB - projectService *services.ProjectService - coolDownTime time.Duration - baseURL string - client *http.Client - schema map[string]interface{} - sessionID string // Reuse the session ID from initialization + Name string + Description responses.ToolUnionParam + toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB + projectService *services.ProjectService + coolDownTime time.Duration + baseURL string + client *http.Client + schema map[string]interface{} + sessionID string // Reuse the session ID from initialization + requiresInjection bool // Indicates if this tool needs user/project injection } // NewDynamicTool creates a new dynamic tool from a schema -func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string) *DynamicTool { - // Create tool description with the schema +func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string, requiresInjection bool) *DynamicTool { + // filter schema if injection is required (hide security context like user_id/project_id from LLM) + schemaForLLM := toolSchema.InputSchema + if requiresInjection { + schemaForLLM = filterSecurityParameters(toolSchema.InputSchema) + } + description := responses.ToolUnionParam{ OfFunction: &responses.FunctionToolParam{ Name: toolSchema.Name, Description: param.NewOpt(toolSchema.Description), - Parameters: openai.FunctionParameters(toolSchema.InputSchema), + Parameters: openai.FunctionParameters(schemaForLLM), // Use filtered schema }, } toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db) + //TODO: consider letting llm client know of output schema too return &DynamicTool{ - Name: toolSchema.Name, - Description: description, - toolCallRecordDB: toolCallRecordDB, - projectService: projectService, - coolDownTime: 5 * time.Minute, - baseURL: baseURL, - client: &http.Client{}, - schema: toolSchema.InputSchema, - sessionID: sessionID, // Store the session ID for reuse + Name: toolSchema.Name, + Description: description, + toolCallRecordDB: toolCallRecordDB, + projectService: projectService, + coolDownTime: 5 * time.Minute, + baseURL: baseURL, + client: &http.Client{}, + schema: toolSchema.InputSchema, // Store original schema for validation + sessionID: sessionID, // Store the session ID for reuse + requiresInjection: requiresInjection, } } @@ -86,7 +96,14 @@ func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.Raw return "", "", err } - // Create function call record + // inject user/project context if required + if t.requiresInjection { + err := t.injectSecurityContext(ctx, argsMap) + if err != nil { + return "", "", fmt.Errorf("security context injection failed: %w", err) + } + } + record, err := t.toolCallRecordDB.Create(ctx, toolCallId, t.Name, argsMap) if err != nil { return "", "", err @@ -111,6 +128,42 @@ func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.Raw return respStr, "", nil } +// extracts user/project from context and injects into arguments +func (t *DynamicTool) injectSecurityContext(ctx context.Context, argsMap map[string]interface{}) error { + // 1. Extract from context + actor, projectId, _ := toolkit.GetActorProjectConversationID(ctx) + if actor == nil || projectId == "" { + return fmt.Errorf("authentication required: user context not found") + } + + // 2. Validate user owns the project + _, err := t.projectService.GetProject(ctx, actor.ID, projectId) + if err != nil { + if err == mongo.ErrNoDocuments { + return fmt.Errorf("authorization failed: project not found or access denied") + } + return fmt.Errorf("authorization check failed: %w", err) + } + + // 3. Check if tool schema expects these parameters + properties, ok := t.schema["properties"].(map[string]interface{}) + if !ok { + return fmt.Errorf("invalid tool schema: properties not found") + } + + // 4. Inject user_id if expected by tool + if _, hasUserId := properties["user_id"]; hasUserId { + argsMap["user_id"] = actor.ID.Hex() + } + + // 5. Inject project_id if expected by tool + if _, hasProjectId := properties["project_id"]; hasProjectId { + argsMap["project_id"] = projectId + } + + return nil +} + // executeTool makes the MCP request (generic for any tool) func (t *DynamicTool) executeTool(args map[string]interface{}) (string, error) { diff --git a/webapp/_webapp/src/components/message-entry-container/tools/tools.tsx b/webapp/_webapp/src/components/message-entry-container/tools/tools.tsx index 3f4b4c8..dc43d54 100644 --- a/webapp/_webapp/src/components/message-entry-container/tools/tools.tsx +++ b/webapp/_webapp/src/components/message-entry-container/tools/tools.tsx @@ -25,7 +25,7 @@ const XTRA_MCP_TOOL_NAMES = [ // "deep_research", // REVIEWER TOOLS "review_paper", - // "verify_citations" + "verify_citations", // ENHANCER TOOLS // "enhance_academic_writing", // OPENREVIEW ONLINE TOOLS