diff --git a/auth.go b/auth.go index cb153ec..333c3ef 100644 --- a/auth.go +++ b/auth.go @@ -142,10 +142,12 @@ func ensureFreshToken( full = loadFull() // loadFull (loadConfig) builds a fresh store instance that could resolve // to a different backend/path than the one we loaded the token from. - // Pin the store and client ID to the originally loaded config so the - // refreshed token is saved exactly where it came from. + // Pin the store, client ID, and token-file path to the originally loaded + // config so the refreshed token is saved exactly where it came from and + // the cross-process refresh lock sits next to that same file. full.Store = cfg.Store full.ClientID = cfg.ClientID + full.TokenFile = cfg.TokenFile } // Resolve endpoints lazily too. Callers that pre-populate Endpoints skip @@ -154,7 +156,7 @@ func ensureFreshToken( resolveEndpoints(ctx, full) } - newTok, err := refreshAccessToken(ctx, full, tok.RefreshToken) + newTok, err := refreshAccessToken(ctx, full, tok) if err != nil { return reuseOrFail(err) } @@ -162,11 +164,43 @@ func ensureFreshToken( } // refreshAccessToken exchanges a refresh token for a new access token. +// +// It takes a cross-process advisory lock before the critical section so that +// concurrent CLI invocations cannot spend the same refresh token twice — +// which would cause one of them to receive invalid_grant on rotation servers. +// After acquiring the lock, it re-reads the store: if a peer process has +// already refreshed, the peer's fresh token is returned without a network +// call. func refreshAccessToken( ctx context.Context, cfg *AppConfig, - refreshToken string, + stale credstore.Token, ) (*credstore.Token, error) { + unlock, err := lockTokenStore(ctx, cfg.TokenFile, cfg.ClientID) + if err != nil { + return nil, fmt.Errorf("acquire refresh lock: %w", err) + } + defer unlock.Close() + + refreshToken := stale.RefreshToken + + // Peer check: another process may have refreshed while we waited for the lock. + if fresh, loadErr := cfg.Store.Load(cfg.ClientID); loadErr == nil { + // A peer refreshed if the stored token is now usable (present and not + // expired — same predicate the reuse paths use) and differs from the + // stale copy we were handed. + peerRefreshed := tokenUsable(fresh, time.Now()) && + (fresh.AccessToken != stale.AccessToken || + fresh.RefreshToken != stale.RefreshToken) + if peerRefreshed { + return &fresh, nil + } + // Peer may have rotated the refresh token without updating our view. + if fresh.RefreshToken != "" { + refreshToken = fresh.RefreshToken + } + } + ctx, cancel := context.WithTimeout(ctx, cfg.RefreshTokenTimeout) defer cancel() @@ -246,7 +280,7 @@ func makeAPICallWithAutoRefresh( if resp.StatusCode == http.StatusUnauthorized { ui.ShowStatus(tui.StatusUpdate{Event: tui.EventAccessTokenRejected}) - newStorage, err := refreshAccessToken(ctx, cfg, storage.RefreshToken) + newStorage, err := refreshAccessToken(ctx, cfg, *storage) if err != nil { if errors.Is(err, ErrRefreshTokenExpired) { return ErrRefreshTokenExpired diff --git a/config.go b/config.go index 307c305..d3c54cd 100644 --- a/config.go +++ b/config.go @@ -100,6 +100,7 @@ type AppConfig struct { Scope string ForceDevice bool TokenStoreMode string // "auto", "file", or "keyring" + TokenFile string // path used for the file backend and for the cross-process refresh lock RetryClient *retry.Client Store credstore.Store[credstore.Token] @@ -191,7 +192,7 @@ func loadStoreConfig() *AppConfig { cfg.ClientID = getConfig(flagClientID, "CLIENT_ID", "") cfg.TokenStoreMode = getConfig(flagTokenStore, "TOKEN_STORE", "auto") - tokenFile := getConfig(flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json") + cfg.TokenFile = getConfig(flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json") // Resolved here (not loadConfig) so the offline `token get` path can decide // whether a refresh is due without building the full network config. @@ -207,7 +208,7 @@ func loadStoreConfig() *AppConfig { } var storeErr error - cfg.Store, storeErr = newTokenStore(cfg.TokenStoreMode, tokenFile, defaultKeyringService) + cfg.Store, storeErr = newTokenStore(cfg.TokenStoreMode, cfg.TokenFile, defaultKeyringService) if storeErr != nil { fmt.Fprintln(os.Stderr, storeErr) os.Exit(1) diff --git a/extra_claims_test.go b/extra_claims_test.go index f18f587..eaaeded 100644 --- a/extra_claims_test.go +++ b/extra_claims_test.go @@ -11,6 +11,8 @@ import ( "strings" "testing" "time" + + "github.com/go-authgate/sdk-go/credstore" ) // ----------------------------------------------------------------------- @@ -244,7 +246,11 @@ func TestRefreshAccessToken_SendsExtraClaims(t *testing.T) { cfg.Endpoints = defaultEndpoints(gotForm.serverURL) cfg.ExtraClaims = `{"project":"acme"}` - if _, err := refreshAccessToken(context.Background(), cfg, "old-refresh"); err != nil { + if _, err := refreshAccessToken( + context.Background(), + cfg, + credstore.Token{RefreshToken: "old-refresh"}, + ); err != nil { t.Fatalf("refreshAccessToken() error: %v", err) } @@ -265,7 +271,11 @@ func TestRefreshAccessToken_OmitsExtraClaimsWhenUnset(t *testing.T) { cfg.Endpoints = defaultEndpoints(gotForm.serverURL) cfg.ExtraClaims = "" - if _, err := refreshAccessToken(context.Background(), cfg, "old-refresh"); err != nil { + if _, err := refreshAccessToken( + context.Background(), + cfg, + credstore.Token{RefreshToken: "old-refresh"}, + ); err != nil { t.Fatalf("refreshAccessToken() error: %v", err) } @@ -348,7 +358,11 @@ func TestRefreshAccessToken_ExtraClaimsSurviveURLEncoding(t *testing.T) { } cfg.ExtraClaims = resolved - if _, err := refreshAccessToken(context.Background(), cfg, "old-refresh"); err != nil { + if _, err := refreshAccessToken( + context.Background(), + cfg, + credstore.Token{RefreshToken: "old-refresh"}, + ); err != nil { t.Fatalf("refreshAccessToken() error: %v", err) } got := gotForm.value(t).Get("extra_claims") diff --git a/go.mod b/go.mod index 3094ca4..6390e8f 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( charm.land/lipgloss/v2 v2.0.3 github.com/appleboy/go-httpretry v0.12.0 github.com/go-authgate/sdk-go v0.11.0 + github.com/gofrs/flock v0.13.0 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/mattn/go-isatty v0.0.22 diff --git a/go.sum b/go.sum index 576549e..3bd9269 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/go-authgate/sdk-go v0.11.0 h1:ZTfJ0rzeDn4QBqAmF9VKS3CqlKhE8+0tJxg8OGN github.com/go-authgate/sdk-go v0.11.0/go.mod h1:sa0ige5wtayj2WcnXlxa8wGuyi5z/c/chc0mXPJTl/Q= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= +github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/lock.go b/lock.go new file mode 100644 index 0000000..14f8193 --- /dev/null +++ b/lock.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "time" + + "github.com/gofrs/flock" +) + +const lockRetryInterval = 100 * time.Millisecond + +// lockTokenStore acquires a cross-process advisory lock scoped to +// (tokenFile, clientID). It serialises the "load → refresh → save" +// critical section so concurrent CLI invocations cannot spend the same +// refresh token twice (which would yield invalid_grant on rotation servers). +// +// The returned io.Closer releases the lock on Close (flock.Close is documented +// as equivalent to Unlock). +// +// The lock sits next to the token file regardless of the active backend — +// keyring-backed runs also need the coordination because the race is in the +// refresh flow, not the storage layer. +func lockTokenStore(ctx context.Context, tokenFile, clientID string) (io.Closer, error) { + if tokenFile == "" { + return nil, errors.New("lock: tokenFile is empty") + } + if clientID == "" { + return nil, errors.New("lock: clientID is empty") + } + + dir := filepath.Dir(tokenFile) + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, fmt.Errorf("create lock directory %q: %w", dir, err) + } + + lockPath := filepath.Join(dir, filepath.Base(tokenFile)+"."+clientID+".lock") + fl := flock.New(lockPath) + locked, err := fl.TryLockContext(ctx, lockRetryInterval) + if err != nil { + return nil, fmt.Errorf("acquire lock %s: %w", lockPath, err) + } + if !locked { + return nil, fmt.Errorf("could not acquire lock %s", lockPath) + } + return fl, nil +} diff --git a/main.go b/main.go index 086b038..9f893bf 100644 --- a/main.go +++ b/main.go @@ -118,7 +118,7 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int { } else { ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenExpired}) } - newStorage, refreshErr := refreshAccessToken(ctx, cfg, existing.RefreshToken) + newStorage, refreshErr := refreshAccessToken(ctx, cfg, existing) if refreshErr != nil { // The refresh genuinely failed, so mark it failed in the UI. // Then degrade gracefully (reuse the still-valid token) or diff --git a/main_test.go b/main_test.go index 752fc8c..9249f2e 100644 --- a/main_test.go +++ b/main_test.go @@ -25,14 +25,14 @@ func testConfig(t *testing.T) *AppConfig { t.Fatalf("failed to create retry client: %v", err) } serverURL := "http://localhost:8080" + tokenFile := filepath.Join(t.TempDir(), "tokens.json") return &AppConfig{ - ServerURL: serverURL, - ClientID: "test-client", - Scope: "email profile", - RetryClient: rc, - Store: credstore.NewTokenFileStore( - filepath.Join(t.TempDir(), "tokens.json"), - ), + ServerURL: serverURL, + ClientID: "test-client", + Scope: "email profile", + RetryClient: rc, + TokenFile: tokenFile, + Store: credstore.NewTokenFileStore(tokenFile), Endpoints: defaultEndpoints(serverURL), TokenExchangeTimeout: defaultTokenExchangeTimeout, TokenVerificationTimeout: defaultTokenVerificationTimeout, @@ -317,7 +317,11 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) { cfg.Endpoints = defaultEndpoints(srv.URL) cfg.ClientID = "test-client-rotation" - storage, err := refreshAccessToken(context.Background(), cfg, tt.oldRefreshToken) + storage, err := refreshAccessToken( + context.Background(), + cfg, + credstore.Token{RefreshToken: tt.oldRefreshToken}, + ) if err != nil { t.Fatalf("refreshAccessToken() error: %v", err) } @@ -332,6 +336,57 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) { } } +// TestRefreshAccessToken_PeerAlreadyRefreshed verifies that when a peer process +// has already refreshed and saved the tokens, refreshAccessToken returns the +// stored fresh token without making a network call. This guards against the +// refresh-token-rotation race that motivated the cross-process lock. +func TestRefreshAccessToken_PeerAlreadyRefreshed(t *testing.T) { + var serverCalled atomic.Bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + serverCalled.Store(true) + http.Error(w, "should not be called", http.StatusInternalServerError) + })) + defer srv.Close() + + cfg := testConfig(t) + cfg.ServerURL = srv.URL + cfg.Endpoints = defaultEndpoints(srv.URL) + cfg.ClientID = "peer-refresh-test" + + peerFresh := credstore.Token{ + AccessToken: "peer-fresh-access", + RefreshToken: "peer-fresh-refresh", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(1 * time.Hour), + ClientID: cfg.ClientID, + } + if err := cfg.Store.Save(cfg.ClientID, peerFresh); err != nil { + t.Fatalf("setup save: %v", err) + } + + stale := credstore.Token{ + AccessToken: "stale-access", + RefreshToken: "stale-refresh", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(-1 * time.Minute), + ClientID: cfg.ClientID, + } + + got, err := refreshAccessToken(context.Background(), cfg, stale) + if err != nil { + t.Fatalf("refreshAccessToken() error: %v", err) + } + if serverCalled.Load() { + t.Fatalf("network refresh was performed; peer-refresh shortcut was expected") + } + if got.AccessToken != peerFresh.AccessToken { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, peerFresh.AccessToken) + } + if got.RefreshToken != peerFresh.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, peerFresh.RefreshToken) + } +} + // ----------------------------------------------------------------------- // Device code request with retry // ----------------------------------------------------------------------- diff --git a/token_cmd_test.go b/token_cmd_test.go index 69f2cbb..3b7e6c3 100644 --- a/token_cmd_test.go +++ b/token_cmd_test.go @@ -455,9 +455,8 @@ func tokenGetRefreshConfig(t *testing.T, tokenURL string, tok credstore.Token) * if err != nil { t.Fatal(err) } - store := credstore.NewTokenFileStore( - filepath.Join(t.TempDir(), "tokens.json"), - ) + tokenFile := filepath.Join(t.TempDir(), "tokens.json") + store := credstore.NewTokenFileStore(tokenFile) if err := store.Save("test-id", tok); err != nil { t.Fatal(err) } @@ -468,6 +467,7 @@ func tokenGetRefreshConfig(t *testing.T, tokenURL string, tok credstore.Token) * MaxResponseBodySize: defaultMaxResponseBodySize, RefreshThreshold: defaultRefreshThreshold, RetryClient: rc, + TokenFile: tokenFile, Store: store, } } @@ -742,7 +742,8 @@ func TestRunTokenGet_DefersDiscoveryUntilRefresh(t *testing.T) { if err != nil { t.Fatal(err) } - store := credstore.NewTokenFileStore(filepath.Join(t.TempDir(), "tokens.json")) + tokenFile := filepath.Join(t.TempDir(), "tokens.json") + store := credstore.NewTokenFileStore(tokenFile) if err := store.Save("test-id", credstore.Token{ AccessToken: "old-access-token", RefreshToken: "refresh-456", @@ -759,6 +760,7 @@ func TestRunTokenGet_DefersDiscoveryUntilRefresh(t *testing.T) { DiscoveryTimeout: defaultDiscoveryTimeout, MaxResponseBodySize: defaultMaxResponseBodySize, RetryClient: rc, + TokenFile: tokenFile, Store: store, } @@ -877,7 +879,8 @@ func TestRunTokenGet_LazyFullConfig(t *testing.T) { if err != nil { t.Fatal(err) } - store := credstore.NewTokenFileStore(filepath.Join(t.TempDir(), "tokens.json")) + tokenFile := filepath.Join(t.TempDir(), "tokens.json") + store := credstore.NewTokenFileStore(tokenFile) if err := store.Save("test-id", credstore.Token{ AccessToken: "old-access-token", RefreshToken: "refresh-456", @@ -891,6 +894,7 @@ func TestRunTokenGet_LazyFullConfig(t *testing.T) { partial := &AppConfig{ ClientID: "test-id", Store: store, + TokenFile: tokenFile, RefreshThreshold: defaultRefreshThreshold, }