diff --git a/pkg/gateway/mcp/server.go b/pkg/gateway/mcp/server.go index 72740d8..99a9991 100644 --- a/pkg/gateway/mcp/server.go +++ b/pkg/gateway/mcp/server.go @@ -18,6 +18,12 @@ type Server struct { client *hookdeck.Client cfg *config.Config mcpServer *mcpsdk.Server + + // sessionCtx is the context passed to RunStdio. It is cancelled when the + // MCP transport closes (stdin EOF). Background goroutines (e.g. login + // polling) should select on this — NOT on the per-request ctx passed to + // tool handlers, which is cancelled when the handler returns. + sessionCtx context.Context } // NewServer creates an MCP server with all Hookdeck tools registered. @@ -56,7 +62,7 @@ func (s *Server) registerTools() { "reauth": {Type: "boolean", Desc: "If true, clear stored credentials and start a new browser login. Use when project listing fails — complete login in the browser, then retry hookdeck_projects."}, }), }, - s.wrapWithTelemetry("hookdeck_login", handleLogin(s.client, s.cfg)), + s.wrapWithTelemetry("hookdeck_login", handleLogin(s)), ) } @@ -126,5 +132,13 @@ func extractAction(req *mcpsdk.CallToolRequest) string { // RunStdio starts the MCP server on stdin/stdout and blocks until the // connection is closed (i.e. stdin reaches EOF). func (s *Server) RunStdio(ctx context.Context) error { - return s.mcpServer.Run(ctx, &mcpsdk.StdioTransport{}) + return s.Run(ctx, &mcpsdk.StdioTransport{}) +} + +// Run starts the MCP server on the given transport. It stores ctx as the +// session-level context so background goroutines (e.g. login polling) can +// detect when the session ends. +func (s *Server) Run(ctx context.Context, transport mcpsdk.Transport) error { + s.sessionCtx = ctx + return s.mcpServer.Run(ctx, transport) } diff --git a/pkg/gateway/mcp/server_test.go b/pkg/gateway/mcp/server_test.go index bdc60fa..7c9d4cd 100644 --- a/pkg/gateway/mcp/server_test.go +++ b/pkg/gateway/mcp/server_test.go @@ -9,6 +9,7 @@ import ( "net/url" "strings" "testing" + "time" mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/assert" @@ -45,7 +46,7 @@ func connectInMemory(t *testing.T, client *hookdeck.Client) *mcpsdk.ClientSessio ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) go func() { - _ = srv.mcpServer.Run(ctx, serverTransport) + _ = srv.Run(ctx, serverTransport) }() mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{ @@ -1143,7 +1144,7 @@ func TestLoginTool_ReauthStartsFreshLogin(t *testing.T) { serverTransport, clientTransport := mcpsdk.NewInMemoryTransports() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - go func() { _ = srv.mcpServer.Run(ctx, serverTransport) }() + go func() { _ = srv.Run(ctx, serverTransport) }() mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil) session, err := mcpClient.Connect(ctx, clientTransport, nil) @@ -1182,7 +1183,7 @@ func TestLoginTool_ReturnsURLImmediately(t *testing.T) { serverTransport, clientTransport := mcpsdk.NewInMemoryTransports() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - go func() { _ = srv.mcpServer.Run(ctx, serverTransport) }() + go func() { _ = srv.Run(ctx, serverTransport) }() mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil) session, err := mcpClient.Connect(ctx, clientTransport, nil) @@ -1218,7 +1219,7 @@ func TestLoginTool_InProgressShowsURL(t *testing.T) { serverTransport, clientTransport := mcpsdk.NewInMemoryTransports() ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - go func() { _ = srv.mcpServer.Run(ctx, serverTransport) }() + go func() { _ = srv.Run(ctx, serverTransport) }() mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil) session, err := mcpClient.Connect(ctx, clientTransport, nil) @@ -1236,6 +1237,70 @@ func TestLoginTool_InProgressShowsURL(t *testing.T) { assert.Contains(t, text, "https://hookdeck.com/auth?code=xyz") } +func TestLoginTool_PollSurvivesAcrossToolCalls(t *testing.T) { + // Regression: the login polling goroutine must use the session-level + // context, not the per-request ctx (which is cancelled when the handler + // returns). If the goroutine selected on per-request ctx, it would be + // cancelled immediately and the second hookdeck_login call would see a + // "login cancelled" error instead of "Already authenticated". + pollCount := 0 + api := mockAPI(t, map[string]http.HandlerFunc{ + "/2025-07-01/cli-auth": func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "browser_url": "https://hookdeck.com/auth?code=survive", + "poll_url": "http://" + r.Host + "/2025-07-01/cli-auth/poll?key=survive", + }) + }, + "/2025-07-01/cli-auth/poll": func(w http.ResponseWriter, r *http.Request) { + pollCount++ + if pollCount >= 2 { + // Simulate user completing browser auth on 2nd poll. + json.NewEncoder(w).Encode(map[string]any{ + "claimed": true, + "key": "sk_test_survive12345", + "team_id": "proj_survive", + "team_name": "Survive Project", + "team_mode": "console", + "user_name": "test-user", + "organization_name": "test-org", + }) + return + } + json.NewEncoder(w).Encode(map[string]any{"claimed": false}) + }, + }) + + unauthClient := newTestClient(api.URL, "") + cfg := &config.Config{APIBaseURL: api.URL} + srv := NewServer(unauthClient, cfg) + + serverTransport, clientTransport := mcpsdk.NewInMemoryTransports() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + go func() { _ = srv.Run(ctx, serverTransport) }() + + mcpClient := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test", Version: "0.0.1"}, nil) + session, err := mcpClient.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = session.Close() }) + + // First call initiates the flow — handler returns immediately. + result := callTool(t, session, "hookdeck_login", map[string]any{}) + assert.False(t, result.IsError) + assert.Contains(t, textContent(t, result), "https://hookdeck.com/auth?code=survive") + + // Wait briefly for the polling goroutine to complete (poll interval is 2s + // in production, but the mock returns instantly so it completes quickly). + time.Sleep(500 * time.Millisecond) + + // Second call — if the goroutine survived, the client is now authenticated. + result2 := callTool(t, session, "hookdeck_login", map[string]any{}) + assert.False(t, result2.IsError) + text := textContent(t, result2) + assert.Contains(t, text, "Already authenticated") + assert.Equal(t, "sk_test_survive12345", unauthClient.APIKey) +} + // --------------------------------------------------------------------------- // API error scenarios (shared across tools) // --------------------------------------------------------------------------- diff --git a/pkg/gateway/mcp/tool_login.go b/pkg/gateway/mcp/tool_login.go index 0bc5792..723f668 100644 --- a/pkg/gateway/mcp/tool_login.go +++ b/pkg/gateway/mcp/tool_login.go @@ -36,7 +36,9 @@ type loginState struct { err error // non-nil if polling failed } -func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler { +func handleLogin(srv *Server) mcpsdk.ToolHandler { + client := srv.client + cfg := srv.cfg var stateMu sync.Mutex var state *loginState @@ -121,9 +123,11 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler } // Poll in the background so we return the URL to the agent immediately. - // WaitForAPIKey blocks with time.Sleep; run it in a goroutine and - // select on ctx so we abandon the attempt when the session closes. - go func(s *loginState, ctx context.Context) { + // WaitForAPIKey blocks with time.Sleep internally, so we run it in an + // inner goroutine and select on the session-level context (not the + // per-request ctx, which is cancelled when this handler returns). + sessionCtx := srv.sessionCtx + go func(s *loginState) { defer close(s.done) type pollResult struct { @@ -138,7 +142,7 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler var response *hookdeck.PollAPIKeyResponse select { - case <-ctx.Done(): + case <-sessionCtx.Done(): s.err = fmt.Errorf("login cancelled: MCP session closed") log.Debug("Login polling cancelled — MCP session closed") return @@ -187,7 +191,7 @@ func handleLogin(client *hookdeck.Client, cfg *config.Config) mcpsdk.ToolHandler "user": response.UserName, "project": response.ProjectName, }).Info("MCP login completed successfully") - }(state, ctx) + }(state) // Return the URL immediately so the agent can show it to the user. return TextResult(fmt.Sprintf(