diff --git a/multinode/multi_node.go b/multinode/multi_node.go index c350953..3f8007c 100644 --- a/multinode/multi_node.go +++ b/multinode/multi_node.go @@ -164,7 +164,7 @@ func (c *MultiNode[CHAIN_ID, RPC]) start(ctx context.Context) error { } c.eng.Go(c.runLoop) - if c.leaseDuration.Seconds() > 0 && c.selectionMode != NodeSelectionModeRoundRobin { + if c.leaseDuration.Seconds() > 0 && c.selectionMode != NodeSelectionModeRoundRobin && c.selectionMode != NodeSelectionModeRandomRPC { c.lggr.Infof("The MultiNode will switch to best node every %s", c.leaseDuration.String()) c.eng.Go(c.checkLeaseLoop) } else { @@ -192,6 +192,10 @@ func (c *MultiNode[CHAIN_ID, RPC]) SelectRPC(ctx context.Context) (rpc RPC, err // selectNode returns the active Node, if it is still nodeStateAlive, otherwise it selects a new one from the NodeSelector. func (c *MultiNode[CHAIN_ID, RPC]) selectNode(ctx context.Context) (node Node[CHAIN_ID, RPC], err error) { + if c.selectionMode == NodeSelectionModeRandomRPC { + return c.awaitNodeSelection(ctx) + } + c.activeMu.RLock() node = c.activeNode c.activeMu.RUnlock() @@ -213,15 +217,26 @@ func (c *MultiNode[CHAIN_ID, RPC]) selectNode(ctx context.Context) (node Node[CH c.activeNode.UnsubscribeAllExceptAliveLoop() } + c.activeNode, err = c.awaitNodeSelection(ctx) + if err != nil { + return nil, err + } + + c.lggr.Debugw("Switched to a new active node due to prev node heath issues", "prevNode", prevNodeName, "newNode", c.activeNode.String()) + return c.activeNode, err +} + +// awaitNodeSelection blocks until nodeSelector returns a live node or all nodes +// finish initializing. Returns ErrNodeError when no live nodes are available. +func (c *MultiNode[CHAIN_ID, RPC]) awaitNodeSelection(ctx context.Context) (Node[CHAIN_ID, RPC], error) { for { - c.activeNode = c.nodeSelector.Select() - if c.activeNode != nil { - break + node := c.nodeSelector.Select() + if node != nil { + return node, nil } if slices.ContainsFunc(c.primaryNodes, func(n Node[CHAIN_ID, RPC]) bool { return n.State().isInitializing() }) { - // initial dial still in-progress - retry until done select { case <-ctx.Done(): return nil, ctx.Err() @@ -233,9 +248,6 @@ func (c *MultiNode[CHAIN_ID, RPC]) selectNode(ctx context.Context) (node Node[CH c.eng.EmitHealthErr(fmt.Errorf("no live nodes available for chain %s", c.chainID.String())) return nil, ErrNodeError } - - c.lggr.Debugw("Switched to a new active node due to prev node heath issues", "prevNode", prevNodeName, "newNode", c.activeNode.String()) - return c.activeNode, err } // LatestChainInfo - returns number of live nodes available in the pool, so we can prevent the last alive node in a pool from being marked as out-of-sync. diff --git a/multinode/multi_node_test.go b/multinode/multi_node_test.go index 89b54f1..efde024 100644 --- a/multinode/multi_node_test.go +++ b/multinode/multi_node_test.go @@ -406,6 +406,101 @@ func TestMultiNode_selectNode(t *testing.T) { }) } +func TestMultiNode_RandomRPC(t *testing.T) { + t.Parallel() + t.Run("RandomRPC disables lease check", func(t *testing.T) { + t.Parallel() + chainID := RandomID() + node := newHealthyNode(t, chainID) + lggr, observedLogs := logger.TestObserved(t, zap.InfoLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRandomRPC, + chainID: chainID, + logger: lggr, + nodes: []Node[ID, multiNodeRPCClient]{node}, + }) + servicetest.Run(t, mn) + tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled") + }) + t.Run("RandomRPC is non-sticky, calls Select on every invocation", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + chainID := RandomID() + node1 := newMockNode[ID, multiNodeRPCClient](t) + node1.On("State").Return(nodeStateAlive).Maybe() + node1.On("String").Return("node1").Maybe() + node2 := newMockNode[ID, multiNodeRPCClient](t) + node2.On("State").Return(nodeStateAlive).Maybe() + node2.On("String").Return("node2").Maybe() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRandomRPC, + chainID: chainID, + nodes: []Node[ID, multiNodeRPCClient]{node1, node2}, + }) + nodeSelector := newMockNodeSelector[ID, multiNodeRPCClient](t) + nodeSelector.On("Select").Return(node1).Once() + nodeSelector.On("Select").Return(node2).Once() + mn.nodeSelector = nodeSelector + + first, err := mn.selectNode(ctx) + require.NoError(t, err) + assert.Same(t, node1, first) + + second, err := mn.selectNode(ctx) + require.NoError(t, err) + assert.Same(t, node2, second) + }) + t.Run("RandomRPC does not unsubscribe previous node on selection", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + chainID := RandomID() + node1 := newMockNode[ID, multiNodeRPCClient](t) + node1.On("State").Return(nodeStateAlive).Maybe() + node1.On("String").Return("node1").Maybe() + node2 := newMockNode[ID, multiNodeRPCClient](t) + node2.On("State").Return(nodeStateAlive).Maybe() + node2.On("String").Return("node2").Maybe() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRandomRPC, + chainID: chainID, + nodes: []Node[ID, multiNodeRPCClient]{node1, node2}, + }) + nodeSelector := newMockNodeSelector[ID, multiNodeRPCClient](t) + nodeSelector.On("Select").Return(node1).Once() + nodeSelector.On("Select").Return(node2).Once() + mn.nodeSelector = nodeSelector + + _, err := mn.selectNode(ctx) + require.NoError(t, err) + _, err = mn.selectNode(ctx) + require.NoError(t, err) + + // UnsubscribeAllExceptAliveLoop must NOT have been called on either node. + // mockNode would fail the test if an unexpected call was made. + node1.AssertNotCalled(t, "UnsubscribeAllExceptAliveLoop") + node2.AssertNotCalled(t, "UnsubscribeAllExceptAliveLoop") + }) + t.Run("RandomRPC reports error when no nodes available", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + chainID := RandomID() + lggr, observedLogs := logger.TestObserved(t, zap.InfoLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRandomRPC, + chainID: chainID, + logger: lggr, + }) + nodeSelector := newMockNodeSelector[ID, multiNodeRPCClient](t) + nodeSelector.On("Select").Return(nil).Once() + nodeSelector.On("Name").Return("MockedNodeSelector").Once() + mn.nodeSelector = nodeSelector + node, err := mn.selectNode(ctx) + require.EqualError(t, err, ErrNodeError.Error()) + require.Nil(t, node) + tests.RequireLogMessage(t, observedLogs, "No live RPC nodes available") + }) +} + func TestMultiNode_ChainInfo(t *testing.T) { t.Parallel() type nodeParams struct { diff --git a/multinode/node_selector.go b/multinode/node_selector.go index ebf5166..603ac6c 100644 --- a/multinode/node_selector.go +++ b/multinode/node_selector.go @@ -9,6 +9,7 @@ const ( NodeSelectionModeRoundRobin = "RoundRobin" NodeSelectionModeTotalDifficulty = "TotalDifficulty" NodeSelectionModePriorityLevel = "PriorityLevel" + NodeSelectionModeRandomRPC = "RandomRPC" ) type NodeSelector[ @@ -35,6 +36,8 @@ func newNodeSelector[ return NewTotalDifficultyNodeSelector[CHAIN_ID, RPC](nodes) case NodeSelectionModePriorityLevel: return NewPriorityLevelNodeSelector[CHAIN_ID, RPC](nodes) + case NodeSelectionModeRandomRPC: + return NewRandomRPCSelector[CHAIN_ID, RPC](nodes) default: panic(fmt.Sprintf("unsupported NodeSelectionMode: %s", selectionMode)) } diff --git a/multinode/node_selector_random_rpc.go b/multinode/node_selector_random_rpc.go new file mode 100644 index 0000000..6bf45fa --- /dev/null +++ b/multinode/node_selector_random_rpc.go @@ -0,0 +1,43 @@ +package multinode + +import ( + "math/rand/v2" +) + +type randomRPCSelector[ + CHAIN_ID ID, + RPC any, +] struct { + nodes []Node[CHAIN_ID, RPC] +} + +func NewRandomRPCSelector[ + CHAIN_ID ID, + RPC any, +](nodes []Node[CHAIN_ID, RPC]) NodeSelector[CHAIN_ID, RPC] { + return &randomRPCSelector[CHAIN_ID, RPC]{ + nodes: nodes, + } +} + +func (s *randomRPCSelector[CHAIN_ID, RPC]) Select() Node[CHAIN_ID, RPC] { + var liveNodes []Node[CHAIN_ID, RPC] + for _, n := range s.nodes { + if n.State() == nodeStateAlive { + liveNodes = append(liveNodes, n) + } else { + n.UnsubscribeAllExceptAliveLoop() + } + } + + if len(liveNodes) == 0 { + return nil + } + + // #nosec G404 + return liveNodes[rand.IntN(len(liveNodes))] +} + +func (s *randomRPCSelector[CHAIN_ID, RPC]) Name() string { + return NodeSelectionModeRandomRPC +} diff --git a/multinode/node_selector_random_rpc_test.go b/multinode/node_selector_random_rpc_test.go new file mode 100644 index 0000000..4392585 --- /dev/null +++ b/multinode/node_selector_random_rpc_test.go @@ -0,0 +1,104 @@ +package multinode + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRandomRPCNodeSelectorName(t *testing.T) { + selector := newNodeSelector[ID, RPCClient[ID, Head]](NodeSelectionModeRandomRPC, nil) + assert.Equal(t, NodeSelectionModeRandomRPC, selector.Name()) +} + +func TestRandomRPCNodeSelector(t *testing.T) { + t.Parallel() + + type nodeClient RPCClient[ID, Head] + var nodes []Node[ID, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[ID, nodeClient](t) + if i == 0 { + node.On("State").Return(nodeStateOutOfSync) + node.On("UnsubscribeAllExceptAliveLoop") + } else { + node.On("State").Return(nodeStateAlive) + } + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeRandomRPC, nodes) + + // All selections should be from alive nodes only + for i := 0; i < 20; i++ { + selected := selector.Select() + assert.NotNil(t, selected) + assert.Contains(t, []Node[ID, nodeClient]{nodes[1], nodes[2]}, selected) + } +} + +func TestRandomRPCNodeSelector_None(t *testing.T) { + t.Parallel() + + type nodeClient RPCClient[ID, Head] + var nodes []Node[ID, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[ID, nodeClient](t) + if i == 0 { + node.On("State").Return(nodeStateOutOfSync) + } else { + node.On("State").Return(nodeStateUnreachable) + } + node.On("UnsubscribeAllExceptAliveLoop") + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeRandomRPC, nodes) + assert.Nil(t, selector.Select()) +} + +func TestRandomRPCNodeSelector_Distribution(t *testing.T) { + t.Parallel() + + type nodeClient RPCClient[ID, Head] + var nodes []Node[ID, nodeClient] + + const nAlive = 3 + for i := 0; i < nAlive; i++ { + node := newMockNode[ID, nodeClient](t) + node.On("State").Return(nodeStateAlive) + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeRandomRPC, nodes) + + const iterations = 1000 + counts := make(map[Node[ID, nodeClient]]int, nAlive) + for i := 0; i < iterations; i++ { + selected := selector.Select() + assert.NotNil(t, selected) + counts[selected]++ + } + + // Each node should be selected at least once with overwhelming probability + for _, n := range nodes { + assert.Positive(t, counts[n], "expected every alive node to be selected at least once") + } +} + +func TestRandomRPCNodeSelector_SingleNode(t *testing.T) { + t.Parallel() + + type nodeClient RPCClient[ID, Head] + + node := newMockNode[ID, nodeClient](t) + node.On("State").Return(nodeStateAlive) + + selector := newNodeSelector(NodeSelectionModeRandomRPC, []Node[ID, nodeClient]{node}) + + for i := 0; i < 5; i++ { + assert.Same(t, node, selector.Select()) + } +}