diff --git a/discovery/balancer/balancertest/doc.go b/discovery/balancer/balancertest/doc.go new file mode 100644 index 0000000..37443b8 --- /dev/null +++ b/discovery/balancer/balancertest/doc.go @@ -0,0 +1,17 @@ +// Package balancertest provides testing utilities for balancer.Policy implementations. +// +// Usage: +// +// func TestMyPolicy(t *testing.T) { +// balancertest.PolicyTest(t, func() balancer.Policy[static.Peer, static.Peer] { +// return NewPolicy[static.Peer]() +// }) +// } +// +// PolicyTest runs a comprehensive test suite that covers: +// - Empty policy (no peers) - tests NoWait and context cancellation behavior +// - Single peer - verifies Get() returns the same peer consistently +// - Adding and removing peers - tests dynamic peer updates +// - Removing all peers - verifies transition back to empty state +// - Adding peers to empty policy - tests that waiting Get() calls unblock +package balancertest diff --git a/discovery/balancer/balancertest/policy.go b/discovery/balancer/balancertest/policy.go new file mode 100644 index 0000000..76fd087 --- /dev/null +++ b/discovery/balancer/balancertest/policy.go @@ -0,0 +1,199 @@ +package balancertest + +import ( + "context" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +// PolicyFactory creates a new Policy instance for testing. +type PolicyFactory func() balancer.Policy[static.Peer] + +// PolicyTest runs a comprehensive test suite for a Policy implementation. +// It tests common policy behaviors including: +// - Empty policy (no peers) +// - Single peer +// - Multiple peers +// - Adding and removing peers +func PolicyTest(t *testing.T, factory PolicyFactory) { + for _, tt := range []struct { + name string + test func(*testing.T, PolicyFactory) + }{ + {"NoPeers", testNoPeers}, + {"SinglePeer", testSinglePeer}, + {"AddAndRemovePeers", testAddAndRemovePeers}, + {"RemoveAllPeers", testRemoveAllPeers}, + {"AddPeersToEmpty", testAddPeersToEmpty}, + } { + t.Run(tt.name, func(t *testing.T) { + tt.test(t, factory) + }) + } +} + +func testNoPeers(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // NoWait should return ErrNoPeerAvailable immediately + p, done, err := policy.Get(ctx, balancer.GetOptions{NoWait: true}) + assert.Equal(t, balancer.ErrNoPeerAvailable, err) + assert.Nil(t, done) + assert.Empty(t, p.Addr()) + + // Canceled context should return context.Canceled + cctx, cancel := context.WithCancel(ctx) + cancel() + + p, done, err = policy.Get(cctx, balancer.GetOptions{}) + assert.Equal(t, context.Canceled, err) + assert.Nil(t, done) + assert.Empty(t, p.Addr()) +} + +func testSinglePeer(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // Add a single peer + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Should get the same peer repeatedly + for range 5 { + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + + require.NoError(t, err) + assert.NotNil(t, done) + assert.Equal(t, "localhost:1", p.Addr()) + done(nil) + } +} + +func testAddAndRemovePeers(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // Add initial peers + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{ + static.Peer("localhost:1"), + static.Peer("localhost:2"), + }, + }) + + // Verify we can get peers + seen := make(map[string]bool) + + for range 50 { + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + + require.NoError(t, err) + assert.NotNil(t, done) + + seen[p.Addr()] = true + + done(nil) + } + + assert.Contains(t, seen, "localhost:1") + assert.Contains(t, seen, "localhost:2") + + // Update peers: remove localhost:1, add localhost:3 + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:3")}, + Deletions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Verify we only see localhost:2 and localhost:3 + seen = make(map[string]bool) + + for range 50 { + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + + require.NoError(t, err) + assert.NotNil(t, done) + + seen[p.Addr()] = true + + done(nil) + } + + assert.Contains(t, seen, "localhost:2") + assert.Contains(t, seen, "localhost:3") + assert.NotContains(t, seen, "localhost:1") +} + +func testRemoveAllPeers(t *testing.T, factory PolicyFactory) { + ctx := context.Background() + policy := factory() + + // Add a peer + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Verify we can get it + p, done, err := policy.Get(ctx, balancer.GetOptions{}) + require.NoError(t, err) + assert.Equal(t, "localhost:1", p.Addr()) + done(nil) + + // Remove the peer + policy.Update(resolver.Update[static.Peer]{ + Deletions: []static.Peer{static.Peer("localhost:1")}, + }) + + // NoWait should return ErrNoPeerAvailable + p, done, err = policy.Get(ctx, balancer.GetOptions{NoWait: true}) + assert.Equal(t, balancer.ErrNoPeerAvailable, err) + assert.Nil(t, done) + assert.Empty(t, p.Addr()) +} + +func testAddPeersToEmpty(t *testing.T, factory PolicyFactory) { + ctx := t.Context() + policy := factory() + + // started is closed just before the goroutine enters Get's select, + // giving us a deterministic signal instead of a sleep. + started := make(chan struct{}) + done := make(chan struct{}) + + go func() { + close(started) + + p, doneFn, err := policy.Get(ctx, balancer.GetOptions{}) + assert.NoError(t, err) + assert.NotNil(t, doneFn) + assert.NotEmpty(t, p.Addr()) + doneFn(nil) + close(done) + }() + + // Wait until the goroutine has been scheduled, then yield once more so + // it reaches the select inside Get before we call Update. + <-started + runtime.Gosched() + + // Add a peer — this closes the notifier and unblocks Get. + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + select { + case <-done: + // success + case <-ctx.Done(): + t.Fatal("Get() did not unblock after adding peers") + } +} diff --git a/discovery/balancer/dialer.go b/discovery/balancer/dialer.go index 10fd6ba..f892779 100644 --- a/discovery/balancer/dialer.go +++ b/discovery/balancer/dialer.go @@ -4,6 +4,7 @@ import ( "context" "net" "sync" + "sync/atomic" "github.com/upfluence/errors" @@ -16,16 +17,23 @@ type Dialer[T peer.Peer] struct { Dialer *net.Dialer Options GetOptions + dialerOnce sync.Once + netDialer *net.Dialer + mu sync.Mutex lds map[string]*localDialer[T] } func (d *Dialer[T]) dialer() *net.Dialer { - if d.Dialer == nil { - d.Dialer = &net.Dialer{} - } + d.dialerOnce.Do(func() { + if d.Dialer != nil { + d.netDialer = d.Dialer + } else { + d.netDialer = &net.Dialer{} + } + }) - return d.Dialer + return d.netDialer } func (d *Dialer[T]) Dial(network, addr string) (net.Conn, error) { @@ -71,7 +79,7 @@ type localDialer[T peer.Peer] struct { d *Dialer[T] b Balancer[T] - opened bool + opened atomic.Bool sf syncutil.Singleflight[struct{}] } @@ -82,13 +90,13 @@ func (ld *localDialer[T]) open(ctx context.Context) (struct{}, error) { } } - ld.opened = true + ld.opened.Store(true) return struct{}{}, nil } func (ld *localDialer[T]) dial(ctx context.Context, network string) (net.Conn, error) { - if !ld.opened { + if !ld.opened.Load() { if _, _, err := ld.sf.Do(ctx, ld.open); err != nil { return nil, err } @@ -118,13 +126,14 @@ func (ld *localDialer[T]) close() error { type doneCloserConn struct { net.Conn - done func(error) + doneOnce sync.Once + done func(error) } func (dcc *doneCloserConn) Close() error { err := dcc.Conn.Close() - dcc.done(err) + dcc.doneOnce.Do(func() { dcc.done(err) }) return err } diff --git a/discovery/balancer/dialer_test.go b/discovery/balancer/dialer_test.go index 98f9eca..d6cbc51 100644 --- a/discovery/balancer/dialer_test.go +++ b/discovery/balancer/dialer_test.go @@ -2,12 +2,13 @@ package balancer_test import ( "io" - "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/upfluence/pkg/v2/discovery/balancer" "github.com/upfluence/pkg/v2/discovery/balancer/roundrobin" "github.com/upfluence/pkg/v2/discovery/resolver/static" @@ -16,13 +17,13 @@ import ( func TestDialer(t *testing.T) { s1 := httptest.NewServer( http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - io.WriteString(rw, "s1") + io.WriteString(rw, "s1") //nolint:errcheck }), ) s2 := httptest.NewServer( http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - io.WriteString(rw, "s2") + io.WriteString(rw, "s2") //nolint:errcheck }), ) @@ -33,7 +34,7 @@ func TestDialer(t *testing.T) { d := balancer.Dialer[static.Peer]{ Builder: balancer.ResolverBuilder[static.Peer]{ Builder: r, - BalancerFunc: roundrobin.BalancerFunc[static.Peer], + BalancerFunc: balancer.PolicyBalancerFunc(roundrobin.NewPolicy[static.Peer]()), }, } @@ -44,14 +45,14 @@ func TestDialer(t *testing.T) { cl := http.Client{Transport: &http.Transport{DialContext: d.DialContext}} req, err := http.NewRequest("GET", "http://example.com/foo", http.NoBody) - assert.Nil(t, err) + require.NoError(t, err) for _, want := range []string{"s1", "s2", "s1"} { resp, err := cl.Do(req) - assert.Nil(t, err) + require.NoError(t, err) - buf, err := ioutil.ReadAll(resp.Body) - assert.Nil(t, err) + buf, err := io.ReadAll(resp.Body) + require.NoError(t, err) resp.Body.Close() assert.Equal(t, want, string(buf)) diff --git a/discovery/balancer/policy.go b/discovery/balancer/policy.go new file mode 100644 index 0000000..9a31c90 --- /dev/null +++ b/discovery/balancer/policy.go @@ -0,0 +1,74 @@ +package balancer + +import ( + "context" + "sync" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/log" +) + +type Policy[T peer.Peer] interface { + Get(context.Context, GetOptions) (T, func(error), error) + Update(resolver.Update[T]) +} + +type policyBalancer[S, T peer.Peer] struct { + *resolver.Puller[S] + Policy[T] + + mu sync.Mutex + peers map[string]T + builder func(S) (T, error) +} + +func WrapPolicy[S, T peer.Peer](r resolver.Resolver[S], p Policy[T], build func(S) (T, error)) Balancer[T] { + b := &policyBalancer[S, T]{ + Policy: p, + peers: make(map[string]T), + builder: build, + } + + b.Puller = &resolver.Puller[S]{ + Resolver: r, + UpdateFunc: b.handleUpdate, + } + + return b +} + +func PolicyBalancerFunc[T peer.Peer](p Policy[T]) func(resolver.Resolver[T]) Balancer[T] { + return func(r resolver.Resolver[T]) Balancer[T] { + return WrapPolicy(r, p, func(t T) (T, error) { return t, nil }) + } +} + +func (b *policyBalancer[S, T]) handleUpdate(u resolver.Update[S]) { + b.mu.Lock() + defer b.mu.Unlock() + + var mapped resolver.Update[T] + + for _, sp := range u.Additions { + tp, err := b.builder(sp) + if err != nil { + log.WithError(err).Errorf("balancer: failed to build peer for %q, skipping", sp.Addr()) + + continue + } + + b.peers[sp.Addr()] = tp + mapped.Additions = append(mapped.Additions, tp) + } + + for _, sp := range u.Deletions { + if tp, ok := b.peers[sp.Addr()]; ok { + delete(b.peers, sp.Addr()) + + mapped.Deletions = append(mapped.Deletions, tp) + } + } + + b.Update(mapped) +} diff --git a/discovery/balancer/policy_test.go b/discovery/balancer/policy_test.go new file mode 100644 index 0000000..3d80aaa --- /dev/null +++ b/discovery/balancer/policy_test.go @@ -0,0 +1,216 @@ +package balancer_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" + "github.com/upfluence/pkg/v2/metadata" +) + +type wrappedPeer struct { + addr string +} + +func (p wrappedPeer) Addr() string { return p.addr } +func (p wrappedPeer) Metadata() metadata.Metadata { return nil } + +// testPolicy is a Policy implementation used in tests that records all +// updates it receives and exposes waitForUpdate to synchronise on them +// without relying on time.Sleep. +type testPolicy struct { + mu sync.Mutex + peers []wrappedPeer + updates []resolver.Update[wrappedPeer] + // updatec is closed on each Update call and immediately replaced so + // that waitForUpdate returns exactly once per Update. + updatec chan struct{} +} + +func newTestPolicy() *testPolicy { + return &testPolicy{updatec: make(chan struct{})} +} + +func (p *testPolicy) Get(ctx context.Context, opts balancer.GetOptions) (wrappedPeer, func(error), error) { + p.mu.Lock() + defer p.mu.Unlock() + + if len(p.peers) == 0 { + if opts.NoWait { + return wrappedPeer{}, nil, balancer.ErrNoPeerAvailable + } + + <-ctx.Done() + + return wrappedPeer{}, nil, ctx.Err() + } + + peer := p.peers[0] + p.peers = p.peers[1:] + + return peer, func(error) {}, nil +} + +func (p *testPolicy) Update(u resolver.Update[wrappedPeer]) { + p.mu.Lock() + + p.updates = append(p.updates, u) + p.peers = append(p.peers, u.Additions...) + + // Close the current channel and replace it before releasing the lock so + // waitForUpdate never misses a notification. + ch := p.updatec + p.updatec = make(chan struct{}) + + p.mu.Unlock() + + close(ch) +} + +// waitForUpdate blocks until the next Update call completes or ctx is done. +func (p *testPolicy) waitForUpdate(ctx context.Context) error { + p.mu.Lock() + ch := p.updatec + p.mu.Unlock() + + select { + case <-ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (p *testPolicy) getUpdates() []resolver.Update[wrappedPeer] { + p.mu.Lock() + defer p.mu.Unlock() + + updates := make([]resolver.Update[wrappedPeer], len(p.updates)) + copy(updates, p.updates) + + return updates +} + +func TestWrapPolicyMapsAdditions(t *testing.T) { + ctx := t.Context() + r := static.NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + policy := newTestPolicy() + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) + + updates := policy.getUpdates() + assert.Len(t, updates, 1) + assert.ElementsMatch(t, []wrappedPeer{ + {addr: "localhost:1"}, + {addr: "localhost:2"}, + }, updates[0].Additions) + assert.Empty(t, updates[0].Deletions) + + peer, done, err := b.Get(ctx, balancer.GetOptions{}) + require.NoError(t, err) + assert.Equal(t, "localhost:1", peer.Addr()) + done(nil) + + assert.NoError(t, b.Close()) +} + +func TestWrapPolicyMapsDeletions(t *testing.T) { + ctx := t.Context() + r := static.NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + policy := newTestPolicy() + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) // initial peers + + r.UpdatePeers(static.PeersFromStrings("localhost:2", "localhost:3")) + require.NoError(t, policy.waitForUpdate(ctx)) // diff update + + updates := policy.getUpdates() + assert.Len(t, updates, 2) + + assert.ElementsMatch(t, []wrappedPeer{{addr: "localhost:3"}}, updates[1].Additions) + assert.ElementsMatch(t, []wrappedPeer{{addr: "localhost:1"}}, updates[1].Deletions) + + assert.NoError(t, b.Close()) +} + +func TestWrapPolicySkipsFailedBuilds(t *testing.T) { + ctx := t.Context() + r := static.NewResolverFromStrings([]string{"localhost:1", "fail:2", "localhost:3"}) + policy := newTestPolicy() + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + if sp.Addr() == "fail:2" { + return wrappedPeer{}, errors.New("build failed") + } + + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) + + updates := policy.getUpdates() + assert.Len(t, updates, 1) + assert.ElementsMatch(t, []wrappedPeer{ + {addr: "localhost:1"}, + {addr: "localhost:3"}, + }, updates[0].Additions) + + assert.NoError(t, b.Close()) +} + +func TestWrapPolicyDelegatesGetToPolicy(t *testing.T) { + ctx := t.Context() + r := static.NewResolverFromStrings([]string{"localhost:1"}) + policy := newTestPolicy() + + b := balancer.WrapPolicy( + r, + policy, + func(sp static.Peer) (wrappedPeer, error) { + return wrappedPeer{addr: sp.Addr()}, nil + }, + ) + + require.NoError(t, b.Open(ctx)) + require.NoError(t, policy.waitForUpdate(ctx)) + + peer, done, err := b.Get(ctx, balancer.GetOptions{}) + require.NoError(t, err) + assert.Equal(t, "localhost:1", peer.Addr()) + done(nil) + + _, _, err = b.Get(ctx, balancer.GetOptions{NoWait: true}) + assert.Equal(t, balancer.ErrNoPeerAvailable, err) + + assert.NoError(t, b.Close()) +} diff --git a/discovery/balancer/random/balancer.go b/discovery/balancer/random/balancer.go index 4fa383a..fcfbdb8 100644 --- a/discovery/balancer/random/balancer.go +++ b/discovery/balancer/random/balancer.go @@ -2,132 +2,49 @@ package random import ( "context" - "fmt" - "math/rand" + "math/rand/v2" "sync" - "time" "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/simple" "github.com/upfluence/pkg/v2/discovery/peer" "github.com/upfluence/pkg/v2/discovery/resolver" ) +// Rand is the interface consumed by the random picker, allowing tests to +// inject a deterministic source. type Rand interface { - Intn(int) int + IntN(int) int } -type Balancer[T peer.Peer] struct { - *resolver.Puller[T] - - peers []T - peersMu *sync.RWMutex - rand Rand - - notifier chan interface{} - closeFn func() +// lockedRand wraps a *rand.Rand with a mutex so it is safe for concurrent use. +type lockedRand struct { + mu sync.Mutex + r *rand.Rand } -func NewBalancer[T peer.Peer](r resolver.Resolver[T]) *Balancer[T] { - var b = &Balancer[T]{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - peersMu: &sync.RWMutex{}, - notifier: make(chan interface{}), - } - - b.Puller, b.closeFn = resolver.NewPuller(r, b.updatePeers) - - return b -} +func (lr *lockedRand) IntN(n int) int { + lr.mu.Lock() + v := lr.r.IntN(n) + lr.mu.Unlock() -func (b *Balancer[T]) String() string { - return fmt.Sprintf("loadbalancer/random [resolver: %v]", b.Puller) + return v } -func (b *Balancer[T]) updatePeers(u resolver.Update[T]) { - b.peersMu.Lock() - defer b.peersMu.Unlock() - - var newPeers = make(map[T]interface{}) - - for _, p := range b.peers { - var found bool - - for _, peer := range u.Deletions { - if p.Addr() == peer.Addr() { - found = true - } - } - - if !found { - newPeers[p] = nil - } - } - - for _, p := range u.Additions { - var found bool - - for _, peer := range b.peers { - if p.Addr() == peer.Addr() { - found = true - } - } - - if !found { - newPeers[p] = nil - } - } - - var ( - i = 0 - empty = len(b.peers) == 0 - ) - - b.peers = make([]T, len(newPeers)) - - for p, _ := range newPeers { - b.peers[i] = p - i++ - } - - if empty && (len(b.peers) > 0) { - for { - select { - case <-b.notifier: - default: - return - } - } - } +type picker[T peer.Peer] struct { + rand Rand } -func (b *Balancer[T]) hasPeers() bool { - b.peersMu.RLock() - defer b.peersMu.RUnlock() - - return len(b.peers) > 0 +func (p *picker[T]) Pick(_ context.Context, peers []T) (T, error) { + return peers[p.rand.IntN(len(peers))], nil } -func (b *Balancer[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { - var zero T - - if !b.hasPeers() { - if opts.NoWait { - return zero, nil, balancer.ErrNoPeerAvailable - } - - select { - case b.notifier <- true: - case <-ctx.Done(): - return zero, nil, ctx.Err() - } - } - - b.peersMu.RLock() - defer b.peersMu.RUnlock() - return b.peers[b.rand.Intn(len(b.peers))], func(error) {}, nil +func NewPolicy[T peer.Peer]() balancer.Policy[T] { + return simple.NewPolicy(&picker[T]{ + rand: &lockedRand{r: rand.New(rand.NewPCG(rand.Uint64(), rand.Uint64()))}, //nolint:gosec + }) } -func (b *Balancer[T]) Close() error { - b.closeFn() - return nil +func NewBalancer[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { + return balancer.WrapPolicy(r, NewPolicy[T](), func(p T) (T, error) { return p, nil }) } diff --git a/discovery/balancer/random/balancer_test.go b/discovery/balancer/random/balancer_test.go new file mode 100644 index 0000000..4287856 --- /dev/null +++ b/discovery/balancer/random/balancer_test.go @@ -0,0 +1,48 @@ +package random + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +func TestPolicy(t *testing.T) { + balancertest.PolicyTest(t, func() balancer.Policy[static.Peer] { + return NewPolicy[static.Peer]() + }) +} + +func TestBalancerWithPeers(t *testing.T) { + ctx := context.Background() + b := NewBalancer( + static.NewResolverFromStrings([]string{"localhost:0", "localhost:1", "localhost:2"}), + ) + + err := b.Open(ctx) + require.NoError(t, err) + + seen := make(map[string]int) + + for range 100 { + p, done, err := b.Get(ctx, balancer.GetOptions{}) + + require.NoError(t, err) + assert.NotEmpty(t, p.Addr()) + seen[p.Addr()]++ + + done(nil) + } + + // With 100 requests across 3 peers, all should be selected at least once + assert.Contains(t, seen, "localhost:0") + assert.Contains(t, seen, "localhost:1") + assert.Contains(t, seen, "localhost:2") + + b.Close() +} diff --git a/discovery/balancer/roundrobin/balancer.go b/discovery/balancer/roundrobin/balancer.go index 23dbe11..ff44406 100644 --- a/discovery/balancer/roundrobin/balancer.go +++ b/discovery/balancer/roundrobin/balancer.go @@ -1,122 +1,29 @@ package roundrobin import ( - "container/ring" "context" - "fmt" - "sync" + "sync/atomic" "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/simple" "github.com/upfluence/pkg/v2/discovery/peer" "github.com/upfluence/pkg/v2/discovery/resolver" ) -type Balancer[T peer.Peer] struct { - resolver.Puller[T] - - addrs map[string]*ring.Ring - ring *ring.Ring - ringMu sync.RWMutex - - notifier chan struct{} -} - -func BalancerFunc[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { - return NewBalancer[T](r) +type picker[T peer.Peer] struct { + index atomic.Uint64 } -func NewBalancer[T peer.Peer](r resolver.Resolver[T]) *Balancer[T] { - var b = Balancer[T]{ - addrs: make(map[string]*ring.Ring), - notifier: make(chan struct{}), - } - - b.Puller = resolver.Puller[T]{Resolver: r, UpdateFunc: b.updateRing} - - return &b -} +func (p *picker[T]) Pick(_ context.Context, peers []T) (T, error) { + idx := p.index.Add(1) - 1 -func (b *Balancer[T]) String() string { - return fmt.Sprintf("loadbalancer/roundrobin [resolver: %v]", &b.Puller) + return peers[idx%uint64(len(peers))], nil } -func (b *Balancer[T]) updateRing(update resolver.Update[T]) { - b.ringMu.Lock() - defer b.ringMu.Unlock() - - wasEmpty := b.ring == nil - - for _, p := range update.Additions { - r := &ring.Ring{Value: p} - b.addrs[p.Addr()] = r - - if b.ring == nil { - b.ring = r - continue - } - - b.ring.Link(r) - } - - for _, p := range update.Deletions { - addr := p.Addr() - r, ok := b.addrs[addr] - - if !ok { - continue - } - - delete(b.addrs, addr) - - if p := r.Prev(); p != nil { - b.ring = p.Unlink(1) - continue - } - - b.ring = nil - } - - isEmpty := b.ring == nil - - if wasEmpty && !isEmpty { - close(b.notifier) - } else if !wasEmpty && isEmpty { - b.notifier = make(chan struct{}) - } +func NewPolicy[T peer.Peer]() balancer.Policy[T] { + return simple.NewPolicy(&picker[T]{}) } -func (b *Balancer[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { - var zero T - - b.ringMu.RLock() - r := b.ring - n := b.notifier - b.ringMu.RUnlock() - - if r == nil { - if opts.NoWait { - return zero, nil, balancer.ErrNoPeerAvailable - } - - pctx := b.Puller.Monitor.Context() - - select { - case <-n: - case <-ctx.Done(): - return zero, nil, ctx.Err() - case <-pctx.Done(): - return zero, nil, pctx.Err() - } - } - - b.ringMu.Lock() - defer b.ringMu.Unlock() - - if v := b.ring.Value; v != nil { - b.ring = b.ring.Next() - - return v.(T), func(error) {}, nil - } - - return zero, nil, balancer.ErrNoPeerAvailable +func NewBalancer[T peer.Peer](r resolver.Resolver[T]) balancer.Balancer[T] { + return balancer.PolicyBalancerFunc(NewPolicy[T]())(r) } diff --git a/discovery/balancer/roundrobin/balancer_test.go b/discovery/balancer/roundrobin/balancer_test.go index 5010084..4f92f81 100644 --- a/discovery/balancer/roundrobin/balancer_test.go +++ b/discovery/balancer/roundrobin/balancer_test.go @@ -5,66 +5,57 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" "github.com/upfluence/pkg/v2/discovery/resolver/static" ) -func TestBalanceEmpty(t *testing.T) { - ctx := context.Background() - b := NewBalancer(&static.Resolver[static.Peer]{}) - - p, done, err := b.Get(ctx, balancer.GetOptions{NoWait: true}) - - assert.Empty(t, p.Addr()) - assert.Nil(t, done) - assert.Equal(t, balancer.ErrNoPeerAvailable, err) - - cctx, cancel := context.WithCancel(ctx) - cancel() - - p, done, err = b.Get(cctx, balancer.GetOptions{}) - - assert.Empty(t, p.Addr()) - assert.Nil(t, done) - assert.Equal(t, err, context.Canceled) - - err = b.Close() - assert.NoError(t, err) - - p, done, err = b.Get(ctx, balancer.GetOptions{}) - - assert.Empty(t, p.Addr()) - assert.Nil(t, done) - assert.Equal(t, err, context.Canceled) +func TestPolicy(t *testing.T) { + balancertest.PolicyTest(t, func() balancer.Policy[static.Peer] { + return NewPolicy[static.Peer]() + }) } -func TestBalanceWithPerrs(t *testing.T) { +func TestBalanceRoundRobinOrder(t *testing.T) { ctx := context.Background() b := NewBalancer( - static.NewResolverFromStrings([]string{"localhost:0", "localhost:1"}), + static.NewResolverFromStrings([]string{"localhost:0", "localhost:1", "localhost:2"}), ) err := b.Open(ctx) - assert.Nil(t, err) + require.NoError(t, err) + + // Collect the first full cycle. The initial Get blocks until the Puller + // goroutine delivers peers, so this also acts as the synchronisation point. + // The policy preserves insertion order, so subsequent cycles are identical. + var firstCycle [3]string + + for i := range 3 { + p, done, err := b.Get(ctx, balancer.GetOptions{}) + + require.NoError(t, err) - p, done, err := b.Get(ctx, balancer.GetOptions{}) - done(nil) + firstCycle[i] = p.Addr() - assert.Nil(t, err) - assert.Equal(t, "localhost:0", p.Addr()) + done(nil) + } - p, done, err = b.Get(ctx, balancer.GetOptions{}) - done(nil) + // All three peers must appear in the first cycle. + assert.ElementsMatch(t, []string{"localhost:0", "localhost:1", "localhost:2"}, firstCycle[:]) - assert.Nil(t, err) - assert.Equal(t, "localhost:1", p.Addr()) + // Subsequent cycles must repeat in the exact same order. + for range 2 { + for i := range 3 { + p, done, err := b.Get(ctx, balancer.GetOptions{}) - p, done, err = b.Get(ctx, balancer.GetOptions{}) - done(nil) + require.NoError(t, err) + assert.Equal(t, firstCycle[i], p.Addr()) - assert.Nil(t, err) - assert.Equal(t, "localhost:0", p.Addr()) + done(nil) + } + } b.Close() } diff --git a/discovery/balancer/simple/doc.go b/discovery/balancer/simple/doc.go new file mode 100644 index 0000000..daec72f --- /dev/null +++ b/discovery/balancer/simple/doc.go @@ -0,0 +1,15 @@ +// Package simple provides a flexible balancer policy that delegates peer +// selection to a Picker implementation. +// +// The Picker interface allows for custom balancing strategies without +// implementing the full Policy interface: +// +// type CustomPicker struct{} +// +// func (p *CustomPicker) Pick(ctx context.Context, peers []peer.Peer) (peer.Peer, error) { +// // Custom selection logic +// return peers[0], nil +// } +// +// policy := simple.NewPolicy[*peer.Peer](new(CustomPicker)) +package simple diff --git a/discovery/balancer/simple/policy.go b/discovery/balancer/simple/policy.go new file mode 100644 index 0000000..b2320a3 --- /dev/null +++ b/discovery/balancer/simple/policy.go @@ -0,0 +1,103 @@ +package simple + +import ( + "context" + "slices" + "sync" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +// Picker selects a peer from a list of available peers. +type Picker[T peer.Peer] interface { + // Pick selects a peer from the provided list. The implementation should + // return an error if no suitable peer can be selected. + Pick(context.Context, []T) (T, error) +} + +type policy[T peer.Peer] struct { + picker Picker[T] + + mu sync.RWMutex + peers []T + // peerSet is an auxiliary index for O(1) presence checks; the + // authoritative ordering lives in peers. + peerSet map[string]struct{} + notifier chan struct{} +} + +// NewPolicy creates a new Policy that delegates peer selection to the provided Picker. +func NewPolicy[T peer.Peer](picker Picker[T]) balancer.Policy[T] { + return &policy[T]{ + picker: picker, + peerSet: make(map[string]struct{}), + notifier: make(chan struct{}), + } +} + +func (p *policy[T]) Update(u resolver.Update[T]) { + p.mu.Lock() + defer p.mu.Unlock() + + wasEmpty := len(p.peers) == 0 + + // Apply deletions: remove from both the set and the ordered slice. + for _, peer := range u.Deletions { + addr := peer.Addr() + + if _, ok := p.peerSet[addr]; ok { + delete(p.peerSet, addr) + p.peers = slices.DeleteFunc(p.peers, func(q T) bool { + return q.Addr() == addr + }) + } + } + + // Apply additions: append new peers while preserving existing order. + for _, peer := range u.Additions { + addr := peer.Addr() + + if _, ok := p.peerSet[addr]; !ok { + p.peerSet[addr] = struct{}{} + p.peers = append(p.peers, peer) + } + } + + if wasEmpty && len(p.peers) > 0 { + close(p.notifier) + p.notifier = make(chan struct{}) + } +} + +func (p *policy[T]) Get(ctx context.Context, opts balancer.GetOptions) (T, func(error), error) { + var zero T + + for { + p.mu.RLock() + notifier := p.notifier + peers := slices.Clone(p.peers) + p.mu.RUnlock() + + if len(peers) == 0 { + if opts.NoWait { + return zero, nil, balancer.ErrNoPeerAvailable + } + + select { + case <-notifier: + // Notifier closed, peers may be available now. + // Loop back to re-check in case peers were removed + // between the notification and now. + continue + case <-ctx.Done(): + return zero, nil, ctx.Err() + } + } + + peer, err := p.picker.Pick(ctx, peers) + + return peer, func(error) {}, err + } +} diff --git a/discovery/balancer/simple/policy_test.go b/discovery/balancer/simple/policy_test.go new file mode 100644 index 0000000..902f790 --- /dev/null +++ b/discovery/balancer/simple/policy_test.go @@ -0,0 +1,166 @@ +package simple + +import ( + "context" + "errors" + "runtime" + "sync" + "testing" + + "github.com/upfluence/pkg/v2/discovery/balancer" + "github.com/upfluence/pkg/v2/discovery/balancer/balancertest" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +func TestPolicy(t *testing.T) { + balancertest.PolicyTest(t, func() balancer.Policy[static.Peer] { + return NewPolicy(&roundRobinPicker{}) + }) +} + +// roundRobinPicker cycles through peers sequentially +type roundRobinPicker struct { + mu sync.Mutex + index int +} + +func (p *roundRobinPicker) Pick(_ context.Context, peers []static.Peer) (static.Peer, error) { + if len(peers) == 0 { + return static.Peer(""), errors.New("no peers available") + } + + p.mu.Lock() + idx := p.index % len(peers) + p.index++ + p.mu.Unlock() + + return peers[idx], nil +} + +func TestPickerDelegation(t *testing.T) { + policy := NewPolicy(&lastPicker{}) + + peers := []static.Peer{ + static.Peer("peer1"), + static.Peer("peer2"), + static.Peer("peer3"), + } + + // Update is synchronous: peers are visible to Get immediately after return. + policy.Update(resolver.Update[static.Peer]{Additions: peers}) + + peer, _, err := policy.Get(context.Background(), balancer.GetOptions{NoWait: true}) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + + found := false + + for _, p := range peers { + if peer.Addr() == p.Addr() { + found = true + + break + } + } + + if !found { + t.Errorf("Get() returned unexpected peer %q", peer.Addr()) + } +} + +func TestPickerError(t *testing.T) { + policy := NewPolicy(&errorPicker{}) + + peers := []static.Peer{static.Peer("peer1")} + // Update is synchronous: no sleep needed before Get. + policy.Update(resolver.Update[static.Peer]{Additions: peers}) + + _, _, err := policy.Get(context.Background(), balancer.GetOptions{NoWait: true}) + if err == nil { + t.Fatal("Get() expected error, got nil") + } + + if err.Error() != "picker error" { + t.Errorf("Get() error = %q, want %q", err.Error(), "picker error") + } +} + +// lastPicker picks the last peer from the list +type lastPicker struct{} + +func (p *lastPicker) Pick(_ context.Context, peers []static.Peer) (static.Peer, error) { + if len(peers) == 0 { + return static.Peer(""), errors.New("no peers available") + } + + return peers[len(peers)-1], nil +} + +// errorPicker always returns an error +type errorPicker struct{} + +func (p *errorPicker) Pick(_ context.Context, _ []static.Peer) (static.Peer, error) { + return static.Peer(""), errors.New("picker error") +} + +// TestRaceConditionPeerRemovalAfterWakeup verifies that Get() retries when +// peers are removed between the notifier being closed and Get() re-reading +// the peer list. +func TestRaceConditionPeerRemovalAfterWakeup(t *testing.T) { + policy := NewPolicy(&roundRobinPicker{}) + ctx := t.Context() + + // started is closed just before the goroutine enters Get's select, giving + // the test a deterministic signal to proceed rather than sleeping. + started := make(chan struct{}) + gotPeer := make(chan static.Peer, 1) + gotErr := make(chan error, 1) + + go func() { + close(started) + + peer, _, err := policy.Get(ctx, balancer.GetOptions{}) + if err != nil { + gotErr <- err + } else { + gotPeer <- peer + } + }() + + // Wait until the goroutine has been scheduled, then yield once more so it + // reaches the select inside Get before we send any updates. + <-started + runtime.Gosched() + + // Add a peer — closes the notifier, waking the goroutine. + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Immediately remove it to exercise the retry path: the goroutine may have + // woken from the notifier but not yet re-read the peer slice. + policy.Update(resolver.Update[static.Peer]{ + Deletions: []static.Peer{static.Peer("localhost:1")}, + }) + + // Re-add a peer so Get() can eventually succeed. + // No sleep needed: Update() is synchronous and the goroutine will pick up + // the new notifier on its next iteration. + policy.Update(resolver.Update[static.Peer]{ + Additions: []static.Peer{static.Peer("localhost:2")}, + }) + + select { + case peer := <-gotPeer: + addr := peer.Addr() + if addr != "localhost:1" && addr != "localhost:2" { + t.Errorf("Get() returned unexpected peer %q, want localhost:1 or localhost:2", addr) + } + case err := <-gotErr: + t.Fatalf("Get() returned error %v, expected to retry and succeed", err) + case <-ctx.Done(): + t.Fatal("Get() did not complete before test deadline") + } +} diff --git a/discovery/resolver/filter/resolver.go b/discovery/resolver/filter/resolver.go new file mode 100644 index 0000000..0527785 --- /dev/null +++ b/discovery/resolver/filter/resolver.go @@ -0,0 +1,99 @@ +package filter + +import ( + "context" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +type filterResolver[T peer.Peer] struct { + inner resolver.Resolver[T] + allow func(T) bool +} + +func WrapResolver[T peer.Peer](r resolver.Resolver[T], allow func(T) bool) resolver.Resolver[T] { + return &filterResolver[T]{inner: r, allow: allow} +} + +func (r *filterResolver[T]) Open(ctx context.Context) error { + return r.inner.Open(ctx) +} + +func (r *filterResolver[T]) Close() error { + return r.inner.Close() +} + +func (r *filterResolver[T]) Resolve() resolver.Watcher[T] { + return &watcher[T]{inner: r.inner.Resolve(), allow: r.allow} +} + +type watcher[T peer.Peer] struct { + inner resolver.Watcher[T] + allow func(T) bool + admitted map[string]struct{} +} + +func (w *watcher[T]) Close() error { + return w.inner.Close() +} + +func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { + // When NoWait is requested we only attempt one read from the inner watcher. + // If that read is empty (ErrNoUpdates) or entirely filtered out, we return + // ErrNoUpdates immediately. We must not keep consuming inner updates in a + // loop under NoWait — doing so would silently drain the inner channel and + // discard real updates the caller may want to inspect later. + if opts.NoWait { + u, err := w.inner.Next(ctx, opts) + if err != nil { + return resolver.Update[T]{}, err + } + + filtered := w.filter(u) + if len(filtered.Additions) == 0 && len(filtered.Deletions) == 0 { + return resolver.Update[T]{}, resolver.ErrNoUpdates + } + + return filtered, nil + } + + // Blocking path: loop until we get at least one non-empty filtered update. + for { + u, err := w.inner.Next(ctx, opts) + if err != nil { + return resolver.Update[T]{}, err + } + + filtered := w.filter(u) + if len(filtered.Additions) == 0 && len(filtered.Deletions) == 0 { + continue + } + + return filtered, nil + } +} + +func (w *watcher[T]) filter(u resolver.Update[T]) resolver.Update[T] { + var out resolver.Update[T] + + for _, p := range u.Additions { + if w.allow(p) { + if w.admitted == nil { + w.admitted = make(map[string]struct{}) + } + + w.admitted[p.Addr()] = struct{}{} + out.Additions = append(out.Additions, p) + } + } + + for _, p := range u.Deletions { + if _, ok := w.admitted[p.Addr()]; ok { + delete(w.admitted, p.Addr()) + out.Deletions = append(out.Deletions, p) + } + } + + return out +} diff --git a/discovery/resolver/filter/resolver_test.go b/discovery/resolver/filter/resolver_test.go new file mode 100644 index 0000000..7d6c192 --- /dev/null +++ b/discovery/resolver/filter/resolver_test.go @@ -0,0 +1,102 @@ +package filter + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +func TestResolver(t *testing.T) { + resolvertest.ResolverTest(t, func(peers []static.Peer) (resolver.Resolver[static.Peer], []static.Peer) { + inner := static.NewResolver(peers) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "localhost") + }) + + // Filter the expected peers + var expected []static.Peer + + for _, p := range peers { + if strings.HasPrefix(p.Addr(), "localhost") { + expected = append(expected, p) + } + } + + return r, expected + }, static.PeersFromStrings) +} + +func TestFilterResolverAllowsAdditions(t *testing.T) { + ctx := context.Background() + + inner := static.NewResolverFromStrings([]string{"allow:1", "deny:1"}) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "allow") + }) + + w := r.Resolve() + + u, err := w.Next(ctx, resolver.ResolveOptions{}) + + require.NoError(t, err) + assert.Equal(t, []static.Peer{static.Peer("allow:1")}, u.Additions) + assert.Empty(t, u.Deletions) + + u, err = w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Equal(t, resolver.Update[static.Peer]{}, u) +} + +func TestFilterResolverTracksDeletions(t *testing.T) { + ctx := context.Background() + + inner := static.NewResolverFromStrings([]string{"allow:1", "deny:1"}) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "allow") + }) + + w := r.Resolve() + + _, err := w.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + + inner.UpdatePeers(static.PeersFromStrings("allow:2", "deny:2")) + + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + require.NoError(t, err) + assert.ElementsMatch(t, []static.Peer{static.Peer("allow:2")}, u.Additions) + assert.ElementsMatch(t, []static.Peer{static.Peer("allow:1")}, u.Deletions) +} + +func TestFilterResolverNoWaitFilteredEmpty(t *testing.T) { + ctx := context.Background() + + inner := static.NewResolverFromStrings([]string{"deny:1"}) + r := WrapResolver(inner, func(p static.Peer) bool { + return strings.HasPrefix(p.Addr(), "allow") + }) + + w := r.Resolve() + + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Equal(t, resolver.Update[static.Peer]{}, u) + + inner.UpdatePeers(static.PeersFromStrings("allow:1")) + + u, err = w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + + require.NoError(t, err) + assert.ElementsMatch(t, []static.Peer{static.Peer("allow:1")}, u.Additions) + assert.Empty(t, u.Deletions) +} diff --git a/discovery/resolver/puller.go b/discovery/resolver/puller.go index bbcaa7c..386daed 100644 --- a/discovery/resolver/puller.go +++ b/discovery/resolver/puller.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "github.com/upfluence/errors" @@ -18,25 +19,34 @@ type Puller[T peer.Peer] struct { Monitor closer.Monitor NoWait bool - openErr error - openOnce sync.Once + openErr error + openOnce sync.Once + closeOnce sync.Once + opened atomic.Bool } -func NewPuller[T peer.Peer](r Resolver[T], fn func(Update[T])) (*Puller[T], func()) { +func NewPuller[T peer.Peer](r Resolver[T], fn func(Update[T])) (*Puller[T], func() error) { var p = &Puller[T]{ Resolver: r, UpdateFunc: fn, } - return p, func() { p.Close() } + return p, p.Close } func (p *Puller[T]) Close() error { - return errors.Combine(p.Monitor.Close(), p.Resolver.Close()) + var err error + + p.closeOnce.Do(func() { + p.opened.Store(false) + err = errors.Combine(p.Monitor.Close(), p.Resolver.Close()) + }) + + return err } func (p *Puller[T]) IsOpen() bool { - return p.openErr == nil && p.openOnce != sync.Once{} + return p.opened.Load() } func (p *Puller[T]) String() string { @@ -48,6 +58,7 @@ func (p *Puller[T]) Open(ctx context.Context) error { p.openErr = p.Resolver.Open(ctx) if p.openErr == nil { + p.opened.Store(true) p.Monitor.Run(p.pull) } }) @@ -65,14 +76,25 @@ func (p *Puller[T]) pull(ctx context.Context) { ) for { + err = nil w = p.Resolver.Resolve() for err == nil { u, err = w.Next(ctx, ResolveOptions{NoWait: noWait}) - if err == nil || err == ErrNoUpdates { + if err == nil { noWait = false + p.UpdateFunc(u) + + continue + } + + if errors.Is(err, ErrNoUpdates) { + // No update available right now; reset noWait and keep going + // without calling UpdateFunc with a meaningless empty update. + noWait = false + continue } diff --git a/discovery/resolver/puller_test.go b/discovery/resolver/puller_test.go new file mode 100644 index 0000000..acb5862 --- /dev/null +++ b/discovery/resolver/puller_test.go @@ -0,0 +1,99 @@ +package resolver_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/static" +) + +type transientErrorStaticResolver struct { + inner *static.Resolver[static.Peer] + failed atomic.Bool + failWith error +} + +func newTransientErrorStaticResolver(peers []string, err error) *transientErrorStaticResolver { + return &transientErrorStaticResolver{ + inner: static.NewResolverFromStrings(peers), + failWith: err, + } +} + +func (r *transientErrorStaticResolver) Open(ctx context.Context) error { + return r.inner.Open(ctx) +} + +func (r *transientErrorStaticResolver) Close() error { + return r.inner.Close() +} + +func (r *transientErrorStaticResolver) Resolve() resolver.Watcher[static.Peer] { + return &transientErrorWatcher{ + inner: r.inner.Resolve(), + failed: &r.failed, + failWith: r.failWith, + } +} + +func (r *transientErrorStaticResolver) UpdatePeers(peers []static.Peer) { + r.inner.UpdatePeers(peers) +} + +type transientErrorWatcher struct { + inner resolver.Watcher[static.Peer] + failed *atomic.Bool + failWith error +} + +func (w *transientErrorWatcher) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[static.Peer], error) { + if w.failed.CompareAndSwap(false, true) { + return resolver.Update[static.Peer]{}, w.failWith + } + + return w.inner.Next(ctx, opts) +} + +func (w *transientErrorWatcher) Close() error { + return w.inner.Close() +} + +func TestPullerRecoversAfterWatcherError(t *testing.T) { + r := newTransientErrorStaticResolver([]string{"allow:1"}, errors.New("boom")) + + updates := make(chan resolver.Update[static.Peer], 1) + p, _ := resolver.NewPuller(r, func(u resolver.Update[static.Peer]) { + if len(u.Additions) > 0 || len(u.Deletions) > 0 { + updates <- u + } + }) + + require.NoError(t, p.Open(context.Background())) + + select { + case u := <-updates: + assert.Equal(t, []static.Peer{static.Peer("allow:1")}, u.Additions) + case <-t.Context().Done(): + t.Fatal("timed out waiting for update after watcher error") + } + + assert.NoError(t, p.Close()) +} + +func TestPullerIsOpenTracksClose(t *testing.T) { + r := static.NewResolverFromStrings([]string{"allow:1"}) + p, _ := resolver.NewPuller(r, func(resolver.Update[static.Peer]) {}) + assert.False(t, p.IsOpen()) + + require.NoError(t, p.Open(context.Background())) + assert.True(t, p.IsOpen()) + + require.NoError(t, p.Close()) + assert.False(t, p.IsOpen()) +} diff --git a/discovery/resolver/resolvertest/doc.go b/discovery/resolver/resolvertest/doc.go new file mode 100644 index 0000000..f4334e7 --- /dev/null +++ b/discovery/resolver/resolvertest/doc.go @@ -0,0 +1,18 @@ +// Package resolvertest provides reusable test helpers for resolver implementations. +// +// The ResolverTest function runs a comprehensive test suite that validates common +// resolver behaviors including: +// - Empty resolver (no initial peers) +// - Initial peers resolution +// - NoWait behavior with no updates +// - Context cancellation +// - Watcher closure +// +// Example usage: +// +// func TestMyResolver(t *testing.T) { +// resolvertest.ResolverTest(t, func() resolver.Resolver[static.Peer] { +// return myresolver.New() +// }) +// } +package resolvertest diff --git a/discovery/resolver/resolvertest/resolver.go b/discovery/resolver/resolvertest/resolver.go new file mode 100644 index 0000000..6e640d3 --- /dev/null +++ b/discovery/resolver/resolvertest/resolver.go @@ -0,0 +1,154 @@ +package resolvertest + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +// ResolverFactory creates a new Resolver instance for testing. +// The argument is the peers that should pre-exist. +// Returns the resolver and the peers that should actually be returned (after filtering/transformation). +type ResolverFactory[T peer.Peer] func([]T) (resolver.Resolver[T], []T) + +// ResolverTest runs a comprehensive test suite for a Resolver implementation. +// It tests common resolver behaviors including: +// - Empty resolver (no initial peers) +// - Initial peers resolution +// - NoWait behavior with no updates +// - Context cancellation +// - Watcher closure +func ResolverTest[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + for _, tt := range []struct { + name string + test func(*testing.T, ResolverFactory[T], func(...string) []T) + }{ + {"NoSeeds", testNoSeeds[T]}, + {"InitialPeers", testInitialPeers[T]}, + {"NoWaitNoUpdates", testNoWaitNoUpdates[T]}, + {"ContextCancellation", testContextCancellation[T]}, + {"WatcherClose", testWatcherClose[T]}, + } { + t.Run(tt.name, func(t *testing.T) { + tt.test(t, factory, makePeers) + }) + } +} + +func testNoSeeds[T peer.Peer](t *testing.T, factory ResolverFactory[T], _ func(...string) []T) { + ctx := context.Background() + r, expected := factory(nil) + + assert.Empty(t, expected) + + require.NoError(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // NoWait should return ErrNoUpdates immediately when no peers exist + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Empty(t, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testInitialPeers[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1", "localhost:2") + r, expected := factory(peers) + + require.NoError(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // Should get initial peers + u, err := w.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + assert.ElementsMatch(t, expected, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testNoWaitNoUpdates[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1") + r, _ := factory(peers) + + require.NoError(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // Consume initial update + _, err := w.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + + // NoWait should return ErrNoUpdates when no updates are available + u, err := w.Next(ctx, resolver.ResolveOptions{NoWait: true}) + assert.Equal(t, resolver.ErrNoUpdates, err) + assert.Empty(t, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testContextCancellation[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1") + r, _ := factory(peers) + + require.NoError(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + defer w.Close() + + // Consume initial update + _, err := w.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + + // Cancel context and try to wait for updates + cctx, cancel := context.WithCancel(ctx) + cancel() + + u, err := w.Next(cctx, resolver.ResolveOptions{}) + assert.Equal(t, context.Canceled, err) + assert.Empty(t, u.Additions) + assert.Empty(t, u.Deletions) +} + +func testWatcherClose[T peer.Peer](t *testing.T, factory ResolverFactory[T], makePeers func(...string) []T) { + ctx := context.Background() + peers := makePeers("localhost:1") + r, _ := factory(peers) + + require.NoError(t, r.Open(ctx)) + defer r.Close() + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + + // Close the watcher + err = w.Close() + require.NoError(t, err) + + // Trying to get next update should fail after close + // Give a small timeout to avoid blocking forever + cctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + + _, err = w.Next(cctx, resolver.ResolveOptions{}) + assert.Error(t, err) + // Should be either context.Canceled, context.DeadlineExceeded, or a close error +} diff --git a/discovery/resolver/static/resolver.go b/discovery/resolver/static/resolver.go index cdf8b4c..f75c06b 100644 --- a/discovery/resolver/static/resolver.go +++ b/discovery/resolver/static/resolver.go @@ -3,7 +3,7 @@ package static import ( "context" "fmt" - "sync/atomic" + "sync" "github.com/upfluence/pkg/v2/closer" "github.com/upfluence/pkg/v2/discovery/peer" @@ -14,7 +14,7 @@ import ( type Builder[T peer.Peer] map[string][]T func (b Builder[T]) Build(n string) resolver.Resolver[T] { - return &Resolver[T]{Peers: b[n]} + return NewResolver(b[n]) } func PeersFromStrings(addrs ...string) []Peer { @@ -30,7 +30,10 @@ func PeersFromStrings(addrs ...string) []Peer { type Resolver[T peer.Peer] struct { closer.Monitor - Peers []T + mu sync.Mutex + chs []chan resolver.Update[T] + + peers []T } type Peer string @@ -43,13 +46,112 @@ func NewResolverFromStrings(addrs []string) *Resolver[Peer] { } func NewResolver[T peer.Peer](peers []T) *Resolver[T] { - return &Resolver[T]{Peers: peers} + return &Resolver[T]{peers: peers} +} + +func (r *Resolver[T]) Peers() []T { + r.mu.Lock() + defer r.mu.Unlock() + + out := make([]T, len(r.peers)) + copy(out, r.peers) + + return out +} + +func (r *Resolver[T]) UpdatePeers(peers []T) { + r.mu.Lock() + + old := make(map[string]T, len(r.peers)) + for _, p := range r.peers { + old[p.Addr()] = p + } + + cur := make(map[string]T, len(peers)) + for _, p := range peers { + cur[p.Addr()] = p + } + + var u resolver.Update[T] + + for addr, p := range cur { + if _, ok := old[addr]; !ok { + u.Additions = append(u.Additions, p) + } + } + + for addr, p := range old { + if _, ok := cur[addr]; !ok { + u.Deletions = append(u.Deletions, p) + } + } + + r.peers = peers + + if len(u.Additions) == 0 && len(u.Deletions) == 0 { + r.mu.Unlock() + + return + } + + chs := make([]chan resolver.Update[T], len(r.chs)) + copy(chs, r.chs) + r.mu.Unlock() + + for _, ch := range chs { + // Non-blocking send: if the watcher's buffer is already full (slow + // consumer), merge the pending update with the new one so no update + // is silently dropped and the caller is never blocked. + select { + case ch <- u: + default: + select { + case pending := <-ch: + pending.Additions = append(pending.Additions, u.Additions...) + + pending.Deletions = append(pending.Deletions, u.Deletions...) + ch <- pending + default: + // Channel was drained concurrently; just send the new update. + ch <- u + } + } + } +} + +func (r *Resolver[T]) subscribe() (chan resolver.Update[T], []T) { + ch := make(chan resolver.Update[T], 1) + + r.mu.Lock() + r.chs = append(r.chs, ch) + + peers := make([]T, len(r.peers)) + copy(peers, r.peers) + r.mu.Unlock() + + return ch, peers +} + +func (r *Resolver[T]) unsubscribe(ch chan resolver.Update[T]) { + r.mu.Lock() + defer r.mu.Unlock() + + for i, c := range r.chs { + if c == ch { + r.chs = append(r.chs[:i], r.chs[i+1:]...) + + return + } + } } func (r *Resolver[T]) String() string { - var addrs = make([]string, len(r.Peers)) + r.mu.Lock() + defer r.mu.Unlock() - for i, peer := range r.Peers { + var addrs = make([]string, len(r.peers)) + + for i, peer := range r.peers { addrs[i] = peer.Addr() } @@ -67,31 +169,62 @@ func (r *Resolver[T]) Resolve() resolver.Watcher[T] { type watcher[T peer.Peer] struct { closer.Monitor - r *Resolver[T] - initial int32 + r *Resolver[T] + + mu sync.Mutex + ch chan resolver.Update[T] + initial bool } func (w *watcher[T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { - ok := atomic.CompareAndSwapInt32(&w.initial, 0, 1) - - if opts.NoWait && (!ok || len(w.r.Peers) == 0) { - return resolver.Update[T]{}, resolver.ErrNoUpdates + w.mu.Lock() + if !w.initial { + w.initial = true + ch, peers := w.r.subscribe() + w.ch = ch + w.mu.Unlock() + + if len(peers) > 0 { + return resolver.Update[T]{Additions: peers}, nil + } + } else { + w.mu.Unlock() } - if ok && len(w.r.Peers) > 0 { - return resolver.Update[T]{Additions: w.r.Peers}, nil + if opts.NoWait { + select { + case u := <-w.ch: + return u, nil + default: + return resolver.Update[T]{}, resolver.ErrNoUpdates + } } wctx := w.Context() rctx := w.r.Context() select { + case u := <-w.ch: + return u, nil case <-ctx.Done(): return resolver.Update[T]{}, ctx.Err() case <-wctx.Done(): return resolver.Update[T]{}, wctx.Err() case <-rctx.Done(): w.Close() - return resolver.Update[T]{}, wctx.Err() + + return resolver.Update[T]{}, rctx.Err() } } + +func (w *watcher[T]) Close() error { + w.mu.Lock() + ch := w.ch + w.mu.Unlock() + + if ch != nil { + w.r.unsubscribe(ch) + } + + return w.Monitor.Close() +} diff --git a/discovery/resolver/static/resolver_test.go b/discovery/resolver/static/resolver_test.go index f1c0030..ecfeb23 100644 --- a/discovery/resolver/static/resolver_test.go +++ b/discovery/resolver/static/resolver_test.go @@ -5,9 +5,18 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" ) +func TestResolver(t *testing.T) { + resolvertest.ResolverTest(t, func(peers []Peer) (resolver.Resolver[Peer], []Peer) { + return NewResolver(peers), peers + }, PeersFromStrings) +} + func TestResolve(t *testing.T) { ctx := context.Background() r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) @@ -16,7 +25,7 @@ func TestResolve(t *testing.T) { u, err := w.Next(ctx, resolver.ResolveOptions{}) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal( t, resolver.Update[Peer]{ @@ -34,10 +43,148 @@ func TestResolve(t *testing.T) { assert.Equal(t, resolver.Update[Peer]{}, u) err = w.Close() - assert.Nil(t, err) + require.NoError(t, err) u, err = w.Next(ctx, resolver.ResolveOptions{}) assert.Equal(t, err, context.Canceled) assert.Equal(t, resolver.Update[Peer]{}, u) } + +func TestPeers(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + + peers := r.Peers() + + assert.Equal(t, []Peer{Peer("localhost:1"), Peer("localhost:2")}, peers) + + // Mutating the returned slice should not affect the resolver + peers[0] = Peer("localhost:99") + + assert.Equal(t, []Peer{Peer("localhost:1"), Peer("localhost:2")}, r.Peers()) +} + +func TestPeersEmpty(t *testing.T) { + r := NewResolver[Peer](nil) + assert.Equal(t, []Peer{}, r.Peers()) +} + +func TestUpdatePeers(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + + w := r.Resolve() + + // Consume initial update + u, err := w.Next(context.Background(), resolver.ResolveOptions{}) + require.NoError(t, err) + assert.Equal( + t, + resolver.Update[Peer]{ + Additions: []Peer{ + Peer("localhost:1"), + Peer("localhost:2"), + }, + }, + u, + ) + + // Update peers: remove localhost:1, add localhost:3 + r.UpdatePeers(PeersFromStrings("localhost:2", "localhost:3")) + + assert.Equal( + t, + []Peer{Peer("localhost:2"), Peer("localhost:3")}, + r.Peers(), + ) + + // The watcher should receive the diff + u, err = w.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + require.NoError(t, err) + assert.ElementsMatch(t, []Peer{Peer("localhost:3")}, u.Additions) + assert.ElementsMatch(t, []Peer{Peer("localhost:1")}, u.Deletions) +} + +func TestUpdatePeersNoChange(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1", "localhost:2"}) + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(context.Background(), resolver.ResolveOptions{}) + require.NoError(t, err) + + // Update with same peers — no diff + r.UpdatePeers(PeersFromStrings("localhost:1", "localhost:2")) + + // Should have no update + _, err = w.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + assert.Equal(t, resolver.ErrNoUpdates, err) +} + +func TestUpdatePeersMultipleWatchers(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1"}) + + w1 := r.Resolve() + w2 := r.Resolve() + + // Consume initial updates + _, err := w1.Next(context.Background(), resolver.ResolveOptions{}) + require.NoError(t, err) + + _, err = w2.Next(context.Background(), resolver.ResolveOptions{}) + require.NoError(t, err) + + r.UpdatePeers(PeersFromStrings("localhost:1", "localhost:2")) + + u1, err := w1.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + require.NoError(t, err) + assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u1.Additions) + + u2, err := w2.Next(context.Background(), resolver.ResolveOptions{NoWait: true}) + require.NoError(t, err) + assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u2.Additions) +} + +func TestUpdatePeersClosedWatcher(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1"}) + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(context.Background(), resolver.ResolveOptions{}) + require.NoError(t, err) + + // Close the watcher — it should unsubscribe + err = w.Close() + require.NoError(t, err) + + // This should not block or panic + r.UpdatePeers(PeersFromStrings("localhost:2")) +} + +func TestUpdatePeersBlockingWatcher(t *testing.T) { + r := NewResolverFromStrings([]string{"localhost:1"}) + + w := r.Resolve() + + // Consume initial update + _, err := w.Next(context.Background(), resolver.ResolveOptions{}) + require.NoError(t, err) + + // Start a blocking Next call in a goroutine + done := make(chan resolver.Update[Peer], 1) + + go func() { + u, err := w.Next(context.Background(), resolver.ResolveOptions{}) + assert.NoError(t, err) + + done <- u + }() + + // Update peers — should unblock the watcher + r.UpdatePeers(PeersFromStrings("localhost:2")) + + u := <-done + assert.ElementsMatch(t, []Peer{Peer("localhost:2")}, u.Additions) + assert.ElementsMatch(t, []Peer{Peer("localhost:1")}, u.Deletions) +} diff --git a/discovery/resolver/sync_resolver.go b/discovery/resolver/sync_resolver.go index 3895fb2..d718fe1 100644 --- a/discovery/resolver/sync_resolver.go +++ b/discovery/resolver/sync_resolver.go @@ -36,7 +36,7 @@ func (sr *syncResolver[T]) ResolveSync(ctx context.Context, n string) ([]T, erro lr, ok := sr.lrs[n] if !ok { - lr = &localResolver[T]{readyc: make(chan struct{})} + lr = &localResolver[T]{readyc: make(chan struct{}), noWait: sr.noWait} lr.p = &Puller[T]{ Resolver: sr.builder.Build(n), @@ -46,7 +46,7 @@ func (sr *syncResolver[T]) ResolveSync(ctx context.Context, n string) ([]T, erro if err := lr.p.Open(ctx); err != nil { sr.mu.Unlock() - close(lr.readyc) + return nil, err } @@ -65,7 +65,7 @@ func (sr *syncResolver[T]) Close() error { for _, lr := range sr.lrs { if err := lr.close(); err != nil { - errs = append(errs) + errs = append(errs, err) } } @@ -76,7 +76,8 @@ func (sr *syncResolver[T]) Close() error { } type localResolver[T peer.Peer] struct { - p *Puller[T] + p *Puller[T] + noWait bool readyOnce sync.Once readyc chan struct{} @@ -93,11 +94,7 @@ func (lr *localResolver[T]) update(u Update[T]) { } for _, p := range u.Deletions { - addr := p.Addr() - - if _, ok := lr.ps[addr]; ok { - delete(lr.ps, addr) - } + delete(lr.ps, p.Addr()) } for _, p := range u.Additions { @@ -110,7 +107,7 @@ func (lr *localResolver[T]) update(u Update[T]) { } func (lr *localResolver[T]) close() error { - return errors.Combine(lr.p.Close()) + return lr.p.Close() } func (lr *localResolver[T]) resolve(ctx context.Context) ([]T, error) { @@ -118,10 +115,21 @@ func (lr *localResolver[T]) resolve(ctx context.Context) ([]T, error) { return nil, ErrClose } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-lr.readyc: + if lr.noWait { + // With noWait, return whatever peers are available right now without + // blocking for the first update. If readyc isn't closed yet the + // background Puller hasn't delivered anything, so return empty. + select { + case <-lr.readyc: + default: + return nil, nil + } + } else { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-lr.readyc: + } } lr.mu.RLock() diff --git a/discovery/resolver/sync_resolver_test.go b/discovery/resolver/sync_resolver_test.go index 037c2fb..13a32f7 100644 --- a/discovery/resolver/sync_resolver_test.go +++ b/discovery/resolver/sync_resolver_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/upfluence/pkg/v2/discovery/resolver" "github.com/upfluence/pkg/v2/discovery/resolver/static" @@ -23,16 +24,16 @@ func TestNameResolverWithPeers(t *testing.T) { ps, err := nr.ResolveSync(ctx, "n1") - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, static.PeersFromStrings("foo", "bar"), ps) ps, err = nr.ResolveSync(ctx, "n2") - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, static.PeersFromStrings("biz", "buz"), ps) err = nr.Close() - assert.Nil(t, err) + require.NoError(t, err) } func TestNameResolverNoPeerNoWait(t *testing.T) { @@ -41,11 +42,11 @@ func TestNameResolverNoPeerNoWait(t *testing.T) { ps, err := nr.ResolveSync(ctx, "n1") - assert.Nil(t, err) + require.NoError(t, err) assert.ElementsMatch(t, 0, len(ps)) err = nr.Close() - assert.Nil(t, err) + require.NoError(t, err) } func TestNameResolverNoPeerWait(t *testing.T) { @@ -60,5 +61,5 @@ func TestNameResolverNoPeerWait(t *testing.T) { assert.ElementsMatch(t, 0, len(ps)) err = nr.Close() - assert.Nil(t, err) + require.NoError(t, err) } diff --git a/discovery/resolver/transform/doc.go b/discovery/resolver/transform/doc.go new file mode 100644 index 0000000..28d053d --- /dev/null +++ b/discovery/resolver/transform/doc.go @@ -0,0 +1,20 @@ +// Package transform provides a resolver wrapper that transforms peers from one type to another. +// +// The transform resolver wraps an existing resolver and applies a transformation function +// to convert source peers (S) into target peers (T). This is useful for: +// - Adding metadata or wrapping peers with additional information +// - Converting between different peer implementations +// - Applying prefixes or transformations to peer addresses +// +// Example: +// +// source := static.NewResolverFromStrings([]string{"host1:80", "host2:80"}) +// +// // Wrap peers with a prefix +// tr := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { +// return wrappedPeer{addr: p.Addr(), prefix: "service-"} +// }) +// +// // The transformed resolver will emit wrappedPeer instances +// // with addresses like "service-host1:80", "service-host2:80" +package transform diff --git a/discovery/resolver/transform/resolver.go b/discovery/resolver/transform/resolver.go new file mode 100644 index 0000000..a9e6b70 --- /dev/null +++ b/discovery/resolver/transform/resolver.go @@ -0,0 +1,68 @@ +package transform + +import ( + "context" + + "github.com/upfluence/pkg/v2/discovery/peer" + "github.com/upfluence/pkg/v2/discovery/resolver" +) + +type transformResolver[S, T peer.Peer] struct { + source resolver.Resolver[S] + transform func(S) T +} + +func WrapResolver[S, T peer.Peer](r resolver.Resolver[S], fn func(S) T) resolver.Resolver[T] { + return &transformResolver[S, T]{ + source: r, + transform: fn, + } +} + +func (r *transformResolver[S, T]) Open(ctx context.Context) error { + return r.source.Open(ctx) +} + +func (r *transformResolver[S, T]) Close() error { + return r.source.Close() +} + +func (r *transformResolver[S, T]) Resolve() resolver.Watcher[T] { + return &watcher[S, T]{ + inner: r.source.Resolve(), + transform: r.transform, + } +} + +type watcher[S, T peer.Peer] struct { + inner resolver.Watcher[S] + transform func(S) T +} + +func (w *watcher[S, T]) Close() error { + return w.inner.Close() +} + +func (w *watcher[S, T]) Next(ctx context.Context, opts resolver.ResolveOptions) (resolver.Update[T], error) { + u, err := w.inner.Next(ctx, opts) + + if err != nil { + return resolver.Update[T]{}, err + } + + var result resolver.Update[T] + + result.Additions = make([]T, len(u.Additions)) + + for i, p := range u.Additions { + result.Additions[i] = w.transform(p) + } + + result.Deletions = make([]T, len(u.Deletions)) + + for i, p := range u.Deletions { + result.Deletions[i] = w.transform(p) + } + + return result, nil +} diff --git a/discovery/resolver/transform/resolver_test.go b/discovery/resolver/transform/resolver_test.go new file mode 100644 index 0000000..4a898c2 --- /dev/null +++ b/discovery/resolver/transform/resolver_test.go @@ -0,0 +1,135 @@ +package transform_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/upfluence/pkg/v2/discovery/resolver" + "github.com/upfluence/pkg/v2/discovery/resolver/resolvertest" + "github.com/upfluence/pkg/v2/discovery/resolver/static" + "github.com/upfluence/pkg/v2/discovery/resolver/transform" + "github.com/upfluence/pkg/v2/metadata" +) + +type wrappedPeer struct { + addr string + prefix string +} + +func (p wrappedPeer) Addr() string { return p.prefix + p.addr } +func (p wrappedPeer) Metadata() metadata.Metadata { return nil } + +func makeWrappedPeers(addrs ...string) []wrappedPeer { + peers := make([]wrappedPeer, len(addrs)) + for i, addr := range addrs { + peers[i] = wrappedPeer{addr: addr, prefix: "prefix-"} + } + + return peers +} + +func TestResolver(t *testing.T) { + resolvertest.ResolverTest(t, func(peers []wrappedPeer) (resolver.Resolver[wrappedPeer], []wrappedPeer) { + // Extract source peers from wrapped peers + sourcePeers := make([]static.Peer, len(peers)) + for i, p := range peers { + // Remove the prefix to get back the original address + sourcePeers[i] = static.Peer(p.addr) + } + + source := static.NewResolver(sourcePeers) + r := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { + return wrappedPeer{addr: p.Addr(), prefix: "prefix-"} + }) + + return r, peers + }, makeWrappedPeers) +} + +func TestTransformResolverTransformsUpdates(t *testing.T) { + ctx := context.Background() + source := static.NewResolverFromStrings([]string{"host1:80"}) + + tr := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { + return wrappedPeer{addr: p.Addr(), prefix: "transformed-"} + }) + + require.NoError(t, tr.Open(ctx)) + defer tr.Close() + + w := tr.Resolve() + defer w.Close() + + // Get initial peers + u, err := w.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + assert.Len(t, u.Additions, 1) + assert.Equal(t, "transformed-host1:80", u.Additions[0].Addr()) + + // Update peers + source.UpdatePeers(static.PeersFromStrings("host2:80", "host3:80")) + + u, err = w.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + + // Should have additions and deletions + assert.Len(t, u.Additions, 2) + assert.Len(t, u.Deletions, 1) + + additionAddrs := []string{u.Additions[0].Addr(), u.Additions[1].Addr()} + assert.Contains(t, additionAddrs, "transformed-host2:80") + assert.Contains(t, additionAddrs, "transformed-host3:80") + + assert.Equal(t, "transformed-host1:80", u.Deletions[0].Addr()) +} + +func TestTransformResolverMultipleWatchers(t *testing.T) { + ctx := context.Background() + source := static.NewResolverFromStrings([]string{"host1:80"}) + + tr := transform.WrapResolver(source, func(p static.Peer) wrappedPeer { + return wrappedPeer{addr: p.Addr(), prefix: "watcher-"} + }) + + require.NoError(t, tr.Open(ctx)) + defer tr.Close() + + w1 := tr.Resolve() + defer w1.Close() + + w2 := tr.Resolve() + defer w2.Close() + + // Both watchers should get initial peers + u1, err := w1.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + assert.Len(t, u1.Additions, 1) + assert.Equal(t, "watcher-host1:80", u1.Additions[0].Addr()) + + u2, err := w2.Next(ctx, resolver.ResolveOptions{}) + require.NoError(t, err) + assert.Len(t, u2.Additions, 1) + assert.Equal(t, "watcher-host1:80", u2.Additions[0].Addr()) + + // Update peers — UpdatePeers sends synchronously into the watchers' + // buffered channels, so the update is available for NoWait reads immediately. + source.UpdatePeers(static.PeersFromStrings("host2:80")) + + // Both watchers should receive the update + u1, err = w1.Next(ctx, resolver.ResolveOptions{NoWait: true}) + require.NoError(t, err) + assert.Len(t, u1.Additions, 1) + assert.Equal(t, "watcher-host2:80", u1.Additions[0].Addr()) + assert.Len(t, u1.Deletions, 1) + assert.Equal(t, "watcher-host1:80", u1.Deletions[0].Addr()) + + u2, err = w2.Next(ctx, resolver.ResolveOptions{NoWait: true}) + require.NoError(t, err) + assert.Len(t, u2.Additions, 1) + assert.Equal(t, "watcher-host2:80", u2.Additions[0].Addr()) + assert.Len(t, u2.Deletions, 1) + assert.Equal(t, "watcher-host1:80", u2.Deletions[0].Addr()) +}