Skip to content
Merged
17 changes: 17 additions & 0 deletions discovery/balancer/balancertest/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Package balancertest provides testing utilities for balancer.Policy implementations.
//
// Usage:
//
// func TestMyPolicy(t *testing.T) {
// balancertest.PolicyTest(t, func() balancer.Policy[static.Peer, static.Peer] {
// return NewPolicy[static.Peer]()
// })
// }
//
// PolicyTest runs a comprehensive test suite that covers:
// - Empty policy (no peers) - tests NoWait and context cancellation behavior
// - Single peer - verifies Get() returns the same peer consistently
// - Adding and removing peers - tests dynamic peer updates
// - Removing all peers - verifies transition back to empty state
// - Adding peers to empty policy - tests that waiting Get() calls unblock
package balancertest
199 changes: 199 additions & 0 deletions discovery/balancer/balancertest/policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package balancertest

import (
"context"
"runtime"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/upfluence/pkg/v2/discovery/balancer"
"github.com/upfluence/pkg/v2/discovery/resolver"
"github.com/upfluence/pkg/v2/discovery/resolver/static"
)

// PolicyFactory creates a new Policy instance for testing.
type PolicyFactory func() balancer.Policy[static.Peer]

// PolicyTest runs a comprehensive test suite for a Policy implementation.
// It tests common policy behaviors including:
// - Empty policy (no peers)
// - Single peer
// - Multiple peers
// - Adding and removing peers
func PolicyTest(t *testing.T, factory PolicyFactory) {
for _, tt := range []struct {
name string
test func(*testing.T, PolicyFactory)
}{
{"NoPeers", testNoPeers},
{"SinglePeer", testSinglePeer},
{"AddAndRemovePeers", testAddAndRemovePeers},
{"RemoveAllPeers", testRemoveAllPeers},
{"AddPeersToEmpty", testAddPeersToEmpty},
} {
t.Run(tt.name, func(t *testing.T) {
tt.test(t, factory)
})
}
}

func testNoPeers(t *testing.T, factory PolicyFactory) {
ctx := context.Background()
policy := factory()

// NoWait should return ErrNoPeerAvailable immediately
p, done, err := policy.Get(ctx, balancer.GetOptions{NoWait: true})
assert.Equal(t, balancer.ErrNoPeerAvailable, err)
assert.Nil(t, done)
assert.Empty(t, p.Addr())

// Canceled context should return context.Canceled
cctx, cancel := context.WithCancel(ctx)
cancel()

p, done, err = policy.Get(cctx, balancer.GetOptions{})
assert.Equal(t, context.Canceled, err)
assert.Nil(t, done)
assert.Empty(t, p.Addr())
}

func testSinglePeer(t *testing.T, factory PolicyFactory) {
ctx := context.Background()
policy := factory()

// Add a single peer
policy.Update(resolver.Update[static.Peer]{
Additions: []static.Peer{static.Peer("localhost:1")},
})

// Should get the same peer repeatedly
for range 5 {
p, done, err := policy.Get(ctx, balancer.GetOptions{})

require.NoError(t, err)
assert.NotNil(t, done)
assert.Equal(t, "localhost:1", p.Addr())
done(nil)
}
}

func testAddAndRemovePeers(t *testing.T, factory PolicyFactory) {
ctx := context.Background()
policy := factory()

// Add initial peers
policy.Update(resolver.Update[static.Peer]{
Additions: []static.Peer{
static.Peer("localhost:1"),
static.Peer("localhost:2"),
},
})

// Verify we can get peers
seen := make(map[string]bool)

for range 50 {
p, done, err := policy.Get(ctx, balancer.GetOptions{})

require.NoError(t, err)
assert.NotNil(t, done)

seen[p.Addr()] = true

done(nil)
}

assert.Contains(t, seen, "localhost:1")
assert.Contains(t, seen, "localhost:2")

// Update peers: remove localhost:1, add localhost:3
policy.Update(resolver.Update[static.Peer]{
Additions: []static.Peer{static.Peer("localhost:3")},
Deletions: []static.Peer{static.Peer("localhost:1")},
})

// Verify we only see localhost:2 and localhost:3
seen = make(map[string]bool)

for range 50 {
p, done, err := policy.Get(ctx, balancer.GetOptions{})

require.NoError(t, err)
assert.NotNil(t, done)

seen[p.Addr()] = true

done(nil)
}

assert.Contains(t, seen, "localhost:2")
assert.Contains(t, seen, "localhost:3")
assert.NotContains(t, seen, "localhost:1")
}

func testRemoveAllPeers(t *testing.T, factory PolicyFactory) {
ctx := context.Background()
policy := factory()

// Add a peer
policy.Update(resolver.Update[static.Peer]{
Additions: []static.Peer{static.Peer("localhost:1")},
})

// Verify we can get it
p, done, err := policy.Get(ctx, balancer.GetOptions{})
require.NoError(t, err)
assert.Equal(t, "localhost:1", p.Addr())
done(nil)

// Remove the peer
policy.Update(resolver.Update[static.Peer]{
Deletions: []static.Peer{static.Peer("localhost:1")},
})

// NoWait should return ErrNoPeerAvailable
p, done, err = policy.Get(ctx, balancer.GetOptions{NoWait: true})
assert.Equal(t, balancer.ErrNoPeerAvailable, err)
assert.Nil(t, done)
assert.Empty(t, p.Addr())
}

func testAddPeersToEmpty(t *testing.T, factory PolicyFactory) {
ctx := t.Context()
policy := factory()

// started is closed just before the goroutine enters Get's select,
// giving us a deterministic signal instead of a sleep.
started := make(chan struct{})
done := make(chan struct{})

go func() {
close(started)

p, doneFn, err := policy.Get(ctx, balancer.GetOptions{})
assert.NoError(t, err)
assert.NotNil(t, doneFn)
assert.NotEmpty(t, p.Addr())
doneFn(nil)
close(done)
}()

// Wait until the goroutine has been scheduled, then yield once more so
// it reaches the select inside Get before we call Update.
<-started
runtime.Gosched()

// Add a peer — this closes the notifier and unblocks Get.
policy.Update(resolver.Update[static.Peer]{
Additions: []static.Peer{static.Peer("localhost:1")},
})

select {
case <-done:
// success
case <-ctx.Done():
t.Fatal("Get() did not unblock after adding peers")
}
}
27 changes: 18 additions & 9 deletions discovery/balancer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net"
"sync"
"sync/atomic"

"github.com/upfluence/errors"

Expand All @@ -16,16 +17,23 @@ type Dialer[T peer.Peer] struct {
Dialer *net.Dialer
Options GetOptions

dialerOnce sync.Once
netDialer *net.Dialer

mu sync.Mutex
lds map[string]*localDialer[T]
}

func (d *Dialer[T]) dialer() *net.Dialer {
if d.Dialer == nil {
d.Dialer = &net.Dialer{}
}
d.dialerOnce.Do(func() {
if d.Dialer != nil {
d.netDialer = d.Dialer
} else {
d.netDialer = &net.Dialer{}
}
})

return d.Dialer
return d.netDialer
}

func (d *Dialer[T]) Dial(network, addr string) (net.Conn, error) {
Expand Down Expand Up @@ -71,7 +79,7 @@ type localDialer[T peer.Peer] struct {
d *Dialer[T]
b Balancer[T]

opened bool
opened atomic.Bool
sf syncutil.Singleflight[struct{}]
}

Expand All @@ -82,13 +90,13 @@ func (ld *localDialer[T]) open(ctx context.Context) (struct{}, error) {
}
}

ld.opened = true
ld.opened.Store(true)

return struct{}{}, nil
}

func (ld *localDialer[T]) dial(ctx context.Context, network string) (net.Conn, error) {
if !ld.opened {
if !ld.opened.Load() {
if _, _, err := ld.sf.Do(ctx, ld.open); err != nil {
return nil, err
}
Expand Down Expand Up @@ -118,13 +126,14 @@ func (ld *localDialer[T]) close() error {
type doneCloserConn struct {
net.Conn

done func(error)
doneOnce sync.Once
done func(error)
}

func (dcc *doneCloserConn) Close() error {
err := dcc.Conn.Close()

dcc.done(err)
dcc.doneOnce.Do(func() { dcc.done(err) })

return err
}
17 changes: 9 additions & 8 deletions discovery/balancer/dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package balancer_test

import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/upfluence/pkg/v2/discovery/balancer"
"github.com/upfluence/pkg/v2/discovery/balancer/roundrobin"
"github.com/upfluence/pkg/v2/discovery/resolver/static"
Expand All @@ -16,13 +17,13 @@ import (
func TestDialer(t *testing.T) {
s1 := httptest.NewServer(
http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
io.WriteString(rw, "s1")
io.WriteString(rw, "s1") //nolint:errcheck
}),
)

s2 := httptest.NewServer(
http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) {
io.WriteString(rw, "s2")
io.WriteString(rw, "s2") //nolint:errcheck
}),
)

Expand All @@ -33,7 +34,7 @@ func TestDialer(t *testing.T) {
d := balancer.Dialer[static.Peer]{
Builder: balancer.ResolverBuilder[static.Peer]{
Builder: r,
BalancerFunc: roundrobin.BalancerFunc[static.Peer],
BalancerFunc: balancer.PolicyBalancerFunc(roundrobin.NewPolicy[static.Peer]()),
},
}

Expand All @@ -44,14 +45,14 @@ func TestDialer(t *testing.T) {
cl := http.Client{Transport: &http.Transport{DialContext: d.DialContext}}

req, err := http.NewRequest("GET", "http://example.com/foo", http.NoBody)
assert.Nil(t, err)
require.NoError(t, err)

for _, want := range []string{"s1", "s2", "s1"} {
resp, err := cl.Do(req)
assert.Nil(t, err)
require.NoError(t, err)

buf, err := ioutil.ReadAll(resp.Body)
assert.Nil(t, err)
buf, err := io.ReadAll(resp.Body)
require.NoError(t, err)
resp.Body.Close()
assert.Equal(t, want, string(buf))

Expand Down
Loading
Loading