From 82cb09c3c34c51c34437c57a5cb5d1bd118a44de Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Thu, 9 Apr 2026 11:20:06 -0700 Subject: [PATCH 1/8] discovery/resolver/static: Revamp the static implementation and make it updatable --- discovery/resolver/static/resolver.go | 124 +++++++++++++++++-- discovery/resolver/static/resolver_test.go | 134 +++++++++++++++++++++ 2 files changed, 249 insertions(+), 9 deletions(-) diff --git a/discovery/resolver/static/resolver.go b/discovery/resolver/static/resolver.go index cdf8b4c..4ab80ef 100644 --- a/discovery/resolver/static/resolver.go +++ b/discovery/resolver/static/resolver.go @@ -3,6 +3,7 @@ package static import ( "context" "fmt" + "sync" "sync/atomic" "github.com/upfluence/pkg/v2/closer" @@ -14,7 +15,7 @@ import ( type Builder[T peer.Peer] map[string][]T func (b Builder[T]) Build(n string) resolver.Resolver[T] { - return &Resolver[T]{Peers: b[n]} + return NewResolver(b[n]) } func PeersFromStrings(addrs ...string) []Peer { @@ -30,7 +31,10 @@ func PeersFromStrings(addrs ...string) []Peer { type Resolver[T peer.Peer] struct { closer.Monitor - Peers []T + mu sync.Mutex + chs []chan resolver.Update[T] + + peers []T } type Peer string @@ -43,13 +47,94 @@ func NewResolverFromStrings(addrs []string) *Resolver[Peer] { } func NewResolver[T peer.Peer](peers []T) *Resolver[T] { - return &Resolver[T]{Peers: peers} + return &Resolver[T]{peers: peers} +} + +func (r *Resolver[T]) Peers() []T { + r.mu.Lock() + defer r.mu.Unlock() + + out := make([]T, len(r.peers)) + copy(out, r.peers) + + return out +} + +func (r *Resolver[T]) UpdatePeers(peers []T) { + r.mu.Lock() + + old := make(map[string]T, len(r.peers)) + for _, p := range r.peers { + old[p.Addr()] = p + } + + cur := make(map[string]T, len(peers)) + for _, p := range peers { + cur[p.Addr()] = p + } + + var u resolver.Update[T] + + for addr, p := range cur { + if _, ok := old[addr]; !ok { + u.Additions = append(u.Additions, p) + } + } + + for addr, p := range old { + if _, ok := cur[addr]; !ok { + u.Deletions = append(u.Deletions, p) + } + } + + r.peers = peers + + if len(u.Additions) == 0 && len(u.Deletions) == 0 { + r.mu.Unlock() + return + } + + chs := make([]chan resolver.Update[T], len(r.chs)) + copy(chs, r.chs) + r.mu.Unlock() + + for _, ch := range chs { + ch <- u + } +} + +func (r *Resolver[T]) subscribe() (chan resolver.Update[T], []T) { + ch := make(chan resolver.Update[T], 1) + + r.mu.Lock() + r.chs = append(r.chs, ch) + + peers := make([]T, len(r.peers)) + copy(peers, r.peers) + r.mu.Unlock() + + return ch, peers +} + +func (r *Resolver[T]) unsubscribe(ch chan resolver.Update[T]) { + r.mu.Lock() + defer r.mu.Unlock() + + for i, c := range r.chs { + if c == ch { + r.chs = append(r.chs[:i], r.chs[i+1:]...) + return + } + } } func (r *Resolver[T]) String() string { - var addrs = make([]string, len(r.Peers)) + r.mu.Lock() + defer r.mu.Unlock() - for i, peer := range r.Peers { + var addrs = make([]string, len(r.peers)) + + for i, peer := range r.peers { addrs[i] = peer.Addr() } @@ -68,24 +153,37 @@ type watcher[T peer.Peer] struct { closer.Monitor r *Resolver[T] + ch chan resolver.Update[T] initial int32 } func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { ok := atomic.CompareAndSwapInt32(&w.initial, 0, 1) - if opts.NoWait && (!ok || len(w.r.Peers) == 0) { - return resolver.Update[T]{}, resolver.ErrNoUpdates + if ok { + ch, peers := w.r.subscribe() + w.ch = ch + + if len(peers) > 0 { + return resolver.Update[T]{Additions: peers}, nil + } } - if ok && len(w.r.Peers) > 0 { - return resolver.Update[T]{Additions: w.r.Peers}, nil + if opts.NoWait { + select { + case u := <-w.ch: + return u, nil + default: + return resolver.Update[T]{}, resolver.ErrNoUpdates + } } wctx := w.Context() rctx := w.r.Context() select { + case u := <-w.ch: + return u, nil case <-ctx.Done(): return resolver.Update[T]{}, ctx.Err() case <-wctx.Done(): @@ -95,3 +193,11 @@ func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (re return resolver.Update[T]{}, wctx.Err() } } + +func (w *watcher[T]) Close() error { + if w.ch != nil { + w.r.unsubscribe(w.ch) + } + + return w.Monitor.Close() +} diff --git a/discovery/resolver/static/resolver_test.go b/discovery/resolver/static/resolver_test.go index f1c0030..b67ca29 100644 --- a/discovery/resolver/static/resolver_test.go +++ b/discovery/resolver/static/resolver_test.go @@ -41,3 +41,137 @@ func TestResolve(t *testing.T) { assert.Equal(t, err, context.Canceled) assert.Equal(t, resolver.Update[Peer]{}, u) } + +func TestPeers(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + + peers := r.Peers() + + assert.Equal(t, []Peer{Peer("localhost:1"), Peer("localhost:2")}, peers) + + // Mutating the returned slice should not affect the resolver + peers[0] = Peer("localhost:99") + assert.Equal(t, []Peer{Peer("localhost:1"), Peer("localhost:2")}, r.Peers()) +} + +func TestPeersEmpty(t *testing.T) { + r := NewResolver[Peer](nil) + assert.Equal(t, []Peer{}, r.Peers()) +} + +func TestUpdatePeers(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + + w := r.Resolve() + + // Consume initial update + u, err := w.Next(context.Background(), resolver.ResolveOptions{}) + assert.Nil(t, err) + assert.Equal( + t, + resolver.Update[Peer]{ + Additions: []Peer{ + Peer("localhost:1"), + Peer("localhost:2"), + }, + }, + u, + ) + + // Update peers: remove localhost:1, add localhost:3 + r.UpdatePeers(PeersFromStrings("localhost:2", "localhost:3")) + + assert.Equal( + t, + []Peer{Peer("localhost:2"), Peer("localhost:3")}, + r.Peers(), + ) + + // The watcher should receive the diff + u, err = w.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + assert.Nil(t, err) + assert.ElementsMatch(t, []Peer{Peer("localhost:3")}, u.Additions) + assert.ElementsMatch(t, []Peer{Peer("localhost:1")}, u.Deletions) +} + +func TestUpdatePeersNoChange(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(context.Background(), resolver.ResolveOptions{}) + assert.Nil(t, err) + + // Update with same peers — no diff + r.UpdatePeers(PeersFromStrings("localhost:1", "localhost:2")) + + // Should have no update + _, err = w.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + assert.Equal(t, resolver.ErrNoUpdates, err) +} + +func TestUpdatePeersMultipleWatchers(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1"}) + + w1 := r.Resolve() + w2 := r.Resolve() + + // Consume initial updates + _, err := w1.Next(context.Background(), resolver.ResolveOptions{}) + assert.Nil(t, err) + _, err = w2.Next(context.Background(), resolver.ResolveOptions{}) + assert.Nil(t, err) + + r.UpdatePeers(PeersFromStrings("localhost:1", "localhost:2")) + + u1, err := w1.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + assert.Nil(t, err) + assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u1.Additions) + + u2, err := w2.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + assert.Nil(t, err) + assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u2.Additions) +} + +func TestUpdatePeersClosedWatcher(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1"}) + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(context.Background(), resolver.ResolveOptions{}) + assert.Nil(t, err) + + // Close the watcher — it should unsubscribe + err = w.Close() + assert.Nil(t, err) + + // This should not block or panic + r.UpdatePeers(PeersFromStrings("localhost:2")) +} + +func TestUpdatePeersBlockingWatcher(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1"}) + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(context.Background(), resolver.ResolveOptions{}) + assert.Nil(t, err) + + // Start a blocking Next call in a goroutine + done := make(chan resolver.Update[Peer], 1) + go func() { + u, err := w.Next(context.Background(), resolver.ResolveOptions{}) + assert.Nil(t, err) + done <- u + }() + + // Update peers — should unblock the watcher + r.UpdatePeers(PeersFromStrings("localhost:2")) + + u := <-done + assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u.Additions) + assert.ElementsMatch(t, []Peer{Peer("localhost:1")}, u.Deletions) +} From 39ed4865f0ea2bebd37001f982632f234123f4ca Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Thu, 9 Apr 2026 11:21:04 -0700 Subject: [PATCH 2/8] discovery/resolver/filter: Implement resolver that can filter peers that pass thru --- discovery/resolver/filter/resolver.go | 77 +++++++++++++++++++++ discovery/resolver/filter/resolver_test.go | 80 ++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 discovery/resolver/filter/resolver.go create mode 100644 discovery/resolver/filter/resolver_test.go diff --git a/discovery/resolver/filter/resolver.go b/discovery/resolver/filter/resolver.go new file mode 100644 index 0000000..fcb8197 --- /dev/null +++ b/discovery/resolver/filter/resolver.go @@ -0,0 +1,77 @@ +package filter + +import ( + "context" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +type filterResolver[T peer.Peer] struct { + inner resolver.Resolver[T] + allow func(T) bool +} + +func WrapResolver[T peer.Peer](r resolver.Resolver[T], allow func(T) bool) resolver.Resolver[T] { + return &filterResolver[T]{inner: r, allow: allow} +} + +func (r *filterResolver[T]) Open(ctx context.Context) error { return r.inner.Open(ctx) } +func (r *filterResolver[T]) Close() error { return r.inner.Close() } + +func (r *filterResolver[T]) Resolve() resolver.Watcher[T] { + return &watcher[T]{inner: r.inner.Resolve(), allow: r.allow} +} + +type watcher[T peer.Peer] struct { + inner resolver.Watcher[T] + allow func(T) bool + admitted map[string]struct{} +} + +func (w *watcher[T]) Close() error { return w.inner.Close() } + +func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { + for { + u, err := w.inner.Next(ctx, opts) + if err != nil { + return resolver.Update[T]{}, err + } + + filtered := w.filter(u) + + if len(filtered.Additions) == 0 && len(filtered.Deletions) == 0 { + if opts.NoWait { + return resolver.Update[T]{}, resolver.ErrNoUpdates + } + + continue + } + + return filtered, nil + } +} + +func (w *watcher[T]) filter(u resolver.Update[T]) resolver.Update[T] { + var out resolver.Update[T] + + for _, p := range u.Additions { + if w.allow(p) { + if w.admitted == nil { + w.admitted = make(map[string]struct{}) + } + + w.admitted[p.Addr()] = struct{}{} + out.Additions = append(out.Additions, p) + } + } + + for _, p := range u.Deletions { + if _, ok := w.admitted[p.Addr()]; ok { + delete(w.admitted, p.Addr()) + out.Deletions = append(out.Deletions, p) + } + } + + return out +} diff --git a/discovery/resolver/filter/resolver_test.go b/discovery/resolver/filter/resolver_test.go new file mode 100644 index 0000000..af0f584 --- /dev/null +++ b/discovery/resolver/filter/resolver_test.go @@ -0,0 +1,80 @@ +package filter + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +func TestFilterResolverAllowsAdditions(t *testing.T) { + ctx := context.Background() + + inner := static.NewResolverFromStrings([]string{"allow:1", "deny:1"}) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "allow") + }) + + w := r.Resolve() + + u, err := w.Next(ctx, resolver.ResolveOptions{}) + + assert.Nil(t, err) + assert.Equal(t, []static.Peer{static.Peer("allow:1")}, u.Additions) + assert.Empty(t, u.Deletions) + + u, err = w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Equal(t, resolver.Update[static.Peer]{}, u) +} + +func TestFilterResolverTracksDeletions(t *testing.T) { + ctx := context.Background() + + inner := static.NewResolverFromStrings([]string{"allow:1", "deny:1"}) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "allow") + }) + + w := r.Resolve() + + _, err := w.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + + inner.UpdatePeers(static.PeersFromStrings("allow:2", "deny:2")) + + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + assert.Nil(t, err) + assert.ElementsMatch(t, []static.Peer{static.Peer("allow:2")}, u.Additions) + assert.ElementsMatch(t, []static.Peer{static.Peer("allow:1")}, u.Deletions) +} + +func TestFilterResolverNoWaitFilteredEmpty(t *testing.T) { + ctx := context.Background() + + inner := static.NewResolverFromStrings([]string{"deny:1"}) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "allow") + }) + + w := r.Resolve() + + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Equal(t, resolver.Update[static.Peer]{}, u) + + inner.UpdatePeers(static.PeersFromStrings("allow:1")) + + u, err = w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + assert.Nil(t, err) + assert.ElementsMatch(t, []static.Peer{static.Peer("allow:1")}, u.Additions) + assert.Empty(t, u.Deletions) +} From fad1ba2b3cccb06cff69b378ef2af58439509ff3 Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Thu, 9 Apr 2026 11:21:58 -0700 Subject: [PATCH 3/8] discovery/balancer: Add a sub concept to the balancer package: Policy --- discovery/balancer/policy.go | 70 ++++++++++++ discovery/balancer/policy_test.go | 184 ++++++++++++++++++++++++++++++ discovery/resolver/puller.go | 7 +- discovery/resolver/puller_test.go | 99 ++++++++++++++++ 4 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 discovery/balancer/policy.go create mode 100644 discovery/balancer/policy_test.go create mode 100644 discovery/resolver/puller_test.go diff --git a/discovery/balancer/policy.go b/discovery/balancer/policy.go new file mode 100644 index 0000000..4bc70d5 --- /dev/null +++ b/discovery/balancer/policy.go @@ -0,0 +1,70 @@ +package balancer + +import ( + "context" + "sync" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +type Policy[T peer.Peer] interface { + Get(context.Context, GetOptions) (T, func(error), error) + Update(resolver.Update[T]) +} + +type policyBalancer[S, T peer.Peer] struct { + resolver.Puller[S] + Policy[T] + + mu sync.Mutex + peers map[string]T + builder func(S) (T, error) +} + +func WrapPolicy[S, T peer.Peer](r resolver.Resolver[S], p Policy[T], build func(S) (T, error)) Balancer[T] { + b := &policyBalancer[S, T]{ + Policy: p, + peers: make(map[string]T), + builder: build, + } + + b.Puller = resolver.Puller[S]{ + Resolver: r, + UpdateFunc: b.handleUpdate, + } + + return b +} + +func PolicyBalancerFunc[T peer.Peer](p Policy[T]) func(resolver.Resolver[T]) Balancer[T] { + return func(r resolver.Resolver[T]) Balancer[T] { + return WrapPolicy(r, p, func(t T) (T, error) { return t, nil }) + } +} + +func (b *policyBalancer[S, T]) handleUpdate(u resolver.Update[S]) { + b.mu.Lock() + defer b.mu.Unlock() + + var mapped resolver.Update[T] + + for _, sp := range u.Additions { + tp, err := b.builder(sp) + if err != nil { + continue + } + + b.peers[sp.Addr()] = tp + mapped.Additions = append(mapped.Additions, tp) + } + + for _, sp := range u.Deletions { + if tp, ok := b.peers[sp.Addr()]; ok { + delete(b.peers, sp.Addr()) + mapped.Deletions = append(mapped.Deletions, tp) + } + } + + b.Policy.Update(mapped) +} diff --git a/discovery/balancer/policy_test.go b/discovery/balancer/policy_test.go new file mode 100644 index 0000000..d435c6a --- /dev/null +++ b/discovery/balancer/policy_test.go @@ -0,0 +1,184 @@ +package balancer_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" + "github.com/upfluence/pkg/v2/metadata" +) + +type wrappedPeer struct { + addr string +} + +func (p wrappedPeer) Addr() string { return p.addr } +func (p wrappedPeer) Metadata() metadata.Metadata { return nil } + +type testPolicy struct { + mu sync.Mutex + peers []wrappedPeer + updates []resolver.Update[wrappedPeer] +} + +func (p *testPolicy) Get(ctx context.Context, opts balancer.GetOptions) (wrappedPeer, func(error), error) { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.peers) == 0 { + if opts.NoWait { + return wrappedPeer{}, nil, balancer.ErrNoPeerAvailable + } + <-ctx.Done() + return wrappedPeer{}, nil, ctx.Err() + } + + peer := p.peers[0] + p.peers = p.peers[1:] + return peer, func(error) {}, nil +} + +func (p *testPolicy) Update(u resolver.Update[wrappedPeer]) { + p.mu.Lock() + defer p.mu.Unlock() + + p.updates = append(p.updates, u) + + for _, peer := range u.Additions { + p.peers = append(p.peers, peer) + } +} + +func (p *testPolicy) getUpdates() []resolver.Update[wrappedPeer] { + p.mu.Lock() + defer p.mu.Unlock() + + updates := make([]resolver.Update[wrappedPeer], len(p.updates)) + copy(updates, p.updates) + return updates +} + +func TestWrapPolicyMapsAdditions(t *testing.T) { + ctx := context.Background() + r := static.NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + policy := &testPolicy{} + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + assert.Nil(t, b.Open(ctx)) + + time.Sleep(10 * time.Millisecond) + + updates := policy.getUpdates() + assert.Len(t, updates, 1) + assert.ElementsMatch(t, []wrappedPeer{ + {addr: "localhost:1"}, + {addr: "localhost:2"}, + }, updates[0].Additions) + assert.Empty(t, updates[0].Deletions) + + peer, done, err := b.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.Equal(t, "localhost:1", peer.Addr()) + done(nil) + + assert.Nil(t, b.Close()) +} + +func TestWrapPolicyMapsDeletions(t *testing.T) { + ctx := context.Background() + r := static.NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + policy := &testPolicy{} + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + assert.Nil(t, b.Open(ctx)) + time.Sleep(10 * time.Millisecond) + + r.UpdatePeers(static.PeersFromStrings("localhost:2", "localhost:3")) + time.Sleep(10 * time.Millisecond) + + updates := policy.getUpdates() + assert.Len(t, updates, 2) + + assert.ElementsMatch(t, []wrappedPeer{{addr: "localhost:3"}}, updates[1].Additions) + assert.ElementsMatch(t, []wrappedPeer{{addr: "localhost:1"}}, updates[1].Deletions) + + assert.Nil(t, b.Close()) +} + +func TestWrapPolicySkipsFailedBuilds(t *testing.T) { + ctx := context.Background() + r := static.NewResolverFromStrings([]string{"localhost:1", "fail:2", "localhost:3"}) + policy := &testPolicy{} + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + if sp.Addr() == "fail:2" { + return wrappedPeer{}, errors.New("build failed") + } + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + assert.Nil(t, b.Open(ctx)) + time.Sleep(10 * time.Millisecond) + + updates := policy.getUpdates() + assert.Len(t, updates, 1) + assert.ElementsMatch(t, []wrappedPeer{ + {addr: "localhost:1"}, + {addr: "localhost:3"}, + }, updates[0].Additions) + + assert.Nil(t, b.Close()) +} + +func TestWrapPolicyDelegatesGetToPolicy(t *testing.T) { + ctx := context.Background() + r := static.NewResolverFromStrings([]string{"localhost:1"}) + policy := &testPolicy{} + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + assert.Nil(t, b.Open(ctx)) + time.Sleep(10 * time.Millisecond) + + peer, done, err := b.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.Equal(t, "localhost:1", peer.Addr()) + done(nil) + + peer, done, err = b.Get(ctx, balancer.GetOptions{NoWait: true}) + assert.Equal(t, balancer.ErrNoPeerAvailable, err) + assert.Nil(t, done) + + assert.Nil(t, b.Close()) +} diff --git a/discovery/resolver/puller.go b/discovery/resolver/puller.go index bbcaa7c..517f2c4 100644 --- a/discovery/resolver/puller.go +++ b/discovery/resolver/puller.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "github.com/upfluence/errors" @@ -20,6 +21,7 @@ type Puller[T peer.Peer] struct { openErr error openOnce sync.Once + opened atomic.Bool } func NewPuller[T peer.Peer](r Resolver[T], fn func(Update[T])) (*Puller[T], func()) { @@ -32,11 +34,12 @@ func NewPuller[T peer.Peer](r Resolver[T], fn func(Update[T])) (*Puller[T], func } func (p *Puller[T]) Close() error { + p.opened.Store(false) return errors.Combine(p.Monitor.Close(), p.Resolver.Close()) } func (p *Puller[T]) IsOpen() bool { - return p.openErr == nil && p.openOnce != sync.Once{} + return p.opened.Load() } func (p *Puller[T]) String() string { @@ -48,6 +51,7 @@ func (p *Puller[T]) Open(ctx context.Context) error { p.openErr = p.Resolver.Open(ctx) if p.openErr == nil { + p.opened.Store(true) p.Monitor.Run(p.pull) } }) @@ -65,6 +69,7 @@ func (p *Puller[T]) pull(ctx context.Context) { ) for { + err = nil w = p.Resolver.Resolve() for err == nil { diff --git a/discovery/resolver/puller_test.go b/discovery/resolver/puller_test.go new file mode 100644 index 0000000..bbb832a --- /dev/null +++ b/discovery/resolver/puller_test.go @@ -0,0 +1,99 @@ +package resolver_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +type transientErrorStaticResolver struct { + inner *static.Resolver[static.Peer] + failed atomic.Bool + failWith error +} + +func newTransientErrorStaticResolver(peers []string, err error) *transientErrorStaticResolver { + return &transientErrorStaticResolver{ + inner: static.NewResolverFromStrings(peers), + failWith: err, + } +} + +func (r *transientErrorStaticResolver) Open(ctx context.Context) error { + return r.inner.Open(ctx) +} + +func (r *transientErrorStaticResolver) Close() error { + return r.inner.Close() +} + +func (r *transientErrorStaticResolver) Resolve() resolver.Watcher[static.Peer] { + return &transientErrorWatcher{ + inner: r.inner.Resolve(), + failed: &r.failed, + failWith: r.failWith, + } +} + +func (r *transientErrorStaticResolver) UpdatePeers(peers []static.Peer) { + r.inner.UpdatePeers(peers) +} + +type transientErrorWatcher struct { + inner resolver.Watcher[static.Peer] + failed *atomic.Bool + failWith error +} + +func (w *transientErrorWatcher) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[static.Peer], error) { + if w.failed.CompareAndSwap(false, true) { + return resolver.Update[static.Peer]{}, w.failWith + } + + return w.inner.Next(ctx, opts) +} + +func (w *transientErrorWatcher) Close() error { + return w.inner.Close() +} + +func TestPullerRecoversAfterWatcherError(t *testing.T) { + r := newTransientErrorStaticResolver([]string{"allow:1"}, errors.New("boom")) + + updates := make(chan resolver.Update[static.Peer], 1) + p, _ := resolver.NewPuller(r, func(u resolver.Update[static.Peer]) { + if len(u.Additions) > 0 || len(u.Deletions) > 0 { + updates <- u + } + }) + + assert.Nil(t, p.Open(context.Background())) + + select { + case u := <-updates: + assert.Equal(t, []static.Peer{static.Peer("allow:1")}, u.Additions) + case <-time.After(200 * time.Millisecond): + t.Fatal("timed out waiting for update after watcher error") + } + + assert.Nil(t, p.Close()) +} + +func TestPullerIsOpenTracksClose(t *testing.T) { + r := static.NewResolverFromStrings([]string{"allow:1"}) + p, _ := resolver.NewPuller(r, func(resolver.Update[static.Peer]) {}) + assert.False(t, p.IsOpen()) + + assert.Nil(t, p.Open(context.Background())) + assert.True(t, p.IsOpen()) + + assert.Nil(t, p.Close()) + assert.False(t, p.IsOpen()) +} From 79aace786ad33e65faf65a82ee52c8815d96c106 Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Thu, 9 Apr 2026 11:22:30 -0700 Subject: [PATCH 4/8] discovery/balancer/*: Refactor and improve testing for balancer implementations --- discovery/balancer/balancertest/doc.go | 17 ++ discovery/balancer/balancertest/policy.go | 183 ++++++++++++++++++ discovery/balancer/dialer_test.go | 2 +- discovery/balancer/random/balancer.go | 130 +++++-------- discovery/balancer/random/balancer_test.go | 44 +++++ discovery/balancer/roundrobin/balancer.go | 105 +++++----- .../balancer/roundrobin/balancer_test.go | 61 +++--- 7 files changed, 374 insertions(+), 168 deletions(-) create mode 100644 discovery/balancer/balancertest/doc.go create mode 100644 discovery/balancer/balancertest/policy.go create mode 100644 discovery/balancer/random/balancer_test.go diff --git a/discovery/balancer/balancertest/doc.go b/discovery/balancer/balancertest/doc.go new file mode 100644 index 0000000..37443b8 --- /dev/null +++ b/discovery/balancer/balancertest/doc.go @@ -0,0 +1,17 @@ +// Package balancertest provides testing utilities for balancer.Policy implementations. +// +// Usage: +// +// func TestMyPolicy(t *testing.T) { +// balancertest.PolicyTest(t, func() balancer.Policy[static.Peer, static.Peer] { +// return NewPolicy[static.Peer]() +// }) +// } +// +// PolicyTest runs a comprehensive test suite that covers: +// - Empty policy (no peers) - tests NoWait and context cancellation behavior +// - Single peer - verifies Get() returns the same peer consistently +// - Adding and removing peers - tests dynamic peer updates +// - Removing all peers - verifies transition back to empty state +// - Adding peers to empty policy - tests that waiting Get() calls unblock +package balancertest diff --git a/discovery/balancer/balancertest/policy.go b/discovery/balancer/balancertest/policy.go new file mode 100644 index 0000000..df1e98d --- /dev/null +++ b/discovery/balancer/balancertest/policy.go @@ -0,0 +1,183 @@ +package balancertest + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +// PolicyFactory creates a new Policy instance for testing. +type PolicyFactory func() balancer.Policy[static.Peer] + +// PolicyTest runs a comprehensive test suite for a Policy implementation. +// It tests common policy behaviors including: +// - Empty policy (no peers) +// - Single peer +// - Multiple peers +// - Adding and removing peers +func PolicyTest(t *testing.T, factory PolicyFactory) { + for _, tt := range []struct { + name string + test func(*testing.T, PolicyFactory) + }{ + {"NoPeers", testNoPeers}, + {"SinglePeer", testSinglePeer}, + {"AddAndRemovePeers", testAddAndRemovePeers}, + {"RemoveAllPeers", testRemoveAllPeers}, + {"AddPeersToEmpty", testAddPeersToEmpty}, + } { + t.Run(tt.name, func(t *testing.T) { + tt.test(t, factory) + }) + } +} + +func testNoPeers(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // NoWait should return ErrNoPeerAvailable immediately + p, done, err := policy.Get(ctx, balancer.GetOptions{NoWait: true}) + assert.Equal(t, balancer.ErrNoPeerAvailable, err) + assert.Nil(t, done) + assert.Empty(t, p.Addr()) + + // Canceled context should return context.Canceled + cctx, cancel := context.WithCancel(ctx) + cancel() + + p, done, err = policy.Get(cctx, balancer.GetOptions{}) + assert.Equal(t, context.Canceled, err) + assert.Nil(t, done) + assert.Empty(t, p.Addr()) +} + +func testSinglePeer(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // Add a single peer + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Should get the same peer repeatedly + for i := 0; i < 5; i++ { + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.NotNil(t, done) + assert.Equal(t, "localhost:1", p.Addr()) + done(nil) + } +} + +func testAddAndRemovePeers(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // Add initial peers + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{ + static.Peer("localhost:1"), + static.Peer("localhost:2"), + }, + }) + + // Verify we can get peers + seen := make(map[string]bool) + for i := 0; i < 50; i++ { + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.NotNil(t, done) + seen[p.Addr()] = true + done(nil) + } + + assert.Contains(t, seen, "localhost:1") + assert.Contains(t, seen, "localhost:2") + + // Update peers: remove localhost:1, add localhost:3 + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:3")}, + Deletions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Verify we only see localhost:2 and localhost:3 + seen = make(map[string]bool) + for i := 0; i < 50; i++ { + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.NotNil(t, done) + seen[p.Addr()] = true + done(nil) + } + + assert.Contains(t, seen, "localhost:2") + assert.Contains(t, seen, "localhost:3") + assert.NotContains(t, seen, "localhost:1") +} + +func testRemoveAllPeers(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // Add a peer + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Verify we can get it + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.Equal(t, "localhost:1", p.Addr()) + done(nil) + + // Remove the peer + policy.Update(resolver.Update[static.Peer]{ + Deletions: []static.Peer{static.Peer("localhost:1")}, + }) + + // NoWait should return ErrNoPeerAvailable + p, done, err = policy.Get(ctx, balancer.GetOptions{NoWait: true}) + assert.Equal(t, balancer.ErrNoPeerAvailable, err) + assert.Nil(t, done) + assert.Empty(t, p.Addr()) +} + +func testAddPeersToEmpty(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // Start with no peers, spawn a goroutine that will wait + done := make(chan struct{}) + go func() { + p, doneFn, err := policy.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.NotNil(t, doneFn) + assert.NotEmpty(t, p.Addr()) + doneFn(nil) + close(done) + }() + + // Give the goroutine time to start waiting + time.Sleep(10 * time.Millisecond) + + // Add a peer + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + // The waiting goroutine should complete + select { + case <-done: + // Success + case <-time.After(100 * time.Millisecond): + t.Fatal("Get() did not unblock after adding peers") + } +} diff --git a/discovery/balancer/dialer_test.go b/discovery/balancer/dialer_test.go index 98f9eca..a089d6c 100644 --- a/discovery/balancer/dialer_test.go +++ b/discovery/balancer/dialer_test.go @@ -33,7 +33,7 @@ func TestDialer(t *testing.T) { d := balancer.Dialer[static.Peer]{ Builder: balancer.ResolverBuilder[static.Peer]{ Builder: r, - BalancerFunc: roundrobin.BalancerFunc[static.Peer], + BalancerFunc: balancer.PolicyBalancerFunc(roundrobin.NewPolicy[static.Peer]()), }, } diff --git a/discovery/balancer/random/balancer.go b/discovery/balancer/random/balancer.go index 4fa383a..37eedb1 100644 --- a/discovery/balancer/random/balancer.go +++ b/discovery/balancer/random/balancer.go @@ -2,8 +2,8 @@ package random import ( "context" - "fmt" "math/rand" + "slices" "sync" "time" @@ -16,118 +16,82 @@ type Rand interface { Intn(int) int } -type Balancer[T peer.Peer] struct { - *resolver.Puller[T] - - peers []T - peersMu *sync.RWMutex - rand Rand - - notifier chan interface{} - closeFn func() +type Policy[T peer.Peer] struct { + mu sync.RWMutex + peers []T + rand Rand + notifier chan struct{} } -func NewBalancer[T peer.Peer](r resolver.Resolver[T]) *Balancer[T] { - var b = &Balancer[T]{ +func NewPolicy[T peer.Peer]() *Policy[T] { + return &Policy[T]{ rand: rand.New(rand.NewSource(time.Now().UnixNano())), - peersMu: &sync.RWMutex{}, - notifier: make(chan interface{}), + notifier: make(chan struct{}), } - - b.Puller, b.closeFn = resolver.NewPuller(r, b.updatePeers) - - return b -} - -func (b *Balancer[T]) String() string { - return fmt.Sprintf("loadbalancer/random [resolver: %v]", b.Puller) } -func (b *Balancer[T]) updatePeers(u resolver.Update[T]) { - b.peersMu.Lock() - defer b.peersMu.Unlock() +func (p *Policy[T]) Update(u resolver.Update[T]) { + p.mu.Lock() + defer p.mu.Unlock() - var newPeers = make(map[T]interface{}) + wasEmpty := len(p.peers) == 0 - for _, p := range b.peers { - var found bool - - for _, peer := range u.Deletions { - if p.Addr() == peer.Addr() { - found = true - } - } - - if !found { - newPeers[p] = nil - } + peerMap := make(map[string]T) + for _, peer := range p.peers { + peerMap[peer.Addr()] = peer } - for _, p := range u.Additions { - var found bool - - for _, peer := range b.peers { - if p.Addr() == peer.Addr() { - found = true - } - } - - if !found { - newPeers[p] = nil - } + for _, peer := range u.Deletions { + delete(peerMap, peer.Addr()) } - var ( - i = 0 - empty = len(b.peers) == 0 - ) - - b.peers = make([]T, len(newPeers)) - - for p, _ := range newPeers { - b.peers[i] = p - i++ + for _, peer := range u.Additions { + peerMap[peer.Addr()] = peer } - if empty && (len(b.peers) > 0) { - for { - select { - case <-b.notifier: - default: - return - } - } + p.peers = make([]T, 0, len(peerMap)) + for _, peer := range peerMap { + p.peers = append(p.peers, peer) } -} -func (b *Balancer[T]) hasPeers() bool { - b.peersMu.RLock() - defer b.peersMu.RUnlock() - - return len(b.peers) > 0 + if wasEmpty && len(p.peers) > 0 { + close(p.notifier) + p.notifier = make(chan struct{}) + } } -func (b *Balancer[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { +func (p *Policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { var zero T - if !b.hasPeers() { + p.mu.RLock() + hasPeers := len(p.peers) > 0 + notifier := p.notifier + peers := slices.Clone(p.peers) + p.mu.RUnlock() + + if !hasPeers { if opts.NoWait { return zero, nil, balancer.ErrNoPeerAvailable } select { - case b.notifier <- true: + case <-notifier: case <-ctx.Done(): return zero, nil, ctx.Err() } + + p.mu.RLock() + peers = slices.Clone(p.peers) + p.mu.RUnlock() + } + + if len(peers) == 0 { + return zero, nil, balancer.ErrNoPeerAvailable } - b.peersMu.RLock() - defer b.peersMu.RUnlock() - return b.peers[b.rand.Intn(len(b.peers))], func(error) {}, nil + return peers[p.rand.Intn(len(peers))], func(error) {}, nil } -func (b *Balancer[T]) Close() error { - b.closeFn() - return nil +func NewBalancer[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { + return balancer.WrapPolicy(r, NewPolicy[T](), func(p T) (T, error) { return p, nil }) } diff --git a/discovery/balancer/random/balancer_test.go b/discovery/balancer/random/balancer_test.go new file mode 100644 index 0000000..b953b87 --- /dev/null +++ b/discovery/balancer/random/balancer_test.go @@ -0,0 +1,44 @@ +package random + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +func TestPolicy(t *testing.T) { + balancertest.PolicyTest(t, func() balancer.Policy[static.Peer] { + return NewPolicy[static.Peer]() + }) +} + +func TestBalancerWithPeers(t *testing.T) { + ctx := context.Background() + b := NewBalancer( + static.NewResolverFromStrings([]string{"localhost:0", "localhost:1", "localhost:2"}), + ) + + err := b.Open(ctx) + assert.Nil(t, err) + + seen := make(map[string]int) + for i := 0; i < 100; i++ { + p, done, err := b.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.NotEmpty(t, p.Addr()) + seen[p.Addr()]++ + done(nil) + } + + // With 100 requests across 3 peers, all should be selected at least once + assert.Contains(t, seen, "localhost:0") + assert.Contains(t, seen, "localhost:1") + assert.Contains(t, seen, "localhost:2") + + b.Close() +} diff --git a/discovery/balancer/roundrobin/balancer.go b/discovery/balancer/roundrobin/balancer.go index 23dbe11..1cd3dd7 100644 --- a/discovery/balancer/roundrobin/balancer.go +++ b/discovery/balancer/roundrobin/balancer.go @@ -3,7 +3,6 @@ package roundrobin import ( "container/ring" "context" - "fmt" "sync" "github.com/upfluence/pkg/v2/discovery/balancer" @@ -11,112 +10,106 @@ import ( "github.com/upfluence/pkg/v2/discovery/resolver" ) -type Balancer[T peer.Peer] struct { - resolver.Puller[T] - - addrs map[string]*ring.Ring - ring *ring.Ring - ringMu sync.RWMutex - +type Policy[T peer.Peer] struct { + mu sync.Mutex + addrs map[string]*ring.Ring + ring *ring.Ring notifier chan struct{} } -func BalancerFunc[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { - return NewBalancer[T](r) -} - -func NewBalancer[T peer.Peer](r resolver.Resolver[T]) *Balancer[T] { - var b = Balancer[T]{ +func NewPolicy[T peer.Peer]() *Policy[T] { + return &Policy[T]{ addrs: make(map[string]*ring.Ring), notifier: make(chan struct{}), } - - b.Puller = resolver.Puller[T]{Resolver: r, UpdateFunc: b.updateRing} - - return &b } -func (b *Balancer[T]) String() string { - return fmt.Sprintf("loadbalancer/roundrobin [resolver: %v]", &b.Puller) -} - -func (b *Balancer[T]) updateRing(update resolver.Update[T]) { - b.ringMu.Lock() - defer b.ringMu.Unlock() +func (p *Policy[T]) Update(u resolver.Update[T]) { + p.mu.Lock() + defer p.mu.Unlock() - wasEmpty := b.ring == nil + wasEmpty := p.ring == nil - for _, p := range update.Additions { - r := &ring.Ring{Value: p} - b.addrs[p.Addr()] = r + for _, peer := range u.Additions { + r := &ring.Ring{Value: peer} + p.addrs[peer.Addr()] = r - if b.ring == nil { - b.ring = r + if p.ring == nil { + p.ring = r continue } - b.ring.Link(r) + p.ring.Link(r) } - for _, p := range update.Deletions { - addr := p.Addr() - r, ok := b.addrs[addr] + for _, peer := range u.Deletions { + addr := peer.Addr() + r, ok := p.addrs[addr] if !ok { continue } - delete(b.addrs, addr) + delete(p.addrs, addr) - if p := r.Prev(); p != nil { - b.ring = p.Unlink(1) + // If this is the only element in the ring, set ring to nil + if r.Len() == 1 { + p.ring = nil continue } - b.ring = nil + // Otherwise, unlink this element from the ring + // Unlink returns the removed subring, so we keep prev as the new ring position + prev := r.Prev() + prev.Unlink(1) // Remove r from the ring + p.ring = prev } - isEmpty := b.ring == nil + isEmpty := p.ring == nil if wasEmpty && !isEmpty { - close(b.notifier) + close(p.notifier) + p.notifier = make(chan struct{}) } else if !wasEmpty && isEmpty { - b.notifier = make(chan struct{}) + p.notifier = make(chan struct{}) } } -func (b *Balancer[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { +func (p *Policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { var zero T - b.ringMu.RLock() - r := b.ring - n := b.notifier - b.ringMu.RUnlock() + p.mu.Lock() + r := p.ring + notifier := p.notifier + p.mu.Unlock() if r == nil { if opts.NoWait { return zero, nil, balancer.ErrNoPeerAvailable } - pctx := b.Puller.Monitor.Context() - select { - case <-n: + case <-notifier: case <-ctx.Done(): return zero, nil, ctx.Err() - case <-pctx.Done(): - return zero, nil, pctx.Err() } } - b.ringMu.Lock() - defer b.ringMu.Unlock() + p.mu.Lock() + defer p.mu.Unlock() - if v := b.ring.Value; v != nil { - b.ring = b.ring.Next() + if p.ring == nil { + return zero, nil, balancer.ErrNoPeerAvailable + } + if v := p.ring.Value; v != nil { + p.ring = p.ring.Next() return v.(T), func(error) {}, nil } return zero, nil, balancer.ErrNoPeerAvailable } + +func NewBalancer[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { + return balancer.PolicyBalancerFunc(NewPolicy[T]())(r) +} diff --git a/discovery/balancer/roundrobin/balancer_test.go b/discovery/balancer/roundrobin/balancer_test.go index 5010084..5739558 100644 --- a/discovery/balancer/roundrobin/balancer_test.go +++ b/discovery/balancer/roundrobin/balancer_test.go @@ -3,43 +3,22 @@ package roundrobin import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" "github.com/upfluence/pkg/v2/discovery/resolver/static" ) -func TestBalanceEmpty(t *testing.T) { - ctx := context.Background() - b := NewBalancer(&static.Resolver[static.Peer]{}) - - p, done, err := b.Get(ctx, balancer.GetOptions{NoWait: true}) - - assert.Empty(t, p.Addr()) - assert.Nil(t, done) - assert.Equal(t, balancer.ErrNoPeerAvailable, err) - - cctx, cancel := context.WithCancel(ctx) - cancel() - - p, done, err = b.Get(cctx, balancer.GetOptions{}) - - assert.Empty(t, p.Addr()) - assert.Nil(t, done) - assert.Equal(t, err, context.Canceled) - - err = b.Close() - assert.NoError(t, err) - - p, done, err = b.Get(ctx, balancer.GetOptions{}) - - assert.Empty(t, p.Addr()) - assert.Nil(t, done) - assert.Equal(t, err, context.Canceled) +func TestPolicy(t *testing.T) { + balancertest.PolicyTest(t, func() balancer.Policy[static.Peer] { + return NewPolicy[static.Peer]() + }) } -func TestBalanceWithPerrs(t *testing.T) { +func TestBalanceWithPeers(t *testing.T) { ctx := context.Background() b := NewBalancer( static.NewResolverFromStrings([]string{"localhost:0", "localhost:1"}), @@ -48,6 +27,8 @@ func TestBalanceWithPerrs(t *testing.T) { err := b.Open(ctx) assert.Nil(t, err) + time.Sleep(10 * time.Millisecond) + p, done, err := b.Get(ctx, balancer.GetOptions{}) done(nil) @@ -68,3 +49,27 @@ func TestBalanceWithPerrs(t *testing.T) { b.Close() } + +func TestBalanceRoundRobinOrder(t *testing.T) { + ctx := context.Background() + b := NewBalancer( + static.NewResolverFromStrings([]string{"localhost:0", "localhost:1", "localhost:2"}), + ) + + err := b.Open(ctx) + assert.Nil(t, err) + + time.Sleep(10 * time.Millisecond) + + // Verify round-robin cycles through all peers in order + for cycle := 0; cycle < 3; cycle++ { + for i := 0; i < 3; i++ { + p, done, err := b.Get(ctx, balancer.GetOptions{}) + assert.Nil(t, err) + assert.Contains(t, []string{"localhost:0", "localhost:1", "localhost:2"}, p.Addr()) + done(nil) + } + } + + b.Close() +} From 0c2d55f7e6f294ea30063aff0c556be54a39e774 Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Thu, 9 Apr 2026 11:52:48 -0700 Subject: [PATCH 5/8] discovery/resolver/transform: Implement a middleware that switch from one type to another easily --- discovery/resolver/filter/resolver_test.go | 20 +++ discovery/resolver/resolvertest/doc.go | 18 +++ discovery/resolver/resolvertest/resolver.go | 153 ++++++++++++++++++ discovery/resolver/static/resolver_test.go | 7 + discovery/resolver/transform/doc.go | 20 +++ discovery/resolver/transform/resolver.go | 68 ++++++++ discovery/resolver/transform/resolver_test.go | 136 ++++++++++++++++ 7 files changed, 422 insertions(+) create mode 100644 discovery/resolver/resolvertest/doc.go create mode 100644 discovery/resolver/resolvertest/resolver.go create mode 100644 discovery/resolver/transform/doc.go create mode 100644 discovery/resolver/transform/resolver.go create mode 100644 discovery/resolver/transform/resolver_test.go diff --git a/discovery/resolver/filter/resolver_test.go b/discovery/resolver/filter/resolver_test.go index af0f584..6322adc 100644 --- a/discovery/resolver/filter/resolver_test.go +++ b/discovery/resolver/filter/resolver_test.go @@ -8,9 +8,29 @@ import ( "github.com/stretchr/testify/assert" "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" "github.com/upfluence/pkg/v2/discovery/resolver/static" ) +func TestResolver(t *testing.T) { + resolvertest.ResolverTest(t, func(peers []static.Peer) (resolver.Resolver[static.Peer], []static.Peer) { + inner := static.NewResolver(peers) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "localhost") + }) + + // Filter the expected peers + var expected []static.Peer + for _, p := range peers { + if strings.HasPrefix(p.Addr(), "localhost") { + expected = append(expected, p) + } + } + + return r, expected + }, static.PeersFromStrings) +} + func TestFilterResolverAllowsAdditions(t *testing.T) { ctx := context.Background() diff --git a/discovery/resolver/resolvertest/doc.go b/discovery/resolver/resolvertest/doc.go new file mode 100644 index 0000000..f4334e7 --- /dev/null +++ b/discovery/resolver/resolvertest/doc.go @@ -0,0 +1,18 @@ +// Package resolvertest provides reusable test helpers for resolver implementations. +// +// The ResolverTest function runs a comprehensive test suite that validates common +// resolver behaviors including: +// - Empty resolver (no initial peers) +// - Initial peers resolution +// - NoWait behavior with no updates +// - Context cancellation +// - Watcher closure +// +// Example usage: +// +// func TestMyResolver(t *testing.T) { +// resolvertest.ResolverTest(t, func() resolver.Resolver[static.Peer] { +// return myresolver.New() +// }) +// } +package resolvertest diff --git a/discovery/resolver/resolvertest/resolver.go b/discovery/resolver/resolvertest/resolver.go new file mode 100644 index 0000000..ecc1500 --- /dev/null +++ b/discovery/resolver/resolvertest/resolver.go @@ -0,0 +1,153 @@ +package resolvertest + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +// ResolverFactory creates a new Resolver instance for testing. +// The argument is the peers that should pre-exist. +// Returns the resolver and the peers that should actually be returned (after filtering/transformation). +type ResolverFactory[T peer.Peer] func([]T) (resolver.Resolver[T], []T) + +// ResolverTest runs a comprehensive test suite for a Resolver implementation. +// It tests common resolver behaviors including: +// - Empty resolver (no initial peers) +// - Initial peers resolution +// - NoWait behavior with no updates +// - Context cancellation +// - Watcher closure +func ResolverTest[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + for _, tt := range []struct { + name string + test func(*testing.T, ResolverFactory[T], func(...string) []T) + }{ + {"NoSeeds", testNoSeeds[T]}, + {"InitialPeers", testInitialPeers[T]}, + {"NoWaitNoUpdates", testNoWaitNoUpdates[T]}, + {"ContextCancellation", testContextCancellation[T]}, + {"WatcherClose", testWatcherClose[T]}, + } { + t.Run(tt.name, func(t *testing.T) { + tt.test(t, factory, makePeers) + }) + } +} + +func testNoSeeds[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + r, expected := factory(nil) + + assert.Empty(t, expected) + + assert.Nil(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // NoWait should return ErrNoUpdates immediately when no peers exist + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Empty(t, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testInitialPeers[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1", "localhost:2") + r, expected := factory(peers) + + assert.Nil(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // Should get initial peers + u, err := w.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + assert.ElementsMatch(t, expected, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testNoWaitNoUpdates[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1") + r, _ := factory(peers) + + assert.Nil(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // Consume initial update + _, err := w.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + + // NoWait should return ErrNoUpdates when no updates are available + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Empty(t, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testContextCancellation[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1") + r, _ := factory(peers) + + assert.Nil(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // Consume initial update + _, err := w.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + + // Cancel context and try to wait for updates + cctx, cancel := context.WithCancel(ctx) + cancel() + + u, err := w.Next(cctx, resolver.ResolveOptions{}) + assert.Equal(t, context.Canceled, err) + assert.Empty(t, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testWatcherClose[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1") + r, _ := factory(peers) + + assert.Nil(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + + // Close the watcher + err = w.Close() + assert.Nil(t, err) + + // Trying to get next update should fail after close + // Give a small timeout to avoid blocking forever + cctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + + _, err = w.Next(cctx, resolver.ResolveOptions{}) + assert.NotNil(t, err) + // Should be either context.Canceled, context.DeadlineExceeded, or a close error +} diff --git a/discovery/resolver/static/resolver_test.go b/discovery/resolver/static/resolver_test.go index b67ca29..41783d3 100644 --- a/discovery/resolver/static/resolver_test.go +++ b/discovery/resolver/static/resolver_test.go @@ -6,8 +6,15 @@ import ( "github.com/stretchr/testify/assert" "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" ) +func TestResolver(t *testing.T) { + resolvertest.ResolverTest(t, func(peers []Peer) (resolver.Resolver[Peer], []Peer) { + return NewResolver(peers), peers + }, PeersFromStrings) +} + func TestResolve(t *testing.T) { ctx := context.Background() r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) diff --git a/discovery/resolver/transform/doc.go b/discovery/resolver/transform/doc.go new file mode 100644 index 0000000..28d053d --- /dev/null +++ b/discovery/resolver/transform/doc.go @@ -0,0 +1,20 @@ +// Package transform provides a resolver wrapper that transforms peers from one type to another. +// +// The transform resolver wraps an existing resolver and applies a transformation function +// to convert source peers (S) into target peers (T). This is useful for: +// - Adding metadata or wrapping peers with additional information +// - Converting between different peer implementations +// - Applying prefixes or transformations to peer addresses +// +// Example: +// +// source := static.NewResolverFromStrings([]string{"host1:80", "host2:80"}) +// +// // Wrap peers with a prefix +// tr := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { +// return wrappedPeer{addr: p.Addr(), prefix: "service-"} +// }) +// +// // The transformed resolver will emit wrappedPeer instances +// // with addresses like "service-host1:80", "service-host2:80" +package transform diff --git a/discovery/resolver/transform/resolver.go b/discovery/resolver/transform/resolver.go new file mode 100644 index 0000000..a9e6b70 --- /dev/null +++ b/discovery/resolver/transform/resolver.go @@ -0,0 +1,68 @@ +package transform + +import ( + "context" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +type transformResolver[S, T peer.Peer] struct { + source resolver.Resolver[S] + transform func(S) T +} + +func WrapResolver[S, T peer.Peer](r resolver.Resolver[S], fn func(S) T) resolver.Resolver[T] { + return &transformResolver[S, T]{ + source: r, + transform: fn, + } +} + +func (r *transformResolver[S, T]) Open(ctx context.Context) error { + return r.source.Open(ctx) +} + +func (r *transformResolver[S, T]) Close() error { + return r.source.Close() +} + +func (r *transformResolver[S, T]) Resolve() resolver.Watcher[T] { + return &watcher[S, T]{ + inner: r.source.Resolve(), + transform: r.transform, + } +} + +type watcher[S, T peer.Peer] struct { + inner resolver.Watcher[S] + transform func(S) T +} + +func (w *watcher[S, T]) Close() error { + return w.inner.Close() +} + +func (w *watcher[S, T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { + u, err := w.inner.Next(ctx, opts) + + if err != nil { + return resolver.Update[T]{}, err + } + + var result resolver.Update[T] + + result.Additions = make([]T, len(u.Additions)) + + for i, p := range u.Additions { + result.Additions[i] = w.transform(p) + } + + result.Deletions = make([]T, len(u.Deletions)) + + for i, p := range u.Deletions { + result.Deletions[i] = w.transform(p) + } + + return result, nil +} diff --git a/discovery/resolver/transform/resolver_test.go b/discovery/resolver/transform/resolver_test.go new file mode 100644 index 0000000..3acd7a0 --- /dev/null +++ b/discovery/resolver/transform/resolver_test.go @@ -0,0 +1,136 @@ +package transform_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" + "github.com/upfluence/pkg/v2/discovery/resolver/static" + "github.com/upfluence/pkg/v2/discovery/resolver/transform" + "github.com/upfluence/pkg/v2/metadata" +) + +type wrappedPeer struct { + addr string + prefix string +} + +func (p wrappedPeer) Addr() string { return p.prefix + p.addr } +func (p wrappedPeer) Metadata() metadata.Metadata { return nil } + +func makeWrappedPeers(addrs ...string) []wrappedPeer { + peers := make([]wrappedPeer, len(addrs)) + for i, addr := range addrs { + peers[i] = wrappedPeer{addr: addr, prefix: "prefix-"} + } + return peers +} + +func TestResolver(t *testing.T) { + resolvertest.ResolverTest(t, func(peers []wrappedPeer) (resolver.Resolver[wrappedPeer], []wrappedPeer) { + // Extract source peers from wrapped peers + sourcePeers := make([]static.Peer, len(peers)) + for i, p := range peers { + // Remove the prefix to get back the original address + sourcePeers[i] = static.Peer(p.addr) + } + + source := static.NewResolver(sourcePeers) + r := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { + return wrappedPeer{addr: p.Addr(), prefix: "prefix-"} + }) + + return r, peers + }, makeWrappedPeers) +} + +func TestTransformResolverTransformsUpdates(t *testing.T) { + ctx := context.Background() + source := static.NewResolverFromStrings([]string{"host1:80"}) + + tr := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { + return wrappedPeer{addr: p.Addr(), prefix: "transformed-"} + }) + + assert.Nil(t, tr.Open(ctx)) + defer tr.Close() + + w := tr.Resolve() + defer w.Close() + + // Get initial peers + u, err := w.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + assert.Len(t, u.Additions, 1) + assert.Equal(t, "transformed-host1:80", u.Additions[0].Addr()) + + // Update peers + source.UpdatePeers(static.PeersFromStrings("host2:80", "host3:80")) + + u, err = w.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + + // Should have additions and deletions + assert.Len(t, u.Additions, 2) + assert.Len(t, u.Deletions, 1) + + additionAddrs := []string{u.Additions[0].Addr(), u.Additions[1].Addr()} + assert.Contains(t, additionAddrs, "transformed-host2:80") + assert.Contains(t, additionAddrs, "transformed-host3:80") + + assert.Equal(t, "transformed-host1:80", u.Deletions[0].Addr()) +} + +func TestTransformResolverMultipleWatchers(t *testing.T) { + ctx := context.Background() + source := static.NewResolverFromStrings([]string{"host1:80"}) + + tr := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { + return wrappedPeer{addr: p.Addr(), prefix: "watcher-"} + }) + + assert.Nil(t, tr.Open(ctx)) + defer tr.Close() + + w1 := tr.Resolve() + defer w1.Close() + + w2 := tr.Resolve() + defer w2.Close() + + // Both watchers should get initial peers + u1, err := w1.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + assert.Len(t, u1.Additions, 1) + assert.Equal(t, "watcher-host1:80", u1.Additions[0].Addr()) + + u2, err := w2.Next(ctx, resolver.ResolveOptions{}) + assert.Nil(t, err) + assert.Len(t, u2.Additions, 1) + assert.Equal(t, "watcher-host1:80", u2.Additions[0].Addr()) + + // Update peers + source.UpdatePeers(static.PeersFromStrings("host2:80")) + + // Give time for update to propagate + time.Sleep(10 * time.Millisecond) + + // Both watchers should receive the update + u1, err = w1.Next(ctx, resolver.ResolveOptions{NoWait: true}) + assert.Nil(t, err) + assert.Len(t, u1.Additions, 1) + assert.Equal(t, "watcher-host2:80", u1.Additions[0].Addr()) + assert.Len(t, u1.Deletions, 1) + assert.Equal(t, "watcher-host1:80", u1.Deletions[0].Addr()) + + u2, err = w2.Next(ctx, resolver.ResolveOptions{NoWait: true}) + assert.Nil(t, err) + assert.Len(t, u2.Additions, 1) + assert.Equal(t, "watcher-host2:80", u2.Additions[0].Addr()) + assert.Len(t, u2.Deletions, 1) + assert.Equal(t, "watcher-host1:80", u2.Deletions[0].Addr()) +} From 99a52edbec5b8ec24c8134d19afb7636c3ae382d Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Thu, 9 Apr 2026 12:07:18 -0700 Subject: [PATCH 6/8] discovery/balancer/simple: Have a simple balancer that allow the usage of simple picker interface --- discovery/balancer/random/balancer.go | 81 ++------------ discovery/balancer/roundrobin/balancer.go | 105 ++---------------- .../balancer/roundrobin/balancer_test.go | 32 ------ discovery/balancer/simple/doc.go | 15 +++ discovery/balancer/simple/policy.go | 97 ++++++++++++++++ discovery/balancer/simple/policy_test.go | 103 +++++++++++++++++ 6 files changed, 233 insertions(+), 200 deletions(-) create mode 100644 discovery/balancer/simple/doc.go create mode 100644 discovery/balancer/simple/policy.go create mode 100644 discovery/balancer/simple/policy_test.go diff --git a/discovery/balancer/random/balancer.go b/discovery/balancer/random/balancer.go index 37eedb1..caed6d6 100644 --- a/discovery/balancer/random/balancer.go +++ b/discovery/balancer/random/balancer.go @@ -3,11 +3,10 @@ package random import ( "context" "math/rand" - "slices" - "sync" "time" "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/simple" "github.com/upfluence/pkg/v2/discovery/peer" "github.com/upfluence/pkg/v2/discovery/resolver" ) @@ -16,80 +15,18 @@ type Rand interface { Intn(int) int } -type Policy[T peer.Peer] struct { - mu sync.RWMutex - peers []T - rand Rand - notifier chan struct{} +type picker[T peer.Peer] struct { + rand Rand } -func NewPolicy[T peer.Peer]() *Policy[T] { - return &Policy[T]{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - notifier: make(chan struct{}), - } +func (p *picker[T]) Pick(ctx context.Context, peers []T) (T, error) { + return peers[p.rand.Intn(len(peers))], nil } -func (p *Policy[T]) Update(u resolver.Update[T]) { - p.mu.Lock() - defer p.mu.Unlock() - - wasEmpty := len(p.peers) == 0 - - peerMap := make(map[string]T) - for _, peer := range p.peers { - peerMap[peer.Addr()] = peer - } - - for _, peer := range u.Deletions { - delete(peerMap, peer.Addr()) - } - - for _, peer := range u.Additions { - peerMap[peer.Addr()] = peer - } - - p.peers = make([]T, 0, len(peerMap)) - for _, peer := range peerMap { - p.peers = append(p.peers, peer) - } - - if wasEmpty && len(p.peers) > 0 { - close(p.notifier) - p.notifier = make(chan struct{}) - } -} - -func (p *Policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { - var zero T - - p.mu.RLock() - hasPeers := len(p.peers) > 0 - notifier := p.notifier - peers := slices.Clone(p.peers) - p.mu.RUnlock() - - if !hasPeers { - if opts.NoWait { - return zero, nil, balancer.ErrNoPeerAvailable - } - - select { - case <-notifier: - case <-ctx.Done(): - return zero, nil, ctx.Err() - } - - p.mu.RLock() - peers = slices.Clone(p.peers) - p.mu.RUnlock() - } - - if len(peers) == 0 { - return zero, nil, balancer.ErrNoPeerAvailable - } - - return peers[p.rand.Intn(len(peers))], func(error) {}, nil +func NewPolicy[T peer.Peer]() balancer.Policy[T] { + return simple.NewPolicy(&picker[T]{ + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + }) } func NewBalancer[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { diff --git a/discovery/balancer/roundrobin/balancer.go b/discovery/balancer/roundrobin/balancer.go index 1cd3dd7..01c0083 100644 --- a/discovery/balancer/roundrobin/balancer.go +++ b/discovery/balancer/roundrobin/balancer.go @@ -1,113 +1,26 @@ package roundrobin import ( - "container/ring" "context" - "sync" + "sync/atomic" "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/simple" "github.com/upfluence/pkg/v2/discovery/peer" "github.com/upfluence/pkg/v2/discovery/resolver" ) -type Policy[T peer.Peer] struct { - mu sync.Mutex - addrs map[string]*ring.Ring - ring *ring.Ring - notifier chan struct{} +type picker[T peer.Peer] struct { + index atomic.Uint64 } -func NewPolicy[T peer.Peer]() *Policy[T] { - return &Policy[T]{ - addrs: make(map[string]*ring.Ring), - notifier: make(chan struct{}), - } +func (p *picker[T]) Pick(ctx context.Context, peers []T) (T, error) { + idx := p.index.Add(1) - 1 + return peers[idx%uint64(len(peers))], nil } -func (p *Policy[T]) Update(u resolver.Update[T]) { - p.mu.Lock() - defer p.mu.Unlock() - - wasEmpty := p.ring == nil - - for _, peer := range u.Additions { - r := &ring.Ring{Value: peer} - p.addrs[peer.Addr()] = r - - if p.ring == nil { - p.ring = r - continue - } - - p.ring.Link(r) - } - - for _, peer := range u.Deletions { - addr := peer.Addr() - r, ok := p.addrs[addr] - - if !ok { - continue - } - - delete(p.addrs, addr) - - // If this is the only element in the ring, set ring to nil - if r.Len() == 1 { - p.ring = nil - continue - } - - // Otherwise, unlink this element from the ring - // Unlink returns the removed subring, so we keep prev as the new ring position - prev := r.Prev() - prev.Unlink(1) // Remove r from the ring - p.ring = prev - } - - isEmpty := p.ring == nil - - if wasEmpty && !isEmpty { - close(p.notifier) - p.notifier = make(chan struct{}) - } else if !wasEmpty && isEmpty { - p.notifier = make(chan struct{}) - } -} - -func (p *Policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { - var zero T - - p.mu.Lock() - r := p.ring - notifier := p.notifier - p.mu.Unlock() - - if r == nil { - if opts.NoWait { - return zero, nil, balancer.ErrNoPeerAvailable - } - - select { - case <-notifier: - case <-ctx.Done(): - return zero, nil, ctx.Err() - } - } - - p.mu.Lock() - defer p.mu.Unlock() - - if p.ring == nil { - return zero, nil, balancer.ErrNoPeerAvailable - } - - if v := p.ring.Value; v != nil { - p.ring = p.ring.Next() - return v.(T), func(error) {}, nil - } - - return zero, nil, balancer.ErrNoPeerAvailable +func NewPolicy[T peer.Peer]() balancer.Policy[T] { + return simple.NewPolicy(&picker[T]{}) } func NewBalancer[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { diff --git a/discovery/balancer/roundrobin/balancer_test.go b/discovery/balancer/roundrobin/balancer_test.go index 5739558..6bbf5c1 100644 --- a/discovery/balancer/roundrobin/balancer_test.go +++ b/discovery/balancer/roundrobin/balancer_test.go @@ -18,38 +18,6 @@ func TestPolicy(t *testing.T) { }) } -func TestBalanceWithPeers(t *testing.T) { - ctx := context.Background() - b := NewBalancer( - static.NewResolverFromStrings([]string{"localhost:0", "localhost:1"}), - ) - - err := b.Open(ctx) - assert.Nil(t, err) - - time.Sleep(10 * time.Millisecond) - - p, done, err := b.Get(ctx, balancer.GetOptions{}) - done(nil) - - assert.Nil(t, err) - assert.Equal(t, "localhost:0", p.Addr()) - - p, done, err = b.Get(ctx, balancer.GetOptions{}) - done(nil) - - assert.Nil(t, err) - assert.Equal(t, "localhost:1", p.Addr()) - - p, done, err = b.Get(ctx, balancer.GetOptions{}) - done(nil) - - assert.Nil(t, err) - assert.Equal(t, "localhost:0", p.Addr()) - - b.Close() -} - func TestBalanceRoundRobinOrder(t *testing.T) { ctx := context.Background() b := NewBalancer( diff --git a/discovery/balancer/simple/doc.go b/discovery/balancer/simple/doc.go new file mode 100644 index 0000000..daec72f --- /dev/null +++ b/discovery/balancer/simple/doc.go @@ -0,0 +1,15 @@ +// Package simple provides a flexible balancer policy that delegates peer +// selection to a Picker implementation. +// +// The Picker interface allows for custom balancing strategies without +// implementing the full Policy interface: +// +// type CustomPicker struct{} +// +// func (p *CustomPicker) Pick(ctx context.Context, peers []peer.Peer) (peer.Peer, error) { +// // Custom selection logic +// return peers[0], nil +// } +// +// policy := simple.NewPolicy[*peer.Peer](new(CustomPicker)) +package simple diff --git a/discovery/balancer/simple/policy.go b/discovery/balancer/simple/policy.go new file mode 100644 index 0000000..5a47e54 --- /dev/null +++ b/discovery/balancer/simple/policy.go @@ -0,0 +1,97 @@ +package simple + +import ( + "context" + "slices" + "sync" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +// Picker selects a peer from a list of available peers. +type Picker[T peer.Peer] interface { + // Pick selects a peer from the provided list. The implementation should + // return an error if no suitable peer can be selected. + Pick(context.Context, []T) (T, error) +} + +type policy[T peer.Peer] struct { + picker Picker[T] + + mu sync.RWMutex + peers []T + notifier chan struct{} +} + +// NewPolicy creates a new Policy that delegates peer selection to the provided Picker. +func NewPolicy[T peer.Peer](picker Picker[T]) balancer.Policy[T] { + return &policy[T]{ + picker: picker, + notifier: make(chan struct{}), + } +} + +func (p *policy[T]) Update(u resolver.Update[T]) { + p.mu.Lock() + defer p.mu.Unlock() + + wasEmpty := len(p.peers) == 0 + + peerMap := make(map[string]T) + for _, peer := range p.peers { + peerMap[peer.Addr()] = peer + } + + for _, peer := range u.Deletions { + delete(peerMap, peer.Addr()) + } + + for _, peer := range u.Additions { + peerMap[peer.Addr()] = peer + } + + p.peers = make([]T, 0, len(peerMap)) + for _, peer := range peerMap { + p.peers = append(p.peers, peer) + } + + if wasEmpty && len(p.peers) > 0 { + close(p.notifier) + p.notifier = make(chan struct{}) + } +} + +func (p *policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { + var zero T + + p.mu.RLock() + hasPeers := len(p.peers) > 0 + notifier := p.notifier + peers := slices.Clone(p.peers) + p.mu.RUnlock() + + if !hasPeers { + if opts.NoWait { + return zero, nil, balancer.ErrNoPeerAvailable + } + + select { + case <-notifier: + case <-ctx.Done(): + return zero, nil, ctx.Err() + } + + p.mu.RLock() + peers = slices.Clone(p.peers) + p.mu.RUnlock() + } + + if len(peers) == 0 { + return zero, nil, balancer.ErrNoPeerAvailable + } + + peer, err := p.picker.Pick(ctx, peers) + return peer, func(error) {}, err +} diff --git a/discovery/balancer/simple/policy_test.go b/discovery/balancer/simple/policy_test.go new file mode 100644 index 0000000..d4405af --- /dev/null +++ b/discovery/balancer/simple/policy_test.go @@ -0,0 +1,103 @@ +package simple + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +func TestPolicy(t *testing.T) { + balancertest.PolicyTest(t, func() balancer.Policy[static.Peer] { + return NewPolicy(&roundRobinPicker{}) + }) +} + +// roundRobinPicker cycles through peers sequentially +type roundRobinPicker struct { + mu sync.Mutex + index int +} + +func (p *roundRobinPicker) Pick(ctx context.Context, peers []static.Peer) (static.Peer, error) { + if len(peers) == 0 { + return static.Peer(""), errors.New("no peers available") + } + + p.mu.Lock() + idx := p.index % len(peers) + p.index++ + p.mu.Unlock() + + return peers[idx], nil +} + +func TestPickerDelegation(t *testing.T) { + policy := NewPolicy(&lastPicker{}) + + peers := []static.Peer{ + static.Peer("peer1"), + static.Peer("peer2"), + static.Peer("peer3"), + } + + policy.Update(resolver.Update[static.Peer]{Additions: peers}) + time.Sleep(10 * time.Millisecond) + + peer, _, err := policy.Get(context.Background(), balancer.GetOptions{NoWait: true}) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + + // Verify we got one of the peers (order from map is not guaranteed) + found := false + for _, p := range peers { + if peer.Addr() == p.Addr() { + found = true + break + } + } + if !found { + t.Errorf("Get() returned unexpected peer %q", peer.Addr()) + } +} + +func TestPickerError(t *testing.T) { + policy := NewPolicy(&errorPicker{}) + + peers := []static.Peer{static.Peer("peer1")} + policy.Update(resolver.Update[static.Peer]{Additions: peers}) + time.Sleep(10 * time.Millisecond) + + _, _, err := policy.Get(context.Background(), balancer.GetOptions{NoWait: true}) + if err == nil { + t.Fatal("Get() expected error, got nil") + } + + if err.Error() != "picker error" { + t.Errorf("Get() error = %q, want %q", err.Error(), "picker error") + } +} + +// lastPicker picks the last peer from the list +type lastPicker struct{} + +func (p *lastPicker) Pick(ctx context.Context, peers []static.Peer) (static.Peer, error) { + if len(peers) == 0 { + return static.Peer(""), errors.New("no peers available") + } + return peers[len(peers)-1], nil +} + +// errorPicker always returns an error +type errorPicker struct{} + +func (p *errorPicker) Pick(ctx context.Context, peers []static.Peer) (static.Peer, error) { + return static.Peer(""), errors.New("picker error") +} From 1fdcf89b4bfa1ceb55586c74fda06de1addd2f6c Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Thu, 9 Apr 2026 15:26:17 -0700 Subject: [PATCH 7/8] discovery/balancer/simple: Fix the flakiness for the first get --- discovery/balancer/simple/policy.go | 45 +++++++++---------- discovery/balancer/simple/policy_test.go | 55 ++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/discovery/balancer/simple/policy.go b/discovery/balancer/simple/policy.go index 5a47e54..513ceec 100644 --- a/discovery/balancer/simple/policy.go +++ b/discovery/balancer/simple/policy.go @@ -66,32 +66,29 @@ func (p *policy[T]) Update(u resolver.Update[T]) { func (p *policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { var zero T - p.mu.RLock() - hasPeers := len(p.peers) > 0 - notifier := p.notifier - peers := slices.Clone(p.peers) - p.mu.RUnlock() - - if !hasPeers { - if opts.NoWait { - return zero, nil, balancer.ErrNoPeerAvailable - } - - select { - case <-notifier: - case <-ctx.Done(): - return zero, nil, ctx.Err() - } - + for { p.mu.RLock() - peers = slices.Clone(p.peers) + notifier := p.notifier + peers := slices.Clone(p.peers) p.mu.RUnlock() - } - if len(peers) == 0 { - return zero, nil, balancer.ErrNoPeerAvailable - } + if len(peers) == 0 { + if opts.NoWait { + return zero, nil, balancer.ErrNoPeerAvailable + } + + select { + case <-notifier: + // Notifier closed, peers may be available now. + // Loop back to re-check in case peers were removed + // between the notification and now. + continue + case <-ctx.Done(): + return zero, nil, ctx.Err() + } + } - peer, err := p.picker.Pick(ctx, peers) - return peer, func(error) {}, err + peer, err := p.picker.Pick(ctx, peers) + return peer, func(error) {}, err + } } diff --git a/discovery/balancer/simple/policy_test.go b/discovery/balancer/simple/policy_test.go index d4405af..0eccb21 100644 --- a/discovery/balancer/simple/policy_test.go +++ b/discovery/balancer/simple/policy_test.go @@ -101,3 +101,58 @@ type errorPicker struct{} func (p *errorPicker) Pick(ctx context.Context, peers []static.Peer) (static.Peer, error) { return static.Peer(""), errors.New("picker error") } + +// TestRaceConditionPeerRemovalAfterWakeup tests the race condition where +// peers are removed after Get() wakes up from the notifier but before it +// re-reads the peer list. The fix should retry waiting in this case. +func TestRaceConditionPeerRemovalAfterWakeup(t *testing.T) { + policy := NewPolicy(&roundRobinPicker{}) + ctx := context.Background() + + // Start a goroutine that will wait for peers + gotPeer := make(chan static.Peer) + gotErr := make(chan error) + go func() { + peer, _, err := policy.Get(ctx, balancer.GetOptions{}) + if err != nil { + gotErr <- err + } else { + gotPeer <- peer + } + }() + + // Give the goroutine time to start waiting + time.Sleep(10 * time.Millisecond) + + // Add a peer (this will close the notifier) + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Immediately remove the peer to simulate the race condition + // This happens after the notifier is closed but potentially before + // Get() re-reads the peer list + policy.Update(resolver.Update[static.Peer]{ + Deletions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Add the peer back so Get() can eventually succeed + time.Sleep(5 * time.Millisecond) + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:2")}, + }) + + // The Get() call should eventually succeed (not return error) + // It could get either localhost:1 (if fast enough) or localhost:2 (after retry) + select { + case peer := <-gotPeer: + addr := peer.Addr() + if addr != "localhost:1" && addr != "localhost:2" { + t.Errorf("Get() returned unexpected peer %q, want localhost:1 or localhost:2", addr) + } + case err := <-gotErr: + t.Fatalf("Get() returned error %v, expected to retry and succeed", err) + case <-time.After(200 * time.Millisecond): + t.Fatal("Get() did not complete in time") + } +} From 6f65ec932a7df4545ab0a95051666731bd81940a Mon Sep 17 00:00:00 2001 From: Alexis Montagne Date: Mon, 13 Apr 2026 14:22:00 -0700 Subject: [PATCH 8/8] discovery/*: Remove time.Sleep coordination in the tests --- discovery/balancer/balancertest/policy.go | 50 ++++++---- discovery/balancer/dialer.go | 27 ++++-- discovery/balancer/dialer_test.go | 15 +-- discovery/balancer/policy.go | 10 +- discovery/balancer/policy_test.go | 92 +++++++++++++------ discovery/balancer/random/balancer.go | 28 ++++-- discovery/balancer/random/balancer_test.go | 10 +- discovery/balancer/roundrobin/balancer.go | 3 +- .../balancer/roundrobin/balancer_test.go | 34 +++++-- discovery/balancer/simple/policy.go | 35 ++++--- discovery/balancer/simple/policy_test.go | 60 ++++++------ discovery/resolver/filter/resolver.go | 38 ++++++-- discovery/resolver/filter/resolver_test.go | 10 +- discovery/resolver/puller.go | 33 +++++-- discovery/resolver/puller_test.go | 12 +-- discovery/resolver/resolvertest/resolver.go | 25 ++--- discovery/resolver/static/resolver.go | 47 ++++++++-- discovery/resolver/static/resolver_test.go | 32 ++++--- discovery/resolver/sync_resolver.go | 36 +++++--- discovery/resolver/sync_resolver_test.go | 13 +-- discovery/resolver/transform/resolver_test.go | 25 +++-- 21 files changed, 418 insertions(+), 217 deletions(-) diff --git a/discovery/balancer/balancertest/policy.go b/discovery/balancer/balancertest/policy.go index df1e98d..76fd087 100644 --- a/discovery/balancer/balancertest/policy.go +++ b/discovery/balancer/balancertest/policy.go @@ -2,10 +2,11 @@ package balancertest import ( "context" + "runtime" "testing" - "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/resolver" @@ -68,9 +69,10 @@ func testSinglePeer(t *testing.T, factory PolicyFactory) { }) // Should get the same peer repeatedly - for i := 0; i < 5; i++ { + for range 5 { p, done, err := policy.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + + require.NoError(t, err) assert.NotNil(t, done) assert.Equal(t, "localhost:1", p.Addr()) done(nil) @@ -91,11 +93,15 @@ func testAddAndRemovePeers(t *testing.T, factory PolicyFactory) { // Verify we can get peers seen := make(map[string]bool) - for i := 0; i < 50; i++ { + + for range 50 { p, done, err := policy.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + + require.NoError(t, err) assert.NotNil(t, done) + seen[p.Addr()] = true + done(nil) } @@ -110,11 +116,15 @@ func testAddAndRemovePeers(t *testing.T, factory PolicyFactory) { // Verify we only see localhost:2 and localhost:3 seen = make(map[string]bool) - for i := 0; i < 50; i++ { + + for range 50 { p, done, err := policy.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + + require.NoError(t, err) assert.NotNil(t, done) + seen[p.Addr()] = true + done(nil) } @@ -134,7 +144,7 @@ func testRemoveAllPeers(t *testing.T, factory PolicyFactory) { // Verify we can get it p, done, err := policy.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "localhost:1", p.Addr()) done(nil) @@ -151,33 +161,39 @@ func testRemoveAllPeers(t *testing.T, factory PolicyFactory) { } func testAddPeersToEmpty(t *testing.T, factory PolicyFactory) { - ctx := context.Background() + ctx := t.Context() policy := factory() - // Start with no peers, spawn a goroutine that will wait + // started is closed just before the goroutine enters Get's select, + // giving us a deterministic signal instead of a sleep. + started := make(chan struct{}) done := make(chan struct{}) + go func() { + close(started) + p, doneFn, err := policy.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, doneFn) assert.NotEmpty(t, p.Addr()) doneFn(nil) close(done) }() - // Give the goroutine time to start waiting - time.Sleep(10 * time.Millisecond) + // Wait until the goroutine has been scheduled, then yield once more so + // it reaches the select inside Get before we call Update. + <-started + runtime.Gosched() - // Add a peer + // Add a peer — this closes the notifier and unblocks Get. policy.Update(resolver.Update[static.Peer]{ Additions: []static.Peer{static.Peer("localhost:1")}, }) - // The waiting goroutine should complete select { case <-done: - // Success - case <-time.After(100 * time.Millisecond): + // success + case <-ctx.Done(): t.Fatal("Get() did not unblock after adding peers") } } diff --git a/discovery/balancer/dialer.go b/discovery/balancer/dialer.go index 10fd6ba..f892779 100644 --- a/discovery/balancer/dialer.go +++ b/discovery/balancer/dialer.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "github.com/upfluence/errors" @@ -16,16 +17,23 @@ type Dialer[T peer.Peer] struct { Dialer *net.Dialer Options GetOptions + dialerOnce sync.Once + netDialer *net.Dialer + mu sync.Mutex lds map[string]*localDialer[T] } func (d *Dialer[T]) dialer() *net.Dialer { - if d.Dialer == nil { - d.Dialer = &net.Dialer{} - } + d.dialerOnce.Do(func() { + if d.Dialer != nil { + d.netDialer = d.Dialer + } else { + d.netDialer = &net.Dialer{} + } + }) - return d.Dialer + return d.netDialer } func (d *Dialer[T]) Dial(network, addr string) (net.Conn, error) { @@ -71,7 +79,7 @@ type localDialer[T peer.Peer] struct { d *Dialer[T] b Balancer[T] - opened bool + opened atomic.Bool sf syncutil.Singleflight[struct{}] } @@ -82,13 +90,13 @@ func (ld *localDialer[T]) open(ctx context.Context) (struct{}, error) { } } - ld.opened = true + ld.opened.Store(true) return struct{}{}, nil } func (ld *localDialer[T]) dial(ctx context.Context, network string) (net.Conn, error) { - if !ld.opened { + if !ld.opened.Load() { if _, _, err := ld.sf.Do(ctx, ld.open); err != nil { return nil, err } @@ -118,13 +126,14 @@ func (ld *localDialer[T]) close() error { type doneCloserConn struct { net.Conn - done func(error) + doneOnce sync.Once + done func(error) } func (dcc *doneCloserConn) Close() error { err := dcc.Conn.Close() - dcc.done(err) + dcc.doneOnce.Do(func() { dcc.done(err) }) return err } diff --git a/discovery/balancer/dialer_test.go b/discovery/balancer/dialer_test.go index a089d6c..d6cbc51 100644 --- a/discovery/balancer/dialer_test.go +++ b/discovery/balancer/dialer_test.go @@ -2,12 +2,13 @@ package balancer_test import ( "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/balancer/roundrobin" "github.com/upfluence/pkg/v2/discovery/resolver/static" @@ -16,13 +17,13 @@ import ( func TestDialer(t *testing.T) { s1 := httptest.NewServer( http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - io.WriteString(rw, "s1") + io.WriteString(rw, "s1") //nolint:errcheck }), ) s2 := httptest.NewServer( http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - io.WriteString(rw, "s2") + io.WriteString(rw, "s2") //nolint:errcheck }), ) @@ -44,14 +45,14 @@ func TestDialer(t *testing.T) { cl := http.Client{Transport: &http.Transport{DialContext: d.DialContext}} req, err := http.NewRequest("GET", "http://example.com/foo", http.NoBody) - assert.Nil(t, err) + require.NoError(t, err) for _, want := range []string{"s1", "s2", "s1"} { resp, err := cl.Do(req) - assert.Nil(t, err) + require.NoError(t, err) - buf, err := ioutil.ReadAll(resp.Body) - assert.Nil(t, err) + buf, err := io.ReadAll(resp.Body) + require.NoError(t, err) resp.Body.Close() assert.Equal(t, want, string(buf)) diff --git a/discovery/balancer/policy.go b/discovery/balancer/policy.go index 4bc70d5..9a31c90 100644 --- a/discovery/balancer/policy.go +++ b/discovery/balancer/policy.go @@ -6,6 +6,7 @@ import ( "github.com/upfluence/pkg/v2/discovery/peer" "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/log" ) type Policy[T peer.Peer] interface { @@ -14,7 +15,7 @@ type Policy[T peer.Peer] interface { } type policyBalancer[S, T peer.Peer] struct { - resolver.Puller[S] + *resolver.Puller[S] Policy[T] mu sync.Mutex @@ -29,7 +30,7 @@ func WrapPolicy[S, T peer.Peer](r resolver.Resolver[S], p Policy[T], build func( builder: build, } - b.Puller = resolver.Puller[S]{ + b.Puller = &resolver.Puller[S]{ Resolver: r, UpdateFunc: b.handleUpdate, } @@ -52,6 +53,8 @@ func (b *policyBalancer[S, T]) handleUpdate(u resolver.Update[S]) { for _, sp := range u.Additions { tp, err := b.builder(sp) if err != nil { + log.WithError(err).Errorf("balancer: failed to build peer for %q, skipping", sp.Addr()) + continue } @@ -62,9 +65,10 @@ func (b *policyBalancer[S, T]) handleUpdate(u resolver.Update[S]) { for _, sp := range u.Deletions { if tp, ok := b.peers[sp.Addr()]; ok { delete(b.peers, sp.Addr()) + mapped.Deletions = append(mapped.Deletions, tp) } } - b.Policy.Update(mapped) + b.Update(mapped) } diff --git a/discovery/balancer/policy_test.go b/discovery/balancer/policy_test.go index d435c6a..3d80aaa 100644 --- a/discovery/balancer/policy_test.go +++ b/discovery/balancer/policy_test.go @@ -5,9 +5,9 @@ import ( "errors" "sync" "testing" - "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/resolver" @@ -22,10 +22,20 @@ type wrappedPeer struct { func (p wrappedPeer) Addr() string { return p.addr } func (p wrappedPeer) Metadata() metadata.Metadata { return nil } +// testPolicy is a Policy implementation used in tests that records all +// updates it receives and exposes waitForUpdate to synchronise on them +// without relying on time.Sleep. type testPolicy struct { mu sync.Mutex peers []wrappedPeer updates []resolver.Update[wrappedPeer] + // updatec is closed on each Update call and immediately replaced so + // that waitForUpdate returns exactly once per Update. + updatec chan struct{} +} + +func newTestPolicy() *testPolicy { + return &testPolicy{updatec: make(chan struct{})} } func (p *testPolicy) Get(ctx context.Context, opts balancer.GetOptions) (wrappedPeer, func(error), error) { @@ -36,23 +46,45 @@ func (p *testPolicy) Get(ctx context.Context, opts balancer.GetOptions) (wrapped if opts.NoWait { return wrappedPeer{}, nil, balancer.ErrNoPeerAvailable } + <-ctx.Done() + return wrappedPeer{}, nil, ctx.Err() } peer := p.peers[0] p.peers = p.peers[1:] + return peer, func(error) {}, nil } func (p *testPolicy) Update(u resolver.Update[wrappedPeer]) { p.mu.Lock() - defer p.mu.Unlock() p.updates = append(p.updates, u) + p.peers = append(p.peers, u.Additions...) + + // Close the current channel and replace it before releasing the lock so + // waitForUpdate never misses a notification. + ch := p.updatec + p.updatec = make(chan struct{}) - for _, peer := range u.Additions { - p.peers = append(p.peers, peer) + p.mu.Unlock() + + close(ch) +} + +// waitForUpdate blocks until the next Update call completes or ctx is done. +func (p *testPolicy) waitForUpdate(ctx context.Context) error { + p.mu.Lock() + ch := p.updatec + p.mu.Unlock() + + select { + case <-ch: + return nil + case <-ctx.Done(): + return ctx.Err() } } @@ -62,13 +94,14 @@ func (p *testPolicy) getUpdates() []resolver.Update[wrappedPeer] { updates := make([]resolver.Update[wrappedPeer], len(p.updates)) copy(updates, p.updates) + return updates } func TestWrapPolicyMapsAdditions(t *testing.T) { - ctx := context.Background() + ctx := t.Context() r := static.NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) - policy := &testPolicy{} + policy := newTestPolicy() b := balancer.WrapPolicy( r, @@ -78,9 +111,8 @@ func TestWrapPolicyMapsAdditions(t *testing.T) { }, ) - assert.Nil(t, b.Open(ctx)) - - time.Sleep(10 * time.Millisecond) + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) updates := policy.getUpdates() assert.Len(t, updates, 1) @@ -91,17 +123,17 @@ func TestWrapPolicyMapsAdditions(t *testing.T) { assert.Empty(t, updates[0].Deletions) peer, done, err := b.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "localhost:1", peer.Addr()) done(nil) - assert.Nil(t, b.Close()) + assert.NoError(t, b.Close()) } func TestWrapPolicyMapsDeletions(t *testing.T) { - ctx := context.Background() + ctx := t.Context() r := static.NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) - policy := &testPolicy{} + policy := newTestPolicy() b := balancer.WrapPolicy( r, @@ -111,11 +143,11 @@ func TestWrapPolicyMapsDeletions(t *testing.T) { }, ) - assert.Nil(t, b.Open(ctx)) - time.Sleep(10 * time.Millisecond) + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) // initial peers r.UpdatePeers(static.PeersFromStrings("localhost:2", "localhost:3")) - time.Sleep(10 * time.Millisecond) + require.NoError(t, policy.waitForUpdate(ctx)) // diff update updates := policy.getUpdates() assert.Len(t, updates, 2) @@ -123,13 +155,13 @@ func TestWrapPolicyMapsDeletions(t *testing.T) { assert.ElementsMatch(t, []wrappedPeer{{addr: "localhost:3"}}, updates[1].Additions) assert.ElementsMatch(t, []wrappedPeer{{addr: "localhost:1"}}, updates[1].Deletions) - assert.Nil(t, b.Close()) + assert.NoError(t, b.Close()) } func TestWrapPolicySkipsFailedBuilds(t *testing.T) { - ctx := context.Background() + ctx := t.Context() r := static.NewResolverFromStrings([]string{"localhost:1", "fail:2", "localhost:3"}) - policy := &testPolicy{} + policy := newTestPolicy() b := balancer.WrapPolicy( r, @@ -138,12 +170,13 @@ func TestWrapPolicySkipsFailedBuilds(t *testing.T) { if sp.Addr() == "fail:2" { return wrappedPeer{}, errors.New("build failed") } + return wrappedPeer{addr: sp.Addr()}, nil }, ) - assert.Nil(t, b.Open(ctx)) - time.Sleep(10 * time.Millisecond) + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) updates := policy.getUpdates() assert.Len(t, updates, 1) @@ -152,13 +185,13 @@ func TestWrapPolicySkipsFailedBuilds(t *testing.T) { {addr: "localhost:3"}, }, updates[0].Additions) - assert.Nil(t, b.Close()) + assert.NoError(t, b.Close()) } func TestWrapPolicyDelegatesGetToPolicy(t *testing.T) { - ctx := context.Background() + ctx := t.Context() r := static.NewResolverFromStrings([]string{"localhost:1"}) - policy := &testPolicy{} + policy := newTestPolicy() b := balancer.WrapPolicy( r, @@ -168,17 +201,16 @@ func TestWrapPolicyDelegatesGetToPolicy(t *testing.T) { }, ) - assert.Nil(t, b.Open(ctx)) - time.Sleep(10 * time.Millisecond) + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) peer, done, err := b.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, "localhost:1", peer.Addr()) done(nil) - peer, done, err = b.Get(ctx, balancer.GetOptions{NoWait: true}) + _, _, err = b.Get(ctx, balancer.GetOptions{NoWait: true}) assert.Equal(t, balancer.ErrNoPeerAvailable, err) - assert.Nil(t, done) - assert.Nil(t, b.Close()) + assert.NoError(t, b.Close()) } diff --git a/discovery/balancer/random/balancer.go b/discovery/balancer/random/balancer.go index caed6d6..fcfbdb8 100644 --- a/discovery/balancer/random/balancer.go +++ b/discovery/balancer/random/balancer.go @@ -2,8 +2,8 @@ package random import ( "context" - "math/rand" - "time" + "math/rand/v2" + "sync" "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/balancer/simple" @@ -11,21 +11,37 @@ import ( "github.com/upfluence/pkg/v2/discovery/resolver" ) +// Rand is the interface consumed by the random picker, allowing tests to +// inject a deterministic source. type Rand interface { - Intn(int) int + IntN(int) int +} + +// lockedRand wraps a *rand.Rand with a mutex so it is safe for concurrent use. +type lockedRand struct { + mu sync.Mutex + r *rand.Rand +} + +func (lr *lockedRand) IntN(n int) int { + lr.mu.Lock() + v := lr.r.IntN(n) + lr.mu.Unlock() + + return v } type picker[T peer.Peer] struct { rand Rand } -func (p *picker[T]) Pick(ctx context.Context, peers []T) (T, error) { - return peers[p.rand.Intn(len(peers))], nil +func (p *picker[T]) Pick(_ context.Context, peers []T) (T, error) { + return peers[p.rand.IntN(len(peers))], nil } func NewPolicy[T peer.Peer]() balancer.Policy[T] { return simple.NewPolicy(&picker[T]{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + rand: &lockedRand{r: rand.New(rand.NewPCG(rand.Uint64(), rand.Uint64()))}, //nolint:gosec }) } diff --git a/discovery/balancer/random/balancer_test.go b/discovery/balancer/random/balancer_test.go index b953b87..4287856 100644 --- a/discovery/balancer/random/balancer_test.go +++ b/discovery/balancer/random/balancer_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" @@ -24,14 +25,17 @@ func TestBalancerWithPeers(t *testing.T) { ) err := b.Open(ctx) - assert.Nil(t, err) + require.NoError(t, err) seen := make(map[string]int) - for i := 0; i < 100; i++ { + + for range 100 { p, done, err := b.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) + + require.NoError(t, err) assert.NotEmpty(t, p.Addr()) seen[p.Addr()]++ + done(nil) } diff --git a/discovery/balancer/roundrobin/balancer.go b/discovery/balancer/roundrobin/balancer.go index 01c0083..ff44406 100644 --- a/discovery/balancer/roundrobin/balancer.go +++ b/discovery/balancer/roundrobin/balancer.go @@ -14,8 +14,9 @@ type picker[T peer.Peer] struct { index atomic.Uint64 } -func (p *picker[T]) Pick(ctx context.Context, peers []T) (T, error) { +func (p *picker[T]) Pick(_ context.Context, peers []T) (T, error) { idx := p.index.Add(1) - 1 + return peers[idx%uint64(len(peers))], nil } diff --git a/discovery/balancer/roundrobin/balancer_test.go b/discovery/balancer/roundrobin/balancer_test.go index 6bbf5c1..4f92f81 100644 --- a/discovery/balancer/roundrobin/balancer_test.go +++ b/discovery/balancer/roundrobin/balancer_test.go @@ -3,9 +3,9 @@ package roundrobin import ( "context" "testing" - "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" @@ -25,16 +25,34 @@ func TestBalanceRoundRobinOrder(t *testing.T) { ) err := b.Open(ctx) - assert.Nil(t, err) + require.NoError(t, err) - time.Sleep(10 * time.Millisecond) + // Collect the first full cycle. The initial Get blocks until the Puller + // goroutine delivers peers, so this also acts as the synchronisation point. + // The policy preserves insertion order, so subsequent cycles are identical. + var firstCycle [3]string - // Verify round-robin cycles through all peers in order - for cycle := 0; cycle < 3; cycle++ { - for i := 0; i < 3; i++ { + for i := range 3 { + p, done, err := b.Get(ctx, balancer.GetOptions{}) + + require.NoError(t, err) + + firstCycle[i] = p.Addr() + + done(nil) + } + + // All three peers must appear in the first cycle. + assert.ElementsMatch(t, []string{"localhost:0", "localhost:1", "localhost:2"}, firstCycle[:]) + + // Subsequent cycles must repeat in the exact same order. + for range 2 { + for i := range 3 { p, done, err := b.Get(ctx, balancer.GetOptions{}) - assert.Nil(t, err) - assert.Contains(t, []string{"localhost:0", "localhost:1", "localhost:2"}, p.Addr()) + + require.NoError(t, err) + assert.Equal(t, firstCycle[i], p.Addr()) + done(nil) } } diff --git a/discovery/balancer/simple/policy.go b/discovery/balancer/simple/policy.go index 513ceec..b2320a3 100644 --- a/discovery/balancer/simple/policy.go +++ b/discovery/balancer/simple/policy.go @@ -20,8 +20,11 @@ type Picker[T peer.Peer] interface { type policy[T peer.Peer] struct { picker Picker[T] - mu sync.RWMutex - peers []T + mu sync.RWMutex + peers []T + // peerSet is an auxiliary index for O(1) presence checks; the + // authoritative ordering lives in peers. + peerSet map[string]struct{} notifier chan struct{} } @@ -29,6 +32,7 @@ type policy[T peer.Peer] struct { func NewPolicy[T peer.Peer](picker Picker[T]) balancer.Policy[T] { return &policy[T]{ picker: picker, + peerSet: make(map[string]struct{}), notifier: make(chan struct{}), } } @@ -39,22 +43,26 @@ func (p *policy[T]) Update(u resolver.Update[T]) { wasEmpty := len(p.peers) == 0 - peerMap := make(map[string]T) - for _, peer := range p.peers { - peerMap[peer.Addr()] = peer - } - + // Apply deletions: remove from both the set and the ordered slice. for _, peer := range u.Deletions { - delete(peerMap, peer.Addr()) + addr := peer.Addr() + + if _, ok := p.peerSet[addr]; ok { + delete(p.peerSet, addr) + p.peers = slices.DeleteFunc(p.peers, func(q T) bool { + return q.Addr() == addr + }) + } } + // Apply additions: append new peers while preserving existing order. for _, peer := range u.Additions { - peerMap[peer.Addr()] = peer - } + addr := peer.Addr() - p.peers = make([]T, 0, len(peerMap)) - for _, peer := range peerMap { - p.peers = append(p.peers, peer) + if _, ok := p.peerSet[addr]; !ok { + p.peerSet[addr] = struct{}{} + p.peers = append(p.peers, peer) + } } if wasEmpty && len(p.peers) > 0 { @@ -89,6 +97,7 @@ func (p *policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func( } peer, err := p.picker.Pick(ctx, peers) + return peer, func(error) {}, err } } diff --git a/discovery/balancer/simple/policy_test.go b/discovery/balancer/simple/policy_test.go index 0eccb21..902f790 100644 --- a/discovery/balancer/simple/policy_test.go +++ b/discovery/balancer/simple/policy_test.go @@ -3,9 +3,9 @@ package simple import ( "context" "errors" + "runtime" "sync" "testing" - "time" "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" @@ -25,7 +25,7 @@ type roundRobinPicker struct { index int } -func (p *roundRobinPicker) Pick(ctx context.Context, peers []static.Peer) (static.Peer, error) { +func (p *roundRobinPicker) Pick(_ context.Context, peers []static.Peer) (static.Peer, error) { if len(peers) == 0 { return static.Peer(""), errors.New("no peers available") } @@ -47,22 +47,24 @@ func TestPickerDelegation(t *testing.T) { static.Peer("peer3"), } + // Update is synchronous: peers are visible to Get immediately after return. policy.Update(resolver.Update[static.Peer]{Additions: peers}) - time.Sleep(10 * time.Millisecond) peer, _, err := policy.Get(context.Background(), balancer.GetOptions{NoWait: true}) if err != nil { t.Fatalf("Get() error = %v", err) } - // Verify we got one of the peers (order from map is not guaranteed) found := false + for _, p := range peers { if peer.Addr() == p.Addr() { found = true + break } } + if !found { t.Errorf("Get() returned unexpected peer %q", peer.Addr()) } @@ -72,8 +74,8 @@ func TestPickerError(t *testing.T) { policy := NewPolicy(&errorPicker{}) peers := []static.Peer{static.Peer("peer1")} + // Update is synchronous: no sleep needed before Get. policy.Update(resolver.Update[static.Peer]{Additions: peers}) - time.Sleep(10 * time.Millisecond) _, _, err := policy.Get(context.Background(), balancer.GetOptions{NoWait: true}) if err == nil { @@ -88,31 +90,37 @@ func TestPickerError(t *testing.T) { // lastPicker picks the last peer from the list type lastPicker struct{} -func (p *lastPicker) Pick(ctx context.Context, peers []static.Peer) (static.Peer, error) { +func (p *lastPicker) Pick(_ context.Context, peers []static.Peer) (static.Peer, error) { if len(peers) == 0 { return static.Peer(""), errors.New("no peers available") } + return peers[len(peers)-1], nil } // errorPicker always returns an error type errorPicker struct{} -func (p *errorPicker) Pick(ctx context.Context, peers []static.Peer) (static.Peer, error) { +func (p *errorPicker) Pick(_ context.Context, _ []static.Peer) (static.Peer, error) { return static.Peer(""), errors.New("picker error") } -// TestRaceConditionPeerRemovalAfterWakeup tests the race condition where -// peers are removed after Get() wakes up from the notifier but before it -// re-reads the peer list. The fix should retry waiting in this case. +// TestRaceConditionPeerRemovalAfterWakeup verifies that Get() retries when +// peers are removed between the notifier being closed and Get() re-reading +// the peer list. func TestRaceConditionPeerRemovalAfterWakeup(t *testing.T) { policy := NewPolicy(&roundRobinPicker{}) - ctx := context.Background() + ctx := t.Context() + + // started is closed just before the goroutine enters Get's select, giving + // the test a deterministic signal to proceed rather than sleeping. + started := make(chan struct{}) + gotPeer := make(chan static.Peer, 1) + gotErr := make(chan error, 1) - // Start a goroutine that will wait for peers - gotPeer := make(chan static.Peer) - gotErr := make(chan error) go func() { + close(started) + peer, _, err := policy.Get(ctx, balancer.GetOptions{}) if err != nil { gotErr <- err @@ -121,29 +129,29 @@ func TestRaceConditionPeerRemovalAfterWakeup(t *testing.T) { } }() - // Give the goroutine time to start waiting - time.Sleep(10 * time.Millisecond) + // Wait until the goroutine has been scheduled, then yield once more so it + // reaches the select inside Get before we send any updates. + <-started + runtime.Gosched() - // Add a peer (this will close the notifier) + // Add a peer — closes the notifier, waking the goroutine. policy.Update(resolver.Update[static.Peer]{ Additions: []static.Peer{static.Peer("localhost:1")}, }) - // Immediately remove the peer to simulate the race condition - // This happens after the notifier is closed but potentially before - // Get() re-reads the peer list + // Immediately remove it to exercise the retry path: the goroutine may have + // woken from the notifier but not yet re-read the peer slice. policy.Update(resolver.Update[static.Peer]{ Deletions: []static.Peer{static.Peer("localhost:1")}, }) - // Add the peer back so Get() can eventually succeed - time.Sleep(5 * time.Millisecond) + // Re-add a peer so Get() can eventually succeed. + // No sleep needed: Update() is synchronous and the goroutine will pick up + // the new notifier on its next iteration. policy.Update(resolver.Update[static.Peer]{ Additions: []static.Peer{static.Peer("localhost:2")}, }) - // The Get() call should eventually succeed (not return error) - // It could get either localhost:1 (if fast enough) or localhost:2 (after retry) select { case peer := <-gotPeer: addr := peer.Addr() @@ -152,7 +160,7 @@ func TestRaceConditionPeerRemovalAfterWakeup(t *testing.T) { } case err := <-gotErr: t.Fatalf("Get() returned error %v, expected to retry and succeed", err) - case <-time.After(200 * time.Millisecond): - t.Fatal("Get() did not complete in time") + case <-ctx.Done(): + t.Fatal("Get() did not complete before test deadline") } } diff --git a/discovery/resolver/filter/resolver.go b/discovery/resolver/filter/resolver.go index fcb8197..0527785 100644 --- a/discovery/resolver/filter/resolver.go +++ b/discovery/resolver/filter/resolver.go @@ -16,8 +16,13 @@ func WrapResolver[T peer.Peer](r resolver.Resolver[T], allow func(T) bool) resol return &filterResolver[T]{inner: r, allow: allow} } -func (r *filterResolver[T]) Open(ctx context.Context) error { return r.inner.Open(ctx) } -func (r *filterResolver[T]) Close() error { return r.inner.Close() } +func (r *filterResolver[T]) Open(ctx context.Context) error { + return r.inner.Open(ctx) +} + +func (r *filterResolver[T]) Close() error { + return r.inner.Close() +} func (r *filterResolver[T]) Resolve() resolver.Watcher[T] { return &watcher[T]{inner: r.inner.Resolve(), allow: r.allow} @@ -29,22 +34,39 @@ type watcher[T peer.Peer] struct { admitted map[string]struct{} } -func (w *watcher[T]) Close() error { return w.inner.Close() } +func (w *watcher[T]) Close() error { + return w.inner.Close() +} func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { - for { + // When NoWait is requested we only attempt one read from the inner watcher. + // If that read is empty (ErrNoUpdates) or entirely filtered out, we return + // ErrNoUpdates immediately. We must not keep consuming inner updates in a + // loop under NoWait — doing so would silently drain the inner channel and + // discard real updates the caller may want to inspect later. + if opts.NoWait { u, err := w.inner.Next(ctx, opts) if err != nil { return resolver.Update[T]{}, err } filtered := w.filter(u) - if len(filtered.Additions) == 0 && len(filtered.Deletions) == 0 { - if opts.NoWait { - return resolver.Update[T]{}, resolver.ErrNoUpdates - } + return resolver.Update[T]{}, resolver.ErrNoUpdates + } + return filtered, nil + } + + // Blocking path: loop until we get at least one non-empty filtered update. + for { + u, err := w.inner.Next(ctx, opts) + if err != nil { + return resolver.Update[T]{}, err + } + + filtered := w.filter(u) + if len(filtered.Additions) == 0 && len(filtered.Deletions) == 0 { continue } diff --git a/discovery/resolver/filter/resolver_test.go b/discovery/resolver/filter/resolver_test.go index 6322adc..7d6c192 100644 --- a/discovery/resolver/filter/resolver_test.go +++ b/discovery/resolver/filter/resolver_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/resolver" "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" @@ -21,6 +22,7 @@ func TestResolver(t *testing.T) { // Filter the expected peers var expected []static.Peer + for _, p := range peers { if strings.HasPrefix(p.Addr(), "localhost") { expected = append(expected, p) @@ -43,7 +45,7 @@ func TestFilterResolverAllowsAdditions(t *testing.T) { u, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, []static.Peer{static.Peer("allow:1")}, u.Additions) assert.Empty(t, u.Deletions) @@ -64,13 +66,13 @@ func TestFilterResolverTracksDeletions(t *testing.T) { w := r.Resolve() _, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) inner.UpdatePeers(static.PeersFromStrings("allow:2", "deny:2")) u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, []static.Peer{static.Peer("allow:2")}, u.Additions) assert.ElementsMatch(t, []static.Peer{static.Peer("allow:1")}, u.Deletions) } @@ -94,7 +96,7 @@ func TestFilterResolverNoWaitFilteredEmpty(t *testing.T) { u, err = w.Next(ctx, resolver.ResolveOptions{NoWait: true}) - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, []static.Peer{static.Peer("allow:1")}, u.Additions) assert.Empty(t, u.Deletions) } diff --git a/discovery/resolver/puller.go b/discovery/resolver/puller.go index 517f2c4..386daed 100644 --- a/discovery/resolver/puller.go +++ b/discovery/resolver/puller.go @@ -19,23 +19,30 @@ type Puller[T peer.Peer] struct { Monitor closer.Monitor NoWait bool - openErr error - openOnce sync.Once - opened atomic.Bool + openErr error + openOnce sync.Once + closeOnce sync.Once + opened atomic.Bool } -func NewPuller[T peer.Peer](r Resolver[T], fn func(Update[T])) (*Puller[T], func()) { +func NewPuller[T peer.Peer](r Resolver[T], fn func(Update[T])) (*Puller[T], func() error) { var p = &Puller[T]{ Resolver: r, UpdateFunc: fn, } - return p, func() { p.Close() } + return p, p.Close } func (p *Puller[T]) Close() error { - p.opened.Store(false) - return errors.Combine(p.Monitor.Close(), p.Resolver.Close()) + var err error + + p.closeOnce.Do(func() { + p.opened.Store(false) + err = errors.Combine(p.Monitor.Close(), p.Resolver.Close()) + }) + + return err } func (p *Puller[T]) IsOpen() bool { @@ -75,9 +82,19 @@ func (p *Puller[T]) pull(ctx context.Context) { for err == nil { u, err = w.Next(ctx, ResolveOptions{NoWait: noWait}) - if err == nil || err == ErrNoUpdates { + if err == nil { noWait = false + p.UpdateFunc(u) + + continue + } + + if errors.Is(err, ErrNoUpdates) { + // No update available right now; reset noWait and keep going + // without calling UpdateFunc with a meaningless empty update. + noWait = false + continue } diff --git a/discovery/resolver/puller_test.go b/discovery/resolver/puller_test.go index bbb832a..acb5862 100644 --- a/discovery/resolver/puller_test.go +++ b/discovery/resolver/puller_test.go @@ -5,9 +5,9 @@ import ( "errors" "sync/atomic" "testing" - "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/resolver" "github.com/upfluence/pkg/v2/discovery/resolver/static" @@ -74,16 +74,16 @@ func TestPullerRecoversAfterWatcherError(t *testing.T) { } }) - assert.Nil(t, p.Open(context.Background())) + require.NoError(t, p.Open(context.Background())) select { case u := <-updates: assert.Equal(t, []static.Peer{static.Peer("allow:1")}, u.Additions) - case <-time.After(200 * time.Millisecond): + case <-t.Context().Done(): t.Fatal("timed out waiting for update after watcher error") } - assert.Nil(t, p.Close()) + assert.NoError(t, p.Close()) } func TestPullerIsOpenTracksClose(t *testing.T) { @@ -91,9 +91,9 @@ func TestPullerIsOpenTracksClose(t *testing.T) { p, _ := resolver.NewPuller(r, func(resolver.Update[static.Peer]) {}) assert.False(t, p.IsOpen()) - assert.Nil(t, p.Open(context.Background())) + require.NoError(t, p.Open(context.Background())) assert.True(t, p.IsOpen()) - assert.Nil(t, p.Close()) + require.NoError(t, p.Close()) assert.False(t, p.IsOpen()) } diff --git a/discovery/resolver/resolvertest/resolver.go b/discovery/resolver/resolvertest/resolver.go index ecc1500..6e640d3 100644 --- a/discovery/resolver/resolvertest/resolver.go +++ b/discovery/resolver/resolvertest/resolver.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/peer" "github.com/upfluence/pkg/v2/discovery/resolver" @@ -40,13 +41,13 @@ func ResolverTest[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePee } } -func testNoSeeds[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { +func testNoSeeds[T peer.Peer](t *testing.T, factory ResolverFactory[T], _ func(...string) []T) { ctx := context.Background() r, expected := factory(nil) assert.Empty(t, expected) - assert.Nil(t, r.Open(ctx)) + require.NoError(t, r.Open(ctx)) defer r.Close() w := r.Resolve() @@ -64,7 +65,7 @@ func testInitialPeers[T peer.Peer](t *testing.T, factory ResolverFactory[T], mak peers := makePeers("localhost:1", "localhost:2") r, expected := factory(peers) - assert.Nil(t, r.Open(ctx)) + require.NoError(t, r.Open(ctx)) defer r.Close() w := r.Resolve() @@ -72,7 +73,7 @@ func testInitialPeers[T peer.Peer](t *testing.T, factory ResolverFactory[T], mak // Should get initial peers u, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, expected, u.Additions) assert.Empty(t, u.Deletions) } @@ -82,7 +83,7 @@ func testNoWaitNoUpdates[T peer.Peer](t *testing.T, factory ResolverFactory[T], peers := makePeers("localhost:1") r, _ := factory(peers) - assert.Nil(t, r.Open(ctx)) + require.NoError(t, r.Open(ctx)) defer r.Close() w := r.Resolve() @@ -90,7 +91,7 @@ func testNoWaitNoUpdates[T peer.Peer](t *testing.T, factory ResolverFactory[T], // Consume initial update _, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) // NoWait should return ErrNoUpdates when no updates are available u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) @@ -104,7 +105,7 @@ func testContextCancellation[T peer.Peer](t *testing.T, factory ResolverFactory[ peers := makePeers("localhost:1") r, _ := factory(peers) - assert.Nil(t, r.Open(ctx)) + require.NoError(t, r.Open(ctx)) defer r.Close() w := r.Resolve() @@ -112,7 +113,7 @@ func testContextCancellation[T peer.Peer](t *testing.T, factory ResolverFactory[ // Consume initial update _, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) // Cancel context and try to wait for updates cctx, cancel := context.WithCancel(ctx) @@ -129,18 +130,18 @@ func testWatcherClose[T peer.Peer](t *testing.T, factory ResolverFactory[T], mak peers := makePeers("localhost:1") r, _ := factory(peers) - assert.Nil(t, r.Open(ctx)) + require.NoError(t, r.Open(ctx)) defer r.Close() w := r.Resolve() // Consume initial update _, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) // Close the watcher err = w.Close() - assert.Nil(t, err) + require.NoError(t, err) // Trying to get next update should fail after close // Give a small timeout to avoid blocking forever @@ -148,6 +149,6 @@ func testWatcherClose[T peer.Peer](t *testing.T, factory ResolverFactory[T], mak defer cancel() _, err = w.Next(cctx, resolver.ResolveOptions{}) - assert.NotNil(t, err) + assert.Error(t, err) // Should be either context.Canceled, context.DeadlineExceeded, or a close error } diff --git a/discovery/resolver/static/resolver.go b/discovery/resolver/static/resolver.go index 4ab80ef..f75c06b 100644 --- a/discovery/resolver/static/resolver.go +++ b/discovery/resolver/static/resolver.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "sync" - "sync/atomic" "github.com/upfluence/pkg/v2/closer" "github.com/upfluence/pkg/v2/discovery/peer" @@ -91,6 +90,7 @@ func (r *Resolver[T]) UpdatePeers(peers []T) { if len(u.Additions) == 0 && len(u.Deletions) == 0 { r.mu.Unlock() + return } @@ -99,7 +99,23 @@ func (r *Resolver[T]) UpdatePeers(peers []T) { r.mu.Unlock() for _, ch := range chs { - ch <- u + // Non-blocking send: if the watcher's buffer is already full (slow + // consumer), merge the pending update with the new one so no update + // is silently dropped and the caller is never blocked. + select { + case ch <- u: + default: + select { + case pending := <-ch: + pending.Additions = append(pending.Additions, u.Additions...) + + pending.Deletions = append(pending.Deletions, u.Deletions...) + ch <- pending + default: + // Channel was drained concurrently; just send the new update. + ch <- u + } + } } } @@ -123,6 +139,7 @@ func (r *Resolver[T]) unsubscribe(ch chan resolver.Update[T]) { for i, c := range r.chs { if c == ch { r.chs = append(r.chs[:i], r.chs[i+1:]...) + return } } @@ -152,21 +169,26 @@ func (r *Resolver[T]) Resolve() resolver.Watcher[T] { type watcher[T peer.Peer] struct { closer.Monitor - r *Resolver[T] + r *Resolver[T] + + mu sync.Mutex ch chan resolver.Update[T] - initial int32 + initial bool } func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { - ok := atomic.CompareAndSwapInt32(&w.initial, 0, 1) - - if ok { + w.mu.Lock() + if !w.initial { + w.initial = true ch, peers := w.r.subscribe() w.ch = ch + w.mu.Unlock() if len(peers) > 0 { return resolver.Update[T]{Additions: peers}, nil } + } else { + w.mu.Unlock() } if opts.NoWait { @@ -190,13 +212,18 @@ func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (re return resolver.Update[T]{}, wctx.Err() case <-rctx.Done(): w.Close() - return resolver.Update[T]{}, wctx.Err() + + return resolver.Update[T]{}, rctx.Err() } } func (w *watcher[T]) Close() error { - if w.ch != nil { - w.r.unsubscribe(w.ch) + w.mu.Lock() + ch := w.ch + w.mu.Unlock() + + if ch != nil { + w.r.unsubscribe(ch) } return w.Monitor.Close() diff --git a/discovery/resolver/static/resolver_test.go b/discovery/resolver/static/resolver_test.go index 41783d3..ecfeb23 100644 --- a/discovery/resolver/static/resolver_test.go +++ b/discovery/resolver/static/resolver_test.go @@ -5,6 +5,8 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/upfluence/pkg/v2/discovery/resolver" "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" ) @@ -23,7 +25,7 @@ func TestResolve(t *testing.T) { u, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal( t, resolver.Update[Peer]{ @@ -41,7 +43,7 @@ func TestResolve(t *testing.T) { assert.Equal(t, resolver.Update[Peer]{}, u) err = w.Close() - assert.Nil(t, err) + require.NoError(t, err) u, err = w.Next(ctx, resolver.ResolveOptions{}) @@ -58,6 +60,7 @@ func TestPeers(t *testing.T) { // Mutating the returned slice should not affect the resolver peers[0] = Peer("localhost:99") + assert.Equal(t, []Peer{Peer("localhost:1"), Peer("localhost:2")}, r.Peers()) } @@ -73,7 +76,7 @@ func TestUpdatePeers(t *testing.T) { // Consume initial update u, err := w.Next(context.Background(), resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal( t, resolver.Update[Peer]{ @@ -96,7 +99,7 @@ func TestUpdatePeers(t *testing.T) { // The watcher should receive the diff u, err = w.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, []Peer{Peer("localhost:3")}, u.Additions) assert.ElementsMatch(t, []Peer{Peer("localhost:1")}, u.Deletions) } @@ -108,7 +111,7 @@ func TestUpdatePeersNoChange(t *testing.T) { // Consume initial update _, err := w.Next(context.Background(), resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) // Update with same peers — no diff r.UpdatePeers(PeersFromStrings("localhost:1", "localhost:2")) @@ -126,18 +129,19 @@ func TestUpdatePeersMultipleWatchers(t *testing.T) { // Consume initial updates _, err := w1.Next(context.Background(), resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) + _, err = w2.Next(context.Background(), resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) r.UpdatePeers(PeersFromStrings("localhost:1", "localhost:2")) u1, err := w1.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u1.Additions) u2, err := w2.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u2.Additions) } @@ -148,11 +152,11 @@ func TestUpdatePeersClosedWatcher(t *testing.T) { // Consume initial update _, err := w.Next(context.Background(), resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) // Close the watcher — it should unsubscribe err = w.Close() - assert.Nil(t, err) + require.NoError(t, err) // This should not block or panic r.UpdatePeers(PeersFromStrings("localhost:2")) @@ -165,13 +169,15 @@ func TestUpdatePeersBlockingWatcher(t *testing.T) { // Consume initial update _, err := w.Next(context.Background(), resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) // Start a blocking Next call in a goroutine done := make(chan resolver.Update[Peer], 1) + go func() { u, err := w.Next(context.Background(), resolver.ResolveOptions{}) - assert.Nil(t, err) + assert.NoError(t, err) + done <- u }() diff --git a/discovery/resolver/sync_resolver.go b/discovery/resolver/sync_resolver.go index 3895fb2..d718fe1 100644 --- a/discovery/resolver/sync_resolver.go +++ b/discovery/resolver/sync_resolver.go @@ -36,7 +36,7 @@ func (sr *syncResolver[T]) ResolveSync(ctx context.Context, n string) ([]T, erro lr, ok := sr.lrs[n] if !ok { - lr = &localResolver[T]{readyc: make(chan struct{})} + lr = &localResolver[T]{readyc: make(chan struct{}), noWait: sr.noWait} lr.p = &Puller[T]{ Resolver: sr.builder.Build(n), @@ -46,7 +46,7 @@ func (sr *syncResolver[T]) ResolveSync(ctx context.Context, n string) ([]T, erro if err := lr.p.Open(ctx); err != nil { sr.mu.Unlock() - close(lr.readyc) + return nil, err } @@ -65,7 +65,7 @@ func (sr *syncResolver[T]) Close() error { for _, lr := range sr.lrs { if err := lr.close(); err != nil { - errs = append(errs) + errs = append(errs, err) } } @@ -76,7 +76,8 @@ func (sr *syncResolver[T]) Close() error { } type localResolver[T peer.Peer] struct { - p *Puller[T] + p *Puller[T] + noWait bool readyOnce sync.Once readyc chan struct{} @@ -93,11 +94,7 @@ func (lr *localResolver[T]) update(u Update[T]) { } for _, p := range u.Deletions { - addr := p.Addr() - - if _, ok := lr.ps[addr]; ok { - delete(lr.ps, addr) - } + delete(lr.ps, p.Addr()) } for _, p := range u.Additions { @@ -110,7 +107,7 @@ func (lr *localResolver[T]) update(u Update[T]) { } func (lr *localResolver[T]) close() error { - return errors.Combine(lr.p.Close()) + return lr.p.Close() } func (lr *localResolver[T]) resolve(ctx context.Context) ([]T, error) { @@ -118,10 +115,21 @@ func (lr *localResolver[T]) resolve(ctx context.Context) ([]T, error) { return nil, ErrClose } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-lr.readyc: + if lr.noWait { + // With noWait, return whatever peers are available right now without + // blocking for the first update. If readyc isn't closed yet the + // background Puller hasn't delivered anything, so return empty. + select { + case <-lr.readyc: + default: + return nil, nil + } + } else { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-lr.readyc: + } } lr.mu.RLock() diff --git a/discovery/resolver/sync_resolver_test.go b/discovery/resolver/sync_resolver_test.go index 037c2fb..13a32f7 100644 --- a/discovery/resolver/sync_resolver_test.go +++ b/discovery/resolver/sync_resolver_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/resolver" "github.com/upfluence/pkg/v2/discovery/resolver/static" @@ -23,16 +24,16 @@ func TestNameResolverWithPeers(t *testing.T) { ps, err := nr.ResolveSync(ctx, "n1") - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, static.PeersFromStrings("foo", "bar"), ps) ps, err = nr.ResolveSync(ctx, "n2") - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, static.PeersFromStrings("biz", "buz"), ps) err = nr.Close() - assert.Nil(t, err) + require.NoError(t, err) } func TestNameResolverNoPeerNoWait(t *testing.T) { @@ -41,11 +42,11 @@ func TestNameResolverNoPeerNoWait(t *testing.T) { ps, err := nr.ResolveSync(ctx, "n1") - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, 0, len(ps)) err = nr.Close() - assert.Nil(t, err) + require.NoError(t, err) } func TestNameResolverNoPeerWait(t *testing.T) { @@ -60,5 +61,5 @@ func TestNameResolverNoPeerWait(t *testing.T) { assert.ElementsMatch(t, 0, len(ps)) err = nr.Close() - assert.Nil(t, err) + require.NoError(t, err) } diff --git a/discovery/resolver/transform/resolver_test.go b/discovery/resolver/transform/resolver_test.go index 3acd7a0..4a898c2 100644 --- a/discovery/resolver/transform/resolver_test.go +++ b/discovery/resolver/transform/resolver_test.go @@ -3,9 +3,9 @@ package transform_test import ( "context" "testing" - "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/resolver" "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" @@ -27,6 +27,7 @@ func makeWrappedPeers(addrs ...string) []wrappedPeer { for i, addr := range addrs { peers[i] = wrappedPeer{addr: addr, prefix: "prefix-"} } + return peers } @@ -56,7 +57,7 @@ func TestTransformResolverTransformsUpdates(t *testing.T) { return wrappedPeer{addr: p.Addr(), prefix: "transformed-"} }) - assert.Nil(t, tr.Open(ctx)) + require.NoError(t, tr.Open(ctx)) defer tr.Close() w := tr.Resolve() @@ -64,7 +65,7 @@ func TestTransformResolverTransformsUpdates(t *testing.T) { // Get initial peers u, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, u.Additions, 1) assert.Equal(t, "transformed-host1:80", u.Additions[0].Addr()) @@ -72,7 +73,7 @@ func TestTransformResolverTransformsUpdates(t *testing.T) { source.UpdatePeers(static.PeersFromStrings("host2:80", "host3:80")) u, err = w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) // Should have additions and deletions assert.Len(t, u.Additions, 2) @@ -93,7 +94,7 @@ func TestTransformResolverMultipleWatchers(t *testing.T) { return wrappedPeer{addr: p.Addr(), prefix: "watcher-"} }) - assert.Nil(t, tr.Open(ctx)) + require.NoError(t, tr.Open(ctx)) defer tr.Close() w1 := tr.Resolve() @@ -104,31 +105,29 @@ func TestTransformResolverMultipleWatchers(t *testing.T) { // Both watchers should get initial peers u1, err := w1.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, u1.Additions, 1) assert.Equal(t, "watcher-host1:80", u1.Additions[0].Addr()) u2, err := w2.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, u2.Additions, 1) assert.Equal(t, "watcher-host1:80", u2.Additions[0].Addr()) - // Update peers + // Update peers — UpdatePeers sends synchronously into the watchers' + // buffered channels, so the update is available for NoWait reads immediately. source.UpdatePeers(static.PeersFromStrings("host2:80")) - // Give time for update to propagate - time.Sleep(10 * time.Millisecond) - // Both watchers should receive the update u1, err = w1.Next(ctx, resolver.ResolveOptions{NoWait: true}) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, u1.Additions, 1) assert.Equal(t, "watcher-host2:80", u1.Additions[0].Addr()) assert.Len(t, u1.Deletions, 1) assert.Equal(t, "watcher-host1:80", u1.Deletions[0].Addr()) u2, err = w2.Next(ctx, resolver.ResolveOptions{NoWait: true}) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, u2.Additions, 1) assert.Equal(t, "watcher-host2:80", u2.Additions[0].Addr()) assert.Len(t, u2.Deletions, 1)