diff --git a/cmd/redis-proxy/main.go b/cmd/redis-proxy/main.go new file mode 100644 index 00000000..24f3722f --- /dev/null +++ b/cmd/redis-proxy/main.go @@ -0,0 +1,126 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log/slog" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/bootjp/elastickv/proxy" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +const ( + sentryFlushTimeout = 2 * time.Second + metricsShutdownTimeout = 5 * time.Second +) + +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } +} + +func run() error { + cfg := proxy.DefaultConfig() + var modeStr string + + flag.StringVar(&cfg.ListenAddr, "listen", cfg.ListenAddr, "Proxy listen address") + flag.StringVar(&cfg.PrimaryAddr, "primary", cfg.PrimaryAddr, "Primary (Redis) address") + flag.IntVar(&cfg.PrimaryDB, "primary-db", cfg.PrimaryDB, "Primary Redis DB number") + flag.StringVar(&cfg.PrimaryPassword, "primary-password", cfg.PrimaryPassword, "Primary Redis password") + flag.StringVar(&cfg.SecondaryAddr, "secondary", cfg.SecondaryAddr, "Secondary (ElasticKV) address") + flag.IntVar(&cfg.SecondaryDB, "secondary-db", cfg.SecondaryDB, "Secondary Redis DB number") + flag.StringVar(&cfg.SecondaryPassword, "secondary-password", cfg.SecondaryPassword, "Secondary Redis password") + flag.StringVar(&modeStr, "mode", "dual-write", "Proxy mode: redis-only, dual-write, dual-write-shadow, elastickv-primary, elastickv-only") + flag.DurationVar(&cfg.SecondaryTimeout, "secondary-timeout", cfg.SecondaryTimeout, "Secondary write timeout") + flag.DurationVar(&cfg.ShadowTimeout, "shadow-timeout", cfg.ShadowTimeout, "Shadow read timeout") + flag.StringVar(&cfg.SentryDSN, "sentry-dsn", cfg.SentryDSN, "Sentry DSN (empty = disabled)") + flag.StringVar(&cfg.SentryEnv, "sentry-env", cfg.SentryEnv, "Sentry environment") + flag.Float64Var(&cfg.SentrySampleRate, "sentry-sample", cfg.SentrySampleRate, "Sentry sample rate") + flag.StringVar(&cfg.MetricsAddr, "metrics", cfg.MetricsAddr, "Prometheus metrics address") + flag.Parse() + + mode, ok := proxy.ParseProxyMode(modeStr) + if !ok { + return fmt.Errorf("unknown mode: %s", modeStr) + } + cfg.Mode = mode + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + + // Sentry + sentryReporter := proxy.NewSentryReporter(cfg.SentryDSN, cfg.SentryEnv, cfg.SentrySampleRate, logger) + defer sentryReporter.Flush(sentryFlushTimeout) + + // Prometheus + reg := prometheus.NewRegistry() + metrics := proxy.NewProxyMetrics(reg) + + // Backends + primaryOpts := proxy.DefaultBackendOptions() + primaryOpts.DB = cfg.PrimaryDB + primaryOpts.Password = cfg.PrimaryPassword + secondaryOpts := proxy.DefaultBackendOptions() + secondaryOpts.DB = cfg.SecondaryDB + secondaryOpts.Password = cfg.SecondaryPassword + + var primary, secondary proxy.Backend + switch cfg.Mode { + case proxy.ModeElasticKVPrimary, proxy.ModeElasticKVOnly: + primary = proxy.NewRedisBackendWithOptions(cfg.SecondaryAddr, "elastickv", secondaryOpts) + secondary = proxy.NewRedisBackendWithOptions(cfg.PrimaryAddr, "redis", primaryOpts) + case proxy.ModeRedisOnly, proxy.ModeDualWrite, proxy.ModeDualWriteShadow: + primary = proxy.NewRedisBackendWithOptions(cfg.PrimaryAddr, "redis", primaryOpts) + secondary = proxy.NewRedisBackendWithOptions(cfg.SecondaryAddr, "elastickv", secondaryOpts) + } + defer primary.Close() + defer secondary.Close() + + dual := proxy.NewDualWriter(primary, secondary, cfg, metrics, sentryReporter, logger) + defer dual.Close() // wait for in-flight async goroutines + srv := proxy.NewProxyServer(cfg, dual, metrics, sentryReporter, logger) + + // Context for graceful shutdown + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + // Start metrics server + go func() { + mux := http.NewServeMux() + mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{})) + var lc net.ListenConfig + ln, err := lc.Listen(ctx, "tcp", cfg.MetricsAddr) + if err != nil { + logger.Error("metrics listen failed", "addr", cfg.MetricsAddr, "err", err) + return + } + metricsSrv := &http.Server{Handler: mux, ReadHeaderTimeout: time.Second} + go func() { + <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), metricsShutdownTimeout) + defer shutdownCancel() + if err := metricsSrv.Shutdown(shutdownCtx); err != nil { + logger.Warn("metrics server shutdown error", "err", err) + } + }() + logger.Info("metrics server starting", "addr", cfg.MetricsAddr) + if err := metricsSrv.Serve(ln); err != nil && err != http.ErrServerClosed { + logger.Error("metrics server error", "err", err) + } + }() + + // Start proxy + if err := srv.ListenAndServe(ctx); err != nil { + return fmt.Errorf("proxy server: %w", err) + } + return nil +} diff --git a/go.mod b/go.mod index 806fe0e4..4c6824c0 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/cockroachdb/errors v1.12.0 github.com/cockroachdb/pebble v1.1.5 github.com/emirpasic/gods v1.18.1 + github.com/getsentry/sentry-go v0.27.0 github.com/hashicorp/go-hclog v1.6.3 github.com/hashicorp/raft v1.7.3 github.com/pkg/errors v0.9.1 @@ -58,7 +59,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fatih/color v1.15.0 // indirect - github.com/getsentry/sentry-go v0.27.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.5-0.20231225225746-43d5d4cd4e0e // indirect diff --git a/proxy/backend.go b/proxy/backend.go new file mode 100644 index 00000000..9960cee1 --- /dev/null +++ b/proxy/backend.go @@ -0,0 +1,122 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +const ( + defaultPoolSize = 128 + defaultDialTimeout = 5 * time.Second + defaultReadTimeout = 3 * time.Second + defaultWriteTimeout = 3 * time.Second +) + +// Backend abstracts a Redis-protocol endpoint (real Redis or ElasticKV). +type Backend interface { + // Do sends a single command and returns its result. + Do(ctx context.Context, args ...any) *redis.Cmd + // Pipeline sends multiple commands in a pipeline. + Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) + // Close releases the underlying connection. + Close() error + // Name identifies this backend for logging and metrics. + Name() string +} + +// BackendOptions configures the underlying go-redis connection pool. +type BackendOptions struct { + DB int + Password string + PoolSize int + DialTimeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +// DefaultBackendOptions returns reasonable defaults for a proxy backend. +func DefaultBackendOptions() BackendOptions { + return BackendOptions{ + PoolSize: defaultPoolSize, + DialTimeout: defaultDialTimeout, + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + } +} + +// PubSubBackend is an optional interface for backends that support +// creating dedicated PubSub connections. +type PubSubBackend interface { + NewPubSub(ctx context.Context) *redis.PubSub +} + +// RedisBackend connects to an upstream Redis instance via go-redis. +type RedisBackend struct { + client *redis.Client + name string +} + +// NewRedisBackend creates a Backend targeting a Redis server with default pool options. +func NewRedisBackend(addr string, name string) *RedisBackend { + return NewRedisBackendWithOptions(addr, name, DefaultBackendOptions()) +} + +// NewRedisBackendWithOptions creates a Backend with explicit pool configuration. +func NewRedisBackendWithOptions(addr string, name string, opts BackendOptions) *RedisBackend { + return &RedisBackend{ + client: redis.NewClient(&redis.Options{ + Addr: addr, + DB: opts.DB, + Password: opts.Password, + PoolSize: opts.PoolSize, + DialTimeout: opts.DialTimeout, + ReadTimeout: opts.ReadTimeout, + WriteTimeout: opts.WriteTimeout, + }), + name: name, + } +} + +func (b *RedisBackend) Do(ctx context.Context, args ...any) *redis.Cmd { + return b.client.Do(ctx, args...) +} + +func (b *RedisBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) { + pipe := b.client.Pipeline() + results := make([]*redis.Cmd, len(cmds)) + for i, args := range cmds { + results[i] = pipe.Do(ctx, args...) + } + _, err := pipe.Exec(ctx) + if err != nil { + // go-redis pipelines return redis.Error for Redis reply errors (e.g., EXECABORT). + // Return results with nil error so callers can read per-command results (especially EXEC). + // Only propagate true transport/context errors. + var redisErr redis.Error + if errors.As(err, &redisErr) || errors.Is(err, redis.Nil) { + return results, nil + } + return results, fmt.Errorf("pipeline exec: %w", err) + } + return results, nil +} + +func (b *RedisBackend) Close() error { + if err := b.client.Close(); err != nil { + return fmt.Errorf("close %s backend: %w", b.name, err) + } + return nil +} + +func (b *RedisBackend) Name() string { + return b.name +} + +// NewPubSub creates a dedicated PubSub connection (not from the pool). +func (b *RedisBackend) NewPubSub(ctx context.Context) *redis.PubSub { + return b.client.Subscribe(ctx) +} diff --git a/proxy/command.go b/proxy/command.go new file mode 100644 index 00000000..3231af06 --- /dev/null +++ b/proxy/command.go @@ -0,0 +1,255 @@ +package proxy + +import "strings" + +// CommandCategory classifies a Redis command for routing purposes. +type CommandCategory int + +const ( + CmdRead CommandCategory = iota // GET, HGET, LRANGE, ZRANGE, etc. + CmdWrite // SET, DEL, HSET, LPUSH, ZADD, etc. + CmdBlocking // BLPOP, BRPOP, BZPOPMIN, XREAD (with BLOCK) + CmdPubSub // SUBSCRIBE, UNSUBSCRIBE, PSUBSCRIBE, PUNSUBSCRIBE, PUBSUB (note: PUBLISH is CmdWrite) + CmdAdmin // PING, INFO, CLIENT, SELECT, QUIT, DBSIZE, SCAN, AUTH + CmdTxn // MULTI, EXEC, DISCARD + CmdScript // EVAL, EVALSHA +) + +var commandTable = map[string]CommandCategory{ + // ---- Read commands ---- + "GET": CmdRead, + "GETRANGE": CmdRead, + "MGET": CmdRead, + "STRLEN": CmdRead, + "HGET": CmdRead, + "HGETALL": CmdRead, + "HEXISTS": CmdRead, + "HLEN": CmdRead, + "HMGET": CmdRead, + "HKEYS": CmdRead, + "HVALS": CmdRead, + "HRANDFIELD": CmdRead, + "HSCAN": CmdRead, + "EXISTS": CmdRead, + "KEYS": CmdRead, + "RANDOMKEY": CmdRead, + "LINDEX": CmdRead, + "LLEN": CmdRead, + "LPOS": CmdRead, + "LRANGE": CmdRead, + "PTTL": CmdRead, + "TTL": CmdRead, + "TYPE": CmdRead, + "SCARD": CmdRead, + "SISMEMBER": CmdRead, + "SMISMEMBER": CmdRead, + "SMEMBERS": CmdRead, + "SRANDMEMBER": CmdRead, + "SDIFF": CmdRead, + "SINTER": CmdRead, + "SINTERCARD": CmdRead, + "SUNION": CmdRead, + "SSCAN": CmdRead, + "XLEN": CmdRead, + "XRANGE": CmdRead, + "XREVRANGE": CmdRead, + "XINFO": CmdRead, + "XPENDING": CmdRead, + "ZCARD": CmdRead, + "ZCOUNT": CmdRead, + "ZLEXCOUNT": CmdRead, + "ZMSCORE": CmdRead, + "ZRANGE": CmdRead, + "ZRANGEBYSCORE": CmdRead, + "ZRANGEBYLEX": CmdRead, + "ZREVRANGE": CmdRead, + "ZREVRANGEBYSCORE": CmdRead, + "ZREVRANGEBYLEX": CmdRead, + "ZRANK": CmdRead, + "ZREVRANK": CmdRead, + "ZSCORE": CmdRead, + "ZSCAN": CmdRead, + "ZDIFF": CmdRead, + "PFCOUNT": CmdRead, + "TOUCH": CmdRead, + "DUMP": CmdRead, + "GEODIST": CmdRead, + "GEOHASH": CmdRead, + "GEOPOS": CmdRead, + "GEOSEARCH": CmdRead, + "GEORADIUS_RO": CmdRead, + "GEORADIUSBYMEMBER_RO": CmdRead, + "OBJECT": CmdRead, + "SORT_RO": CmdRead, + "SUBSTR": CmdRead, + + // ---- Write commands ---- + "SET": CmdWrite, + "SETEX": CmdWrite, + "PSETEX": CmdWrite, + "SETNX": CmdWrite, + "SETRANGE": CmdWrite, + "MSET": CmdWrite, + "MSETNX": CmdWrite, + "APPEND": CmdWrite, + "GETSET": CmdWrite, + "GETEX": CmdWrite, + "GETDEL": CmdWrite, + "DEL": CmdWrite, + "UNLINK": CmdWrite, + "COPY": CmdWrite, + "RENAME": CmdWrite, + "RENAMENX": CmdWrite, + "RESTORE": CmdWrite, + "INCR": CmdWrite, + "INCRBY": CmdWrite, + "INCRBYFLOAT": CmdWrite, + "DECR": CmdWrite, + "DECRBY": CmdWrite, + "HSET": CmdWrite, + "HMSET": CmdWrite, + "HDEL": CmdWrite, + "HINCRBY": CmdWrite, + "HINCRBYFLOAT": CmdWrite, + "HSETNX": CmdWrite, + "LPUSH": CmdWrite, + "LPUSHX": CmdWrite, + "LPOP": CmdWrite, + "RPUSH": CmdWrite, + "RPUSHX": CmdWrite, + "RPOP": CmdWrite, + "RPOPLPUSH": CmdWrite, + "LMOVE": CmdWrite, + "LREM": CmdWrite, + "LSET": CmdWrite, + "LTRIM": CmdWrite, + "LINSERT": CmdWrite, + "SADD": CmdWrite, + "SREM": CmdWrite, + "SPOP": CmdWrite, + "SMOVE": CmdWrite, + "SDIFFSTORE": CmdWrite, + "SINTERSTORE": CmdWrite, + "SUNIONSTORE": CmdWrite, + "EXPIRE": CmdWrite, + "PEXPIRE": CmdWrite, + "EXPIREAT": CmdWrite, + "PEXPIREAT": CmdWrite, + "PERSIST": CmdWrite, + "SORT": CmdWrite, + "XADD": CmdWrite, + "XTRIM": CmdWrite, + "XACK": CmdWrite, + "XDEL": CmdWrite, + "XGROUP": CmdWrite, + "XCLAIM": CmdWrite, + "XAUTOCLAIM": CmdWrite, + "ZADD": CmdWrite, + "ZINCRBY": CmdWrite, + "ZREM": CmdWrite, + "ZREMRANGEBYSCORE": CmdWrite, + "ZREMRANGEBYRANK": CmdWrite, + "ZREMRANGEBYLEX": CmdWrite, + "ZPOPMIN": CmdWrite, + "ZPOPMAX": CmdWrite, + "ZRANGESTORE": CmdWrite, + "ZUNIONSTORE": CmdWrite, + "ZINTERSTORE": CmdWrite, + "ZDIFFSTORE": CmdWrite, + "PFADD": CmdWrite, + "PFMERGE": CmdWrite, + "GEOADD": CmdWrite, + "GEORADIUS": CmdWrite, + "GEORADIUSBYMEMBER": CmdWrite, + "GEOSEARCHSTORE": CmdWrite, + "FLUSHALL": CmdWrite, + "FLUSHDB": CmdWrite, + "PUBLISH": CmdWrite, // write to both backends + + // ---- Blocking commands ---- + "BLPOP": CmdBlocking, + "BRPOP": CmdBlocking, + "BRPOPLPUSH": CmdBlocking, + "BLMOVE": CmdBlocking, + "BLMPOP": CmdBlocking, + "BZPOPMIN": CmdBlocking, + "BZPOPMAX": CmdBlocking, + // XREAD/XREADGROUP handled specially in ClassifyCommand (BLOCK arg check) + + // ---- PubSub commands ---- + "SUBSCRIBE": CmdPubSub, + "UNSUBSCRIBE": CmdPubSub, + "PSUBSCRIBE": CmdPubSub, + "PUNSUBSCRIBE": CmdPubSub, + "SSUBSCRIBE": CmdPubSub, + "SUNSUBSCRIBE": CmdPubSub, + "PUBSUB": CmdPubSub, + + // ---- Admin commands — forwarded to primary only ---- + "PING": CmdAdmin, + "ECHO": CmdAdmin, + "INFO": CmdAdmin, + "CLIENT": CmdAdmin, + "SELECT": CmdAdmin, + "QUIT": CmdAdmin, + "RESET": CmdAdmin, + "DBSIZE": CmdAdmin, + "SCAN": CmdAdmin, + "AUTH": CmdAdmin, + "HELLO": CmdAdmin, + "WAIT": CmdAdmin, + "CONFIG": CmdAdmin, + "DEBUG": CmdAdmin, + "CLUSTER": CmdAdmin, + "COMMAND": CmdAdmin, + "TIME": CmdAdmin, + "SLOWLOG": CmdAdmin, + "MEMORY": CmdAdmin, + "LATENCY": CmdAdmin, + "MODULE": CmdAdmin, + "ACL": CmdAdmin, + "SWAPDB": CmdAdmin, + // WATCH/UNWATCH: connection-scoped optimistic locking. + // Forwarded to primary only since the proxy uses a shared connection pool + // and per-connection WATCH state cannot be maintained. + "WATCH": CmdAdmin, + "UNWATCH": CmdAdmin, + + // ---- Transaction commands ---- + "MULTI": CmdTxn, + "EXEC": CmdTxn, + "DISCARD": CmdTxn, + + // ---- Script commands ---- + "EVAL": CmdScript, + "EVALSHA": CmdScript, + "EVAL_RO": CmdScript, + "EVALSHA_RO": CmdScript, + "SCRIPT": CmdScript, + "FUNCTION": CmdScript, + "FCALL": CmdScript, + "FCALL_RO": CmdScript, +} + +// ClassifyCommand returns the category for a Redis command name. +// XREAD/XREADGROUP is classified as CmdBlocking if args contain BLOCK, otherwise CmdRead. +// Unknown commands default to CmdWrite (sent to both backends). +func ClassifyCommand(name string, args [][]byte) CommandCategory { + upper := strings.ToUpper(name) + + // Special case: XREAD/XREADGROUP with BLOCK + if upper == "XREAD" || upper == "XREADGROUP" { + for _, arg := range args { + if strings.ToUpper(string(arg)) == "BLOCK" { + return CmdBlocking + } + } + return CmdRead + } + + if cat, ok := commandTable[upper]; ok { + return cat + } + // Unknown commands → treat as write (safe default, sent to both backends) + return CmdWrite +} diff --git a/proxy/compare.go b/proxy/compare.go new file mode 100644 index 00000000..5d2b1a37 --- /dev/null +++ b/proxy/compare.go @@ -0,0 +1,215 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "log/slog" + "reflect" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" +) + +const ( + unknownStr = "unknown" + defaultGapLogSampleRate int64 = 1000 +) + +// DivergenceKind classifies the nature of a shadow-read mismatch. +type DivergenceKind int + +const ( + DivMigrationGap DivergenceKind = iota // Secondary nil/empty, Primary has data → expected during migration + DivDataMismatch // Both have data but differ → real inconsistency + DivExtraData // Primary nil, Secondary has data → unexpected +) + +func (k DivergenceKind) String() string { + switch k { + case DivMigrationGap: + return "migration_gap" + case DivDataMismatch: + return "data_mismatch" + case DivExtraData: + return "extra_data" + default: + return unknownStr + } +} + +// Divergence records a detected mismatch between primary and secondary. +type Divergence struct { + Command string + Key string + Kind DivergenceKind + Primary any + Secondary any + DetectedAt time.Time +} + +// ShadowReader compares primary and secondary read results. +type ShadowReader struct { + secondary Backend + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + timeout time.Duration + + // Sampling counter for migration gap logs (log 1 per gapLogSampleRate). + gapCount atomic.Int64 + gapLogSampleRate int64 +} + +// NewShadowReader creates a ShadowReader. +func NewShadowReader(secondary Backend, metrics *ProxyMetrics, sentryReporter *SentryReporter, logger *slog.Logger, timeout time.Duration) *ShadowReader { + return &ShadowReader{ + secondary: secondary, + metrics: metrics, + sentry: sentryReporter, + logger: logger, + timeout: timeout, + gapLogSampleRate: defaultGapLogSampleRate, + } +} + +// Compare issues the same read to the secondary and checks for divergence. +func (s *ShadowReader) Compare(ctx context.Context, cmd string, args [][]byte, primaryResp any, primaryErr error) { + sCtx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + + iArgs := bytesArgsToInterfaces(args) + secondaryResult := s.secondary.Do(sCtx, iArgs...) + secondaryResp, secondaryErr := secondaryResult.Result() + + if isConsistent(primaryResp, secondaryResp, primaryErr, secondaryErr) { + return + } + + // --- Divergence detected --- + kind := classifyDivergence(primaryResp, primaryErr, secondaryResp, secondaryErr) + + if kind == DivMigrationGap { + s.metrics.MigrationGaps.WithLabelValues(cmd).Inc() + count := s.gapCount.Add(1) + if count%s.gapLogSampleRate == 1 { + s.logger.Debug("migration gap (sampled)", "cmd", cmd, "key", truncateValue(extractKey(args))) + } + return + } + + div := Divergence{ + Command: cmd, + Key: extractKey(args), + Kind: kind, + Primary: formatResp(primaryResp, primaryErr), + Secondary: formatResp(secondaryResp, secondaryErr), + DetectedAt: time.Now(), + } + + s.metrics.Divergences.WithLabelValues(cmd, kind.String()).Inc() + s.logger.Warn("response divergence detected", + "cmd", div.Command, "key", truncateValue(div.Key), "kind", div.Kind.String(), + "primary", truncateValue(div.Primary), + "secondary", truncateValue(div.Secondary), + ) + s.sentry.CaptureDivergence(div) +} + +// isConsistent checks whether primary and secondary responses agree. +func isConsistent(primaryResp, secondaryResp any, primaryErr, secondaryErr error) bool { + // Both are redis.Nil → consistent (key missing on both) + if isNilError(primaryErr) && isNilError(secondaryErr) { + return true + } + // Both returned the same non-nil error → consistent + if primaryErr != nil && secondaryErr != nil && primaryErr.Error() == secondaryErr.Error() { + return true + } + // Both succeeded with equal values → consistent + return primaryErr == nil && secondaryErr == nil && responseEqual(primaryResp, secondaryResp) +} + +// responseEqual compares two go-redis response values for equality. +func responseEqual(a, b any) bool { + if a == nil || b == nil { + return a == nil && b == nil + } + switch av := a.(type) { + case string: + bv, ok := b.(string) + return ok && av == bv + case int64: + bv, ok := b.(int64) + return ok && av == bv + case []any: + return interfaceSliceEqual(av, b) + default: + return reflect.DeepEqual(a, b) + } +} + +// interfaceSliceEqual compares two []interface{} slices element-by-element. +func interfaceSliceEqual(av []any, b any) bool { + bv, ok := b.([]any) + if !ok || len(av) != len(bv) { + return false + } + for i := range av { + if !responseEqual(av[i], bv[i]) { + return false + } + } + return true +} + +// classifyDivergence determines the kind based on primary/secondary values. +func classifyDivergence(primaryResp any, primaryErr error, secondaryResp any, secondaryErr error) DivergenceKind { + primaryNil := isNilResp(primaryResp, primaryErr) + secondaryNil := isNilResp(secondaryResp, secondaryErr) + + switch { + case !primaryNil && secondaryNil: + return DivMigrationGap + case primaryNil && !secondaryNil: + return DivExtraData + default: + return DivDataMismatch + } +} + +func isNilError(err error) bool { + return errors.Is(err, redis.Nil) +} + +// isNilResp checks if a response represents "no data" (nil response or redis.Nil error). +// Empty string is NOT nil — it is a valid value. +func isNilResp(resp any, err error) bool { + if errors.Is(err, redis.Nil) { + return true + } + return resp == nil +} + +func formatResp(resp any, err error) any { + if err != nil { + return fmt.Sprintf("error: %v", err) + } + return resp +} + +func extractKey(args [][]byte) string { + if len(args) > 1 { + return string(args[1]) + } + return "" +} + +func bytesArgsToInterfaces(args [][]byte) []any { + out := make([]any, len(args)) + for i, a := range args { + out[i] = a + } + return out +} diff --git a/proxy/config.go b/proxy/config.go new file mode 100644 index 00000000..44ba290d --- /dev/null +++ b/proxy/config.go @@ -0,0 +1,84 @@ +package proxy + +import "time" + +const ( + defaultSecondaryTimeout = 5 * time.Second + defaultShadowTimeout = 3 * time.Second + defaultPubSubCompareWindow = 2 * time.Second + defaultPubSubSweepInterval = 500 * time.Millisecond +) + +// ProxyMode controls which backends receive reads and writes. +type ProxyMode int + +const ( + ModeRedisOnly ProxyMode = iota // Redis only (passthrough) + ModeDualWrite // Write to both, read from Redis + ModeDualWriteShadow // Write to both, read from Redis + shadow read from ElasticKV + ModeElasticKVPrimary // Write to both, read from ElasticKV + shadow read from Redis + ModeElasticKVOnly // ElasticKV only +) + +var modeNames = map[string]ProxyMode{ + "redis-only": ModeRedisOnly, + "dual-write": ModeDualWrite, + "dual-write-shadow": ModeDualWriteShadow, + "elastickv-primary": ModeElasticKVPrimary, + "elastickv-only": ModeElasticKVOnly, +} + +var modeStrings = map[ProxyMode]string{ + ModeRedisOnly: "redis-only", + ModeDualWrite: "dual-write", + ModeDualWriteShadow: "dual-write-shadow", + ModeElasticKVPrimary: "elastickv-primary", + ModeElasticKVOnly: "elastickv-only", +} + +// ParseProxyMode converts a string to ProxyMode. Returns false if unknown. +func ParseProxyMode(s string) (ProxyMode, bool) { + m, ok := modeNames[s] + return m, ok +} + +func (m ProxyMode) String() string { + if s, ok := modeStrings[m]; ok { + return s + } + return unknownStr +} + +// ProxyConfig holds all configuration for the dual-write proxy. +type ProxyConfig struct { + ListenAddr string + PrimaryAddr string + PrimaryDB int + PrimaryPassword string + SecondaryAddr string + SecondaryDB int + SecondaryPassword string + Mode ProxyMode + SecondaryTimeout time.Duration + ShadowTimeout time.Duration + SentryDSN string + SentryEnv string + SentrySampleRate float64 + MetricsAddr string + PubSubCompareWindow time.Duration +} + +// DefaultConfig returns a ProxyConfig with sensible defaults. +func DefaultConfig() ProxyConfig { + return ProxyConfig{ + ListenAddr: ":6479", + PrimaryAddr: "localhost:6379", + SecondaryAddr: "localhost:6380", + Mode: ModeDualWrite, + SecondaryTimeout: defaultSecondaryTimeout, + ShadowTimeout: defaultShadowTimeout, + SentrySampleRate: 1.0, + MetricsAddr: ":9191", + PubSubCompareWindow: defaultPubSubCompareWindow, + } +} diff --git a/proxy/dualwrite.go b/proxy/dualwrite.go new file mode 100644 index 00000000..81fa919a --- /dev/null +++ b/proxy/dualwrite.go @@ -0,0 +1,308 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" +) + +const ( + // maxWriteGoroutines limits concurrent secondary write goroutines. + maxWriteGoroutines = 4096 + // maxShadowGoroutines limits concurrent shadow read goroutines separately + // so that secondary write failures cannot starve shadow reads. + maxShadowGoroutines = 1024 +) + +// DualWriter routes commands to primary and secondary backends based on mode. +type DualWriter struct { + primary Backend + secondary Backend + cfg ProxyConfig + shadow *ShadowReader + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + + writeSem chan struct{} // bounds concurrent secondary write goroutines + shadowSem chan struct{} // bounds concurrent shadow read goroutines + wg sync.WaitGroup + + // closing is set to 1 when Close has begun; accessed atomically. + closing int32 +} + +// NewDualWriter creates a DualWriter with the given backends. +func NewDualWriter(primary, secondary Backend, cfg ProxyConfig, metrics *ProxyMetrics, sentryReporter *SentryReporter, logger *slog.Logger) *DualWriter { + d := &DualWriter{ + primary: primary, + secondary: secondary, + cfg: cfg, + metrics: metrics, + sentry: sentryReporter, + logger: logger, + writeSem: make(chan struct{}, maxWriteGoroutines), + shadowSem: make(chan struct{}, maxShadowGoroutines), + } + + if cfg.Mode == ModeDualWriteShadow || cfg.Mode == ModeElasticKVPrimary { + // Shadow reads go to the non-primary backend for comparison. + // Note: main.go already swaps primary/secondary for ElasticKVPrimary mode, + // so here "secondary" is always the non-primary backend. + shadowBackend := secondary + d.shadow = NewShadowReader(shadowBackend, metrics, sentryReporter, logger, cfg.ShadowTimeout) + } + + return d +} + +// Close waits for all in-flight async goroutines to finish. +// Should be called during graceful shutdown. +func (d *DualWriter) Close() { + // Mark as closing so no new async goroutines are scheduled. + atomic.StoreInt32(&d.closing, 1) + d.wg.Wait() +} + +// Write sends a write command to the primary synchronously, then to the secondary asynchronously. +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Write(ctx context.Context, cmd string, args [][]byte) (any, error) { + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.PrimaryWriteErrors.Inc() + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + d.logger.Error("primary write failed", "cmd", cmd, "err", err) + return nil, fmt.Errorf("primary write %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + // Secondary: async fire-and-forget (bounded) + if d.hasSecondaryWrite() { + d.goWrite(func() { d.writeSecondary(cmd, iArgs) }) + } + + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies +} + +// Read sends a read command to the primary and optionally performs a shadow read. +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Read(ctx context.Context, cmd string, args [][]byte) (any, error) { + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.PrimaryReadErrors.Inc() + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary read %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + // Shadow read (bounded, separate semaphore from writes) + if d.shadow != nil { + shadowArgs := args + shadowResp := resp + shadowErr := err + d.goShadow(func() { + d.shadow.Compare(ctx, cmd, shadowArgs, shadowResp, shadowErr) + }) + } + + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies +} + +// Blocking forwards a blocking command to the primary only. +// Optionally sends a short-timeout version to secondary for warmup. +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Blocking(ctx context.Context, cmd string, args [][]byte) (any, error) { + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary blocking %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + // Warmup: send to secondary with short timeout (fire-and-forget, bounded) + if d.hasSecondaryWrite() { + d.goWrite(func() { + sCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + d.secondary.Do(sCtx, iArgs...) + }) + } + + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies +} + +// Admin forwards an admin command to the primary only. +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Admin(ctx context.Context, cmd string, args [][]byte) (any, error) { + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary admin %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies +} + +// Script forwards EVAL/EVALSHA to the primary, and async replays to secondary. +// cmd must be the pre-uppercased command name. +func (d *DualWriter) Script(ctx context.Context, cmd string, args [][]byte) (any, error) { + iArgs := bytesArgsToInterfaces(args) + + start := time.Now() + result := d.primary.Do(ctx, iArgs...) + resp, err := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.primary.Name()).Observe(time.Since(start).Seconds()) + + if err != nil && !errors.Is(err, redis.Nil) { + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "error").Inc() + return nil, fmt.Errorf("primary script %s: %w", cmd, err) + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.primary.Name(), "ok").Inc() + + if d.hasSecondaryWrite() { + d.goWrite(func() { d.writeSecondary(cmd, iArgs) }) + } + + return resp, err //nolint:wrapcheck // redis.Nil must pass through unwrapped for callers to detect nil replies +} + +func (d *DualWriter) writeSecondary(cmd string, iArgs []any) { + sCtx, cancel := context.WithTimeout(context.Background(), d.cfg.SecondaryTimeout) + defer cancel() + + start := time.Now() + result := d.secondary.Do(sCtx, iArgs...) + _, sErr := result.Result() + d.metrics.CommandDuration.WithLabelValues(cmd, d.secondary.Name()).Observe(time.Since(start).Seconds()) + + if sErr != nil && !errors.Is(sErr, redis.Nil) { + d.metrics.SecondaryWriteErrors.Inc() + d.metrics.CommandTotal.WithLabelValues(cmd, d.secondary.Name(), "error").Inc() + fingerprint := fmt.Sprintf("secondary_write_%s", cmd) + if d.sentry.ShouldReport(fingerprint) { + d.sentry.CaptureException(sErr, "secondary_write_failure", argsToBytes(iArgs)) + } + d.logger.Warn("secondary write failed", "cmd", cmd, "err", sErr) + return + } + d.metrics.CommandTotal.WithLabelValues(cmd, d.secondary.Name(), "ok").Inc() +} + +// goWrite launches fn in a bounded write goroutine. +func (d *DualWriter) goWrite(fn func()) { + d.goAsyncWithSem(d.writeSem, fn) +} + +// goShadow launches fn in a bounded shadow-read goroutine. +func (d *DualWriter) goShadow(fn func()) { + d.goAsyncWithSem(d.shadowSem, fn) +} + +// goAsync launches fn using the write semaphore (for backward compat with txn replay). +func (d *DualWriter) goAsync(fn func()) { + d.goWrite(fn) +} + +// goAsyncWithSem launches fn in a bounded goroutine using the given semaphore. +// If the DualWriter is closing or the semaphore is full, the work is dropped. +func (d *DualWriter) goAsyncWithSem(sem chan struct{}, fn func()) { + if atomic.LoadInt32(&d.closing) != 0 { + return + } + select { + case sem <- struct{}{}: + d.wg.Add(1) + go func() { + defer func() { + <-sem + d.wg.Done() + }() + fn() + }() + default: + d.metrics.AsyncDrops.Inc() + d.logger.Warn("async goroutine limit reached, dropping secondary operation") + } +} + +func (d *DualWriter) hasSecondaryWrite() bool { + switch d.cfg.Mode { + case ModeDualWrite, ModeDualWriteShadow, ModeElasticKVPrimary: + return true + case ModeRedisOnly, ModeElasticKVOnly: + return false + } + return false +} + +// Primary returns the primary backend for direct use (e.g., PubSub). +func (d *DualWriter) Primary() Backend { + return d.primary +} + +// PubSubBackend returns the primary backend as a PubSubBackend, or nil. +func (d *DualWriter) PubSubBackend() PubSubBackend { + if ps, ok := d.primary.(PubSubBackend); ok { + return ps + } + return nil +} + +// ShadowPubSubBackend returns the secondary backend as a PubSubBackend +// when shadow mode is active, or nil otherwise. +func (d *DualWriter) ShadowPubSubBackend() PubSubBackend { + if d.shadow == nil { + return nil + } + if ps, ok := d.secondary.(PubSubBackend); ok { + return ps + } + return nil +} + +// Secondary returns the secondary backend. +func (d *DualWriter) Secondary() Backend { + return d.secondary +} + +func argsToBytes(iArgs []any) [][]byte { + out := make([][]byte, len(iArgs)) + for i, a := range iArgs { + if b, ok := a.([]byte); ok { + out[i] = b + } else { + out[i] = fmt.Appendf(nil, "%v", a) + } + } + return out +} diff --git a/proxy/integration_test.go b/proxy/integration_test.go new file mode 100644 index 00000000..37ab5a5f --- /dev/null +++ b/proxy/integration_test.go @@ -0,0 +1,259 @@ +package proxy_test + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "os" + "testing" + "time" + + "github.com/bootjp/elastickv/proxy" + "github.com/prometheus/client_golang/prometheus" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testPrimaryAddr = "localhost:6379" + +// skipWithoutRedis skips the test if the required Redis instances are not available. +func skipWithoutRedis(t *testing.T, addrs ...string) { + t.Helper() + for _, addr := range addrs { + d := net.Dialer{Timeout: 500 * time.Millisecond} + conn, err := d.DialContext(context.Background(), "tcp", addr) + if err != nil { + t.Skipf("skipping integration test: cannot connect to %s: %v", addr, err) + } + conn.Close() + } +} + +func freePort(t *testing.T) string { + t.Helper() + var lc net.ListenConfig + ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := ln.Addr().String() + ln.Close() + return addr +} + +func setupProxy(t *testing.T, mode proxy.ProxyMode, secondaryAddr string) (*redis.Client, context.CancelFunc) { + t.Helper() + listenAddr := freePort(t) + metricsAddr := freePort(t) + + cfg := proxy.ProxyConfig{ + ListenAddr: listenAddr, + PrimaryAddr: testPrimaryAddr, + SecondaryAddr: secondaryAddr, + Mode: mode, + SecondaryTimeout: 5 * time.Second, + ShadowTimeout: 3 * time.Second, + MetricsAddr: metricsAddr, + } + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + reg := prometheus.NewRegistry() + metrics := proxy.NewProxyMetrics(reg) + sentryReporter := proxy.NewSentryReporter("", "", 1.0, logger) + + var primary, secondary proxy.Backend + switch cfg.Mode { + case proxy.ModeElasticKVPrimary, proxy.ModeElasticKVOnly: + primary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") + secondary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") + case proxy.ModeRedisOnly, proxy.ModeDualWrite, proxy.ModeDualWriteShadow: + primary = proxy.NewRedisBackend(cfg.PrimaryAddr, "redis") + secondary = proxy.NewRedisBackend(cfg.SecondaryAddr, "elastickv") + } + + dual := proxy.NewDualWriter(primary, secondary, cfg, metrics, sentryReporter, logger) + srv := proxy.NewProxyServer(cfg, dual, metrics, sentryReporter, logger) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + _ = srv.ListenAndServe(ctx) + }() + + // Wait for proxy to be ready + var client *redis.Client + for i := range 50 { + _ = i + d := net.Dialer{Timeout: 100 * time.Millisecond} + conn, err := d.DialContext(ctx, "tcp", listenAddr) + if err == nil { + conn.Close() + client = redis.NewClient(&redis.Options{Addr: listenAddr}) + break + } + time.Sleep(50 * time.Millisecond) + } + require.NotNil(t, client, "proxy did not become ready") + + t.Cleanup(func() { + client.Close() + primary.Close() + secondary.Close() + }) + + return client, cancel +} + +func TestIntegration_RedisOnly(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") // secondary unused + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-test-%d", time.Now().UnixNano()) + + // PING + pong, err := client.Ping(ctx).Result() + require.NoError(t, err) + assert.Equal(t, "PONG", pong) + + // SET / GET + err = client.Set(ctx, key, "hello", 0).Err() + require.NoError(t, err) + + val, err := client.Get(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "hello", val) + + // DEL + err = client.Del(ctx, key).Err() + require.NoError(t, err) + + _, err = client.Get(ctx, key).Result() + assert.Equal(t, redis.Nil, err) +} + +func TestIntegration_DualWrite(t *testing.T) { + secondaryAddr := "localhost:6380" + skipWithoutRedis(t, testPrimaryAddr, secondaryAddr) + + client, cancel := setupProxy(t, proxy.ModeDualWrite, secondaryAddr) + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-dual-%d", time.Now().UnixNano()) + + // Write through proxy + err := client.Set(ctx, key, "dual-value", 0).Err() + require.NoError(t, err) + + // Read through proxy (from primary) + val, err := client.Get(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "dual-value", val) + + // Poll secondary until async write propagates (avoids flaky fixed sleeps). + secondaryClient := redis.NewClient(&redis.Options{Addr: secondaryAddr}) + defer secondaryClient.Close() + + require.Eventually(t, func() bool { + sVal, sErr := secondaryClient.Get(ctx, key).Result() + return sErr == nil && sVal == "dual-value" + }, 5*time.Second, 50*time.Millisecond, "secondary should have the value") + + // Clean up — poll until key is gone on secondary. + client.Del(ctx, key) + require.Eventually(t, func() bool { + _, sErr := secondaryClient.Get(ctx, key).Result() + return errors.Is(sErr, redis.Nil) + }, 5*time.Second, 50*time.Millisecond, "secondary should have key deleted") +} + +func TestIntegration_HashCommands(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-hash-%d", time.Now().UnixNano()) + + // HSET + err := client.HSet(ctx, key, "field1", "val1", "field2", "val2").Err() + require.NoError(t, err) + + // HGET + val, err := client.HGet(ctx, key, "field1").Result() + require.NoError(t, err) + assert.Equal(t, "val1", val) + + // HGETALL + all, err := client.HGetAll(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "val1", all["field1"]) + assert.Equal(t, "val2", all["field2"]) + + // HDEL + err = client.HDel(ctx, key, "field1").Err() + require.NoError(t, err) + + _, err = client.HGet(ctx, key, "field1").Result() + assert.Equal(t, redis.Nil, err) + + // Clean up + client.Del(ctx, key) +} + +func TestIntegration_ListCommands(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-list-%d", time.Now().UnixNano()) + + // LPUSH + err := client.LPush(ctx, key, "a", "b", "c").Err() + require.NoError(t, err) + + // LLEN + llen, err := client.LLen(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, int64(3), llen) + + // LRANGE + vals, err := client.LRange(ctx, key, 0, -1).Result() + require.NoError(t, err) + assert.Equal(t, []string{"c", "b", "a"}, vals) + + // Clean up + client.Del(ctx, key) +} + +func TestIntegration_Transaction(t *testing.T) { + skipWithoutRedis(t, testPrimaryAddr) + + client, cancel := setupProxy(t, proxy.ModeRedisOnly, "localhost:16399") + defer cancel() + + ctx := context.Background() + key := fmt.Sprintf("proxy-txn-%d", time.Now().UnixNano()) + + // MULTI/EXEC via pipeline (go-redis TxPipelined) + _, err := client.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, key, "txn-val", 0) + pipe.Get(ctx, key) + return nil + }) + require.NoError(t, err) + + // Verify + val, err := client.Get(ctx, key).Result() + require.NoError(t, err) + assert.Equal(t, "txn-val", val) + + // Clean up + client.Del(ctx, key) +} diff --git a/proxy/metrics.go b/proxy/metrics.go new file mode 100644 index 00000000..f2b4d406 --- /dev/null +++ b/proxy/metrics.go @@ -0,0 +1,112 @@ +package proxy + +import "github.com/prometheus/client_golang/prometheus" + +// ProxyMetrics holds all Prometheus metrics for the dual-write proxy. +type ProxyMetrics struct { + CommandTotal *prometheus.CounterVec + CommandDuration *prometheus.HistogramVec + + PrimaryWriteErrors prometheus.Counter + SecondaryWriteErrors prometheus.Counter + PrimaryReadErrors prometheus.Counter + ShadowReadErrors prometheus.Counter + Divergences *prometheus.CounterVec + MigrationGaps *prometheus.CounterVec + + ActiveConnections prometheus.Gauge + + AsyncDrops prometheus.Counter + + PubSubShadowDivergences *prometheus.CounterVec + PubSubShadowErrors prometheus.Counter +} + +// NewProxyMetrics creates and registers all proxy metrics. +func NewProxyMetrics(reg prometheus.Registerer) *ProxyMetrics { + m := &ProxyMetrics{ + CommandTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "command_total", + Help: "Total commands processed by the proxy.", + }, []string{"command", "backend", "status"}), + + CommandDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "proxy", + Name: "command_duration_seconds", + Help: "Latency of commands forwarded to backends.", + Buckets: prometheus.DefBuckets, + }, []string{"command", "backend"}), + + PrimaryWriteErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "primary_write_errors_total", + Help: "Total write errors from the primary backend.", + }), + SecondaryWriteErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "secondary_write_errors_total", + Help: "Total write errors from the secondary backend.", + }), + PrimaryReadErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "primary_read_errors_total", + Help: "Total read errors from the primary backend.", + }), + ShadowReadErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "shadow_read_errors_total", + Help: "Total errors from shadow reads.", + }), + Divergences: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "divergences_total", + Help: "Total data mismatches detected by shadow reads.", + }, []string{"command", "kind"}), + MigrationGaps: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "migration_gap_total", + Help: "Expected divergences due to missing data on secondary (pre-migration).", + }, []string{"command"}), + + AsyncDrops: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "async_drops_total", + Help: "Total async operations dropped due to semaphore backpressure.", + }), + + ActiveConnections: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "proxy", + Name: "active_connections", + Help: "Current number of active client connections.", + }), + + PubSubShadowDivergences: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "pubsub_shadow_divergences_total", + Help: "Total pub/sub message mismatches detected by shadow subscribe.", + }, []string{"kind"}), + PubSubShadowErrors: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "proxy", + Name: "pubsub_shadow_errors_total", + Help: "Total errors from shadow pub/sub operations.", + }), + } + + reg.MustRegister( + m.CommandTotal, + m.CommandDuration, + m.PrimaryWriteErrors, + m.SecondaryWriteErrors, + m.PrimaryReadErrors, + m.ShadowReadErrors, + m.Divergences, + m.MigrationGaps, + m.AsyncDrops, + m.ActiveConnections, + m.PubSubShadowDivergences, + m.PubSubShadowErrors, + ) + + return m +} diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 00000000..df2cd70f --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,518 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "strings" + + "github.com/redis/go-redis/v9" + "github.com/tidwall/redcon" +) + +// txnCommandsOverhead is the number of extra commands (MULTI + EXEC) wrapping queued commands. +const txnCommandsOverhead = 2 + +// respWriter is the subset of redcon.Conn and redcon.DetachedConn used for writing RESP responses. +// Both connection types satisfy this interface, enabling shared response-writing logic +// between the main event loop and detached pub/sub sessions. +type respWriter interface { + WriteError(msg string) + WriteString(msg string) + WriteBulk(b []byte) + WriteBulkString(msg string) + WriteInt64(num int64) + WriteArray(count int) + WriteNull() +} + +// proxyConnState tracks per-connection state (transactions, PubSub). +type proxyConnState struct { + inTxn bool + txnQueue [][][]byte // buffered commands between MULTI and EXEC +} + +// ProxyServer is a Redis-protocol proxy that dual-writes to two backends. +type ProxyServer struct { + cfg ProxyConfig + dual *DualWriter + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + + // shutdownCtx is cancelled on graceful shutdown, used for blocking commands. + shutdownCtx context.Context +} + +// NewProxyServer creates a proxy server with the given configuration and backends. +func NewProxyServer(cfg ProxyConfig, dual *DualWriter, metrics *ProxyMetrics, sentryReporter *SentryReporter, logger *slog.Logger) *ProxyServer { + return &ProxyServer{ + cfg: cfg, + dual: dual, + metrics: metrics, + sentry: sentryReporter, + logger: logger, + } +} + +// ListenAndServe starts the redcon proxy server. +func (p *ProxyServer) ListenAndServe(ctx context.Context) error { + p.shutdownCtx = ctx + + var lc net.ListenConfig + ln, err := lc.Listen(ctx, "tcp", p.cfg.ListenAddr) + if err != nil { + return fmt.Errorf("proxy listen: %w", err) + } + + srv := redcon.NewServer(p.cfg.ListenAddr, + p.handleCommand, + p.handleAccept, + p.handleClose, + ) + + // Graceful shutdown on context cancel. + go func() { + <-ctx.Done() + p.logger.Info("shutting down proxy server") + srv.Close() + }() + + p.logger.Info("proxy server starting", + "addr", p.cfg.ListenAddr, + "mode", p.cfg.Mode.String(), + "primary", p.cfg.PrimaryAddr, + "secondary", p.cfg.SecondaryAddr, + ) + + if err = srv.Serve(ln); err != nil { + // During graceful shutdown, srv.Close() causes Serve to return + // a listener-closed error. Treat this as a normal exit. + if ctx.Err() != nil { + return nil //nolint:nilerr // intentional: suppress expected listener-closed error during graceful shutdown + } + return fmt.Errorf("proxy serve: %w", err) + } + return nil +} + +func (p *ProxyServer) handleAccept(conn redcon.Conn) bool { + p.metrics.ActiveConnections.Inc() + conn.SetContext(&proxyConnState{}) + return true +} + +func (p *ProxyServer) handleClose(conn redcon.Conn, _ error) { + p.metrics.ActiveConnections.Dec() +} + +func getConnState(conn redcon.Conn) *proxyConnState { + if ctx := conn.Context(); ctx != nil { + if st, ok := ctx.(*proxyConnState); ok { + return st + } + } + st := &proxyConnState{} + conn.SetContext(st) + return st +} + +func (p *ProxyServer) handleCommand(conn redcon.Conn, cmd redcon.Command) { + if len(cmd.Args) == 0 { + conn.WriteError("ERR empty command") + return + } + + name := strings.ToUpper(string(cmd.Args[0])) + args := cloneArgs(cmd.Args) + state := getConnState(conn) + + // Transaction handling + if state.inTxn { + p.handleQueuedCommand(conn, state, name, args) + return + } + + p.dispatchCommand(conn, state, name, args) +} + +func (p *ProxyServer) dispatchCommand(conn redcon.Conn, state *proxyConnState, name string, args [][]byte) { + cat := ClassifyCommand(name, args[1:]) + + switch cat { + case CmdTxn: + p.handleTxnCommand(conn, state, name) + case CmdWrite: + p.handleWrite(conn, name, args) + case CmdRead: + p.handleRead(conn, name, args) + case CmdBlocking: + p.handleBlocking(conn, name, args) + case CmdPubSub: + p.handlePubSub(conn, name, args) + case CmdAdmin: + p.handleAdmin(conn, name, args) + case CmdScript: + p.handleScript(conn, name, args) + } +} + +func (p *ProxyServer) handleQueuedCommand(conn redcon.Conn, state *proxyConnState, name string, args [][]byte) { + switch name { + case cmdExec: + p.execTxn(conn, state) + case cmdDiscard: + p.discardTxn(conn, state) + case cmdMulti: + conn.WriteError("ERR MULTI calls can not be nested") + default: + // NOTE: Commands are queued locally and always return QUEUED without + // upstream validation. Real Redis validates queued commands immediately + // (e.g., wrong arity returns an error). Full compatibility would require + // pinning a dedicated upstream connection for the MULTI..EXEC lifetime. + state.txnQueue = append(state.txnQueue, args) + conn.WriteString("QUEUED") + } +} + +func (p *ProxyServer) handleWrite(conn redcon.Conn, cmd string, args [][]byte) { + resp, err := p.dual.Write(context.Background(), cmd, args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handleRead(conn redcon.Conn, cmd string, args [][]byte) { + resp, err := p.dual.Read(context.Background(), cmd, args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handleBlocking(conn redcon.Conn, cmd string, args [][]byte) { + // Use shutdownCtx so blocking commands are interrupted on graceful shutdown. + resp, err := p.dual.Blocking(p.shutdownCtx, cmd, args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handlePubSub(conn redcon.Conn, name string, args [][]byte) { + + switch name { + case cmdSubscribe, cmdPSubscribe: + p.startPubSubSession(conn, name, args) + case cmdUnsubscribe, cmdPUnsubscribe: + // No active session; emit one reply per argument, or a single null reply if none. + kind := strings.ToLower(name) + if len(args) > 1 { + for _, ch := range args[1:] { + conn.WriteArray(pubsubArrayReply) + conn.WriteBulkString(kind) + conn.WriteBulkString(string(ch)) + conn.WriteInt64(0) + } + } else { + conn.WriteArray(pubsubArrayReply) + conn.WriteBulkString(kind) + conn.WriteNull() + conn.WriteInt64(0) + } + default: + // PUBSUB CHANNELS / NUMSUB etc. + resp, err := p.dual.Admin(context.Background(), name, args) + writeResponse(conn, resp, err) + } +} + +func (p *ProxyServer) startPubSubSession(conn redcon.Conn, cmdName string, args [][]byte) { + psBackend := p.dual.PubSubBackend() + if psBackend == nil { + conn.WriteError("ERR PubSub not supported by backend") + return + } + + if len(args) < pubsubMinArgs { + conn.WriteError(fmt.Sprintf("ERR wrong number of arguments for '%s' command", strings.ToLower(cmdName))) + return + } + + // Create dedicated upstream PubSub connection. + upstream := psBackend.NewPubSub(context.Background()) + + channels := byteSlicesToStrings(args[1:]) + var err error + if cmdName == cmdSubscribe { + err = upstream.Subscribe(context.Background(), channels...) + } else { + err = upstream.PSubscribe(context.Background(), channels...) + } + if err != nil { + upstream.Close() + writeRedisError(conn, err) + return + } + + // Detach the connection from redcon's event loop. + dconn := conn.Detach() + + session := &pubsubSession{ + dconn: dconn, + upstream: upstream, + proxy: p, + logger: p.logger, + channelSet: make(map[string]struct{}), + patternSet: make(map[string]struct{}), + } + + session.shadow = p.createShadowPubSub(cmdName, channels) + + // Write initial subscription confirmations. + kind := strings.ToLower(cmdName) + for _, ch := range channels { + dconn.WriteArray(pubsubArrayReply) + dconn.WriteBulkString(kind) + dconn.WriteBulkString(ch) + if cmdName == cmdSubscribe { + session.channelSet[ch] = struct{}{} + } else { + session.patternSet[ch] = struct{}{} + } + dconn.WriteInt64(int64(session.subCount())) + } + if err := dconn.Flush(); err != nil { + dconn.Close() + upstream.Close() + if session.shadow != nil { + session.shadow.Close() + } + return + } + + go session.run() +} + +// createShadowPubSub creates a shadow pub/sub for secondary comparison if in shadow mode. +// Returns nil if shadow mode is not active or the shadow subscribe fails. +func (p *ProxyServer) createShadowPubSub(cmdName string, channels []string) *shadowPubSub { + shadowBackend := p.dual.ShadowPubSubBackend() + if shadowBackend == nil { + return nil + } + shadow := newShadowPubSub(shadowBackend, p.metrics, p.sentry, p.logger, p.cfg.PubSubCompareWindow) + var err error + if cmdName == cmdSubscribe { + err = shadow.Subscribe(context.Background(), channels...) + } else { + err = shadow.PSubscribe(context.Background(), channels...) + } + if err != nil { + p.logger.Warn("shadow pubsub subscribe failed", "err", err) + p.metrics.PubSubShadowErrors.Inc() + shadow.Close() + return nil + } + shadow.Start() + return shadow +} + +func (p *ProxyServer) handleAdmin(conn redcon.Conn, name string, args [][]byte) { + + // Handle PING locally for speed. + if name == cmdPing { + if len(args) > 1 { + conn.WriteBulk(args[1]) + } else { + conn.WriteString("PONG") + } + return + } + + // Handle QUIT locally. + if name == cmdQuit { + conn.WriteString("OK") + conn.Close() + return + } + + // SELECT: accept only the configured DB; reject others since the proxy + // uses a shared connection pool and cannot maintain per-client DB state. + if name == "SELECT" { + if len(args) > 1 && string(args[1]) != "0" && string(args[1]) != fmt.Sprintf("%d", p.cfg.PrimaryDB) { + conn.WriteError(fmt.Sprintf("ERR proxy does not support SELECT %s (configured DB: %d)", string(args[1]), p.cfg.PrimaryDB)) + return + } + conn.WriteString("OK") + return + } + + // AUTH is handled at the connection-pool level via config. + // Silently accept so clients don't break. + if name == "AUTH" { + conn.WriteString("OK") + return + } + + // HELLO: reject to prevent RESP3 protocol upgrade, which the proxy + // (redcon, RESP2 only) cannot support. + if name == "HELLO" { + conn.WriteError("ERR proxy does not support HELLO (RESP2 only)") + return + } + + resp, err := p.dual.Admin(context.Background(), name, args) + writeResponse(conn, resp, err) +} + +func (p *ProxyServer) handleScript(conn redcon.Conn, name string, args [][]byte) { + resp, err := p.dual.Script(context.Background(), name, args) + writeResponse(conn, resp, err) +} + +// Transaction handling + +func (p *ProxyServer) handleTxnCommand(conn redcon.Conn, state *proxyConnState, name string) { + switch name { + case cmdMulti: + if state.inTxn { + conn.WriteError("ERR MULTI calls can not be nested") + return + } + state.inTxn = true + state.txnQueue = nil + conn.WriteString("OK") + case cmdExec: + if !state.inTxn { + conn.WriteError("ERR EXEC without MULTI") + return + } + p.execTxn(conn, state) + case cmdDiscard: + if !state.inTxn { + conn.WriteError("ERR DISCARD without MULTI") + return + } + p.discardTxn(conn, state) + } +} + +func (p *ProxyServer) execTxn(conn redcon.Conn, state *proxyConnState) { + queue := state.txnQueue + state.inTxn = false + state.txnQueue = nil + + ctx := context.Background() + + // Build pipeline: MULTI + queued commands + EXEC + cmds := make([][]any, 0, len(queue)+txnCommandsOverhead) + cmds = append(cmds, []any{"MULTI"}) + for _, args := range queue { + cmds = append(cmds, bytesArgsToInterfaces(args)) + } + cmds = append(cmds, []any{"EXEC"}) + + results, err := p.dual.Primary().Pipeline(ctx, cmds) + switch { + case err != nil: + // Pipeline-level error (connection/transport failure) takes precedence. + writeRedisError(conn, err) + case len(results) > 0: + // Write the EXEC result (last command in the pipeline). + lastResult := results[len(results)-1] + resp, rErr := lastResult.Result() + writeResponse(conn, resp, rErr) + default: + conn.WriteError("ERR empty transaction response") + } + + // Async replay to secondary (bounded) + if p.dual.hasSecondaryWrite() { + p.dual.goAsync(func() { + sCtx, cancel := context.WithTimeout(context.Background(), p.cfg.SecondaryTimeout) + defer cancel() + _, pErr := p.dual.Secondary().Pipeline(sCtx, cmds) + if pErr != nil { + p.logger.Warn("secondary txn replay failed", "err", pErr) + p.metrics.SecondaryWriteErrors.Inc() + } + }) + } +} + +func (p *ProxyServer) discardTxn(conn redcon.Conn, state *proxyConnState) { + state.inTxn = false + state.txnQueue = nil + conn.WriteString("OK") +} + +// writeResponse handles the common pattern of writing a go-redis response, +// correctly handling redis.Nil and upstream errors. +func writeResponse(w respWriter, resp any, err error) { + if err != nil { + if errors.Is(err, redis.Nil) { + w.WriteNull() + return + } + writeRedisError(w, err) + return + } + writeRedisValue(w, resp) +} + +// writeRedisError writes an upstream error to the client. +// go-redis redis.Error values already carry the Redis prefix (e.g. "ERR ...", "WRONGTYPE ..."). +// Other errors (timeouts, dial failures) are normalized to "ERR ..." to produce valid RESP. +func writeRedisError(w respWriter, err error) { + var redisErr redis.Error + if errors.As(err, &redisErr) { + w.WriteError(redisErr.Error()) + return + } + w.WriteError("ERR " + err.Error()) +} + +// writeRedisValue writes a go-redis response value to a respWriter. +func writeRedisValue(w respWriter, val any) { + if val == nil { + w.WriteNull() + return + } + switch v := val.(type) { + case string: + if isStatusResponse(v) { + w.WriteString(v) + } else { + w.WriteBulkString(v) + } + case int64: + w.WriteInt64(v) + case []any: + w.WriteArray(len(v)) + for _, item := range v { + writeRedisValue(w, item) + } + case []byte: + w.WriteBulk(v) + case redis.Error: + w.WriteError(v.Error()) + default: + w.WriteBulkString(fmt.Sprintf("%v", v)) + } +} + +// isStatusResponse detects known Redis status reply strings that should be +// sent as simple strings (+OK) rather than bulk strings ($2\r\nOK). +func isStatusResponse(s string) bool { + switch s { + case "OK", "QUEUED", "PONG": + return true + default: + return false + } +} + +func cloneArgs(args [][]byte) [][]byte { + out := make([][]byte, len(args)) + for i, arg := range args { + cp := make([]byte, len(arg)) + copy(cp, arg) + out[i] = cp + } + return out +} diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go new file mode 100644 index 00000000..4fd535a8 --- /dev/null +++ b/proxy/proxy_test.go @@ -0,0 +1,797 @@ +package proxy + +import ( + "context" + "errors" + "io" + "log/slog" + "sync" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var testLogger = slog.New(slog.NewTextHandler(io.Discard, nil)) + +// --- Mock Backend --- + +type mockBackend struct { + name string + doFunc func(ctx context.Context, args ...any) *redis.Cmd + mu sync.Mutex + calls [][]any +} + +func newMockBackend(name string) *mockBackend { + return &mockBackend{name: name} +} + +func (b *mockBackend) Do(ctx context.Context, args ...any) *redis.Cmd { + b.mu.Lock() + b.calls = append(b.calls, args) + b.mu.Unlock() + if b.doFunc != nil { + return b.doFunc(ctx, args...) + } + cmd := redis.NewCmd(ctx, args...) + cmd.SetVal("OK") + return cmd +} + +func (b *mockBackend) Pipeline(ctx context.Context, cmds [][]any) ([]*redis.Cmd, error) { + results := make([]*redis.Cmd, len(cmds)) + for i, args := range cmds { + results[i] = b.Do(ctx, args...) + } + return results, nil +} + +func (b *mockBackend) Close() error { return nil } +func (b *mockBackend) Name() string { return b.name } + +func (b *mockBackend) CallCount() int { + b.mu.Lock() + defer b.mu.Unlock() + return len(b.calls) +} + +// Helper to create a doFunc that returns a specific value. +func makeCmd(val any, err error) func(ctx context.Context, args ...any) *redis.Cmd { + return func(ctx context.Context, args ...any) *redis.Cmd { + cmd := redis.NewCmd(ctx, args...) + if err != nil { + cmd.SetErr(err) + } else { + cmd.SetVal(val) + } + return cmd + } +} + +func newTestMetrics() *ProxyMetrics { + reg := prometheus.NewRegistry() + return NewProxyMetrics(reg) +} + +func newTestSentry() *SentryReporter { + return NewSentryReporter("", "", 1.0, testLogger) +} + +// ========== command.go tests ========== + +func TestClassifyCommand(t *testing.T) { + tests := []struct { + name string + cmd string + args [][]byte + expected CommandCategory + }{ + {"GET is read", "GET", nil, CmdRead}, + {"SET is write", "SET", nil, CmdWrite}, + {"DEL is write", "DEL", nil, CmdWrite}, + {"HGET is read", "HGET", nil, CmdRead}, + {"HSET is write", "HSET", nil, CmdWrite}, + {"ZADD is write", "ZADD", nil, CmdWrite}, + {"ZRANGE is read", "ZRANGE", nil, CmdRead}, + {"PING is admin", "PING", nil, CmdAdmin}, + {"INFO is admin", "INFO", nil, CmdAdmin}, + {"SELECT is admin", "SELECT", nil, CmdAdmin}, + {"QUIT is admin", "QUIT", nil, CmdAdmin}, + {"AUTH is admin", "AUTH", nil, CmdAdmin}, + {"HELLO is admin", "HELLO", nil, CmdAdmin}, + {"CONFIG is admin", "CONFIG", nil, CmdAdmin}, + {"WAIT is admin", "WAIT", nil, CmdAdmin}, + {"COMMAND is admin", "COMMAND", nil, CmdAdmin}, + {"MULTI is txn", "MULTI", nil, CmdTxn}, + {"EXEC is txn", "EXEC", nil, CmdTxn}, + {"DISCARD is txn", "DISCARD", nil, CmdTxn}, + {"SUBSCRIBE is pubsub", "SUBSCRIBE", nil, CmdPubSub}, + {"UNSUBSCRIBE is pubsub", "UNSUBSCRIBE", nil, CmdPubSub}, + {"PSUBSCRIBE is pubsub", "PSUBSCRIBE", nil, CmdPubSub}, + {"PUNSUBSCRIBE is pubsub", "PUNSUBSCRIBE", nil, CmdPubSub}, + {"BZPOPMIN is blocking", "BZPOPMIN", nil, CmdBlocking}, + {"EVAL is script", "EVAL", nil, CmdScript}, + {"lowercase get is read", "get", nil, CmdRead}, + {"GETDEL is write", "GETDEL", nil, CmdWrite}, + {"PUBLISH is write", "PUBLISH", nil, CmdWrite}, + {"unknown cmd is write", "UNKNOWNCMD", nil, CmdWrite}, + + // XREAD special handling + {"XREAD without BLOCK is read", "XREAD", [][]byte{[]byte("COUNT"), []byte("10"), []byte("STREAMS"), []byte("s1"), []byte("0")}, CmdRead}, + {"XREAD with BLOCK is blocking", "XREAD", [][]byte{[]byte("BLOCK"), []byte("0"), []byte("STREAMS"), []byte("s1"), []byte("0")}, CmdBlocking}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ClassifyCommand(tt.cmd, tt.args) + assert.Equal(t, tt.expected, got) + }) + } +} + +// ========== config.go tests ========== + +func TestParseProxyMode(t *testing.T) { + tests := []struct { + input string + expected ProxyMode + ok bool + }{ + {"redis-only", ModeRedisOnly, true}, + {"dual-write", ModeDualWrite, true}, + {"dual-write-shadow", ModeDualWriteShadow, true}, + {"elastickv-primary", ModeElasticKVPrimary, true}, + {"elastickv-only", ModeElasticKVOnly, true}, + {"invalid", 0, false}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, ok := ParseProxyMode(tt.input) + assert.Equal(t, tt.ok, ok) + if ok { + assert.Equal(t, tt.expected, got) + } + }) + } +} + +func TestProxyModeString(t *testing.T) { + assert.Equal(t, "redis-only", ModeRedisOnly.String()) + assert.Equal(t, "dual-write", ModeDualWrite.String()) + assert.Equal(t, "dual-write-shadow", ModeDualWriteShadow.String()) + assert.Equal(t, "elastickv-primary", ModeElasticKVPrimary.String()) + assert.Equal(t, "elastickv-only", ModeElasticKVOnly.String()) + assert.Equal(t, "unknown", ProxyMode(99).String()) +} + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + assert.Equal(t, ":6479", cfg.ListenAddr) + assert.Equal(t, "localhost:6379", cfg.PrimaryAddr) + assert.Equal(t, "localhost:6380", cfg.SecondaryAddr) + assert.Equal(t, ModeDualWrite, cfg.Mode) +} + +// ========== compare.go tests ========== + +func TestDivergenceKindString(t *testing.T) { + assert.Equal(t, "migration_gap", DivMigrationGap.String()) + assert.Equal(t, "data_mismatch", DivDataMismatch.String()) + assert.Equal(t, "extra_data", DivExtraData.String()) + assert.Equal(t, "unknown", DivergenceKind(99).String()) +} + +func TestExtractKey(t *testing.T) { + assert.Equal(t, "mykey", extractKey([][]byte{[]byte("GET"), []byte("mykey")})) + assert.Equal(t, "", extractKey([][]byte{[]byte("PING")})) + assert.Equal(t, "", extractKey(nil)) +} + +func TestBytesArgsToInterfaces(t *testing.T) { + args := [][]byte{[]byte("SET"), []byte("key"), []byte("val")} + result := bytesArgsToInterfaces(args) + assert.Len(t, result, 3) + assert.Equal(t, []byte("SET"), result[0]) + assert.Equal(t, []byte("key"), result[1]) + assert.Equal(t, []byte("val"), result[2]) +} + +func TestResponseEqual(t *testing.T) { + tests := []struct { + name string + a, b any + want bool + }{ + {"nil nil", nil, nil, true}, + {"nil vs value", nil, "hello", false}, + {"value vs nil", "hello", nil, false}, + {"same string", "hello", "hello", true}, + {"diff string", "hello", "world", false}, + {"empty string equals empty string", "", "", true}, + {"same int64", int64(42), int64(42), true}, + {"diff int64", int64(42), int64(43), false}, + {"same array", []any{"a", "b"}, []any{"a", "b"}, true}, + {"diff array values", []any{"a", "b"}, []any{"a", "c"}, false}, + {"diff array length", []any{"a"}, []any{"a", "b"}, false}, + {"empty arrays", []any{}, []any{}, true}, + {"same bytes", []byte("data"), []byte("data"), true}, + {"diff bytes", []byte("data"), []byte("other"), false}, + {"nested array", []any{[]any{"x"}}, []any{[]any{"x"}}, true}, + {"type mismatch string vs int", "42", int64(42), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, responseEqual(tt.a, tt.b)) + }) + } +} + +func TestClassifyDivergence(t *testing.T) { + tests := []struct { + name string + primaryResp, secondaryResp any + primaryErr, secondaryErr error + want DivergenceKind + }{ + {"primary has data, secondary nil", "val", nil, nil, redis.Nil, DivMigrationGap}, + {"primary nil, secondary has data", nil, "val", redis.Nil, nil, DivExtraData}, + {"both have different data", "val1", "val2", nil, nil, DivDataMismatch}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := classifyDivergence(tt.primaryResp, tt.primaryErr, tt.secondaryResp, tt.secondaryErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestIsNilResp(t *testing.T) { + assert.True(t, isNilResp(nil, redis.Nil)) + assert.True(t, isNilResp(nil, nil)) + assert.False(t, isNilResp("", nil), "empty string is NOT nil") + assert.False(t, isNilResp("hello", nil)) + assert.False(t, isNilResp(int64(0), nil)) +} + +// ========== proxy.go tests ========== + +func TestCloneArgs(t *testing.T) { + original := [][]byte{[]byte("GET"), []byte("key")} + cloned := cloneArgs(original) + + assert.Equal(t, original, cloned) + + // Mutating cloned should not affect original. + cloned[0][0] = 'X' + assert.Equal(t, byte('G'), original[0][0]) +} + +func TestIsStatusResponse(t *testing.T) { + assert.True(t, isStatusResponse("OK")) + assert.True(t, isStatusResponse("QUEUED")) + assert.True(t, isStatusResponse("PONG")) + assert.False(t, isStatusResponse("hello")) + assert.False(t, isStatusResponse("")) + assert.False(t, isStatusResponse("ok")) // case-sensitive +} + +// ========== sentry.go tests ========== + +func TestSentryReporterDisabled(t *testing.T) { + r := NewSentryReporter("", "", 1.0, nil) + assert.False(t, r.enabled) + // Should not panic + r.CaptureException(nil, "test", nil) + r.CaptureDivergence(Divergence{}) + r.Flush(0) +} + +func TestShouldReportCooldown(t *testing.T) { + r := &SentryReporter{ + lastReport: make(map[string]time.Time), + cooldown: 50 * time.Millisecond, + } + + assert.True(t, r.ShouldReport("fp1")) + assert.False(t, r.ShouldReport("fp1")) // within cooldown + + time.Sleep(60 * time.Millisecond) + assert.True(t, r.ShouldReport("fp1")) // cooldown elapsed +} + +func TestShouldReportEvictsExpired(t *testing.T) { + r := &SentryReporter{ + lastReport: make(map[string]time.Time), + cooldown: 1 * time.Millisecond, + } + // Fill to maxReportEntries + for i := range maxReportEntries { + r.lastReport[string(rune(i))] = time.Now().Add(-time.Hour) + } + time.Sleep(2 * time.Millisecond) + assert.True(t, r.ShouldReport("new-fp")) + assert.Less(t, len(r.lastReport), maxReportEntries) +} + +// ========== dualwrite.go tests ========== + +func TestHasSecondaryWrite(t *testing.T) { + for _, tc := range []struct { + mode ProxyMode + expected bool + }{ + {ModeRedisOnly, false}, + {ModeDualWrite, true}, + {ModeDualWriteShadow, true}, + {ModeElasticKVPrimary, true}, + {ModeElasticKVOnly, false}, + } { + d := &DualWriter{cfg: ProxyConfig{Mode: tc.mode}, writeSem: make(chan struct{}, 1), shadowSem: make(chan struct{}, 1)} + assert.Equal(t, tc.expected, d.hasSecondaryWrite(), "mode=%s", tc.mode) + } +} + +func TestDualWriter_Write_PrimarySuccess(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("OK", nil) + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + resp, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + assert.Equal(t, "OK", resp) + assert.Equal(t, 1, primary.CallCount()) + + // Wait for async secondary write deterministically. + assert.Eventually(t, func() bool { return secondary.CallCount() == 1 }, + time.Second, time.Millisecond, "secondary should be called once") +} + +func TestDualWriter_Write_PrimaryFail(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd(nil, errors.New("connection refused")) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + _, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "connection refused") + // Secondary should NOT be called when primary fails; drain async work. + d.Close() + assert.Equal(t, 0, secondary.CallCount()) +} + +func TestDualWriter_Write_SecondaryFail_ClientSucceeds(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, errors.New("secondary down")) + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + resp, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + assert.Equal(t, "OK", resp) + + d.Close() + assert.InDelta(t, 1, testutil.ToFloat64(metrics.SecondaryWriteErrors), 0.001) +} + +func TestDualWriter_Write_RedisNil(t *testing.T) { + // SET with NX returns redis.Nil when key exists + primary := newMockBackend("primary") + primary.doFunc = makeCmd(nil, redis.Nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: time.Second}, metrics, newTestSentry(), testLogger) + + resp, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v"), []byte("NX")}) + assert.ErrorIs(t, err, redis.Nil) + assert.Nil(t, resp) + // Should still send to secondary + assert.Eventually(t, func() bool { return secondary.CallCount() == 1 }, + time.Second, time.Millisecond, "secondary should be called for redis.Nil") +} + +func TestDualWriter_Write_RedisOnlyMode(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeRedisOnly}, metrics, newTestSentry(), testLogger) + + _, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + d.Close() + assert.Equal(t, 0, secondary.CallCount(), "secondary should not be called in redis-only mode") +} + +func TestDualWriter_Write_ElasticKVOnlyMode(t *testing.T) { + primary := newMockBackend("elastickv") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("redis") + + metrics := newTestMetrics() + d := NewDualWriter(primary, secondary, ProxyConfig{Mode: ModeElasticKVOnly}, metrics, newTestSentry(), testLogger) + + _, err := d.Write(context.Background(), "SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.NoError(t, err) + d.Close() + assert.Equal(t, 0, secondary.CallCount(), "secondary should not be called in elastickv-only mode") +} + +func TestDualWriter_Read_WithShadow(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("hello", nil) + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("hello", nil) + + metrics := newTestMetrics() + cfg := ProxyConfig{Mode: ModeDualWriteShadow, ShadowTimeout: time.Second, SecondaryTimeout: time.Second} + d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) + + resp, err := d.Read(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}) + assert.NoError(t, err) + assert.Equal(t, "hello", resp) + + // Wait for shadow read deterministically. + assert.Eventually(t, func() bool { return secondary.CallCount() == 1 }, + time.Second, time.Millisecond, "shadow read should be issued") +} + +func TestDualWriter_Read_NoShadowInDualWrite(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("hello", nil) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + cfg := ProxyConfig{Mode: ModeDualWrite, ShadowTimeout: time.Second} + d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) + + _, err := d.Read(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}) + assert.NoError(t, err) + d.Close() + assert.Equal(t, 0, secondary.CallCount(), "no shadow in dual-write mode") +} + +func TestDualWriter_GoAsync_Bounded(t *testing.T) { + primary := newMockBackend("primary") + primary.doFunc = makeCmd("OK", nil) + secondary := newMockBackend("secondary") + + metrics := newTestMetrics() + cfg := ProxyConfig{Mode: ModeDualWrite, SecondaryTimeout: 10 * time.Second} + d := NewDualWriter(primary, secondary, cfg, metrics, newTestSentry(), testLogger) + + // Fill the write semaphore with blocking goroutines + blocker := make(chan struct{}) + for range maxWriteGoroutines { + d.goAsync(func() { + <-blocker + }) + } + + // Next one should be dropped, not block + done := make(chan struct{}) + go func() { + d.goAsync(func() { t.Error("should not run") }) + close(done) + }() + + select { + case <-done: + // good — goAsync returned immediately + case <-time.After(time.Second): + t.Fatal("goAsync blocked when semaphore was full") + } + + // Verify drop metric was incremented + assert.InDelta(t, 1, testutil.ToFloat64(metrics.AsyncDrops), 0.001) + + close(blocker) // unblock all + d.Close() // wait for all goroutines to finish +} + +// ========== ShadowReader tests ========== + +func TestShadowReader_Compare_Equal(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("hello", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "hello", nil) + + // No divergence should be reported + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_BothNil(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, nil, redis.Nil) + + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_MigrationGap(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) // secondary has no data + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "hello", nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) +} + +func TestShadowReader_Compare_DataMismatch(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("world", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "hello", nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) +} + +func TestShadowReader_Compare_ExtraData(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("surprise", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, nil, redis.Nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "extra_data")), 0.001) +} + +func TestShadowReader_Compare_EmptyStringIsNotNil(t *testing.T) { + // Primary returns "", secondary returns "" → equal, no divergence + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd("", nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "", nil) + + assert.InDelta(t, 0, testutil.ToFloat64(metrics.Divergences.WithLabelValues("GET", "data_mismatch")), 0.001) + assert.InDelta(t, 0, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_EmptyStringVsNil(t *testing.T) { + // Primary returns "", secondary returns nil → MigrationGap + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "", nil) + + assert.InDelta(t, 1, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) +} + +func TestShadowReader_Compare_MigrationGapSampling(t *testing.T) { + secondary := newMockBackend("secondary") + secondary.doFunc = makeCmd(nil, redis.Nil) + + metrics := newTestMetrics() + sr := NewShadowReader(secondary, metrics, newTestSentry(), testLogger, time.Second) + sr.gapLogSampleRate = 10 + + for range 25 { + sr.Compare(context.Background(), "GET", [][]byte{[]byte("GET"), []byte("k")}, "val", nil) + } + + assert.InDelta(t, 25, testutil.ToFloat64(metrics.MigrationGaps.WithLabelValues("GET")), 0.001) + assert.Equal(t, int64(25), sr.gapCount.Load()) +} + +// ========== Backend tests ========== + +func TestDefaultBackendOptions(t *testing.T) { + opts := DefaultBackendOptions() + assert.Equal(t, 128, opts.PoolSize) + assert.Equal(t, 5*time.Second, opts.DialTimeout) +} + +// ========== Pipeline error handling tests ========== + +func TestPipeline_TransportError(t *testing.T) { + b := newMockBackend("test") + b.doFunc = makeCmd(nil, errors.New("connection refused")) + + // mockBackend.Pipeline doesn't simulate pipe.Exec; test RedisBackend via unit behaviour. + // Here we verify the mock-based pipeline returns results. + results, err := b.Pipeline(context.Background(), [][]any{{"MULTI"}, {"SET", "k", "v"}, {"EXEC"}}) + assert.NoError(t, err) // mock always returns nil error + assert.Len(t, results, 3) +} + +// ========== writeRedisValue tests ========== + +// testRedisErr satisfies the redis.Error interface for testing. +type testRedisErr string + +func (e testRedisErr) Error() string { return string(e) } +func (e testRedisErr) RedisError() {} + +type mockRespWriter struct { + writes []any +} + +func (m *mockRespWriter) WriteError(msg string) { m.writes = append(m.writes, "ERR:"+msg) } +func (m *mockRespWriter) WriteString(msg string) { m.writes = append(m.writes, "STR:"+msg) } +func (m *mockRespWriter) WriteBulk(b []byte) { m.writes = append(m.writes, "BULK:"+string(b)) } +func (m *mockRespWriter) WriteBulkString(msg string) { m.writes = append(m.writes, "BULKSTR:"+msg) } +func (m *mockRespWriter) WriteInt64(num int64) { m.writes = append(m.writes, num) } +func (m *mockRespWriter) WriteArray(count int) { m.writes = append(m.writes, count) } +func (m *mockRespWriter) WriteNull() { m.writes = append(m.writes, nil) } + +func TestWriteRedisValue(t *testing.T) { + tests := []struct { + name string + val any + expect []any + }{ + {"nil", nil, []any{nil}}, + {"status OK", "OK", []any{"STR:OK"}}, + {"status QUEUED", "QUEUED", []any{"STR:QUEUED"}}, + {"status PONG", "PONG", []any{"STR:PONG"}}, + {"bulk string", "hello", []any{"BULKSTR:hello"}}, + {"int64", int64(42), []any{int64(42)}}, + {"bytes", []byte("data"), []any{"BULK:data"}}, + {"array", []any{"a", int64(1)}, []any{2, "BULKSTR:a", int64(1)}}, + {"nested array", []any{[]any{"x"}}, []any{1, 1, "BULKSTR:x"}}, + {"redis error", testRedisErr("WRONGTYPE bad"), []any{"ERR:WRONGTYPE bad"}}, + {"default type", float64(3.14), []any{"BULKSTR:3.14"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &mockRespWriter{} + writeRedisValue(w, tt.val) + assert.Equal(t, tt.expect, w.writes) + }) + } +} + +func TestWriteRedisError(t *testing.T) { + t.Run("redis.Error passthrough", func(t *testing.T) { + w := &mockRespWriter{} + writeRedisError(w, testRedisErr("WRONGTYPE Operation against a key")) + assert.Equal(t, []any{"ERR:WRONGTYPE Operation against a key"}, w.writes) + }) + + t.Run("generic error gets ERR prefix", func(t *testing.T) { + w := &mockRespWriter{} + writeRedisError(w, errors.New("connection refused")) + assert.Equal(t, []any{"ERR:ERR connection refused"}, w.writes) + }) +} + +func TestWriteResponse(t *testing.T) { + t.Run("nil error with value", func(t *testing.T) { + w := &mockRespWriter{} + writeResponse(w, "hello", nil) + assert.Equal(t, []any{"BULKSTR:hello"}, w.writes) + }) + + t.Run("redis.Nil", func(t *testing.T) { + w := &mockRespWriter{} + writeResponse(w, nil, redis.Nil) + assert.Equal(t, []any{nil}, w.writes) + }) + + t.Run("real error", func(t *testing.T) { + w := &mockRespWriter{} + writeResponse(w, nil, errors.New("timeout")) + assert.Equal(t, []any{"ERR:ERR timeout"}, w.writes) + }) +} + +// ========== truncateValue tests ========== + +type testStringer struct{ s string } + +func (ts testStringer) String() string { return ts.s } + +func TestTruncateValue(t *testing.T) { + tests := []struct { + name string + input any + expect string + }{ + {"nil", nil, ""}, + {"short string", "hello", "hello"}, + {"short bytes", []byte("abc"), "abc"}, + {"fmt.Stringer", testStringer{"ok"}, "ok"}, + {"int", 42, "42"}, + {"slice", []int{1, 2, 3}, "[1, 2, 3]"}, + {"map", map[string]int{"a": 1}, "{a: 1}"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateValue(tt.input) + assert.Equal(t, tt.expect, result) + }) + } + + // Long string truncation + long := make([]byte, 300) + for i := range long { + long[i] = 'a' + } + result := truncateValue(string(long)) + assert.Contains(t, result, "...(truncated)") + assert.LessOrEqual(t, len(result), 300) + + // Long []byte truncation + result = truncateValue(long) + assert.Contains(t, result, "...(truncated)") +} + +// ========== Additional command table tests ========== + +func TestClassifyCommand_NewCommands(t *testing.T) { + reads := []string{"MGET", "GETRANGE", "STRLEN", "SRANDMEMBER", "SDIFF", "SINTER", "SUNION", "XINFO", "GEODIST", "GEOSEARCH", "OBJECT", "RANDOMKEY", "TOUCH"} + for _, cmd := range reads { + assert.Equal(t, CmdRead, ClassifyCommand(cmd, nil), "expected %s to be CmdRead", cmd) + } + + writes := []string{"MSET", "MSETNX", "APPEND", "INCRBY", "DECR", "SETRANGE", "SPOP", "SMOVE", "PERSIST", "EXPIREAT", "GEOADD", "UNLINK", "COPY", "ZPOPMAX", "RESTORE"} + for _, cmd := range writes { + assert.Equal(t, CmdWrite, ClassifyCommand(cmd, nil), "expected %s to be CmdWrite", cmd) + } + + blocking := []string{"BLPOP", "BRPOP", "BZPOPMAX", "BLMOVE"} + for _, cmd := range blocking { + assert.Equal(t, CmdBlocking, ClassifyCommand(cmd, nil), "expected %s to be CmdBlocking", cmd) + } + + // XREADGROUP with BLOCK + assert.Equal(t, CmdBlocking, ClassifyCommand("XREADGROUP", [][]byte{[]byte("GROUP"), []byte("g"), []byte("c"), []byte("BLOCK"), []byte("0"), []byte("STREAMS"), []byte("s1"), []byte(">")})) + // XREADGROUP without BLOCK + assert.Equal(t, CmdRead, ClassifyCommand("XREADGROUP", [][]byte{[]byte("GROUP"), []byte("g"), []byte("c"), []byte("STREAMS"), []byte("s1"), []byte(">")})) + + admins := []string{"WATCH", "UNWATCH", "HELLO", "TIME", "SLOWLOG", "ACL"} + for _, cmd := range admins { + assert.Equal(t, CmdAdmin, ClassifyCommand(cmd, nil), "expected %s to be CmdAdmin", cmd) + } + + pubsubs := []string{"SSUBSCRIBE", "SUNSUBSCRIBE"} + for _, cmd := range pubsubs { + assert.Equal(t, CmdPubSub, ClassifyCommand(cmd, nil), "expected %s to be CmdPubSub", cmd) + } +} diff --git a/proxy/pubsub.go b/proxy/pubsub.go new file mode 100644 index 00000000..38fa7e3a --- /dev/null +++ b/proxy/pubsub.go @@ -0,0 +1,749 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9" + "github.com/tidwall/redcon" +) + +const ( + pubsubArrayMessage = 3 // ["message", channel, payload] + pubsubArrayPMessage = 4 // ["pmessage", pattern, channel, payload] + pubsubArrayReply = 3 // ["subscribe"/"unsubscribe", channel, count] + pubsubArrayPong = 2 // ["pong", data] + pubsubMinArgs = 2 // command + at least one channel + + cmdSubscribe = "SUBSCRIBE" + cmdUnsubscribe = "UNSUBSCRIBE" + cmdPSubscribe = "PSUBSCRIBE" + cmdPUnsubscribe = "PUNSUBSCRIBE" + cmdMulti = "MULTI" + cmdExec = "EXEC" + cmdDiscard = "DISCARD" + cmdPing = "PING" + cmdQuit = "QUIT" + + // cleanupFwdTimeout bounds the wait for forwardMessages to exit during cleanup. + // If the client socket is stuck, we don't want to block indefinitely. + cleanupFwdTimeout = 5 * time.Second +) + +// pubsubSession manages a single client's detached connection. +// It bridges a detached redcon connection to an upstream go-redis PubSub connection. +// When all subscriptions are removed, the session transitions to normal command mode, +// enabling the client to execute regular Redis commands without reconnecting. +type pubsubSession struct { + mu sync.Mutex // protects upstream, closed, and shadow; channelSet/patternSet/txn are goroutine-confined to commandLoop + writeMu sync.Mutex // serializes writes to dconn; never held across state operations + dconn redcon.DetachedConn + upstream *redis.PubSub // nil when not in pub/sub mode + proxy *ProxyServer + logger *slog.Logger + closed bool + + // Track subscribed channels/patterns in sets for idempotent subscribe/unsubscribe + // and correct subscription count tracking (matching Redis behavior). + channelSet map[string]struct{} + patternSet map[string]struct{} + + // fwdDone is closed when the current forwardMessages goroutine exits. + fwdDone chan struct{} + + // Shadow pub/sub for secondary comparison (nil when not in shadow mode). + shadow *shadowPubSub + + // Transaction state for normal command mode. + inTxn bool + txnQueue [][][]byte +} + +// subCount returns the total number of active subscriptions (channels + patterns). +func (s *pubsubSession) subCount() int { + return len(s.channelSet) + len(s.patternSet) +} + +// run starts the session. It blocks until the client disconnects or sends QUIT. +func (s *pubsubSession) run() { + defer s.cleanup() + s.startForwarding() + s.commandLoop() +} + +func (s *pubsubSession) cleanup() { + s.mu.Lock() + s.closed = true + if s.upstream != nil { + s.upstream.Close() + s.upstream = nil + } + s.mu.Unlock() + if s.fwdDone != nil { + // Bounded wait: if forwardMessages is stuck on a slow/dead client socket, + // close dconn to unblock it, then wait with a second bounded timeout. + if !s.waitFwdDone() { + s.logger.Warn("forwardMessages did not exit within timeout, closing dconn to unblock") + s.dconn.Close() + if !s.waitFwdDone() { + s.logger.Error("forwardMessages still stuck after dconn.Close, abandoning") + } + s.closeShadow() + return // dconn already closed + } + } + // Close shadow after forwardMessages exits (it calls RecordPrimary). + s.closeShadow() + s.dconn.Close() +} + +// waitFwdDone waits for fwdDone with a bounded timeout, returning true if it completed. +func (s *pubsubSession) waitFwdDone() bool { + select { + case <-s.fwdDone: + return true + case <-time.After(cleanupFwdTimeout): + return false + } +} + +func (s *pubsubSession) startForwarding() { + // Capture upstream under lock to avoid race with exitPubSubMode. + s.mu.Lock() + upstream := s.upstream + s.mu.Unlock() + if upstream == nil { + return + } + ch := upstream.Channel() + s.fwdDone = make(chan struct{}) + go func() { + defer close(s.fwdDone) + s.forwardMessages(ch) + }() +} + +// forwardMessages reads from the upstream go-redis PubSub channel and writes +// messages to the detached client connection. +func (s *pubsubSession) forwardMessages(ch <-chan *redis.Message) { + for msg := range ch { + s.mu.Lock() + closed := s.closed + s.mu.Unlock() + if closed { + return + } + s.writeMu.Lock() + if msg.Pattern != "" { + s.dconn.WriteArray(pubsubArrayPMessage) + s.dconn.WriteBulkString("pmessage") + s.dconn.WriteBulkString(msg.Pattern) + s.dconn.WriteBulkString(msg.Channel) + s.dconn.WriteBulkString(msg.Payload) + } else { + s.dconn.WriteArray(pubsubArrayMessage) + s.dconn.WriteBulkString("message") + s.dconn.WriteBulkString(msg.Channel) + s.dconn.WriteBulkString(msg.Payload) + } + err := s.dconn.Flush() + s.writeMu.Unlock() + if err != nil { + // Mark closed and close dconn to unblock commandLoop, preventing + // goroutine/resource leaks (aligned with adapter/redis_pubsub.go). + s.mu.Lock() + s.closed = true + s.mu.Unlock() + _ = s.dconn.Close() + return + } + // Record for shadow comparison (outside writeMu to avoid nested locking). + // Capture shadow under mu since cleanup/exitPubSubMode can nil it concurrently. + s.mu.Lock() + shadow := s.shadow + s.mu.Unlock() + if shadow != nil { + shadow.RecordPrimary(msg) + } + } +} + +// commandLoop reads commands from the detached client and dispatches them. +// In pub/sub mode, only subscription commands are allowed. +// When all subscriptions are removed, it transitions to normal command mode. +func (s *pubsubSession) commandLoop() { + for { + cmd, err := s.dconn.ReadCommand() + if err != nil { + return + } + if len(cmd.Args) == 0 { + continue + } + args := cloneArgs(cmd.Args) + name := strings.ToUpper(string(args[0])) + + // subCount() reads channelSet/patternSet which are only modified + // from this goroutine (commandLoop), so no lock is needed. + if s.subCount() > 0 { + if !s.dispatchPubSubCommand(args) { + return + } + if s.shouldExitPubSub() { + s.exitPubSubMode() + } + } else if !s.dispatchNormalCommand(name, args) { + return + } + } +} + +func (s *pubsubSession) shouldExitPubSub() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.upstream != nil && s.subCount() == 0 +} + +func (s *pubsubSession) exitPubSubMode() { + s.mu.Lock() + if s.upstream != nil { + s.upstream.Close() + s.upstream = nil + } + s.mu.Unlock() + if s.fwdDone != nil { + if !s.waitFwdDone() { + s.logger.Warn("forwardMessages did not exit within timeout, closing dconn to unblock") + s.dconn.Close() + if !s.waitFwdDone() { + s.logger.Error("forwardMessages still stuck after dconn.Close, abandoning") + } + } + s.fwdDone = nil + } + // Close shadow after forwardMessages exits (it calls RecordPrimary). + s.closeShadow() +} + +// dispatchPubSubCommand handles a single command in pub/sub mode. +// Returns false if the session should end (QUIT). +func (s *pubsubSession) dispatchPubSubCommand(args [][]byte) bool { + switch strings.ToUpper(string(args[0])) { + case cmdSubscribe: + s.handleSubscribe(args) + case cmdUnsubscribe: + s.handleUnsub(args, false) + case cmdPSubscribe: + s.handlePSubscribe(args) + case cmdPUnsubscribe: + s.handleUnsub(args, true) + case cmdPing: + s.handlePubSubPing(args) + case cmdQuit: + s.writeString("OK") + return false + default: + s.writeError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") + } + return true +} + +// dispatchNormalCommand handles a command in normal (non-pub/sub) mode. +// Returns false if the session should end (QUIT). +func (s *pubsubSession) dispatchNormalCommand(name string, args [][]byte) bool { + if name == cmdQuit { + s.writeString("OK") + return false + } + if name == cmdPing { + s.handleNormalPing(args) + return true + } + // Let the transaction handler process or queue commands first, so that + // behavior during MULTI is consistent with the main ProxyServer handler. + if s.handleTxnInSession(name, args) { + return true + } + if name == cmdSubscribe || name == cmdPSubscribe { + s.reenterPubSub(name, args) + return true + } + if name == cmdUnsubscribe || name == cmdPUnsubscribe { + s.handleUnsubNoSession(name) + return true + } + s.dispatchRegularCommand(name, args) + return true +} + +// handleTxnInSession handles transaction commands in normal mode. +// Returns true if the command was handled as a transaction command. +func (s *pubsubSession) handleTxnInSession(name string, args [][]byte) bool { + switch name { + case cmdMulti: + if s.inTxn { + s.writeError("ERR MULTI calls can not be nested") + } else { + s.inTxn = true + s.txnQueue = nil + s.writeString("OK") + } + return true + case cmdExec: + if !s.inTxn { + s.writeError("ERR EXEC without MULTI") + } else { + s.execTxn() + } + return true + case cmdDiscard: + if !s.inTxn { + s.writeError("ERR DISCARD without MULTI") + } else { + s.inTxn = false + s.txnQueue = nil + s.writeString("OK") + } + return true + } + if s.inTxn { + s.txnQueue = append(s.txnQueue, args) + s.writeString("QUEUED") + return true + } + return false +} + +// dispatchRegularCommand sends a non-transaction, non-special command to the backend. +func (s *pubsubSession) dispatchRegularCommand(name string, args [][]byte) { + // SELECT and AUTH are handled at the connection-pool level; accept silently. + if name == "SELECT" || name == "AUTH" { + s.writeString("OK") + return + } + + cat := ClassifyCommand(name, args[1:]) + ctx := context.Background() + + var resp any + var err error + + switch cat { + case CmdWrite: + resp, err = s.proxy.dual.Write(ctx, name, args) + case CmdRead: + resp, err = s.proxy.dual.Read(ctx, name, args) + case CmdBlocking: + resp, err = s.proxy.dual.Blocking(s.proxy.shutdownCtx, name, args) + case CmdPubSub: + resp, err = s.proxy.dual.Admin(ctx, name, args) + case CmdAdmin: + resp, err = s.proxy.dual.Admin(ctx, name, args) + case CmdScript: + resp, err = s.proxy.dual.Script(ctx, name, args) + case CmdTxn: + // Handled by handleTxnInSession; should not reach here. + return + } + + s.writeMu.Lock() + writeResponse(s.dconn, resp, err) + s.flushOrClose() + s.writeMu.Unlock() +} + +func (s *pubsubSession) reenterPubSub(cmdName string, args [][]byte) { + if len(args) < pubsubMinArgs { + s.writeError(fmt.Sprintf("ERR wrong number of arguments for '%s' command", strings.ToLower(cmdName))) + return + } + psBackend := s.proxy.dual.PubSubBackend() + if psBackend == nil { + s.writeError("ERR PubSub not supported by backend") + return + } + + upstream := psBackend.NewPubSub(context.Background()) + channels := byteSlicesToStrings(args[1:]) + var err error + if cmdName == cmdSubscribe { + err = upstream.Subscribe(context.Background(), channels...) + } else { + err = upstream.PSubscribe(context.Background(), channels...) + } + if err != nil { + upstream.Close() + s.writeRedisError(err) + return + } + + shadow := s.proxy.createShadowPubSub(cmdName, channels) + + s.mu.Lock() + s.upstream = upstream + s.shadow = shadow + s.mu.Unlock() + + // Update state (sets only accessed from commandLoop goroutine). + kind := strings.ToLower(cmdName) + for _, ch := range channels { + if cmdName == cmdSubscribe { + s.channelSet[ch] = struct{}{} + } else { + s.patternSet[ch] = struct{}{} + } + } + + // Write subscription confirmations before starting forwarding so that + // clients receive acknowledgements before any pub/sub messages. + s.writeMu.Lock() + for _, ch := range channels { + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(ch) + s.dconn.WriteInt64(int64(s.subCount())) + } + s.flushOrClose() + s.writeMu.Unlock() + + s.startForwarding() +} + +func (s *pubsubSession) execTxn() { + queue := s.txnQueue + s.inTxn = false + s.txnQueue = nil + + ctx := context.Background() + cmds := make([][]any, 0, len(queue)+txnCommandsOverhead) + cmds = append(cmds, []any{"MULTI"}) + for _, args := range queue { + cmds = append(cmds, bytesArgsToInterfaces(args)) + } + cmds = append(cmds, []any{"EXEC"}) + + results, err := s.proxy.dual.Primary().Pipeline(ctx, cmds) + + s.writeMu.Lock() + switch { + case err != nil: + // Pipeline-level error (connection/transport failure) takes precedence. + writeRedisError(s.dconn, err) + case len(results) > 0: + lastResult := results[len(results)-1] + resp, rErr := lastResult.Result() + writeResponse(s.dconn, resp, rErr) + default: + s.dconn.WriteError("ERR empty transaction response") + } + s.flushOrClose() + s.writeMu.Unlock() + + if s.proxy.dual.hasSecondaryWrite() { + s.proxy.dual.goAsync(func() { + sCtx, cancel := context.WithTimeout(context.Background(), s.proxy.cfg.SecondaryTimeout) + defer cancel() + _, pErr := s.proxy.dual.Secondary().Pipeline(sCtx, cmds) + if pErr != nil { + s.proxy.logger.Warn("secondary txn replay failed", "err", pErr) + s.proxy.metrics.SecondaryWriteErrors.Inc() + } + }) + } +} + +// --- Subscription handlers --- + +func (s *pubsubSession) handleSubscribe(args [][]byte) { + s.handleSub(args, false) +} + +func (s *pubsubSession) handlePSubscribe(args [][]byte) { + s.handleSub(args, true) +} + +// handleSub is the shared implementation for SUBSCRIBE and PSUBSCRIBE. +func (s *pubsubSession) handleSub(args [][]byte, isPattern bool) { + kind := "subscribe" + if isPattern { + kind = "psubscribe" + } + if len(args) < pubsubMinArgs { + s.writeError(fmt.Sprintf("ERR wrong number of arguments for '%s' command", kind)) + return + } + names := byteSlicesToStrings(args[1:]) + var err error + if isPattern { + err = s.upstream.PSubscribe(context.Background(), names...) + } else { + err = s.upstream.Subscribe(context.Background(), names...) + } + if err != nil { + s.logger.Warn("upstream "+kind+" failed", "err", err) + s.writeRedisError(err) + return + } + s.mirrorSub(names, isPattern) + // Update state (goroutine-confined to commandLoop). + for _, n := range names { + if isPattern { + s.patternSet[n] = struct{}{} + } else { + s.channelSet[n] = struct{}{} + } + } + s.writeMu.Lock() + for _, n := range names { + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(n) + s.dconn.WriteInt64(int64(s.subCount())) + } + s.flushOrClose() + s.writeMu.Unlock() +} + +// handleUnsub handles both UNSUBSCRIBE and PUNSUBSCRIBE. +// When isPattern is true, it operates on pattern subscriptions. +func (s *pubsubSession) handleUnsub(args [][]byte, isPattern bool) { + kind := "unsubscribe" + unsubFn := s.upstream.Unsubscribe + if isPattern { + kind = "punsubscribe" + unsubFn = s.upstream.PUnsubscribe + } + + if len(args) < pubsubMinArgs { + // Unsubscribe all: emit per-channel reply (matching Redis behavior). + if err := unsubFn(context.Background()); err != nil { + s.logger.Warn("upstream "+kind+" failed, closing session", "err", err) + s.writeRedisError(err) + return + } + if s.shadow != nil { + s.mirrorUnsubAll(isPattern) + } + s.writeUnsubAll(kind, isPattern) + return + } + + names := byteSlicesToStrings(args[1:]) + if err := unsubFn(context.Background(), names...); err != nil { + s.logger.Warn("upstream "+kind+" failed, closing session", "err", err) + s.writeRedisError(err) + return + } + if s.shadow != nil { + s.mirrorUnsub(names, isPattern) + } + // Update state (goroutine-confined) and pre-compute counts before taking writeMu. + counts := make([]int, len(names)) + for i, n := range names { + if isPattern { + delete(s.patternSet, n) + } else { + delete(s.channelSet, n) + } + counts[i] = s.subCount() + } + s.writeMu.Lock() + for i, n := range names { + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(n) + s.dconn.WriteInt64(int64(counts[i])) + } + s.flushOrClose() + s.writeMu.Unlock() +} + +// writeUnsubAll emits per-channel/pattern unsubscribe replies, removing entries +// one-by-one so the subscription count decrements per reply (matching Redis). +func (s *pubsubSession) writeUnsubAll(kind string, isPattern bool) { + set := s.channelSet + if isPattern { + set = s.patternSet + } + + if len(set) == 0 { + // No subscriptions: single reply with null channel (matching Redis). + s.writeMu.Lock() + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteNull() + s.dconn.WriteInt64(int64(s.subCount())) + s.flushOrClose() + s.writeMu.Unlock() + return + } + + // Collect names and pre-compute decreasing counts (state is goroutine-confined). + names := make([]string, 0, len(set)) + for n := range set { + names = append(names, n) + } + counts := make([]int, len(names)) + for i, n := range names { + if isPattern { + delete(s.patternSet, n) + } else { + delete(s.channelSet, n) + } + counts[i] = s.subCount() + } + + s.writeMu.Lock() + for i, n := range names { + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(kind) + s.dconn.WriteBulkString(n) + s.dconn.WriteInt64(int64(counts[i])) + } + s.flushOrClose() + s.writeMu.Unlock() +} + +// --- Ping handlers --- + +func (s *pubsubSession) handlePubSubPing(args [][]byte) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.dconn.WriteArray(pubsubArrayPong) + s.dconn.WriteBulkString("pong") + if len(args) > 1 { + s.dconn.WriteBulk(args[1]) + } else { + s.dconn.WriteBulkString("") + } + s.flushOrClose() +} + +func (s *pubsubSession) handleNormalPing(args [][]byte) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if len(args) > 1 { + s.dconn.WriteBulk(args[1]) + } else { + s.dconn.WriteString("PONG") + } + s.flushOrClose() +} + +func (s *pubsubSession) handleUnsubNoSession(cmdName string) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.dconn.WriteArray(pubsubArrayReply) + s.dconn.WriteBulkString(strings.ToLower(cmdName)) + s.dconn.WriteNull() + s.dconn.WriteInt64(0) + s.flushOrClose() +} + +func (s *pubsubSession) closeShadow() { + s.mu.Lock() + shadow := s.shadow + s.shadow = nil + s.mu.Unlock() + if shadow != nil { + shadow.Close() + } +} + +// --- Shadow mirror helpers --- + +func (s *pubsubSession) mirrorSub(names []string, isPattern bool) { + if s.shadow == nil { + return + } + var err error + if isPattern { + err = s.shadow.PSubscribe(context.Background(), names...) + } else { + err = s.shadow.Subscribe(context.Background(), names...) + } + if err != nil { + kind := "subscribe" + if isPattern { + kind = "psubscribe" + } + s.logger.Warn("shadow "+kind+" failed", "err", err) + s.proxy.metrics.PubSubShadowErrors.Inc() + } +} + +func (s *pubsubSession) mirrorUnsubAll(isPattern bool) { + var err error + if isPattern { + err = s.shadow.PUnsubscribe(context.Background()) + } else { + err = s.shadow.Unsubscribe(context.Background()) + } + if err != nil { + s.logger.Warn("shadow unsubscribe-all failed", "err", err) + s.proxy.metrics.PubSubShadowErrors.Inc() + } +} + +func (s *pubsubSession) mirrorUnsub(names []string, isPattern bool) { + var err error + if isPattern { + err = s.shadow.PUnsubscribe(context.Background(), names...) + } else { + err = s.shadow.Unsubscribe(context.Background(), names...) + } + if err != nil { + s.logger.Warn("shadow unsubscribe failed", "err", err) + s.proxy.metrics.PubSubShadowErrors.Inc() + } +} + +// --- Helpers --- + +// flushOrClose flushes the detached connection. On error it closes the +// connection so that commandLoop will observe a read failure and shut down. +// Caller must hold s.writeMu. +func (s *pubsubSession) flushOrClose() { + if err := s.dconn.Flush(); err != nil { + s.logger.Warn("failed to flush to client; closing connection", "err", err) + _ = s.dconn.Close() + } +} + +func (s *pubsubSession) writeError(msg string) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.dconn.WriteError(msg) + s.flushOrClose() +} + +// writeRedisError writes an upstream error, preserving redis.Error prefixes verbatim +// and normalizing other errors to "ERR ..." (matching writeRedisError in proxy.go). +func (s *pubsubSession) writeRedisError(err error) { + var redisErr redis.Error + if errors.As(err, &redisErr) { + s.writeError(redisErr.Error()) + return + } + s.writeError("ERR " + err.Error()) +} + +func (s *pubsubSession) writeString(msg string) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.dconn.WriteString(msg) + s.flushOrClose() +} + +func byteSlicesToStrings(bs [][]byte) []string { + out := make([]string, len(bs)) + for i, b := range bs { + out[i] = string(b) + } + return out +} diff --git a/proxy/pubsub_test.go b/proxy/pubsub_test.go new file mode 100644 index 00000000..3ee7039b --- /dev/null +++ b/proxy/pubsub_test.go @@ -0,0 +1,644 @@ +package proxy + +import ( + "errors" + "net" + "sync" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/tidwall/redcon" +) + +// Tagged types to distinguish RESP wire types in mock writes. +type respInt struct{ V int } // WriteInt +type respInt64 struct{ V int64 } // WriteInt64 +type respArray struct{ N int } // WriteArray + +// mockDetachedConn implements redcon.DetachedConn for unit testing. +type mockDetachedConn struct { + mu sync.Mutex + commands []redcon.Command // queued commands to return from ReadCommand + cmdIdx int + writes []any // recorded writes: string for WriteString/WriteError/WriteBulkString, respInt/respInt64/respArray for typed ints, nil for WriteNull + closed bool + readErr error // error to return from ReadCommand when commands exhausted +} + +func newMockDetachedConn() *mockDetachedConn { + return &mockDetachedConn{ + readErr: errors.New("EOF"), + } +} + +func (m *mockDetachedConn) queueCommand(args ...string) { + bArgs := make([][]byte, len(args)) + for i, a := range args { + bArgs[i] = []byte(a) + } + m.commands = append(m.commands, redcon.Command{Args: bArgs}) +} + +func (m *mockDetachedConn) ReadCommand() (redcon.Command, error) { + m.mu.Lock() + defer m.mu.Unlock() + if m.cmdIdx >= len(m.commands) { + return redcon.Command{}, m.readErr + } + cmd := m.commands[m.cmdIdx] + m.cmdIdx++ + return cmd, nil +} + +func (m *mockDetachedConn) Flush() error { return nil } +func (m *mockDetachedConn) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + m.closed = true + return nil +} + +func (m *mockDetachedConn) RemoteAddr() string { return "127.0.0.1:9999" } + +func (m *mockDetachedConn) WriteError(msg string) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "ERR:"+msg) +} +func (m *mockDetachedConn) WriteString(str string) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "STR:"+str) +} +func (m *mockDetachedConn) WriteBulk(bulk []byte) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "BULK:"+string(bulk)) +} +func (m *mockDetachedConn) WriteBulkString(bulk string) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "BULKSTR:"+bulk) +} +func (m *mockDetachedConn) WriteInt(num int) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, respInt{num}) +} +func (m *mockDetachedConn) WriteInt64(num int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, respInt64{num}) +} +func (m *mockDetachedConn) WriteUint64(_ uint64) { + // Not used in pubsub tests. +} +func (m *mockDetachedConn) WriteArray(count int) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, respArray{count}) +} +func (m *mockDetachedConn) WriteNull() { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, nil) +} +func (m *mockDetachedConn) WriteRaw(data []byte) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, "RAW:"+string(data)) +} +func (m *mockDetachedConn) WriteAny(v any) { + m.mu.Lock() + defer m.mu.Unlock() + m.writes = append(m.writes, v) +} +func (m *mockDetachedConn) Context() any { return nil } +func (m *mockDetachedConn) SetContext(v any) {} +func (m *mockDetachedConn) SetReadBuffer(n int) {} +func (m *mockDetachedConn) Detach() redcon.DetachedConn { return m } +func (m *mockDetachedConn) ReadPipeline() []redcon.Command { return nil } +func (m *mockDetachedConn) PeekPipeline() []redcon.Command { return nil } +func (m *mockDetachedConn) NetConn() net.Conn { return nil } + +func (m *mockDetachedConn) getWrites() []any { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]any, len(m.writes)) + copy(out, m.writes) + return out +} + +// newTestSession creates a pubsubSession with a mock connection for testing. +func newTestSession(dconn *mockDetachedConn) *pubsubSession { + return &pubsubSession{ + dconn: dconn, + logger: testLogger, + channelSet: make(map[string]struct{}), + patternSet: make(map[string]struct{}), + } +} + +func TestPubSub_SubscribeDuplicate_IdempotentCount(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // Simulate an upstream that's already connected. + s.channelSet["ch1"] = struct{}{} + + // Subscribe to ch1 again and ch2 (should count ch1 only once). + // We can't call handleSubscribe directly because it needs s.upstream, + // so we test the set logic directly. + s.channelSet["ch1"] = struct{}{} // re-add same key + s.channelSet["ch2"] = struct{}{} + + assert.Equal(t, 2, len(s.channelSet), "duplicate subscribe should not increase count") + assert.Equal(t, 2, s.subCount()) +} + +func TestPubSub_UnsubscribeNonSubscribed_NoEffect(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.channelSet["ch1"] = struct{}{} + s.channelSet["ch2"] = struct{}{} + + // Unsubscribe from "ch3" which was never subscribed. + delete(s.channelSet, "ch3") + + assert.Equal(t, 2, len(s.channelSet), "unsubscribe non-subscribed channel should not affect count") +} + +func TestPubSub_UnsubscribeSpecific_RemovesFromSet(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.channelSet["ch1"] = struct{}{} + s.channelSet["ch2"] = struct{}{} + s.channelSet["ch3"] = struct{}{} + + delete(s.channelSet, "ch2") + + assert.Equal(t, 2, len(s.channelSet)) + _, hasCh1 := s.channelSet["ch1"] + _, hasCh2 := s.channelSet["ch2"] + _, hasCh3 := s.channelSet["ch3"] + assert.True(t, hasCh1) + assert.False(t, hasCh2) + assert.True(t, hasCh3) +} + +func TestPubSub_WriteUnsubAll_PerChannelReplies(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.channelSet["ch1"] = struct{}{} + s.channelSet["ch2"] = struct{}{} + s.patternSet["pat1"] = struct{}{} + + s.mu.Lock() + s.writeUnsubAll("unsubscribe", false) + s.mu.Unlock() + + writes := dconn.getWrites() + + // Should have emitted 2 unsubscribe replies (one per channel). + // Each reply is: WriteArray(3), WriteBulkString("unsubscribe"), WriteBulkString(name), WriteInt(remaining) + assert.Equal(t, 0, len(s.channelSet), "channelSet should be cleared") + assert.Equal(t, 1, len(s.patternSet), "patternSet should not be affected") + + // Count array headers (each reply starts with WriteArray(3)) + arrCount := 0 + for _, w := range writes { + if a, ok := w.(respArray); ok && a.N == 3 { + arrCount++ + } + } + assert.Equal(t, 2, arrCount, "should emit one reply per unsubscribed channel") +} + +func TestPubSub_WriteUnsubAll_EmptySet_SingleNullReply(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // No subscriptions + s.mu.Lock() + s.writeUnsubAll("unsubscribe", false) + s.mu.Unlock() + + writes := dconn.getWrites() + + // Should emit single reply with null channel. + // WriteArray(3), WriteBulkString("unsubscribe"), WriteNull(), WriteInt(0) + assert.Len(t, writes, 4, "single null-channel reply expected") + + hasNull := false + for _, w := range writes { + if w == nil { + hasNull = true + } + } + assert.True(t, hasNull, "should contain null for empty unsubscribe-all") +} + +func TestPubSub_WriteUnsubAll_Patterns(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.patternSet["h*"] = struct{}{} + s.channelSet["ch1"] = struct{}{} + + s.mu.Lock() + s.writeUnsubAll("punsubscribe", true) + s.mu.Unlock() + + assert.Equal(t, 0, len(s.patternSet), "patternSet should be cleared") + assert.Equal(t, 1, len(s.channelSet), "channelSet should not be affected") + + writes := dconn.getWrites() + // Should have 1 reply (one pattern) + arrCount := 0 + for _, w := range writes { + if a, ok := w.(respArray); ok && a.N == 3 { + arrCount++ + } + } + assert.Equal(t, 1, arrCount) +} + +func TestPubSub_SubCount(t *testing.T) { + s := newTestSession(newMockDetachedConn()) + + assert.Equal(t, 0, s.subCount()) + + s.channelSet["a"] = struct{}{} + assert.Equal(t, 1, s.subCount()) + + s.patternSet["b*"] = struct{}{} + assert.Equal(t, 2, s.subCount()) + + s.channelSet["a"] = struct{}{} // duplicate + assert.Equal(t, 2, s.subCount()) + + delete(s.channelSet, "a") + assert.Equal(t, 1, s.subCount()) +} + +func TestPubSub_DispatchPubSubCommand_Quit(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchPubSubCommand([][]byte{[]byte("QUIT")}) + assert.False(t, cont, "QUIT should return false") + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:OK") +} + +func TestPubSub_DispatchPubSubCommand_InvalidCommand(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchPubSubCommand([][]byte{[]byte("GET"), []byte("key")}) + assert.True(t, cont, "invalid command should not end session") + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context" { + found = true + } + } + assert.True(t, found, "should write error for invalid pub/sub command") +} + +func TestPubSub_DispatchNormalCommand_Quit(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchNormalCommand("QUIT", [][]byte{[]byte("QUIT")}) + assert.False(t, cont, "QUIT should return false") +} + +func TestPubSub_DispatchNormalCommand_Ping(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchNormalCommand("PING", [][]byte{[]byte("PING")}) + assert.True(t, cont) + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:PONG") +} + +func TestPubSub_DispatchNormalCommand_PingWithMessage(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + cont := s.dispatchNormalCommand("PING", [][]byte{[]byte("PING"), []byte("hello")}) + assert.True(t, cont) + + writes := dconn.getWrites() + assert.Contains(t, writes, "BULK:hello") +} + +func TestPubSub_HandlePubSubPing(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.handlePubSubPing([][]byte{[]byte("PING")}) + + writes := dconn.getWrites() + // ["pong", ""] + assert.Contains(t, writes, respArray{2}) // WriteArray(2) + assert.Contains(t, writes, "BULKSTR:pong") + assert.Contains(t, writes, "BULKSTR:") // empty string +} + +func TestPubSub_HandlePubSubPingWithData(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.handlePubSubPing([][]byte{[]byte("PING"), []byte("hello")}) + + writes := dconn.getWrites() + assert.Contains(t, writes, "BULK:hello") +} + +func TestPubSub_HandleUnsubNoSession(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.handleUnsubNoSession("UNSUBSCRIBE") + + writes := dconn.getWrites() + // ["unsubscribe", null, 0] + assert.Contains(t, writes, "BULKSTR:unsubscribe") + hasNull := false + for _, w := range writes { + if w == nil { + hasNull = true + } + } + assert.True(t, hasNull) + assert.Contains(t, writes, respInt64{0}) // WriteInt64(0) +} + +func TestPubSub_SubscribeInTxnQueued(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + s.inTxn = true + + cont := s.dispatchNormalCommand("SUBSCRIBE", [][]byte{[]byte("SUBSCRIBE"), []byte("ch1")}) + assert.True(t, cont) + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "STR:QUEUED" { + found = true + } + } + assert.True(t, found, "SUBSCRIBE during MULTI should be queued") + assert.Len(t, s.txnQueue, 1, "SUBSCRIBE should be added to txn queue") +} + +func TestPubSub_HandleTxnInSession_MultiExecDiscard(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // MULTI + handled := s.handleTxnInSession("MULTI", nil) + assert.True(t, handled) + assert.True(t, s.inTxn) + + // Nested MULTI + dconn2 := newMockDetachedConn() + s2 := newTestSession(dconn2) + s2.inTxn = true + handled = s2.handleTxnInSession("MULTI", nil) + assert.True(t, handled) + writes := dconn2.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR MULTI calls can not be nested" { + found = true + } + } + assert.True(t, found) + + // Queue a command + handled = s.handleTxnInSession("SET", [][]byte{[]byte("SET"), []byte("k"), []byte("v")}) + assert.True(t, handled) + assert.Len(t, s.txnQueue, 1) + + // DISCARD + handled = s.handleTxnInSession("DISCARD", nil) + assert.True(t, handled) + assert.False(t, s.inTxn) + assert.Nil(t, s.txnQueue) + + // EXEC without MULTI + handled = s.handleTxnInSession("EXEC", nil) + assert.True(t, handled) + writes = dconn.getWrites() + found = false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR EXEC without MULTI" { + found = true + } + } + assert.True(t, found) + + // Non-txn command when not in txn + handled = s.handleTxnInSession("GET", [][]byte{[]byte("GET"), []byte("k")}) + assert.False(t, handled, "non-txn command outside txn should not be handled") +} + +func TestPubSub_ShouldExitPubSub(t *testing.T) { + s := newTestSession(newMockDetachedConn()) + + // No upstream, no subs → false (no upstream to close) + assert.False(t, s.shouldExitPubSub()) + + // With upstream but has subs → false + s.upstream = &redis.PubSub{} + s.channelSet["ch1"] = struct{}{} + assert.False(t, s.shouldExitPubSub()) + + // With upstream and no subs → true + delete(s.channelSet, "ch1") + assert.True(t, s.shouldExitPubSub()) +} + +func TestPubSub_ByteSlicesToStrings(t *testing.T) { + input := [][]byte{[]byte("hello"), []byte("world")} + result := byteSlicesToStrings(input) + assert.Equal(t, []string{"hello", "world"}, result) + + // Empty + assert.Equal(t, []string{}, byteSlicesToStrings([][]byte{})) +} + +func TestPubSub_SelectAuthSilentlyAccepted(t *testing.T) { + for _, cmd := range []string{"SELECT", "AUTH"} { + t.Run(cmd, func(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.dispatchRegularCommand(cmd, [][]byte{[]byte(cmd), []byte("0")}) + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:OK") + }) + } +} + +// TestPubSub_CleanupClosesUpstream verifies that cleanup closes upstream and dconn. +func TestPubSub_CleanupClosesUpstream(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + s.cleanup() + assert.True(t, s.closed) + assert.True(t, dconn.closed) + assert.Nil(t, s.upstream) +} + +// TestPubSub_CommandLoop_EmptyArgs verifies that empty command args are skipped. +func TestPubSub_CommandLoop_EmptyArgs(t *testing.T) { + dconn := newMockDetachedConn() + // Queue empty args then QUIT + dconn.commands = append(dconn.commands, redcon.Command{Args: nil}) + dconn.queueCommand("QUIT") + + s := newTestSession(dconn) + // Not in pub/sub mode, so should dispatch as normal commands + s.commandLoop() + + writes := dconn.getWrites() + assert.Contains(t, writes, "STR:OK", "QUIT should be handled") +} + +func TestPubSub_CommandLoop_EOF(t *testing.T) { + dconn := newMockDetachedConn() + dconn.readErr = errors.New("connection closed") + + s := newTestSession(dconn) + s.commandLoop() // should return without panic +} + +// ========== T3: forwardMessages message/pmessage branch test ========== + +func TestPubSub_ForwardMessages_MessageBranch(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + ch := make(chan *redis.Message, 2) + + // Regular message (no pattern) + ch <- &redis.Message{Channel: "ch1", Payload: "hello"} + // Pattern message + ch <- &redis.Message{Pattern: "h*", Channel: "hello-world", Payload: "data"} + close(ch) + + s.forwardMessages(ch) + + writes := dconn.getWrites() + + // First message: ["message", "ch1", "hello"] + // Expect: WriteArray(3), WriteBulkString("message"), WriteBulkString("ch1"), WriteBulkString("hello") + assert.Contains(t, writes, respArray{pubsubArrayMessage}) // WriteArray(3) + assert.Contains(t, writes, "BULKSTR:message") + assert.Contains(t, writes, "BULKSTR:ch1") + assert.Contains(t, writes, "BULKSTR:hello") + + // Second message: ["pmessage", "h*", "hello-world", "data"] + // Expect: WriteArray(4), WriteBulkString("pmessage"), WriteBulkString("h*"), WriteBulkString("hello-world"), WriteBulkString("data") + assert.Contains(t, writes, respArray{pubsubArrayPMessage}) // WriteArray(4) + assert.Contains(t, writes, "BULKSTR:pmessage") + assert.Contains(t, writes, "BULKSTR:h*") + assert.Contains(t, writes, "BULKSTR:hello-world") + assert.Contains(t, writes, "BULKSTR:data") +} + +func TestPubSub_ForwardMessages_ClosedSession(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // Mark session as closed before sending message. + s.mu.Lock() + s.closed = true + s.mu.Unlock() + + ch := make(chan *redis.Message, 1) + ch <- &redis.Message{Channel: "ch1", Payload: "should-not-write"} + close(ch) + + s.forwardMessages(ch) + + writes := dconn.getWrites() + assert.Empty(t, writes, "should not write to dconn when session is closed") +} + +func TestPubSub_ForwardMessages_EmptyChannel(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + ch := make(chan *redis.Message) + close(ch) + + s.forwardMessages(ch) + + writes := dconn.getWrites() + assert.Empty(t, writes, "no messages should produce no writes") +} + +// ========== T4: reenterPubSub test ========== + +func TestPubSub_ReenterPubSub_TooFewArgs(t *testing.T) { + dconn := newMockDetachedConn() + s := newTestSession(dconn) + + // Only command name, no channels + s.reenterPubSub("SUBSCRIBE", [][]byte{[]byte("SUBSCRIBE")}) + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR wrong number of arguments for 'subscribe' command" { + found = true + } + } + assert.True(t, found, "should error on missing channel args") +} + +func TestPubSub_ReenterPubSub_NilBackend(t *testing.T) { + dconn := newMockDetachedConn() + reg := prometheus.NewRegistry() + metrics := NewProxyMetrics(reg) + cfg := DefaultConfig() + + // DualWriter with no PubSubBackend + primary := newMockBackend("primary") + secondary := newMockBackend("secondary") + dual := NewDualWriter(primary, secondary, cfg, metrics, NewSentryReporter("", "", 0, testLogger), testLogger) + + srv := NewProxyServer(cfg, dual, metrics, NewSentryReporter("", "", 0, testLogger), testLogger) + + s := newTestSession(dconn) + s.proxy = srv + + s.reenterPubSub("SUBSCRIBE", [][]byte{[]byte("SUBSCRIBE"), []byte("ch1")}) + + writes := dconn.getWrites() + found := false + for _, w := range writes { + if str, ok := w.(string); ok && str == "ERR:ERR PubSub not supported by backend" { + found = true + } + } + assert.True(t, found, "should error when PubSubBackend is nil") +} diff --git a/proxy/sentry.go b/proxy/sentry.go new file mode 100644 index 00000000..da26f7b8 --- /dev/null +++ b/proxy/sentry.go @@ -0,0 +1,254 @@ +package proxy + +import ( + "fmt" + "log/slog" + "reflect" + "strings" + "sync" + "time" + + "github.com/getsentry/sentry-go" +) + +const ( + defaultReportCooldown = 60 * time.Second + // maxSentryValueLen limits the length of values attached to Sentry events + // to prevent data leakage and oversized events. + maxSentryValueLen = 256 + // maxReportEntries caps the lastReport map to prevent unbounded growth. + maxReportEntries = 10000 +) + +// SentryReporter sends anomaly events to Sentry with de-duplication. +type SentryReporter struct { + enabled bool + hub *sentry.Hub + logger *slog.Logger + cooldown time.Duration + + mu sync.Mutex + lastReport map[string]time.Time // fingerprint → last report time +} + +// NewSentryReporter initialises Sentry. If dsn is empty, reporting is disabled. +func NewSentryReporter(dsn string, environment string, sampleRate float64, logger *slog.Logger) *SentryReporter { + if logger == nil { + logger = slog.Default() + } + r := &SentryReporter{ + logger: logger, + cooldown: defaultReportCooldown, + lastReport: make(map[string]time.Time), + } + if dsn == "" { + return r + } + + err := sentry.Init(sentry.ClientOptions{ + Dsn: dsn, + Environment: environment, + SampleRate: sampleRate, + EnableTracing: false, + AttachStacktrace: true, + }) + if err != nil { + logger.Error("failed to init sentry", "err", err) + return r + } + r.enabled = true + r.hub = sentry.CurrentHub() + return r +} + +// CaptureException reports an error to Sentry. +func (r *SentryReporter) CaptureException(err error, operation string, args [][]byte) { + if !r.enabled { + return + } + r.hub.WithScope(func(scope *sentry.Scope) { + scope.SetTag("operation", operation) + if len(args) > 0 { + scope.SetTag("command", string(args[0])) + } + scope.SetFingerprint([]string{operation, cmdNameFromArgs(args)}) + r.hub.CaptureException(err) + }) +} + +// CaptureDivergence reports a data divergence to Sentry with cooldown-based de-duplication. +func (r *SentryReporter) CaptureDivergence(div Divergence) { + if !r.enabled { + return + } + fingerprint := fmt.Sprintf("divergence_%s_%s", div.Kind.String(), div.Command) + if !r.ShouldReport(fingerprint) { + return + } + r.hub.WithScope(func(scope *sentry.Scope) { + scope.SetTag("command", div.Command) + scope.SetTag("kind", div.Kind.String()) + // Omit raw key from Sentry tags to avoid leaking sensitive data; + // only send a truncated form as an extra for debugging. + scope.SetExtra("key", truncateValue(div.Key)) + scope.SetExtra("primary", truncateValue(div.Primary)) + scope.SetExtra("secondary", truncateValue(div.Secondary)) + scope.SetFingerprint([]string{"divergence", div.Kind.String(), div.Command}) + scope.SetLevel(sentry.LevelWarning) + r.hub.CaptureMessage(fmt.Sprintf("data divergence: %s %s", div.Kind, div.Command)) + }) +} + +// ShouldReport checks if this fingerprint has been reported recently (cooldown-based). +// Evicts expired entries when the map reaches maxReportEntries to bound memory usage. +// Returns false (drops the report) if the map is still at capacity after eviction. +func (r *SentryReporter) ShouldReport(fingerprint string) bool { + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + + // Evict expired entries if map grows too large + if len(r.lastReport) >= maxReportEntries { + for k, t := range r.lastReport { + if now.Sub(t) >= r.cooldown { + delete(r.lastReport, k) + } + } + // If still at capacity after eviction, drop report to prevent unbounded growth and Sentry flooding. + if len(r.lastReport) >= maxReportEntries { + return false + } + } + + if last, ok := r.lastReport[fingerprint]; ok && now.Sub(last) < r.cooldown { + return false + } + r.lastReport[fingerprint] = now + return true +} + +// Flush waits for pending Sentry events. +func (r *SentryReporter) Flush(timeout time.Duration) { + if r.enabled { + sentry.Flush(timeout) + } +} + +func cmdNameFromArgs(args [][]byte) string { + if len(args) > 0 { + return string(args[0]) + } + return unknownStr +} + +// truncateValue formats a value for logging/Sentry, truncating to avoid data leakage and oversized events. +// Handles common types by slicing before formatting to avoid allocating the full string representation. +func truncateValue(v any) string { + if v == nil { + return "" + } + + switch tv := v.(type) { + case string: + return truncateString(tv) + case []byte: + // Avoid converting an arbitrarily large byte slice into a full string. + if len(tv) > maxSentryValueLen { + return string(tv[:maxSentryValueLen]) + "...(truncated)" + } + return string(tv) + case fmt.Stringer: + // Respect custom String implementations but still apply length limits. + return truncateString(tv.String()) + default: + rv := reflect.ValueOf(v) + switch rv.Kind() { //nolint:exhaustive // only slice/array/map need special handling + case reflect.Slice, reflect.Array: + return formatSliceValue(rv, maxSentryValueLen) + case reflect.Map: + return formatMapValue(rv, maxSentryValueLen) + default: + // For non-container types, fall back to fmt and then truncate. + return truncateString(fmt.Sprintf("%v", v)) + } + } +} + +// truncateString enforces maxSentryValueLen on an already-built string. +func truncateString(s string) string { + if len(s) > maxSentryValueLen { + return s[:maxSentryValueLen] + "...(truncated)" + } + return s +} + +// formatSliceValue formats a slice/array value without allocating an unbounded string. +// It stops once approximately maxLen bytes have been written and appends a truncation marker. +func formatSliceValue(rv reflect.Value, maxLen int) string { + var b strings.Builder + b.WriteByte('[') + for i := 0; i < rv.Len(); i++ { + if b.Len() >= maxLen { + b.WriteString("...(truncated)]") + return b.String() + } + if i > 0 { + b.WriteString(", ") + } + elemStr := truncateValue(rv.Index(i).Interface()) + if b.Len()+len(elemStr) > maxLen { + // Write as much as fits, then mark as truncated. + remaining := maxLen - b.Len() + if remaining > 0 { + if remaining < len(elemStr) { + b.WriteString(elemStr[:remaining]) + } else { + b.WriteString(elemStr) + } + } + b.WriteString("...(truncated)]") + return b.String() + } + b.WriteString(elemStr) + } + b.WriteByte(']') + return truncateString(b.String()) +} + +// formatMapValue formats a map value without allocating an unbounded string. +// It stops once approximately maxLen bytes have been written and appends a truncation marker. +func formatMapValue(rv reflect.Value, maxLen int) string { + var b strings.Builder + b.WriteByte('{') + iter := rv.MapRange() + first := true + for iter.Next() { + if b.Len() >= maxLen { + b.WriteString("...(truncated)}") + return b.String() + } + if !first { + b.WriteString(", ") + } + first = false + keyStr := truncateValue(iter.Key().Interface()) + valStr := truncateValue(iter.Value().Interface()) + entry := fmt.Sprintf("%s: %s", keyStr, valStr) + if b.Len()+len(entry) > maxLen { + remaining := maxLen - b.Len() + if remaining > 0 { + if remaining < len(entry) { + b.WriteString(entry[:remaining]) + } else { + b.WriteString(entry) + } + } + b.WriteString("...(truncated)}") + return b.String() + } + b.WriteString(entry) + } + b.WriteByte('}') + return truncateString(b.String()) +} diff --git a/proxy/shadow_pubsub.go b/proxy/shadow_pubsub.go new file mode 100644 index 00000000..15f23848 --- /dev/null +++ b/proxy/shadow_pubsub.go @@ -0,0 +1,402 @@ +package proxy + +import ( + "context" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +// msgKey is used as a map key for matching primary and secondary messages. +// Includes Pattern to correctly distinguish pmessage deliveries. +type msgKey struct { + Pattern string + Channel string + Payload string +} + +// secondaryPending represents a secondary message that has not yet been matched +// to a primary within the comparison window. +type secondaryPending struct { + timestamp time.Time + channel string + payload string +} + +// unmatchedSecondaries buffers unmatched secondary messages per shadowPubSub +// instance. This allows us to avoid reporting DivExtraData immediately when a +// secondary arrives before the corresponding primary (e.g. due to network +// jitter) and instead only report once the comparison window has elapsed. +// Each key maps to a slice to handle duplicate secondary messages correctly. +var unmatchedSecondaries = struct { + sync.Mutex + data map[*shadowPubSub]map[msgKey][]secondaryPending +}{ + data: make(map[*shadowPubSub]map[msgKey][]secondaryPending), +} + +// pendingMsg records a message awaiting its counterpart from the other source. +type pendingMsg struct { + pattern string + channel string + payload string + timestamp time.Time +} + +// divergenceEvent holds divergence info collected under lock for deferred reporting. +type divergenceEvent struct { + channel string + payload string + kind DivergenceKind + isPattern bool // true if this originated from a PSUBSCRIBE +} + +// shadowPubSub subscribes to the secondary backend for the same channels +// and compares messages against the primary to detect divergences. +type shadowPubSub struct { + secondary *redis.PubSub + metrics *ProxyMetrics + sentry *SentryReporter + logger *slog.Logger + window time.Duration + nowFunc func() time.Time // injectable clock; defaults to time.Now + + mu sync.Mutex + pending map[msgKey][]pendingMsg // primary messages awaiting secondary match + closed bool + started bool + startOnce sync.Once + done chan struct{} +} + +func newShadowPubSub(backend PubSubBackend, metrics *ProxyMetrics, sentry *SentryReporter, logger *slog.Logger, window time.Duration) *shadowPubSub { + return &shadowPubSub{ + secondary: backend.NewPubSub(context.Background()), + metrics: metrics, + sentry: sentry, + logger: logger, + window: window, + nowFunc: time.Now, + pending: make(map[msgKey][]pendingMsg), + done: make(chan struct{}), + } +} + +// Start begins reading from the secondary and comparing messages. +// Must be called after initial subscribe. Safe to call multiple times; +// only the first call launches the compare loop. +func (sp *shadowPubSub) Start() { + sp.startOnce.Do(func() { + sp.mu.Lock() + sp.started = true + sp.mu.Unlock() + ch := sp.secondary.Channel() + go func() { + defer close(sp.done) + sp.compareLoop(ch) + }() + }) +} + +// Subscribe mirrors a SUBSCRIBE to the secondary. +func (sp *shadowPubSub) Subscribe(ctx context.Context, channels ...string) error { + if err := sp.secondary.Subscribe(ctx, channels...); err != nil { + return fmt.Errorf("shadow subscribe: %w", err) + } + return nil +} + +// PSubscribe mirrors a PSUBSCRIBE to the secondary. +func (sp *shadowPubSub) PSubscribe(ctx context.Context, patterns ...string) error { + if err := sp.secondary.PSubscribe(ctx, patterns...); err != nil { + return fmt.Errorf("shadow psubscribe: %w", err) + } + return nil +} + +// Unsubscribe mirrors an UNSUBSCRIBE to the secondary. +func (sp *shadowPubSub) Unsubscribe(ctx context.Context, channels ...string) error { + if err := sp.secondary.Unsubscribe(ctx, channels...); err != nil { + return fmt.Errorf("shadow unsubscribe: %w", err) + } + return nil +} + +// PUnsubscribe mirrors a PUNSUBSCRIBE to the secondary. +func (sp *shadowPubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { + if err := sp.secondary.PUnsubscribe(ctx, patterns...); err != nil { + return fmt.Errorf("shadow punsubscribe: %w", err) + } + return nil +} + +// RecordPrimary records a message received from the primary for comparison. +func (sp *shadowPubSub) RecordPrimary(msg *redis.Message) { + sp.mu.Lock() + defer sp.mu.Unlock() + if sp.closed { + return + } + key := msgKeyFromMessage(msg) + sp.pending[key] = append(sp.pending[key], pendingMsg{ + pattern: msg.Pattern, + channel: msg.Channel, + payload: msg.Payload, + timestamp: sp.nowFunc(), + }) +} + +// Close stops the shadow comparison and closes the secondary pub/sub. +// Safe to call even if Start was never called. +func (sp *shadowPubSub) Close() { + sp.mu.Lock() + sp.closed = true + started := sp.started + sp.mu.Unlock() + if sp.secondary != nil { + sp.secondary.Close() + } + if started { + <-sp.done + } + // Clean up buffered unmatched secondaries to prevent memory leak. + unmatchedSecondaries.Lock() + delete(unmatchedSecondaries.data, sp) + unmatchedSecondaries.Unlock() +} + +// compareLoop reads from the secondary channel and matches messages. +func (sp *shadowPubSub) compareLoop(ch <-chan *redis.Message) { + ticker := time.NewTicker(defaultPubSubSweepInterval) + defer ticker.Stop() + + for { + select { + case msg, ok := <-ch: + if !ok { + // Channel closed — mark as closed so RecordPrimary becomes a no-op, + // then flush all remaining pending as divergences. + sp.mu.Lock() + sp.closed = true + sp.mu.Unlock() + sp.sweepAll() + return + } + sp.matchSecondary(msg) + case <-ticker.C: + sp.sweepExpired() + } + } +} + +// matchSecondary tries to match a secondary message against a pending primary message. +// Collects divergence info under lock and reports after releasing it. +func (sp *shadowPubSub) matchSecondary(msg *redis.Message) { + sp.mu.Lock() + + key := msgKeyFromMessage(msg) + if entries, ok := sp.pending[key]; ok && len(entries) > 0 { + // Match found — remove the oldest pending primary message. + if len(entries) == 1 { + delete(sp.pending, key) + } else { + sp.pending[key] = entries[1:] + } + sp.mu.Unlock() + return + } + + sp.mu.Unlock() + + // No matching primary message at this moment. Buffer the secondary and only + // report DivExtraData if it remains unmatched after the comparison window. + now := sp.nowFunc() + unmatchedSecondaries.Lock() + defer unmatchedSecondaries.Unlock() + + perInstance, ok := unmatchedSecondaries.data[sp] + if !ok { + perInstance = make(map[msgKey][]secondaryPending) + unmatchedSecondaries.data[sp] = perInstance + } + + perInstance[key] = append(perInstance[key], secondaryPending{ + timestamp: now, + channel: msg.Channel, + payload: msg.Payload, + }) +} + +// sweepExpired reports primary messages that were not matched within the window. +func (sp *shadowPubSub) sweepExpired() { + var divergences []divergenceEvent + + sp.mu.Lock() + now := sp.nowFunc() + + unmatchedSecondaries.Lock() + perInstance := sp.getOrCreateSecondaryBuffer() + + // Expire old buffered secondaries first so they cannot be consumed as a + // "match" during reconciliation (prevents bypassing the comparison window). + divergences = sweepExpiredSecondaries(now, sp.window, perInstance, divergences) + divergences = sp.reconcilePrimaries(now, perInstance, divergences) + + if len(perInstance) == 0 { + delete(unmatchedSecondaries.data, sp) + } + + unmatchedSecondaries.Unlock() + sp.mu.Unlock() + + for _, d := range divergences { + sp.reportDivergence(d) + } +} + +// getOrCreateSecondaryBuffer returns the per-instance unmatched secondary buffer. +// Caller must hold unmatchedSecondaries.Lock(). +func (sp *shadowPubSub) getOrCreateSecondaryBuffer() map[msgKey][]secondaryPending { + perInstance, ok := unmatchedSecondaries.data[sp] + if !ok { + perInstance = make(map[msgKey][]secondaryPending) + unmatchedSecondaries.data[sp] = perInstance + } + return perInstance +} + +// reconcilePrimaries matches pending primaries against buffered secondaries, +// reporting expired unmatched primaries as divergences. +// Caller must hold sp.mu and unmatchedSecondaries.Lock(). +func (sp *shadowPubSub) reconcilePrimaries(now time.Time, secBuf map[msgKey][]secondaryPending, out []divergenceEvent) []divergenceEvent { + for key, entries := range sp.pending { + var remaining []pendingMsg + for _, e := range entries { + if secs := secBuf[key]; len(secs) > 0 { + // Matched — consume the oldest buffered secondary. + if len(secs) == 1 { + delete(secBuf, key) + } else { + secBuf[key] = secs[1:] + } + continue + } + if now.Sub(e.timestamp) >= sp.window { + out = append(out, divergenceEvent{channel: e.channel, payload: e.payload, kind: DivDataMismatch, isPattern: e.pattern != ""}) + } else { + remaining = append(remaining, e) + } + } + if len(remaining) == 0 { + delete(sp.pending, key) + } else { + sp.pending[key] = remaining + } + } + return out +} + +// sweepExpiredSecondaries ages out buffered secondaries past the window. +func sweepExpiredSecondaries(now time.Time, window time.Duration, secBuf map[msgKey][]secondaryPending, out []divergenceEvent) []divergenceEvent { + for key, secs := range secBuf { + var remaining []secondaryPending + for _, sec := range secs { + if now.Sub(sec.timestamp) >= window { + out = append(out, divergenceEvent{channel: sec.channel, payload: sec.payload, kind: DivExtraData, isPattern: key.Pattern != ""}) + } else { + remaining = append(remaining, sec) + } + } + if len(remaining) == 0 { + delete(secBuf, key) + } else { + secBuf[key] = remaining + } + } + return out +} + +// sweepAll reports all remaining pending primaries and buffered secondaries +// as divergences (used on shutdown / channel close). +func (sp *shadowPubSub) sweepAll() { + var divergences []divergenceEvent + + sp.mu.Lock() + for key, entries := range sp.pending { + for _, e := range entries { + divergences = append(divergences, divergenceEvent{ + channel: e.channel, + payload: e.payload, + kind: DivDataMismatch, + isPattern: e.pattern != "", + }) + } + delete(sp.pending, key) + } + sp.mu.Unlock() + + // Also drain any buffered unmatched secondaries for this instance. + unmatchedSecondaries.Lock() + if perInstance, ok := unmatchedSecondaries.data[sp]; ok { + for key, secs := range perInstance { + for _, sec := range secs { + divergences = append(divergences, divergenceEvent{ + channel: sec.channel, + payload: sec.payload, + kind: DivExtraData, + isPattern: key.Pattern != "", + }) + } + delete(perInstance, key) + } + delete(unmatchedSecondaries.data, sp) + } + unmatchedSecondaries.Unlock() + + for _, d := range divergences { + sp.reportDivergence(d) + } +} + +func (sp *shadowPubSub) reportDivergence(d divergenceEvent) { + sp.metrics.PubSubShadowDivergences.WithLabelValues(d.kind.String()).Inc() + sp.logger.Warn("pubsub shadow divergence", + "channel", truncateValue(d.channel), + "payload", truncateValue(d.payload), + "kind", d.kind.String(), + ) + + cmd := "SUBSCRIBE" + if d.isPattern { + cmd = "PSUBSCRIBE" + } + + var primary, secondary any + switch d.kind { //nolint:exhaustive // only two kinds apply to pub/sub shadow + case DivExtraData: + primary = nil + secondary = d.payload + default: + primary = d.payload + secondary = nil + } + sp.sentry.CaptureDivergence(Divergence{ + Command: cmd, + Key: d.channel, + Kind: d.kind, + Primary: primary, + Secondary: secondary, + DetectedAt: time.Now(), + }) +} + +func msgKeyFromMessage(msg *redis.Message) msgKey { + return msgKey{ + Pattern: msg.Pattern, + Channel: msg.Channel, + Payload: msg.Payload, + } +} diff --git a/proxy/shadow_pubsub_test.go b/proxy/shadow_pubsub_test.go new file mode 100644 index 00000000..6fe0236d --- /dev/null +++ b/proxy/shadow_pubsub_test.go @@ -0,0 +1,252 @@ +package proxy + +import ( + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +func counterValue(c prometheus.Counter) float64 { + m := &dto.Metric{} + pm, ok := c.(prometheus.Metric) + if !ok { + return 0 + } + _ = pm.Write(m) + return m.GetCounter().GetValue() +} + +// testClock provides a deterministic clock for tests, avoiding time.Sleep flakiness. +type testClock struct { + v atomic.Value // stores time.Time +} + +func newTestClock() *testClock { + c := &testClock{} + c.v.Store(time.Now()) + return c +} + +func (c *testClock) Now() time.Time { + v, ok := c.v.Load().(time.Time) + if !ok { + return time.Time{} + } + return v +} +func (c *testClock) Advance(d time.Duration) { c.v.Store(c.Now().Add(d)) } + +func newTestShadowPubSub(window time.Duration) *shadowPubSub { + return newTestShadowPubSubWithClock(window, time.Now) +} + +func newTestShadowPubSubWithClock(window time.Duration, nowFunc func() time.Time) *shadowPubSub { + return &shadowPubSub{ + metrics: newTestMetrics(), + sentry: newTestSentry(), + logger: slog.Default(), + window: window, + nowFunc: nowFunc, + pending: make(map[msgKey][]pendingMsg), + done: make(chan struct{}), + } +} + +func TestShadowPubSub_MatchedMessage(t *testing.T) { + sp := newTestShadowPubSub(100 * time.Millisecond) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "hello"}) + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "hello"}) + + sp.mu.Lock() + remaining := len(sp.pending) + sp.mu.Unlock() + assert.Equal(t, 0, remaining, "matched message should be removed from pending") +} + +func TestShadowPubSub_MissingOnSecondary(t *testing.T) { + clock := newTestClock() + sp := newTestShadowPubSubWithClock(10*time.Millisecond, clock.Now) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "hello"}) + clock.Advance(20 * time.Millisecond) + sp.sweepExpired() + + sp.mu.Lock() + remaining := len(sp.pending) + sp.mu.Unlock() + assert.Equal(t, 0, remaining, "expired message should be removed") + + val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("data_mismatch")) + assert.Equal(t, float64(1), val) +} + +func TestShadowPubSub_ExtraOnSecondary(t *testing.T) { + clock := newTestClock() + sp := newTestShadowPubSubWithClock(10*time.Millisecond, clock.Now) + defer func() { + unmatchedSecondaries.Lock() + delete(unmatchedSecondaries.data, sp) + unmatchedSecondaries.Unlock() + }() + + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "extra"}) + + // Advance the clock past the comparison window and sweep. + clock.Advance(20 * time.Millisecond) + sp.sweepExpired() + + val := counterValue(sp.metrics.PubSubShadowDivergences.WithLabelValues("extra_data")) + assert.Equal(t, float64(1), val) +} + +func TestShadowPubSub_OutOfOrderMatching(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "msg1"}) + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "msg2"}) + + // Secondary delivers in reverse order. + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "msg2"}) + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "msg1"}) + + sp.mu.Lock() + remaining := len(sp.pending) + sp.mu.Unlock() + assert.Equal(t, 0, remaining, "all messages should be matched") +} + +func TestShadowPubSub_DuplicateMessages(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "dup"}) + + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.mu.Lock() + assert.Equal(t, 1, len(sp.pending[msgKey{Channel: "ch1", Payload: "dup"}])) + sp.mu.Unlock() + + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending)) + sp.mu.Unlock() +} + +func TestShadowPubSub_RecordAfterClose(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.mu.Lock() + sp.closed = true + sp.mu.Unlock() + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "after-close"}) + + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending)) + sp.mu.Unlock() +} + +func TestShadowPubSub_CompareLoopExitsOnChannelClose(t *testing.T) { + clock := newTestClock() + sp := newTestShadowPubSubWithClock(10*time.Millisecond, clock.Now) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "orphan"}) + clock.Advance(20 * time.Millisecond) + + ch := make(chan *redis.Message) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sp.compareLoop(ch) + }() + + close(ch) + wg.Wait() + + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending), "should sweep on exit") + assert.True(t, sp.closed, "should mark closed on channel close to prevent RecordPrimary leak") + sp.mu.Unlock() +} + +func TestShadowPubSub_DuplicateSecondaryBuffered(t *testing.T) { + sp := newTestShadowPubSub(10 * time.Millisecond) + defer func() { + // Clean up without calling Close (test has no real secondary connection). + unmatchedSecondaries.Lock() + delete(unmatchedSecondaries.data, sp) + unmatchedSecondaries.Unlock() + }() + + // Two identical secondary messages arrive before any primary. + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "dup"}) + + unmatchedSecondaries.Lock() + key := msgKey{Channel: "ch1", Payload: "dup"} + secs := unmatchedSecondaries.data[sp][key] + assert.Len(t, secs, 2, "both duplicate secondaries should be buffered") + unmatchedSecondaries.Unlock() + + // Now one primary arrives — should consume one buffered secondary. + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "dup"}) + sp.sweepExpired() // reconcile + + unmatchedSecondaries.Lock() + secs = unmatchedSecondaries.data[sp][key] + unmatchedSecondaries.Unlock() + // One secondary remains buffered (only one primary consumed one). + assert.Len(t, secs, 1, "one duplicate should remain after matching one primary") +} + +func TestShadowPubSub_CloseCleanupUnmatchedSecondaries(t *testing.T) { + // Close() now tolerates nil secondary, so no mock client is needed. + sp := newTestShadowPubSub(1 * time.Second) + + // Buffer a secondary message. + sp.matchSecondary(&redis.Message{Channel: "ch1", Payload: "leaked"}) + + unmatchedSecondaries.Lock() + _, exists := unmatchedSecondaries.data[sp] + unmatchedSecondaries.Unlock() + assert.True(t, exists, "secondary should be buffered before Close") + + sp.Close() + + unmatchedSecondaries.Lock() + _, exists = unmatchedSecondaries.data[sp] + unmatchedSecondaries.Unlock() + assert.False(t, exists, "Close should clean up unmatchedSecondaries entry") +} + +func TestShadowPubSub_CompareLoopMatchesFromChannel(t *testing.T) { + sp := newTestShadowPubSub(1 * time.Second) + + sp.RecordPrimary(&redis.Message{Channel: "ch1", Payload: "msg1"}) + + ch := make(chan *redis.Message, 1) + ch <- &redis.Message{Channel: "ch1", Payload: "msg1"} + close(ch) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + sp.compareLoop(ch) + }() + + wg.Wait() + + sp.mu.Lock() + assert.Equal(t, 0, len(sp.pending), "message should be matched via compareLoop") + sp.mu.Unlock() +}