From c4b5d226eff046d4382859d70712a34cb7e8b613 Mon Sep 17 00:00:00 2001 From: Sujatha Sivaramakrishnan Date: Thu, 2 Apr 2026 14:05:31 +0530 Subject: [PATCH] *: plug in no-op metric implementations for nil fields This is a follow-up to PR #33 with the following fixes: - Initialize nil metric fields with no-op implementations at construction time (server, client, pool) so that call sites don't need nil guards. - Unexport MeteredTransport behind a NewMeteredTransport constructor that handles nil counters. --- drpcclient/dialoptions.go | 8 +- drpcconn/conn.go | 22 +++-- drpcmetrics/metrics.go | 73 +++++++++------ drpcpool/pool.go | 51 ++++++----- drpcpool/pool_test.go | 6 +- drpcserver/server.go | 35 +++----- internal/integration/metrics_test.go | 130 +++++++-------------------- 7 files changed, 141 insertions(+), 184 deletions(-) diff --git a/drpcclient/dialoptions.go b/drpcclient/dialoptions.go index 954aa84..a34d954 100644 --- a/drpcclient/dialoptions.go +++ b/drpcclient/dialoptions.go @@ -134,6 +134,11 @@ func DialContext(ctx context.Context, address string, opts ...DialOption) (conn } } + collectMetrics := true + if options.metrics == nil { + collectMetrics = false + options.metrics = &drpcmetrics.ClientMetrics{} + } return drpcconn.NewWithOptions(netConn, drpcconn.Options{ Manager: drpcmanager.Options{ Reader: drpcwire.ReaderOptions{ @@ -144,6 +149,7 @@ func DialContext(ctx context.Context, address string, opts ...DialOption) (conn }, SoftCancel: false, }, - Metrics: options.metrics, + CollectMetrics: collectMetrics, + Metrics: *options.metrics, }), nil } diff --git a/drpcconn/conn.go b/drpcconn/conn.go index 0a677ba..8727036 100644 --- a/drpcconn/conn.go +++ b/drpcconn/conn.go @@ -25,13 +25,16 @@ type Options struct { // Manager controls the options we pass to the manager of this conn. Manager drpcmanager.Options + // TODO: (server): deprecate this // CollectStats controls whether the client should collect stats on the // rpcs it creates. CollectStats bool - // Metrics holds optional metrics the client will populate. If nil, no - // metrics are recorded. - Metrics *drpcmetrics.ClientMetrics + // CollectMetrics controls whether the client should collect metrics. + CollectMetrics bool + + // Metrics holds optional metrics the client will populate. + Metrics drpcmetrics.ClientMetrics } // Conn is a drpc client connection. @@ -41,7 +44,7 @@ type Conn struct { mu sync.Mutex wbuf []byte - stats map[string]*drpcstats.Stats + stats map[string]*drpcstats.Stats // TODO (server): deprecate } var _ drpc.Conn = (*Conn)(nil) @@ -56,18 +59,19 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Conn { tr: tr, } - if opts.Metrics != nil { - mt := &drpcmetrics.MeteredTransport{Transport: tr, BytesSent: opts.Metrics.BytesSent, BytesRecv: opts.Metrics.BytesRecv} - tr = mt - c.tr = tr + if opts.CollectMetrics { + mt := drpcmetrics.ToMeteredTransport(tr, opts.Metrics.BytesSent, + opts.Metrics.BytesRecv) + c.tr = mt } + // TODO: (server): deprecate if opts.CollectStats { drpcopts.SetManagerStatsCB(&opts.Manager.Internal, c.getStats) c.stats = make(map[string]*drpcstats.Stats) } - c.man = drpcmanager.NewWithOptions(tr, opts.Manager) + c.man = drpcmanager.NewWithOptions(c.tr, opts.Manager) return c } diff --git a/drpcmetrics/metrics.go b/drpcmetrics/metrics.go index f92460b..e0f6c4a 100644 --- a/drpcmetrics/metrics.go +++ b/drpcmetrics/metrics.go @@ -11,27 +11,37 @@ import ( ) // Counter is a metric that can only be incremented (monotonically increasing). -// The labels parameter contains key-value pairs for metric dimensions -// (e.g. rpcService, rpcMethod). It may be nil when no -// dimensional context is available. -// The concrete type *must* provide a thread-safe implementation for these -// methods. +// The concrete type must provide a thread-safe implementation for the method. type Counter interface { - Inc(labels map[string]string, v int64) + Inc(v int64) } // NoOpCounter is a Counter implementation that does nothing. type NoOpCounter struct{} // Inc implements Counter. -func (NoOpCounter) Inc(labels map[string]string, v int64) {} +func (NoOpCounter) Inc(v int64) {} + +// LabeledCounter is a Counter that accepts dimensional labels on each +// increment. The labels parameter contains key-value pairs for metric +// dimensions. It may be nil when no dimensional context is available. +// The concrete type must provide a thread-safe implementation. +type LabeledCounter interface { + Inc(labels map[string]string, v int64) +} -// Gauge is a metric that can increase and decrease (e.g. pool size, -// active count). Update sets the gauge to the given absolute value. +// NoOpLabeledCounter is a LabeledCounter implementation that does nothing. +type NoOpLabeledCounter struct{} + +// Inc implements LabeledCounter. +func (NoOpLabeledCounter) Inc(labels map[string]string, v int64) {} + +// Gauge is a metric that can increase and decrease (e.g. pool size). +// Update sets the gauge to the given absolute value. // // Note: Gauge values may go up or down; Counter values must only increase. -// The concrete type *must* provide a thread-safe implementation for these -// methods. +// The concrete type must provide a thread-safe implementation for the +// method. type Gauge interface { Update(labels map[string]string, v int64) } @@ -42,30 +52,43 @@ type NoOpGauge struct{} // Update implements Gauge. func (NoOpGauge) Update(labels map[string]string, v int64) {} -// TODO (sujatha): Plug-in no-op implementation for nil metrics - -// MeteredTransport wraps a Transport and increments byte counters on each +// meteredTransport wraps a Transport and increments byte counters on each // Read and Write call. -type MeteredTransport struct { +type meteredTransport struct { drpc.Transport - BytesSent Counter - BytesRecv Counter + bytesSent Counter + bytesRecv Counter +} + +// ToMeteredTransport returns a transport that increments bytesSent and +// bytesRecv on each Write and Read call respectively. Nil counters are +// replaced with no-op implementations. +func ToMeteredTransport(tr drpc.Transport, bytesSent, + bytesRecv Counter) drpc.Transport { + if bytesSent == nil { + bytesSent = NoOpCounter{} + } + if bytesRecv == nil { + bytesRecv = NoOpCounter{} + } + return &meteredTransport{Transport: tr, bytesSent: bytesSent, + bytesRecv: bytesRecv} } -// Read reads from the underlying transport and increments BytesRecv. -func (t *MeteredTransport) Read(p []byte) (n int, err error) { +// Read reads from the underlying transport and increments bytesRecv. +func (t *meteredTransport) Read(p []byte) (n int, err error) { n, err = t.Transport.Read(p) - if n > 0 && t.BytesRecv != nil { - t.BytesRecv.Inc(nil, int64(n)) + if n > 0 { + t.bytesRecv.Inc(int64(n)) } return n, err } -// Write writes to the underlying transport and increments BytesSent. -func (t *MeteredTransport) Write(p []byte) (n int, err error) { +// Write writes to the underlying transport and increments bytesSent. +func (t *meteredTransport) Write(p []byte) (n int, err error) { n, err = t.Transport.Write(p) - if n > 0 && t.BytesSent != nil { - t.BytesSent.Inc(nil, int64(n)) + if n > 0 { + t.bytesSent.Inc(int64(n)) } return n, err } diff --git a/drpcpool/pool.go b/drpcpool/pool.go index 5224ae6..0e8e999 100644 --- a/drpcpool/pool.go +++ b/drpcpool/pool.go @@ -17,8 +17,8 @@ import ( // PoolMetrics holds optional metrics for connection pool monitoring. type PoolMetrics struct { PoolSize drpcmetrics.Gauge - ConnectionHitsTotal drpcmetrics.Counter - ConnectionMissesTotal drpcmetrics.Counter + ConnectionHitsTotal drpcmetrics.LabeledCounter + ConnectionMissesTotal drpcmetrics.LabeledCounter } // Options contains the options to configure a pool. @@ -36,9 +36,8 @@ type Options struct { // no values for any single key. KeyCapacity int - // Metrics holds optional metrics the pool will populate. If nil, - // no metrics are recorded. - Metrics *PoolMetrics + // Metrics holds optional metrics the pool will populate. + Metrics PoolMetrics // Labels holds optional labels to be attached to all metrics. Labels map[string]string @@ -61,36 +60,39 @@ func New[K comparable, V Conn](opts Options) *Pool[K, V] { opts: opts, entries: make(map[K]*list[K, V]), } + + pool.initPoolMetrics() + // emit the metric (0 value) so it shows up as soon as the pool is created pool.updatePoolSize() return &pool } -func (p *Pool[K, V]) recordHit() { - if p.opts.Metrics == nil { - return +// initPoolMetrics copies the caller-supplied metrics into the pool, +// substituting no-op implementations for any nil fields. +func (p *Pool[K, V]) initPoolMetrics() { + metrics := &p.opts.Metrics + if metrics.PoolSize == nil { + metrics.PoolSize = drpcmetrics.NoOpGauge{} } - if p.opts.Metrics.ConnectionHitsTotal != nil { - p.opts.Metrics.ConnectionHitsTotal.Inc(p.opts.Labels, 1) + if metrics.ConnectionHitsTotal == nil { + metrics.ConnectionHitsTotal = drpcmetrics.NoOpLabeledCounter{} } + if metrics.ConnectionMissesTotal == nil { + metrics.ConnectionMissesTotal = drpcmetrics.NoOpLabeledCounter{} + } +} + +func (p *Pool[K, V]) recordHit() { + p.opts.Metrics.ConnectionHitsTotal.Inc(p.opts.Labels, 1) } func (p *Pool[K, V]) recordMiss() { - if p.opts.Metrics == nil { - return - } - if p.opts.Metrics.ConnectionMissesTotal != nil { - p.opts.Metrics.ConnectionMissesTotal.Inc(p.opts.Labels, 1) - } + p.opts.Metrics.ConnectionMissesTotal.Inc(p.opts.Labels, 1) } func (p *Pool[K, V]) updatePoolSize() { - if p.opts.Metrics == nil { - return - } - if p.opts.Metrics.PoolSize != nil { - p.opts.Metrics.PoolSize.Update(p.opts.Labels, int64(p.order.count)) - } + p.opts.Metrics.PoolSize.Update(p.opts.Labels, int64(p.order.count)) } func (p *Pool[K, V]) log(what string, cb func() string) { @@ -120,8 +122,9 @@ func (p *Pool[K, V]) Close() (err error) { // Get returns a new Conn that will use the provided dial function to create an // underlying conn to be cached by the Pool when Conn methods are invoked. It will // share any cached connections with other conns that use the same key. -func (p *Pool[K, V]) Get(ctx context.Context, key K, - dial func(ctx context.Context, key K) (V, error)) Conn { +func (p *Pool[K, V]) Get( + ctx context.Context, key K, dial func(ctx context.Context, key K) (V, error), +) Conn { return &poolConn[K, V]{ key: key, pool: p, diff --git a/drpcpool/pool_test.go b/drpcpool/pool_test.go index 0aab51d..a7b42ce 100644 --- a/drpcpool/pool_test.go +++ b/drpcpool/pool_test.go @@ -448,7 +448,7 @@ func TestPoolMetrics_PutTakeClose(t *testing.T) { pool := New[string, Conn](Options{ Capacity: 10, - Metrics: &PoolMetrics{ + Metrics: PoolMetrics{ PoolSize: poolSize, ConnectionHitsTotal: hits, ConnectionMissesTotal: misses, @@ -501,7 +501,7 @@ func TestPoolMetrics_Eviction(t *testing.T) { pool := New[string, Conn](Options{ Capacity: 1, KeyCapacity: 1, - Metrics: &PoolMetrics{ + Metrics: PoolMetrics{ PoolSize: poolSize, ConnectionMissesTotal: misses, }, @@ -529,7 +529,7 @@ func TestPoolMetrics_Eviction(t *testing.T) { func TestPoolMetrics_NilFields(t *testing.T) { // All PoolMetrics fields are nil — should not panic. pool := New[string, Conn](Options{ - Metrics: &PoolMetrics{}, + Metrics: PoolMetrics{}, }) conn := &callbackConn{ diff --git a/drpcserver/server.go b/drpcserver/server.go index ad6f187..75e034f 100644 --- a/drpcserver/server.go +++ b/drpcserver/server.go @@ -46,9 +46,8 @@ type Options struct { // restrictions. If it returns a non-nil error the connection is rejected. TLSCipherRestrict func(conn net.Conn) error - // Metrics holds optional metrics the server will populate. If nil, no - // metrics are recorded. - Metrics *ServerMetrics + // Metrics holds optional metrics the server will populate. + Metrics ServerMetrics } // ServerMetrics holds optional metrics that the server will populate during @@ -56,24 +55,12 @@ type Options struct { // Metrics are defined and registered by the caller (e.g. in CockroachDB) and // passed in; this package never imports a metrics library. type ServerMetrics struct { - BytesSent drpcmetrics.Counter - BytesRecv drpcmetrics.Counter TLSHandshakeErrors drpcmetrics.Counter } -// addTLSHandshakeError increments the TLS handshake error counter. -func (m *ServerMetrics) addTLSHandshakeError() { - if m != nil && m.TLSHandshakeErrors != nil { - m.TLSHandshakeErrors.Inc(nil, 1) - } -} - -// toMeteredTransport wraps tr with byte counters. -func (m *ServerMetrics) toMeteredTransport(tr drpc.Transport) drpc.Transport { - if m == nil { - return tr - } - return &drpcmetrics.MeteredTransport{Transport: tr, BytesSent: m.BytesSent, BytesRecv: m.BytesRecv} +// recordTLSHandshakeError increments the TLS handshake error counter. +func (s *Server) recordTLSHandshakeError() { + s.opts.Metrics.TLSHandshakeErrors.Inc(1) } // Server is an implementation of drpc.Server to serve drpc connections. @@ -103,12 +90,14 @@ func NewWithOptions(handler drpc.Handler, opts Options) *Server { opts: opts, handler: handler, } - if s.opts.CollectStats { + // TODO: (server): deprecate stats drpcopts.SetManagerStatsCB(&s.opts.Manager.Internal, s.getStats) s.stats = make(map[string]*drpcstats.Stats) } - + if s.opts.Metrics.TLSHandshakeErrors == nil { + s.opts.Metrics.TLSHandshakeErrors = drpcmetrics.NoOpCounter{} + } return s } @@ -156,12 +145,12 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { // anyway. err := tlsConn.HandshakeContext(ctx) if err != nil { - s.opts.Metrics.addTLSHandshakeError() + s.recordTLSHandshakeError() return drpc.ConnectionError.New("server handshake [%q] failed: %w", tlsConn.RemoteAddr(), err) } if s.opts.TLSCipherRestrict != nil { if err := s.opts.TLSCipherRestrict(tlsConn); err != nil { - s.opts.Metrics.addTLSHandshakeError() + s.recordTLSHandshakeError() return drpc.ConnectionError.New("server handshake [%q] failed: %w", tlsConn.RemoteAddr(), err) } } @@ -172,8 +161,6 @@ func (s *Server) ServeOne(ctx context.Context, tr drpc.Transport) (err error) { } } - tr = s.opts.Metrics.toMeteredTransport(tr) - man := drpcmanager.NewWithOptions(tr, s.opts.Manager) defer func() { err = errs.Combine(err, man.Close()) }() diff --git a/internal/integration/metrics_test.go b/internal/integration/metrics_test.go index 3c7f73e..a4a0288 100644 --- a/internal/integration/metrics_test.go +++ b/internal/integration/metrics_test.go @@ -27,63 +27,36 @@ import ( // type testCounter struct { - mu sync.Mutex - calls []metricCall + mu sync.Mutex + total_ float64 + count_ int } -type metricCall struct { - labels map[string]string - value float64 -} - -func (c *testCounter) Inc(labels map[string]string, v int64) { +func (c *testCounter) Inc(v int64) { c.mu.Lock() defer c.mu.Unlock() - c.calls = append(c.calls, metricCall{labels: labels, value: float64(v)}) + c.total_ += float64(v) + c.count_++ } func (c *testCounter) total() float64 { c.mu.Lock() defer c.mu.Unlock() - var t float64 - for _, call := range c.calls { - t += call.value - } - return t + return c.total_ } func (c *testCounter) count() int { c.mu.Lock() defer c.mu.Unlock() - return len(c.calls) + return c.count_ } // // connection helpers // -func createMeteredServerConnection( - t testing.TB, server DRPCServiceServer, metrics *drpcserver.ServerMetrics, -) (DRPCServiceClient, func()) { - ctx := drpctest.NewTracker(t) - c1, c2 := net.Pipe() - mux := drpcmux.New() - assert.NoError(t, DRPCRegisterService(mux, server)) - srv := drpcserver.NewWithOptions(mux, drpcserver.Options{ - Metrics: metrics, - }) - ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) - conn := drpcconn.NewWithOptions(c2, drpcconn.Options{ - Manager: drpcmanager.Options{}, - }) - return NewDRPCServiceClient(conn), func() { - _ = conn.Close() - ctx.Close() - } -} - func createMeteredClientConnection( - t testing.TB, server DRPCServiceServer, metrics *drpcmetrics.ClientMetrics, + t testing.TB, server DRPCServiceServer, metrics drpcmetrics.ClientMetrics, ) (DRPCServiceClient, func()) { ctx := drpctest.NewTracker(t) c1, c2 := net.Pipe() @@ -92,8 +65,9 @@ func createMeteredClientConnection( srv := drpcserver.New(mux) ctx.Run(func(ctx context.Context) { _ = srv.ServeOne(ctx, c1) }) conn := drpcconn.NewWithOptions(c2, drpcconn.Options{ - Manager: drpcmanager.Options{}, - Metrics: metrics, + Manager: drpcmanager.Options{}, + Metrics: metrics, + CollectMetrics: true, }) return NewDRPCServiceClient(conn), func() { _ = conn.Close() @@ -101,20 +75,6 @@ func createMeteredClientConnection( } } -// waitForCount waits until counter.count() reaches at least n. -// Server-side byte metrics are recorded inside transport Read/Write which -// may not have completed by the time the client observes the response. -func waitForCount(t testing.TB, c interface{ count() int }, n int) { - t.Helper() - deadline := time.Now().Add(5 * time.Second) - for c.count() < n { - if time.Now().After(deadline) { - t.Fatalf("timed out waiting for count >= %d (got %d)", n, c.count()) - } - time.Sleep(time.Millisecond) - } -} - // // tests // @@ -125,7 +85,7 @@ func TestClientByteMetrics(t *testing.T) { sent := &testCounter{} recv := &testCounter{} - cli, close := createMeteredClientConnection(t, standardImpl, &drpcmetrics.ClientMetrics{ + cli, close := createMeteredClientConnection(t, standardImpl, drpcmetrics.ClientMetrics{ BytesSent: sent, BytesRecv: recv, }) @@ -162,7 +122,7 @@ func TestClientByteMetricsPartialNil(t *testing.T) { defer ctx.Close() sent := &testCounter{} - cli, close := createMeteredClientConnection(t, standardImpl, &drpcmetrics.ClientMetrics{ + cli, close := createMeteredClientConnection(t, standardImpl, drpcmetrics.ClientMetrics{ BytesSent: sent, // BytesRecv intentionally nil. }) @@ -174,61 +134,35 @@ func TestClientByteMetricsPartialNil(t *testing.T) { assert.That(t, sent.total() > 0) } -func TestServerByteMetrics(t *testing.T) { +func TestClientByteMetricsNotCollected(t *testing.T) { ctx := drpctest.NewTracker(t) defer ctx.Close() sent := &testCounter{} recv := &testCounter{} - cli, close := createMeteredServerConnection(t, standardImpl, &drpcserver.ServerMetrics{ - BytesSent: sent, - BytesRecv: recv, - }) - defer close() - - out, err := cli.Method1(ctx, in(1)) - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 1})) - - // The server's byte counters are incremented inside transport - // Read/Write which may not have returned by the time the client - // observes the response, so poll briefly. - waitForCount(t, sent, 1) - waitForCount(t, recv, 1) - assert.That(t, sent.total() > 0) - assert.That(t, recv.total() > 0) -} - -func TestServerMetricsAllNilFields(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - // Non-nil ServerMetrics with all nil fields — should not panic. - cli, close := createMeteredServerConnection(t, standardImpl, &drpcserver.ServerMetrics{}) - defer close() - - out, err := cli.Method1(ctx, in(1)) - assert.NoError(t, err) - assert.True(t, Equal(out, &Out{Out: 1})) -} -func TestServerByteMetricsPartialNil(t *testing.T) { - ctx := drpctest.NewTracker(t) - defer ctx.Close() - - sent := &testCounter{} - // BytesRecv intentionally nil. - cli, close := createMeteredServerConnection(t, standardImpl, &drpcserver.ServerMetrics{ - BytesSent: sent, + c1, c2 := net.Pipe() + mux := drpcmux.New() + assert.NoError(t, DRPCRegisterService(mux, standardImpl)) + srv := drpcserver.New(mux) + ctx.Run(func(ctx2 context.Context) { _ = srv.ServeOne(ctx2, c1) }) + conn := drpcconn.NewWithOptions(c2, drpcconn.Options{ + Metrics: drpcmetrics.ClientMetrics{ + BytesSent: sent, + BytesRecv: recv, + }, }) - defer close() + cli := NewDRPCServiceClient(conn) out, err := cli.Method1(ctx, in(1)) assert.NoError(t, err) assert.True(t, Equal(out, &Out{Out: 1})) - waitForCount(t, sent, 1) - assert.That(t, sent.total() > 0) + // CollectMetrics is false, so no metrics should be collected. + assert.Equal(t, sent.total(), 0.0) + assert.Equal(t, recv.total(), 0.0) + + _ = conn.Close() } func TestServerTLSHandshakeErrorMetric(t *testing.T) { @@ -237,7 +171,7 @@ func TestServerTLSHandshakeErrorMetric(t *testing.T) { mux := drpcmux.New() assert.NoError(t, DRPCRegisterService(mux, standardImpl)) srv := drpcserver.NewWithOptions(mux, drpcserver.Options{ - Metrics: &drpcserver.ServerMetrics{ + Metrics: drpcserver.ServerMetrics{ TLSHandshakeErrors: tlsErrors, }, })