diff --git a/go.mod b/go.mod index a837f4d..fb62af5 100644 --- a/go.mod +++ b/go.mod @@ -21,9 +21,11 @@ require ( github.com/smartcontractkit/chainlink-common v0.11.2-0.20260518100439-9564f35fd264 github.com/smartcontractkit/chainlink-common/keystore v1.1.0 github.com/smartcontractkit/chainlink-evm/gethwrappers v0.0.0-20260512150409-b4068bf735e6 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260514104516-a827acdffe43 github.com/smartcontractkit/libocr v0.0.0-20260508200755-99940c85383c github.com/smartcontractkit/wsrpc v0.8.5-0.20250502134807-c57d3d995945 github.com/stretchr/testify v1.11.1 + go.uber.org/zap v1.28.0 golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a golang.org/x/sync v0.20.0 google.golang.org/grpc v1.81.0 @@ -142,7 +144,6 @@ require ( github.com/shirou/gopsutil v3.21.11+incompatible // indirect github.com/smartcontractkit/chain-selectors v1.0.100 // indirect github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 // indirect - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260514104516-a827acdffe43 // indirect github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20260512230622-65f10f4cd305 // indirect github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260512230622-65f10f4cd305 // indirect github.com/smartcontractkit/freeport v0.1.3-0.20250828155247-add56fa28aad // indirect @@ -182,7 +183,6 @@ require ( go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect - go.uber.org/zap v1.28.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect go.yaml.in/yaml/v4 v4.0.0-rc.4 // indirect golang.org/x/crypto v0.51.0 // indirect diff --git a/llo/cre/report_codec.go b/llo/cre/report_codec.go new file mode 100644 index 0000000..caf852a --- /dev/null +++ b/llo/cre/report_codec.go @@ -0,0 +1,166 @@ +package cre + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + + "github.com/shopspring/decimal" + "google.golang.org/protobuf/proto" + + commonds "github.com/smartcontractkit/chainlink-common/pkg/capabilities/datastreams" + capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + datastreamsllo "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink-protos/cre/go/values" +) + +var _ datastreamsllo.ReportCodec = ReportCodecCapabilityTrigger{} + +type ReportCodecCapabilityTrigger struct { + lggr logger.Logger + donID uint32 +} + +func NewReportCodecCapabilityTrigger(lggr logger.Logger, donID uint32) ReportCodecCapabilityTrigger { + return ReportCodecCapabilityTrigger{lggr, donID} +} + +type ReportCodecCapabilityTriggerMultiplier struct { + Multiplier decimal.Decimal `json:"multiplier"` + StreamID llotypes.StreamID `json:"streamID"` +} + +// Opts format remains unchanged +type ReportCodecCapabilityTriggerOpts struct { + // EXAMPLE + // + // [{streamID: 1000000001, "multiplier":"10000"}, ...] + // + // The total number of streams must be n, where n is the number of + // top-level elements in this ReportCodecCapabilityTriggerMultipliers array + Multipliers []ReportCodecCapabilityTriggerMultiplier `json:"multipliers"` +} + +func (r *ReportCodecCapabilityTriggerOpts) Decode(opts []byte) error { + if len(opts) == 0 { + return nil + } + decoder := json.NewDecoder(bytes.NewReader(opts)) + decoder.DisallowUnknownFields() // Error on unrecognized fields + return decoder.Decode(r) +} + +func (r *ReportCodecCapabilityTriggerOpts) Encode() ([]byte, error) { + return json.Marshal(r) +} + +// Encode a report into a capability trigger report +// the returned byte slice is the marshaled protobuf of [capabilitiespb.OCRTriggerReport] +func (r ReportCodecCapabilityTrigger) Encode(report datastreamsllo.Report, cd llotypes.ChannelDefinition, optsCache *datastreamsllo.OptsCache) ([]byte, error) { + if len(cd.Streams) != len(report.Values) { + // Invariant violation + return nil, fmt.Errorf("capability trigger expected %d streams, got %d", len(cd.Streams), len(report.Values)) + } + if report.Specimen { + // Not supported for now + return nil, errors.New("capability trigger encoder does not currently support specimen reports") + } + + var opts ReportCodecCapabilityTriggerOpts + var err error + opts, err = datastreamsllo.GetOpts[ReportCodecCapabilityTriggerOpts](optsCache, report.ChannelID) + if err != nil { + return nil, fmt.Errorf("failed to get opts: %w", err) + } + + payload := make([]*commonds.LLOStreamDecimal, len(report.Values)) + for i, stream := range report.Values { + var d []byte + switch v := stream.(type) { + case nil: + // Missing observations are nil + case *datastreamsllo.Decimal: + multipliedStreamValue := v.Decimal() + + if len(opts.Multipliers) != 0 { + multipliedStreamValue = multipliedStreamValue.Mul(opts.Multipliers[i].Multiplier) + } + + var err error + d, err = multipliedStreamValue.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("failed to marshal decimal: %w", err) + } + default: + return nil, fmt.Errorf("only decimal StreamValues are supported, got: %T", stream) + } + payload[i] = &commonds.LLOStreamDecimal{ + StreamID: cd.Streams[i].StreamID, + Decimal: d, + } + } + + ste := commonds.LLOStreamsTriggerEvent{ + Payload: payload, + ObservationTimestampNanoseconds: report.ObservationTimestampNanoseconds, + } + outputs, err := values.WrapMap(ste) + if err != nil { + return nil, fmt.Errorf("failed to wrap map: %w", err) + } + p := &capabilitiespb.OCRTriggerReport{ + EventID: r.EventID(report), + Timestamp: report.ObservationTimestampNanoseconds, + Outputs: values.ProtoMap(outputs), + } + + b, err := proto.MarshalOptions{Deterministic: true}.Marshal(p) + if err != nil { + return nil, fmt.Errorf("failed to marshal capability trigger report: %w", err) + } + return b, nil +} + +func (r ReportCodecCapabilityTrigger) Verify(cd llotypes.ChannelDefinition) error { + opts := new(ReportCodecCapabilityTriggerOpts) + if err := opts.Decode(cd.Opts); err != nil { + return fmt.Errorf("invalid Opts, got: %q; %w", cd.Opts, err) + } + if opts != nil && opts.Multipliers != nil { + if len(opts.Multipliers) != len(cd.Streams) { + return fmt.Errorf("multipliers length %d != StreamValues length %d", len(opts.Multipliers), len(cd.Streams)) + } + + for i, stream := range cd.Streams { + if opts.Multipliers[i].StreamID != stream.StreamID { + return fmt.Errorf("LLO StreamID %d mismatched with Multiplier StreamID %d", stream.StreamID, opts.Multipliers[i].StreamID) + } + if !(opts.Multipliers[i].Multiplier.IsInteger()) { + return fmt.Errorf("multiplier for StreamID %d must be an integer", opts.Multipliers[i].StreamID) + } + if opts.Multipliers[i].Multiplier.IsZero() { + return fmt.Errorf("multiplier for StreamID %d can't be zero", opts.Multipliers[i].StreamID) + } + if opts.Multipliers[i].Multiplier.IsNegative() { + return fmt.Errorf("multiplier for StreamID %d can't be negative", opts.Multipliers[i].StreamID) + } + } + } + return nil +} + +// EventID is expected to uniquely identify a (don, round) +func (r ReportCodecCapabilityTrigger) EventID(report datastreamsllo.Report) string { + return fmt.Sprintf("streams_%d_%d", r.donID, report.ObservationTimestampNanoseconds) +} + +func (r ReportCodecCapabilityTrigger) ParseOpts(opts []byte) (any, error) { + var o ReportCodecCapabilityTriggerOpts + if err := o.Decode(opts); err != nil { + return nil, fmt.Errorf("failed to decode opts; got: '%s'; %w", opts, err) + } + return o, nil +} diff --git a/llo/cre/report_codec_test.go b/llo/cre/report_codec_test.go new file mode 100644 index 0000000..1ef2396 --- /dev/null +++ b/llo/cre/report_codec_test.go @@ -0,0 +1,468 @@ +package cre + +import ( + "testing" + + "github.com/shopspring/decimal" + "github.com/smartcontractkit/libocr/offchainreporting2/types" + "github.com/smartcontractkit/wsrpc/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + "github.com/smartcontractkit/chainlink-data-streams/llo" + datastreamsllo "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink-protos/cre/go/values/pb" +) + +func Test_ReportCodec(t *testing.T) { + t.Run("Encode: Without Opts SUCCESS", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + optsCache := datastreamsllo.NewOptsCache() + r := datastreamsllo.Report{ + ConfigDigest: types.ConfigDigest{1, 2, 3}, + SeqNr: 32, + ChannelID: llotypes.ChannelID(31), + ValidAfterNanoseconds: 28, + ObservationTimestampNanoseconds: 34, + Values: []llo.StreamValue{llo.ToDecimal(decimal.NewFromInt(35)), llo.ToDecimal(decimal.NewFromInt(36))}, + Specimen: false, + } + optsCache.Set(r.ChannelID, []byte{}) + encoded, err := c.Encode(r, llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + }, + }, optsCache) + require.NoError(t, err) + + var pbuf capabilitiespb.OCRTriggerReport + err = proto.Unmarshal(encoded, &pbuf) + require.NoError(t, err) + + assert.Equal(t, "streams_1_34", pbuf.EventID) + assert.Equal(t, uint64(34), pbuf.Timestamp) + require.Len(t, pbuf.Outputs.Fields, 2) + assert.Equal(t, &pb.Value_Uint64Value{Uint64Value: 34}, pbuf.Outputs.Fields["ObservationTimestampNanoseconds"].Value) + require.Len(t, pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields, 2) + + require.Len(t, pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[0].Value.(*pb.Value_MapValue).MapValue.Fields, 2) + decimalBytes := pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[0].Value.(*pb.Value_MapValue).MapValue.Fields["Decimal"].Value.(*pb.Value_BytesValue).BytesValue + d := decimal.Decimal{} + require.NoError(t, (&d).UnmarshalBinary(decimalBytes)) + assert.Equal(t, "35", d.String()) + assert.Equal(t, int64(1), pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[0].Value.(*pb.Value_MapValue).MapValue.Fields["StreamID"].Value.(*pb.Value_Int64Value).Int64Value) + + require.Len(t, pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[1].Value.(*pb.Value_MapValue).MapValue.Fields, 2) + decimalBytes = pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[1].Value.(*pb.Value_MapValue).MapValue.Fields["Decimal"].Value.(*pb.Value_BytesValue).BytesValue + d = decimal.Decimal{} + require.NoError(t, (&d).UnmarshalBinary(decimalBytes)) + assert.Equal(t, "36", d.String()) + assert.Equal(t, int64(2), pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[1].Value.(*pb.Value_MapValue).MapValue.Fields["StreamID"].Value.(*pb.Value_Int64Value).Int64Value) + }) + t.Run("Encode: With Opts SUCCESS", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + r := datastreamsllo.Report{ + ConfigDigest: types.ConfigDigest{1, 2, 3}, + SeqNr: 32, + ChannelID: llotypes.ChannelID(31), + ValidAfterNanoseconds: 28, + ObservationTimestampNanoseconds: 34, + Values: []llo.StreamValue{llo.ToDecimal(decimal.NewFromInt(35)), llo.ToDecimal(decimal.NewFromInt(36)), llo.ToDecimal(decimal.NewFromInt(37))}, + Specimen: false, + } + + multiplier1, err := decimal.NewFromString("1") + require.NoError(t, err) + multiplier2, err := decimal.NewFromString("1000000000000000000") // 10^18 + require.NoError(t, err) + multiplier3, err := decimal.NewFromString("1000000") // 10^6 + require.NoError(t, err) + + cache := datastreamsllo.NewOptsCache() + + opts, err := (&ReportCodecCapabilityTriggerOpts{ + Multipliers: []ReportCodecCapabilityTriggerMultiplier{ + {Multiplier: multiplier1, StreamID: 1}, + {Multiplier: multiplier2, StreamID: 2}, + {Multiplier: multiplier3, StreamID: 3}, + }, + }).Encode() + cache.Set(r.ChannelID, opts) + require.NoError(t, err) + encoded, err := c.Encode(r, llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + {StreamID: 3}, + }, + Opts: opts, + }, cache) + require.NoError(t, err) + + var pbuf capabilitiespb.OCRTriggerReport + err = proto.Unmarshal(encoded, &pbuf) + require.NoError(t, err) + + assert.Equal(t, "streams_1_34", pbuf.EventID) + assert.Equal(t, uint64(34), pbuf.Timestamp) + require.Len(t, pbuf.Outputs.Fields, 2) + assert.Equal(t, &pb.Value_Uint64Value{Uint64Value: 34}, pbuf.Outputs.Fields["ObservationTimestampNanoseconds"].Value) + require.Len(t, pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields, 3) + + require.Len(t, pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[0].Value.(*pb.Value_MapValue).MapValue.Fields, 2) + decimalBytes := pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[0].Value.(*pb.Value_MapValue).MapValue.Fields["Decimal"].Value.(*pb.Value_BytesValue).BytesValue + d := decimal.Decimal{} + require.NoError(t, (&d).UnmarshalBinary(decimalBytes)) + assert.Equal(t, "35", d.String()) + assert.Equal(t, int64(1), pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[0].Value.(*pb.Value_MapValue).MapValue.Fields["StreamID"].Value.(*pb.Value_Int64Value).Int64Value) + + require.Len(t, pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[1].Value.(*pb.Value_MapValue).MapValue.Fields, 2) + decimalBytes = pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[1].Value.(*pb.Value_MapValue).MapValue.Fields["Decimal"].Value.(*pb.Value_BytesValue).BytesValue + d = decimal.Decimal{} + require.NoError(t, (&d).UnmarshalBinary(decimalBytes)) + assert.Equal(t, "36000000000000000000", d.String()) + assert.Equal(t, int64(2), pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[1].Value.(*pb.Value_MapValue).MapValue.Fields["StreamID"].Value.(*pb.Value_Int64Value).Int64Value) + + require.Len(t, pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[2].Value.(*pb.Value_MapValue).MapValue.Fields, 2) + decimalBytes = pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[2].Value.(*pb.Value_MapValue).MapValue.Fields["Decimal"].Value.(*pb.Value_BytesValue).BytesValue + d = decimal.Decimal{} + require.NoError(t, (&d).UnmarshalBinary(decimalBytes)) + assert.Equal(t, "37000000", d.String()) + assert.Equal(t, int64(3), pbuf.Outputs.Fields["Payload"].Value.(*pb.Value_ListValue).ListValue.Fields[2].Value.(*pb.Value_MapValue).MapValue.Fields["StreamID"].Value.(*pb.Value_Int64Value).Int64Value) + }) + t.Run("Decode: With Opts SUCCESS", func(t *testing.T) { + optBytes := []byte{123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 115, 34, 58, 91, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 48, 49, 48, 49, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 48, 49, 48, 50, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 49, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 50, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 51, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 52, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 53, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 54, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 55, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 56, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 48, 57, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 48, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 49, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 50, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 51, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 52, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 53, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 54, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 55, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 56, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 49, 57, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 50, 48, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 50, 49, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 50, 50, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 50, 51, 125, 44, 123, 34, 109, 117, 108, 116, 105, 112, 108, 105, 101, 114, 34, 58, 34, 49, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 34, 44, 34, 115, 116, 114, 101, 97, 109, 73, 68, 34, 58, 49, 48, 50, 48, 48, 48, 49, 48, 50, 52, 125, 93, 125} + + opts := &ReportCodecCapabilityTriggerOpts{} + err := opts.Decode(optBytes) + + require.NoError(t, err) + + // Verify the decoded opts structure contains expected multipliers and stream IDs + require.Len(t, opts.Multipliers, 26) + + expectedMultiplier, err := decimal.NewFromString("1000000000000000000") // 10^18 + require.NoError(t, err) + + expectedStreamIDs := []uint32{ + 1020000101, 1020000102, 1020001001, 1020001002, 1020001003, 1020001004, + 1020001005, 1020001006, 1020001007, 1020001008, 1020001009, 1020001010, + 1020001011, 1020001012, 1020001013, 1020001014, 1020001015, 1020001016, + 1020001017, 1020001018, 1020001019, 1020001020, 1020001021, 1020001022, + 1020001023, 1020001024, + } + + for i, multiplier := range opts.Multipliers { + assert.True(t, multiplier.Multiplier.Equal(expectedMultiplier), "Multiplier %d should be %s", i, expectedMultiplier.String()) + assert.Equal(t, expectedStreamIDs[i], multiplier.StreamID, "StreamID %d should be %d", i, expectedStreamIDs[i]) + } + }) + t.Run("Verify: Without Opts SUCCESS", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + err := c.Verify( + llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + }, + }, + ) + require.NoError(t, err) + }) + t.Run("Verify: Misaligned Multiplier StreamIDs FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + multiplier1, err := decimal.NewFromString("1") + require.NoError(t, err) + multiplier2, err := decimal.NewFromString("1000000000000000000") // 10^18 + require.NoError(t, err) + multiplier3, err := decimal.NewFromString("1000000") // 10^6 + require.NoError(t, err) + + opts, err := (&ReportCodecCapabilityTriggerOpts{ + Multipliers: []ReportCodecCapabilityTriggerMultiplier{ + {Multiplier: multiplier1, StreamID: 1}, + {Multiplier: multiplier2, StreamID: 3}, + {Multiplier: multiplier3, StreamID: 2}, + }, + }).Encode() + require.NoError(t, err) + err = c.Verify( + llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + {StreamID: 3}, + }, + Opts: opts, + }, + ) + require.EqualError(t, err, "LLO StreamID 2 mismatched with Multiplier StreamID 3") + }) + t.Run("Verify: Multiplier isn't an integer FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + multiplier1, err := decimal.NewFromString("123.4567") + require.NoError(t, err) + multiplier2, err := decimal.NewFromString("89.01234") + require.NoError(t, err) + multiplier3, err := decimal.NewFromString("1000000") // 10^6 + require.NoError(t, err) + + opts, err := (&ReportCodecCapabilityTriggerOpts{ + Multipliers: []ReportCodecCapabilityTriggerMultiplier{ + {Multiplier: multiplier1, StreamID: 1}, + {Multiplier: multiplier2, StreamID: 2}, + {Multiplier: multiplier3, StreamID: 3}, + }, + }).Encode() + require.NoError(t, err) + err = c.Verify( + llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + {StreamID: 3}, + }, + Opts: opts, + }, + ) + require.EqualError(t, err, "multiplier for StreamID 1 must be an integer") + }) + t.Run("Verify: Multiplier is zero FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + multiplier1, err := decimal.NewFromString("0") + require.NoError(t, err) + multiplier2, err := decimal.NewFromString("0") + require.NoError(t, err) + multiplier3, err := decimal.NewFromString("1000000") // 10^6 + require.NoError(t, err) + + opts, err := (&ReportCodecCapabilityTriggerOpts{ + Multipliers: []ReportCodecCapabilityTriggerMultiplier{ + {Multiplier: multiplier1, StreamID: 1}, + {Multiplier: multiplier2, StreamID: 2}, + {Multiplier: multiplier3, StreamID: 3}, + }, + }).Encode() + require.NoError(t, err) + err = c.Verify( + llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + {StreamID: 3}, + }, + Opts: opts, + }, + ) + require.EqualError(t, err, "multiplier for StreamID 1 can't be zero") + }) + t.Run("Verify: Multiplier is negative FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + multiplier1, err := decimal.NewFromString("-1000000000000000000") // -10^18 + require.NoError(t, err) + multiplier2, err := decimal.NewFromString("-1") + require.NoError(t, err) + multiplier3, err := decimal.NewFromString("1000000") // 10^6 + require.NoError(t, err) + + opts, err := (&ReportCodecCapabilityTriggerOpts{ + Multipliers: []ReportCodecCapabilityTriggerMultiplier{ + {Multiplier: multiplier1, StreamID: 1}, + {Multiplier: multiplier2, StreamID: 2}, + {Multiplier: multiplier3, StreamID: 3}, + }, + }).Encode() + require.NoError(t, err) + err = c.Verify( + llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + {StreamID: 3}, + }, + Opts: opts, + }, + ) + require.EqualError(t, err, "multiplier for StreamID 1 can't be negative") + }) + t.Run("Verify: Multipliers length, StreamValues length mismatch FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + multiplier1, err := decimal.NewFromString("1000000000000000000") // 10^18 + require.NoError(t, err) + multiplier2, err := decimal.NewFromString("1") + require.NoError(t, err) + multiplier3, err := decimal.NewFromString("1000000") // 10^6 + require.NoError(t, err) + + opts, err := (&ReportCodecCapabilityTriggerOpts{ + Multipliers: []ReportCodecCapabilityTriggerMultiplier{ + {Multiplier: multiplier1, StreamID: 1}, + {Multiplier: multiplier2, StreamID: 2}, + {Multiplier: multiplier3, StreamID: 3}, + }, + }).Encode() + require.NoError(t, err) + + err = c.Verify( + llotypes.ChannelDefinition{ + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 3}, + }, + Opts: opts, + }, + ) + require.EqualError(t, err, "multipliers length 3 != StreamValues length 2") + }) +} + +func TestReportCodecCapabilityTrigger_ParseOpts(t *testing.T) { + t.Run("ParseOpts: Valid opts with multipliers SUCCESS", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + multiplier1, err := decimal.NewFromString("1") + require.NoError(t, err) + multiplier2, err := decimal.NewFromString("1000000000000000000") // 10^18 + require.NoError(t, err) + multiplier3, err := decimal.NewFromString("1000000") // 10^6 + require.NoError(t, err) + + optsBytes, err := (&ReportCodecCapabilityTriggerOpts{ + Multipliers: []ReportCodecCapabilityTriggerMultiplier{ + {Multiplier: multiplier1, StreamID: 1}, + {Multiplier: multiplier2, StreamID: 2}, + {Multiplier: multiplier3, StreamID: 3}, + }, + }).Encode() + require.NoError(t, err) + + parsed, err := c.ParseOpts(optsBytes) + require.NoError(t, err) + require.NotNil(t, parsed) + + opts, ok := parsed.(ReportCodecCapabilityTriggerOpts) + require.True(t, ok, "parsed result should be ReportCodecCapabilityTriggerOpts") + + require.Len(t, opts.Multipliers, 3) + assert.True(t, opts.Multipliers[0].Multiplier.Equal(multiplier1)) + assert.Equal(t, uint32(1), opts.Multipliers[0].StreamID) + assert.True(t, opts.Multipliers[1].Multiplier.Equal(multiplier2)) + assert.Equal(t, uint32(2), opts.Multipliers[1].StreamID) + assert.True(t, opts.Multipliers[2].Multiplier.Equal(multiplier3)) + assert.Equal(t, uint32(3), opts.Multipliers[2].StreamID) + }) + + t.Run("ParseOpts: Empty opts nil SUCCESS", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + parsed, err := c.ParseOpts(nil) + require.NoError(t, err) + require.NotNil(t, parsed) + + opts, ok := parsed.(ReportCodecCapabilityTriggerOpts) + require.True(t, ok, "parsed result should be ReportCodecCapabilityTriggerOpts") + + assert.Nil(t, opts.Multipliers) + }) + + t.Run("ParseOpts: Empty opts empty byte slice SUCCESS", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + parsed, err := c.ParseOpts([]byte{}) + require.NoError(t, err) + require.NotNil(t, parsed) + + opts, ok := parsed.(ReportCodecCapabilityTriggerOpts) + require.True(t, ok, "parsed result should be ReportCodecCapabilityTriggerOpts") + + assert.Nil(t, opts.Multipliers) + }) + + t.Run("ParseOpts: Invalid JSON FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + invalidJSON := []byte("{invalid json}") + parsed, err := c.ParseOpts(invalidJSON) + + require.Error(t, err) + require.Nil(t, parsed) + assert.Contains(t, err.Error(), "failed to decode opts") + assert.Contains(t, err.Error(), string(invalidJSON)) + }) + + t.Run("ParseOpts: JSON with unknown fields FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + optsWithUnknownField := []byte(`{"multipliers":[],"unknown":"field"}`) + parsed, err := c.ParseOpts(optsWithUnknownField) + + require.Error(t, err) + require.Nil(t, parsed) + assert.Contains(t, err.Error(), "failed to decode opts") + assert.Contains(t, err.Error(), string(optsWithUnknownField)) + }) + + t.Run("ParseOpts: Wrong JSON structure multipliers as string FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + wrongTypeJSON := []byte(`{"multipliers":"not an array"}`) + parsed, err := c.ParseOpts(wrongTypeJSON) + + require.Error(t, err) + require.Nil(t, parsed) + assert.Contains(t, err.Error(), "failed to decode opts") + assert.Contains(t, err.Error(), string(wrongTypeJSON)) + }) + + t.Run("ParseOpts: Wrong JSON structure invalid multiplier type FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + invalidMultiplierJSON := []byte(`{"multipliers":[{"multiplier":"not a number","streamID":1}]}`) + parsed, err := c.ParseOpts(invalidMultiplierJSON) + + require.Error(t, err) + require.Nil(t, parsed) + assert.Contains(t, err.Error(), "failed to decode opts") + assert.Contains(t, err.Error(), string(invalidMultiplierJSON)) + }) + + t.Run("ParseOpts: Wrong JSON structure invalid streamID type FAIL", func(t *testing.T) { + donID := uint32(1) + c := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + + invalidStreamIDJSON := []byte(`{"multipliers":[{"multiplier":"1000","streamID":"not a number"}]}`) + parsed, err := c.ParseOpts(invalidStreamIDJSON) + + require.Error(t, err) + require.Nil(t, parsed) + assert.Contains(t, err.Error(), "failed to decode opts") + assert.Contains(t, err.Error(), string(invalidStreamIDJSON)) + }) +} diff --git a/llo/cre/transmitter.go b/llo/cre/transmitter.go new file mode 100644 index 0000000..3434821 --- /dev/null +++ b/llo/cre/transmitter.go @@ -0,0 +1,267 @@ +package cre + +import ( + "context" + "fmt" + "strconv" + "sync" + + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/libocr/offchainreporting2/types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + datastreamsllo "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink-protos/cre/go/values" +) + +const ( + defaultCapabilityName = "streams-trigger" + defaultCapabilityVersion = "2.0.0" // v2 = LLO + defaultTickerResolutionMs = 1000 + defaultSendChannelBufferSize = 1000 +) + +type Transmitter interface { + llotypes.Transmitter + services.Service +} + +type TransmitterConfig struct { + Logger logger.Logger `json:"-"` + CapabilitiesRegistry coretypes.CapabilitiesRegistry `json:"-"` + DonID uint32 `json:"-"` + + TriggerCapabilityName string `json:"triggerCapabilityName"` + TriggerCapabilityVersion string `json:"triggerCapabilityVersion"` + TriggerTickerMinResolutionMs int `json:"triggerTickerMinResolutionMs"` + TriggerSendChannelBufferSize int `json:"triggerSendChannelBufferSize"` +} + +var _ Transmitter = &transmitter{} +var _ capabilities.TriggerCapability = &transmitter{} + +type transmitter struct { + services.Service + eng *services.Engine + capabilities.CapabilityInfo + + config TransmitterConfig + fromAccount ocr2types.Account + registry coretypes.CapabilitiesRegistry + + subscribers map[string]*subscriber + lastReportMs uint64 + mu sync.Mutex +} + +type subscriber struct { + ch chan<- capabilities.TriggerResponse + workflowID string + config LLOTriggerConfig +} + +func (c TransmitterConfig) NewTransmitter() (*transmitter, error) { + return c.newTransmitter(c.Logger) +} + +func (c TransmitterConfig) newTransmitter(lggr logger.Logger) (*transmitter, error) { + t := &transmitter{ + config: c, + fromAccount: ocr2types.Account(lggr.Name() + strconv.FormatUint(uint64(c.DonID), 10)), + registry: c.CapabilitiesRegistry, + subscribers: make(map[string]*subscriber), + } + if t.config.TriggerCapabilityName == "" { + t.config.TriggerCapabilityName = defaultCapabilityName + } + if t.config.TriggerCapabilityVersion == "" { + t.config.TriggerCapabilityVersion = defaultCapabilityVersion + } + if t.config.TriggerTickerMinResolutionMs == 0 { + t.config.TriggerTickerMinResolutionMs = defaultTickerResolutionMs + } + if t.config.TriggerSendChannelBufferSize == 0 { + t.config.TriggerSendChannelBufferSize = defaultSendChannelBufferSize + } + + capInfo, err := capabilities.NewCapabilityInfo( + // TODO(CAPPL-645): add labels + t.config.TriggerCapabilityName+"@"+t.config.TriggerCapabilityVersion, + capabilities.CapabilityTypeTrigger, + "Streams LLO Trigger", + ) + if err != nil { + return nil, err + } + t.CapabilityInfo = capInfo + + t.Service, t.eng = services.Config{ + Name: "CRETransmitter", + Start: t.start, + Close: t.close, + }.NewServiceEngine(lggr) + + return t, nil +} + +func (t *transmitter) start(ctx context.Context) error { + return t.registry.Add(ctx, t) +} + +func (t *transmitter) close() error { + return t.registry.Remove(context.Background(), t.ID) +} + +func (t *transmitter) FromAccount(context.Context) (ocr2types.Account, error) { + return t.fromAccount, nil +} + +func (t *transmitter) Transmit( + ctx context.Context, + cd ocr2types.ConfigDigest, + seqNr uint64, + report ocr3types.ReportWithInfo[llotypes.ReportInfo], + sigs []types.AttributedOnchainSignature, +) error { + switch report.Info.ReportFormat { + case llotypes.ReportFormatCapabilityTrigger: + default: + // NOTE: Silently ignore non-capability format reports here. All + // channels are broadcast to all transmitters but this transmitter only + // cares about channels of type ReportFormatCapabilityTrigger + return nil + } + switch report.Info.LifeCycleStage { + case datastreamsllo.LifeCycleStageProduction: + default: + // NOTE: Ignore retirement and staging reports; for now we assume that + // we only care about sending production reports. + // + // Support could be added in future e.g. for verifying blue-green + // deploys etc. + return nil + } + + capSigs := make([]capabilities.OCRAttributedOnchainSignature, len(sigs)) + for i, sig := range sigs { + capSigs[i] = capabilities.OCRAttributedOnchainSignature{ + Signer: uint32(sig.Signer), + Signature: sig.Signature, + } + } + ev := &capabilities.OCRTriggerEvent{ + ConfigDigest: cd[:], + SeqNr: seqNr, + Report: report.Report, + Sigs: capSigs, + } + return t.processNewEvent(ctx, ev) +} + +func (t *transmitter) processNewEvent(ctx context.Context, event *capabilities.OCRTriggerEvent) error { + // unmarshal signed report to extract timestamp and eventID + p := &capabilitiespb.OCRTriggerReport{} + err := proto.Unmarshal(event.Report, p) + if err != nil { + return fmt.Errorf("failed to unmarshal OCRTriggerReport: %w", err) + } + + t.mu.Lock() + defer t.mu.Unlock() + tsMs := p.Timestamp / 1000000 // nanoseconds -> milliseconds + if tsMs/uint64(t.config.TriggerTickerMinResolutionMs) == t.lastReportMs/uint64(t.config.TriggerTickerMinResolutionMs) { //nolint:gosec // disable G115 + // ignore reports that are too frequent + return nil + } + t.lastReportMs = tsMs + alignedTsMs := tsMs - tsMs%uint64(t.config.TriggerTickerMinResolutionMs) //nolint:gosec // disable G115 + o, err := event.ToMap() + if err != nil { + return fmt.Errorf("failed to convert OCRTriggerEvent to map: %w", err) + } + capResponse := capabilities.TriggerResponse{ + Event: capabilities.TriggerEvent{ + TriggerType: t.ID, + ID: p.EventID, + Outputs: o, + }, + } + + t.eng.Debugw("ProcessReport pushing event", "eventID", p.EventID, "tsMs", tsMs, "alignedTsMs", alignedTsMs) + nIncludedSubscribers := 0 + for _, sub := range t.subscribers { + if alignedTsMs%sub.config.MaxFrequencyMs == 0 { + // include this subscriber + select { + case sub.ch <- capResponse: + case <-ctx.Done(): + t.eng.Error("context done, dropping event") + return ctx.Err() + default: + // drop event if channel is full - processNewEvent() should be non-blocking + t.eng.Errorw("subscriber channel full, dropping event", "eventID", p.EventID, "workflowID", sub.workflowID) + } + nIncludedSubscribers++ + } + } + t.eng.Debugw("ProcessReport done", "eventID", p.EventID, "nIncludedSubscribers", nIncludedSubscribers) + return nil +} + +func (t *transmitter) AckEvent(ctx context.Context, triggerID string, eventID string, method string) error { + return nil +} + +func (t *transmitter) RegisterTrigger(ctx context.Context, req capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) { + t.mu.Lock() + defer t.mu.Unlock() + + config, err := validateConfig(req.Config, &t.config) + if err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + if _, ok := t.subscribers[req.TriggerID]; ok { + return nil, fmt.Errorf("triggerId %s already registered", t.ID) + } + + ch := make(chan capabilities.TriggerResponse, defaultSendChannelBufferSize) + t.subscribers[req.TriggerID] = + &subscriber{ + ch: ch, + workflowID: req.Metadata.WorkflowID, + config: *config, + } + return ch, nil +} + +func validateConfig(registerConfig *values.Map, capabilityConfig *TransmitterConfig) (*LLOTriggerConfig, error) { + cfg := &LLOTriggerConfig{} + if err := registerConfig.UnwrapTo(cfg); err != nil { + return nil, err + } + if int64(cfg.MaxFrequencyMs)%int64(capabilityConfig.TriggerTickerMinResolutionMs) != 0 { //nolint:gosec // disable G115 + return nil, fmt.Errorf("MaxFrequencyMs must be a multiple of %d", capabilityConfig.TriggerTickerMinResolutionMs) + } + return cfg, nil +} + +func (t *transmitter) UnregisterTrigger(ctx context.Context, req capabilities.TriggerRegistrationRequest) error { + t.mu.Lock() + defer t.mu.Unlock() + + subscriber, ok := t.subscribers[req.TriggerID] + if !ok { + return fmt.Errorf("triggerId %s not registered", t.ID) + } + close(subscriber.ch) + delete(t.subscribers, req.TriggerID) + return nil +} diff --git a/llo/cre/transmitter_test.go b/llo/cre/transmitter_test.go new file mode 100644 index 0000000..6b4795c --- /dev/null +++ b/llo/cre/transmitter_test.go @@ -0,0 +1,110 @@ +package cre + +import ( + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/libocr/offchainreporting2/types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + "github.com/smartcontractkit/chainlink-data-streams/llo" + datastreamsllo "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink-protos/cre/go/values" +) + +const ( + donID = 4 +) + +func Test_Transmitter(t *testing.T) { + digest := types.ConfigDigest{1, 2, 3} + sigs := []types.AttributedOnchainSignature{ + { + Signer: 6, + Signature: []byte{4, 5, 6}, + }, + } + + cfg := TransmitterConfig{ + Logger: logger.Test(t), + CapabilitiesRegistry: nil, + DonID: donID, + } + tr, err := cfg.NewTransmitter() + require.NoError(t, err) + + t.Run("invalid config", func(t *testing.T) { + req := buildRegistrationRequest(t, "myID123", []LLOStreamID{12345, 67890}, 2300) + _, err = tr.RegisterTrigger(t.Context(), req) + require.Error(t, err) + }) + + t.Run("two registrations", func(t *testing.T) { + req1 := buildRegistrationRequest(t, "wf1_trigger1", []LLOStreamID{12345, 67890}, 1000) + req2 := buildRegistrationRequest(t, "wf2_trigger1", []LLOStreamID{67890}, 3000) + respCh1, err := tr.RegisterTrigger(t.Context(), req1) + require.NoError(t, err) + respCh2, err := tr.RegisterTrigger(t.Context(), req2) + require.NoError(t, err) + + require.NoError(t, tr.Transmit(t.Context(), digest, 1, encodeReport(t, 1023000000), sigs)) + require.NoError(t, tr.Transmit(t.Context(), digest, 2, encodeReport(t, 1803000000), sigs)) + require.NoError(t, tr.Transmit(t.Context(), digest, 3, encodeReport(t, 2101000000), sigs)) + require.NoError(t, tr.Transmit(t.Context(), digest, 4, encodeReport(t, 3456000000), sigs)) + require.NoError(t, tr.Transmit(t.Context(), digest, 5, encodeReport(t, 4502000000), sigs)) + require.NoError(t, tr.Transmit(t.Context(), digest, 6, encodeReport(t, 4777000000), sigs)) + require.Len(t, respCh1, 4) // every second + require.Len(t, respCh2, 1) // every 3 seconds + }) +} + +func buildRegistrationRequest(t *testing.T, triggerID string, streamIDs []LLOStreamID, maxFrequencyMs uint64) capabilities.TriggerRegistrationRequest { + cfg := &LLOTriggerConfig{ + StreamIDs: streamIDs, + MaxFrequencyMs: maxFrequencyMs, + } + wrappedCfg, err := values.WrapMap(cfg) + require.NoError(t, err) + + return capabilities.TriggerRegistrationRequest{ + TriggerID: triggerID, + Config: wrappedCfg, + } +} + +func encodeReport(t *testing.T, timestamp uint64) ocr3types.ReportWithInfo[llotypes.ReportInfo] { + codec := NewReportCodecCapabilityTrigger(logger.Test(t), donID) + rep := llo.Report{ + ConfigDigest: types.ConfigDigest{1, 2, 3}, + SeqNr: 32, + ChannelID: llotypes.ChannelID(31), + ValidAfterNanoseconds: 28, + ObservationTimestampNanoseconds: timestamp, + Values: []llo.StreamValue{llo.ToDecimal(decimal.NewFromInt(35)), llo.ToDecimal(decimal.NewFromInt(36))}, + Specimen: false, + } + cd := llotypes.ChannelDefinition{ + ReportFormat: llotypes.ReportFormatCapabilityTrigger, + Streams: []llotypes.Stream{ + {StreamID: 1}, + {StreamID: 2}, + }, + } + cache := datastreamsllo.NewOptsCache() + cache.Set(rep.ChannelID, []byte{}) + rawReport, err := codec.Encode(rep, cd, cache) + require.NoError(t, err) + + return ocr3types.ReportWithInfo[llotypes.ReportInfo]{ + Report: rawReport, + Info: llotypes.ReportInfo{ + LifeCycleStage: datastreamsllo.LifeCycleStageProduction, + ReportFormat: llotypes.ReportFormatCapabilityTrigger, + }, + } +} diff --git a/llo/cre/types.go b/llo/cre/types.go new file mode 100644 index 0000000..b5a44d8 --- /dev/null +++ b/llo/cre/types.go @@ -0,0 +1,11 @@ +package cre + +type LLOStreamID uint32 + +type LLOTriggerConfig struct { + // The IDs of the data feeds (LLO streams) that will be included in the trigger event. + StreamIDs []LLOStreamID `json:"streamIds" yaml:"streamIds" mapstructure:"streamIds"` + + // The interval in seconds after which a new trigger event is generated. + MaxFrequencyMs uint64 `json:"maxFrequencyMs" yaml:"maxFrequencyMs" mapstructure:"maxFrequencyMs"` +} diff --git a/llo/retirement/never_retire_cache.go b/llo/retirement/never_retire_cache.go new file mode 100644 index 0000000..5628713 --- /dev/null +++ b/llo/retirement/never_retire_cache.go @@ -0,0 +1,17 @@ +package retirement + +import ( + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" +) + +type neverShouldRetireCache struct{} + +func NewNeverShouldRetireCache() llotypes.ShouldRetireCache { + return &neverShouldRetireCache{} +} + +func (n *neverShouldRetireCache) ShouldRetire(digest ocrtypes.ConfigDigest) (bool, error) { + return false, nil +} diff --git a/llo/retirement/null_retirement_report_cache.go b/llo/retirement/null_retirement_report_cache.go new file mode 100644 index 0000000..a8dbedb --- /dev/null +++ b/llo/retirement/null_retirement_report_cache.go @@ -0,0 +1,25 @@ +package retirement + +import ( + "context" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + datastreamsllo "github.com/smartcontractkit/chainlink-data-streams/llo" +) + +type NullRetirementReportCache struct{} + +func (n *NullRetirementReportCache) StoreAttestedRetirementReport(ctx context.Context, cd ocr2types.ConfigDigest, retirementReport []byte, sigs []types.AttributedOnchainSignature) error { + return nil +} +func (n *NullRetirementReportCache) StoreConfig(ctx context.Context, cd ocr2types.ConfigDigest, signers [][]byte, f uint8) error { + return nil +} +func (n *NullRetirementReportCache) AttestedRetirementReport(predecessorConfigDigest ocr2types.ConfigDigest) ([]byte, error) { + return nil, nil +} +func (n *NullRetirementReportCache) CheckAttestedRetirementReport(predecessorConfigDigest ocr2types.ConfigDigest, attestedRetirementReport []byte) (datastreamsllo.RetirementReport, error) { + return datastreamsllo.RetirementReport{}, nil +} diff --git a/llo/retirement/plugin_scoped_retirement_report_cache.go b/llo/retirement/plugin_scoped_retirement_report_cache.go new file mode 100644 index 0000000..b4cdf9d --- /dev/null +++ b/llo/retirement/plugin_scoped_retirement_report_cache.go @@ -0,0 +1,85 @@ +package retirement + +import ( + "fmt" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "google.golang.org/protobuf/proto" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + + llo "github.com/smartcontractkit/chainlink-data-streams/llo" + retirement "github.com/smartcontractkit/chainlink-data-streams/llo/reportcodecs/retirement" +) + +type RetirementReportVerifier interface { + Verify(key types.OnchainPublicKey, digest types.ConfigDigest, seqNr uint64, r ocr3types.ReportWithInfo[llotypes.ReportInfo], signature []byte) bool +} + +// PluginScopedRetirementReportCache is a wrapper around RetirementReportCache +// that implements CheckAttestedRetirementReport +// +// This is necessary because while config digest keys are globally unique, +// different plugins may implement different signing/verification strategies +var _ llo.PredecessorRetirementReportCache = &pluginScopedRetirementReportCache{} + +type pluginScopedRetirementReportCache struct { + rrc RetirementReportCacheReader + verifier RetirementReportVerifier + codec llo.RetirementReportCodec +} + +func NewPluginScopedRetirementReportCache(rrc RetirementReportCacheReader, verifier RetirementReportVerifier, codec llo.RetirementReportCodec) llo.PredecessorRetirementReportCache { + return &pluginScopedRetirementReportCache{ + rrc: rrc, + verifier: verifier, + codec: codec, + } +} + +func (pr *pluginScopedRetirementReportCache) CheckAttestedRetirementReport(predecessorConfigDigest ocr2types.ConfigDigest, serializedAttestedRetirementReport []byte) (llo.RetirementReport, error) { + config, exists := pr.rrc.Config(predecessorConfigDigest) + if !exists { + return llo.RetirementReport{}, fmt.Errorf("Verify failed; predecessor config not found for config digest %x", predecessorConfigDigest[:]) + } + + var arr retirement.AttestedRetirementReport + if err := proto.Unmarshal(serializedAttestedRetirementReport, &arr); err != nil { + return llo.RetirementReport{}, fmt.Errorf("Verify failed; failed to unmarshal protobuf: %w", err) + } + + validSigs := 0 + for _, sig := range arr.Sigs { + // #nosec G115 + if sig.Signer >= uint32(len(config.Signers)) { + return llo.RetirementReport{}, fmt.Errorf("Verify failed; attested report signer index out of bounds (got: %d, max: %d)", sig.Signer, len(config.Signers)-1) + } + signer := config.Signers[sig.Signer] + valid := pr.verifier.Verify(types.OnchainPublicKey(signer), predecessorConfigDigest, arr.SeqNr, ocr3types.ReportWithInfo[llotypes.ReportInfo]{ + Report: arr.RetirementReport, + Info: llotypes.ReportInfo{ReportFormat: llotypes.ReportFormatRetirement}, + }, sig.Signature) + if !valid { + continue + } + validSigs++ + } + if validSigs <= int(config.F) { + return llo.RetirementReport{}, fmt.Errorf("Verify failed; not enough valid signatures (got: %d, need: %d)", validSigs, config.F+1) + } + decoded, err := pr.codec.Decode(arr.RetirementReport) + if err != nil { + return llo.RetirementReport{}, fmt.Errorf("Verify failed; failed to decode retirement report: %w", err) + } + return decoded, nil +} + +func (pr *pluginScopedRetirementReportCache) AttestedRetirementReport(predecessorConfigDigest ocr2types.ConfigDigest) ([]byte, error) { + arr, exists := pr.rrc.AttestedRetirementReport(predecessorConfigDigest) + if !exists { + return nil, nil + } + return arr, nil +} diff --git a/llo/retirement/plugin_scoped_retirement_report_cache_test.go b/llo/retirement/plugin_scoped_retirement_report_cache_test.go new file mode 100644 index 0000000..eb47c51 --- /dev/null +++ b/llo/retirement/plugin_scoped_retirement_report_cache_test.go @@ -0,0 +1,186 @@ +package retirement + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + datastreamsllo "github.com/smartcontractkit/chainlink-data-streams/llo" + + "github.com/smartcontractkit/chainlink-data-streams/llo/reportcodecs/retirement" +) + +func FuzzPluginScopedRetirementReportCache_CheckAttestedRetirementReport(f *testing.F) { + f.Add([]byte("not a protobuf")) + f.Add([]byte{0x0a, 0x00}) // empty protobuf + f.Add([]byte{0x0a, 0x02, 0x08, 0x01}) // invalid protobuf + f.Add(([]byte)(nil)) + f.Add([]byte{}) + + rrc := &mockRetirementReportCache{} + v := &mockVerifier{} + c := &mockCodec{} + psrrc := NewPluginScopedRetirementReportCache(rrc, v, c) + + exampleDigest := ocr2types.ConfigDigest{1} + + f.Fuzz(func(t *testing.T, data []byte) { + psrrc.CheckAttestedRetirementReport(exampleDigest, data) //nolint:errcheck // test that it doesn't panic, don't care about errors + }) +} + +type mockRetirementReportCache struct { + arr []byte + cfg Config + exists bool +} + +func (m *mockRetirementReportCache) AttestedRetirementReport(digest ocr2types.ConfigDigest) ([]byte, bool) { + return m.arr, m.exists +} +func (m *mockRetirementReportCache) Config(cd ocr2types.ConfigDigest) (Config, bool) { + return m.cfg, m.exists +} + +type mockVerifier struct { + verify func(key types.OnchainPublicKey, digest types.ConfigDigest, seqNr uint64, r ocr3types.ReportWithInfo[llotypes.ReportInfo], signature []byte) bool +} + +func (m *mockVerifier) Verify(key types.OnchainPublicKey, digest types.ConfigDigest, seqNr uint64, r ocr3types.ReportWithInfo[llotypes.ReportInfo], signature []byte) bool { + return m.verify(key, digest, seqNr, r, signature) +} + +type mockCodec struct { + decode func([]byte) (datastreamsllo.RetirementReport, error) +} + +func (m *mockCodec) Encode(datastreamsllo.RetirementReport) ([]byte, error) { + panic("not implemented") +} +func (m *mockCodec) Decode(b []byte) (datastreamsllo.RetirementReport, error) { + return m.decode(b) +} + +func Test_PluginScopedRetirementReportCache(t *testing.T) { + rrc := &mockRetirementReportCache{} + v := &mockVerifier{} + c := &mockCodec{} + psrrc := NewPluginScopedRetirementReportCache(rrc, v, c) + exampleDigest := ocr2types.ConfigDigest{1} + exampleDigest2 := ocr2types.ConfigDigest{2} + + exampleUnattestedSerializedRetirementReport := []byte("foo example unattested retirement report") + + validArr := retirement.AttestedRetirementReport{ + RetirementReport: exampleUnattestedSerializedRetirementReport, + SeqNr: 42, + Sigs: []*retirement.AttributedOnchainSignature{ + { + Signer: 0, + Signature: []byte("bar0"), + }, + { + Signer: 1, + Signature: []byte("bar1"), + }, + { + Signer: 2, + Signature: []byte("bar2"), + }, + { + Signer: 3, + Signature: []byte("bar3"), + }, + }, + } + serializedValidArr, err := proto.Marshal(&validArr) + require.NoError(t, err) + + t.Run("CheckAttestedRetirementReport", func(t *testing.T) { + t.Run("invalid", func(t *testing.T) { + // config missing + _, err := psrrc.CheckAttestedRetirementReport(exampleDigest, []byte("not valid")) + require.EqualError(t, err, "Verify failed; predecessor config not found for config digest 0100000000000000000000000000000000000000000000000000000000000000") + + rrc.cfg = Config{Digest: exampleDigest} + rrc.exists = true + + // unmarshal failure + _, err = psrrc.CheckAttestedRetirementReport(exampleDigest, []byte("not valid")) + require.Error(t, err) + assert.Contains(t, err.Error(), "Verify failed; failed to unmarshal protobuf: proto") + + // config is invalid (no signers) + _, err = psrrc.CheckAttestedRetirementReport(exampleDigest, serializedValidArr) + require.EqualError(t, err, "Verify failed; attested report signer index out of bounds (got: 0, max: -1)") + + rrc.cfg = Config{Digest: exampleDigest, Signers: [][]byte{[]byte{0}, []byte{1}, []byte{2}, []byte{3}}, F: 1} + + // no valid sigs + v.verify = func(key types.OnchainPublicKey, digest types.ConfigDigest, seqNr uint64, r ocr3types.ReportWithInfo[llotypes.ReportInfo], signature []byte) bool { + return false + } + _, err = psrrc.CheckAttestedRetirementReport(exampleDigest, serializedValidArr) + require.EqualError(t, err, "Verify failed; not enough valid signatures (got: 0, need: 2)") + + // not enough valid sigs + v.verify = func(key types.OnchainPublicKey, digest types.ConfigDigest, seqNr uint64, r ocr3types.ReportWithInfo[llotypes.ReportInfo], signature []byte) bool { + return string(signature) == "bar0" + } + _, err = psrrc.CheckAttestedRetirementReport(exampleDigest, serializedValidArr) + require.EqualError(t, err, "Verify failed; not enough valid signatures (got: 1, need: 2)") + + // enough valid sigs, but codec decode fails + v.verify = func(key types.OnchainPublicKey, digest types.ConfigDigest, seqNr uint64, r ocr3types.ReportWithInfo[llotypes.ReportInfo], signature []byte) bool { + if string(signature) == "bar0" || string(signature) == "bar3" { + return true + } + return false + } + c.decode = func([]byte) (datastreamsllo.RetirementReport, error) { + return datastreamsllo.RetirementReport{}, errors.New("codec decode failed") + } + _, err = psrrc.CheckAttestedRetirementReport(exampleDigest, serializedValidArr) + require.EqualError(t, err, "Verify failed; failed to decode retirement report: codec decode failed") + + exampleRetirementReport := datastreamsllo.RetirementReport{ + ValidAfterNanoseconds: map[llotypes.ChannelID]uint64{ + 0: 1, + }, + } + + // enough valid sigs and codec decode succeeds + c.decode = func(b []byte) (datastreamsllo.RetirementReport, error) { + assert.Equal(t, exampleUnattestedSerializedRetirementReport, b) + return exampleRetirementReport, nil + } + decoded, err := psrrc.CheckAttestedRetirementReport(exampleDigest, serializedValidArr) + require.NoError(t, err) + assert.Equal(t, exampleRetirementReport, decoded) + }) + }) + t.Run("AttestedRetirementReport", func(t *testing.T) { + rrc.arr = []byte("foo") + rrc.exists = true + + // exists + arr, err := psrrc.AttestedRetirementReport(exampleDigest) + require.NoError(t, err) + assert.Equal(t, rrc.arr, arr) + + rrc.exists = false + + // doesn't exist + arr, err = psrrc.AttestedRetirementReport(exampleDigest2) + require.NoError(t, err) + assert.Nil(t, arr) + }) +} diff --git a/llo/retirement/retirement_report_cache.go b/llo/retirement/retirement_report_cache.go new file mode 100644 index 0000000..35aef45 --- /dev/null +++ b/llo/retirement/retirement_report_cache.go @@ -0,0 +1,153 @@ +package retirement + +import ( + "context" + "fmt" + sync "sync" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + + "github.com/smartcontractkit/chainlink-data-streams/llo/reportcodecs/retirement" +) + +// RetirementReportCacheReader is used by the plugin-scoped +// RetirementReportCache +type RetirementReportCacheReader interface { + AttestedRetirementReport(cd ocr2types.ConfigDigest) ([]byte, bool) + Config(cd ocr2types.ConfigDigest) (Config, bool) +} + +// RetirementReportCache is intended to be a global singleton that is wrapped +// by a PluginScopedRetirementReportCache for a given plugin +type RetirementReportCache interface { + services.Service + StoreAttestedRetirementReport(ctx context.Context, cd ocrtypes.ConfigDigest, seqNr uint64, retirementReport []byte, sigs []types.AttributedOnchainSignature) error + StoreConfig(ctx context.Context, cd ocr2types.ConfigDigest, signers [][]byte, f uint8) error + RetirementReportCacheReader +} + +type retirementReportCache struct { + services.Service + eng *services.Engine + + mu sync.RWMutex + arrs map[ocr2types.ConfigDigest][]byte + configs map[ocr2types.ConfigDigest]Config + + orm RetirementReportCacheORM +} + +func NewRetirementReportCache(lggr logger.Logger, ds sqlutil.DataSource) RetirementReportCache { + orm := &retirementReportCacheORM{ds: ds} + return newRetirementReportCache(lggr, orm) +} + +func newRetirementReportCache(lggr logger.Logger, orm RetirementReportCacheORM) *retirementReportCache { + r := &retirementReportCache{ + arrs: make(map[ocr2types.ConfigDigest][]byte), + configs: make(map[ocr2types.ConfigDigest]Config), + orm: orm, + } + r.Service, r.eng = services.Config{ + Name: "RetirementReportCache", + Start: r.start, + }.NewServiceEngine(lggr) + return r +} + +// NOTE: Could do this lazily instead if we wanted to avoid a performance hit +// or potential tables missing etc on application startup (since +// RetirementReportCache is global) +func (r *retirementReportCache) start(ctx context.Context) (err error) { + // Load all attested retirement reports from the ORM + // and store them in the cache + r.arrs, err = r.orm.LoadAttestedRetirementReports(ctx) + if err != nil { + return fmt.Errorf("failed to load attested retirement reports: %w", err) + } + configs, err := r.orm.LoadConfigs(ctx) + if err != nil { + return fmt.Errorf("failed to load configs: %w", err) + } + for _, c := range configs { + r.configs[c.Digest] = c + } + return nil +} + +func (r *retirementReportCache) StoreAttestedRetirementReport(ctx context.Context, cd ocr2types.ConfigDigest, seqNr uint64, retirementReport []byte, sigs []types.AttributedOnchainSignature) error { + r.mu.RLock() + if _, ok := r.arrs[cd]; ok { + r.mu.RUnlock() + return nil + } + r.mu.RUnlock() + + pbSigs := make([]*retirement.AttributedOnchainSignature, len(sigs)) + for i, s := range sigs { + pbSigs[i] = &retirement.AttributedOnchainSignature{ + Signer: uint32(s.Signer), + Signature: s.Signature, + } + } + attestedRetirementReport := retirement.AttestedRetirementReport{ + RetirementReport: retirementReport, + SeqNr: seqNr, + Sigs: pbSigs, + } + + serialized, err := proto.Marshal(&attestedRetirementReport) + if err != nil { + return fmt.Errorf("StoreAttestedRetirementReport failed; failed to marshal protobuf: %w", err) + } + + if err := r.orm.StoreAttestedRetirementReport(ctx, cd, serialized); err != nil { + return fmt.Errorf("StoreAttestedRetirementReport failed; failed to persist to ORM: %w", err) + } + + r.mu.Lock() + r.arrs[cd] = serialized + r.mu.Unlock() + + return nil +} + +func (r *retirementReportCache) StoreConfig(ctx context.Context, cd ocr2types.ConfigDigest, signers [][]byte, f uint8) error { + r.mu.RLock() + if _, ok := r.configs[cd]; ok { + r.mu.RUnlock() + return nil + } + r.mu.RUnlock() + + r.mu.Lock() + r.configs[cd] = Config{ + Digest: cd, + Signers: signers, + F: f, + } + r.mu.Unlock() + + return r.orm.StoreConfig(ctx, cd, signers, f) +} + +func (r *retirementReportCache) AttestedRetirementReport(predecessorConfigDigest ocr2types.ConfigDigest) ([]byte, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + arr, exists := r.arrs[predecessorConfigDigest] + return arr, exists +} + +func (r *retirementReportCache) Config(cd ocr2types.ConfigDigest) (Config, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + c, exists := r.configs[cd] + return c, exists +} diff --git a/llo/retirement/retirement_report_cache_test.go b/llo/retirement/retirement_report_cache_test.go new file mode 100644 index 0000000..2ec0099 --- /dev/null +++ b/llo/retirement/retirement_report_cache_test.go @@ -0,0 +1,168 @@ +package retirement + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +type mockORM struct { + storedAttestedRetirementReports map[ocr2types.ConfigDigest][]byte + storedConfigs map[ocr2types.ConfigDigest]Config + + err error +} + +func (m *mockORM) StoreAttestedRetirementReport(ctx context.Context, cd ocr2types.ConfigDigest, attestedRetirementReport []byte) error { + m.storedAttestedRetirementReports[cd] = attestedRetirementReport + return m.err +} +func (m *mockORM) LoadAttestedRetirementReports(ctx context.Context) (map[ocr2types.ConfigDigest][]byte, error) { + return m.storedAttestedRetirementReports, m.err +} +func (m *mockORM) StoreConfig(ctx context.Context, cd ocr2types.ConfigDigest, signers [][]byte, f uint8) error { + m.storedConfigs[cd] = Config{Signers: signers, F: f, Digest: cd} + return m.err +} +func (m *mockORM) LoadConfigs(ctx context.Context) ([]Config, error) { + configs := make([]Config, 0, len(m.storedConfigs)) + for _, config := range m.storedConfigs { + configs = append(configs, config) + } + return configs, m.err +} + +func Test_RetirementReportCache(t *testing.T) { + t.Parallel() + + ctx := t.Context() + lggr := logger.Test(t) + orm := &mockORM{ + make(map[ocrtypes.ConfigDigest][]byte), + make(map[ocrtypes.ConfigDigest]Config), + nil, + } + exampleRetirementReport := []byte{1, 2, 3} + exampleRetirementReport2 := []byte{4, 5, 6} + exampleSignatures := []ocrtypes.AttributedOnchainSignature{ + {Signature: []byte("signature0"), Signer: 0}, + {Signature: []byte("signature1"), Signer: 1}, + {Signature: []byte("signature2"), Signer: 2}, + {Signature: []byte("signature3"), Signer: 3}, + } + // this is a serialized protobuf of report with 4 signers + exampleAttestedRetirementReport := []byte{0xa, 0x3, 0x1, 0x2, 0x3, 0x10, 0x64, 0x1a, 0xc, 0xa, 0xa, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x30, 0x1a, 0xe, 0xa, 0xa, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x31, 0x10, 0x1, 0x1a, 0xe, 0xa, 0xa, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x32, 0x10, 0x2, 0x1a, 0xe, 0xa, 0xa, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x33, 0x10, 0x3} + exampleDigest := ocrtypes.ConfigDigest{1} + exampleDigest2 := ocrtypes.ConfigDigest{2} + + seqNr := uint64(100) + + t.Run("start loads from ORM", func(t *testing.T) { + rrc := newRetirementReportCache(lggr, orm) + + t.Run("orm failure, errors", func(t *testing.T) { + orm.err = errors.New("orm failed") + err := rrc.start(ctx) + assert.EqualError(t, err, "failed to load attested retirement reports: orm failed") + }) + t.Run("orm success, loads both configs and attestedRetirementReports from orm", func(t *testing.T) { + orm.err = nil + orm.storedAttestedRetirementReports = map[ocr2types.ConfigDigest][]byte{ + exampleDigest: exampleAttestedRetirementReport, + exampleDigest2: exampleAttestedRetirementReport, + } + config1 := Config{Digest: exampleDigest, Signers: [][]byte{{1}, {2}, {3}, {4}}, F: 1} + config2 := Config{Digest: exampleDigest2, Signers: [][]byte{{5}, {6}, {7}, {8}}, F: 2} + orm.storedConfigs[exampleDigest] = config1 + orm.storedConfigs[exampleDigest2] = config2 + + err := rrc.start(ctx) + assert.NoError(t, err) + + assert.Len(t, rrc.arrs, 2) + assert.Equal(t, exampleAttestedRetirementReport, rrc.arrs[exampleDigest]) + assert.Equal(t, exampleAttestedRetirementReport, rrc.arrs[exampleDigest2]) + + assert.Len(t, rrc.configs, 2) + assert.Equal(t, config1, rrc.configs[exampleDigest]) + assert.Equal(t, config2, rrc.configs[exampleDigest2]) + }) + }) + + t.Run("StoreAttestedRetirementReport", func(t *testing.T) { + rrc := newRetirementReportCache(lggr, orm) + + err := rrc.StoreAttestedRetirementReport(ctx, exampleDigest, seqNr, exampleRetirementReport, exampleSignatures) + assert.NoError(t, err) + + assert.Len(t, rrc.arrs, 1) + assert.Equal(t, exampleAttestedRetirementReport, rrc.arrs[exampleDigest]) + assert.Equal(t, exampleAttestedRetirementReport, orm.storedAttestedRetirementReports[exampleDigest]) + + t.Run("does nothing if retirement report already exists for the given config digest", func(t *testing.T) { + err = rrc.StoreAttestedRetirementReport(ctx, exampleDigest, seqNr, exampleRetirementReport2, exampleSignatures) + assert.NoError(t, err) + assert.Len(t, rrc.arrs, 1) + assert.Equal(t, exampleAttestedRetirementReport, rrc.arrs[exampleDigest]) + }) + + t.Run("returns error if ORM store fails", func(t *testing.T) { + orm.err = errors.New("failed to store") + err = rrc.StoreAttestedRetirementReport(ctx, exampleDigest2, seqNr, exampleRetirementReport, exampleSignatures) + assert.Error(t, err) + + // it wasn't cached + assert.Len(t, rrc.arrs, 1) + }) + + t.Run("second retirement report succeeds when orm starts working again", func(t *testing.T) { + orm.err = nil + err := rrc.StoreAttestedRetirementReport(ctx, exampleDigest2, seqNr, exampleRetirementReport, exampleSignatures) + assert.NoError(t, err) + + assert.Len(t, rrc.arrs, 2) + assert.Equal(t, exampleAttestedRetirementReport, rrc.arrs[exampleDigest2]) + assert.Equal(t, exampleAttestedRetirementReport, orm.storedAttestedRetirementReports[exampleDigest2]) + + assert.Len(t, orm.storedAttestedRetirementReports, 2) + }) + }) + t.Run("AttestedRetirementReport", func(t *testing.T) { + rrc := newRetirementReportCache(lggr, orm) + + attestedRetirementReport, exists := rrc.AttestedRetirementReport(exampleDigest) + assert.False(t, exists) + assert.Nil(t, attestedRetirementReport) + + rrc.arrs[exampleDigest] = exampleAttestedRetirementReport + + attestedRetirementReport, exists = rrc.AttestedRetirementReport(exampleDigest) + assert.True(t, exists) + assert.Equal(t, exampleAttestedRetirementReport, attestedRetirementReport) + }) + t.Run("StoreConfig", func(t *testing.T) { + rrc := newRetirementReportCache(lggr, orm) + + signers := [][]byte{{1}, {2}, {3}, {4}} + + err := rrc.StoreConfig(ctx, exampleDigest, signers, 1) + assert.NoError(t, err) + + assert.Len(t, rrc.configs, 1) + assert.Equal(t, Config{Digest: exampleDigest, Signers: [][]byte{{1}, {2}, {3}, {4}}, F: 1}, rrc.configs[exampleDigest]) + assert.Equal(t, Config{Digest: exampleDigest, Signers: [][]byte{{1}, {2}, {3}, {4}}, F: 1}, orm.storedConfigs[exampleDigest]) + + t.Run("Config", func(t *testing.T) { + config, exists := rrc.Config(exampleDigest) + assert.True(t, exists) + assert.Equal(t, Config{Digest: exampleDigest, Signers: [][]byte{{1}, {2}, {3}, {4}}, F: 1}, config) + }) + }) +} diff --git a/llo/retirement/retirement_report_orm.go b/llo/retirement/retirement_report_orm.go new file mode 100644 index 0000000..cccffe5 --- /dev/null +++ b/llo/retirement/retirement_report_orm.go @@ -0,0 +1,114 @@ +package retirement + +import ( + "context" + "errors" + "fmt" + + "github.com/lib/pq" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" +) + +type RetirementReportCacheORM interface { + StoreAttestedRetirementReport(ctx context.Context, cd ocr2types.ConfigDigest, attestedRetirementReport []byte) error + LoadAttestedRetirementReports(ctx context.Context) (map[ocr2types.ConfigDigest][]byte, error) + StoreConfig(ctx context.Context, cd ocr2types.ConfigDigest, signers [][]byte, f uint8) error + LoadConfigs(ctx context.Context) ([]Config, error) +} + +type retirementReportCacheORM struct { + ds sqlutil.DataSource +} + +func (o *retirementReportCacheORM) StoreAttestedRetirementReport(ctx context.Context, cd ocr2types.ConfigDigest, attestedRetirementReport []byte) error { + _, err := o.ds.ExecContext(ctx, ` +INSERT INTO llo_retirement_report_cache (config_digest, attested_retirement_report, updated_at) +VALUES ($1, $2, NOW()) +ON CONFLICT (config_digest) DO NOTHING +`, cd, attestedRetirementReport) + if err != nil { + return fmt.Errorf("StoreAttestedRetirementReport failed: %w", err) + } + return nil +} + +func (o *retirementReportCacheORM) LoadAttestedRetirementReports(ctx context.Context) (map[ocr2types.ConfigDigest][]byte, error) { + rows, err := o.ds.QueryContext(ctx, "SELECT config_digest, attested_retirement_report FROM llo_retirement_report_cache") + if err != nil { + return nil, fmt.Errorf("LoadAttestedRetirementReports failed: %w", err) + } + defer rows.Close() + + reports := make(map[ocr2types.ConfigDigest][]byte) + for rows.Next() { + var rawCd []byte + var arr []byte + if err := rows.Scan(&rawCd, &arr); err != nil { + return nil, fmt.Errorf("LoadAttestedRetirementReports failed: %w", err) + } + cd, err := ocr2types.BytesToConfigDigest(rawCd) + if err != nil { + return nil, fmt.Errorf("LoadAttestedRetirementReports failed to scan config digest: %w", err) + } + reports[cd] = arr + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("LoadAttestedRetirementReports failed: %w", err) + } + + return reports, nil +} + +func (o *retirementReportCacheORM) StoreConfig(ctx context.Context, cd ocr2types.ConfigDigest, signers [][]byte, f uint8) error { + // do nothing on overwrite since configs are supposedly immutable + _, err := o.ds.ExecContext(ctx, `INSERT INTO llo_retirement_report_cache_configs (config_digest, signers, f, updated_at) VALUES ($1, $2, $3, NOW()) ON CONFLICT (config_digest) DO NOTHING`, cd, signers, f) + if err != nil { + return fmt.Errorf("StoreConfig failed: %w", err) + } + return nil +} + +type Config struct { + Digest [32]byte `db:"config_digest"` + Signers pq.ByteaArray `db:"signers"` + F uint8 `db:"f"` +} + +type scannableConfigDigest [32]byte + +func (s *scannableConfigDigest) Scan(src any) error { + b, ok := src.([]byte) + if !ok { + return errors.New("type assertion to []byte failed") + } + + cd, err := ocr2types.BytesToConfigDigest(b) + if err != nil { + return err + } + copy(s[:], cd[:]) + return nil +} + +func (o *retirementReportCacheORM) LoadConfigs(ctx context.Context) (configs []Config, err error) { + type config struct { + Digest scannableConfigDigest `db:"config_digest"` + Signers pq.ByteaArray `db:"signers"` + F uint8 `db:"f"` + } + var rawCfgs []config + err = o.ds.SelectContext(ctx, &rawCfgs, `SELECT config_digest, signers, f FROM llo_retirement_report_cache_configs ORDER BY config_digest`) + if err != nil { + return nil, fmt.Errorf("LoadConfigs failed: %w", err) + } + for _, rawCfg := range rawCfgs { + var cfg Config + copy(cfg.Digest[:], rawCfg.Digest[:]) + cfg.Signers = rawCfg.Signers + cfg.F = rawCfg.F + configs = append(configs, cfg) + } + return +} diff --git a/llo/retirement/retirement_report_orm_test.go b/llo/retirement/retirement_report_orm_test.go new file mode 100644 index 0000000..3210b62 --- /dev/null +++ b/llo/retirement/retirement_report_orm_test.go @@ -0,0 +1,66 @@ +package retirement + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-data-streams/mercury/testutils" +) + +func Test_RetirementReportCache_ORM(t *testing.T) { + db := testutils.NewSqlxDB(t) + orm := &retirementReportCacheORM{db} + ctx := t.Context() + + cd := ocr2types.ConfigDigest{1} + attestedRetirementReport := []byte("report1") + cd2 := ocr2types.ConfigDigest{2} + attestedRetirementReport2 := []byte("report2") + + t.Run("StoreAttestedRetirementReport", func(t *testing.T) { + err := orm.StoreAttestedRetirementReport(ctx, cd, attestedRetirementReport) + require.NoError(t, err) + err = orm.StoreAttestedRetirementReport(ctx, cd2, attestedRetirementReport2) + require.NoError(t, err) + }) + t.Run("LoadAttestedRetirementReports", func(t *testing.T) { + arrs, err := orm.LoadAttestedRetirementReports(ctx) + require.NoError(t, err) + + require.Len(t, arrs, 2) + assert.Equal(t, attestedRetirementReport, arrs[cd]) + assert.Equal(t, attestedRetirementReport2, arrs[cd2]) + }) + t.Run("StoreConfig", func(t *testing.T) { + signers := [][]byte{[]byte("signer1"), []byte("signer2")} + err := orm.StoreConfig(ctx, cd, signers, 1) + require.NoError(t, err) + + err = orm.StoreConfig(ctx, cd2, signers, 2) + require.NoError(t, err) + + // overwriting does nothing, the 255 is ignored + err = orm.StoreConfig(ctx, cd2, signers, 255) + require.NoError(t, err) + }) + t.Run("LoadConfigs", func(t *testing.T) { + configs, err := orm.LoadConfigs(ctx) + require.NoError(t, err) + + require.Len(t, configs, 2) + assert.Equal(t, Config{ + Digest: cd, + Signers: [][]byte{[]byte("signer1"), []byte("signer2")}, + F: 1, + }, configs[0]) + assert.Equal(t, Config{ + Digest: cd2, + Signers: [][]byte{[]byte("signer1"), []byte("signer2")}, + F: 2, + }, configs[1]) + }) +} diff --git a/llo/transmitter/bm/dummy_transmitter.go b/llo/transmitter/bm/dummy_transmitter.go new file mode 100644 index 0000000..d7666ac --- /dev/null +++ b/llo/transmitter/bm/dummy_transmitter.go @@ -0,0 +1,120 @@ +package bm + +import ( + "context" + "fmt" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink-data-streams/llo/reportcodecs/evm" +) + +// A dummy transmitter useful for benchmarking and testing + +var ( + promTransmitSuccessCount = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "dummytransmitter", + Name: "transmit_success_count", + Help: "Running count of successful transmits", + }) +) + +type Transmitter interface { + llotypes.Transmitter + services.Service +} + +type transmitter struct { + lggr logger.Logger + fromAccount string +} + +func NewTransmitter(lggr logger.Logger, fromAccount string) Transmitter { + return &transmitter{ + logger.Named(lggr, "DummyTransmitter"), + fromAccount, + } +} + +func (t *transmitter) Start(context.Context) error { + return nil +} + +func (t *transmitter) Close() error { + return nil +} + +func (t *transmitter) Transmit( + ctx context.Context, + digest ocr2types.ConfigDigest, + seqNr uint64, + report ocr3types.ReportWithInfo[llotypes.ReportInfo], + sigs []ocr2types.AttributedOnchainSignature, +) error { + lggr := t.lggr + { + switch report.Info.ReportFormat { + case llotypes.ReportFormatJSON: + r, err := (llo.JSONReportCodec{}).Decode(report.Report) + if err != nil { + lggr.Debugw(fmt.Sprintf("Failed to decode report with type %s", report.Info.ReportFormat), "err", err) + } else if r.SeqNr > 0 { + lggr = logger.With(lggr, + "report.Report.ConfigDigest", r.ConfigDigest, + "report.Report.SeqNr", r.SeqNr, + "report.Report.ChannelID", r.ChannelID, + "report.Report.ValidAfterNanoseconds", r.ValidAfterNanoseconds, + "report.Report.ObservationTimestampNanoseconds", r.ObservationTimestampNanoseconds, + "report.Report.Values", r.Values, + "report.Report.Specimen", r.Specimen, + ) + } + case llotypes.ReportFormatEVMPremiumLegacy: + r, err := (evm.ReportCodecPremiumLegacy{}).Decode(report.Report) + if err != nil { + lggr.Debugw(fmt.Sprintf("Failed to decode report with type %s", report.Info.ReportFormat), "err", err) + } else if r.ObservationsTimestamp > 0 { + lggr = logger.With(lggr, + "report.Report.FeedId", r.FeedId, + "report.Report.ObservationsTimestamp", r.ObservationsTimestamp, + "report.Report.BenchmarkPrice", r.BenchmarkPrice, + "report.Report.Bid", r.Bid, + "report.Report.Ask", r.Ask, + "report.Report.ValidFromTimestamp", r.ValidFromTimestamp, + "report.Report.ExpiresAt", r.ExpiresAt, + "report.Report.LinkFee", r.LinkFee, + "report.Report.NativeFee", r.NativeFee, + ) + } + default: + err := fmt.Errorf("unhandled report format: %s", report.Info.ReportFormat) + lggr.Debugw(fmt.Sprintf("Failed to decode report with type %s", report.Info.ReportFormat), "err", err) + } + } + promTransmitSuccessCount.Inc() + lggr.Infow("Transmit (dummy)", "digest", digest, "seqNr", seqNr, "report.Report", report.Report, "report.Info", report.Info, "sigs", sigs) + return nil +} + +// FromAccount returns the stringified (hex) CSA public key +func (t *transmitter) FromAccount(context.Context) (ocr2types.Account, error) { + return ocr2types.Account(t.fromAccount), nil +} + +func (t *transmitter) Ready() error { return nil } + +func (t *transmitter) HealthReport() map[string]error { + report := map[string]error{t.Name(): nil} + return report +} + +func (t *transmitter) Name() string { return t.lggr.Name() } diff --git a/llo/transmitter/bm/dummy_transmitter_test.go b/llo/transmitter/bm/dummy_transmitter_test.go new file mode 100644 index 0000000..349c346 --- /dev/null +++ b/llo/transmitter/bm/dummy_transmitter_test.go @@ -0,0 +1,34 @@ +package bm + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" +) + +func Test_DummyTransmitter(t *testing.T) { + lggr, observedLogs := logger.TestObservedSugared(t, zapcore.DebugLevel) + tr := NewTransmitter(lggr, "dummy") + + servicetest.Run(t, tr) + + err := tr.Transmit( + t.Context(), + types.ConfigDigest{}, + 42, + ocr3types.ReportWithInfo[llotypes.ReportInfo]{}, + []types.AttributedOnchainSignature{}, + ) + require.NoError(t, err) + + tests.RequireLogMessage(t, observedLogs, "Transmit") +} diff --git a/llo/transmitter/de/helpers_test.go b/llo/transmitter/de/helpers_test.go new file mode 100644 index 0000000..95844a8 --- /dev/null +++ b/llo/transmitter/de/helpers_test.go @@ -0,0 +1,61 @@ +package de + +import ( + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" +) + +func makeSampleReport() ocr3types.ReportWithInfo[llotypes.ReportInfo] { + return ocr3types.ReportWithInfo[llotypes.ReportInfo]{ + Report: ocrtypes.Report{1, 2, 3}, + Info: llotypes.ReportInfo{ + LifeCycleStage: llotypes.LifeCycleStage("production"), + ReportFormat: llotypes.ReportFormatEVMPremiumLegacy, + }, + } +} + +func makeSampleConfigDigest() ocrtypes.ConfigDigest { + return ocrtypes.ConfigDigest{1, 2, 3, 4, 5, 6} +} +func makeValidTransmission() *Transmission { + return &Transmission{ + ServerURL: "wss://example.com/mercury", + ConfigDigest: types.ConfigDigest{0x0, 0x9, 0x57, 0xdd, 0x2f, 0x63, 0x56, 0x69, 0x34, 0xfd, 0xc2, 0xe1, 0xcd, 0xc1, 0xe, 0x3e, 0x25, 0xb9, 0x26, 0x5a, 0x16, 0x23, 0x91, 0xa6, 0x53, 0x16, 0x66, 0x59, 0x51, 0x0, 0x28, 0x7c}, + SeqNr: 3, + Report: ocr3types.ReportWithInfo[llotypes.ReportInfo]{ + Report: ocrtypes.Report{0x0, 0x3, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x8e, 0x95, 0xcf, 0xb5, 0xd8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0xd0, 0x1c, 0x67, 0xa9, 0xcf, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xdf, 0x3, 0xca, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x1c, 0x93, 0x6d, 0xa4, 0xf2, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x14, 0x8d, 0x9a, 0xc1, 0xd9, 0x6f, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x40, 0x5c, 0xcf, 0xa1, 0xbc, 0x63, 0xc0, 0x0}, + Info: llotypes.ReportInfo{ + LifeCycleStage: llotypes.LifeCycleStage("production"), + ReportFormat: llotypes.ReportFormatEVMPremiumLegacy, + }, + }, + Sigs: []types.AttributedOnchainSignature{types.AttributedOnchainSignature{Signature: []uint8{0x9d, 0xab, 0x8f, 0xa7, 0xca, 0x7, 0x62, 0x57, 0xf7, 0x11, 0x2c, 0xb7, 0xf3, 0x49, 0x37, 0x12, 0xbd, 0xe, 0x14, 0x27, 0xfc, 0x32, 0x5c, 0xec, 0xa6, 0xb9, 0x7f, 0xf9, 0xd7, 0x7b, 0xa6, 0x36, 0x30, 0x9d, 0x84, 0x29, 0xbf, 0xd4, 0xeb, 0xc5, 0xc9, 0x29, 0xef, 0xdd, 0xd3, 0x2f, 0xa6, 0x25, 0x63, 0xda, 0xd9, 0x2c, 0xa1, 0x4a, 0xba, 0x75, 0xb2, 0x85, 0x25, 0x8f, 0x2b, 0x84, 0xcd, 0x99, 0x1}, Signer: 0x1}, types.AttributedOnchainSignature{Signature: []uint8{0x9a, 0x47, 0x4a, 0x3, 0x1a, 0x95, 0xcf, 0x46, 0x10, 0xaf, 0xcc, 0x90, 0x49, 0xb2, 0xce, 0xbf, 0x63, 0xaa, 0xc7, 0x25, 0x4d, 0x2a, 0x8, 0x36, 0xda, 0xd5, 0x9f, 0x9d, 0x63, 0x69, 0x22, 0xb3, 0x36, 0xd9, 0x6e, 0xf, 0xae, 0x7b, 0xd1, 0x61, 0x59, 0xf, 0x36, 0x4a, 0x22, 0xec, 0xde, 0x45, 0x32, 0xe0, 0x5b, 0x5c, 0xe3, 0x14, 0x29, 0x4, 0x60, 0x7b, 0xce, 0xa3, 0x89, 0x6b, 0xbb, 0xe0, 0x0}, Signer: 0x3}}, + } +} +func makeSampleTransmission(seqNr uint64, sURL string, report ocrtypes.Report) *Transmission { + return &Transmission{ + ServerURL: sURL, + ConfigDigest: types.ConfigDigest{0x0, 0x9, 0x57, 0xdd, 0x2f, 0x63, 0x56, 0x69, 0x34, 0xfd, 0xc2, 0xe1, 0xcd, 0xc1, 0xe, 0x3e, 0x25, 0xb9, 0x26, 0x5a, 0x16, 0x23, 0x91, 0xa6, 0x53, 0x16, 0x66, 0x59, 0x51, 0x0, 0x28, 0x7c}, + SeqNr: seqNr, + Report: ocr3types.ReportWithInfo[llotypes.ReportInfo]{ + Report: report, + Info: llotypes.ReportInfo{ + LifeCycleStage: llotypes.LifeCycleStage("production"), + ReportFormat: llotypes.ReportFormatEVMPremiumLegacy, + }, + }, + Sigs: []types.AttributedOnchainSignature{types.AttributedOnchainSignature{Signature: []uint8{0x9d, 0xab, 0x8f, 0xa7, 0xca, 0x7, 0x62, 0x57, 0xf7, 0x11, 0x2c, 0xb7, 0xf3, 0x49, 0x37, 0x12, 0xbd, 0xe, 0x14, 0x27, 0xfc, 0x32, 0x5c, 0xec, 0xa6, 0xb9, 0x7f, 0xf9, 0xd7, 0x7b, 0xa6, 0x36, 0x30, 0x9d, 0x84, 0x29, 0xbf, 0xd4, 0xeb, 0xc5, 0xc9, 0x29, 0xef, 0xdd, 0xd3, 0x2f, 0xa6, 0x25, 0x63, 0xda, 0xd9, 0x2c, 0xa1, 0x4a, 0xba, 0x75, 0xb2, 0x85, 0x25, 0x8f, 0x2b, 0x84, 0xcd, 0x99, 0x1}, Signer: 0x1}, types.AttributedOnchainSignature{Signature: []uint8{0x9a, 0x47, 0x4a, 0x3, 0x1a, 0x95, 0xcf, 0x46, 0x10, 0xaf, 0xcc, 0x90, 0x49, 0xb2, 0xce, 0xbf, 0x63, 0xaa, 0xc7, 0x25, 0x4d, 0x2a, 0x8, 0x36, 0xda, 0xd5, 0x9f, 0x9d, 0x63, 0x69, 0x22, 0xb3, 0x36, 0xd9, 0x6e, 0xf, 0xae, 0x7b, 0xd1, 0x61, 0x59, 0xf, 0x36, 0x4a, 0x22, 0xec, 0xde, 0x45, 0x32, 0xe0, 0x5b, 0x5c, 0xe3, 0x14, 0x29, 0x4, 0x60, 0x7b, 0xce, 0xa3, 0x89, 0x6b, 0xbb, 0xe0, 0x0}, Signer: 0x3}}, + } +} + +func makeSampleTransmissions(n int, sURL string) []*Transmission { + transmissions := make([]*Transmission, n) + for i := range n { + transmissions[i] = makeSampleTransmission(uint64(i), sURL, ocrtypes.Report{0x0, 0x3, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x8e, 0x95, 0xcf, 0xb5, 0xd8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0xd0, 0x1c, 0x67, 0xa9, 0xcf, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xdf, 0x3, 0xca, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x1c, 0x93, 0x6d, 0xa4, 0xf2, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x14, 0x8d, 0x9a, 0xc1, 0xd9, 0x6f, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x40, 0x5c, 0xcf, 0xa1, 0xbc, 0x63, 0xc0, 0x0}) + } + return transmissions +} diff --git a/llo/transmitter/de/orm.go b/llo/transmitter/de/orm.go new file mode 100644 index 0000000..c51d554 --- /dev/null +++ b/llo/transmitter/de/orm.go @@ -0,0 +1,263 @@ +package de + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math" + "time" + + "github.com/lib/pq" + + "github.com/smartcontractkit/libocr/commontypes" + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" +) + +// ORM is scoped to a single DON ID +type ORM interface { + DonID() uint32 + Insert(ctx context.Context, transmissions []*Transmission) error + Delete(ctx context.Context, hashes [][32]byte) error + Get(ctx context.Context, serverURL string, limit int, maxAge time.Duration) ([]*Transmission, error) + Prune(ctx context.Context, serverURL string, maxSize, batchSize int) (int64, error) +} + +type orm struct { + ds sqlutil.DataSource + donID uint32 +} + +func NewORM(ds sqlutil.DataSource, donID uint32) ORM { + return &orm{ds, donID} +} + +func (o *orm) DonID() uint32 { + return o.donID +} + +// Insert inserts the transmissions, ignoring duplicates +func (o *orm) Insert(ctx context.Context, transmissions []*Transmission) error { + if len(transmissions) == 0 { + return nil + } + + type transmission struct { + DonID uint32 `db:"don_id"` + ServerURL string `db:"server_url"` + ConfigDigest ocrtypes.ConfigDigest `db:"config_digest"` + SeqNr int64 `db:"seq_nr"` + Report []byte `db:"report"` + LifecycleStage string `db:"lifecycle_stage"` + ReportFormat uint32 `db:"report_format"` + Signatures [][]byte `db:"signatures"` + Signers []uint8 `db:"signers"` + TransmissionHash []byte `db:"transmission_hash"` + } + records := make([]transmission, len(transmissions)) + for i, t := range transmissions { + signatures := make([][]byte, len(t.Sigs)) + signers := make([]uint8, len(t.Sigs)) + for j, sig := range t.Sigs { + signatures[j] = sig.Signature + signers[j] = uint8(sig.Signer) + } + h := t.Hash() + if t.SeqNr > math.MaxInt64 { + // this is to appease the linter but shouldn't ever happen + return fmt.Errorf("seqNr is too large (got: %d, max: %d)", t.SeqNr, math.MaxInt64) + } + records[i] = transmission{ + DonID: o.donID, + ServerURL: t.ServerURL, + ConfigDigest: t.ConfigDigest, + SeqNr: int64(t.SeqNr), + Report: t.Report.Report, + LifecycleStage: string(t.Report.Info.LifeCycleStage), + ReportFormat: uint32(t.Report.Info.ReportFormat), + Signatures: signatures, + Signers: signers, + TransmissionHash: h[:], + } + } + + _, err := o.ds.NamedExecContext(ctx, ` + INSERT INTO llo_mercury_transmit_queue (don_id, server_url, config_digest, seq_nr, report, lifecycle_stage, report_format, signatures, signers, transmission_hash) + VALUES (:don_id, :server_url, :config_digest, :seq_nr, :report, :lifecycle_stage, :report_format, :signatures, :signers, :transmission_hash) + ON CONFLICT (transmission_hash) DO NOTHING + `, records) + + if err != nil { + return fmt.Errorf("llo orm: failed to insert transmissions: %w", err) + } + return nil +} + +// Delete deletes the given transmissions +func (o *orm) Delete(ctx context.Context, hashes [][32]byte) error { + if len(hashes) == 0 { + return nil + } + + var pqHashes pq.ByteaArray + for _, hash := range hashes { + pqHashes = append(pqHashes, hash[:]) + } + + _, err := o.ds.ExecContext(ctx, ` + DELETE FROM llo_mercury_transmit_queue + WHERE transmission_hash = ANY($1) + `, pqHashes) + if err != nil { + return fmt.Errorf("llo orm: failed to delete transmissions: %w", err) + } + return nil +} + +// Get returns all transmissions in chronologically descending order +// NOTE: passing maxAge=0 disables any age filter +func (o *orm) Get(ctx context.Context, serverURL string, limit int, maxAge time.Duration) ([]*Transmission, error) { + // The priority queue uses seqnr to sort transmissions so order by + // the same fields here for optimal insertion into the pq. + maxAgeClause := "" + params := []any{o.donID, serverURL, limit} + if maxAge > 0 { + maxAgeClause = "\nAND inserted_at >= NOW() - ($4 * INTERVAL '1 MICROSECOND')" + params = append(params, maxAge.Microseconds()) + } + q := fmt.Sprintf(` + SELECT config_digest, seq_nr, report, lifecycle_stage, report_format, signatures, signers + FROM llo_mercury_transmit_queue + WHERE don_id = $1 AND server_url = $2%s + ORDER BY seq_nr DESC, inserted_at DESC + LIMIT $3 + `, maxAgeClause) + rows, err := o.ds.QueryContext(ctx, q, params...) + if err != nil { + return nil, fmt.Errorf("llo orm: failed to get transmissions: %w", err) + } + defer rows.Close() + + var transmissions []*Transmission + for rows.Next() { + transmission := Transmission{ + ServerURL: serverURL, + } + var digest []byte + var signatures pq.ByteaArray + var signers pq.Int32Array + + err := rows.Scan( + &digest, + &transmission.SeqNr, + &transmission.Report.Report, + &transmission.Report.Info.LifeCycleStage, + &transmission.Report.Info.ReportFormat, + &signatures, + &signers, + ) + if err != nil { + return nil, fmt.Errorf("llo orm: failed to scan transmission: %w", err) + } + transmission.ConfigDigest = ocrtypes.ConfigDigest(digest) + if len(signatures) != len(signers) { + return nil, errors.New("signatures and signers must have the same length") + } + for i, sig := range signatures { + if signers[i] > math.MaxUint8 { + // this is to appease the linter but shouldn't ever happen + return nil, fmt.Errorf("signer is too large (got: %d, max: %d)", signers[i], math.MaxUint8) + } + transmission.Sigs = append(transmission.Sigs, ocrtypes.AttributedOnchainSignature{ + Signature: sig, + Signer: commontypes.OracleID(signers[i]), //nolint:gosec // G115 false positive + }) + } + + transmissions = append(transmissions, &transmission) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("llo orm: failed to scan transmissions: %w", err) + } + + return transmissions, nil +} + +// Prune keeps at most maxSize rows for the given (donID, serverURL) pair by +// deleting the oldest transmissions. +func (o *orm) Prune(ctx context.Context, serverURL string, maxSize, batchSize int) (rowsDeleted int64, err error) { + var oldest uint64 + err = o.ds.GetContext(ctx, &oldest, `SELECT seq_nr + FROM llo_mercury_transmit_queue + WHERE don_id = $1 AND server_url = $2 + ORDER BY seq_nr DESC + OFFSET $3 + LIMIT 1`, o.donID, serverURL, maxSize) + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("llo orm: failed to get oldest seq_nr: %w", err) + } + // Prune the requests with seq_nr older than this, in batches to avoid long + // locking or queries + for { + var res sql.Result + res, err = o.ds.ExecContext(ctx, ` +DELETE FROM llo_mercury_transmit_queue AS q +USING ( + SELECT transmission_hash + FROM llo_mercury_transmit_queue + WHERE don_id = $1 + AND server_url = $2 + AND seq_nr < $3 + ORDER BY seq_nr ASC + LIMIT $4 +) AS to_delete +WHERE q.transmission_hash = to_delete.transmission_hash; + `, o.donID, serverURL, oldest, batchSize) + if err != nil { + return rowsDeleted, fmt.Errorf("llo orm: batch delete failed to prune transmissions: %w", err) + } + var rowsAffected int64 + rowsAffected, err = res.RowsAffected() + if err != nil { + return rowsDeleted, fmt.Errorf("llo orm: batch delete failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + break + } + rowsDeleted += rowsAffected + } + + // This query to trim off the final few rows to reach exactly maxSize with + // should now be fast and efficient because of the batch deletes that + // already completed above. + res, err := o.ds.ExecContext(ctx, ` +WITH to_delete AS ( + SELECT ctid + FROM ( + SELECT ctid, + ROW_NUMBER() OVER (PARTITION BY don_id, server_url ORDER BY seq_nr DESC, inserted_at DESC) AS row_num + FROM llo_mercury_transmit_queue + WHERE don_id = $1 AND server_url = $2 + ) sub + WHERE row_num > $3 +) +DELETE FROM llo_mercury_transmit_queue +WHERE ctid IN (SELECT ctid FROM to_delete); +`, o.donID, serverURL, maxSize) + + if err != nil { + return rowsDeleted, fmt.Errorf("llo orm: final truncate failed to prune transmissions: %w", err) + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return rowsDeleted, fmt.Errorf("llo orm: final truncate failed to get rows affected: %w", err) + } + rowsDeleted += rowsAffected + + return rowsDeleted, nil +} diff --git a/llo/transmitter/de/orm_test.go b/llo/transmitter/de/orm_test.go new file mode 100644 index 0000000..3bfe167 --- /dev/null +++ b/llo/transmitter/de/orm_test.go @@ -0,0 +1,147 @@ +package de + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-data-streams/mercury/testutils" +) + +var ( + sURL = "wss://example.com/mercury" + sURL2 = "wss://mercuryserver.test" + sURL3 = "wss://mercuryserver.example/foo" +) + +func TestORM(t *testing.T) { + ctx := t.Context() + db := testutils.NewSqlxDB(t) + + t.Run("Insert, Get, Delete, Prune", func(t *testing.T) { + donID := uint32(654321) + orm := NewORM(db, donID) + + t.Run("DonID", func(t *testing.T) { + assert.Equal(t, donID, orm.DonID()) + }) + const n = 10 + transmissions := makeSampleTransmissions(n, sURL) + // Insert + err := orm.Insert(ctx, transmissions) + require.NoError(t, err) + // Get limits + result, err := orm.Get(ctx, sURL, 0, 0) + require.NoError(t, err) + assert.Empty(t, result) + + // Get limits + result, err = orm.Get(ctx, sURL, 1, 0) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, transmissions[len(transmissions)-1], result[0]) + + result, err = orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + + assert.ElementsMatch(t, transmissions, result) + + result, err = orm.Get(ctx, "other server url", 100, 0) + require.NoError(t, err) + assert.Empty(t, result) + // Delete + err = orm.Delete(ctx, [][32]byte{transmissions[0].Hash()}) + require.NoError(t, err) + + result, err = orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + + require.Len(t, result, n-1) + assert.NotContains(t, result, transmissions[0]) + assert.Contains(t, result, transmissions[1]) + + err = orm.Delete(ctx, [][32]byte{transmissions[1].Hash()}) + require.NoError(t, err) + + result, err = orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + require.Len(t, result, n-2) + // Prune + // ensure that len(transmissions) exceeds batch size to test batching + err = orm.Insert(ctx, transmissions) + require.NoError(t, err) + + d, err := orm.Prune(ctx, sURL, 1, n/3) + require.NoError(t, err) + assert.Equal(t, int64(n-1), d) + + result, err = orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, transmissions[len(transmissions)-1], result[0]) + + // Prune again, should not delete anything + d, err = orm.Prune(ctx, sURL, 1, n/3) + require.NoError(t, err) + assert.Zero(t, d) + result, err = orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + require.Len(t, result, 1) + assert.Equal(t, transmissions[len(transmissions)-1], result[0]) + + // Pruning with max allowed records = 0 deletes everything + d, err = orm.Prune(ctx, sURL, 0, 1) + require.NoError(t, err) + assert.Equal(t, int64(1), d) + result, err = orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + require.Empty(t, result) + }) + + t.Run("Prune trims to exactly the correct number of records", func(t *testing.T) { + donID := uint32(100) + orm := NewORM(db, donID) + + var transmissions []*Transmission + // create 100 records (10 * 10 duplicate sequence numbers) + for seqNr := range uint64(10) { + for i := range 10 { + transmissions = append(transmissions, makeSampleTransmission(seqNr, sURL, []byte{byte(i)})) + } + } + err := orm.Insert(ctx, transmissions) + require.NoError(t, err) + + d, err := orm.Prune(ctx, sURL, 43, 3) + require.NoError(t, err) + assert.Equal(t, int64(57), d) + + result, err := orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + require.Len(t, result, 43) + assert.Equal(t, uint64(9), result[0].SeqNr) + }) + + t.Run("Get respects maxAge argument and does not retrieve records older than this", func(t *testing.T) { + donID := uint32(101) + orm := NewORM(db, donID) + + transmissions := makeSampleTransmissions(10, sURL) + err := orm.Insert(ctx, transmissions) + require.NoError(t, err) + + testutils.MustExec(t, db, `UPDATE llo_mercury_transmit_queue SET inserted_at = NOW() - INTERVAL '1 year' WHERE seq_nr < 5`) + + // Get with maxAge = 0 should return all records + result, err := orm.Get(ctx, sURL, 100, 0) + require.NoError(t, err) + require.Len(t, result, 10) + + // Get with maxAge = 1 month should return only the records with seq_nr >= 5 + result, err = orm.Get(ctx, sURL, 100, 30*24*time.Hour) + require.NoError(t, err) + require.Len(t, result, 5) + }) +} diff --git a/llo/transmitter/de/persistence_manager.go b/llo/transmitter/de/persistence_manager.go new file mode 100644 index 0000000..110c2d4 --- /dev/null +++ b/llo/transmitter/de/persistence_manager.go @@ -0,0 +1,256 @@ +package de + +import ( + "context" + "strconv" + "sync" + "time" + + "github.com/jpillora/backoff" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" +) + +const ( + // DeleteQueueMaxSize is a sanity limit to avoid unbounded memory consumption. + // We should never get anywhere close to this under normal operation. + DeleteQueueMaxSize = 1_000_000 + // DeleteBatchSize is the max number of transmission records to delete + // in one query. Setting this larger may reduce overall total transaction + // load on the DB at the expense of blocking inserts for longer. + DeleteBatchSize = 1_000 + // FlushDeletesFrequency controls how often we wake up to check if there + // are records in the delete queue, and if so, attempt to drain the queue + // and delete them all. + FlushDeletesFrequency = 15 * time.Second + + // PruneFrequency controls how often we wake up to check to see if the + // transmissions table has exceeded its allowed size, and if so, truncate + // it. This should already be automatically handled by the transmission + // queue calling AsyncDelete, but it's here anyway for safety. + PruneFrequency = 1 * time.Hour + // PruneBatchSize is the max number of transmission records to delete in + // one query when pruning the table. + PruneBatchSize = 10_000 + + // OvertimeDeleteTimeout is the maximum time we will spend trying to delete + // queued transmissions after exit signal before giving up and logging an + // error. + OvertimeDeleteTimeout = 2 * time.Second +) + +var ( + promTransmitQueueDeleteErrorCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_queue_delete_error_count", + Help: "Running count of DB errors when trying to delete an item from the queue DB", + }, + []string{"donID", "serverURL"}, + ) +) + +// persistenceManager scopes an ORM to a single serverURL and handles cleanup +// and asynchronous deletion +type persistenceManager struct { + lggr logger.Logger + orm ORM + serverURL string + donID uint32 + + once services.StateMachine + stopCh services.StopChan + wg sync.WaitGroup + + deleteMu sync.Mutex + deleteQueue [][32]byte + + maxTransmitQueueSize int + flushDeletesFrequency time.Duration + pruneFrequency time.Duration + maxAge time.Duration + + transmitQueueDeleteErrorCount prometheus.Counter +} + +func NewPersistenceManager(lggr logger.Logger, orm ORM, serverURL string, maxTransmitQueueSize int, flushDeletesFrequency, pruneFrequency, maxAge time.Duration) *persistenceManager { + return &persistenceManager{ + logger.Sugared(lggr).Named("LLOPersistenceManager"), + orm, + serverURL, + orm.DonID(), + services.StateMachine{}, + make(services.StopChan), + sync.WaitGroup{}, + sync.Mutex{}, + nil, + maxTransmitQueueSize, + flushDeletesFrequency, + pruneFrequency, + maxAge, + promTransmitQueueDeleteErrorCount.WithLabelValues(strconv.Itoa(int(orm.DonID())), serverURL), + } +} + +func (pm *persistenceManager) Start(ctx context.Context) error { + return pm.once.StartOnce("LLOMercuryPersistenceManager", func() error { + pm.wg.Add(2) + go pm.runFlushDeletesLoop() + go pm.runPruneLoop() + return nil + }) +} + +func (pm *persistenceManager) Close() error { + return pm.once.StopOnce("LLOMercuryPersistenceManager", func() error { + close(pm.stopCh) + pm.wg.Wait() + return nil + }) +} + +func (pm *persistenceManager) DonID() uint32 { + return pm.orm.DonID() +} + +func (pm *persistenceManager) AsyncDelete(hash [32]byte) { + pm.addToDeleteQueue(hash) +} + +func (pm *persistenceManager) Load(ctx context.Context) ([]*Transmission, error) { + return pm.orm.Get(ctx, pm.serverURL, pm.maxTransmitQueueSize, pm.maxAge) +} + +func (pm *persistenceManager) runFlushDeletesLoop() { + defer pm.wg.Done() + + ctx, cancel := pm.stopCh.NewCtx() + defer cancel() + + ticker := services.TickerConfig{ + // Don't prune right away, wait some time for the application to settle + // down first + Initial: services.DefaultJitter.Apply(pm.flushDeletesFrequency), + JitterPct: services.DefaultJitter, + }.NewTicker(pm.flushDeletesFrequency) + defer ticker.Stop() + for { + select { + case <-pm.stopCh: + q := pm.resetDeleteQueue() + if len(q) > 0 { + // make a final effort to clear the database that goes into + // overtime + overtimeCtx, cancel := context.WithTimeout(context.Background(), OvertimeDeleteTimeout) + pm.deleteTransmissions(overtimeCtx, q, DeleteBatchSize) + cancel() + if n := pm.lenDeleteQueue(); n > 0 { + pm.lggr.Errorw("Exiting with undeleted transmissions", "n", n) + } + } + return + case <-ticker.C: + queuedTransmissionHashes := pm.resetDeleteQueue() + if len(queuedTransmissionHashes) == 0 { + continue + } + pm.deleteTransmissions(ctx, queuedTransmissionHashes, DeleteBatchSize) + } + } +} + +// deleteTransmissions blocks until transmissions are deleted or context is canceled +// it auto-retries on errors +func (pm *persistenceManager) deleteTransmissions(ctx context.Context, hashes [][32]byte, batchSize int) { + // Exponential backoff for very rarely occurring errors (DB disconnect etc) + b := backoff.Backoff{ + Min: 10 * time.Millisecond, + Max: 1 * time.Second, + Factor: 2, + Jitter: true, + } + + for i := 0; i < len(hashes); i += batchSize { // batch deletes to avoid large transactions + end := min(i+batchSize, len(hashes)) + deleteBatch := hashes[i:end] + for { + if err := pm.orm.Delete(ctx, deleteBatch); err != nil { + pm.lggr.Errorw("Failed to delete queued transmit requests", "err", err) + pm.transmitQueueDeleteErrorCount.Inc() + select { + case <-time.After(b.Duration()): + // Wait a backoff duration before trying to delete again + continue + case <-ctx.Done(): + // put undeleted items back on the queue and exit + pm.addToDeleteQueue(hashes[i:]...) + return + } + } + break + } + } + pm.lggr.Debugw("Flushed delete queue", "nDeleted", len(hashes)) +} + +func (pm *persistenceManager) runPruneLoop() { + defer pm.wg.Done() + + ctx, cancel := pm.stopCh.NewCtx() + defer cancel() + + ticker := services.NewTicker(pm.pruneFrequency) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + overtimeCtx, cancel := context.WithTimeout(context.Background(), OvertimeDeleteTimeout) + n, err := pm.orm.Prune(overtimeCtx, pm.serverURL, pm.maxTransmitQueueSize, PruneBatchSize) + cancel() + if err != nil { + pm.lggr.Errorw("Failed to truncate transmit requests table on close", "err", err) + } else if n > 0 { + pm.lggr.Debugw("Truncated transmit requests table on close", "nDeleted", n) + } + return + case <-ticker.C: + n, err := pm.orm.Prune(ctx, pm.serverURL, pm.maxTransmitQueueSize, PruneBatchSize) + if err != nil { + pm.lggr.Errorw("Failed to prune transmit requests table", "err", err) + continue + } + if n > 0 { + pm.lggr.Debugw("Pruned transmit requests table", "nDeleted", n) + } + } + } +} + +func (pm *persistenceManager) addToDeleteQueue(hashes ...[32]byte) { + pm.deleteMu.Lock() + defer pm.deleteMu.Unlock() + pm.deleteQueue = append(pm.deleteQueue, hashes...) + if len(pm.deleteQueue) > DeleteQueueMaxSize { + // NOTE: This could only happen if inserts are succeeding while deletes are + // failing (or not fast enough) which would be very strange + pm.lggr.Errorw("Delete queue is full; dropping transmissions", "hashes", hashes, "n", len(pm.deleteQueue)) + pm.deleteQueue = pm.deleteQueue[:DeleteQueueMaxSize] + } +} + +func (pm *persistenceManager) resetDeleteQueue() [][32]byte { + pm.deleteMu.Lock() + defer pm.deleteMu.Unlock() + queue := pm.deleteQueue + pm.deleteQueue = nil + return queue +} + +func (pm *persistenceManager) lenDeleteQueue() int { + pm.deleteMu.Lock() + defer pm.deleteMu.Unlock() + return len(pm.deleteQueue) +} diff --git a/llo/transmitter/de/persistence_manager_test.go b/llo/transmitter/de/persistence_manager_test.go new file mode 100644 index 0000000..2dda63b --- /dev/null +++ b/llo/transmitter/de/persistence_manager_test.go @@ -0,0 +1,189 @@ +package de + +import ( + "sort" + "testing" + "time" + + "github.com/jmoiron/sqlx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-data-streams/mercury/testutils" +) + +func bootstrapPersistenceManager(t *testing.T, donID uint32, db *sqlx.DB, maxTransmitQueueSize int) (*persistenceManager, *observer.ObservedLogs) { + t.Helper() + lggr, observedLogs := logger.TestObservedSugared(t, zapcore.DebugLevel) + orm := NewORM(db, donID) + return NewPersistenceManager(lggr, orm, "wss://example.com/mercury", maxTransmitQueueSize, 5*time.Millisecond, 5*time.Millisecond, 30*24*time.Hour), observedLogs +} + +func TestPersistenceManager(t *testing.T) { + donID1 := uint32(1234) + donID2 := uint32(2345) + + ctx := t.Context() + db := testutils.NewSqlxDB(t) + + t.Run("loads transmissions", func(t *testing.T) { + pm, _ := bootstrapPersistenceManager(t, donID1, db, 2) + transmissions := makeSampleTransmissions(3, sURL) + err := pm.orm.Insert(ctx, transmissions) + require.NoError(t, err) + + sort.Slice(transmissions, func(i, j int) bool { + // sort by seqnr desc to match return of Get + return transmissions[i].SeqNr > transmissions[j].SeqNr + }) + result, err := pm.Load(ctx) + require.NoError(t, err) + assert.ElementsMatch(t, transmissions[0:2], result) + + err = pm.orm.Delete(ctx, [][32]byte{transmissions[0].Hash()}) + require.NoError(t, err) + }) + + t.Run("scopes load to only transmissions with matching don ID", func(t *testing.T) { + pm, _ := bootstrapPersistenceManager(t, donID1, db, 2) + transmissions := makeSampleTransmissions(3, sURL) + err := pm.orm.Insert(ctx, transmissions) + require.NoError(t, err) + + pm2, _ := bootstrapPersistenceManager(t, donID2, db, 3) + result, err := pm2.Load(ctx) + require.NoError(t, err) + + assert.Empty(t, result) + }) + + t.Run("does not load records older than maxAge", func(t *testing.T) { + pm, _ := bootstrapPersistenceManager(t, donID1, db, 3) + transmissions := makeSampleTransmissions(3, sURL) + err := pm.orm.Insert(ctx, transmissions) + require.NoError(t, err) + + testutils.MustExec(t, db, `UPDATE llo_mercury_transmit_queue SET inserted_at = NOW() - INTERVAL '1 year' WHERE seq_nr = 0`) + + result, err := pm.Load(ctx) + require.NoError(t, err) + + assert.Len(t, result, 2) + assert.Equal(t, uint64(2), result[0].SeqNr) + assert.Equal(t, uint64(1), result[1].SeqNr) + }) +} + +func TestPersistenceManagerAsyncDelete(t *testing.T) { + ctx := t.Context() + donID := uint32(1234) + db := testutils.NewSqlxDB(t) + pm, observedLogs := bootstrapPersistenceManager(t, donID, db, 1000) + + transmissions := makeSampleTransmissions(3, sURL) + err := pm.orm.Insert(ctx, transmissions) + require.NoError(t, err) + + servicetest.Run(t, pm) + + pm.AsyncDelete(transmissions[0].Hash()) + + // Wait for next poll. + observedLogs.TakeAll() + tests.AssertLogEventually(t, observedLogs, "Flushed delete queue") + + result, err := pm.Load(ctx) + require.NoError(t, err) + require.Len(t, result, 2) + assert.ElementsMatch(t, transmissions[1:], result) +} + +func TestPersistenceManagerPrune(t *testing.T) { + donID1 := uint32(123456) + donID2 := uint32(654321) + db := testutils.NewSqlxDB(t) + + ctx := t.Context() + + transmissions := make([]*Transmission, 45) + for i := range uint64(45) { + transmissions[i] = makeSampleTransmission(i, sURL, ocrtypes.Report{byte(i)}) + } + + // cut 25 down to 2 + pm, observedLogs := bootstrapPersistenceManager(t, donID1, db, 2) + err := pm.orm.Insert(ctx, transmissions[:25]) + require.NoError(t, err) + + pm2, _ := bootstrapPersistenceManager(t, donID2, db, 20) + err = pm2.orm.Insert(ctx, transmissions[25:]) + require.NoError(t, err) + + err = pm.Start(ctx) + require.NoError(t, err) + + // Wait for next poll. + observedLogs.TakeAll() + tests.AssertLogEventually(t, observedLogs, "Pruned transmit requests table") + + result, err := pm.Load(ctx) + require.NoError(t, err) + require.ElementsMatch(t, transmissions[23:25], result) + + // Test pruning stops after Close. + err = pm.Close() + require.NoError(t, err) + + err = pm.orm.Insert(ctx, transmissions) + require.NoError(t, err) + + result, err = pm.Load(ctx) + require.NoError(t, err) + require.Len(t, result, 2) + + t.Run("prune was scoped to don ID", func(t *testing.T) { + result, err = pm2.Load(ctx) + require.NoError(t, err) + assert.Len(t, result, 20) + }) +} + +func Test_PersistenceManager_deleteTransmissions(t *testing.T) { + donID1 := uint32(123456) + db := testutils.NewSqlxDB(t) + + ctx := t.Context() + + transmissions := make([]*Transmission, 45) + for i := range uint64(45) { + transmissions[i] = makeSampleTransmission(i, sURL, ocrtypes.Report{byte(i)}) + } + + pm, _ := bootstrapPersistenceManager(t, donID1, db, 1000) + require.NoError(t, pm.orm.Insert(ctx, transmissions)) + + hashesToDelete := make([][32]byte, 20) + for i := range 20 { + hashesToDelete[i] = transmissions[i].Hash() + } + pm.deleteTransmissions(ctx, hashesToDelete, 7) + + ts, err := pm.Load(ctx) + require.NoError(t, err) + + require.Len(t, ts, 25) + for i := range 20 { + assert.NotContains(t, ts, transmissions[i]) + } + for i := 20; i < 45; i++ { + assert.Contains(t, ts, transmissions[i]) + } +} diff --git a/llo/transmitter/de/queue.go b/llo/transmitter/de/queue.go new file mode 100644 index 0000000..fab3b9c --- /dev/null +++ b/llo/transmitter/de/queue.go @@ -0,0 +1,269 @@ +package de + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "strconv" + "sync" + "time" + + heap "github.com/esote/minmaxheap" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" +) + +type asyncDeleter interface { + AsyncDelete(hash [32]byte) + DonID() uint32 +} + +var _ services.Service = (*transmitQueue)(nil) + +var promTransmitQueueLoad = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_queue_load", + Help: "Current count of items in the transmit queue", +}, + []string{"donID", "serverURL", "capacity"}, +) + +// Prometheus' default interval is 15s, set this to under 7.5s to avoid +// aliasing (see: https://en.wikipedia.org/wiki/Nyquist_frequency) +const promInterval = 6500 * time.Millisecond + +// TransmitQueue is the high-level package that everything outside of this file should be using +// It stores pending transmissions, yielding the latest (highest priority) first to the caller +type transmitQueue struct { + services.StateMachine + + cond sync.Cond + lggr logger.SugaredLogger + asyncDeleter asyncDeleter + mu *sync.RWMutex + + pq *priorityQueue + maxlen int + closed bool + + // monitor loop + stopMonitor func() + transmitQueueLoad prometheus.Gauge +} + +type TransmitQueue interface { + services.Service + + BlockingPop() (t *Transmission) + Push(t *Transmission) (ok bool) + Init(ts []*Transmission) error + IsEmpty() bool +} + +// maxlen controls how many items will be stored in the queue +// 0 means unlimited - be careful, this can cause memory leaks +func NewTransmitQueue(lggr logger.Logger, serverURL string, maxlen int, asyncDeleter asyncDeleter) TransmitQueue { + mu := new(sync.RWMutex) + return &transmitQueue{ + services.StateMachine{}, + sync.Cond{L: mu}, + logger.Sugared(lggr).Named("TransmitQueue"), + asyncDeleter, + mu, + nil, // pq needs to be initialized by calling tq.Init before use + maxlen, + false, + nil, + promTransmitQueueLoad.WithLabelValues(strconv.FormatUint(uint64(asyncDeleter.DonID()), 10), serverURL, strconv.FormatInt(int64(maxlen), 10)), + } +} + +func (tq *transmitQueue) Init(ts []*Transmission) error { + if len(ts) > tq.maxlen { + return fmt.Errorf("transmit queue is too small to hold %d transmissions", len(ts)) + } + tq.lggr.Debugw("Initializing transmission queue", "nTransmissions", len(ts), "maxlen", tq.maxlen) + pq := priorityQueue(ts) + heap.Init(&pq) // ensure the heap is ordered + tq.pq = &pq + return nil +} + +func (tq *transmitQueue) Push(t *Transmission) (ok bool) { + tq.cond.L.Lock() + defer tq.cond.L.Unlock() + + if tq.closed { + return false + } + + if tq.maxlen != 0 { + for tq.pq.Len() >= tq.maxlen { + // evict oldest entries to make room + removed := heap.PopMax(tq.pq) + if removed, ok := removed.(*Transmission); ok { + hash := removed.Hash() + tq.asyncDeleter.AsyncDelete(hash) + tq.lggr.Criticalw(fmt.Sprintf("Transmit queue is full; dropping oldest transmission (reached max length of %d)", tq.maxlen), "transmission", removed, "transmissionHash", hex.EncodeToString(hash[:])) + } + } + } + + heap.Push(tq.pq, t) + tq.cond.Signal() + + return true +} + +// BlockingPop will block until at least one item is in the heap, and then return it +// If the queue is closed, it will immediately return nil +func (tq *transmitQueue) BlockingPop() (t *Transmission) { + tq.cond.L.Lock() + defer tq.cond.L.Unlock() + if tq.closed { + return nil + } + for t = tq.pop(); t == nil; t = tq.pop() { + tq.cond.Wait() + if tq.closed { + return nil + } + } + return t +} + +func (tq *transmitQueue) IsEmpty() bool { + return tq.Len() == 0 +} + +func (tq *transmitQueue) Len() int { + tq.cond.L.Lock() + defer tq.cond.L.Unlock() + + sz := tq.pq.Len() + tq.cond.Signal() + return sz +} + +func (tq *transmitQueue) Start(context.Context) error { + return tq.StartOnce("TransmitQueue", func() error { + t := services.NewTicker(promInterval) + wg := new(sync.WaitGroup) + chStop := make(chan struct{}) + tq.stopMonitor = func() { + t.Stop() + close(chStop) + wg.Wait() + } + wg.Add(1) + go tq.monitorLoop(t.C, chStop, wg) + return nil + }) +} + +func (tq *transmitQueue) Close() error { + return tq.StopOnce("TransmitQueue", func() error { + tq.cond.L.Lock() + tq.closed = true + tq.cond.L.Unlock() + tq.cond.Broadcast() + tq.stopMonitor() + return nil + }) +} + +func (tq *transmitQueue) monitorLoop(c <-chan time.Time, chStop <-chan struct{}, wg *sync.WaitGroup) { + defer wg.Done() + + for { + select { + case <-c: + tq.report() + case <-chStop: + return + } + } +} + +func (tq *transmitQueue) report() { + tq.mu.RLock() + length := tq.pq.Len() + tq.mu.RUnlock() + tq.transmitQueueLoad.Set(float64(length)) +} + +func (tq *transmitQueue) Ready() error { + return nil +} +func (tq *transmitQueue) Name() string { return tq.lggr.Name() } +func (tq *transmitQueue) HealthReport() map[string]error { + report := map[string]error{tq.Name(): errors.Join( + tq.status(), + )} + return report +} + +func (tq *transmitQueue) status() (merr error) { + tq.mu.RLock() + length := tq.pq.Len() + closed := tq.closed + tq.mu.RUnlock() + if tq.maxlen != 0 && length > (tq.maxlen/2) { + merr = errors.Join(merr, fmt.Errorf("transmit priority queue is greater than 50%% full (%d/%d)", length, tq.maxlen)) + } + if closed { + merr = errors.New("transmit queue is closed") + } + return merr +} + +// pop latest Transmission from the heap +// Not thread-safe +func (tq *transmitQueue) pop() *Transmission { + if tq.pq.Len() == 0 { + return nil + } + return heap.Pop(tq.pq).(*Transmission) +} + +// HEAP +// Adapted from https://pkg.go.dev/container/heap#example-package-PriorityQueue + +// WARNING: None of these methods are thread-safe, caller must synchronize + +var _ heap.Interface = &priorityQueue{} + +type priorityQueue []*Transmission + +func (pq priorityQueue) Len() int { return len(pq) } + +func (pq priorityQueue) Less(i, j int) bool { + // We want Pop to give us the latest round, so we use greater than here + // i.e. a later seqNr is "less" than an earlier one + return pq[i].SeqNr > pq[j].SeqNr +} + +func (pq priorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] +} + +func (pq *priorityQueue) Pop() any { + n := len(*pq) + if n == 0 { + return nil + } + old := *pq + item := old[n-1] + old[n-1] = nil // avoid memory leak + *pq = old[0 : n-1] + return item +} + +func (pq *priorityQueue) Push(x any) { + *pq = append(*pq, x.(*Transmission)) +} diff --git a/llo/transmitter/de/queue_test.go b/llo/transmitter/de/queue_test.go new file mode 100644 index 0000000..edd1504 --- /dev/null +++ b/llo/transmitter/de/queue_test.go @@ -0,0 +1,155 @@ +package de + +import ( + "sync" + "testing" + + heap "github.com/esote/minmaxheap" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" +) + +var _ asyncDeleter = &mockAsyncDeleter{} + +type mockAsyncDeleter struct { + donID uint32 + hashes [][32]byte +} + +func (m *mockAsyncDeleter) AsyncDelete(hash [32]byte) { + m.hashes = append(m.hashes, hash) +} +func (m *mockAsyncDeleter) DonID() uint32 { + return m.donID +} + +func Test_Queue(t *testing.T) { + t.Parallel() + const maxSize = 7 + + lggr, observedLogs := logger.TestObserved(t, zapcore.ErrorLevel) + + t.Run("cannot init with more transmissions than capacity", func(t *testing.T) { + transmissions := makeSampleTransmissions(maxSize+1, sURL) + tq := NewTransmitQueue(lggr, sURL, maxSize, &mockAsyncDeleter{}) + err := tq.Init(transmissions) + require.Error(t, err) + }) + + t.Run("happy cases", func(t *testing.T) { + testTransmissions := makeSampleTransmissions(3, sURL) + deleter := &mockAsyncDeleter{} + tq := NewTransmitQueue(lggr, sURL, maxSize, deleter) + + require.NoError(t, tq.Init([]*Transmission{})) + + t.Run("successfully add transmissions to transmit queue", func(t *testing.T) { + for _, tt := range testTransmissions { + ok := tq.Push(tt) + require.True(t, ok) + } + report := tq.HealthReport() + require.NoError(t, report[tq.Name()]) + }) + + t.Run("transmit queue is more than 50% full", func(t *testing.T) { + tq.Push(testTransmissions[2]) + report := tq.HealthReport() + assert.Equal(t, "transmit priority queue is greater than 50% full (4/7)", report[tq.Name()].Error()) + }) + + t.Run("transmit queue pops the highest priority transmission", func(t *testing.T) { + tr := tq.BlockingPop() + assert.Equal(t, testTransmissions[2], tr) + }) + + t.Run("transmit queue is full and evicts the oldest transmission", func(t *testing.T) { + // add 5 more transmissions to overflow the queue by 1 + for range 5 { + tq.Push(testTransmissions[1]) + } + + // expecting testTransmissions[0] to get evicted and not present in the queue anymore + tests.AssertLogEventually(t, observedLogs, "Transmit queue is full; dropping oldest transmission (reached max length of 7)") + var transmissions []*Transmission + for range 7 { + tr := tq.BlockingPop() + transmissions = append(transmissions, tr) + } + + assert.NotContains(t, transmissions, testTransmissions[0]) + require.Len(t, deleter.hashes, 1) + assert.Equal(t, testTransmissions[0].Hash(), deleter.hashes[0]) + }) + + t.Run("transmit queue blocks when empty and resumes when transmission available", func(t *testing.T) { + assert.True(t, tq.IsEmpty()) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + tr := tq.BlockingPop() + assert.Equal(t, tr, testTransmissions[0]) + }() + go func() { + defer wg.Done() + tq.Push(testTransmissions[0]) + }() + wg.Wait() + }) + + t.Run("initializes transmissions", func(t *testing.T) { + expected := makeSampleTransmission(1, sURL, []byte{1}) + transmissions := []*Transmission{ + expected, + } + tq := NewTransmitQueue(lggr, sURL, 7, deleter) + require.NoError(t, tq.Init(transmissions)) + + transmission := tq.BlockingPop() + assert.Equal(t, expected, transmission) + assert.True(t, tq.IsEmpty()) + }) + }) + + t.Run("if the queue was overfilled it evicts entries until reaching maxSize", func(t *testing.T) { + testTransmissions := makeSampleTransmissions(maxSize*3, sURL) + deleter := &mockAsyncDeleter{} + tq := NewTransmitQueue(lggr, sURL, maxSize, deleter) + + // add 3 over capacity to queue + { + // need to copy to avoid sorting original slice + init := make([]*Transmission, maxSize+3) + copy(init, testTransmissions) + pq := priorityQueue(init) + heap.Init(&pq) // ensure the heap is ordered + tq.(*transmitQueue).pq = &pq // directly assign to bypass Init check + } + + tq.Push(testTransmissions[maxSize+3]) // push one more to trigger eviction + require.Equal(t, maxSize, tq.(*transmitQueue).Len()) + require.Len(t, deleter.hashes, 4) // evicted overfill entries (3 oversize plus 1 more to make room) + + // oldest entries removed + assert.Equal(t, testTransmissions[0].Hash(), deleter.hashes[0]) + assert.Equal(t, testTransmissions[1].Hash(), deleter.hashes[1]) + assert.Equal(t, testTransmissions[2].Hash(), deleter.hashes[2]) + assert.Equal(t, testTransmissions[3].Hash(), deleter.hashes[3]) + + queueEntriesSorted := []*Transmission{} + for { + transmission := tq.(*transmitQueue).pop() + if transmission == nil { + break + } + queueEntriesSorted = append(queueEntriesSorted, transmission) + } + assert.ElementsMatch(t, testTransmissions[4:4+maxSize], queueEntriesSorted) + }) +} diff --git a/llo/transmitter/de/server.go b/llo/transmitter/de/server.go new file mode 100644 index 0000000..997e3ba --- /dev/null +++ b/llo/transmitter/de/server.go @@ -0,0 +1,317 @@ +package de + +import ( + "context" + "fmt" + "maps" + "slices" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/jpillora/backoff" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink-data-streams/llo" + "github.com/smartcontractkit/chainlink-data-streams/llo/reportcodecs/evm" + "github.com/smartcontractkit/chainlink-data-streams/rpc" +) + +var ( + promTransmitQueueInsertErrorCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_queue_insert_error_count", + Help: "Running count of DB errors when trying to insert an item into the queue DB", + }, + []string{"donID", "serverURL"}, + ) + promTransmitQueuePushErrorCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_queue_push_error_count", + Help: "Running count of DB errors when trying to push an item onto the queue", + }, + []string{"donID", "serverURL"}, + ) + promTransmitServerErrorCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_server_error_count", + Help: "Number of errored transmissions that failed due to an error returned by the mercury server", + }, + []string{"donID", "serverURL", "code"}, + ) + promTransmitConcurrentTransmitGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "concurrent_transmit_gauge", + Help: "Gauge that measures the number of transmit threads currently waiting on a remote transmit call. You may wish to alert if this exceeds some number for a given period of time, or if it ever reaches its max.", + }, + []string{"donID", "serverURL"}, + ) +) + +type ReportPacker interface { + Pack(digest types.ConfigDigest, seqNr uint64, report ocr2types.Report, sigs []ocr2types.AttributedOnchainSignature) ([]byte, error) +} + +// A server handles the queue for a given mercury server + +type server struct { + lggr logger.SugaredLogger + verboseLogging bool + + transmitTimeout time.Duration + + c rpc.Client + pm *persistenceManager + q TransmitQueue + + url string + + evmPremiumLegacyPacker ReportPacker + evmStreamlinedPacker ReportPacker + jsonPacker ReportPacker + + transmitSuccessCount prometheus.Counter + transmitDuplicateCount prometheus.Counter + transmitConnectionErrorCount prometheus.Counter + transmitQueueInsertErrorCount prometheus.Counter + transmitQueuePushErrorCount prometheus.Counter + transmitConcurrentTransmitGauge prometheus.Gauge + + transmitThreadBusyCount atomic.Int32 + consecutiveTransmitErrorCount int + consecutiveTransmitUniqueErrors map[string]struct{} + consecutiveTransmitErrorMu sync.Mutex +} + +type QueueConfig interface { + ReaperMaxAge() time.Duration + TransmitQueueMaxSize() uint32 + TransmitTimeout() time.Duration +} + +func newServer(lggr logger.Logger, verboseLogging bool, cfg QueueConfig, client rpc.Client, orm ORM, serverURL string) *server { + pm := NewPersistenceManager(lggr, orm, serverURL, int(cfg.TransmitQueueMaxSize()), FlushDeletesFrequency, PruneFrequency, cfg.ReaperMaxAge()) + donIDStr := strconv.FormatUint(uint64(pm.DonID()), 10) + var codecLggr logger.Logger + if verboseLogging { + codecLggr = lggr + } else { + codecLggr = logger.Nop() + } + + s := &server{ + logger.Sugared(lggr), + verboseLogging, + cfg.TransmitTimeout(), + client, + pm, + NewTransmitQueue(lggr, serverURL, int(cfg.TransmitQueueMaxSize()), pm), + serverURL, + evm.NewReportCodecPremiumLegacy(codecLggr, pm.DonID()), + evm.NewReportCodecStreamlined(codecLggr), + llo.JSONReportCodec{}, + promTransmitSuccessCount.WithLabelValues(donIDStr, serverURL), + promTransmitDuplicateCount.WithLabelValues(donIDStr, serverURL), + promTransmitConnectionErrorCount.WithLabelValues(donIDStr, serverURL), + promTransmitQueueInsertErrorCount.WithLabelValues(donIDStr, serverURL), + promTransmitQueuePushErrorCount.WithLabelValues(donIDStr, serverURL), + promTransmitConcurrentTransmitGauge.WithLabelValues(donIDStr, serverURL), + atomic.Int32{}, + 0, + make(map[string]struct{}), + sync.Mutex{}, + } + + return s +} + +func (s *server) HealthReport() map[string]error { + report := map[string]error{} + services.CopyHealth(report, s.c.HealthReport()) + services.CopyHealth(report, s.q.HealthReport()) + return report +} + +func (s *server) transmitThreadBusyCountInc() { + val := s.transmitThreadBusyCount.Add(1) + s.transmitConcurrentTransmitGauge.Set(float64(val)) +} +func (s *server) transmitThreadBusyCountDec() { + val := s.transmitThreadBusyCount.Add(-1) + s.transmitConcurrentTransmitGauge.Set(float64(val)) +} + +func (s *server) spawnTransmitLoops(stopCh services.StopChan, wg *sync.WaitGroup, donID uint32, n int) { + donIDStr := strconv.FormatUint(uint64(donID), 10) + wg.Add(n) + for range n { + go s.spawnTransmitLoop(stopCh, wg, donIDStr) + } +} + +func (s *server) spawnTransmitLoop(stopCh services.StopChan, wg *sync.WaitGroup, donIDStr string) { + defer wg.Done() + s.transmitConcurrentTransmitGauge.Set(0) // initial set to populate metric + + // Exponential backoff with very short retry interval (since latency is a priority) + // 5ms, 10ms, 20ms, 40ms etc + b := backoff.Backoff{ + Min: 5 * time.Millisecond, + Max: 1 * time.Second, + Factor: 2, + Jitter: true, + } + ctx, cancel := stopCh.NewCtx() + defer cancel() + cont := true + for cont { + cont = func() bool { + t := s.q.BlockingPop() + if t == nil { + // queue was closed + return false + } + if t.Report.Info.ReportFormat == llotypes.ReportFormatCapabilityTrigger { + // `capability_trigger` reports are Data Feeds product specific and aren't sent to the Mercury servers + s.pm.AsyncDelete(t.Hash()) + return true + } + + s.transmitThreadBusyCountInc() + defer s.transmitThreadBusyCountDec() + + req, res, err := func(ctx context.Context) (*rpc.TransmitRequest, *rpc.TransmitResponse, error) { + ctx, cancelFn := context.WithTimeout(ctx, utils.WithJitter(s.transmitTimeout)) + defer cancelFn() + return s.transmit(ctx, t) + }(ctx) + + lggr := s.lggr.With("transmission", t, "response", res, "transmissionHash", fmt.Sprintf("%x", t.Hash())) + if req != nil { + lggr = s.lggr.With("req.Payload", req.Payload, "req.ReportFormat", req.ReportFormat) + } + + if ctx.Err() != nil { + // only canceled on transmitter close so we can exit + return false + } else if err != nil { + s.transmitConnectionErrorCount.Inc() + s.rateLimitedLogError(lggr, "Transmit report failed", err.Error()) + if ok := s.q.Push(t); !ok { + s.lggr.Error("Failed to push report to transmit queue; queue is closed") + return false + } + // Wait a backoff duration before pulling the most recent transmission + // the heap + select { + case <-time.After(b.Duration()): + return true + case <-stopCh: + return false + } + } + + b.Reset() + if res.Error == "" { + s.transmitSuccessCount.Inc() + s.resetConsecutiveTransmitFailures() + lggr.Debug("Transmit report success") + } else { + // We don't need to retry here because the mercury server + // has confirmed it received the report. We only need to retry + // on networking/unknown errors + switch res.Code { + case DuplicateReport: + s.transmitSuccessCount.Inc() + s.transmitDuplicateCount.Inc() + s.resetConsecutiveTransmitFailures() + lggr.Debug("Transmit report success; duplicate report") + default: + promTransmitServerErrorCount.WithLabelValues(donIDStr, s.url, strconv.FormatInt(int64(res.Code), 10)).Inc() + s.rateLimitedLogError(lggr, "Transmit report failed; mercury server returned error", fmt.Sprintf("mercury server returned error: %q, statusCode: %d", res.Error, res.Code)) + } + } + + s.pm.AsyncDelete(t.Hash()) + return true + }() + } +} + +func (s *server) rateLimitedLogError(lggr logger.Logger, msg string, err string) { + cnt, uniqueErrors := s.incConsecutiveTransmitErrorCount(err) + switch { + case cnt < 10: + // Log first 10 errors individually + lggr.Errorw(msg, "nErrs", 1, "err", err) + return + case cnt < 10_000: + // Log errors up to 10k in batches of 100 + if cnt%100 == 0 { + lggr.Errorw(msg+" (100 failures)", "nErrs", 100, "uniqueErrors", uniqueErrors) + } + return + default: + // After that, log every 10k errors + if cnt%10_000 == 0 { + lggr.Errorw(msg+" (10,000 failures)", "nErrs", 10_000, "uniqueErrors", uniqueErrors) + } + return + } +} + +func (s *server) incConsecutiveTransmitErrorCount(errStr string) (int, []string) { + s.consecutiveTransmitErrorMu.Lock() + defer s.consecutiveTransmitErrorMu.Unlock() + s.consecutiveTransmitErrorCount++ + s.consecutiveTransmitUniqueErrors[errStr] = struct{}{} + return s.consecutiveTransmitErrorCount, slices.Sorted(maps.Keys(s.consecutiveTransmitUniqueErrors)) +} + +func (s *server) resetConsecutiveTransmitFailures() { + s.consecutiveTransmitErrorMu.Lock() + s.consecutiveTransmitErrorCount = 0 + s.consecutiveTransmitUniqueErrors = make(map[string]struct{}) + s.consecutiveTransmitErrorMu.Unlock() +} + +func (s *server) transmit(ctx context.Context, t *Transmission) (*rpc.TransmitRequest, *rpc.TransmitResponse, error) { + var payload []byte + var err error + + switch t.Report.Info.ReportFormat { + case llotypes.ReportFormatJSON: + payload, err = s.jsonPacker.Pack(t.ConfigDigest, t.SeqNr, t.Report.Report, t.Sigs) + case llotypes.ReportFormatEVMPremiumLegacy, llotypes.ReportFormatEVMABIEncodeUnpacked, llotypes.ReportFormatEVMABIEncodeUnpackedExpr: + payload, err = s.evmPremiumLegacyPacker.Pack(t.ConfigDigest, t.SeqNr, t.Report.Report, t.Sigs) + case llotypes.ReportFormatEVMStreamlined: + payload, err = s.evmStreamlinedPacker.Pack(t.ConfigDigest, t.SeqNr, t.Report.Report, t.Sigs) + default: + return nil, nil, fmt.Errorf("Transmit failed; don't know how to Pack unsupported report format: %q", t.Report.Info.ReportFormat) + } + + if err != nil { + return nil, nil, fmt.Errorf("Transmit: encode failed; %w", err) + } + + req := &rpc.TransmitRequest{ + Payload: payload, + ReportFormat: uint32(t.Report.Info.ReportFormat), + } + + resp, err := s.c.Transmit(ctx, req) + return req, resp, err +} diff --git a/llo/transmitter/de/transmitter.go b/llo/transmitter/de/transmitter.go new file mode 100644 index 0000000..69b3471 --- /dev/null +++ b/llo/transmitter/de/transmitter.go @@ -0,0 +1,413 @@ +package de + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "golang.org/x/sync/errgroup" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + "github.com/smartcontractkit/chainlink-data-streams/rpc" +) + +const ( + // Mercury server error codes + DuplicateReport = 2 + commitInterval = time.Millisecond * 25 +) + +var ( + promTransmitSuccessCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_success_count", + Help: "Number of successful transmissions (duplicates are counted as success)", + }, + []string{"donID", "serverURL"}, + ) + promTransmitDuplicateCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_duplicate_count", + Help: "Number of transmissions where the server told us it was a duplicate", + }, + []string{"donID", "serverURL"}, + ) + promTransmitConnectionErrorCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "llo", + Subsystem: "mercurytransmitter", + Name: "transmit_connection_error_count", + Help: "Number of errored transmissions that failed due to problem with the connection", + }, + []string{"donID", "serverURL"}, + ) +) + +type MercuryTransmitterProtocol string + +const ( + MercuryTransmitterProtocolWSRPC MercuryTransmitterProtocol = "wsrpc" + MercuryTransmitterProtocolGRPC MercuryTransmitterProtocol = "grpc" +) + +func (m MercuryTransmitterProtocol) String() string { + return string(m) +} + +func (m *MercuryTransmitterProtocol) UnmarshalText(text []byte) error { + switch string(text) { + case "wsrpc": + *m = MercuryTransmitterProtocolWSRPC + case "grpc": + *m = MercuryTransmitterProtocolGRPC + default: + return fmt.Errorf("unknown mercury transmitter protocol: %s", text) + } + return nil +} + +type Transmission struct { + ServerURL string + ConfigDigest types.ConfigDigest + SeqNr uint64 + Report ocr3types.ReportWithInfo[llotypes.ReportInfo] + Sigs []types.AttributedOnchainSignature +} + +// Hash takes sha256 hash of all fields +func (t Transmission) Hash() [32]byte { + h := sha256.New() + h.Write([]byte(t.ServerURL)) + h.Write(t.ConfigDigest[:]) + if err := binary.Write(h, binary.BigEndian, t.SeqNr); err != nil { + // This should never happen + panic(err) + } + h.Write(t.Report.Report) + h.Write([]byte(t.Report.Info.LifeCycleStage)) + if err := binary.Write(h, binary.BigEndian, t.Report.Info.ReportFormat); err != nil { + // This should never happen + panic(err) + } + for _, sig := range t.Sigs { + h.Write(sig.Signature) + if err := binary.Write(h, binary.BigEndian, sig.Signer); err != nil { + // This should never happen + panic(err) + } + } + var result [32]byte + h.Sum(result[:0]) + return result +} + +type Transmitter interface { + llotypes.Transmitter + services.Service +} + +var _ Transmitter = (*transmitter)(nil) + +type Config interface { + Protocol() MercuryTransmitterProtocol + ReaperMaxAge() time.Duration + TransmitConcurrency() uint32 + TransmitQueueMaxSize() uint32 + TransmitTimeout() time.Duration +} + +type transmitter struct { + services.StateMachine + lggr logger.SugaredLogger + verboseLogging bool + cfg Config + + orm ORM + servers map[string]*server + + donID uint32 + fromAccount string + + stopCh services.StopChan + wg *sync.WaitGroup + + commitCh chan *Transmission +} + +type Opts struct { + Lggr logger.Logger + VerboseLogging bool + Cfg Config + Clients map[string]rpc.Client + FromAccount string + DonID uint32 + ORM ORM + CapabilitiesRegistry coretypes.CapabilitiesRegistry +} + +func New(opts Opts) Transmitter { + return newTransmitter(opts) +} + +func newTransmitter(opts Opts) *transmitter { + sugared := logger.Sugared(opts.Lggr).Named("LLOMercuryTransmitter") + servers := make(map[string]*server, len(opts.Clients)) + for serverURL, client := range opts.Clients { + sLggr := sugared.Named(fmt.Sprintf("%q", serverURL)).With("serverURL", serverURL) + servers[serverURL] = newServer(sLggr, opts.VerboseLogging, opts.Cfg, client, opts.ORM, serverURL) + } + return &transmitter{ + services.StateMachine{}, + sugared.Named("LLOMercuryTransmitter"), + opts.VerboseLogging, + opts.Cfg, + opts.ORM, + servers, + opts.DonID, + opts.FromAccount, + make(services.StopChan), + &sync.WaitGroup{}, + make(chan *Transmission, 1000*len(servers)), + } +} + +func (mt *transmitter) Start(ctx context.Context) (err error) { + return mt.StartOnce("LLOMercuryTransmitter", func() error { + if mt.verboseLogging { + mt.lggr.Debugw("Loading transmit requests from database") + } + + g, startCtx := errgroup.WithContext(ctx) + // Number of goroutines spawned per server will be + // TransmitConcurrency+2 (1 for persistence manager, 1 for client) + // + // This could potentially be reduced by implementing transmit batching, + // see: https://smartcontract-it.atlassian.net/browse/MERC-6635 + for _, s := range mt.servers { + // concurrent start of all servers + g.Go(func() error { + // Load DB transmissions and populate server transmit queue + transmissions, err := s.pm.Load(startCtx) + if err != nil { + return err + } + s.q.Init(transmissions) + + // Start all associated services + // + // client, queue etc should be started before spawning server loops + // + // pm must be stopped last to give it a chance to clean up the + // remaining transmissions + startClosers := []services.StartClose{s.pm, s.c, s.q} + if err := (&services.MultiStart{}).Start(startCtx, startClosers...); err != nil { + return err + } + + // Spawn transmission loop threads + s.spawnTransmitLoops(mt.stopCh, mt.wg, mt.donID, int(mt.cfg.TransmitConcurrency())) + return nil + }) + } + + mt.spawnCommitLoops() + return g.Wait() + }) +} + +func (mt *transmitter) Close() error { + return mt.StopOnce("LLOMercuryTransmitter", func() error { + // Drain all the queues first + var qs []io.Closer + for _, s := range mt.servers { + qs = append(qs, s.q) + } + if err := services.CloseAll(qs...); err != nil { + return err + } + + close(mt.stopCh) + mt.wg.Wait() + + // Close all the persistence managers + // Close all the clients + var closers []io.Closer + for _, s := range mt.servers { + closers = append(closers, s.pm) + closers = append(closers, s.c) + } + return services.CloseAll(closers...) + }) +} + +func (mt *transmitter) Name() string { return mt.lggr.Name() } + +func (mt *transmitter) HealthReport() map[string]error { + report := map[string]error{mt.Name(): mt.Healthy()} + for _, s := range mt.servers { + services.CopyHealth(report, s.HealthReport()) + } + return report +} + +// Transmit enqueues the report for transmission to the Mercury servers +func (mt *transmitter) Transmit( + ctx context.Context, + digest types.ConfigDigest, + seqNr uint64, + report ocr3types.ReportWithInfo[llotypes.ReportInfo], + sigs []types.AttributedOnchainSignature, +) (err error) { + ok := mt.IfStarted(func() { + for serverURL := range mt.servers { + t := &Transmission{ + ServerURL: serverURL, + ConfigDigest: digest, + SeqNr: seqNr, + Report: report, + Sigs: sigs, + } + select { + case mt.commitCh <- t: + case <-ctx.Done(): + err = fmt.Errorf("failed to add transmission to commit channel: %w", ctx.Err()) + } + } + }) + + if !ok { + return errors.New("transmitter is not started") + } + + return err +} + +func (mt *transmitter) transmit(ctx context.Context, transmissions []*Transmission) error { + // On shutdown appears that libocr can pass us a pre-canceled context; + // don't even bother trying to insert/transmit in this case + if ctx.Err() != nil { + return fmt.Errorf("cannot transmit; context already canceled: %w", ctx.Err()) + } + + // NOTE: This insert on its own can leave orphaned records in the case of + // shutdown, because: + // 1. Transmitter is shut down after oracle + // 2. OCR may pass a pre-canceled context or a context that is canceled mid-transmit + // 3. Insert can succeed even if the context is canceled, but return error + // + // Usually the number of orphaned records will be very small, and they + // would be transmitted/cleaned up on the next boot anyway. + // + // However, there are two ways to avoid this: + // 1. Use a transaction to rollback the insert on error + // 2. Allow the insert anyway (it will be transmitted on next boot) and be + // sure that the persistence manager issues a final cleanup that truncates + // the table to exactly maxSize records. Since persistenceManager is shut + // down AFTER the Oracle closes, this should always catch the straggler + // records. + // + // Since this is a hot path, the performance impact of holding a + // transaction open is too high, hence we choose option 2. + // + // In very rare cases if the final delete fails for some reason, we could + // end up with slightly more than maxSize records persisted to the DB on + // application exit. + // + // Must insert BEFORE pushing to queue since the queue will handle deletion + // on queue overflow. + if err := mt.orm.Insert(ctx, transmissions); err != nil { + return err + } + + for i := range transmissions { + t := transmissions[i] + if mt.verboseLogging { + mt.lggr.Debugw("Transmit report", + "digest", t.ConfigDigest.Hex(), "seqNr", t.SeqNr, "reportFormat", t.Report.Info.ReportFormat, + "reportLifeCycleStage", t.Report.Info.LifeCycleStage, + "transmissionHash", fmt.Sprintf("%x", t.Hash())) + } + + // OK to do this synchronously since pushing to queue is just a mutex + // lock and array append and ought to be extremely fast + s := mt.servers[t.ServerURL] + if ok := s.q.Push(t); !ok { + s.transmitQueuePushErrorCount.Inc() + // This shouldn't be possible since transmitter is always shut down + // after oracle + return errors.New("transmit queue is closed") + } + } + + return nil +} + +// FromAccount returns the stringified (hex) CSA public key +func (mt *transmitter) FromAccount(ctx context.Context) (ocrtypes.Account, error) { + return ocrtypes.Account(mt.fromAccount), nil +} + +func (mt *transmitter) spawnCommitLoops() { + for x := 0; x < len(mt.servers); x++ { + mt.wg.Add(1) + + go func() { + defer mt.wg.Done() + + var err error + ctx, cancel := mt.stopCh.NewCtx() + defer cancel() + + buff := cap(mt.commitCh) / 10 + transmissions := make([]*Transmission, 0, buff) + ticker := time.NewTicker(commitInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + if len(transmissions) >= buff { + closeCtx, closeCancel := context.WithTimeout(context.Background(), time.Second) + defer closeCancel() + if err = mt.transmit(closeCtx, transmissions); err != nil { + mt.lggr.Error("Error transmitting records when stopping", "error", err) + } + } + return + + case <-ticker.C: + if len(transmissions) > 0 { + err = mt.transmit(ctx, transmissions) + transmissions = make([]*Transmission, 0, buff) + } + + case t := <-mt.commitCh: + transmissions = append(transmissions, t) + if len(transmissions) >= buff { + err = mt.transmit(ctx, transmissions) + transmissions = make([]*Transmission, 0, buff) + } + } + + if err != nil { + mt.lggr.Error("Error transmitting records", "error", err) + } + } + }() + } +} diff --git a/llo/transmitter/de/transmitter_test.go b/llo/transmitter/de/transmitter_test.go new file mode 100644 index 0000000..eb1efa5 --- /dev/null +++ b/llo/transmitter/de/transmitter_test.go @@ -0,0 +1,313 @@ +package de + +import ( + "context" + "crypto/ed25519" + "encoding/hex" + "sync" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + "github.com/smartcontractkit/chainlink-data-streams/mercury/testutils" + "github.com/smartcontractkit/chainlink-data-streams/rpc" +) + +type mockCfg struct{} + +func (m mockCfg) Protocol() MercuryTransmitterProtocol { + return MercuryTransmitterProtocolGRPC +} + +func (m mockCfg) ReaperMaxAge() time.Duration { + return 0 +} + +func (m mockCfg) TransmitQueueMaxSize() uint32 { + return 10_000 +} + +func (m mockCfg) TransmitTimeout() time.Duration { + return 1 * time.Hour +} + +func (m mockCfg) TransmitConcurrency() uint32 { + return 5 +} + +type MockGRPCClient struct { + TransmitF func(ctx context.Context, in *rpc.TransmitRequest) (*rpc.TransmitResponse, error) +} + +func (m *MockGRPCClient) Name() string { return "" } +func (m *MockGRPCClient) Start(context.Context) error { return nil } +func (m *MockGRPCClient) Close() error { return nil } +func (m *MockGRPCClient) HealthReport() map[string]error { return map[string]error{} } +func (m *MockGRPCClient) Ready() error { return nil } +func (m *MockGRPCClient) Transmit(ctx context.Context, in *rpc.TransmitRequest) (*rpc.TransmitResponse, error) { + return m.TransmitF(ctx, in) +} +func (m *MockGRPCClient) ServerURL() string { return "mock server url" } + +func Test_Transmitter_Transmit(t *testing.T) { + lggr := logger.Test(t) + db := testutils.NewSqlxDB(t) + donID := uint32(123456) + orm := NewORM(db, donID) + clients := map[string]rpc.Client{} + + t.Run("errors if not started", func(t *testing.T) { + mt := newTransmitter(Opts{ + Lggr: lggr, + Cfg: mockCfg{}, + Clients: clients, + FromAccount: hex.EncodeToString(ed25519.PublicKey{}), + DonID: donID, + ORM: orm, + }) + + seqNr := uint64(55) + report := makeSampleReport() + digest := makeSampleConfigDigest() + sigs := []types.AttributedOnchainSignature{{ + Signature: []byte{22}, + Signer: commontypes.OracleID(43), + }} + err := mt.Transmit(t.Context(), digest, seqNr, report, sigs) + require.Error(t, err) + assert.Contains(t, err.Error(), "transmitter is not started") + }) + + t.Run("with multiple mercury servers", func(t *testing.T) { + t.Run("transmission successfully enqueued", func(t *testing.T) { + c := &MockGRPCClient{} + clients[sURL] = c + clients[sURL2] = c + clients[sURL3] = c + + mt := newTransmitter(Opts{ + Lggr: lggr, + Cfg: mockCfg{}, + Clients: clients, + FromAccount: hex.EncodeToString(ed25519.PublicKey{}), + DonID: donID, + ORM: orm, + }) + err := mt.StartOnce("SimulateTransmitterStart", func() error { + // init the queue since we simulate starting transmitter + require.NoError(t, mt.servers[sURL].q.Init([]*Transmission{})) + require.NoError(t, mt.servers[sURL2].q.Init([]*Transmission{})) + require.NoError(t, mt.servers[sURL3].q.Init([]*Transmission{})) + mt.spawnCommitLoops() + + return nil + }) + require.NoError(t, err) + + seqNr := uint64(55) + report := makeSampleReport() + digest := makeSampleConfigDigest() + sigs := []types.AttributedOnchainSignature{{ + Signature: []byte{22}, + Signer: commontypes.OracleID(43), + }} + err = mt.Transmit(t.Context(), digest, seqNr, report, sigs) + require.NoError(t, err) + + // wait for the commit loop to run + time.Sleep(2 * commitInterval) + + // ensure it was added to the queue + require.Equal(t, 1, mt.servers[sURL].q.(*transmitQueue).Len()) + assert.Equal(t, &Transmission{ + ServerURL: sURL, + ConfigDigest: digest, + SeqNr: seqNr, + Report: report, + Sigs: sigs, + }, mt.servers[sURL].q.(*transmitQueue).pq.Pop().(*Transmission)) + require.Equal(t, 1, mt.servers[sURL2].q.(*transmitQueue).Len()) + assert.Equal(t, &Transmission{ + ServerURL: sURL2, + ConfigDigest: digest, + SeqNr: seqNr, + Report: report, + Sigs: sigs, + }, mt.servers[sURL2].q.(*transmitQueue).pq.Pop().(*Transmission)) + require.Equal(t, 1, mt.servers[sURL3].q.(*transmitQueue).Len()) + assert.Equal(t, &Transmission{ + ServerURL: sURL3, + ConfigDigest: digest, + SeqNr: seqNr, + Report: report, + Sigs: sigs, + }, mt.servers[sURL3].q.(*transmitQueue).pq.Pop().(*Transmission)) + }) + }) +} + +type mockQ struct { + ch chan *Transmission +} + +func newMockQ() *mockQ { + return &mockQ{make(chan *Transmission, 100)} +} + +func (m *mockQ) Start(context.Context) error { return nil } +func (m *mockQ) Close() error { + m.ch <- nil + return nil +} +func (m *mockQ) Ready() error { return nil } +func (m *mockQ) HealthReport() map[string]error { return nil } +func (m *mockQ) Name() string { return "" } +func (m *mockQ) BlockingPop() (t *Transmission) { + val := <-m.ch + return val +} +func (m *mockQ) Push(t *Transmission) (ok bool) { + m.ch <- t + return true +} +func (m *mockQ) Init(transmissions []*Transmission) error { return nil } +func (m *mockQ) IsEmpty() bool { return false } + +func Test_Transmitter_runQueueLoop(t *testing.T) { + donIDStr := "555" + lggr := logger.Test(t) + c := &MockGRPCClient{} + db := testutils.NewSqlxDB(t) + donID := uint32(123456) + orm := NewORM(db, donID) + cfg := mockCfg{} + + s := newServer(lggr, true, cfg, c, orm, sURL) + + t.Run("pulls from queue and transmits successfully", func(t *testing.T) { + transmit := make(chan *rpc.TransmitRequest, 1) + c.TransmitF = func(ctx context.Context, in *rpc.TransmitRequest) (*rpc.TransmitResponse, error) { + transmit <- in + return &rpc.TransmitResponse{Code: 0, Error: ""}, nil + } + q := newMockQ() + s.q = q + wg := &sync.WaitGroup{} + wg.Add(1) + + go s.spawnTransmitLoop(nil, wg, donIDStr) + + transmission := makeValidTransmission() + q.Push(transmission) + + select { + case tr := <-transmit: + assert.Equal(t, []byte{0x0, 0x9, 0x57, 0xdd, 0x2f, 0x63, 0x56, 0x69, 0x34, 0xfd, 0xc2, 0xe1, 0xcd, 0xc1, 0xe, 0x3e, 0x25, 0xb9, 0x26, 0x5a, 0x16, 0x23, 0x91, 0xa6, 0x53, 0x16, 0x66, 0x59, 0x51, 0x0, 0x28, 0x7c, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xe2, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x80, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x20, 0x0, 0x3, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x8e, 0x95, 0xcf, 0xb5, 0xd8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0xd0, 0x1c, 0x67, 0xa9, 0xcf, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xdf, 0x3, 0xca, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x1c, 0x93, 0x6d, 0xa4, 0xf2, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x14, 0x8d, 0x9a, 0xc1, 0xd9, 0x6f, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x40, 0x5c, 0xcf, 0xa1, 0xbc, 0x63, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x9d, 0xab, 0x8f, 0xa7, 0xca, 0x7, 0x62, 0x57, 0xf7, 0x11, 0x2c, 0xb7, 0xf3, 0x49, 0x37, 0x12, 0xbd, 0xe, 0x14, 0x27, 0xfc, 0x32, 0x5c, 0xec, 0xa6, 0xb9, 0x7f, 0xf9, 0xd7, 0x7b, 0xa6, 0x36, 0x9a, 0x47, 0x4a, 0x3, 0x1a, 0x95, 0xcf, 0x46, 0x10, 0xaf, 0xcc, 0x90, 0x49, 0xb2, 0xce, 0xbf, 0x63, 0xaa, 0xc7, 0x25, 0x4d, 0x2a, 0x8, 0x36, 0xda, 0xd5, 0x9f, 0x9d, 0x63, 0x69, 0x22, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x9d, 0x84, 0x29, 0xbf, 0xd4, 0xeb, 0xc5, 0xc9, 0x29, 0xef, 0xdd, 0xd3, 0x2f, 0xa6, 0x25, 0x63, 0xda, 0xd9, 0x2c, 0xa1, 0x4a, 0xba, 0x75, 0xb2, 0x85, 0x25, 0x8f, 0x2b, 0x84, 0xcd, 0x99, 0x36, 0xd9, 0x6e, 0xf, 0xae, 0x7b, 0xd1, 0x61, 0x59, 0xf, 0x36, 0x4a, 0x22, 0xec, 0xde, 0x45, 0x32, 0xe0, 0x5b, 0x5c, 0xe3, 0x14, 0x29, 0x4, 0x60, 0x7b, 0xce, 0xa3, 0x89, 0x6b, 0xbb, 0xe0}, tr.Payload) + assert.Equal(t, int(transmission.Report.Info.ReportFormat), int(tr.ReportFormat)) + case <-time.After(tests.WaitTimeout(t)): + t.Fatal("expected a transmit request to be sent") + } + + q.Close() + wg.Wait() + }) + + t.Run("on duplicate, success", func(t *testing.T) { + transmit := make(chan *rpc.TransmitRequest, 1) + c.TransmitF = func(ctx context.Context, in *rpc.TransmitRequest) (*rpc.TransmitResponse, error) { + transmit <- in + return &rpc.TransmitResponse{Code: DuplicateReport, Error: ""}, nil + } + q := newMockQ() + s.q = q + wg := &sync.WaitGroup{} + wg.Add(1) + + go s.spawnTransmitLoop(nil, wg, donIDStr) + + transmission := makeValidTransmission() + q.Push(transmission) + + select { + case tr := <-transmit: + assert.Equal(t, []byte{0x0, 0x9, 0x57, 0xdd, 0x2f, 0x63, 0x56, 0x69, 0x34, 0xfd, 0xc2, 0xe1, 0xcd, 0xc1, 0xe, 0x3e, 0x25, 0xb9, 0x26, 0x5a, 0x16, 0x23, 0x91, 0xa6, 0x53, 0x16, 0x66, 0x59, 0x51, 0x0, 0x28, 0x7c, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xe2, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x80, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x20, 0x0, 0x3, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x8e, 0x95, 0xcf, 0xb5, 0xd8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0xd0, 0x1c, 0x67, 0xa9, 0xcf, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xdf, 0x3, 0xca, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x1c, 0x93, 0x6d, 0xa4, 0xf2, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x14, 0x8d, 0x9a, 0xc1, 0xd9, 0x6f, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x40, 0x5c, 0xcf, 0xa1, 0xbc, 0x63, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x9d, 0xab, 0x8f, 0xa7, 0xca, 0x7, 0x62, 0x57, 0xf7, 0x11, 0x2c, 0xb7, 0xf3, 0x49, 0x37, 0x12, 0xbd, 0xe, 0x14, 0x27, 0xfc, 0x32, 0x5c, 0xec, 0xa6, 0xb9, 0x7f, 0xf9, 0xd7, 0x7b, 0xa6, 0x36, 0x9a, 0x47, 0x4a, 0x3, 0x1a, 0x95, 0xcf, 0x46, 0x10, 0xaf, 0xcc, 0x90, 0x49, 0xb2, 0xce, 0xbf, 0x63, 0xaa, 0xc7, 0x25, 0x4d, 0x2a, 0x8, 0x36, 0xda, 0xd5, 0x9f, 0x9d, 0x63, 0x69, 0x22, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x9d, 0x84, 0x29, 0xbf, 0xd4, 0xeb, 0xc5, 0xc9, 0x29, 0xef, 0xdd, 0xd3, 0x2f, 0xa6, 0x25, 0x63, 0xda, 0xd9, 0x2c, 0xa1, 0x4a, 0xba, 0x75, 0xb2, 0x85, 0x25, 0x8f, 0x2b, 0x84, 0xcd, 0x99, 0x36, 0xd9, 0x6e, 0xf, 0xae, 0x7b, 0xd1, 0x61, 0x59, 0xf, 0x36, 0x4a, 0x22, 0xec, 0xde, 0x45, 0x32, 0xe0, 0x5b, 0x5c, 0xe3, 0x14, 0x29, 0x4, 0x60, 0x7b, 0xce, 0xa3, 0x89, 0x6b, 0xbb, 0xe0}, tr.Payload) + assert.Equal(t, int(transmission.Report.Info.ReportFormat), int(tr.ReportFormat)) + case <-time.After(tests.WaitTimeout(t)): + t.Fatal("expected a transmit request to be sent") + } + + q.Close() + wg.Wait() + }) + t.Run("on server-side error, does not retry", func(t *testing.T) { + transmit := make(chan *rpc.TransmitRequest, 1) + c.TransmitF = func(ctx context.Context, in *rpc.TransmitRequest) (*rpc.TransmitResponse, error) { + transmit <- in + return &rpc.TransmitResponse{Code: DuplicateReport, Error: ""}, nil + } + q := newMockQ() + s.q = q + wg := &sync.WaitGroup{} + wg.Add(1) + + go s.spawnTransmitLoop(nil, wg, donIDStr) + + transmission := makeValidTransmission() + q.Push(transmission) + + select { + case tr := <-transmit: + assert.Equal(t, []byte{0x0, 0x9, 0x57, 0xdd, 0x2f, 0x63, 0x56, 0x69, 0x34, 0xfd, 0xc2, 0xe1, 0xcd, 0xc1, 0xe, 0x3e, 0x25, 0xb9, 0x26, 0x5a, 0x16, 0x23, 0x91, 0xa6, 0x53, 0x16, 0x66, 0x59, 0x51, 0x0, 0x28, 0x7c, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xe2, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x80, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x20, 0x0, 0x3, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x8e, 0x95, 0xcf, 0xb5, 0xd8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0xd0, 0x1c, 0x67, 0xa9, 0xcf, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xdf, 0x3, 0xca, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x1c, 0x93, 0x6d, 0xa4, 0xf2, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x14, 0x8d, 0x9a, 0xc1, 0xd9, 0x6f, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x40, 0x5c, 0xcf, 0xa1, 0xbc, 0x63, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x9d, 0xab, 0x8f, 0xa7, 0xca, 0x7, 0x62, 0x57, 0xf7, 0x11, 0x2c, 0xb7, 0xf3, 0x49, 0x37, 0x12, 0xbd, 0xe, 0x14, 0x27, 0xfc, 0x32, 0x5c, 0xec, 0xa6, 0xb9, 0x7f, 0xf9, 0xd7, 0x7b, 0xa6, 0x36, 0x9a, 0x47, 0x4a, 0x3, 0x1a, 0x95, 0xcf, 0x46, 0x10, 0xaf, 0xcc, 0x90, 0x49, 0xb2, 0xce, 0xbf, 0x63, 0xaa, 0xc7, 0x25, 0x4d, 0x2a, 0x8, 0x36, 0xda, 0xd5, 0x9f, 0x9d, 0x63, 0x69, 0x22, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x9d, 0x84, 0x29, 0xbf, 0xd4, 0xeb, 0xc5, 0xc9, 0x29, 0xef, 0xdd, 0xd3, 0x2f, 0xa6, 0x25, 0x63, 0xda, 0xd9, 0x2c, 0xa1, 0x4a, 0xba, 0x75, 0xb2, 0x85, 0x25, 0x8f, 0x2b, 0x84, 0xcd, 0x99, 0x36, 0xd9, 0x6e, 0xf, 0xae, 0x7b, 0xd1, 0x61, 0x59, 0xf, 0x36, 0x4a, 0x22, 0xec, 0xde, 0x45, 0x32, 0xe0, 0x5b, 0x5c, 0xe3, 0x14, 0x29, 0x4, 0x60, 0x7b, 0xce, 0xa3, 0x89, 0x6b, 0xbb, 0xe0}, tr.Payload) + assert.Equal(t, int(transmission.Report.Info.ReportFormat), int(tr.ReportFormat)) + case <-time.After(tests.WaitTimeout(t)): + t.Fatal("expected a transmit request to be sent") + } + + q.Close() + wg.Wait() + }) + t.Run("on transmit error, retries", func(t *testing.T) { + transmit := make(chan *rpc.TransmitRequest, 1) + c.TransmitF = func(ctx context.Context, in *rpc.TransmitRequest) (*rpc.TransmitResponse, error) { + transmit <- in + return &rpc.TransmitResponse{}, errors.New("transmission error") + } + q := newMockQ() + s.q = q + wg := &sync.WaitGroup{} + wg.Add(1) + stopCh := make(chan struct{}, 1) + + go s.spawnTransmitLoop(stopCh, wg, donIDStr) + + transmission := makeValidTransmission() + q.Push(transmission) + + cnt := 0 + Loop: + for { + select { + case tr := <-transmit: + assert.Equal(t, []byte{0x0, 0x9, 0x57, 0xdd, 0x2f, 0x63, 0x56, 0x69, 0x34, 0xfd, 0xc2, 0xe1, 0xcd, 0xc1, 0xe, 0x3e, 0x25, 0xb9, 0x26, 0x5a, 0x16, 0x23, 0x91, 0xa6, 0x53, 0x16, 0x66, 0x59, 0x51, 0x0, 0x28, 0x7c, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xe2, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x80, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x20, 0x0, 0x3, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xde, 0xf5, 0xba, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x8e, 0x95, 0xcf, 0xb5, 0xd8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0xd0, 0x1c, 0x67, 0xa9, 0xcf, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x66, 0xdf, 0x3, 0xca, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x1c, 0x93, 0x6d, 0xa4, 0xf2, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x14, 0x8d, 0x9a, 0xc1, 0xd9, 0x6f, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x40, 0x5c, 0xcf, 0xa1, 0xbc, 0x63, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x9d, 0xab, 0x8f, 0xa7, 0xca, 0x7, 0x62, 0x57, 0xf7, 0x11, 0x2c, 0xb7, 0xf3, 0x49, 0x37, 0x12, 0xbd, 0xe, 0x14, 0x27, 0xfc, 0x32, 0x5c, 0xec, 0xa6, 0xb9, 0x7f, 0xf9, 0xd7, 0x7b, 0xa6, 0x36, 0x9a, 0x47, 0x4a, 0x3, 0x1a, 0x95, 0xcf, 0x46, 0x10, 0xaf, 0xcc, 0x90, 0x49, 0xb2, 0xce, 0xbf, 0x63, 0xaa, 0xc7, 0x25, 0x4d, 0x2a, 0x8, 0x36, 0xda, 0xd5, 0x9f, 0x9d, 0x63, 0x69, 0x22, 0xb3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x30, 0x9d, 0x84, 0x29, 0xbf, 0xd4, 0xeb, 0xc5, 0xc9, 0x29, 0xef, 0xdd, 0xd3, 0x2f, 0xa6, 0x25, 0x63, 0xda, 0xd9, 0x2c, 0xa1, 0x4a, 0xba, 0x75, 0xb2, 0x85, 0x25, 0x8f, 0x2b, 0x84, 0xcd, 0x99, 0x36, 0xd9, 0x6e, 0xf, 0xae, 0x7b, 0xd1, 0x61, 0x59, 0xf, 0x36, 0x4a, 0x22, 0xec, 0xde, 0x45, 0x32, 0xe0, 0x5b, 0x5c, 0xe3, 0x14, 0x29, 0x4, 0x60, 0x7b, 0xce, 0xa3, 0x89, 0x6b, 0xbb, 0xe0}, tr.Payload) + assert.Equal(t, int(transmission.Report.Info.ReportFormat), int(tr.ReportFormat)) + if cnt > 2 { + break Loop + } + cnt++ + case <-time.After(tests.WaitTimeout(t)): + t.Fatal("expected 3 transmit requests to be sent") + } + } + + close(stopCh) + wg.Wait() + }) +} diff --git a/llo/transmitter/transmitter.go b/llo/transmitter/transmitter.go new file mode 100644 index 0000000..e82b517 --- /dev/null +++ b/llo/transmitter/transmitter.go @@ -0,0 +1,197 @@ +package transmitter + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "golang.org/x/sync/errgroup" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core" + llotypes "github.com/smartcontractkit/chainlink-common/pkg/types/llo" + mercurytransmitter "github.com/smartcontractkit/chainlink-data-streams/llo/transmitter/de" + + "github.com/smartcontractkit/chainlink-data-streams/llo/config" + "github.com/smartcontractkit/chainlink-data-streams/llo/cre" +) + +// LLO Transmitter implementation, based on +// core/services/relay/evm/mercury/transmitter.go +// +// If you need to "fan-out" transmits and send reports to a new destination, +// add a new subTransmitter + +const ( + // Mercury server error codes + DuplicateReport = 2 +) + +type TransmitNotifier interface { + OnTransmit(listen func(digest types.ConfigDigest, seqNr uint64)) +} + +type Transmitter interface { + llotypes.Transmitter + services.Service +} + +type TransmitterRetirementReportCacheWriter interface { + StoreAttestedRetirementReport(ctx context.Context, cd ocrtypes.ConfigDigest, seqNr uint64, retirementReport []byte, sigs []types.AttributedOnchainSignature) error +} + +type onTransmit struct { + mu sync.RWMutex + listeners []func(digest types.ConfigDigest, seqNr uint64) +} + +func (o *onTransmit) OnTransmit(listen func(digest types.ConfigDigest, seqNr uint64)) { + o.mu.Lock() + defer o.mu.Unlock() + o.listeners = append(o.listeners, listen) +} + +func (o *onTransmit) notify(digest types.ConfigDigest, seqNr uint64) { + o.mu.RLock() + defer o.mu.RUnlock() + for _, listener := range o.listeners { + go listener(digest, seqNr) + } +} + +type transmitter struct { + services.StateMachine + lggr logger.Logger + verboseLogging bool + fromAccount string + + subTransmitters []Transmitter + retirementReportCache TransmitterRetirementReportCacheWriter + *onTransmit +} + +type TransmitterOpts struct { + Lggr logger.Logger + DonID uint32 + VerboseLogging bool + FromAccount string + MercuryTransmitterOpts *mercurytransmitter.Opts + Subtransmitters []config.TransmitterConfig + RetirementReportCache TransmitterRetirementReportCacheWriter + CapabilitiesRegistry coretypes.CapabilitiesRegistry +} + +// The transmitter will handle starting and stopping the subtransmitters +func NewTransmitter(opts TransmitterOpts) (Transmitter, error) { + subTransmitters := []Transmitter{} + + if opts.MercuryTransmitterOpts != nil { + subTransmitters = append( + subTransmitters, + mercurytransmitter.New(*opts.MercuryTransmitterOpts), + ) + } + for _, cfg := range opts.Subtransmitters { + switch cfg.Type { + case config.TransmitterTypeCRE: + var creTransmitterCfg cre.TransmitterConfig + err := json.Unmarshal(cfg.Opts, &creTransmitterCfg) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal CRE transmitter config: %w", err) + } + creTransmitterCfg.Logger = opts.Lggr + creTransmitterCfg.CapabilitiesRegistry = opts.CapabilitiesRegistry + creTransmitterCfg.DonID = opts.DonID + creTransmitter, err := creTransmitterCfg.NewTransmitter() + if err != nil { + return nil, fmt.Errorf("failed to create CRE transmitter: %w", err) + } + subTransmitters = append(subTransmitters, creTransmitter) + default: + return nil, fmt.Errorf("unknown transmitter type: %s", cfg.Type) + } + } + return &transmitter{ + services.StateMachine{}, + opts.Lggr, + opts.VerboseLogging, + opts.FromAccount, + subTransmitters, + opts.RetirementReportCache, + &onTransmit{}, + }, nil +} + +func (t *transmitter) Start(ctx context.Context) error { + return t.StartOnce("llo.Transmitter", func() error { + for _, st := range t.subTransmitters { + if err := st.Start(ctx); err != nil { + return err + } + } + return nil + }) +} + +func (t *transmitter) Close() error { + return t.StopOnce("llo.Transmitter", func() error { + for _, st := range t.subTransmitters { + if err := st.Close(); err != nil { + return err + } + } + return nil + }) +} + +func (t *transmitter) HealthReport() map[string]error { + report := map[string]error{t.Name(): t.Healthy()} + for _, st := range t.subTransmitters { + services.CopyHealth(report, st.HealthReport()) + } + return report +} + +func (t *transmitter) Name() string { return t.lggr.Name() } + +func (t *transmitter) Transmit( + ctx context.Context, + digest types.ConfigDigest, + seqNr uint64, + report ocr3types.ReportWithInfo[llotypes.ReportInfo], + sigs []types.AttributedOnchainSignature, +) (err error) { + if t.verboseLogging { + t.lggr.Debugw("Transmit report", "digest", digest, "seqNr", seqNr, "report", report, "sigs", sigs) + } + + if report.Info.ReportFormat == llotypes.ReportFormatRetirement { + // Retirement reports don't get transmitted; rather, they are stored in + // the RetirementReportCache + t.lggr.Debugw("Storing retirement report", "digest", digest, "seqNr", seqNr) + if err := t.retirementReportCache.StoreAttestedRetirementReport(ctx, digest, seqNr, report.Report, sigs); err != nil { + return fmt.Errorf("failed to write retirement report to cache: %w", err) + } + return nil + } + t.notify(digest, seqNr) + + g := new(errgroup.Group) + for _, st := range t.subTransmitters { + g.Go(func() error { + return st.Transmit(ctx, digest, seqNr, report, sigs) + }) + } + return g.Wait() +} + +// FromAccount returns the stringified (hex) CSA public key +func (t *transmitter) FromAccount(ctx context.Context) (ocr2types.Account, error) { + return ocr2types.Account(t.fromAccount), nil +}