diff --git a/cmd/proapi/main.go b/cmd/proapi/main.go index c836aa4..162abd3 100644 --- a/cmd/proapi/main.go +++ b/cmd/proapi/main.go @@ -388,9 +388,12 @@ func wireAccountHandler(a *app.Application, adminG *gin.RouterGroup, log *zap.Lo middleware.SessionAuth(sessStore, a.Clock), middleware.RoleGate(2), ) - // 普通 admin 可访问的 13 个 endpoint。 + // OAuth callback 不带 session,依赖一次性 state 校验。 + adminG.GET("/accounts/oauth/callback", h.OAuthCallback) + // 普通 admin 可访问的 endpoint。 accG.GET("/accounts", h.List) accG.GET("/accounts/stats/overview", h.Stats) + accG.POST("/accounts/oauth/start", h.OAuthStart) accG.GET("/accounts/:id", h.Get) accG.POST("/accounts", h.Create) accG.POST("/accounts/import", h.Import) diff --git a/docs/superpowers/plans/2026-06-08-account-oauth-admin-endpoints.md b/docs/superpowers/plans/2026-06-08-account-oauth-admin-endpoints.md new file mode 100644 index 0000000..e2d0028 --- /dev/null +++ b/docs/superpowers/plans/2026-06-08-account-oauth-admin-endpoints.md @@ -0,0 +1,96 @@ +# Account OAuth Admin Endpoints Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Expose the existing account-pool OAuth PKCE flow through admin HTTP endpoints and enable the admin UI to launch it. + +**Architecture:** Keep PKCE/state/token exchange inside `account.OAuthFlow`. `AccountHandler` owns HTTP validation, account persistence, events, probe trigger, and audit. The callback endpoint is mounted without session auth and relies on one-time state validation. + +**Tech Stack:** Go, Gin, account facade, Redis-backed OAuth state store, Vue 3, Naive UI, TypeScript. + +--- + +### Task 1: Backend Handler Tests + +**Files:** +- Modify: `internal/server/handler/admin/account_test.go` + +- [ ] **Step 1: Add a fake OAuthFlow to the existing account handler test harness** + +Create a fake that records `Start` input and returns a prepared account from `Callback`. Add it to the `account.Facade`. + +- [ ] **Step 2: Write failing tests for start and callback** + +Add tests asserting: +- `POST /api/admin/accounts/oauth/start` with `provider=openai` and `channel_id=42` returns `auth_url` and `state`. +- missing `channel_id` returns HTTP 400. +- `GET /api/admin/accounts/oauth/callback?state=s&code=c` persists the returned account, appends an `oauth_callback` event, writes `account.oauth_callback` audit, and returns HTML containing `account_oauth_done`. + +- [ ] **Step 3: Verify red** + +Run: `go test ./internal/server/handler/admin -run 'TestAccountHandler_OAuth'` + +Expected: FAIL because `OAuthStart` and `OAuthCallback` handler methods and routes do not exist yet. + +### Task 2: Backend Implementation + +**Files:** +- Modify: `internal/server/handler/admin/account.go` +- Modify: `cmd/proapi/main.go` + +- [ ] **Step 1: Add handler request/response code** + +Implement: +- `OAuthStart`: validates JSON body, requires `channel_id > 0`, calls `h.Facade.OAuth.Start`, returns `{"data":{"auth_url": "...", "state": "..."}}`. +- `OAuthCallback`: validates `state` and `code`, calls `h.Facade.OAuth.Callback`, persists account through `Repo.Create`, appends `oauth_callback` event, optionally starts probe, audits `account.oauth_callback`, and returns a tiny HTML page that posts `account_oauth_done` to the opener and closes. + +- [ ] **Step 2: Route normal admin start and no-auth callback** + +In `wireAccountHandler`, mount `POST /accounts/oauth/start` inside the existing admin-authenticated account group. Mount `GET /accounts/oauth/callback` directly on `/api/admin` with only JSON error middleware already present on the parent group. + +- [ ] **Step 3: Verify green** + +Run: `go test ./internal/server/handler/admin -run 'TestAccountHandler_OAuth'` + +Expected: PASS. + +### Task 3: Frontend Minimal Wiring + +**Files:** +- Modify: `web/admin/src/api/account.ts` +- Modify: `web/admin/src/views/accounts/AddDialog.vue` +- Modify: `web/admin/src/i18n/zh.json` +- Modify: `web/admin/src/i18n/en.json` + +- [ ] **Step 1: Add account API methods** + +Add `oauthStart(payload)` returning `{ auth_url, state }`. + +- [ ] **Step 2: Enable OAuth tab** + +Replace the disabled info panel with channel/provider controls and a button that calls `accountApi.oauthStart`, opens the returned URL in a popup, and listens for `account_oauth_done` to refresh and close the dialog. + +- [ ] **Step 3: Verify frontend type/build health** + +Run the package build command available in `web/admin/package.json`. + +### Task 4: Full Verification + +**Files:** +- No edits. + +- [ ] **Step 1: Backend full test** + +Run: `go test ./...` + +Expected: PASS. + +- [ ] **Step 2: Build** + +Run: `go build ./...` + +Expected: PASS. + +- [ ] **Step 3: Manual IdP smoke note** + +If real OAuth config is present, open admin, select a provider/channel, launch OAuth, complete provider auth, and verify a new account row appears. If config is absent, document the exact missing config keys. diff --git a/internal/account/oauth/anthropic.go b/internal/account/oauth/anthropic.go index 50d65e9..e407241 100644 --- a/internal/account/oauth/anthropic.go +++ b/internal/account/oauth/anthropic.go @@ -58,7 +58,7 @@ func (a *Anthropic) token(ctx context.Context, form url.Values) (*account.Accoun if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { return nil, fmt.Errorf("anthropic token exchange: status %d", resp.StatusCode) } diff --git a/internal/account/oauth/openai.go b/internal/account/oauth/openai.go index 716e873..301698f 100644 --- a/internal/account/oauth/openai.go +++ b/internal/account/oauth/openai.go @@ -58,7 +58,7 @@ func (o *OpenAI) token(ctx context.Context, form url.Values) (*account.AccountCr if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { return nil, fmt.Errorf("openai token exchange: status %d", resp.StatusCode) } diff --git a/internal/account/probe/anthropic.go b/internal/account/probe/anthropic.go index 7fa5a0c..06cc05b 100644 --- a/internal/account/probe/anthropic.go +++ b/internal/account/probe/anthropic.go @@ -34,7 +34,7 @@ func (a *Anthropic) Probe(ctx context.Context, cred account.AccountCred) (http.H if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() _, _ = io.Copy(io.Discard, resp.Body) if resp.StatusCode >= 400 { return resp.Header, fmt.Errorf("anthropic probe: status %d", resp.StatusCode) diff --git a/internal/account/probe/openai.go b/internal/account/probe/openai.go index 2b03062..0e36acc 100644 --- a/internal/account/probe/openai.go +++ b/internal/account/probe/openai.go @@ -34,7 +34,7 @@ func (o *OpenAI) Probe(ctx context.Context, cred account.AccountCred) (http.Head if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() _, _ = io.Copy(io.Discard, resp.Body) if resp.StatusCode >= 400 { return resp.Header, fmt.Errorf("openai probe: status %d", resp.StatusCode) diff --git a/internal/account/repo_test.go b/internal/account/repo_test.go index 4611641..b6584e0 100644 --- a/internal/account/repo_test.go +++ b/internal/account/repo_test.go @@ -38,22 +38,36 @@ func newTestDB(t *testing.T) *gorm.DB { } t.Cleanup(func() { _ = pool.Purge(res) }) - dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC", + dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC&timeout=5s&readTimeout=5s&writeTimeout=5s", res.GetPort("3306/tcp")) var db *gorm.DB - pool.MaxWait = 90 * time.Second + pool.MaxWait = 120 * time.Second if err := pool.Retry(func() error { var openErr error - db, openErr = gorm.Open(mysql.Open(dsn), &gorm.Config{}) + candidate, openErr := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if openErr != nil { return openErr } - sqlDB, _ := db.DB() - return sqlDB.Ping() + sqlDB, dbErr := candidate.DB() + if dbErr != nil { + return dbErr + } + if pingErr := sqlDB.Ping(); pingErr != nil { + _ = sqlDB.Close() + return pingErr + } + db = candidate + return nil }); err != nil { t.Fatalf("could not connect to mysql: %v", err) } + t.Cleanup(func() { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + }) // Run account migrations via raw SQL. sqls := []string{ @@ -162,8 +176,11 @@ func TestRepo_ListByChannel(t *testing.T) { r := account.NewRepository(db, cr, idg, clock.Real, nil) for i := 0; i < 3; i++ { require.NoError(t, r.Create(ctx, &account.Account{ - ChannelID: 200, Provider: "anthropic", CredType: "apikey", - Status: account.StatusActive, Weight: 100, + ChannelID: 200, + Provider: "anthropic", + ExternalAccountID: fmt.Sprintf("list-%d", i), + CredType: "apikey", + Status: account.StatusActive, Weight: 100, Extra: json.RawMessage("{}"), Cred: account.AccountCred{APIKey: "sk-x"}, })) diff --git a/internal/auth/oauth/dingtalk/provider.go b/internal/auth/oauth/dingtalk/provider.go index 1c05fe6..3bea358 100644 --- a/internal/auth/oauth/dingtalk/provider.go +++ b/internal/auth/oauth/dingtalk/provider.go @@ -142,7 +142,7 @@ func (p *provider) fetchToken(ctx context.Context, code, redirectURL string) (*t if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "dingtalk token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(respBody, &tok); err != nil { @@ -166,7 +166,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "dingtalk userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("dingtalk /users/me http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/discord/provider.go b/internal/auth/oauth/discord/provider.go index b417026..6e36c9d 100644 --- a/internal/auth/oauth/discord/provider.go +++ b/internal/auth/oauth/discord/provider.go @@ -129,7 +129,7 @@ func (p *provider) fetchToken(ctx context.Context, code, redirectURL string) (*t if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "discord token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(body, &tok); err != nil { @@ -153,7 +153,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "discord userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("discord /users/@me http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/feishu/provider.go b/internal/auth/oauth/feishu/provider.go index 670f5b5..5a042c7 100644 --- a/internal/auth/oauth/feishu/provider.go +++ b/internal/auth/oauth/feishu/provider.go @@ -146,7 +146,7 @@ func (p *provider) fetchToken(ctx context.Context, code string) (*tokenResponse, if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "feishu token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(respBody, &tok); err != nil { @@ -170,7 +170,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "feishu userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("feishu /user_info http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/google/provider.go b/internal/auth/oauth/google/provider.go index 1a26002..cde399a 100644 --- a/internal/auth/oauth/google/provider.go +++ b/internal/auth/oauth/google/provider.go @@ -106,7 +106,7 @@ func (p *provider) Exchange(ctx context.Context, code, redirectURL string) (*oau if err != nil { return nil, "", apierr.Wrap(apierr.CodeUpstreamUnavail, "google token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() tokBody, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(tokBody, &tok); err != nil { @@ -140,7 +140,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "google userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("google /userinfo http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/wechat/provider.go b/internal/auth/oauth/wechat/provider.go index 4a93a73..df3826b 100644 --- a/internal/auth/oauth/wechat/provider.go +++ b/internal/auth/oauth/wechat/provider.go @@ -127,7 +127,7 @@ func (p *provider) fetchToken(ctx context.Context, code string) (*tokenResponse, if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "wechat token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(body, &tok); err != nil { @@ -154,7 +154,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken, openID string if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "wechat userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("wechat /userinfo http %d", resp.StatusCode)) } diff --git a/internal/invite/service.go b/internal/invite/service.go index e3fd6f6..21b235d 100644 --- a/internal/invite/service.go +++ b/internal/invite/service.go @@ -21,7 +21,7 @@ type WalletCredit interface { // Deps holds dependencies for NewService. type Deps struct { Repo Repository - DB *gorm.DB // raw DB for cross-package queries without import cycles + DB *gorm.DB // raw DB for cross-package queries without import cycles Wallet WalletCredit Setting setting.Store IDGen *idgen.Generator @@ -51,7 +51,10 @@ func (s *Service) OnOrderPaid(ctx context.Context, orderID, userID int64) error Table("users"). Select("invited_by"). Where("id = ?", userID). - Scan(&invitedBy).Error; err != nil || invitedBy == 0 { + Scan(&invitedBy).Error; err != nil { + return fmt.Errorf("invite: query inviter for user %d: %w", userID, err) + } + if invitedBy == 0 { return nil // no inviter } diff --git a/internal/invite/service_query_test.go b/internal/invite/service_query_test.go index 2d87d47..6f25f3c 100644 --- a/internal/invite/service_query_test.go +++ b/internal/invite/service_query_test.go @@ -33,8 +33,8 @@ func (f *fakeSetting) GetString(_ context.Context, key, def string) string { } return v.(string) } -func (f *fakeSetting) GetBool(_ context.Context, _ string, def bool) bool { return def } -func (f *fakeSetting) GetInt(_ context.Context, _ string, def int) int { return def } +func (f *fakeSetting) GetBool(_ context.Context, _ string, def bool) bool { return def } +func (f *fakeSetting) GetInt(_ context.Context, _ string, def int) int { return def } func (f *fakeSetting) GetFloat(_ context.Context, key string, def float64) float64 { v, ok := f.kv[key] if !ok { @@ -42,7 +42,7 @@ func (f *fakeSetting) GetFloat(_ context.Context, key string, def float64) float } return v.(float64) } -func (f *fakeSetting) GetJSON(_ context.Context, _ string, _ any) error { return nil } +func (f *fakeSetting) GetJSON(_ context.Context, _ string, _ any) error { return nil } func (f *fakeSetting) Put(_ context.Context, _ string, _ any, _ int64) error { return nil } func (f *fakeSetting) Close() error { return nil } func (f *fakeSetting) GetSecret(_ context.Context, _ string, _ setting.Decryptor) (string, error) { @@ -143,6 +143,18 @@ func TestMaskEmail(t *testing.T) { } } +func TestOnOrderPaidReturnsUserLookupError(t *testing.T) { + db, svc := setupQueryDB(t) + ctx := context.Background() + + if err := db.Exec(`DROP TABLE users`).Error; err != nil { + t.Fatal(err) + } + if err := svc.OnOrderPaid(ctx, 1, 20); err == nil { + t.Fatal("expected user lookup error") + } +} + func TestListInvitees(t *testing.T) { db, svc := setupQueryDB(t) ctx := context.Background() diff --git a/internal/notify/email/smtp.go b/internal/notify/email/smtp.go index 41ad410..edd792a 100644 --- a/internal/notify/email/smtp.go +++ b/internal/notify/email/smtp.go @@ -66,7 +66,7 @@ func (m *smtpMailer) sendTLS(addr string, raw []byte, to string) error { if err != nil { return fmt.Errorf("smtp new client: %w", err) } - defer c.Close() + defer func() { _ = c.Close() }() return m.doSend(c, raw, to) } @@ -75,7 +75,7 @@ func (m *smtpMailer) sendSTARTTLS(addr string, raw []byte, to string) error { if err != nil { return fmt.Errorf("smtp dial: %w", err) } - defer c.Close() + defer func() { _ = c.Close() }() if ok, _ := c.Extension("STARTTLS"); ok { tlsCfg := &tls.Config{ diff --git a/internal/orm/orm_integration_test.go b/internal/orm/orm_integration_test.go index 47b1a4a..3c8a31f 100644 --- a/internal/orm/orm_integration_test.go +++ b/internal/orm/orm_integration_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/ory/dockertest/v3" "github.com/ijry/pro-api/internal/app/config" + "github.com/ory/dockertest/v3" ) func mustPool(t *testing.T) *dockertest.Pool { @@ -34,10 +34,10 @@ func TestOpen_MySQL(t *testing.T) { } t.Cleanup(func() { _ = pool.Purge(res) }) - dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC", + dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC&timeout=5s&readTimeout=5s&writeTimeout=5s", res.GetPort("3306/tcp")) - pool.MaxWait = 90 * time.Second + pool.MaxWait = 120 * time.Second if err := pool.Retry(func() error { db, err := Open(config.DatabaseConfig{ Driver: "mysql", DSN: dsn, @@ -46,8 +46,15 @@ func TestOpen_MySQL(t *testing.T) { if err != nil { return err } - sqlDB, _ := db.DB() - return sqlDB.Ping() + sqlDB, dbErr := db.DB() + if dbErr != nil { + return dbErr + } + if pingErr := sqlDB.Ping(); pingErr != nil { + _ = sqlDB.Close() + return pingErr + } + return sqlDB.Close() }); err != nil { t.Fatal(err) } @@ -76,10 +83,16 @@ func TestOpen_Postgres(t *testing.T) { if err != nil { return err } - sqlDB, _ := db.DB() - return sqlDB.Ping() + sqlDB, dbErr := db.DB() + if dbErr != nil { + return dbErr + } + if pingErr := sqlDB.Ping(); pingErr != nil { + _ = sqlDB.Close() + return pingErr + } + return sqlDB.Close() }); err != nil { t.Fatal(err) } } - diff --git a/internal/server/handler/admin/account.go b/internal/server/handler/admin/account.go index ef6d593..5ff04b3 100644 --- a/internal/server/handler/admin/account.go +++ b/internal/server/handler/admin/account.go @@ -15,6 +15,7 @@ package admin import ( "encoding/json" + "errors" "io" "net/http" "strconv" @@ -22,6 +23,7 @@ import ( "github.com/gin-gonic/gin" "github.com/ijry/pro-api/internal/account" + accountoauth "github.com/ijry/pro-api/internal/account/oauth" "github.com/ijry/pro-api/internal/audit" "github.com/ijry/pro-api/internal/server/middleware" "github.com/ijry/pro-api/pkg/apierr" @@ -51,6 +53,8 @@ func NewAccountHandler(f *account.Facade, a audit.Logger, actorOf func(*gin.Cont func (h *AccountHandler) Register(r gin.IRouter) { r.GET("/accounts", h.List) r.GET("/accounts/stats/overview", h.Stats) + r.POST("/accounts/oauth/start", h.OAuthStart) + r.GET("/accounts/oauth/callback", h.OAuthCallback) r.GET("/accounts/:id", h.Get) r.POST("/accounts", h.Create) r.POST("/accounts/import", h.Import) @@ -119,26 +123,26 @@ func (h *AccountHandler) auditOne(c *gin.Context, action string, targetID int64, // listItem 是 List 返回的 item。绝不包含明文凭证或 Credentials 密文。 type listItem struct { - ID int64 `json:"id"` - ChannelID int64 `json:"channel_id"` - ShareTag string `json:"share_tag"` - Name string `json:"name"` - Provider string `json:"provider"` - Tier string `json:"tier"` - CredType string `json:"cred_type"` - Email string `json:"email"` - Status int8 `json:"status"` - Priority int16 `json:"priority"` - Weight int `json:"weight"` - ConsecFailures int `json:"consec_failures"` - RefreshTokenValid int8 `json:"refresh_token_valid"` - CooldownUntil *time.Time `json:"cooldown_until,omitempty"` - LastUsedAt *time.Time `json:"last_used_at,omitempty"` - LastSuccessAt *time.Time `json:"last_success_at,omitempty"` - LastFailureAt *time.Time `json:"last_failure_at,omitempty"` - AccessTokenExpiresAt *time.Time `json:"access_token_expires_at,omitempty"` - Quota5h quotaItem `json:"quota_5h"` - QuotaWeek quotaItem `json:"quota_week"` + ID int64 `json:"id"` + ChannelID int64 `json:"channel_id"` + ShareTag string `json:"share_tag"` + Name string `json:"name"` + Provider string `json:"provider"` + Tier string `json:"tier"` + CredType string `json:"cred_type"` + Email string `json:"email"` + Status int8 `json:"status"` + Priority int16 `json:"priority"` + Weight int `json:"weight"` + ConsecFailures int `json:"consec_failures"` + RefreshTokenValid int8 `json:"refresh_token_valid"` + CooldownUntil *time.Time `json:"cooldown_until,omitempty"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + LastSuccessAt *time.Time `json:"last_success_at,omitempty"` + LastFailureAt *time.Time `json:"last_failure_at,omitempty"` + AccessTokenExpiresAt *time.Time `json:"access_token_expires_at,omitempty"` + Quota5h quotaItem `json:"quota_5h"` + QuotaWeek quotaItem `json:"quota_week"` } // detailItem 在 listItem 之外再加几个详情字段。同样不含明文凭证。 @@ -289,6 +293,11 @@ type acctCreateReq struct { Weight int `json:"weight"` } +type acctOAuthStartReq struct { + Provider string `json:"provider"` + ChannelID int64 `json:"channel_id"` +} + // Create POST /accounts — 单条创建,与 Import 类似但只取第一条。dry_run 时不落库。 func (h *AccountHandler) Create(c *gin.Context) { var req acctCreateReq @@ -368,6 +377,90 @@ func (h *AccountHandler) Create(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"data": gin.H{"id": a.ID}}) } +// OAuthStart POST /accounts/oauth/start — 启动账号池 OAuth PKCE 流程。 +func (h *AccountHandler) OAuthStart(c *gin.Context) { + if h.Facade.OAuth == nil { + middleware.SetErr(c, apierr.New(apierr.CodeInternal, "oauth 未就绪")) + return + } + var req acctOAuthStartReq + if err := c.ShouldBindJSON(&req); err != nil { + middleware.SetErr(c, apierr.New(apierr.CodeInvalidParam, "请求体不合法")) + return + } + if req.Provider == "" { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "provider 必填")) + return + } + if req.ChannelID <= 0 { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "channel_id 必填")) + return + } + authURL, state, err := h.Facade.OAuth.Start(c.Request.Context(), req.Provider, req.ChannelID) + if err != nil { + middleware.SetErr(c, apierr.Wrap(apierr.CodeAccountOAuthRejected, "oauth start failed", err)) + return + } + c.JSON(http.StatusOK, gin.H{"data": gin.H{"auth_url": authURL, "state": state}}) +} + +// OAuthCallback GET /accounts/oauth/callback — OAuth provider 回调入口(no-auth,state 校验)。 +func (h *AccountHandler) OAuthCallback(c *gin.Context) { + if h.Facade.OAuth == nil { + middleware.SetErr(c, apierr.New(apierr.CodeInternal, "oauth 未就绪")) + return + } + state := c.Query("state") + code := c.Query("code") + if state == "" { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "state 必填")) + return + } + if code == "" { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "code 必填")) + return + } + a, err := h.Facade.OAuth.Callback(c.Request.Context(), state, code) + if err != nil { + if errors.Is(err, accountoauth.ErrStateNotFound) { + middleware.SetErr(c, apierr.Wrap(apierr.CodeAccountOAuthState, "oauth state invalid", err)) + return + } + middleware.SetErr(c, apierr.Wrap(apierr.CodeAccountOAuthRejected, "oauth callback failed", err)) + return + } + if err := h.Facade.Repo.Create(c.Request.Context(), a); err != nil { + writeAcctErr(c, err) + return + } + _ = h.Facade.Repo.AppendEvent(c.Request.Context(), a.ID, "oauth_callback", + map[string]any{"provider": a.Provider, "channel_id": a.ChannelID}) + if h.Facade.Probe != nil { + cc := c.Copy() + go func(acc *account.Account) { + _ = h.Facade.Probe.Run(cc.Request.Context(), acc) + }(a) + } + h.auditOne(c, "account.oauth_callback", a.ID, nil, map[string]any{ + "id": a.ID, + "channel_id": a.ChannelID, + "provider": a.Provider, + }) + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(accountOAuthDoneHTML)) +} + +const accountOAuthDoneHTML = ` +OAuth complete + + +OAuth complete. You can close this window. +` + // Import POST /accounts/import — 批量导入。支持 JSON body 或 multipart file。 func (h *AccountHandler) Import(c *gin.Context) { var req acctCreateReq @@ -384,7 +477,7 @@ func (h *AccountHandler) Import(c *gin.Context) { if err == nil && fh != nil { f, err := fh.Open() if err == nil { - defer f.Close() + defer func() { _ = f.Close() }() buf, _ := io.ReadAll(f) text = string(buf) } @@ -696,4 +789,3 @@ func (h *AccountHandler) Stats(c *gin.Context) { "by_provider": gin.H{}, }}) } - diff --git a/internal/server/handler/admin/account_test.go b/internal/server/handler/admin/account_test.go index 1335b8a..5ee1862 100644 --- a/internal/server/handler/admin/account_test.go +++ b/internal/server/handler/admin/account_test.go @@ -251,6 +251,51 @@ func (f *fakeRefresher) RefreshOne(_ context.Context, id int64) error { } func (f *fakeRefresher) Close() error { return nil } +// fakeOAuth records Start calls and returns a prepared account from Callback. +type fakeOAuth struct { + mu sync.Mutex + startProvider string + startChannelID int64 + callbackState string + callbackCode string + callbackAccount *account.Account +} + +func (f *fakeOAuth) Start(_ context.Context, provider string, channelID int64) (string, string, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.startProvider = provider + f.startChannelID = channelID + return "https://oauth.example/authorize?state=st-1", "st-1", nil +} + +func (f *fakeOAuth) Callback(_ context.Context, state, code string) (*account.Account, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.callbackState = state + f.callbackCode = code + if f.callbackAccount != nil { + return f.callbackAccount, nil + } + return &account.Account{ + ChannelID: 42, + Provider: "openai", + CredType: "oauth", + Email: "oauth@example.com", + Status: account.StatusActive, + Weight: 100, + RefreshTokenValid: 1, + Cred: account.AccountCred{ + AccessToken: "at", + RefreshToken: "rt", + }, + }, nil +} + +func (f *fakeOAuth) ExchangeRefreshToken(context.Context, string, string) (*account.AccountCred, error) { + return nil, nil +} + // memAudit captures audit entries for assertions. type memAudit struct { mu sync.Mutex @@ -277,16 +322,23 @@ func (m *memAudit) actions() []string { // --- harness --- func newAccountTestHarness() (*gin.Engine, *fakeRepo, *memAudit, *fakeProbe, *fakeRefresher) { + r, repo, aud, probe, ref, _ := newAccountTestHarnessWithOAuth() + return r, repo, aud, probe, ref +} + +func newAccountTestHarnessWithOAuth() (*gin.Engine, *fakeRepo, *memAudit, *fakeProbe, *fakeRefresher, *fakeOAuth) { gin.SetMode(gin.TestMode) repo := newFakeRepo() imp := &fakeImporter{format: "raw_api_key"} probe := &fakeProbe{} ref := &fakeRefresher{} + oauth := &fakeOAuth{} facade := &account.Facade{ Repo: repo, Importer: imp, Probe: probe, Refresher: ref, + OAuth: oauth, } aud := &memAudit{} h := NewAccountHandler(facade, aud, func(c *gin.Context) int64 { return 7 }) @@ -294,7 +346,7 @@ func newAccountTestHarness() (*gin.Engine, *fakeRepo, *memAudit, *fakeProbe, *fa r.Use(middleware.ErrorResponse("json")) g := r.Group("/api/admin") h.Register(g) - return r, repo, aud, probe, ref + return r, repo, aud, probe, ref, oauth } func doAcctReq(t *testing.T, r http.Handler, method, path, body string) *httptest.ResponseRecorder { @@ -423,6 +475,81 @@ func TestAccountHandler_Create_Persists(t *testing.T) { } } +func TestAccountHandler_OAuthStart_OK(t *testing.T) { + r, _, _, _, _, oauth := newAccountTestHarnessWithOAuth() + + rec := doAcctReq(t, r, http.MethodPost, "/api/admin/accounts/oauth/start", + `{"provider":"openai","channel_id":42}`) + if rec.Code != http.StatusOK { + t.Fatalf("want 200, got %d body=%s", rec.Code, rec.Body.String()) + } + var body struct { + Data struct { + AuthURL string `json:"auth_url"` + State string `json:"state"` + } `json:"data"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal: %v body=%s", err, rec.Body.String()) + } + if body.Data.AuthURL == "" || body.Data.State != "st-1" { + t.Fatalf("unexpected oauth start response: %+v", body.Data) + } + if oauth.startProvider != "openai" || oauth.startChannelID != 42 { + t.Fatalf("oauth start inputs: provider=%q channel=%d", oauth.startProvider, oauth.startChannelID) + } +} + +func TestAccountHandler_OAuthStart_MissingChannelID(t *testing.T) { + r, _, _, _, _ := newAccountTestHarness() + + rec := doAcctReq(t, r, http.MethodPost, "/api/admin/accounts/oauth/start", + `{"provider":"openai"}`) + if rec.Code != http.StatusBadRequest { + t.Fatalf("want 400, got %d body=%s", rec.Code, rec.Body.String()) + } +} + +func TestAccountHandler_OAuthCallback_PersistsAndReturnsDoneHTML(t *testing.T) { + r, repo, aud, _, _, oauth := newAccountTestHarnessWithOAuth() + + rec := doAcctReq(t, r, http.MethodGet, "/api/admin/accounts/oauth/callback?state=st-1&code=code-1", "") + if rec.Code != http.StatusOK { + t.Fatalf("want 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "account_oauth_done") { + t.Fatalf("callback should return popup completion HTML, got %s", rec.Body.String()) + } + if len(repo.items) != 1 { + t.Fatalf("want 1 persisted account, got %d", len(repo.items)) + } + var got *account.Account + for _, a := range repo.items { + got = a + } + if got.Provider != "openai" || got.CredType != "oauth" || got.ChannelID != 42 { + t.Fatalf("unexpected persisted account: %+v", got) + } + if oauth.callbackState != "st-1" || oauth.callbackCode != "code-1" { + t.Fatalf("oauth callback inputs: state=%q code=%q", oauth.callbackState, oauth.callbackCode) + } + if len(repo.events) != 1 || repo.events[0].eventType != "oauth_callback" || repo.events[0].accountID != got.ID { + t.Fatalf("oauth callback event missing: %+v", repo.events) + } + found := false + for _, e := range aud.entries { + if e.Action == "account.oauth_callback" { + found = true + if e.TargetID == nil || *e.TargetID != got.ID { + t.Fatalf("TargetID mismatch: %+v", e.TargetID) + } + } + } + if !found { + t.Fatalf("audit account.oauth_callback missing, got %v", aud.actions()) + } +} + // Patch: update name + priority; row updated; audit account.patch written. func TestAccountHandler_Patch_OK(t *testing.T) { r, repo, aud, _, _ := newAccountTestHarness() @@ -460,7 +587,7 @@ func TestAccountHandler_PeekCredentials_AuditWritten(t *testing.T) { a := &account.Account{ ChannelID: 1, Provider: "anthropic", CredType: "apikey", Status: account.StatusActive, Weight: 100, Extra: json.RawMessage("{}"), - Cred: account.AccountCred{APIKey: "sk-secret"}, + Cred: account.AccountCred{APIKey: "sk-secret"}, } _ = repo.Create(context.Background(), a) diff --git a/web/admin/src/api/account.ts b/web/admin/src/api/account.ts index b963f61..87c4248 100644 --- a/web/admin/src/api/account.ts +++ b/web/admin/src/api/account.ts @@ -101,11 +101,23 @@ export interface StatsResp { by_provider: Record } +export interface OAuthStartPayload { + provider: 'anthropic' | 'openai' + channel_id: number +} + +export interface OAuthStartResp { + auth_url: string + state: string +} + export const accountApi = { list: (p: ListParams) => get('/api/admin/accounts', p as Record), get: (id: number) => get(`/api/admin/accounts/${id}`), create:(payload: CreatePayload) => post<{ id?: number; preview?: Account }>('/api/admin/accounts', payload), import:(payload: ImportPayload) => post('/api/admin/accounts/import', payload), + oauthStart: async (payload: OAuthStartPayload) => + (await post<{ data: OAuthStartResp }>('/api/admin/accounts/oauth/start', payload)).data, patch: (id: number, p: Partial) => patch(`/api/admin/accounts/${id}`, p), delete:(id: number) => del<{ id: number }>(`/api/admin/accounts/${id}`), enable:(id: number) => post<{ id: number; status: number }>(`/api/admin/accounts/${id}/enable`, {}), diff --git a/web/admin/src/i18n/en.json b/web/admin/src/i18n/en.json index 42207f1..49f95c1 100644 --- a/web/admin/src/i18n/en.json +++ b/web/admin/src/i18n/en.json @@ -142,6 +142,12 @@ "name_label": "Account Name (optional)", "name_placeholder": "Parsed from credentials by default", "oauth_disabled": "Direct OAuth authorization will ship in M2", + "provider_label": "OAuth Provider", + "oauth_hint": "Select a channel and provider, then complete authorization in the popup. The account is saved automatically; edit account name and share tag after it appears in the list.", + "oauth_start": "Start Authorization", + "oauth_started": "Authorization popup opened. Complete authorization in the new window.", + "oauth_done": "OAuth authorization completed and account saved", + "oauth_popup_blocked": "The browser blocked the authorization popup. Allow popups and retry.", "token_text_label": "Token / JSON", "token_text_placeholder": "Paste access_token, refresh_token or full JSON", "apikey_label": "API Key", diff --git a/web/admin/src/i18n/zh.json b/web/admin/src/i18n/zh.json index c818e22..30dc0ff 100644 --- a/web/admin/src/i18n/zh.json +++ b/web/admin/src/i18n/zh.json @@ -142,6 +142,12 @@ "name_label": "账号名(可选)", "name_placeholder": "默认从凭证解析", "oauth_disabled": "OAuth 直连授权将在 M2 上线", + "provider_label": "OAuth Provider", + "oauth_hint": "选择渠道和 Provider 后会打开授权窗口;授权完成后账号自动入库,账号名和共享 tag 可在列表中编辑。", + "oauth_start": "开始授权", + "oauth_started": "已打开授权窗口,请在新窗口完成授权", + "oauth_done": "OAuth 授权完成,账号已入库", + "oauth_popup_blocked": "浏览器拦截了授权窗口,请允许弹窗后重试", "token_text_label": "Token 文本 / JSON", "token_text_placeholder": "粘贴 access_token、refresh_token 或完整 JSON", "apikey_label": "API Key", diff --git a/web/admin/src/views/accounts/AddDialog.vue b/web/admin/src/views/accounts/AddDialog.vue index 50917de..67ec99b 100644 --- a/web/admin/src/views/accounts/AddDialog.vue +++ b/web/admin/src/views/accounts/AddDialog.vue @@ -1,5 +1,5 @@