Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,12 @@ func ensureFreshToken(
full = loadFull()
// loadFull (loadConfig) builds a fresh store instance that could resolve
// to a different backend/path than the one we loaded the token from.
// Pin the store and client ID to the originally loaded config so the
// refreshed token is saved exactly where it came from.
// Pin the store, client ID, and token-file path to the originally loaded
// config so the refreshed token is saved exactly where it came from and
// the cross-process refresh lock sits next to that same file.
full.Store = cfg.Store
full.ClientID = cfg.ClientID
full.TokenFile = cfg.TokenFile
}

// Resolve endpoints lazily too. Callers that pre-populate Endpoints skip
Expand All @@ -154,19 +156,51 @@ func ensureFreshToken(
resolveEndpoints(ctx, full)
}

newTok, err := refreshAccessToken(ctx, full, tok.RefreshToken)
newTok, err := refreshAccessToken(ctx, full, tok)
if err != nil {
return reuseOrFail(err)
}
return *newTok, true, nil
}

// refreshAccessToken exchanges a refresh token for a new access token.
//
// It takes a cross-process advisory lock before the critical section so that
// concurrent CLI invocations cannot spend the same refresh token twice —
// which would cause one of them to receive invalid_grant on rotation servers.
// After acquiring the lock, it re-reads the store: if a peer process has
// already refreshed, the peer's fresh token is returned without a network
// call.
func refreshAccessToken(
ctx context.Context,
cfg *AppConfig,
refreshToken string,
stale credstore.Token,
) (*credstore.Token, error) {
unlock, err := lockTokenStore(ctx, cfg.TokenFile, cfg.ClientID)
if err != nil {
return nil, fmt.Errorf("acquire refresh lock: %w", err)
}
defer unlock.Close()

refreshToken := stale.RefreshToken

// Peer check: another process may have refreshed while we waited for the lock.
if fresh, loadErr := cfg.Store.Load(cfg.ClientID); loadErr == nil {
// A peer refreshed if the stored token is now usable (present and not
// expired — same predicate the reuse paths use) and differs from the
// stale copy we were handed.
peerRefreshed := tokenUsable(fresh, time.Now()) &&
(fresh.AccessToken != stale.AccessToken ||
fresh.RefreshToken != stale.RefreshToken)
if peerRefreshed {
return &fresh, nil
}
// Peer may have rotated the refresh token without updating our view.
if fresh.RefreshToken != "" {
refreshToken = fresh.RefreshToken
}
}

ctx, cancel := context.WithTimeout(ctx, cfg.RefreshTokenTimeout)
defer cancel()

Expand Down Expand Up @@ -246,7 +280,7 @@ func makeAPICallWithAutoRefresh(
if resp.StatusCode == http.StatusUnauthorized {
ui.ShowStatus(tui.StatusUpdate{Event: tui.EventAccessTokenRejected})

newStorage, err := refreshAccessToken(ctx, cfg, storage.RefreshToken)
newStorage, err := refreshAccessToken(ctx, cfg, *storage)
if err != nil {
if errors.Is(err, ErrRefreshTokenExpired) {
return ErrRefreshTokenExpired
Expand Down
5 changes: 3 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ type AppConfig struct {
Scope string
ForceDevice bool
TokenStoreMode string // "auto", "file", or "keyring"
TokenFile string // path used for the file backend and for the cross-process refresh lock
RetryClient *retry.Client
Store credstore.Store[credstore.Token]

Expand Down Expand Up @@ -191,7 +192,7 @@ func loadStoreConfig() *AppConfig {

cfg.ClientID = getConfig(flagClientID, "CLIENT_ID", "")
cfg.TokenStoreMode = getConfig(flagTokenStore, "TOKEN_STORE", "auto")
tokenFile := getConfig(flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json")
cfg.TokenFile = getConfig(flagTokenFile, "TOKEN_FILE", ".authgate-tokens.json")

// Resolved here (not loadConfig) so the offline `token get` path can decide
// whether a refresh is due without building the full network config.
Expand All @@ -207,7 +208,7 @@ func loadStoreConfig() *AppConfig {
}

var storeErr error
cfg.Store, storeErr = newTokenStore(cfg.TokenStoreMode, tokenFile, defaultKeyringService)
cfg.Store, storeErr = newTokenStore(cfg.TokenStoreMode, cfg.TokenFile, defaultKeyringService)
if storeErr != nil {
fmt.Fprintln(os.Stderr, storeErr)
os.Exit(1)
Expand Down
20 changes: 17 additions & 3 deletions extra_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"strings"
"testing"
"time"

"github.com/go-authgate/sdk-go/credstore"
)

// -----------------------------------------------------------------------
Expand Down Expand Up @@ -244,7 +246,11 @@ func TestRefreshAccessToken_SendsExtraClaims(t *testing.T) {
cfg.Endpoints = defaultEndpoints(gotForm.serverURL)
cfg.ExtraClaims = `{"project":"acme"}`

if _, err := refreshAccessToken(context.Background(), cfg, "old-refresh"); err != nil {
if _, err := refreshAccessToken(
context.Background(),
cfg,
credstore.Token{RefreshToken: "old-refresh"},
); err != nil {
t.Fatalf("refreshAccessToken() error: %v", err)
}

Expand All @@ -265,7 +271,11 @@ func TestRefreshAccessToken_OmitsExtraClaimsWhenUnset(t *testing.T) {
cfg.Endpoints = defaultEndpoints(gotForm.serverURL)
cfg.ExtraClaims = ""

if _, err := refreshAccessToken(context.Background(), cfg, "old-refresh"); err != nil {
if _, err := refreshAccessToken(
context.Background(),
cfg,
credstore.Token{RefreshToken: "old-refresh"},
); err != nil {
t.Fatalf("refreshAccessToken() error: %v", err)
}

Expand Down Expand Up @@ -348,7 +358,11 @@ func TestRefreshAccessToken_ExtraClaimsSurviveURLEncoding(t *testing.T) {
}
cfg.ExtraClaims = resolved

if _, err := refreshAccessToken(context.Background(), cfg, "old-refresh"); err != nil {
if _, err := refreshAccessToken(
context.Background(),
cfg,
credstore.Token{RefreshToken: "old-refresh"},
); err != nil {
t.Fatalf("refreshAccessToken() error: %v", err)
}
got := gotForm.value(t).Get("extra_claims")
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
charm.land/lipgloss/v2 v2.0.3
github.com/appleboy/go-httpretry v0.12.0
github.com/go-authgate/sdk-go v0.11.0
github.com/gofrs/flock v0.13.0
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/mattn/go-isatty v0.0.22
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ github.com/go-authgate/sdk-go v0.11.0 h1:ZTfJ0rzeDn4QBqAmF9VKS3CqlKhE8+0tJxg8OGN
github.com/go-authgate/sdk-go v0.11.0/go.mod h1:sa0ige5wtayj2WcnXlxa8wGuyi5z/c/chc0mXPJTl/Q=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw=
github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
Expand Down
51 changes: 51 additions & 0 deletions lock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package main

import (
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"time"

"github.com/gofrs/flock"
)

const lockRetryInterval = 100 * time.Millisecond

// lockTokenStore acquires a cross-process advisory lock scoped to
// (tokenFile, clientID). It serialises the "load → refresh → save"
// critical section so concurrent CLI invocations cannot spend the same
// refresh token twice (which would yield invalid_grant on rotation servers).
//
// The returned io.Closer releases the lock on Close (flock.Close is documented
// as equivalent to Unlock).
//
// The lock sits next to the token file regardless of the active backend —
// keyring-backed runs also need the coordination because the race is in the
// refresh flow, not the storage layer.
func lockTokenStore(ctx context.Context, tokenFile, clientID string) (io.Closer, error) {
if tokenFile == "" {
return nil, errors.New("lock: tokenFile is empty")
}
if clientID == "" {
return nil, errors.New("lock: clientID is empty")
}

dir := filepath.Dir(tokenFile)
if err := os.MkdirAll(dir, 0o700); err != nil {
return nil, fmt.Errorf("create lock directory %q: %w", dir, err)
}
Comment on lines +28 to +39
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lock placement currently depends on cfg.TokenFile, whose default is a relative path (.authgate-tokens.json). For keyring/auto (keyring available) this introduces a new requirement that the current working directory be writable (so the lock file can be created), otherwise refresh will fail even though the keyring backend itself doesn’t need filesystem writes. Consider choosing a lock location that’s reliably writable (e.g., under os.UserCacheDir()/os.UserConfigDir()) when the token file path is relative or when using keyring storage.

Copilot uses AI. Check for mistakes.

lockPath := filepath.Join(dir, filepath.Base(tokenFile)+"."+clientID+".lock")
fl := flock.New(lockPath)
locked, err := fl.TryLockContext(ctx, lockRetryInterval)
Comment on lines +41 to +43
Copy link

Copilot AI Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lockPath is built by concatenating clientID directly into the filename. If clientID contains path separators or .. segments, filepath.Join will clean the path and can place the lock file outside the intended directory (and on Windows some characters can make the filename invalid). Consider deriving a filesystem-safe lock name (e.g., hash/hex of clientID, or base64-url encoding, and/or rejecting/escaping path separators) before constructing the path.

Copilot uses AI. Check for mistakes.
if err != nil {
return nil, fmt.Errorf("acquire lock %s: %w", lockPath, err)
}
if !locked {
return nil, fmt.Errorf("could not acquire lock %s", lockPath)
}
return fl, nil
}
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func run(ctx context.Context, ui tui.Manager, cfg *AppConfig) int {
} else {
ui.ShowStatus(tui.StatusUpdate{Event: tui.EventTokenExpired})
}
newStorage, refreshErr := refreshAccessToken(ctx, cfg, existing.RefreshToken)
newStorage, refreshErr := refreshAccessToken(ctx, cfg, existing)
if refreshErr != nil {
// The refresh genuinely failed, so mark it failed in the UI.
// Then degrade gracefully (reuse the still-valid token) or
Expand Down
71 changes: 63 additions & 8 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ func testConfig(t *testing.T) *AppConfig {
t.Fatalf("failed to create retry client: %v", err)
}
serverURL := "http://localhost:8080"
tokenFile := filepath.Join(t.TempDir(), "tokens.json")
return &AppConfig{
ServerURL: serverURL,
ClientID: "test-client",
Scope: "email profile",
RetryClient: rc,
Store: credstore.NewTokenFileStore(
filepath.Join(t.TempDir(), "tokens.json"),
),
ServerURL: serverURL,
ClientID: "test-client",
Scope: "email profile",
RetryClient: rc,
TokenFile: tokenFile,
Store: credstore.NewTokenFileStore(tokenFile),
Endpoints: defaultEndpoints(serverURL),
TokenExchangeTimeout: defaultTokenExchangeTimeout,
TokenVerificationTimeout: defaultTokenVerificationTimeout,
Expand Down Expand Up @@ -317,7 +317,11 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) {
cfg.Endpoints = defaultEndpoints(srv.URL)
cfg.ClientID = "test-client-rotation"

storage, err := refreshAccessToken(context.Background(), cfg, tt.oldRefreshToken)
storage, err := refreshAccessToken(
context.Background(),
cfg,
credstore.Token{RefreshToken: tt.oldRefreshToken},
)
if err != nil {
t.Fatalf("refreshAccessToken() error: %v", err)
}
Expand All @@ -332,6 +336,57 @@ func TestRefreshAccessToken_RotationMode(t *testing.T) {
}
}

// TestRefreshAccessToken_PeerAlreadyRefreshed verifies that when a peer process
// has already refreshed and saved the tokens, refreshAccessToken returns the
// stored fresh token without making a network call. This guards against the
// refresh-token-rotation race that motivated the cross-process lock.
func TestRefreshAccessToken_PeerAlreadyRefreshed(t *testing.T) {
var serverCalled atomic.Bool
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
serverCalled.Store(true)
http.Error(w, "should not be called", http.StatusInternalServerError)
}))
defer srv.Close()

cfg := testConfig(t)
cfg.ServerURL = srv.URL
cfg.Endpoints = defaultEndpoints(srv.URL)
cfg.ClientID = "peer-refresh-test"

peerFresh := credstore.Token{
AccessToken: "peer-fresh-access",
RefreshToken: "peer-fresh-refresh",
TokenType: "Bearer",
ExpiresAt: time.Now().Add(1 * time.Hour),
ClientID: cfg.ClientID,
}
if err := cfg.Store.Save(cfg.ClientID, peerFresh); err != nil {
t.Fatalf("setup save: %v", err)
}

stale := credstore.Token{
AccessToken: "stale-access",
RefreshToken: "stale-refresh",
TokenType: "Bearer",
ExpiresAt: time.Now().Add(-1 * time.Minute),
ClientID: cfg.ClientID,
}

got, err := refreshAccessToken(context.Background(), cfg, stale)
if err != nil {
t.Fatalf("refreshAccessToken() error: %v", err)
}
if serverCalled.Load() {
t.Fatalf("network refresh was performed; peer-refresh shortcut was expected")
}
if got.AccessToken != peerFresh.AccessToken {
t.Errorf("AccessToken = %q, want %q", got.AccessToken, peerFresh.AccessToken)
}
if got.RefreshToken != peerFresh.RefreshToken {
t.Errorf("RefreshToken = %q, want %q", got.RefreshToken, peerFresh.RefreshToken)
}
}

// -----------------------------------------------------------------------
// Device code request with retry
// -----------------------------------------------------------------------
Expand Down
14 changes: 9 additions & 5 deletions token_cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,8 @@ func tokenGetRefreshConfig(t *testing.T, tokenURL string, tok credstore.Token) *
if err != nil {
t.Fatal(err)
}
store := credstore.NewTokenFileStore(
filepath.Join(t.TempDir(), "tokens.json"),
)
tokenFile := filepath.Join(t.TempDir(), "tokens.json")
store := credstore.NewTokenFileStore(tokenFile)
if err := store.Save("test-id", tok); err != nil {
t.Fatal(err)
}
Expand All @@ -468,6 +467,7 @@ func tokenGetRefreshConfig(t *testing.T, tokenURL string, tok credstore.Token) *
MaxResponseBodySize: defaultMaxResponseBodySize,
RefreshThreshold: defaultRefreshThreshold,
RetryClient: rc,
TokenFile: tokenFile,
Store: store,
}
}
Expand Down Expand Up @@ -742,7 +742,8 @@ func TestRunTokenGet_DefersDiscoveryUntilRefresh(t *testing.T) {
if err != nil {
t.Fatal(err)
}
store := credstore.NewTokenFileStore(filepath.Join(t.TempDir(), "tokens.json"))
tokenFile := filepath.Join(t.TempDir(), "tokens.json")
store := credstore.NewTokenFileStore(tokenFile)
if err := store.Save("test-id", credstore.Token{
AccessToken: "old-access-token",
RefreshToken: "refresh-456",
Expand All @@ -759,6 +760,7 @@ func TestRunTokenGet_DefersDiscoveryUntilRefresh(t *testing.T) {
DiscoveryTimeout: defaultDiscoveryTimeout,
MaxResponseBodySize: defaultMaxResponseBodySize,
RetryClient: rc,
TokenFile: tokenFile,
Store: store,
}

Expand Down Expand Up @@ -877,7 +879,8 @@ func TestRunTokenGet_LazyFullConfig(t *testing.T) {
if err != nil {
t.Fatal(err)
}
store := credstore.NewTokenFileStore(filepath.Join(t.TempDir(), "tokens.json"))
tokenFile := filepath.Join(t.TempDir(), "tokens.json")
store := credstore.NewTokenFileStore(tokenFile)
if err := store.Save("test-id", credstore.Token{
AccessToken: "old-access-token",
RefreshToken: "refresh-456",
Expand All @@ -891,6 +894,7 @@ func TestRunTokenGet_LazyFullConfig(t *testing.T) {
partial := &AppConfig{
ClientID: "test-id",
Store: store,
TokenFile: tokenFile,
RefreshThreshold: defaultRefreshThreshold,
}

Expand Down
Loading