From 8dd631283367d1008fde46ec4b45e4de8a434e49 Mon Sep 17 00:00:00 2001 From: Algis Dumbris Date: Mon, 15 Jun 2026 15:42:36 +0300 Subject: [PATCH 1/3] =?UTF-8?q?feat(broker):=20CredentialResolver=20?= =?UTF-8?q?=E2=80=94=20per-user-only=20ordering=20+=20policy=20seam=20(spe?= =?UTF-8?q?c=20074,=20MCP-1039)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the per-user credential resolver (T6) that selects which brokered credential to inject on a proxied request. Strict per-user-only ordering (FR-013/FR-014), with no shared or static fallback: 1. valid cached per-user credential (refreshed if near-expiry); 2. else token-exchange / Entra OBO from the stored IdP subject token; 3. else, for oauth_connect upstreams the user has not connected, an actionable NotConnectedError carrying the connect URL; 4. else ErrNoCredential. - Single-flight coalescing per (user, server) so concurrent acquisitions do not trigger duplicate upstream token flows (golang.org/x/sync). - PolicyHook seam (FR-015) evaluated per call before a credential is returned; ships with an allow-all default (no policy engine yet). - Unauthenticated callers and a disabled store are rejected up front. - Exchanger/Connector/ConnectorProvider interfaces decouple the resolver from the concrete TokenExchanger (T4) and OAuthConnector (T5); both satisfy the interfaces via compile-time assertions. TDD: each ordering branch, near-expiry refresh (token-exchange + connect), unconnected -> connect-URL error, unauthenticated reject, store-disabled, policy-denied, no-static-fallback, and single-flight under -race. Related: spec 074 (MCP-1033). Builds on T3 (#601), T4 (#600), T5 (#602), all merged to main. Co-Authored-By: Paperclip --- go.mod | 2 +- .../broker/credential_resolver.go | 334 +++++++++++++++ .../broker/credential_resolver_test.go | 402 ++++++++++++++++++ 3 files changed, 737 insertions(+), 1 deletion(-) create mode 100644 internal/serveredition/broker/credential_resolver.go create mode 100644 internal/serveredition/broker/credential_resolver_test.go diff --git a/go.mod b/go.mod index 87b7ff96f..67846fa0f 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( go.opentelemetry.io/otel/trace v1.44.0 go.uber.org/zap v1.28.0 golang.org/x/mod v0.37.0 + golang.org/x/sync v0.20.0 golang.org/x/sys v0.46.0 golang.org/x/term v0.44.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 @@ -139,7 +140,6 @@ require ( go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/net v0.55.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/text v0.37.0 // indirect golang.org/x/tools v0.45.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260526163538-3dc84a4a5aaa // indirect diff --git a/internal/serveredition/broker/credential_resolver.go b/internal/serveredition/broker/credential_resolver.go new file mode 100644 index 000000000..e9038120e --- /dev/null +++ b/internal/serveredition/broker/credential_resolver.go @@ -0,0 +1,334 @@ +//go:build server + +package broker + +import ( + "context" + "errors" + "fmt" + "time" + + "go.uber.org/zap" + "golang.org/x/sync/singleflight" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" +) + +// defaultRefreshThreshold is how close to expiry a cached credential may be +// before the resolver proactively refreshes it. A credential expiring within +// this window is treated as stale (FR-013). +const defaultRefreshThreshold = 60 * time.Second + +// Sentinel errors returned by the resolver. They are deliberately coarse and +// secret-free so they can be surfaced to callers and audited (FR-014/FR-019). +var ( + // ErrUnauthenticated is returned when Resolve is called without a user + // identity. Brokering is strictly per-user; an anonymous caller is rejected + // before any store or upstream access (FR-014). + ErrUnauthenticated = errors.New("credential resolver: unauthenticated caller") + + // ErrNoCredential is returned when no per-user credential can be produced and + // no actionable connect flow is available. There is deliberately no shared or + // static fallback (FR-014). + ErrNoCredential = errors.New("credential resolver: no per-user credential available") + + // ErrBrokerNotConfigured is returned when the target server has no auth_broker + // block. Such upstreams are not brokered and behave exactly as today. + ErrBrokerNotConfigured = errors.New("credential resolver: server has no auth_broker configuration") +) + +// Exchanger mints an upstream credential by exchanging the user's stored IdP +// subject token (token_exchange / entra_obo). *TokenExchanger satisfies it. +type Exchanger interface { + Exchange(ctx context.Context, userID, serverKey string, cfg *config.AuthBrokerConfig) (*UpstreamCredential, error) +} + +// Connector drives the per-user OAuth connect flow (Path B). *OAuthConnector +// satisfies it. The resolver uses Refresh to renew a near-expiry connect-flow +// credential and BuildAuthorizationURL to produce an actionable connect URL +// when the user has not yet connected the upstream. +type Connector interface { + ServerKey() string + BuildAuthorizationURL(userID string) (authURL, state string, err error) + Refresh(ctx context.Context, userID string) (*UpstreamCredential, error) +} + +// ConnectorProvider resolves the per-upstream OAuthConnector for a server. The +// REST layer (T8) supplies an implementation that assembles a ConnectorConfig +// from the server's auth_broker block plus the gateway's callback URL. It is +// only consulted for oauth_connect-mode upstreams. +type ConnectorProvider interface { + ConnectorFor(server *config.ServerConfig) (Connector, error) +} + +// NotConnectedError is returned when an oauth_connect upstream has no per-user +// credential yet. It carries the authorize URL the caller must redirect the +// user to in order to connect the upstream (FR-013, actionable error). +type NotConnectedError struct { + ServerName string + ConnectURL string +} + +func (e *NotConnectedError) Error() string { + return fmt.Sprintf("credential resolver: upstream %q is not connected for this user; connect at: %s", + e.ServerName, e.ConnectURL) +} + +// PolicyDecision is the verdict of the policy-decision seam evaluated before a +// resolved credential is returned. Allow=false blocks the injection. +type PolicyDecision struct { + Allow bool + Reason string +} + +// PolicyInput is the context handed to the policy seam. +type PolicyInput struct { + UserID string + ServerName string + ServerKey string + Credential *UpstreamCredential +} + +// PolicyHook is the policy-decision seam (FR-015). No policy engine ships now; +// the resolver defaults to an allow-all hook. A future engine implements this +// interface without changing the resolver. +type PolicyHook interface { + Evaluate(ctx context.Context, in PolicyInput) (PolicyDecision, error) +} + +// PolicyHookFunc adapts a function to the PolicyHook interface. +type PolicyHookFunc func(ctx context.Context, in PolicyInput) (PolicyDecision, error) + +// Evaluate implements PolicyHook. +func (f PolicyHookFunc) Evaluate(ctx context.Context, in PolicyInput) (PolicyDecision, error) { + return f(ctx, in) +} + +// allowAllPolicy is the default seam implementation: it permits every +// injection. It exists so the resolver always has a non-nil hook (FR-015). +type allowAllPolicy struct{} + +func (allowAllPolicy) Evaluate(_ context.Context, _ PolicyInput) (PolicyDecision, error) { + return PolicyDecision{Allow: true}, nil +} + +// PolicyDeniedError is returned when the policy seam blocks a resolved +// credential from being injected. +type PolicyDeniedError struct { + ServerName string + Reason string +} + +func (e *PolicyDeniedError) Error() string { + if e.Reason != "" { + return fmt.Sprintf("credential resolver: policy denied credential for %q: %s", e.ServerName, e.Reason) + } + return fmt.Sprintf("credential resolver: policy denied credential for %q", e.ServerName) +} + +// ResolverDeps are the collaborators a CredentialResolver needs. Store and +// Exchanger are required for token-exchange upstreams; Connectors is required +// only for oauth_connect upstreams. Policy and Logger are optional. +type ResolverDeps struct { + Store CredentialStore + Exchanger Exchanger + Connectors ConnectorProvider + Policy PolicyHook + Logger *zap.Logger + RefreshThreshold time.Duration +} + +// CredentialResolver produces the per-user upstream credential to inject on a +// proxied request. It applies a strict per-user-only ordering (FR-013/FR-014): +// +// 1. a valid cached per-user credential (refreshed if near-expiry); +// 2. else a freshly token-exchanged / OBO credential from the stored IdP +// subject token; +// 3. else, for oauth_connect upstreams the user has not connected, an +// actionable NotConnectedError carrying the connect URL; +// 4. else ErrNoCredential. +// +// There is no shared or static fallback. Concurrent acquisitions for the same +// (user, server) are coalesced via single-flight so the upstream authorization +// server is not hit with duplicate flows. +type CredentialResolver struct { + store CredentialStore + exchanger Exchanger + conns ConnectorProvider + policy PolicyHook + logger *zap.Logger + + refreshThreshold time.Duration + group singleflight.Group +} + +// NewCredentialResolver constructs a resolver from its dependencies, applying +// defaults for the optional fields. +func NewCredentialResolver(deps ResolverDeps) *CredentialResolver { + logger := deps.Logger + if logger == nil { + logger = zap.NewNop() + } + policy := deps.Policy + if policy == nil { + policy = allowAllPolicy{} + } + threshold := deps.RefreshThreshold + if threshold <= 0 { + threshold = defaultRefreshThreshold + } + return &CredentialResolver{ + store: deps.Store, + exchanger: deps.Exchanger, + conns: deps.Connectors, + policy: policy, + logger: logger.Named("credential-resolver"), + refreshThreshold: threshold, + } +} + +// Resolve returns the per-user credential to inject for (userID, server), +// applying the ordering described on CredentialResolver. The policy seam is +// evaluated per call after acquisition; credential acquisition itself is +// coalesced per (user, server) via single-flight. +func (r *CredentialResolver) Resolve(ctx context.Context, userID string, server *config.ServerConfig) (*UpstreamCredential, error) { + if userID == "" { + return nil, ErrUnauthenticated + } + if server == nil || server.AuthBroker == nil { + return nil, ErrBrokerNotConfigured + } + if r.store == nil || !r.store.Enabled() { + return nil, ErrStoreDisabled + } + + serverKey := oauth.GenerateServerKey(server.Name, server.URL) + + // Coalesce concurrent acquisitions for the same (user, server) so duplicate + // upstream token flows are not triggered (reuse the single-flight pattern). + flightKey := userID + "\x00" + serverKey + v, err, _ := r.group.Do(flightKey, func() (interface{}, error) { + return r.acquire(ctx, userID, serverKey, server) + }) + if err != nil { + return nil, err + } + cred, ok := v.(*UpstreamCredential) + if !ok || cred == nil { + return nil, ErrNoCredential + } + + // Policy-decision seam: evaluated per call, before the credential is handed + // to the caller (FR-015). Default hook allows everything. + decision, perr := r.policy.Evaluate(ctx, PolicyInput{ + UserID: userID, + ServerName: server.Name, + ServerKey: serverKey, + Credential: cred, + }) + if perr != nil { + return nil, fmt.Errorf("credential resolver: policy evaluation failed: %w", perr) + } + if !decision.Allow { + return nil, &PolicyDeniedError{ServerName: server.Name, Reason: decision.Reason} + } + return cred, nil +} + +// acquire runs the per-user-only ordering for a single (user, server). It is +// invoked inside the single-flight group. +func (r *CredentialResolver) acquire(ctx context.Context, userID, serverKey string, server *config.ServerConfig) (*UpstreamCredential, error) { + cfg := server.AuthBroker + + // 1. Cached per-user credential. + cached, err := r.store.Get(userID, serverKey) + switch { + case err == nil && cached != nil: + if cached.IsValid() && !cached.ExpiresWithin(r.refreshThreshold) { + return cached, nil + } + // Near-expiry or expired: refresh in place. If refresh fails, fall + // through to a fresh acquisition rather than serving a stale credential. + if refreshed, rerr := r.refresh(ctx, userID, serverKey, server, cached); rerr == nil { + return refreshed, nil + } else { + r.logger.Warn("near-expiry credential refresh failed; re-acquiring", + zap.String("server", server.Name), zap.Error(rerr)) + } + case errors.Is(err, ErrNotFound): + // No cache: fall through to acquisition. + case err != nil: + // Unexpected store error (not "missing"): surface it. + return nil, fmt.Errorf("credential resolver: load cached credential: %w", err) + } + + // 2 / 3. Acquire fresh based on mode. + switch cfg.Mode { + case config.AuthBrokerModeTokenExchange, config.AuthBrokerModeEntraOBO: + if r.exchanger == nil { + return nil, fmt.Errorf("credential resolver: no token exchanger configured for mode %q", cfg.Mode) + } + return r.exchanger.Exchange(ctx, userID, serverKey, cfg) + case config.AuthBrokerModeOAuthConnect: + // 3. Configured but the user has not connected: return an actionable + // error carrying the connect URL. + conn, cerr := r.connectorFor(server) + if cerr != nil { + return nil, cerr + } + authURL, _, aerr := conn.BuildAuthorizationURL(userID) + if aerr != nil { + return nil, fmt.Errorf("credential resolver: build connect URL: %w", aerr) + } + return nil, &NotConnectedError{ServerName: server.Name, ConnectURL: authURL} + default: + // 4. No recognised acquisition strategy and no per-user credential. + return nil, ErrNoCredential + } +} + +// refresh renews a near-expiry cached credential according to the upstream's +// mode: token-exchange upstreams re-exchange the IdP subject token; connect-flow +// upstreams use the stored refresh token. +func (r *CredentialResolver) refresh(ctx context.Context, userID, serverKey string, server *config.ServerConfig, cached *UpstreamCredential) (*UpstreamCredential, error) { + cfg := server.AuthBroker + switch cfg.Mode { + case config.AuthBrokerModeTokenExchange, config.AuthBrokerModeEntraOBO: + if r.exchanger == nil { + return nil, fmt.Errorf("credential resolver: no token exchanger configured for mode %q", cfg.Mode) + } + return r.exchanger.Exchange(ctx, userID, serverKey, cfg) + case config.AuthBrokerModeOAuthConnect: + if cached.RefreshToken == "" { + return nil, fmt.Errorf("credential resolver: connect-flow credential has no refresh token") + } + conn, cerr := r.connectorFor(server) + if cerr != nil { + return nil, cerr + } + return conn.Refresh(ctx, userID) + default: + return nil, fmt.Errorf("credential resolver: cannot refresh mode %q", cfg.Mode) + } +} + +// connectorFor resolves the per-upstream connector, guarding against a missing +// provider (only oauth_connect upstreams need one). +func (r *CredentialResolver) connectorFor(server *config.ServerConfig) (Connector, error) { + if r.conns == nil { + return nil, fmt.Errorf("credential resolver: no connector provider configured for oauth_connect upstream %q", server.Name) + } + conn, err := r.conns.ConnectorFor(server) + if err != nil { + return nil, fmt.Errorf("credential resolver: resolve connector: %w", err) + } + return conn, nil +} + +// Compile-time assertions that the concrete broker types satisfy the resolver's +// collaborator interfaces. +var ( + _ Exchanger = (*TokenExchanger)(nil) + _ Connector = (*OAuthConnector)(nil) +) diff --git a/internal/serveredition/broker/credential_resolver_test.go b/internal/serveredition/broker/credential_resolver_test.go new file mode 100644 index 000000000..9885628bb --- /dev/null +++ b/internal/serveredition/broker/credential_resolver_test.go @@ -0,0 +1,402 @@ +//go:build server + +package broker + +import ( + "context" + "errors" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" + "github.com/smart-mcp-proxy/mcpproxy-go/internal/oauth" +) + +// --- test doubles ----------------------------------------------------------- + +// fakeStore is an in-memory CredentialStore keyed by (userID, serverKey), +// matching the real backend's keying so resolver key derivation is exercised. +type fakeStore struct { + mu sync.Mutex + enabled bool + data map[string]*UpstreamCredential + getErr error +} + +func newFakeStore() *fakeStore { + return &fakeStore{enabled: true, data: map[string]*UpstreamCredential{}} +} + +func storeKey(userID, serverKey string) string { return userID + "\x00" + serverKey } + +func (s *fakeStore) Enabled() bool { return s.enabled } + +func (s *fakeStore) Get(userID, serverKey string) (*UpstreamCredential, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.getErr != nil { + return nil, s.getErr + } + c, ok := s.data[storeKey(userID, serverKey)] + if !ok { + return nil, ErrNotFound + } + return c, nil +} + +func (s *fakeStore) Put(userID, serverKey string, cred *UpstreamCredential) error { + s.mu.Lock() + defer s.mu.Unlock() + s.data[storeKey(userID, serverKey)] = cred + return nil +} + +func (s *fakeStore) Delete(userID, serverKey string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.data, storeKey(userID, serverKey)) + return nil +} + +func (s *fakeStore) List(userID string) ([]CredentialEntry, error) { return nil, nil } + +func (s *fakeStore) seed(userID, serverKey string, cred *UpstreamCredential) { + s.mu.Lock() + defer s.mu.Unlock() + s.data[storeKey(userID, serverKey)] = cred +} + +// fakeExchanger records calls and returns a programmed credential/error. +type fakeExchanger struct { + calls int32 + cred *UpstreamCredential + err error + delay time.Duration + startWG *sync.WaitGroup +} + +func (e *fakeExchanger) Exchange(_ context.Context, userID, serverKey string, _ *config.AuthBrokerConfig) (*UpstreamCredential, error) { + if e.startWG != nil { + e.startWG.Done() + } + if e.delay > 0 { + time.Sleep(e.delay) + } + atomic.AddInt32(&e.calls, 1) + if e.err != nil { + return nil, e.err + } + return e.cred, nil +} + +// fakeConnector implements Connector for connect-flow paths. +type fakeConnector struct { + serverKey string + authURL string + buildErr error + refreshCred *UpstreamCredential + refreshErr error + buildCalls int32 + refreshCalls int32 +} + +func (c *fakeConnector) ServerKey() string { return c.serverKey } + +func (c *fakeConnector) BuildAuthorizationURL(_ string) (string, string, error) { + atomic.AddInt32(&c.buildCalls, 1) + if c.buildErr != nil { + return "", "", c.buildErr + } + return c.authURL, "state-xyz", nil +} + +func (c *fakeConnector) Refresh(_ context.Context, _ string) (*UpstreamCredential, error) { + atomic.AddInt32(&c.refreshCalls, 1) + if c.refreshErr != nil { + return nil, c.refreshErr + } + return c.refreshCred, nil +} + +type fakeConnectorProvider struct { + conn *fakeConnector + err error +} + +func (p *fakeConnectorProvider) ConnectorFor(_ *config.ServerConfig) (Connector, error) { + if p.err != nil { + return nil, p.err + } + return p.conn, nil +} + +// --- fixtures ---------------------------------------------------------------- + +func httpServer(name string, broker *config.AuthBrokerConfig) *config.ServerConfig { + return &config.ServerConfig{ + Name: name, + URL: "https://" + name + ".example.com/mcp", + Protocol: "http", + AuthBroker: broker, + } +} + +func tokenExchangeBroker() *config.AuthBrokerConfig { + b := &config.AuthBrokerConfig{Mode: config.AuthBrokerModeTokenExchange, TokenEndpoint: "https://idp/token", Scopes: []string{"api"}} + b.ApplyDefaults() + return b +} + +func connectBroker() *config.AuthBrokerConfig { + b := &config.AuthBrokerConfig{ + Mode: config.AuthBrokerModeOAuthConnect, + TokenEndpoint: "https://idp/token", + AuthorizationEndpoint: "https://idp/authorize", + ClientID: "client", + } + b.ApplyDefaults() + return b +} + +func validCred() *UpstreamCredential { + return &UpstreamCredential{Type: "oauth2", AccessToken: "cached-token", ExpiresAt: time.Now().Add(time.Hour), ObtainedVia: "token_exchange"} +} + +// --- tests ------------------------------------------------------------------- + +func TestResolve_ValidCachedCredential(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, validCred()) + + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "fresh"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "cached-token" { + t.Fatalf("expected cached token, got %q", got.AccessToken) + } + if c := atomic.LoadInt32(&ex.calls); c != 0 { + t.Fatalf("expected no exchange calls for valid cache, got %d", c) + } +} + +func TestResolve_NearExpiryRefresh_TokenExchange(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + nearExpiry := &UpstreamCredential{AccessToken: "old", ExpiresAt: time.Now().Add(10 * time.Second)} + store.seed("alice", key, nearExpiry) + + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "refreshed"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "refreshed" { + t.Fatalf("expected refreshed token, got %q", got.AccessToken) + } + if c := atomic.LoadInt32(&ex.calls); c != 1 { + t.Fatalf("expected 1 exchange call, got %d", c) + } +} + +func TestResolve_NearExpiryRefresh_ConnectFlow(t *testing.T) { + store := newFakeStore() + server := httpServer("github", connectBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, &UpstreamCredential{AccessToken: "old", RefreshToken: "rt", ExpiresAt: time.Now().Add(5 * time.Second)}) + + conn := &fakeConnector{serverKey: key, refreshCred: &UpstreamCredential{AccessToken: "refreshed-connect"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Connectors: &fakeConnectorProvider{conn: conn}}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "refreshed-connect" { + t.Fatalf("expected refreshed-connect, got %q", got.AccessToken) + } + if c := atomic.LoadInt32(&conn.refreshCalls); c != 1 { + t.Fatalf("expected 1 refresh call, got %d", c) + } +} + +func TestResolve_NoCache_TokenExchange(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "exchanged"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "exchanged" { + t.Fatalf("expected exchanged token, got %q", got.AccessToken) + } +} + +func TestResolve_NoCache_EntraOBO(t *testing.T) { + store := newFakeStore() + b := &config.AuthBrokerConfig{Mode: config.AuthBrokerModeEntraOBO, TokenEndpoint: "https://login.microsoftonline.com/token"} + b.ApplyDefaults() + server := httpServer("graph", b) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "obo"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.AccessToken != "obo" { + t.Fatalf("expected obo token, got %q", got.AccessToken) + } +} + +func TestResolve_ConnectUnconnected_ReturnsActionableConnectURL(t *testing.T) { + store := newFakeStore() + server := httpServer("github", connectBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + conn := &fakeConnector{serverKey: key, authURL: "https://idp/authorize?client_id=client&state=state-xyz"} + r := NewCredentialResolver(ResolverDeps{Store: store, Connectors: &fakeConnectorProvider{conn: conn}}) + + _, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected NotConnectedError, got nil") + } + var nce *NotConnectedError + if !errors.As(err, &nce) { + t.Fatalf("expected *NotConnectedError, got %T: %v", err, err) + } + if nce.ConnectURL != conn.authURL { + t.Fatalf("expected connect URL %q in error, got %q", conn.authURL, nce.ConnectURL) + } + if !strings.Contains(err.Error(), conn.authURL) { + t.Fatalf("error message must surface the connect URL, got %q", err.Error()) + } +} + +func TestResolve_Unauthenticated_Rejected(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{cred: validCred()} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + _, err := r.Resolve(context.Background(), "", server) + if !errors.Is(err, ErrUnauthenticated) { + t.Fatalf("expected ErrUnauthenticated, got %v", err) + } + if c := atomic.LoadInt32(&ex.calls); c != 0 { + t.Fatalf("expected no work for unauthenticated caller, got %d exchange calls", c) + } +} + +func TestResolve_StoreDisabled_DegradesGracefully(t *testing.T) { + store := newFakeStore() + store.enabled = false + server := httpServer("grafana", tokenExchangeBroker()) + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: &fakeExchanger{cred: validCred()}}) + + _, err := r.Resolve(context.Background(), "alice", server) + if !errors.Is(err, ErrStoreDisabled) { + t.Fatalf("expected ErrStoreDisabled, got %v", err) + } +} + +func TestResolve_NoBrokerConfig_Rejected(t *testing.T) { + store := newFakeStore() + server := httpServer("plain", nil) + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: &fakeExchanger{}}) + + _, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected error for server without auth_broker, got nil") + } +} + +func TestResolve_NoStaticFallback_OnExchangeFailure(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{err: errors.New("token exchange failed: status 401, error \"invalid_grant\"")} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + got, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected the exchange error to propagate (no static fallback), got nil") + } + if got != nil { + t.Fatalf("expected no credential on failure (FR-014, no shared fallback), got %+v", got) + } +} + +func TestResolve_PolicyHook_DeniesInjection(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, validCred()) + + policy := PolicyHookFunc(func(_ context.Context, in PolicyInput) (PolicyDecision, error) { + return PolicyDecision{Allow: false, Reason: "blocked by policy for " + in.ServerName}, nil + }) + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: &fakeExchanger{}, Policy: policy}) + + _, err := r.Resolve(context.Background(), "alice", server) + var pde *PolicyDeniedError + if !errors.As(err, &pde) { + t.Fatalf("expected *PolicyDeniedError, got %T: %v", err, err) + } + if !strings.Contains(pde.Reason, "grafana") { + t.Fatalf("expected reason to include server name, got %q", pde.Reason) + } +} + +func TestResolve_SingleFlight_CoalescesConcurrentAcquisitions(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + + const n = 12 + var start sync.WaitGroup + start.Add(1) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "exchanged"}, delay: 40 * time.Millisecond} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + var wg sync.WaitGroup + errs := make([]error, n) + toks := make([]string, n) + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + start.Wait() + cred, err := r.Resolve(context.Background(), "alice", server) + errs[idx] = err + if cred != nil { + toks[idx] = cred.AccessToken + } + }(i) + } + start.Done() // release all goroutines together + wg.Wait() + + for i := 0; i < n; i++ { + if errs[i] != nil { + t.Fatalf("goroutine %d errored: %v", i, errs[i]) + } + if toks[i] != "exchanged" { + t.Fatalf("goroutine %d got %q", i, toks[i]) + } + } + if c := atomic.LoadInt32(&ex.calls); c != 1 { + t.Fatalf("single-flight should coalesce to 1 upstream acquisition, got %d", c) + } +} From 3cfb19fcbf22a53f327cfaa4829f52d58552d16a Mon Sep 17 00:00:00 2001 From: Algis Dumbris Date: Mon, 15 Jun 2026 17:03:16 +0300 Subject: [PATCH 2/3] =?UTF-8?q?fix(broker):=20address=20review=20on=20Cred?= =?UTF-8?q?entialResolver=20=E2=80=94=20singleflight=20ctx,=20double-excha?= =?UTF-8?q?nge,=20reconnect=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review (MCP-2490 / Critic) on PR #688: - must-fix: detach the caller's context inside the single-flight closure with context.WithoutCancel. The flight runs the acquisition once for all co-pending callers; inheriting the first caller's cancellation let a single client disconnect/timeout broadcast ctx.Err() to every waiter for the same (user, server). Per-caller cancellation still applies at the policy/return layer (caller's original ctx). - advisory: collapse the per-mode acquire/refresh paths so a near-expiry token-exchange miss no longer calls Exchange twice (refresh-then-fallthrough); a single Exchange now covers both cache-miss and near-expiry. - advisory: an already-connected oauth_connect user whose refresh fails now gets an actionable reconnect error (NotConnectedError.Reason) instead of a misleading "never connected" message; the connect URL is still surfaced. - advisory: document that the Exchanger (T4) and Connector (T5) persist their results themselves, so the resolver never calls store.Put. Tests: add single-flight caller-cancellation detach, no-double-exchange on near-expiry failure, and connect-flow refresh-fail -> reconnect. Full broker suite green under -tags server -race; golangci-lint v2.5.0 clean. Co-Authored-By: Paperclip --- .../broker/credential_resolver.go | 105 ++++++++++-------- .../broker/credential_resolver_test.go | 95 +++++++++++++++- 2 files changed, 149 insertions(+), 51 deletions(-) diff --git a/internal/serveredition/broker/credential_resolver.go b/internal/serveredition/broker/credential_resolver.go index e9038120e..932cc75de 100644 --- a/internal/serveredition/broker/credential_resolver.go +++ b/internal/serveredition/broker/credential_resolver.go @@ -62,15 +62,23 @@ type ConnectorProvider interface { ConnectorFor(server *config.ServerConfig) (Connector, error) } -// NotConnectedError is returned when an oauth_connect upstream has no per-user -// credential yet. It carries the authorize URL the caller must redirect the -// user to in order to connect the upstream (FR-013, actionable error). +// NotConnectedError is returned when an oauth_connect upstream cannot produce a +// usable per-user credential and the user must (re)consent. It carries the +// authorize URL the caller redirects the user to (FR-013, actionable error) and +// a Reason that distinguishes a first-time connect from an expired credential +// whose refresh failed (so callers do not tell an already-connected user they +// have "never connected"). type NotConnectedError struct { ServerName string ConnectURL string + Reason string } func (e *NotConnectedError) Error() string { + if e.Reason != "" { + return fmt.Sprintf("credential resolver: upstream %q requires connection (%s); connect at: %s", + e.ServerName, e.Reason, e.ConnectURL) + } return fmt.Sprintf("credential resolver: upstream %q is not connected for this user; connect at: %s", e.ServerName, e.ConnectURL) } @@ -207,9 +215,16 @@ func (r *CredentialResolver) Resolve(ctx context.Context, userID string, server // Coalesce concurrent acquisitions for the same (user, server) so duplicate // upstream token flows are not triggered (reuse the single-flight pattern). + // + // The flight runs the acquisition once for every co-pending caller. Detach + // the caller's cancellation with context.WithoutCancel so the in-flight + // acquisition is not aborted — and its error broadcast to all waiters — just + // because whichever caller happened to start the flight cancelled (client + // disconnect, timeout). Per-caller cancellation still applies below at the + // policy/return layer, which uses the caller's original ctx. flightKey := userID + "\x00" + serverKey v, err, _ := r.group.Do(flightKey, func() (interface{}, error) { - return r.acquire(ctx, userID, serverKey, server) + return r.acquire(context.WithoutCancel(ctx), userID, serverKey, server) }) if err != nil { return nil, err @@ -238,79 +253,79 @@ func (r *CredentialResolver) Resolve(ctx context.Context, userID string, server // acquire runs the per-user-only ordering for a single (user, server). It is // invoked inside the single-flight group. +// +// Acquisition and refresh share a path per mode so a near-expiry cache miss does +// not trigger a redundant double acquisition. The Exchanger (T4) and Connector +// (T5) persist their results into the store themselves, so the resolver never +// calls store.Put — it only reads the cache via store.Get. func (r *CredentialResolver) acquire(ctx context.Context, userID, serverKey string, server *config.ServerConfig) (*UpstreamCredential, error) { cfg := server.AuthBroker - // 1. Cached per-user credential. + // 1. Serve a still-valid, not-near-expiry cached credential directly. cached, err := r.store.Get(userID, serverKey) + hasCache := err == nil && cached != nil switch { - case err == nil && cached != nil: + case hasCache: if cached.IsValid() && !cached.ExpiresWithin(r.refreshThreshold) { return cached, nil } - // Near-expiry or expired: refresh in place. If refresh fails, fall - // through to a fresh acquisition rather than serving a stale credential. - if refreshed, rerr := r.refresh(ctx, userID, serverKey, server, cached); rerr == nil { - return refreshed, nil - } else { - r.logger.Warn("near-expiry credential refresh failed; re-acquiring", - zap.String("server", server.Name), zap.Error(rerr)) - } + // Stale / near-expiry: renewed by the per-mode path below. case errors.Is(err, ErrNotFound): - // No cache: fall through to acquisition. - case err != nil: + // No cache: acquired by the per-mode path below. + default: // Unexpected store error (not "missing"): surface it. return nil, fmt.Errorf("credential resolver: load cached credential: %w", err) } - // 2 / 3. Acquire fresh based on mode. switch cfg.Mode { case config.AuthBrokerModeTokenExchange, config.AuthBrokerModeEntraOBO: + // 2. Token-exchange / OBO: the first-acquisition and refresh paths are + // identical (re-mint from the stored IdP subject token), so a single + // Exchange call covers both the cache-miss and near-expiry cases. if r.exchanger == nil { return nil, fmt.Errorf("credential resolver: no token exchanger configured for mode %q", cfg.Mode) } return r.exchanger.Exchange(ctx, userID, serverKey, cfg) + case config.AuthBrokerModeOAuthConnect: - // 3. Configured but the user has not connected: return an actionable - // error carrying the connect URL. conn, cerr := r.connectorFor(server) if cerr != nil { return nil, cerr } - authURL, _, aerr := conn.BuildAuthorizationURL(userID) - if aerr != nil { - return nil, fmt.Errorf("credential resolver: build connect URL: %w", aerr) + // A cached connect-flow credential means the user already connected: + // renew transparently via the stored refresh token. Only when that + // refresh fails do we ask the (already-connected) user to reconnect. + if hasCache && cached.RefreshToken != "" { + refreshed, rerr := conn.Refresh(ctx, userID) + if rerr == nil { + return refreshed, nil + } + r.logger.Warn("connect-flow credential refresh failed; user must reconnect", + zap.String("server", server.Name), zap.Error(rerr)) + return nil, r.notConnected(conn, server, userID, "stored credential expired and refresh failed; reconnect required") + } + // 3. Never connected, or connected without a usable refresh token and now + // expired — both require (re)consent through the connect flow. + reason := "not connected" + if hasCache { + reason = "stored credential expired; reconnect required" } - return nil, &NotConnectedError{ServerName: server.Name, ConnectURL: authURL} + return nil, r.notConnected(conn, server, userID, reason) + default: // 4. No recognised acquisition strategy and no per-user credential. return nil, ErrNoCredential } } -// refresh renews a near-expiry cached credential according to the upstream's -// mode: token-exchange upstreams re-exchange the IdP subject token; connect-flow -// upstreams use the stored refresh token. -func (r *CredentialResolver) refresh(ctx context.Context, userID, serverKey string, server *config.ServerConfig, cached *UpstreamCredential) (*UpstreamCredential, error) { - cfg := server.AuthBroker - switch cfg.Mode { - case config.AuthBrokerModeTokenExchange, config.AuthBrokerModeEntraOBO: - if r.exchanger == nil { - return nil, fmt.Errorf("credential resolver: no token exchanger configured for mode %q", cfg.Mode) - } - return r.exchanger.Exchange(ctx, userID, serverKey, cfg) - case config.AuthBrokerModeOAuthConnect: - if cached.RefreshToken == "" { - return nil, fmt.Errorf("credential resolver: connect-flow credential has no refresh token") - } - conn, cerr := r.connectorFor(server) - if cerr != nil { - return nil, cerr - } - return conn.Refresh(ctx, userID) - default: - return nil, fmt.Errorf("credential resolver: cannot refresh mode %q", cfg.Mode) +// notConnected builds the actionable NotConnectedError carrying the upstream +// authorize URL the caller must redirect the user to, tagged with reason. +func (r *CredentialResolver) notConnected(conn Connector, server *config.ServerConfig, userID, reason string) error { + authURL, _, aerr := conn.BuildAuthorizationURL(userID) + if aerr != nil { + return fmt.Errorf("credential resolver: build connect URL: %w", aerr) } + return &NotConnectedError{ServerName: server.Name, ConnectURL: authURL, Reason: reason} } // connectorFor resolves the per-upstream connector, guarding against a missing diff --git a/internal/serveredition/broker/credential_resolver_test.go b/internal/serveredition/broker/credential_resolver_test.go index 9885628bb..0ae0fb6b2 100644 --- a/internal/serveredition/broker/credential_resolver_test.go +++ b/internal/serveredition/broker/credential_resolver_test.go @@ -71,14 +71,16 @@ func (s *fakeStore) seed(userID, serverKey string, cred *UpstreamCredential) { // fakeExchanger records calls and returns a programmed credential/error. type fakeExchanger struct { - calls int32 - cred *UpstreamCredential - err error - delay time.Duration - startWG *sync.WaitGroup + calls int32 + cred *UpstreamCredential + err error + delay time.Duration + startWG *sync.WaitGroup + gotCtxErr error } -func (e *fakeExchanger) Exchange(_ context.Context, userID, serverKey string, _ *config.AuthBrokerConfig) (*UpstreamCredential, error) { +func (e *fakeExchanger) Exchange(ctx context.Context, userID, serverKey string, _ *config.AuthBrokerConfig) (*UpstreamCredential, error) { + e.gotCtxErr = ctx.Err() if e.startWG != nil { e.startWG.Done() } @@ -400,3 +402,84 @@ func TestResolve_SingleFlight_CoalescesConcurrentAcquisitions(t *testing.T) { t.Fatalf("single-flight should coalesce to 1 upstream acquisition, got %d", c) } } + +// TestResolve_SingleFlight_DetachesCallerCancellation proves the must-fix from +// review: the in-flight acquisition must not inherit the calling request's +// cancellation, or a cancelled caller would broadcast its ctx error to every +// co-pending acquisition for the same (user, server). +func TestResolve_SingleFlight_DetachesCallerCancellation(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + ex := &fakeExchanger{cred: &UpstreamCredential{AccessToken: "exchanged"}} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // caller's request is already cancelled before acquisition runs + + got, err := r.Resolve(ctx, "alice", server) + if err != nil { + t.Fatalf("acquisition should run despite caller cancellation, got error: %v", err) + } + if got.AccessToken != "exchanged" { + t.Fatalf("expected exchanged token, got %q", got.AccessToken) + } + if ex.gotCtxErr != nil { + t.Fatalf("flight context must be detached from caller cancellation, got ctx.Err()=%v", ex.gotCtxErr) + } +} + +// TestResolve_TokenExchange_NearExpiry_NoDoubleExchangeOnFailure proves the +// advisory fix: a near-expiry token-exchange credential whose re-mint fails must +// surface that single error, not retry Exchange a second time. +func TestResolve_TokenExchange_NearExpiry_NoDoubleExchangeOnFailure(t *testing.T) { + store := newFakeStore() + server := httpServer("grafana", tokenExchangeBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, &UpstreamCredential{AccessToken: "old", ExpiresAt: time.Now().Add(5 * time.Second)}) + + ex := &fakeExchanger{err: errors.New("token exchange failed: status 401, error \"invalid_grant\"")} + r := NewCredentialResolver(ResolverDeps{Store: store, Exchanger: ex}) + + _, err := r.Resolve(context.Background(), "alice", server) + if err == nil { + t.Fatal("expected the exchange error to propagate, got nil") + } + if c := atomic.LoadInt32(&ex.calls); c != 1 { + t.Fatalf("near-expiry exchange failure must not double-call Exchange, got %d calls", c) + } +} + +// TestResolve_ConnectFlow_RefreshFails_ReturnsReconnectError proves the advisory +// fix: an already-connected user whose refresh fails gets an actionable +// reconnect error (with the connect URL), not a misleading "never connected". +func TestResolve_ConnectFlow_RefreshFails_ReturnsReconnectError(t *testing.T) { + store := newFakeStore() + server := httpServer("github", connectBroker()) + key := oauth.GenerateServerKey(server.Name, server.URL) + store.seed("alice", key, &UpstreamCredential{AccessToken: "old", RefreshToken: "rt", ExpiresAt: time.Now().Add(5 * time.Second)}) + + conn := &fakeConnector{ + serverKey: key, + authURL: "https://idp/authorize?client_id=client&state=state-xyz", + refreshErr: errors.New("oauth connector: token endpoint returned 400: invalid_grant"), + } + r := NewCredentialResolver(ResolverDeps{Store: store, Connectors: &fakeConnectorProvider{conn: conn}}) + + _, err := r.Resolve(context.Background(), "alice", server) + var nce *NotConnectedError + if !errors.As(err, &nce) { + t.Fatalf("expected *NotConnectedError, got %T: %v", err, err) + } + if nce.Reason == "" || !strings.Contains(nce.Reason, "reconnect") { + t.Fatalf("expected a reconnect reason, got %q", nce.Reason) + } + if nce.ConnectURL != conn.authURL { + t.Fatalf("expected connect URL %q, got %q", conn.authURL, nce.ConnectURL) + } + if c := atomic.LoadInt32(&conn.refreshCalls); c != 1 { + t.Fatalf("expected exactly 1 refresh attempt, got %d", c) + } + if c := atomic.LoadInt32(&conn.buildCalls); c != 1 { + t.Fatalf("expected the connect URL to be built once, got %d", c) + } +} From 18215de024407762da4305823e8cb03aa94fbf72 Mon Sep 17 00:00:00 2001 From: Algis Dumbris Date: Mon, 15 Jun 2026 17:18:51 +0300 Subject: [PATCH 3/3] docs(auth-broker): document per-user credential resolution ordering (spec 074, MCP-1039) Add a "Credential resolution" section describing the resolver's strict per-user-only ordering (cached/refresh -> token-exchange/OBO -> actionable connect-URL error -> no-credential), the no-shared/static-fallback guarantee, single-flight coalescing, and the policy-decision seam. Keeps the feature doc consistent with the CredentialResolver added in this PR (review follow-up). Co-Authored-By: Paperclip --- docs/features/auth-broker.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/features/auth-broker.md b/docs/features/auth-broker.md index 55e687c79..96b44a662 100644 --- a/docs/features/auth-broker.md +++ b/docs/features/auth-broker.md @@ -79,6 +79,17 @@ Config validation fails with `auth_broker.authorization_endpoint is required for A denied consent (`error=access_denied`) clears the pending flow and stores nothing. +## Credential resolution + +On each proxied request the broker resolves the per-user credential to inject, in a strict **per-user-only** order. There is **no shared or static fallback** — a request that cannot produce a per-user credential fails rather than borrowing another identity: + +1. A valid cached per-user credential is injected directly; if it is within the near-expiry window it is refreshed first (re-minted for `token_exchange`/`entra_obo`, or renewed from the stored refresh token for `oauth_connect`). +2. Otherwise, for `token_exchange`/`entra_obo`, a credential is minted from the user's stored IdP subject token. +3. Otherwise, for `oauth_connect` upstreams the user has not connected — or whose stored credential expired and could not be refreshed — the request fails with an **actionable error carrying the connect URL**, so the user is told to (re)connect rather than being silently denied. +4. Otherwise the request fails with "no per-user credential available". + +Concurrent requests for the same `(user, upstream)` are coalesced (single-flight) so a burst does not trigger duplicate upstream token flows. A policy-decision hook is evaluated per call immediately before the credential is returned; no policy engine ships yet, so it permits every injection by default. + ## See also - [OAuth Authentication](./oauth-authentication.md) — upstream OAuth for the personal edition.