diff --git a/internal/handlers/sns_verify.go b/internal/handlers/sns_verify.go index b25c76a7..97bd9a9c 100644 --- a/internal/handlers/sns_verify.go +++ b/internal/handlers/sns_verify.go @@ -43,6 +43,8 @@ import ( "strings" "sync" "time" + + "golang.org/x/sync/singleflight" ) // snsSigningCertHostRegex enforces "sns..amazonaws.com" hostnames @@ -92,6 +94,13 @@ type snsVerifier struct { mu sync.RWMutex certCache map[string]snsCertCacheEntry + + // sfGroup collapses concurrent cache-miss fetches for the same certURL + // into a single in-flight network call. Without it a burst of SNS + // deliveries that all miss the cache (cold start, or just after a TTL + // expiry) each fire their own fetchCert and then race to overwrite the + // map — a thundering herd against the AWS cert endpoint. See getCert. + sfGroup singleflight.Group } // newSNSVerifier returns a verifier with production defaults. @@ -216,6 +225,14 @@ func (v *snsVerifier) verify(msg snsMessage) error { } // getCert returns a cached certificate or fetches it via fetchCert. +// +// The miss path is collapsed through a singleflight group keyed by certURL so +// a burst of concurrent SNS deliveries that all miss the cache (cold start, or +// the instant after a TTL expiry) issue exactly ONE fetchCert call between +// them rather than a thundering herd against the AWS cert endpoint. The fetch +// closure re-checks the cache under the lock (double-checked locking) before +// fetching, so a request that joins just after the leader populated the cache +// returns the fresh entry without a redundant network call. func (v *snsVerifier) getCert(certURL string) (*x509.Certificate, error) { v.mu.RLock() entry, ok := v.certCache[certURL] @@ -224,15 +241,20 @@ func (v *snsVerifier) getCert(certURL string) (*x509.Certificate, error) { return entry.cert, nil } - cert, err := v.fetchCert("sns", certURL) + cert, err, _ := v.sfGroup.Do(certURL, func() (any, error) { + c, err := v.fetchCert("sns", certURL) + if err != nil { + return nil, err + } + v.mu.Lock() + v.certCache[certURL] = snsCertCacheEntry{cert: c, fetched: time.Now()} + v.mu.Unlock() + return c, nil + }) if err != nil { return nil, err } - - v.mu.Lock() - v.certCache[certURL] = snsCertCacheEntry{cert: cert, fetched: time.Now()} - v.mu.Unlock() - return cert, nil + return cert.(*x509.Certificate), nil } // defaultFetchCert fetches the PEM cert at certURL and returns the first diff --git a/internal/handlers/sns_verify_final2_test.go b/internal/handlers/sns_verify_final2_test.go index 6fde2db4..01787cb2 100644 --- a/internal/handlers/sns_verify_final2_test.go +++ b/internal/handlers/sns_verify_final2_test.go @@ -25,6 +25,8 @@ import ( "encoding/base64" "errors" "math/big" + "sync" + "sync/atomic" "testing" "time" ) @@ -183,3 +185,47 @@ func TestSNSVerifyFinal2_GetCert_CacheHitAndMiss(t *testing.T) { t.Error("getCert must propagate fetch error") } } + +// TestSNSVerifyFinal2_GetCert_SingleflightCollapsesConcurrentMisses proves the +// bug-bash #1/#9 fix: a burst of concurrent cache-miss getCert calls for the +// same certURL collapses into exactly ONE fetchCert call (no thundering herd +// against the AWS cert endpoint). The fetch is held open on a channel so every +// goroutine joins the same in-flight singleflight call before it completes. +func TestSNSVerifyFinal2_GetCert_SingleflightCollapsesConcurrentMisses(t *testing.T) { + cert, _ := final2GenCertKey(t) + v := newSNSVerifier() + + var calls int32 + proceed := make(chan struct{}) + v.fetchCert = func(_ string, _ string) (*x509.Certificate, error) { + atomic.AddInt32(&calls, 1) + <-proceed // hold the flight open so concurrent callers collapse into it + return cert, nil + } + + const n = 20 + var wg sync.WaitGroup + errs := make([]error, n) + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, errs[i] = v.getCert(final2AWSCertURL) + }(i) + } + + // Give all goroutines time to enter the singleflight Do and collapse onto + // the leader's still-open flight, then release the fetch. + time.Sleep(50 * time.Millisecond) + close(proceed) + wg.Wait() + + for i, err := range errs { + if err != nil { + t.Fatalf("getCert[%d] returned error: %v", i, err) + } + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Errorf("singleflight should collapse %d concurrent misses into 1 fetch, got %d", n, got) + } +}