From f7aaa45a91c8f77829dc5704d37b976ead505824 Mon Sep 17 00:00:00 2001 From: RodrigoAD <15104916+RodrigoAD@users.noreply.github.com> Date: Thu, 12 Feb 2026 17:34:01 +0100 Subject: [PATCH] add canton unit tests --- .mockery.yaml | 13 + sdk/canton/chain_metadata_test.go | 367 +++++++++ sdk/canton/configurer.go | 4 +- sdk/canton/configurer_test.go | 218 ++++++ sdk/canton/encoder_test.go | 429 +++++++++++ sdk/canton/executor.go | 2 +- sdk/canton/executor_test.go | 414 ++++++++++ sdk/canton/helpers_test.go | 194 +++++ sdk/canton/inspector_test.go | 720 +++++++++++++++--- .../mocks/apiv2/command_service_client.go | 262 +++++++ .../mocks/apiv2/state_service_client.go | 336 ++++++++ sdk/mocks/logger.go | 76 ++ 12 files changed, 2907 insertions(+), 128 deletions(-) create mode 100644 sdk/canton/chain_metadata_test.go create mode 100644 sdk/canton/configurer_test.go create mode 100644 sdk/canton/encoder_test.go create mode 100644 sdk/canton/executor_test.go create mode 100644 sdk/canton/helpers_test.go create mode 100644 sdk/canton/mocks/apiv2/command_service_client.go create mode 100644 sdk/canton/mocks/apiv2/state_service_client.go create mode 100644 sdk/mocks/logger.go diff --git a/.mockery.yaml b/.mockery.yaml index d4ad3d30..fe17a2db 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -160,6 +160,19 @@ packages: config: dir: "./sdk/sui/mocks/mcmsdeployer" filename: "mcmsdeployerencoder.go" + github.com/digital-asset/dazl-client/v8/go/api/com/daml/ledger/api/v2: + config: + all: false + outpkg: "mock_apiv2" + interfaces: + CommandServiceClient: + config: + dir: "./sdk/canton/mocks/apiv2" + filename: "command_service_client.go" + StateServiceClient: + config: + dir: "./sdk/canton/mocks/apiv2" + filename: "state_service_client.go" # Required to fix the following deprecation warning: # https://vektra.github.io/mockery/v2.48/deprecations/#issue-845-fix diff --git a/sdk/canton/chain_metadata_test.go b/sdk/canton/chain_metadata_test.go new file mode 100644 index 00000000..6fc87f89 --- /dev/null +++ b/sdk/canton/chain_metadata_test.go @@ -0,0 +1,367 @@ +package canton + +import ( + "encoding/json" + "testing" + + "github.com/smartcontractkit/mcms/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAdditionalFieldsMetadata_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input AdditionalFieldsMetadata + wantErr string + }{ + { + name: "valid metadata", + input: AdditionalFieldsMetadata{ + ChainId: 1, + MultisigId: "test-multisig", + PreOpCount: 0, + PostOpCount: 5, + OverridePreviousRoot: false, + }, + wantErr: "", + }, + { + name: "valid metadata with override", + input: AdditionalFieldsMetadata{ + ChainId: 123, + MultisigId: "another-multisig", + PreOpCount: 10, + PostOpCount: 20, + OverridePreviousRoot: true, + }, + wantErr: "", + }, + { + name: "valid metadata with same pre and post op count", + input: AdditionalFieldsMetadata{ + ChainId: 1, + MultisigId: "test-multisig", + PreOpCount: 5, + PostOpCount: 5, + }, + wantErr: "", + }, + { + name: "missing chainId", + input: AdditionalFieldsMetadata{ + MultisigId: "test-multisig", + PreOpCount: 0, + PostOpCount: 5, + }, + wantErr: "chainId is required", + }, + { + name: "missing multisigId", + input: AdditionalFieldsMetadata{ + ChainId: 1, + PreOpCount: 0, + PostOpCount: 5, + }, + wantErr: "multisigId is required", + }, + { + name: "postOpCount less than preOpCount", + input: AdditionalFieldsMetadata{ + ChainId: 1, + MultisigId: "test-multisig", + PreOpCount: 10, + PostOpCount: 5, + }, + wantErr: "postOpCount must be >= preOpCount", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.input.Validate() + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestNewChainMetadata(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + preOpCount uint64 + postOpCount uint64 + chainId int64 + multisigId string + mcmsContractID string + overridePreviousRoot bool + wantErr string + }{ + { + name: "valid metadata", + preOpCount: 0, + postOpCount: 5, + chainId: 1, + multisigId: "test-multisig", + mcmsContractID: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + wantErr: "", + }, + { + name: "valid metadata with override", + preOpCount: 10, + postOpCount: 20, + chainId: 123, + multisigId: "another-multisig", + mcmsContractID: "11f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + overridePreviousRoot: true, + wantErr: "", + }, + { + name: "missing mcmsContractID", + preOpCount: 0, + postOpCount: 5, + chainId: 1, + multisigId: "test-multisig", + mcmsContractID: "", + wantErr: "MCMS contract ID is required", + }, + { + name: "missing chainId", + preOpCount: 0, + postOpCount: 5, + chainId: 0, + multisigId: "test-multisig", + mcmsContractID: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + wantErr: "chainId is required", + }, + { + name: "missing multisigId", + preOpCount: 0, + postOpCount: 5, + chainId: 1, + multisigId: "", + mcmsContractID: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + wantErr: "multisigId is required", + }, + { + name: "postOpCount less than preOpCount", + preOpCount: 10, + postOpCount: 5, + chainId: 1, + multisigId: "test-multisig", + mcmsContractID: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + wantErr: "postOpCount must be >= preOpCount", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := NewChainMetadata( + tt.preOpCount, + tt.postOpCount, + tt.chainId, + tt.multisigId, + tt.mcmsContractID, + tt.overridePreviousRoot, + ) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Equal(t, types.ChainMetadata{}, got) + } else { + require.NoError(t, err) + assert.Equal(t, tt.mcmsContractID, got.MCMAddress) + assert.Equal(t, tt.preOpCount, got.StartingOpCount) + + // Validate additional fields were marshaled correctly + var additionalFields AdditionalFieldsMetadata + err = json.Unmarshal(got.AdditionalFields, &additionalFields) + require.NoError(t, err) + assert.Equal(t, tt.chainId, additionalFields.ChainId) + assert.Equal(t, tt.multisigId, additionalFields.MultisigId) + assert.Equal(t, tt.preOpCount, additionalFields.PreOpCount) + assert.Equal(t, tt.postOpCount, additionalFields.PostOpCount) + assert.Equal(t, tt.overridePreviousRoot, additionalFields.OverridePreviousRoot) + } + }) + } +} + +func TestValidateChainMetadata(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata types.ChainMetadata + wantErr string + }{ + { + name: "valid metadata", + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig", + "preOpCount": 0, + "postOpCount": 5, + "overridePreviousRoot": false + }`), + }, + wantErr: "", + }, + { + name: "invalid additional fields - missing chainId", + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "multisigId": "test-multisig", + "preOpCount": 0, + "postOpCount": 5 + }`), + }, + wantErr: "chainId is required", + }, + { + name: "invalid additional fields - missing multisigId", + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "preOpCount": 0, + "postOpCount": 5 + }`), + }, + wantErr: "multisigId is required", + }, + { + name: "invalid additional fields - postOpCount less than preOpCount", + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig", + "preOpCount": 10, + "postOpCount": 5 + }`), + }, + wantErr: "postOpCount must be >= preOpCount", + }, + { + name: "invalid JSON in additional fields", + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{invalid json}`), + }, + wantErr: "unable to unmarshal additional fields", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateChainMetadata(tt.metadata) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTimelockRole_String(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + role TimelockRole + want string + }{ + { + name: "bypasser", + role: TimelockRoleBypasser, + want: "Bypasser", + }, + { + name: "proposer", + role: TimelockRoleProposer, + want: "Proposer", + }, + { + name: "canceller", + role: TimelockRoleCanceller, + want: "Canceller", + }, + { + name: "unknown", + role: TimelockRole(99), + want: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := tt.role.String() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestTimelockRole_Byte(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + role TimelockRole + want uint8 + }{ + { + name: "bypasser", + role: TimelockRoleBypasser, + want: 0, + }, + { + name: "canceller", + role: TimelockRoleCanceller, + want: 1, + }, + { + name: "proposer", + role: TimelockRoleProposer, + want: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := tt.role.Byte() + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/sdk/canton/configurer.go b/sdk/canton/configurer.go index a723874a..6ee1bb6d 100644 --- a/sdk/canton/configurer.go +++ b/sdk/canton/configurer.go @@ -73,7 +73,7 @@ func (c Configurer) SetConfig(ctx context.Context, mcmsAddr string, cfg *types.C exerciseCmd := mcmsContract.SetConfig(mcmsAddr, input) // Parse template ID - packageID, moduleName, entityName, err := parseTemplateIDFromString(mcmsContract.GetTemplateID()) + packageID, moduleName, entityName, err := ParseTemplateIDFromString(mcmsContract.GetTemplateID()) if err != nil { return types.TransactionResult{}, fmt.Errorf("failed to parse template ID: %w", err) } @@ -113,7 +113,7 @@ func (c Configurer) SetConfig(ctx context.Context, mcmsAddr string, cfg *types.C transaction := submitResp.GetTransaction() for _, ev := range transaction.GetEvents() { if createdEv := ev.GetCreated(); createdEv != nil { - templateID := formatTemplateID(createdEv.GetTemplateId()) + templateID := FormatTemplateID(createdEv.GetTemplateId()) normalized := NormalizeTemplateKey(templateID) if normalized == MCMSTemplateKey { newMCMSContractID = createdEv.GetContractId() diff --git a/sdk/canton/configurer_test.go b/sdk/canton/configurer_test.go new file mode 100644 index 00000000..eccb0fb5 --- /dev/null +++ b/sdk/canton/configurer_test.go @@ -0,0 +1,218 @@ +package canton + +import ( + "context" + "testing" + + apiv2 "github.com/digital-asset/dazl-client/v8/go/api/com/daml/ledger/api/v2" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + mock_apiv2 "github.com/smartcontractkit/mcms/sdk/canton/mocks/apiv2" + "github.com/smartcontractkit/mcms/types" +) + +func TestNewConfigurer(t *testing.T) { + t.Parallel() + + mockCommandClient := mock_apiv2.NewCommandServiceClient(t) + userId := "user123" + party := "Alice::party123" + role := TimelockRoleProposer + + configurer, err := NewConfigurer(mockCommandClient, userId, party, role) + + require.NoError(t, err) + require.NotNil(t, configurer) + assert.Equal(t, mockCommandClient, configurer.client) + assert.Equal(t, userId, configurer.userId) + assert.Equal(t, party, configurer.party) + assert.Equal(t, role, configurer.role) +} + +func TestConfigurer_SetConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mcmsAddr string + cfg *types.Config + clearRoot bool + role TimelockRole + mockSetup func(*mock_apiv2.CommandServiceClient) + wantErr string + wantTxHash string + }{ + { + name: "success - simple config", + mcmsAddr: "contract-id-123", + cfg: &types.Config{ + Quorum: 2, + Signers: []common.Address{ + common.HexToAddress("0x1111111111111111111111111111111111111111"), + common.HexToAddress("0x2222222222222222222222222222222222222222"), + }, + GroupSigners: []types.Config{}, + }, + clearRoot: false, + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + // Mock successful submission + mockClient.EXPECT().SubmitAndWaitForTransaction( + mock.Anything, + mock.MatchedBy(func(req *apiv2.SubmitAndWaitForTransactionRequest) bool { + return req.Commands != nil && + req.Commands.WorkflowId == "mcms-set-config" && + len(req.Commands.Commands) == 1 && + req.Commands.Commands[0].GetExercise() != nil + }), + ).Return(&apiv2.SubmitAndWaitForTransactionResponse{ + Transaction: &apiv2.Transaction{ + UpdateId: "tx-123", + Events: []*apiv2.Event{ + { + Event: &apiv2.Event_Created{ + Created: &apiv2.CreatedEvent{ + ContractId: "new-contract-id-456", + TemplateId: &apiv2.Identifier{ + PackageId: "mcms-package", + ModuleName: "MCMS.Main", + EntityName: "MCMS", + }, + }, + }, + }, + }, + }, + }, nil) + }, + wantTxHash: "tx.Digest", + wantErr: "", + }, + { + name: "success - hierarchical config", + mcmsAddr: "contract-id-456", + cfg: &types.Config{ + Quorum: 1, + Signers: []common.Address{ + common.HexToAddress("0x1111111111111111111111111111111111111111"), + }, + GroupSigners: []types.Config{ + { + Quorum: 2, + Signers: []common.Address{ + common.HexToAddress("0x2222222222222222222222222222222222222222"), + common.HexToAddress("0x3333333333333333333333333333333333333333"), + }, + GroupSigners: []types.Config{}, + }, + }, + }, + clearRoot: true, + role: TimelockRoleBypasser, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction(mock.Anything, mock.Anything).Return( + &apiv2.SubmitAndWaitForTransactionResponse{ + Transaction: &apiv2.Transaction{ + UpdateId: "tx-456", + Events: []*apiv2.Event{ + { + Event: &apiv2.Event_Created{ + Created: &apiv2.CreatedEvent{ + ContractId: "new-contract-id-789", + TemplateId: &apiv2.Identifier{ + PackageId: "mcms-package", + ModuleName: "MCMS.Main", + EntityName: "MCMS", + }, + }, + }, + }, + }, + }, + }, + nil, + ) + }, + wantTxHash: "tx.Digest", + wantErr: "", + }, + { + name: "failure - submission error", + mcmsAddr: "contract-id-bad", + cfg: &types.Config{ + Quorum: 1, + Signers: []common.Address{ + common.HexToAddress("0x1111111111111111111111111111111111111111"), + }, + GroupSigners: []types.Config{}, + }, + clearRoot: false, + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction(mock.Anything, mock.Anything).Return( + nil, + assert.AnError, + ) + }, + wantTxHash: "", + wantErr: "failed to set config", + }, + { + name: "failure - no MCMS created event", + mcmsAddr: "contract-id-no-event", + cfg: &types.Config{ + Quorum: 1, + Signers: []common.Address{ + common.HexToAddress("0x1111111111111111111111111111111111111111"), + }, + GroupSigners: []types.Config{}, + }, + clearRoot: false, + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction(mock.Anything, mock.Anything).Return( + &apiv2.SubmitAndWaitForTransactionResponse{ + Transaction: &apiv2.Transaction{ + UpdateId: "tx-no-event", + Events: []*apiv2.Event{}, // No events + }, + }, + nil, + ) + }, + wantTxHash: "", + wantErr: "set-config tx had no Created MCMS event", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + mockCommandClient := mock_apiv2.NewCommandServiceClient(t) + if tt.mockSetup != nil { + tt.mockSetup(mockCommandClient) + } + + configurer, err := NewConfigurer(mockCommandClient, "user123", "Alice::party123", tt.role) + require.NoError(t, err) + + result, err := configurer.SetConfig(ctx, tt.mcmsAddr, tt.cfg, tt.clearRoot) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantTxHash, result.Hash) + assert.NotNil(t, result.RawData) + assert.Contains(t, result.RawData, "NewMCMSContractID") + assert.Contains(t, result.RawData, "NewMCMSTemplateID") + } + }) + } +} diff --git a/sdk/canton/encoder_test.go b/sdk/canton/encoder_test.go new file mode 100644 index 00000000..4e051b2c --- /dev/null +++ b/sdk/canton/encoder_test.go @@ -0,0 +1,429 @@ +package canton + +import ( + "encoding/json" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/mcms/types" +) + +func AssertErrorContains(errorMessage string) assert.ErrorAssertionFunc { + return func(t assert.TestingT, err error, msgAndArgs ...any) bool { + if err == nil { + return assert.Fail(t, "Expected error to be returned", msgAndArgs...) + } + return assert.Contains(t, err.Error(), errorMessage, msgAndArgs...) + } +} + +func TestEncoder_HashOperation(t *testing.T) { + t.Parallel() + + type fields struct { + ChainSelector types.ChainSelector + TxCount uint64 + OverridePreviousRoot bool + } + type args struct { + opCount uint32 + metadata types.ChainMetadata + op types.Operation + } + tests := []struct { + name string + fields fields + args args + want common.Hash + wantErr assert.ErrorAssertionFunc + }{ + { + name: "success", + fields: fields{ + ChainSelector: 1, + TxCount: 5, + OverridePreviousRoot: false, + }, + args: args{ + opCount: 3, + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig", + "preOpCount": 0, + "postOpCount": 5, + "overridePreviousRoot": false + }`), + }, + op: types.Operation{ + Transaction: types.Transaction{ + To: "target-contract", + Data: []byte{0x11, 0x22, 0x33, 0x44}, + AdditionalFields: json.RawMessage(`{ + "targetInstanceId": "instance-123", + "functionName": "executeAction", + "operationData": "1122334455", + "targetCid": "cid-123" + }`), + }, + }, + }, + want: common.HexToHash("0xe4f7155153e90245c12d484e518341394978318905a6710b230b54977170138a"), + wantErr: assert.NoError, + }, + { + name: "success with different values", + fields: fields{ + ChainSelector: 2, + TxCount: 10, + OverridePreviousRoot: true, + }, + args: args{ + opCount: 5, + metadata: types.ChainMetadata{ + MCMAddress: "11f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{ + "chainId": 123, + "multisigId": "prod-multisig", + "preOpCount": 5, + "postOpCount": 15, + "overridePreviousRoot": true + }`), + }, + op: types.Operation{ + Transaction: types.Transaction{ + To: "another-target", + Data: []byte{0xaa, 0xbb, 0xcc}, + AdditionalFields: json.RawMessage(`{ + "targetInstanceId": "prod-instance", + "functionName": "transfer", + "operationData": "aabbccdd", + "targetCid": "cid-456" + }`), + }, + }, + }, + // Different encoding with different values + want: common.HexToHash("0xe391f3bfd945f06842d5dd286ecdabd1a148a94fe982e840b82db16daf6848ba"), + wantErr: assert.NoError, + }, + { + name: "failure - invalid metadata additional fields JSON", + fields: fields{ + ChainSelector: 1, + }, + args: args{ + opCount: 3, + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{invalid json}`), + }, + op: types.Operation{}, + }, + wantErr: AssertErrorContains("failed to unmarshal metadata additional fields"), + }, + { + name: "failure - invalid operation additional fields JSON", + fields: fields{ + ChainSelector: 1, + }, + args: args{ + opCount: 3, + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig", + "preOpCount": 0, + "postOpCount": 5 + }`), + }, + op: types.Operation{ + Transaction: types.Transaction{ + AdditionalFields: json.RawMessage(`{invalid json}`), + }, + }, + }, + wantErr: AssertErrorContains("failed to unmarshal operation additional fields"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + e := NewEncoder(tt.fields.ChainSelector, tt.fields.TxCount, tt.fields.OverridePreviousRoot) + got, err := e.HashOperation(tt.args.opCount, tt.args.metadata, tt.args.op) + + tt.wantErr(t, err) + if err == nil { + t.Logf("Computed hash for %s: %s", tt.name, got.Hex()) + assert.Equal(t, tt.want, got, "hash should match expected value") + } + }) + } +} + +func TestEncoder_HashMetadata(t *testing.T) { + t.Parallel() + + type fields struct { + ChainSelector types.ChainSelector + TxCount uint64 + OverridePreviousRoot bool + } + type args struct { + metadata types.ChainMetadata + } + tests := []struct { + name string + fields fields + args args + want common.Hash + wantErr assert.ErrorAssertionFunc + }{ + { + name: "success without override", + fields: fields{ + ChainSelector: 1, + TxCount: 5, + OverridePreviousRoot: false, + }, + args: args{ + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig", + "preOpCount": 0, + "postOpCount": 5, + "overridePreviousRoot": false + }`), + }, + }, + want: common.HexToHash("0xb88e2ae0ecfa263c7a6fa6322e9aac55a06e722a1d2cf470cca361dd1325f9a2"), + wantErr: assert.NoError, + }, + { + name: "success with override", + fields: fields{ + ChainSelector: 2, + TxCount: 10, + OverridePreviousRoot: true, + }, + args: args{ + metadata: types.ChainMetadata{ + MCMAddress: "11f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{ + "chainId": 123, + "multisigId": "prod-multisig", + "preOpCount": 5, + "postOpCount": 15, + "overridePreviousRoot": true + }`), + }, + }, + want: common.HexToHash("0x8d9da63765997aa703892ca624faaecf58c83d0201642ae313d98588a9705f08"), + wantErr: assert.NoError, + }, + { + name: "success with different chain ID", + fields: fields{ + ChainSelector: 3, + }, + args: args{ + metadata: types.ChainMetadata{ + MCMAddress: "22f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{ + "chainId": 999, + "multisigId": "another-multisig", + "preOpCount": 100, + "postOpCount": 200, + "overridePreviousRoot": false + }`), + }, + }, + want: common.HexToHash("0x92794b1a017c4f55705839076c13ceb257e75fad47cb9420420e58cd38854351"), + wantErr: assert.NoError, + }, + { + name: "failure - invalid metadata additional fields JSON", + fields: fields{ + ChainSelector: 1, + }, + args: args{ + metadata: types.ChainMetadata{ + MCMAddress: "00f8a3c8ed6c7e34bb3f3f16ed5d8a5fc9b7a6c1d2e3f4a5b6c7d8e9f0a1b2c3", + AdditionalFields: json.RawMessage(`{invalid json}`), + }, + }, + wantErr: AssertErrorContains("failed to unmarshal metadata additional fields"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + e := NewEncoder(tt.fields.ChainSelector, tt.fields.TxCount, tt.fields.OverridePreviousRoot) + got, err := e.HashMetadata(tt.args.metadata) + + tt.wantErr(t, err) + if err == nil { + t.Logf("Computed hash for %s: %s", tt.name, got.Hex()) + assert.Equal(t, tt.want, got, "hash should match expected value") + } + }) + } +} + +func TestPadLeft32(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + { + name: "pad short string", + input: "1", + want: "0000000000000000000000000000000000000000000000000000000000000001", + }, + { + name: "pad medium string", + input: "123abc", + want: "000000000000000000000000000000000000000000000000000000000123abc", + }, + { + name: "no padding needed", + input: "0000000000000000000000000000000000000000000000000000000000000001", + want: "0000000000000000000000000000000000000000000000000000000000000001", + }, + { + name: "truncate if too long", + input: "00000000000000000000000000000000000000000000000000000000000000001", + want: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + name: "empty string", + input: "", + want: "0000000000000000000000000000000000000000000000000000000000000000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := padLeft32(tt.input) + assert.Equal(t, tt.want, got) + assert.Equal(t, 64, len(got), "result should always be 64 characters") + }) + } +} + +func TestIntToHex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input int + want string + }{ + { + name: "zero", + input: 0, + want: "0", + }, + { + name: "single digit", + input: 5, + want: "5", + }, + { + name: "two digits", + input: 15, + want: "f", + }, + { + name: "large number", + input: 255, + want: "ff", + }, + { + name: "very large number", + input: 1234567, + want: "12d687", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := intToHex(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestAsciiToHex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want string + }{ + { + name: "simple string", + input: "hello", + want: "68656c6c6f", + }, + { + name: "with hyphen", + input: "test-multisig", + want: "746573742d6d756c7469736967", + }, + { + name: "numbers", + input: "12345", + want: "3132333435", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "special characters", + input: "a@b.c", + want: "6140622e63", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := asciiToHex(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestNewEncoder(t *testing.T) { + t.Parallel() + + encoder := NewEncoder(123, 456, true) + + require.NotNil(t, encoder) + assert.Equal(t, types.ChainSelector(123), encoder.ChainSelector) + assert.Equal(t, uint64(456), encoder.TxCount) + assert.True(t, encoder.OverridePreviousRoot) +} diff --git a/sdk/canton/executor.go b/sdk/canton/executor.go index 68ee8cc4..407acd38 100644 --- a/sdk/canton/executor.go +++ b/sdk/canton/executor.go @@ -123,7 +123,7 @@ func (e Executor) ExecuteOperation( choiceArgument = ledger.MapToValue(input) // Parse template ID - packageID, moduleName, entityName, err := parseTemplateIDFromString(mcmsContract.GetTemplateID()) + packageID, moduleName, entityName, err := ParseTemplateIDFromString(mcmsContract.GetTemplateID()) if err != nil { return types.TransactionResult{}, fmt.Errorf("failed to parse template ID: %w", err) } diff --git a/sdk/canton/executor_test.go b/sdk/canton/executor_test.go new file mode 100644 index 00000000..c05069ee --- /dev/null +++ b/sdk/canton/executor_test.go @@ -0,0 +1,414 @@ +package canton + +import ( + "context" + "encoding/json" + "testing" + + apiv2 "github.com/digital-asset/dazl-client/v8/go/api/com/daml/ledger/api/v2" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + mock_apiv2 "github.com/smartcontractkit/mcms/sdk/canton/mocks/apiv2" + "github.com/smartcontractkit/mcms/types" +) + +func TestNewExecutor(t *testing.T) { + t.Parallel() + + encoder := NewEncoder(1, 5, false) + mockStateClient := mock_apiv2.NewStateServiceClient(t) + inspector := NewInspector(mockStateClient, "Alice::party123", TimelockRoleProposer) + mockCommandClient := mock_apiv2.NewCommandServiceClient(t) + userId := "user123" + party := "Alice::party123" + role := TimelockRoleProposer + + executor, err := NewExecutor(encoder, inspector, mockCommandClient, userId, party, role) + + require.NoError(t, err) + require.NotNil(t, executor) + assert.Equal(t, encoder, executor.Encoder) + assert.Equal(t, inspector, executor.Inspector) + assert.Equal(t, mockCommandClient, executor.client) + assert.Equal(t, userId, executor.userId) + assert.Equal(t, party, executor.party) + assert.Equal(t, role, executor.role) +} + +func TestExecutor_ExecuteOperation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata types.ChainMetadata + nonce uint32 + proof []common.Hash + op types.Operation + mockSetup func(*mock_apiv2.CommandServiceClient) + wantErr string + }{ + { + name: "success - execute operation", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-123", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig" + }`), + }, + nonce: 1, + proof: []common.Hash{ + common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111"), + common.HexToHash("0x2222222222222222222222222222222222222222222222222222222222222222"), + }, + op: types.Operation{ + Transaction: types.Transaction{ + To: "target-contract", + Data: []byte{0x11, 0x22, 0x33}, + AdditionalFields: json.RawMessage(`{ + "targetInstanceId": "instance-123", + "functionName": "executeAction", + "operationData": "112233", + "targetCid": "cid-123", + "contractIds": ["cid-456", "cid-789"] + }`), + }, + }, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction( + mock.Anything, + mock.MatchedBy(func(req *apiv2.SubmitAndWaitForTransactionRequest) bool { + return req.Commands != nil && + req.Commands.WorkflowId == "mcms-execute-op" && + len(req.Commands.Commands) == 1 + }), + ).Return(&apiv2.SubmitAndWaitForTransactionResponse{ + Transaction: &apiv2.Transaction{ + UpdateId: "tx-execute-123", + Events: []*apiv2.Event{ + { + Event: &apiv2.Event_Created{ + Created: &apiv2.CreatedEvent{ + ContractId: "new-contract-id-after-execute", + TemplateId: &apiv2.Identifier{ + PackageId: "mcms-package", + ModuleName: "MCMS.Main", + EntityName: "MCMS", + }, + }, + }, + }, + }, + }, + }, nil) + }, + wantErr: "", + }, + { + name: "failure - missing targetInstanceId", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-123", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig" + }`), + }, + nonce: 1, + proof: []common.Hash{}, + op: types.Operation{ + Transaction: types.Transaction{ + To: "target-contract", + Data: []byte{0x11, 0x22, 0x33}, + AdditionalFields: json.RawMessage(`{ + "functionName": "executeAction", + "operationData": "112233", + "targetCid": "cid-123" + }`), + }, + }, + mockSetup: nil, + wantErr: "targetInstanceId is required", + }, + { + name: "failure - missing functionName", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-123", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig" + }`), + }, + nonce: 1, + proof: []common.Hash{}, + op: types.Operation{ + Transaction: types.Transaction{ + To: "target-contract", + Data: []byte{0x11, 0x22, 0x33}, + AdditionalFields: json.RawMessage(`{ + "targetInstanceId": "instance-123", + "operationData": "112233", + "targetCid": "cid-123" + }`), + }, + }, + mockSetup: nil, + wantErr: "functionName is required", + }, + { + name: "failure - missing targetCid", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-123", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig" + }`), + }, + nonce: 1, + proof: []common.Hash{}, + op: types.Operation{ + Transaction: types.Transaction{ + To: "target-contract", + Data: []byte{0x11, 0x22, 0x33}, + AdditionalFields: json.RawMessage(`{ + "targetInstanceId": "instance-123", + "functionName": "executeAction", + "operationData": "112233" + }`), + }, + }, + mockSetup: nil, + wantErr: "targetCid is required", + }, + { + name: "failure - submission error", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-bad", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig" + }`), + }, + nonce: 1, + proof: []common.Hash{}, + op: types.Operation{ + Transaction: types.Transaction{ + To: "target-contract", + Data: []byte{0x11, 0x22, 0x33}, + AdditionalFields: json.RawMessage(`{ + "targetInstanceId": "instance-123", + "functionName": "executeAction", + "operationData": "112233", + "targetCid": "cid-123" + }`), + }, + }, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction(mock.Anything, mock.Anything).Return( + nil, + assert.AnError, + ) + }, + wantErr: "failed to execute operation", + }, + { + name: "failure - no MCMS created event after execution", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-no-event", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig" + }`), + }, + nonce: 1, + proof: []common.Hash{}, + op: types.Operation{ + Transaction: types.Transaction{ + To: "target-contract", + Data: []byte{0x11, 0x22, 0x33}, + AdditionalFields: json.RawMessage(`{ + "targetInstanceId": "instance-123", + "functionName": "executeAction", + "operationData": "112233", + "targetCid": "cid-123" + }`), + }, + }, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction(mock.Anything, mock.Anything).Return( + &apiv2.SubmitAndWaitForTransactionResponse{ + Transaction: &apiv2.Transaction{ + UpdateId: "tx-no-event", + Events: []*apiv2.Event{}, // No events + }, + }, + nil, + ) + }, + wantErr: "execute-op tx had no Created MCMS event", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + encoder := NewEncoder(1, 5, false) + mockStateClient := mock_apiv2.NewStateServiceClient(t) + inspector := NewInspector(mockStateClient, "Alice::party123", TimelockRoleProposer) + mockCommandClient := mock_apiv2.NewCommandServiceClient(t) + + if tt.mockSetup != nil { + tt.mockSetup(mockCommandClient) + } + + executor, err := NewExecutor(encoder, inspector, mockCommandClient, "user123", "Alice::party123", TimelockRoleProposer) + require.NoError(t, err) + + result, err := executor.ExecuteOperation(ctx, tt.metadata, tt.nonce, tt.proof, tt.op) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + assert.NotEmpty(t, result.Hash) + assert.NotNil(t, result.RawData) + assert.Contains(t, result.RawData, "NewMCMSContractID") + } + }) + } +} + +func TestExecutor_SetRoot(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata types.ChainMetadata + proof []common.Hash + root [32]byte + validUntil uint32 + sortedSignatures []types.Signature + mockSetup func(*mock_apiv2.CommandServiceClient) + wantErr string + }{ + { + name: "success - set root", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-123", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig", + "preOpCount": 0, + "postOpCount": 5 + }`), + }, + proof: []common.Hash{ + common.HexToHash("0x1111111111111111111111111111111111111111111111111111111111111111"), + }, + root: [32]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}, + validUntil: 1234567890, + sortedSignatures: []types.Signature{ + { + V: 27, + R: common.Hash{0x11}, + S: common.Hash{0x22}, + }, + }, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction( + mock.Anything, + mock.MatchedBy(func(req *apiv2.SubmitAndWaitForTransactionRequest) bool { + return req.Commands != nil && + req.Commands.WorkflowId == "mcms-set-root" && + len(req.Commands.Commands) == 1 + }), + ).Return(&apiv2.SubmitAndWaitForTransactionResponse{ + Transaction: &apiv2.Transaction{ + UpdateId: "tx-setroot-123", + Events: []*apiv2.Event{ + { + Event: &apiv2.Event_Created{ + Created: &apiv2.CreatedEvent{ + ContractId: "new-contract-id-after-setroot", + TemplateId: &apiv2.Identifier{ + PackageId: "mcms-package", + ModuleName: "MCMS.Main", + EntityName: "MCMS", + }, + }, + }, + }, + }, + }, + }, nil) + }, + wantErr: "", + }, + { + name: "failure - submission error", + metadata: types.ChainMetadata{ + MCMAddress: "contract-id-bad", + StartingOpCount: 0, + AdditionalFields: json.RawMessage(`{ + "chainId": 1, + "multisigId": "test-multisig", + "preOpCount": 0, + "postOpCount": 5 + }`), + }, + proof: []common.Hash{}, + root: [32]byte{0x12}, + validUntil: 1234567890, + sortedSignatures: []types.Signature{}, + mockSetup: func(mockClient *mock_apiv2.CommandServiceClient) { + mockClient.EXPECT().SubmitAndWaitForTransaction(mock.Anything, mock.Anything).Return( + nil, + assert.AnError, + ) + }, + wantErr: "failed to set root", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + encoder := NewEncoder(1, 5, false) + mockStateClient := mock_apiv2.NewStateServiceClient(t) + inspector := NewInspector(mockStateClient, "Alice::party123", TimelockRoleProposer) + mockCommandClient := mock_apiv2.NewCommandServiceClient(t) + + if tt.mockSetup != nil { + tt.mockSetup(mockCommandClient) + } + + executor, err := NewExecutor(encoder, inspector, mockCommandClient, "user123", "Alice::party123", TimelockRoleProposer) + require.NoError(t, err) + + result, err := executor.SetRoot(ctx, tt.metadata, tt.proof, tt.root, tt.validUntil, tt.sortedSignatures) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + assert.NotEmpty(t, result.Hash) + assert.NotNil(t, result.RawData) + } + }) + } +} diff --git a/sdk/canton/helpers_test.go b/sdk/canton/helpers_test.go new file mode 100644 index 00000000..7dfb30a1 --- /dev/null +++ b/sdk/canton/helpers_test.go @@ -0,0 +1,194 @@ +package canton + +import ( + "testing" + + apiv2 "github.com/digital-asset/dazl-client/v8/go/api/com/daml/ledger/api/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseTemplateIDFromString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + templateID string + wantPkg string + wantModule string + wantEntity string + wantErr string + }{ + { + name: "valid template ID", + templateID: "#packageid123:MCMS.Main:MCMS", + wantPkg: "#packageid123", + wantModule: "MCMS.Main", + wantEntity: "MCMS", + wantErr: "", + }, + { + name: "another valid template ID", + templateID: "#abc123def456:Module.Submodule:Contract", + wantPkg: "#abc123def456", + wantModule: "Module.Submodule", + wantEntity: "Contract", + wantErr: "", + }, + { + name: "missing hash prefix", + templateID: "packageid123:MCMS.Main:MCMS", + wantPkg: "", + wantModule: "", + wantEntity: "", + wantErr: "template ID must start with #", + }, + { + name: "too few parts", + templateID: "#packageid123:MCMS", + wantPkg: "", + wantModule: "", + wantEntity: "", + wantErr: "template ID must have format #package:module:entity", + }, + { + name: "too many parts", + templateID: "#packageid123:MCMS.Main:MCMS:Extra", + wantPkg: "", + wantModule: "", + wantEntity: "", + wantErr: "template ID must have format #package:module:entity", + }, + { + name: "empty string", + templateID: "", + wantPkg: "", + wantModule: "", + wantEntity: "", + wantErr: "template ID must start with #", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + pkg, module, entity, err := ParseTemplateIDFromString(tt.templateID) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Empty(t, pkg) + assert.Empty(t, module) + assert.Empty(t, entity) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantPkg, pkg) + assert.Equal(t, tt.wantModule, module) + assert.Equal(t, tt.wantEntity, entity) + } + }) + } +} + +func TestFormatTemplateID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id *apiv2.Identifier + want string + }{ + { + name: "valid identifier", + id: &apiv2.Identifier{ + PackageId: "packageid123", + ModuleName: "MCMS.Main", + EntityName: "MCMS", + }, + want: "packageid123:MCMS.Main:MCMS", + }, + { + name: "another valid identifier", + id: &apiv2.Identifier{ + PackageId: "abc123def456", + ModuleName: "Module.Submodule", + EntityName: "Contract", + }, + want: "abc123def456:Module.Submodule:Contract", + }, + { + name: "nil identifier", + id: nil, + want: "", + }, + { + name: "empty fields", + id: &apiv2.Identifier{ + PackageId: "", + ModuleName: "", + EntityName: "", + }, + want: "::", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := FormatTemplateID(tt.id) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestNormalizeTemplateKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tid string + want string + }{ + { + name: "with hash prefix and full path", + tid: "#packageid123:MCMS.Main:MCMS", + want: "MCMS.Main:MCMS", + }, + { + name: "without hash prefix", + tid: "packageid123:MCMS.Main:MCMS", + want: "MCMS.Main:MCMS", + }, + { + name: "only two parts", + tid: "MCMS.Main:MCMS", + want: "MCMS.Main:MCMS", + }, + { + name: "single part", + tid: "MCMS", + want: "MCMS", + }, + { + name: "four parts with hash", + tid: "#pkg:ver:Module:Entity", + want: "Module:Entity", + }, + { + name: "empty string", + tid: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := NormalizeTemplateKey(tt.tid) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/sdk/canton/inspector_test.go b/sdk/canton/inspector_test.go index 2c275a5a..8edc3869 100644 --- a/sdk/canton/inspector_test.go +++ b/sdk/canton/inspector_test.go @@ -1,193 +1,663 @@ -//go:build e2e - package canton import ( + "context" + "io" "testing" + "time" + apiv2 "github.com/digital-asset/dazl-client/v8/go/api/com/daml/ledger/api/v2" "github.com/ethereum/go-ethereum/common" - "github.com/smartcontractkit/chainlink-canton/bindings/mcms" - "github.com/smartcontractkit/go-daml/pkg/types" - mcmstypes "github.com/smartcontractkit/mcms/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/smartcontractkit/chainlink-canton/bindings/mcms" + cantontypes "github.com/smartcontractkit/go-daml/pkg/types" + mock_apiv2 "github.com/smartcontractkit/mcms/sdk/canton/mocks/apiv2" + "github.com/smartcontractkit/mcms/types" ) -func TestToConfig(t *testing.T) { +// mockGetActiveContractsClient implements the streaming client for GetActiveContracts +type mockGetActiveContractsClient struct { + grpc.ClientStream + responses []*apiv2.GetActiveContractsResponse + index int +} + +func (m *mockGetActiveContractsClient) Recv() (*apiv2.GetActiveContractsResponse, error) { + if m.index >= len(m.responses) { + return nil, io.EOF + } + resp := m.responses[m.index] + m.index++ + return resp, nil +} + +func (m *mockGetActiveContractsClient) CloseSend() error { + return nil +} + +func TestNewInspector(t *testing.T) { + t.Parallel() + + mockStateClient := mock_apiv2.NewStateServiceClient(t) + party := "Alice::party123" + role := TimelockRoleProposer + + inspector := NewInspector(mockStateClient, party, role) + + require.NotNil(t, inspector) + assert.Equal(t, mockStateClient, inspector.stateClient) + assert.Equal(t, party, inspector.party) + assert.Equal(t, role, inspector.role) + assert.Nil(t, inspector.contractCache) +} + +func TestInspector_GetConfig(t *testing.T) { + t.Parallel() + tests := []struct { - name string - description string - input mcms.MultisigConfig - expected mcmstypes.Config + name string + mcmsAddr string + role TimelockRole + mockSetup func(*mock_apiv2.StateServiceClient) *mcms.MCMS + want *types.Config + wantErr string }{ { - name: "simple_2of3", - description: "Simple 2-of-3 multisig with all signers in root group (group 0)", - input: mcms.MultisigConfig{ - Signers: []mcms.SignerInfo{ - {SignerAddress: types.TEXT("0x1111111111111111111111111111111111111111"), SignerIndex: types.INT64(0), SignerGroup: types.INT64(0)}, - {SignerAddress: types.TEXT("0x2222222222222222222222222222222222222222"), SignerIndex: types.INT64(1), SignerGroup: types.INT64(0)}, - {SignerAddress: types.TEXT("0x3333333333333333333333333333333333333333"), SignerIndex: types.INT64(2), SignerGroup: types.INT64(0)}, - }, - GroupQuorums: []types.INT64{2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - GroupParents: []types.INT64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + name: "success - proposer role", + mcmsAddr: "contract-id-123", + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) *mcms.MCMS { + setupMockGetMCMSContract(t, mockClient, "contract-id-123", &mcms.MCMS{ + Owner: "Alice::party123", + InstanceId: "instance-123", + ChainId: 1, + Proposer: mcms.RoleState{ + Config: mcms.MultisigConfig{ + Signers: []mcms.SignerInfo{ + { + SignerAddress: "1122334455667788", + SignerGroup: 0, + SignerIndex: 0, + }, + { + SignerAddress: "2233445566778899", + SignerGroup: 1, + SignerIndex: 1, + }, + }, + GroupQuorums: []cantontypes.INT64{2, 1}, + GroupParents: []cantontypes.INT64{0, 0}, + }, + }, + }) + return nil }, - expected: mcmstypes.Config{ + want: &types.Config{ Quorum: 2, Signers: []common.Address{ - common.HexToAddress("1111111111111111111111111111111111111111"), - common.HexToAddress("2222222222222222222222222222222222222222"), - common.HexToAddress("3333333333333333333333333333333333333333"), + common.HexToAddress("0x1122334455667788"), + }, + GroupSigners: []types.Config{ + { + Quorum: 1, + Signers: []common.Address{ + common.HexToAddress("0x2233445566778899"), + }, + GroupSigners: []types.Config{}, + }, }, - GroupSigners: []mcmstypes.Config{}, }, + wantErr: "", }, { - name: "hierarchical_2level", - description: "2-level hierarchy: root group 0 has 1 direct signer + group 1 as child. Group 1 has 3 signers with quorum 2. Root quorum is 1 (can be satisfied by direct signer OR group 1 reaching quorum).", - input: mcms.MultisigConfig{ - Signers: []mcms.SignerInfo{ - {SignerAddress: types.TEXT("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), SignerIndex: types.INT64(0), SignerGroup: types.INT64(0)}, - {SignerAddress: types.TEXT("0xbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"), SignerIndex: types.INT64(1), SignerGroup: types.INT64(1)}, - {SignerAddress: types.TEXT("0xcccccccccccccccccccccccccccccccccccccccc"), SignerIndex: types.INT64(2), SignerGroup: types.INT64(1)}, - {SignerAddress: types.TEXT("0xdddddddddddddddddddddddddddddddddddddddd"), SignerIndex: types.INT64(3), SignerGroup: types.INT64(1)}, - }, - GroupQuorums: []types.INT64{1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - GroupParents: []types.INT64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + name: "success - bypasser role", + mcmsAddr: "contract-id-456", + role: TimelockRoleBypasser, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) *mcms.MCMS { + setupMockGetMCMSContract(t, mockClient, "contract-id-456", &mcms.MCMS{ + Owner: "Bob::party456", + InstanceId: "instance-456", + ChainId: 2, + Bypasser: mcms.RoleState{ + Config: mcms.MultisigConfig{ + Signers: []mcms.SignerInfo{ + { + SignerAddress: "aabbccddeeff0011", + SignerGroup: 0, + SignerIndex: 0, + }, + }, + GroupQuorums: []cantontypes.INT64{1}, + GroupParents: []cantontypes.INT64{0}, + }, + }, + }) + return nil }, - expected: mcmstypes.Config{ + want: &types.Config{ Quorum: 1, Signers: []common.Address{ - common.HexToAddress("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), + common.HexToAddress("0xaabbccddeeff0011"), }, - GroupSigners: []mcmstypes.Config{ - { - Quorum: 2, - Signers: []common.Address{ - common.HexToAddress("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"), - common.HexToAddress("cccccccccccccccccccccccccccccccccccccccc"), - common.HexToAddress("dddddddddddddddddddddddddddddddddddddddd"), + GroupSigners: []types.Config{}, + }, + wantErr: "", + }, + { + name: "success - canceller role", + mcmsAddr: "contract-id-789", + role: TimelockRoleCanceller, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) *mcms.MCMS { + setupMockGetMCMSContract(t, mockClient, "contract-id-789", &mcms.MCMS{ + Owner: "Carol::party789", + InstanceId: "instance-789", + ChainId: 3, + Canceller: mcms.RoleState{ + Config: mcms.MultisigConfig{ + Signers: []mcms.SignerInfo{ + { + SignerAddress: "ffeeddccbbaa9988", + SignerGroup: 0, + SignerIndex: 0, + }, + }, + GroupQuorums: []cantontypes.INT64{1}, + GroupParents: []cantontypes.INT64{0}, }, - GroupSigners: []mcmstypes.Config{}, }, + }) + return nil + }, + want: &types.Config{ + Quorum: 1, + Signers: []common.Address{ + common.HexToAddress("0xffeeddccbbaa9988"), }, + GroupSigners: []types.Config{}, }, + wantErr: "", }, { - name: "complex_3level", - description: "3-level hierarchy: Group 0 (root) quorum 2, Group 1 (parent 0) quorum 2, Group 2 (parent 0) quorum 1, Group 3 (parent 1) quorum 2. Tests deeper nesting with multiple child groups at same level.", - input: mcms.MultisigConfig{ - Signers: []mcms.SignerInfo{ - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000001"), SignerIndex: types.INT64(0), SignerGroup: types.INT64(0)}, - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000002"), SignerIndex: types.INT64(1), SignerGroup: types.INT64(1)}, - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000003"), SignerIndex: types.INT64(2), SignerGroup: types.INT64(1)}, - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000004"), SignerIndex: types.INT64(3), SignerGroup: types.INT64(2)}, - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000005"), SignerIndex: types.INT64(4), SignerGroup: types.INT64(2)}, - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000006"), SignerIndex: types.INT64(5), SignerGroup: types.INT64(3)}, - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000007"), SignerIndex: types.INT64(6), SignerGroup: types.INT64(3)}, - {SignerAddress: types.TEXT("0x1000000000000000000000000000000000000008"), SignerIndex: types.INT64(7), SignerGroup: types.INT64(3)}, - }, - GroupQuorums: []types.INT64{2, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - GroupParents: []types.INT64{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + name: "failure - contract not found", + mcmsAddr: "nonexistent-contract", + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) *mcms.MCMS { + mockClient.EXPECT().GetLedgerEnd(mock.Anything, mock.Anything).Return( + &apiv2.GetLedgerEndResponse{ + Offset: 123, + }, + nil, + ) + + streamClient := &mockGetActiveContractsClient{ + responses: []*apiv2.GetActiveContractsResponse{}, + } + mockClient.EXPECT().GetActiveContracts(mock.Anything, mock.Anything).Return( + streamClient, + nil, + ) + return nil }, - expected: mcmstypes.Config{ - Quorum: 2, - Signers: []common.Address{ - common.HexToAddress("1000000000000000000000000000000000000001"), - }, - GroupSigners: []mcmstypes.Config{ - { - Quorum: 2, - Signers: []common.Address{ - common.HexToAddress("1000000000000000000000000000000000000002"), - common.HexToAddress("1000000000000000000000000000000000000003"), + want: nil, + wantErr: "MCMS contract with ID nonexistent-contract not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + mockStateClient := mock_apiv2.NewStateServiceClient(t) + if tt.mockSetup != nil { + tt.mockSetup(mockStateClient) + } + + inspector := NewInspector(mockStateClient, "Alice::party123", tt.role) + + got, err := inspector.GetConfig(ctx, tt.mcmsAddr) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Nil(t, got) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestInspector_GetOpCount(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mcmsAddr string + role TimelockRole + mockSetup func(*mock_apiv2.StateServiceClient) + want uint64 + wantErr string + }{ + { + name: "success - proposer role", + mcmsAddr: "contract-id-123", + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-123", &mcms.MCMS{ + Proposer: mcms.RoleState{ + ExpiringRoot: mcms.ExpiringRoot{ + OpCount: 5, }, - GroupSigners: []mcmstypes.Config{ - { - Quorum: 2, - Signers: []common.Address{ - common.HexToAddress("1000000000000000000000000000000000000006"), - common.HexToAddress("1000000000000000000000000000000000000007"), - common.HexToAddress("1000000000000000000000000000000000000008"), - }, - GroupSigners: []mcmstypes.Config{}, - }, + }, + }) + }, + want: 5, + wantErr: "", + }, + { + name: "success - bypasser role", + mcmsAddr: "contract-id-456", + role: TimelockRoleBypasser, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-456", &mcms.MCMS{ + Bypasser: mcms.RoleState{ + ExpiringRoot: mcms.ExpiringRoot{ + OpCount: 10, }, }, - { - Quorum: 1, - Signers: []common.Address{ - common.HexToAddress("1000000000000000000000000000000000000004"), - common.HexToAddress("1000000000000000000000000000000000000005"), + }) + }, + want: 10, + wantErr: "", + }, + { + name: "success - canceller role with zero op count", + mcmsAddr: "contract-id-789", + role: TimelockRoleCanceller, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-789", &mcms.MCMS{ + Canceller: mcms.RoleState{ + ExpiringRoot: mcms.ExpiringRoot{ + OpCount: 0, }, - GroupSigners: []mcmstypes.Config{}, + }, + }) + }, + want: 0, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + mockStateClient := mock_apiv2.NewStateServiceClient(t) + if tt.mockSetup != nil { + tt.mockSetup(mockStateClient) + } + + inspector := NewInspector(mockStateClient, "Alice::party123", tt.role) + + got, err := inspector.GetOpCount(ctx, tt.mcmsAddr) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestInspector_GetRoot(t *testing.T) { + t.Parallel() + + validUntilTime := time.Date(2026, 3, 15, 10, 30, 0, 0, time.UTC) + + tests := []struct { + name string + mcmsAddr string + role TimelockRole + mockSetup func(*mock_apiv2.StateServiceClient) + wantRoot common.Hash + wantValidUntil uint32 + wantErr string + }{ + { + name: "success - proposer role", + mcmsAddr: "contract-id-123", + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-123", &mcms.MCMS{ + Proposer: mcms.RoleState{ + ExpiringRoot: mcms.ExpiringRoot{ + Root: "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef", + ValidUntil: cantontypes.TIMESTAMP(validUntilTime), + }, + }, + }) + }, + wantRoot: common.HexToHash("0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef"), + wantValidUntil: uint32(validUntilTime.Unix()), + wantErr: "", + }, + { + name: "success - bypasser role", + mcmsAddr: "contract-id-456", + role: TimelockRoleBypasser, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-456", &mcms.MCMS{ + Bypasser: mcms.RoleState{ + ExpiringRoot: mcms.ExpiringRoot{ + Root: "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890", + ValidUntil: cantontypes.TIMESTAMP(validUntilTime), + }, + }, + }) + }, + wantRoot: common.HexToHash("0xabcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890"), + wantValidUntil: uint32(validUntilTime.Unix()), + wantErr: "", + }, + { + name: "failure - invalid root hex", + mcmsAddr: "contract-id-bad", + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-bad", &mcms.MCMS{ + Proposer: mcms.RoleState{ + ExpiringRoot: mcms.ExpiringRoot{ + Root: "invalid-hex-string", + ValidUntil: cantontypes.TIMESTAMP(validUntilTime), + }, + }, + }) + }, + wantRoot: common.Hash{}, + wantValidUntil: 0, + wantErr: "failed to decode root hash", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + mockStateClient := mock_apiv2.NewStateServiceClient(t) + if tt.mockSetup != nil { + tt.mockSetup(mockStateClient) + } + + inspector := NewInspector(mockStateClient, "Alice::party123", tt.role) + + gotRoot, gotValidUntil, err := inspector.GetRoot(ctx, tt.mcmsAddr) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantRoot, gotRoot) + assert.Equal(t, tt.wantValidUntil, gotValidUntil) + } + }) + } +} + +func TestInspector_GetRootMetadata(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mcmsAddr string + role TimelockRole + mockSetup func(*mock_apiv2.StateServiceClient) + want types.ChainMetadata + wantErr string + }{ + { + name: "success - proposer role", + mcmsAddr: "contract-id-123", + role: TimelockRoleProposer, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-123", &mcms.MCMS{ + InstanceId: "instance-123", + Proposer: mcms.RoleState{ + RootMetadata: mcms.RootMetadata{ + ChainId: 1, + PreOpCount: 5, + PostOpCount: 10, + }, + }, + }) + }, + want: types.ChainMetadata{ + StartingOpCount: 5, + MCMAddress: "instance-123", + }, + wantErr: "", + }, + { + name: "success - bypasser role", + mcmsAddr: "contract-id-456", + role: TimelockRoleBypasser, + mockSetup: func(mockClient *mock_apiv2.StateServiceClient) { + setupMockGetMCMSContract(t, mockClient, "contract-id-456", &mcms.MCMS{ + InstanceId: "instance-456", + Bypasser: mcms.RoleState{ + RootMetadata: mcms.RootMetadata{ + ChainId: 2, + PreOpCount: 0, + PostOpCount: 3, + }, + }, + }) + }, + want: types.ChainMetadata{ + StartingOpCount: 0, + MCMAddress: "instance-456", + }, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + mockStateClient := mock_apiv2.NewStateServiceClient(t) + if tt.mockSetup != nil { + tt.mockSetup(mockStateClient) + } + + inspector := NewInspector(mockStateClient, "Alice::party123", tt.role) + + got, err := inspector.GetRootMetadata(ctx, tt.mcmsAddr) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestToConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + bindConfig mcms.MultisigConfig + want *types.Config + wantErr string + }{ + { + name: "success - simple single group", + bindConfig: mcms.MultisigConfig{ + Signers: []mcms.SignerInfo{ + { + SignerAddress: "1122334455667788", + SignerGroup: 0, + SignerIndex: 0, + }, + { + SignerAddress: "2233445566778899", + SignerGroup: 0, + SignerIndex: 1, }, }, + GroupQuorums: []cantontypes.INT64{2}, + GroupParents: []cantontypes.INT64{0}, }, + want: &types.Config{ + Quorum: 2, + Signers: []common.Address{ + common.HexToAddress("0x1122334455667788"), + common.HexToAddress("0x2233445566778899"), + }, + GroupSigners: []types.Config{}, + }, + wantErr: "", }, { - name: "empty_groups_edge_case", - description: "Edge case: groups with quorum 0 (disabled) interspersed with active groups. Group 0 active (quorum 1), Group 1 disabled (quorum 0), Group 2 active (quorum 2, parent 0). The toConfig function should skip disabled groups.", - input: mcms.MultisigConfig{ + name: "success - hierarchical groups", + bindConfig: mcms.MultisigConfig{ Signers: []mcms.SignerInfo{ - {SignerAddress: types.TEXT("0xdead000000000000000000000000000000000001"), SignerIndex: types.INT64(0), SignerGroup: types.INT64(0)}, - {SignerAddress: types.TEXT("0xdead000000000000000000000000000000000002"), SignerIndex: types.INT64(1), SignerGroup: types.INT64(2)}, - {SignerAddress: types.TEXT("0xdead000000000000000000000000000000000003"), SignerIndex: types.INT64(2), SignerGroup: types.INT64(2)}, + { + SignerAddress: "1122334455667788", + SignerGroup: 0, + SignerIndex: 0, + }, + { + SignerAddress: "2233445566778899", + SignerGroup: 1, + SignerIndex: 1, + }, + { + SignerAddress: "3344556677889900", + SignerGroup: 1, + SignerIndex: 2, + }, }, - GroupQuorums: []types.INT64{1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - GroupParents: []types.INT64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + GroupQuorums: []cantontypes.INT64{2, 2}, + GroupParents: []cantontypes.INT64{0, 0}, }, - expected: mcmstypes.Config{ - Quorum: 1, + want: &types.Config{ + Quorum: 2, Signers: []common.Address{ - common.HexToAddress("dead000000000000000000000000000000000001"), + common.HexToAddress("0x1122334455667788"), }, - GroupSigners: []mcmstypes.Config{ + GroupSigners: []types.Config{ { Quorum: 2, Signers: []common.Address{ - common.HexToAddress("dead000000000000000000000000000000000002"), - common.HexToAddress("dead000000000000000000000000000000000003"), + common.HexToAddress("0x2233445566778899"), + common.HexToAddress("0x3344556677889900"), }, - GroupSigners: []mcmstypes.Config{}, + GroupSigners: []types.Config{}, + }, + }, + }, + wantErr: "", + }, + { + name: "failure - empty config", + bindConfig: mcms.MultisigConfig{ + Signers: []mcms.SignerInfo{}, + GroupQuorums: []cantontypes.INT64{}, + GroupParents: []cantontypes.INT64{}, + }, + want: nil, + wantErr: "Quorum must be greater than 0", + }, + { + name: "failure - group index exceeds maximum", + bindConfig: mcms.MultisigConfig{ + Signers: []mcms.SignerInfo{ + { + SignerAddress: "1122334455667788", + SignerGroup: 32, // Exceeds maximum of 31 + SignerIndex: 0, }, }, + GroupQuorums: []cantontypes.INT64{1}, + GroupParents: []cantontypes.INT64{0}, }, + want: nil, + wantErr: "signer group index 32 exceeds maximum of 31", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := toConfig(tt.input) - require.NoError(t, err, tt.description) - require.NotNil(t, result) + t.Parallel() - // Compare the result with expected - require.Equal(t, tt.expected.Quorum, result.Quorum, "quorum mismatch") - require.Equal(t, len(tt.expected.Signers), len(result.Signers), "signers count mismatch") + got, err := toConfig(tt.bindConfig) - // Compare signers - for i, expectedSigner := range tt.expected.Signers { - require.Equal(t, expectedSigner, result.Signers[i], "signer mismatch at index %d", i) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Nil(t, got) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) } - - // Compare group signers recursively - compareGroupSigners(t, tt.expected.GroupSigners, result.GroupSigners) }) } } -func compareGroupSigners(t *testing.T, expected, actual []mcmstypes.Config) { - require.Equal(t, len(expected), len(actual), "group signers count mismatch") +// Helper function to setup mock for getMCMSContract +func setupMockGetMCMSContract(t *testing.T, mockClient *mock_apiv2.StateServiceClient, contractID string, mcmsContract *mcms.MCMS) { + // Mock GetLedgerEnd + mockClient.EXPECT().GetLedgerEnd(mock.Anything, mock.Anything).Return( + &apiv2.GetLedgerEndResponse{ + Offset: 123, + }, + nil, + ) - for i := range expected { - require.Equal(t, expected[i].Quorum, actual[i].Quorum, "group %d quorum mismatch", i) - require.Equal(t, len(expected[i].Signers), len(actual[i].Signers), "group %d signers count mismatch", i) + // Create the created event for the MCMS contract + createdEvent := &apiv2.CreatedEvent{ + ContractId: contractID, + TemplateId: &apiv2.Identifier{ + PackageId: "#mcms", + ModuleName: "MCMS.Main", + EntityName: "MCMS", + }, + } - for j, expectedSigner := range expected[i].Signers { - require.Equal(t, expectedSigner, actual[i].Signers[j], "group %d signer mismatch at index %d", i, j) - } + // Create response with the active contract + responses := []*apiv2.GetActiveContractsResponse{ + { + ContractEntry: &apiv2.GetActiveContractsResponse_ActiveContract{ + ActiveContract: &apiv2.ActiveContract{ + CreatedEvent: createdEvent, + }, + }, + }, + } - // Recursively compare nested group signers - compareGroupSigners(t, expected[i].GroupSigners, actual[i].GroupSigners) + streamClient := &mockGetActiveContractsClient{ + responses: responses, } + + mockClient.EXPECT().GetActiveContracts(mock.Anything, mock.Anything).Return( + streamClient, + nil, + ) } diff --git a/sdk/canton/mocks/apiv2/command_service_client.go b/sdk/canton/mocks/apiv2/command_service_client.go new file mode 100644 index 00000000..06ecb3fe --- /dev/null +++ b/sdk/canton/mocks/apiv2/command_service_client.go @@ -0,0 +1,262 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mock_apiv2 + +import ( + context "context" + + grpc "google.golang.org/grpc" + + mock "github.com/stretchr/testify/mock" + + v2 "github.com/digital-asset/dazl-client/v8/go/api/com/daml/ledger/api/v2" +) + +// CommandServiceClient is an autogenerated mock type for the CommandServiceClient type +type CommandServiceClient struct { + mock.Mock +} + +type CommandServiceClient_Expecter struct { + mock *mock.Mock +} + +func (_m *CommandServiceClient) EXPECT() *CommandServiceClient_Expecter { + return &CommandServiceClient_Expecter{mock: &_m.Mock} +} + +// SubmitAndWait provides a mock function with given fields: ctx, in, opts +func (_m *CommandServiceClient) SubmitAndWait(ctx context.Context, in *v2.SubmitAndWaitRequest, opts ...grpc.CallOption) (*v2.SubmitAndWaitResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for SubmitAndWait") + } + + var r0 *v2.SubmitAndWaitResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.SubmitAndWaitRequest, ...grpc.CallOption) (*v2.SubmitAndWaitResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *v2.SubmitAndWaitRequest, ...grpc.CallOption) *v2.SubmitAndWaitResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v2.SubmitAndWaitResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *v2.SubmitAndWaitRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CommandServiceClient_SubmitAndWait_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SubmitAndWait' +type CommandServiceClient_SubmitAndWait_Call struct { + *mock.Call +} + +// SubmitAndWait is a helper method to define mock.On call +// - ctx context.Context +// - in *v2.SubmitAndWaitRequest +// - opts ...grpc.CallOption +func (_e *CommandServiceClient_Expecter) SubmitAndWait(ctx interface{}, in interface{}, opts ...interface{}) *CommandServiceClient_SubmitAndWait_Call { + return &CommandServiceClient_SubmitAndWait_Call{Call: _e.mock.On("SubmitAndWait", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *CommandServiceClient_SubmitAndWait_Call) Run(run func(ctx context.Context, in *v2.SubmitAndWaitRequest, opts ...grpc.CallOption)) *CommandServiceClient_SubmitAndWait_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*v2.SubmitAndWaitRequest), variadicArgs...) + }) + return _c +} + +func (_c *CommandServiceClient_SubmitAndWait_Call) Return(_a0 *v2.SubmitAndWaitResponse, _a1 error) *CommandServiceClient_SubmitAndWait_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CommandServiceClient_SubmitAndWait_Call) RunAndReturn(run func(context.Context, *v2.SubmitAndWaitRequest, ...grpc.CallOption) (*v2.SubmitAndWaitResponse, error)) *CommandServiceClient_SubmitAndWait_Call { + _c.Call.Return(run) + return _c +} + +// SubmitAndWaitForReassignment provides a mock function with given fields: ctx, in, opts +func (_m *CommandServiceClient) SubmitAndWaitForReassignment(ctx context.Context, in *v2.SubmitAndWaitForReassignmentRequest, opts ...grpc.CallOption) (*v2.SubmitAndWaitForReassignmentResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for SubmitAndWaitForReassignment") + } + + var r0 *v2.SubmitAndWaitForReassignmentResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.SubmitAndWaitForReassignmentRequest, ...grpc.CallOption) (*v2.SubmitAndWaitForReassignmentResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *v2.SubmitAndWaitForReassignmentRequest, ...grpc.CallOption) *v2.SubmitAndWaitForReassignmentResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v2.SubmitAndWaitForReassignmentResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *v2.SubmitAndWaitForReassignmentRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CommandServiceClient_SubmitAndWaitForReassignment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SubmitAndWaitForReassignment' +type CommandServiceClient_SubmitAndWaitForReassignment_Call struct { + *mock.Call +} + +// SubmitAndWaitForReassignment is a helper method to define mock.On call +// - ctx context.Context +// - in *v2.SubmitAndWaitForReassignmentRequest +// - opts ...grpc.CallOption +func (_e *CommandServiceClient_Expecter) SubmitAndWaitForReassignment(ctx interface{}, in interface{}, opts ...interface{}) *CommandServiceClient_SubmitAndWaitForReassignment_Call { + return &CommandServiceClient_SubmitAndWaitForReassignment_Call{Call: _e.mock.On("SubmitAndWaitForReassignment", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *CommandServiceClient_SubmitAndWaitForReassignment_Call) Run(run func(ctx context.Context, in *v2.SubmitAndWaitForReassignmentRequest, opts ...grpc.CallOption)) *CommandServiceClient_SubmitAndWaitForReassignment_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*v2.SubmitAndWaitForReassignmentRequest), variadicArgs...) + }) + return _c +} + +func (_c *CommandServiceClient_SubmitAndWaitForReassignment_Call) Return(_a0 *v2.SubmitAndWaitForReassignmentResponse, _a1 error) *CommandServiceClient_SubmitAndWaitForReassignment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CommandServiceClient_SubmitAndWaitForReassignment_Call) RunAndReturn(run func(context.Context, *v2.SubmitAndWaitForReassignmentRequest, ...grpc.CallOption) (*v2.SubmitAndWaitForReassignmentResponse, error)) *CommandServiceClient_SubmitAndWaitForReassignment_Call { + _c.Call.Return(run) + return _c +} + +// SubmitAndWaitForTransaction provides a mock function with given fields: ctx, in, opts +func (_m *CommandServiceClient) SubmitAndWaitForTransaction(ctx context.Context, in *v2.SubmitAndWaitForTransactionRequest, opts ...grpc.CallOption) (*v2.SubmitAndWaitForTransactionResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for SubmitAndWaitForTransaction") + } + + var r0 *v2.SubmitAndWaitForTransactionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.SubmitAndWaitForTransactionRequest, ...grpc.CallOption) (*v2.SubmitAndWaitForTransactionResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *v2.SubmitAndWaitForTransactionRequest, ...grpc.CallOption) *v2.SubmitAndWaitForTransactionResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v2.SubmitAndWaitForTransactionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *v2.SubmitAndWaitForTransactionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CommandServiceClient_SubmitAndWaitForTransaction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SubmitAndWaitForTransaction' +type CommandServiceClient_SubmitAndWaitForTransaction_Call struct { + *mock.Call +} + +// SubmitAndWaitForTransaction is a helper method to define mock.On call +// - ctx context.Context +// - in *v2.SubmitAndWaitForTransactionRequest +// - opts ...grpc.CallOption +func (_e *CommandServiceClient_Expecter) SubmitAndWaitForTransaction(ctx interface{}, in interface{}, opts ...interface{}) *CommandServiceClient_SubmitAndWaitForTransaction_Call { + return &CommandServiceClient_SubmitAndWaitForTransaction_Call{Call: _e.mock.On("SubmitAndWaitForTransaction", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *CommandServiceClient_SubmitAndWaitForTransaction_Call) Run(run func(ctx context.Context, in *v2.SubmitAndWaitForTransactionRequest, opts ...grpc.CallOption)) *CommandServiceClient_SubmitAndWaitForTransaction_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*v2.SubmitAndWaitForTransactionRequest), variadicArgs...) + }) + return _c +} + +func (_c *CommandServiceClient_SubmitAndWaitForTransaction_Call) Return(_a0 *v2.SubmitAndWaitForTransactionResponse, _a1 error) *CommandServiceClient_SubmitAndWaitForTransaction_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *CommandServiceClient_SubmitAndWaitForTransaction_Call) RunAndReturn(run func(context.Context, *v2.SubmitAndWaitForTransactionRequest, ...grpc.CallOption) (*v2.SubmitAndWaitForTransactionResponse, error)) *CommandServiceClient_SubmitAndWaitForTransaction_Call { + _c.Call.Return(run) + return _c +} + +// NewCommandServiceClient creates a new instance of CommandServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCommandServiceClient(t interface { + mock.TestingT + Cleanup(func()) +}) *CommandServiceClient { + mock := &CommandServiceClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/sdk/canton/mocks/apiv2/state_service_client.go b/sdk/canton/mocks/apiv2/state_service_client.go new file mode 100644 index 00000000..998db25e --- /dev/null +++ b/sdk/canton/mocks/apiv2/state_service_client.go @@ -0,0 +1,336 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mock_apiv2 + +import ( + context "context" + + grpc "google.golang.org/grpc" + + mock "github.com/stretchr/testify/mock" + + v2 "github.com/digital-asset/dazl-client/v8/go/api/com/daml/ledger/api/v2" +) + +// StateServiceClient is an autogenerated mock type for the StateServiceClient type +type StateServiceClient struct { + mock.Mock +} + +type StateServiceClient_Expecter struct { + mock *mock.Mock +} + +func (_m *StateServiceClient) EXPECT() *StateServiceClient_Expecter { + return &StateServiceClient_Expecter{mock: &_m.Mock} +} + +// GetActiveContracts provides a mock function with given fields: ctx, in, opts +func (_m *StateServiceClient) GetActiveContracts(ctx context.Context, in *v2.GetActiveContractsRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[v2.GetActiveContractsResponse], error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetActiveContracts") + } + + var r0 grpc.ServerStreamingClient[v2.GetActiveContractsResponse] + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetActiveContractsRequest, ...grpc.CallOption) (grpc.ServerStreamingClient[v2.GetActiveContractsResponse], error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetActiveContractsRequest, ...grpc.CallOption) grpc.ServerStreamingClient[v2.GetActiveContractsResponse]); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(grpc.ServerStreamingClient[v2.GetActiveContractsResponse]) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *v2.GetActiveContractsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StateServiceClient_GetActiveContracts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetActiveContracts' +type StateServiceClient_GetActiveContracts_Call struct { + *mock.Call +} + +// GetActiveContracts is a helper method to define mock.On call +// - ctx context.Context +// - in *v2.GetActiveContractsRequest +// - opts ...grpc.CallOption +func (_e *StateServiceClient_Expecter) GetActiveContracts(ctx interface{}, in interface{}, opts ...interface{}) *StateServiceClient_GetActiveContracts_Call { + return &StateServiceClient_GetActiveContracts_Call{Call: _e.mock.On("GetActiveContracts", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *StateServiceClient_GetActiveContracts_Call) Run(run func(ctx context.Context, in *v2.GetActiveContractsRequest, opts ...grpc.CallOption)) *StateServiceClient_GetActiveContracts_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*v2.GetActiveContractsRequest), variadicArgs...) + }) + return _c +} + +func (_c *StateServiceClient_GetActiveContracts_Call) Return(_a0 grpc.ServerStreamingClient[v2.GetActiveContractsResponse], _a1 error) *StateServiceClient_GetActiveContracts_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *StateServiceClient_GetActiveContracts_Call) RunAndReturn(run func(context.Context, *v2.GetActiveContractsRequest, ...grpc.CallOption) (grpc.ServerStreamingClient[v2.GetActiveContractsResponse], error)) *StateServiceClient_GetActiveContracts_Call { + _c.Call.Return(run) + return _c +} + +// GetConnectedSynchronizers provides a mock function with given fields: ctx, in, opts +func (_m *StateServiceClient) GetConnectedSynchronizers(ctx context.Context, in *v2.GetConnectedSynchronizersRequest, opts ...grpc.CallOption) (*v2.GetConnectedSynchronizersResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetConnectedSynchronizers") + } + + var r0 *v2.GetConnectedSynchronizersResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetConnectedSynchronizersRequest, ...grpc.CallOption) (*v2.GetConnectedSynchronizersResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetConnectedSynchronizersRequest, ...grpc.CallOption) *v2.GetConnectedSynchronizersResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v2.GetConnectedSynchronizersResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *v2.GetConnectedSynchronizersRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StateServiceClient_GetConnectedSynchronizers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConnectedSynchronizers' +type StateServiceClient_GetConnectedSynchronizers_Call struct { + *mock.Call +} + +// GetConnectedSynchronizers is a helper method to define mock.On call +// - ctx context.Context +// - in *v2.GetConnectedSynchronizersRequest +// - opts ...grpc.CallOption +func (_e *StateServiceClient_Expecter) GetConnectedSynchronizers(ctx interface{}, in interface{}, opts ...interface{}) *StateServiceClient_GetConnectedSynchronizers_Call { + return &StateServiceClient_GetConnectedSynchronizers_Call{Call: _e.mock.On("GetConnectedSynchronizers", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *StateServiceClient_GetConnectedSynchronizers_Call) Run(run func(ctx context.Context, in *v2.GetConnectedSynchronizersRequest, opts ...grpc.CallOption)) *StateServiceClient_GetConnectedSynchronizers_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*v2.GetConnectedSynchronizersRequest), variadicArgs...) + }) + return _c +} + +func (_c *StateServiceClient_GetConnectedSynchronizers_Call) Return(_a0 *v2.GetConnectedSynchronizersResponse, _a1 error) *StateServiceClient_GetConnectedSynchronizers_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *StateServiceClient_GetConnectedSynchronizers_Call) RunAndReturn(run func(context.Context, *v2.GetConnectedSynchronizersRequest, ...grpc.CallOption) (*v2.GetConnectedSynchronizersResponse, error)) *StateServiceClient_GetConnectedSynchronizers_Call { + _c.Call.Return(run) + return _c +} + +// GetLatestPrunedOffsets provides a mock function with given fields: ctx, in, opts +func (_m *StateServiceClient) GetLatestPrunedOffsets(ctx context.Context, in *v2.GetLatestPrunedOffsetsRequest, opts ...grpc.CallOption) (*v2.GetLatestPrunedOffsetsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetLatestPrunedOffsets") + } + + var r0 *v2.GetLatestPrunedOffsetsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetLatestPrunedOffsetsRequest, ...grpc.CallOption) (*v2.GetLatestPrunedOffsetsResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetLatestPrunedOffsetsRequest, ...grpc.CallOption) *v2.GetLatestPrunedOffsetsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v2.GetLatestPrunedOffsetsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *v2.GetLatestPrunedOffsetsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StateServiceClient_GetLatestPrunedOffsets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLatestPrunedOffsets' +type StateServiceClient_GetLatestPrunedOffsets_Call struct { + *mock.Call +} + +// GetLatestPrunedOffsets is a helper method to define mock.On call +// - ctx context.Context +// - in *v2.GetLatestPrunedOffsetsRequest +// - opts ...grpc.CallOption +func (_e *StateServiceClient_Expecter) GetLatestPrunedOffsets(ctx interface{}, in interface{}, opts ...interface{}) *StateServiceClient_GetLatestPrunedOffsets_Call { + return &StateServiceClient_GetLatestPrunedOffsets_Call{Call: _e.mock.On("GetLatestPrunedOffsets", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *StateServiceClient_GetLatestPrunedOffsets_Call) Run(run func(ctx context.Context, in *v2.GetLatestPrunedOffsetsRequest, opts ...grpc.CallOption)) *StateServiceClient_GetLatestPrunedOffsets_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*v2.GetLatestPrunedOffsetsRequest), variadicArgs...) + }) + return _c +} + +func (_c *StateServiceClient_GetLatestPrunedOffsets_Call) Return(_a0 *v2.GetLatestPrunedOffsetsResponse, _a1 error) *StateServiceClient_GetLatestPrunedOffsets_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *StateServiceClient_GetLatestPrunedOffsets_Call) RunAndReturn(run func(context.Context, *v2.GetLatestPrunedOffsetsRequest, ...grpc.CallOption) (*v2.GetLatestPrunedOffsetsResponse, error)) *StateServiceClient_GetLatestPrunedOffsets_Call { + _c.Call.Return(run) + return _c +} + +// GetLedgerEnd provides a mock function with given fields: ctx, in, opts +func (_m *StateServiceClient) GetLedgerEnd(ctx context.Context, in *v2.GetLedgerEndRequest, opts ...grpc.CallOption) (*v2.GetLedgerEndResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetLedgerEnd") + } + + var r0 *v2.GetLedgerEndResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetLedgerEndRequest, ...grpc.CallOption) (*v2.GetLedgerEndResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *v2.GetLedgerEndRequest, ...grpc.CallOption) *v2.GetLedgerEndResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v2.GetLedgerEndResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *v2.GetLedgerEndRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StateServiceClient_GetLedgerEnd_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetLedgerEnd' +type StateServiceClient_GetLedgerEnd_Call struct { + *mock.Call +} + +// GetLedgerEnd is a helper method to define mock.On call +// - ctx context.Context +// - in *v2.GetLedgerEndRequest +// - opts ...grpc.CallOption +func (_e *StateServiceClient_Expecter) GetLedgerEnd(ctx interface{}, in interface{}, opts ...interface{}) *StateServiceClient_GetLedgerEnd_Call { + return &StateServiceClient_GetLedgerEnd_Call{Call: _e.mock.On("GetLedgerEnd", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *StateServiceClient_GetLedgerEnd_Call) Run(run func(ctx context.Context, in *v2.GetLedgerEndRequest, opts ...grpc.CallOption)) *StateServiceClient_GetLedgerEnd_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*v2.GetLedgerEndRequest), variadicArgs...) + }) + return _c +} + +func (_c *StateServiceClient_GetLedgerEnd_Call) Return(_a0 *v2.GetLedgerEndResponse, _a1 error) *StateServiceClient_GetLedgerEnd_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *StateServiceClient_GetLedgerEnd_Call) RunAndReturn(run func(context.Context, *v2.GetLedgerEndRequest, ...grpc.CallOption) (*v2.GetLedgerEndResponse, error)) *StateServiceClient_GetLedgerEnd_Call { + _c.Call.Return(run) + return _c +} + +// NewStateServiceClient creates a new instance of StateServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewStateServiceClient(t interface { + mock.TestingT + Cleanup(func()) +}) *StateServiceClient { + mock := &StateServiceClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/sdk/mocks/logger.go b/sdk/mocks/logger.go new file mode 100644 index 00000000..378e8c46 --- /dev/null +++ b/sdk/mocks/logger.go @@ -0,0 +1,76 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Logger is an autogenerated mock type for the Logger type +type Logger struct { + mock.Mock +} + +type Logger_Expecter struct { + mock *mock.Mock +} + +func (_m *Logger) EXPECT() *Logger_Expecter { + return &Logger_Expecter{mock: &_m.Mock} +} + +// Infof provides a mock function with given fields: template, args +func (_m *Logger) Infof(template string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, template) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// Logger_Infof_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Infof' +type Logger_Infof_Call struct { + *mock.Call +} + +// Infof is a helper method to define mock.On call +// - template string +// - args ...interface{} +func (_e *Logger_Expecter) Infof(template interface{}, args ...interface{}) *Logger_Infof_Call { + return &Logger_Infof_Call{Call: _e.mock.On("Infof", + append([]interface{}{template}, args...)...)} +} + +func (_c *Logger_Infof_Call) Run(run func(template string, args ...interface{})) *Logger_Infof_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(string), variadicArgs...) + }) + return _c +} + +func (_c *Logger_Infof_Call) Return() *Logger_Infof_Call { + _c.Call.Return() + return _c +} + +func (_c *Logger_Infof_Call) RunAndReturn(run func(string, ...interface{})) *Logger_Infof_Call { + _c.Run(run) + return _c +} + +// NewLogger creates a new instance of Logger. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewLogger(t interface { + mock.TestingT + Cleanup(func()) +}) *Logger { + mock := &Logger{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}