diff --git a/internal/api/recover.go b/internal/api/recover.go index 7c967aaf0..70569c1e7 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -2,10 +2,12 @@ package api import ( "net/http" + "time" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" ) // RecoverParams holds the parameters for a password recovery request @@ -29,8 +31,15 @@ func (p *RecoverParams) Validate(a *API) error { return nil } -// Recover sends a recovery email func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { + start := time.Now() + const minResponseTime = 500 * time.Millisecond + defer func() { + if elapsed := time.Since(start); elapsed < minResponseTime { + time.Sleep(minResponseTime - elapsed) + } + }() + ctx := r.Context() db := a.db.WithContext(ctx) config := a.config @@ -52,6 +61,12 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) if err != nil { if models.IsNotFoundError(err) { + // Mitigate rate-limit enumeration by using an in-memory cache for non-existent users + // Use a domain-separated secret to prevent key separation violations + secret := []byte("fake_rate_limit:" + config.JWT.Secret) + if lastReq := utilities.CheckFakeRateLimit(db, params.Email, config.SMTP.MaxFrequency, secret); lastReq != nil { + return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverEmailSendRateLimit, "%s", generateFrequencyLimitErrorMessage(lastReq, config.SMTP.MaxFrequency)) + } return sendJSON(w, http.StatusOK, map[string]string{}) } return apierrors.NewInternalServerError("Unable to process request").WithInternalError(err) diff --git a/internal/api/recover_test.go b/internal/api/recover_test.go index a7e655c59..cfcf2abba 100644 --- a/internal/api/recover_test.go +++ b/internal/api/recover_test.go @@ -130,7 +130,7 @@ func (ts *RecoverTestSuite) TestRecover_NewEmailSent() { assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) } -func (ts *RecoverTestSuite) TestRecover_NoSideChannelLeak() { +func (ts *RecoverTestSuite) TestRecover_NoSideChannelLeak_FirstRequest() { email := "doesntexist@example.com" _, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) @@ -151,3 +151,36 @@ func (ts *RecoverTestSuite) TestRecover_NoSideChannelLeak() { ts.API.handler.ServeHTTP(w, req) assert.Equal(ts.T(), http.StatusOK, w.Code) } + +func (ts *RecoverTestSuite) TestRecover_NoSideChannelLeak_RateLimit() { + email := "doesntexist_ratelimit@example.com" + + _, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) + require.True(ts.T(), models.IsNotFoundError(err), "User with email %s does exist", email) + + // First Request + var buffer1 bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer1).Encode(map[string]interface{}{ + "email": email, + })) + req1 := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer1) + req1.Header.Set("Content-Type", "application/json") + + w1 := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w1, req1) + assert.Equal(ts.T(), http.StatusOK, w1.Code) + + // Second Request immediately after + var buffer2 bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer2).Encode(map[string]interface{}{ + "email": email, + })) + req2 := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer2) + req2.Header.Set("Content-Type", "application/json") + + w2 := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w2, req2) + + // Should be rate limited + assert.Equal(ts.T(), http.StatusTooManyRequests, w2.Code) +} diff --git a/internal/utilities/fake_rate_limiter.go b/internal/utilities/fake_rate_limiter.go new file mode 100644 index 000000000..bebd6fa7d --- /dev/null +++ b/internal/utilities/fake_rate_limiter.go @@ -0,0 +1,70 @@ +package utilities + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "math/rand" + "time" + + "github.com/supabase/auth/internal/storage" +) + +type FakeRateLimit struct { + EmailHash string `db:"email_hash"` + LastRequestAt time.Time `db:"last_request_at"` +} + +// TableName returns the table name +func (FakeRateLimit) TableName() string { + return "fake_rate_limits" +} + +// CheckFakeRateLimit simulates a rate limit check for a non-existent email. +// It returns the timestamp of the last request if it was rate limited, or nil if not. +func CheckFakeRateLimit(db *storage.Connection, email string, frequency time.Duration, secret []byte) *time.Time { + h := hmac.New(sha256.New, secret) + h.Write([]byte(email)) + hashStr := hex.EncodeToString(h.Sum(nil)) + + var lastReq *time.Time + _ = db.Transaction(func(tx *storage.Connection) error { + // Pre-insert a sentinel row so the row always exists before we lock it. + // This prevents two concurrent first-requests from both racing past FOR UPDATE. + epoch := time.Unix(0, 0).UTC() + _ = tx.RawQuery(`INSERT INTO fake_rate_limits (email_hash, last_request_at) VALUES (?, ?) ON CONFLICT DO NOTHING`, hashStr, epoch).Exec() + + // Lock the now-guaranteed-existing row + existing := &FakeRateLimit{} + if err := tx.RawQuery(`SELECT last_request_at FROM fake_rate_limits WHERE email_hash = ? FOR UPDATE`, hashStr).First(existing); err != nil { + return err + } + + now := time.Now() + if now.Sub(existing.LastRequestAt) < frequency { + // Rate limited! + last := existing.LastRequestAt + lastReq = &last + return nil + } + // Not rate limited, update the timestamp + _ = tx.RawQuery(`UPDATE fake_rate_limits SET last_request_at = ? WHERE email_hash = ?`, now, hashStr).Exec() + return nil + }) + + // Probabilistic cleanup (10% chance) to prevent table unbounded growth + if rand.Intn(10) == 0 { + go CleanupFakeRateLimitCache(db, frequency) + } + + return lastReq +} + +// CleanupFakeRateLimitCache removes expired entries from the cache. +// Call this periodically or when necessary to prevent unbounded memory growth. +func CleanupFakeRateLimitCache(db *storage.Connection, frequency time.Duration) { + _ = db.RawQuery( + `DELETE FROM fake_rate_limits WHERE EXTRACT(EPOCH FROM (NOW() - last_request_at)) > ?`, + frequency.Seconds(), + ).Exec() +} diff --git a/migrations/20260527000000_add_fake_rate_limits.up.sql b/migrations/20260527000000_add_fake_rate_limits.up.sql new file mode 100644 index 000000000..a4b19452e --- /dev/null +++ b/migrations/20260527000000_add_fake_rate_limits.up.sql @@ -0,0 +1,4 @@ +CREATE TABLE IF NOT EXISTS fake_rate_limits ( + email_hash VARCHAR(64) PRIMARY KEY, + last_request_at TIMESTAMP WITH TIME ZONE NOT NULL +); diff --git a/test_jwks.go b/test_jwks.go new file mode 100644 index 000000000..830d9c6b0 --- /dev/null +++ b/test_jwks.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "github.com/go-jose/go-jose/v3" +) + +func main() { + jwksStr := `{"keys":[{"kty":"EC","crv":"secp256k1","x":"1","y":"2"}]}` + var jwks jose.JSONWebKeySet + err := jwks.UnmarshalJSON([]byte(jwksStr)) + if err != nil { + fmt.Println("Error:", err) + return + } + fmt.Println("Parsed:", len(jwks.Keys)) +}