diff --git a/cmd/proapi/main.go b/cmd/proapi/main.go index c836aa4..162abd3 100644 --- a/cmd/proapi/main.go +++ b/cmd/proapi/main.go @@ -388,9 +388,12 @@ func wireAccountHandler(a *app.Application, adminG *gin.RouterGroup, log *zap.Lo middleware.SessionAuth(sessStore, a.Clock), middleware.RoleGate(2), ) - // 普通 admin 可访问的 13 个 endpoint。 + // OAuth callback 不带 session,依赖一次性 state 校验。 + adminG.GET("/accounts/oauth/callback", h.OAuthCallback) + // 普通 admin 可访问的 endpoint。 accG.GET("/accounts", h.List) accG.GET("/accounts/stats/overview", h.Stats) + accG.POST("/accounts/oauth/start", h.OAuthStart) accG.GET("/accounts/:id", h.Get) accG.POST("/accounts", h.Create) accG.POST("/accounts/import", h.Import) diff --git a/docs/superpowers/plans/2026-06-08-account-oauth-admin-endpoints.md b/docs/superpowers/plans/2026-06-08-account-oauth-admin-endpoints.md new file mode 100644 index 0000000..e2d0028 --- /dev/null +++ b/docs/superpowers/plans/2026-06-08-account-oauth-admin-endpoints.md @@ -0,0 +1,96 @@ +# Account OAuth Admin Endpoints Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Expose the existing account-pool OAuth PKCE flow through admin HTTP endpoints and enable the admin UI to launch it. + +**Architecture:** Keep PKCE/state/token exchange inside `account.OAuthFlow`. `AccountHandler` owns HTTP validation, account persistence, events, probe trigger, and audit. The callback endpoint is mounted without session auth and relies on one-time state validation. + +**Tech Stack:** Go, Gin, account facade, Redis-backed OAuth state store, Vue 3, Naive UI, TypeScript. + +--- + +### Task 1: Backend Handler Tests + +**Files:** +- Modify: `internal/server/handler/admin/account_test.go` + +- [ ] **Step 1: Add a fake OAuthFlow to the existing account handler test harness** + +Create a fake that records `Start` input and returns a prepared account from `Callback`. Add it to the `account.Facade`. + +- [ ] **Step 2: Write failing tests for start and callback** + +Add tests asserting: +- `POST /api/admin/accounts/oauth/start` with `provider=openai` and `channel_id=42` returns `auth_url` and `state`. +- missing `channel_id` returns HTTP 400. +- `GET /api/admin/accounts/oauth/callback?state=s&code=c` persists the returned account, appends an `oauth_callback` event, writes `account.oauth_callback` audit, and returns HTML containing `account_oauth_done`. + +- [ ] **Step 3: Verify red** + +Run: `go test ./internal/server/handler/admin -run 'TestAccountHandler_OAuth'` + +Expected: FAIL because `OAuthStart` and `OAuthCallback` handler methods and routes do not exist yet. + +### Task 2: Backend Implementation + +**Files:** +- Modify: `internal/server/handler/admin/account.go` +- Modify: `cmd/proapi/main.go` + +- [ ] **Step 1: Add handler request/response code** + +Implement: +- `OAuthStart`: validates JSON body, requires `channel_id > 0`, calls `h.Facade.OAuth.Start`, returns `{"data":{"auth_url": "...", "state": "..."}}`. +- `OAuthCallback`: validates `state` and `code`, calls `h.Facade.OAuth.Callback`, persists account through `Repo.Create`, appends `oauth_callback` event, optionally starts probe, audits `account.oauth_callback`, and returns a tiny HTML page that posts `account_oauth_done` to the opener and closes. + +- [ ] **Step 2: Route normal admin start and no-auth callback** + +In `wireAccountHandler`, mount `POST /accounts/oauth/start` inside the existing admin-authenticated account group. Mount `GET /accounts/oauth/callback` directly on `/api/admin` with only JSON error middleware already present on the parent group. + +- [ ] **Step 3: Verify green** + +Run: `go test ./internal/server/handler/admin -run 'TestAccountHandler_OAuth'` + +Expected: PASS. + +### Task 3: Frontend Minimal Wiring + +**Files:** +- Modify: `web/admin/src/api/account.ts` +- Modify: `web/admin/src/views/accounts/AddDialog.vue` +- Modify: `web/admin/src/i18n/zh.json` +- Modify: `web/admin/src/i18n/en.json` + +- [ ] **Step 1: Add account API methods** + +Add `oauthStart(payload)` returning `{ auth_url, state }`. + +- [ ] **Step 2: Enable OAuth tab** + +Replace the disabled info panel with channel/provider controls and a button that calls `accountApi.oauthStart`, opens the returned URL in a popup, and listens for `account_oauth_done` to refresh and close the dialog. + +- [ ] **Step 3: Verify frontend type/build health** + +Run the package build command available in `web/admin/package.json`. + +### Task 4: Full Verification + +**Files:** +- No edits. + +- [ ] **Step 1: Backend full test** + +Run: `go test ./...` + +Expected: PASS. + +- [ ] **Step 2: Build** + +Run: `go build ./...` + +Expected: PASS. + +- [ ] **Step 3: Manual IdP smoke note** + +If real OAuth config is present, open admin, select a provider/channel, launch OAuth, complete provider auth, and verify a new account row appears. If config is absent, document the exact missing config keys. diff --git a/internal/account/oauth/anthropic.go b/internal/account/oauth/anthropic.go index 50d65e9..e407241 100644 --- a/internal/account/oauth/anthropic.go +++ b/internal/account/oauth/anthropic.go @@ -58,7 +58,7 @@ func (a *Anthropic) token(ctx context.Context, form url.Values) (*account.Accoun if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { return nil, fmt.Errorf("anthropic token exchange: status %d", resp.StatusCode) } diff --git a/internal/account/oauth/openai.go b/internal/account/oauth/openai.go index 716e873..301698f 100644 --- a/internal/account/oauth/openai.go +++ b/internal/account/oauth/openai.go @@ -58,7 +58,7 @@ func (o *OpenAI) token(ctx context.Context, form url.Values) (*account.AccountCr if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { return nil, fmt.Errorf("openai token exchange: status %d", resp.StatusCode) } diff --git a/internal/account/probe/anthropic.go b/internal/account/probe/anthropic.go index 7fa5a0c..06cc05b 100644 --- a/internal/account/probe/anthropic.go +++ b/internal/account/probe/anthropic.go @@ -34,7 +34,7 @@ func (a *Anthropic) Probe(ctx context.Context, cred account.AccountCred) (http.H if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() _, _ = io.Copy(io.Discard, resp.Body) if resp.StatusCode >= 400 { return resp.Header, fmt.Errorf("anthropic probe: status %d", resp.StatusCode) diff --git a/internal/account/probe/openai.go b/internal/account/probe/openai.go index 2b03062..0e36acc 100644 --- a/internal/account/probe/openai.go +++ b/internal/account/probe/openai.go @@ -34,7 +34,7 @@ func (o *OpenAI) Probe(ctx context.Context, cred account.AccountCred) (http.Head if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() _, _ = io.Copy(io.Discard, resp.Body) if resp.StatusCode >= 400 { return resp.Header, fmt.Errorf("openai probe: status %d", resp.StatusCode) diff --git a/internal/account/repo_test.go b/internal/account/repo_test.go index 4611641..b6584e0 100644 --- a/internal/account/repo_test.go +++ b/internal/account/repo_test.go @@ -38,22 +38,36 @@ func newTestDB(t *testing.T) *gorm.DB { } t.Cleanup(func() { _ = pool.Purge(res) }) - dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC", + dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC&timeout=5s&readTimeout=5s&writeTimeout=5s", res.GetPort("3306/tcp")) var db *gorm.DB - pool.MaxWait = 90 * time.Second + pool.MaxWait = 120 * time.Second if err := pool.Retry(func() error { var openErr error - db, openErr = gorm.Open(mysql.Open(dsn), &gorm.Config{}) + candidate, openErr := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if openErr != nil { return openErr } - sqlDB, _ := db.DB() - return sqlDB.Ping() + sqlDB, dbErr := candidate.DB() + if dbErr != nil { + return dbErr + } + if pingErr := sqlDB.Ping(); pingErr != nil { + _ = sqlDB.Close() + return pingErr + } + db = candidate + return nil }); err != nil { t.Fatalf("could not connect to mysql: %v", err) } + t.Cleanup(func() { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + }) // Run account migrations via raw SQL. sqls := []string{ @@ -162,8 +176,11 @@ func TestRepo_ListByChannel(t *testing.T) { r := account.NewRepository(db, cr, idg, clock.Real, nil) for i := 0; i < 3; i++ { require.NoError(t, r.Create(ctx, &account.Account{ - ChannelID: 200, Provider: "anthropic", CredType: "apikey", - Status: account.StatusActive, Weight: 100, + ChannelID: 200, + Provider: "anthropic", + ExternalAccountID: fmt.Sprintf("list-%d", i), + CredType: "apikey", + Status: account.StatusActive, Weight: 100, Extra: json.RawMessage("{}"), Cred: account.AccountCred{APIKey: "sk-x"}, })) diff --git a/internal/auth/oauth/dingtalk/provider.go b/internal/auth/oauth/dingtalk/provider.go index 1c05fe6..3bea358 100644 --- a/internal/auth/oauth/dingtalk/provider.go +++ b/internal/auth/oauth/dingtalk/provider.go @@ -142,7 +142,7 @@ func (p *provider) fetchToken(ctx context.Context, code, redirectURL string) (*t if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "dingtalk token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(respBody, &tok); err != nil { @@ -166,7 +166,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "dingtalk userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("dingtalk /users/me http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/discord/provider.go b/internal/auth/oauth/discord/provider.go index b417026..6e36c9d 100644 --- a/internal/auth/oauth/discord/provider.go +++ b/internal/auth/oauth/discord/provider.go @@ -129,7 +129,7 @@ func (p *provider) fetchToken(ctx context.Context, code, redirectURL string) (*t if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "discord token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(body, &tok); err != nil { @@ -153,7 +153,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "discord userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("discord /users/@me http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/feishu/provider.go b/internal/auth/oauth/feishu/provider.go index 670f5b5..5a042c7 100644 --- a/internal/auth/oauth/feishu/provider.go +++ b/internal/auth/oauth/feishu/provider.go @@ -146,7 +146,7 @@ func (p *provider) fetchToken(ctx context.Context, code string) (*tokenResponse, if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "feishu token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(respBody, &tok); err != nil { @@ -170,7 +170,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "feishu userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("feishu /user_info http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/google/provider.go b/internal/auth/oauth/google/provider.go index 1a26002..cde399a 100644 --- a/internal/auth/oauth/google/provider.go +++ b/internal/auth/oauth/google/provider.go @@ -106,7 +106,7 @@ func (p *provider) Exchange(ctx context.Context, code, redirectURL string) (*oau if err != nil { return nil, "", apierr.Wrap(apierr.CodeUpstreamUnavail, "google token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() tokBody, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(tokBody, &tok); err != nil { @@ -140,7 +140,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken string) ([]byt if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "google userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("google /userinfo http %d", resp.StatusCode)) } diff --git a/internal/auth/oauth/wechat/provider.go b/internal/auth/oauth/wechat/provider.go index 4a93a73..df3826b 100644 --- a/internal/auth/oauth/wechat/provider.go +++ b/internal/auth/oauth/wechat/provider.go @@ -127,7 +127,7 @@ func (p *provider) fetchToken(ctx context.Context, code string) (*tokenResponse, if err != nil { return nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "wechat token http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) var tok tokenResponse if err := json.Unmarshal(body, &tok); err != nil { @@ -154,7 +154,7 @@ func (p *provider) fetchUserInfo(ctx context.Context, accessToken, openID string if err != nil { return nil, nil, apierr.Wrap(apierr.CodeUpstreamUnavail, "wechat userinfo http", err) } - defer resp.Body.Close() + defer func() { _ = resp.Body.Close() }() if resp.StatusCode/100 != 2 { return nil, nil, apierr.New(apierr.CodeUpstreamError, fmt.Sprintf("wechat /userinfo http %d", resp.StatusCode)) } diff --git a/internal/invite/service.go b/internal/invite/service.go index e3fd6f6..21b235d 100644 --- a/internal/invite/service.go +++ b/internal/invite/service.go @@ -21,7 +21,7 @@ type WalletCredit interface { // Deps holds dependencies for NewService. type Deps struct { Repo Repository - DB *gorm.DB // raw DB for cross-package queries without import cycles + DB *gorm.DB // raw DB for cross-package queries without import cycles Wallet WalletCredit Setting setting.Store IDGen *idgen.Generator @@ -51,7 +51,10 @@ func (s *Service) OnOrderPaid(ctx context.Context, orderID, userID int64) error Table("users"). Select("invited_by"). Where("id = ?", userID). - Scan(&invitedBy).Error; err != nil || invitedBy == 0 { + Scan(&invitedBy).Error; err != nil { + return fmt.Errorf("invite: query inviter for user %d: %w", userID, err) + } + if invitedBy == 0 { return nil // no inviter } diff --git a/internal/invite/service_query_test.go b/internal/invite/service_query_test.go index 2d87d47..6f25f3c 100644 --- a/internal/invite/service_query_test.go +++ b/internal/invite/service_query_test.go @@ -33,8 +33,8 @@ func (f *fakeSetting) GetString(_ context.Context, key, def string) string { } return v.(string) } -func (f *fakeSetting) GetBool(_ context.Context, _ string, def bool) bool { return def } -func (f *fakeSetting) GetInt(_ context.Context, _ string, def int) int { return def } +func (f *fakeSetting) GetBool(_ context.Context, _ string, def bool) bool { return def } +func (f *fakeSetting) GetInt(_ context.Context, _ string, def int) int { return def } func (f *fakeSetting) GetFloat(_ context.Context, key string, def float64) float64 { v, ok := f.kv[key] if !ok { @@ -42,7 +42,7 @@ func (f *fakeSetting) GetFloat(_ context.Context, key string, def float64) float } return v.(float64) } -func (f *fakeSetting) GetJSON(_ context.Context, _ string, _ any) error { return nil } +func (f *fakeSetting) GetJSON(_ context.Context, _ string, _ any) error { return nil } func (f *fakeSetting) Put(_ context.Context, _ string, _ any, _ int64) error { return nil } func (f *fakeSetting) Close() error { return nil } func (f *fakeSetting) GetSecret(_ context.Context, _ string, _ setting.Decryptor) (string, error) { @@ -143,6 +143,18 @@ func TestMaskEmail(t *testing.T) { } } +func TestOnOrderPaidReturnsUserLookupError(t *testing.T) { + db, svc := setupQueryDB(t) + ctx := context.Background() + + if err := db.Exec(`DROP TABLE users`).Error; err != nil { + t.Fatal(err) + } + if err := svc.OnOrderPaid(ctx, 1, 20); err == nil { + t.Fatal("expected user lookup error") + } +} + func TestListInvitees(t *testing.T) { db, svc := setupQueryDB(t) ctx := context.Background() diff --git a/internal/notify/email/smtp.go b/internal/notify/email/smtp.go index 41ad410..edd792a 100644 --- a/internal/notify/email/smtp.go +++ b/internal/notify/email/smtp.go @@ -66,7 +66,7 @@ func (m *smtpMailer) sendTLS(addr string, raw []byte, to string) error { if err != nil { return fmt.Errorf("smtp new client: %w", err) } - defer c.Close() + defer func() { _ = c.Close() }() return m.doSend(c, raw, to) } @@ -75,7 +75,7 @@ func (m *smtpMailer) sendSTARTTLS(addr string, raw []byte, to string) error { if err != nil { return fmt.Errorf("smtp dial: %w", err) } - defer c.Close() + defer func() { _ = c.Close() }() if ok, _ := c.Extension("STARTTLS"); ok { tlsCfg := &tls.Config{ diff --git a/internal/orm/orm_integration_test.go b/internal/orm/orm_integration_test.go index 47b1a4a..3c8a31f 100644 --- a/internal/orm/orm_integration_test.go +++ b/internal/orm/orm_integration_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/ory/dockertest/v3" "github.com/ijry/pro-api/internal/app/config" + "github.com/ory/dockertest/v3" ) func mustPool(t *testing.T) *dockertest.Pool { @@ -34,10 +34,10 @@ func TestOpen_MySQL(t *testing.T) { } t.Cleanup(func() { _ = pool.Purge(res) }) - dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC", + dsn := fmt.Sprintf("root:proapi@tcp(127.0.0.1:%s)/proapi?charset=utf8mb4&parseTime=True&loc=UTC&timeout=5s&readTimeout=5s&writeTimeout=5s", res.GetPort("3306/tcp")) - pool.MaxWait = 90 * time.Second + pool.MaxWait = 120 * time.Second if err := pool.Retry(func() error { db, err := Open(config.DatabaseConfig{ Driver: "mysql", DSN: dsn, @@ -46,8 +46,15 @@ func TestOpen_MySQL(t *testing.T) { if err != nil { return err } - sqlDB, _ := db.DB() - return sqlDB.Ping() + sqlDB, dbErr := db.DB() + if dbErr != nil { + return dbErr + } + if pingErr := sqlDB.Ping(); pingErr != nil { + _ = sqlDB.Close() + return pingErr + } + return sqlDB.Close() }); err != nil { t.Fatal(err) } @@ -76,10 +83,16 @@ func TestOpen_Postgres(t *testing.T) { if err != nil { return err } - sqlDB, _ := db.DB() - return sqlDB.Ping() + sqlDB, dbErr := db.DB() + if dbErr != nil { + return dbErr + } + if pingErr := sqlDB.Ping(); pingErr != nil { + _ = sqlDB.Close() + return pingErr + } + return sqlDB.Close() }); err != nil { t.Fatal(err) } } - diff --git a/internal/server/handler/admin/account.go b/internal/server/handler/admin/account.go index ef6d593..5ff04b3 100644 --- a/internal/server/handler/admin/account.go +++ b/internal/server/handler/admin/account.go @@ -15,6 +15,7 @@ package admin import ( "encoding/json" + "errors" "io" "net/http" "strconv" @@ -22,6 +23,7 @@ import ( "github.com/gin-gonic/gin" "github.com/ijry/pro-api/internal/account" + accountoauth "github.com/ijry/pro-api/internal/account/oauth" "github.com/ijry/pro-api/internal/audit" "github.com/ijry/pro-api/internal/server/middleware" "github.com/ijry/pro-api/pkg/apierr" @@ -51,6 +53,8 @@ func NewAccountHandler(f *account.Facade, a audit.Logger, actorOf func(*gin.Cont func (h *AccountHandler) Register(r gin.IRouter) { r.GET("/accounts", h.List) r.GET("/accounts/stats/overview", h.Stats) + r.POST("/accounts/oauth/start", h.OAuthStart) + r.GET("/accounts/oauth/callback", h.OAuthCallback) r.GET("/accounts/:id", h.Get) r.POST("/accounts", h.Create) r.POST("/accounts/import", h.Import) @@ -119,26 +123,26 @@ func (h *AccountHandler) auditOne(c *gin.Context, action string, targetID int64, // listItem 是 List 返回的 item。绝不包含明文凭证或 Credentials 密文。 type listItem struct { - ID int64 `json:"id"` - ChannelID int64 `json:"channel_id"` - ShareTag string `json:"share_tag"` - Name string `json:"name"` - Provider string `json:"provider"` - Tier string `json:"tier"` - CredType string `json:"cred_type"` - Email string `json:"email"` - Status int8 `json:"status"` - Priority int16 `json:"priority"` - Weight int `json:"weight"` - ConsecFailures int `json:"consec_failures"` - RefreshTokenValid int8 `json:"refresh_token_valid"` - CooldownUntil *time.Time `json:"cooldown_until,omitempty"` - LastUsedAt *time.Time `json:"last_used_at,omitempty"` - LastSuccessAt *time.Time `json:"last_success_at,omitempty"` - LastFailureAt *time.Time `json:"last_failure_at,omitempty"` - AccessTokenExpiresAt *time.Time `json:"access_token_expires_at,omitempty"` - Quota5h quotaItem `json:"quota_5h"` - QuotaWeek quotaItem `json:"quota_week"` + ID int64 `json:"id"` + ChannelID int64 `json:"channel_id"` + ShareTag string `json:"share_tag"` + Name string `json:"name"` + Provider string `json:"provider"` + Tier string `json:"tier"` + CredType string `json:"cred_type"` + Email string `json:"email"` + Status int8 `json:"status"` + Priority int16 `json:"priority"` + Weight int `json:"weight"` + ConsecFailures int `json:"consec_failures"` + RefreshTokenValid int8 `json:"refresh_token_valid"` + CooldownUntil *time.Time `json:"cooldown_until,omitempty"` + LastUsedAt *time.Time `json:"last_used_at,omitempty"` + LastSuccessAt *time.Time `json:"last_success_at,omitempty"` + LastFailureAt *time.Time `json:"last_failure_at,omitempty"` + AccessTokenExpiresAt *time.Time `json:"access_token_expires_at,omitempty"` + Quota5h quotaItem `json:"quota_5h"` + QuotaWeek quotaItem `json:"quota_week"` } // detailItem 在 listItem 之外再加几个详情字段。同样不含明文凭证。 @@ -289,6 +293,11 @@ type acctCreateReq struct { Weight int `json:"weight"` } +type acctOAuthStartReq struct { + Provider string `json:"provider"` + ChannelID int64 `json:"channel_id"` +} + // Create POST /accounts — 单条创建,与 Import 类似但只取第一条。dry_run 时不落库。 func (h *AccountHandler) Create(c *gin.Context) { var req acctCreateReq @@ -368,6 +377,90 @@ func (h *AccountHandler) Create(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"data": gin.H{"id": a.ID}}) } +// OAuthStart POST /accounts/oauth/start — 启动账号池 OAuth PKCE 流程。 +func (h *AccountHandler) OAuthStart(c *gin.Context) { + if h.Facade.OAuth == nil { + middleware.SetErr(c, apierr.New(apierr.CodeInternal, "oauth 未就绪")) + return + } + var req acctOAuthStartReq + if err := c.ShouldBindJSON(&req); err != nil { + middleware.SetErr(c, apierr.New(apierr.CodeInvalidParam, "请求体不合法")) + return + } + if req.Provider == "" { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "provider 必填")) + return + } + if req.ChannelID <= 0 { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "channel_id 必填")) + return + } + authURL, state, err := h.Facade.OAuth.Start(c.Request.Context(), req.Provider, req.ChannelID) + if err != nil { + middleware.SetErr(c, apierr.Wrap(apierr.CodeAccountOAuthRejected, "oauth start failed", err)) + return + } + c.JSON(http.StatusOK, gin.H{"data": gin.H{"auth_url": authURL, "state": state}}) +} + +// OAuthCallback GET /accounts/oauth/callback — OAuth provider 回调入口(no-auth,state 校验)。 +func (h *AccountHandler) OAuthCallback(c *gin.Context) { + if h.Facade.OAuth == nil { + middleware.SetErr(c, apierr.New(apierr.CodeInternal, "oauth 未就绪")) + return + } + state := c.Query("state") + code := c.Query("code") + if state == "" { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "state 必填")) + return + } + if code == "" { + middleware.SetErr(c, apierr.New(apierr.CodeMissingParam, "code 必填")) + return + } + a, err := h.Facade.OAuth.Callback(c.Request.Context(), state, code) + if err != nil { + if errors.Is(err, accountoauth.ErrStateNotFound) { + middleware.SetErr(c, apierr.Wrap(apierr.CodeAccountOAuthState, "oauth state invalid", err)) + return + } + middleware.SetErr(c, apierr.Wrap(apierr.CodeAccountOAuthRejected, "oauth callback failed", err)) + return + } + if err := h.Facade.Repo.Create(c.Request.Context(), a); err != nil { + writeAcctErr(c, err) + return + } + _ = h.Facade.Repo.AppendEvent(c.Request.Context(), a.ID, "oauth_callback", + map[string]any{"provider": a.Provider, "channel_id": a.ChannelID}) + if h.Facade.Probe != nil { + cc := c.Copy() + go func(acc *account.Account) { + _ = h.Facade.Probe.Run(cc.Request.Context(), acc) + }(a) + } + h.auditOne(c, "account.oauth_callback", a.ID, nil, map[string]any{ + "id": a.ID, + "channel_id": a.ChannelID, + "provider": a.Provider, + }) + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(accountOAuthDoneHTML)) +} + +const accountOAuthDoneHTML = ` +