diff --git a/docs/features/auth-broker.md b/docs/features/auth-broker.md index 55e687c79..96b44a662 100644 --- a/docs/features/auth-broker.md +++ b/docs/features/auth-broker.md @@ -79,6 +79,17 @@ Config validation fails with `auth_broker.authorization_endpoint is required for A denied consent (`error=access_denied`) clears the pending flow and stores nothing. +## Credential resolution + +On each proxied request the broker resolves the per-user credential to inject, in a strict **per-user-only** order. There is **no shared or static fallback** — a request that cannot produce a per-user credential fails rather than borrowing another identity: + +1. A valid cached per-user credential is injected directly; if it is within the near-expiry window it is refreshed first (re-minted for `token_exchange`/`entra_obo`, or renewed from the stored refresh token for `oauth_connect`). +2. Otherwise, for `token_exchange`/`entra_obo`, a credential is minted from the user's stored IdP subject token. +3. Otherwise, for `oauth_connect` upstreams the user has not connected — or whose stored credential expired and could not be refreshed — the request fails with an **actionable error carrying the connect URL**, so the user is told to (re)connect rather than being silently denied. +4. Otherwise the request fails with "no per-user credential available". + +Concurrent requests for the same `(user, upstream)` are coalesced (single-flight) so a burst does not trigger duplicate upstream token flows. A policy-decision hook is evaluated per call immediately before the credential is returned; no policy engine ships yet, so it permits every injection by default. + ## See also - [OAuth Authentication](./oauth-authentication.md) — upstream OAuth for the personal edition. diff --git a/go.mod b/go.mod index 87b7ff96f..67846fa0f 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( go.opentelemetry.io/otel/trace v1.44.0 go.uber.org/zap v1.28.0 golang.org/x/mod v0.37.0 + golang.org/x/sync v0.20.0 golang.org/x/sys v0.46.0 golang.org/x/term v0.44.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -139,7 +140,6 @@ require ( go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/net v0.55.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/text v0.37.0 // indirect golang.org/x/tools v0.45.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260526163538-3dc84a4a5aaa // indirect diff --git a/internal/serveredition/broker/credential_resolver.go b/internal/serveredition/broker/credential_resolver.go new file mode 100644 index 000000000..932cc75de --- /dev/null +++ b/internal/serveredition/broker/credential_resolver.go @@ -0,0 +1,349 @@ +//go:build server + +package broker + +import ( + "context" + "errors" + "fmt" + "time" + + "go.uber.org/zap" + "golang.org/x/sync/singleflight" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" +) + +// defaultRefreshThreshold is how close to expiry a cached credential may be +// before the resolver proactively refreshes it. A credential expiring within +// this window is treated as stale (FR-013). +const defaultRefreshThreshold = 60 * time.Second + +// Sentinel errors returned by the resolver. They are deliberately coarse and +// secret-free so they can be surfaced to callers and audited (FR-014/FR-019). +var ( + // ErrUnauthenticated is returned when Resolve is called without a user + // identity. Brokering is strictly per-user; an anonymous caller is rejected + // before any store or upstream access (FR-014). + ErrUnauthenticated = errors.New("credential resolver: unauthenticated caller") + + // ErrNoCredential is returned when no per-user credential can be produced and + // no actionable connect flow is available. There is deliberately no shared or + // static fallback (FR-014). + ErrNoCredential = errors.New("credential resolver: no per-user credential available") + + // ErrBrokerNotConfigured is returned when the target server has no auth_broker + // block. Such upstreams are not brokered and behave exactly as today. + ErrBrokerNotConfigured = errors.New("credential resolver: server has no auth_broker configuration") +) + +// Exchanger mints an upstream credential by exchanging the user's stored IdP +// subject token (token_exchange / entra_obo). *TokenExchanger satisfies it. +type Exchanger interface { + Exchange(ctx context.Context, userID, serverKey string, cfg *config.AuthBrokerConfig) (*UpstreamCredential, error) +} + +// Connector drives the per-user OAuth connect flow (Path B). *OAuthConnector +// satisfies it. The resolver uses Refresh to renew a near-expiry connect-flow +// credential and BuildAuthorizationURL to produce an actionable connect URL +// when the user has not yet connected the upstream. +type Connector interface { + ServerKey() string + BuildAuthorizationURL(userID string) (authURL, state string, err error) + Refresh(ctx context.Context, userID string) (*UpstreamCredential, error) +} + +// ConnectorProvider resolves the per-upstream OAuthConnector for a server. The +// REST layer (T8) supplies an implementation that assembles a ConnectorConfig +// from the server's auth_broker block plus the gateway's callback URL. It is +// only consulted for oauth_connect-mode upstreams. +type ConnectorProvider interface { + ConnectorFor(server *config.ServerConfig) (Connector, error) +} + +// NotConnectedError is returned when an oauth_connect upstream cannot produce a +// usable per-user credential and the user must (re)consent. It carries the +// authorize URL the caller redirects the user to (FR-013, actionable error) and +// a Reason that distinguishes a first-time connect from an expired credential +// whose refresh failed (so callers do not tell an already-connected user they +// have "never connected"). +type NotConnectedError struct { + ServerName string + ConnectURL string + Reason string +} + +func (e *NotConnectedError) Error() string { + if e.Reason != "" { + return fmt.Sprintf("credential resolver: upstream %q requires connection (%s); connect at: %s", + e.ServerName, e.Reason, e.ConnectURL) + } + return fmt.Sprintf("credential resolver: upstream %q is not connected for this user; connect at: %s", + e.ServerName, e.ConnectURL) +} + +// PolicyDecision is the verdict of the policy-decision seam evaluated before a +// resolved credential is returned. Allow=false blocks the injection. +type PolicyDecision struct { + Allow bool + Reason string +} + +// PolicyInput is the context handed to the policy seam. +type PolicyInput struct { + UserID string + ServerName string + ServerKey string + Credential *UpstreamCredential +} + +// PolicyHook is the policy-decision seam (FR-015). No policy engine ships now; +// the resolver defaults to an allow-all hook. A future engine implements this +// interface without changing the resolver. +type PolicyHook interface { + Evaluate(ctx context.Context, in PolicyInput) (PolicyDecision, error) +} + +// PolicyHookFunc adapts a function to the PolicyHook interface. +type PolicyHookFunc func(ctx context.Context, in PolicyInput) (PolicyDecision, error) + +// Evaluate implements PolicyHook. +func (f PolicyHookFunc) Evaluate(ctx context.Context, in PolicyInput) (PolicyDecision, error) { + return f(ctx, in) +} + +// allowAllPolicy is the default seam implementation: it permits every +// injection. It exists so the resolver always has a non-nil hook (FR-015). +type allowAllPolicy struct{} + +func (allowAllPolicy) Evaluate(_ context.Context, _ PolicyInput) (PolicyDecision, error) { + return PolicyDecision{Allow: true}, nil +} + +// PolicyDeniedError is returned when the policy seam blocks a resolved +// credential from being injected. +type PolicyDeniedError struct { + ServerName string + Reason string +} + +func (e *PolicyDeniedError) Error() string { + if e.Reason != "" { + return fmt.Sprintf("credential resolver: policy denied credential for %q: %s", e.ServerName, e.Reason) + } + return fmt.Sprintf("credential resolver: policy denied credential for %q", e.ServerName) +} + +// ResolverDeps are the collaborators a CredentialResolver needs. Store and +// Exchanger are required for token-exchange upstreams; Connectors is required +// only for oauth_connect upstreams. Policy and Logger are optional. +type ResolverDeps struct { + Store CredentialStore + Exchanger Exchanger + Connectors ConnectorProvider + Policy PolicyHook + Logger *zap.Logger + RefreshThreshold time.Duration +} + +// CredentialResolver produces the per-user upstream credential to inject on a +// proxied request. It applies a strict per-user-only ordering (FR-013/FR-014): +// +// 1. a valid cached per-user credential (refreshed if near-expiry); +// 2. else a freshly token-exchanged / OBO credential from the stored IdP +// subject token; +// 3. else, for oauth_connect upstreams the user has not connected, an +// actionable NotConnectedError carrying the connect URL; +// 4. else ErrNoCredential. +// +// There is no shared or static fallback. Concurrent acquisitions for the same +// (user, server) are coalesced via single-flight so the upstream authorization +// server is not hit with duplicate flows. +type CredentialResolver struct { + store CredentialStore + exchanger Exchanger + conns ConnectorProvider + policy PolicyHook + logger *zap.Logger + + refreshThreshold time.Duration + group singleflight.Group +} + +// NewCredentialResolver constructs a resolver from its dependencies, applying +// defaults for the optional fields. +func NewCredentialResolver(deps ResolverDeps) *CredentialResolver { + logger := deps.Logger + if logger == nil { + logger = zap.NewNop() + } + policy := deps.Policy + if policy == nil { + policy = allowAllPolicy{} + } + threshold := deps.RefreshThreshold + if threshold <= 0 { + threshold = defaultRefreshThreshold + } + return &CredentialResolver{ + store: deps.Store, + exchanger: deps.Exchanger, + conns: deps.Connectors, + policy: policy, + logger: logger.Named("credential-resolver"), + refreshThreshold: threshold, + } +} + +// Resolve returns the per-user credential to inject for (userID, server), +// applying the ordering described on CredentialResolver. The policy seam is +// evaluated per call after acquisition; credential acquisition itself is +// coalesced per (user, server) via single-flight. +func (r *CredentialResolver) Resolve(ctx context.Context, userID string, server *config.ServerConfig) (*UpstreamCredential, error) { + if userID == "" { + return nil, ErrUnauthenticated + } + if server == nil || server.AuthBroker == nil { + return nil, ErrBrokerNotConfigured + } + if r.store == nil || !r.store.Enabled() { + return nil, ErrStoreDisabled + } + + serverKey := oauth.GenerateServerKey(server.Name, server.URL) + + // Coalesce concurrent acquisitions for the same (user, server) so duplicate + // upstream token flows are not triggered (reuse the single-flight pattern). + // + // The flight runs the acquisition once for every co-pending caller. Detach + // the caller's cancellation with context.WithoutCancel so the in-flight + // acquisition is not aborted — and its error broadcast to all waiters — just + // because whichever caller happened to start the flight cancelled (client + // disconnect, timeout). Per-caller cancellation still applies below at the + // policy/return layer, which uses the caller's original ctx. + flightKey := userID + "\x00" + serverKey + v, err, _ := r.group.Do(flightKey, func() (interface{}, error) { + return r.acquire(context.WithoutCancel(ctx), userID, serverKey, server) + }) + if err != nil { + return nil, err + } + cred, ok := v.(*UpstreamCredential) + if !ok || cred == nil { + return nil, ErrNoCredential + } + + // Policy-decision seam: evaluated per call, before the credential is handed + // to the caller (FR-015). Default hook allows everything. + decision, perr := r.policy.Evaluate(ctx, PolicyInput{ + UserID: userID, + ServerName: server.Name, + ServerKey: serverKey, + Credential: cred, + }) + if perr != nil { + return nil, fmt.Errorf("credential resolver: policy evaluation failed: %w", perr) + } + if !decision.Allow { + return nil, &PolicyDeniedError{ServerName: server.Name, Reason: decision.Reason} + } + return cred, nil +} + +// acquire runs the per-user-only ordering for a single (user, server). It is +// invoked inside the single-flight group. +// +// Acquisition and refresh share a path per mode so a near-expiry cache miss does +// not trigger a redundant double acquisition. The Exchanger (T4) and Connector +// (T5) persist their results into the store themselves, so the resolver never +// calls store.Put — it only reads the cache via store.Get. +func (r *CredentialResolver) acquire(ctx context.Context, userID, serverKey string, server *config.ServerConfig) (*UpstreamCredential, error) { + cfg := server.AuthBroker + + // 1. Serve a still-valid, not-near-expiry cached credential directly. + cached, err := r.store.Get(userID, serverKey) + hasCache := err == nil && cached != nil + switch { + case hasCache: + if cached.IsValid() && !cached.ExpiresWithin(r.refreshThreshold) { + return cached, nil + } + // Stale / near-expiry: renewed by the per-mode path below. + case errors.Is(err, ErrNotFound): + // No cache: acquired by the per-mode path below. + default: + // Unexpected store error (not "missing"): surface it. + return nil, fmt.Errorf("credential resolver: load cached credential: %w", err) + } + + switch cfg.Mode { + case config.AuthBrokerModeTokenExchange, config.AuthBrokerModeEntraOBO: + // 2. Token-exchange / OBO: the first-acquisition and refresh paths are + // identical (re-mint from the stored IdP subject token), so a single + // Exchange call covers both the cache-miss and near-expiry cases. + if r.exchanger == nil { + return nil, fmt.Errorf("credential resolver: no token exchanger configured for mode %q", cfg.Mode) + } + return r.exchanger.Exchange(ctx, userID, serverKey, cfg) + + case config.AuthBrokerModeOAuthConnect: + conn, cerr := r.connectorFor(server) + if cerr != nil { + return nil, cerr + } + // A cached connect-flow credential means the user already connected: + // renew transparently via the stored refresh token. Only when that + // refresh fails do we ask the (already-connected) user to reconnect. + if hasCache && cached.RefreshToken != "" { + refreshed, rerr := conn.Refresh(ctx, userID) + if rerr == nil { + return refreshed, nil + } + r.logger.Warn("connect-flow credential refresh failed; user must reconnect", + zap.String("server", server.Name), zap.Error(rerr)) + return nil, r.notConnected(conn, server, userID, "stored credential expired and refresh failed; reconnect required") + } + // 3. Never connected, or connected without a usable refresh token and now + // expired — both require (re)consent through the connect flow. + reason := "not connected" + if hasCache { + reason = "stored credential expired; reconnect required" + } + return nil, r.notConnected(conn, server, userID, reason) + + default: + // 4. No recognised acquisition strategy and no per-user credential. + return nil, ErrNoCredential + } +} + +// notConnected builds the actionable NotConnectedError carrying the upstream +// authorize URL the caller must redirect the user to, tagged with reason. +func (r *CredentialResolver) notConnected(conn Connector, server *config.ServerConfig, userID, reason string) error { + authURL, _, aerr := conn.BuildAuthorizationURL(userID) + if aerr != nil { + return fmt.Errorf("credential resolver: build connect URL: %w", aerr) + } + return &NotConnectedError{ServerName: server.Name, ConnectURL: authURL, Reason: reason} +} + +// connectorFor resolves the per-upstream connector, guarding against a missing +// provider (only oauth_connect upstreams need one). +func (r *CredentialResolver) connectorFor(server *config.ServerConfig) (Connector, error) { + if r.conns == nil { + return nil, fmt.Errorf("credential resolver: no connector provider configured for oauth_connect upstream %q", server.Name) + } + conn, err := r.conns.ConnectorFor(server) + if err != nil { + return nil, fmt.Errorf("credential resolver: resolve connector: %w", err) + } + return conn, nil +} + +// Compile-time assertions that the concrete broker types satisfy the resolver's +// collaborator interfaces. +var ( + _ Exchanger = (*TokenExchanger)(nil) + _ Connector = (*OAuthConnector)(nil) +) diff --git a/internal/serveredition/broker/credential_resolver_test.go b/internal/serveredition/broker/credential_resolver_test.go new file mode 100644 index 000000000..0ae0fb6b2 --- /dev/null +++ b/internal/serveredition/broker/credential_resolver_test.go @@ -0,0 +1,485 @@ +//go:build server + +package broker + +import ( + "context" + "errors" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" +) + +// --- test doubles ----------------------------------------------------------- + +// fakeStore is an in-memory CredentialStore keyed by (userID, serverKey), +// matching the real backend's keying so resolver key derivation is exercised. +type fakeStore struct { + mu sync.Mutex + enabled bool + data map[string]*UpstreamCredential + getErr error +} + +func newFakeStore() *fakeStore { + return &fakeStore{enabled: true, data: map[string]*UpstreamCredential{}} +} + +func storeKey(userID, serverKey string) string { return userID + "\x00" + serverKey } + +func (s *fakeStore) Enabled() bool { return s.enabled } + +func (s *fakeStore) Get(userID, serverKey string) (*UpstreamCredential, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.getErr != nil { + return nil, s.getErr + } + c, ok := s.data[storeKey(userID, serverKey)] + if !ok { + return nil, ErrNotFound + } + return c, nil +} + +func (s *fakeStore) Put(userID, serverKey string, cred *UpstreamCredential) error { + s.mu.Lock() + defer s.mu.Unlock() + s.data[storeKey(userID, serverKey)] = cred + return nil +} + +func (s *fakeStore) Delete(userID, serverKey string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.data, storeKey(userID, serverKey)) + return nil +} + +func (s *fakeStore) List(userID string) ([]CredentialEntry, error) { return nil, nil } + +func (s *fakeStore) seed(userID, serverKey string, cred *UpstreamCredential) { + s.mu.Lock() + defer s.mu.Unlock() + s.data[storeKey(userID, serverKey)] = cred +} + +// fakeExchanger records calls and returns a programmed credential/error. +type fakeExchanger struct { + calls int32 + cred *UpstreamCredential + err error + delay time.Duration + startWG *sync.WaitGroup + gotCtxErr error +} + +func (e *fakeExchanger) Exchange(ctx context.Context, userID, serverKey string, _ *config.AuthBrokerConfig) (*UpstreamCredential, error) { + e.gotCtxErr = ctx.Err() + if e.startWG != nil { + e.startWG.Done() + } + if e.delay > 0 { + time.Sleep(e.delay) + } + atomic.AddInt32(&e.calls, 1) + if e.err != nil { + return nil, e.err + } + return e.cred, nil +} + +// fakeConnector implements Connector for connect-flow paths. +type fakeConnector struct { + serverKey string + authURL string + buildErr error + refreshCred *UpstreamCredential + refreshErr error + buildCalls int32 + refreshCalls int32 +} + +func (c *fakeConnector) ServerKey() string { return c.serverKey } + +func (c *fakeConnector) BuildAuthorizationURL(_ string) (string, string, error) { + atomic.AddInt32(&c.buildCalls, 1) + if c.buildErr != nil { + return "", "", c.buildErr + } + return c.authURL, "state-xyz", nil +} + +func (c *fakeConnector) Refresh(_ context.Context, _ string) (*UpstreamCredential, error) { + atomic.AddInt32(&c.refreshCalls, 1) + if c.refreshErr != nil { + return nil, c.refreshErr + } + return c.refreshCred, nil +} + +type fakeConnectorProvider struct { + conn *fakeConnector + err error +} + +func (p *fakeConnectorProvider) ConnectorFor(_ *config.ServerConfig) (Connector, error) { + if p.err != nil { + return nil, p.err + } + return p.conn, nil +} + +// --- fixtures ---------------------------------------------------------------- + +func httpServer(name string, broker *config.AuthBrokerConfig) *config.ServerConfig { + return &config.ServerConfig{ + Name: name, + URL: "https://" + name + ".example.com/mcp", + Protocol: "http", + AuthBroker: broker, + } +} + +func tokenExchangeBroker() *config.AuthBrokerConfig { + b := &config.AuthBrokerConfig{Mode: config.AuthBrokerModeTokenExchange, TokenEndpoint: "https://idp/token", Scopes: []string{"api"}} + b.ApplyDefaults() + return b +} + +func connectBroker() *config.AuthBrokerConfig { + b := &config.AuthBrokerConfig{ + Mode: config.AuthBrokerModeOAuthConnect, + TokenEndpoint: "https://idp/token", + AuthorizationEndpoint: "https://idp/authorize", + ClientID: "client", + } + b.ApplyDefaults() + return b +} + +func validCred() *UpstreamCredential { + return &UpstreamCredential{Type: "oauth2", AccessToken: "cached-token", ExpiresAt: time.Now().Add(time.Hour), ObtainedVia: "token_exchange"} +} + +// --- tests ------------------------------------------------------------------- + +func TestResolve_ValidCachedCredential(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, validCred()) + + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "fresh"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "cached-token" { + t.Fatalf("expected cached token, got %q", got.AccessToken) + } + if c := atomic.LoadInt32(&ex.calls); c != 0 { + t.Fatalf("expected no exchange calls for valid cache, got %d", c) + } +} + +func TestResolve_NearExpiryRefresh_TokenExchange(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + nearExpiry := &UpstreamCredential{AccessToken: "old", ExpiresAt: time.Now().Add(10 * time.Second)} + store.seed("alice", key, nearExpiry) + + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "refreshed"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "refreshed" { + t.Fatalf("expected refreshed token, got %q", got.AccessToken) + } + if c := atomic.LoadInt32(&ex.calls); c != 1 { + t.Fatalf("expected 1 exchange call, got %d", c) + } +} + +func TestResolve_NearExpiryRefresh_ConnectFlow(t *testing.T) { + store := newFakeStore() + server := httpServer("github", connectBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, &UpstreamCredential{AccessToken: "old", RefreshToken: "rt", ExpiresAt: time.Now().Add(5 * time.Second)}) + + conn := &fakeConnector{serverKey: key, refreshCred: &UpstreamCredential{AccessToken: "refreshed-connect"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Connectors: &fakeConnectorProvider{conn: conn}}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "refreshed-connect" { + t.Fatalf("expected refreshed-connect, got %q", got.AccessToken) + } + if c := atomic.LoadInt32(&conn.refreshCalls); c != 1 { + t.Fatalf("expected 1 refresh call, got %d", c) + } +} + +func TestResolve_NoCache_TokenExchange(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "exchanged"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "exchanged" { + t.Fatalf("expected exchanged token, got %q", got.AccessToken) + } +} + +func TestResolve_NoCache_EntraOBO(t *testing.T) { + store := newFakeStore() + b := &config.AuthBrokerConfig{Mode: config.AuthBrokerModeEntraOBO, TokenEndpoint: "https://login.microsoftonline.com/token"} + b.ApplyDefaults() + server := httpServer("graph", b) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "obo"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "obo" { + t.Fatalf("expected obo token, got %q", got.AccessToken) + } +} + +func TestResolve_ConnectUnconnected_ReturnsActionableConnectURL(t *testing.T) { + store := newFakeStore() + server := httpServer("github", connectBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + conn := &fakeConnector{serverKey: key, authURL: "https://idp/authorize?client_id=client&state=state-xyz"} + r := NewCredentialResolver(ResolverDeps{Store: store, Connectors: &fakeConnectorProvider{conn: conn}}) + + _, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected NotConnectedError, got nil") + } + var nce *NotConnectedError + if !errors.As(err, &nce) { + t.Fatalf("expected *NotConnectedError, got %T: %v", err, err) + } + if nce.ConnectURL != conn.authURL { + t.Fatalf("expected connect URL %q in error, got %q", conn.authURL, nce.ConnectURL) + } + if !strings.Contains(err.Error(), conn.authURL) { + t.Fatalf("error message must surface the connect URL, got %q", err.Error()) + } +} + +func TestResolve_Unauthenticated_Rejected(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{cred: validCred()} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + _, err := r.Resolve(context.Background(), "", server) + if !errors.Is(err, ErrUnauthenticated) { + t.Fatalf("expected ErrUnauthenticated, got %v", err) + } + if c := atomic.LoadInt32(&ex.calls); c != 0 { + t.Fatalf("expected no work for unauthenticated caller, got %d exchange calls", c) + } +} + +func TestResolve_StoreDisabled_DegradesGracefully(t *testing.T) { + store := newFakeStore() + store.enabled = false + server := httpServer("grafana", tokenExchangeBroker()) + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: &fakeExchanger{cred: validCred()}}) + + _, err := r.Resolve(context.Background(), "alice", server) + if !errors.Is(err, ErrStoreDisabled) { + t.Fatalf("expected ErrStoreDisabled, got %v", err) + } +} + +func TestResolve_NoBrokerConfig_Rejected(t *testing.T) { + store := newFakeStore() + server := httpServer("plain", nil) + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: &fakeExchanger{}}) + + _, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected error for server without auth_broker, got nil") + } +} + +func TestResolve_NoStaticFallback_OnExchangeFailure(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{err: errors.New("token exchange failed: status 401, error \"invalid_grant\"")} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected the exchange error to propagate (no static fallback), got nil") + } + if got != nil { + t.Fatalf("expected no credential on failure (FR-014, no shared fallback), got %+v", got) + } +} + +func TestResolve_PolicyHook_DeniesInjection(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, validCred()) + + policy := PolicyHookFunc(func(_ context.Context, in PolicyInput) (PolicyDecision, error) { + return PolicyDecision{Allow: false, Reason: "blocked by policy for " + in.ServerName}, nil + }) + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: &fakeExchanger{}, Policy: policy}) + + _, err := r.Resolve(context.Background(), "alice", server) + var pde *PolicyDeniedError + if !errors.As(err, &pde) { + t.Fatalf("expected *PolicyDeniedError, got %T: %v", err, err) + } + if !strings.Contains(pde.Reason, "grafana") { + t.Fatalf("expected reason to include server name, got %q", pde.Reason) + } +} + +func TestResolve_SingleFlight_CoalescesConcurrentAcquisitions(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + + const n = 12 + var start sync.WaitGroup + start.Add(1) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "exchanged"}, delay: 40 * time.Millisecond} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + var wg sync.WaitGroup + errs := make([]error, n) + toks := make([]string, n) + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + start.Wait() + cred, err := r.Resolve(context.Background(), "alice", server) + errs[idx] = err + if cred != nil { + toks[idx] = cred.AccessToken + } + }(i) + } + start.Done() // release all goroutines together + wg.Wait() + + for i := 0; i < n; i++ { + if errs[i] != nil { + t.Fatalf("goroutine %d errored: %v", i, errs[i]) + } + if toks[i] != "exchanged" { + t.Fatalf("goroutine %d got %q", i, toks[i]) + } + } + if c := atomic.LoadInt32(&ex.calls); c != 1 { + t.Fatalf("single-flight should coalesce to 1 upstream acquisition, got %d", c) + } +} + +// TestResolve_SingleFlight_DetachesCallerCancellation proves the must-fix from +// review: the in-flight acquisition must not inherit the calling request's +// cancellation, or a cancelled caller would broadcast its ctx error to every +// co-pending acquisition for the same (user, server). +func TestResolve_SingleFlight_DetachesCallerCancellation(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "exchanged"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // caller's request is already cancelled before acquisition runs + + got, err := r.Resolve(ctx, "alice", server) + if err != nil { + t.Fatalf("acquisition should run despite caller cancellation, got error: %v", err) + } + if got.AccessToken != "exchanged" { + t.Fatalf("expected exchanged token, got %q", got.AccessToken) + } + if ex.gotCtxErr != nil { + t.Fatalf("flight context must be detached from caller cancellation, got ctx.Err()=%v", ex.gotCtxErr) + } +} + +// TestResolve_TokenExchange_NearExpiry_NoDoubleExchangeOnFailure proves the +// advisory fix: a near-expiry token-exchange credential whose re-mint fails must +// surface that single error, not retry Exchange a second time. +func TestResolve_TokenExchange_NearExpiry_NoDoubleExchangeOnFailure(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, &UpstreamCredential{AccessToken: "old", ExpiresAt: time.Now().Add(5 * time.Second)}) + + ex := &fakeExchanger{err: errors.New("token exchange failed: status 401, error \"invalid_grant\"")} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + _, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected the exchange error to propagate, got nil") + } + if c := atomic.LoadInt32(&ex.calls); c != 1 { + t.Fatalf("near-expiry exchange failure must not double-call Exchange, got %d calls", c) + } +} + +// TestResolve_ConnectFlow_RefreshFails_ReturnsReconnectError proves the advisory +// fix: an already-connected user whose refresh fails gets an actionable +// reconnect error (with the connect URL), not a misleading "never connected". +func TestResolve_ConnectFlow_RefreshFails_ReturnsReconnectError(t *testing.T) { + store := newFakeStore() + server := httpServer("github", connectBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, &UpstreamCredential{AccessToken: "old", RefreshToken: "rt", ExpiresAt: time.Now().Add(5 * time.Second)}) + + conn := &fakeConnector{ + serverKey: key, + authURL: "https://idp/authorize?client_id=client&state=state-xyz", + refreshErr: errors.New("oauth connector: token endpoint returned 400: invalid_grant"), + } + r := NewCredentialResolver(ResolverDeps{Store: store, Connectors: &fakeConnectorProvider{conn: conn}}) + + _, err := r.Resolve(context.Background(), "alice", server) + var nce *NotConnectedError + if !errors.As(err, &nce) { + t.Fatalf("expected *NotConnectedError, got %T: %v", err, err) + } + if nce.Reason == "" || !strings.Contains(nce.Reason, "reconnect") { + t.Fatalf("expected a reconnect reason, got %q", nce.Reason) + } + if nce.ConnectURL != conn.authURL { + t.Fatalf("expected connect URL %q, got %q", conn.authURL, nce.ConnectURL) + } + if c := atomic.LoadInt32(&conn.refreshCalls); c != 1 { + t.Fatalf("expected exactly 1 refresh attempt, got %d", c) + } + if c := atomic.LoadInt32(&conn.buildCalls); c != 1 { + t.Fatalf("expected the connect URL to be built once, got %d", c) + } +}