diff --git a/docs/features/idp-token-storage.md b/docs/features/idp-token-storage.md index e558ba58a..79aefef2b 100644 --- a/docs/features/idp-token-storage.md +++ b/docs/features/idp-token-storage.md @@ -78,6 +78,63 @@ Secret, etc.) and inject it as `MCPPROXY_CRED_KEY` at runtime. 4. **Re-auth** — when no refresh token is available, or the refresh fails, the user is required to sign in again (`ErrReauthRequired`). +## Per-user credentials REST API (server edition) + +Brokered upstreams expose a per-user credential surface under the session/JWT +auth middleware. Every endpoint is scoped to the authenticated caller — a user +can only see and manage their own credentials, never another user's. + +| Endpoint | Description | +|----------|-------------| +| `GET /api/v1/user/credentials` | List the connection status of every brokered upstream for the caller. | +| `DELETE /api/v1/user/credentials/{server}` | Disconnect (revoke) the caller's credential for an upstream. | +| `GET /api/v1/user/credentials/{server}/connect` | Initiate the per-user OAuth connect flow (Path B); 302-redirects to the upstream authorization server. | +| `GET /api/v1/user/credentials/{server}/callback` | OAuth connect callback; exchanges the code, stores the per-user credential, and redirects back to the Web UI. | + +### Connection status (`GET /api/v1/user/credentials`) + +The list returns **non-secret metadata only** — access and refresh tokens are +never serialized. Each entry carries a `status`: + +- `connected` — a valid, non-expired per-user credential exists. +- `expired` — a credential exists but its access token has expired. +- `not_connected` — no per-user credential exists for this upstream. +- `unavailable` — the credential store is disabled (no encryption key configured). + +For `oauth_connect` upstreams that are `not_connected` or `expired`, the entry +includes an actionable `connect_path` pointing at the connect endpoint. + +```json +{ + "credentials": [ + { + "server": "github-shared", + "mode": "oauth_connect", + "status": "not_connected", + "connect_path": "/api/v1/user/credentials/github-shared/connect" + }, + { + "server": "internal-api", + "mode": "token_exchange", + "status": "connected", + "token_type": "Bearer", + "scopes": ["read"], + "expires_at": "2026-06-15T20:00:00Z" + } + ] +} +``` + +### Connect flow (Path B) + +`connect` builds an authorization-code + PKCE URL bound to the authenticated +user and redirects there. After consent, the upstream redirects to `callback`, +which validates the one-time `state`, exchanges the code, and persists the +per-user credential (encrypted, `obtained_via=connect_flow`). The credential is +always stored under the **initiating** user, so the callback cannot be used to +write into another user's record. The browser then lands on `/ui/` with a +`credential_connected` / `credential_error` query flag. + ## Operational notes - **Key rotation** is not yet supported. Rotating the key requires clearing the diff --git a/internal/serveredition/api/connector_provider.go b/internal/serveredition/api/connector_provider.go new file mode 100644 index 000000000..9a64103c8 --- /dev/null +++ b/internal/serveredition/api/connector_provider.go @@ -0,0 +1,135 @@ +//go:build server + +package api + +import ( + "fmt" + "net/http" + "net/url" + "strings" + "sync" + + "go.uber.org/zap" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/serveredition/broker" +) + +// connectorProvider builds and caches one broker.OAuthConnector per +// oauth_connect upstream (keyed by serverKey). The same connector instance must +// serve both the connect redirect and the callback because the connector holds +// the in-memory PKCE/state for each pending flow; rebuilding it per request +// would lose that state. It satisfies broker.ConnectorProvider so the T6 +// CredentialResolver can reuse the same connectors when it needs to produce a +// connect URL for an unconnected user. +type connectorProvider struct { + store broker.CredentialStore + logger *zap.Logger + + mu sync.Mutex + baseURL string // gateway public origin, e.g. "https://gw.example.com" + cache map[string]*broker.OAuthConnector +} + +// newConnectorProvider constructs an empty provider. +func newConnectorProvider(store broker.CredentialStore, logger *zap.Logger) *connectorProvider { + if logger == nil { + logger = zap.NewNop() + } + return &connectorProvider{ + store: store, + logger: logger, + cache: make(map[string]*broker.OAuthConnector), + } +} + +// observeBaseURL records the gateway's public origin the first time it is seen +// (from an incoming request). The connect callback URL registered with the +// upstream authorization server is derived from it, and OAuth requires the +// redirect_uri to be byte-identical between the authorize request and the token +// exchange — so it is fixed once and reused for the lifetime of a connector. +func (p *connectorProvider) observeBaseURL(r *http.Request) { + base := baseURLFromRequest(r) + p.mu.Lock() + defer p.mu.Unlock() + if p.baseURL == "" { + p.baseURL = base + } +} + +// connector returns the cached connector for an oauth_connect upstream, building +// it on first use. It errors for non-oauth_connect or unbrokered servers. +func (p *connectorProvider) connector(server *config.ServerConfig) (*broker.OAuthConnector, error) { + if server == nil || server.AuthBroker == nil { + return nil, fmt.Errorf("connector provider: server has no auth_broker configuration") + } + if server.AuthBroker.Mode != config.AuthBrokerModeOAuthConnect { + return nil, fmt.Errorf("connector provider: server %q is not an oauth_connect upstream", server.Name) + } + + key := oauth.GenerateServerKey(server.Name, server.URL) + + p.mu.Lock() + defer p.mu.Unlock() + if c, ok := p.cache[key]; ok { + return c, nil + } + + ab := server.AuthBroker + cfg := broker.ConnectorConfig{ + ServerName: server.Name, + ServerURL: server.URL, + AuthorizationEndpoint: ab.AuthorizationEndpoint, + TokenEndpoint: ab.TokenEndpoint, + ClientID: ab.ClientID, + ClientSecret: ab.ClientSecret, + Scopes: ab.Scopes, + RedirectURI: p.callbackURLLocked(server.Name), + Resource: ab.Resource, + } + conn, err := broker.NewOAuthConnector(p.store, cfg, p.logger) + if err != nil { + return nil, err + } + p.cache[key] = conn + return conn, nil +} + +// ConnectorFor satisfies broker.ConnectorProvider for the credential resolver. +func (p *connectorProvider) ConnectorFor(server *config.ServerConfig) (broker.Connector, error) { + return p.connector(server) +} + +// callbackURLLocked builds the gateway callback URL for a server. Caller holds p.mu. +func (p *connectorProvider) callbackURLLocked(serverName string) string { + base := strings.TrimSuffix(p.baseURL, "/") + return base + connectCallbackPath(serverName) +} + +// connectCallbackPath is the relative callback route for a server's connect flow. +func connectCallbackPath(serverName string) string { + return "/api/v1/user/credentials/" + url.PathEscape(serverName) + "/callback" +} + +// connectInitiatePath is the relative connect route for a server. +func connectInitiatePath(serverName string) string { + return "/api/v1/user/credentials/" + url.PathEscape(serverName) + "/connect" +} + +// baseURLFromRequest derives the gateway's public origin (scheme://host), +// honoring X-Forwarded-Proto for reverse-proxy deployments. Mirrors the OAuth +// login handler's buildCallbackURL scheme detection. +func baseURLFromRequest(r *http.Request) string { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + scheme = proto + } + return scheme + "://" + r.Host +} + +// Compile-time assertion that the provider satisfies the resolver's interface. +var _ broker.ConnectorProvider = (*connectorProvider)(nil) diff --git a/internal/serveredition/api/credential_handlers.go b/internal/serveredition/api/credential_handlers.go new file mode 100644 index 000000000..42a5c55ae --- /dev/null +++ b/internal/serveredition/api/credential_handlers.go @@ -0,0 +1,351 @@ +//go:build server + +package api + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "go.uber.org/zap" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/serveredition/broker" +) + +// Connection-status values surfaced by GET /api/v1/user/credentials. They carry +// no secret material (FR-026). +const ( + credStatusConnected = "connected" // a valid, non-expired per-user credential exists + credStatusExpired = "expired" // a credential exists but its access token has expired + credStatusNotConnected = "not_connected" // no per-user credential exists for this upstream + credStatusUnavailable = "unavailable" // the credential store is disabled (no encryption key) +) + +// credentialConnectSuccessRedirect is where the browser lands after a successful +// or denied connect flow. The Web UI surfaces the resulting state. +const credentialConnectSuccessRedirect = "/ui/" + +// CredentialHandlers exposes per-user brokered-credential surfaces for the +// server edition: listing connection status, disconnecting, and driving the +// per-user OAuth connect flow (Path B, spec 074 T5). Every operation is scoped +// to the authenticated user (FR-027) and never returns secret values (FR-026). +type CredentialHandlers struct { + store broker.CredentialStore + brokerServers []*config.ServerConfig // admin-configured shared servers; broker ones are filtered at use + connectors *connectorProvider + logger *zap.SugaredLogger +} + +// NewCredentialHandlers builds the handlers over a credential store and the set +// of shared servers (only those carrying an auth_broker block are brokered). +func NewCredentialHandlers(store broker.CredentialStore, sharedServers []*config.ServerConfig, logger *zap.SugaredLogger) *CredentialHandlers { + if logger == nil { + logger = zap.NewNop().Sugar() + } + return &CredentialHandlers{ + store: store, + brokerServers: sharedServers, + connectors: newConnectorProvider(store, logger.Desugar()), + logger: logger, + } +} + +// ConnectorProvider exposes the shared, connector cache so the credential +// resolver (T6) can mint connect URLs through the same connectors that serve the +// REST connect/callback flow. +func (h *CredentialHandlers) ConnectorProvider() broker.ConnectorProvider { + return h.connectors +} + +// RegisterRoutes registers credential routes on the provided router. +func (h *CredentialHandlers) RegisterRoutes(r chi.Router) { + r.Route("/user/credentials", func(r chi.Router) { + r.Get("/", h.listCredentials) + r.Delete("/{server}", h.deleteCredential) + r.Get("/{server}/connect", h.connect) + r.Get("/{server}/callback", h.callback) + }) +} + +// RegisterRoutesWithPrefix registers credential routes with a path prefix. +func (h *CredentialHandlers) RegisterRoutesWithPrefix(r chi.Router, prefix string) { + r.Get(prefix+"/user/credentials", h.listCredentials) + r.Delete(prefix+"/user/credentials/{server}", h.deleteCredential) + r.Get(prefix+"/user/credentials/{server}/connect", h.connect) + r.Get(prefix+"/user/credentials/{server}/callback", h.callback) +} + +// --- Response types --- + +// CredentialStatus is the non-secret connection view for one brokered upstream. +// It deliberately omits access_token / refresh_token (FR-026). +type CredentialStatus struct { + Server string `json:"server"` + Mode string `json:"mode"` + Status string `json:"status"` + TokenType string `json:"token_type,omitempty"` + Scopes []string `json:"scopes,omitempty"` + Audience string `json:"audience,omitempty"` + ObtainedVia string `json:"obtained_via,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + // ConnectPath is the actionable connect endpoint for oauth_connect upstreams + // that are not currently connected (or whose credential expired). + ConnectPath string `json:"connect_path,omitempty"` +} + +// CredentialListResponse wraps the per-user credential statuses. +type CredentialListResponse struct { + Credentials []CredentialStatus `json:"credentials"` +} + +// --- Handlers --- + +// listCredentials returns the connection status of every brokered upstream for +// the authenticated user. Secret values are never included (FR-026). +func (h *CredentialHandlers) listCredentials(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + writeError(w, http.StatusUnauthorized, "Authentication required") + return + } + + storeEnabled := h.store != nil && h.store.Enabled() + + out := make([]CredentialStatus, 0) + for _, srv := range h.brokerServerList() { + status := CredentialStatus{ + Server: srv.Name, + Mode: srv.AuthBroker.Mode, + } + + switch { + case !storeEnabled: + status.Status = credStatusUnavailable + default: + cred, gerr := h.store.Get(userID, oauth.GenerateServerKey(srv.Name, srv.URL)) + switch { + case errors.Is(gerr, broker.ErrNotFound): + status.Status = credStatusNotConnected + case gerr != nil: + h.logger.Warnw("failed to load brokered credential", "user_id", userID, "server", srv.Name, "error", gerr) + status.Status = credStatusUnavailable + default: + status.populateFromCredential(cred) + } + } + + // Offer an actionable connect path for connect-flow upstreams that are + // not currently usable. + if srv.AuthBroker.Mode == config.AuthBrokerModeOAuthConnect && + (status.Status == credStatusNotConnected || status.Status == credStatusExpired) { + status.ConnectPath = connectInitiatePath(srv.Name) + } + + out = append(out, status) + } + + sort.Slice(out, func(i, j int) bool { return out[i].Server < out[j].Server }) + writeJSON(w, http.StatusOK, CredentialListResponse{Credentials: out}) +} + +// populateFromCredential fills the non-secret metadata and the connected/expired +// status from a stored credential. It never copies token material (FR-026). +func (s *CredentialStatus) populateFromCredential(cred *broker.UpstreamCredential) { + if cred.IsExpired() { + s.Status = credStatusExpired + } else { + s.Status = credStatusConnected + } + s.TokenType = cred.TokenType + s.Scopes = cred.Scopes + s.Audience = cred.Audience + s.ObtainedVia = cred.ObtainedVia + if !cred.ExpiresAt.IsZero() { + t := cred.ExpiresAt + s.ExpiresAt = &t + } + if !cred.UpdatedAt.IsZero() { + t := cred.UpdatedAt + s.UpdatedAt = &t + } +} + +// deleteCredential disconnects (revokes) the authenticated user's credential for +// a brokered upstream. Only the caller's own record is affected (FR-027). +func (h *CredentialHandlers) deleteCredential(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + writeError(w, http.StatusUnauthorized, "Authentication required") + return + } + + srv, ok := h.lookupServer(w, r) + if !ok { + return + } + + if h.store == nil || !h.store.Enabled() { + writeError(w, http.StatusServiceUnavailable, "Credential broker is not enabled") + return + } + + if err := h.store.Delete(userID, oauth.GenerateServerKey(srv.Name, srv.URL)); err != nil { + h.logger.Errorw("failed to delete brokered credential", "user_id", userID, "server", srv.Name, "error", err) + writeError(w, http.StatusInternalServerError, "Failed to disconnect credential") + return + } + + h.logger.Infow("brokered credential disconnected", "user_id", userID, "server", srv.Name) + writeJSON(w, http.StatusOK, map[string]string{ + "message": fmt.Sprintf("Disconnected credential for %q", srv.Name), + }) +} + +// connect initiates Path B: it builds the upstream authorize URL bound to the +// authenticated user and redirects the browser there (spec 074 T5, FR-011). +func (h *CredentialHandlers) connect(w http.ResponseWriter, r *http.Request) { + userID, err := getUserID(r) + if err != nil { + writeError(w, http.StatusUnauthorized, "Authentication required") + return + } + + srv, ok := h.lookupServer(w, r) + if !ok { + return + } + if srv.AuthBroker.Mode != config.AuthBrokerModeOAuthConnect { + writeError(w, http.StatusBadRequest, fmt.Sprintf("Server %q does not use the OAuth connect flow", srv.Name)) + return + } + + h.connectors.observeBaseURL(r) + conn, err := h.connectors.connector(srv) + if err != nil { + h.logger.Errorw("failed to build connector", "user_id", userID, "server", srv.Name, "error", err) + writeError(w, http.StatusInternalServerError, "Failed to initiate connect flow") + return + } + + authURL, _, err := conn.BuildAuthorizationURL(userID) + if err != nil { + h.logger.Errorw("failed to build authorize URL", "user_id", userID, "server", srv.Name, "error", err) + writeError(w, http.StatusInternalServerError, "Failed to initiate connect flow") + return + } + + h.logger.Infow("brokered credential connect initiated", "user_id", userID, "server", srv.Name) + http.Redirect(w, r, authURL, http.StatusFound) +} + +// callback completes Path B: it validates the state, exchanges the code for a +// per-user upstream credential (persisted by the connector under the initiating +// user), and redirects back to the Web UI. A denied/failed authorization clears +// the pending flow and stores nothing. +func (h *CredentialHandlers) callback(w http.ResponseWriter, r *http.Request) { + if _, err := getUserID(r); err != nil { + writeError(w, http.StatusUnauthorized, "Authentication required") + return + } + + srv, ok := h.lookupServer(w, r) + if !ok { + return + } + if srv.AuthBroker.Mode != config.AuthBrokerModeOAuthConnect { + writeError(w, http.StatusBadRequest, fmt.Sprintf("Server %q does not use the OAuth connect flow", srv.Name)) + return + } + + conn, err := h.connectors.connector(srv) + if err != nil { + h.logger.Errorw("failed to resolve connector for callback", "server", srv.Name, "error", err) + writeError(w, http.StatusInternalServerError, "Failed to complete connect flow") + return + } + + q := r.URL.Query() + state := q.Get("state") + + // Authorization-server-side error (e.g. user denied consent). + if asErr := q.Get("error"); asErr != "" { + _ = conn.Deny(state, asErr) + h.logger.Infow("brokered credential connect denied", "server", srv.Name, "reason", asErr) + http.Redirect(w, r, credentialConnectRedirect(srv.Name, asErr), http.StatusFound) + return + } + + code := q.Get("code") + if _, err := conn.Complete(r.Context(), state, code); err != nil { + h.logger.Warnw("brokered credential connect callback failed", "server", srv.Name, "error", err) + http.Redirect(w, r, credentialConnectRedirect(srv.Name, "connect_failed"), http.StatusFound) + return + } + + h.logger.Infow("brokered credential connected", "server", srv.Name) + http.Redirect(w, r, credentialConnectRedirect(srv.Name, ""), http.StatusFound) +} + +// --- Helpers --- + +// lookupServer resolves the {server} path param to a brokered upstream, writing +// a 4xx and returning ok=false when missing/unknown. +func (h *CredentialHandlers) lookupServer(w http.ResponseWriter, r *http.Request) (*config.ServerConfig, bool) { + name := chi.URLParam(r, "server") + if decoded, err := url.PathUnescape(name); err == nil { + name = decoded + } + name = strings.TrimSpace(name) + if name == "" { + writeError(w, http.StatusBadRequest, "Server name is required") + return nil, false + } + srv := h.brokerServerByName(name) + if srv == nil { + writeError(w, http.StatusNotFound, fmt.Sprintf("Brokered server %q not found", name)) + return nil, false + } + return srv, true +} + +// brokerServerList returns the shared servers that carry an auth_broker block. +func (h *CredentialHandlers) brokerServerList() []*config.ServerConfig { + out := make([]*config.ServerConfig, 0, len(h.brokerServers)) + for _, s := range h.brokerServers { + if s != nil && s.AuthBroker != nil { + out = append(out, s) + } + } + return out +} + +// brokerServerByName finds a brokered upstream by case-insensitive name. +func (h *CredentialHandlers) brokerServerByName(name string) *config.ServerConfig { + for _, s := range h.brokerServers { + if s != nil && s.AuthBroker != nil && strings.EqualFold(s.Name, name) { + return s + } + } + return nil +} + +// credentialConnectRedirect builds the post-callback Web UI redirect, tagging +// the outcome so the UI can surface success or an error without exposing secrets. +func credentialConnectRedirect(serverName, errCode string) string { + v := url.Values{} + v.Set("credential_server", serverName) + if errCode != "" { + v.Set("credential_error", errCode) + } else { + v.Set("credential_connected", "1") + } + return credentialConnectSuccessRedirect + "?" + v.Encode() +} diff --git a/internal/serveredition/api/credential_handlers_test.go b/internal/serveredition/api/credential_handlers_test.go new file mode 100644 index 000000000..f7cbfbe71 --- /dev/null +++ b/internal/serveredition/api/credential_handlers_test.go @@ -0,0 +1,361 @@ +//go:build server + +package api + +import ( + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "testing" + "time" + + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.etcd.io/bbolt" + "go.uber.org/zap" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/auth" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/serveredition/broker" +) + +const testUserB = "01HTEST00000000000000USERB" + +// credTestStore builds an enabled AES credential store backed by a temp BBolt DB. +func credTestStore(t *testing.T) broker.CredentialStore { + t.Helper() + tmp := filepath.Join(t.TempDir(), "cred.db") + db, err := bbolt.Open(tmp, 0600, &bbolt.Options{Timeout: time.Second}) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + // 32 zero bytes is a valid AES-256 key for tests. + key := base64.StdEncoding.EncodeToString(make([]byte, 32)) + store, err := broker.NewBBoltAESStore(db, key, zap.NewNop()) + require.NoError(t, err) + require.True(t, store.Enabled()) + return store +} + +// credRouter wires the credential handlers behind an auth context injector. +func credRouter(h *CredentialHandlers, authCtx *auth.AuthContext) *chi.Mux { + r := chi.NewRouter() + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := auth.WithAuthContext(req.Context(), authCtx) + next.ServeHTTP(w, req.WithContext(ctx)) + }) + }) + h.RegisterRoutesWithPrefix(r, "/api/v1") + return r +} + +// brokerHTTPServer builds an HTTP-family broker upstream config for the given mode. +func brokerHTTPServer(name, mode string) *config.ServerConfig { + ab := &config.AuthBrokerConfig{ + Mode: mode, + TokenEndpoint: "https://idp.example.com/token", + ClientID: "client-" + name, + Scopes: []string{"repo"}, + } + if mode == config.AuthBrokerModeOAuthConnect { + ab.AuthorizationEndpoint = "https://as.example.com/authorize" + } + return &config.ServerConfig{ + Name: name, + URL: "https://" + name + ".example.com/mcp", + Protocol: "http", + Shared: true, + AuthBroker: ab, + } +} + +func serverKeyFor(s *config.ServerConfig) string { + return oauth.GenerateServerKey(s.Name, s.URL) +} + +func TestCredentialsList_RedactsSecrets(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("shared-gh", config.AuthBrokerModeTokenExchange) + require.NoError(t, store.Put(testUserID, serverKeyFor(srv), &broker.UpstreamCredential{ + Type: "oauth2", + AccessToken: "SECRET-ACCESS-TOKEN", + RefreshToken: "SECRET-REFRESH-TOKEN", + ExpiresAt: time.Now().Add(time.Hour), + Scopes: []string{"repo"}, + TokenType: "Bearer", + ObtainedVia: "token_exchange", + })) + + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials", http.NoBody) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + body := w.Body.String() + // FR-026: secret material must never appear in the response. + assert.NotContains(t, body, "SECRET-ACCESS-TOKEN") + assert.NotContains(t, body, "SECRET-REFRESH-TOKEN") + + var resp CredentialListResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Len(t, resp.Credentials, 1) + got := resp.Credentials[0] + assert.Equal(t, "shared-gh", got.Server) + assert.Equal(t, credStatusConnected, got.Status) + assert.Equal(t, config.AuthBrokerModeTokenExchange, got.Mode) + assert.Equal(t, []string{"repo"}, got.Scopes) + assert.NotNil(t, got.ExpiresAt) +} + +func TestCredentialsList_Statuses(t *testing.T) { + store := credTestStore(t) + connected := brokerHTTPServer("connected-srv", config.AuthBrokerModeTokenExchange) + expired := brokerHTTPServer("expired-srv", config.AuthBrokerModeTokenExchange) + fresh := brokerHTTPServer("fresh-srv", config.AuthBrokerModeOAuthConnect) + + require.NoError(t, store.Put(testUserID, serverKeyFor(connected), &broker.UpstreamCredential{ + Type: "oauth2", AccessToken: "a", ExpiresAt: time.Now().Add(time.Hour), + })) + require.NoError(t, store.Put(testUserID, serverKeyFor(expired), &broker.UpstreamCredential{ + Type: "oauth2", AccessToken: "b", ExpiresAt: time.Now().Add(-time.Hour), + })) + + h := NewCredentialHandlers(store, []*config.ServerConfig{connected, expired, fresh}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials", http.NoBody) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var resp CredentialListResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + byName := map[string]CredentialStatus{} + for _, c := range resp.Credentials { + byName[c.Server] = c + } + require.Len(t, byName, 3) + assert.Equal(t, credStatusConnected, byName["connected-srv"].Status) + assert.Equal(t, credStatusExpired, byName["expired-srv"].Status) + assert.Equal(t, credStatusNotConnected, byName["fresh-srv"].Status) + // oauth_connect upstreams expose an actionable connect path. + assert.Equal(t, "/api/v1/user/credentials/fresh-srv/connect", byName["fresh-srv"].ConnectPath) + assert.Empty(t, byName["connected-srv"].ConnectPath) +} + +func TestCredentialsList_StoreDisabled(t *testing.T) { + // Disabled store (no key): every broker upstream reports "unavailable". + tmp := filepath.Join(t.TempDir(), "cred.db") + db, err := bbolt.Open(tmp, 0600, &bbolt.Options{Timeout: time.Second}) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + store, err := broker.NewBBoltAESStore(db, "", zap.NewNop()) + require.NoError(t, err) + require.False(t, store.Enabled()) + + srv := brokerHTTPServer("shared-gh", config.AuthBrokerModeTokenExchange) + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials", http.NoBody) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var resp CredentialListResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Len(t, resp.Credentials, 1) + assert.Equal(t, credStatusUnavailable, resp.Credentials[0].Status) +} + +func TestCredentialsDelete_Removes(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("shared-gh", config.AuthBrokerModeTokenExchange) + sk := serverKeyFor(srv) + require.NoError(t, store.Put(testUserID, sk, &broker.UpstreamCredential{ + Type: "oauth2", AccessToken: "a", ExpiresAt: time.Now().Add(time.Hour), + })) + + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/credentials/shared-gh", http.NoBody) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + _, err := store.Get(testUserID, sk) + assert.ErrorIs(t, err, broker.ErrNotFound) +} + +func TestCredentialsDelete_UnknownServer404(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("shared-gh", config.AuthBrokerModeTokenExchange) + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + req := httptest.NewRequest(http.MethodDelete, "/api/v1/user/credentials/does-not-exist", http.NoBody) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) +} + +func TestCredentials_CrossUserIsolation(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("shared-gh", config.AuthBrokerModeTokenExchange) + sk := serverKeyFor(srv) + // User B has a valid credential. + require.NoError(t, store.Put(testUserB, sk, &broker.UpstreamCredential{ + Type: "oauth2", AccessToken: "B-SECRET", ExpiresAt: time.Now().Add(time.Hour), + })) + + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + // Act as user A. + r := credRouter(h, auth.UserContext(testUserID, "a@example.com", "A", "google")) + + // FR-027: A must not see B's credential — A sees not_connected. + listReq := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials", http.NoBody) + listW := httptest.NewRecorder() + r.ServeHTTP(listW, listReq) + require.Equal(t, http.StatusOK, listW.Code) + assert.NotContains(t, listW.Body.String(), "B-SECRET") + var resp CredentialListResponse + require.NoError(t, json.Unmarshal(listW.Body.Bytes(), &resp)) + require.Len(t, resp.Credentials, 1) + assert.Equal(t, credStatusNotConnected, resp.Credentials[0].Status) + + // FR-027: A deleting must not remove B's credential. + delReq := httptest.NewRequest(http.MethodDelete, "/api/v1/user/credentials/shared-gh", http.NoBody) + delW := httptest.NewRecorder() + r.ServeHTTP(delW, delReq) + require.Equal(t, http.StatusOK, delW.Code) + + bCred, err := store.Get(testUserB, sk) + require.NoError(t, err) + assert.Equal(t, "B-SECRET", bCred.AccessToken) +} + +func TestCredentialsConnect_Redirects(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("connect-srv", config.AuthBrokerModeOAuthConnect) + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials/connect-srv/connect", http.NoBody) + req.Host = "gw.example.com" + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusFound, w.Code) + loc, err := url.Parse(w.Header().Get("Location")) + require.NoError(t, err) + assert.Equal(t, "as.example.com", loc.Host) + q := loc.Query() + assert.NotEmpty(t, q.Get("state")) + assert.Equal(t, "client-connect-srv", q.Get("client_id")) + assert.NotEmpty(t, q.Get("code_challenge")) + assert.Contains(t, q.Get("redirect_uri"), "/api/v1/user/credentials/connect-srv/callback") +} + +func TestCredentialsConnect_NonConnectMode400(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("xchg-srv", config.AuthBrokerModeTokenExchange) + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials/xchg-srv/connect", http.NoBody) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + +func TestCredentialsConnectCallback_StoresCredential(t *testing.T) { + // Upstream token endpoint that mints a credential for any code. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"NEW-ACCESS","refresh_token":"NEW-REFRESH","token_type":"Bearer","expires_in":3600,"scope":"repo"}`) + })) + defer ts.Close() + + store := credTestStore(t) + srv := brokerHTTPServer("connect-srv", config.AuthBrokerModeOAuthConnect) + srv.AuthBroker.TokenEndpoint = ts.URL + sk := serverKeyFor(srv) + + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + // Step 1: connect → capture state from the redirect. + connReq := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials/connect-srv/connect", http.NoBody) + connReq.Host = "gw.example.com" + connW := httptest.NewRecorder() + r.ServeHTTP(connW, connReq) + require.Equal(t, http.StatusFound, connW.Code) + loc, err := url.Parse(connW.Header().Get("Location")) + require.NoError(t, err) + state := loc.Query().Get("state") + require.NotEmpty(t, state) + + // Step 2: callback with code+state on the same handler instance. + cbURL := "/api/v1/user/credentials/connect-srv/callback?code=auth-code&state=" + url.QueryEscape(state) + cbReq := httptest.NewRequest(http.MethodGet, cbURL, http.NoBody) + cbReq.Host = "gw.example.com" + cbW := httptest.NewRecorder() + r.ServeHTTP(cbW, cbReq) + require.Equal(t, http.StatusFound, cbW.Code) + + // The per-user credential is now persisted. + cred, err := store.Get(testUserID, sk) + require.NoError(t, err) + assert.Equal(t, "NEW-ACCESS", cred.AccessToken) + assert.Equal(t, "NEW-REFRESH", cred.RefreshToken) +} + +func TestCredentialsCallback_DeniedByUpstream(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("connect-srv", config.AuthBrokerModeOAuthConnect) + sk := serverKeyFor(srv) + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + r := credRouter(h, defaultAuthContext()) + + // Begin a flow to register a state. + connReq := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials/connect-srv/connect", http.NoBody) + connReq.Host = "gw.example.com" + connW := httptest.NewRecorder() + r.ServeHTTP(connW, connReq) + require.Equal(t, http.StatusFound, connW.Code) + loc, _ := url.Parse(connW.Header().Get("Location")) + state := loc.Query().Get("state") + + cbURL := "/api/v1/user/credentials/connect-srv/callback?error=access_denied&state=" + url.QueryEscape(state) + cbReq := httptest.NewRequest(http.MethodGet, cbURL, http.NoBody) + cbW := httptest.NewRecorder() + r.ServeHTTP(cbW, cbReq) + // Denied flow redirects back to the UI and stores nothing. + require.Equal(t, http.StatusFound, cbW.Code) + _, err := store.Get(testUserID, sk) + assert.ErrorIs(t, err, broker.ErrNotFound) +} + +func TestCredentials_Unauthenticated(t *testing.T) { + store := credTestStore(t) + srv := brokerHTTPServer("shared-gh", config.AuthBrokerModeTokenExchange) + h := NewCredentialHandlers(store, []*config.ServerConfig{srv}, zap.NewNop().Sugar()) + // Empty auth context → unauthenticated. + r := credRouter(h, &auth.AuthContext{}) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/user/credentials", http.NoBody) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusUnauthorized, w.Code) +} diff --git a/internal/serveredition/setup.go b/internal/serveredition/setup.go index 7f684cd3b..8201c6a34 100644 --- a/internal/serveredition/setup.go +++ b/internal/serveredition/setup.go @@ -87,6 +87,10 @@ func setupMultiUserOAuth(deps Dependencies) error { adminHandlers := teamsapi.NewAdminHandlers(userStore, nil, sessionManager, cfg.AdminEmails, sharedServers, deps.Config, configPath, deps.ManagementService, deps.Logger) userHandlers := teamsapi.NewUserHandlers(userStore, sharedServers, deps.StorageManager, hmacKey, deps.Logger) userActivityHandlers := teamsapi.NewUserActivityHandlers(nil, userStore, sharedServers, deps.Logger) + // Per-user brokered-credential surfaces (spec 074 T8): list connection + // status, disconnect, and the Path B connect/callback flow. Reuses the same + // credential store wired into the OAuth login handler above. + credentialHandlers := teamsapi.NewCredentialHandlers(credStore, sharedServers, deps.Logger) deps.Router.Group(func(r chi.Router) { r.Use(authMiddleware.Middleware()) @@ -95,6 +99,7 @@ func setupMultiUserOAuth(deps Dependencies) error { adminHandlers.RegisterRoutesWithPrefix(r, "/api/v1") userHandlers.RegisterRoutesWithPrefix(r, "/api/v1") userActivityHandlers.RegisterRoutesWithPrefix(r, "/api/v1") + credentialHandlers.RegisterRoutesWithPrefix(r, "/api/v1") }) deps.Logger.Infow("Server multi-user OAuth initialized",