Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions internal/services/toolkit/tools/xtramcp/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,45 @@ func (loader *XtraMCPLoader) LoadToolsFromBackend(toolRegistry *registry.ToolReg

// Register each tool dynamically, passing the session ID
for _, toolSchema := range toolSchemas {
dynamicTool := NewDynamicTool(loader.db, loader.projectService, toolSchema, loader.baseURL, loader.sessionID)
// some tools require secrutiy context injection e.g. user_id to authenticate
requiresInjection := loader.requiresSecurityInjection(toolSchema)

dynamicTool := NewDynamicTool(
loader.db,
loader.projectService,
toolSchema,
loader.baseURL,
loader.sessionID,
requiresInjection,
)

// Register the tool with the registry
toolRegistry.Register(toolSchema.Name, dynamicTool.Description, dynamicTool.Call)

fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
if requiresInjection {
fmt.Printf("Registered dynamic tool with security injection: %s\n", toolSchema.Name)
} else {
fmt.Printf("Registered dynamic tool: %s\n", toolSchema.Name)
}
}

return nil
}

// checks if a tool schema contains parameters that should be inejected instead of LLM-generated
func (loader *XtraMCPLoader) requiresSecurityInjection(schema ToolSchema) bool {
properties, ok := schema.InputSchema["properties"].(map[string]interface{})
if !ok {
return false
}

// injected parameters
_, hasUserId := properties["user_id"]
_, hasProjectId := properties["project_id"]

return hasUserId || hasProjectId
}

// InitializeMCP performs the full MCP initialization handshake, stores session ID, and returns it
func (loader *XtraMCPLoader) InitializeMCP() (string, error) {
// Step 1: Initialize
Expand Down
65 changes: 65 additions & 0 deletions internal/services/toolkit/tools/xtramcp/schema_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package xtramcp

import "encoding/json"

// parameters that should be injected server-side
var securityParameters = []string{"user_id", "project_id"}

// removes security parameters from schema shown to LLM so LLM does not need to generate / fill
func filterSecurityParameters(schema map[string]interface{}) map[string]interface{} {
filtered := deepCopySchema(schema)

// Remove from properties
if properties, ok := filtered["properties"].(map[string]interface{}); ok {
for _, param := range securityParameters {
delete(properties, param)
}
}

// Remove from required array
if required, ok := filtered["required"].([]interface{}); ok {
filtered["required"] = filterRequiredArray(required, securityParameters)
}

return filtered
}

// creates a deep copy of the schema using JSON marshal/unmarshal
func deepCopySchema(schema map[string]interface{}) map[string]interface{} {
// Use JSON marshal/unmarshal for deep copy
jsonBytes, err := json.Marshal(schema)
if err != nil {
// If marshaling fails, return original schema
return schema
}

var copy map[string]interface{}
err = json.Unmarshal(jsonBytes, &copy)
if err != nil {
// If unmarshaling fails, return original schema
return schema
}

return copy
}

// removes security parameters from the required array
func filterRequiredArray(required []interface{}, toRemove []string) []interface{} {
filtered := []interface{}{}
removeMap := make(map[string]bool)

for _, r := range toRemove {
removeMap[r] = true
}

// filter our security params
for _, item := range required {
if str, ok := item.(string); ok {
if !removeMap[str] {
filtered = append(filtered, item)
}
}
}

return filtered
}
97 changes: 75 additions & 22 deletions internal/services/toolkit/tools/xtramcp/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ import (
"net/http"
"paperdebugger/internal/libs/db"
"paperdebugger/internal/services"
"paperdebugger/internal/services/toolkit"
toolCallRecordDB "paperdebugger/internal/services/toolkit/db"
"time"

"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/packages/param"
"github.com/openai/openai-go/v2/responses"
"go.mongodb.org/mongo-driver/v2/mongo"
)

// ToolSchema represents the schema from your backend
Expand All @@ -41,39 +43,47 @@ type MCPParams struct {

// DynamicTool represents a generic tool that can handle any schema
type DynamicTool struct {
Name string
Description responses.ToolUnionParam
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
projectService *services.ProjectService
coolDownTime time.Duration
baseURL string
client *http.Client
schema map[string]interface{}
sessionID string // Reuse the session ID from initialization
Name string
Description responses.ToolUnionParam
toolCallRecordDB *toolCallRecordDB.ToolCallRecordDB
projectService *services.ProjectService
coolDownTime time.Duration
baseURL string
client *http.Client
schema map[string]interface{}
sessionID string // Reuse the session ID from initialization
requiresInjection bool // Indicates if this tool needs user/project injection
}

// NewDynamicTool creates a new dynamic tool from a schema
func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string) *DynamicTool {
// Create tool description with the schema
func NewDynamicTool(db *db.DB, projectService *services.ProjectService, toolSchema ToolSchema, baseURL string, sessionID string, requiresInjection bool) *DynamicTool {
// filter schema if injection is required (hide security context like user_id/project_id from LLM)
schemaForLLM := toolSchema.InputSchema
if requiresInjection {
schemaForLLM = filterSecurityParameters(toolSchema.InputSchema)
}

description := responses.ToolUnionParam{
OfFunction: &responses.FunctionToolParam{
Name: toolSchema.Name,
Description: param.NewOpt(toolSchema.Description),
Parameters: openai.FunctionParameters(toolSchema.InputSchema),
Parameters: openai.FunctionParameters(schemaForLLM), // Use filtered schema
},
}

toolCallRecordDB := toolCallRecordDB.NewToolCallRecordDB(db)
//TODO: consider letting llm client know of output schema too
return &DynamicTool{
Name: toolSchema.Name,
Description: description,
toolCallRecordDB: toolCallRecordDB,
projectService: projectService,
coolDownTime: 5 * time.Minute,
baseURL: baseURL,
client: &http.Client{},
schema: toolSchema.InputSchema,
sessionID: sessionID, // Store the session ID for reuse
Name: toolSchema.Name,
Description: description,
toolCallRecordDB: toolCallRecordDB,
projectService: projectService,
coolDownTime: 5 * time.Minute,
baseURL: baseURL,
client: &http.Client{},
schema: toolSchema.InputSchema, // Store original schema for validation
sessionID: sessionID, // Store the session ID for reuse
requiresInjection: requiresInjection,
}
}

Expand All @@ -86,7 +96,14 @@ func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.Raw
return "", "", err
}

// Create function call record
// inject user/project context if required
if t.requiresInjection {
err := t.injectSecurityContext(ctx, argsMap)
if err != nil {
return "", "", fmt.Errorf("security context injection failed: %w", err)
}
}

record, err := t.toolCallRecordDB.Create(ctx, toolCallId, t.Name, argsMap)
if err != nil {
return "", "", err
Expand All @@ -111,6 +128,42 @@ func (t *DynamicTool) Call(ctx context.Context, toolCallId string, args json.Raw
return respStr, "", nil
}

// extracts user/project from context and injects into arguments
func (t *DynamicTool) injectSecurityContext(ctx context.Context, argsMap map[string]interface{}) error {
// 1. Extract from context
actor, projectId, _ := toolkit.GetActorProjectConversationID(ctx)
if actor == nil || projectId == "" {
return fmt.Errorf("authentication required: user context not found")
}

// 2. Validate user owns the project
_, err := t.projectService.GetProject(ctx, actor.ID, projectId)
if err != nil {
if err == mongo.ErrNoDocuments {
return fmt.Errorf("authorization failed: project not found or access denied")
}
return fmt.Errorf("authorization check failed: %w", err)
}

// 3. Check if tool schema expects these parameters
properties, ok := t.schema["properties"].(map[string]interface{})
if !ok {
return fmt.Errorf("invalid tool schema: properties not found")
}

// 4. Inject user_id if expected by tool
if _, hasUserId := properties["user_id"]; hasUserId {
argsMap["user_id"] = actor.ID.Hex()
}

// 5. Inject project_id if expected by tool
if _, hasProjectId := properties["project_id"]; hasProjectId {
argsMap["project_id"] = projectId
}

return nil
}

// executeTool makes the MCP request (generic for any tool)
func (t *DynamicTool) executeTool(args map[string]interface{}) (string, error) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const XTRA_MCP_TOOL_NAMES = [
// "deep_research",
// REVIEWER TOOLS
"review_paper",
// "verify_citations"
"verify_citations",
// ENHANCER TOOLS
// "enhance_academic_writing",
// OPENREVIEW ONLINE TOOLS
Expand Down