From 3576c510886cb655e505b5ff7679bc5e405b896e Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 14:18:32 +0900 Subject: [PATCH 01/43] Add redis-proxy: dual-write Redis proxy with pub/sub forwarding Implement a Redis-protocol proxy that supports dual-write to primary (Redis) and secondary (ElasticKV) backends. Includes full pub/sub message forwarding via redcon Detach + go-redis PubSub. Verified all 581 Misskey unit tests pass through the proxy. --- cmd/redis-proxy/main.go | 107 +++++++ proxy/backend.go | 110 +++++++ proxy/command.go | 146 +++++++++ proxy/compare.go | 214 +++++++++++++ proxy/config.go | 76 +++++ proxy/dualwrite.go | 267 ++++++++++++++++ proxy/integration_test.go | 257 ++++++++++++++++ proxy/metrics.go | 87 ++++++ proxy/proxy.go | 435 ++++++++++++++++++++++++++ proxy/proxy_test.go | 622 ++++++++++++++++++++++++++++++++++++++ proxy/pubsub.go | 238 +++++++++++++++ proxy/sentry.go | 124 ++++++++ 12 files changed, 2683 insertions(+) create mode 100644 cmd/redis-proxy/main.go create mode 100644 proxy/backend.go create mode 100644 proxy/command.go create mode 100644 proxy/compare.go create mode 100644 proxy/config.go create mode 100644 proxy/dualwrite.go create mode 100644 proxy/integration_test.go create mode 100644 proxy/metrics.go create mode 100644 proxy/proxy.go create mode 100644 proxy/proxy_test.go create mode 100644 proxy/pubsub.go create mode 100644 proxy/sentry.go diff --git a/cmd/redis-proxy/main.go b/cmd/redis-proxy/main.go new file mode 100644 index 00000000..25696065 --- /dev/null +++ b/cmd/redis-proxy/main.go @@ -0,0 +1,107 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log/slog" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/bootjp/elastickv/proxy" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +const sentryFlushTimeout = 2 * time.Second + +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } +} + +func run() error { + cfg := proxy.DefaultConfig() + var modeStr string + + flag.StringVar(&cfg.ListenAddr, "listen", cfg.ListenAddr, "Proxy listen address") + flag.StringVar(&cfg.PrimaryAddr, "primary", cfg.PrimaryAddr, "Primary (Redis) address") + flag.StringVar(&cfg.SecondaryAddr, "secondary", cfg.SecondaryAddr, "Secondary (ElasticKV) address") + flag.StringVar(&modeStr, "mode", "dual-write", "Proxy mode: redis-only, dual-write, dual-write-shadow, elastickv-primary, elastickv-only") + flag.DurationVar(&cfg.SecondaryTimeout, "secondary-timeout", cfg.SecondaryTimeout, "Secondary write timeout") + flag.DurationVar(&cfg.ShadowTimeout, "shadow-timeout", cfg.ShadowTimeout, "Shadow read timeout") + flag.StringVar(&cfg.SentryDSN, "sentry-dsn", cfg.SentryDSN, "Sentry DSN (empty = disabled)") + flag.StringVar(&cfg.SentryEnv, "sentry-env", cfg.SentryEnv, "Sentry environment") + flag.Float64Var(&cfg.SentrySampleRate, "sentry-sample", cfg.SentrySampleRate, "Sentry sample rate") + flag.StringVar(&cfg.MetricsAddr, "metrics", cfg.MetricsAddr, "Prometheus metrics address") + flag.Parse() + + mode, ok := proxy.ParseProxyMode(modeStr) + if !ok { + return fmt.Errorf("unknown mode: %s", modeStr) + } + cfg.Mode = mode + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + + // Sentry + sentryReporter := proxy.NewSentryReporter(cfg.SentryDSN, cfg.SentryEnv, cfg.SentrySampleRate, logger) + defer sentryReporter.Flush(sentryFlushTimeout) + + // Prometheus + reg := prometheus.NewRegistry() + metrics := proxy.NewProxyMetrics(reg) + + // Backends + var primary, secondary proxy.Backend + switch cfg.Mode { + case proxy.ModeElasticKVPrimary, proxy.ModeElasticKVOnly: + primary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") + secondary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") + case proxy.ModeRedisOnly, proxy.ModeDualWrite, proxy.ModeDualWriteShadow: + primary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") + secondary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") + } + defer primary.Close() + defer secondary.Close() + + dual := proxy.NewDualWriter(primary, secondary, cfg, metrics, sentryReporter, logger) + srv := proxy.NewProxyServer(cfg, dual, metrics, sentryReporter, logger) + + // Context for graceful shutdown + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + // Start metrics server + go func() { + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) + var lc net.ListenConfig + ln, err := lc.Listen(ctx, "tcp", cfg.MetricsAddr) + if err != nil { + logger.Error("metrics listen failed", "addr", cfg.MetricsAddr, "err", err) + return + } + metricsSrv := &http.Server{Handler: mux, ReadHeaderTimeout: time.Second} + go func() { + <-ctx.Done() + metricsSrv.Close() + }() + logger.Info("metrics server starting", "addr", cfg.MetricsAddr) + if err := metricsSrv.Serve(ln); err != nil && err != http.ErrServerClosed { + logger.Error("metrics server error", "err", err) + } + }() + + // Start proxy + if err := srv.ListenAndServe(ctx); err != nil { + return fmt.Errorf("proxy server: %w", err) + } + return nil +} diff --git a/proxy/backend.go b/proxy/backend.go new file mode 100644 index 00000000..75ebd695 --- /dev/null +++ b/proxy/backend.go @@ -0,0 +1,110 @@ +package proxy + +import ( + "context" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +const ( + defaultPoolSize = 128 + defaultDialTimeout = 5 * time.Second + defaultReadTimeout = 3 * time.Second + defaultWriteTimeout = 3 * time.Second +) + +// Backend abstracts a Redis-protocol endpoint (real Redis or ElasticKV). +type Backend interface { + // Do sends a single command and returns its result. + Do(ctx context.Context, args ...interface{}) *redis.Cmd + // Pipeline sends multiple commands in a pipeline. + Pipeline(ctx context.Context, cmds [][]interface{}) ([]*redis.Cmd, error) + // Close releases the underlying connection. + Close() error + // Name identifies this backend for logging and metrics. + Name() string +} + +// BackendOptions configures the underlying go-redis connection pool. +type BackendOptions struct { + PoolSize int + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +// DefaultBackendOptions returns reasonable defaults for a proxy backend. +func DefaultBackendOptions() BackendOptions { + return BackendOptions{ + PoolSize: defaultPoolSize, + DialTimeout: defaultDialTimeout, + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + } +} + +// PubSubBackend is an optional interface for backends that support +// creating dedicated PubSub connections. +type PubSubBackend interface { + NewPubSub(ctx context.Context) *redis.PubSub +} + +// RedisBackend connects to an upstream Redis instance via go-redis. +type RedisBackend struct { + client *redis.Client + name string +} + +// NewRedisBackend creates a Backend targeting a Redis server with default pool options. +func NewRedisBackend(addr string, name string) *RedisBackend { + return NewRedisBackendWithOptions(addr, name, DefaultBackendOptions()) +} + +// NewRedisBackendWithOptions creates a Backend with explicit pool configuration. +func NewRedisBackendWithOptions(addr string, name string, opts BackendOptions) *RedisBackend { + return &RedisBackend{ + client: redis.NewClient(&redis.Options{ + Addr: addr, + PoolSize: opts.PoolSize, + DialTimeout: opts.DialTimeout, + ReadTimeout: opts.ReadTimeout, + WriteTimeout: opts.WriteTimeout, + }), + name: name, + } +} + +func (b *RedisBackend) Do(ctx context.Context, args ...interface{}) *redis.Cmd { + return b.client.Do(ctx, args...) +} + +func (b *RedisBackend) Pipeline(ctx context.Context, cmds [][]interface{}) ([]*redis.Cmd, error) { + pipe := b.client.Pipeline() + results := make([]*redis.Cmd, len(cmds)) + for i, args := range cmds { + results[i] = pipe.Do(ctx, args...) + } + _, err := pipe.Exec(ctx) + if err != nil { + return results, fmt.Errorf("pipeline exec: %w", err) + } + return results, nil +} + +func (b *RedisBackend) Close() error { + if err := b.client.Close(); err != nil { + return fmt.Errorf("close %s backend: %w", b.name, err) + } + return nil +} + +func (b *RedisBackend) Name() string { + return b.name +} + +// NewPubSub creates a dedicated PubSub connection (not from the pool). +func (b *RedisBackend) NewPubSub(ctx context.Context) *redis.PubSub { + return b.client.Subscribe(ctx) +} diff --git a/proxy/command.go b/proxy/command.go new file mode 100644 index 00000000..b04ead9b --- /dev/null +++ b/proxy/command.go @@ -0,0 +1,146 @@ +package proxy + +import "strings" + +// CommandCategory classifies a Redis command for routing purposes. +type CommandCategory int + +const ( + CmdRead CommandCategory = iota // GET, HGET, LRANGE, ZRANGE, etc. + CmdWrite // SET, DEL, HSET, LPUSH, ZADD, etc. + CmdBlocking // BZPOPMIN, XREAD (with BLOCK) + CmdPubSub // SUBSCRIBE, PUBLISH, PUBSUB + CmdAdmin // PING, INFO, CLIENT, SELECT, QUIT, DBSIZE, SCAN, AUTH + CmdTxn // MULTI, EXEC, DISCARD + CmdScript // EVAL, EVALSHA +) + +var commandTable = map[string]CommandCategory{ + // Read commands + "GET": CmdRead, + "GETDEL": CmdWrite, // read+write → write + "HGET": CmdRead, + "HGETALL": CmdRead, + "HEXISTS": CmdRead, + "HLEN": CmdRead, + "HMGET": CmdRead, + "EXISTS": CmdRead, + "KEYS": CmdRead, + "LINDEX": CmdRead, + "LLEN": CmdRead, + "LPOS": CmdRead, + "LRANGE": CmdRead, + "PTTL": CmdRead, + "TTL": CmdRead, + "TYPE": CmdRead, + "SCARD": CmdRead, + "SISMEMBER": CmdRead, + "SMEMBERS": CmdRead, + "XLEN": CmdRead, + "XRANGE": CmdRead, + "XREVRANGE": CmdRead, + "ZCARD": CmdRead, + "ZCOUNT": CmdRead, + "ZRANGE": CmdRead, + "ZRANGEBYSCORE": CmdRead, + "ZREVRANGE": CmdRead, + "ZREVRANGEBYSCORE": CmdRead, + "ZSCORE": CmdRead, + "PFCOUNT": CmdRead, + + // Write commands + "SET": CmdWrite, + "SETEX": CmdWrite, + "SETNX": CmdWrite, + "DEL": CmdWrite, + "HSET": CmdWrite, + "HMSET": CmdWrite, + "HDEL": CmdWrite, + "HINCRBY": CmdWrite, + "INCR": CmdWrite, + "LPUSH": CmdWrite, + "LPOP": CmdWrite, + "RPUSH": CmdWrite, + "RPOP": CmdWrite, + "RPOPLPUSH": CmdWrite, + "LREM": CmdWrite, + "LSET": CmdWrite, + "LTRIM": CmdWrite, + "SADD": CmdWrite, + "SREM": CmdWrite, + "EXPIRE": CmdWrite, + "PEXPIRE": CmdWrite, + "RENAME": CmdWrite, + "XADD": CmdWrite, + "XTRIM": CmdWrite, + "ZADD": CmdWrite, + "ZINCRBY": CmdWrite, + "ZREM": CmdWrite, + "ZREMRANGEBYSCORE": CmdWrite, + "ZREMRANGEBYRANK": CmdWrite, + "ZPOPMIN": CmdWrite, + "PFADD": CmdWrite, + "FLUSHALL": CmdWrite, + "FLUSHDB": CmdWrite, + "PUBLISH": CmdWrite, // write to both backends + + // Blocking commands + "BZPOPMIN": CmdBlocking, + // XREAD is handled specially in ClassifyCommand (BLOCK arg check) + + // PubSub commands + "SUBSCRIBE": CmdPubSub, + "UNSUBSCRIBE": CmdPubSub, + "PSUBSCRIBE": CmdPubSub, + "PUNSUBSCRIBE": CmdPubSub, + "PUBSUB": CmdPubSub, + + // Admin commands — forwarded to primary only + "PING": CmdAdmin, + "INFO": CmdAdmin, + "CLIENT": CmdAdmin, + "SELECT": CmdAdmin, + "QUIT": CmdAdmin, + "DBSIZE": CmdAdmin, + "SCAN": CmdAdmin, + "AUTH": CmdAdmin, + "HELLO": CmdAdmin, + "WAIT": CmdAdmin, + "CONFIG": CmdAdmin, + "OBJECT": CmdAdmin, + "DEBUG": CmdAdmin, + "CLUSTER": CmdAdmin, + "COMMAND": CmdAdmin, + + // Transaction commands + "MULTI": CmdTxn, + "EXEC": CmdTxn, + "DISCARD": CmdTxn, + + // Script commands + "EVAL": CmdScript, + "EVALSHA": CmdScript, +} + +// ClassifyCommand returns the category for a Redis command name. +// XREAD is classified as CmdBlocking if args contain BLOCK, otherwise CmdRead. +// Unknown commands default to CmdWrite (sent to both backends). +func ClassifyCommand(name string, args [][]byte) CommandCategory { + upper := strings.ToUpper(name) + + // Special case: XREAD with BLOCK + if upper == "XREAD" { + for _, arg := range args { + if strings.ToUpper(string(arg)) == "BLOCK" { + return CmdBlocking + } + } + return CmdRead + } + + if cat, ok := commandTable[upper]; ok { + return cat + } + // Unknown commands → treat as write (safe default, sent to both backends) + return CmdWrite +} diff --git a/proxy/compare.go b/proxy/compare.go new file mode 100644 index 00000000..81c4d641 --- /dev/null +++ b/proxy/compare.go @@ -0,0 +1,214 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "log/slog" + "reflect" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" +) + +const ( + unknownStr = "unknown" + defaultGapLogSampleRate int64 = 1000 +) + +// DivergenceKind classifies the nature of a shadow-read mismatch. +type DivergenceKind int + +const ( + DivMigrationGap DivergenceKind = iota // Secondary nil/empty, Primary has data → expected during migration + DivDataMismatch // Both have data but differ → real inconsistency + DivExtraData // Primary nil, Secondary has data → unexpected +) + +func (k DivergenceKind) String() string { + switch k { + case DivMigrationGap: + return "migration_gap" + case DivDataMismatch: + return "data_mismatch" + case DivExtraData: + return "extra_data" + default: + return unknownStr + } +} + +// Divergence records a detected mismatch between primary and secondary. +type Divergence struct { + Command string + Key string + Kind DivergenceKind + Primary interface{} + Secondary interface{} + DetectedAt time.Time +} + +// ShadowReader compares primary and secondary read results. +type ShadowReader struct { + secondary Backend + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + timeout time.Duration + + // Sampling counter for migration gap logs (log 1 per gapLogSampleRate). + gapCount atomic.Int64 + gapLogSampleRate int64 +} + +// NewShadowReader creates a ShadowReader. +func NewShadowReader(secondary Backend, metrics *ProxyMetrics, sentryReporter *SentryReporter, logger *slog.Logger, timeout time.Duration) *ShadowReader { + return &ShadowReader{ + secondary: secondary, + metrics: metrics, + sentry: sentryReporter, + logger: logger, + timeout: timeout, + gapLogSampleRate: defaultGapLogSampleRate, + } +} + +// Compare issues the same read to the secondary and checks for divergence. +func (s *ShadowReader) Compare(ctx context.Context, cmd string, args [][]byte, primaryResp interface{}, primaryErr error) { + sCtx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + + iArgs := bytesArgsToInterfaces(args) + secondaryResult := s.secondary.Do(sCtx, iArgs...) + secondaryResp, secondaryErr := secondaryResult.Result() + + if isConsistent(primaryResp, secondaryResp, primaryErr, secondaryErr) { + return + } + + // --- Divergence detected --- + kind := classifyDivergence(primaryResp, primaryErr, secondaryResp, secondaryErr) + + if kind == DivMigrationGap { + s.metrics.MigrationGaps.WithLabelValues(cmd).Inc() + count := s.gapCount.Add(1) + if count%s.gapLogSampleRate == 1 { + s.logger.Debug("migration gap (sampled)", "cmd", cmd, "key", extractKey(args)) + } + return + } + + div := Divergence{ + Command: cmd, + Key: extractKey(args), + Kind: kind, + Primary: formatResp(primaryResp, primaryErr), + Secondary: formatResp(secondaryResp, secondaryErr), + DetectedAt: time.Now(), + } + + s.metrics.Divergences.WithLabelValues(cmd, kind.String()).Inc() + s.logger.Warn("response divergence detected", + "cmd", div.Command, "key", div.Key, "kind", div.Kind.String(), + "primary", fmt.Sprintf("%v", div.Primary), + "secondary", fmt.Sprintf("%v", div.Secondary), + ) + s.sentry.CaptureDivergence(div) +} + +// isConsistent checks whether primary and secondary responses agree. +func isConsistent(primaryResp, secondaryResp interface{}, primaryErr, secondaryErr error) bool { + // Both are redis.Nil → consistent (key missing on both) + if isNilError(primaryErr) && isNilError(secondaryErr) { + return true + } + // Both returned the same non-nil error → consistent + if primaryErr != nil && secondaryErr != nil && primaryErr.Error() == secondaryErr.Error() { + return true + } + // Both succeeded with equal values → consistent + return primaryErr == nil && secondaryErr == nil && responseEqual(primaryResp, secondaryResp) +} + +// responseEqual compares two go-redis response values for equality. +func responseEqual(a, b interface{}) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + switch a := a.(type) { + case string, int64: + return a == b + case []interface{}: + return interfaceSliceEqual(a, b) + default: + return reflect.DeepEqual(a, b) + } +} + +// interfaceSliceEqual compares two []interface{} slices element-by-element. +func interfaceSliceEqual(av []interface{}, b interface{}) bool { + bv, ok := b.([]interface{}) + if !ok || len(av) != len(bv) { + return false + } + for i := range av { + if !responseEqual(av[i], bv[i]) { + return false + } + } + return true +} + +// classifyDivergence determines the kind based on primary/secondary values. +func classifyDivergence(primaryResp interface{}, primaryErr error, secondaryResp interface{}, secondaryErr error) DivergenceKind { + primaryNil := isNilResp(primaryResp, primaryErr) + secondaryNil := isNilResp(secondaryResp, secondaryErr) + + switch { + case !primaryNil && secondaryNil: + return DivMigrationGap + case primaryNil && !secondaryNil: + return DivExtraData + default: + return DivDataMismatch + } +} + +func isNilError(err error) bool { + return errors.Is(err, redis.Nil) +} + +// isNilResp checks if a response represents "no data" (nil response or redis.Nil error). +// Empty string is NOT nil — it is a valid value. +func isNilResp(resp interface{}, err error) bool { + if errors.Is(err, redis.Nil) { + return true + } + return resp == nil +} + +func formatResp(resp interface{}, err error) interface{} { + if err != nil { + return fmt.Sprintf("error: %v", err) + } + return resp +} + +func extractKey(args [][]byte) string { + if len(args) > 1 { + return string(args[1]) + } + return "" +} + +func bytesArgsToInterfaces(args [][]byte) []interface{} { + out := make([]interface{}, len(args)) + for i, a := range args { + out[i] = string(a) + } + return out +} diff --git a/proxy/config.go b/proxy/config.go new file mode 100644 index 00000000..267f67eb --- /dev/null +++ b/proxy/config.go @@ -0,0 +1,76 @@ +package proxy + +import "time" + +const ( + defaultSecondaryTimeout = 5 * time.Second + defaultShadowTimeout = 3 * time.Second +) + +// ProxyMode controls which backends receive reads and writes. +type ProxyMode int + +const ( + ModeRedisOnly ProxyMode = iota // Redis only (passthrough) + ModeDualWrite // Write to both, read from Redis + ModeDualWriteShadow // Write to both, read from Redis + shadow read from ElasticKV + ModeElasticKVPrimary // Write to both, read from ElasticKV + shadow read from Redis + ModeElasticKVOnly // ElasticKV only +) + +var modeNames = map[string]ProxyMode{ + "redis-only": ModeRedisOnly, + "dual-write": ModeDualWrite, + "dual-write-shadow": ModeDualWriteShadow, + "elastickv-primary": ModeElasticKVPrimary, + "elastickv-only": ModeElasticKVOnly, +} + +var modeStrings = map[ProxyMode]string{ + ModeRedisOnly: "redis-only", + ModeDualWrite: "dual-write", + ModeDualWriteShadow: "dual-write-shadow", + ModeElasticKVPrimary: "elastickv-primary", + ModeElasticKVOnly: "elastickv-only", +} + +// ParseProxyMode converts a string to ProxyMode. Returns false if unknown. +func ParseProxyMode(s string) (ProxyMode, bool) { + m, ok := modeNames[s] + return m, ok +} + +func (m ProxyMode) String() string { + if s, ok := modeStrings[m]; ok { + return s + } + return unknownStr +} + +// ProxyConfig holds all configuration for the dual-write proxy. +type ProxyConfig struct { + ListenAddr string + PrimaryAddr string + SecondaryAddr string + Mode ProxyMode + SecondaryTimeout time.Duration + ShadowTimeout time.Duration + SentryDSN string + SentryEnv string + SentrySampleRate float64 + MetricsAddr string +} + +// DefaultConfig returns a ProxyConfig with sensible defaults. +func DefaultConfig() ProxyConfig { + return ProxyConfig{ + ListenAddr: ":6479", + PrimaryAddr: "localhost:6379", + SecondaryAddr: "localhost:6380", + Mode: ModeDualWrite, + SecondaryTimeout: defaultSecondaryTimeout, + ShadowTimeout: defaultShadowTimeout, + SentrySampleRate: 1.0, + MetricsAddr: ":9191", + } +} diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go new file mode 100644 index 00000000..df42a32c --- /dev/null +++ b/proxy/dualwrite.go @@ -0,0 +1,267 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/redis/go-redis/v9" +) + +// maxAsyncGoroutines limits concurrent fire-and-forget goroutines to prevent +// goroutine explosion when the secondary backend is slow or down. +const maxAsyncGoroutines = 4096 + +// DualWriter routes commands to primary and secondary backends based on mode. +type DualWriter struct { + primary Backend + secondary Backend + cfg ProxyConfig + shadow *ShadowReader + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + + // asyncSem bounds the number of concurrent async goroutines + // (secondary writes + shadow reads). + asyncSem chan struct{} +} + +// NewDualWriter creates a DualWriter with the given backends. +func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMetrics, sentryReporter *SentryReporter, logger *slog.Logger) *DualWriter { + d := &DualWriter{ + primary: primary, + secondary: secondary, + cfg: cfg, + metrics: metrics, + sentry: sentryReporter, + logger: logger, + asyncSem: make(chan struct{}, maxAsyncGoroutines), + } + + if cfg.Mode == ModeDualWriteShadow || cfg.Mode == ModeElasticKVPrimary { + // Shadow reads go to the non-primary backend for comparison. + // Note: main.go already swaps primary/secondary for ElasticKVPrimary mode, + // so here "secondary" is always the non-primary backend. + shadowBackend := secondary + d.shadow = NewShadowReader(shadowBackend, metrics, sentryReporter, logger, cfg.ShadowTimeout) + } + + return d +} + +// Write sends a write command to the primary synchronously, then to the secondary asynchronously. +func (d *DualWriter) Write(ctx context.Context, args [][]byte) (interface{}, error) { + cmd := string(args[0]) + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.PrimaryWriteErrors.Inc() + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + d.logger.Error("primary write failed", "cmd", cmd, "err", err) + return nil, fmt.Errorf("primary write %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + // Secondary: async fire-and-forget (bounded) + if d.hasSecondaryWrite() { + d.goAsync(func() { d.writeSecondary(cmd, iArgs) }) + } + + if err != nil { + return resp, fmt.Errorf("primary write %s: %w", cmd, err) + } + return resp, nil +} + +// Read sends a read command to the primary and optionally performs a shadow read. +func (d *DualWriter) Read(ctx context.Context, args [][]byte) (interface{}, error) { + cmd := string(args[0]) + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.PrimaryReadErrors.Inc() + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary read %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + // Shadow read (bounded) + if d.shadow != nil { + shadowArgs := args + shadowResp := resp + shadowErr := err + d.goAsync(func() { + d.shadow.Compare(context.Background(), cmd, shadowArgs, shadowResp, shadowErr) + }) + } + + if err != nil { + return resp, fmt.Errorf("primary read %s: %w", cmd, err) + } + return resp, nil +} + +// Blocking forwards a blocking command to the primary only. +// Optionally sends a short-timeout version to secondary for warmup. +func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (interface{}, error) { + cmd := string(args[0]) + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary blocking %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + // Warmup: send to secondary with short timeout (fire-and-forget, bounded) + if d.hasSecondaryWrite() { + d.goAsync(func() { + sCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d.secondary.Do(sCtx, iArgs...) + }) + } + + if err != nil { + return resp, fmt.Errorf("primary blocking %s: %w", cmd, err) + } + return resp, nil +} + +// Admin forwards an admin command to the primary only. +func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (interface{}, error) { + cmd := string(args[0]) + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary admin %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + if err != nil { + return resp, fmt.Errorf("primary admin %s: %w", cmd, err) + } + return resp, nil +} + +// Script forwards EVAL/EVALSHA to the primary, and async replays to secondary. +func (d *DualWriter) Script(ctx context.Context, args [][]byte) (interface{}, error) { + cmd := string(args[0]) + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary script %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + if d.hasSecondaryWrite() { + d.goAsync(func() { d.writeSecondary(cmd, iArgs) }) + } + + if err != nil { + return resp, fmt.Errorf("primary script %s: %w", cmd, err) + } + return resp, nil +} + +func (d *DualWriter) writeSecondary(cmd string, iArgs []interface{}) { + sCtx, cancel := context.WithTimeout(context.Background(), d.cfg.SecondaryTimeout) + defer cancel() + + start := time.Now() + result := d.secondary.Do(sCtx, iArgs...) + _, sErr := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.secondary.Name()).Observe(time.Since(start).Seconds()) + + if sErr != nil && !errors.Is(sErr, redis.Nil) { + d.metrics.SecondaryWriteErrors.Inc() + d.metrics.CommandTotal.WithLabelValues(cmd, d.secondary.Name(), "error").Inc() + fingerprint := fmt.Sprintf("secondary_write_%s", cmd) + if d.sentry.ShouldReport(fingerprint) { + d.sentry.CaptureException(sErr, "secondary_write_failure", argsToBytes(iArgs)) + } + d.logger.Warn("secondary write failed", "cmd", cmd, "err", sErr) + return + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.secondary.Name(), "ok").Inc() +} + +// goAsync launches fn in a bounded goroutine. If the semaphore is full, +// the work is dropped and a metric is incremented rather than blocking the caller. +func (d *DualWriter) goAsync(fn func()) { + select { + case d.asyncSem <- struct{}{}: + go func() { + defer func() { <-d.asyncSem }() + fn() + }() + default: + // Semaphore full — drop async work to protect the proxy. + d.logger.Warn("async goroutine limit reached, dropping secondary operation") + } +} + +func (d *DualWriter) hasSecondaryWrite() bool { + switch d.cfg.Mode { + case ModeDualWrite, ModeDualWriteShadow, ModeElasticKVPrimary: + return true + case ModeRedisOnly, ModeElasticKVOnly: + return false + } + return false +} + +// Primary returns the primary backend for direct use (e.g., PubSub). +func (d *DualWriter) Primary() Backend { + return d.primary +} + +// PubSubBackend returns the primary backend as a PubSubBackend, or nil. +func (d *DualWriter) PubSubBackend() PubSubBackend { + if ps, ok := d.primary.(PubSubBackend); ok { + return ps + } + return nil +} + +// Secondary returns the secondary backend. +func (d *DualWriter) Secondary() Backend { + return d.secondary +} + +func argsToBytes(iArgs []interface{}) [][]byte { + out := make([][]byte, len(iArgs)) + for i, a := range iArgs { + out[i] = []byte(fmt.Sprintf("%v", a)) + } + return out +} diff --git a/proxy/integration_test.go b/proxy/integration_test.go new file mode 100644 index 00000000..4902a530 --- /dev/null +++ b/proxy/integration_test.go @@ -0,0 +1,257 @@ +package proxy_test + +import ( + "context" + "fmt" + "log/slog" + "net" + "os" + "testing" + "time" + + "github.com/bootjp/elastickv/proxy" + "github.com/prometheus/client_golang/prometheus" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testPrimaryAddr = "localhost:6379" + +// skipWithoutRedis skips the test if the required Redis instances are not available. +func skipWithoutRedis(t *testing.T, addrs ...string) { + t.Helper() + for _, addr := range addrs { + d := net.Dialer{Timeout: 500 * time.Millisecond} + conn, err := d.DialContext(context.Background(), "tcp", addr) + if err != nil { + t.Skipf("skipping integration test: cannot connect to %s: %v", addr, err) + } + conn.Close() + } +} + +func freePort(t *testing.T) string { + t.Helper() + var lc net.ListenConfig + ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := ln.Addr().String() + ln.Close() + return addr +} + +func setupProxy(t *testing.T, mode proxy.ProxyMode, secondaryAddr string) (*redis.Client, context.CancelFunc) { + t.Helper() + listenAddr := freePort(t) + metricsAddr := freePort(t) + + cfg := proxy.ProxyConfig{ + ListenAddr: listenAddr, + PrimaryAddr: testPrimaryAddr, + SecondaryAddr: secondaryAddr, + Mode: mode, + SecondaryTimeout: 5 * time.Second, + ShadowTimeout: 3 * time.Second, + MetricsAddr: metricsAddr, + } + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + reg := prometheus.NewRegistry() + metrics := proxy.NewProxyMetrics(reg) + sentryReporter := proxy.NewSentryReporter("", "", 1.0, logger) + + var primary, secondary proxy.Backend + switch cfg.Mode { + case proxy.ModeElasticKVPrimary, proxy.ModeElasticKVOnly: + primary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") + secondary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") + case proxy.ModeRedisOnly, proxy.ModeDualWrite, proxy.ModeDualWriteShadow: + primary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") + secondary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") + } + + dual := proxy.NewDualWriter(primary, secondary, cfg, metrics, sentryReporter, logger) + srv := proxy.NewProxyServer(cfg, dual, metrics, sentryReporter, logger) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + _ = srv.ListenAndServe(ctx) + }() + + // Wait for proxy to be ready + var client *redis.Client + for i := range 50 { + _ = i + d := net.Dialer{Timeout: 100 * time.Millisecond} + conn, err := d.DialContext(ctx, "tcp", listenAddr) + if err == nil { + conn.Close() + client = redis.NewClient(&redis.Options{Addr: listenAddr}) + break + } + time.Sleep(50 * time.Millisecond) + } + require.NotNil(t, client, "proxy did not become ready") + + t.Cleanup(func() { + client.Close() + primary.Close() + secondary.Close() + }) + + return client, cancel +} + +func TestIntegration_RedisOnly(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") // secondary unused + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-test-%d", time.Now().UnixNano()) + + // PING + pong, err := client.Ping(ctx).Result() + require.NoError(t, err) + assert.Equal(t, "PONG", pong) + + // SET / GET + err = client.Set(ctx, key, "hello", 0).Err() + require.NoError(t, err) + + val, err := client.Get(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "hello", val) + + // DEL + err = client.Del(ctx, key).Err() + require.NoError(t, err) + + _, err = client.Get(ctx, key).Result() + assert.Equal(t, redis.Nil, err) +} + +func TestIntegration_DualWrite(t *testing.T) { + secondaryAddr := "localhost:6380" + skipWithoutRedis(t, testPrimaryAddr, secondaryAddr) + + client, cancel := setupProxy(t, proxy.ModeDualWrite, secondaryAddr) + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-dual-%d", time.Now().UnixNano()) + + // Write through proxy + err := client.Set(ctx, key, "dual-value", 0).Err() + require.NoError(t, err) + + // Read through proxy (from primary) + val, err := client.Get(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "dual-value", val) + + // Wait a bit for async secondary write + time.Sleep(200 * time.Millisecond) + + // Verify data exists on secondary directly + secondaryClient := redis.NewClient(&redis.Options{Addr: secondaryAddr}) + defer secondaryClient.Close() + + sVal, err := secondaryClient.Get(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "dual-value", sVal) + + // Clean up + client.Del(ctx, key) + time.Sleep(100 * time.Millisecond) +} + +func TestIntegration_HashCommands(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-hash-%d", time.Now().UnixNano()) + + // HSET + err := client.HSet(ctx, key, "field1", "val1", "field2", "val2").Err() + require.NoError(t, err) + + // HGET + val, err := client.HGet(ctx, key, "field1").Result() + require.NoError(t, err) + assert.Equal(t, "val1", val) + + // HGETALL + all, err := client.HGetAll(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "val1", all["field1"]) + assert.Equal(t, "val2", all["field2"]) + + // HDEL + err = client.HDel(ctx, key, "field1").Err() + require.NoError(t, err) + + _, err = client.HGet(ctx, key, "field1").Result() + assert.Equal(t, redis.Nil, err) + + // Clean up + client.Del(ctx, key) +} + +func TestIntegration_ListCommands(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-list-%d", time.Now().UnixNano()) + + // LPUSH + err := client.LPush(ctx, key, "a", "b", "c").Err() + require.NoError(t, err) + + // LLEN + llen, err := client.LLen(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, int64(3), llen) + + // LRANGE + vals, err := client.LRange(ctx, key, 0, -1).Result() + require.NoError(t, err) + assert.Equal(t, []string{"c", "b", "a"}, vals) + + // Clean up + client.Del(ctx, key) +} + +func TestIntegration_Transaction(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-txn-%d", time.Now().UnixNano()) + + // MULTI/EXEC via pipeline (go-redis TxPipelined) + _, err := client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, key, "txn-val", 0) + pipe.Get(ctx, key) + return nil + }) + require.NoError(t, err) + + // Verify + val, err := client.Get(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "txn-val", val) + + // Clean up + client.Del(ctx, key) +} diff --git a/proxy/metrics.go b/proxy/metrics.go new file mode 100644 index 00000000..67c6128c --- /dev/null +++ b/proxy/metrics.go @@ -0,0 +1,87 @@ +package proxy + +import "github.com/prometheus/client_golang/prometheus" + +// ProxyMetrics holds all Prometheus metrics for the dual-write proxy. +type ProxyMetrics struct { + CommandTotal *prometheus.CounterVec + CommandDuration *prometheus.HistogramVec + + PrimaryWriteErrors prometheus.Counter + SecondaryWriteErrors prometheus.Counter + PrimaryReadErrors prometheus.Counter + ShadowReadErrors prometheus.Counter + Divergences *prometheus.CounterVec + MigrationGaps *prometheus.CounterVec + + ActiveConnections prometheus.Gauge +} + +// NewProxyMetrics creates and registers all proxy metrics. +func NewProxyMetrics(reg prometheus.Registerer) *ProxyMetrics { + m := &ProxyMetrics{ + CommandTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "command_total", + Help: "Total commands processed by the proxy.", + }, []string{"command", "backend", "status"}), + + CommandDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "proxy", + Name: "command_duration_seconds", + Help: "Latency of commands forwarded to backends.", + Buckets: prometheus.DefBuckets, + }, []string{"command", "backend"}), + + PrimaryWriteErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "primary_write_errors_total", + Help: "Total write errors from the primary backend.", + }), + SecondaryWriteErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "secondary_write_errors_total", + Help: "Total write errors from the secondary backend.", + }), + PrimaryReadErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "primary_read_errors_total", + Help: "Total read errors from the primary backend.", + }), + ShadowReadErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "shadow_read_errors_total", + Help: "Total errors from shadow reads.", + }), + Divergences: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "divergences_total", + Help: "Total data mismatches detected by shadow reads.", + }, []string{"command", "kind"}), + MigrationGaps: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "migration_gap_total", + Help: "Expected divergences due to missing data on secondary (pre-migration).", + }, []string{"command"}), + + ActiveConnections: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "proxy", + Name: "active_connections", + Help: "Current number of active client connections.", + }), + } + + reg.MustRegister( + m.CommandTotal, + m.CommandDuration, + m.PrimaryWriteErrors, + m.SecondaryWriteErrors, + m.PrimaryReadErrors, + m.ShadowReadErrors, + m.Divergences, + m.MigrationGaps, + m.ActiveConnections, + ) + + return m +} diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 00000000..60a00fe2 --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,435 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "strings" + + "github.com/redis/go-redis/v9" + "github.com/tidwall/redcon" +) + +// txnCommandsOverhead is the number of extra commands (MULTI + EXEC) wrapping queued commands. +const txnCommandsOverhead = 2 + +// proxyConnState tracks per-connection state (transactions, PubSub). +type proxyConnState struct { + inTxn bool + txnQueue [][][]byte // buffered commands between MULTI and EXEC +} + +// ProxyServer is a Redis-protocol proxy that dual-writes to two backends. +type ProxyServer struct { + cfg ProxyConfig + dual *DualWriter + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + + // shutdownCtx is cancelled on graceful shutdown, used for blocking commands. + shutdownCtx context.Context +} + +// NewProxyServer creates a proxy server with the given configuration and backends. +func NewProxyServer(cfg ProxyConfig, dual *DualWriter, metrics *ProxyMetrics, sentryReporter *SentryReporter, logger *slog.Logger) *ProxyServer { + return &ProxyServer{ + cfg: cfg, + dual: dual, + metrics: metrics, + sentry: sentryReporter, + logger: logger, + } +} + +// ListenAndServe starts the redcon proxy server. +func (p *ProxyServer) ListenAndServe(ctx context.Context) error { + p.shutdownCtx = ctx + + var lc net.ListenConfig + ln, err := lc.Listen(ctx, "tcp", p.cfg.ListenAddr) + if err != nil { + return fmt.Errorf("proxy listen: %w", err) + } + + srv := redcon.NewServer(p.cfg.ListenAddr, + p.handleCommand, + p.handleAccept, + p.handleClose, + ) + + // Graceful shutdown on context cancel. + go func() { + <-ctx.Done() + p.logger.Info("shutting down proxy server") + srv.Close() + }() + + p.logger.Info("proxy server starting", + "addr", p.cfg.ListenAddr, + "mode", p.cfg.Mode.String(), + "primary", p.cfg.PrimaryAddr, + "secondary", p.cfg.SecondaryAddr, + ) + + if err = srv.Serve(ln); err != nil { + return fmt.Errorf("proxy serve: %w", err) + } + return nil +} + +func (p *ProxyServer) handleAccept(conn redcon.Conn) bool { + p.metrics.ActiveConnections.Inc() + conn.SetContext(&proxyConnState{}) + return true +} + +func (p *ProxyServer) handleClose(conn redcon.Conn, _ error) { + p.metrics.ActiveConnections.Dec() +} + +func getConnState(conn redcon.Conn) *proxyConnState { + if ctx := conn.Context(); ctx != nil { + if st, ok := ctx.(*proxyConnState); ok { + return st + } + } + st := &proxyConnState{} + conn.SetContext(st) + return st +} + +func (p *ProxyServer) handleCommand(conn redcon.Conn, cmd redcon.Command) { + if len(cmd.Args) == 0 { + conn.WriteError("ERR empty command") + return + } + + name := strings.ToUpper(string(cmd.Args[0])) + args := cloneArgs(cmd.Args) + state := getConnState(conn) + + // Transaction handling + if state.inTxn { + p.handleQueuedCommand(conn, state, name, args) + return + } + + p.dispatchCommand(conn, state, name, args) +} + +func (p *ProxyServer) dispatchCommand(conn redcon.Conn, state *proxyConnState, name string, args [][]byte) { + cat := ClassifyCommand(name, args[1:]) + + switch cat { + case CmdTxn: + p.handleTxnCommand(conn, state, name) + case CmdWrite: + p.handleWrite(conn, args) + case CmdRead: + p.handleRead(conn, args) + case CmdBlocking: + p.handleBlocking(conn, args) + case CmdPubSub: + p.handlePubSub(conn, args) + case CmdAdmin: + p.handleAdmin(conn, args) + case CmdScript: + p.handleScript(conn, args) + } +} + +func (p *ProxyServer) handleQueuedCommand(conn redcon.Conn, state *proxyConnState, name string, args [][]byte) { + switch name { + case "EXEC": + p.execTxn(conn, state) + case "DISCARD": + p.discardTxn(conn, state) + case "MULTI": + conn.WriteError("ERR MULTI calls can not be nested") + default: + state.txnQueue = append(state.txnQueue, args) + conn.WriteString("QUEUED") + } +} + +func (p *ProxyServer) handleWrite(conn redcon.Conn, args [][]byte) { + resp, err := p.dual.Write(context.Background(), args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handleRead(conn redcon.Conn, args [][]byte) { + resp, err := p.dual.Read(context.Background(), args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handleBlocking(conn redcon.Conn, args [][]byte) { + // Use shutdownCtx so blocking commands are interrupted on graceful shutdown. + resp, err := p.dual.Blocking(p.shutdownCtx, args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handlePubSub(conn redcon.Conn, args [][]byte) { + name := strings.ToUpper(string(args[0])) + + switch name { + case cmdSubscribe, cmdPSubscribe: + p.startPubSubSession(conn, name, args) + case cmdUnsubscribe, cmdPUnsubscribe: + // No active session; return empty confirmation. + conn.WriteArray(pubsubArrayReply) + conn.WriteBulkString(strings.ToLower(name)) + conn.WriteNull() + conn.WriteInt64(0) + default: + // PUBSUB CHANNELS / NUMSUB etc. + resp, err := p.dual.Admin(context.Background(), args) + writeResponse(conn, resp, err) + } +} + +func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args [][]byte) { + psBackend := p.dual.PubSubBackend() + if psBackend == nil { + conn.WriteError("ERR PubSub not supported by backend") + return + } + + if len(args) < pubsubMinArgs { + conn.WriteError(fmt.Sprintf("ERR wrong number of arguments for '%s' command", strings.ToLower(cmdName))) + return + } + + // Create dedicated upstream PubSub connection. + upstream := psBackend.NewPubSub(context.Background()) + + channels := byteSlicesToStrings(args[1:]) + var err error + if cmdName == cmdSubscribe { + err = upstream.Subscribe(context.Background(), channels...) + } else { + err = upstream.PSubscribe(context.Background(), channels...) + } + if err != nil { + upstream.Close() + conn.WriteError("ERR " + err.Error()) + return + } + + // Detach the connection from redcon's event loop. + dconn := conn.Detach() + + session := &pubsubSession{ + dconn: dconn, + upstream: upstream, + logger: p.logger, + } + + // Write initial subscription confirmations. + kind := strings.ToLower(cmdName) + for i, ch := range channels { + dconn.WriteArray(pubsubArrayReply) + dconn.WriteBulkString(kind) + dconn.WriteBulkString(ch) + if cmdName == cmdSubscribe { + session.channels = i + 1 + } else { + session.patterns = i + 1 + } + dconn.WriteInt(session.channels + session.patterns) + } + if err := dconn.Flush(); err != nil { + dconn.Close() + upstream.Close() + return + } + + go session.run() +} + +func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) { + name := strings.ToUpper(string(args[0])) + + // Handle PING locally for speed. + if name == "PING" { + if len(args) > 1 { + conn.WriteBulk(args[1]) + } else { + conn.WriteString("PONG") + } + return + } + + // Handle QUIT locally. + if name == "QUIT" { + conn.WriteString("OK") + conn.Close() + return + } + + resp, err := p.dual.Admin(context.Background(), args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handleScript(conn redcon.Conn, args [][]byte) { + resp, err := p.dual.Script(context.Background(), args) + writeResponse(conn, resp, err) +} + +// Transaction handling + +func (p *ProxyServer) handleTxnCommand(conn redcon.Conn, state *proxyConnState, name string) { + switch name { + case "MULTI": + if state.inTxn { + conn.WriteError("ERR MULTI calls can not be nested") + return + } + state.inTxn = true + state.txnQueue = nil + conn.WriteString("OK") + case "EXEC": + if !state.inTxn { + conn.WriteError("ERR EXEC without MULTI") + return + } + p.execTxn(conn, state) + case "DISCARD": + if !state.inTxn { + conn.WriteError("ERR DISCARD without MULTI") + return + } + p.discardTxn(conn, state) + } +} + +func (p *ProxyServer) execTxn(conn redcon.Conn, state *proxyConnState) { + queue := state.txnQueue + state.inTxn = false + state.txnQueue = nil + + ctx := context.Background() + + // Build pipeline: MULTI + queued commands + EXEC + cmds := make([][]interface{}, 0, len(queue)+txnCommandsOverhead) + cmds = append(cmds, []interface{}{"MULTI"}) + for _, args := range queue { + cmds = append(cmds, bytesArgsToInterfaces(args)) + } + cmds = append(cmds, []interface{}{"EXEC"}) + + results, err := p.dual.Primary().Pipeline(ctx, cmds) + if err != nil { + // Pipeline exec error — still try to extract EXEC result + if len(results) > 0 { + lastResult := results[len(results)-1] + resp, rErr := lastResult.Result() + writeResponse(conn, resp, rErr) + } else { + writeRedisError(conn, err) + } + } else { + // The EXEC result is the last command + if len(results) > 0 { + lastResult := results[len(results)-1] + resp, rErr := lastResult.Result() + writeResponse(conn, resp, rErr) + } + } + + // Async replay to secondary (bounded) + if p.dual.hasSecondaryWrite() { + p.dual.goAsync(func() { + sCtx, cancel := context.WithTimeout(context.Background(), p.cfg.SecondaryTimeout) + defer cancel() + _, pErr := p.dual.Secondary().Pipeline(sCtx, cmds) + if pErr != nil { + p.logger.Warn("secondary txn replay failed", "err", pErr) + p.metrics.SecondaryWriteErrors.Inc() + } + }) + } +} + +func (p *ProxyServer) discardTxn(conn redcon.Conn, state *proxyConnState) { + state.inTxn = false + state.txnQueue = nil + conn.WriteString("OK") +} + +// writeResponse handles the common pattern of writing a go-redis response +// to a redcon connection, correctly handling redis.Nil and upstream errors. +func writeResponse(conn redcon.Conn, resp interface{}, err error) { + if err != nil { + if errors.Is(err, redis.Nil) { + conn.WriteNull() + return + } + writeRedisError(conn, err) + return + } + writeRedisValue(conn, resp) +} + +// writeRedisError writes an upstream error without double-prefixing. +// Redis errors already contain their prefix (e.g. "ERR ...", "WRONGTYPE ..."). +func writeRedisError(conn redcon.Conn, err error) { + msg := err.Error() + // go-redis errors are already formatted with prefix; pass through as-is. + conn.WriteError(msg) +} + +// writeRedisValue writes a go-redis response value to a redcon connection. +func writeRedisValue(conn redcon.Conn, val interface{}) { + if val == nil { + conn.WriteNull() + return + } + switch v := val.(type) { + case string: + // go-redis flattens Status and Bulk strings into Go strings. + // Use WriteString (status reply) for known status responses, + // WriteBulkString (bulk reply) for data values. + if isStatusResponse(v) { + conn.WriteString(v) + } else { + conn.WriteBulkString(v) + } + case int64: + conn.WriteInt64(v) + case []interface{}: + conn.WriteArray(len(v)) + for _, item := range v { + writeRedisValue(conn, item) + } + case []byte: + conn.WriteBulk(v) + case redis.Error: + conn.WriteError(v.Error()) + default: + conn.WriteBulkString(fmt.Sprintf("%v", v)) + } +} + +// isStatusResponse detects known Redis status reply strings that should be +// sent as simple strings (+OK) rather than bulk strings ($2\r\nOK). +func isStatusResponse(s string) bool { + switch s { + case "OK", "QUEUED", "PONG": + return true + default: + return false + } +} + +func cloneArgs(args [][]byte) [][]byte { + out := make([][]byte, len(args)) + for i, arg := range args { + cp := make([]byte, len(arg)) + copy(cp, arg) + out[i] = cp + } + return out +} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go new file mode 100644 index 00000000..84e6b8fb --- /dev/null +++ b/proxy/proxy_test.go @@ -0,0 +1,622 @@ +package proxy + +import ( + "context" + "errors" + "io" + "log/slog" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var testLogger = slog.New(slog.NewTextHandler(io.Discard, nil)) + +// --- Mock Backend --- + +type mockBackend struct { + name string + doFunc func(ctx context.Context, args ...interface{}) *redis.Cmd + mu sync.Mutex + calls [][]interface{} +} + +func newMockBackend(name string) *mockBackend { + return &mockBackend{name: name} +} + +func (b *mockBackend) Do(ctx context.Context, args ...interface{}) *redis.Cmd { + b.mu.Lock() + b.calls = append(b.calls, args) + b.mu.Unlock() + if b.doFunc != nil { + return b.doFunc(ctx, args...) + } + cmd := redis.NewCmd(ctx, args...) + cmd.SetVal("OK") + return cmd +} + +func (b *mockBackend) Pipeline(ctx context.Context, cmds [][]interface{}) ([]*redis.Cmd, error) { + results := make([]*redis.Cmd, len(cmds)) + for i, args := range cmds { + results[i] = b.Do(ctx, args...) + } + return results, nil +} + +func (b *mockBackend) Close() error { return nil } +func (b *mockBackend) Name() string { return b.name } + +func (b *mockBackend) CallCount() int { + b.mu.Lock() + defer b.mu.Unlock() + return len(b.calls) +} + +// Helper to create a doFunc that returns a specific value. +func makeCmd(val interface{}, err error) func(ctx context.Context, args ...interface{}) *redis.Cmd { + return func(ctx context.Context, args ...interface{}) *redis.Cmd { + cmd := redis.NewCmd(ctx, args...) + if err != nil { + cmd.SetErr(err) + } else { + cmd.SetVal(val) + } + return cmd + } +} + +func newTestMetrics() *ProxyMetrics { + reg := prometheus.NewRegistry() + return NewProxyMetrics(reg) +} + +func newTestSentry() *SentryReporter { + return NewSentryReporter("", "", 1.0, testLogger) +} + +// ========== command.go tests ========== + +func TestClassifyCommand(t *testing.T) { + tests := []struct { + name string + cmd string + args [][]byte + expected CommandCategory + }{ + {"GET is read", "GET", nil, CmdRead}, + {"SET is write", "SET", nil, CmdWrite}, + {"DEL is write", "DEL", nil, CmdWrite}, + {"HGET is read", "HGET", nil, CmdRead}, + {"HSET is write", "HSET", nil, CmdWrite}, + {"ZADD is write", "ZADD", nil, CmdWrite}, + {"ZRANGE is read", "ZRANGE", nil, CmdRead}, + {"PING is admin", "PING", nil, CmdAdmin}, + {"INFO is admin", "INFO", nil, CmdAdmin}, + {"SELECT is admin", "SELECT", nil, CmdAdmin}, + {"QUIT is admin", "QUIT", nil, CmdAdmin}, + {"AUTH is admin", "AUTH", nil, CmdAdmin}, + {"HELLO is admin", "HELLO", nil, CmdAdmin}, + {"CONFIG is admin", "CONFIG", nil, CmdAdmin}, + {"WAIT is admin", "WAIT", nil, CmdAdmin}, + {"COMMAND is admin", "COMMAND", nil, CmdAdmin}, + {"MULTI is txn", "MULTI", nil, CmdTxn}, + {"EXEC is txn", "EXEC", nil, CmdTxn}, + {"DISCARD is txn", "DISCARD", nil, CmdTxn}, + {"SUBSCRIBE is pubsub", "SUBSCRIBE", nil, CmdPubSub}, + {"UNSUBSCRIBE is pubsub", "UNSUBSCRIBE", nil, CmdPubSub}, + {"PSUBSCRIBE is pubsub", "PSUBSCRIBE", nil, CmdPubSub}, + {"PUNSUBSCRIBE is pubsub", "PUNSUBSCRIBE", nil, CmdPubSub}, + {"BZPOPMIN is blocking", "BZPOPMIN", nil, CmdBlocking}, + {"EVAL is script", "EVAL", nil, CmdScript}, + {"lowercase get is read", "get", nil, CmdRead}, + {"GETDEL is write", "GETDEL", nil, CmdWrite}, + {"PUBLISH is write", "PUBLISH", nil, CmdWrite}, + {"unknown cmd is write", "UNKNOWNCMD", nil, CmdWrite}, + + // XREAD special handling + {"XREAD without BLOCK is read", "XREAD", [][]byte{[]byte("COUNT"), []byte("10"), []byte("STREAMS"), []byte("s1"), []byte("0")}, CmdRead}, + {"XREAD with BLOCK is blocking", "XREAD", [][]byte{[]byte("BLOCK"), []byte("0"), []byte("STREAMS"), []byte("s1"), []byte("0")}, CmdBlocking}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ClassifyCommand(tt.cmd, tt.args) + assert.Equal(t, tt.expected, got) + }) + } +} + +// ========== config.go tests ========== + +func TestParseProxyMode(t *testing.T) { + tests := []struct { + input string + expected ProxyMode + ok bool + }{ + {"redis-only", ModeRedisOnly, true}, + {"dual-write", ModeDualWrite, true}, + {"dual-write-shadow", ModeDualWriteShadow, true}, + {"elastickv-primary", ModeElasticKVPrimary, true}, + {"elastickv-only", ModeElasticKVOnly, true}, + {"invalid", 0, false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, ok := ParseProxyMode(tt.input) + assert.Equal(t, tt.ok, ok) + if ok { + assert.Equal(t, tt.expected, got) + } + }) + } +} + +func TestProxyModeString(t *testing.T) { + assert.Equal(t, "redis-only", ModeRedisOnly.String()) + assert.Equal(t, "dual-write", ModeDualWrite.String()) + assert.Equal(t, "dual-write-shadow", ModeDualWriteShadow.String()) + assert.Equal(t, "elastickv-primary", ModeElasticKVPrimary.String()) + assert.Equal(t, "elastickv-only", ModeElasticKVOnly.String()) + assert.Equal(t, "unknown", ProxyMode(99).String()) +} + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + assert.Equal(t, ":6479", cfg.ListenAddr) + assert.Equal(t, "localhost:6379", cfg.PrimaryAddr) + assert.Equal(t, "localhost:6380", cfg.SecondaryAddr) + assert.Equal(t, ModeDualWrite, cfg.Mode) +} + +// ========== compare.go tests ========== + +func TestDivergenceKindString(t *testing.T) { + assert.Equal(t, "migration_gap", DivMigrationGap.String()) + assert.Equal(t, "data_mismatch", DivDataMismatch.String()) + assert.Equal(t, "extra_data", DivExtraData.String()) + assert.Equal(t, "unknown", DivergenceKind(99).String()) +} + +func TestExtractKey(t *testing.T) { + assert.Equal(t, "mykey", extractKey([][]byte{[]byte("GET"), []byte("mykey")})) + assert.Equal(t, "", extractKey([][]byte{[]byte("PING")})) + assert.Equal(t, "", extractKey(nil)) +} + +func TestBytesArgsToInterfaces(t *testing.T) { + args := [][]byte{[]byte("SET"), []byte("key"), []byte("val")} + result := bytesArgsToInterfaces(args) + assert.Len(t, result, 3) + assert.Equal(t, "SET", result[0]) + assert.Equal(t, "key", result[1]) + assert.Equal(t, "val", result[2]) +} + +func TestResponseEqual(t *testing.T) { + tests := []struct { + name string + a, b interface{} + want bool + }{ + {"nil nil", nil, nil, true}, + {"nil vs value", nil, "hello", false}, + {"value vs nil", "hello", nil, false}, + {"same string", "hello", "hello", true}, + {"diff string", "hello", "world", false}, + {"empty string equals empty string", "", "", true}, + {"same int64", int64(42), int64(42), true}, + {"diff int64", int64(42), int64(43), false}, + {"same array", []interface{}{"a", "b"}, []interface{}{"a", "b"}, true}, + {"diff array values", []interface{}{"a", "b"}, []interface{}{"a", "c"}, false}, + {"diff array length", []interface{}{"a"}, []interface{}{"a", "b"}, false}, + {"empty arrays", []interface{}{}, []interface{}{}, true}, + {"same bytes", []byte("data"), []byte("data"), true}, + {"diff bytes", []byte("data"), []byte("other"), false}, + {"nested array", []interface{}{[]interface{}{"x"}}, []interface{}{[]interface{}{"x"}}, true}, + {"type mismatch string vs int", "42", int64(42), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, responseEqual(tt.a, tt.b)) + }) + } +} + +func TestClassifyDivergence(t *testing.T) { + tests := []struct { + name string + primaryResp, secondaryResp interface{} + primaryErr, secondaryErr error + want DivergenceKind + }{ + {"primary has data, secondary nil", "val", nil, nil, redis.Nil, DivMigrationGap}, + {"primary nil, secondary has data", nil, "val", redis.Nil, nil, DivExtraData}, + {"both have different data", "val1", "val2", nil, nil, DivDataMismatch}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyDivergence(tt.primaryResp, tt.primaryErr, tt.secondaryResp, tt.secondaryErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestIsNilResp(t *testing.T) { + assert.True(t, isNilResp(nil, redis.Nil)) + assert.True(t, isNilResp(nil, nil)) + assert.False(t, isNilResp("", nil), "empty string is NOT nil") + assert.False(t, isNilResp("hello", nil)) + assert.False(t, isNilResp(int64(0), nil)) +} + +// ========== proxy.go tests ========== + +func TestCloneArgs(t *testing.T) { + original := [][]byte{[]byte("GET"), []byte("key")} + cloned := cloneArgs(original) + + assert.Equal(t, original, cloned) + + // Mutating cloned should not affect original. + cloned[0][0] = 'X' + assert.Equal(t, byte('G'), original[0][0]) +} + +func TestIsStatusResponse(t *testing.T) { + assert.True(t, isStatusResponse("OK")) + assert.True(t, isStatusResponse("QUEUED")) + assert.True(t, isStatusResponse("PONG")) + assert.False(t, isStatusResponse("hello")) + assert.False(t, isStatusResponse("")) + assert.False(t, isStatusResponse("ok")) // case-sensitive +} + +// ========== sentry.go tests ========== + +func TestSentryReporterDisabled(t *testing.T) { + r := NewSentryReporter("", "", 1.0, nil) + assert.False(t, r.enabled) + // Should not panic + r.CaptureException(nil, "test", nil) + r.CaptureDivergence(Divergence{}) + r.Flush(0) +} + +func TestShouldReportCooldown(t *testing.T) { + r := &SentryReporter{ + lastReport: make(map[string]time.Time), + cooldown: 50 * time.Millisecond, + } + + assert.True(t, r.ShouldReport("fp1")) + assert.False(t, r.ShouldReport("fp1")) // within cooldown + + time.Sleep(60 * time.Millisecond) + assert.True(t, r.ShouldReport("fp1")) // cooldown elapsed +} + +func TestShouldReportEvictsExpired(t *testing.T) { + r := &SentryReporter{ + lastReport: make(map[string]time.Time), + cooldown: 1 * time.Millisecond, + } + // Fill to maxReportEntries + for i := range maxReportEntries { + r.lastReport[string(rune(i))] = time.Now().Add(-time.Hour) + } + time.Sleep(2 * time.Millisecond) + assert.True(t, r.ShouldReport("new-fp")) + assert.Less(t, len(r.lastReport), maxReportEntries) +} + +// ========== dualwrite.go tests ========== + +func TestHasSecondaryWrite(t *testing.T) { + for _, tc := range []struct { + mode ProxyMode + expected bool + }{ + {ModeRedisOnly, false}, + {ModeDualWrite, true}, + {ModeDualWriteShadow, true}, + {ModeElasticKVPrimary, true}, + {ModeElasticKVOnly, false}, + } { + d := &DualWriter{cfg: ProxyConfig{Mode: tc.mode}, asyncSem: make(chan struct{}, 1)} + assert.Equal(t, tc.expected, d.hasSecondaryWrite(), "mode=%s", tc.mode) + } +} + +func TestDualWriter_Write_PrimarySuccess(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("OK", nil) + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + resp, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + assert.Equal(t, "OK", resp) + assert.Equal(t, 1, primary.CallCount()) + + // Wait for async secondary write + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, secondary.CallCount()) +} + +func TestDualWriter_Write_PrimaryFail(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd(nil, errors.New("connection refused")) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + _, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "connection refused") + // Secondary should NOT be called when primary fails + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, secondary.CallCount()) +} + +func TestDualWriter_Write_SecondaryFail_ClientSucceeds(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, errors.New("secondary down")) + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + resp, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + assert.Equal(t, "OK", resp) + + time.Sleep(50 * time.Millisecond) + assert.InDelta(t, 1, testutil.ToFloat64(metrics.SecondaryWriteErrors), 0.001) +} + +func TestDualWriter_Write_RedisNil(t *testing.T) { + // SET with NX returns redis.Nil when key exists + primary := newMockBackend("primary") + primary.doFunc = makeCmd(nil, redis.Nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + resp, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v"), []byte("NX")}) + assert.ErrorIs(t, err, redis.Nil) + assert.Nil(t, resp) + // Should still send to secondary + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, secondary.CallCount()) +} + +func TestDualWriter_Write_RedisOnlyMode(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly}, metrics, newTestSentry(), testLogger) + + _, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, secondary.CallCount(), "secondary should not be called in redis-only mode") +} + +func TestDualWriter_Write_ElasticKVOnlyMode(t *testing.T) { + primary := newMockBackend("elastickv") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("redis") + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeElasticKVOnly}, metrics, newTestSentry(), testLogger) + + _, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, secondary.CallCount(), "secondary should not be called in elastickv-only mode") +} + +func TestDualWriter_Read_WithShadow(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("hello", nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("hello", nil) + + metrics := newTestMetrics() + cfg := ProxyConfig{Mode: ModeDualWriteShadow, ShadowTimeout: time.Second, SecondaryTimeout: time.Second} + d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) + + resp, err := d.Read(context.Background(), [][]byte{[]byte("GET"), []byte("k")}) + assert.NoError(t, err) + assert.Equal(t, "hello", resp) + + // Wait for shadow read + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, secondary.CallCount(), "shadow read should be issued") +} + +func TestDualWriter_Read_NoShadowInDualWrite(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("hello", nil) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + cfg := ProxyConfig{Mode: ModeDualWrite, ShadowTimeout: time.Second} + d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) + + _, err := d.Read(context.Background(), [][]byte{[]byte("GET"), []byte("k")}) + assert.NoError(t, err) + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, secondary.CallCount(), "no shadow in dual-write mode") +} + +func TestDualWriter_GoAsync_Bounded(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + cfg := ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: 10 * time.Second} + d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) + + // Fill the semaphore with blocking goroutines + blocker := make(chan struct{}) + for range maxAsyncGoroutines { + d.goAsync(func() { + <-blocker + }) + } + + // Next one should be dropped, not block + done := make(chan struct{}) + go func() { + d.goAsync(func() { t.Error("should not run") }) + close(done) + }() + + select { + case <-done: + // good — goAsync returned immediately + case <-time.After(time.Second): + t.Fatal("goAsync blocked when semaphore was full") + } + + close(blocker) // unblock all +} + +// ========== ShadowReader tests ========== + +func TestShadowReader_Compare_Equal(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("hello", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "hello", nil) + + // No divergence should be reported + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_BothNil(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, nil, redis.Nil) + + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_MigrationGap(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) // secondary has no data + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "hello", nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) +} + +func TestShadowReader_Compare_DataMismatch(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("world", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "hello", nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) +} + +func TestShadowReader_Compare_ExtraData(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("surprise", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, nil, redis.Nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "extra_data")), 0.001) +} + +func TestShadowReader_Compare_EmptyStringIsNotNil(t *testing.T) { + // Primary returns "", secondary returns "" → equal, no divergence + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "", nil) + + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_EmptyStringVsNil(t *testing.T) { + // Primary returns "", secondary returns nil → MigrationGap + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "", nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_MigrationGapSampling(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + sr.gapLogSampleRate = 10 + + for range 25 { + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "val", nil) + } + + assert.InDelta(t, 25, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) + assert.Equal(t, int64(25), sr.gapCount.Load()) +} + +// ========== Backend tests ========== + +func TestDefaultBackendOptions(t *testing.T) { + opts := DefaultBackendOptions() + assert.Equal(t, 128, opts.PoolSize) + assert.Equal(t, 5*time.Second, opts.DialTimeout) +} diff --git a/proxy/pubsub.go b/proxy/pubsub.go new file mode 100644 index 00000000..933949c8 --- /dev/null +++ b/proxy/pubsub.go @@ -0,0 +1,238 @@ +package proxy + +import ( + "context" + "log/slog" + "strings" + "sync" + + "github.com/redis/go-redis/v9" + "github.com/tidwall/redcon" +) + +const ( + pubsubArrayMessage = 3 // ["message", channel, payload] + pubsubArrayPMessage = 4 // ["pmessage", pattern, channel, payload] + pubsubArrayReply = 3 // ["subscribe"/"unsubscribe", channel, count] + pubsubArrayPong = 2 // ["pong", data] + pubsubMinArgs = 2 // command + at least one channel + + cmdSubscribe = "SUBSCRIBE" + cmdUnsubscribe = "UNSUBSCRIBE" + cmdPSubscribe = "PSUBSCRIBE" + cmdPUnsubscribe = "PUNSUBSCRIBE" +) + +// pubsubSession manages a single client's pub/sub session. +// It bridges a detached redcon connection to an upstream go-redis PubSub connection. +type pubsubSession struct { + mu sync.Mutex + dconn redcon.DetachedConn + upstream *redis.PubSub + logger *slog.Logger + closed bool + + // Track subscription counts for RESP replies. + channels int + patterns int +} + +// run starts the forwarding session. It blocks until the client disconnects +// or the upstream closes. +func (s *pubsubSession) run() { + defer func() { + s.upstream.Close() + s.mu.Lock() + s.closed = true + s.mu.Unlock() + s.dconn.Close() + }() + + go s.forwardMessages() + s.readClientCommands() +} + +// forwardMessages reads from the upstream go-redis PubSub channel and writes +// messages to the detached client connection. +func (s *pubsubSession) forwardMessages() { + ch := s.upstream.Channel() + for msg := range ch { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return + } + if msg.Pattern != "" { + s.dconn.WriteArray(pubsubArrayPMessage) + s.dconn.WriteBulkString("pmessage") + s.dconn.WriteBulkString(msg.Pattern) + s.dconn.WriteBulkString(msg.Channel) + s.dconn.WriteBulkString(msg.Payload) + } else { + s.dconn.WriteArray(pubsubArrayMessage) + s.dconn.WriteBulkString("message") + s.dconn.WriteBulkString(msg.Channel) + s.dconn.WriteBulkString(msg.Payload) + } + err := s.dconn.Flush() + s.mu.Unlock() + if err != nil { + return + } + } +} + +// readClientCommands reads commands from the detached client connection. +// In pub/sub mode, only SUBSCRIBE, UNSUBSCRIBE, PSUBSCRIBE, PUNSUBSCRIBE, +// PING, and QUIT are valid. +func (s *pubsubSession) readClientCommands() { + for { + cmd, err := s.dconn.ReadCommand() + if err != nil { + return + } + if len(cmd.Args) == 0 { + continue + } + switch strings.ToUpper(string(cmd.Args[0])) { + case cmdSubscribe: + s.handleSubscribe(cmd.Args) + case cmdUnsubscribe: + s.handleUnsub(cmd.Args, false) + case cmdPSubscribe: + s.handlePSubscribe(cmd.Args) + case cmdPUnsubscribe: + s.handleUnsub(cmd.Args, true) + case "PING": + s.handlePing(cmd.Args) + case "QUIT": + return + } + } +} + +func (s *pubsubSession) handleSubscribe(args [][]byte) { + if len(args) < pubsubMinArgs { + s.writeError("ERR wrong number of arguments for 'subscribe'") + return + } + channels := byteSlicesToStrings(args[1:]) + if err := s.upstream.Subscribe(context.Background(), channels...); err != nil { + s.logger.Warn("upstream subscribe failed", "err", err) + return + } + s.mu.Lock() + for _, ch := range channels { + s.channels++ + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString("subscribe") + s.dconn.WriteBulkString(ch) + s.dconn.WriteInt(s.channels + s.patterns) + } + _ = s.dconn.Flush() + s.mu.Unlock() +} + +func (s *pubsubSession) handlePSubscribe(args [][]byte) { + if len(args) < pubsubMinArgs { + s.writeError("ERR wrong number of arguments for 'psubscribe'") + return + } + pats := byteSlicesToStrings(args[1:]) + if err := s.upstream.PSubscribe(context.Background(), pats...); err != nil { + s.logger.Warn("upstream psubscribe failed", "err", err) + return + } + s.mu.Lock() + for _, p := range pats { + s.patterns++ + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString("psubscribe") + s.dconn.WriteBulkString(p) + s.dconn.WriteInt(s.channels + s.patterns) + } + _ = s.dconn.Flush() + s.mu.Unlock() +} + +// handleUnsub handles both UNSUBSCRIBE and PUNSUBSCRIBE. +// When isPattern is true, it operates on pattern subscriptions. +func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { + kind := "unsubscribe" + unsubFn := s.upstream.Unsubscribe + if isPattern { + kind = "punsubscribe" + unsubFn = s.upstream.PUnsubscribe + } + + if len(args) < pubsubMinArgs { + // Unsubscribe all + if err := unsubFn(context.Background()); err != nil { + s.logger.Warn("upstream "+kind+" failed", "err", err) + } + s.mu.Lock() + if isPattern { + s.patterns = 0 + } else { + s.channels = 0 + } + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteNull() + s.dconn.WriteInt(s.channels + s.patterns) + _ = s.dconn.Flush() + s.mu.Unlock() + return + } + + names := byteSlicesToStrings(args[1:]) + if err := unsubFn(context.Background(), names...); err != nil { + s.logger.Warn("upstream "+kind+" failed", "err", err) + } + s.mu.Lock() + for _, n := range names { + if isPattern { + if s.patterns > 0 { + s.patterns-- + } + } else { + if s.channels > 0 { + s.channels-- + } + } + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(n) + s.dconn.WriteInt(s.channels + s.patterns) + } + _ = s.dconn.Flush() + s.mu.Unlock() +} + +func (s *pubsubSession) handlePing(args [][]byte) { + s.mu.Lock() + defer s.mu.Unlock() + s.dconn.WriteArray(pubsubArrayPong) + s.dconn.WriteBulkString("pong") + if len(args) > 1 { + s.dconn.WriteBulk(args[1]) + } else { + s.dconn.WriteBulkString("") + } + _ = s.dconn.Flush() +} + +func (s *pubsubSession) writeError(msg string) { + s.mu.Lock() + defer s.mu.Unlock() + s.dconn.WriteError(msg) + _ = s.dconn.Flush() +} + +func byteSlicesToStrings(bs [][]byte) []string { + out := make([]string, len(bs)) + for i, b := range bs { + out[i] = string(b) + } + return out +} diff --git a/proxy/sentry.go b/proxy/sentry.go new file mode 100644 index 00000000..f39a5ddd --- /dev/null +++ b/proxy/sentry.go @@ -0,0 +1,124 @@ +package proxy + +import ( + "fmt" + "log/slog" + "sync" + "time" + + "github.com/getsentry/sentry-go" +) + +const ( + defaultReportCooldown = 60 * time.Second + // maxReportEntries caps the lastReport map to prevent unbounded growth. + maxReportEntries = 10000 +) + +// SentryReporter sends anomaly events to Sentry with de-duplication. +type SentryReporter struct { + enabled bool + hub *sentry.Hub + logger *slog.Logger + cooldown time.Duration + + mu sync.Mutex + lastReport map[string]time.Time // fingerprint → last report time +} + +// NewSentryReporter initialises Sentry. If dsn is empty, reporting is disabled. +func NewSentryReporter(dsn string, environment string, sampleRate float64, logger *slog.Logger) *SentryReporter { + r := &SentryReporter{ + logger: logger, + cooldown: defaultReportCooldown, + lastReport: make(map[string]time.Time), + } + if dsn == "" { + return r + } + + err := sentry.Init(sentry.ClientOptions{ + Dsn: dsn, + Environment: environment, + SampleRate: sampleRate, + EnableTracing: false, + AttachStacktrace: true, + }) + if err != nil { + logger.Error("failed to init sentry", "err", err) + return r + } + r.enabled = true + r.hub = sentry.CurrentHub() + return r +} + +// CaptureException reports an error to Sentry. +func (r *SentryReporter) CaptureException(err error, operation string, args [][]byte) { + if !r.enabled { + return + } + r.hub.WithScope(func(scope *sentry.Scope) { + scope.SetTag("operation", operation) + if len(args) > 0 { + scope.SetTag("command", string(args[0])) + } + scope.SetFingerprint([]string{operation, cmdNameFromArgs(args)}) + r.hub.CaptureException(err) + }) +} + +// CaptureDivergence reports a data divergence to Sentry. +func (r *SentryReporter) CaptureDivergence(div Divergence) { + if !r.enabled { + return + } + r.hub.WithScope(func(scope *sentry.Scope) { + scope.SetTag("command", div.Command) + scope.SetTag("key", div.Key) + scope.SetTag("kind", div.Kind.String()) + scope.SetExtra("primary", fmt.Sprintf("%v", div.Primary)) + scope.SetExtra("secondary", fmt.Sprintf("%v", div.Secondary)) + scope.SetFingerprint([]string{"divergence", div.Kind.String(), div.Command}) + scope.SetLevel(sentry.LevelWarning) + r.hub.CaptureMessage(fmt.Sprintf("data divergence: %s %s (%s)", div.Kind, div.Command, div.Key)) + }) +} + +// ShouldReport checks if this fingerprint has been reported recently (cooldown-based). +// Periodically evicts expired entries to prevent unbounded map growth. +func (r *SentryReporter) ShouldReport(fingerprint string) bool { + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + + // Evict expired entries if map grows too large + if len(r.lastReport) >= maxReportEntries { + for k, t := range r.lastReport { + if now.Sub(t) >= r.cooldown { + delete(r.lastReport, k) + } + } + } + + if last, ok := r.lastReport[fingerprint]; ok && now.Sub(last) < r.cooldown { + return false + } + r.lastReport[fingerprint] = now + return true +} + +// Flush waits for pending Sentry events. +func (r *SentryReporter) Flush(timeout time.Duration) { + if r.enabled { + sentry.Flush(timeout) + } +} + +func cmdNameFromArgs(args [][]byte) string { + if len(args) > 0 { + return string(args[0]) + } + return unknownStr +} From 06736e98e561217bcf52824bb89b1d3700ca5f3a Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 14:28:46 +0900 Subject: [PATCH 02/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/pubsub.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 933949c8..9f07d1ac 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -107,6 +107,10 @@ func (s *pubsubSession) readClientCommands() { s.handlePing(cmd.Args) case "QUIT": return + default: + // In pub/sub mode, Redis only allows (P)SUBSCRIBE, (P)UNSUBSCRIBE, PING, and QUIT. + // Any other command must return an error to avoid clients hanging waiting for a reply. + s.writeError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") } } } From f35c41d75e76bb224130079866a873b8f3358996 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 15:12:16 +0900 Subject: [PATCH 03/43] Address PR #351 review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract dispatchPubSubCommand to reduce readClientCommands complexity (CC=11→8) - Use metricsSrv.Shutdown for graceful metrics server stop - Guard against nil logger in NewSentryReporter - Cap ShouldReport map after eviction to prevent unbounded growth - Normalize command names with strings.ToUpper in DualWriter methods - Fix goAsync doc comment to match actual behavior - Simplify execTxn by merging duplicate result-handling branches --- cmd/redis-proxy/main.go | 4 +++- proxy/dualwrite.go | 13 +++++++------ proxy/proxy.go | 23 +++++++---------------- proxy/pubsub.go | 39 +++++++++++++++++++++++---------------- proxy/sentry.go | 7 +++++++ 5 files changed, 47 insertions(+), 39 deletions(-) diff --git a/cmd/redis-proxy/main.go b/cmd/redis-proxy/main.go index 25696065..d6c39bf9 100644 --- a/cmd/redis-proxy/main.go +++ b/cmd/redis-proxy/main.go @@ -91,7 +91,9 @@ func run() error { metricsSrv := &http.Server{Handler: mux, ReadHeaderTimeout: time.Second} go func() { <-ctx.Done() - metricsSrv.Close() + if err := metricsSrv.Shutdown(context.Background()); err != nil { + logger.Warn("metrics server shutdown error", "err", err) + } }() logger.Info("metrics server starting", "addr", cfg.MetricsAddr) if err := metricsSrv.Serve(ln); err != nil && err != http.ErrServerClosed { diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index df42a32c..84f736c7 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "strings" "time" "github.com/redis/go-redis/v9" @@ -54,7 +55,7 @@ func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMe // Write sends a write command to the primary synchronously, then to the secondary asynchronously. func (d *DualWriter) Write(ctx context.Context, args [][]byte) (interface{}, error) { - cmd := string(args[0]) + cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -83,7 +84,7 @@ func (d *DualWriter) Write(ctx context.Context, args [][]byte) (interface{}, err // Read sends a read command to the primary and optionally performs a shadow read. func (d *DualWriter) Read(ctx context.Context, args [][]byte) (interface{}, error) { - cmd := string(args[0]) + cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -117,7 +118,7 @@ func (d *DualWriter) Read(ctx context.Context, args [][]byte) (interface{}, erro // Blocking forwards a blocking command to the primary only. // Optionally sends a short-timeout version to secondary for warmup. func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (interface{}, error) { - cmd := string(args[0]) + cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -148,7 +149,7 @@ func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (interface{}, // Admin forwards an admin command to the primary only. func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (interface{}, error) { - cmd := string(args[0]) + cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -169,7 +170,7 @@ func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (interface{}, err // Script forwards EVAL/EVALSHA to the primary, and async replays to secondary. func (d *DualWriter) Script(ctx context.Context, args [][]byte) (interface{}, error) { - cmd := string(args[0]) + cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -216,7 +217,7 @@ func (d *DualWriter) writeSecondary(cmd string, iArgs []interface{}) { } // goAsync launches fn in a bounded goroutine. If the semaphore is full, -// the work is dropped and a metric is incremented rather than blocking the caller. +// the work is dropped and a warning is logged rather than blocking the caller. func (d *DualWriter) goAsync(fn func()) { select { case d.asyncSem <- struct{}{}: diff --git a/proxy/proxy.go b/proxy/proxy.go index 60a00fe2..7907ae8c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -321,22 +321,13 @@ func (p *ProxyServer) execTxn(conn redcon.Conn, state *proxyConnState) { cmds = append(cmds, []interface{}{"EXEC"}) results, err := p.dual.Primary().Pipeline(ctx, cmds) - if err != nil { - // Pipeline exec error — still try to extract EXEC result - if len(results) > 0 { - lastResult := results[len(results)-1] - resp, rErr := lastResult.Result() - writeResponse(conn, resp, rErr) - } else { - writeRedisError(conn, err) - } - } else { - // The EXEC result is the last command - if len(results) > 0 { - lastResult := results[len(results)-1] - resp, rErr := lastResult.Result() - writeResponse(conn, resp, rErr) - } + if len(results) > 0 { + // Write the EXEC result (last command in the pipeline). + lastResult := results[len(results)-1] + resp, rErr := lastResult.Result() + writeResponse(conn, resp, rErr) + } else if err != nil { + writeRedisError(conn, err) } // Async replay to secondary (bounded) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 9f07d1ac..9ba3d72c 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -94,27 +94,34 @@ func (s *pubsubSession) readClientCommands() { if len(cmd.Args) == 0 { continue } - switch strings.ToUpper(string(cmd.Args[0])) { - case cmdSubscribe: - s.handleSubscribe(cmd.Args) - case cmdUnsubscribe: - s.handleUnsub(cmd.Args, false) - case cmdPSubscribe: - s.handlePSubscribe(cmd.Args) - case cmdPUnsubscribe: - s.handleUnsub(cmd.Args, true) - case "PING": - s.handlePing(cmd.Args) - case "QUIT": + if !s.dispatchPubSubCommand(cmd.Args) { return - default: - // In pub/sub mode, Redis only allows (P)SUBSCRIBE, (P)UNSUBSCRIBE, PING, and QUIT. - // Any other command must return an error to avoid clients hanging waiting for a reply. - s.writeError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") } } } +// dispatchPubSubCommand handles a single command in pub/sub mode. +// Returns false if the session should end (QUIT). +func (s *pubsubSession) dispatchPubSubCommand(args [][]byte) bool { + switch strings.ToUpper(string(args[0])) { + case cmdSubscribe: + s.handleSubscribe(args) + case cmdUnsubscribe: + s.handleUnsub(args, false) + case cmdPSubscribe: + s.handlePSubscribe(args) + case cmdPUnsubscribe: + s.handleUnsub(args, true) + case "PING": + s.handlePing(args) + case "QUIT": + return false + default: + s.writeError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") + } + return true +} + func (s *pubsubSession) handleSubscribe(args [][]byte) { if len(args) < pubsubMinArgs { s.writeError("ERR wrong number of arguments for 'subscribe'") diff --git a/proxy/sentry.go b/proxy/sentry.go index f39a5ddd..3967da17 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -28,6 +28,9 @@ type SentryReporter struct { // NewSentryReporter initialises Sentry. If dsn is empty, reporting is disabled. func NewSentryReporter(dsn string, environment string, sampleRate float64, logger *slog.Logger) *SentryReporter { + if logger == nil { + logger = slog.Default() + } r := &SentryReporter{ logger: logger, cooldown: defaultReportCooldown, @@ -100,6 +103,10 @@ func (r *SentryReporter) ShouldReport(fingerprint string) bool { delete(r.lastReport, k) } } + // If still at capacity after eviction, skip tracking to prevent unbounded growth. + if len(r.lastReport) >= maxReportEntries { + return true + } } if last, ok := r.lastReport[fingerprint]; ok && now.Sub(last) < r.cooldown { From 4bc099dfd9417c725fcfde6dc886e5b130420859 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 16:49:10 +0900 Subject: [PATCH 04/43] Fix binary data handling, type safety, and Sentry flood protection - Pass []byte directly in bytesArgsToInterfaces instead of converting to string - Handle []byte type assertion in argsToBytes for correct Sentry reporting - Fix responseEqual panic when comparing mismatched interface types - Return false in ShouldReport when at capacity to prevent Sentry flooding --- proxy/compare.go | 21 +++++++++++---------- proxy/dualwrite.go | 6 +++++- proxy/proxy_test.go | 6 +++--- proxy/sentry.go | 4 ++-- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/proxy/compare.go b/proxy/compare.go index 81c4d641..183b4497 100644 --- a/proxy/compare.go +++ b/proxy/compare.go @@ -133,17 +133,18 @@ func isConsistent(primaryResp, secondaryResp interface{}, primaryErr, secondaryE // responseEqual compares two go-redis response values for equality. func responseEqual(a, b interface{}) bool { - if a == nil && b == nil { - return true - } if a == nil || b == nil { - return false - } - switch a := a.(type) { - case string, int64: - return a == b + return a == nil && b == nil + } + switch av := a.(type) { + case string: + bv, ok := b.(string) + return ok && av == bv + case int64: + bv, ok := b.(int64) + return ok && av == bv case []interface{}: - return interfaceSliceEqual(a, b) + return interfaceSliceEqual(av, b) default: return reflect.DeepEqual(a, b) } @@ -208,7 +209,7 @@ func extractKey(args [][]byte) string { func bytesArgsToInterfaces(args [][]byte) []interface{} { out := make([]interface{}, len(args)) for i, a := range args { - out[i] = string(a) + out[i] = a } return out } diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index 84f736c7..494d5518 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -262,7 +262,11 @@ func (d *DualWriter) Secondary() Backend { func argsToBytes(iArgs []interface{}) [][]byte { out := make([][]byte, len(iArgs)) for i, a := range iArgs { - out[i] = []byte(fmt.Sprintf("%v", a)) + if b, ok := a.([]byte); ok { + out[i] = b + } else { + out[i] = []byte(fmt.Sprintf("%v", a)) + } } return out } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 84e6b8fb..505659e6 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -196,9 +196,9 @@ func TestBytesArgsToInterfaces(t *testing.T) { args := [][]byte{[]byte("SET"), []byte("key"), []byte("val")} result := bytesArgsToInterfaces(args) assert.Len(t, result, 3) - assert.Equal(t, "SET", result[0]) - assert.Equal(t, "key", result[1]) - assert.Equal(t, "val", result[2]) + assert.Equal(t, []byte("SET"), result[0]) + assert.Equal(t, []byte("key"), result[1]) + assert.Equal(t, []byte("val"), result[2]) } func TestResponseEqual(t *testing.T) { diff --git a/proxy/sentry.go b/proxy/sentry.go index 3967da17..82e1ddcd 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -103,9 +103,9 @@ func (r *SentryReporter) ShouldReport(fingerprint string) bool { delete(r.lastReport, k) } } - // If still at capacity after eviction, skip tracking to prevent unbounded growth. + // If still at capacity after eviction, drop report to prevent unbounded growth and Sentry flooding. if len(r.lastReport) >= maxReportEntries { - return true + return false } } From 411cf06279137103accfbb6d59d4d3a42dfe1308 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 17:04:01 +0900 Subject: [PATCH 05/43] Support normal command mode after pub/sub unsubscribe When all subscriptions are removed, the detached connection transitions to normal command mode instead of being stuck in pub/sub-only mode. Clients can then execute regular Redis commands (GET, SET, transactions, etc.) or re-enter pub/sub mode with a new SUBSCRIBE/PSUBSCRIBE. Key changes: - Add respWriter interface so writeResponse works with both Conn and DetachedConn - Refactor pubsubSession with commandLoop that handles both pub/sub and normal modes - Support transactions (MULTI/EXEC/DISCARD) in normal mode on detached connections - Track forwardMessages goroutine lifecycle for clean pub/sub mode transitions - Extract command name constants (MULTI, EXEC, DISCARD, PING, QUIT) --- proxy/proxy.go | 71 ++++++----- proxy/pubsub.go | 321 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 341 insertions(+), 51 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 7907ae8c..565992ec 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -15,6 +15,19 @@ import ( // txnCommandsOverhead is the number of extra commands (MULTI + EXEC) wrapping queued commands. const txnCommandsOverhead = 2 +// respWriter is the subset of redcon.Conn and redcon.DetachedConn used for writing RESP responses. +// Both connection types satisfy this interface, enabling shared response-writing logic +// between the main event loop and detached pub/sub sessions. +type respWriter interface { + WriteError(msg string) + WriteString(msg string) + WriteBulk(b []byte) + WriteBulkString(msg string) + WriteInt64(num int64) + WriteArray(count int) + WriteNull() +} + // proxyConnState tracks per-connection state (transactions, PubSub). type proxyConnState struct { inTxn bool @@ -143,11 +156,11 @@ func (p *ProxyServer) dispatchCommand(conn redcon.Conn, state *proxyConnState, n func (p *ProxyServer) handleQueuedCommand(conn redcon.Conn, state *proxyConnState, name string, args [][]byte) { switch name { - case "EXEC": + case cmdExec: p.execTxn(conn, state) - case "DISCARD": + case cmdDiscard: p.discardTxn(conn, state) - case "MULTI": + case cmdMulti: conn.WriteError("ERR MULTI calls can not be nested") default: state.txnQueue = append(state.txnQueue, args) @@ -224,6 +237,7 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args session := &pubsubSession{ dconn: dconn, upstream: upstream, + proxy: p, logger: p.logger, } @@ -253,7 +267,7 @@ func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) { name := strings.ToUpper(string(args[0])) // Handle PING locally for speed. - if name == "PING" { + if name == cmdPing { if len(args) > 1 { conn.WriteBulk(args[1]) } else { @@ -263,7 +277,7 @@ func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) { } // Handle QUIT locally. - if name == "QUIT" { + if name == cmdQuit { conn.WriteString("OK") conn.Close() return @@ -282,7 +296,7 @@ func (p *ProxyServer) handleScript(conn redcon.Conn, args [][]byte) { func (p *ProxyServer) handleTxnCommand(conn redcon.Conn, state *proxyConnState, name string) { switch name { - case "MULTI": + case cmdMulti: if state.inTxn { conn.WriteError("ERR MULTI calls can not be nested") return @@ -290,13 +304,13 @@ func (p *ProxyServer) handleTxnCommand(conn redcon.Conn, state *proxyConnState, state.inTxn = true state.txnQueue = nil conn.WriteString("OK") - case "EXEC": + case cmdExec: if !state.inTxn { conn.WriteError("ERR EXEC without MULTI") return } p.execTxn(conn, state) - case "DISCARD": + case cmdDiscard: if !state.inTxn { conn.WriteError("ERR DISCARD without MULTI") return @@ -350,57 +364,54 @@ func (p *ProxyServer) discardTxn(conn redcon.Conn, state *proxyConnState) { conn.WriteString("OK") } -// writeResponse handles the common pattern of writing a go-redis response -// to a redcon connection, correctly handling redis.Nil and upstream errors. -func writeResponse(conn redcon.Conn, resp interface{}, err error) { +// writeResponse handles the common pattern of writing a go-redis response, +// correctly handling redis.Nil and upstream errors. +func writeResponse(w respWriter, resp interface{}, err error) { if err != nil { if errors.Is(err, redis.Nil) { - conn.WriteNull() + w.WriteNull() return } - writeRedisError(conn, err) + writeRedisError(w, err) return } - writeRedisValue(conn, resp) + writeRedisValue(w, resp) } // writeRedisError writes an upstream error without double-prefixing. // Redis errors already contain their prefix (e.g. "ERR ...", "WRONGTYPE ..."). -func writeRedisError(conn redcon.Conn, err error) { +func writeRedisError(w respWriter, err error) { msg := err.Error() // go-redis errors are already formatted with prefix; pass through as-is. - conn.WriteError(msg) + w.WriteError(msg) } -// writeRedisValue writes a go-redis response value to a redcon connection. -func writeRedisValue(conn redcon.Conn, val interface{}) { +// writeRedisValue writes a go-redis response value to a respWriter. +func writeRedisValue(w respWriter, val interface{}) { if val == nil { - conn.WriteNull() + w.WriteNull() return } switch v := val.(type) { case string: - // go-redis flattens Status and Bulk strings into Go strings. - // Use WriteString (status reply) for known status responses, - // WriteBulkString (bulk reply) for data values. if isStatusResponse(v) { - conn.WriteString(v) + w.WriteString(v) } else { - conn.WriteBulkString(v) + w.WriteBulkString(v) } case int64: - conn.WriteInt64(v) + w.WriteInt64(v) case []interface{}: - conn.WriteArray(len(v)) + w.WriteArray(len(v)) for _, item := range v { - writeRedisValue(conn, item) + writeRedisValue(w, item) } case []byte: - conn.WriteBulk(v) + w.WriteBulk(v) case redis.Error: - conn.WriteError(v.Error()) + w.WriteError(v.Error()) default: - conn.WriteBulkString(fmt.Sprintf("%v", v)) + w.WriteBulkString(fmt.Sprintf("%v", v)) } } diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 9ba3d72c..efd42de1 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "fmt" "log/slog" "strings" "sync" @@ -21,35 +22,64 @@ const ( cmdUnsubscribe = "UNSUBSCRIBE" cmdPSubscribe = "PSUBSCRIBE" cmdPUnsubscribe = "PUNSUBSCRIBE" + cmdMulti = "MULTI" + cmdExec = "EXEC" + cmdDiscard = "DISCARD" + cmdPing = "PING" + cmdQuit = "QUIT" ) -// pubsubSession manages a single client's pub/sub session. +// pubsubSession manages a single client's detached connection. // It bridges a detached redcon connection to an upstream go-redis PubSub connection. +// When all subscriptions are removed, the session transitions to normal command mode, +// enabling the client to execute regular Redis commands without reconnecting. type pubsubSession struct { mu sync.Mutex dconn redcon.DetachedConn - upstream *redis.PubSub + upstream *redis.PubSub // nil when not in pub/sub mode + proxy *ProxyServer logger *slog.Logger closed bool // Track subscription counts for RESP replies. channels int patterns int + + // fwdDone is closed when the current forwardMessages goroutine exits. + fwdDone chan struct{} + + // Transaction state for normal command mode. + inTxn bool + txnQueue [][][]byte } -// run starts the forwarding session. It blocks until the client disconnects -// or the upstream closes. +// run starts the session. It blocks until the client disconnects or sends QUIT. func (s *pubsubSession) run() { - defer func() { + defer s.cleanup() + s.startForwarding() + s.commandLoop() +} + +func (s *pubsubSession) cleanup() { + s.mu.Lock() + s.closed = true + if s.upstream != nil { s.upstream.Close() - s.mu.Lock() - s.closed = true - s.mu.Unlock() - s.dconn.Close() - }() + s.upstream = nil + } + s.mu.Unlock() + if s.fwdDone != nil { + <-s.fwdDone + } + s.dconn.Close() +} - go s.forwardMessages() - s.readClientCommands() +func (s *pubsubSession) startForwarding() { + s.fwdDone = make(chan struct{}) + go func() { + defer close(s.fwdDone) + s.forwardMessages() + }() } // forwardMessages reads from the upstream go-redis PubSub channel and writes @@ -82,10 +112,10 @@ func (s *pubsubSession) forwardMessages() { } } -// readClientCommands reads commands from the detached client connection. -// In pub/sub mode, only SUBSCRIBE, UNSUBSCRIBE, PSUBSCRIBE, PUNSUBSCRIBE, -// PING, and QUIT are valid. -func (s *pubsubSession) readClientCommands() { +// commandLoop reads commands from the detached client and dispatches them. +// In pub/sub mode, only subscription commands are allowed. +// When all subscriptions are removed, it transitions to normal command mode. +func (s *pubsubSession) commandLoop() { for { cmd, err := s.dconn.ReadCommand() if err != nil { @@ -94,12 +124,45 @@ func (s *pubsubSession) readClientCommands() { if len(cmd.Args) == 0 { continue } - if !s.dispatchPubSubCommand(cmd.Args) { + args := cloneArgs(cmd.Args) + name := strings.ToUpper(string(args[0])) + + s.mu.Lock() + inPubSub := s.channels > 0 || s.patterns > 0 + s.mu.Unlock() + + if inPubSub { + if !s.dispatchPubSubCommand(args) { + return + } + if s.shouldExitPubSub() { + s.exitPubSubMode() + } + } else if !s.dispatchNormalCommand(name, args) { return } } } +func (s *pubsubSession) shouldExitPubSub() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.upstream != nil && s.channels == 0 && s.patterns == 0 +} + +func (s *pubsubSession) exitPubSubMode() { + s.mu.Lock() + if s.upstream != nil { + s.upstream.Close() + s.upstream = nil + } + s.mu.Unlock() + if s.fwdDone != nil { + <-s.fwdDone + s.fwdDone = nil + } +} + // dispatchPubSubCommand handles a single command in pub/sub mode. // Returns false if the session should end (QUIT). func (s *pubsubSession) dispatchPubSubCommand(args [][]byte) bool { @@ -112,9 +175,9 @@ func (s *pubsubSession) dispatchPubSubCommand(args [][]byte) bool { s.handlePSubscribe(args) case cmdPUnsubscribe: s.handleUnsub(args, true) - case "PING": - s.handlePing(args) - case "QUIT": + case cmdPing: + s.handlePubSubPing(args) + case cmdQuit: return false default: s.writeError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") @@ -122,6 +185,190 @@ func (s *pubsubSession) dispatchPubSubCommand(args [][]byte) bool { return true } +// dispatchNormalCommand handles a command in normal (non-pub/sub) mode. +// Returns false if the session should end (QUIT). +func (s *pubsubSession) dispatchNormalCommand(name string, args [][]byte) bool { + if name == cmdQuit { + s.writeString("OK") + return false + } + if name == cmdPing { + s.handleNormalPing(args) + return true + } + if name == cmdSubscribe || name == cmdPSubscribe { + s.reenterPubSub(name, args) + return true + } + if name == cmdUnsubscribe || name == cmdPUnsubscribe { + s.handleUnsubNoSession(name) + return true + } + if s.handleTxnInSession(name, args) { + return true + } + s.dispatchRegularCommand(name, args) + return true +} + +// handleTxnInSession handles transaction commands in normal mode. +// Returns true if the command was handled as a transaction command. +func (s *pubsubSession) handleTxnInSession(name string, args [][]byte) bool { + switch name { + case cmdMulti: + if s.inTxn { + s.writeError("ERR MULTI calls can not be nested") + } else { + s.inTxn = true + s.txnQueue = nil + s.writeString("OK") + } + return true + case cmdExec: + if !s.inTxn { + s.writeError("ERR EXEC without MULTI") + } else { + s.execTxn() + } + return true + case cmdDiscard: + if !s.inTxn { + s.writeError("ERR DISCARD without MULTI") + } else { + s.inTxn = false + s.txnQueue = nil + s.writeString("OK") + } + return true + } + if s.inTxn { + s.txnQueue = append(s.txnQueue, args) + s.writeString("QUEUED") + return true + } + return false +} + +// dispatchRegularCommand sends a non-transaction, non-special command to the backend. +func (s *pubsubSession) dispatchRegularCommand(name string, args [][]byte) { + cat := ClassifyCommand(name, args[1:]) + ctx := context.Background() + + var resp interface{} + var err error + + switch cat { + case CmdWrite: + resp, err = s.proxy.dual.Write(ctx, args) + case CmdRead: + resp, err = s.proxy.dual.Read(ctx, args) + case CmdBlocking: + resp, err = s.proxy.dual.Blocking(s.proxy.shutdownCtx, args) + case CmdPubSub: + resp, err = s.proxy.dual.Admin(ctx, args) + case CmdAdmin: + resp, err = s.proxy.dual.Admin(ctx, args) + case CmdScript: + resp, err = s.proxy.dual.Script(ctx, args) + case CmdTxn: + // Handled by handleTxnInSession; should not reach here. + return + } + + s.mu.Lock() + writeResponse(s.dconn, resp, err) + _ = s.dconn.Flush() + s.mu.Unlock() +} + +func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { + if len(args) < pubsubMinArgs { + s.writeError(fmt.Sprintf("ERR wrong number of arguments for '%s' command", strings.ToLower(cmdName))) + return + } + psBackend := s.proxy.dual.PubSubBackend() + if psBackend == nil { + s.writeError("ERR PubSub not supported by backend") + return + } + + upstream := psBackend.NewPubSub(context.Background()) + channels := byteSlicesToStrings(args[1:]) + var err error + if cmdName == cmdSubscribe { + err = upstream.Subscribe(context.Background(), channels...) + } else { + err = upstream.PSubscribe(context.Background(), channels...) + } + if err != nil { + upstream.Close() + s.writeError("ERR " + err.Error()) + return + } + + s.mu.Lock() + s.upstream = upstream + s.mu.Unlock() + s.startForwarding() + + kind := strings.ToLower(cmdName) + s.mu.Lock() + for _, ch := range channels { + if cmdName == cmdSubscribe { + s.channels++ + } else { + s.patterns++ + } + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(ch) + s.dconn.WriteInt(s.channels + s.patterns) + } + _ = s.dconn.Flush() + s.mu.Unlock() +} + +func (s *pubsubSession) execTxn() { + queue := s.txnQueue + s.inTxn = false + s.txnQueue = nil + + ctx := context.Background() + cmds := make([][]interface{}, 0, len(queue)+txnCommandsOverhead) + cmds = append(cmds, []interface{}{"MULTI"}) + for _, args := range queue { + cmds = append(cmds, bytesArgsToInterfaces(args)) + } + cmds = append(cmds, []interface{}{"EXEC"}) + + results, err := s.proxy.dual.Primary().Pipeline(ctx, cmds) + + s.mu.Lock() + if len(results) > 0 { + lastResult := results[len(results)-1] + resp, rErr := lastResult.Result() + writeResponse(s.dconn, resp, rErr) + } else if err != nil { + writeRedisError(s.dconn, err) + } + _ = s.dconn.Flush() + s.mu.Unlock() + + if s.proxy.dual.hasSecondaryWrite() { + s.proxy.dual.goAsync(func() { + sCtx, cancel := context.WithTimeout(context.Background(), s.proxy.cfg.SecondaryTimeout) + defer cancel() + _, pErr := s.proxy.dual.Secondary().Pipeline(sCtx, cmds) + if pErr != nil { + s.proxy.logger.Warn("secondary txn replay failed", "err", pErr) + s.proxy.metrics.SecondaryWriteErrors.Inc() + } + }) + } +} + +// --- Subscription handlers --- + func (s *pubsubSession) handleSubscribe(args [][]byte) { if len(args) < pubsubMinArgs { s.writeError("ERR wrong number of arguments for 'subscribe'") @@ -220,7 +467,9 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { s.mu.Unlock() } -func (s *pubsubSession) handlePing(args [][]byte) { +// --- Ping handlers --- + +func (s *pubsubSession) handlePubSubPing(args [][]byte) { s.mu.Lock() defer s.mu.Unlock() s.dconn.WriteArray(pubsubArrayPong) @@ -233,6 +482,29 @@ func (s *pubsubSession) handlePing(args [][]byte) { _ = s.dconn.Flush() } +func (s *pubsubSession) handleNormalPing(args [][]byte) { + s.mu.Lock() + defer s.mu.Unlock() + if len(args) > 1 { + s.dconn.WriteBulk(args[1]) + } else { + s.dconn.WriteString("PONG") + } + _ = s.dconn.Flush() +} + +func (s *pubsubSession) handleUnsubNoSession(cmdName string) { + s.mu.Lock() + defer s.mu.Unlock() + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(strings.ToLower(cmdName)) + s.dconn.WriteNull() + s.dconn.WriteInt64(0) + _ = s.dconn.Flush() +} + +// --- Helpers --- + func (s *pubsubSession) writeError(msg string) { s.mu.Lock() defer s.mu.Unlock() @@ -240,6 +512,13 @@ func (s *pubsubSession) writeError(msg string) { _ = s.dconn.Flush() } +func (s *pubsubSession) writeString(msg string) { + s.mu.Lock() + defer s.mu.Unlock() + s.dconn.WriteString(msg) + _ = s.dconn.Flush() +} + func byteSlicesToStrings(bs [][]byte) []string { out := make([]string, len(bs)) for i, b := range bs { From 7bc0300e50d7fa54ad5d88d1c6e8cfc4a5c05ae4 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 17:51:55 +0900 Subject: [PATCH 06/43] Replace interface{} with any in codebase --- proxy/backend.go | 8 ++++---- proxy/compare.go | 26 +++++++++++++------------- proxy/dualwrite.go | 16 ++++++++-------- proxy/proxy.go | 12 ++++++------ proxy/proxy_test.go | 26 +++++++++++++------------- proxy/pubsub.go | 8 ++++---- 6 files changed, 48 insertions(+), 48 deletions(-) diff --git a/proxy/backend.go b/proxy/backend.go index 75ebd695..2004f158 100644 --- a/proxy/backend.go +++ b/proxy/backend.go @@ -18,9 +18,9 @@ const ( // Backend abstracts a Redis-protocol endpoint (real Redis or ElasticKV). type Backend interface { // Do sends a single command and returns its result. - Do(ctx context.Context, args ...interface{}) *redis.Cmd + Do(ctx context.Context, args ...any) *redis.Cmd // Pipeline sends multiple commands in a pipeline. - Pipeline(ctx context.Context, cmds [][]interface{}) ([]*redis.Cmd, error) + Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) // Close releases the underlying connection. Close() error // Name identifies this backend for logging and metrics. @@ -76,11 +76,11 @@ func NewRedisBackendWithOptions(addr string, name string, opts BackendOptions) * } } -func (b *RedisBackend) Do(ctx context.Context, args ...interface{}) *redis.Cmd { +func (b *RedisBackend) Do(ctx context.Context, args ...any) *redis.Cmd { return b.client.Do(ctx, args...) } -func (b *RedisBackend) Pipeline(ctx context.Context, cmds [][]interface{}) ([]*redis.Cmd, error) { +func (b *RedisBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) { pipe := b.client.Pipeline() results := make([]*redis.Cmd, len(cmds)) for i, args := range cmds { diff --git a/proxy/compare.go b/proxy/compare.go index 183b4497..6b8e39d7 100644 --- a/proxy/compare.go +++ b/proxy/compare.go @@ -44,8 +44,8 @@ type Divergence struct { Command string Key string Kind DivergenceKind - Primary interface{} - Secondary interface{} + Primary any + Secondary any DetectedAt time.Time } @@ -75,7 +75,7 @@ func NewShadowReader(secondary Backend, metrics *ProxyMetrics, sentryReporter *S } // Compare issues the same read to the secondary and checks for divergence. -func (s *ShadowReader) Compare(ctx context.Context, cmd string, args [][]byte, primaryResp interface{}, primaryErr error) { +func (s *ShadowReader) Compare(ctx context.Context, cmd string, args [][]byte, primaryResp any, primaryErr error) { sCtx, cancel := context.WithTimeout(ctx, s.timeout) defer cancel() @@ -118,7 +118,7 @@ func (s *ShadowReader) Compare(ctx context.Context, cmd string, args [][]byte, p } // isConsistent checks whether primary and secondary responses agree. -func isConsistent(primaryResp, secondaryResp interface{}, primaryErr, secondaryErr error) bool { +func isConsistent(primaryResp, secondaryResp any, primaryErr, secondaryErr error) bool { // Both are redis.Nil → consistent (key missing on both) if isNilError(primaryErr) && isNilError(secondaryErr) { return true @@ -132,7 +132,7 @@ func isConsistent(primaryResp, secondaryResp interface{}, primaryErr, secondaryE } // responseEqual compares two go-redis response values for equality. -func responseEqual(a, b interface{}) bool { +func responseEqual(a, b any) bool { if a == nil || b == nil { return a == nil && b == nil } @@ -143,7 +143,7 @@ func responseEqual(a, b interface{}) bool { case int64: bv, ok := b.(int64) return ok && av == bv - case []interface{}: + case []any: return interfaceSliceEqual(av, b) default: return reflect.DeepEqual(a, b) @@ -151,8 +151,8 @@ func responseEqual(a, b interface{}) bool { } // interfaceSliceEqual compares two []interface{} slices element-by-element. -func interfaceSliceEqual(av []interface{}, b interface{}) bool { - bv, ok := b.([]interface{}) +func interfaceSliceEqual(av []any, b any) bool { + bv, ok := b.([]any) if !ok || len(av) != len(bv) { return false } @@ -165,7 +165,7 @@ func interfaceSliceEqual(av []interface{}, b interface{}) bool { } // classifyDivergence determines the kind based on primary/secondary values. -func classifyDivergence(primaryResp interface{}, primaryErr error, secondaryResp interface{}, secondaryErr error) DivergenceKind { +func classifyDivergence(primaryResp any, primaryErr error, secondaryResp any, secondaryErr error) DivergenceKind { primaryNil := isNilResp(primaryResp, primaryErr) secondaryNil := isNilResp(secondaryResp, secondaryErr) @@ -185,14 +185,14 @@ func isNilError(err error) bool { // isNilResp checks if a response represents "no data" (nil response or redis.Nil error). // Empty string is NOT nil — it is a valid value. -func isNilResp(resp interface{}, err error) bool { +func isNilResp(resp any, err error) bool { if errors.Is(err, redis.Nil) { return true } return resp == nil } -func formatResp(resp interface{}, err error) interface{} { +func formatResp(resp any, err error) any { if err != nil { return fmt.Sprintf("error: %v", err) } @@ -206,8 +206,8 @@ func extractKey(args [][]byte) string { return "" } -func bytesArgsToInterfaces(args [][]byte) []interface{} { - out := make([]interface{}, len(args)) +func bytesArgsToInterfaces(args [][]byte) []any { + out := make([]any, len(args)) for i, a := range args { out[i] = a } diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index 494d5518..9212c7f2 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -54,7 +54,7 @@ func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMe } // Write sends a write command to the primary synchronously, then to the secondary asynchronously. -func (d *DualWriter) Write(ctx context.Context, args [][]byte) (interface{}, error) { +func (d *DualWriter) Write(ctx context.Context, args [][]byte) (any, error) { cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) @@ -83,7 +83,7 @@ func (d *DualWriter) Write(ctx context.Context, args [][]byte) (interface{}, err } // Read sends a read command to the primary and optionally performs a shadow read. -func (d *DualWriter) Read(ctx context.Context, args [][]byte) (interface{}, error) { +func (d *DualWriter) Read(ctx context.Context, args [][]byte) (any, error) { cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) @@ -117,7 +117,7 @@ func (d *DualWriter) Read(ctx context.Context, args [][]byte) (interface{}, erro // Blocking forwards a blocking command to the primary only. // Optionally sends a short-timeout version to secondary for warmup. -func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (interface{}, error) { +func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (any, error) { cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) @@ -148,7 +148,7 @@ func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (interface{}, } // Admin forwards an admin command to the primary only. -func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (interface{}, error) { +func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (any, error) { cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) @@ -169,7 +169,7 @@ func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (interface{}, err } // Script forwards EVAL/EVALSHA to the primary, and async replays to secondary. -func (d *DualWriter) Script(ctx context.Context, args [][]byte) (interface{}, error) { +func (d *DualWriter) Script(ctx context.Context, args [][]byte) (any, error) { cmd := strings.ToUpper(string(args[0])) iArgs := bytesArgsToInterfaces(args) @@ -194,7 +194,7 @@ func (d *DualWriter) Script(ctx context.Context, args [][]byte) (interface{}, er return resp, nil } -func (d *DualWriter) writeSecondary(cmd string, iArgs []interface{}) { +func (d *DualWriter) writeSecondary(cmd string, iArgs []any) { sCtx, cancel := context.WithTimeout(context.Background(), d.cfg.SecondaryTimeout) defer cancel() @@ -259,13 +259,13 @@ func (d *DualWriter) Secondary() Backend { return d.secondary } -func argsToBytes(iArgs []interface{}) [][]byte { +func argsToBytes(iArgs []any) [][]byte { out := make([][]byte, len(iArgs)) for i, a := range iArgs { if b, ok := a.([]byte); ok { out[i] = b } else { - out[i] = []byte(fmt.Sprintf("%v", a)) + out[i] = fmt.Appendf(nil, "%v", a) } } return out diff --git a/proxy/proxy.go b/proxy/proxy.go index 565992ec..5881223f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -327,12 +327,12 @@ func (p *ProxyServer) execTxn(conn redcon.Conn, state *proxyConnState) { ctx := context.Background() // Build pipeline: MULTI + queued commands + EXEC - cmds := make([][]interface{}, 0, len(queue)+txnCommandsOverhead) - cmds = append(cmds, []interface{}{"MULTI"}) + cmds := make([][]any, 0, len(queue)+txnCommandsOverhead) + cmds = append(cmds, []any{"MULTI"}) for _, args := range queue { cmds = append(cmds, bytesArgsToInterfaces(args)) } - cmds = append(cmds, []interface{}{"EXEC"}) + cmds = append(cmds, []any{"EXEC"}) results, err := p.dual.Primary().Pipeline(ctx, cmds) if len(results) > 0 { @@ -366,7 +366,7 @@ func (p *ProxyServer) discardTxn(conn redcon.Conn, state *proxyConnState) { // writeResponse handles the common pattern of writing a go-redis response, // correctly handling redis.Nil and upstream errors. -func writeResponse(w respWriter, resp interface{}, err error) { +func writeResponse(w respWriter, resp any, err error) { if err != nil { if errors.Is(err, redis.Nil) { w.WriteNull() @@ -387,7 +387,7 @@ func writeRedisError(w respWriter, err error) { } // writeRedisValue writes a go-redis response value to a respWriter. -func writeRedisValue(w respWriter, val interface{}) { +func writeRedisValue(w respWriter, val any) { if val == nil { w.WriteNull() return @@ -401,7 +401,7 @@ func writeRedisValue(w respWriter, val interface{}) { } case int64: w.WriteInt64(v) - case []interface{}: + case []any: w.WriteArray(len(v)) for _, item := range v { writeRedisValue(w, item) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 505659e6..6804c9f0 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -21,16 +21,16 @@ var testLogger = slog.New(slog.NewTextHandler(io.Discard, nil)) type mockBackend struct { name string - doFunc func(ctx context.Context, args ...interface{}) *redis.Cmd + doFunc func(ctx context.Context, args ...any) *redis.Cmd mu sync.Mutex - calls [][]interface{} + calls [][]any } func newMockBackend(name string) *mockBackend { return &mockBackend{name: name} } -func (b *mockBackend) Do(ctx context.Context, args ...interface{}) *redis.Cmd { +func (b *mockBackend) Do(ctx context.Context, args ...any) *redis.Cmd { b.mu.Lock() b.calls = append(b.calls, args) b.mu.Unlock() @@ -42,7 +42,7 @@ func (b *mockBackend) Do(ctx context.Context, args ...interface{}) *redis.Cmd { return cmd } -func (b *mockBackend) Pipeline(ctx context.Context, cmds [][]interface{}) ([]*redis.Cmd, error) { +func (b *mockBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) { results := make([]*redis.Cmd, len(cmds)) for i, args := range cmds { results[i] = b.Do(ctx, args...) @@ -60,8 +60,8 @@ func (b *mockBackend) CallCount() int { } // Helper to create a doFunc that returns a specific value. -func makeCmd(val interface{}, err error) func(ctx context.Context, args ...interface{}) *redis.Cmd { - return func(ctx context.Context, args ...interface{}) *redis.Cmd { +func makeCmd(val any, err error) func(ctx context.Context, args ...any) *redis.Cmd { + return func(ctx context.Context, args ...any) *redis.Cmd { cmd := redis.NewCmd(ctx, args...) if err != nil { cmd.SetErr(err) @@ -204,7 +204,7 @@ func TestBytesArgsToInterfaces(t *testing.T) { func TestResponseEqual(t *testing.T) { tests := []struct { name string - a, b interface{} + a, b any want bool }{ {"nil nil", nil, nil, true}, @@ -215,13 +215,13 @@ func TestResponseEqual(t *testing.T) { {"empty string equals empty string", "", "", true}, {"same int64", int64(42), int64(42), true}, {"diff int64", int64(42), int64(43), false}, - {"same array", []interface{}{"a", "b"}, []interface{}{"a", "b"}, true}, - {"diff array values", []interface{}{"a", "b"}, []interface{}{"a", "c"}, false}, - {"diff array length", []interface{}{"a"}, []interface{}{"a", "b"}, false}, - {"empty arrays", []interface{}{}, []interface{}{}, true}, + {"same array", []any{"a", "b"}, []any{"a", "b"}, true}, + {"diff array values", []any{"a", "b"}, []any{"a", "c"}, false}, + {"diff array length", []any{"a"}, []any{"a", "b"}, false}, + {"empty arrays", []any{}, []any{}, true}, {"same bytes", []byte("data"), []byte("data"), true}, {"diff bytes", []byte("data"), []byte("other"), false}, - {"nested array", []interface{}{[]interface{}{"x"}}, []interface{}{[]interface{}{"x"}}, true}, + {"nested array", []any{[]any{"x"}}, []any{[]any{"x"}}, true}, {"type mismatch string vs int", "42", int64(42), false}, } @@ -235,7 +235,7 @@ func TestResponseEqual(t *testing.T) { func TestClassifyDivergence(t *testing.T) { tests := []struct { name string - primaryResp, secondaryResp interface{} + primaryResp, secondaryResp any primaryErr, secondaryErr error want DivergenceKind }{ diff --git a/proxy/pubsub.go b/proxy/pubsub.go index efd42de1..93078a6e 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -254,7 +254,7 @@ func (s *pubsubSession) dispatchRegularCommand(name string, args [][]byte) { cat := ClassifyCommand(name, args[1:]) ctx := context.Background() - var resp interface{} + var resp any var err error switch cat { @@ -334,12 +334,12 @@ func (s *pubsubSession) execTxn() { s.txnQueue = nil ctx := context.Background() - cmds := make([][]interface{}, 0, len(queue)+txnCommandsOverhead) - cmds = append(cmds, []interface{}{"MULTI"}) + cmds := make([][]any, 0, len(queue)+txnCommandsOverhead) + cmds = append(cmds, []any{"MULTI"}) for _, args := range queue { cmds = append(cmds, bytesArgsToInterfaces(args)) } - cmds = append(cmds, []interface{}{"EXEC"}) + cmds = append(cmds, []any{"EXEC"}) results, err := s.proxy.dual.Primary().Pipeline(ctx, cmds) From 74e0f44016c5115460f66a09f9adf60cb7fe946d Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 18:01:52 +0900 Subject: [PATCH 07/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/pubsub.go | 1 + 1 file changed, 1 insertion(+) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 93078a6e..bff50614 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -178,6 +178,7 @@ func (s *pubsubSession) dispatchPubSubCommand(args [][]byte) bool { case cmdPing: s.handlePubSubPing(args) case cmdQuit: + s.writeString("OK") return false default: s.writeError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") From 534a39bc801e765e934b707e2edc9d324bd3aba3 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 19:20:41 +0900 Subject: [PATCH 08/43] Fix race condition, error normalization, and security concerns - Fix startForwarding race: capture upstream under lock before goroutine start - Reject SUBSCRIBE/PSUBSCRIBE during MULTI transaction to prevent state corruption - Normalize non-redis.Error to "ERR ..." prefix in writeRedisError for valid RESP - Truncate divergence values in Sentry reports to prevent data leakage --- proxy/proxy.go | 14 +++++++++----- proxy/pubsub.go | 17 ++++++++++++++--- proxy/sentry.go | 16 ++++++++++++++-- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 5881223f..19c3c84b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -378,12 +378,16 @@ func writeResponse(w respWriter, resp any, err error) { writeRedisValue(w, resp) } -// writeRedisError writes an upstream error without double-prefixing. -// Redis errors already contain their prefix (e.g. "ERR ...", "WRONGTYPE ..."). +// writeRedisError writes an upstream error to the client. +// go-redis redis.Error values already carry the Redis prefix (e.g. "ERR ...", "WRONGTYPE ..."). +// Other errors (timeouts, dial failures) are normalized to "ERR ..." to produce valid RESP. func writeRedisError(w respWriter, err error) { - msg := err.Error() - // go-redis errors are already formatted with prefix; pass through as-is. - w.WriteError(msg) + var redisErr redis.Error + if errors.As(err, &redisErr) { + w.WriteError(redisErr.Error()) + return + } + w.WriteError("ERR " + err.Error()) } // writeRedisValue writes a go-redis response value to a respWriter. diff --git a/proxy/pubsub.go b/proxy/pubsub.go index bff50614..adc91240 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -75,17 +75,24 @@ func (s *pubsubSession) cleanup() { } func (s *pubsubSession) startForwarding() { + // Capture upstream under lock to avoid race with exitPubSubMode. + s.mu.Lock() + upstream := s.upstream + s.mu.Unlock() + if upstream == nil { + return + } + ch := upstream.Channel() s.fwdDone = make(chan struct{}) go func() { defer close(s.fwdDone) - s.forwardMessages() + s.forwardMessages(ch) }() } // forwardMessages reads from the upstream go-redis PubSub channel and writes // messages to the detached client connection. -func (s *pubsubSession) forwardMessages() { - ch := s.upstream.Channel() +func (s *pubsubSession) forwardMessages(ch <-chan *redis.Message) { for msg := range ch { s.mu.Lock() if s.closed { @@ -198,6 +205,10 @@ func (s *pubsubSession) dispatchNormalCommand(name string, args [][]byte) bool { return true } if name == cmdSubscribe || name == cmdPSubscribe { + if s.inTxn { + s.writeError("ERR Command not allowed inside a transaction") + return true + } s.reenterPubSub(name, args) return true } diff --git a/proxy/sentry.go b/proxy/sentry.go index 82e1ddcd..f964d87d 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -11,6 +11,9 @@ import ( const ( defaultReportCooldown = 60 * time.Second + // maxSentryValueLen limits the length of values attached to Sentry events + // to prevent data leakage and oversized events. + maxSentryValueLen = 256 // maxReportEntries caps the lastReport map to prevent unbounded growth. maxReportEntries = 10000 ) @@ -80,8 +83,8 @@ func (r *SentryReporter) CaptureDivergence(div Divergence) { scope.SetTag("command", div.Command) scope.SetTag("key", div.Key) scope.SetTag("kind", div.Kind.String()) - scope.SetExtra("primary", fmt.Sprintf("%v", div.Primary)) - scope.SetExtra("secondary", fmt.Sprintf("%v", div.Secondary)) + scope.SetExtra("primary", truncateValue(div.Primary)) + scope.SetExtra("secondary", truncateValue(div.Secondary)) scope.SetFingerprint([]string{"divergence", div.Kind.String(), div.Command}) scope.SetLevel(sentry.LevelWarning) r.hub.CaptureMessage(fmt.Sprintf("data divergence: %s %s (%s)", div.Kind, div.Command, div.Key)) @@ -129,3 +132,12 @@ func cmdNameFromArgs(args [][]byte) string { } return unknownStr } + +// truncateValue formats a value for Sentry, truncating to avoid data leakage and oversized events. +func truncateValue(v any) string { + s := fmt.Sprintf("%v", v) + if len(s) > maxSentryValueLen { + return s[:maxSentryValueLen] + "...(truncated)" + } + return s +} From c005cf08ae46fcde104cc3f4e405f6bc6230abb1 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 19:23:56 +0900 Subject: [PATCH 09/43] Add DB/password config for backends and ignore SELECT/AUTH commands - Add PrimaryDB, PrimaryPassword, SecondaryDB, SecondaryPassword to ProxyConfig - Add DB and Password fields to BackendOptions, passed to go-redis client - Add CLI flags: -primary-db, -primary-password, -secondary-db, -secondary-password - Silently accept SELECT and AUTH commands (return OK without forwarding) since DB/auth are configured at the connection-pool level --- cmd/redis-proxy/main.go | 19 +++++++++++++++---- proxy/backend.go | 4 ++++ proxy/config.go | 24 ++++++++++++++---------- proxy/proxy.go | 7 +++++++ proxy/pubsub.go | 6 ++++++ 5 files changed, 46 insertions(+), 14 deletions(-) diff --git a/cmd/redis-proxy/main.go b/cmd/redis-proxy/main.go index d6c39bf9..165c4250 100644 --- a/cmd/redis-proxy/main.go +++ b/cmd/redis-proxy/main.go @@ -32,7 +32,11 @@ func run() error { flag.StringVar(&cfg.ListenAddr, "listen", cfg.ListenAddr, "Proxy listen address") flag.StringVar(&cfg.PrimaryAddr, "primary", cfg.PrimaryAddr, "Primary (Redis) address") + flag.IntVar(&cfg.PrimaryDB, "primary-db", cfg.PrimaryDB, "Primary Redis DB number") + flag.StringVar(&cfg.PrimaryPassword, "primary-password", cfg.PrimaryPassword, "Primary Redis password") flag.StringVar(&cfg.SecondaryAddr, "secondary", cfg.SecondaryAddr, "Secondary (ElasticKV) address") + flag.IntVar(&cfg.SecondaryDB, "secondary-db", cfg.SecondaryDB, "Secondary Redis DB number") + flag.StringVar(&cfg.SecondaryPassword, "secondary-password", cfg.SecondaryPassword, "Secondary Redis password") flag.StringVar(&modeStr, "mode", "dual-write", "Proxy mode: redis-only, dual-write, dual-write-shadow, elastickv-primary, elastickv-only") flag.DurationVar(&cfg.SecondaryTimeout, "secondary-timeout", cfg.SecondaryTimeout, "Secondary write timeout") flag.DurationVar(&cfg.ShadowTimeout, "shadow-timeout", cfg.ShadowTimeout, "Shadow read timeout") @@ -59,14 +63,21 @@ func run() error { metrics := proxy.NewProxyMetrics(reg) // Backends + primaryOpts := proxy.DefaultBackendOptions() + primaryOpts.DB = cfg.PrimaryDB + primaryOpts.Password = cfg.PrimaryPassword + secondaryOpts := proxy.DefaultBackendOptions() + secondaryOpts.DB = cfg.SecondaryDB + secondaryOpts.Password = cfg.SecondaryPassword + var primary, secondary proxy.Backend switch cfg.Mode { case proxy.ModeElasticKVPrimary, proxy.ModeElasticKVOnly: - primary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") - secondary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") + primary = proxy.NewRedisBackendWithOptions(cfg.SecondaryAddr, "elastickv", secondaryOpts) + secondary = proxy.NewRedisBackendWithOptions(cfg.PrimaryAddr, "redis", primaryOpts) case proxy.ModeRedisOnly, proxy.ModeDualWrite, proxy.ModeDualWriteShadow: - primary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") - secondary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") + primary = proxy.NewRedisBackendWithOptions(cfg.PrimaryAddr, "redis", primaryOpts) + secondary = proxy.NewRedisBackendWithOptions(cfg.SecondaryAddr, "elastickv", secondaryOpts) } defer primary.Close() defer secondary.Close() diff --git a/proxy/backend.go b/proxy/backend.go index 2004f158..6ca88d66 100644 --- a/proxy/backend.go +++ b/proxy/backend.go @@ -29,6 +29,8 @@ type Backend interface { // BackendOptions configures the underlying go-redis connection pool. type BackendOptions struct { + DB int + Password string PoolSize int DialTimeout time.Duration ReadTimeout time.Duration @@ -67,6 +69,8 @@ func NewRedisBackendWithOptions(addr string, name string, opts BackendOptions) * return &RedisBackend{ client: redis.NewClient(&redis.Options{ Addr: addr, + DB: opts.DB, + Password: opts.Password, PoolSize: opts.PoolSize, DialTimeout: opts.DialTimeout, ReadTimeout: opts.ReadTimeout, diff --git a/proxy/config.go b/proxy/config.go index 267f67eb..395308e9 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -49,16 +49,20 @@ func (m ProxyMode) String() string { // ProxyConfig holds all configuration for the dual-write proxy. type ProxyConfig struct { - ListenAddr string - PrimaryAddr string - SecondaryAddr string - Mode ProxyMode - SecondaryTimeout time.Duration - ShadowTimeout time.Duration - SentryDSN string - SentryEnv string - SentrySampleRate float64 - MetricsAddr string + ListenAddr string + PrimaryAddr string + PrimaryDB int + PrimaryPassword string + SecondaryAddr string + SecondaryDB int + SecondaryPassword string + Mode ProxyMode + SecondaryTimeout time.Duration + ShadowTimeout time.Duration + SentryDSN string + SentryEnv string + SentrySampleRate float64 + MetricsAddr string } // DefaultConfig returns a ProxyConfig with sensible defaults. diff --git a/proxy/proxy.go b/proxy/proxy.go index 19c3c84b..e960d37a 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -283,6 +283,13 @@ func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) { return } + // SELECT and AUTH are handled at the connection-pool level via config. + // Silently accept them so clients don't break. + if name == "SELECT" || name == "AUTH" { + conn.WriteString("OK") + return + } + resp, err := p.dual.Admin(context.Background(), args) writeResponse(conn, resp, err) } diff --git a/proxy/pubsub.go b/proxy/pubsub.go index adc91240..3107a45c 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -263,6 +263,12 @@ func (s *pubsubSession) handleTxnInSession(name string, args [][]byte) bool { // dispatchRegularCommand sends a non-transaction, non-special command to the backend. func (s *pubsubSession) dispatchRegularCommand(name string, args [][]byte) { + // SELECT and AUTH are handled at the connection-pool level; accept silently. + if name == "SELECT" || name == "AUTH" { + s.writeString("OK") + return + } + cat := ClassifyCommand(name, args[1:]) ctx := context.Background() From 430891218d01eded20f7029a31ab56ffdea61da1 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 23:02:32 +0900 Subject: [PATCH 10/43] Enhance pubsub with idempotent subscription handling --- proxy/proxy.go | 22 +- proxy/pubsub.go | 93 +++++--- proxy/pubsub_test.go | 554 +++++++++++++++++++++++++++++++++++++++++++ proxy/sentry.go | 6 +- 4 files changed, 636 insertions(+), 39 deletions(-) create mode 100644 proxy/pubsub_test.go diff --git a/proxy/proxy.go b/proxy/proxy.go index e960d37a..ae67c171 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -163,6 +163,10 @@ func (p *ProxyServer) handleQueuedCommand(conn redcon.Conn, state *proxyConnStat case cmdMulti: conn.WriteError("ERR MULTI calls can not be nested") default: + // NOTE: Commands are queued locally and always return QUEUED without + // upstream validation. Real Redis validates queued commands immediately + // (e.g., wrong arity returns an error). Full compatibility would require + // pinning a dedicated upstream connection for the MULTI..EXEC lifetime. state.txnQueue = append(state.txnQueue, args) conn.WriteString("QUEUED") } @@ -235,24 +239,26 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args dconn := conn.Detach() session := &pubsubSession{ - dconn: dconn, - upstream: upstream, - proxy: p, - logger: p.logger, + dconn: dconn, + upstream: upstream, + proxy: p, + logger: p.logger, + channelSet: make(map[string]struct{}), + patternSet: make(map[string]struct{}), } // Write initial subscription confirmations. kind := strings.ToLower(cmdName) - for i, ch := range channels { + for _, ch := range channels { dconn.WriteArray(pubsubArrayReply) dconn.WriteBulkString(kind) dconn.WriteBulkString(ch) if cmdName == cmdSubscribe { - session.channels = i + 1 + session.channelSet[ch] = struct{}{} } else { - session.patterns = i + 1 + session.patternSet[ch] = struct{}{} } - dconn.WriteInt(session.channels + session.patterns) + dconn.WriteInt(session.subCount()) } if err := dconn.Flush(); err != nil { dconn.Close() diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 3107a45c..7070e427 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -41,9 +41,10 @@ type pubsubSession struct { logger *slog.Logger closed bool - // Track subscription counts for RESP replies. - channels int - patterns int + // Track subscribed channels/patterns in sets for idempotent subscribe/unsubscribe + // and correct subscription count tracking (matching Redis behavior). + channelSet map[string]struct{} + patternSet map[string]struct{} // fwdDone is closed when the current forwardMessages goroutine exits. fwdDone chan struct{} @@ -53,6 +54,11 @@ type pubsubSession struct { txnQueue [][][]byte } +// subCount returns the total number of active subscriptions (channels + patterns). +func (s *pubsubSession) subCount() int { + return len(s.channelSet) + len(s.patternSet) +} + // run starts the session. It blocks until the client disconnects or sends QUIT. func (s *pubsubSession) run() { defer s.cleanup() @@ -135,7 +141,7 @@ func (s *pubsubSession) commandLoop() { name := strings.ToUpper(string(args[0])) s.mu.Lock() - inPubSub := s.channels > 0 || s.patterns > 0 + inPubSub := s.subCount() > 0 s.mu.Unlock() if inPubSub { @@ -154,7 +160,7 @@ func (s *pubsubSession) commandLoop() { func (s *pubsubSession) shouldExitPubSub() bool { s.mu.Lock() defer s.mu.Unlock() - return s.upstream != nil && s.channels == 0 && s.patterns == 0 + return s.upstream != nil && s.subCount() == 0 } func (s *pubsubSession) exitPubSubMode() { @@ -333,14 +339,14 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { s.mu.Lock() for _, ch := range channels { if cmdName == cmdSubscribe { - s.channels++ + s.channelSet[ch] = struct{}{} } else { - s.patterns++ + s.patternSet[ch] = struct{}{} } s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(ch) - s.dconn.WriteInt(s.channels + s.patterns) + s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() s.mu.Unlock() @@ -399,11 +405,12 @@ func (s *pubsubSession) handleSubscribe(args [][]byte) { } s.mu.Lock() for _, ch := range channels { - s.channels++ + // Idempotent: Redis treats re-subscribe as a no-op for counting. + s.channelSet[ch] = struct{}{} s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString("subscribe") s.dconn.WriteBulkString(ch) - s.dconn.WriteInt(s.channels + s.patterns) + s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() s.mu.Unlock() @@ -421,11 +428,12 @@ func (s *pubsubSession) handlePSubscribe(args [][]byte) { } s.mu.Lock() for _, p := range pats { - s.patterns++ + // Idempotent: Redis treats re-subscribe as a no-op for counting. + s.patternSet[p] = struct{}{} s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString("psubscribe") s.dconn.WriteBulkString(p) - s.dconn.WriteInt(s.channels + s.patterns) + s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() s.mu.Unlock() @@ -442,21 +450,12 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { } if len(args) < pubsubMinArgs { - // Unsubscribe all + // Unsubscribe all: emit per-channel reply (matching Redis behavior). if err := unsubFn(context.Background()); err != nil { s.logger.Warn("upstream "+kind+" failed", "err", err) } s.mu.Lock() - if isPattern { - s.patterns = 0 - } else { - s.channels = 0 - } - s.dconn.WriteArray(pubsubArrayReply) - s.dconn.WriteBulkString(kind) - s.dconn.WriteNull() - s.dconn.WriteInt(s.channels + s.patterns) - _ = s.dconn.Flush() + s.writeUnsubAll(kind, isPattern) s.mu.Unlock() return } @@ -468,23 +467,57 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { s.mu.Lock() for _, n := range names { if isPattern { - if s.patterns > 0 { - s.patterns-- - } + delete(s.patternSet, n) } else { - if s.channels > 0 { - s.channels-- - } + delete(s.channelSet, n) } s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(n) - s.dconn.WriteInt(s.channels + s.patterns) + s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() s.mu.Unlock() } +// writeUnsubAll emits per-channel/pattern unsubscribe replies and clears the set. +// Must be called with s.mu held. +func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { + set := s.channelSet + if isPattern { + set = s.patternSet + } + + if len(set) == 0 { + // No subscriptions: single reply with null channel (matching Redis). + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteNull() + s.dconn.WriteInt(s.subCount()) + _ = s.dconn.Flush() + return + } + + // Collect names before clearing so we can emit per-channel replies. + names := make([]string, 0, len(set)) + for n := range set { + names = append(names, n) + } + if isPattern { + s.patternSet = make(map[string]struct{}) + } else { + s.channelSet = make(map[string]struct{}) + } + + for _, n := range names { + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(n) + s.dconn.WriteInt(s.subCount()) + } + _ = s.dconn.Flush() +} + // --- Ping handlers --- func (s *pubsubSession) handlePubSubPing(args [][]byte) { diff --git a/proxy/pubsub_test.go b/proxy/pubsub_test.go new file mode 100644 index 00000000..670a4e94 --- /dev/null +++ b/proxy/pubsub_test.go @@ -0,0 +1,554 @@ +package proxy + +import ( + "errors" + "net" + "sync" + "testing" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/tidwall/redcon" +) + +// mockDetachedConn implements redcon.DetachedConn for unit testing. +type mockDetachedConn struct { + mu sync.Mutex + commands []redcon.Command // queued commands to return from ReadCommand + cmdIdx int + writes []any // recorded writes: string for WriteString/WriteError/WriteBulkString, int for WriteInt/WriteArray, nil for WriteNull + closed bool + readErr error // error to return from ReadCommand when commands exhausted +} + +func newMockDetachedConn() *mockDetachedConn { + return &mockDetachedConn{ + readErr: errors.New("EOF"), + } +} + +func (m *mockDetachedConn) queueCommand(args ...string) { + bArgs := make([][]byte, len(args)) + for i, a := range args { + bArgs[i] = []byte(a) + } + m.commands = append(m.commands, redcon.Command{Args: bArgs}) +} + +func (m *mockDetachedConn) ReadCommand() (redcon.Command, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.cmdIdx >= len(m.commands) { + return redcon.Command{}, m.readErr + } + cmd := m.commands[m.cmdIdx] + m.cmdIdx++ + return cmd, nil +} + +func (m *mockDetachedConn) Flush() error { return nil } +func (m *mockDetachedConn) Close() error { m.closed = true; return nil } + +func (m *mockDetachedConn) RemoteAddr() string { return "127.0.0.1:9999" } + +func (m *mockDetachedConn) WriteError(msg string) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "ERR:"+msg) +} +func (m *mockDetachedConn) WriteString(str string) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "STR:"+str) +} +func (m *mockDetachedConn) WriteBulk(bulk []byte) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "BULK:"+string(bulk)) +} +func (m *mockDetachedConn) WriteBulkString(bulk string) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "BULKSTR:"+bulk) +} +func (m *mockDetachedConn) WriteInt(num int) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, num) +} +func (m *mockDetachedConn) WriteInt64(num int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, int(num)) +} +func (m *mockDetachedConn) WriteUint64(_ uint64) { + // Not used in pubsub tests. +} +func (m *mockDetachedConn) WriteArray(count int) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, count) +} +func (m *mockDetachedConn) WriteNull() { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, nil) +} +func (m *mockDetachedConn) WriteRaw(data []byte) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "RAW:"+string(data)) +} +func (m *mockDetachedConn) WriteAny(v any) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, v) +} +func (m *mockDetachedConn) Context() any { return nil } +func (m *mockDetachedConn) SetContext(v any) {} +func (m *mockDetachedConn) SetReadBuffer(n int) {} +func (m *mockDetachedConn) Detach() redcon.DetachedConn { return m } +func (m *mockDetachedConn) ReadPipeline() []redcon.Command { return nil } +func (m *mockDetachedConn) PeekPipeline() []redcon.Command { return nil } +func (m *mockDetachedConn) NetConn() net.Conn { return nil } + +func (m *mockDetachedConn) getWrites() []any { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]any, len(m.writes)) + copy(out, m.writes) + return out +} + +// newTestSession creates a pubsubSession with a mock connection for testing. +func newTestSession(dconn *mockDetachedConn) *pubsubSession { + return &pubsubSession{ + dconn: dconn, + logger: testLogger, + channelSet: make(map[string]struct{}), + patternSet: make(map[string]struct{}), + } +} + +func TestPubSub_SubscribeDuplicate_IdempotentCount(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // Simulate an upstream that's already connected. + s.channelSet["ch1"] = struct{}{} + + // Subscribe to ch1 again and ch2 (should count ch1 only once). + // We can't call handleSubscribe directly because it needs s.upstream, + // so we test the set logic directly. + s.channelSet["ch1"] = struct{}{} // re-add same key + s.channelSet["ch2"] = struct{}{} + + assert.Equal(t, 2, len(s.channelSet), "duplicate subscribe should not increase count") + assert.Equal(t, 2, s.subCount()) +} + +func TestPubSub_UnsubscribeNonSubscribed_NoEffect(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.channelSet["ch1"] = struct{}{} + s.channelSet["ch2"] = struct{}{} + + // Unsubscribe from "ch3" which was never subscribed. + delete(s.channelSet, "ch3") + + assert.Equal(t, 2, len(s.channelSet), "unsubscribe non-subscribed channel should not affect count") +} + +func TestPubSub_UnsubscribeSpecific_RemovesFromSet(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.channelSet["ch1"] = struct{}{} + s.channelSet["ch2"] = struct{}{} + s.channelSet["ch3"] = struct{}{} + + delete(s.channelSet, "ch2") + + assert.Equal(t, 2, len(s.channelSet)) + _, hasCh1 := s.channelSet["ch1"] + _, hasCh2 := s.channelSet["ch2"] + _, hasCh3 := s.channelSet["ch3"] + assert.True(t, hasCh1) + assert.False(t, hasCh2) + assert.True(t, hasCh3) +} + +func TestPubSub_WriteUnsubAll_PerChannelReplies(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.channelSet["ch1"] = struct{}{} + s.channelSet["ch2"] = struct{}{} + s.patternSet["pat1"] = struct{}{} + + s.mu.Lock() + s.writeUnsubAll("unsubscribe", false) + s.mu.Unlock() + + writes := dconn.getWrites() + + // Should have emitted 2 unsubscribe replies (one per channel). + // Each reply is: WriteArray(3), WriteBulkString("unsubscribe"), WriteBulkString(name), WriteInt(remaining) + assert.Equal(t, 0, len(s.channelSet), "channelSet should be cleared") + assert.Equal(t, 1, len(s.patternSet), "patternSet should not be affected") + + // Count array headers (each reply starts with WriteArray(3)) + arrCount := 0 + for _, w := range writes { + if n, ok := w.(int); ok && n == 3 { + arrCount++ + } + } + assert.Equal(t, 2, arrCount, "should emit one reply per unsubscribed channel") +} + +func TestPubSub_WriteUnsubAll_EmptySet_SingleNullReply(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // No subscriptions + s.mu.Lock() + s.writeUnsubAll("unsubscribe", false) + s.mu.Unlock() + + writes := dconn.getWrites() + + // Should emit single reply with null channel. + // WriteArray(3), WriteBulkString("unsubscribe"), WriteNull(), WriteInt(0) + assert.Len(t, writes, 4, "single null-channel reply expected") + + hasNull := false + for _, w := range writes { + if w == nil { + hasNull = true + } + } + assert.True(t, hasNull, "should contain null for empty unsubscribe-all") +} + +func TestPubSub_WriteUnsubAll_Patterns(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.patternSet["h*"] = struct{}{} + s.channelSet["ch1"] = struct{}{} + + s.mu.Lock() + s.writeUnsubAll("punsubscribe", true) + s.mu.Unlock() + + assert.Equal(t, 0, len(s.patternSet), "patternSet should be cleared") + assert.Equal(t, 1, len(s.channelSet), "channelSet should not be affected") + + writes := dconn.getWrites() + // Should have 1 reply (one pattern) + arrCount := 0 + for _, w := range writes { + if n, ok := w.(int); ok && n == 3 { + arrCount++ + } + } + assert.Equal(t, 1, arrCount) +} + +func TestPubSub_SubCount(t *testing.T) { + s := newTestSession(newMockDetachedConn()) + + assert.Equal(t, 0, s.subCount()) + + s.channelSet["a"] = struct{}{} + assert.Equal(t, 1, s.subCount()) + + s.patternSet["b*"] = struct{}{} + assert.Equal(t, 2, s.subCount()) + + s.channelSet["a"] = struct{}{} // duplicate + assert.Equal(t, 2, s.subCount()) + + delete(s.channelSet, "a") + assert.Equal(t, 1, s.subCount()) +} + +func TestPubSub_DispatchPubSubCommand_Quit(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchPubSubCommand([][]byte{[]byte("QUIT")}) + assert.False(t, cont, "QUIT should return false") + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:OK") +} + +func TestPubSub_DispatchPubSubCommand_InvalidCommand(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchPubSubCommand([][]byte{[]byte("GET"), []byte("key")}) + assert.True(t, cont, "invalid command should not end session") + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context" { + found = true + } + } + assert.True(t, found, "should write error for invalid pub/sub command") +} + +func TestPubSub_DispatchNormalCommand_Quit(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchNormalCommand("QUIT", [][]byte{[]byte("QUIT")}) + assert.False(t, cont, "QUIT should return false") +} + +func TestPubSub_DispatchNormalCommand_Ping(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchNormalCommand("PING", [][]byte{[]byte("PING")}) + assert.True(t, cont) + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:PONG") +} + +func TestPubSub_DispatchNormalCommand_PingWithMessage(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchNormalCommand("PING", [][]byte{[]byte("PING"), []byte("hello")}) + assert.True(t, cont) + + writes := dconn.getWrites() + assert.Contains(t, writes, "BULK:hello") +} + +func TestPubSub_HandlePubSubPing(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.handlePubSubPing([][]byte{[]byte("PING")}) + + writes := dconn.getWrites() + // ["pong", ""] + assert.Contains(t, writes, 2) // WriteArray(2) + assert.Contains(t, writes, "BULKSTR:pong") + assert.Contains(t, writes, "BULKSTR:") // empty string +} + +func TestPubSub_HandlePubSubPingWithData(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.handlePubSubPing([][]byte{[]byte("PING"), []byte("hello")}) + + writes := dconn.getWrites() + assert.Contains(t, writes, "BULK:hello") +} + +func TestPubSub_HandleUnsubNoSession(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.handleUnsubNoSession("UNSUBSCRIBE") + + writes := dconn.getWrites() + // ["unsubscribe", null, 0] + assert.Contains(t, writes, "BULKSTR:unsubscribe") + hasNull := false + for _, w := range writes { + if w == nil { + hasNull = true + } + } + assert.True(t, hasNull) + assert.Contains(t, writes, 0) // WriteInt64(0) +} + +func TestPubSub_SubscribeInTxnRejected(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + s.inTxn = true + + cont := s.dispatchNormalCommand("SUBSCRIBE", [][]byte{[]byte("SUBSCRIBE"), []byte("ch1")}) + assert.True(t, cont) + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR Command not allowed inside a transaction" { + found = true + } + } + assert.True(t, found, "SUBSCRIBE during MULTI should be rejected") +} + +func TestPubSub_HandleTxnInSession_MultiExecDiscard(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // MULTI + handled := s.handleTxnInSession("MULTI", nil) + assert.True(t, handled) + assert.True(t, s.inTxn) + + // Nested MULTI + dconn2 := newMockDetachedConn() + s2 := newTestSession(dconn2) + s2.inTxn = true + handled = s2.handleTxnInSession("MULTI", nil) + assert.True(t, handled) + writes := dconn2.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR MULTI calls can not be nested" { + found = true + } + } + assert.True(t, found) + + // Queue a command + handled = s.handleTxnInSession("SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.True(t, handled) + assert.Len(t, s.txnQueue, 1) + + // DISCARD + handled = s.handleTxnInSession("DISCARD", nil) + assert.True(t, handled) + assert.False(t, s.inTxn) + assert.Nil(t, s.txnQueue) + + // EXEC without MULTI + handled = s.handleTxnInSession("EXEC", nil) + assert.True(t, handled) + writes = dconn.getWrites() + found = false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR EXEC without MULTI" { + found = true + } + } + assert.True(t, found) + + // Non-txn command when not in txn + handled = s.handleTxnInSession("GET", [][]byte{[]byte("GET"), []byte("k")}) + assert.False(t, handled, "non-txn command outside txn should not be handled") +} + +func TestPubSub_ShouldExitPubSub(t *testing.T) { + s := newTestSession(newMockDetachedConn()) + + // No upstream, no subs → false (no upstream to close) + assert.False(t, s.shouldExitPubSub()) + + // With upstream but has subs → false + s.upstream = &redis.PubSub{} + s.channelSet["ch1"] = struct{}{} + assert.False(t, s.shouldExitPubSub()) + + // With upstream and no subs → true + delete(s.channelSet, "ch1") + assert.True(t, s.shouldExitPubSub()) +} + +func TestPubSub_ByteSlicesToStrings(t *testing.T) { + input := [][]byte{[]byte("hello"), []byte("world")} + result := byteSlicesToStrings(input) + assert.Equal(t, []string{"hello", "world"}, result) + + // Empty + assert.Equal(t, []string{}, byteSlicesToStrings([][]byte{})) +} + +func TestPubSub_SelectAuthSilentlyAccepted(t *testing.T) { + for _, cmd := range []string{"SELECT", "AUTH"} { + t.Run(cmd, func(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.dispatchRegularCommand(cmd, [][]byte{[]byte(cmd), []byte("0")}) + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:OK") + }) + } +} + +// TestPubSub_DuplicateSubscribeDoesNotOvercount verifies that subscribing to the +// same channel multiple times doesn't inflate the subscription count. +func TestPubSub_DuplicateSubscribeDoesNotOvercount(t *testing.T) { + s := newTestSession(newMockDetachedConn()) + + // First subscribe + s.channelSet["ch1"] = struct{}{} + assert.Equal(t, 1, s.subCount()) + + // Duplicate subscribe (idempotent) + s.channelSet["ch1"] = struct{}{} + assert.Equal(t, 1, s.subCount(), "duplicate subscribe must not increase count") + + // Add new channel + s.channelSet["ch2"] = struct{}{} + assert.Equal(t, 2, s.subCount()) +} + +// TestPubSub_UnsubscribeNonExistent verifies that unsubscribing from a channel +// that was never subscribed does not affect the count. +func TestPubSub_UnsubscribeNonExistent(t *testing.T) { + s := newTestSession(newMockDetachedConn()) + + s.channelSet["ch1"] = struct{}{} + s.channelSet["ch2"] = struct{}{} + + // Delete non-existent: no panic, no effect + delete(s.channelSet, "never-subscribed") + assert.Equal(t, 2, s.subCount()) + + // Delete existing + delete(s.channelSet, "ch1") + assert.Equal(t, 1, s.subCount()) +} + +// TestPubSub_CleanupClosesUpstream verifies that cleanup closes upstream and dconn. +func TestPubSub_CleanupClosesUpstream(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.cleanup() + assert.True(t, s.closed) + assert.True(t, dconn.closed) + assert.Nil(t, s.upstream) +} + +// TestPubSub_CommandLoop_EmptyArgs verifies that empty command args are skipped. +func TestPubSub_CommandLoop_EmptyArgs(t *testing.T) { + dconn := newMockDetachedConn() + // Queue empty args then QUIT + dconn.commands = append(dconn.commands, redcon.Command{Args: nil}) + dconn.queueCommand("QUIT") + + s := newTestSession(dconn) + // Not in pub/sub mode, so should dispatch as normal commands + s.commandLoop() + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:OK", "QUIT should be handled") +} + +func TestPubSub_CommandLoop_EOF(t *testing.T) { + dconn := newMockDetachedConn() + dconn.readErr = errors.New("connection closed") + + s := newTestSession(dconn) + s.commandLoop() // should return without panic +} diff --git a/proxy/sentry.go b/proxy/sentry.go index f964d87d..51537e7b 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -74,11 +74,15 @@ func (r *SentryReporter) CaptureException(err error, operation string, args [][] }) } -// CaptureDivergence reports a data divergence to Sentry. +// CaptureDivergence reports a data divergence to Sentry with cooldown-based de-duplication. func (r *SentryReporter) CaptureDivergence(div Divergence) { if !r.enabled { return } + fingerprint := fmt.Sprintf("divergence_%s_%s", div.Kind.String(), div.Command) + if !r.ShouldReport(fingerprint) { + return + } r.hub.WithScope(func(scope *sentry.Scope) { scope.SetTag("command", div.Command) scope.SetTag("key", div.Key) From 18e1aaa48f03d2317cad281ad09868e2e042fe52 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 23:17:17 +0900 Subject: [PATCH 11/43] Refactor error handling in pubsub session --- proxy/proxy.go | 2 +- proxy/pubsub.go | 16 +++++++++++++++- proxy/pubsub_test.go | 35 ----------------------------------- 3 files changed, 16 insertions(+), 37 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index ae67c171..68d9a9b4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -231,7 +231,7 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args } if err != nil { upstream.Close() - conn.WriteError("ERR " + err.Error()) + writeRedisError(conn, err) return } diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 7070e427..b483ce2e 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -326,7 +327,7 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { } if err != nil { upstream.Close() - s.writeError("ERR " + err.Error()) + s.writeRedisError(err) return } @@ -401,6 +402,7 @@ func (s *pubsubSession) handleSubscribe(args [][]byte) { channels := byteSlicesToStrings(args[1:]) if err := s.upstream.Subscribe(context.Background(), channels...); err != nil { s.logger.Warn("upstream subscribe failed", "err", err) + s.writeRedisError(err) return } s.mu.Lock() @@ -424,6 +426,7 @@ func (s *pubsubSession) handlePSubscribe(args [][]byte) { pats := byteSlicesToStrings(args[1:]) if err := s.upstream.PSubscribe(context.Background(), pats...); err != nil { s.logger.Warn("upstream psubscribe failed", "err", err) + s.writeRedisError(err) return } s.mu.Lock() @@ -563,6 +566,17 @@ func (s *pubsubSession) writeError(msg string) { _ = s.dconn.Flush() } +// writeRedisError writes an upstream error, preserving redis.Error prefixes verbatim +// and normalizing other errors to "ERR ..." (matching writeRedisError in proxy.go). +func (s *pubsubSession) writeRedisError(err error) { + var redisErr redis.Error + if errors.As(err, &redisErr) { + s.writeError(redisErr.Error()) + return + } + s.writeError("ERR " + err.Error()) +} + func (s *pubsubSession) writeString(msg string) { s.mu.Lock() defer s.mu.Unlock() diff --git a/proxy/pubsub_test.go b/proxy/pubsub_test.go index 670a4e94..2666d844 100644 --- a/proxy/pubsub_test.go +++ b/proxy/pubsub_test.go @@ -484,41 +484,6 @@ func TestPubSub_SelectAuthSilentlyAccepted(t *testing.T) { } } -// TestPubSub_DuplicateSubscribeDoesNotOvercount verifies that subscribing to the -// same channel multiple times doesn't inflate the subscription count. -func TestPubSub_DuplicateSubscribeDoesNotOvercount(t *testing.T) { - s := newTestSession(newMockDetachedConn()) - - // First subscribe - s.channelSet["ch1"] = struct{}{} - assert.Equal(t, 1, s.subCount()) - - // Duplicate subscribe (idempotent) - s.channelSet["ch1"] = struct{}{} - assert.Equal(t, 1, s.subCount(), "duplicate subscribe must not increase count") - - // Add new channel - s.channelSet["ch2"] = struct{}{} - assert.Equal(t, 2, s.subCount()) -} - -// TestPubSub_UnsubscribeNonExistent verifies that unsubscribing from a channel -// that was never subscribed does not affect the count. -func TestPubSub_UnsubscribeNonExistent(t *testing.T) { - s := newTestSession(newMockDetachedConn()) - - s.channelSet["ch1"] = struct{}{} - s.channelSet["ch2"] = struct{}{} - - // Delete non-existent: no panic, no effect - delete(s.channelSet, "never-subscribed") - assert.Equal(t, 2, s.subCount()) - - // Delete existing - delete(s.channelSet, "ch1") - assert.Equal(t, 1, s.subCount()) -} - // TestPubSub_CleanupClosesUpstream verifies that cleanup closes upstream and dconn. func TestPubSub_CleanupClosesUpstream(t *testing.T) { dconn := newMockDetachedConn() From 1a70ea0797a5f847b281ccd45412d34cf1a1ae81 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Tue, 17 Mar 2026 23:53:35 +0900 Subject: [PATCH 12/43] Refactor pubsub test writes with tagged types --- proxy/pubsub_test.go | 21 +++++++++++++-------- proxy/sentry.go | 9 ++++++--- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/proxy/pubsub_test.go b/proxy/pubsub_test.go index 2666d844..70fcfc27 100644 --- a/proxy/pubsub_test.go +++ b/proxy/pubsub_test.go @@ -11,12 +11,17 @@ import ( "github.com/tidwall/redcon" ) +// Tagged types to distinguish RESP wire types in mock writes. +type respInt struct{ V int } // WriteInt +type respInt64 struct{ V int64 } // WriteInt64 +type respArray struct{ N int } // WriteArray + // mockDetachedConn implements redcon.DetachedConn for unit testing. type mockDetachedConn struct { mu sync.Mutex commands []redcon.Command // queued commands to return from ReadCommand cmdIdx int - writes []any // recorded writes: string for WriteString/WriteError/WriteBulkString, int for WriteInt/WriteArray, nil for WriteNull + writes []any // recorded writes: string for WriteString/WriteError/WriteBulkString, respInt/respInt64/respArray for typed ints, nil for WriteNull closed bool readErr error // error to return from ReadCommand when commands exhausted } @@ -74,12 +79,12 @@ func (m *mockDetachedConn) WriteBulkString(bulk string) { func (m *mockDetachedConn) WriteInt(num int) { m.mu.Lock() defer m.mu.Unlock() - m.writes = append(m.writes, num) + m.writes = append(m.writes, respInt{num}) } func (m *mockDetachedConn) WriteInt64(num int64) { m.mu.Lock() defer m.mu.Unlock() - m.writes = append(m.writes, int(num)) + m.writes = append(m.writes, respInt64{num}) } func (m *mockDetachedConn) WriteUint64(_ uint64) { // Not used in pubsub tests. @@ -87,7 +92,7 @@ func (m *mockDetachedConn) WriteUint64(_ uint64) { func (m *mockDetachedConn) WriteArray(count int) { m.mu.Lock() defer m.mu.Unlock() - m.writes = append(m.writes, count) + m.writes = append(m.writes, respArray{count}) } func (m *mockDetachedConn) WriteNull() { m.mu.Lock() @@ -201,7 +206,7 @@ func TestPubSub_WriteUnsubAll_PerChannelReplies(t *testing.T) { // Count array headers (each reply starts with WriteArray(3)) arrCount := 0 for _, w := range writes { - if n, ok := w.(int); ok && n == 3 { + if a, ok := w.(respArray); ok && a.N == 3 { arrCount++ } } @@ -250,7 +255,7 @@ func TestPubSub_WriteUnsubAll_Patterns(t *testing.T) { // Should have 1 reply (one pattern) arrCount := 0 for _, w := range writes { - if n, ok := w.(int); ok && n == 3 { + if a, ok := w.(respArray); ok && a.N == 3 { arrCount++ } } @@ -341,7 +346,7 @@ func TestPubSub_HandlePubSubPing(t *testing.T) { writes := dconn.getWrites() // ["pong", ""] - assert.Contains(t, writes, 2) // WriteArray(2) + assert.Contains(t, writes, respArray{2}) // WriteArray(2) assert.Contains(t, writes, "BULKSTR:pong") assert.Contains(t, writes, "BULKSTR:") // empty string } @@ -372,7 +377,7 @@ func TestPubSub_HandleUnsubNoSession(t *testing.T) { } } assert.True(t, hasNull) - assert.Contains(t, writes, 0) // WriteInt64(0) + assert.Contains(t, writes, respInt64{0}) // WriteInt64(0) } func TestPubSub_SubscribeInTxnRejected(t *testing.T) { diff --git a/proxy/sentry.go b/proxy/sentry.go index 51537e7b..798d695a 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -85,18 +85,21 @@ func (r *SentryReporter) CaptureDivergence(div Divergence) { } r.hub.WithScope(func(scope *sentry.Scope) { scope.SetTag("command", div.Command) - scope.SetTag("key", div.Key) scope.SetTag("kind", div.Kind.String()) + // Omit raw key from Sentry tags to avoid leaking sensitive data; + // only send a truncated form as an extra for debugging. + scope.SetExtra("key", truncateValue(div.Key)) scope.SetExtra("primary", truncateValue(div.Primary)) scope.SetExtra("secondary", truncateValue(div.Secondary)) scope.SetFingerprint([]string{"divergence", div.Kind.String(), div.Command}) scope.SetLevel(sentry.LevelWarning) - r.hub.CaptureMessage(fmt.Sprintf("data divergence: %s %s (%s)", div.Kind, div.Command, div.Key)) + r.hub.CaptureMessage(fmt.Sprintf("data divergence: %s %s", div.Kind, div.Command)) }) } // ShouldReport checks if this fingerprint has been reported recently (cooldown-based). -// Periodically evicts expired entries to prevent unbounded map growth. +// Evicts expired entries when the map reaches maxReportEntries to bound memory usage. +// Returns false (drops the report) if the map is still at capacity after eviction. func (r *SentryReporter) ShouldReport(fingerprint string) bool { r.mu.Lock() defer r.mu.Unlock() From f065da7a2819f4ce02fe61fb2ce46fb5461b34d7 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 00:24:50 +0900 Subject: [PATCH 13/43] Fix execTxn to prioritize pipeline-level errors over results Pipeline-level errors (connection/transport failures) were only checked when results were empty. Now checked first to avoid silently ignoring transport errors when partial results are returned. Co-Authored-By: Claude Opus 4.6 --- proxy/proxy.go | 7 ++++--- proxy/pubsub.go | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 68d9a9b4..5ae9ce9b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -348,13 +348,14 @@ func (p *ProxyServer) execTxn(conn redcon.Conn, state *proxyConnState) { cmds = append(cmds, []any{"EXEC"}) results, err := p.dual.Primary().Pipeline(ctx, cmds) - if len(results) > 0 { + if err != nil { + // Pipeline-level error (connection/transport failure) takes precedence. + writeRedisError(conn, err) + } else if len(results) > 0 { // Write the EXEC result (last command in the pipeline). lastResult := results[len(results)-1] resp, rErr := lastResult.Result() writeResponse(conn, resp, rErr) - } else if err != nil { - writeRedisError(conn, err) } // Async replay to secondary (bounded) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index b483ce2e..89f71f5a 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -369,12 +369,13 @@ func (s *pubsubSession) execTxn() { results, err := s.proxy.dual.Primary().Pipeline(ctx, cmds) s.mu.Lock() - if len(results) > 0 { + if err != nil { + // Pipeline-level error (connection/transport failure) takes precedence. + writeRedisError(s.dconn, err) + } else if len(results) > 0 { lastResult := results[len(results)-1] resp, rErr := lastResult.Result() writeResponse(s.dconn, resp, rErr) - } else if err != nil { - writeRedisError(s.dconn, err) } _ = s.dconn.Flush() s.mu.Unlock() From 9346ceeeacf05505095850fb5a9568dae843f881 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 00:32:09 +0900 Subject: [PATCH 14/43] Address review: writeMu separation, graceful shutdown, log truncation, unsub count fix - Introduce writeMu to serialize dconn writes separately from state mutex, preventing potential deadlock in forwardMessages - Add graceful shutdown handling in ListenAndServe to suppress expected listener-closed errors when context is cancelled - Use truncateValue() in ShadowReader log output to prevent data leakage - Fix writeUnsubAll to remove entries one-by-one so subscription count decrements correctly per reply, matching Redis behavior Co-Authored-By: Claude Opus 4.6 --- proxy/compare.go | 6 +-- proxy/proxy.go | 5 +++ proxy/pubsub.go | 96 ++++++++++++++++++++++++++---------------------- 3 files changed, 61 insertions(+), 46 deletions(-) diff --git a/proxy/compare.go b/proxy/compare.go index 6b8e39d7..32e451fe 100644 --- a/proxy/compare.go +++ b/proxy/compare.go @@ -110,9 +110,9 @@ func (s *ShadowReader) Compare(ctx context.Context, cmd string, args [][]byte, p s.metrics.Divergences.WithLabelValues(cmd, kind.String()).Inc() s.logger.Warn("response divergence detected", - "cmd", div.Command, "key", div.Key, "kind", div.Kind.String(), - "primary", fmt.Sprintf("%v", div.Primary), - "secondary", fmt.Sprintf("%v", div.Secondary), + "cmd", div.Command, "key", truncateValue(div.Key), "kind", div.Kind.String(), + "primary", truncateValue(div.Primary), + "secondary", truncateValue(div.Secondary), ) s.sentry.CaptureDivergence(div) } diff --git a/proxy/proxy.go b/proxy/proxy.go index 5ae9ce9b..447acc26 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -88,6 +88,11 @@ func (p *ProxyServer) ListenAndServe(ctx context.Context) error { ) if err = srv.Serve(ln); err != nil { + // During graceful shutdown, srv.Close() causes Serve to return + // a listener-closed error. Treat this as a normal exit. + if ctx.Err() != nil { + return nil //nolint:nilerr // intentional: suppress expected listener-closed error during graceful shutdown + } return fmt.Errorf("proxy serve: %w", err) } return nil diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 89f71f5a..2bda8654 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -35,7 +35,8 @@ const ( // When all subscriptions are removed, the session transitions to normal command mode, // enabling the client to execute regular Redis commands without reconnecting. type pubsubSession struct { - mu sync.Mutex + mu sync.Mutex // protects session state (upstream, closed, channelSet, patternSet, txn) + writeMu sync.Mutex // serializes writes to dconn; never held across state operations dconn redcon.DetachedConn upstream *redis.PubSub // nil when not in pub/sub mode proxy *ProxyServer @@ -102,10 +103,12 @@ func (s *pubsubSession) startForwarding() { func (s *pubsubSession) forwardMessages(ch <-chan *redis.Message) { for msg := range ch { s.mu.Lock() - if s.closed { - s.mu.Unlock() + closed := s.closed + s.mu.Unlock() + if closed { return } + s.writeMu.Lock() if msg.Pattern != "" { s.dconn.WriteArray(pubsubArrayPMessage) s.dconn.WriteBulkString("pmessage") @@ -119,7 +122,7 @@ func (s *pubsubSession) forwardMessages(ch <-chan *redis.Message) { s.dconn.WriteBulkString(msg.Payload) } err := s.dconn.Flush() - s.mu.Unlock() + s.writeMu.Unlock() if err != nil { return } @@ -141,11 +144,9 @@ func (s *pubsubSession) commandLoop() { args := cloneArgs(cmd.Args) name := strings.ToUpper(string(args[0])) - s.mu.Lock() - inPubSub := s.subCount() > 0 - s.mu.Unlock() - - if inPubSub { + // subCount() reads channelSet/patternSet which are only modified + // from this goroutine (commandLoop), so no lock is needed. + if s.subCount() > 0 { if !s.dispatchPubSubCommand(args) { return } @@ -300,10 +301,10 @@ func (s *pubsubSession) dispatchRegularCommand(name string, args [][]byte) { return } - s.mu.Lock() + s.writeMu.Lock() writeResponse(s.dconn, resp, err) _ = s.dconn.Flush() - s.mu.Unlock() + s.writeMu.Unlock() } func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { @@ -336,21 +337,24 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { s.mu.Unlock() s.startForwarding() + // Update state (sets only accessed from commandLoop goroutine). kind := strings.ToLower(cmdName) - s.mu.Lock() for _, ch := range channels { if cmdName == cmdSubscribe { s.channelSet[ch] = struct{}{} } else { s.patternSet[ch] = struct{}{} } + } + s.writeMu.Lock() + for _, ch := range channels { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(ch) s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() - s.mu.Unlock() + s.writeMu.Unlock() } func (s *pubsubSession) execTxn() { @@ -368,7 +372,7 @@ func (s *pubsubSession) execTxn() { results, err := s.proxy.dual.Primary().Pipeline(ctx, cmds) - s.mu.Lock() + s.writeMu.Lock() if err != nil { // Pipeline-level error (connection/transport failure) takes precedence. writeRedisError(s.dconn, err) @@ -378,7 +382,7 @@ func (s *pubsubSession) execTxn() { writeResponse(s.dconn, resp, rErr) } _ = s.dconn.Flush() - s.mu.Unlock() + s.writeMu.Unlock() if s.proxy.dual.hasSecondaryWrite() { s.proxy.dual.goAsync(func() { @@ -406,17 +410,19 @@ func (s *pubsubSession) handleSubscribe(args [][]byte) { s.writeRedisError(err) return } - s.mu.Lock() + // Update state (channelSet is only accessed from commandLoop goroutine). for _, ch := range channels { - // Idempotent: Redis treats re-subscribe as a no-op for counting. s.channelSet[ch] = struct{}{} + } + s.writeMu.Lock() + for _, ch := range channels { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString("subscribe") s.dconn.WriteBulkString(ch) s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() - s.mu.Unlock() + s.writeMu.Unlock() } func (s *pubsubSession) handlePSubscribe(args [][]byte) { @@ -430,17 +436,19 @@ func (s *pubsubSession) handlePSubscribe(args [][]byte) { s.writeRedisError(err) return } - s.mu.Lock() + // Update state (patternSet is only accessed from commandLoop goroutine). for _, p := range pats { - // Idempotent: Redis treats re-subscribe as a no-op for counting. s.patternSet[p] = struct{}{} + } + s.writeMu.Lock() + for _, p := range pats { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString("psubscribe") s.dconn.WriteBulkString(p) s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() - s.mu.Unlock() + s.writeMu.Unlock() } // handleUnsub handles both UNSUBSCRIBE and PUNSUBSCRIBE. @@ -458,9 +466,7 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { if err := unsubFn(context.Background()); err != nil { s.logger.Warn("upstream "+kind+" failed", "err", err) } - s.mu.Lock() s.writeUnsubAll(kind, isPattern) - s.mu.Unlock() return } @@ -468,7 +474,8 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { if err := unsubFn(context.Background(), names...); err != nil { s.logger.Warn("upstream "+kind+" failed", "err", err) } - s.mu.Lock() + // Update state then write replies. + s.writeMu.Lock() for _, n := range names { if isPattern { delete(s.patternSet, n) @@ -481,17 +488,20 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { s.dconn.WriteInt(s.subCount()) } _ = s.dconn.Flush() - s.mu.Unlock() + s.writeMu.Unlock() } -// writeUnsubAll emits per-channel/pattern unsubscribe replies and clears the set. -// Must be called with s.mu held. +// writeUnsubAll emits per-channel/pattern unsubscribe replies, removing entries +// one-by-one so the subscription count decrements per reply (matching Redis). func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { set := s.channelSet if isPattern { set = s.patternSet } + s.writeMu.Lock() + defer s.writeMu.Unlock() + if len(set) == 0 { // No subscriptions: single reply with null channel (matching Redis). s.dconn.WriteArray(pubsubArrayReply) @@ -502,18 +512,18 @@ func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { return } - // Collect names before clearing so we can emit per-channel replies. + // Collect names, then remove one-by-one to decrement count per reply. names := make([]string, 0, len(set)) for n := range set { names = append(names, n) } - if isPattern { - s.patternSet = make(map[string]struct{}) - } else { - s.channelSet = make(map[string]struct{}) - } for _, n := range names { + if isPattern { + delete(s.patternSet, n) + } else { + delete(s.channelSet, n) + } s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(n) @@ -525,8 +535,8 @@ func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { // --- Ping handlers --- func (s *pubsubSession) handlePubSubPing(args [][]byte) { - s.mu.Lock() - defer s.mu.Unlock() + s.writeMu.Lock() + defer s.writeMu.Unlock() s.dconn.WriteArray(pubsubArrayPong) s.dconn.WriteBulkString("pong") if len(args) > 1 { @@ -538,8 +548,8 @@ func (s *pubsubSession) handlePubSubPing(args [][]byte) { } func (s *pubsubSession) handleNormalPing(args [][]byte) { - s.mu.Lock() - defer s.mu.Unlock() + s.writeMu.Lock() + defer s.writeMu.Unlock() if len(args) > 1 { s.dconn.WriteBulk(args[1]) } else { @@ -549,8 +559,8 @@ func (s *pubsubSession) handleNormalPing(args [][]byte) { } func (s *pubsubSession) handleUnsubNoSession(cmdName string) { - s.mu.Lock() - defer s.mu.Unlock() + s.writeMu.Lock() + defer s.writeMu.Unlock() s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(strings.ToLower(cmdName)) s.dconn.WriteNull() @@ -561,8 +571,8 @@ func (s *pubsubSession) handleUnsubNoSession(cmdName string) { // --- Helpers --- func (s *pubsubSession) writeError(msg string) { - s.mu.Lock() - defer s.mu.Unlock() + s.writeMu.Lock() + defer s.writeMu.Unlock() s.dconn.WriteError(msg) _ = s.dconn.Flush() } @@ -579,8 +589,8 @@ func (s *pubsubSession) writeRedisError(err error) { } func (s *pubsubSession) writeString(msg string) { - s.mu.Lock() - defer s.mu.Unlock() + s.writeMu.Lock() + defer s.writeMu.Unlock() s.dconn.WriteString(msg) _ = s.dconn.Flush() } From 721087bfb84de36286938daeefecf7baf0c5732c Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 00:47:53 +0900 Subject: [PATCH 15/43] Truncate migration-gap log key and optimize truncateValue for large values - Apply truncateValue() to migration-gap DEBUG log key for consistency - Handle string/[]byte types in truncateValue by slicing before formatting to avoid allocating the full string representation for large values Co-Authored-By: Claude Opus 4.6 --- proxy/compare.go | 2 +- proxy/sentry.go | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/proxy/compare.go b/proxy/compare.go index 32e451fe..5d2b1a37 100644 --- a/proxy/compare.go +++ b/proxy/compare.go @@ -94,7 +94,7 @@ func (s *ShadowReader) Compare(ctx context.Context, cmd string, args [][]byte, p s.metrics.MigrationGaps.WithLabelValues(cmd).Inc() count := s.gapCount.Add(1) if count%s.gapLogSampleRate == 1 { - s.logger.Debug("migration gap (sampled)", "cmd", cmd, "key", extractKey(args)) + s.logger.Debug("migration gap (sampled)", "cmd", cmd, "key", truncateValue(extractKey(args))) } return } diff --git a/proxy/sentry.go b/proxy/sentry.go index 798d695a..60d99fae 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -140,9 +140,21 @@ func cmdNameFromArgs(args [][]byte) string { return unknownStr } -// truncateValue formats a value for Sentry, truncating to avoid data leakage and oversized events. +// truncateValue formats a value for logging/Sentry, truncating to avoid data leakage and oversized events. +// Handles common types by slicing before formatting to avoid allocating the full string representation. func truncateValue(v any) string { - s := fmt.Sprintf("%v", v) + var s string + switch tv := v.(type) { + case string: + s = tv + case []byte: + if len(tv) > maxSentryValueLen { + tv = tv[:maxSentryValueLen] + } + s = string(tv) + default: + s = fmt.Sprintf("%v", v) + } if len(s) > maxSentryValueLen { return s[:maxSentryValueLen] + "...(truncated)" } From 1f9334f3f1d4110d11e1da0618fa58054a93a575 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 01:21:56 +0900 Subject: [PATCH 16/43] Fix truncateValue []byte marker, bounded cleanup wait, mock Close race - Add "...(truncated)" marker to []byte path in truncateValue - Add bounded timeout (5s) to cleanup() wait on forwardMessages to prevent hanging on stuck client sockets - Guard mockDetachedConn.Close() with mutex to fix data race under -race Co-Authored-By: Claude Opus 4.6 --- proxy/pubsub.go | 13 ++++++++++++- proxy/pubsub_test.go | 7 ++++++- proxy/sentry.go | 4 ++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 2bda8654..da41a8c8 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -7,6 +7,7 @@ import ( "log/slog" "strings" "sync" + "time" "github.com/redis/go-redis/v9" "github.com/tidwall/redcon" @@ -28,6 +29,10 @@ const ( cmdDiscard = "DISCARD" cmdPing = "PING" cmdQuit = "QUIT" + + // cleanupFwdTimeout bounds the wait for forwardMessages to exit during cleanup. + // If the client socket is stuck, we don't want to block indefinitely. + cleanupFwdTimeout = 5 * time.Second ) // pubsubSession manages a single client's detached connection. @@ -77,7 +82,13 @@ func (s *pubsubSession) cleanup() { } s.mu.Unlock() if s.fwdDone != nil { - <-s.fwdDone + // Bounded wait: if forwardMessages is stuck on a slow/dead client socket, + // don't block cleanup indefinitely. + select { + case <-s.fwdDone: + case <-time.After(cleanupFwdTimeout): + s.logger.Warn("forwardMessages did not exit within timeout, proceeding with cleanup") + } } s.dconn.Close() } diff --git a/proxy/pubsub_test.go b/proxy/pubsub_test.go index 70fcfc27..0dba4521 100644 --- a/proxy/pubsub_test.go +++ b/proxy/pubsub_test.go @@ -52,7 +52,12 @@ func (m *mockDetachedConn) ReadCommand() (redcon.Command, error) { } func (m *mockDetachedConn) Flush() error { return nil } -func (m *mockDetachedConn) Close() error { m.closed = true; return nil } +func (m *mockDetachedConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return nil +} func (m *mockDetachedConn) RemoteAddr() string { return "127.0.0.1:9999" } diff --git a/proxy/sentry.go b/proxy/sentry.go index 60d99fae..239c5173 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -149,9 +149,9 @@ func truncateValue(v any) string { s = tv case []byte: if len(tv) > maxSentryValueLen { - tv = tv[:maxSentryValueLen] + return string(tv[:maxSentryValueLen]) + "...(truncated)" } - s = string(tv) + return string(tv) default: s = fmt.Sprintf("%v", v) } From 69a443971d047dd977e3c2705cd760d216188f6f Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 01:34:56 +0900 Subject: [PATCH 17/43] Bound exitPubSubMode fwdDone wait and use WriteInt64 for sub counts - Add timeout to exitPubSubMode fwdDone wait, matching cleanup() pattern - Switch all subscription count replies from WriteInt to WriteInt64 for RESP integer consistency and 32-bit safety --- proxy/proxy.go | 2 +- proxy/pubsub.go | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 447acc26..298dd991 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -263,7 +263,7 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args } else { session.patternSet[ch] = struct{}{} } - dconn.WriteInt(session.subCount()) + dconn.WriteInt64(int64(session.subCount())) } if err := dconn.Flush(); err != nil { dconn.Close() diff --git a/proxy/pubsub.go b/proxy/pubsub.go index da41a8c8..f6ddb7bf 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -184,7 +184,11 @@ func (s *pubsubSession) exitPubSubMode() { } s.mu.Unlock() if s.fwdDone != nil { - <-s.fwdDone + select { + case <-s.fwdDone: + case <-time.After(cleanupFwdTimeout): + s.logger.Warn("forwardMessages did not exit within timeout during pub/sub mode exit") + } s.fwdDone = nil } } @@ -362,7 +366,7 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(ch) - s.dconn.WriteInt(s.subCount()) + s.dconn.WriteInt64(int64(s.subCount())) } _ = s.dconn.Flush() s.writeMu.Unlock() @@ -430,7 +434,7 @@ func (s *pubsubSession) handleSubscribe(args [][]byte) { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString("subscribe") s.dconn.WriteBulkString(ch) - s.dconn.WriteInt(s.subCount()) + s.dconn.WriteInt64(int64(s.subCount())) } _ = s.dconn.Flush() s.writeMu.Unlock() @@ -456,7 +460,7 @@ func (s *pubsubSession) handlePSubscribe(args [][]byte) { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString("psubscribe") s.dconn.WriteBulkString(p) - s.dconn.WriteInt(s.subCount()) + s.dconn.WriteInt64(int64(s.subCount())) } _ = s.dconn.Flush() s.writeMu.Unlock() @@ -496,7 +500,7 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(n) - s.dconn.WriteInt(s.subCount()) + s.dconn.WriteInt64(int64(s.subCount())) } _ = s.dconn.Flush() s.writeMu.Unlock() @@ -518,7 +522,7 @@ func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteNull() - s.dconn.WriteInt(s.subCount()) + s.dconn.WriteInt64(int64(s.subCount())) _ = s.dconn.Flush() return } @@ -538,7 +542,7 @@ func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(n) - s.dconn.WriteInt(s.subCount()) + s.dconn.WriteInt64(int64(s.subCount())) } _ = s.dconn.Flush() } From 11609b0d0d0871e9758bbad0201c9fa0b4016601 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 01:42:34 +0900 Subject: [PATCH 18/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/command.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/command.go b/proxy/command.go index b04ead9b..b1414ee4 100644 --- a/proxy/command.go +++ b/proxy/command.go @@ -9,7 +9,7 @@ const ( CmdRead CommandCategory = iota // GET, HGET, LRANGE, ZRANGE, etc. CmdWrite // SET, DEL, HSET, LPUSH, ZADD, etc. CmdBlocking // BZPOPMIN, XREAD (with BLOCK) - CmdPubSub // SUBSCRIBE, PUBLISH, PUBSUB + CmdPubSub // SUBSCRIBE, PUBSUB CmdAdmin // PING, INFO, CLIENT, SELECT, QUIT, DBSIZE, SCAN, AUTH CmdTxn // MULTI, EXEC, DISCARD CmdScript // EVAL, EVALSHA From da95a8c22983f08c14ef689152e66d26b51a42d1 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 01:43:25 +0900 Subject: [PATCH 19/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/pubsub.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index f6ddb7bf..477cce9f 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -227,11 +227,12 @@ func (s *pubsubSession) dispatchNormalCommand(name string, args [][]byte) bool { s.handleNormalPing(args) return true } + // Let the transaction handler process or queue commands first, so that + // behavior during MULTI is consistent with the main ProxyServer handler. + if s.handleTxnInSession(name, args) { + return true + } if name == cmdSubscribe || name == cmdPSubscribe { - if s.inTxn { - s.writeError("ERR Command not allowed inside a transaction") - return true - } s.reenterPubSub(name, args) return true } @@ -239,9 +240,6 @@ func (s *pubsubSession) dispatchNormalCommand(name string, args [][]byte) bool { s.handleUnsubNoSession(name) return true } - if s.handleTxnInSession(name, args) { - return true - } s.dispatchRegularCommand(name, args) return true } From 50343d98a56fb3232c3529014d35ae3fb0a87aa8 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 02:10:05 +0900 Subject: [PATCH 20/43] Fix CmdPubSub comment, align SUBSCRIBE-in-txn test with new queuing behavior - Update CmdPubSub comment to list all classified commands and note PUBLISH is CmdWrite - Update test: SUBSCRIBE during MULTI is now queued (consistent with main ProxyServer handler) instead of rejected --- proxy/command.go | 2 +- proxy/pubsub_test.go | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/proxy/command.go b/proxy/command.go index b1414ee4..9522bf95 100644 --- a/proxy/command.go +++ b/proxy/command.go @@ -9,7 +9,7 @@ const ( CmdRead CommandCategory = iota // GET, HGET, LRANGE, ZRANGE, etc. CmdWrite // SET, DEL, HSET, LPUSH, ZADD, etc. CmdBlocking // BZPOPMIN, XREAD (with BLOCK) - CmdPubSub // SUBSCRIBE, PUBSUB + CmdPubSub // SUBSCRIBE, UNSUBSCRIBE, PSUBSCRIBE, PUNSUBSCRIBE, PUBSUB (note: PUBLISH is CmdWrite) CmdAdmin // PING, INFO, CLIENT, SELECT, QUIT, DBSIZE, SCAN, AUTH CmdTxn // MULTI, EXEC, DISCARD CmdScript // EVAL, EVALSHA diff --git a/proxy/pubsub_test.go b/proxy/pubsub_test.go index 0dba4521..b8076acb 100644 --- a/proxy/pubsub_test.go +++ b/proxy/pubsub_test.go @@ -385,7 +385,7 @@ func TestPubSub_HandleUnsubNoSession(t *testing.T) { assert.Contains(t, writes, respInt64{0}) // WriteInt64(0) } -func TestPubSub_SubscribeInTxnRejected(t *testing.T) { +func TestPubSub_SubscribeInTxnQueued(t *testing.T) { dconn := newMockDetachedConn() s := newTestSession(dconn) s.inTxn = true @@ -396,11 +396,12 @@ func TestPubSub_SubscribeInTxnRejected(t *testing.T) { writes := dconn.getWrites() found := false for _, w := range writes { - if str, ok := w.(string); ok && str == "ERR:ERR Command not allowed inside a transaction" { + if str, ok := w.(string); ok && str == "STR:QUEUED" { found = true } } - assert.True(t, found, "SUBSCRIBE during MULTI should be rejected") + assert.True(t, found, "SUBSCRIBE during MULTI should be queued") + assert.Len(t, s.txnQueue, 1, "SUBSCRIBE should be added to txn queue") } func TestPubSub_HandleTxnInSession_MultiExecDiscard(t *testing.T) { From f1b03505a30e1a036167e1900bf28fb0378f0b89 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 02:40:59 +0900 Subject: [PATCH 21/43] Bound metrics shutdown, return error on upstream unsub failure, fix mu comment - Use bounded context (5s) for metrics server shutdown to avoid hanging - Return error reply and skip local state mutation when upstream UNSUBSCRIBE/PUNSUBSCRIBE fails, preventing client/upstream desync - Narrow mu comment to actual protected fields (upstream, closed) --- cmd/redis-proxy/main.go | 9 +++++++-- proxy/pubsub.go | 10 +++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/cmd/redis-proxy/main.go b/cmd/redis-proxy/main.go index 165c4250..9a616279 100644 --- a/cmd/redis-proxy/main.go +++ b/cmd/redis-proxy/main.go @@ -17,7 +17,10 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" ) -const sentryFlushTimeout = 2 * time.Second +const ( + sentryFlushTimeout = 2 * time.Second + metricsShutdownTimeout = 5 * time.Second +) func main() { if err := run(); err != nil { @@ -102,7 +105,9 @@ func run() error { metricsSrv := &http.Server{Handler: mux, ReadHeaderTimeout: time.Second} go func() { <-ctx.Done() - if err := metricsSrv.Shutdown(context.Background()); err != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), metricsShutdownTimeout) + defer shutdownCancel() + if err := metricsSrv.Shutdown(shutdownCtx); err != nil { logger.Warn("metrics server shutdown error", "err", err) } }() diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 477cce9f..b6d2acc6 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -40,7 +40,7 @@ const ( // When all subscriptions are removed, the session transitions to normal command mode, // enabling the client to execute regular Redis commands without reconnecting. type pubsubSession struct { - mu sync.Mutex // protects session state (upstream, closed, channelSet, patternSet, txn) + mu sync.Mutex // protects upstream and closed (channelSet, patternSet, txn are goroutine-confined to commandLoop) writeMu sync.Mutex // serializes writes to dconn; never held across state operations dconn redcon.DetachedConn upstream *redis.PubSub // nil when not in pub/sub mode @@ -477,7 +477,9 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { if len(args) < pubsubMinArgs { // Unsubscribe all: emit per-channel reply (matching Redis behavior). if err := unsubFn(context.Background()); err != nil { - s.logger.Warn("upstream "+kind+" failed", "err", err) + s.logger.Warn("upstream "+kind+" failed, closing session", "err", err) + s.writeRedisError(err) + return } s.writeUnsubAll(kind, isPattern) return @@ -485,7 +487,9 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { names := byteSlicesToStrings(args[1:]) if err := unsubFn(context.Background(), names...); err != nil { - s.logger.Warn("upstream "+kind+" failed", "err", err) + s.logger.Warn("upstream "+kind+" failed, closing session", "err", err) + s.writeRedisError(err) + return } // Update state then write replies. s.writeMu.Lock() From 3d38cd3da23526e7be91185b32d97af79c808649 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 03:34:32 +0900 Subject: [PATCH 22/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/sentry.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 3 deletions(-) diff --git a/proxy/sentry.go b/proxy/sentry.go index 239c5173..7720f097 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -3,6 +3,8 @@ package proxy import ( "fmt" "log/slog" + "reflect" + "strings" "sync" "time" @@ -143,20 +145,106 @@ func cmdNameFromArgs(args [][]byte) string { // truncateValue formats a value for logging/Sentry, truncating to avoid data leakage and oversized events. // Handles common types by slicing before formatting to avoid allocating the full string representation. func truncateValue(v any) string { - var s string switch tv := v.(type) { case string: - s = tv + return truncateString(tv) case []byte: + // Avoid converting an arbitrarily large byte slice into a full string. if len(tv) > maxSentryValueLen { return string(tv[:maxSentryValueLen]) + "...(truncated)" } return string(tv) + case fmt.Stringer: + // Respect custom String implementations but still apply length limits. + return truncateString(tv.String()) default: - s = fmt.Sprintf("%v", v) + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Slice, reflect.Array: + return formatSliceValue(rv, maxSentryValueLen) + case reflect.Map: + return formatMapValue(rv, maxSentryValueLen) + default: + // For non-container types, fall back to fmt and then truncate. + return truncateString(fmt.Sprintf("%v", v)) + } } +} + +// truncateString enforces maxSentryValueLen on an already-built string. +func truncateString(s string) string { if len(s) > maxSentryValueLen { return s[:maxSentryValueLen] + "...(truncated)" } return s } + +// formatSliceValue formats a slice/array value without allocating an unbounded string. +// It stops once approximately maxLen bytes have been written and appends a truncation marker. +func formatSliceValue(rv reflect.Value, maxLen int) string { + var b strings.Builder + b.WriteByte('[') + for i := 0; i < rv.Len(); i++ { + if b.Len() >= maxLen { + b.WriteString("...(truncated)]") + return b.String() + } + if i > 0 { + b.WriteString(", ") + } + elemStr := truncateValue(rv.Index(i).Interface()) + if b.Len()+len(elemStr) > maxLen { + // Write as much as fits, then mark as truncated. + remaining := maxLen - b.Len() + if remaining > 0 { + if remaining < len(elemStr) { + b.WriteString(elemStr[:remaining]) + } else { + b.WriteString(elemStr) + } + } + b.WriteString("...(truncated)]") + return b.String() + } + b.WriteString(elemStr) + } + b.WriteByte(']') + return truncateString(b.String()) +} + +// formatMapValue formats a map value without allocating an unbounded string. +// It stops once approximately maxLen bytes have been written and appends a truncation marker. +func formatMapValue(rv reflect.Value, maxLen int) string { + var b strings.Builder + b.WriteByte('{') + iter := rv.MapRange() + first := true + for iter.Next() { + if b.Len() >= maxLen { + b.WriteString("...(truncated)}") + return b.String() + } + if !first { + b.WriteString(", ") + } + first = false + keyStr := truncateValue(iter.Key().Interface()) + valStr := truncateValue(iter.Value().Interface()) + entry := fmt.Sprintf("%s: %s", keyStr, valStr) + if b.Len()+len(entry) > maxLen { + remaining := maxLen - b.Len() + if remaining > 0 { + if remaining < len(entry) { + b.WriteString(entry[:remaining]) + } else { + b.WriteString(entry) + } + } + b.WriteString("...(truncated)}") + return b.String() + } + b.WriteString(entry) + } + b.WriteByte('}') + return truncateString(b.String()) +} From 001fbd35112f1a4ef428dcbe16f99868f57cbef3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 19:13:34 +0000 Subject: [PATCH 23/43] Initial plan From 2f2f0ce585800a17de63f758f478ff310e5dbedc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 19:16:09 +0000 Subject: [PATCH 24/43] proxy: fix exhaustive lint error in truncateValue by using if/else instead of switch on reflect.Kind Co-authored-by: bootjp <1306365+bootjp@users.noreply.github.com> --- proxy/sentry.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/proxy/sentry.go b/proxy/sentry.go index 7720f097..905ee820 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -159,15 +159,14 @@ func truncateValue(v any) string { return truncateString(tv.String()) default: rv := reflect.ValueOf(v) - switch rv.Kind() { - case reflect.Slice, reflect.Array: + kind := rv.Kind() + if kind == reflect.Slice || kind == reflect.Array { return formatSliceValue(rv, maxSentryValueLen) - case reflect.Map: + } else if kind == reflect.Map { return formatMapValue(rv, maxSentryValueLen) - default: - // For non-container types, fall back to fmt and then truncate. - return truncateString(fmt.Sprintf("%v", v)) } + // For non-container types, fall back to fmt and then truncate. + return truncateString(fmt.Sprintf("%v", v)) } } From 1f2f00d8d6542538c981e29655e76806117751d4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 19:19:47 +0000 Subject: [PATCH 25/43] proxy/sentry: fix exhaustive lint error in truncateValue Co-authored-by: bootjp <1306365+bootjp@users.noreply.github.com> --- go.mod | 5 +---- go.sum | 8 -------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/go.mod b/go.mod index f9b0077a..55f45eba 100644 --- a/go.mod +++ b/go.mod @@ -17,9 +17,9 @@ require ( github.com/cockroachdb/errors v1.12.0 github.com/cockroachdb/pebble v1.1.5 github.com/emirpasic/gods v1.18.1 + github.com/getsentry/sentry-go v0.27.0 github.com/hashicorp/go-hclog v1.6.3 github.com/hashicorp/raft v1.7.3 - github.com/hashicorp/raft-boltdb/v2 v2.3.1 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.23.2 github.com/redis/go-redis/v9 v9.18.0 @@ -50,7 +50,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/boltdb/bolt v1.3.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cockroachdb/datadriven v1.0.3-0.20250407164829-2945557346d5 // indirect github.com/cockroachdb/fifo v0.0.0-20240606204812-0bbfbd93a7ce // indirect @@ -60,7 +59,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fatih/color v1.15.0 // indirect - github.com/getsentry/sentry-go v0.27.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.5-0.20231225225746-43d5d4cd4e0e // indirect @@ -85,7 +83,6 @@ require ( github.com/tidwall/btree v1.1.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect - go.etcd.io/bbolt v1.4.3 // indirect go.uber.org/atomic v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df // indirect diff --git a/go.sum b/go.sum index 7b2c117a..1932fa79 100644 --- a/go.sum +++ b/go.sum @@ -61,7 +61,6 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= @@ -181,7 +180,6 @@ github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJ github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= github.com/hashicorp/go-metrics v0.5.4 h1:8mmPiIJkTPPEbAiV97IxdAGNdRdaWwVap1BU6elejKY= github.com/hashicorp/go-metrics v0.5.4/go.mod h1:CG5yz4NZ/AI/aQt9Ucm/vdBnbh7fvmv4lxZ350i+QQI= -github.com/hashicorp/go-msgpack v0.5.5 h1:i9R9JSrqIz0QVLz3sz+i3YJdT7TTSLcfLLzJi9aZTuI= github.com/hashicorp/go-msgpack v0.5.5/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= github.com/hashicorp/go-msgpack/v2 v2.1.1/go.mod h1:upybraOAblm4S7rx0+jeNy+CWWhzywQsSRV5033mMu4= github.com/hashicorp/go-msgpack/v2 v2.1.2 h1:4Ee8FTp834e+ewB71RDrQ0VKpyFdrKOjvYtnQ/ltVj0= @@ -200,10 +198,6 @@ github.com/hashicorp/raft v1.7.0/go.mod h1:N1sKh6Vn47mrWvEArQgILTyng8GoDRNYlgKyK github.com/hashicorp/raft v1.7.3 h1:DxpEqZJysHN0wK+fviai5mFcSYsCkNpFUl1xpAW8Rbo= github.com/hashicorp/raft v1.7.3/go.mod h1:DfvCGFxpAUPE0L4Uc8JLlTPtc3GzSbdH0MTJCLgnmJQ= github.com/hashicorp/raft-boltdb v0.0.0-20171010151810-6e5ba93211ea/go.mod h1:pNv7Wc3ycL6F5oOWn+tPGo2gWD4a5X+yp/ntwdKLjRk= -github.com/hashicorp/raft-boltdb v0.0.0-20230125174641-2a8082862702 h1:RLKEcCuKcZ+qp2VlaaZsYZfLOmIiuJNpEi48Rl8u9cQ= -github.com/hashicorp/raft-boltdb v0.0.0-20230125174641-2a8082862702/go.mod h1:nTakvJ4XYq45UXtn0DbwR4aU9ZdjlnIenpbs6Cd+FM0= -github.com/hashicorp/raft-boltdb/v2 v2.3.1 h1:ackhdCNPKblmOhjEU9+4lHSJYFkJd6Jqyvj6eW9pwkc= -github.com/hashicorp/raft-boltdb/v2 v2.3.1/go.mod h1:n4S+g43dXF1tqDT+yzcXHhXM6y7MrlUd3TTwGRcUvQE= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -342,8 +336,6 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= -go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= -go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= From 7c634c12ff409f0ca5819caadbce6f131c7c3298 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 05:02:10 +0900 Subject: [PATCH 26/43] Close dconn on fwd timeout to unblock writeMu, fix exhaustive lint - exitPubSubMode and cleanup now close dconn after timeout to unblock forwardMessages stuck on Flush, preventing writeMu deadlock - Add nolint:exhaustive for reflect.Kind switch that intentionally uses default for all non-container types --- proxy/pubsub.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index b6d2acc6..d5ab041a 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -83,11 +83,14 @@ func (s *pubsubSession) cleanup() { s.mu.Unlock() if s.fwdDone != nil { // Bounded wait: if forwardMessages is stuck on a slow/dead client socket, - // don't block cleanup indefinitely. + // close dconn to unblock it, then wait for completion. select { case <-s.fwdDone: case <-time.After(cleanupFwdTimeout): - s.logger.Warn("forwardMessages did not exit within timeout, proceeding with cleanup") + s.logger.Warn("forwardMessages did not exit within timeout, closing dconn to unblock") + s.dconn.Close() + <-s.fwdDone + return // dconn already closed } } s.dconn.Close() @@ -187,7 +190,9 @@ func (s *pubsubSession) exitPubSubMode() { select { case <-s.fwdDone: case <-time.After(cleanupFwdTimeout): - s.logger.Warn("forwardMessages did not exit within timeout during pub/sub mode exit") + s.logger.Warn("forwardMessages did not exit within timeout, closing dconn to unblock") + s.dconn.Close() + <-s.fwdDone } s.fwdDone = nil } From 77ee8a1771fb244760df7ecd32eb5571592ba88a Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 05:03:16 +0900 Subject: [PATCH 27/43] Use switch for reflect.Kind to satisfy staticcheck with exhaustive nolint --- proxy/sentry.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/proxy/sentry.go b/proxy/sentry.go index 905ee820..a6af913e 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -159,14 +159,15 @@ func truncateValue(v any) string { return truncateString(tv.String()) default: rv := reflect.ValueOf(v) - kind := rv.Kind() - if kind == reflect.Slice || kind == reflect.Array { + switch rv.Kind() { //nolint:exhaustive // only slice/array/map need special handling + case reflect.Slice, reflect.Array: return formatSliceValue(rv, maxSentryValueLen) - } else if kind == reflect.Map { + case reflect.Map: return formatMapValue(rv, maxSentryValueLen) + default: + // For non-container types, fall back to fmt and then truncate. + return truncateString(fmt.Sprintf("%v", v)) } - // For non-container types, fall back to fmt and then truncate. - return truncateString(fmt.Sprintf("%v", v)) } } From f389fc606ba1be48103ca7b8bbd729e64a271d54 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 13:51:37 +0900 Subject: [PATCH 28/43] Fix Pipeline redis.Error wrapping and separate state from writeMu - Pipeline now returns results with nil error for redis.Error/redis.Nil, only propagating transport/context errors. This preserves Redis transaction semantics in execTxn callers. - handleUnsub and writeUnsubAll now mutate goroutine-confined state and pre-compute counts before acquiring writeMu, keeping writeMu strictly for dconn write serialization. --- proxy/backend.go | 8 ++++++++ proxy/pubsub.go | 31 ++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/proxy/backend.go b/proxy/backend.go index 6ca88d66..9960cee1 100644 --- a/proxy/backend.go +++ b/proxy/backend.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "errors" "fmt" "time" @@ -92,6 +93,13 @@ func (b *RedisBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd } _, err := pipe.Exec(ctx) if err != nil { + // go-redis pipelines return redis.Error for Redis reply errors (e.g., EXECABORT). + // Return results with nil error so callers can read per-command results (especially EXEC). + // Only propagate true transport/context errors. + var redisErr redis.Error + if errors.As(err, &redisErr) || errors.Is(err, redis.Nil) { + return results, nil + } return results, fmt.Errorf("pipeline exec: %w", err) } return results, nil diff --git a/proxy/pubsub.go b/proxy/pubsub.go index d5ab041a..705f503b 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -496,18 +496,22 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { s.writeRedisError(err) return } - // Update state then write replies. - s.writeMu.Lock() - for _, n := range names { + // Update state (goroutine-confined) and pre-compute counts before taking writeMu. + counts := make([]int, len(names)) + for i, n := range names { if isPattern { delete(s.patternSet, n) } else { delete(s.channelSet, n) } + counts[i] = s.subCount() + } + s.writeMu.Lock() + for i, n := range names { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(n) - s.dconn.WriteInt64(int64(s.subCount())) + s.dconn.WriteInt64(int64(counts[i])) } _ = s.dconn.Flush() s.writeMu.Unlock() @@ -521,37 +525,42 @@ func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { set = s.patternSet } - s.writeMu.Lock() - defer s.writeMu.Unlock() - if len(set) == 0 { // No subscriptions: single reply with null channel (matching Redis). + s.writeMu.Lock() s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteNull() s.dconn.WriteInt64(int64(s.subCount())) _ = s.dconn.Flush() + s.writeMu.Unlock() return } - // Collect names, then remove one-by-one to decrement count per reply. + // Collect names and pre-compute decreasing counts (state is goroutine-confined). names := make([]string, 0, len(set)) for n := range set { names = append(names, n) } - - for _, n := range names { + counts := make([]int, len(names)) + for i, n := range names { if isPattern { delete(s.patternSet, n) } else { delete(s.channelSet, n) } + counts[i] = s.subCount() + } + + s.writeMu.Lock() + for i, n := range names { s.dconn.WriteArray(pubsubArrayReply) s.dconn.WriteBulkString(kind) s.dconn.WriteBulkString(n) - s.dconn.WriteInt64(int64(s.subCount())) + s.dconn.WriteInt64(int64(counts[i])) } _ = s.dconn.Flush() + s.writeMu.Unlock() } // --- Ping handlers --- From 70b14f4cd5f9682cc63d0c7eb56d676f6fd3d892 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 14:05:17 +0900 Subject: [PATCH 29/43] Add shadow subscribe for pub/sub divergence detection Implement shadow pub/sub that subscribes to the secondary backend for the same channels/patterns and compares received messages against the primary using a time-windowed matching strategy. - New shadowPubSub struct with sliding-window message comparison - Messages matched by (channel, payload) key within configurable window - Primary-only messages after window expiry reported as data_mismatch - Secondary-only messages with no primary match reported as extra_data - All shadow operations are fire-and-forget (never block primary path) - Active only in ModeDualWriteShadow / ModeElasticKVPrimary - New metrics: pubsub_shadow_divergences_total, pubsub_shadow_errors_total - New config: PubSubCompareWindow (default 2s) - Fix Pipeline to not wrap redis.Error (preserves EXEC result semantics) --- proxy/config.go | 52 +++++---- proxy/dualwrite.go | 12 ++ proxy/metrics.go | 16 +++ proxy/proxy.go | 29 +++++ proxy/pubsub.go | 135 ++++++++++++++++------ proxy/shadow_pubsub.go | 217 ++++++++++++++++++++++++++++++++++++ proxy/shadow_pubsub_test.go | 165 +++++++++++++++++++++++++++ 7 files changed, 569 insertions(+), 57 deletions(-) create mode 100644 proxy/shadow_pubsub.go create mode 100644 proxy/shadow_pubsub_test.go diff --git a/proxy/config.go b/proxy/config.go index 395308e9..44ba290d 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -3,8 +3,10 @@ package proxy import "time" const ( - defaultSecondaryTimeout = 5 * time.Second - defaultShadowTimeout = 3 * time.Second + defaultSecondaryTimeout = 5 * time.Second + defaultShadowTimeout = 3 * time.Second + defaultPubSubCompareWindow = 2 * time.Second + defaultPubSubSweepInterval = 500 * time.Millisecond ) // ProxyMode controls which backends receive reads and writes. @@ -49,32 +51,34 @@ func (m ProxyMode) String() string { // ProxyConfig holds all configuration for the dual-write proxy. type ProxyConfig struct { - ListenAddr string - PrimaryAddr string - PrimaryDB int - PrimaryPassword string - SecondaryAddr string - SecondaryDB int - SecondaryPassword string - Mode ProxyMode - SecondaryTimeout time.Duration - ShadowTimeout time.Duration - SentryDSN string - SentryEnv string - SentrySampleRate float64 - MetricsAddr string + ListenAddr string + PrimaryAddr string + PrimaryDB int + PrimaryPassword string + SecondaryAddr string + SecondaryDB int + SecondaryPassword string + Mode ProxyMode + SecondaryTimeout time.Duration + ShadowTimeout time.Duration + SentryDSN string + SentryEnv string + SentrySampleRate float64 + MetricsAddr string + PubSubCompareWindow time.Duration } // DefaultConfig returns a ProxyConfig with sensible defaults. func DefaultConfig() ProxyConfig { return ProxyConfig{ - ListenAddr: ":6479", - PrimaryAddr: "localhost:6379", - SecondaryAddr: "localhost:6380", - Mode: ModeDualWrite, - SecondaryTimeout: defaultSecondaryTimeout, - ShadowTimeout: defaultShadowTimeout, - SentrySampleRate: 1.0, - MetricsAddr: ":9191", + ListenAddr: ":6479", + PrimaryAddr: "localhost:6379", + SecondaryAddr: "localhost:6380", + Mode: ModeDualWrite, + SecondaryTimeout: defaultSecondaryTimeout, + ShadowTimeout: defaultShadowTimeout, + SentrySampleRate: 1.0, + MetricsAddr: ":9191", + PubSubCompareWindow: defaultPubSubCompareWindow, } } diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index 9212c7f2..3d050b28 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -254,6 +254,18 @@ func (d *DualWriter) PubSubBackend() PubSubBackend { return nil } +// ShadowPubSubBackend returns the secondary backend as a PubSubBackend +// when shadow mode is active, or nil otherwise. +func (d *DualWriter) ShadowPubSubBackend() PubSubBackend { + if d.shadow == nil { + return nil + } + if ps, ok := d.secondary.(PubSubBackend); ok { + return ps + } + return nil +} + // Secondary returns the secondary backend. func (d *DualWriter) Secondary() Backend { return d.secondary diff --git a/proxy/metrics.go b/proxy/metrics.go index 67c6128c..177f9f53 100644 --- a/proxy/metrics.go +++ b/proxy/metrics.go @@ -15,6 +15,9 @@ type ProxyMetrics struct { MigrationGaps *prometheus.CounterVec ActiveConnections prometheus.Gauge + + PubSubShadowDivergences *prometheus.CounterVec + PubSubShadowErrors prometheus.Counter } // NewProxyMetrics creates and registers all proxy metrics. @@ -69,6 +72,17 @@ func NewProxyMetrics(reg prometheus.Registerer) *ProxyMetrics { Name: "active_connections", Help: "Current number of active client connections.", }), + + PubSubShadowDivergences: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "pubsub_shadow_divergences_total", + Help: "Total pub/sub message mismatches detected by shadow subscribe.", + }, []string{"channel", "kind"}), + PubSubShadowErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "pubsub_shadow_errors_total", + Help: "Total errors from shadow pub/sub operations.", + }), } reg.MustRegister( @@ -81,6 +95,8 @@ func NewProxyMetrics(reg prometheus.Registerer) *ProxyMetrics { m.Divergences, m.MigrationGaps, m.ActiveConnections, + m.PubSubShadowDivergences, + m.PubSubShadowErrors, ) return m diff --git a/proxy/proxy.go b/proxy/proxy.go index 298dd991..d60fdec6 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -252,6 +252,8 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args patternSet: make(map[string]struct{}), } + session.shadow = p.createShadowPubSub(cmdName, channels) + // Write initial subscription confirmations. kind := strings.ToLower(cmdName) for _, ch := range channels { @@ -268,12 +270,39 @@ func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args if err := dconn.Flush(); err != nil { dconn.Close() upstream.Close() + if session.shadow != nil { + session.shadow.Close() + } return } go session.run() } +// createShadowPubSub creates a shadow pub/sub for secondary comparison if in shadow mode. +// Returns nil if shadow mode is not active or the shadow subscribe fails. +func (p *ProxyServer) createShadowPubSub(cmdName string, channels []string) *shadowPubSub { + shadowBackend := p.dual.ShadowPubSubBackend() + if shadowBackend == nil { + return nil + } + shadow := newShadowPubSub(shadowBackend, p.metrics, p.sentry, p.logger, p.cfg.PubSubCompareWindow) + var err error + if cmdName == cmdSubscribe { + err = shadow.Subscribe(context.Background(), channels...) + } else { + err = shadow.PSubscribe(context.Background(), channels...) + } + if err != nil { + p.logger.Warn("shadow pubsub subscribe failed", "err", err) + p.metrics.PubSubShadowErrors.Inc() + shadow.Close() + return nil + } + shadow.Start() + return shadow +} + func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) { name := strings.ToUpper(string(args[0])) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 705f503b..f5236c79 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -56,6 +56,9 @@ type pubsubSession struct { // fwdDone is closed when the current forwardMessages goroutine exits. fwdDone chan struct{} + // Shadow pub/sub for secondary comparison (nil when not in shadow mode). + shadow *shadowPubSub + // Transaction state for normal command mode. inTxn bool txnQueue [][][]byte @@ -81,6 +84,10 @@ func (s *pubsubSession) cleanup() { s.upstream = nil } s.mu.Unlock() + if s.shadow != nil { + s.shadow.Close() + s.shadow = nil + } if s.fwdDone != nil { // Bounded wait: if forwardMessages is stuck on a slow/dead client socket, // close dconn to unblock it, then wait for completion. @@ -140,6 +147,10 @@ func (s *pubsubSession) forwardMessages(ch <-chan *redis.Message) { if err != nil { return } + // Record for shadow comparison (outside writeMu to avoid nested locking). + if s.shadow != nil { + s.shadow.RecordPrimary(msg) + } } } @@ -186,6 +197,10 @@ func (s *pubsubSession) exitPubSubMode() { s.upstream = nil } s.mu.Unlock() + if s.shadow != nil { + s.shadow.Close() + s.shadow = nil + } if s.fwdDone != nil { select { case <-s.fwdDone: @@ -355,6 +370,8 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { s.mu.Unlock() s.startForwarding() + s.shadow = s.proxy.createShadowPubSub(cmdName, channels) + // Update state (sets only accessed from commandLoop goroutine). kind := strings.ToLower(cmdName) for _, ch := range channels { @@ -418,51 +435,49 @@ func (s *pubsubSession) execTxn() { // --- Subscription handlers --- func (s *pubsubSession) handleSubscribe(args [][]byte) { - if len(args) < pubsubMinArgs { - s.writeError("ERR wrong number of arguments for 'subscribe'") - return - } - channels := byteSlicesToStrings(args[1:]) - if err := s.upstream.Subscribe(context.Background(), channels...); err != nil { - s.logger.Warn("upstream subscribe failed", "err", err) - s.writeRedisError(err) - return - } - // Update state (channelSet is only accessed from commandLoop goroutine). - for _, ch := range channels { - s.channelSet[ch] = struct{}{} - } - s.writeMu.Lock() - for _, ch := range channels { - s.dconn.WriteArray(pubsubArrayReply) - s.dconn.WriteBulkString("subscribe") - s.dconn.WriteBulkString(ch) - s.dconn.WriteInt64(int64(s.subCount())) - } - _ = s.dconn.Flush() - s.writeMu.Unlock() + s.handleSub(args, false) } func (s *pubsubSession) handlePSubscribe(args [][]byte) { + s.handleSub(args, true) +} + +// handleSub is the shared implementation for SUBSCRIBE and PSUBSCRIBE. +func (s *pubsubSession) handleSub(args [][]byte, isPattern bool) { + kind := "subscribe" + if isPattern { + kind = "psubscribe" + } if len(args) < pubsubMinArgs { - s.writeError("ERR wrong number of arguments for 'psubscribe'") + s.writeError(fmt.Sprintf("ERR wrong number of arguments for '%s'", kind)) return } - pats := byteSlicesToStrings(args[1:]) - if err := s.upstream.PSubscribe(context.Background(), pats...); err != nil { - s.logger.Warn("upstream psubscribe failed", "err", err) + names := byteSlicesToStrings(args[1:]) + var err error + if isPattern { + err = s.upstream.PSubscribe(context.Background(), names...) + } else { + err = s.upstream.Subscribe(context.Background(), names...) + } + if err != nil { + s.logger.Warn("upstream "+kind+" failed", "err", err) s.writeRedisError(err) return } - // Update state (patternSet is only accessed from commandLoop goroutine). - for _, p := range pats { - s.patternSet[p] = struct{}{} + s.mirrorSub(names, isPattern) + // Update state (goroutine-confined to commandLoop). + for _, n := range names { + if isPattern { + s.patternSet[n] = struct{}{} + } else { + s.channelSet[n] = struct{}{} + } } s.writeMu.Lock() - for _, p := range pats { + for _, n := range names { s.dconn.WriteArray(pubsubArrayReply) - s.dconn.WriteBulkString("psubscribe") - s.dconn.WriteBulkString(p) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(n) s.dconn.WriteInt64(int64(s.subCount())) } _ = s.dconn.Flush() @@ -486,6 +501,9 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { s.writeRedisError(err) return } + if s.shadow != nil { + s.mirrorUnsubAll(isPattern) + } s.writeUnsubAll(kind, isPattern) return } @@ -496,6 +514,9 @@ func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { s.writeRedisError(err) return } + if s.shadow != nil { + s.mirrorUnsub(names, isPattern) + } // Update state (goroutine-confined) and pre-compute counts before taking writeMu. counts := make([]int, len(names)) for i, n := range names { @@ -599,6 +620,54 @@ func (s *pubsubSession) handleUnsubNoSession(cmdName string) { _ = s.dconn.Flush() } +// --- Shadow mirror helpers --- + +func (s *pubsubSession) mirrorSub(names []string, isPattern bool) { + if s.shadow == nil { + return + } + var err error + if isPattern { + err = s.shadow.PSubscribe(context.Background(), names...) + } else { + err = s.shadow.Subscribe(context.Background(), names...) + } + if err != nil { + kind := "subscribe" + if isPattern { + kind = "psubscribe" + } + s.logger.Warn("shadow "+kind+" failed", "err", err) + s.proxy.metrics.PubSubShadowErrors.Inc() + } +} + +func (s *pubsubSession) mirrorUnsubAll(isPattern bool) { + var err error + if isPattern { + err = s.shadow.PUnsubscribe(context.Background()) + } else { + err = s.shadow.Unsubscribe(context.Background()) + } + if err != nil { + s.logger.Warn("shadow unsubscribe-all failed", "err", err) + s.proxy.metrics.PubSubShadowErrors.Inc() + } +} + +func (s *pubsubSession) mirrorUnsub(names []string, isPattern bool) { + var err error + if isPattern { + err = s.shadow.PUnsubscribe(context.Background(), names...) + } else { + err = s.shadow.Unsubscribe(context.Background(), names...) + } + if err != nil { + s.logger.Warn("shadow unsubscribe failed", "err", err) + s.proxy.metrics.PubSubShadowErrors.Inc() + } +} + // --- Helpers --- func (s *pubsubSession) writeError(msg string) { diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go new file mode 100644 index 00000000..dc97e4a7 --- /dev/null +++ b/proxy/shadow_pubsub.go @@ -0,0 +1,217 @@ +package proxy + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +// msgKey is used as a map key for matching primary and secondary messages. +type msgKey struct { + Channel string + Payload string +} + +// pendingMsg records a message awaiting its counterpart from the other source. +type pendingMsg struct { + channel string + payload string + timestamp time.Time +} + +// shadowPubSub subscribes to the secondary backend for the same channels +// and compares messages against the primary to detect divergences. +type shadowPubSub struct { + secondary *redis.PubSub + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + window time.Duration + + mu sync.Mutex + pending map[msgKey][]pendingMsg // primary messages awaiting secondary match + closed bool + done chan struct{} +} + +func newShadowPubSub(backend PubSubBackend, metrics *ProxyMetrics, sentry *SentryReporter, logger *slog.Logger, window time.Duration) *shadowPubSub { + return &shadowPubSub{ + secondary: backend.NewPubSub(context.Background()), + metrics: metrics, + sentry: sentry, + logger: logger, + window: window, + pending: make(map[msgKey][]pendingMsg), + done: make(chan struct{}), + } +} + +// Start begins reading from the secondary and comparing messages. +// Must be called after initial subscribe. +func (sp *shadowPubSub) Start() { + ch := sp.secondary.Channel() + go func() { + defer close(sp.done) + sp.compareLoop(ch) + }() +} + +// Subscribe mirrors a SUBSCRIBE to the secondary. +func (sp *shadowPubSub) Subscribe(ctx context.Context, channels ...string) error { + if err := sp.secondary.Subscribe(ctx, channels...); err != nil { + return fmt.Errorf("shadow subscribe: %w", err) + } + return nil +} + +// PSubscribe mirrors a PSUBSCRIBE to the secondary. +func (sp *shadowPubSub) PSubscribe(ctx context.Context, patterns ...string) error { + if err := sp.secondary.PSubscribe(ctx, patterns...); err != nil { + return fmt.Errorf("shadow psubscribe: %w", err) + } + return nil +} + +// Unsubscribe mirrors an UNSUBSCRIBE to the secondary. +func (sp *shadowPubSub) Unsubscribe(ctx context.Context, channels ...string) error { + if err := sp.secondary.Unsubscribe(ctx, channels...); err != nil { + return fmt.Errorf("shadow unsubscribe: %w", err) + } + return nil +} + +// PUnsubscribe mirrors a PUNSUBSCRIBE to the secondary. +func (sp *shadowPubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { + if err := sp.secondary.PUnsubscribe(ctx, patterns...); err != nil { + return fmt.Errorf("shadow punsubscribe: %w", err) + } + return nil +} + +// RecordPrimary records a message received from the primary for comparison. +func (sp *shadowPubSub) RecordPrimary(msg *redis.Message) { + sp.mu.Lock() + defer sp.mu.Unlock() + if sp.closed { + return + } + key := msgKey{Channel: msg.Channel, Payload: msg.Payload} + sp.pending[key] = append(sp.pending[key], pendingMsg{ + channel: msg.Channel, + payload: msg.Payload, + timestamp: time.Now(), + }) +} + +// Close stops the shadow comparison and closes the secondary pub/sub. +func (sp *shadowPubSub) Close() { + sp.mu.Lock() + sp.closed = true + sp.mu.Unlock() + sp.secondary.Close() + <-sp.done +} + +// compareLoop reads from the secondary channel and matches messages. +func (sp *shadowPubSub) compareLoop(ch <-chan *redis.Message) { + ticker := time.NewTicker(defaultPubSubSweepInterval) + defer ticker.Stop() + + for { + select { + case msg, ok := <-ch: + if !ok { + // Channel closed — secondary connection terminated. + sp.sweepExpired() + return + } + sp.matchSecondary(msg) + case <-ticker.C: + sp.sweepExpired() + } + } +} + +// matchSecondary tries to match a secondary message against a pending primary message. +func (sp *shadowPubSub) matchSecondary(msg *redis.Message) { + sp.mu.Lock() + defer sp.mu.Unlock() + + key := msgKey{Channel: msg.Channel, Payload: msg.Payload} + if entries, ok := sp.pending[key]; ok && len(entries) > 0 { + // Match found — remove the oldest pending primary message. + if len(entries) == 1 { + delete(sp.pending, key) + } else { + sp.pending[key] = entries[1:] + } + return + } + + // No matching primary message — extra on secondary. + sp.reportDivergence(msg.Channel, msg.Payload, DivExtraData) +} + +// sweepExpired reports primary messages that were not matched within the window. +func (sp *shadowPubSub) sweepExpired() { + sp.mu.Lock() + defer sp.mu.Unlock() + + now := time.Now() + for key, entries := range sp.pending { + var remaining []pendingMsg + for _, e := range entries { + if now.Sub(e.timestamp) >= sp.window { + sp.reportDivergenceLocked(e.channel, e.payload, DivDataMismatch) + } else { + remaining = append(remaining, e) + } + } + if len(remaining) == 0 { + delete(sp.pending, key) + } else { + sp.pending[key] = remaining + } + } +} + +func (sp *shadowPubSub) reportDivergence(channel, payload string, kind DivergenceKind) { + sp.metrics.PubSubShadowDivergences.WithLabelValues(channel, kind.String()).Inc() + sp.logger.Warn("pubsub shadow divergence", + "channel", truncateValue(channel), + "payload", truncateValue(payload), + "kind", kind.String(), + ) + sp.sentry.CaptureDivergence(Divergence{ + Command: "SUBSCRIBE", + Key: channel, + Kind: kind, + Primary: payload, + Secondary: nil, + DetectedAt: time.Now(), + }) +} + +// reportDivergenceLocked is the same as reportDivergence but assumes mu is held. +// It releases the lock briefly for the report to avoid holding it during I/O. +func (sp *shadowPubSub) reportDivergenceLocked(channel, payload string, kind DivergenceKind) { + // Metrics and logging are goroutine-safe; safe to call under lock. + sp.metrics.PubSubShadowDivergences.WithLabelValues(channel, kind.String()).Inc() + sp.logger.Warn("pubsub shadow divergence", + "channel", truncateValue(channel), + "payload", truncateValue(payload), + "kind", kind.String(), + ) + sp.sentry.CaptureDivergence(Divergence{ + Command: "SUBSCRIBE", + Key: channel, + Kind: kind, + Primary: payload, + Secondary: nil, + DetectedAt: time.Now(), + }) +} diff --git a/proxy/shadow_pubsub_test.go b/proxy/shadow_pubsub_test.go new file mode 100644 index 00000000..20eeacf1 --- /dev/null +++ b/proxy/shadow_pubsub_test.go @@ -0,0 +1,165 @@ +package proxy + +import ( + "log/slog" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +func counterValue(c prometheus.Counter) float64 { + m := &dto.Metric{} + pm, ok := c.(prometheus.Metric) + if !ok { + return 0 + } + _ = pm.Write(m) + return m.GetCounter().GetValue() +} + +func newTestShadowPubSub(window time.Duration) *shadowPubSub { + return &shadowPubSub{ + metrics: newTestMetrics(), + sentry: newTestSentry(), + logger: slog.Default(), + window: window, + pending: make(map[msgKey][]pendingMsg), + done: make(chan struct{}), + } +} + +func TestShadowPubSub_MatchedMessage(t *testing.T) { + sp := newTestShadowPubSub(100 * time.Millisecond) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "hello"}) + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "hello"}) + + sp.mu.Lock() + remaining := len(sp.pending) + sp.mu.Unlock() + assert.Equal(t, 0, remaining, "matched message should be removed from pending") +} + +func TestShadowPubSub_MissingOnSecondary(t *testing.T) { + sp := newTestShadowPubSub(10 * time.Millisecond) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "hello"}) + time.Sleep(20 * time.Millisecond) + sp.sweepExpired() + + sp.mu.Lock() + remaining := len(sp.pending) + sp.mu.Unlock() + assert.Equal(t, 0, remaining, "expired message should be removed") + + val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("ch1", "data_mismatch")) + assert.Equal(t, float64(1), val) +} + +func TestShadowPubSub_ExtraOnSecondary(t *testing.T) { + sp := newTestShadowPubSub(100 * time.Millisecond) + + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "extra"}) + + val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("ch1", "extra_data")) + assert.Equal(t, float64(1), val) +} + +func TestShadowPubSub_OutOfOrderMatching(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "msg1"}) + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "msg2"}) + + // Secondary delivers in reverse order. + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "msg2"}) + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "msg1"}) + + sp.mu.Lock() + remaining := len(sp.pending) + sp.mu.Unlock() + assert.Equal(t, 0, remaining, "all messages should be matched") +} + +func TestShadowPubSub_DuplicateMessages(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "dup"}) + + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.mu.Lock() + assert.Equal(t, 1, len(sp.pending[msgKey{Channel: "ch1", Payload: "dup"}])) + sp.mu.Unlock() + + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending)) + sp.mu.Unlock() +} + +func TestShadowPubSub_RecordAfterClose(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.mu.Lock() + sp.closed = true + sp.mu.Unlock() + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "after-close"}) + + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending)) + sp.mu.Unlock() +} + +func TestShadowPubSub_CompareLoopExitsOnChannelClose(t *testing.T) { + sp := newTestShadowPubSub(10 * time.Millisecond) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "orphan"}) + time.Sleep(20 * time.Millisecond) + + ch := make(chan *redis.Message) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sp.compareLoop(ch) + }() + + close(ch) + wg.Wait() + + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending), "should sweep on exit") + sp.mu.Unlock() +} + +func TestShadowPubSub_CompareLoopMatchesFromChannel(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "msg1"}) + + ch := make(chan *redis.Message, 1) + ch <- &redis.Message{Channel: "ch1", Payload: "msg1"} + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sp.compareLoop(ch) + }() + + // Give time for the message to be processed, then close. + time.Sleep(10 * time.Millisecond) + close(ch) + wg.Wait() + + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending), "message should be matched via compareLoop") + sp.mu.Unlock() +} From 493f80cf202f8979c59fc5c58c8daabacf271317 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 17:46:22 +0900 Subject: [PATCH 30/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/sentry.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/proxy/sentry.go b/proxy/sentry.go index a6af913e..da26f7b8 100644 --- a/proxy/sentry.go +++ b/proxy/sentry.go @@ -145,6 +145,10 @@ func cmdNameFromArgs(args [][]byte) string { // truncateValue formats a value for logging/Sentry, truncating to avoid data leakage and oversized events. // Handles common types by slicing before formatting to avoid allocating the full string representation. func truncateValue(v any) string { + if v == nil { + return "" + } + switch tv := v.(type) { case string: return truncateString(tv) From a998315206627431ae1823367910c0a779785a4f Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 17:46:33 +0900 Subject: [PATCH 31/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/shadow_pubsub.go | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index dc97e4a7..1d61d1c8 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -158,15 +158,27 @@ func (sp *shadowPubSub) matchSecondary(msg *redis.Message) { // sweepExpired reports primary messages that were not matched within the window. func (sp *shadowPubSub) sweepExpired() { + // Collect divergences while holding the lock, then report them after releasing it + // to avoid doing potentially blocking I/O (logging/Sentry) under sp.mu. + type divergenceEvent struct { + channel string + payload string + kind DivergenceKind + } + var divergences []divergenceEvent + sp.mu.Lock() - defer sp.mu.Unlock() now := time.Now() for key, entries := range sp.pending { var remaining []pendingMsg for _, e := range entries { if now.Sub(e.timestamp) >= sp.window { - sp.reportDivergenceLocked(e.channel, e.payload, DivDataMismatch) + divergences = append(divergences, divergenceEvent{ + channel: e.channel, + payload: e.payload, + kind: DivDataMismatch, + }) } else { remaining = append(remaining, e) } @@ -177,6 +189,12 @@ func (sp *shadowPubSub) sweepExpired() { sp.pending[key] = remaining } } + + sp.mu.Unlock() + + for _, d := range divergences { + sp.reportDivergence(d.channel, d.payload, d.kind) + } } func (sp *shadowPubSub) reportDivergence(channel, payload string, kind DivergenceKind) { @@ -197,7 +215,7 @@ func (sp *shadowPubSub) reportDivergence(channel, payload string, kind Divergenc } // reportDivergenceLocked is the same as reportDivergence but assumes mu is held. -// It releases the lock briefly for the report to avoid holding it during I/O. +// It does not release the lock; callers must ensure it's safe to perform logging/Sentry under sp.mu. func (sp *shadowPubSub) reportDivergenceLocked(channel, payload string, kind DivergenceKind) { // Metrics and logging are goroutine-safe; safe to call under lock. sp.metrics.PubSubShadowDivergences.WithLabelValues(channel, kind.String()).Inc() From ec3d20a87523740ea1f8e0f878cf749906b8261f Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 17:55:04 +0900 Subject: [PATCH 32/43] Fix shadow Close deadlock, shadow data race, remove unused method - Track started flag so Close is safe before Start (prevents deadlock when subscribe fails during setup) - Guard s.shadow access in forwardMessages with mu to prevent data race with concurrent cleanup/exitPubSubMode - Move closeShadow after fwdDone wait so RecordPrimary completes first - Remove unused reportDivergenceLocked --- proxy/pubsub.go | 31 +++++++++++++++++++++---------- proxy/shadow_pubsub.go | 30 +++++++++--------------------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index f5236c79..25c70a35 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -84,10 +84,6 @@ func (s *pubsubSession) cleanup() { s.upstream = nil } s.mu.Unlock() - if s.shadow != nil { - s.shadow.Close() - s.shadow = nil - } if s.fwdDone != nil { // Bounded wait: if forwardMessages is stuck on a slow/dead client socket, // close dconn to unblock it, then wait for completion. @@ -97,9 +93,12 @@ func (s *pubsubSession) cleanup() { s.logger.Warn("forwardMessages did not exit within timeout, closing dconn to unblock") s.dconn.Close() <-s.fwdDone + s.closeShadow() return // dconn already closed } } + // Close shadow after forwardMessages exits (it calls RecordPrimary). + s.closeShadow() s.dconn.Close() } @@ -148,8 +147,12 @@ func (s *pubsubSession) forwardMessages(ch <-chan *redis.Message) { return } // Record for shadow comparison (outside writeMu to avoid nested locking). - if s.shadow != nil { - s.shadow.RecordPrimary(msg) + // Capture shadow under mu since cleanup/exitPubSubMode can nil it concurrently. + s.mu.Lock() + shadow := s.shadow + s.mu.Unlock() + if shadow != nil { + shadow.RecordPrimary(msg) } } } @@ -197,10 +200,6 @@ func (s *pubsubSession) exitPubSubMode() { s.upstream = nil } s.mu.Unlock() - if s.shadow != nil { - s.shadow.Close() - s.shadow = nil - } if s.fwdDone != nil { select { case <-s.fwdDone: @@ -211,6 +210,8 @@ func (s *pubsubSession) exitPubSubMode() { } s.fwdDone = nil } + // Close shadow after forwardMessages exits (it calls RecordPrimary). + s.closeShadow() } // dispatchPubSubCommand handles a single command in pub/sub mode. @@ -620,6 +621,16 @@ func (s *pubsubSession) handleUnsubNoSession(cmdName string) { _ = s.dconn.Flush() } +func (s *pubsubSession) closeShadow() { + s.mu.Lock() + shadow := s.shadow + s.shadow = nil + s.mu.Unlock() + if shadow != nil { + shadow.Close() + } +} + // --- Shadow mirror helpers --- func (s *pubsubSession) mirrorSub(names []string, isPattern bool) { diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index 1d61d1c8..601beb4c 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -35,6 +35,7 @@ type shadowPubSub struct { mu sync.Mutex pending map[msgKey][]pendingMsg // primary messages awaiting secondary match closed bool + started bool done chan struct{} } @@ -53,6 +54,9 @@ func newShadowPubSub(backend PubSubBackend, metrics *ProxyMetrics, sentry *Sentr // Start begins reading from the secondary and comparing messages. // Must be called after initial subscribe. func (sp *shadowPubSub) Start() { + sp.mu.Lock() + sp.started = true + sp.mu.Unlock() ch := sp.secondary.Channel() go func() { defer close(sp.done) @@ -108,12 +112,16 @@ func (sp *shadowPubSub) RecordPrimary(msg *redis.Message) { } // Close stops the shadow comparison and closes the secondary pub/sub. +// Safe to call even if Start was never called. func (sp *shadowPubSub) Close() { sp.mu.Lock() sp.closed = true + started := sp.started sp.mu.Unlock() sp.secondary.Close() - <-sp.done + if started { + <-sp.done + } } // compareLoop reads from the secondary channel and matches messages. @@ -213,23 +221,3 @@ func (sp *shadowPubSub) reportDivergence(channel, payload string, kind Divergenc DetectedAt: time.Now(), }) } - -// reportDivergenceLocked is the same as reportDivergence but assumes mu is held. -// It does not release the lock; callers must ensure it's safe to perform logging/Sentry under sp.mu. -func (sp *shadowPubSub) reportDivergenceLocked(channel, payload string, kind DivergenceKind) { - // Metrics and logging are goroutine-safe; safe to call under lock. - sp.metrics.PubSubShadowDivergences.WithLabelValues(channel, kind.String()).Inc() - sp.logger.Warn("pubsub shadow divergence", - "channel", truncateValue(channel), - "payload", truncateValue(payload), - "kind", kind.String(), - ) - sp.sentry.CaptureDivergence(Divergence{ - Command: "SUBSCRIBE", - Key: channel, - Kind: kind, - Primary: payload, - Secondary: nil, - DetectedAt: time.Now(), - }) -} From ca0bea193d288c3d82391c99b27d82b486f3f8b9 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 19:46:34 +0900 Subject: [PATCH 33/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/shadow_pubsub.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index 601beb4c..401a4418 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -212,12 +212,27 @@ func (sp *shadowPubSub) reportDivergence(channel, payload string, kind Divergenc "payload", truncateValue(payload), "kind", kind.String(), ) + + var primary any + var secondary any + switch kind { + case DivExtraData: + // Message exists on secondary but not on primary. + primary = nil + secondary = payload + default: + // Default: message exists on primary but not on secondary (or other kinds + // that follow the same primary/secondary semantics). + primary = payload + secondary = nil + } + sp.sentry.CaptureDivergence(Divergence{ Command: "SUBSCRIBE", Key: channel, Kind: kind, - Primary: payload, - Secondary: nil, + Primary: primary, + Secondary: secondary, DetectedAt: time.Now(), }) } From d7b675c44c36d6490ec0d66beba5466649cbea36 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 19:47:22 +0900 Subject: [PATCH 34/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index d60fdec6..13fc35c3 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -287,6 +287,7 @@ func (p *ProxyServer) createShadowPubSub(cmdName string, channels []string) *sha return nil } shadow := newShadowPubSub(shadowBackend, p.metrics, p.sentry, p.logger, p.cfg.PubSubCompareWindow) + shadow.Start() var err error if cmdName == cmdSubscribe { err = shadow.Subscribe(context.Background(), channels...) @@ -299,7 +300,6 @@ func (p *ProxyServer) createShadowPubSub(cmdName string, channels []string) *sha shadow.Close() return nil } - shadow.Start() return shadow } From 75a5209a99cac45734662eb92cc16044c70a95c5 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 22:17:44 +0900 Subject: [PATCH 35/43] Fix shadow pubsub review issues: msgKey Pattern, lock scope, sweepAll, divergence reporting - Include Pattern in msgKey for correct pmessage matching - Release lock before reporting in matchSecondary to avoid blocking I/O under mu - Add sweepAll to flush all pending on channel close (not just expired) - Set Primary/Secondary in reportDivergence based on DivergenceKind - Guard s.shadow assignment in reenterPubSub with mu - Promote divergenceEvent to package-level type shared by sweep methods --- proxy/pubsub.go | 5 ++- proxy/shadow_pubsub.go | 70 ++++++++++++++++++++++++++++++------------ 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 25c70a35..1aa66ac2 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -371,7 +371,10 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { s.mu.Unlock() s.startForwarding() - s.shadow = s.proxy.createShadowPubSub(cmdName, channels) + shadow := s.proxy.createShadowPubSub(cmdName, channels) + s.mu.Lock() + s.shadow = shadow + s.mu.Unlock() // Update state (sets only accessed from commandLoop goroutine). kind := strings.ToLower(cmdName) diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index 401a4418..ed723e64 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -11,18 +11,28 @@ import ( ) // msgKey is used as a map key for matching primary and secondary messages. +// Includes Pattern to correctly distinguish pmessage deliveries. type msgKey struct { + Pattern string Channel string Payload string } // pendingMsg records a message awaiting its counterpart from the other source. type pendingMsg struct { + pattern string channel string payload string timestamp time.Time } +// divergenceEvent holds divergence info collected under lock for deferred reporting. +type divergenceEvent struct { + channel string + payload string + kind DivergenceKind +} + // shadowPubSub subscribes to the secondary backend for the same channels // and compares messages against the primary to detect divergences. type shadowPubSub struct { @@ -103,8 +113,9 @@ func (sp *shadowPubSub) RecordPrimary(msg *redis.Message) { if sp.closed { return } - key := msgKey{Channel: msg.Channel, Payload: msg.Payload} + key := msgKeyFromMessage(msg) sp.pending[key] = append(sp.pending[key], pendingMsg{ + pattern: msg.Pattern, channel: msg.Channel, payload: msg.Payload, timestamp: time.Now(), @@ -133,8 +144,8 @@ func (sp *shadowPubSub) compareLoop(ch <-chan *redis.Message) { select { case msg, ok := <-ch: if !ok { - // Channel closed — secondary connection terminated. - sp.sweepExpired() + // Channel closed — flush all remaining pending as divergences. + sp.sweepAll() return } sp.matchSecondary(msg) @@ -145,11 +156,11 @@ func (sp *shadowPubSub) compareLoop(ch <-chan *redis.Message) { } // matchSecondary tries to match a secondary message against a pending primary message. +// Collects divergence info under lock and reports after releasing it. func (sp *shadowPubSub) matchSecondary(msg *redis.Message) { sp.mu.Lock() - defer sp.mu.Unlock() - key := msgKey{Channel: msg.Channel, Payload: msg.Payload} + key := msgKeyFromMessage(msg) if entries, ok := sp.pending[key]; ok && len(entries) > 0 { // Match found — remove the oldest pending primary message. if len(entries) == 1 { @@ -157,26 +168,21 @@ func (sp *shadowPubSub) matchSecondary(msg *redis.Message) { } else { sp.pending[key] = entries[1:] } + sp.mu.Unlock() return } + sp.mu.Unlock() + // No matching primary message — extra on secondary. sp.reportDivergence(msg.Channel, msg.Payload, DivExtraData) } // sweepExpired reports primary messages that were not matched within the window. func (sp *shadowPubSub) sweepExpired() { - // Collect divergences while holding the lock, then report them after releasing it - // to avoid doing potentially blocking I/O (logging/Sentry) under sp.mu. - type divergenceEvent struct { - channel string - payload string - kind DivergenceKind - } var divergences []divergenceEvent sp.mu.Lock() - now := time.Now() for key, entries := range sp.pending { var remaining []pendingMsg @@ -197,7 +203,28 @@ func (sp *shadowPubSub) sweepExpired() { sp.pending[key] = remaining } } + sp.mu.Unlock() + + for _, d := range divergences { + sp.reportDivergence(d.channel, d.payload, d.kind) + } +} + +// sweepAll reports all remaining pending messages as divergences (used on shutdown). +func (sp *shadowPubSub) sweepAll() { + var divergences []divergenceEvent + sp.mu.Lock() + for key, entries := range sp.pending { + for _, e := range entries { + divergences = append(divergences, divergenceEvent{ + channel: e.channel, + payload: e.payload, + kind: DivDataMismatch, + }) + } + delete(sp.pending, key) + } sp.mu.Unlock() for _, d := range divergences { @@ -213,20 +240,15 @@ func (sp *shadowPubSub) reportDivergence(channel, payload string, kind Divergenc "kind", kind.String(), ) - var primary any - var secondary any - switch kind { + var primary, secondary any + switch kind { //nolint:exhaustive // only two kinds apply to pub/sub shadow case DivExtraData: - // Message exists on secondary but not on primary. primary = nil secondary = payload default: - // Default: message exists on primary but not on secondary (or other kinds - // that follow the same primary/secondary semantics). primary = payload secondary = nil } - sp.sentry.CaptureDivergence(Divergence{ Command: "SUBSCRIBE", Key: channel, @@ -236,3 +258,11 @@ func (sp *shadowPubSub) reportDivergence(channel, payload string, kind Divergenc DetectedAt: time.Now(), }) } + +func msgKeyFromMessage(msg *redis.Message) msgKey { + return msgKey{ + Pattern: msg.Pattern, + Channel: msg.Channel, + Payload: msg.Payload, + } +} From 6b8da888adf77ccbf6065d5647e16d75d2b390e5 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 23:07:57 +0900 Subject: [PATCH 36/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/pubsub.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 1aa66ac2..cc535f95 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -453,7 +453,7 @@ func (s *pubsubSession) handleSub(args [][]byte, isPattern bool) { kind = "psubscribe" } if len(args) < pubsubMinArgs { - s.writeError(fmt.Sprintf("ERR wrong number of arguments for '%s'", kind)) + s.writeError(fmt.Sprintf("ERR wrong number of arguments for '%s' command", kind)) return } names := byteSlicesToStrings(args[1:]) From 2746640369f465bc42059ad1b686ee1365fd65a5 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 23:08:31 +0900 Subject: [PATCH 37/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/pubsub.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index cc535f95..70117b98 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -366,15 +366,13 @@ func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { return } - s.mu.Lock() - s.upstream = upstream - s.mu.Unlock() - s.startForwarding() - shadow := s.proxy.createShadowPubSub(cmdName, channels) + s.mu.Lock() + s.upstream = upstream s.shadow = shadow s.mu.Unlock() + s.startForwarding() // Update state (sets only accessed from commandLoop goroutine). kind := strings.ToLower(cmdName) From cd9ffed6d70a2cda5c9f5a89834f016dee3b2bc8 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 23:09:01 +0900 Subject: [PATCH 38/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 13fc35c3..d60fdec6 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -287,7 +287,6 @@ func (p *ProxyServer) createShadowPubSub(cmdName string, channels []string) *sha return nil } shadow := newShadowPubSub(shadowBackend, p.metrics, p.sentry, p.logger, p.cfg.PubSubCompareWindow) - shadow.Start() var err error if cmdName == cmdSubscribe { err = shadow.Subscribe(context.Background(), channels...) @@ -300,6 +299,7 @@ func (p *ProxyServer) createShadowPubSub(cmdName string, channels []string) *sha shadow.Close() return nil } + shadow.Start() return shadow } From 3cce4093e5b8932e973882e30aa0880a2cdc624d Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Wed, 18 Mar 2026 23:09:49 +0900 Subject: [PATCH 39/43] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- proxy/shadow_pubsub.go | 76 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index ed723e64..02d762fe 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -18,6 +18,25 @@ type msgKey struct { Payload string } +// secondaryPending represents a secondary message that has not yet been matched +// to a primary within the comparison window. +type secondaryPending struct { + timestamp time.Time + channel string + payload string +} + +// unmatchedSecondaries buffers unmatched secondary messages per shadowPubSub +// instance. This allows us to avoid reporting DivExtraData immediately when a +// secondary arrives before the corresponding primary (e.g. due to network +// jitter) and instead only report once the comparison window has elapsed. +var unmatchedSecondaries = struct { + sync.Mutex + data map[*shadowPubSub]map[msgKey]secondaryPending +}{ + data: make(map[*shadowPubSub]map[msgKey]secondaryPending), +} + // pendingMsg records a message awaiting its counterpart from the other source. type pendingMsg struct { pattern string @@ -174,8 +193,23 @@ func (sp *shadowPubSub) matchSecondary(msg *redis.Message) { sp.mu.Unlock() - // No matching primary message — extra on secondary. - sp.reportDivergence(msg.Channel, msg.Payload, DivExtraData) + // No matching primary message at this moment. Buffer the secondary and only + // report DivExtraData if it remains unmatched after the comparison window. + now := time.Now() + unmatchedSecondaries.Lock() + defer unmatchedSecondaries.Unlock() + + perInstance, ok := unmatchedSecondaries.data[sp] + if !ok { + perInstance = make(map[msgKey]secondaryPending) + unmatchedSecondaries.data[sp] = perInstance + } + + perInstance[key] = secondaryPending{ + timestamp: now, + channel: msg.Channel, + payload: msg.Payload, + } } // sweepExpired reports primary messages that were not matched within the window. @@ -184,9 +218,27 @@ func (sp *shadowPubSub) sweepExpired() { sp.mu.Lock() now := time.Now() + + // Lock the unmatched secondary buffer while we reconcile primaries and + // secondaries and age out any expired entries. + unmatchedSecondaries.Lock() + perInstance, ok := unmatchedSecondaries.data[sp] + if !ok { + perInstance = make(map[msgKey]secondaryPending) + unmatchedSecondaries.data[sp] = perInstance + } + + // First, reconcile pending primaries against buffered secondaries. If a + // primary has a matching buffered secondary within the window, treat them + // as matched and drop both without reporting a divergence. for key, entries := range sp.pending { var remaining []pendingMsg for _, e := range entries { + if sec, ok := perInstance[key]; ok && now.Sub(sec.timestamp) < sp.window { + // Matched with a buffered secondary; drop both. + delete(perInstance, key) + continue + } if now.Sub(e.timestamp) >= sp.window { divergences = append(divergences, divergenceEvent{ channel: e.channel, @@ -203,6 +255,26 @@ func (sp *shadowPubSub) sweepExpired() { sp.pending[key] = remaining } } + + // Next, age out any remaining buffered secondaries that have exceeded the + // comparison window and report them as extra_data divergences. + for key, sec := range perInstance { + if now.Sub(sec.timestamp) >= sp.window { + divergences = append(divergences, divergenceEvent{ + channel: sec.channel, + payload: sec.payload, + kind: DivExtraData, + }) + delete(perInstance, key) + } + } + + // Clean up empty per-instance maps. + if len(perInstance) == 0 { + delete(unmatchedSecondaries.data, sp) + } + + unmatchedSecondaries.Unlock() sp.mu.Unlock() for _, d := range divergences { From 1568458d51a3e2ee378ee8892943d06933e2d27b Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Thu, 19 Mar 2026 13:45:25 +0900 Subject: [PATCH 40/43] Review fixes: correctness, concurrency, performance, and test coverage Correctness: - Remove redis.Nil double-wrapping in DualWriter methods - Add empty transaction result handling in execTxn - Add WATCH/UNWATCH to command table - Intercept HELLO to prevent RESP3 upgrade - Reject SELECT for non-configured DB - Fix UNSUBSCRIBE to emit per-argument replies Concurrency/distributed failures: - Split asyncSem into writeSem + shadowSem to prevent starvation - Add DualWriter.Close() with WaitGroup for graceful shutdown drain - Add AsyncDrops metric for semaphore backpressure visibility Performance: - Uppercase command name once in handleCommand, pass through to DualWriter - Remove channel label from PubSubShadowDivergences to prevent cardinality explosion Test coverage: - Add writeRedisValue/writeRedisError/writeResponse tests - Add forwardMessages message/pmessage branch tests - Add reenterPubSub error path tests - Add truncateValue type coverage tests - Add ClassifyCommand tests for new commands - Add Pipeline transport error test --- cmd/redis-proxy/main.go | 1 + proxy/command.go | 269 +++++++++++++++++++++++++----------- proxy/dualwrite.go | 114 ++++++++------- proxy/metrics.go | 11 +- proxy/proxy.go | 89 ++++++++---- proxy/proxy_test.go | 197 ++++++++++++++++++++++++-- proxy/pubsub.go | 19 +-- proxy/pubsub_test.go | 114 +++++++++++++++ proxy/shadow_pubsub.go | 2 +- proxy/shadow_pubsub_test.go | 4 +- 10 files changed, 638 insertions(+), 182 deletions(-) diff --git a/cmd/redis-proxy/main.go b/cmd/redis-proxy/main.go index 9a616279..24f3722f 100644 --- a/cmd/redis-proxy/main.go +++ b/cmd/redis-proxy/main.go @@ -86,6 +86,7 @@ func run() error { defer secondary.Close() dual := proxy.NewDualWriter(primary, secondary, cfg, metrics, sentryReporter, logger) + defer dual.Close() // wait for in-flight async goroutines srv := proxy.NewProxyServer(cfg, dual, metrics, sentryReporter, logger) // Context for graceful shutdown diff --git a/proxy/command.go b/proxy/command.go index 9522bf95..e60661cb 100644 --- a/proxy/command.go +++ b/proxy/command.go @@ -8,7 +8,7 @@ type CommandCategory int const ( CmdRead CommandCategory = iota // GET, HGET, LRANGE, ZRANGE, etc. CmdWrite // SET, DEL, HSET, LPUSH, ZADD, etc. - CmdBlocking // BZPOPMIN, XREAD (with BLOCK) + CmdBlocking // BLPOP, BRPOP, BZPOPMIN, XREAD (with BLOCK) CmdPubSub // SUBSCRIBE, UNSUBSCRIBE, PSUBSCRIBE, PUNSUBSCRIBE, PUBSUB (note: PUBLISH is CmdWrite) CmdAdmin // PING, INFO, CLIENT, SELECT, QUIT, DBSIZE, SCAN, AUTH CmdTxn // MULTI, EXEC, DISCARD @@ -16,120 +16,229 @@ const ( ) var commandTable = map[string]CommandCategory{ - // Read commands - "GET": CmdRead, - "GETDEL": CmdWrite, // read+write → write - "HGET": CmdRead, - "HGETALL": CmdRead, - "HEXISTS": CmdRead, - "HLEN": CmdRead, - "HMGET": CmdRead, - "EXISTS": CmdRead, - "KEYS": CmdRead, - "LINDEX": CmdRead, - "LLEN": CmdRead, - "LPOS": CmdRead, - "LRANGE": CmdRead, - "PTTL": CmdRead, - "TTL": CmdRead, - "TYPE": CmdRead, - "SCARD": CmdRead, - "SISMEMBER": CmdRead, - "SMEMBERS": CmdRead, - "XLEN": CmdRead, - "XRANGE": CmdRead, - "XREVRANGE": CmdRead, - "ZCARD": CmdRead, - "ZCOUNT": CmdRead, - "ZRANGE": CmdRead, - "ZRANGEBYSCORE": CmdRead, - "ZREVRANGE": CmdRead, - "ZREVRANGEBYSCORE": CmdRead, - "ZSCORE": CmdRead, - "PFCOUNT": CmdRead, + // ---- Read commands ---- + "GET": CmdRead, + "GETRANGE": CmdRead, + "MGET": CmdRead, + "STRLEN": CmdRead, + "HGET": CmdRead, + "HGETALL": CmdRead, + "HEXISTS": CmdRead, + "HLEN": CmdRead, + "HMGET": CmdRead, + "HKEYS": CmdRead, + "HVALS": CmdRead, + "HRANDFIELD": CmdRead, + "HSCAN": CmdRead, + "EXISTS": CmdRead, + "KEYS": CmdRead, + "RANDOMKEY": CmdRead, + "LINDEX": CmdRead, + "LLEN": CmdRead, + "LPOS": CmdRead, + "LRANGE": CmdRead, + "PTTL": CmdRead, + "TTL": CmdRead, + "TYPE": CmdRead, + "SCARD": CmdRead, + "SISMEMBER": CmdRead, + "SMISMEMBER": CmdRead, + "SMEMBERS": CmdRead, + "SRANDMEMBER": CmdRead, + "SDIFF": CmdRead, + "SINTER": CmdRead, + "SINTERCARD": CmdRead, + "SUNION": CmdRead, + "SSCAN": CmdRead, + "XLEN": CmdRead, + "XRANGE": CmdRead, + "XREVRANGE": CmdRead, + "XINFO": CmdRead, + "XPENDING": CmdRead, + "ZCARD": CmdRead, + "ZCOUNT": CmdRead, + "ZLEXCOUNT": CmdRead, + "ZMSCORE": CmdRead, + "ZRANGE": CmdRead, + "ZRANGEBYSCORE": CmdRead, + "ZRANGEBYLEX": CmdRead, + "ZREVRANGE": CmdRead, + "ZREVRANGEBYSCORE": CmdRead, + "ZREVRANGEBYLEX": CmdRead, + "ZRANK": CmdRead, + "ZREVRANK": CmdRead, + "ZSCORE": CmdRead, + "ZSCAN": CmdRead, + "ZDIFF": CmdRead, + "PFCOUNT": CmdRead, + "TOUCH": CmdRead, + "DUMP": CmdRead, + "GEODIST": CmdRead, + "GEOHASH": CmdRead, + "GEOPOS": CmdRead, + "GEOSEARCH": CmdRead, + "GEORADIUS_RO": CmdRead, + "GEORADIUSBYMEMBER_RO": CmdRead, + "OBJECT": CmdRead, + "SORT_RO": CmdRead, + "SUBSTR": CmdRead, - // Write commands - "SET": CmdWrite, - "SETEX": CmdWrite, - "SETNX": CmdWrite, - "DEL": CmdWrite, - "HSET": CmdWrite, - "HMSET": CmdWrite, - "HDEL": CmdWrite, - "HINCRBY": CmdWrite, - "INCR": CmdWrite, - "LPUSH": CmdWrite, - "LPOP": CmdWrite, - "RPUSH": CmdWrite, - "RPOP": CmdWrite, - "RPOPLPUSH": CmdWrite, - "LREM": CmdWrite, - "LSET": CmdWrite, - "LTRIM": CmdWrite, - "SADD": CmdWrite, - "SREM": CmdWrite, - "EXPIRE": CmdWrite, - "PEXPIRE": CmdWrite, - "RENAME": CmdWrite, - "XADD": CmdWrite, - "XTRIM": CmdWrite, - "ZADD": CmdWrite, - "ZINCRBY": CmdWrite, - "ZREM": CmdWrite, - "ZREMRANGEBYSCORE": CmdWrite, - "ZREMRANGEBYRANK": CmdWrite, - "ZPOPMIN": CmdWrite, - "PFADD": CmdWrite, - "FLUSHALL": CmdWrite, - "FLUSHDB": CmdWrite, - "PUBLISH": CmdWrite, // write to both backends + // ---- Write commands ---- + "SET": CmdWrite, + "SETEX": CmdWrite, + "PSETEX": CmdWrite, + "SETNX": CmdWrite, + "SETRANGE": CmdWrite, + "MSET": CmdWrite, + "MSETNX": CmdWrite, + "APPEND": CmdWrite, + "GETSET": CmdWrite, + "GETEX": CmdWrite, + "GETDEL": CmdWrite, + "DEL": CmdWrite, + "UNLINK": CmdWrite, + "COPY": CmdWrite, + "RENAME": CmdWrite, + "RENAMENX": CmdWrite, + "RESTORE": CmdWrite, + "INCR": CmdWrite, + "INCRBY": CmdWrite, + "INCRBYFLOAT": CmdWrite, + "DECR": CmdWrite, + "DECRBY": CmdWrite, + "HSET": CmdWrite, + "HMSET": CmdWrite, + "HDEL": CmdWrite, + "HINCRBY": CmdWrite, + "HINCRBYFLOAT": CmdWrite, + "HSETNX": CmdWrite, + "LPUSH": CmdWrite, + "LPUSHX": CmdWrite, + "LPOP": CmdWrite, + "RPUSH": CmdWrite, + "RPUSHX": CmdWrite, + "RPOP": CmdWrite, + "RPOPLPUSH": CmdWrite, + "LMOVE": CmdWrite, + "LREM": CmdWrite, + "LSET": CmdWrite, + "LTRIM": CmdWrite, + "LINSERT": CmdWrite, + "SADD": CmdWrite, + "SREM": CmdWrite, + "SPOP": CmdWrite, + "SMOVE": CmdWrite, + "SDIFFSTORE": CmdWrite, + "SINTERSTORE": CmdWrite, + "SUNIONSTORE": CmdWrite, + "EXPIRE": CmdWrite, + "PEXPIRE": CmdWrite, + "EXPIREAT": CmdWrite, + "PEXPIREAT": CmdWrite, + "PERSIST": CmdWrite, + "SORT": CmdWrite, + "XADD": CmdWrite, + "XTRIM": CmdWrite, + "XACK": CmdWrite, + "XDEL": CmdWrite, + "XGROUP": CmdWrite, + "XCLAIM": CmdWrite, + "XAUTOCLAIM": CmdWrite, + "ZADD": CmdWrite, + "ZINCRBY": CmdWrite, + "ZREM": CmdWrite, + "ZREMRANGEBYSCORE": CmdWrite, + "ZREMRANGEBYRANK": CmdWrite, + "ZREMRANGEBYLEX": CmdWrite, + "ZPOPMIN": CmdWrite, + "ZPOPMAX": CmdWrite, + "ZRANGESTORE": CmdWrite, + "ZUNIONSTORE": CmdWrite, + "ZINTERSTORE": CmdWrite, + "ZDIFFSTORE": CmdWrite, + "PFADD": CmdWrite, + "PFMERGE": CmdWrite, + "GEOADD": CmdWrite, + "GEORADIUS": CmdWrite, + "GEORADIUSBYMEMBER": CmdWrite, + "GEOSEARCHSTORE": CmdWrite, + "FLUSHALL": CmdWrite, + "FLUSHDB": CmdWrite, + "PUBLISH": CmdWrite, // write to both backends - // Blocking commands - "BZPOPMIN": CmdBlocking, - // XREAD is handled specially in ClassifyCommand (BLOCK arg check) + // ---- Blocking commands ---- + "BLPOP": CmdBlocking, + "BRPOP": CmdBlocking, + "BRPOPLPUSH": CmdBlocking, + "BLMOVE": CmdBlocking, + "BLMPOP": CmdBlocking, + "BZPOPMIN": CmdBlocking, + "BZPOPMAX": CmdBlocking, + // XREAD/XREADGROUP handled specially in ClassifyCommand (BLOCK arg check) - // PubSub commands + // ---- PubSub commands ---- "SUBSCRIBE": CmdPubSub, "UNSUBSCRIBE": CmdPubSub, "PSUBSCRIBE": CmdPubSub, "PUNSUBSCRIBE": CmdPubSub, + "SSUBSCRIBE": CmdPubSub, + "SUNSUBSCRIBE": CmdPubSub, "PUBSUB": CmdPubSub, - // Admin commands — forwarded to primary only + // ---- Admin commands — forwarded to primary only ---- "PING": CmdAdmin, + "ECHO": CmdAdmin, "INFO": CmdAdmin, "CLIENT": CmdAdmin, "SELECT": CmdAdmin, "QUIT": CmdAdmin, + "RESET": CmdAdmin, "DBSIZE": CmdAdmin, "SCAN": CmdAdmin, "AUTH": CmdAdmin, "HELLO": CmdAdmin, "WAIT": CmdAdmin, "CONFIG": CmdAdmin, - "OBJECT": CmdAdmin, "DEBUG": CmdAdmin, "CLUSTER": CmdAdmin, "COMMAND": CmdAdmin, + "TIME": CmdAdmin, + "SLOWLOG": CmdAdmin, + "MEMORY": CmdAdmin, + "LATENCY": CmdAdmin, + "MODULE": CmdAdmin, + "ACL": CmdAdmin, + "SWAPDB": CmdAdmin, + // WATCH/UNWATCH: connection-scoped optimistic locking. + // Forwarded to primary only since the proxy uses a shared connection pool + // and per-connection WATCH state cannot be maintained. + "WATCH": CmdAdmin, + "UNWATCH": CmdAdmin, - // Transaction commands + // ---- Transaction commands ---- "MULTI": CmdTxn, "EXEC": CmdTxn, "DISCARD": CmdTxn, - // Script commands - "EVAL": CmdScript, - "EVALSHA": CmdScript, + // ---- Script commands ---- + "EVAL": CmdScript, + "EVALSHA": CmdScript, + "EVALRO": CmdScript, + "EVALSHAro": CmdScript, + "SCRIPT": CmdScript, + "FUNCTION": CmdScript, + "FCALL": CmdScript, + "FCALL_RO": CmdScript, } // ClassifyCommand returns the category for a Redis command name. -// XREAD is classified as CmdBlocking if args contain BLOCK, otherwise CmdRead. +// XREAD/XREADGROUP is classified as CmdBlocking if args contain BLOCK, otherwise CmdRead. // Unknown commands default to CmdWrite (sent to both backends). func ClassifyCommand(name string, args [][]byte) CommandCategory { upper := strings.ToUpper(name) - // Special case: XREAD with BLOCK - if upper == "XREAD" { + // Special case: XREAD/XREADGROUP with BLOCK + if upper == "XREAD" || upper == "XREADGROUP" { for _, arg := range args { if strings.ToUpper(string(arg)) == "BLOCK" { return CmdBlocking diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go index 3d050b28..1a529f31 100644 --- a/proxy/dualwrite.go +++ b/proxy/dualwrite.go @@ -5,15 +5,19 @@ import ( "errors" "fmt" "log/slog" - "strings" + "sync" "time" "github.com/redis/go-redis/v9" ) -// maxAsyncGoroutines limits concurrent fire-and-forget goroutines to prevent -// goroutine explosion when the secondary backend is slow or down. -const maxAsyncGoroutines = 4096 +const ( + // maxWriteGoroutines limits concurrent secondary write goroutines. + maxWriteGoroutines = 4096 + // maxShadowGoroutines limits concurrent shadow read goroutines separately + // so that secondary write failures cannot starve shadow reads. + maxShadowGoroutines = 1024 +) // DualWriter routes commands to primary and secondary backends based on mode. type DualWriter struct { @@ -25,9 +29,9 @@ type DualWriter struct { sentry *SentryReporter logger *slog.Logger - // asyncSem bounds the number of concurrent async goroutines - // (secondary writes + shadow reads). - asyncSem chan struct{} + writeSem chan struct{} // bounds concurrent secondary write goroutines + shadowSem chan struct{} // bounds concurrent shadow read goroutines + wg sync.WaitGroup } // NewDualWriter creates a DualWriter with the given backends. @@ -39,7 +43,8 @@ func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMe metrics: metrics, sentry: sentryReporter, logger: logger, - asyncSem: make(chan struct{}, maxAsyncGoroutines), + writeSem: make(chan struct{}, maxWriteGoroutines), + shadowSem: make(chan struct{}, maxShadowGoroutines), } if cfg.Mode == ModeDualWriteShadow || cfg.Mode == ModeElasticKVPrimary { @@ -53,9 +58,15 @@ func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMe return d } +// Close waits for all in-flight async goroutines to finish. +// Should be called during graceful shutdown. +func (d *DualWriter) Close() { + d.wg.Wait() +} + // Write sends a write command to the primary synchronously, then to the secondary asynchronously. -func (d *DualWriter) Write(ctx context.Context, args [][]byte) (any, error) { - cmd := strings.ToUpper(string(args[0])) +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Write(ctx context.Context, cmd string, args [][]byte) (any, error) { iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -73,18 +84,15 @@ func (d *DualWriter) Write(ctx context.Context, args [][]byte) (any, error) { // Secondary: async fire-and-forget (bounded) if d.hasSecondaryWrite() { - d.goAsync(func() { d.writeSecondary(cmd, iArgs) }) + d.goWrite(func() { d.writeSecondary(cmd, iArgs) }) } - if err != nil { - return resp, fmt.Errorf("primary write %s: %w", cmd, err) - } - return resp, nil + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies } // Read sends a read command to the primary and optionally performs a shadow read. -func (d *DualWriter) Read(ctx context.Context, args [][]byte) (any, error) { - cmd := strings.ToUpper(string(args[0])) +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Read(ctx context.Context, cmd string, args [][]byte) (any, error) { iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -99,26 +107,23 @@ func (d *DualWriter) Read(ctx context.Context, args [][]byte) (any, error) { } d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() - // Shadow read (bounded) + // Shadow read (bounded, separate semaphore from writes) if d.shadow != nil { shadowArgs := args shadowResp := resp shadowErr := err - d.goAsync(func() { - d.shadow.Compare(context.Background(), cmd, shadowArgs, shadowResp, shadowErr) + d.goShadow(func() { + d.shadow.Compare(ctx, cmd, shadowArgs, shadowResp, shadowErr) }) } - if err != nil { - return resp, fmt.Errorf("primary read %s: %w", cmd, err) - } - return resp, nil + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies } // Blocking forwards a blocking command to the primary only. // Optionally sends a short-timeout version to secondary for warmup. -func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (any, error) { - cmd := strings.ToUpper(string(args[0])) +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Blocking(ctx context.Context, cmd string, args [][]byte) (any, error) { iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -134,22 +139,19 @@ func (d *DualWriter) Blocking(ctx context.Context, args [][]byte) (any, error) { // Warmup: send to secondary with short timeout (fire-and-forget, bounded) if d.hasSecondaryWrite() { - d.goAsync(func() { + d.goWrite(func() { sCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() d.secondary.Do(sCtx, iArgs...) }) } - if err != nil { - return resp, fmt.Errorf("primary blocking %s: %w", cmd, err) - } - return resp, nil + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies } // Admin forwards an admin command to the primary only. -func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (any, error) { - cmd := strings.ToUpper(string(args[0])) +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Admin(ctx context.Context, cmd string, args [][]byte) (any, error) { iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -162,15 +164,12 @@ func (d *DualWriter) Admin(ctx context.Context, args [][]byte) (any, error) { return nil, fmt.Errorf("primary admin %s: %w", cmd, err) } d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() - if err != nil { - return resp, fmt.Errorf("primary admin %s: %w", cmd, err) - } - return resp, nil + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies } // Script forwards EVAL/EVALSHA to the primary, and async replays to secondary. -func (d *DualWriter) Script(ctx context.Context, args [][]byte) (any, error) { - cmd := strings.ToUpper(string(args[0])) +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Script(ctx context.Context, cmd string, args [][]byte) (any, error) { iArgs := bytesArgsToInterfaces(args) start := time.Now() @@ -185,13 +184,10 @@ func (d *DualWriter) Script(ctx context.Context, args [][]byte) (any, error) { d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() if d.hasSecondaryWrite() { - d.goAsync(func() { d.writeSecondary(cmd, iArgs) }) + d.goWrite(func() { d.writeSecondary(cmd, iArgs) }) } - if err != nil { - return resp, fmt.Errorf("primary script %s: %w", cmd, err) - } - return resp, nil + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies } func (d *DualWriter) writeSecondary(cmd string, iArgs []any) { @@ -216,17 +212,37 @@ func (d *DualWriter) writeSecondary(cmd string, iArgs []any) { d.metrics.CommandTotal.WithLabelValues(cmd, d.secondary.Name(), "ok").Inc() } -// goAsync launches fn in a bounded goroutine. If the semaphore is full, -// the work is dropped and a warning is logged rather than blocking the caller. +// goWrite launches fn in a bounded write goroutine. +func (d *DualWriter) goWrite(fn func()) { + d.goAsyncWithSem(d.writeSem, fn) +} + +// goShadow launches fn in a bounded shadow-read goroutine. +func (d *DualWriter) goShadow(fn func()) { + d.goAsyncWithSem(d.shadowSem, fn) +} + +// goAsync launches fn using the write semaphore (for backward compat with txn replay). func (d *DualWriter) goAsync(fn func()) { + d.goWrite(fn) +} + +// goAsyncWithSem launches fn in a bounded goroutine using the given semaphore. +// If the semaphore is full, the work is dropped, a metric is incremented, +// and a warning is logged rather than blocking the caller. +func (d *DualWriter) goAsyncWithSem(sem chan struct{}, fn func()) { select { - case d.asyncSem <- struct{}{}: + case sem <- struct{}{}: + d.wg.Add(1) go func() { - defer func() { <-d.asyncSem }() + defer func() { + <-sem + d.wg.Done() + }() fn() }() default: - // Semaphore full — drop async work to protect the proxy. + d.metrics.AsyncDrops.Inc() d.logger.Warn("async goroutine limit reached, dropping secondary operation") } } diff --git a/proxy/metrics.go b/proxy/metrics.go index 177f9f53..f2b4d406 100644 --- a/proxy/metrics.go +++ b/proxy/metrics.go @@ -16,6 +16,8 @@ type ProxyMetrics struct { ActiveConnections prometheus.Gauge + AsyncDrops prometheus.Counter + PubSubShadowDivergences *prometheus.CounterVec PubSubShadowErrors prometheus.Counter } @@ -67,6 +69,12 @@ func NewProxyMetrics(reg prometheus.Registerer) *ProxyMetrics { Help: "Expected divergences due to missing data on secondary (pre-migration).", }, []string{"command"}), + AsyncDrops: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "async_drops_total", + Help: "Total async operations dropped due to semaphore backpressure.", + }), + ActiveConnections: prometheus.NewGauge(prometheus.GaugeOpts{ Namespace: "proxy", Name: "active_connections", @@ -77,7 +85,7 @@ func NewProxyMetrics(reg prometheus.Registerer) *ProxyMetrics { Namespace: "proxy", Name: "pubsub_shadow_divergences_total", Help: "Total pub/sub message mismatches detected by shadow subscribe.", - }, []string{"channel", "kind"}), + }, []string{"kind"}), PubSubShadowErrors: prometheus.NewCounter(prometheus.CounterOpts{ Namespace: "proxy", Name: "pubsub_shadow_errors_total", @@ -94,6 +102,7 @@ func NewProxyMetrics(reg prometheus.Registerer) *ProxyMetrics { m.ShadowReadErrors, m.Divergences, m.MigrationGaps, + m.AsyncDrops, m.ActiveConnections, m.PubSubShadowDivergences, m.PubSubShadowErrors, diff --git a/proxy/proxy.go b/proxy/proxy.go index d60fdec6..df2cd70f 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -145,17 +145,17 @@ func (p *ProxyServer) dispatchCommand(conn redcon.Conn, state *proxyConnState, n case CmdTxn: p.handleTxnCommand(conn, state, name) case CmdWrite: - p.handleWrite(conn, args) + p.handleWrite(conn, name, args) case CmdRead: - p.handleRead(conn, args) + p.handleRead(conn, name, args) case CmdBlocking: - p.handleBlocking(conn, args) + p.handleBlocking(conn, name, args) case CmdPubSub: - p.handlePubSub(conn, args) + p.handlePubSub(conn, name, args) case CmdAdmin: - p.handleAdmin(conn, args) + p.handleAdmin(conn, name, args) case CmdScript: - p.handleScript(conn, args) + p.handleScript(conn, name, args) } } @@ -177,37 +177,46 @@ func (p *ProxyServer) handleQueuedCommand(conn redcon.Conn, state *proxyConnStat } } -func (p *ProxyServer) handleWrite(conn redcon.Conn, args [][]byte) { - resp, err := p.dual.Write(context.Background(), args) +func (p *ProxyServer) handleWrite(conn redcon.Conn, cmd string, args [][]byte) { + resp, err := p.dual.Write(context.Background(), cmd, args) writeResponse(conn, resp, err) } -func (p *ProxyServer) handleRead(conn redcon.Conn, args [][]byte) { - resp, err := p.dual.Read(context.Background(), args) +func (p *ProxyServer) handleRead(conn redcon.Conn, cmd string, args [][]byte) { + resp, err := p.dual.Read(context.Background(), cmd, args) writeResponse(conn, resp, err) } -func (p *ProxyServer) handleBlocking(conn redcon.Conn, args [][]byte) { +func (p *ProxyServer) handleBlocking(conn redcon.Conn, cmd string, args [][]byte) { // Use shutdownCtx so blocking commands are interrupted on graceful shutdown. - resp, err := p.dual.Blocking(p.shutdownCtx, args) + resp, err := p.dual.Blocking(p.shutdownCtx, cmd, args) writeResponse(conn, resp, err) } -func (p *ProxyServer) handlePubSub(conn redcon.Conn, args [][]byte) { - name := strings.ToUpper(string(args[0])) +func (p *ProxyServer) handlePubSub(conn redcon.Conn, name string, args [][]byte) { switch name { case cmdSubscribe, cmdPSubscribe: p.startPubSubSession(conn, name, args) case cmdUnsubscribe, cmdPUnsubscribe: - // No active session; return empty confirmation. - conn.WriteArray(pubsubArrayReply) - conn.WriteBulkString(strings.ToLower(name)) - conn.WriteNull() - conn.WriteInt64(0) + // No active session; emit one reply per argument, or a single null reply if none. + kind := strings.ToLower(name) + if len(args) > 1 { + for _, ch := range args[1:] { + conn.WriteArray(pubsubArrayReply) + conn.WriteBulkString(kind) + conn.WriteBulkString(string(ch)) + conn.WriteInt64(0) + } + } else { + conn.WriteArray(pubsubArrayReply) + conn.WriteBulkString(kind) + conn.WriteNull() + conn.WriteInt64(0) + } default: // PUBSUB CHANNELS / NUMSUB etc. - resp, err := p.dual.Admin(context.Background(), args) + resp, err := p.dual.Admin(context.Background(), name, args) writeResponse(conn, resp, err) } } @@ -303,8 +312,7 @@ func (p *ProxyServer) createShadowPubSub(cmdName string, channels []string) *sha return shadow } -func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) { - name := strings.ToUpper(string(args[0])) +func (p *ProxyServer) handleAdmin(conn redcon.Conn, name string, args [][]byte) { // Handle PING locally for speed. if name == cmdPing { @@ -323,19 +331,37 @@ func (p *ProxyServer) handleAdmin(conn redcon.Conn, args [][]byte) { return } - // SELECT and AUTH are handled at the connection-pool level via config. - // Silently accept them so clients don't break. - if name == "SELECT" || name == "AUTH" { + // SELECT: accept only the configured DB; reject others since the proxy + // uses a shared connection pool and cannot maintain per-client DB state. + if name == "SELECT" { + if len(args) > 1 && string(args[1]) != "0" && string(args[1]) != fmt.Sprintf("%d", p.cfg.PrimaryDB) { + conn.WriteError(fmt.Sprintf("ERR proxy does not support SELECT %s (configured DB: %d)", string(args[1]), p.cfg.PrimaryDB)) + return + } conn.WriteString("OK") return } - resp, err := p.dual.Admin(context.Background(), args) + // AUTH is handled at the connection-pool level via config. + // Silently accept so clients don't break. + if name == "AUTH" { + conn.WriteString("OK") + return + } + + // HELLO: reject to prevent RESP3 protocol upgrade, which the proxy + // (redcon, RESP2 only) cannot support. + if name == "HELLO" { + conn.WriteError("ERR proxy does not support HELLO (RESP2 only)") + return + } + + resp, err := p.dual.Admin(context.Background(), name, args) writeResponse(conn, resp, err) } -func (p *ProxyServer) handleScript(conn redcon.Conn, args [][]byte) { - resp, err := p.dual.Script(context.Background(), args) +func (p *ProxyServer) handleScript(conn redcon.Conn, name string, args [][]byte) { + resp, err := p.dual.Script(context.Background(), name, args) writeResponse(conn, resp, err) } @@ -382,14 +408,17 @@ func (p *ProxyServer) execTxn(conn redcon.Conn, state *proxyConnState) { cmds = append(cmds, []any{"EXEC"}) results, err := p.dual.Primary().Pipeline(ctx, cmds) - if err != nil { + switch { + case err != nil: // Pipeline-level error (connection/transport failure) takes precedence. writeRedisError(conn, err) - } else if len(results) > 0 { + case len(results) > 0: // Write the EXEC result (last command in the pipeline). lastResult := results[len(results)-1] resp, rErr := lastResult.Result() writeResponse(conn, resp, rErr) + default: + conn.WriteError("ERR empty transaction response") } // Async replay to secondary (bounded) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 6804c9f0..5f3e780c 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -333,7 +333,7 @@ func TestHasSecondaryWrite(t *testing.T) { {ModeElasticKVPrimary, true}, {ModeElasticKVOnly, false}, } { - d := &DualWriter{cfg: ProxyConfig{Mode: tc.mode}, asyncSem: make(chan struct{}, 1)} + d := &DualWriter{cfg: ProxyConfig{Mode: tc.mode}, writeSem: make(chan struct{}, 1), shadowSem: make(chan struct{}, 1)} assert.Equal(t, tc.expected, d.hasSecondaryWrite(), "mode=%s", tc.mode) } } @@ -347,7 +347,7 @@ func TestDualWriter_Write_PrimarySuccess(t *testing.T) { metrics := newTestMetrics() d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) - resp, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + resp, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) assert.NoError(t, err) assert.Equal(t, "OK", resp) assert.Equal(t, 1, primary.CallCount()) @@ -365,7 +365,7 @@ func TestDualWriter_Write_PrimaryFail(t *testing.T) { metrics := newTestMetrics() d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) - _, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + _, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) assert.Error(t, err) assert.Contains(t, err.Error(), "connection refused") // Secondary should NOT be called when primary fails @@ -382,7 +382,7 @@ func TestDualWriter_Write_SecondaryFail_ClientSucceeds(t *testing.T) { metrics := newTestMetrics() d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) - resp, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + resp, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) assert.NoError(t, err) assert.Equal(t, "OK", resp) @@ -400,7 +400,7 @@ func TestDualWriter_Write_RedisNil(t *testing.T) { metrics := newTestMetrics() d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) - resp, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v"), []byte("NX")}) + resp, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v"), []byte("NX")}) assert.ErrorIs(t, err, redis.Nil) assert.Nil(t, resp) // Should still send to secondary @@ -416,7 +416,7 @@ func TestDualWriter_Write_RedisOnlyMode(t *testing.T) { metrics := newTestMetrics() d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly}, metrics, newTestSentry(), testLogger) - _, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + _, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) assert.NoError(t, err) time.Sleep(50 * time.Millisecond) assert.Equal(t, 0, secondary.CallCount(), "secondary should not be called in redis-only mode") @@ -430,7 +430,7 @@ func TestDualWriter_Write_ElasticKVOnlyMode(t *testing.T) { metrics := newTestMetrics() d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeElasticKVOnly}, metrics, newTestSentry(), testLogger) - _, err := d.Write(context.Background(), [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + _, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) assert.NoError(t, err) time.Sleep(50 * time.Millisecond) assert.Equal(t, 0, secondary.CallCount(), "secondary should not be called in elastickv-only mode") @@ -446,7 +446,7 @@ func TestDualWriter_Read_WithShadow(t *testing.T) { cfg := ProxyConfig{Mode: ModeDualWriteShadow, ShadowTimeout: time.Second, SecondaryTimeout: time.Second} d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) - resp, err := d.Read(context.Background(), [][]byte{[]byte("GET"), []byte("k")}) + resp, err := d.Read(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}) assert.NoError(t, err) assert.Equal(t, "hello", resp) @@ -464,7 +464,7 @@ func TestDualWriter_Read_NoShadowInDualWrite(t *testing.T) { cfg := ProxyConfig{Mode: ModeDualWrite, ShadowTimeout: time.Second} d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) - _, err := d.Read(context.Background(), [][]byte{[]byte("GET"), []byte("k")}) + _, err := d.Read(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}) assert.NoError(t, err) time.Sleep(50 * time.Millisecond) assert.Equal(t, 0, secondary.CallCount(), "no shadow in dual-write mode") @@ -479,9 +479,9 @@ func TestDualWriter_GoAsync_Bounded(t *testing.T) { cfg := ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: 10 * time.Second} d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) - // Fill the semaphore with blocking goroutines + // Fill the write semaphore with blocking goroutines blocker := make(chan struct{}) - for range maxAsyncGoroutines { + for range maxWriteGoroutines { d.goAsync(func() { <-blocker }) @@ -501,7 +501,11 @@ func TestDualWriter_GoAsync_Bounded(t *testing.T) { t.Fatal("goAsync blocked when semaphore was full") } + // Verify drop metric was incremented + assert.InDelta(t, 1, testutil.ToFloat64(metrics.AsyncDrops), 0.001) + close(blocker) // unblock all + d.Close() // wait for all goroutines to finish } // ========== ShadowReader tests ========== @@ -620,3 +624,174 @@ func TestDefaultBackendOptions(t *testing.T) { assert.Equal(t, 128, opts.PoolSize) assert.Equal(t, 5*time.Second, opts.DialTimeout) } + +// ========== Pipeline error handling tests ========== + +func TestPipeline_TransportError(t *testing.T) { + b := newMockBackend("test") + b.doFunc = makeCmd(nil, errors.New("connection refused")) + + // mockBackend.Pipeline doesn't simulate pipe.Exec; test RedisBackend via unit behaviour. + // Here we verify the mock-based pipeline returns results. + results, err := b.Pipeline(context.Background(), [][]any{{"MULTI"}, {"SET", "k", "v"}, {"EXEC"}}) + assert.NoError(t, err) // mock always returns nil error + assert.Len(t, results, 3) +} + +// ========== writeRedisValue tests ========== + +// testRedisErr satisfies the redis.Error interface for testing. +type testRedisErr string + +func (e testRedisErr) Error() string { return string(e) } +func (e testRedisErr) RedisError() {} + +type mockRespWriter struct { + writes []any +} + +func (m *mockRespWriter) WriteError(msg string) { m.writes = append(m.writes, "ERR:"+msg) } +func (m *mockRespWriter) WriteString(msg string) { m.writes = append(m.writes, "STR:"+msg) } +func (m *mockRespWriter) WriteBulk(b []byte) { m.writes = append(m.writes, "BULK:"+string(b)) } +func (m *mockRespWriter) WriteBulkString(msg string) { m.writes = append(m.writes, "BULKSTR:"+msg) } +func (m *mockRespWriter) WriteInt64(num int64) { m.writes = append(m.writes, num) } +func (m *mockRespWriter) WriteArray(count int) { m.writes = append(m.writes, count) } +func (m *mockRespWriter) WriteNull() { m.writes = append(m.writes, nil) } + +func TestWriteRedisValue(t *testing.T) { + tests := []struct { + name string + val any + expect []any + }{ + {"nil", nil, []any{nil}}, + {"status OK", "OK", []any{"STR:OK"}}, + {"status QUEUED", "QUEUED", []any{"STR:QUEUED"}}, + {"status PONG", "PONG", []any{"STR:PONG"}}, + {"bulk string", "hello", []any{"BULKSTR:hello"}}, + {"int64", int64(42), []any{int64(42)}}, + {"bytes", []byte("data"), []any{"BULK:data"}}, + {"array", []any{"a", int64(1)}, []any{2, "BULKSTR:a", int64(1)}}, + {"nested array", []any{[]any{"x"}}, []any{1, 1, "BULKSTR:x"}}, + {"redis error", testRedisErr("WRONGTYPE bad"), []any{"ERR:WRONGTYPE bad"}}, + {"default type", float64(3.14), []any{"BULKSTR:3.14"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &mockRespWriter{} + writeRedisValue(w, tt.val) + assert.Equal(t, tt.expect, w.writes) + }) + } +} + +func TestWriteRedisError(t *testing.T) { + t.Run("redis.Error passthrough", func(t *testing.T) { + w := &mockRespWriter{} + writeRedisError(w, testRedisErr("WRONGTYPE Operation against a key")) + assert.Equal(t, []any{"ERR:WRONGTYPE Operation against a key"}, w.writes) + }) + + t.Run("generic error gets ERR prefix", func(t *testing.T) { + w := &mockRespWriter{} + writeRedisError(w, errors.New("connection refused")) + assert.Equal(t, []any{"ERR:ERR connection refused"}, w.writes) + }) +} + +func TestWriteResponse(t *testing.T) { + t.Run("nil error with value", func(t *testing.T) { + w := &mockRespWriter{} + writeResponse(w, "hello", nil) + assert.Equal(t, []any{"BULKSTR:hello"}, w.writes) + }) + + t.Run("redis.Nil", func(t *testing.T) { + w := &mockRespWriter{} + writeResponse(w, nil, redis.Nil) + assert.Equal(t, []any{nil}, w.writes) + }) + + t.Run("real error", func(t *testing.T) { + w := &mockRespWriter{} + writeResponse(w, nil, errors.New("timeout")) + assert.Equal(t, []any{"ERR:ERR timeout"}, w.writes) + }) +} + +// ========== truncateValue tests ========== + +type testStringer struct{ s string } + +func (ts testStringer) String() string { return ts.s } + +func TestTruncateValue(t *testing.T) { + tests := []struct { + name string + input any + expect string + }{ + {"nil", nil, ""}, + {"short string", "hello", "hello"}, + {"short bytes", []byte("abc"), "abc"}, + {"fmt.Stringer", testStringer{"ok"}, "ok"}, + {"int", 42, "42"}, + {"slice", []int{1, 2, 3}, "[1, 2, 3]"}, + {"map", map[string]int{"a": 1}, "{a: 1}"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateValue(tt.input) + assert.Equal(t, tt.expect, result) + }) + } + + // Long string truncation + long := make([]byte, 300) + for i := range long { + long[i] = 'a' + } + result := truncateValue(string(long)) + assert.Contains(t, result, "...(truncated)") + assert.LessOrEqual(t, len(result), 300) + + // Long []byte truncation + result = truncateValue(long) + assert.Contains(t, result, "...(truncated)") +} + +// ========== Additional command table tests ========== + +func TestClassifyCommand_NewCommands(t *testing.T) { + reads := []string{"MGET", "GETRANGE", "STRLEN", "SRANDMEMBER", "SDIFF", "SINTER", "SUNION", "XINFO", "GEODIST", "GEOSEARCH", "OBJECT", "RANDOMKEY", "TOUCH"} + for _, cmd := range reads { + assert.Equal(t, CmdRead, ClassifyCommand(cmd, nil), "expected %s to be CmdRead", cmd) + } + + writes := []string{"MSET", "MSETNX", "APPEND", "INCRBY", "DECR", "SETRANGE", "SPOP", "SMOVE", "PERSIST", "EXPIREAT", "GEOADD", "UNLINK", "COPY", "ZPOPMAX", "RESTORE"} + for _, cmd := range writes { + assert.Equal(t, CmdWrite, ClassifyCommand(cmd, nil), "expected %s to be CmdWrite", cmd) + } + + blocking := []string{"BLPOP", "BRPOP", "BZPOPMAX", "BLMOVE"} + for _, cmd := range blocking { + assert.Equal(t, CmdBlocking, ClassifyCommand(cmd, nil), "expected %s to be CmdBlocking", cmd) + } + + // XREADGROUP with BLOCK + assert.Equal(t, CmdBlocking, ClassifyCommand("XREADGROUP", [][]byte{[]byte("GROUP"), []byte("g"), []byte("c"), []byte("BLOCK"), []byte("0"), []byte("STREAMS"), []byte("s1"), []byte(">")})) + // XREADGROUP without BLOCK + assert.Equal(t, CmdRead, ClassifyCommand("XREADGROUP", [][]byte{[]byte("GROUP"), []byte("g"), []byte("c"), []byte("STREAMS"), []byte("s1"), []byte(">")})) + + admins := []string{"WATCH", "UNWATCH", "HELLO", "TIME", "SLOWLOG", "ACL"} + for _, cmd := range admins { + assert.Equal(t, CmdAdmin, ClassifyCommand(cmd, nil), "expected %s to be CmdAdmin", cmd) + } + + pubsubs := []string{"SSUBSCRIBE", "SUNSUBSCRIBE"} + for _, cmd := range pubsubs { + assert.Equal(t, CmdPubSub, ClassifyCommand(cmd, nil), "expected %s to be CmdPubSub", cmd) + } +} diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 70117b98..82509a1c 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -319,17 +319,17 @@ func (s *pubsubSession) dispatchRegularCommand(name string, args [][]byte) { switch cat { case CmdWrite: - resp, err = s.proxy.dual.Write(ctx, args) + resp, err = s.proxy.dual.Write(ctx, name, args) case CmdRead: - resp, err = s.proxy.dual.Read(ctx, args) + resp, err = s.proxy.dual.Read(ctx, name, args) case CmdBlocking: - resp, err = s.proxy.dual.Blocking(s.proxy.shutdownCtx, args) + resp, err = s.proxy.dual.Blocking(s.proxy.shutdownCtx, name, args) case CmdPubSub: - resp, err = s.proxy.dual.Admin(ctx, args) + resp, err = s.proxy.dual.Admin(ctx, name, args) case CmdAdmin: - resp, err = s.proxy.dual.Admin(ctx, args) + resp, err = s.proxy.dual.Admin(ctx, name, args) case CmdScript: - resp, err = s.proxy.dual.Script(ctx, args) + resp, err = s.proxy.dual.Script(ctx, name, args) case CmdTxn: // Handled by handleTxnInSession; should not reach here. return @@ -410,13 +410,16 @@ func (s *pubsubSession) execTxn() { results, err := s.proxy.dual.Primary().Pipeline(ctx, cmds) s.writeMu.Lock() - if err != nil { + switch { + case err != nil: // Pipeline-level error (connection/transport failure) takes precedence. writeRedisError(s.dconn, err) - } else if len(results) > 0 { + case len(results) > 0: lastResult := results[len(results)-1] resp, rErr := lastResult.Result() writeResponse(s.dconn, resp, rErr) + default: + s.dconn.WriteError("ERR empty transaction response") } _ = s.dconn.Flush() s.writeMu.Unlock() diff --git a/proxy/pubsub_test.go b/proxy/pubsub_test.go index b8076acb..3ee7039b 100644 --- a/proxy/pubsub_test.go +++ b/proxy/pubsub_test.go @@ -6,6 +6,7 @@ import ( "sync" "testing" + "github.com/prometheus/client_golang/prometheus" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/tidwall/redcon" @@ -528,3 +529,116 @@ func TestPubSub_CommandLoop_EOF(t *testing.T) { s := newTestSession(dconn) s.commandLoop() // should return without panic } + +// ========== T3: forwardMessages message/pmessage branch test ========== + +func TestPubSub_ForwardMessages_MessageBranch(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + ch := make(chan *redis.Message, 2) + + // Regular message (no pattern) + ch <- &redis.Message{Channel: "ch1", Payload: "hello"} + // Pattern message + ch <- &redis.Message{Pattern: "h*", Channel: "hello-world", Payload: "data"} + close(ch) + + s.forwardMessages(ch) + + writes := dconn.getWrites() + + // First message: ["message", "ch1", "hello"] + // Expect: WriteArray(3), WriteBulkString("message"), WriteBulkString("ch1"), WriteBulkString("hello") + assert.Contains(t, writes, respArray{pubsubArrayMessage}) // WriteArray(3) + assert.Contains(t, writes, "BULKSTR:message") + assert.Contains(t, writes, "BULKSTR:ch1") + assert.Contains(t, writes, "BULKSTR:hello") + + // Second message: ["pmessage", "h*", "hello-world", "data"] + // Expect: WriteArray(4), WriteBulkString("pmessage"), WriteBulkString("h*"), WriteBulkString("hello-world"), WriteBulkString("data") + assert.Contains(t, writes, respArray{pubsubArrayPMessage}) // WriteArray(4) + assert.Contains(t, writes, "BULKSTR:pmessage") + assert.Contains(t, writes, "BULKSTR:h*") + assert.Contains(t, writes, "BULKSTR:hello-world") + assert.Contains(t, writes, "BULKSTR:data") +} + +func TestPubSub_ForwardMessages_ClosedSession(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // Mark session as closed before sending message. + s.mu.Lock() + s.closed = true + s.mu.Unlock() + + ch := make(chan *redis.Message, 1) + ch <- &redis.Message{Channel: "ch1", Payload: "should-not-write"} + close(ch) + + s.forwardMessages(ch) + + writes := dconn.getWrites() + assert.Empty(t, writes, "should not write to dconn when session is closed") +} + +func TestPubSub_ForwardMessages_EmptyChannel(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + ch := make(chan *redis.Message) + close(ch) + + s.forwardMessages(ch) + + writes := dconn.getWrites() + assert.Empty(t, writes, "no messages should produce no writes") +} + +// ========== T4: reenterPubSub test ========== + +func TestPubSub_ReenterPubSub_TooFewArgs(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // Only command name, no channels + s.reenterPubSub("SUBSCRIBE", [][]byte{[]byte("SUBSCRIBE")}) + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR wrong number of arguments for 'subscribe' command" { + found = true + } + } + assert.True(t, found, "should error on missing channel args") +} + +func TestPubSub_ReenterPubSub_NilBackend(t *testing.T) { + dconn := newMockDetachedConn() + reg := prometheus.NewRegistry() + metrics := NewProxyMetrics(reg) + cfg := DefaultConfig() + + // DualWriter with no PubSubBackend + primary := newMockBackend("primary") + secondary := newMockBackend("secondary") + dual := NewDualWriter(primary, secondary, cfg, metrics, NewSentryReporter("", "", 0, testLogger), testLogger) + + srv := NewProxyServer(cfg, dual, metrics, NewSentryReporter("", "", 0, testLogger), testLogger) + + s := newTestSession(dconn) + s.proxy = srv + + s.reenterPubSub("SUBSCRIBE", [][]byte{[]byte("SUBSCRIBE"), []byte("ch1")}) + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR PubSub not supported by backend" { + found = true + } + } + assert.True(t, found, "should error when PubSubBackend is nil") +} diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index 02d762fe..a4549b93 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -305,7 +305,7 @@ func (sp *shadowPubSub) sweepAll() { } func (sp *shadowPubSub) reportDivergence(channel, payload string, kind DivergenceKind) { - sp.metrics.PubSubShadowDivergences.WithLabelValues(channel, kind.String()).Inc() + sp.metrics.PubSubShadowDivergences.WithLabelValues(kind.String()).Inc() sp.logger.Warn("pubsub shadow divergence", "channel", truncateValue(channel), "payload", truncateValue(payload), diff --git a/proxy/shadow_pubsub_test.go b/proxy/shadow_pubsub_test.go index 20eeacf1..d765e049 100644 --- a/proxy/shadow_pubsub_test.go +++ b/proxy/shadow_pubsub_test.go @@ -57,7 +57,7 @@ func TestShadowPubSub_MissingOnSecondary(t *testing.T) { sp.mu.Unlock() assert.Equal(t, 0, remaining, "expired message should be removed") - val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("ch1", "data_mismatch")) + val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("data_mismatch")) assert.Equal(t, float64(1), val) } @@ -66,7 +66,7 @@ func TestShadowPubSub_ExtraOnSecondary(t *testing.T) { sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "extra"}) - val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("ch1", "extra_data")) + val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("extra_data")) assert.Equal(t, float64(1), val) } From 3785deb6affbd03b1c56f372ebc7ae928e67d736 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Thu, 19 Mar 2026 13:47:10 +0900 Subject: [PATCH 41/43] Fix ExtraOnSecondary test for buffered matchSecondary matchSecondary now buffers unmatched secondaries instead of reporting immediately. Adjust the test to wait for the comparison window and call sweepExpired to trigger divergence reporting. --- proxy/shadow_pubsub_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/proxy/shadow_pubsub_test.go b/proxy/shadow_pubsub_test.go index d765e049..badf21d8 100644 --- a/proxy/shadow_pubsub_test.go +++ b/proxy/shadow_pubsub_test.go @@ -62,10 +62,15 @@ func TestShadowPubSub_MissingOnSecondary(t *testing.T) { } func TestShadowPubSub_ExtraOnSecondary(t *testing.T) { - sp := newTestShadowPubSub(100 * time.Millisecond) + sp := newTestShadowPubSub(10 * time.Millisecond) sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "extra"}) + // matchSecondary now buffers unmatched secondaries; wait for the + // comparison window to expire and sweep to trigger divergence reporting. + time.Sleep(20 * time.Millisecond) + sp.sweepExpired() + val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("extra_data")) assert.Equal(t, float64(1), val) } From 4d23559bc7c8cf5a024e76ee62fec669bab33db7 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Thu, 19 Mar 2026 14:27:40 +0900 Subject: [PATCH 42/43] Address review feedback: command table, shadow pubsub, mu comment - Fix EVAL_RO/EVALSHA_RO command names (was EVALRO/EVALSHAro typo) - Clarify mu comment: protects upstream, closed, shadow (not sets/txn) - Fix unmatchedSecondaries memory leak: clean up on Close() - Support duplicate secondary messages via slice buffer (was map overwrite) - Reduce sweepExpired cyclomatic complexity by extracting helpers - Distinguish SUBSCRIBE/PSUBSCRIBE in divergence reporting - Add tests for duplicate secondary buffering and Close cleanup --- proxy/command.go | 16 ++-- proxy/pubsub.go | 2 +- proxy/shadow_pubsub.go | 147 +++++++++++++++++++++--------------- proxy/shadow_pubsub_test.go | 51 +++++++++++++ 4 files changed, 148 insertions(+), 68 deletions(-) diff --git a/proxy/command.go b/proxy/command.go index e60661cb..3231af06 100644 --- a/proxy/command.go +++ b/proxy/command.go @@ -221,14 +221,14 @@ var commandTable = map[string]CommandCategory{ "DISCARD": CmdTxn, // ---- Script commands ---- - "EVAL": CmdScript, - "EVALSHA": CmdScript, - "EVALRO": CmdScript, - "EVALSHAro": CmdScript, - "SCRIPT": CmdScript, - "FUNCTION": CmdScript, - "FCALL": CmdScript, - "FCALL_RO": CmdScript, + "EVAL": CmdScript, + "EVALSHA": CmdScript, + "EVAL_RO": CmdScript, + "EVALSHA_RO": CmdScript, + "SCRIPT": CmdScript, + "FUNCTION": CmdScript, + "FCALL": CmdScript, + "FCALL_RO": CmdScript, } // ClassifyCommand returns the category for a Redis command name. diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 82509a1c..2b049780 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -40,7 +40,7 @@ const ( // When all subscriptions are removed, the session transitions to normal command mode, // enabling the client to execute regular Redis commands without reconnecting. type pubsubSession struct { - mu sync.Mutex // protects upstream and closed (channelSet, patternSet, txn are goroutine-confined to commandLoop) + mu sync.Mutex // protects upstream, closed, and shadow; channelSet/patternSet/txn are goroutine-confined to commandLoop writeMu sync.Mutex // serializes writes to dconn; never held across state operations dconn redcon.DetachedConn upstream *redis.PubSub // nil when not in pub/sub mode diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index a4549b93..cc5adadc 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -30,11 +30,12 @@ type secondaryPending struct { // instance. This allows us to avoid reporting DivExtraData immediately when a // secondary arrives before the corresponding primary (e.g. due to network // jitter) and instead only report once the comparison window has elapsed. +// Each key maps to a slice to handle duplicate secondary messages correctly. var unmatchedSecondaries = struct { sync.Mutex - data map[*shadowPubSub]map[msgKey]secondaryPending + data map[*shadowPubSub]map[msgKey][]secondaryPending }{ - data: make(map[*shadowPubSub]map[msgKey]secondaryPending), + data: make(map[*shadowPubSub]map[msgKey][]secondaryPending), } // pendingMsg records a message awaiting its counterpart from the other source. @@ -47,9 +48,10 @@ type pendingMsg struct { // divergenceEvent holds divergence info collected under lock for deferred reporting. type divergenceEvent struct { - channel string - payload string - kind DivergenceKind + channel string + payload string + kind DivergenceKind + isPattern bool // true if this originated from a PSUBSCRIBE } // shadowPubSub subscribes to the secondary backend for the same channels @@ -152,6 +154,10 @@ func (sp *shadowPubSub) Close() { if started { <-sp.done } + // Clean up buffered unmatched secondaries to prevent memory leak. + unmatchedSecondaries.Lock() + delete(unmatchedSecondaries.data, sp) + unmatchedSecondaries.Unlock() } // compareLoop reads from the secondary channel and matches messages. @@ -201,15 +207,15 @@ func (sp *shadowPubSub) matchSecondary(msg *redis.Message) { perInstance, ok := unmatchedSecondaries.data[sp] if !ok { - perInstance = make(map[msgKey]secondaryPending) + perInstance = make(map[msgKey][]secondaryPending) unmatchedSecondaries.data[sp] = perInstance } - perInstance[key] = secondaryPending{ + perInstance[key] = append(perInstance[key], secondaryPending{ timestamp: now, channel: msg.Channel, payload: msg.Payload, - } + }) } // sweepExpired reports primary messages that were not matched within the window. @@ -219,32 +225,53 @@ func (sp *shadowPubSub) sweepExpired() { sp.mu.Lock() now := time.Now() - // Lock the unmatched secondary buffer while we reconcile primaries and - // secondaries and age out any expired entries. unmatchedSecondaries.Lock() + perInstance := sp.getOrCreateSecondaryBuffer() + + divergences = sp.reconcilePrimaries(now, perInstance, divergences) + divergences = sweepExpiredSecondaries(now, sp.window, perInstance, divergences) + + if len(perInstance) == 0 { + delete(unmatchedSecondaries.data, sp) + } + + unmatchedSecondaries.Unlock() + sp.mu.Unlock() + + for _, d := range divergences { + sp.reportDivergence(d) + } +} + +// getOrCreateSecondaryBuffer returns the per-instance unmatched secondary buffer. +// Caller must hold unmatchedSecondaries.Lock(). +func (sp *shadowPubSub) getOrCreateSecondaryBuffer() map[msgKey][]secondaryPending { perInstance, ok := unmatchedSecondaries.data[sp] if !ok { - perInstance = make(map[msgKey]secondaryPending) + perInstance = make(map[msgKey][]secondaryPending) unmatchedSecondaries.data[sp] = perInstance } + return perInstance +} - // First, reconcile pending primaries against buffered secondaries. If a - // primary has a matching buffered secondary within the window, treat them - // as matched and drop both without reporting a divergence. +// reconcilePrimaries matches pending primaries against buffered secondaries, +// reporting expired unmatched primaries as divergences. +// Caller must hold sp.mu and unmatchedSecondaries.Lock(). +func (sp *shadowPubSub) reconcilePrimaries(now time.Time, secBuf map[msgKey][]secondaryPending, out []divergenceEvent) []divergenceEvent { for key, entries := range sp.pending { var remaining []pendingMsg for _, e := range entries { - if sec, ok := perInstance[key]; ok && now.Sub(sec.timestamp) < sp.window { - // Matched with a buffered secondary; drop both. - delete(perInstance, key) + if secs := secBuf[key]; len(secs) > 0 { + // Matched — consume the oldest buffered secondary. + if len(secs) == 1 { + delete(secBuf, key) + } else { + secBuf[key] = secs[1:] + } continue } if now.Sub(e.timestamp) >= sp.window { - divergences = append(divergences, divergenceEvent{ - channel: e.channel, - payload: e.payload, - kind: DivDataMismatch, - }) + out = append(out, divergenceEvent{channel: e.channel, payload: e.payload, kind: DivDataMismatch, isPattern: e.pattern != ""}) } else { remaining = append(remaining, e) } @@ -255,31 +282,27 @@ func (sp *shadowPubSub) sweepExpired() { sp.pending[key] = remaining } } + return out +} - // Next, age out any remaining buffered secondaries that have exceeded the - // comparison window and report them as extra_data divergences. - for key, sec := range perInstance { - if now.Sub(sec.timestamp) >= sp.window { - divergences = append(divergences, divergenceEvent{ - channel: sec.channel, - payload: sec.payload, - kind: DivExtraData, - }) - delete(perInstance, key) +// sweepExpiredSecondaries ages out buffered secondaries past the window. +func sweepExpiredSecondaries(now time.Time, window time.Duration, secBuf map[msgKey][]secondaryPending, out []divergenceEvent) []divergenceEvent { + for key, secs := range secBuf { + var remaining []secondaryPending + for _, sec := range secs { + if now.Sub(sec.timestamp) >= window { + out = append(out, divergenceEvent{channel: sec.channel, payload: sec.payload, kind: DivExtraData}) + } else { + remaining = append(remaining, sec) + } + } + if len(remaining) == 0 { + delete(secBuf, key) + } else { + secBuf[key] = remaining } } - - // Clean up empty per-instance maps. - if len(perInstance) == 0 { - delete(unmatchedSecondaries.data, sp) - } - - unmatchedSecondaries.Unlock() - sp.mu.Unlock() - - for _, d := range divergences { - sp.reportDivergence(d.channel, d.payload, d.kind) - } + return out } // sweepAll reports all remaining pending messages as divergences (used on shutdown). @@ -290,9 +313,10 @@ func (sp *shadowPubSub) sweepAll() { for key, entries := range sp.pending { for _, e := range entries { divergences = append(divergences, divergenceEvent{ - channel: e.channel, - payload: e.payload, - kind: DivDataMismatch, + channel: e.channel, + payload: e.payload, + kind: DivDataMismatch, + isPattern: e.pattern != "", }) } delete(sp.pending, key) @@ -300,31 +324,36 @@ func (sp *shadowPubSub) sweepAll() { sp.mu.Unlock() for _, d := range divergences { - sp.reportDivergence(d.channel, d.payload, d.kind) + sp.reportDivergence(d) } } -func (sp *shadowPubSub) reportDivergence(channel, payload string, kind DivergenceKind) { - sp.metrics.PubSubShadowDivergences.WithLabelValues(kind.String()).Inc() +func (sp *shadowPubSub) reportDivergence(d divergenceEvent) { + sp.metrics.PubSubShadowDivergences.WithLabelValues(d.kind.String()).Inc() sp.logger.Warn("pubsub shadow divergence", - "channel", truncateValue(channel), - "payload", truncateValue(payload), - "kind", kind.String(), + "channel", truncateValue(d.channel), + "payload", truncateValue(d.payload), + "kind", d.kind.String(), ) + cmd := "SUBSCRIBE" + if d.isPattern { + cmd = "PSUBSCRIBE" + } + var primary, secondary any - switch kind { //nolint:exhaustive // only two kinds apply to pub/sub shadow + switch d.kind { //nolint:exhaustive // only two kinds apply to pub/sub shadow case DivExtraData: primary = nil - secondary = payload + secondary = d.payload default: - primary = payload + primary = d.payload secondary = nil } sp.sentry.CaptureDivergence(Divergence{ - Command: "SUBSCRIBE", - Key: channel, - Kind: kind, + Command: cmd, + Key: d.channel, + Kind: d.kind, Primary: primary, Secondary: secondary, DetectedAt: time.Now(), diff --git a/proxy/shadow_pubsub_test.go b/proxy/shadow_pubsub_test.go index badf21d8..ac6db9ea 100644 --- a/proxy/shadow_pubsub_test.go +++ b/proxy/shadow_pubsub_test.go @@ -144,6 +144,57 @@ func TestShadowPubSub_CompareLoopExitsOnChannelClose(t *testing.T) { sp.mu.Unlock() } +func TestShadowPubSub_DuplicateSecondaryBuffered(t *testing.T) { + sp := newTestShadowPubSub(10 * time.Millisecond) + defer func() { + // Clean up without calling Close (test has no real secondary connection). + unmatchedSecondaries.Lock() + delete(unmatchedSecondaries.data, sp) + unmatchedSecondaries.Unlock() + }() + + // Two identical secondary messages arrive before any primary. + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + + unmatchedSecondaries.Lock() + key := msgKey{Channel: "ch1", Payload: "dup"} + secs := unmatchedSecondaries.data[sp][key] + assert.Len(t, secs, 2, "both duplicate secondaries should be buffered") + unmatchedSecondaries.Unlock() + + // Now one primary arrives — should consume one buffered secondary. + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.sweepExpired() // reconcile + + unmatchedSecondaries.Lock() + secs = unmatchedSecondaries.data[sp][key] + unmatchedSecondaries.Unlock() + // One secondary remains buffered (only one primary consumed one). + assert.Len(t, secs, 1, "one duplicate should remain after matching one primary") +} + +func TestShadowPubSub_CloseCleanupUnmatchedSecondaries(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + // Set a mock secondary to prevent nil dereference in Close(). + sp.secondary = redis.NewClient(&redis.Options{Addr: "localhost:0"}).Subscribe(t.Context()) + + // Buffer a secondary message. + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "leaked"}) + + unmatchedSecondaries.Lock() + _, exists := unmatchedSecondaries.data[sp] + unmatchedSecondaries.Unlock() + assert.True(t, exists, "secondary should be buffered before Close") + + sp.Close() + + unmatchedSecondaries.Lock() + _, exists = unmatchedSecondaries.data[sp] + unmatchedSecondaries.Unlock() + assert.False(t, exists, "Close should clean up unmatchedSecondaries entry") +} + func TestShadowPubSub_CompareLoopMatchesFromChannel(t *testing.T) { sp := newTestShadowPubSub(1 * time.Second) From 6aabb0377b55078e0784db66407c13926400b1b4 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Thu, 19 Mar 2026 17:16:15 +0900 Subject: [PATCH 43/43] Fix bounded wait in cleanup/exitPubSubMode and sweepAll secondary drain - Replace unbounded <-s.fwdDone after dconn.Close with a second bounded wait via waitFwdDone helper, preventing indefinite hangs if forwardMessages is stuck on the upstream channel - Extend sweepAll to also drain buffered unmatched secondaries from unmatchedSecondaries, ensuring DivExtraData divergences are reported on shutdown instead of silently dropped --- proxy/pubsub.go | 28 +++++++++++++++++++--------- proxy/shadow_pubsub.go | 20 +++++++++++++++++++- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/proxy/pubsub.go b/proxy/pubsub.go index 2b049780..2ee628a8 100644 --- a/proxy/pubsub.go +++ b/proxy/pubsub.go @@ -86,13 +86,13 @@ func (s *pubsubSession) cleanup() { s.mu.Unlock() if s.fwdDone != nil { // Bounded wait: if forwardMessages is stuck on a slow/dead client socket, - // close dconn to unblock it, then wait for completion. - select { - case <-s.fwdDone: - case <-time.After(cleanupFwdTimeout): + // close dconn to unblock it, then wait with a second bounded timeout. + if !s.waitFwdDone() { s.logger.Warn("forwardMessages did not exit within timeout, closing dconn to unblock") s.dconn.Close() - <-s.fwdDone + if !s.waitFwdDone() { + s.logger.Error("forwardMessages still stuck after dconn.Close, abandoning") + } s.closeShadow() return // dconn already closed } @@ -102,6 +102,16 @@ func (s *pubsubSession) cleanup() { s.dconn.Close() } +// waitFwdDone waits for fwdDone with a bounded timeout, returning true if it completed. +func (s *pubsubSession) waitFwdDone() bool { + select { + case <-s.fwdDone: + return true + case <-time.After(cleanupFwdTimeout): + return false + } +} + func (s *pubsubSession) startForwarding() { // Capture upstream under lock to avoid race with exitPubSubMode. s.mu.Lock() @@ -201,12 +211,12 @@ func (s *pubsubSession) exitPubSubMode() { } s.mu.Unlock() if s.fwdDone != nil { - select { - case <-s.fwdDone: - case <-time.After(cleanupFwdTimeout): + if !s.waitFwdDone() { s.logger.Warn("forwardMessages did not exit within timeout, closing dconn to unblock") s.dconn.Close() - <-s.fwdDone + if !s.waitFwdDone() { + s.logger.Error("forwardMessages still stuck after dconn.Close, abandoning") + } } s.fwdDone = nil } diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go index cc5adadc..15adee7f 100644 --- a/proxy/shadow_pubsub.go +++ b/proxy/shadow_pubsub.go @@ -305,7 +305,8 @@ func sweepExpiredSecondaries(now time.Time, window time.Duration, secBuf map[msg return out } -// sweepAll reports all remaining pending messages as divergences (used on shutdown). +// sweepAll reports all remaining pending primaries and buffered secondaries +// as divergences (used on shutdown / channel close). func (sp *shadowPubSub) sweepAll() { var divergences []divergenceEvent @@ -323,6 +324,23 @@ func (sp *shadowPubSub) sweepAll() { } sp.mu.Unlock() + // Also drain any buffered unmatched secondaries for this instance. + unmatchedSecondaries.Lock() + if perInstance, ok := unmatchedSecondaries.data[sp]; ok { + for key, secs := range perInstance { + for _, sec := range secs { + divergences = append(divergences, divergenceEvent{ + channel: sec.channel, + payload: sec.payload, + kind: DivExtraData, + }) + } + delete(perInstance, key) + } + delete(unmatchedSecondaries.data, sp) + } + unmatchedSecondaries.Unlock() + for _, d := range divergences { sp.reportDivergence(d) }