Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmd/proapi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,12 @@ func wireAccountHandler(a *app.Application, adminG *gin.RouterGroup, log *zap.Lo
middleware.SessionAuth(sessStore, a.Clock),
middleware.RoleGate(2),
)
// 普通 admin 可访问的 13 个 endpoint。
// OAuth callback 不带 session,依赖一次性 state 校验。
adminG.GET("/accounts/oauth/callback", h.OAuthCallback)
// 普通 admin 可访问的 endpoint。
accG.GET("/accounts", h.List)
accG.GET("/accounts/stats/overview", h.Stats)
accG.POST("/accounts/oauth/start", h.OAuthStart)
accG.GET("/accounts/:id", h.Get)
accG.POST("/accounts", h.Create)
accG.POST("/accounts/import", h.Import)
Expand Down
96 changes: 96 additions & 0 deletions docs/superpowers/plans/2026-06-08-account-oauth-admin-endpoints.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Account OAuth Admin Endpoints Implementation Plan

> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.

**Goal:** Expose the existing account-pool OAuth PKCE flow through admin HTTP endpoints and enable the admin UI to launch it.

**Architecture:** Keep PKCE/state/token exchange inside `account.OAuthFlow`. `AccountHandler` owns HTTP validation, account persistence, events, probe trigger, and audit. The callback endpoint is mounted without session auth and relies on one-time state validation.

**Tech Stack:** Go, Gin, account facade, Redis-backed OAuth state store, Vue 3, Naive UI, TypeScript.

---

### Task 1: Backend Handler Tests

**Files:**
- Modify: `internal/server/handler/admin/account_test.go`

- [ ] **Step 1: Add a fake OAuthFlow to the existing account handler test harness**

Create a fake that records `Start` input and returns a prepared account from `Callback`. Add it to the `account.Facade`.

- [ ] **Step 2: Write failing tests for start and callback**

Add tests asserting:
- `POST /api/admin/accounts/oauth/start` with `provider=openai` and `channel_id=42` returns `auth_url` and `state`.
- missing `channel_id` returns HTTP 400.
- `GET /api/admin/accounts/oauth/callback?state=s&code=c` persists the returned account, appends an `oauth_callback` event, writes `account.oauth_callback` audit, and returns HTML containing `account_oauth_done`.

- [ ] **Step 3: Verify red**

Run: `go test ./internal/server/handler/admin -run 'TestAccountHandler_OAuth'`

Expected: FAIL because `OAuthStart` and `OAuthCallback` handler methods and routes do not exist yet.

### Task 2: Backend Implementation

**Files:**
- Modify: `internal/server/handler/admin/account.go`
- Modify: `cmd/proapi/main.go`

- [ ] **Step 1: Add handler request/response code**

Implement:
- `OAuthStart`: validates JSON body, requires `channel_id > 0`, calls `h.Facade.OAuth.Start`, returns `{"data":{"auth_url": "...", "state": "..."}}`.
- `OAuthCallback`: validates `state` and `code`, calls `h.Facade.OAuth.Callback`, persists account through `Repo.Create`, appends `oauth_callback` event, optionally starts probe, audits `account.oauth_callback`, and returns a tiny HTML page that posts `account_oauth_done` to the opener and closes.

- [ ] **Step 2: Route normal admin start and no-auth callback**

In `wireAccountHandler`, mount `POST /accounts/oauth/start` inside the existing admin-authenticated account group. Mount `GET /accounts/oauth/callback` directly on `/api/admin` with only JSON error middleware already present on the parent group.

- [ ] **Step 3: Verify green**

Run: `go test ./internal/server/handler/admin -run 'TestAccountHandler_OAuth'`

Expected: PASS.

### Task 3: Frontend Minimal Wiring

**Files:**
- Modify: `web/admin/src/api/account.ts`
- Modify: `web/admin/src/views/accounts/AddDialog.vue`
- Modify: `web/admin/src/i18n/zh.json`
- Modify: `web/admin/src/i18n/en.json`

- [ ] **Step 1: Add account API methods**

Add `oauthStart(payload)` returning `{ auth_url, state }`.

- [ ] **Step 2: Enable OAuth tab**

Replace the disabled info panel with channel/provider controls and a button that calls `accountApi.oauthStart`, opens the returned URL in a popup, and listens for `account_oauth_done` to refresh and close the dialog.

- [ ] **Step 3: Verify frontend type/build health**

Run the package build command available in `web/admin/package.json`.

### Task 4: Full Verification

**Files:**
- No edits.

- [ ] **Step 1: Backend full test**

Run: `go test ./...`

Expected: PASS.

- [ ] **Step 2: Build**

Run: `go build ./...`

Expected: PASS.

- [ ] **Step 3: Manual IdP smoke note**

If real OAuth config is present, open admin, select a provider/channel, launch OAuth, complete provider auth, and verify a new account row appears. If config is absent, document the exact missing config keys.
2 changes: 1 addition & 1 deletion internal/account/oauth/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (a *Anthropic) token(ctx context.Context, form url.Values) (*account.Accoun
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("anthropic token exchange: status %d", resp.StatusCode)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/account/oauth/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (o *OpenAI) token(ctx context.Context, form url.Values) (*account.AccountCr
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("openai token exchange: status %d", resp.StatusCode)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/account/probe/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (a *Anthropic) Probe(ctx context.Context, cred account.AccountCred) (http.H
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
_, _ = io.Copy(io.Discard, resp.Body)
if resp.StatusCode >= 400 {
return resp.Header, fmt.Errorf("anthropic probe: status %d", resp.StatusCode)
Expand Down
2 changes: 1 addition & 1 deletion internal/account/probe/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (o *OpenAI) Probe(ctx context.Context, cred account.AccountCred) (http.Head
if err != nil {
return nil, err
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
_, _ = io.Copy(io.Discard, resp.Body)
if resp.StatusCode >= 400 {
return resp.Header, fmt.Errorf("openai probe: status %d", resp.StatusCode)
Expand Down
31 changes: 24 additions & 7 deletions internal/account/repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,36 @@ func newTestDB(t *testing.T) *gorm.DB {
}
t.Cleanup(func() { _ = pool.Purge(res) })

dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC",
dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC&timeout=5s&readTimeout=5s&writeTimeout=5s",
res.GetPort("3306/tcp"))

var db *gorm.DB
pool.MaxWait = 90 * time.Second
pool.MaxWait = 120 * time.Second
if err := pool.Retry(func() error {
var openErr error
db, openErr = gorm.Open(mysql.Open(dsn), &gorm.Config{})
candidate, openErr := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if openErr != nil {
return openErr
}
sqlDB, _ := db.DB()
return sqlDB.Ping()
sqlDB, dbErr := candidate.DB()
if dbErr != nil {
return dbErr
}
if pingErr := sqlDB.Ping(); pingErr != nil {
_ = sqlDB.Close()
return pingErr
}
db = candidate
return nil
}); err != nil {
t.Fatalf("could not connect to mysql: %v", err)
}
t.Cleanup(func() {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
})

// Run account migrations via raw SQL.
sqls := []string{
Expand Down Expand Up @@ -162,8 +176,11 @@ func TestRepo_ListByChannel(t *testing.T) {
r := account.NewRepository(db, cr, idg, clock.Real, nil)
for i := 0; i < 3; i++ {
require.NoError(t, r.Create(ctx, &account.Account{
ChannelID: 200, Provider: "anthropic", CredType: "apikey",
Status: account.StatusActive, Weight: 100,
ChannelID: 200,
Provider: "anthropic",
ExternalAccountID: fmt.Sprintf("list-%d", i),
CredType: "apikey",
Status: account.StatusActive, Weight: 100,
Extra: json.RawMessage("{}"),
Cred: account.AccountCred{APIKey: "sk-x"},
}))
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/oauth/dingtalk/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (p *provider) fetchToken(ctx context.Context, code, redirectURL string) (*t
if err != nil {
return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "dingtalk token http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body)
var tok tokenResponse
if err := json.Unmarshal(respBody, &tok); err != nil {
Expand All @@ -166,7 +166,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt
if err != nil {
return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "dingtalk userinfo http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("dingtalk /users/me http %d", resp.StatusCode))
}
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/oauth/discord/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (p *provider) fetchToken(ctx context.Context, code, redirectURL string) (*t
if err != nil {
return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "discord token http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
var tok tokenResponse
if err := json.Unmarshal(body, &tok); err != nil {
Expand All @@ -153,7 +153,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt
if err != nil {
return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "discord userinfo http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("discord /users/@me http %d", resp.StatusCode))
}
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/oauth/feishu/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (p *provider) fetchToken(ctx context.Context, code string) (*tokenResponse,
if err != nil {
return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "feishu token http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body)
var tok tokenResponse
if err := json.Unmarshal(respBody, &tok); err != nil {
Expand All @@ -170,7 +170,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt
if err != nil {
return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "feishu userinfo http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("feishu /user_info http %d", resp.StatusCode))
}
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/oauth/google/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (p *provider) Exchange(ctx context.Context, code, redirectURL string) (*oau
if err != nil {
return nil, "", apierr.Wrap(apierr.CodeUpstreamUnavail, "google token http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
tokBody, _ := io.ReadAll(resp.Body)
var tok tokenResponse
if err := json.Unmarshal(tokBody, &tok); err != nil {
Expand Down Expand Up @@ -140,7 +140,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt
if err != nil {
return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "google userinfo http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("google /userinfo http %d", resp.StatusCode))
}
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/oauth/wechat/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (p *provider) fetchToken(ctx context.Context, code string) (*tokenResponse,
if err != nil {
return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "wechat token http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
var tok tokenResponse
if err := json.Unmarshal(body, &tok); err != nil {
Expand All @@ -154,7 +154,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken, openID string
if err != nil {
return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "wechat userinfo http", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode/100 != 2 {
return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("wechat /userinfo http %d", resp.StatusCode))
}
Expand Down
7 changes: 5 additions & 2 deletions internal/invite/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type WalletCredit interface {
// Deps holds dependencies for NewService.
type Deps struct {
Repo Repository
DB *gorm.DB // raw DB for cross-package queries without import cycles
DB *gorm.DB // raw DB for cross-package queries without import cycles
Wallet WalletCredit
Setting setting.Store
IDGen *idgen.Generator
Expand Down Expand Up @@ -51,7 +51,10 @@ func (s *Service) OnOrderPaid(ctx context.Context, orderID, userID int64) error
Table("users").
Select("invited_by").
Where("id = ?", userID).
Scan(&invitedBy).Error; err != nil || invitedBy == 0 {
Scan(&invitedBy).Error; err != nil {
return fmt.Errorf("invite: query inviter for user %d: %w", userID, err)
}
if invitedBy == 0 {
return nil // no inviter
}

Expand Down
18 changes: 15 additions & 3 deletions internal/invite/service_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ func (f *fakeSetting) GetString(_ context.Context, key, def string) string {
}
return v.(string)
}
func (f *fakeSetting) GetBool(_ context.Context, _ string, def bool) bool { return def }
func (f *fakeSetting) GetInt(_ context.Context, _ string, def int) int { return def }
func (f *fakeSetting) GetBool(_ context.Context, _ string, def bool) bool { return def }
func (f *fakeSetting) GetInt(_ context.Context, _ string, def int) int { return def }
func (f *fakeSetting) GetFloat(_ context.Context, key string, def float64) float64 {
v, ok := f.kv[key]
if !ok {
return def
}
return v.(float64)
}
func (f *fakeSetting) GetJSON(_ context.Context, _ string, _ any) error { return nil }
func (f *fakeSetting) GetJSON(_ context.Context, _ string, _ any) error { return nil }
func (f *fakeSetting) Put(_ context.Context, _ string, _ any, _ int64) error { return nil }
func (f *fakeSetting) Close() error { return nil }
func (f *fakeSetting) GetSecret(_ context.Context, _ string, _ setting.Decryptor) (string, error) {
Expand Down Expand Up @@ -143,6 +143,18 @@ func TestMaskEmail(t *testing.T) {
}
}

func TestOnOrderPaidReturnsUserLookupError(t *testing.T) {
db, svc := setupQueryDB(t)
ctx := context.Background()

if err := db.Exec(`DROP TABLE users`).Error; err != nil {
t.Fatal(err)
}
if err := svc.OnOrderPaid(ctx, 1, 20); err == nil {
t.Fatal("expected user lookup error")
}
}

func TestListInvitees(t *testing.T) {
db, svc := setupQueryDB(t)
ctx := context.Background()
Expand Down
4 changes: 2 additions & 2 deletions internal/notify/email/smtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (m *smtpMailer) sendTLS(addr string, raw []byte, to string) error {
if err != nil {
return fmt.Errorf("smtp new client: %w", err)
}
defer c.Close()
defer func() { _ = c.Close() }()
return m.doSend(c, raw, to)
}

Expand All @@ -75,7 +75,7 @@ func (m *smtpMailer) sendSTARTTLS(addr string, raw []byte, to string) error {
if err != nil {
return fmt.Errorf("smtp dial: %w", err)
}
defer c.Close()
defer func() { _ = c.Close() }()

if ok, _ := c.Extension("STARTTLS"); ok {
tlsCfg := &tls.Config{
Expand Down
Loading
Loading