From b7dd124144a9c6fae9284ddab06953976ebe07f8 Mon Sep 17 00:00:00 2001 From: Sebastian Rath Date: Sun, 15 Feb 2026 23:21:45 -0500 Subject: [PATCH 1/4] Add a MCP server with debug tools and a WebSocket bridge including e2e tests --- cmd/cmd_mcp.go | 25 ++ examples/mcp-debug-example.act | 144 +++++++++ go.mod | 8 + go.sum | 19 ++ mcp/debug/bridge.go | 230 ++++++++++++++ mcp/debug/register.go | 125 ++++++++ mcp/debug/tools.go | 281 ++++++++++++++++++ mcp/server.go | 23 ++ mcp/tools_graph.go | 240 +++++++++++++++ .../reference_mcp_debug_bridge.sh_l8 | 32 ++ tests_e2e/scripts/mcp_debug_bridge.py | 158 ++++++++++ tests_e2e/scripts/mcp_debug_bridge.sh | 8 + 12 files changed, 1293 insertions(+) create mode 100644 cmd/cmd_mcp.go create mode 100644 examples/mcp-debug-example.act create mode 100644 mcp/debug/bridge.go create mode 100644 mcp/debug/register.go create mode 100644 mcp/debug/tools.go create mode 100644 mcp/server.go create mode 100644 mcp/tools_graph.go create mode 100644 tests_e2e/references/reference_mcp_debug_bridge.sh_l8 create mode 100644 tests_e2e/scripts/mcp_debug_bridge.py create mode 100644 tests_e2e/scripts/mcp_debug_bridge.sh diff --git a/cmd/cmd_mcp.go b/cmd/cmd_mcp.go new file mode 100644 index 0000000..ea1d790 --- /dev/null +++ b/cmd/cmd_mcp.go @@ -0,0 +1,25 @@ +package cmd + +import ( + "fmt" + "os" + + mcpserver "github.com/actionforge/actrun-cli/mcp" + "github.com/spf13/cobra" +) + +var cmdMcp = &cobra.Command{ + Use: "mcp", + Short: "Start the MCP server (stdio transport).", + Long: `Starts an MCP server over stdio that exposes graph tools (validate, schema, node types) and debug tools for bridging between an AI agent and an actrun local debug session (WebSocket). Configure this as an MCP server in your AI tool with: {"command": "actrun", "args": ["mcp"]}`, + Run: func(cmd *cobra.Command, args []string) { + if err := mcpserver.RunMCPServer(ActfileSchema); err != nil { + fmt.Fprintf(os.Stderr, "MCP server error: %v\n", err) + os.Exit(1) + } + }, +} + +func init() { + cmdRoot.AddCommand(cmdMcp) +} diff --git a/examples/mcp-debug-example.act b/examples/mcp-debug-example.act new file mode 100644 index 0000000..d80cfcd --- /dev/null +++ b/examples/mcp-debug-example.act @@ -0,0 +1,144 @@ +editor: + version: + created: v1.34.0 +entry: start +type: generic +nodes: + - id: start + type: core/start@v1 + position: + x: -200 + y: 0 + - id: greeting-text + type: core/const-string@v1 + position: + x: -200 + y: 200 + inputs: + value: "Hello from MCP debug session!" + - id: print-greeting + type: core/print@v1 + position: + x: 200 + y: 0 + inputs: + values[0]: null + color: fg_green + # Group node: contains two inner print steps to demonstrate step-into / step-out + - id: my-group + type: core/group@v1 + position: + x: 500 + y: 0 + graph: + entry: group-inputs + type: group + nodes: + - id: group-inputs + type: core/group-inputs@v1 + position: + x: 10 + y: 60 + - id: group-outputs + type: core/group-outputs@v1 + position: + x: 600 + y: 60 + - id: inner-text + type: core/const-string@v1 + position: + x: 200 + y: 200 + inputs: + value: "Inside the group!" + - id: inner-print + type: core/print@v1 + position: + x: 400 + y: 60 + inputs: + values[0]: null + color: fg_cyan + connections: + - src: + node: inner-text + port: result + dst: + node: inner-print + port: values[0] + executions: + - src: + node: group-inputs + port: exec-in + dst: + node: inner-print + port: exec + - src: + node: inner-print + port: exec + dst: + node: group-outputs + port: exec-out + inputs: + exec-in: + type: '' + index: 0 + exec: true + outputs: + exec-out: + name: '' + type: '' + index: 0 + exec: true + - id: done-text + type: core/const-string@v1 + position: + x: 800 + y: 200 + inputs: + value: "Done! Debug session complete." + - id: print-done + type: core/print@v1 + position: + x: 1000 + y: 0 + inputs: + values[0]: null + color: fg_yellow +connections: + - src: + node: greeting-text + port: result + dst: + node: print-greeting + port: values[0] + isLoop: false + - src: + node: done-text + port: result + dst: + node: print-done + port: values[0] + isLoop: false +executions: + - src: + node: start + port: exec + dst: + node: print-greeting + port: exec + isLoop: false + - src: + node: print-greeting + port: exec + dst: + node: my-group + port: exec-in + isLoop: false + - src: + node: my-group + port: exec-out + dst: + node: print-done + port: exec + isLoop: false diff --git a/go.mod b/go.mod index 8c21be6..aad2c03 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/inconshreveable/mousetrap v1.1.0 github.com/joho/godotenv v1.5.1 + github.com/mark3labs/mcp-go v0.44.0 github.com/pkg/errors v0.9.1 github.com/rhysd/actionlint v1.7.10 github.com/rossmacarthur/cases v0.3.0 @@ -66,7 +67,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect github.com/aws/smithy-go v1.24.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/bmatcuk/doublestar/v4 v4.9.1 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect @@ -91,10 +94,12 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect github.com/googleapis/gax-go/v2 v2.16.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/kevinburke/ssh_config v1.4.0 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect @@ -110,10 +115,13 @@ require ( github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sergi/go-diff v1.4.0 // indirect github.com/skeema/knownhosts v1.3.2 // indirect + github.com/spf13/cast v1.7.1 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.2.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.64.0 // indirect diff --git a/go.sum b/go.sum index 84db7ce..d2f4b05 100644 --- a/go.sum +++ b/go.sum @@ -73,8 +73,12 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/bmatcuk/doublestar/v4 v4.9.1 h1:X8jg9rRZmJd4yRy7ZeNDRnM+T3ZfHv15JiBJ/avrEXE= github.com/bmatcuk/doublestar/v4 v4.9.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= @@ -124,6 +128,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/gage-technologies/mistral-go v1.1.0 h1:POv1wM9jA/9OBXGV2YdPi9Y/h09+MjCbUF+9hRYlVUI= github.com/gage-technologies/mistral-go v1.1.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28= github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= @@ -168,10 +174,13 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLW github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ= github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= @@ -185,6 +194,10 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.44.0 h1:OlYfcVviAnwNN40QZUrrzU0QZjq3En7rCU5X09a/B7I= +github.com/mark3labs/mcp-go v0.44.0/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -240,6 +253,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skeema/knownhosts v1.3.2 h1:EDL9mgf4NzwMXCTfaxSD/o/a5fxDw/xL9nkU28JjdBg= github.com/skeema/knownhosts v1.3.2/go.mod h1:bEg3iQAuw+jyiw+484wwFJoKSLwcfd7fqRy+N0QTiow= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -255,6 +270,8 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tmc/langchaingo v0.1.14 h1:o1qWBPigAIuFvrG6cjTFo0cZPFEZ47ZqpOYMjM15yZc= github.com/tmc/langchaingo v0.1.14/go.mod h1:aKKYXYoqhIDEv7WKdpnnCLRaqXic69cX9MnDUk72378= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= @@ -263,6 +280,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs= github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/mcp/debug/bridge.go b/mcp/debug/bridge.go new file mode 100644 index 0000000..a24c002 --- /dev/null +++ b/mcp/debug/bridge.go @@ -0,0 +1,230 @@ +package mcpdebug + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// LogEntry represents a buffered log message from the debug session. +type LogEntry struct { + Level string `json:"level"` + Message string `json:"message"` +} + +// IncomingMessage represents a parsed message from the WebSocket. +type IncomingMessage struct { + Type string `json:"type"` + FullPath string `json:"fullPath,omitempty"` + ExecutionContext any `json:"executionContext,omitempty"` + Message string `json:"message,omitempty"` + Error string `json:"error,omitempty"` +} + +// Bridge is a stateful WebSocket client that converts the push-based WS +// stream into pull-based tool results for the MCP server. +type Bridge struct { + mu sync.Mutex + conn *websocket.Conn + connected bool + readErr error + + logBuffer []LogEntry + lastState *IncomingMessage + + waiter chan IncomingMessage + done chan struct{} +} + +// NewBridge creates a new Bridge instance. +func NewBridge() *Bridge { + return &Bridge{} +} + +// Connect dials the local debug server and starts the read loop. +func (b *Bridge) Connect(port int) error { + b.mu.Lock() + defer b.mu.Unlock() + + if b.connected { + return fmt.Errorf("already connected") + } + + url := fmt.Sprintf("ws://127.0.0.1:%d/ws", port) + conn, _, err := websocket.DefaultDialer.Dial(url, nil) + if err != nil { + return fmt.Errorf("failed to connect to %s: %w", url, err) + } + + b.conn = conn + b.connected = true + b.readErr = nil + b.logBuffer = nil + b.lastState = nil + b.waiter = make(chan IncomingMessage, 1) + b.done = make(chan struct{}) + + go b.readLoop() + + return nil +} + +// readLoop reads messages from the WebSocket and dispatches them. +func (b *Bridge) readLoop() { + defer func() { + b.mu.Lock() + b.connected = false + close(b.done) + b.mu.Unlock() + }() + + for { + _, msgBytes, err := b.conn.ReadMessage() + if err != nil { + b.mu.Lock() + b.readErr = err + b.mu.Unlock() + return + } + + var msg IncomingMessage + if err := json.Unmarshal(msgBytes, &msg); err != nil { + continue + } + + b.mu.Lock() + switch msg.Type { + case "log": + b.logBuffer = append(b.logBuffer, LogEntry{ + Level: "log", + Message: msg.Message, + }) + case "log_error": + b.logBuffer = append(b.logBuffer, LogEntry{ + Level: "error", + Message: msg.Message, + }) + case "warning": + b.logBuffer = append(b.logBuffer, LogEntry{ + Level: "warning", + Message: msg.Message, + }) + case "debug_state": + b.lastState = &msg + // Deliver to waiter if someone is waiting + select { + case b.waiter <- msg: + default: + } + case "job_finished", "job_error": + select { + case b.waiter <- msg: + default: + } + case "control": + // Control messages (e.g. runner_connected) are informational + b.logBuffer = append(b.logBuffer, LogEntry{ + Level: "info", + Message: fmt.Sprintf("control: %s", msg.Message), + }) + } + b.mu.Unlock() + } +} + +// Send marshals payload as JSON and writes it to the WebSocket. +func (b *Bridge) Send(payload any) error { + b.mu.Lock() + if !b.connected { + b.mu.Unlock() + return fmt.Errorf("not connected") + } + conn := b.conn + b.mu.Unlock() + + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + return fmt.Errorf("failed to write message: %w", err) + } + return nil +} + +// SendAndWait sends a payload and blocks until a debug_state, job_finished, +// or job_error message is received, or until the timeout expires. +func (b *Bridge) SendAndWait(payload any, timeout time.Duration) (*IncomingMessage, []LogEntry, error) { + b.mu.Lock() + if !b.connected { + b.mu.Unlock() + return nil, nil, fmt.Errorf("not connected") + } + // Drain any stale message from waiter channel + select { + case <-b.waiter: + default: + } + done := b.done + b.mu.Unlock() + + if err := b.Send(payload); err != nil { + return nil, nil, err + } + + select { + case msg := <-b.waiter: + logs := b.DrainLogs() + return &msg, logs, nil + case <-done: + return nil, nil, fmt.Errorf("connection closed while waiting for response") + case <-time.After(timeout): + return nil, nil, fmt.Errorf("timeout waiting for response after %s", timeout) + } +} + +// DrainLogs returns all buffered log entries and clears the buffer. +func (b *Bridge) DrainLogs() []LogEntry { + b.mu.Lock() + defer b.mu.Unlock() + + logs := b.logBuffer + b.logBuffer = nil + return logs +} + +// LastState returns the last received debug_state message. +func (b *Bridge) LastState() *IncomingMessage { + b.mu.Lock() + defer b.mu.Unlock() + return b.lastState +} + +// Connected returns whether the bridge is connected. +func (b *Bridge) Connected() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.connected +} + +// Disconnect closes the WebSocket connection. +func (b *Bridge) Disconnect() error { + b.mu.Lock() + defer b.mu.Unlock() + + if !b.connected { + return fmt.Errorf("not connected") + } + + err := b.conn.WriteMessage( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ) + b.conn.Close() + b.connected = false + return err +} diff --git a/mcp/debug/register.go b/mcp/debug/register.go new file mode 100644 index 0000000..62f351b --- /dev/null +++ b/mcp/debug/register.go @@ -0,0 +1,125 @@ +package mcpdebug + +import ( + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// RegisterDebugTools creates a WebSocket bridge and registers all debug +// tools on the given MCP server. +func RegisterDebugTools(s *server.MCPServer) { + bridge := NewBridge() + + s.AddTool( + mcp.NewTool("debug_connect", + mcp.WithDescription("Connect to a running actrun local debug server"), + mcp.WithNumber("port", + mcp.Description("The port number of the local debug server (printed as LOCAL_WS_PORT when starting actrun --local)"), + mcp.Required(), + ), + ), + handleDebugConnect(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_run", + mcp.WithDescription("Start executing a graph in the debug session. Sends the graph YAML content and waits for the first pause or completion."), + mcp.WithString("graph", + mcp.Description("The full YAML content of the .act graph file"), + mcp.Required(), + ), + mcp.WithArray("breakpoints", + mcp.Description("List of node IDs to set as breakpoints before execution"), + ), + mcp.WithBoolean("start_paused", + mcp.Description("Whether to pause at the first node (default: true)"), + ), + ), + handleDebugRun(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_step", + mcp.WithDescription("Step over: execute the current node and pause at the next node at the same depth or shallower"), + ), + handleDebugStep(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_step_into", + mcp.WithDescription("Step into: if the current node is a group, pause at the first node inside it; otherwise behaves like step"), + ), + handleDebugStepInto(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_step_out", + mcp.WithDescription("Step out: resume execution and pause when returning to a shallower depth (parent group)"), + ), + handleDebugStepOut(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_resume", + mcp.WithDescription("Resume execution until the next breakpoint is hit or the graph completes"), + ), + handleDebugResume(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_pause", + mcp.WithDescription("Pause execution at the next node visit"), + ), + handleDebugPause(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_set_breakpoint", + mcp.WithDescription("Set a breakpoint at a node. Execution will pause when this node is visited."), + mcp.WithString("node_id", + mcp.Description("The full path or ID of the node to set a breakpoint on"), + mcp.Required(), + ), + ), + handleDebugSetBreakpoint(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_remove_breakpoint", + mcp.WithDescription("Remove a previously set breakpoint from a node"), + mcp.WithString("node_id", + mcp.Description("The full path or ID of the node to remove the breakpoint from"), + mcp.Required(), + ), + ), + handleDebugRemoveBreakpoint(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_inspect", + mcp.WithDescription("Return the last debug state including current node, visited nodes, and execution context (variables, outputs, etc.)"), + ), + handleDebugInspect(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_logs", + mcp.WithDescription("Return and clear buffered log messages (stdout, stderr, warnings) from the debug session"), + ), + handleDebugLogs(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_stop", + mcp.WithDescription("Stop the currently running graph execution"), + ), + handleDebugStop(bridge), + ) + + s.AddTool( + mcp.NewTool("debug_disconnect", + mcp.WithDescription("Disconnect from the debug server and close the WebSocket connection"), + ), + handleDebugDisconnect(bridge), + ) +} diff --git a/mcp/debug/tools.go b/mcp/debug/tools.go new file mode 100644 index 0000000..896829a --- /dev/null +++ b/mcp/debug/tools.go @@ -0,0 +1,281 @@ +package mcpdebug + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func textResult(text string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: text, + }, + }, + } +} + +func errorResult(msg string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: msg, + }, + }, + IsError: true, + } +} + +func jsonResult(v any) *mcp.CallToolResult { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return errorResult(fmt.Sprintf("failed to marshal result: %v", err)) + } + return textResult(string(data)) +} + +// blockingResponse builds the standard response for blocking tools. +func blockingResponse(msg *IncomingMessage, logs []LogEntry) *mcp.CallToolResult { + status := "unknown" + switch msg.Type { + case "debug_state": + status = "paused" + case "job_finished": + status = "finished" + case "job_error": + status = "error" + } + + resp := map[string]any{ + "status": status, + } + if msg.FullPath != "" { + resp["current_node"] = msg.FullPath + } + if msg.ExecutionContext != nil { + resp["execution_context"] = msg.ExecutionContext + } + if msg.Error != "" { + resp["error"] = msg.Error + } + if len(logs) > 0 { + resp["logs"] = logs + } + + return jsonResult(resp) +} + +func requireConnected(b *Bridge) error { + if !b.Connected() { + return fmt.Errorf("not connected to debug session — call debug_connect first") + } + return nil +} + +// handleDebugConnect connects to a running local debug server. +func handleDebugConnect(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + port, err := req.RequireInt("port") + if err != nil { + return errorResult("missing required parameter: port"), nil + } + + if err := b.Connect(port); err != nil { + return errorResult(fmt.Sprintf("connect failed: %v", err)), nil + } + return textResult(fmt.Sprintf("Connected to debug server on port %d", port)), nil + } +} + +// handleDebugRun sends a run command with graph content and waits for pause/finish. +func handleDebugRun(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + + graph, err := req.RequireString("graph") + if err != nil { + return errorResult("missing required parameter: graph"), nil + } + + startPaused := req.GetBool("start_paused", true) + breakpoints := req.GetStringSlice("breakpoints", nil) + + payload := map[string]any{ + "type": "run", + "payload": graph, + "start_paused": startPaused, + } + if len(breakpoints) > 0 { + payload["breakpoints"] = breakpoints + } + + msg, logs, err := b.SendAndWait(payload, 120*time.Second) + if err != nil { + return errorResult(fmt.Sprintf("run failed: %v", err)), nil + } + return blockingResponse(msg, logs), nil + } +} + +// handleDebugStep sends a step-over command and waits. +func handleDebugStep(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_step"}, 60*time.Second) + if err != nil { + return errorResult(fmt.Sprintf("step failed: %v", err)), nil + } + return blockingResponse(msg, logs), nil + } +} + +// handleDebugStepInto sends a step-into command and waits. +func handleDebugStepInto(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_step_into"}, 60*time.Second) + if err != nil { + return errorResult(fmt.Sprintf("step into failed: %v", err)), nil + } + return blockingResponse(msg, logs), nil + } +} + +// handleDebugStepOut sends a step-out command and waits. +func handleDebugStepOut(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_step_out"}, 60*time.Second) + if err != nil { + return errorResult(fmt.Sprintf("step out failed: %v", err)), nil + } + return blockingResponse(msg, logs), nil + } +} + +// handleDebugResume sends a resume command and waits for next pause/finish. +func handleDebugResume(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_resume"}, 120*time.Second) + if err != nil { + return errorResult(fmt.Sprintf("resume failed: %v", err)), nil + } + return blockingResponse(msg, logs), nil + } +} + +// handleDebugPause sends a pause command (fire-and-forget). +func handleDebugPause(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + if err := b.Send(map[string]string{"type": "debug_pause"}); err != nil { + return errorResult(fmt.Sprintf("pause failed: %v", err)), nil + } + return textResult("Pause signal sent"), nil + } +} + +// handleDebugSetBreakpoint adds a breakpoint at a node. +func handleDebugSetBreakpoint(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + nodeID, err := req.RequireString("node_id") + if err != nil { + return errorResult("missing required parameter: node_id"), nil + } + if err := b.Send(map[string]string{"type": "debug_add_breakpoint", "nodeId": nodeID}); err != nil { + return errorResult(fmt.Sprintf("set breakpoint failed: %v", err)), nil + } + return textResult(fmt.Sprintf("Breakpoint set at %s", nodeID)), nil + } +} + +// handleDebugRemoveBreakpoint removes a breakpoint from a node. +func handleDebugRemoveBreakpoint(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + nodeID, err := req.RequireString("node_id") + if err != nil { + return errorResult("missing required parameter: node_id"), nil + } + if err := b.Send(map[string]string{"type": "debug_remove_breakpoint", "nodeId": nodeID}); err != nil { + return errorResult(fmt.Sprintf("remove breakpoint failed: %v", err)), nil + } + return textResult(fmt.Sprintf("Breakpoint removed from %s", nodeID)), nil + } +} + +// handleDebugInspect returns the last debug state. +func handleDebugInspect(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + state := b.LastState() + if state == nil { + return textResult("No debug state available yet — the graph may not have paused."), nil + } + return jsonResult(state), nil + } +} + +// handleDebugLogs drains and returns buffered log entries. +func handleDebugLogs(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + logs := b.DrainLogs() + if len(logs) == 0 { + return textResult("No new log entries."), nil + } + return jsonResult(logs), nil + } +} + +// handleDebugStop sends a stop command (fire-and-forget). +func handleDebugStop(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := requireConnected(b); err != nil { + return errorResult(err.Error()), nil + } + if err := b.Send(map[string]string{"type": "stop"}); err != nil { + return errorResult(fmt.Sprintf("stop failed: %v", err)), nil + } + return textResult("Stop signal sent"), nil + } +} + +// handleDebugDisconnect closes the WebSocket connection. +func handleDebugDisconnect(b *Bridge) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if err := b.Disconnect(); err != nil { + return errorResult(fmt.Sprintf("disconnect failed: %v", err)), nil + } + return textResult("Disconnected from debug server"), nil + } +} diff --git a/mcp/server.go b/mcp/server.go new file mode 100644 index 0000000..13e9b1a --- /dev/null +++ b/mcp/server.go @@ -0,0 +1,23 @@ +package mcpserver + +import ( + "github.com/actionforge/actrun-cli/build" + mcpdebug "github.com/actionforge/actrun-cli/mcp/debug" + "github.com/mark3labs/mcp-go/server" +) + +// RunMCPServer creates the MCP server, registers all tools (graph + debug), +// and serves over stdio. It blocks until the stdio transport closes. +func RunMCPServer(actfileSchema []byte) error { + version := build.GetAppVersion() + s := server.NewMCPServer( + "actrun", + version, + server.WithToolCapabilities(false), + ) + + registerGraphTools(s, actfileSchema) + mcpdebug.RegisterDebugTools(s) + + return server.ServeStdio(s) +} diff --git a/mcp/tools_graph.go b/mcp/tools_graph.go new file mode 100644 index 0000000..f2dc143 --- /dev/null +++ b/mcp/tools_graph.go @@ -0,0 +1,240 @@ +package mcpserver + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strings" + + "github.com/actionforge/actrun-cli/core" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/santhosh-tekuri/jsonschema/v6" + "go.yaml.in/yaml/v4" +) + +func textResult(text string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: text, + }, + }, + } +} + +func errorResult(msg string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: msg, + }, + }, + IsError: true, + } +} + +func jsonResult(v any) *mcp.CallToolResult { + data, err := json.MarshalIndent(v, "", " ") + if err != nil { + return errorResult(fmt.Sprintf("failed to marshal result: %v", err)) + } + return textResult(string(data)) +} + +// registerGraphTools registers the non-debug graph tools on the server. +func registerGraphTools(s *server.MCPServer, actfileSchema []byte) { + s.AddTool( + mcp.NewTool("validate_graph", + mcp.WithDescription("Validate a graph YAML string against the JSON schema and structural rules. Returns a list of errors or a success message."), + mcp.WithString("graph", + mcp.Description("The full YAML content of the .act graph file"), + mcp.Required(), + ), + ), + handleValidateGraph(actfileSchema), + ) + + s.AddTool( + mcp.NewTool("get_schema", + mcp.WithDescription("Return the JSON schema for ActionForge .act graph files"), + ), + handleGetSchema(actfileSchema), + ) + + s.AddTool( + mcp.NewTool("list_node_types", + mcp.WithDescription("List all registered node types with id, name, version, category, and short description"), + mcp.WithString("category", + mcp.Description("Optional category filter (e.g. 'processing', 'control-flow')"), + ), + ), + handleListNodeTypes(), + ) + + s.AddTool( + mcp.NewTool("get_node_type", + mcp.WithDescription("Get full details for a specific node type including inputs, outputs, and descriptions"), + mcp.WithString("node_type_id", + mcp.Description("The node type ID (e.g. 'core/start@v1')"), + mcp.Required(), + ), + ), + handleGetNodeType(), + ) +} + +// convertToJSONCompatible recursively converts YAML-unmarshalled data into +// types that the JSON schema validator accepts. +func convertToJSONCompatible(v any) any { + switch val := v.(type) { + case map[string]any: + result := make(map[string]any, len(val)) + for k, v := range val { + result[k] = convertToJSONCompatible(v) + } + return result + case []any: + result := make([]any, len(val)) + for i, v := range val { + result[i] = convertToJSONCompatible(v) + } + return result + case int: + return float64(val) + case int64: + return float64(val) + case float32: + return float64(val) + default: + return val + } +} + +func handleValidateGraph(actfileSchema []byte) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + graph, err := req.RequireString("graph") + if err != nil { + return errorResult("missing required parameter: graph"), nil + } + + var graphYaml map[string]any + if err := yaml.Unmarshal([]byte(graph), &graphYaml); err != nil { + return jsonResult(map[string]any{ + "valid": false, + "errors": []string{fmt.Sprintf("YAML parse error: %v", err)}, + }), nil + } + + var allErrors []string + + // Schema validation + if len(actfileSchema) > 0 { + if err := validateSchema(actfileSchema, graphYaml); err != nil { + allErrors = append(allErrors, fmt.Sprintf("schema: %v", err)) + } + } + + // Structural validation + _, errs := core.LoadGraph(graphYaml, nil, "", true, core.RunOpts{}) + for _, e := range errs { + allErrors = append(allErrors, e.Error()) + } + + if len(allErrors) > 0 { + return jsonResult(map[string]any{ + "valid": false, + "errors": allErrors, + }), nil + } + + return jsonResult(map[string]any{ + "valid": true, + }), nil + } +} + +func validateSchema(schemaBytes []byte, data any) error { + var schemaObj any + if err := json.Unmarshal(schemaBytes, &schemaObj); err != nil { + return fmt.Errorf("failed to parse schema JSON: %w", err) + } + + compiler := jsonschema.NewCompiler() + if err := compiler.AddResource("actfile-schema.json", schemaObj); err != nil { + return fmt.Errorf("failed to add schema resource: %w", err) + } + + schema, err := compiler.Compile("actfile-schema.json") + if err != nil { + return fmt.Errorf("failed to compile schema: %w", err) + } + + return schema.Validate(convertToJSONCompatible(data)) +} + +func handleGetSchema(actfileSchema []byte) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if len(actfileSchema) == 0 { + return errorResult("schema not available"), nil + } + return textResult(string(actfileSchema)), nil + } +} + +type nodeTypeSummary struct { + Id string `json:"id"` + Name string `json:"name"` + Version string `json:"version"` + Category string `json:"category"` + ShortDesc string `json:"short_desc"` +} + +func handleListNodeTypes() server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + category := req.GetString("category", "") + + registries := core.GetRegistries() + result := make([]nodeTypeSummary, 0, len(registries)) + + for _, def := range registries { + if category != "" && !strings.EqualFold(def.Category, category) { + continue + } + result = append(result, nodeTypeSummary{ + Id: def.Id, + Name: def.Name, + Version: def.Version, + Category: def.Category, + ShortDesc: def.ShortDesc, + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Id < result[j].Id + }) + + return jsonResult(result), nil + } +} + +func handleGetNodeType() server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + nodeTypeId, err := req.RequireString("node_type_id") + if err != nil { + return errorResult("missing required parameter: node_type_id"), nil + } + + registries := core.GetRegistries() + def, ok := registries[nodeTypeId] + if !ok { + return errorResult(fmt.Sprintf("node type %q not found", nodeTypeId)), nil + } + + // Return the full definition minus the factory function (already excluded via json:"-") + return jsonResult(def), nil + } +} diff --git a/tests_e2e/references/reference_mcp_debug_bridge.sh_l8 b/tests_e2e/references/reference_mcp_debug_bridge.sh_l8 new file mode 100644 index 0000000..aba0c84 --- /dev/null +++ b/tests_e2e/references/reference_mcp_debug_bridge.sh_l8 @@ -0,0 +1,32 @@ +Cleaning up +Connecting to WebSocket +DEBUG PAUSED #1 at node: start (depth=0) +DEBUG PAUSED #2 at node: print-greeting (depth=0) +DEBUG PAUSED #3 at node: greeting-text (depth=0) +DEBUG PAUSED #4 at node: my-group (depth=0) +DEBUG PAUSED #5 at node: my-group/group-inputs (depth=1) +DEBUG PAUSED #6 at node: my-group (depth=0) +DEBUG PAUSED #7 at node: print-done (depth=0) +Job Finished Successfully! +Launching local runner +Log: Done! Debug session complete. +Log: Hello from MCP debug session! +Log: Inside the group! +Log: created temp working directory for debug session: [REDACTED]/actrun-debug-[REDACTED] +Log: debugging paused at node: greeting-text +Log: debugging paused at node: my-group +Log: debugging paused at node: my-group +Log: debugging paused at node: my-group/group-inputs +Log: debugging paused at node: print-done +Log: debugging paused at node: print-greeting +Log: debugging paused at node: start +Log: ✅ Job succeeded. (Total time: ) +Log: 🚀 Task started... +Runner connected! Sending Graph (Paused) +Sending RESUME command +Sending STEP command +Sending STEP command +Sending STEP command +Sending STEP command +Sending STEP_INTO command +Sending STEP_OUT command diff --git a/tests_e2e/scripts/mcp_debug_bridge.py b/tests_e2e/scripts/mcp_debug_bridge.py new file mode 100644 index 0000000..d09c3dc --- /dev/null +++ b/tests_e2e/scripts/mcp_debug_bridge.py @@ -0,0 +1,158 @@ +import asyncio +import json +import os +import re +import websockets + + +ACTRUN_PATH = "actrun" + + +def clean_and_print(text): + if not text: + return + + timestamp_pattern = r'\[?\d{4}[/-]\d{2}[/-]\d{2}\s+\d{2}:\d{2}:\d{2}\]?' + duration_pattern = r'\d+(?:\.\d+)?s' + + text = re.sub(timestamp_pattern, "", text) + text = re.sub(duration_pattern, "", text) + text = re.sub(r'actrun-debug-\d+', 'actrun-debug-[REDACTED]', text) + + # remove empty lines left over from the redaction + lines = [line.strip() for line in text.splitlines() if line.strip()] + + print("\n".join(lines)) + + +async def drain_stream(stream): + """Read and discard stream output to prevent buffer blocking.""" + while True: + line = await stream.readline() + if not line: + break + + +async def main(): + root_dir = os.environ.get("ACT_ROOT", os.path.join(os.environ.get("ACT_GRAPH_FILES_DIR", "."), "..", "..")) + graph_path = os.path.join(root_dir, "examples", "mcp-debug-example.act") + + with open(graph_path, "r") as f: + graph_content = f.read() + + clean_and_print("Launching local runner") + + env = os.environ.copy() + env["ACT_NOCOLOR"] = "true" + env["ACT_LOGLEVEL"] = "warn" + + process = await asyncio.create_subprocess_exec( + ACTRUN_PATH, "--local", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + drain_out = None + drain_err = None + + try: + # Read stdout lines until we find LOCAL_WS_PORT + port = None + while True: + line = await asyncio.wait_for(process.stdout.readline(), timeout=10) + if not line: + clean_and_print("ERROR: Runner exited before printing port") + return + text = line.decode().strip() + match = re.search(r'LOCAL_WS_PORT=(\d+)', text) + if match: + port = int(match.group(1)) + break + + # Drain remaining subprocess output in background + drain_out = asyncio.create_task(drain_stream(process.stdout)) + drain_err = asyncio.create_task(drain_stream(process.stderr)) + + clean_and_print("Connecting to WebSocket") + + pause_count = 0 + entered_group = False + + async with websockets.connect(f"ws://127.0.0.1:{port}/ws") as websocket: + async for message in websocket: + msg = json.loads(message) + msg_type = msg.get("type") + + if msg_type == "control": + if msg["message"] == "runner_connected": + clean_and_print("Runner connected! Sending Graph (Paused)") + + run_payload = { + "type": "run", + "payload": graph_content, + "start_paused": True, + "ignore_breakpoints": False, + "breakpoints": [], + } + await websocket.send(json.dumps(run_payload)) + + elif msg_type == "log": + clean_and_print(f"Log: {msg['message']}") + + elif msg_type == "log_error": + clean_and_print(f"LogError: {msg['message']}") + + elif msg_type == "debug_state": + pause_count += 1 + node = msg.get("fullPath", "unknown") + depth = node.count("/") + clean_and_print(f"DEBUG PAUSED #{pause_count} at node: {node} (depth={depth})") + + await asyncio.sleep(0.1) + + if node == "my-group" and not entered_group: + # First time at the group node — step INTO it + entered_group = True + clean_and_print("Sending STEP_INTO command") + await websocket.send(json.dumps({"type": "debug_step_into"})) + + elif "/" in node: + # Inside the group (depth > 0) — step to see inner nodes, + # then step OUT on the second inner node + clean_and_print("Sending STEP_OUT command") + await websocket.send(json.dumps({"type": "debug_step_out"})) + + elif node == "print-done": + # Past the group — resume to finish + clean_and_print("Sending RESUME command") + await websocket.send(json.dumps({"type": "debug_resume"})) + + else: + clean_and_print("Sending STEP command") + await websocket.send(json.dumps({"type": "debug_step"})) + + elif msg_type == "job_finished": + clean_and_print("Job Finished Successfully!") + break + + elif msg_type == "job_error": + clean_and_print(f"Job Error: {msg.get('error', 'unknown')}") + break + + finally: + clean_and_print("Cleaning up") + try: + process.terminate() + await process.wait() + except ProcessLookupError: + # process already exited + pass + if drain_out: + drain_out.cancel() + if drain_err: + drain_err.cancel() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests_e2e/scripts/mcp_debug_bridge.sh b/tests_e2e/scripts/mcp_debug_bridge.sh new file mode 100644 index 0000000..9196acd --- /dev/null +++ b/tests_e2e/scripts/mcp_debug_bridge.sh @@ -0,0 +1,8 @@ +echo "Test MCP Debug Bridge" + +set -o pipefail + +$PYTHON_EXECUTABLE -m pip install websockets + +# sort the output to make test stable +#! test $PYTHON_EXECUTABLE $ACT_GRAPH_FILES_DIR/mcp_debug_bridge.py | sort From 6844f803b766933a215c10ed52f30d1497512b31 Mon Sep 17 00:00:00 2001 From: Sebastian Rath Date: Mon, 16 Feb 2026 19:38:39 -0500 Subject: [PATCH 2/4] Improve MCP server with detailed instructions and improved debug tool descriptions --- cmd/cmd_mcp.go | 27 ++++++++++++++++++++-- mcp/debug/register.go | 54 ++++++++++++++++++++++++++++--------------- mcp/server.go | 19 +++++++++------ 3 files changed, 72 insertions(+), 28 deletions(-) diff --git a/cmd/cmd_mcp.go b/cmd/cmd_mcp.go index ea1d790..bcff7cd 100644 --- a/cmd/cmd_mcp.go +++ b/cmd/cmd_mcp.go @@ -3,17 +3,40 @@ package cmd import ( "fmt" "os" + "strings" mcpserver "github.com/actionforge/actrun-cli/mcp" "github.com/spf13/cobra" ) +// buildMCPInstructions generates the MCP server instructions from the +// actual flags registered on cmdRoot so they stay in sync automatically. +func buildMCPInstructions() string { + var b strings.Builder + + b.WriteString("This MCP server provides tools for debugging and running ActionForge graph (.act) files interactively. ") + b.WriteString("Use the debug_* tools to step through graph execution node by node, set breakpoints, and inspect state.\n\n") + + b.WriteString("If you just need to run a graph without debugging, use the actrun CLI directly instead of this MCP server:\n\n") + fmt.Fprintf(&b, " %s\n\n", cmdRoot.Use) + + b.WriteString("Available flags:\n") + // Include both persistent and local flags from the root command. + b.WriteString(cmdRoot.PersistentFlags().FlagUsages()) + b.WriteString(cmdRoot.LocalNonPersistentFlags().FlagUsages()) + + b.WriteString("\nTo pass arguments to the graph itself, append them after the file: actrun file.act arg1 arg2\n") + b.WriteString("Use '--' to separate actrun flags from graph arguments: actrun --env-file .env file.act -- --graph-flag") + + return b.String() +} + var cmdMcp = &cobra.Command{ Use: "mcp", Short: "Start the MCP server (stdio transport).", - Long: `Starts an MCP server over stdio that exposes graph tools (validate, schema, node types) and debug tools for bridging between an AI agent and an actrun local debug session (WebSocket). Configure this as an MCP server in your AI tool with: {"command": "actrun", "args": ["mcp"]}`, + Long: `Starts an MCP server over stdio that exposes debug tools for bridging between an AI agent and an actrun local debug session (WebSocket). Configure this as an MCP server in your AI tool with: {"command": "actrun", "args": ["mcp"]}`, Run: func(cmd *cobra.Command, args []string) { - if err := mcpserver.RunMCPServer(ActfileSchema); err != nil { + if err := mcpserver.RunMCPServer(buildMCPInstructions()); err != nil { fmt.Fprintf(os.Stderr, "MCP server error: %v\n", err) os.Exit(1) } diff --git a/mcp/debug/register.go b/mcp/debug/register.go index 62f351b..5dcac98 100644 --- a/mcp/debug/register.go +++ b/mcp/debug/register.go @@ -12,9 +12,21 @@ func RegisterDebugTools(s *server.MCPServer) { s.AddTool( mcp.NewTool("debug_connect", - mcp.WithDescription("Connect to a running actrun local debug server"), + mcp.WithDescription( + "Connect to an actrun local debug server via WebSocket. "+ + "This is the entry point for debugging ActionForge graph (.act) files. "+ + "The entire debug flow is automated — the user provides an .act file path and optional flags, you handle the rest. "+ + "\n\nWorkflow: "+ + "(1) Start 'actrun --local' in the background with NO file argument and capture LOCAL_WS_PORT from stdout. "+ + "Pass any user-provided flags (e.g. --env-file , --concurrency , --local-gh-server) to this command. "+ + "Run 'actrun --help' if you need to discover available flags. "+ + "(2) Call debug_connect with the captured port. "+ + "(3) Read the .act file from disk and pass its YAML to debug_run to start execution. "+ + "(4) Use debug_step / debug_step_into / debug_resume to walk through nodes. "+ + "(5) Call debug_disconnect when done, then kill the background actrun process. "+ + "\n\nSource code: https://github.com/actionforge/actrun-cli (see CLAUDE.md for project structure)."), mcp.WithNumber("port", - mcp.Description("The port number of the local debug server (printed as LOCAL_WS_PORT when starting actrun --local)"), + mcp.Description("The LOCAL_WS_PORT printed by 'actrun --local' on startup."), mcp.Required(), ), ), @@ -23,16 +35,20 @@ func RegisterDebugTools(s *server.MCPServer) { s.AddTool( mcp.NewTool("debug_run", - mcp.WithDescription("Start executing a graph in the debug session. Sends the graph YAML content and waits for the first pause or completion."), + mcp.WithDescription( + "Send a graph to the debug server and start execution. "+ + "Read the .act file from disk and pass its full YAML content as the 'graph' parameter. "+ + "The server must have been started with 'actrun --local' (no file argument) — this tool sends the graph over the debug protocol. "+ + "Blocks until the graph pauses (if start_paused is true) or completes."), mcp.WithString("graph", - mcp.Description("The full YAML content of the .act graph file"), + mcp.Description("The full YAML content of the .act graph file. Read the file from disk and pass the contents verbatim. Do NOT fabricate or modify the YAML."), mcp.Required(), ), mcp.WithArray("breakpoints", - mcp.Description("List of node IDs to set as breakpoints before execution"), + mcp.Description("Optional list of node IDs to set as breakpoints before execution. Node IDs can be found in the 'nodes[].id' fields of the .act YAML."), ), mcp.WithBoolean("start_paused", - mcp.Description("Whether to pause at the first node (default: true)"), + mcp.Description("Whether to pause at the first node (default: true). Set to false to run until a breakpoint or completion."), ), ), handleDebugRun(bridge), @@ -40,44 +56,44 @@ func RegisterDebugTools(s *server.MCPServer) { s.AddTool( mcp.NewTool("debug_step", - mcp.WithDescription("Step over: execute the current node and pause at the next node at the same depth or shallower"), + mcp.WithDescription("Step over: execute the current node and pause at the next node at the same depth or shallower. Use this to walk through nodes sequentially without entering group nodes."), ), handleDebugStep(bridge), ) s.AddTool( mcp.NewTool("debug_step_into", - mcp.WithDescription("Step into: if the current node is a group, pause at the first node inside it; otherwise behaves like step"), + mcp.WithDescription("Step into: if the current node is a group, pause at the first node inside it; otherwise behaves like step. Use this to inspect execution within group nodes."), ), handleDebugStepInto(bridge), ) s.AddTool( mcp.NewTool("debug_step_out", - mcp.WithDescription("Step out: resume execution and pause when returning to a shallower depth (parent group)"), + mcp.WithDescription("Step out: resume execution and pause when returning to a shallower depth (parent group). Use this to exit a group node and return to the parent graph level."), ), handleDebugStepOut(bridge), ) s.AddTool( mcp.NewTool("debug_resume", - mcp.WithDescription("Resume execution until the next breakpoint is hit or the graph completes"), + mcp.WithDescription("Resume execution until the next breakpoint is hit or the graph completes. Use this to skip ahead when you don't need to inspect every node."), ), handleDebugResume(bridge), ) s.AddTool( mcp.NewTool("debug_pause", - mcp.WithDescription("Pause execution at the next node visit"), + mcp.WithDescription("Pause execution at the next node visit. Use this after debug_resume if you want to stop and inspect again."), ), handleDebugPause(bridge), ) s.AddTool( mcp.NewTool("debug_set_breakpoint", - mcp.WithDescription("Set a breakpoint at a node. Execution will pause when this node is visited."), + mcp.WithDescription("Set a breakpoint at a node. Execution will pause when this node is visited. Node IDs are the 'id' fields from the .act file's nodes section. For nodes inside groups, use the full path (e.g. 'group-id/node-id')."), mcp.WithString("node_id", - mcp.Description("The full path or ID of the node to set a breakpoint on"), + mcp.Description("The full path or ID of the node to set a breakpoint on."), mcp.Required(), ), ), @@ -86,9 +102,9 @@ func RegisterDebugTools(s *server.MCPServer) { s.AddTool( mcp.NewTool("debug_remove_breakpoint", - mcp.WithDescription("Remove a previously set breakpoint from a node"), + mcp.WithDescription("Remove a previously set breakpoint from a node."), mcp.WithString("node_id", - mcp.Description("The full path or ID of the node to remove the breakpoint from"), + mcp.Description("The full path or ID of the node to remove the breakpoint from."), mcp.Required(), ), ), @@ -97,28 +113,28 @@ func RegisterDebugTools(s *server.MCPServer) { s.AddTool( mcp.NewTool("debug_inspect", - mcp.WithDescription("Return the last debug state including current node, visited nodes, and execution context (variables, outputs, etc.)"), + mcp.WithDescription("Return the last debug state including current node, visited nodes, and execution context (variables, outputs, caches). Use this to examine state without advancing execution."), ), handleDebugInspect(bridge), ) s.AddTool( mcp.NewTool("debug_logs", - mcp.WithDescription("Return and clear buffered log messages (stdout, stderr, warnings) from the debug session"), + mcp.WithDescription("Return and clear buffered log messages (stdout, stderr, warnings) from the debug session. Logs accumulate between calls, so call this periodically to see output from executed nodes."), ), handleDebugLogs(bridge), ) s.AddTool( mcp.NewTool("debug_stop", - mcp.WithDescription("Stop the currently running graph execution"), + mcp.WithDescription("Stop the currently running graph execution. The graph will be cancelled and the debug session returns to idle. You can start a new execution with debug_run afterwards."), ), handleDebugStop(bridge), ) s.AddTool( mcp.NewTool("debug_disconnect", - mcp.WithDescription("Disconnect from the debug server and close the WebSocket connection"), + mcp.WithDescription("Disconnect from the debug server and close the WebSocket connection. Call this when done debugging. After disconnecting, kill the background 'actrun --local' process to clean up."), ), handleDebugDisconnect(bridge), ) diff --git a/mcp/server.go b/mcp/server.go index 13e9b1a..643f513 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -6,17 +6,22 @@ import ( "github.com/mark3labs/mcp-go/server" ) -// RunMCPServer creates the MCP server, registers all tools (graph + debug), +// RunMCPServer creates the MCP server, registers debug tools, // and serves over stdio. It blocks until the stdio transport closes. -func RunMCPServer(actfileSchema []byte) error { +// The instructions parameter is optional; when non-empty it is sent to +// the client in the initialize response. +func RunMCPServer(instructions string) error { version := build.GetAppVersion() - s := server.NewMCPServer( - "actrun", - version, + + opts := []server.ServerOption{ server.WithToolCapabilities(false), - ) + } + if instructions != "" { + opts = append(opts, server.WithInstructions(instructions)) + } + + s := server.NewMCPServer("actrun", version, opts...) - registerGraphTools(s, actfileSchema) mcpdebug.RegisterDebugTools(s) return server.ServeStdio(s) From 9c620fa95099238f8bf8da3411fb35fa5082b43c Mon Sep 17 00:00:00 2001 From: Sebastian Rath Date: Mon, 16 Feb 2026 19:38:50 -0500 Subject: [PATCH 3/4] Remove graph tools since thats reserved for the api gateway --- mcp/tools_graph.go | 240 --------------------------------------------- 1 file changed, 240 deletions(-) delete mode 100644 mcp/tools_graph.go diff --git a/mcp/tools_graph.go b/mcp/tools_graph.go deleted file mode 100644 index f2dc143..0000000 --- a/mcp/tools_graph.go +++ /dev/null @@ -1,240 +0,0 @@ -package mcpserver - -import ( - "context" - "encoding/json" - "fmt" - "sort" - "strings" - - "github.com/actionforge/actrun-cli/core" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - "github.com/santhosh-tekuri/jsonschema/v6" - "go.yaml.in/yaml/v4" -) - -func textResult(text string) *mcp.CallToolResult { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: text, - }, - }, - } -} - -func errorResult(msg string) *mcp.CallToolResult { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: msg, - }, - }, - IsError: true, - } -} - -func jsonResult(v any) *mcp.CallToolResult { - data, err := json.MarshalIndent(v, "", " ") - if err != nil { - return errorResult(fmt.Sprintf("failed to marshal result: %v", err)) - } - return textResult(string(data)) -} - -// registerGraphTools registers the non-debug graph tools on the server. -func registerGraphTools(s *server.MCPServer, actfileSchema []byte) { - s.AddTool( - mcp.NewTool("validate_graph", - mcp.WithDescription("Validate a graph YAML string against the JSON schema and structural rules. Returns a list of errors or a success message."), - mcp.WithString("graph", - mcp.Description("The full YAML content of the .act graph file"), - mcp.Required(), - ), - ), - handleValidateGraph(actfileSchema), - ) - - s.AddTool( - mcp.NewTool("get_schema", - mcp.WithDescription("Return the JSON schema for ActionForge .act graph files"), - ), - handleGetSchema(actfileSchema), - ) - - s.AddTool( - mcp.NewTool("list_node_types", - mcp.WithDescription("List all registered node types with id, name, version, category, and short description"), - mcp.WithString("category", - mcp.Description("Optional category filter (e.g. 'processing', 'control-flow')"), - ), - ), - handleListNodeTypes(), - ) - - s.AddTool( - mcp.NewTool("get_node_type", - mcp.WithDescription("Get full details for a specific node type including inputs, outputs, and descriptions"), - mcp.WithString("node_type_id", - mcp.Description("The node type ID (e.g. 'core/start@v1')"), - mcp.Required(), - ), - ), - handleGetNodeType(), - ) -} - -// convertToJSONCompatible recursively converts YAML-unmarshalled data into -// types that the JSON schema validator accepts. -func convertToJSONCompatible(v any) any { - switch val := v.(type) { - case map[string]any: - result := make(map[string]any, len(val)) - for k, v := range val { - result[k] = convertToJSONCompatible(v) - } - return result - case []any: - result := make([]any, len(val)) - for i, v := range val { - result[i] = convertToJSONCompatible(v) - } - return result - case int: - return float64(val) - case int64: - return float64(val) - case float32: - return float64(val) - default: - return val - } -} - -func handleValidateGraph(actfileSchema []byte) server.ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - graph, err := req.RequireString("graph") - if err != nil { - return errorResult("missing required parameter: graph"), nil - } - - var graphYaml map[string]any - if err := yaml.Unmarshal([]byte(graph), &graphYaml); err != nil { - return jsonResult(map[string]any{ - "valid": false, - "errors": []string{fmt.Sprintf("YAML parse error: %v", err)}, - }), nil - } - - var allErrors []string - - // Schema validation - if len(actfileSchema) > 0 { - if err := validateSchema(actfileSchema, graphYaml); err != nil { - allErrors = append(allErrors, fmt.Sprintf("schema: %v", err)) - } - } - - // Structural validation - _, errs := core.LoadGraph(graphYaml, nil, "", true, core.RunOpts{}) - for _, e := range errs { - allErrors = append(allErrors, e.Error()) - } - - if len(allErrors) > 0 { - return jsonResult(map[string]any{ - "valid": false, - "errors": allErrors, - }), nil - } - - return jsonResult(map[string]any{ - "valid": true, - }), nil - } -} - -func validateSchema(schemaBytes []byte, data any) error { - var schemaObj any - if err := json.Unmarshal(schemaBytes, &schemaObj); err != nil { - return fmt.Errorf("failed to parse schema JSON: %w", err) - } - - compiler := jsonschema.NewCompiler() - if err := compiler.AddResource("actfile-schema.json", schemaObj); err != nil { - return fmt.Errorf("failed to add schema resource: %w", err) - } - - schema, err := compiler.Compile("actfile-schema.json") - if err != nil { - return fmt.Errorf("failed to compile schema: %w", err) - } - - return schema.Validate(convertToJSONCompatible(data)) -} - -func handleGetSchema(actfileSchema []byte) server.ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - if len(actfileSchema) == 0 { - return errorResult("schema not available"), nil - } - return textResult(string(actfileSchema)), nil - } -} - -type nodeTypeSummary struct { - Id string `json:"id"` - Name string `json:"name"` - Version string `json:"version"` - Category string `json:"category"` - ShortDesc string `json:"short_desc"` -} - -func handleListNodeTypes() server.ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - category := req.GetString("category", "") - - registries := core.GetRegistries() - result := make([]nodeTypeSummary, 0, len(registries)) - - for _, def := range registries { - if category != "" && !strings.EqualFold(def.Category, category) { - continue - } - result = append(result, nodeTypeSummary{ - Id: def.Id, - Name: def.Name, - Version: def.Version, - Category: def.Category, - ShortDesc: def.ShortDesc, - }) - } - - sort.Slice(result, func(i, j int) bool { - return result[i].Id < result[j].Id - }) - - return jsonResult(result), nil - } -} - -func handleGetNodeType() server.ToolHandlerFunc { - return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - nodeTypeId, err := req.RequireString("node_type_id") - if err != nil { - return errorResult("missing required parameter: node_type_id"), nil - } - - registries := core.GetRegistries() - def, ok := registries[nodeTypeId] - if !ok { - return errorResult(fmt.Sprintf("node type %q not found", nodeTypeId)), nil - } - - // Return the full definition minus the factory function (already excluded via json:"-") - return jsonResult(def), nil - } -} From b574606e8b80e3015555657a872b14890e5dccb6 Mon Sep 17 00:00:00 2001 From: Sebastian Rath Date: Mon, 16 Feb 2026 19:59:54 -0500 Subject: [PATCH 4/4] Handle warnings and improvements --- cmd/cmd_mcp.go | 1 + mcp/debug/bridge.go | 45 +++++++++++++++++++++++++++++++++------------ mcp/debug/tools.go | 25 +++++++++++++------------ 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/cmd/cmd_mcp.go b/cmd/cmd_mcp.go index bcff7cd..98850e6 100644 --- a/cmd/cmd_mcp.go +++ b/cmd/cmd_mcp.go @@ -35,6 +35,7 @@ var cmdMcp = &cobra.Command{ Use: "mcp", Short: "Start the MCP server (stdio transport).", Long: `Starts an MCP server over stdio that exposes debug tools for bridging between an AI agent and an actrun local debug session (WebSocket). Configure this as an MCP server in your AI tool with: {"command": "actrun", "args": ["mcp"]}`, + Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { if err := mcpserver.RunMCPServer(buildMCPInstructions()); err != nil { fmt.Fprintf(os.Stderr, "MCP server error: %v\n", err) diff --git a/mcp/debug/bridge.go b/mcp/debug/bridge.go index a24c002..4f58a1d 100644 --- a/mcp/debug/bridge.go +++ b/mcp/debug/bridge.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/actionforge/actrun-cli/sessions" "github.com/gorilla/websocket" ) @@ -24,10 +25,15 @@ type IncomingMessage struct { Error string `json:"error,omitempty"` } +// writeDeadline is the timeout applied to every WebSocket write, +// matching the convention in sessions/protocol.go. +const writeDeadline = 10 * time.Second + // Bridge is a stateful WebSocket client that converts the push-based WS // stream into pull-based tool results for the MCP server. type Bridge struct { mu sync.Mutex + writeMu sync.Mutex // serialises all WebSocket writes (gorilla/websocket allows one concurrent writer) conn *websocket.Conn connected bool readErr error @@ -97,34 +103,34 @@ func (b *Bridge) readLoop() { b.mu.Lock() switch msg.Type { - case "log": + case sessions.MsgTypeLog: b.logBuffer = append(b.logBuffer, LogEntry{ Level: "log", Message: msg.Message, }) - case "log_error": + case sessions.MsgTypeLogError: b.logBuffer = append(b.logBuffer, LogEntry{ Level: "error", Message: msg.Message, }) - case "warning": + case sessions.MsgTypeWarning: b.logBuffer = append(b.logBuffer, LogEntry{ Level: "warning", Message: msg.Message, }) - case "debug_state": + case sessions.MsgTypeDebugState: b.lastState = &msg // Deliver to waiter if someone is waiting select { case b.waiter <- msg: default: } - case "job_finished", "job_error": + case sessions.MsgTypeJobFinished, sessions.MsgTypeJobError: select { case b.waiter <- msg: default: } - case "control": + case sessions.MsgTypeControl: // Control messages (e.g. runner_connected) are informational b.logBuffer = append(b.logBuffer, LogEntry{ Level: "info", @@ -150,6 +156,12 @@ func (b *Bridge) Send(payload any) error { return fmt.Errorf("failed to marshal payload: %w", err) } + b.writeMu.Lock() + defer b.writeMu.Unlock() + + if err := conn.SetWriteDeadline(time.Now().Add(writeDeadline)); err != nil { + return fmt.Errorf("failed to set write deadline: %w", err) + } if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { return fmt.Errorf("failed to write message: %w", err) } @@ -211,20 +223,29 @@ func (b *Bridge) Connected() bool { return b.connected } -// Disconnect closes the WebSocket connection. +// Disconnect closes the WebSocket connection and waits for the read loop to exit. func (b *Bridge) Disconnect() error { b.mu.Lock() - defer b.mu.Unlock() - if !b.connected { + b.mu.Unlock() return fmt.Errorf("not connected") } + done := b.done + b.mu.Unlock() - err := b.conn.WriteMessage( + // Send close frame under the write mutex. + b.writeMu.Lock() + _ = b.conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + _ = b.conn.WriteMessage( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), ) + b.writeMu.Unlock() + b.conn.Close() - b.connected = false - return err + + // Wait for readLoop to finish so the Bridge is fully idle before returning. + <-done + + return nil } diff --git a/mcp/debug/tools.go b/mcp/debug/tools.go index 896829a..fd85499 100644 --- a/mcp/debug/tools.go +++ b/mcp/debug/tools.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/actionforge/actrun-cli/sessions" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) @@ -45,11 +46,11 @@ func jsonResult(v any) *mcp.CallToolResult { func blockingResponse(msg *IncomingMessage, logs []LogEntry) *mcp.CallToolResult { status := "unknown" switch msg.Type { - case "debug_state": + case sessions.MsgTypeDebugState: status = "paused" - case "job_finished": + case sessions.MsgTypeJobFinished: status = "finished" - case "job_error": + case sessions.MsgTypeJobError: status = "error" } @@ -110,7 +111,7 @@ func handleDebugRun(b *Bridge) server.ToolHandlerFunc { breakpoints := req.GetStringSlice("breakpoints", nil) payload := map[string]any{ - "type": "run", + "type": sessions.MsgTypeRun, "payload": graph, "start_paused": startPaused, } @@ -132,7 +133,7 @@ func handleDebugStep(b *Bridge) server.ToolHandlerFunc { if err := requireConnected(b); err != nil { return errorResult(err.Error()), nil } - msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_step"}, 60*time.Second) + msg, logs, err := b.SendAndWait(map[string]string{"type": sessions.MsgTypeDebugStep}, 60*time.Second) if err != nil { return errorResult(fmt.Sprintf("step failed: %v", err)), nil } @@ -146,7 +147,7 @@ func handleDebugStepInto(b *Bridge) server.ToolHandlerFunc { if err := requireConnected(b); err != nil { return errorResult(err.Error()), nil } - msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_step_into"}, 60*time.Second) + msg, logs, err := b.SendAndWait(map[string]string{"type": sessions.MsgTypeDebugStepInto}, 60*time.Second) if err != nil { return errorResult(fmt.Sprintf("step into failed: %v", err)), nil } @@ -160,7 +161,7 @@ func handleDebugStepOut(b *Bridge) server.ToolHandlerFunc { if err := requireConnected(b); err != nil { return errorResult(err.Error()), nil } - msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_step_out"}, 60*time.Second) + msg, logs, err := b.SendAndWait(map[string]string{"type": sessions.MsgTypeDebugStepOut}, 60*time.Second) if err != nil { return errorResult(fmt.Sprintf("step out failed: %v", err)), nil } @@ -174,7 +175,7 @@ func handleDebugResume(b *Bridge) server.ToolHandlerFunc { if err := requireConnected(b); err != nil { return errorResult(err.Error()), nil } - msg, logs, err := b.SendAndWait(map[string]string{"type": "debug_resume"}, 120*time.Second) + msg, logs, err := b.SendAndWait(map[string]string{"type": sessions.MsgTypeDebugResume}, 120*time.Second) if err != nil { return errorResult(fmt.Sprintf("resume failed: %v", err)), nil } @@ -188,7 +189,7 @@ func handleDebugPause(b *Bridge) server.ToolHandlerFunc { if err := requireConnected(b); err != nil { return errorResult(err.Error()), nil } - if err := b.Send(map[string]string{"type": "debug_pause"}); err != nil { + if err := b.Send(map[string]string{"type": sessions.MsgTypeDebugPause}); err != nil { return errorResult(fmt.Sprintf("pause failed: %v", err)), nil } return textResult("Pause signal sent"), nil @@ -205,7 +206,7 @@ func handleDebugSetBreakpoint(b *Bridge) server.ToolHandlerFunc { if err != nil { return errorResult("missing required parameter: node_id"), nil } - if err := b.Send(map[string]string{"type": "debug_add_breakpoint", "nodeId": nodeID}); err != nil { + if err := b.Send(map[string]string{"type": sessions.MsgTypeDebugAddBreakpoint, "nodeId": nodeID}); err != nil { return errorResult(fmt.Sprintf("set breakpoint failed: %v", err)), nil } return textResult(fmt.Sprintf("Breakpoint set at %s", nodeID)), nil @@ -222,7 +223,7 @@ func handleDebugRemoveBreakpoint(b *Bridge) server.ToolHandlerFunc { if err != nil { return errorResult("missing required parameter: node_id"), nil } - if err := b.Send(map[string]string{"type": "debug_remove_breakpoint", "nodeId": nodeID}); err != nil { + if err := b.Send(map[string]string{"type": sessions.MsgTypeDebugRemoveBreakpoint, "nodeId": nodeID}); err != nil { return errorResult(fmt.Sprintf("remove breakpoint failed: %v", err)), nil } return textResult(fmt.Sprintf("Breakpoint removed from %s", nodeID)), nil @@ -263,7 +264,7 @@ func handleDebugStop(b *Bridge) server.ToolHandlerFunc { if err := requireConnected(b); err != nil { return errorResult(err.Error()), nil } - if err := b.Send(map[string]string{"type": "stop"}); err != nil { + if err := b.Send(map[string]string{"type": sessions.MsgTypeStop}); err != nil { return errorResult(fmt.Sprintf("stop failed: %v", err)), nil } return textResult("Stop signal sent"), nil