Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions internal/handlers/sns_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ import (
"strings"
"sync"
"time"

"golang.org/x/sync/singleflight"
)

// snsSigningCertHostRegex enforces "sns.<region>.amazonaws.com" hostnames
Expand Down Expand Up @@ -92,6 +94,13 @@ type snsVerifier struct {

mu sync.RWMutex
certCache map[string]snsCertCacheEntry

// sfGroup collapses concurrent cache-miss fetches for the same certURL
// into a single in-flight network call. Without it a burst of SNS
// deliveries that all miss the cache (cold start, or just after a TTL
// expiry) each fire their own fetchCert and then race to overwrite the
// map — a thundering herd against the AWS cert endpoint. See getCert.
sfGroup singleflight.Group
}

// newSNSVerifier returns a verifier with production defaults.
Expand Down Expand Up @@ -216,6 +225,14 @@ func (v *snsVerifier) verify(msg snsMessage) error {
}

// getCert returns a cached certificate or fetches it via fetchCert.
//
// The miss path is collapsed through a singleflight group keyed by certURL so
// a burst of concurrent SNS deliveries that all miss the cache (cold start, or
// the instant after a TTL expiry) issue exactly ONE fetchCert call between
// them rather than a thundering herd against the AWS cert endpoint. The fetch
// closure re-checks the cache under the lock (double-checked locking) before
// fetching, so a request that joins just after the leader populated the cache
// returns the fresh entry without a redundant network call.
func (v *snsVerifier) getCert(certURL string) (*x509.Certificate, error) {
v.mu.RLock()
entry, ok := v.certCache[certURL]
Expand All @@ -224,15 +241,20 @@ func (v *snsVerifier) getCert(certURL string) (*x509.Certificate, error) {
return entry.cert, nil
}

cert, err := v.fetchCert("sns", certURL)
cert, err, _ := v.sfGroup.Do(certURL, func() (any, error) {
c, err := v.fetchCert("sns", certURL)
if err != nil {
return nil, err
}
v.mu.Lock()
v.certCache[certURL] = snsCertCacheEntry{cert: c, fetched: time.Now()}
v.mu.Unlock()
return c, nil
})
if err != nil {
return nil, err
}

v.mu.Lock()
v.certCache[certURL] = snsCertCacheEntry{cert: cert, fetched: time.Now()}
v.mu.Unlock()
return cert, nil
return cert.(*x509.Certificate), nil
}

// defaultFetchCert fetches the PEM cert at certURL and returns the first
Expand Down
46 changes: 46 additions & 0 deletions internal/handlers/sns_verify_final2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"encoding/base64"
"errors"
"math/big"
"sync"
"sync/atomic"
"testing"
"time"
)
Expand Down Expand Up @@ -183,3 +185,47 @@ func TestSNSVerifyFinal2_GetCert_CacheHitAndMiss(t *testing.T) {
t.Error("getCert must propagate fetch error")
}
}

// TestSNSVerifyFinal2_GetCert_SingleflightCollapsesConcurrentMisses proves the
// bug-bash #1/#9 fix: a burst of concurrent cache-miss getCert calls for the
// same certURL collapses into exactly ONE fetchCert call (no thundering herd
// against the AWS cert endpoint). The fetch is held open on a channel so every
// goroutine joins the same in-flight singleflight call before it completes.
func TestSNSVerifyFinal2_GetCert_SingleflightCollapsesConcurrentMisses(t *testing.T) {
cert, _ := final2GenCertKey(t)
v := newSNSVerifier()

var calls int32
proceed := make(chan struct{})
v.fetchCert = func(_ string, _ string) (*x509.Certificate, error) {
atomic.AddInt32(&calls, 1)
<-proceed // hold the flight open so concurrent callers collapse into it
return cert, nil
}

const n = 20
var wg sync.WaitGroup
errs := make([]error, n)
for i := 0; i < n; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
_, errs[i] = v.getCert(final2AWSCertURL)
}(i)
}

// Give all goroutines time to enter the singleflight Do and collapse onto
// the leader's still-open flight, then release the fetch.
time.Sleep(50 * time.Millisecond)
close(proceed)
wg.Wait()

for i, err := range errs {
if err != nil {
t.Fatalf("getCert[%d] returned error: %v", i, err)
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Errorf("singleflight should collapse %d concurrent misses into 1 fetch, got %d", n, got)
}
}
Loading