Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions universalClient/chains/chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,6 @@ const (
func (c *Chains) determineChainAction(cfg *uregistrytypes.ChainConfig) chainAction {
chainID := cfg.Chain

// Skip disabled chains
if cfg.Enabled == nil || (!cfg.Enabled.IsInboundEnabled && !cfg.Enabled.IsOutboundEnabled) {
c.logger.Debug().Str("chain", chainID).Msg("chain is disabled, skipping")
return chainActionSkip
}

// Check if chain exists
c.chainsMu.RLock()
_, exists := c.chains[chainID]
Expand Down Expand Up @@ -485,24 +479,73 @@ func sanitizeChainID(chainID string) string {
return result
}

// configsEqual compares two chain configurations
// configsEqual compares two chain configurations for fields relevant to the universal client
func configsEqual(a, b *uregistrytypes.ChainConfig) bool {
if a == nil || b == nil {
return a == b
}

// Handle Enabled field comparison
enabledEqual := false
if a.Enabled == nil && b.Enabled == nil {
enabledEqual = true
} else if a.Enabled != nil && b.Enabled != nil {
enabledEqual = a.Enabled.IsInboundEnabled == b.Enabled.IsInboundEnabled &&
a.Enabled.IsOutboundEnabled == b.Enabled.IsOutboundEnabled
if a.Chain != b.Chain ||
a.VmType != b.VmType ||
a.GatewayAddress != b.GatewayAddress {
return false
}

// Compare gateway methods
if !gatewayMethodsEqual(a.GatewayMethods, b.GatewayMethods) {
return false
}

// Compare vault methods
if !vaultMethodsEqual(a.VaultMethods, b.VaultMethods) {
return false
}

// Compare relevant fields
return a.Chain == b.Chain &&
a.VmType == b.VmType &&
a.GatewayAddress == b.GatewayAddress &&
enabledEqual
// Compare block confirmation
if !blockConfirmationEqual(a.BlockConfirmation, b.BlockConfirmation) {
return false
}

return true
}

func gatewayMethodsEqual(a, b []*uregistrytypes.GatewayMethods) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Name != b[i].Name ||
a[i].Identifier != b[i].Identifier ||
a[i].EventIdentifier != b[i].EventIdentifier ||
a[i].ConfirmationType != b[i].ConfirmationType {
return false
}
}
return true
}

func vaultMethodsEqual(a, b []*uregistrytypes.VaultMethods) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Name != b[i].Name ||
a[i].Identifier != b[i].Identifier ||
a[i].EventIdentifier != b[i].EventIdentifier ||
a[i].ConfirmationType != b[i].ConfirmationType {
return false
}
}
return true
}

func blockConfirmationEqual(a, b *uregistrytypes.BlockConfirmation) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return a.FastInbound == b.FastInbound &&
a.StandardInbound == b.StandardInbound
}
103 changes: 42 additions & 61 deletions universalClient/chains/chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,52 +163,71 @@ func TestConfigsEqual(t *testing.T) {
assert.False(t, configsEqual(cfg1, cfg2))
})

t.Run("different enabled state returns false", func(t *testing.T) {
t.Run("different gateway methods returns false", func(t *testing.T) {
cfg1 := &uregistrytypes.ChainConfig{
Chain: "chain1",
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: true,
IsOutboundEnabled: true,
GatewayMethods: []*uregistrytypes.GatewayMethods{
{Name: "method1", Identifier: "0xabc"},
},
}
cfg2 := &uregistrytypes.ChainConfig{
Chain: "chain1",
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: false,
IsOutboundEnabled: true,
GatewayMethods: []*uregistrytypes.GatewayMethods{
{Name: "method1", Identifier: "0xdef"},
},
}

assert.False(t, configsEqual(cfg1, cfg2))
})

t.Run("both enabled nil returns true", func(t *testing.T) {
t.Run("different vault methods returns false", func(t *testing.T) {
cfg1 := &uregistrytypes.ChainConfig{
Chain: "chain1",
Enabled: nil,
Chain: "chain1",
VaultMethods: []*uregistrytypes.VaultMethods{
{Name: "vault1", Identifier: "0xabc"},
},
}
cfg2 := &uregistrytypes.ChainConfig{
Chain: "chain1",
Enabled: nil,
Chain: "chain1",
VaultMethods: nil,
}

assert.True(t, configsEqual(cfg1, cfg2))
assert.False(t, configsEqual(cfg1, cfg2))
})

t.Run("one enabled nil returns false", func(t *testing.T) {
t.Run("different block confirmation returns false", func(t *testing.T) {
cfg1 := &uregistrytypes.ChainConfig{
Chain: "chain1",
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: true,
},
Chain: "chain1",
BlockConfirmation: &uregistrytypes.BlockConfirmation{FastInbound: 2, StandardInbound: 12},
}
cfg2 := &uregistrytypes.ChainConfig{
Chain: "chain1",
Enabled: nil,
Chain: "chain1",
BlockConfirmation: &uregistrytypes.BlockConfirmation{FastInbound: 5, StandardInbound: 12},
}

assert.False(t, configsEqual(cfg1, cfg2))
})

t.Run("same full config returns true", func(t *testing.T) {
cfg1 := &uregistrytypes.ChainConfig{
Chain: "chain1",
GatewayAddress: "0x123",
GatewayMethods: []*uregistrytypes.GatewayMethods{
{Name: "m1", Identifier: "0xabc"},
},
BlockConfirmation: &uregistrytypes.BlockConfirmation{FastInbound: 2, StandardInbound: 12},
}
cfg2 := &uregistrytypes.ChainConfig{
Chain: "chain1",
GatewayAddress: "0x123",
GatewayMethods: []*uregistrytypes.GatewayMethods{
{Name: "m1", Identifier: "0xabc"},
},
BlockConfirmation: &uregistrytypes.BlockConfirmation{FastInbound: 2, StandardInbound: 12},
}

assert.True(t, configsEqual(cfg1, cfg2))
})
}

func TestChainAction(t *testing.T) {
Expand Down Expand Up @@ -240,36 +259,10 @@ func TestDetermineChainAction(t *testing.T) {
}
chains := NewChains(nil, nil, cfg, logger)

t.Run("disabled chain returns skip", func(t *testing.T) {
chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
Enabled: nil,
}

action := chains.determineChainAction(chainCfg)
assert.Equal(t, chainActionSkip, action)
})

t.Run("disabled inbound and outbound returns skip", func(t *testing.T) {
t.Run("new chain returns add", func(t *testing.T) {
chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: false,
IsOutboundEnabled: false,
},
}

action := chains.determineChainAction(chainCfg)
assert.Equal(t, chainActionSkip, action)
})

t.Run("new enabled chain returns add", func(t *testing.T) {
chainCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: true,
IsOutboundEnabled: false,
},
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
}

action := chains.determineChainAction(chainCfg)
Expand All @@ -281,10 +274,6 @@ func TestDetermineChainAction(t *testing.T) {
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
GatewayAddress: "0x123",
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: true,
IsOutboundEnabled: true,
},
}

// Add the chain first
Expand All @@ -308,20 +297,12 @@ func TestDetermineChainAction(t *testing.T) {
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
GatewayAddress: "0x123",
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: true,
IsOutboundEnabled: true,
},
}

newCfg := &uregistrytypes.ChainConfig{
Chain: "eip155:1",
VmType: uregistrytypes.VmType_EVM,
GatewayAddress: "0x456", // Different address
Enabled: &uregistrytypes.ChainEnabled{
IsInboundEnabled: true,
IsOutboundEnabled: true,
},
}

// Add the chain first
Expand Down
75 changes: 54 additions & 21 deletions universalClient/chains/evm/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strings"
"time"

ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/rs/zerolog"

"github.com/pushchain/push-chain-node/universalClient/chains/common"
Expand Down Expand Up @@ -195,11 +197,22 @@ func (c *Client) initializeComponents() error {
eventStartFrom = c.chainConfig.EventStartFrom
}

// Fetch vault address from gateway contract
fetchCtx, fetchCancel := context.WithTimeout(c.ctx, 15*time.Second)
vaultAddr, err := FetchVaultAddress(fetchCtx, c.rpcClient, ethcommon.HexToAddress(c.registryConfig.GatewayAddress))
fetchCancel()
if err != nil {
return fmt.Errorf("failed to fetch vault address from gateway: %w", err)
}
c.logger.Info().Str("vault_address", vaultAddr.Hex()).Msg("vault address fetched from gateway")

eventListener, err := NewEventListener(
c.rpcClient,
c.registryConfig.GatewayAddress,
vaultAddr.Hex(),
c.registryConfig.Chain,
c.registryConfig.GatewayMethods,
c.registryConfig.VaultMethods,
c.database,
eventPollingSeconds,
eventStartFrom,
Expand All @@ -209,6 +222,25 @@ func (c *Client) initializeComponents() error {
return fmt.Errorf("failed to create event listener: %w", err)
}
c.eventListener = eventListener

// Create txBuilder
chainIDInt, err := parseEVMChainID(c.chainIDStr)
if err != nil {
return fmt.Errorf("failed to parse chain ID for txBuilder: %w", err)
}

txBuilder, err := NewTxBuilder(
c.rpcClient,
c.chainIDStr,
chainIDInt,
c.registryConfig.GatewayAddress,
vaultAddr,
c.logger,
)
if err != nil {
return fmt.Errorf("failed to create txBuilder: %w", err)
}
c.txBuilder = txBuilder
}

// Apply defaults for all configuration values
Expand Down Expand Up @@ -236,27 +268,6 @@ func (c *Client) initializeComponents() error {
)
}

// Create txBuilder if gateway is configured
if c.registryConfig != nil && c.registryConfig.GatewayAddress != "" {
// Parse chain ID to integer
chainIDInt, err := parseEVMChainID(c.chainIDStr)
if err != nil {
return fmt.Errorf("failed to parse chain ID for txBuilder: %w", err)
}

txBuilder, err := NewTxBuilder(
c.rpcClient,
c.chainIDStr,
chainIDInt,
c.registryConfig.GatewayAddress,
c.logger,
)
if err != nil {
return fmt.Errorf("failed to create txBuilder: %w", err)
}
c.txBuilder = txBuilder
}

return nil
}

Expand Down Expand Up @@ -363,3 +374,25 @@ func parseEVMChainID(caip2 string) (int64, error) {

return chainID, nil
}

// FetchVaultAddress calls the gateway's VAULT() public getter to retrieve the vault address.
func FetchVaultAddress(ctx context.Context, rpcClient *RPCClient, gatewayAddress ethcommon.Address) (ethcommon.Address, error) {
// vaultCallSelector is the 4-byte selector for VAULT() public getter
vaultCallSelector := crypto.Keccak256([]byte("VAULT()"))[:4]

result, err := rpcClient.CallContract(ctx, gatewayAddress, vaultCallSelector, nil)
if err != nil {
return ethcommon.Address{}, fmt.Errorf("VAULT() call failed: %w", err)
}

if len(result) < 32 {
return ethcommon.Address{}, fmt.Errorf("VAULT() returned invalid data (len=%d)", len(result))
}

addr := ethcommon.BytesToAddress(result[12:32])
if addr == (ethcommon.Address{}) {
return ethcommon.Address{}, fmt.Errorf("VAULT() returned zero address")
}

return addr, nil
}
Loading
Loading