From f4d63e0b6fe01bc5582db3fa23c303e3df7f711f Mon Sep 17 00:00:00 2001 From: "red-hat-konflux[bot]" <126015336+red-hat-konflux[bot]@users.noreply.github.com> Date: Sat, 20 Jun 2026 12:10:56 +0000 Subject: [PATCH] chore(deps): update module github.com/redis/go-redis/v9 to v9.20.1 Signed-off-by: red-hat-konflux <126015336+red-hat-konflux[bot]@users.noreply.github.com> --- go.mod | 4 +- go.sum | 9 +- .../github.com/dgryski/go-rendezvous/rdv.go | 79 - .../github.com/redis/go-redis/v9/.gitignore | 11 +- .../redis/go-redis/v9/.golangci.yml | 2 + .../github.com/redis/go-redis/v9/CHANGELOG.md | 133 - .../redis/go-redis/v9/CONTRIBUTING.md | 2 +- vendor/github.com/redis/go-redis/v9/Makefile | 91 +- vendor/github.com/redis/go-redis/v9/README.md | 384 +- .../redis/go-redis/v9/RELEASE-NOTES.md | 983 +++++ .../github.com/redis/go-redis/v9/RELEASING.md | 141 +- .../redis/go-redis/v9/acl_commands.go | 27 + .../github.com/redis/go-redis/v9/adapters.go | 118 + .../redis/go-redis/v9/array_commands.go | 387 ++ .../github.com/redis/go-redis/v9/auth/auth.go | 79 + .../v9/auth/reauth_credentials_listener.go | 47 + .../redis/go-redis/v9/bitmap_commands.go | 40 +- .../redis/go-redis/v9/cluster_commands.go | 6 + .../github.com/redis/go-redis/v9/command.go | 3885 ++++++++++++++++- .../go-redis/v9/command_policy_resolver.go | 209 + .../github.com/redis/go-redis/v9/commands.go | 90 +- .../redis/go-redis/v9/dial_retry_backoff.go | 39 + .../redis/go-redis/v9/docker-compose.yml | 86 +- vendor/github.com/redis/go-redis/v9/error.go | 287 +- .../redis/go-redis/v9/generic_commands.go | 8 + .../redis/go-redis/v9/geo_commands.go | 10 +- .../redis/go-redis/v9/hash_commands.go | 23 +- .../redis/go-redis/v9/hotkeys_commands.go | 122 + .../conn_reauth_credentials_listener.go | 100 + .../internal/auth/streaming/cred_listeners.go | 77 + .../v9/internal/auth/streaming/manager.go | 137 + .../v9/internal/auth/streaming/pool_hook.go | 241 + .../go-redis/v9/internal/hashtag/hashtag.go | 17 +- .../v9/internal/hashtag/rendezvous.go | 54 + .../go-redis/v9/internal/hscan/structmap.go | 2 + .../v9/internal/interfaces/interfaces.go | 59 + .../redis/go-redis/v9/internal/internal.go | 3 +- .../redis/go-redis/v9/internal/log.go | 61 +- .../maintnotifications/logs/log_messages.go | 663 +++ .../go-redis/v9/internal/otel/metrics.go | 298 ++ .../redis/go-redis/v9/internal/pool/conn.go | 911 +++- .../go-redis/v9/internal/pool/conn_check.go | 12 +- .../v9/internal/pool/conn_check_dummy.go | 15 +- .../go-redis/v9/internal/pool/conn_state.go | 336 ++ .../redis/go-redis/v9/internal/pool/hooks.go | 165 + .../redis/go-redis/v9/internal/pool/pool.go | 1501 ++++++- .../go-redis/v9/internal/pool/pool_single.go | 58 +- .../go-redis/v9/internal/pool/pool_sticky.go | 17 +- .../redis/go-redis/v9/internal/pool/pubsub.go | 105 + .../go-redis/v9/internal/pool/want_conn.go | 115 + .../go-redis/v9/internal/proto/reader.go | 371 +- .../v9/internal/proto/redis_errors.go | 539 +++ .../redis/go-redis/v9/internal/rand/rand.go | 50 - .../redis/go-redis/v9/internal/redis.go | 3 + .../v9/internal/routing/aggregator.go | 1000 +++++ .../go-redis/v9/internal/routing/policy.go | 144 + .../v9/internal/routing/shard_picker.go | 57 + .../redis/go-redis/v9/internal/semaphore.go | 193 + .../redis/go-redis/v9/internal/util.go | 41 +- .../go-redis/v9/internal/util/atomic_max.go | 97 + .../go-redis/v9/internal/util/atomic_min.go | 96 + .../go-redis/v9/internal/util/convert.go | 41 + .../redis/go-redis/v9/internal/util/unsafe.go | 9 +- vendor/github.com/redis/go-redis/v9/json.go | 110 +- .../redis/go-redis/v9/list_commands.go | 8 + .../v9/maintnotifications/FEATURES.md | 235 + .../go-redis/v9/maintnotifications/README.md | 73 + .../v9/maintnotifications/circuit_breaker.go | 353 ++ .../go-redis/v9/maintnotifications/config.go | 502 +++ .../go-redis/v9/maintnotifications/errors.go | 76 + .../v9/maintnotifications/example_hooks.go | 101 + .../v9/maintnotifications/handoff_worker.go | 525 +++ .../go-redis/v9/maintnotifications/hooks.go | 60 + .../go-redis/v9/maintnotifications/manager.go | 362 ++ .../v9/maintnotifications/pool_hook.go | 182 + .../push_notification_handler.go | 524 +++ .../go-redis/v9/maintnotifications/state.go | 24 + .../github.com/redis/go-redis/v9/options.go | 395 +- .../redis/go-redis/v9/osscluster.go | 962 +++- .../redis/go-redis/v9/osscluster_router.go | 1002 +++++ vendor/github.com/redis/go-redis/v9/otel.go | 235 + .../github.com/redis/go-redis/v9/pipeline.go | 29 +- .../redis/go-redis/v9/probabilistic.go | 132 +- vendor/github.com/redis/go-redis/v9/pubsub.go | 189 +- .../redis/go-redis/v9/pubsub_commands.go | 14 +- .../redis/go-redis/v9/push/errors.go | 176 + .../redis/go-redis/v9/push/handler.go | 14 + .../redis/go-redis/v9/push/handler_context.go | 44 + .../redis/go-redis/v9/push/processor.go | 203 + .../github.com/redis/go-redis/v9/push/push.go | 7 + .../redis/go-redis/v9/push/registry.go | 61 + .../redis/go-redis/v9/push_notifications.go | 21 + vendor/github.com/redis/go-redis/v9/redis.go | 1114 ++++- vendor/github.com/redis/go-redis/v9/result.go | 8 + vendor/github.com/redis/go-redis/v9/ring.go | 239 +- vendor/github.com/redis/go-redis/v9/script.go | 143 +- .../redis/go-redis/v9/scripting_commands.go | 5 + .../redis/go-redis/v9/search_builders.go | 858 ++++ .../redis/go-redis/v9/search_commands.go | 2167 ++++++++- .../github.com/redis/go-redis/v9/sentinel.go | 560 ++- .../redis/go-redis/v9/set_commands.go | 171 +- .../redis/go-redis/v9/sortedset_commands.go | 71 +- .../redis/go-redis/v9/stream_commands.go | 239 +- .../redis/go-redis/v9/string_commands.go | 587 ++- .../redis/go-redis/v9/timeseries_commands.go | 246 +- vendor/github.com/redis/go-redis/v9/tx.go | 11 +- .../github.com/redis/go-redis/v9/universal.go | 247 +- .../redis/go-redis/v9/vectorset_commands.go | 480 ++ .../github.com/redis/go-redis/v9/version.go | 2 +- vendor/go.uber.org/atomic/.codecov.yml | 19 + vendor/go.uber.org/atomic/.gitignore | 15 + vendor/go.uber.org/atomic/CHANGELOG.md | 127 + .../atomic/LICENSE.txt} | 4 +- vendor/go.uber.org/atomic/Makefile | 79 + vendor/go.uber.org/atomic/README.md | 63 + vendor/go.uber.org/atomic/bool.go | 88 + vendor/go.uber.org/atomic/bool_ext.go | 53 + vendor/go.uber.org/atomic/doc.go | 23 + vendor/go.uber.org/atomic/duration.go | 89 + vendor/go.uber.org/atomic/duration_ext.go | 40 + vendor/go.uber.org/atomic/error.go | 72 + vendor/go.uber.org/atomic/error_ext.go | 39 + vendor/go.uber.org/atomic/float32.go | 77 + vendor/go.uber.org/atomic/float32_ext.go | 76 + vendor/go.uber.org/atomic/float64.go | 77 + vendor/go.uber.org/atomic/float64_ext.go | 76 + vendor/go.uber.org/atomic/gen.go | 27 + vendor/go.uber.org/atomic/int32.go | 109 + vendor/go.uber.org/atomic/int64.go | 109 + vendor/go.uber.org/atomic/nocmp.go | 35 + vendor/go.uber.org/atomic/pointer_go118.go | 31 + .../atomic/pointer_go118_pre119.go | 60 + vendor/go.uber.org/atomic/pointer_go119.go | 61 + vendor/go.uber.org/atomic/string.go | 72 + vendor/go.uber.org/atomic/string_ext.go | 54 + vendor/go.uber.org/atomic/time.go | 55 + vendor/go.uber.org/atomic/time_ext.go | 36 + vendor/go.uber.org/atomic/uint32.go | 109 + vendor/go.uber.org/atomic/uint64.go | 109 + vendor/go.uber.org/atomic/uintptr.go | 109 + vendor/go.uber.org/atomic/unsafe_pointer.go | 65 + vendor/go.uber.org/atomic/value.go | 31 + vendor/modules.txt | 19 +- 143 files changed, 28941 insertions(+), 1824 deletions(-) delete mode 100644 vendor/github.com/dgryski/go-rendezvous/rdv.go delete mode 100644 vendor/github.com/redis/go-redis/v9/CHANGELOG.md create mode 100644 vendor/github.com/redis/go-redis/v9/adapters.go create mode 100644 vendor/github.com/redis/go-redis/v9/array_commands.go create mode 100644 vendor/github.com/redis/go-redis/v9/auth/auth.go create mode 100644 vendor/github.com/redis/go-redis/v9/auth/reauth_credentials_listener.go create mode 100644 vendor/github.com/redis/go-redis/v9/command_policy_resolver.go create mode 100644 vendor/github.com/redis/go-redis/v9/dial_retry_backoff.go create mode 100644 vendor/github.com/redis/go-redis/v9/hotkeys_commands.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/auth/streaming/conn_reauth_credentials_listener.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/auth/streaming/cred_listeners.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/auth/streaming/manager.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/auth/streaming/pool_hook.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/hashtag/rendezvous.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/interfaces/interfaces.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/maintnotifications/logs/log_messages.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/otel/metrics.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/pool/conn_state.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/pool/hooks.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/pool/pubsub.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/pool/want_conn.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/proto/redis_errors.go delete mode 100644 vendor/github.com/redis/go-redis/v9/internal/rand/rand.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/redis.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/routing/aggregator.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/routing/policy.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/routing/shard_picker.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/semaphore.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/util/atomic_max.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/util/atomic_min.go create mode 100644 vendor/github.com/redis/go-redis/v9/internal/util/convert.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/FEATURES.md create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/README.md create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/circuit_breaker.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/config.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/errors.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/example_hooks.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/handoff_worker.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/hooks.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/manager.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/pool_hook.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/push_notification_handler.go create mode 100644 vendor/github.com/redis/go-redis/v9/maintnotifications/state.go create mode 100644 vendor/github.com/redis/go-redis/v9/osscluster_router.go create mode 100644 vendor/github.com/redis/go-redis/v9/otel.go create mode 100644 vendor/github.com/redis/go-redis/v9/push/errors.go create mode 100644 vendor/github.com/redis/go-redis/v9/push/handler.go create mode 100644 vendor/github.com/redis/go-redis/v9/push/handler_context.go create mode 100644 vendor/github.com/redis/go-redis/v9/push/processor.go create mode 100644 vendor/github.com/redis/go-redis/v9/push/push.go create mode 100644 vendor/github.com/redis/go-redis/v9/push/registry.go create mode 100644 vendor/github.com/redis/go-redis/v9/push_notifications.go create mode 100644 vendor/github.com/redis/go-redis/v9/search_builders.go create mode 100644 vendor/github.com/redis/go-redis/v9/vectorset_commands.go create mode 100644 vendor/go.uber.org/atomic/.codecov.yml create mode 100644 vendor/go.uber.org/atomic/.gitignore create mode 100644 vendor/go.uber.org/atomic/CHANGELOG.md rename vendor/{github.com/dgryski/go-rendezvous/LICENSE => go.uber.org/atomic/LICENSE.txt} (92%) create mode 100644 vendor/go.uber.org/atomic/Makefile create mode 100644 vendor/go.uber.org/atomic/README.md create mode 100644 vendor/go.uber.org/atomic/bool.go create mode 100644 vendor/go.uber.org/atomic/bool_ext.go create mode 100644 vendor/go.uber.org/atomic/doc.go create mode 100644 vendor/go.uber.org/atomic/duration.go create mode 100644 vendor/go.uber.org/atomic/duration_ext.go create mode 100644 vendor/go.uber.org/atomic/error.go create mode 100644 vendor/go.uber.org/atomic/error_ext.go create mode 100644 vendor/go.uber.org/atomic/float32.go create mode 100644 vendor/go.uber.org/atomic/float32_ext.go create mode 100644 vendor/go.uber.org/atomic/float64.go create mode 100644 vendor/go.uber.org/atomic/float64_ext.go create mode 100644 vendor/go.uber.org/atomic/gen.go create mode 100644 vendor/go.uber.org/atomic/int32.go create mode 100644 vendor/go.uber.org/atomic/int64.go create mode 100644 vendor/go.uber.org/atomic/nocmp.go create mode 100644 vendor/go.uber.org/atomic/pointer_go118.go create mode 100644 vendor/go.uber.org/atomic/pointer_go118_pre119.go create mode 100644 vendor/go.uber.org/atomic/pointer_go119.go create mode 100644 vendor/go.uber.org/atomic/string.go create mode 100644 vendor/go.uber.org/atomic/string_ext.go create mode 100644 vendor/go.uber.org/atomic/time.go create mode 100644 vendor/go.uber.org/atomic/time_ext.go create mode 100644 vendor/go.uber.org/atomic/uint32.go create mode 100644 vendor/go.uber.org/atomic/uint64.go create mode 100644 vendor/go.uber.org/atomic/uintptr.go create mode 100644 vendor/go.uber.org/atomic/unsafe_pointer.go create mode 100644 vendor/go.uber.org/atomic/value.go diff --git a/go.mod b/go.mod index 0eaa73fbb..4a8bd5a95 100644 --- a/go.mod +++ b/go.mod @@ -69,7 +69,6 @@ require ( github.com/cyphar/filepath-securejoin v0.6.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/davidmz/go-pageant v1.0.2 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dlclark/regexp2 v1.12.0 // indirect github.com/emicklei/go-restful/v3 v3.13.0 // indirect @@ -151,7 +150,7 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.68.1 // indirect github.com/prometheus/procfs v0.20.1 // indirect - github.com/redis/go-redis/v9 v9.8.0 // indirect + github.com/redis/go-redis/v9 v9.20.1 // indirect github.com/robfig/cron/v3 v3.0.2-0.20210106135023-bc59245fe10e // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect @@ -172,6 +171,7 @@ require ( github.com/xlab/treeprint v1.2.0 // indirect go.opentelemetry.io/otel v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect + go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.28.0 // indirect go.yaml.in/yaml/v2 v2.4.4 // indirect diff --git a/go.sum b/go.sum index fa87829ea..1caa1c7df 100644 --- a/go.sum +++ b/go.sum @@ -104,7 +104,6 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davidmz/go-pageant v1.0.2 h1:bPblRCh5jGU+Uptpz6LgMZGD5hJoOt7otgT454WvHn0= github.com/davidmz/go-pageant v1.0.2/go.mod h1:P2EDDnMqIwG5Rrp05dTRITj9z2zpGcD9efWSkTNKLIE= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= @@ -410,8 +409,8 @@ github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4Ul github.com/r3labs/diff/v3 v3.0.2 h1:yVuxAY1V6MeM4+HNur92xkS39kB/N+cFi2hMkY06BbA= github.com/r3labs/diff/v3 v3.0.2/go.mod h1:Cy542hv0BAEmhDYWtGxXRQ4kqRsVIcEjG9gChUlTmkw= github.com/redis/go-redis/v9 v9.0.0-rc.4/go.mod h1:Vo3EsyWnicKnSKCA7HhgnvnyA74wOA69Cd2Meli5mmA= -github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= -github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/redis/go-redis/v9 v9.20.1 h1:sfCU6A8P3dXbKyWes02uxA2baehGux9dZHfEKtsTB1w= +github.com/redis/go-redis/v9 v9.20.1/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= github.com/robfig/cron/v3 v3.0.2-0.20210106135023-bc59245fe10e h1:0xChnl3lhHiXbgSJKgChye0D+DvoItkOdkGcwelDXH0= github.com/robfig/cron/v3 v3.0.2-0.20210106135023-bc59245fe10e/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= @@ -486,6 +485,8 @@ github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= @@ -496,6 +497,8 @@ go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWv go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/vendor/github.com/dgryski/go-rendezvous/rdv.go b/vendor/github.com/dgryski/go-rendezvous/rdv.go deleted file mode 100644 index 7a6f8203c..000000000 --- a/vendor/github.com/dgryski/go-rendezvous/rdv.go +++ /dev/null @@ -1,79 +0,0 @@ -package rendezvous - -type Rendezvous struct { - nodes map[string]int - nstr []string - nhash []uint64 - hash Hasher -} - -type Hasher func(s string) uint64 - -func New(nodes []string, hash Hasher) *Rendezvous { - r := &Rendezvous{ - nodes: make(map[string]int, len(nodes)), - nstr: make([]string, len(nodes)), - nhash: make([]uint64, len(nodes)), - hash: hash, - } - - for i, n := range nodes { - r.nodes[n] = i - r.nstr[i] = n - r.nhash[i] = hash(n) - } - - return r -} - -func (r *Rendezvous) Lookup(k string) string { - // short-circuit if we're empty - if len(r.nodes) == 0 { - return "" - } - - khash := r.hash(k) - - var midx int - var mhash = xorshiftMult64(khash ^ r.nhash[0]) - - for i, nhash := range r.nhash[1:] { - if h := xorshiftMult64(khash ^ nhash); h > mhash { - midx = i + 1 - mhash = h - } - } - - return r.nstr[midx] -} - -func (r *Rendezvous) Add(node string) { - r.nodes[node] = len(r.nstr) - r.nstr = append(r.nstr, node) - r.nhash = append(r.nhash, r.hash(node)) -} - -func (r *Rendezvous) Remove(node string) { - // find index of node to remove - nidx := r.nodes[node] - - // remove from the slices - l := len(r.nstr) - r.nstr[nidx] = r.nstr[l] - r.nstr = r.nstr[:l] - - r.nhash[nidx] = r.nhash[l] - r.nhash = r.nhash[:l] - - // update the map - delete(r.nodes, node) - moved := r.nstr[nidx] - r.nodes[moved] = nidx -} - -func xorshiftMult64(x uint64) uint64 { - x ^= x >> 12 // a - x ^= x << 25 // b - x ^= x >> 27 // c - return x * 2685821657736338717 -} diff --git a/vendor/github.com/redis/go-redis/v9/.gitignore b/vendor/github.com/redis/go-redis/v9/.gitignore index e9c8f5264..93affec7f 100644 --- a/vendor/github.com/redis/go-redis/v9/.gitignore +++ b/vendor/github.com/redis/go-redis/v9/.gitignore @@ -7,4 +7,13 @@ testdata/* redis8tests.sh coverage.txt **/coverage.txt -.vscode \ No newline at end of file +.vscode +tmp/* +*.test +extra/redisotel-native/metrics-collector-app/ +# maintenanceNotifications upgrade documentation (temporary) +maintenanceNotifications/docs/ + +# Docker-generated files (TLS certificates, cluster data, etc.) +dockers/*/tls/ +dockers/osscluster-tls/ diff --git a/vendor/github.com/redis/go-redis/v9/.golangci.yml b/vendor/github.com/redis/go-redis/v9/.golangci.yml index 872454ff7..dd13c2c29 100644 --- a/vendor/github.com/redis/go-redis/v9/.golangci.yml +++ b/vendor/github.com/redis/go-redis/v9/.golangci.yml @@ -26,6 +26,8 @@ linters: - builtin$ - examples$ formatters: + enable: + - gofmt exclusions: generated: lax paths: diff --git a/vendor/github.com/redis/go-redis/v9/CHANGELOG.md b/vendor/github.com/redis/go-redis/v9/CHANGELOG.md deleted file mode 100644 index e1652b179..000000000 --- a/vendor/github.com/redis/go-redis/v9/CHANGELOG.md +++ /dev/null @@ -1,133 +0,0 @@ -## Unreleased - -### Changed - -* `go-redis` won't skip span creation if the parent spans is not recording. ([#2980](https://github.com/redis/go-redis/issues/2980)) - Users can use the OpenTelemetry sampler to control the sampling behavior. - For instance, you can use the `ParentBased(NeverSample())` sampler from `go.opentelemetry.io/otel/sdk/trace` to keep - a similar behavior (drop orphan spans) of `go-redis` as before. - -## [9.0.5](https://github.com/redis/go-redis/compare/v9.0.4...v9.0.5) (2023-05-29) - - -### Features - -* Add ACL LOG ([#2536](https://github.com/redis/go-redis/issues/2536)) ([31ba855](https://github.com/redis/go-redis/commit/31ba855ddebc38fbcc69a75d9d4fb769417cf602)) -* add field protocol to setupClusterQueryParams ([#2600](https://github.com/redis/go-redis/issues/2600)) ([840c25c](https://github.com/redis/go-redis/commit/840c25cb6f320501886a82a5e75f47b491e46fbe)) -* add protocol option ([#2598](https://github.com/redis/go-redis/issues/2598)) ([3917988](https://github.com/redis/go-redis/commit/391798880cfb915c4660f6c3ba63e0c1a459e2af)) - - - -## [9.0.4](https://github.com/redis/go-redis/compare/v9.0.3...v9.0.4) (2023-05-01) - - -### Bug Fixes - -* reader float parser ([#2513](https://github.com/redis/go-redis/issues/2513)) ([46f2450](https://github.com/redis/go-redis/commit/46f245075e6e3a8bd8471f9ca67ea95fd675e241)) - - -### Features - -* add client info command ([#2483](https://github.com/redis/go-redis/issues/2483)) ([b8c7317](https://github.com/redis/go-redis/commit/b8c7317cc6af444603731f7017c602347c0ba61e)) -* no longer verify HELLO error messages ([#2515](https://github.com/redis/go-redis/issues/2515)) ([7b4f217](https://github.com/redis/go-redis/commit/7b4f2179cb5dba3d3c6b0c6f10db52b837c912c8)) -* read the structure to increase the judgment of the omitempty op… ([#2529](https://github.com/redis/go-redis/issues/2529)) ([37c057b](https://github.com/redis/go-redis/commit/37c057b8e597c5e8a0e372337f6a8ad27f6030af)) - - - -## [9.0.3](https://github.com/redis/go-redis/compare/v9.0.2...v9.0.3) (2023-04-02) - -### New Features - -- feat(scan): scan time.Time sets the default decoding (#2413) -- Add support for CLUSTER LINKS command (#2504) -- Add support for acl dryrun command (#2502) -- Add support for COMMAND GETKEYS & COMMAND GETKEYSANDFLAGS (#2500) -- Add support for LCS Command (#2480) -- Add support for BZMPOP (#2456) -- Adding support for ZMPOP command (#2408) -- Add support for LMPOP (#2440) -- feat: remove pool unused fields (#2438) -- Expiretime and PExpireTime (#2426) -- Implement `FUNCTION` group of commands (#2475) -- feat(zadd): add ZAddLT and ZAddGT (#2429) -- Add: Support for COMMAND LIST command (#2491) -- Add support for BLMPOP (#2442) -- feat: check pipeline.Do to prevent confusion with Exec (#2517) -- Function stats, function kill, fcall and fcall_ro (#2486) -- feat: Add support for CLUSTER SHARDS command (#2507) -- feat(cmd): support for adding byte,bit parameters to the bitpos command (#2498) - -### Fixed - -- fix: eval api cmd.SetFirstKeyPos (#2501) -- fix: limit the number of connections created (#2441) -- fixed #2462 v9 continue support dragonfly, it's Hello command return "NOAUTH Authentication required" error (#2479) -- Fix for internal/hscan/structmap.go:89:23: undefined: reflect.Pointer (#2458) -- fix: group lag can be null (#2448) - -### Maintenance - -- Updating to the latest version of redis (#2508) -- Allowing for running tests on a port other than the fixed 6380 (#2466) -- redis 7.0.8 in tests (#2450) -- docs: Update redisotel example for v9 (#2425) -- chore: update go mod, Upgrade golang.org/x/net version to 0.7.0 (#2476) -- chore: add Chinese translation (#2436) -- chore(deps): bump github.com/bsm/gomega from 1.20.0 to 1.26.0 (#2421) -- chore(deps): bump github.com/bsm/ginkgo/v2 from 2.5.0 to 2.7.0 (#2420) -- chore(deps): bump actions/setup-go from 3 to 4 (#2495) -- docs: add instructions for the HSet api (#2503) -- docs: add reading lag field comment (#2451) -- test: update go mod before testing(go mod tidy) (#2423) -- docs: fix comment typo (#2505) -- test: remove testify (#2463) -- refactor: change ListElementCmd to KeyValuesCmd. (#2443) -- fix(appendArg): appendArg case special type (#2489) - -## [9.0.2](https://github.com/redis/go-redis/compare/v9.0.1...v9.0.2) (2023-02-01) - -### Features - -* upgrade OpenTelemetry, use the new metrics API. ([#2410](https://github.com/redis/go-redis/issues/2410)) ([e29e42c](https://github.com/redis/go-redis/commit/e29e42cde2755ab910d04185025dc43ce6f59c65)) - -## v9 2023-01-30 - -### Breaking - -- Changed Pipelines to not be thread-safe any more. - -### Added - -- Added support for [RESP3](https://github.com/antirez/RESP3/blob/master/spec.md) protocol. It was - contributed by @monkey92t who has done the majority of work in this release. -- Added `ContextTimeoutEnabled` option that controls whether the client respects context timeouts - and deadlines. See - [Redis Timeouts](https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts) for details. -- Added `ParseClusterURL` to parse URLs into `ClusterOptions`, for example, - `redis://user:password@localhost:6789?dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791`. -- Added metrics instrumentation using `redisotel.IstrumentMetrics`. See - [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html) -- Added `redis.HasErrorPrefix` to help working with errors. - -### Changed - -- Removed asynchronous cancellation based on the context timeout. It was racy in v8 and is - completely gone in v9. -- Reworked hook interface and added `DialHook`. -- Replaced `redisotel.NewTracingHook` with `redisotel.InstrumentTracing`. See - [example](example/otel) and - [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html). -- Replaced `*redis.Z` with `redis.Z` since it is small enough to be passed as value without making - an allocation. -- Renamed the option `MaxConnAge` to `ConnMaxLifetime`. -- Renamed the option `IdleTimeout` to `ConnMaxIdleTime`. -- Removed connection reaper in favor of `MaxIdleConns`. -- Removed `WithContext` since `context.Context` can be passed directly as an arg. -- Removed `Pipeline.Close` since there is no real need to explicitly manage pipeline resources and - it can be safely reused via `sync.Pool` etc. `Pipeline.Discard` is still available if you want to - reset commands for some reason. - -### Fixed - -- Improved and fixed pipeline retries. -- As usually, added support for more commands and fixed some bugs. diff --git a/vendor/github.com/redis/go-redis/v9/CONTRIBUTING.md b/vendor/github.com/redis/go-redis/v9/CONTRIBUTING.md index 7228a4a06..8c68c522e 100644 --- a/vendor/github.com/redis/go-redis/v9/CONTRIBUTING.md +++ b/vendor/github.com/redis/go-redis/v9/CONTRIBUTING.md @@ -37,7 +37,7 @@ Here's how to get started with your code contribution: > Note: this clones and builds the docker containers specified in `docker-compose.yml`, to understand more about > the infrastructure that will be started you can check the `docker-compose.yml`. You also have the possiblity > to specify the redis image that will be pulled with the env variable `CLIENT_LIBS_TEST_IMAGE`. -> By default the docker image that will be pulled and started is `redislabs/client-libs-test:rs-7.4.0-v2`. +> By default the docker image that will be pulled and started is `redislabs/client-libs-test:8.2.1-pre`. > If you want to test with newer Redis version, using a newer version of `redislabs/client-libs-test` should work out of the box. 4. While developing, make sure the tests pass by running `make test` (if you have the docker containers running, `make test.ci` may be sufficient). diff --git a/vendor/github.com/redis/go-redis/v9/Makefile b/vendor/github.com/redis/go-redis/v9/Makefile index fc175f5f1..90f03b57e 100644 --- a/vendor/github.com/redis/go-redis/v9/Makefile +++ b/vendor/github.com/redis/go-redis/v9/Makefile @@ -1,33 +1,112 @@ GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort) +REDIS_VERSION ?= 8.8 +RE_CLUSTER ?= false +RCE_DOCKER ?= true +CLIENT_LIBS_TEST_IMAGE ?= redislabs/client-libs-test:8.8.0 docker.start: + export RE_CLUSTER=$(RE_CLUSTER) && \ + export RCE_DOCKER=$(RCE_DOCKER) && \ + export REDIS_VERSION=$(REDIS_VERSION) && \ + export CLIENT_LIBS_TEST_IMAGE=$(CLIENT_LIBS_TEST_IMAGE) && \ docker compose --profile all up -d --quiet-pull docker.stop: docker compose --profile all down +docker.e2e.start: + @echo "Starting Redis and cae-resp-proxy for E2E tests..." + docker compose --profile e2e up -d --quiet-pull + @echo "Waiting for services to be ready..." + @sleep 3 + @echo "Services ready!" + +docker.e2e.stop: + @echo "Stopping E2E services..." + docker compose --profile e2e down + test: $(MAKE) docker.start - $(MAKE) test.ci + @if [ -z "$(REDIS_VERSION)" ]; then \ + echo "REDIS_VERSION not set, running all tests"; \ + $(MAKE) test.ci; \ + else \ + MAJOR_VERSION=$$(echo "$(REDIS_VERSION)" | cut -d. -f1); \ + if [ "$$MAJOR_VERSION" -ge 8 ]; then \ + echo "REDIS_VERSION $(REDIS_VERSION) >= 8, running all tests"; \ + $(MAKE) test.ci; \ + else \ + echo "REDIS_VERSION $(REDIS_VERSION) < 8, skipping vector_sets tests"; \ + $(MAKE) test.ci.skip-vectorsets; \ + fi; \ + fi $(MAKE) docker.stop test.ci: set -e; for dir in $(GO_MOD_DIRS); do \ echo "go test in $${dir}"; \ (cd "$${dir}" && \ - go mod tidy -compat=1.18 && \ + export RE_CLUSTER=$(RE_CLUSTER) && \ + export RCE_DOCKER=$(RCE_DOCKER) && \ + export REDIS_VERSION=$(REDIS_VERSION) && \ + go mod tidy && \ + go vet && \ + go test -v -coverprofile=coverage.txt -covermode=atomic ./... -race -skip Example); \ + done + cd internal/customvet && go build . + go vet -vettool ./internal/customvet/customvet + +test.ci.skip-vectorsets: + set -e; for dir in $(GO_MOD_DIRS); do \ + echo "go test in $${dir} (skipping vector sets)"; \ + (cd "$${dir}" && \ + export RE_CLUSTER=$(RE_CLUSTER) && \ + export RCE_DOCKER=$(RCE_DOCKER) && \ + export REDIS_VERSION=$(REDIS_VERSION) && \ + go mod tidy && \ go vet && \ - go test -v -coverprofile=coverage.txt -covermode=atomic ./... -race); \ + go test -v -coverprofile=coverage.txt -covermode=atomic ./... -race \ + -run '^(?!.*(?:VectorSet|vectorset|ExampleClient_vectorset)).*$$' -skip Example); \ done cd internal/customvet && go build . go vet -vettool ./internal/customvet/customvet bench: - go test ./... -test.run=NONE -test.bench=. -test.benchmem + export RE_CLUSTER=$(RE_CLUSTER) && \ + export RCE_DOCKER=$(RCE_DOCKER) && \ + export REDIS_VERSION=$(REDIS_VERSION) && \ + go test ./... -test.run=NONE -test.bench=. -test.benchmem -skip Example + +test.e2e: + @echo "Running E2E tests with auto-start proxy..." + $(MAKE) docker.e2e.start + @echo "Running tests..." + @E2E_SCENARIO_TESTS=true go test -v ./maintnotifications/e2e/ -timeout 30m || ($(MAKE) docker.e2e.stop && exit 1) + $(MAKE) docker.e2e.stop + @echo "E2E tests completed!" + +test.e2e.docker: + @echo "Running Docker-compatible E2E tests..." + $(MAKE) docker.e2e.start + @echo "Running unified injector tests..." + @E2E_SCENARIO_TESTS=true go test -v -run "TestUnifiedInjector|TestCreateTestFaultInjectorLogic|TestFaultInjectorClientCreation" ./maintnotifications/e2e/ -timeout 10m || ($(MAKE) docker.e2e.stop && exit 1) + $(MAKE) docker.e2e.stop + @echo "Docker E2E tests completed!" + +test.e2e.logic: + @echo "Running E2E logic tests (no proxy required)..." + @E2E_SCENARIO_TESTS=true \ + REDIS_ENDPOINTS_CONFIG_PATH=/tmp/test_endpoints_verify.json \ + FAULT_INJECTION_API_URL=http://localhost:8080 \ + go test -v -run "TestCreateTestFaultInjectorLogic|TestFaultInjectorClientCreation" ./maintnotifications/e2e/ + @echo "Logic tests completed!" -.PHONY: all test bench fmt +.PHONY: all test test.ci test.ci.skip-vectorsets bench fmt test.e2e test.e2e.logic docker.e2e.start docker.e2e.stop build: + export RE_CLUSTER=$(RE_CLUSTER) && \ + export RCE_DOCKER=$(RCE_DOCKER) && \ + export REDIS_VERSION=$(REDIS_VERSION) && \ go build . fmt: @@ -39,5 +118,5 @@ go_mod_tidy: echo "go mod tidy in $${dir}"; \ (cd "$${dir}" && \ go get -u ./... && \ - go mod tidy -compat=1.18); \ + go mod tidy); \ done diff --git a/vendor/github.com/redis/go-redis/v9/README.md b/vendor/github.com/redis/go-redis/v9/README.md index 4487c6e9a..ae90d2b7d 100644 --- a/vendor/github.com/redis/go-redis/v9/README.md +++ b/vendor/github.com/redis/go-redis/v9/README.md @@ -2,7 +2,7 @@ [![build workflow](https://github.com/redis/go-redis/actions/workflows/build.yml/badge.svg)](https://github.com/redis/go-redis/actions) [![PkgGoDev](https://pkg.go.dev/badge/github.com/redis/go-redis/v9)](https://pkg.go.dev/github.com/redis/go-redis/v9?tab=doc) -[![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.uptrace.dev/) +[![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.io/docs/latest/develop/clients/go/) [![Go Report Card](https://goreportcard.com/badge/github.com/redis/go-redis/v9)](https://goreportcard.com/report/github.com/redis/go-redis/v9) [![codecov](https://codecov.io/github/redis/go-redis/graph/badge.svg?token=tsrCZKuSSw)](https://codecov.io/github/redis/go-redis) @@ -17,16 +17,24 @@ ## Supported versions In `go-redis` we are aiming to support the last three releases of Redis. Currently, this means we do support: -- [Redis 7.2](https://raw.githubusercontent.com/redis/redis/7.2/00-RELEASENOTES) - using Redis Stack 7.2 for modules support -- [Redis 7.4](https://raw.githubusercontent.com/redis/redis/7.4/00-RELEASENOTES) - using Redis Stack 7.4 for modules support -- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0 where modules are included +- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0 +- [Redis 8.2](https://raw.githubusercontent.com/redis/redis/8.2/00-RELEASENOTES) - using Redis CE 8.2 +- [Redis 8.4](https://raw.githubusercontent.com/redis/redis/8.4/00-RELEASENOTES) - using Redis CE 8.4 +- [Redis 8.8](https://raw.githubusercontent.com/redis/redis/8.8/00-RELEASENOTES) - using Redis CE 8.8 -Although the `go.mod` states it requires at minimum `go 1.18`, our CI is configured to run the tests against all three -versions of Redis and latest two versions of Go ([1.23](https://go.dev/doc/devel/release#go1.23.0), -[1.24](https://go.dev/doc/devel/release#go1.24.0)). We observe that some modules related test may not pass with +Although the `go.mod` states it requires at minimum `go 1.24`, our CI is configured to run the tests against all supported +versions of Redis and multiple versions of Go ([1.24](https://go.dev/doc/devel/release#go1.24.0), oldstable, and stable). We observe that some modules related test may not pass with Redis Stack 7.2 and some commands are changed with Redis CE 8.0. -Please do refer to the documentation and the tests if you experience any issues. We do plan to update the go version -in the `go.mod` to `go 1.24` in one of the next releases. +Although it is not officially supported, `go-redis/v9` should be able to work with any Redis 7.0+. +Please do refer to the documentation and the tests if you experience any issues. + +### Array data type (Redis 8.8+) + +Starting with Redis 8.8, go-redis exposes the new array data type via the `AR*` command family +(`ARSET`, `ARGET`, `ARGETRANGE`, `ARMSET`, `ARMGET`, `ARINSERT`, `ARDEL`, `ARDELRANGE`, +`ARLEN`, `ARCOUNT`, `ARNEXT`, `ARSEEK`, `ARSCAN`, `ARGREP`, `ARRING`, `ARLASTITEMS`, +`ARINFO`/`ARINFOFULL`, and the `AROP*` reducers). See `array_commands.go` for the full +surface. The API is experimental and may change in a future release. ## How do I Redis? @@ -42,10 +50,6 @@ in the `go.mod` to `go 1.24` in one of the next releases. [Work at Redis](https://redis.com/company/careers/jobs/) -## Documentation - -- [English](https://redis.uptrace.dev) -- [简体中文](https://redis.uptrace.dev/zh/) ## Resources @@ -53,29 +57,33 @@ in the `go.mod` to `go 1.24` in one of the next releases. - [Chat](https://discord.gg/W4txy5AeKM) - [Reference](https://pkg.go.dev/github.com/redis/go-redis/v9) - [Examples](https://pkg.go.dev/github.com/redis/go-redis/v9#pkg-examples) +- [Release notes](./RELEASE-NOTES.md) ([GitHub Releases](https://github.com/redis/go-redis/releases)) + +## old documentation + +- [English](https://redis.uptrace.dev) +- [简体中文](https://redis.uptrace.dev/zh/) ## Ecosystem -- [Redis Mock](https://github.com/go-redis/redismock) +- [Entra ID (Azure AD)](https://github.com/redis/go-redis-entraid) - [Distributed Locks](https://github.com/bsm/redislock) - [Redis Cache](https://github.com/go-redis/cache) - [Rate limiting](https://github.com/go-redis/redis_rate) -This client also works with [Kvrocks](https://github.com/apache/incubator-kvrocks), a distributed -key value NoSQL database that uses RocksDB as storage engine and is compatible with Redis protocol. - ## Features - Redis commands except QUIT and SYNC. - Automatic connection pooling. +- [StreamingCredentialsProvider (e.g. entra id, oauth)](#1-streaming-credentials-provider-highest-priority) (experimental) - [Pub/Sub](https://redis.uptrace.dev/guide/go-redis-pubsub.html). - [Pipelines and transactions](https://redis.uptrace.dev/guide/go-redis-pipelines.html). - [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html). - [Redis Sentinel](https://redis.uptrace.dev/guide/go-redis-sentinel.html). - [Redis Cluster](https://redis.uptrace.dev/guide/go-redis-cluster.html). -- [Redis Ring](https://redis.uptrace.dev/guide/ring.html). - [Redis Performance Monitoring](https://redis.uptrace.dev/guide/redis-performance-monitoring.html). - [Redis Probabilistic [RedisStack]](https://redis.io/docs/data-types/probabilistic/) +- [Customizable read and write buffers size.](#custom-buffer-sizes) ## Installation @@ -111,6 +119,7 @@ func ExampleClient() { Password: "", // no password set DB: 0, // use default DB }) + defer rdb.Close() err := rdb.Set(ctx, "key", "value", 0).Err() if err != nil { @@ -136,17 +145,144 @@ func ExampleClient() { } ``` -The above can be modified to specify the version of the RESP protocol by adding the `protocol` -option to the `Options` struct: +### Dial retries and backoff + +Connection establishment can be retried by the connection pool when dialing fails. + +- **`DialerRetries`**: maximum number of dial attempts (default: 5). +- **`DialerRetryTimeout`**: default delay between attempts when no custom backoff is provided (default: 100ms). +- **`DialerRetryBackoff`**: optional function hook to control the delay between attempts. + +Example: ```go - rdb := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", - Password: "", // no password set - DB: 0, // use default DB - Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3 - }) +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + + DialerRetries: 5, + DialerRetryTimeout: 100 * time.Millisecond, // used when DialerRetryBackoff is nil + + // Optional: exponential backoff with jitter and a cap. + DialerRetryBackoff: redis.DialRetryBackoffExponential(100*time.Millisecond, 2*time.Second), +}) +defer rdb.Close() +``` + +### Authentication + +The Redis client supports multiple ways to provide authentication credentials, with a clear priority order. Here are the available options: + +#### 1. Streaming Credentials Provider (Highest Priority) - Experimental feature + +The streaming credentials provider allows for dynamic credential updates during the connection lifetime. This is particularly useful for managed identity services and token-based authentication. + +```go +type StreamingCredentialsProvider interface { + Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) +} + +type CredentialsListener interface { + OnNext(credentials Credentials) // Called when credentials are updated + OnError(err error) // Called when an error occurs +} + +type Credentials interface { + BasicAuth() (username string, password string) + RawCredentials() string +} +``` + +Example usage: +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + StreamingCredentialsProvider: &MyCredentialsProvider{}, +}) +``` + +**Note:** The streaming credentials provider can be used with [go-redis-entraid](https://github.com/redis/go-redis-entraid) to enable Entra ID (formerly Azure AD) authentication. This allows for seamless integration with Azure's managed identity services and token-based authentication. + +Example with Entra ID: +```go +import ( + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis-entraid" +) + +// Create an Entra ID credentials provider +provider := entraid.NewDefaultAzureIdentityProvider() + +// Configure Redis client with Entra ID authentication +rdb := redis.NewClient(&redis.Options{ + Addr: "your-redis-server.redis.cache.windows.net:6380", + StreamingCredentialsProvider: provider, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, +}) +``` + +#### 2. Context-based Credentials Provider + +The context-based provider allows credentials to be determined at the time of each operation, using the context. + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + CredentialsProviderContext: func(ctx context.Context) (string, string, error) { + // Return username, password, and any error + return "user", "pass", nil + }, +}) +``` + +#### 3. Regular Credentials Provider + +A simple function-based provider that returns static credentials. + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + CredentialsProvider: func() (string, string) { + // Return username and password + return "user", "pass" + }, +}) +``` + +#### 4. Username/Password Fields (Lowest Priority) +The most basic way to provide credentials is through the `Username` and `Password` fields in the options. + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Username: "user", + Password: "pass", +}) +``` + +#### Priority Order + +The client will use credentials in the following priority order: +1. Streaming Credentials Provider (if set) +2. Context-based Credentials Provider (if set) +3. Regular Credentials Provider (if set) +4. Username/Password fields (if set) + +If none of these are set, the client will attempt to connect without authentication. + +### Protocol Version + +The client supports both RESP2 and RESP3 protocols. You can specify the protocol version in the options: + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", // no password set + DB: 0, // use default DB + Protocol: 3, // specify 2 for RESP 2 or 3 for RESP 3 +}) ``` ### Connecting via a redis url @@ -192,6 +328,18 @@ func main() { ``` +### Buffer Size Configuration + +go-redis uses 32KiB read and write buffers by default for optimal performance. For high-throughput applications or large pipelines, you can customize buffer sizes: + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + ReadBufferSize: 1024 * 1024, // 1MiB read buffer + WriteBufferSize: 1024 * 1024, // 1MiB write buffer +}) +``` + ### Advanced Configuration go-redis supports extending the client identification phase to allow projects to send their own custom client identification. @@ -217,18 +365,17 @@ rdb := redis.NewClient(&redis.Options{ }) ``` -#### Unstable RESP3 Structures for RediSearch Commands -When integrating Redis with application functionalities using RESP3, it's important to note that some response structures aren't final yet. This is especially true for more complex structures like search and query results. We recommend using RESP2 when using the search and query capabilities, but we plan to stabilize the RESP3-based API-s in the coming versions. You can find more guidance in the upcoming release notes. +#### RESP3 for RediSearch Commands (`UnstableResp3` is deprecated) +As of v9.20, `FT.SEARCH`, `FT.AGGREGATE`, `FT.INFO`, `FT.SPELLCHECK`, and `FT.SYNDUMP` +parse RESP3 (map) responses into the same typed result objects as RESP2. **No flag +is required — `Val()` / `Result()` work uniformly on both protocols.** -To enable unstable RESP3, set the option in your client configuration: +The legacy `UnstableResp3` option is now a **no-op** and is retained on every +options struct only for backwards compatibility. It will be removed in a future +release; new code should not set it. -```go -redis.NewClient(&redis.Options{ - UnstableResp3: true, - }) -``` -**Note:** When UnstableResp3 mode is enabled, it's necessary to use RawResult() and RawVal() to retrieve a raw data. - Since, raw response is the only option for unstable search commands Val() and Result() calls wouldn't have any affect on them: +`RawResult()` / `RawVal()` continue to work for callers that prefer the raw RESP +payload directly: ```go res1, err := client.FTSearchWithArgs(ctx, "txt", "foo bar", &redis.FTSearchOptions{}).RawResult() @@ -255,6 +402,21 @@ For example: ``` You can find further details in the [query dialect documentation](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/dialects/). +#### Custom buffer sizes +Prior to v9.12, the buffer size was the default go value of 4096 bytes. Starting from v9.12, +go-redis uses 32KiB read and write buffers by default for optimal performance. +For high-throughput applications or large pipelines, you can customize buffer sizes: + +```go +rdb := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + ReadBufferSize: 1024 * 1024, // 1MiB read buffer + WriteBufferSize: 1024 * 1024, // 1MiB write buffer +}) +``` + +**Important**: If you experience any issues with the default buffer sizes, please try setting them to the go default of 4096 bytes. + ## Contributing We welcome contributions to the go-redis library! If you have a bug fix, feature request, or improvement, please open an issue or pull request on GitHub. We appreciate your help in making go-redis better for everyone. @@ -295,38 +457,150 @@ vals, err := rdb.Eval(ctx, "return {KEYS[1],ARGV[1]}", []string{"key"}, "hello") res, err := rdb.Do(ctx, "set", "key", "value").Result() ``` -## Run the test - -go-redis will start a redis-server and run the test cases. +## Typed Errors -The paths of redis-server bin file and redis config file are defined in `main_test.go`: +go-redis provides typed error checking functions for common Redis errors: ```go -var ( - redisServerBin, _ = filepath.Abs(filepath.Join("testdata", "redis", "src", "redis-server")) - redisServerConf, _ = filepath.Abs(filepath.Join("testdata", "redis", "redis.conf")) -) +// Cluster and replication errors +redis.IsLoadingError(err) // Redis is loading the dataset +redis.IsReadOnlyError(err) // Write to read-only replica +redis.IsClusterDownError(err) // Cluster is down +redis.IsTryAgainError(err) // Command should be retried +redis.IsMasterDownError(err) // Master is down +redis.IsMovedError(err) // Returns (address, true) if key moved +redis.IsAskError(err) // Returns (address, true) if key being migrated + +// Connection and resource errors +redis.IsMaxClientsError(err) // Maximum clients reached +redis.IsAuthError(err) // Authentication failed (NOAUTH, WRONGPASS, unauthenticated) +redis.IsPermissionError(err) // Permission denied (NOPERM) +redis.IsOOMError(err) // Out of memory (OOM) + +// Transaction errors +redis.IsExecAbortError(err) // Transaction aborted (EXECABORT) ``` -For local testing, you can change the variables to refer to your local files, or create a soft link -to the corresponding folder for redis-server and copy the config file to `testdata/redis/`: +### Error Wrapping in Hooks -```shell -ln -s /usr/bin/redis-server ./go-redis/testdata/redis/src -cp ./go-redis/testdata/redis.conf ./go-redis/testdata/redis/ +When wrapping errors in hooks, use custom error types with `Unwrap()` method (preferred) or `fmt.Errorf` with `%w`. Always call `cmd.SetErr()` to preserve error type information: + +```go +// Custom error type (preferred) +type AppError struct { + Code string + RequestID string + Err error +} + +func (e *AppError) Error() string { + return fmt.Sprintf("[%s] request_id=%s: %v", e.Code, e.RequestID, e.Err) +} + +func (e *AppError) Unwrap() error { + return e.Err +} + +// Hook implementation +func (h MyHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + err := next(ctx, cmd) + if err != nil { + // Wrap with custom error type + wrappedErr := &AppError{ + Code: "REDIS_ERROR", + RequestID: getRequestID(ctx), + Err: err, + } + cmd.SetErr(wrappedErr) + return wrappedErr // Return wrapped error to preserve it + } + return nil + } +} + +// Typed error detection works through wrappers +if redis.IsLoadingError(err) { + // Retry logic +} + +// Extract custom error if needed +var appErr *AppError +if errors.As(err, &appErr) { + log.Printf("Request: %s", appErr.RequestID) +} ``` -Lastly, run: +Alternatively, use `fmt.Errorf` with `%w`: +```go +wrappedErr := fmt.Errorf("context: %w", err) +cmd.SetErr(wrappedErr) +``` -```shell -go test +### Pipeline Hook Example + +For pipeline operations, use `ProcessPipelineHook`: + +```go +type PipelineLoggingHook struct{} + +func (h PipelineLoggingHook) DialHook(next redis.DialHook) redis.DialHook { + return next +} + +func (h PipelineLoggingHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return next +} + +func (h PipelineLoggingHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + start := time.Now() + + // Execute the pipeline + err := next(ctx, cmds) + + duration := time.Since(start) + log.Printf("Pipeline executed %d commands in %v", len(cmds), duration) + + // Process individual command errors + // Note: Individual command errors are already set on each cmd by the pipeline execution + for _, cmd := range cmds { + if cmdErr := cmd.Err(); cmdErr != nil { + // Check for specific error types using typed error functions + if redis.IsAuthError(cmdErr) { + log.Printf("Auth error in pipeline command %s: %v", cmd.Name(), cmdErr) + } else if redis.IsPermissionError(cmdErr) { + log.Printf("Permission error in pipeline command %s: %v", cmd.Name(), cmdErr) + } + + // Optionally wrap individual command errors to add context + // The wrapped error preserves type information through errors.As() + wrappedErr := fmt.Errorf("pipeline cmd %s failed: %w", cmd.Name(), cmdErr) + cmd.SetErr(wrappedErr) + } + } + + // Return the pipeline-level error (connection errors, etc.) + // You can wrap it if needed, or return it as-is + return err + } +} + +// Register the hook +rdb.AddHook(PipelineLoggingHook{}) + +// Use pipeline - errors are still properly typed +pipe := rdb.Pipeline() +pipe.Set(ctx, "key1", "value1", 0) +pipe.Get(ctx, "key2") +_, err := pipe.Exec(ctx) ``` -Another option is to run your specific tests with an already running redis. The example below, tests -against a redis running on port 9999.: +## Run the test +Recommended to use Docker, just need to run: ```shell -REDIS_PORT=9999 go test +make test ``` ## See also diff --git a/vendor/github.com/redis/go-redis/v9/RELEASE-NOTES.md b/vendor/github.com/redis/go-redis/v9/RELEASE-NOTES.md index fa106cb92..de24456c3 100644 --- a/vendor/github.com/redis/go-redis/v9/RELEASE-NOTES.md +++ b/vendor/github.com/redis/go-redis/v9/RELEASE-NOTES.md @@ -1,5 +1,852 @@ # Release Notes +# 9.20.1 (2026-06-11) + +This is a patch release containing bug fixes only. There are no new features or breaking changes; upgrading from 9.20.0 is a drop-in replacement. + +## 🚀 Highlights + +### RESP3 pub/sub message loss fixed + +`PeekPushNotificationName` previously inspected only the bytes already buffered by `bufio`, so when a push frame header straddled a buffer fill boundary it could return a **truncated** notification name (e.g. `"messa"` instead of `"message"`). The push processor then mis-routed the frame and `ReadReply` silently dropped it, causing intermittent RESP3 pub/sub message loss. The peek now grows its window (36 bytes → up to 4 KiB) and reads more from the connection until the header is complete, cleanly separating incomplete prefixes from corrupt frames (including overflow-safe bulk-length handling). Fixes [#3839](https://github.com/redis/go-redis/issues/3839). + +([#3842](https://github.com/redis/go-redis/pull/3842)) by [@ndyakov](https://github.com/ndyakov) + +## 🐛 Bug Fixes + +- **RESP3 push peeking**: `PeekPushNotificationName` no longer returns a truncated notification name when a push frame header spans a buffer boundary, preventing silent RESP3 pub/sub message loss (fixes [#3839](https://github.com/redis/go-redis/issues/3839)) ([#3842](https://github.com/redis/go-redis/pull/3842)) by [@ndyakov](https://github.com/ndyakov) +- **`FT.HYBRID` vector params**: Vector data is now always sent via `PARAMS` with auto-generated param names (`__vector_param_N`, with collision avoidance) when `VectorParamName` is omitted, since Redis no longer accepts inline vector blobs; the `FTHybridOptions.Params` map is no longer mutated, so the same options struct can be reused across calls ([#3844](https://github.com/redis/go-redis/pull/3844)) by [@ndyakov](https://github.com/ndyakov) +- **`CLUSTER SHARDS` forward compatibility**: Unknown shard- and node-level attributes in the `CLUSTER SHARDS` reply are now skipped via `DiscardNext()` instead of erroring, so clients keep working when the server introduces new fields ([#3843](https://github.com/redis/go-redis/pull/3843)) by [@madolson](https://github.com/madolson) +- **PubSub double reconnect**: `PubSub.releaseConn` no longer reconnects twice when a connection is both unusable (or pending handoff) and reports a bad-connection error, avoiding a wasted connection establish-then-close cycle ([#3833](https://github.com/redis/go-redis/pull/3833)) by [@cxljs](https://github.com/cxljs) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@cxljs](https://github.com/cxljs), [@madolson](https://github.com/madolson), [@ndyakov](https://github.com/ndyakov) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.20.0...v9.20.1 + +# 9.20.0 (2026-05-28) + +## 🚀 Highlights + +### Redis 8.8 Support + +This release adds support for **Redis 8.8**. The README's supported-versions list now includes Redis 8.8 alongside 8.0/8.2/8.4, and CI exercises the `8.8-rc1` client-libs-test image across the full suite (Makefile, build workflow, doctests, run-tests action, and docker-compose). + +Coverage for the new commands that ship in the 8.x line, rounded out in this release: + +- **`AR*` array data type** ([#3813](https://github.com/redis/go-redis/pull/3813)) — new array data structure, exposed via the `ArrayCmdable` interface (see the experimental-features highlight below). +- **`INCREX`** ([#3816](https://github.com/redis/go-redis/pull/3816)) — atomic increment with expiration in a single round-trip. +- **`XNACK`** ([#3790](https://github.com/redis/go-redis/pull/3790)) — explicit negative-acknowledge of pending stream entries. +- **`XAUTOCLAIM` PEL deletes** ([#3798](https://github.com/redis/go-redis/pull/3798)) — `XAUTOCLAIM`/`XAUTOCLAIMJUSTID` now return the list of deleted message IDs from the pending entries list. +- **`TS.RANGE` multiple aggregators** ([#3791](https://github.com/redis/go-redis/pull/3791)) — `TS.RANGE`/`TS.REVRANGE`/`TS.MRANGE`/`TS.MREVRANGE` accept multiple aggregators in a single call. +- **`Z(UNION|INTER|DIFF)` `COUNT` aggregator** ([#3802](https://github.com/redis/go-redis/pull/3802)) — `COUNT` reducer for sorted-set set operations. +- **`JSON.SET FPHA`** ([#3797](https://github.com/redis/go-redis/pull/3797)) — new `FPHA` argument that specifies the floating-point type for homogeneous FP arrays. + +CI image bump ([#3814](https://github.com/redis/go-redis/pull/3814)) by [@ofekshenawa](https://github.com/ofekshenawa). Command coverage contributions by [@cxljs](https://github.com/cxljs), [@elena-kolevska](https://github.com/elena-kolevska), [@Khukharr](https://github.com/Khukharr), [@ndyakov](https://github.com/ndyakov), and [@ofekshenawa](https://github.com/ofekshenawa). + +### Stable RESP3 for RediSearch (`UnstableResp3` deprecated) + +`FT.SEARCH`, `FT.AGGREGATE`, `FT.INFO`, `FT.SPELLCHECK`, and `FT.SYNDUMP` now parse RESP3 (map) responses into the same typed result objects as RESP2 — `Val()` and `Result()` work uniformly on both protocols, no flag required. Previously, RESP3 search responses required `UnstableResp3: true` and were returned as opaque maps accessible only via `RawResult()` / `RawVal()`. + +As a result, the `UnstableResp3` option is now a **no-op** across every options struct (`Options`, `ClusterOptions`, `UniversalOptions`, `FailoverOptions`, `RingOptions`) and has been marked `// Deprecated:`. The field is retained for backwards compatibility — existing code that sets `UnstableResp3: true` will continue to compile and behave identically — but it will be removed in a future release and new code should not set it. `RawResult()` / `RawVal()` continue to work for callers that prefer the raw RESP payload. + +([#3741](https://github.com/redis/go-redis/pull/3741)) by [@ndyakov](https://github.com/ndyakov) + +### Experimental Array Data Structure Commands + +Adds an experimental `ArrayCmdable` interface with the `AR*` command family (`ARSet`, `ARGet`, `ARGetRange`, `ARMSet`, `ARMGet`, `ARDel`, `ARDelRange`, `ARScan`, `ARSeek`, `ARNext`, `ARLastItems`, `ARGrep`, `ARGrepWithValues`, `ARInfo`/`ARInfoFull`, and typed reducers `AROpSum`/`AROpMin`/`AROpMax`/`AROpAnd`/`AROpOr`/`AROpXor`/`AROpMatch`/`AROpUsed`) for working with Redis 8.8's new array data type. **API is experimental and may change in a future release.** + +([#3813](https://github.com/redis/go-redis/pull/3813)) by [@cxljs](https://github.com/cxljs) + +## ✨ New Features + +- **RESP3 search parser**: First-class RESP3 parsing for `FT.SEARCH`/`FT.AGGREGATE`/`FT.INFO`/`FT.SPELLCHECK`/`FT.SYNDUMP` responses with backwards compatibility for RESP2 ([#3741](https://github.com/redis/go-redis/pull/3741)) by [@ndyakov](https://github.com/ndyakov) +- **INCREX**: New `INCREX` command support — atomic increment with expiration ([#3816](https://github.com/redis/go-redis/pull/3816)) by [@ndyakov](https://github.com/ndyakov) +- **XNACK**: Client support for the `XNACK` stream command for explicitly negative-acknowledging pending entries ([#3790](https://github.com/redis/go-redis/pull/3790)) by [@elena-kolevska](https://github.com/elena-kolevska) +- **TS range multiple aggregators**: `TS.RANGE`/`TS.REVRANGE`/`TS.MRANGE`/`TS.MREVRANGE` now accept multiple aggregators in a single call ([#3791](https://github.com/redis/go-redis/pull/3791)) by [@elena-kolevska](https://github.com/elena-kolevska) +- **`XAutoClaim` deleted IDs**: `XAUTOCLAIM`/`XAUTOCLAIMJUSTID` now return the list of deleted message IDs from the PEL ([#3798](https://github.com/redis/go-redis/pull/3798)) by [@Khukharr](https://github.com/Khukharr) +- **`JSON.SET FPHA`**: `JSON.SET` accepts a new `FPHA` argument that specifies the floating-point type for homogeneous floating-point arrays ([#3797](https://github.com/redis/go-redis/pull/3797)) by [@ndyakov](https://github.com/ndyakov) +- **Sorted-set union/intersection COUNT**: `ZUNION`/`ZINTER`/`ZDIFF` aggregator now supports `COUNT` ([#3802](https://github.com/redis/go-redis/pull/3802)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **`FT.HYBRID` vector validation**: Validates hybrid-search vector input types and adds proper typed vector parameters ([#3756](https://github.com/redis/go-redis/pull/3756)) by [@DengY11](https://github.com/DengY11) +- **Cluster pool wait stats**: `ClusterClient.PoolStats()` now accumulates `WaitCount` and `WaitDurationNs` across all node pools (previously always zero) ([#3809](https://github.com/redis/go-redis/pull/3809)) by [@LINKIWI](https://github.com/LINKIWI) + +## 🐛 Bug Fixes + +- **TLS-only Cluster PubSub**: `CLUSTER SLOTS` port-0 entries now fall back to the origin endpoint's port, fixing `dial tcp :0: connection refused` on TLS-only clusters started with `--port 0 --tls-port ` (fixes [#3726](https://github.com/redis/go-redis/issues/3726)) ([#3828](https://github.com/redis/go-redis/pull/3828)) by [@ndyakov](https://github.com/ndyakov) +- **Sharded PubSub reconnect routing**: `PubSub.conn()` now passes both regular (`c.channels`) and sharded (`c.schannels`) channels into the per-PubSub `newConn` closure. Previously, `ClusterClient.SSubscribe`-only PubSubs reconnected to a random node (because the routing closure saw an empty channel list), the `SSUBSCRIBE` was sent to the wrong shard, and the resulting `MOVED` reply was silently dropped ([#3829](https://github.com/redis/go-redis/pull/3829)) by [@ndyakov](https://github.com/ndyakov) +- **ClusterClient `Watch` retry**: User errors returned from a `Watch` callback are no longer subjected to cluster-retry classification; transient cluster errors still retry, but a callback returning e.g. `net.ErrClosed` short-circuits immediately ([#3821](https://github.com/redis/go-redis/pull/3821)) by [@obiyang](https://github.com/obiyang) +- **Sentinel concurrent-probe leak**: `MasterAddr`'s concurrent sentinel probe now closes the non-winning sentinel clients instead of leaking them ([#3827](https://github.com/redis/go-redis/pull/3827)) by [@cxljs](https://github.com/cxljs) +- **Sentinel rediscovery loop on master-only setups**: `replicaAddrs` no longer tears down the cached sentinel client when the replica list is empty, eliminating a continuous rediscovery loop on master-only Sentinel deployments that flooded logs and added per-operation latency ([#3795](https://github.com/redis/go-redis/pull/3795)) by [@shahyash2609](https://github.com/shahyash2609) +- **Pool `CloseConn` hooks**: `Pool.CloseConn` now triggers registered hooks, fixing a memory leak when connections are closed explicitly rather than via the normal removal path ([#3818](https://github.com/redis/go-redis/pull/3818)) by [@ndyakov](https://github.com/ndyakov) +- **Dial TCP error redirection**: Wrapped `dial tcp` errors are now correctly classified as redirectable so cluster routing can recover from a single unreachable node ([#3810](https://github.com/redis/go-redis/pull/3810)) by [@vladisa88](https://github.com/vladisa88) +- **Pool `Close` health checks**: `ConnPool.Close` now only runs health checks against idle connections, avoiding spurious activity on connections still in use ([#3805](https://github.com/redis/go-redis/pull/3805)) by [@ndyakov](https://github.com/ndyakov) +- **VLinks return type**: Fixed the return type of `VLINKS`/`VLINKSWITHSCORES` vector-set replies ([#3820](https://github.com/redis/go-redis/pull/3820)) by [@romanpovol](https://github.com/romanpovol) + +## 🧪 Testing & Infrastructure + +- **Flaky tests**: Stabilized several flaky tests in the sentinel and pool suites ([#3815](https://github.com/redis/go-redis/pull/3815)) by [@ndyakov](https://github.com/ndyakov) +- **Sentinel failover metric race**: Fixed a data race in the sentinel failover metric test ([#3824](https://github.com/redis/go-redis/pull/3824)) by [@cxljs](https://github.com/cxljs) +- **`waitForSentinelClusterStable` post-conditions**: The sentinel test harness now waits for replicas to be fully connected (not just present in the count) and is robust to randomized spec ordering after failover specs, eliminating an intermittent `Expected master to equal slave` flake ([#3830](https://github.com/redis/go-redis/pull/3830)) by [@ndyakov](https://github.com/ndyakov) +- **`govulncheck` workflow**: New scheduled GitHub Actions workflow runs `govulncheck` on every push, PR, and weekly, surfacing newly disclosed Go vulnerabilities even when no code changes ([#3779](https://github.com/redis/go-redis/pull/3779)) by [@solardome](https://github.com/solardome) +- **CI Redis 8.8-rc1**: CI now exercises the 8.8-rc1 Redis image ([#3814](https://github.com/redis/go-redis/pull/3814)) by [@ofekshenawa](https://github.com/ofekshenawa) + +## 🧰 Maintenance + +- **`Cmd.Slot()` lookup refactor**: Caches the per-command `CommandInfo` and short-circuits keyless commands before the switch dispatch, removing redundant `Peek` calls ([#3804](https://github.com/redis/go-redis/pull/3804)) by [@retr0-kernel](https://github.com/retr0-kernel) +- **stdlib `math/rand`**: Replaced `internal/rand` with `math/rand` from the standard library now that the minimum Go version is 1.24 ([#3823](https://github.com/redis/go-redis/pull/3823)) by [@cxljs](https://github.com/cxljs) +- **ConnPool queue channel**: Removed the unused queue channel from `ConnPool`, trimming the pool's footprint ([#3826](https://github.com/redis/go-redis/pull/3826)) by [@cxljs](https://github.com/cxljs) +- **Extra packages LICENSE**: Added a LICENSE file to each `extra/*` package ([#3817](https://github.com/redis/go-redis/pull/3817)) by [@ndyakov](https://github.com/ndyakov) +- **README & CI image**: Documentation refresh and bumped the default CI image tag ([#3822](https://github.com/redis/go-redis/pull/3822)) by [@ndyakov](https://github.com/ndyakov) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@cxljs](https://github.com/cxljs), [@DengY11](https://github.com/DengY11), [@elena-kolevska](https://github.com/elena-kolevska), [@Khukharr](https://github.com/Khukharr), [@LINKIWI](https://github.com/LINKIWI), [@ndyakov](https://github.com/ndyakov), [@obiyang](https://github.com/obiyang), [@ofekshenawa](https://github.com/ofekshenawa), [@retr0-kernel](https://github.com/retr0-kernel), [@romanpovol](https://github.com/romanpovol), [@shahyash2609](https://github.com/shahyash2609), [@solardome](https://github.com/solardome), [@vladisa88](https://github.com/vladisa88) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.19.0...v9.20.0 + +# 9.19.0 (2026-04-27) + +## 🚀 Highlights + +### FIPS-Compatible Script Helper + +`Script` now supports a FIPS-safe execution mode that avoids client-side SHA-1 computation, which is blocked in strict FIPS environments. A new `NewScriptServerSHA` constructor uses `SCRIPT LOAD` to obtain and cache the digest from the server, then runs commands via `EVALSHA`/`EVALSHA_RO`. Falls back to `EVAL`/`EVALRO` if loading fails, and transparently retries once on `NOSCRIPT`. The default behavior is unchanged for existing users. + +([#3700](https://github.com/redis/go-redis/pull/3700)) by [@chaitanyabodlapati](https://github.com/chaitanyabodlapati) + +### FT.AGGREGATE Step-Based Pipeline Builder + +Added a new step-based `FT.AGGREGATE` pipeline API via `FTAggregateOptions.Steps`, allowing `LOAD`, `APPLY`, `GROUPBY`, and `SORTBY` (with per-step `MAX`) to be repeated and interleaved in arbitrary order — matching Redis's native multi-stage aggregation semantics. The legacy `Load`/`Apply`/`GroupBy`/`SortBy`/`SortByMax` fields are now deprecated. + +([#3782](https://github.com/redis/go-redis/pull/3782)) by [@ndyakov](https://github.com/ndyakov) + +### Raw RESP Protocol Access + +Added `DoRaw` and `DoRawWriteTo` methods for executing arbitrary commands and reading the raw RESP response. Useful for proxying, custom protocol inspection, and working with commands not yet wrapped by go-redis. + +([#3713](https://github.com/redis/go-redis/pull/3713)) by [@ofekshenawa](https://github.com/ofekshenawa) + +### Configurable Dial Retry Backoff + +Added `DialerRetryBackoff` option (plumbed through `Options`, `ClusterOptions`, `RingOptions`, `FailoverOptions`) to let callers customize the delay between failed dial attempts. Helpers `DialRetryBackoffConstant` and `DialRetryBackoffExponential` (with jitter and cap) are provided out of the box. Dial timeout is now also applied **per attempt** rather than across all retries. + +([#3706](https://github.com/redis/go-redis/pull/3706), [#3705](https://github.com/redis/go-redis/pull/3705)) by [@mwhooker](https://github.com/mwhooker) + +## ✨ New Features + +- **FT.AGGREGATE Steps**: Step-based pipeline builder for `FT.AGGREGATE` with support for repeated/interleaved `LOAD`, `APPLY`, `GROUPBY`, and `SORTBY` stages ([#3782](https://github.com/redis/go-redis/pull/3782)) by [@ndyakov](https://github.com/ndyakov) +- **VectorSet commands**: Added `VISMEMBER` and `WITHATTRIBS` support ([#3753](https://github.com/redis/go-redis/pull/3753)) by [@romanpovol](https://github.com/romanpovol) +- **FIPS-safe Script**: `NewScriptServerSHA` uses `SCRIPT LOAD` to obtain the digest from the server, avoiding client-side SHA-1 ([#3700](https://github.com/redis/go-redis/pull/3700)) by [@chaitanyabodlapati](https://github.com/chaitanyabodlapati) +- **Raw RESP access**: `DoRaw` and `DoRawWriteTo` for raw RESP protocol access ([#3713](https://github.com/redis/go-redis/pull/3713)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **Dial retry backoff**: `DialerRetryBackoff` function option with constant and exponential helpers ([#3706](https://github.com/redis/go-redis/pull/3706)) by [@mwhooker](https://github.com/mwhooker) +- **Typed NOSCRIPT error**: Redis `NOSCRIPT` replies are now surfaced as a typed error for easier handling ([#3738](https://github.com/redis/go-redis/pull/3738)) by [@LINKIWI](https://github.com/LINKIWI) +- **PubSub ClientSetName**: Added `ClientSetName` method to `PubSub` ([#3727](https://github.com/redis/go-redis/pull/3727)) by [@Flack74](https://github.com/Flack74) +- **ReplicaOf**: New `ReplicaOf` method replaces the deprecated `SlaveOf` ([#3720](https://github.com/redis/go-redis/pull/3720)) by [@Copilot](https://github.com/apps/copilot-swe-agent) +- **HSCAN BinaryUnmarshaler**: `HScan` now supports types implementing `encoding.BinaryUnmarshaler` ([#3768](https://github.com/redis/go-redis/pull/3768)) by [@Aaditya-dubey1](https://github.com/Aaditya-dubey1) + +## 🐛 Bug Fixes + +- **Auto hostname type detection**: Improved endpoint type detection for maintenance notifications using DNS-based classification; handles empty hosts and expanded private-IP ranges ([#3789](https://github.com/redis/go-redis/pull/3789)) by [@ndyakov](https://github.com/ndyakov) +- **HELLO fallback**: Don't send `CLIENT MAINT_NOTIFICATIONS` handshake when `HELLO` fails and connection falls back to RESP2; fail fast when explicitly enabled with RESP3 ([#3788](https://github.com/redis/go-redis/pull/3788)) by [@ndyakov](https://github.com/ndyakov) +- **Dial TCP retry**: `ShouldRetry` now treats `net.OpError` with `Op == "dial"` timeout errors as safe to retry since no command was sent ([#3787](https://github.com/redis/go-redis/pull/3787)) by [@vladisa88](https://github.com/vladisa88) +- **wrappedOnClose leak**: Fixed resource leak caused by repeatedly wrapping `baseClient` close logic; replaced with a bounded, concurrency-safe named-hook registry ([#3785](https://github.com/redis/go-redis/pull/3785)) by [@ndyakov](https://github.com/ndyakov) +- **Pool Close() on stale connections**: Suppress close errors (e.g., TLS `closeNotify` timeouts) for connections already dropped by the server due to idle timeout ([#3778](https://github.com/redis/go-redis/pull/3778)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **FIFO waiter ordering**: Fixed race in `ConnStateMachine.notifyWaiters` that could wake multiple waiters under a single mutex hold and violate FIFO ordering ([#3777](https://github.com/redis/go-redis/pull/3777)) by [@0x48core](https://github.com/0x48core) +- **Lua READONLY detection**: Detect `READONLY` errors embedded in Lua script error messages on read-only replicas so commands are correctly retried ([#3769](https://github.com/redis/go-redis/pull/3769)) by [@zhengjilei](https://github.com/zhengjilei) +- **VectorScoreSliceCmd RESP2**: Fixed `VSimWithScores`, `VSimWithArgsWithScores`, and `VLinksWithScores` which were broken on RESP2 connections returning flat arrays instead of maps ([#3767](https://github.com/redis/go-redis/pull/3767)) by [@Copilot](https://github.com/apps/copilot-swe-agent) +- **Closed connection handling**: Two fixes for closed connection handling in the pool ([#3764](https://github.com/redis/go-redis/pull/3764)) by [@cxljs](https://github.com/cxljs) +- **ZRangeArgs Rev**: Fixed `ZRangeArgs` with `Rev` + `ByScore`/`ByLex` incorrectly swapping `Start`/`Stop`, breaking `ZRANGESTORE` ([#3751](https://github.com/redis/go-redis/pull/3751)) by [@Copilot](https://github.com/apps/copilot-swe-agent) +- **OTel metric instrument types**: Fixed metric instrument types in `redisotel-native` ([#3743](https://github.com/redis/go-redis/pull/3743)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **Options.clone() data race**: Fixed data race when cloning `Options` ([#3739](https://github.com/redis/go-redis/pull/3739)) by [@rubensayshi](https://github.com/rubensayshi) +- **Connection closure metrics**: Fixed connection closure metrics and enabled all metric groups by default in `redisotel-native` ([#3735](https://github.com/redis/go-redis/pull/3735)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **OTel semconv v1.38.0**: Use metric definition from `otel/semconv/v1.38.0` in `redisotel-native` ([#3731](https://github.com/redis/go-redis/pull/3731)) by [@wzy9607](https://github.com/wzy9607) +- **SETNX semantics**: Use `SET ... NX` instead of the deprecated `SETNX` command ([#3723](https://github.com/redis/go-redis/pull/3723)) by [@ndyakov](https://github.com/ndyakov) +- **TIME keyless routing**: Mark `TIME` as a keyless command for correct cluster routing ([#3722](https://github.com/redis/go-redis/pull/3722)) by [@fatal10110](https://github.com/fatal10110) +- **Dial timeout per retry**: Dial timeout now applies per attempt instead of across all retry attempts combined ([#3705](https://github.com/redis/go-redis/pull/3705)) by [@mwhooker](https://github.com/mwhooker) +- **Cluster metrics attributes**: Fixed `pool.name` being appended per node, which corrupted and dropped user-provided custom attributes ([#3699](https://github.com/redis/go-redis/pull/3699)) by [@Jesse-Bonfire](https://github.com/Jesse-Bonfire) +- **initConn nil dereference**: Fixed nil pointer dereference and potential deadlock in `*baseClient.initConn()`; added explicit nil option guards to client constructors ([#3676](https://github.com/redis/go-redis/pull/3676)) by [@olde-ducke](https://github.com/olde-ducke) + +## ⚡ Performance + +- **RESP reader**: Optimized RESP reader by eliminating intermediate string allocations ([#3774](https://github.com/redis/go-redis/pull/3774)) by [@Aaditya-dubey1](https://github.com/Aaditya-dubey1) +- **Inline rendezvous hashing**: Replaced `github.com/dgryski/go-rendezvous` dependency with an in-repo implementation in `internal/hashtag`, reducing the dependency graph while preserving algorithm parity ([#3762](https://github.com/redis/go-redis/pull/3762)) by [@bigsk05](https://github.com/bigsk05) + +## 🧪 Testing & Infrastructure + +- **Release automation**: Added `repository`, `ref`, and `client-libs-test-image-tag` inputs to the `run-tests` composite action; `redis-version` is now optional so unstable builds use `REDIS_VERSION` from the Makefile ([#3749](https://github.com/redis/go-redis/pull/3749)) by [@dariaguy](https://github.com/dariaguy) +- **Go 1.24**: Updated minimum Go version to 1.24 and use `-compat=1.24` in release scripts ([#3714](https://github.com/redis/go-redis/pull/3714), [#3754](https://github.com/redis/go-redis/pull/3754)) by [@ndyakov](https://github.com/ndyakov), [@cxljs](https://github.com/cxljs) + +## 🧰 Maintenance + +- **Pool state machine**: Removed redundant `Conn.closed` atomic field in favor of the state machine's `StateClosed` ([#3783](https://github.com/redis/go-redis/pull/3783)) by [@cxljs](https://github.com/cxljs) +- **OTel SDK**: Updated OpenTelemetry SDK dependencies in `redisotel`/`redisotel-native` ([#3770](https://github.com/redis/go-redis/pull/3770)) by [@ndyakov](https://github.com/ndyakov) +- **Go 1.21+ built-ins**: Use `maps.Keys`, `slices.Collect`, `slices.Contains`, `clear()`, and `slices.SortFunc` instead of custom helpers ([#3758](https://github.com/redis/go-redis/pull/3758), [#3746](https://github.com/redis/go-redis/pull/3746)) by [@cxljs](https://github.com/cxljs) +- **HGetAll docs**: Added Go doc comment to `HGetAll` describing behavior and complexity ([#3776](https://github.com/redis/go-redis/pull/3776)) by [@0x48core](https://github.com/0x48core) +- **Docs links**: Fixed irrelevant docs links ([#3724](https://github.com/redis/go-redis/pull/3724)) by [@olzhas-sabiyev](https://github.com/olzhas-sabiyev) +- **Examples cleanup**: Removed throughput binary from examples ([#3733](https://github.com/redis/go-redis/pull/3733)) by [@ndyakov](https://github.com/ndyakov) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@0x48core](https://github.com/0x48core), [@Aaditya-dubey1](https://github.com/Aaditya-dubey1), [@Copilot](https://github.com/apps/copilot-swe-agent), [@Flack74](https://github.com/Flack74), [@Jesse-Bonfire](https://github.com/Jesse-Bonfire), [@LINKIWI](https://github.com/LINKIWI), [@bigsk05](https://github.com/bigsk05), [@chaitanyabodlapati](https://github.com/chaitanyabodlapati), [@cxljs](https://github.com/cxljs), [@dariaguy](https://github.com/dariaguy), [@fatal10110](https://github.com/fatal10110), [@mwhooker](https://github.com/mwhooker), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@olde-ducke](https://github.com/olde-ducke), [@olzhas-sabiyev](https://github.com/olzhas-sabiyev), [@romanpovol](https://github.com/romanpovol), [@rubensayshi](https://github.com/rubensayshi), [@vladisa88](https://github.com/vladisa88), [@wzy9607](https://github.com/wzy9607), [@zhengjilei](https://github.com/zhengjilei) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.18.0...v9.19.0 + +# 9.18.0 (2026-02-16) + +## 🚀 Highlights + +### Redis 8.6 Support + +Added support for Redis 8.6, including new commands and features for streams idempotent production and HOTKEYS. + +### Smart Client Handoff (Maintenance Notifications) for Cluster + +This release introduces comprehensive support for Redis Cluster maintenance notifications via SMIGRATING/SMIGRATED push notifications. The client now automatically handles slot migrations by: +- **Relaxing timeouts during migration** (SMIGRATING) to prevent false failures +- **Triggering lazy cluster state reloads** upon completion (SMIGRATED) +- Enabling seamless operations during Redis Enterprise maintenance windows + +([#3643](https://github.com/redis/go-redis/pull/3643)) by [@ndyakov](https://github.com/ndyakov) + +### OpenTelemetry Native Metrics Support + +Added comprehensive OpenTelemetry metrics support following the [OpenTelemetry Database Client Semantic Conventions](https://opentelemetry.io/docs/specs/semconv/database/database-metrics/). The implementation uses a Bridge Pattern to keep the core library dependency-free while providing optional metrics instrumentation through the new `extra/redisotel-native` package. + +**Metric groups include:** +- Command metrics: Operation duration with retry tracking +- Connection basic: Connection count and creation time +- Resiliency: Errors, handoffs, timeout relaxation +- Connection advanced: Wait time and use time +- Pubsub metrics: Published and received messages +- Stream metrics: Processing duration and maintenance notifications + +([#3637](https://github.com/redis/go-redis/pull/3637)) by [@ofekshenawa](https://github.com/ofekshenawa) + +## ✨ New Features + +- **HOTKEYS Commands**: Added support for Redis HOTKEYS feature for identifying hot keys based on CPU consumption and network utilization ([#3695](https://github.com/redis/go-redis/pull/3695)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **Streams Idempotent Production**: Added support for Redis 8.6+ Streams Idempotent Production with `ProducerID`, `IdempotentID`, `IdempotentAuto` in `XAddArgs` and new `XCFGSET` command ([#3693](https://github.com/redis/go-redis/pull/3693)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **NaN Values for TimeSeries**: Added support for NaN (Not a Number) values in Redis time series commands ([#3687](https://github.com/redis/go-redis/pull/3687)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **DialerRetries Options**: Added `DialerRetries` and `DialerRetryTimeout` to `ClusterOptions`, `RingOptions`, and `FailoverOptions` ([#3686](https://github.com/redis/go-redis/pull/3686)) by [@naveenchander30](https://github.com/naveenchander30) +- **ConnMaxLifetimeJitter**: Added jitter configuration to distribute connection expiration times and prevent thundering herd ([#3666](https://github.com/redis/go-redis/pull/3666)) by [@cyningsun](https://github.com/cyningsun) +- **Digest Helper Functions**: Added `DigestString` and `DigestBytes` helper functions for client-side xxh3 hashing compatible with Redis DIGEST command ([#3679](https://github.com/redis/go-redis/pull/3679)) by [@ofekshenawa](https://github.com/ofekshenawa) +- **SMIGRATED New Format**: Updated SMIGRATED parser to support new format and remember original host:port ([#3697](https://github.com/redis/go-redis/pull/3697)) by [@ndyakov](https://github.com/ndyakov) +- **Cluster State Reload Interval**: Added cluster state reload interval option for maintenance notifications ([#3663](https://github.com/redis/go-redis/pull/3663)) by [@ndyakov](https://github.com/ndyakov) + +## 🐛 Bug Fixes + +- **PubSub nil pointer dereference**: Fixed nil pointer dereference in PubSub after `WithTimeout()` - `pubSubPool` is now properly cloned ([#3710](https://github.com/redis/go-redis/pull/3710)) by [@Copilot](https://github.com/apps/copilot-swe-agent) +- **MaintNotificationsConfig nil check**: Guard against nil `MaintNotificationsConfig` in `initConn` ([#3707](https://github.com/redis/go-redis/pull/3707)) by [@veeceey](https://github.com/veeceey) +- **wantConnQueue zombie elements**: Fixed zombie `wantConn` elements accumulation in `wantConnQueue` ([#3680](https://github.com/redis/go-redis/pull/3680)) by [@cyningsun](https://github.com/cyningsun) +- **XADD/XTRIM approx flag**: Fixed XADD and XTRIM to use `=` when approx is false ([#3684](https://github.com/redis/go-redis/pull/3684)) by [@ndyakov](https://github.com/ndyakov) +- **Sentinel timeout retry**: When connection to a sentinel times out, attempt to connect to other sentinels ([#3654](https://github.com/redis/go-redis/pull/3654)) by [@cxljs](https://github.com/cxljs) + +## ⚡ Performance + +- **Fuzz test optimization**: Eliminated repeated string conversions, used functional approach for cleaner operation selection ([#3692](https://github.com/redis/go-redis/pull/3692)) by [@feiguoL](https://github.com/feiguoL) +- **Pre-allocate capacity**: Pre-allocate slice capacity to prevent multiple capacity expansions ([#3689](https://github.com/redis/go-redis/pull/3689)) by [@feelshu](https://github.com/feelshu) + +## 🧪 Testing + +- **Comprehensive TLS tests**: Added comprehensive TLS tests and example for standalone, cluster, and certificate authentication ([#3681](https://github.com/redis/go-redis/pull/3681)) by [@ndyakov](https://github.com/ndyakov) +- **Redis 8.6**: Updated CI to use Redis 8.6-pre ([#3685](https://github.com/redis/go-redis/pull/3685)) by [@ndyakov](https://github.com/ndyakov) + +## 🧰 Maintenance + +- **Deprecation warnings**: Added deprecation warnings for commands based on Redis documentation ([#3673](https://github.com/redis/go-redis/pull/3673)) by [@ndyakov](https://github.com/ndyakov) +- **Use errors.Join()**: Replaced custom error join function with standard library `errors.Join()` ([#3653](https://github.com/redis/go-redis/pull/3653)) by [@cxljs](https://github.com/cxljs) +- **Use Go 1.21 min/max**: Use Go 1.21's built-in min/max functions ([#3656](https://github.com/redis/go-redis/pull/3656)) by [@cxljs](https://github.com/cxljs) +- **Proper formatting**: Code formatting improvements ([#3670](https://github.com/redis/go-redis/pull/3670)) by [@12ya](https://github.com/12ya) +- **Set commands documentation**: Added comprehensive documentation to all set command methods ([#3642](https://github.com/redis/go-redis/pull/3642)) by [@iamamirsalehi](https://github.com/iamamirsalehi) +- **MaxActiveConns docs**: Added default value documentation for `MaxActiveConns` ([#3674](https://github.com/redis/go-redis/pull/3674)) by [@codykaup](https://github.com/codykaup) +- **README example update**: Updated README example ([#3657](https://github.com/redis/go-redis/pull/3657)) by [@cxljs](https://github.com/cxljs) +- **Cluster maintnotif example**: Added example application for cluster maintenance notifications ([#3651](https://github.com/redis/go-redis/pull/3651)) by [@ndyakov](https://github.com/ndyakov) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@12ya](https://github.com/12ya), [@Copilot](https://github.com/apps/copilot-swe-agent), [@codykaup](https://github.com/codykaup), [@cxljs](https://github.com/cxljs), [@cyningsun](https://github.com/cyningsun), [@feelshu](https://github.com/feelshu), [@feiguoL](https://github.com/feiguoL), [@iamamirsalehi](https://github.com/iamamirsalehi), [@naveenchander30](https://github.com/naveenchander30), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@veeceey](https://github.com/veeceey) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.0...v9.18.0 + +# 9.18.0-beta.2 (2025-12-09) + +## 🚀 Highlights + +### Go Version Update + +This release updates the minimum required Go version to 1.21. This is part of a gradual migration strategy where the minimum supported Go version will be three versions behind the latest release. With each new Go version release, we will bump the minimum version by one, ensuring compatibility while staying current with the Go ecosystem. + +### Stability Improvements + +This release includes several important stability fixes: +- Fixed a critical panic in the handoff worker manager that could occur when handling nil errors +- Improved test reliability for Smart Client Handoff functionality +- Fixed logging format issues that could cause runtime errors + +## ✨ New Features + +- OpenTelemetry metrics improvements for nil response handling ([#3638](https://github.com/redis/go-redis/pull/3638)) by [@fengve](https://github.com/fengve) + +## 🐛 Bug Fixes + +- Fixed panic on nil error in handoffWorkerManager closeConnFromRequest ([#3633](https://github.com/redis/go-redis/pull/3633)) by [@ccoVeille](https://github.com/ccoVeille) +- Fixed bad sprintf syntax in logging ([#3632](https://github.com/redis/go-redis/pull/3632)) by [@ccoVeille](https://github.com/ccoVeille) + +## 🧰 Maintenance + +- Updated minimum Go version to 1.21 ([#3640](https://github.com/redis/go-redis/pull/3640)) by [@ndyakov](https://github.com/ndyakov) +- Use Go 1.20 idiomatic string<->byte conversion ([#3435](https://github.com/redis/go-redis/pull/3435)) by [@justinhwang](https://github.com/justinhwang) +- Reduce flakiness of Smart Client Handoff test ([#3641](https://github.com/redis/go-redis/pull/3641)) by [@kiryazovi-redis](https://github.com/kiryazovi-redis) +- Revert PR #3634 (Observability metrics phase1) ([#3635](https://github.com/redis/go-redis/pull/3635)) by [@ofekshenawa](https://github.com/ofekshenawa) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@justinhwang](https://github.com/justinhwang), [@ndyakov](https://github.com/ndyakov), [@kiryazovi-redis](https://github.com/kiryazovi-redis), [@fengve](https://github.com/fengve), [@ccoVeille](https://github.com/ccoVeille), [@ofekshenawa](https://github.com/ofekshenawa) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.18.0-beta.1...v9.18.0-beta.2 + +# 9.18.0-beta.1 (2025-12-01) + +## 🚀 Highlights + +### Request and Response Policy Based Routing in Cluster Mode + +This beta release introduces comprehensive support for Redis COMMAND-based request and response policy routing for cluster clients. This feature enables intelligent command routing and response aggregation based on Redis command metadata. + +**Key Features:** +- **Command Policy Loader**: Automatically parses and caches COMMAND metadata with routing/aggregation hints +- **Enhanced Routing Engine**: Supports all request policies including: + - `default(keyless)` - Commands without keys + - `default(hashslot)` - Commands with hash slot routing + - `all_shards` - Commands that need to run on all shards + - `all_nodes` - Commands that need to run on all nodes + - `multi_shard` - Commands that span multiple shards + - `special` - Commands with custom routing logic +- **Response Aggregator**: Intelligently combines multi-shard replies based on response policies: + - `all_succeeded` - All shards must succeed + - `one_succeeded` - At least one shard must succeed + - `agg_sum` - Aggregate numeric responses + - `special` - Custom aggregation logic (e.g., FT.CURSOR) +- **Raw Command Support**: Policies are enforced on `Client.Do(ctx, args...)` + +This feature is particularly useful for Redis Stack commands like RediSearch that need to operate across multiple shards in a cluster. + +### Connection Pool Improvements + +Fixed a critical defect in the connection pool's turn management mechanism that could lead to connection leaks under certain conditions. The fix ensures proper 1:1 correspondence between turns and connections. + +## ✨ New Features + +- Request and Response Policy Based Routing in Cluster Mode ([#3422](https://github.com/redis/go-redis/pull/3422)) by [@ofekshenawa](https://github.com/ofekshenawa) + +## 🐛 Bug Fixes + +- Fixed connection pool turn management to prevent connection leaks ([#3626](https://github.com/redis/go-redis/pull/3626)) by [@cyningsun](https://github.com/cyningsun) + +## 🧰 Maintenance + +- chore(deps): bump rojopolis/spellcheck-github-actions from 0.54.0 to 0.55.0 ([#3627](https://github.com/redis/go-redis/pull/3627)) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@cyningsun](https://github.com/cyningsun), [@ofekshenawa](https://github.com/ofekshenawa), [@ndyakov](https://github.com/ndyakov) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.1...v9.18.0-beta.1 + +# 9.17.1 (2025-11-25) + +## 🐛 Bug Fixes + +- add wait to keyless commands list ([#3615](https://github.com/redis/go-redis/pull/3615)) by [@marcoferrer](https://github.com/marcoferrer) +- fix(time): remove cached time optimization ([#3611](https://github.com/redis/go-redis/pull/3611)) by [@ndyakov](https://github.com/ndyakov) + +## 🧰 Maintenance + +- chore(deps): bump golangci/golangci-lint-action from 9.0.0 to 9.1.0 ([#3609](https://github.com/redis/go-redis/pull/3609)) +- chore(deps): bump actions/checkout from 5 to 6 ([#3610](https://github.com/redis/go-redis/pull/3610)) +- chore(script): fix help call in tag.sh ([#3606](https://github.com/redis/go-redis/pull/3606)) by [@ndyakov](https://github.com/ndyakov) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@marcoferrer](https://github.com/marcoferrer) and [@ndyakov](https://github.com/ndyakov) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.0...v9.17.1 + +# 9.17.0 (2025-11-19) + +## 🚀 Highlights + +### Redis 8.4 Support +Added support for Redis 8.4, including new commands and features ([#3572](https://github.com/redis/go-redis/pull/3572)) + +### Typed Errors +Introduced typed errors for better error handling using `errors.As` instead of string checks. Errors can now be wrapped and set to commands in hooks without breaking library functionality ([#3602](https://github.com/redis/go-redis/pull/3602)) + +### New Commands +- **CAS/CAD Commands**: Added support for Compare-And-Set/Compare-And-Delete operations with conditional matching (`IFEQ`, `IFNE`, `IFDEQ`, `IFDNE`) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595)) +- **MSETEX**: Atomically set multiple key-value pairs with expiration options and conditional modes ([#3580](https://github.com/redis/go-redis/pull/3580)) +- **XReadGroup CLAIM**: Consume both incoming and idle pending entries from streams in a single call ([#3578](https://github.com/redis/go-redis/pull/3578)) +- **ACL Commands**: Added `ACLGenPass`, `ACLUsers`, and `ACLWhoAmI` ([#3576](https://github.com/redis/go-redis/pull/3576)) +- **SLOWLOG Commands**: Added `SLOWLOG LEN` and `SLOWLOG RESET` ([#3585](https://github.com/redis/go-redis/pull/3585)) +- **LATENCY Commands**: Added `LATENCY LATEST` and `LATENCY RESET` ([#3584](https://github.com/redis/go-redis/pull/3584)) + +### Search & Vector Improvements +- **Hybrid Search**: Added **EXPERIMENTAL** support for the new `FT.HYBRID` command ([#3573](https://github.com/redis/go-redis/pull/3573)) +- **Vector Range**: Added `VRANGE` command for vector sets ([#3543](https://github.com/redis/go-redis/pull/3543)) +- **FT.INFO Enhancements**: Added vector-specific attributes in FT.INFO response ([#3596](https://github.com/redis/go-redis/pull/3596)) + +### Connection Pool Improvements +- **Improved Connection Success Rate**: Implemented FIFO queue-based fairness and context pattern for connection creation to prevent premature cancellation under high concurrency ([#3518](https://github.com/redis/go-redis/pull/3518)) +- **Connection State Machine**: Resolved race conditions and improved pool performance with proper state tracking ([#3559](https://github.com/redis/go-redis/pull/3559)) +- **Pool Performance**: Significant performance improvements with faster semaphores, lockless hook manager, and reduced allocations (47-67% faster Get/Put operations) ([#3565](https://github.com/redis/go-redis/pull/3565)) + +### Metrics & Observability +- **Canceled Metric Attribute**: Added 'canceled' metrics attribute to distinguish context cancellation errors from other errors ([#3566](https://github.com/redis/go-redis/pull/3566)) + +## ✨ New Features + +- Typed errors with wrapping support ([#3602](https://github.com/redis/go-redis/pull/3602)) by [@ndyakov](https://github.com/ndyakov) +- CAS/CAD commands (marked as experimental) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595)) by [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis) +- MSETEX command support ([#3580](https://github.com/redis/go-redis/pull/3580)) by [@ofekshenawa](https://github.com/ofekshenawa) +- XReadGroup CLAIM argument ([#3578](https://github.com/redis/go-redis/pull/3578)) by [@ofekshenawa](https://github.com/ofekshenawa) +- ACL commands: GenPass, Users, WhoAmI ([#3576](https://github.com/redis/go-redis/pull/3576)) by [@destinyoooo](https://github.com/destinyoooo) +- SLOWLOG commands: LEN, RESET ([#3585](https://github.com/redis/go-redis/pull/3585)) by [@destinyoooo](https://github.com/destinyoooo) +- LATENCY commands: LATEST, RESET ([#3584](https://github.com/redis/go-redis/pull/3584)) by [@destinyoooo](https://github.com/destinyoooo) +- Hybrid search command (FT.HYBRID) ([#3573](https://github.com/redis/go-redis/pull/3573)) by [@htemelski-redis](https://github.com/htemelski-redis) +- Vector range command (VRANGE) ([#3543](https://github.com/redis/go-redis/pull/3543)) by [@cxljs](https://github.com/cxljs) +- Vector-specific attributes in FT.INFO ([#3596](https://github.com/redis/go-redis/pull/3596)) by [@ndyakov](https://github.com/ndyakov) +- Improved connection pool success rate with FIFO queue ([#3518](https://github.com/redis/go-redis/pull/3518)) by [@cyningsun](https://github.com/cyningsun) +- Canceled metrics attribute for context errors ([#3566](https://github.com/redis/go-redis/pull/3566)) by [@pvragov](https://github.com/pvragov) + +## 🐛 Bug Fixes + +- Fixed Failover Client MaintNotificationsConfig ([#3600](https://github.com/redis/go-redis/pull/3600)) by [@ajax16384](https://github.com/ajax16384) +- Fixed ACLGenPass function to use the bit parameter ([#3597](https://github.com/redis/go-redis/pull/3597)) by [@destinyoooo](https://github.com/destinyoooo) +- Return error instead of panic from commands ([#3568](https://github.com/redis/go-redis/pull/3568)) by [@dragneelfps](https://github.com/dragneelfps) +- Safety harness in `joinErrors` to prevent panic ([#3577](https://github.com/redis/go-redis/pull/3577)) by [@manisharma](https://github.com/manisharma) + +## ⚡ Performance + +- Connection state machine with race condition fixes ([#3559](https://github.com/redis/go-redis/pull/3559)) by [@ndyakov](https://github.com/ndyakov) +- Pool performance improvements: 47-67% faster Get/Put, 33% less memory, 50% fewer allocations ([#3565](https://github.com/redis/go-redis/pull/3565)) by [@ndyakov](https://github.com/ndyakov) + +## 🧪 Testing & Infrastructure + +- Updated to Redis 8.4.0 image ([#3603](https://github.com/redis/go-redis/pull/3603)) by [@ndyakov](https://github.com/ndyakov) +- Added Redis 8.4-RC1-pre to CI ([#3572](https://github.com/redis/go-redis/pull/3572)) by [@ndyakov](https://github.com/ndyakov) +- Refactored tests for idiomatic Go ([#3561](https://github.com/redis/go-redis/pull/3561), [#3562](https://github.com/redis/go-redis/pull/3562), [#3563](https://github.com/redis/go-redis/pull/3563)) by [@12ya](https://github.com/12ya) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@12ya](https://github.com/12ya), [@ajax16384](https://github.com/ajax16384), [@cxljs](https://github.com/cxljs), [@cyningsun](https://github.com/cyningsun), [@destinyoooo](https://github.com/destinyoooo), [@dragneelfps](https://github.com/dragneelfps), [@htemelski-redis](https://github.com/htemelski-redis), [@manisharma](https://github.com/manisharma), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@pvragov](https://github.com/pvragov) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.16.0...v9.17.0 + +# 9.16.0 (2025-10-23) + +## 🚀 Highlights + +### Maintenance Notifications Support + +This release introduces comprehensive support for Redis maintenance notifications, enabling applications to handle server maintenance events gracefully. The new `maintnotifications` package provides: + +- **RESP3 Push Notifications**: Full support for Redis RESP3 protocol push notifications +- **Connection Handoff**: Automatic connection migration during server maintenance with configurable retry policies and circuit breakers +- **Graceful Degradation**: Configurable timeout relaxation during maintenance windows to prevent false failures +- **Event-Driven Architecture**: Background workers with on-demand scaling for efficient handoff processing +- **Production-Ready**: Comprehensive E2E testing framework and monitoring capabilities + +For detailed usage examples and configuration options, see the [maintenance notifications documentation](maintnotifications/README.md). + +## ✨ New Features + +- **Trace Filtering**: Add support for filtering traces for specific commands, including pipeline operations and dial operations ([#3519](https://github.com/redis/go-redis/pull/3519), [#3550](https://github.com/redis/go-redis/pull/3550)) + - New `TraceCmdFilter` option to selectively trace commands + - Reduces overhead by excluding high-frequency or low-value commands from traces + +## 🐛 Bug Fixes + +- **Pipeline Error Handling**: Fix issue where pipeline repeatedly sets the same error ([#3525](https://github.com/redis/go-redis/pull/3525)) +- **Connection Pool**: Ensure re-authentication does not interfere with connection handoff operations ([#3547](https://github.com/redis/go-redis/pull/3547)) + +## 🔧 Improvements + +- **Hash Commands**: Update hash command implementations ([#3523](https://github.com/redis/go-redis/pull/3523)) +- **OpenTelemetry**: Use `metric.WithAttributeSet` to avoid unnecessary attribute copying in redisotel ([#3552](https://github.com/redis/go-redis/pull/3552)) + +## 📚 Documentation + +- **Cluster Client**: Add explanation for why `MaxRetries` is disabled for `ClusterClient` ([#3551](https://github.com/redis/go-redis/pull/3551)) + +## 🧪 Testing & Infrastructure + +- **E2E Testing**: Upgrade E2E testing framework with improved reliability and coverage ([#3541](https://github.com/redis/go-redis/pull/3541)) +- **Release Process**: Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530)) + +## 📦 Dependencies + +- Bump `rojopolis/spellcheck-github-actions` from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520)) +- Bump `github/codeql-action` from 3 to 4 ([#3544](https://github.com/redis/go-redis/pull/3544)) + +## 👥 Contributors + +We'd like to thank all the contributors who worked on this release! + +[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@Sovietaced](https://github.com/Sovietaced), [@Udhayarajan](https://github.com/Udhayarajan), [@boekkooi-impossiblecloud](https://github.com/boekkooi-impossiblecloud), [@Pika-Gopher](https://github.com/Pika-Gopher), [@cxljs](https://github.com/cxljs), [@huiyifyj](https://github.com/huiyifyj), [@omid-h70](https://github.com/omid-h70) + +--- + +**Full Changelog**: https://github.com/redis/go-redis/compare/v9.14.0...v9.16.0 + + +# 9.15.0 was accidentally released. Please use version 9.16.0 instead. + +# 9.15.0-beta.3 (2025-09-26) + +## Highlights +This beta release includes a pre-production version of processing push notifications and hitless upgrades. + +# Changes + +- chore: Update hash_commands.go ([#3523](https://github.com/redis/go-redis/pull/3523)) + +## 🚀 New Features + +- feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418)) + +## 🐛 Bug Fixes + +- fix: pipeline repeatedly sets the error ([#3525](https://github.com/redis/go-redis/pull/3525)) + +## 🧰 Maintenance + +- chore(deps): bump rojopolis/spellcheck-github-actions from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520)) +- feat(e2e-testing): maintnotifications e2e and refactor ([#3526](https://github.com/redis/go-redis/pull/3526)) +- feat(tag.sh): Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@cxljs](https://github.com/cxljs), [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), and [@omid-h70](https://github.com/omid-h70) + + +# 9.15.0-beta.1 (2025-09-10) + +## Highlights +This beta release includes a pre-production version of processing push notifications and hitless upgrades. + +### Hitless Upgrades +Hitless upgrades is a major new feature that allows for zero-downtime upgrades in Redis clusters. +You can find more information in the [Hitless Upgrades documentation](https://github.com/redis/go-redis/tree/master/hitless). + +# Changes + +## 🚀 New Features +- [CAE-1088] & [CAE-1072] feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@ofekshenawa](https://github.com/ofekshenawa) + + +# 9.14.0 (2025-09-10) + +## Highlights +- Added batch process method to the pipeline ([#3510](https://github.com/redis/go-redis/pull/3510)) + +# Changes + +## 🚀 New Features + +- Added batch process method to the pipeline ([#3510](https://github.com/redis/go-redis/pull/3510)) + +## 🐛 Bug Fixes + +- fix: SetErr on Cmd if the command cannot be queued correctly in multi/exec ([#3509](https://github.com/redis/go-redis/pull/3509)) + +## 🧰 Maintenance + +- Updates release drafter config to exclude dependabot ([#3511](https://github.com/redis/go-redis/pull/3511)) +- chore(deps): bump actions/setup-go from 5 to 6 ([#3504](https://github.com/redis/go-redis/pull/3504)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@elena-kolevska](https://github.com/elena-kolevksa), [@htemelski-redis](https://github.com/htemelski-redis) and [@ndyakov](https://github.com/ndyakov) + + +# 9.13.0 (2025-09-03) + +## Highlights +- Pipeliner expose queued commands ([#3496](https://github.com/redis/go-redis/pull/3496)) +- Ensure that JSON.GET returns Nil response ([#3470](https://github.com/redis/go-redis/pull/3470)) +- Fixes on Read and Write buffer sizes and UniversalOptions + +## Changes +- Pipeliner expose queued commands ([#3496](https://github.com/redis/go-redis/pull/3496)) +- fix(test): fix a timing issue in pubsub test ([#3498](https://github.com/redis/go-redis/pull/3498)) +- Allow users to enable read-write splitting in failover mode. ([#3482](https://github.com/redis/go-redis/pull/3482)) +- Set the read/write buffer size of the sentinel client to 4KiB ([#3476](https://github.com/redis/go-redis/pull/3476)) + +## 🚀 New Features + +- fix(otel): register wait metrics ([#3499](https://github.com/redis/go-redis/pull/3499)) +- Support subscriptions against cluster slave nodes ([#3480](https://github.com/redis/go-redis/pull/3480)) +- Add wait metrics to otel ([#3493](https://github.com/redis/go-redis/pull/3493)) +- Clean failing timeout implementation ([#3472](https://github.com/redis/go-redis/pull/3472)) + +## 🐛 Bug Fixes + +- Do not assume that all non-IP hosts are loopbacks ([#3085](https://github.com/redis/go-redis/pull/3085)) +- Ensure that JSON.GET returns Nil response ([#3470](https://github.com/redis/go-redis/pull/3470)) + +## 🧰 Maintenance + +- fix(otel): register wait metrics ([#3499](https://github.com/redis/go-redis/pull/3499)) +- fix(make test): Add default env in makefile ([#3491](https://github.com/redis/go-redis/pull/3491)) +- Update the introduction to running tests in README.md ([#3495](https://github.com/redis/go-redis/pull/3495)) +- test: Add comprehensive edge case tests for IncrByFloat command ([#3477](https://github.com/redis/go-redis/pull/3477)) +- Set the default read/write buffer size of Redis connection to 32KiB ([#3483](https://github.com/redis/go-redis/pull/3483)) +- Bumps test image to 8.2.1-pre ([#3478](https://github.com/redis/go-redis/pull/3478)) +- fix UniversalOptions miss ReadBufferSize and WriteBufferSize options ([#3485](https://github.com/redis/go-redis/pull/3485)) +- chore(deps): bump actions/checkout from 4 to 5 ([#3484](https://github.com/redis/go-redis/pull/3484)) +- Removes dry run for stale issues policy ([#3471](https://github.com/redis/go-redis/pull/3471)) +- Update otel metrics URL ([#3474](https://github.com/redis/go-redis/pull/3474)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@LINKIWI](https://github.com/LINKIWI), [@cxljs](https://github.com/cxljs), [@cybersmeashish](https://github.com/cybersmeashish), [@elena-kolevska](https://github.com/elena-kolevska), [@htemelski-redis](https://github.com/htemelski-redis), [@mwhooker](https://github.com/mwhooker), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@suever](https://github.com/suever) + + +# 9.12.1 (2025-08-11) +## 🚀 Highlights +In the last version (9.12.0) the client introduced bigger write and read buffer sized. The default value we set was 512KiB. +However, users reported that this is too big for most use cases and can lead to high memory usage. +In this version the default value is changed to 256KiB. The `README.md` was updated to reflect the +correct default value and include a note that the default value can be changed. + +## 🐛 Bug Fixes + +- fix(options): Add buffer sizes to failover. Update README ([#3468](https://github.com/redis/go-redis/pull/3468)) + +## 🧰 Maintenance + +- fix(options): Add buffer sizes to failover. Update README ([#3468](https://github.com/redis/go-redis/pull/3468)) +- chore: update & fix otel example ([#3466](https://github.com/redis/go-redis/pull/3466)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@ndyakov](https://github.com/ndyakov) and [@vmihailenco](https://github.com/vmihailenco) + +# 9.12.0 (2025-08-05) + +## 🚀 Highlights + +- This release includes support for [Redis 8.2](https://redis.io/docs/latest/operate/oss_and_stack/stack-with-enterprise/release-notes/redisce/redisos-8.2-release-notes/). +- Introduces an experimental Query Builders for `FTSearch`, `FTAggregate` and other search commands. +- Adds support for `EPSILON` option in `FT.VSIM`. +- Includes bug fixes and improvements contributed by the community related to ring and [redisotel](https://github.com/redis/go-redis/tree/master/extra/redisotel). + +## Changes +- Improve stale issue workflow ([#3458](https://github.com/redis/go-redis/pull/3458)) +- chore(ci): Add 8.2 rc2 pre build for CI ([#3459](https://github.com/redis/go-redis/pull/3459)) +- Added new stream commands ([#3450](https://github.com/redis/go-redis/pull/3450)) +- feat: Add "skip_verify" to Sentinel ([#3428](https://github.com/redis/go-redis/pull/3428)) +- fix: `errors.Join` requires Go 1.20 or later ([#3442](https://github.com/redis/go-redis/pull/3442)) +- DOC-4344 document quickstart examples ([#3426](https://github.com/redis/go-redis/pull/3426)) +- feat(bitop): add support for the new bitop operations ([#3409](https://github.com/redis/go-redis/pull/3409)) + +## 🚀 New Features + +- feat: recover addIdleConn may occur panic ([#2445](https://github.com/redis/go-redis/pull/2445)) +- feat(ring): specify custom health check func via HeartbeatFn option ([#2940](https://github.com/redis/go-redis/pull/2940)) +- Add Query Builder for RediSearch commands ([#3436](https://github.com/redis/go-redis/pull/3436)) +- add configurable buffer sizes for Redis connections ([#3453](https://github.com/redis/go-redis/pull/3453)) +- Add VAMANA vector type to RediSearch ([#3449](https://github.com/redis/go-redis/pull/3449)) +- VSIM add `EPSILON` option ([#3454](https://github.com/redis/go-redis/pull/3454)) +- Add closing support to otel metrics instrumentation ([#3444](https://github.com/redis/go-redis/pull/3444)) + +## 🐛 Bug Fixes + +- fix(redisotel): fix buggy append in reportPoolStats ([#3122](https://github.com/redis/go-redis/pull/3122)) +- fix(search): return results even if doc is empty ([#3457](https://github.com/redis/go-redis/pull/3457)) +- [ISSUE-3402]: Ring.Pipelined return dial timeout error ([#3403](https://github.com/redis/go-redis/pull/3403)) + +## 🧰 Maintenance + +- Merges stale issues jobs into one job with two steps ([#3463](https://github.com/redis/go-redis/pull/3463)) +- improve code readability ([#3446](https://github.com/redis/go-redis/pull/3446)) +- chore(release): 9.12.0-beta.1 ([#3460](https://github.com/redis/go-redis/pull/3460)) +- DOC-5472 time series doc examples ([#3443](https://github.com/redis/go-redis/pull/3443)) +- Add VAMANA compression algorithm tests ([#3461](https://github.com/redis/go-redis/pull/3461)) +- bumped redis 8.2 version used in the CI/CD ([#3451](https://github.com/redis/go-redis/pull/3451)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@andy-stark-redis](https://github.com/andy-stark-redis), [@cxljs](https://github.com/cxljs), [@elena-kolevska](https://github.com/elena-kolevska), [@htemelski-redis](https://github.com/htemelski-redis), [@jouir](https://github.com/jouir), [@monkey92t](https://github.com/monkey92t), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@rokn](https://github.com/rokn), [@smnvdev](https://github.com/smnvdev), [@strobil](https://github.com/strobil) and [@wzy9607](https://github.com/wzy9607) + +## New Contributors +* [@htemelski-redis](https://github.com/htemelski-redis) made their first contribution in [#3409](https://github.com/redis/go-redis/pull/3409) +* [@smnvdev](https://github.com/smnvdev) made their first contribution in [#3403](https://github.com/redis/go-redis/pull/3403) +* [@rokn](https://github.com/rokn) made their first contribution in [#3444](https://github.com/redis/go-redis/pull/3444) + +# 9.11.0 (2025-06-24) + +## 🚀 Highlights + +Fixes TxPipeline to work correctly in cluster scenarios, allowing execution of commands +only in the same slot. + +# Changes + +## 🚀 New Features + +- Set cluster slot for `scan` commands, rather than random ([#2623](https://github.com/redis/go-redis/pull/2623)) +- Add CredentialsProvider field to UniversalOptions ([#2927](https://github.com/redis/go-redis/pull/2927)) +- feat(redisotel): add WithCallerEnabled option ([#3415](https://github.com/redis/go-redis/pull/3415)) + +## 🐛 Bug Fixes + +- fix(txpipeline): keyless commands should take the slot of the keyed ([#3411](https://github.com/redis/go-redis/pull/3411)) +- fix(loading): cache the loaded flag for slave nodes ([#3410](https://github.com/redis/go-redis/pull/3410)) +- fix(txpipeline): should return error on multi/exec on multiple slots ([#3408](https://github.com/redis/go-redis/pull/3408)) +- fix: check if the shard exists to avoid returning nil ([#3396](https://github.com/redis/go-redis/pull/3396)) + +## 🧰 Maintenance + +- feat: optimize connection pool waitTurn ([#3412](https://github.com/redis/go-redis/pull/3412)) +- chore(ci): update CI redis builds ([#3407](https://github.com/redis/go-redis/pull/3407)) +- chore: remove a redundant method from `Ring`, `Client` and `ClusterClient` ([#3401](https://github.com/redis/go-redis/pull/3401)) +- test: refactor TestBasicCredentials using table-driven tests ([#3406](https://github.com/redis/go-redis/pull/3406)) +- perf: reduce unnecessary memory allocation operations ([#3399](https://github.com/redis/go-redis/pull/3399)) +- fix: insert entry during iterating over a map ([#3398](https://github.com/redis/go-redis/pull/3398)) +- DOC-5229 probabilistic data type examples ([#3413](https://github.com/redis/go-redis/pull/3413)) +- chore(deps): bump rojopolis/spellcheck-github-actions from 0.49.0 to 0.51.0 ([#3414](https://github.com/redis/go-redis/pull/3414)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@andy-stark-redis](https://github.com/andy-stark-redis), [@boekkooi-impossiblecloud](https://github.com/boekkooi-impossiblecloud), [@cxljs](https://github.com/cxljs), [@dcherubini](https://github.com/dcherubini), [@dependabot[bot]](https://github.com/apps/dependabot), [@iamamirsalehi](https://github.com/iamamirsalehi), [@ndyakov](https://github.com/ndyakov), [@pete-woods](https://github.com/pete-woods), [@twz915](https://github.com/twz915) and [dependabot[bot]](https://github.com/apps/dependabot) + +# 9.10.0 (2025-06-06) + +## 🚀 Highlights + +`go-redis` now supports [vector sets](https://redis.io/docs/latest/develop/data-types/vector-sets/). This data type is marked +as "in preview" in Redis and its support in `go-redis` is marked as experimental. You can find examples in the documentation and +in the `doctests` folder. + +# Changes + +## 🚀 New Features + +- feat: support vectorset ([#3375](https://github.com/redis/go-redis/pull/3375)) + +## 🧰 Maintenance + +- Add the missing NewFloatSliceResult for testing ([#3393](https://github.com/redis/go-redis/pull/3393)) +- DOC-5078 vector set examples ([#3394](https://github.com/redis/go-redis/pull/3394)) + +## Contributors +We'd like to thank all the contributors who worked on this release! + +[@AndBobsYourUncle](https://github.com/AndBobsYourUncle), [@andy-stark-redis](https://github.com/andy-stark-redis), [@fukua95](https://github.com/fukua95) and [@ndyakov](https://github.com/ndyakov) + + + +# 9.9.0 (2025-05-27) + +## 🚀 Highlights +- **Token-based Authentication**: Added `StreamingCredentialsProvider` for dynamic credential updates (experimental) + - Can be used with [go-redis-entraid](https://github.com/redis/go-redis-entraid) for Azure AD authentication +- **Connection Statistics**: Added connection waiting statistics for better monitoring +- **Failover Improvements**: Added `ParseFailoverURL` for easier failover configuration +- **Ring Client Enhancements**: Added shard access methods for better Pub/Sub management + +## ✨ New Features +- Added `StreamingCredentialsProvider` for token-based authentication ([#3320](https://github.com/redis/go-redis/pull/3320)) + - Supports dynamic credential updates + - Includes connection close hooks + - Note: Currently marked as experimental +- Added `ParseFailoverURL` for parsing failover URLs ([#3362](https://github.com/redis/go-redis/pull/3362)) +- Added connection waiting statistics ([#2804](https://github.com/redis/go-redis/pull/2804)) +- Added new utility functions: + - `ParseFloat` and `MustParseFloat` in public utils package ([#3371](https://github.com/redis/go-redis/pull/3371)) + - Unit tests for `Atoi`, `ParseInt`, `ParseUint`, and `ParseFloat` ([#3377](https://github.com/redis/go-redis/pull/3377)) +- Added Ring client shard access methods: + - `GetShardClients()` to retrieve all active shard clients + - `GetShardClientForKey(key string)` to get the shard client for a specific key ([#3388](https://github.com/redis/go-redis/pull/3388)) + +## 🐛 Bug Fixes +- Fixed routing reads to loading slave nodes ([#3370](https://github.com/redis/go-redis/pull/3370)) +- Added support for nil lag in XINFO GROUPS ([#3369](https://github.com/redis/go-redis/pull/3369)) +- Fixed pool acquisition timeout issues ([#3381](https://github.com/redis/go-redis/pull/3381)) +- Optimized unnecessary copy operations ([#3376](https://github.com/redis/go-redis/pull/3376)) + +## 📚 Documentation +- Updated documentation for XINFO GROUPS with nil lag support ([#3369](https://github.com/redis/go-redis/pull/3369)) +- Added package-level comments for new features + +## ⚡ Performance and Reliability +- Optimized `ReplaceSpaces` function ([#3383](https://github.com/redis/go-redis/pull/3383)) +- Set default value for `Options.Protocol` in `init()` ([#3387](https://github.com/redis/go-redis/pull/3387)) +- Exported pool errors for public consumption ([#3380](https://github.com/redis/go-redis/pull/3380)) + +## 🔧 Dependencies and Infrastructure +- Updated Redis CI to version 8.0.1 ([#3372](https://github.com/redis/go-redis/pull/3372)) +- Updated spellcheck GitHub Actions ([#3389](https://github.com/redis/go-redis/pull/3389)) +- Removed unused parameters ([#3382](https://github.com/redis/go-redis/pull/3382), [#3384](https://github.com/redis/go-redis/pull/3384)) + +## 🧪 Testing +- Added unit tests for pool acquisition timeout ([#3381](https://github.com/redis/go-redis/pull/3381)) +- Added unit tests for utility functions ([#3377](https://github.com/redis/go-redis/pull/3377)) + +## 👥 Contributors + +We would like to thank all the contributors who made this release possible: + +[@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@LINKIWI](https://github.com/LINKIWI), [@iamamirsalehi](https://github.com/iamamirsalehi), [@fukua95](https://github.com/fukua95), [@lzakharov](https://github.com/lzakharov), [@DengY11](https://github.com/DengY11) + +## 📝 Changelog + +For a complete list of changes, see the [full changelog](https://github.com/redis/go-redis/compare/v9.8.0...v9.9.0). + # 9.8.0 (2025-04-30) ## 🚀 Highlights @@ -78,3 +925,139 @@ We would like to thank all the contributors who made this release possible: [@alexander-menshchikov](https://github.com/alexander-menshchikov), [@EXPEbdodla](https://github.com/EXPEbdodla), [@afti](https://github.com/afti), [@dmaier-redislabs](https://github.com/dmaier-redislabs), [@four_leaf_clover](https://github.com/four_leaf_clover), [@alohaglenn](https://github.com/alohaglenn), [@gh73962](https://github.com/gh73962), [@justinmir](https://github.com/justinmir), [@LINKIWI](https://github.com/LINKIWI), [@liushuangbill](https://github.com/liushuangbill), [@golang88](https://github.com/golang88), [@gnpaone](https://github.com/gnpaone), [@ndyakov](https://github.com/ndyakov), [@nikolaydubina](https://github.com/nikolaydubina), [@oleglacto](https://github.com/oleglacto), [@andy-stark-redis](https://github.com/andy-stark-redis), [@rodneyosodo](https://github.com/rodneyosodo), [@dependabot](https://github.com/dependabot), [@rfyiamcool](https://github.com/rfyiamcool), [@frankxjkuang](https://github.com/frankxjkuang), [@fukua95](https://github.com/fukua95), [@soleymani-milad](https://github.com/soleymani-milad), [@ofekshenawa](https://github.com/ofekshenawa), [@khasanovbi](https://github.com/khasanovbi) + + +# Old Changelog +## Unreleased + +### Changed + +* `go-redis` won't skip span creation if the parent spans is not recording. ([#2980](https://github.com/redis/go-redis/issues/2980)) + Users can use the OpenTelemetry sampler to control the sampling behavior. + For instance, you can use the `ParentBased(NeverSample())` sampler from `go.opentelemetry.io/otel/sdk/trace` to keep + a similar behavior (drop orphan spans) of `go-redis` as before. + +## [9.0.5](https://github.com/redis/go-redis/compare/v9.0.4...v9.0.5) (2023-05-29) + + +### Features + +* Add ACL LOG ([#2536](https://github.com/redis/go-redis/issues/2536)) ([31ba855](https://github.com/redis/go-redis/commit/31ba855ddebc38fbcc69a75d9d4fb769417cf602)) +* add field protocol to setupClusterQueryParams ([#2600](https://github.com/redis/go-redis/issues/2600)) ([840c25c](https://github.com/redis/go-redis/commit/840c25cb6f320501886a82a5e75f47b491e46fbe)) +* add protocol option ([#2598](https://github.com/redis/go-redis/issues/2598)) ([3917988](https://github.com/redis/go-redis/commit/391798880cfb915c4660f6c3ba63e0c1a459e2af)) + + + +## [9.0.4](https://github.com/redis/go-redis/compare/v9.0.3...v9.0.4) (2023-05-01) + + +### Bug Fixes + +* reader float parser ([#2513](https://github.com/redis/go-redis/issues/2513)) ([46f2450](https://github.com/redis/go-redis/commit/46f245075e6e3a8bd8471f9ca67ea95fd675e241)) + + +### Features + +* add client info command ([#2483](https://github.com/redis/go-redis/issues/2483)) ([b8c7317](https://github.com/redis/go-redis/commit/b8c7317cc6af444603731f7017c602347c0ba61e)) +* no longer verify HELLO error messages ([#2515](https://github.com/redis/go-redis/issues/2515)) ([7b4f217](https://github.com/redis/go-redis/commit/7b4f2179cb5dba3d3c6b0c6f10db52b837c912c8)) +* read the structure to increase the judgment of the omitempty op… ([#2529](https://github.com/redis/go-redis/issues/2529)) ([37c057b](https://github.com/redis/go-redis/commit/37c057b8e597c5e8a0e372337f6a8ad27f6030af)) + + + +## [9.0.3](https://github.com/redis/go-redis/compare/v9.0.2...v9.0.3) (2023-04-02) + +### New Features + +- feat(scan): scan time.Time sets the default decoding (#2413) +- Add support for CLUSTER LINKS command (#2504) +- Add support for acl dryrun command (#2502) +- Add support for COMMAND GETKEYS & COMMAND GETKEYSANDFLAGS (#2500) +- Add support for LCS Command (#2480) +- Add support for BZMPOP (#2456) +- Adding support for ZMPOP command (#2408) +- Add support for LMPOP (#2440) +- feat: remove pool unused fields (#2438) +- Expiretime and PExpireTime (#2426) +- Implement `FUNCTION` group of commands (#2475) +- feat(zadd): add ZAddLT and ZAddGT (#2429) +- Add: Support for COMMAND LIST command (#2491) +- Add support for BLMPOP (#2442) +- feat: check pipeline.Do to prevent confusion with Exec (#2517) +- Function stats, function kill, fcall and fcall_ro (#2486) +- feat: Add support for CLUSTER SHARDS command (#2507) +- feat(cmd): support for adding byte,bit parameters to the bitpos command (#2498) + +### Fixed + +- fix: eval api cmd.SetFirstKeyPos (#2501) +- fix: limit the number of connections created (#2441) +- fixed #2462 v9 continue support dragonfly, it's Hello command return "NOAUTH Authentication required" error (#2479) +- Fix for internal/hscan/structmap.go:89:23: undefined: reflect.Pointer (#2458) +- fix: group lag can be null (#2448) + +### Maintenance + +- Updating to the latest version of redis (#2508) +- Allowing for running tests on a port other than the fixed 6380 (#2466) +- redis 7.0.8 in tests (#2450) +- docs: Update redisotel example for v9 (#2425) +- chore: update go mod, Upgrade golang.org/x/net version to 0.7.0 (#2476) +- chore: add Chinese translation (#2436) +- chore(deps): bump github.com/bsm/gomega from 1.20.0 to 1.26.0 (#2421) +- chore(deps): bump github.com/bsm/ginkgo/v2 from 2.5.0 to 2.7.0 (#2420) +- chore(deps): bump actions/setup-go from 3 to 4 (#2495) +- docs: add instructions for the HSet api (#2503) +- docs: add reading lag field comment (#2451) +- test: update go mod before testing(go mod tidy) (#2423) +- docs: fix comment typo (#2505) +- test: remove testify (#2463) +- refactor: change ListElementCmd to KeyValuesCmd. (#2443) +- fix(appendArg): appendArg case special type (#2489) + +## [9.0.2](https://github.com/redis/go-redis/compare/v9.0.1...v9.0.2) (2023-02-01) + +### Features + +* upgrade OpenTelemetry, use the new metrics API. ([#2410](https://github.com/redis/go-redis/issues/2410)) ([e29e42c](https://github.com/redis/go-redis/commit/e29e42cde2755ab910d04185025dc43ce6f59c65)) + +## v9 2023-01-30 + +### Breaking + +- Changed Pipelines to not be thread-safe any more. + +### Added + +- Added support for [RESP3](https://github.com/antirez/RESP3/blob/master/spec.md) protocol. It was + contributed by @monkey92t who has done the majority of work in this release. +- Added `ContextTimeoutEnabled` option that controls whether the client respects context timeouts + and deadlines. See + [Redis Timeouts](https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts) for details. +- Added `ParseClusterURL` to parse URLs into `ClusterOptions`, for example, + `redis://user:password@localhost:6789?dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791`. +- Added metrics instrumentation using `redisotel.IstrumentMetrics`. See + [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html) +- Added `redis.HasErrorPrefix` to help working with errors. + +### Changed + +- Removed asynchronous cancellation based on the context timeout. It was racy in v8 and is + completely gone in v9. +- Reworked hook interface and added `DialHook`. +- Replaced `redisotel.NewTracingHook` with `redisotel.InstrumentTracing`. See + [example](example/otel) and + [documentation](https://redis.uptrace.dev/guide/go-redis-monitoring.html). +- Replaced `*redis.Z` with `redis.Z` since it is small enough to be passed as value without making + an allocation. +- Renamed the option `MaxConnAge` to `ConnMaxLifetime`. +- Renamed the option `IdleTimeout` to `ConnMaxIdleTime`. +- Removed connection reaper in favor of `MaxIdleConns`. +- Removed `WithContext` since `context.Context` can be passed directly as an arg. +- Removed `Pipeline.Close` since there is no real need to explicitly manage pipeline resources and + it can be safely reused via `sync.Pool` etc. `Pipeline.Discard` is still available if you want to + reset commands for some reason. + +### Fixed + +- Improved and fixed pipeline retries. +- As usually, added support for more commands and fixed some bugs. diff --git a/vendor/github.com/redis/go-redis/v9/RELEASING.md b/vendor/github.com/redis/go-redis/v9/RELEASING.md index 1115db4e3..033ec100b 100644 --- a/vendor/github.com/redis/go-redis/v9/RELEASING.md +++ b/vendor/github.com/redis/go-redis/v9/RELEASING.md @@ -1,15 +1,146 @@ # Releasing -1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: +This document is the runbook for cutting a go-redis release. It is intended +for maintainers with write/tag access to the repository. + +For the format and style of the release notes themselves, see +[.github/RELEASE_NOTES_TEMPLATE.md](./.github/RELEASE_NOTES_TEMPLATE.md). + +## Versioning + +go-redis follows [Semantic Versioning](https://semver.org/): + +- **Patch** (`vX.Y.Z+1`) — bug fixes, no API changes. +- **Minor** (`vX.Y+1.0`) — backwards-compatible new features, deprecations. +- **Major** (`vX+1.0.0`) — breaking changes. Coordinate with the team first. + +Pre-releases use `vX.Y.Z-beta.N` / `vX.Y.Z-rc.N`. + +## Pre-release checklist + +- [ ] Target branch is `master` and CI is green on the latest commit. +- [ ] All PRs intended for this release are merged. +- [ ] There are no open issues in the release milestone (if used). +- [ ] `CHANGELOG` / release notes have been considered; dependabot-only + and doc-only changes are excluded per the template. +- [ ] Confirm the next version number and decide if it's a patch / minor / major. + +## 1. Draft the release notes + +1. Open the draft release auto-generated by + [release-drafter](.github/release-drafter-config.yml) on GitHub. +2. Prepend a new section to [`RELEASE-NOTES.md`](./RELEASE-NOTES.md) using + [`.github/RELEASE_NOTES_TEMPLATE.md`](./.github/RELEASE_NOTES_TEMPLATE.md) + as the format. Keep the file in chronological order (newest first). +3. Pick 3–5 **Highlights** — the most user-facing, impactful changes. +4. Remove dependabot bumps and doc-only typo fixes from the lists. +5. Verify every PR has a contributor attribution and link. +6. Open a PR with just the release-notes change if you want review before + bumping versions, otherwise include it in the release PR below. + +## 2. Bump versions and open the release PR + +Create a release branch from `master`: + +```shell +git checkout master && git pull --ff-only +git checkout -b release/vX.Y.Z +``` + +Run the release script on that branch: ```shell -TAG=v1.0.0 ./scripts/release.sh +TAG=vX.Y.Z ./scripts/release.sh ``` -2. Open a pull request and wait for the build to finish. +What the script does (and explicitly does **not** do): -3. Merge the pull request and run `tag.sh` to create tags for packages: +- ✅ Validates `TAG` matches the semver regex and isn't already a git tag. +- ✅ Rewrites every `redis/go-redis*` line in every sub-module `go.mod` to + point at the new `TAG`. Trailing `// indirect` markers are preserved. +- ✅ Runs `go mod tidy -compat=1.24` in each sub-module. +- ✅ Updates the return value in [`version.go`](./version.go). +- ❌ Does **not** switch branches (runs in your current branch). +- ❌ Does **not** require a clean working tree (so you can mix it with + release-notes edits in the same branch). +- ❌ Does **not** commit, tag, or push anything. + +Review and commit the changes yourself: ```shell -TAG=v1.0.0 ./scripts/tag.sh +git diff # sanity-check the bumps +git add -u +git commit -m "chore: release vX.Y.Z" +git push origin release/vX.Y.Z ``` + +Then on GitHub: + +- [ ] Open a PR from `release/vX.Y.Z` into `master`. +- [ ] Wait for all required CI checks (build, golangci-lint, spellcheck, + doctests, e2e where applicable) to pass. +- [ ] Get at least one maintainer approval. +- [ ] Merge the PR (use a merge commit — the tag will point at the merge SHA). + +## 3. Tag the release + +After the release PR is merged, pull the latest `master` and dry-run the +tagger: + +```shell +git checkout master && git pull --ff-only +TAG=vX.Y.Z ./scripts/tag.sh vX.Y.Z +``` + +The script defaults to **dry-run** and prints the commands it would run. +Verify the output, then apply for real with `-t`: + +```shell +./scripts/tag.sh vX.Y.Z -t +``` + +This creates and pushes: +- The top-level tag `vX.Y.Z`. +- A per-module tag `/vX.Y.Z` for each public sub-module + (skipping `example/*` and `internal/*`). + +## 4. Publish the GitHub release + +1. On GitHub, open the draft release created by release-drafter. +2. Set the tag to `vX.Y.Z` and the target to `master`. +3. Replace the auto-generated body with the curated notes from + `RELEASE-NOTES.md` for this version. +4. For pre-releases, check **"Set as a pre-release"**. +5. Publish. + +## 5. Post-release + +- [ ] Verify the release appears on + [pkg.go.dev](https://pkg.go.dev/github.com/redis/go-redis/v9) within + a few minutes (trigger a fetch by visiting the version URL if needed). +- [ ] Announce on Discord (see the link in `CONTRIBUTING.md`). +- [ ] Close the release milestone if one was used. +- [ ] Open follow-up issues for anything deferred from this release. + +## Hotfix / patch release + +For an urgent fix on top of the latest release: + +1. Branch from the latest release tag: `git checkout -b hotfix/vX.Y.Z+1 vX.Y.Z`. +2. Cherry-pick (or re-apply) only the required fix commits. +3. Follow the normal release flow above with `TAG=vX.Y.Z+1`. +4. Make sure the fix is also present on `master` (forward-port if necessary). + +## Troubleshooting + +- **`release.sh` fails with "tag already exists"** — the tag has already + been created. Pick the next version, or delete the local tag first if + it was created by mistake. +- **`tag.sh` reports version mismatch in a `go.mod`** — a sub-module was + not updated by `release.sh`. Fix the `go.mod` manually (or re-run + `release.sh`), amend the release PR, and re-run the tagger. +- **`version.go` does not contain the tag** — `release.sh` did not run or + the bump was reverted. Re-run `release.sh` on the release branch. +- **pkg.go.dev does not show the new version** — visit + `https://pkg.go.dev/github.com/redis/go-redis/v9@vX.Y.Z` once to trigger + a fetch from the module proxy. diff --git a/vendor/github.com/redis/go-redis/v9/acl_commands.go b/vendor/github.com/redis/go-redis/v9/acl_commands.go index 9cb800bb3..0a8a195ce 100644 --- a/vendor/github.com/redis/go-redis/v9/acl_commands.go +++ b/vendor/github.com/redis/go-redis/v9/acl_commands.go @@ -8,8 +8,12 @@ type ACLCmdable interface { ACLLog(ctx context.Context, count int64) *ACLLogCmd ACLLogReset(ctx context.Context) *StatusCmd + ACLGenPass(ctx context.Context, bit int) *StringCmd + ACLSetUser(ctx context.Context, username string, rules ...string) *StatusCmd ACLDelUser(ctx context.Context, username string) *IntCmd + ACLUsers(ctx context.Context) *StringSliceCmd + ACLWhoAmI(ctx context.Context) *StringCmd ACLList(ctx context.Context) *StringSliceCmd ACLCat(ctx context.Context) *StringSliceCmd @@ -65,6 +69,29 @@ func (c cmdable) ACLSetUser(ctx context.Context, username string, rules ...strin return cmd } +func (c cmdable) ACLGenPass(ctx context.Context, bit int) *StringCmd { + args := make([]interface{}, 0, 3) + args = append(args, "acl", "genpass") + if bit > 0 { + args = append(args, bit) + } + cmd := NewStringCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) ACLUsers(ctx context.Context) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "acl", "users") + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) ACLWhoAmI(ctx context.Context) *StringCmd { + cmd := NewStringCmd(ctx, "acl", "whoami") + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) ACLList(ctx context.Context) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "acl", "list") _ = c(ctx, cmd) diff --git a/vendor/github.com/redis/go-redis/v9/adapters.go b/vendor/github.com/redis/go-redis/v9/adapters.go new file mode 100644 index 000000000..952a4c266 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/adapters.go @@ -0,0 +1,118 @@ +package redis + +import ( + "context" + "errors" + "net" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/push" +) + +// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand. +var ErrInvalidCommand = errors.New("invalid command type") + +// ErrInvalidPool is returned when the pool type is not supported. +var ErrInvalidPool = errors.New("invalid pool type") + +// newClientAdapter creates a new client adapter for regular Redis clients. +func newClientAdapter(client *baseClient) interfaces.ClientInterface { + return &clientAdapter{client: client} +} + +// clientAdapter adapts a Redis client to implement interfaces.ClientInterface. +type clientAdapter struct { + client *baseClient +} + +// GetOptions returns the client options. +func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface { + return &optionsAdapter{options: ca.client.opt} +} + +// GetPushProcessor returns the client's push notification processor. +func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor { + return &pushProcessorAdapter{processor: ca.client.pushProcessor} +} + +// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface. +type optionsAdapter struct { + options *Options +} + +// GetReadTimeout returns the read timeout. +func (oa *optionsAdapter) GetReadTimeout() time.Duration { + return oa.options.ReadTimeout +} + +// GetWriteTimeout returns the write timeout. +func (oa *optionsAdapter) GetWriteTimeout() time.Duration { + return oa.options.WriteTimeout +} + +// GetNetwork returns the network type. +func (oa *optionsAdapter) GetNetwork() string { + return oa.options.Network +} + +// GetAddr returns the connection address. +func (oa *optionsAdapter) GetAddr() string { + return oa.options.Addr +} + +// GetNodeAddress returns the address of the Redis node as reported by the server. +// For cluster clients, this is the endpoint from CLUSTER SLOTS before any transformation. +// For standalone clients, this defaults to Addr. +func (oa *optionsAdapter) GetNodeAddress() string { + return oa.options.NodeAddress +} + +// IsTLSEnabled returns true if TLS is enabled. +func (oa *optionsAdapter) IsTLSEnabled() bool { + return oa.options.TLSConfig != nil +} + +// GetProtocol returns the protocol version. +func (oa *optionsAdapter) GetProtocol() int { + return oa.options.Protocol +} + +// GetPoolSize returns the connection pool size. +func (oa *optionsAdapter) GetPoolSize() int { + return oa.options.PoolSize +} + +// NewDialer returns a new dialer function for the connection. +func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) { + baseDialer := oa.options.NewDialer() + return func(ctx context.Context) (net.Conn, error) { + // Extract network and address from the options + network := oa.options.Network + addr := oa.options.Addr + return baseDialer(ctx, network, addr) + } +} + +// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor. +type pushProcessorAdapter struct { + processor push.NotificationProcessor +} + +// RegisterHandler registers a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error { + if pushHandler, ok := handler.(push.NotificationHandler); ok { + return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected) + } + return errors.New("handler must implement push.NotificationHandler") +} + +// UnregisterHandler removes a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error { + return ppa.processor.UnregisterHandler(pushNotificationName) +} + +// GetHandler returns the handler for a specific push notification name. +func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} { + return ppa.processor.GetHandler(pushNotificationName) +} diff --git a/vendor/github.com/redis/go-redis/v9/array_commands.go b/vendor/github.com/redis/go-redis/v9/array_commands.go new file mode 100644 index 000000000..71037dfd6 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/array_commands.go @@ -0,0 +1,387 @@ +package redis + +import ( + "context" +) + +// note: the APIs is experimental and may be subject to change. +// +// ArrayCmdable defines the interface for Redis Array data structure commands +// available in Redis 8.8.0+. +// +// Redis array supports index range [0, math.MaxUint64-1), so index parameters use uint64. +type ArrayCmdable interface { + ARSet(ctx context.Context, key string, index uint64, values ...string) *IntCmd + ARGet(ctx context.Context, key string, index uint64) *StringCmd + ARGetRange(ctx context.Context, key string, start, end uint64) *SliceCmd + ARMGet(ctx context.Context, key string, indexes ...uint64) *SliceCmd + ARMSet(ctx context.Context, key string, members ...AREntry) *IntCmd + ARInsert(ctx context.Context, key string, values ...string) *UintCmd + ARDel(ctx context.Context, key string, indexes ...uint64) *IntCmd + ARDelRange(ctx context.Context, key string, ranges ...ARRange) *UintCmd + ARLen(ctx context.Context, key string) *UintCmd + ARCount(ctx context.Context, key string) *UintCmd + ARNext(ctx context.Context, key string) *UintCmd + ARSeek(ctx context.Context, key string, index uint64) *IntCmd + ARInfo(ctx context.Context, key string) *MapStringInterfaceCmd + ARInfoFull(ctx context.Context, key string) *MapStringInterfaceCmd + ARScan(ctx context.Context, key string, start, end uint64, args *ARScanArgs) *AREntrySliceCmd + AROpSum(ctx context.Context, key string, start, end uint64) *StringCmd + AROpMin(ctx context.Context, key string, start, end uint64) *StringCmd + AROpMax(ctx context.Context, key string, start, end uint64) *StringCmd + AROpAnd(ctx context.Context, key string, start, end uint64) *IntCmd + AROpOr(ctx context.Context, key string, start, end uint64) *IntCmd + AROpXor(ctx context.Context, key string, start, end uint64) *IntCmd + AROpMatch(ctx context.Context, key string, start, end uint64, value string) *IntCmd + AROpUsed(ctx context.Context, key string, start, end uint64) *IntCmd + ARGrep(ctx context.Context, key string, start, end string, args *ARGrepArgs) *UintSliceCmd + ARGrepWithValues(ctx context.Context, key string, start, end string, args *ARGrepArgs) *AREntrySliceCmd + ARRing(ctx context.Context, key string, size uint64, values ...string) *UintCmd + ARLastItems(ctx context.Context, key string, count uint64, rev bool) *SliceCmd +} + +// AREntry represents an index-value pair for ARMSET. +type AREntry struct { + Index uint64 + Value string +} + +// ARRange represents a start-end range for ARDELRANGE. +type ARRange struct { + Start uint64 + End uint64 +} + +// ARScanArgs contains optional arguments for ARSCAN. +type ARScanArgs struct { + Limit uint64 +} + +// ARGrepPredicateType defines the type of predicate for ARGREP. +type ARGrepPredicateType string + +const ( + ARGrepExact ARGrepPredicateType = "EXACT" + ARGrepMatch ARGrepPredicateType = "MATCH" + ARGrepGlob ARGrepPredicateType = "GLOB" + ARGrepRegex ARGrepPredicateType = "RE" +) + +// ARGrepPredicate represents a search predicate for ARGREP. +type ARGrepPredicate struct { + Type ARGrepPredicateType + Value string +} + +// ARGrepArgs contains optional arguments for ARGREP. +// Redis ARGREP defaults to OR when multiple predicates are given. +// Set CombineAnd to true to combine predicates with AND instead. +type ARGrepArgs struct { + Predicates []ARGrepPredicate + CombineAnd bool + Limit uint64 + NoCase bool +} + +// ARSet sets one or more contiguous values starting at an index in an array. +// Returns the number of new slots that were set (previously empty). +func (c cmdable) ARSet(ctx context.Context, key string, index uint64, values ...string) *IntCmd { + args := make([]any, 3, 3+len(values)) + args[0] = "arset" + args[1] = key + args[2] = index + for _, v := range values { + args = append(args, v) + } + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARGet gets the value at an index in an array. +// Returns redis.Nil if the key or index does not exist. +func (c cmdable) ARGet(ctx context.Context, key string, index uint64) *StringCmd { + cmd := NewStringCmd(ctx, "arget", key, index) + _ = c(ctx, cmd) + return cmd +} + +// ARGetRange gets values in a range of indexes. +// Returns values in the range, with nil for unset indexes. +func (c cmdable) ARGetRange(ctx context.Context, key string, start, end uint64) *SliceCmd { + cmd := NewSliceCmd(ctx, "argetrange", key, start, end) + _ = c(ctx, cmd) + return cmd +} + +// ARMGet gets values at multiple indexes in an array. +// Returns values at the specified indexes, with nil for unset indexes. +func (c cmdable) ARMGet(ctx context.Context, key string, indexes ...uint64) *SliceCmd { + args := make([]any, 2+len(indexes)) + args[0] = "armget" + args[1] = key + for i, idx := range indexes { + args[2+i] = idx + } + cmd := NewSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARMSet sets multiple index-value pairs in an array. +// Returns the number of new slots that were set (previously empty). +func (c cmdable) ARMSet(ctx context.Context, key string, members ...AREntry) *IntCmd { + args := make([]any, 2, 2+2*len(members)) + args[0] = "armset" + args[1] = key + for _, m := range members { + args = append(args, m.Index, m.Value) + } + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARInsert inserts one or more values at consecutive indexes. +// Returns the last index where a value was inserted. +func (c cmdable) ARInsert(ctx context.Context, key string, values ...string) *UintCmd { + args := make([]any, 2, 2+len(values)) + args[0] = "arinsert" + args[1] = key + for _, v := range values { + args = append(args, v) + } + cmd := NewUintCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARDel deletes elements at the specified indexes in an array. +// Returns the number of elements deleted. +func (c cmdable) ARDel(ctx context.Context, key string, indexes ...uint64) *IntCmd { + args := make([]any, 2+len(indexes)) + args[0] = "ardel" + args[1] = key + for i, idx := range indexes { + args[2+i] = idx + } + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARDelRange deletes elements in one or more ranges. +// Returns the number of elements deleted. +func (c cmdable) ARDelRange(ctx context.Context, key string, ranges ...ARRange) *UintCmd { + args := make([]any, 2, 2+2*len(ranges)) + args[0] = "ardelrange" + args[1] = key + for _, r := range ranges { + args = append(args, r.Start, r.End) + } + cmd := NewUintCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARLen returns the length of an array (max index + 1). +// Returns 0 if the key does not exist. +func (c cmdable) ARLen(ctx context.Context, key string) *UintCmd { + cmd := NewUintCmd(ctx, "arlen", key) + _ = c(ctx, cmd) + return cmd +} + +// ARCount returns the number of non-empty elements in an array. +// Returns 0 if the key does not exist. +func (c cmdable) ARCount(ctx context.Context, key string) *UintCmd { + cmd := NewUintCmd(ctx, "arcount", key) + _ = c(ctx, cmd) + return cmd +} + +// ARNext returns the next index ARINSERT would use. +// Returns 0 for missing keys or when no insert happened yet. +// Returns nil when the insertion cursor is exhausted / would overflow. +func (c cmdable) ARNext(ctx context.Context, key string) *UintCmd { + cmd := NewUintCmd(ctx, "arnext", key) + _ = c(ctx, cmd) + return cmd +} + +// ARSeek sets the ARINSERT / ARRING cursor to a specific index. +// Returns 1 if the cursor was set, 0 if the key does not exist. +func (c cmdable) ARSeek(ctx context.Context, key string, index uint64) *IntCmd { + cmd := NewIntCmd(ctx, "arseek", key, index) + _ = c(ctx, cmd) + return cmd +} + +// ARInfo returns metadata about an array. +func (c cmdable) ARInfo(ctx context.Context, key string) *MapStringInterfaceCmd { + cmd := NewMapStringInterfaceCmd(ctx, "arinfo", key) + _ = c(ctx, cmd) + return cmd +} + +// ARInfoFull returns detailed metadata about an array including slice statistics. +func (c cmdable) ARInfoFull(ctx context.Context, key string) *MapStringInterfaceCmd { + cmd := NewMapStringInterfaceCmd(ctx, "arinfo", key, "full") + _ = c(ctx, cmd) + return cmd +} + +// ARScan iterates existing elements in a range, returning index-value pairs. +func (c cmdable) ARScan(ctx context.Context, key string, start, end uint64, scanArgs *ARScanArgs) *AREntrySliceCmd { + args := make([]any, 4, 6) + args[0], args[1], args[2], args[3] = "arscan", key, start, end + if scanArgs != nil && scanArgs.Limit > 0 { + args = append(args, "limit", scanArgs.Limit) + } + cmd := NewAREntrySliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// AROpSum returns the sum of numeric elements in a range. +func (c cmdable) AROpSum(ctx context.Context, key string, start, end uint64) *StringCmd { + cmd := NewStringCmd(ctx, "arop", key, start, end, "SUM") + _ = c(ctx, cmd) + return cmd +} + +// AROpMin returns the minimum numeric element in a range. +func (c cmdable) AROpMin(ctx context.Context, key string, start, end uint64) *StringCmd { + cmd := NewStringCmd(ctx, "arop", key, start, end, "MIN") + _ = c(ctx, cmd) + return cmd +} + +// AROpMax returns the maximum numeric element in a range. +func (c cmdable) AROpMax(ctx context.Context, key string, start, end uint64) *StringCmd { + cmd := NewStringCmd(ctx, "arop", key, start, end, "MAX") + _ = c(ctx, cmd) + return cmd +} + +// AROpAnd returns the bitwise AND of integer elements in a range. +func (c cmdable) AROpAnd(ctx context.Context, key string, start, end uint64) *IntCmd { + cmd := NewIntCmd(ctx, "arop", key, start, end, "AND") + _ = c(ctx, cmd) + return cmd +} + +// AROpOr returns the bitwise OR of integer elements in a range. +func (c cmdable) AROpOr(ctx context.Context, key string, start, end uint64) *IntCmd { + cmd := NewIntCmd(ctx, "arop", key, start, end, "OR") + _ = c(ctx, cmd) + return cmd +} + +// AROpXor returns the bitwise XOR of integer elements in a range. +func (c cmdable) AROpXor(ctx context.Context, key string, start, end uint64) *IntCmd { + cmd := NewIntCmd(ctx, "arop", key, start, end, "XOR") + _ = c(ctx, cmd) + return cmd +} + +// AROpMatch returns the count of elements matching a target string in a range. +func (c cmdable) AROpMatch(ctx context.Context, key string, start, end uint64, value string) *IntCmd { + cmd := NewIntCmd(ctx, "arop", key, start, end, "MATCH", value) + _ = c(ctx, cmd) + return cmd +} + +// AROpUsed returns the count of non-empty slots in a range. +func (c cmdable) AROpUsed(ctx context.Context, key string, start, end uint64) *IntCmd { + cmd := NewIntCmd(ctx, "arop", key, start, end, "USED") + _ = c(ctx, cmd) + return cmd +} + +// ARGrep searches array elements in a range using textual predicates. +// Returns matching indexes only. Use ARGrepWithValues to also get the values. +func (c cmdable) ARGrep(ctx context.Context, key string, start, end string, grepArgs *ARGrepArgs) *UintSliceCmd { + args := make([]any, 4, 4+grepArgs.Len()) + args[0], args[1], args[2], args[3] = "argrep", key, start, end + args = grepArgs.Append(args) + cmd := NewUintSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARGrepWithValues searches array elements in a range using textual predicates. +// Returns matching indexes and their values as index-value pairs. +func (c cmdable) ARGrepWithValues(ctx context.Context, key string, start, end string, grepArgs *ARGrepArgs) *AREntrySliceCmd { + args := make([]any, 4, 5+grepArgs.Len()) + args[0], args[1], args[2], args[3] = "argrep", key, start, end + args = grepArgs.Append(args) + args = append(args, "withvalues") + cmd := NewAREntrySliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +func (args *ARGrepArgs) Len() int { + if args == nil { + return 0 + } + n := 2 * len(args.Predicates) + if args.CombineAnd { + n++ + } + if args.Limit > 0 { + n += 2 + } + if args.NoCase { + n++ + } + return n +} + +func (args *ARGrepArgs) Append(a []any) []any { + if args == nil { + return a + } + for _, p := range args.Predicates { + a = append(a, string(p.Type), p.Value) + } + if args.CombineAnd { + a = append(a, "and") + } + if args.Limit > 0 { + a = append(a, "limit", args.Limit) + } + if args.NoCase { + a = append(a, "nocase") + } + return a +} + +// ARRing inserts values into a ring buffer of specified size, wrapping and truncating as needed. +// Returns the last index where a value was inserted. +func (c cmdable) ARRing(ctx context.Context, key string, size uint64, values ...string) *UintCmd { + args := make([]any, 3, 3+len(values)) + args[0] = "arring" + args[1] = key + args[2] = size + for _, v := range values { + args = append(args, v) + } + cmd := NewUintCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// ARLastItems returns the most recently inserted elements. +// When rev is true, returns items in reverse order. +func (c cmdable) ARLastItems(ctx context.Context, key string, count uint64, rev bool) *SliceCmd { + args := make([]any, 3, 4) + args[0], args[1], args[2] = "arlastitems", key, count + if rev { + args = append(args, "rev") + } + cmd := NewSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} diff --git a/vendor/github.com/redis/go-redis/v9/auth/auth.go b/vendor/github.com/redis/go-redis/v9/auth/auth.go new file mode 100644 index 000000000..21667a128 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/auth/auth.go @@ -0,0 +1,79 @@ +// Package auth package provides authentication-related interfaces and types. +// It also includes a basic implementation of credentials using username and password. +package auth + +// StreamingCredentialsProvider is an interface that defines the methods for a streaming credentials provider. +// It is used to provide credentials for authentication. +// The CredentialsListener is used to receive updates when the credentials change. +type StreamingCredentialsProvider interface { + // Subscribe subscribes to the credentials provider for updates. + // It returns the current credentials, a cancel function to unsubscribe from the provider, + // and an error if any. + // + // Implementations MUST be idempotent with respect to listener identity: + // subscribing the same listener value more than once must not produce + // duplicate notifications and must not create multiple independent + // subscriptions that each need to be cancelled separately. Every + // UnsubscribeFunc returned for a given listener must cancel that + // listener's subscription; calling any one of them must be sufficient to + // stop updates to that listener, and calling subsequent ones must be a + // safe no-op. Callers (including go-redis internals) may retain only + // the most recently returned UnsubscribeFunc and rely on it to fully + // unsubscribe the listener. + // + // TODO(ndyakov): Should we add context to the Subscribe method? + Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error) +} + +// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider. +// It is used to unsubscribe from the provider when the credentials are no longer needed. +// +// Per the StreamingCredentialsProvider.Subscribe contract, if the same +// listener is subscribed multiple times, every UnsubscribeFunc returned for +// that listener must fully unsubscribe it on first invocation, and +// subsequent invocations (from any of the equivalent UnsubscribeFuncs) must +// be a safe no-op. +type UnsubscribeFunc func() error + +// CredentialsListener is an interface that defines the methods for a credentials listener. +// It is used to receive updates when the credentials change. +// The OnNext method is called when the credentials change. +// The OnError method is called when an error occurs while requesting the credentials. +type CredentialsListener interface { + OnNext(credentials Credentials) + OnError(err error) +} + +// Credentials is an interface that defines the methods for credentials. +// It is used to provide the credentials for authentication. +type Credentials interface { + // BasicAuth returns the username and password for basic authentication. + BasicAuth() (username string, password string) + // RawCredentials returns the raw credentials as a string. + // This can be used to extract the username and password from the raw credentials or + // additional information if present in the token. + RawCredentials() string +} + +type basicAuth struct { + username string + password string +} + +// RawCredentials returns the raw credentials as a string. +func (b *basicAuth) RawCredentials() string { + return b.username + ":" + b.password +} + +// BasicAuth returns the username and password for basic authentication. +func (b *basicAuth) BasicAuth() (username string, password string) { + return b.username, b.password +} + +// NewBasicCredentials creates a new Credentials object from the given username and password. +func NewBasicCredentials(username, password string) Credentials { + return &basicAuth{ + username: username, + password: password, + } +} diff --git a/vendor/github.com/redis/go-redis/v9/auth/reauth_credentials_listener.go b/vendor/github.com/redis/go-redis/v9/auth/reauth_credentials_listener.go new file mode 100644 index 000000000..40076a0b1 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/auth/reauth_credentials_listener.go @@ -0,0 +1,47 @@ +package auth + +// ReAuthCredentialsListener is a struct that implements the CredentialsListener interface. +// It is used to re-authenticate the credentials when they are updated. +// It contains: +// - reAuth: a function that takes the new credentials and returns an error if any. +// - onErr: a function that takes an error and handles it. +type ReAuthCredentialsListener struct { + reAuth func(credentials Credentials) error + onErr func(err error) +} + +// OnNext is called when the credentials are updated. +// It calls the reAuth function with the new credentials. +// If the reAuth function returns an error, it calls the onErr function with the error. +func (c *ReAuthCredentialsListener) OnNext(credentials Credentials) { + if c.reAuth == nil { + return + } + + err := c.reAuth(credentials) + if err != nil { + c.OnError(err) + } +} + +// OnError is called when an error occurs. +// It can be called from both the credentials provider and the reAuth function. +func (c *ReAuthCredentialsListener) OnError(err error) { + if c.onErr == nil { + return + } + + c.onErr(err) +} + +// NewReAuthCredentialsListener creates a new ReAuthCredentialsListener. +// Implements the auth.CredentialsListener interface. +func NewReAuthCredentialsListener(reAuth func(credentials Credentials) error, onErr func(err error)) *ReAuthCredentialsListener { + return &ReAuthCredentialsListener{ + reAuth: reAuth, + onErr: onErr, + } +} + +// Ensure ReAuthCredentialsListener implements the CredentialsListener interface. +var _ CredentialsListener = (*ReAuthCredentialsListener)(nil) diff --git a/vendor/github.com/redis/go-redis/v9/bitmap_commands.go b/vendor/github.com/redis/go-redis/v9/bitmap_commands.go index a21558289..86aa9b7ef 100644 --- a/vendor/github.com/redis/go-redis/v9/bitmap_commands.go +++ b/vendor/github.com/redis/go-redis/v9/bitmap_commands.go @@ -12,6 +12,10 @@ type BitMapCmdable interface { BitOpAnd(ctx context.Context, destKey string, keys ...string) *IntCmd BitOpOr(ctx context.Context, destKey string, keys ...string) *IntCmd BitOpXor(ctx context.Context, destKey string, keys ...string) *IntCmd + BitOpDiff(ctx context.Context, destKey string, keys ...string) *IntCmd + BitOpDiff1(ctx context.Context, destKey string, keys ...string) *IntCmd + BitOpAndOr(ctx context.Context, destKey string, keys ...string) *IntCmd + BitOpOne(ctx context.Context, destKey string, keys ...string) *IntCmd BitOpNot(ctx context.Context, destKey string, key string) *IntCmd BitPos(ctx context.Context, key string, bit int64, pos ...int64) *IntCmd BitPosSpan(ctx context.Context, key string, bit int8, start, end int64, span string) *IntCmd @@ -78,22 +82,50 @@ func (c cmdable) bitOp(ctx context.Context, op, destKey string, keys ...string) return cmd } +// BitOpAnd creates a new bitmap in which users are members of all given bitmaps func (c cmdable) BitOpAnd(ctx context.Context, destKey string, keys ...string) *IntCmd { return c.bitOp(ctx, "and", destKey, keys...) } +// BitOpOr creates a new bitmap in which users are member of at least one given bitmap func (c cmdable) BitOpOr(ctx context.Context, destKey string, keys ...string) *IntCmd { return c.bitOp(ctx, "or", destKey, keys...) } +// BitOpXor creates a new bitmap in which users are the result of XORing all given bitmaps func (c cmdable) BitOpXor(ctx context.Context, destKey string, keys ...string) *IntCmd { return c.bitOp(ctx, "xor", destKey, keys...) } +// BitOpNot creates a new bitmap in which users are not members of a given bitmap func (c cmdable) BitOpNot(ctx context.Context, destKey string, key string) *IntCmd { return c.bitOp(ctx, "not", destKey, key) } +// BitOpDiff creates a new bitmap in which users are members of bitmap X but not of any of bitmaps Y1, Y2, … +// Introduced with Redis 8.2 +func (c cmdable) BitOpDiff(ctx context.Context, destKey string, keys ...string) *IntCmd { + return c.bitOp(ctx, "diff", destKey, keys...) +} + +// BitOpDiff1 creates a new bitmap in which users are members of one or more of bitmaps Y1, Y2, … but not members of bitmap X +// Introduced with Redis 8.2 +func (c cmdable) BitOpDiff1(ctx context.Context, destKey string, keys ...string) *IntCmd { + return c.bitOp(ctx, "diff1", destKey, keys...) +} + +// BitOpAndOr creates a new bitmap in which users are members of bitmap X and also members of one or more of bitmaps Y1, Y2, … +// Introduced with Redis 8.2 +func (c cmdable) BitOpAndOr(ctx context.Context, destKey string, keys ...string) *IntCmd { + return c.bitOp(ctx, "andor", destKey, keys...) +} + +// BitOpOne creates a new bitmap in which users are members of exactly one of the given bitmaps +// Introduced with Redis 8.2 +func (c cmdable) BitOpOne(ctx context.Context, destKey string, keys ...string) *IntCmd { + return c.bitOp(ctx, "one", destKey, keys...) +} + // BitPos is an API before Redis version 7.0, cmd: bitpos key bit start end // if you need the `byte | bit` parameter, please use `BitPosSpan`. func (c cmdable) BitPos(ctx context.Context, key string, bit int64, pos ...int64) *IntCmd { @@ -109,7 +141,9 @@ func (c cmdable) BitPos(ctx context.Context, key string, bit int64, pos ...int64 args[3] = pos[0] args[4] = pos[1] default: - panic("too many arguments") + cmd := NewIntCmd(ctx) + cmd.SetErr(errors.New("too many arguments")) + return cmd } cmd := NewIntCmd(ctx, args...) _ = c(ctx, cmd) @@ -150,7 +184,9 @@ func (c cmdable) BitFieldRO(ctx context.Context, key string, values ...interface args[0] = "BITFIELD_RO" args[1] = key if len(values)%2 != 0 { - panic("BitFieldRO: invalid number of arguments, must be even") + c := NewIntSliceCmd(ctx) + c.SetErr(errors.New("BitFieldRO: invalid number of arguments, must be even")) + return c } for i := 0; i < len(values); i += 2 { args = append(args, "GET", values[i], values[i+1]) diff --git a/vendor/github.com/redis/go-redis/v9/cluster_commands.go b/vendor/github.com/redis/go-redis/v9/cluster_commands.go index 4857b01ea..a02683f20 100644 --- a/vendor/github.com/redis/go-redis/v9/cluster_commands.go +++ b/vendor/github.com/redis/go-redis/v9/cluster_commands.go @@ -42,6 +42,9 @@ func (c cmdable) ClusterMyID(ctx context.Context) *StringCmd { return cmd } +// ClusterSlots returns the mapping of cluster slots to nodes. +// +// Deprecated: Use ClusterShards instead as of Redis 7.0.0. func (c cmdable) ClusterSlots(ctx context.Context) *ClusterSlotsCmd { cmd := NewClusterSlotsCmd(ctx, "cluster", "slots") _ = c(ctx, cmd) @@ -153,6 +156,9 @@ func (c cmdable) ClusterSaveConfig(ctx context.Context) *StatusCmd { return cmd } +// ClusterSlaves lists the replica nodes of a master node. +// +// Deprecated: Use ClusterReplicas instead as of Redis 5.0.0. func (c cmdable) ClusterSlaves(ctx context.Context, nodeID string) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "cluster", "slaves", nodeID) _ = c(ctx, cmd) diff --git a/vendor/github.com/redis/go-redis/v9/command.go b/vendor/github.com/redis/go-redis/v9/command.go index 3253af6cc..57a26c8ba 100644 --- a/vendor/github.com/redis/go-redis/v9/command.go +++ b/vendor/github.com/redis/go-redis/v9/command.go @@ -4,6 +4,8 @@ import ( "bufio" "context" "fmt" + "io" + "maps" "net" "regexp" "strconv" @@ -14,9 +16,186 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/routing" "github.com/redis/go-redis/v9/internal/util" ) +// keylessCommands contains Redis commands that have empty key specifications (9th slot empty) +// Only includes core Redis commands, excludes FT.*, ts.*, timeseries.*, search.* and subcommands +var keylessCommands = map[string]struct{}{ + "acl": {}, + "asking": {}, + "auth": {}, + "bgrewriteaof": {}, + "bgsave": {}, + "client": {}, + "cluster": {}, + "config": {}, + "debug": {}, + "discard": {}, + "echo": {}, + "exec": {}, + "failover": {}, + "function": {}, + "hello": {}, + "hotkeys": {}, + "latency": {}, + "lolwut": {}, + "module": {}, + "monitor": {}, + "multi": {}, + "pfselftest": {}, + "ping": {}, + "psubscribe": {}, + "psync": {}, + "publish": {}, + "pubsub": {}, + "punsubscribe": {}, + "quit": {}, + "readonly": {}, + "readwrite": {}, + "replconf": {}, + "replicaof": {}, + "role": {}, + "save": {}, + "script": {}, + "select": {}, + "shutdown": {}, + "slaveof": {}, + "slowlog": {}, + "subscribe": {}, + "swapdb": {}, + "sync": {}, + "time": {}, + "unsubscribe": {}, + "unwatch": {}, + "wait": {}, +} + +// CmdTyper interface for getting command type +type CmdTyper interface { + GetCmdType() CmdType +} + +// CmdTypeGetter interface for getting command type without circular imports +type CmdTypeGetter interface { + GetCmdType() CmdType +} + +type CmdType uint8 + +const ( + CmdTypeGeneric CmdType = iota + CmdTypeString + CmdTypeInt + CmdTypeBool + CmdTypeFloat + CmdTypeStringSlice + CmdTypeIntSlice + CmdTypeFloatSlice + CmdTypeBoolSlice + CmdTypeMapStringString + CmdTypeMapStringInt + CmdTypeMapStringInterface + CmdTypeMapStringInterfaceSlice + CmdTypeSlice + CmdTypeStatus + CmdTypeDuration + CmdTypeTime + CmdTypeKeyValueSlice + CmdTypeStringStructMap + CmdTypeXMessageSlice + CmdTypeXStreamSlice + CmdTypeXPending + CmdTypeXPendingExt + CmdTypeXAutoClaim + CmdTypeXAutoClaimWithDeleted + CmdTypeXAutoClaimJustID + CmdTypeXInfoConsumers + CmdTypeXInfoGroups + CmdTypeXInfoStream + CmdTypeXInfoStreamFull + CmdTypeZSlice + CmdTypeZWithKey + CmdTypeScan + CmdTypeClusterSlots + CmdTypeGeoLocation + CmdTypeGeoSearchLocation + CmdTypeGeoPos + CmdTypeCommandsInfo + CmdTypeSlowLog + CmdTypeMapStringStringSlice + CmdTypeMapMapStringInterface + CmdTypeKeyValues + CmdTypeZSliceWithKey + CmdTypeFunctionList + CmdTypeFunctionStats + CmdTypeLCS + CmdTypeKeyFlags + CmdTypeClusterLinks + CmdTypeClusterShards + CmdTypeRankWithScore + CmdTypeClientInfo + CmdTypeACLLog + CmdTypeInfo + CmdTypeMonitor + CmdTypeJSON + CmdTypeJSONSlice + CmdTypeIntPointerSlice + CmdTypeScanDump + CmdTypeBFInfo + CmdTypeCFInfo + CmdTypeCMSInfo + CmdTypeTopKInfo + CmdTypeTDigestInfo + CmdTypeFTSynDump + CmdTypeAggregate + CmdTypeFTInfo + CmdTypeFTSpellCheck + CmdTypeFTSearch + CmdTypeTSTimestampValue + CmdTypeTSTimestampValueSlice + CmdTypeHotKeys + CmdTypeIncrEXInt + CmdTypeIncrEXFloat + CmdTypeUint + CmdTypeUintSlice + CmdTypeAREntrySlice +) + +type ( + CmdTypeXAutoClaimValue struct { + messages []XMessage + start string + } + + CmdTypeXAutoClaimWithDeletedValue struct { + messages []XMessage + start string + deletedIDs []string + } + + CmdTypeXAutoClaimJustIDValue struct { + ids []string + start string + } + + CmdTypeScanValue struct { + keys []string + cursor uint64 + } + + CmdTypeKeyValuesValue struct { + key string + values []string + } + + CmdTypeZSliceWithKeyValue struct { + key string + zSlice []Z + } +) + type Cmder interface { // command name. // e.g. "set k v ex 10" -> "set", "cluster info" -> "cluster". @@ -34,15 +213,28 @@ type Cmder interface { // e.g. "set k v ex 10" -> "set k v ex 10: OK", "get k" -> "get k: v". String() string + // Clone creates a copy of the command. + Clone() Cmder + stringArg(int) string firstKeyPos() int8 SetFirstKeyPos(int8) + stepCount() int8 + SetStepCount(int8) readTimeout() *time.Duration readReply(rd *proto.Reader) error readRawReply(rd *proto.Reader) error SetErr(error) Err() error + + // NoRetry returns true if the command should not be retried on failure. + // Commands that write directly to an io.Writer should return true since + // partial writes cannot be undone on retry. + NoRetry() bool + + // GetCmdType returns the command type for fast value extraction + GetCmdType() CmdType } func setCmdsErr(cmds []Cmder, e error) { @@ -62,6 +254,18 @@ func cmdsFirstErr(cmds []Cmder) error { return nil } +// cmdsContainNoRetry returns true if any command in the slice has NoRetry() == true. +// If a pipeline contains a non-retryable command (e.g., RawWriteToCmd), the entire +// pipeline must not be retried to prevent data corruption from partial writes. +func cmdsContainNoRetry(cmds []Cmder) bool { + for _, cmd := range cmds { + if cmd.NoRetry() { + return true + } + } + return false +} + func writeCmds(wr *proto.Writer, cmds []Cmder) error { for _, cmd := range cmds { if err := writeCmd(wr, cmd); err != nil { @@ -75,26 +279,42 @@ func writeCmd(wr *proto.Writer, cmd Cmder) error { return wr.WriteArgs(cmd.Args()) } -func cmdFirstKeyPos(cmd Cmder) int { +// cmdFirstKeyPosWithInfo returns the first key position in a command's args (0 if none). +// Uses CommandInfo.FirstKeyPos when available (via cache peek, no network call), falling +// back to a hardcoded table. eval/evalsha variants are resolved from the runtime numkeys arg. +func cmdFirstKeyPosWithInfo(cmd Cmder, info *CommandInfo) int { if pos := cmd.firstKeyPos(); pos != 0 { return int(pos) } - switch cmd.Name() { + name := cmd.Name() + + // first check if the command is keyless + if _, ok := keylessCommands[name]; ok { + return 0 + } + + switch name { case "eval", "evalsha", "eval_ro", "evalsha_ro": if cmd.stringArg(2) != "0" { return 3 } return 0 - case "publish": - return 1 case "memory": // https://github.com/redis/redis/issues/7493 if cmd.stringArg(1) == "usage" { return 2 } + // CommandInfo (if available) gives the correct answer + // otherwise the hardcoded fallback applies. + } + + // Use CommandInfo cache when warm (in-memory only, no extra round-trips). + if info != nil { + return int(info.FirstKeyPos) } + return 1 } @@ -126,8 +346,10 @@ type baseCmd struct { args []interface{} err error keyPos int8 + _stepCount int8 rawVal interface{} _readTimeout *time.Duration + cmdType CmdType } var _ Cmder = (*Cmd)(nil) @@ -183,6 +405,14 @@ func (cmd *baseCmd) SetFirstKeyPos(keyPos int8) { cmd.keyPos = keyPos } +func (cmd *baseCmd) stepCount() int8 { + return cmd._stepCount +} + +func (cmd *baseCmd) SetStepCount(stepCount int8) { + cmd._stepCount = stepCount +} + func (cmd *baseCmd) SetErr(e error) { cmd.err = e } @@ -204,6 +434,41 @@ func (cmd *baseCmd) readRawReply(rd *proto.Reader) (err error) { return err } +// NoRetry returns true if the command should not be retried on failure. +// By default, commands can be retried. Commands that write directly to an +// io.Writer (like RawWriteToCmd) should override this to return true since +// partial writes cannot be undone on retry. +func (cmd *baseCmd) NoRetry() bool { + return false +} + +func (cmd *baseCmd) GetCmdType() CmdType { + return cmd.cmdType +} + +func (cmd *baseCmd) cloneBaseCmd() baseCmd { + var readTimeout *time.Duration + if cmd._readTimeout != nil { + timeout := *cmd._readTimeout + readTimeout = &timeout + } + + // Create a copy of args slice + args := make([]interface{}, len(cmd.args)) + copy(args, cmd.args) + + return baseCmd{ + ctx: cmd.ctx, + args: args, + err: cmd.err, + keyPos: cmd.keyPos, + _stepCount: cmd._stepCount, + rawVal: cmd.rawVal, + _readTimeout: readTimeout, + cmdType: cmd.cmdType, + } +} + //------------------------------------------------------------------------------ type Cmd struct { @@ -215,8 +480,9 @@ type Cmd struct { func NewCmd(ctx context.Context, args ...interface{}) *Cmd { return &Cmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeGeneric, }, } } @@ -489,6 +755,129 @@ func (cmd *Cmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *Cmd) Clone() Cmder { + return &Cmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +//------------------------------------------------------------------------------ + +// RawCmd returns raw RESP protocol bytes without parsing. +type RawCmd struct { + baseCmd + val []byte +} + +var _ Cmder = (*RawCmd)(nil) + +func NewRawCmd(ctx context.Context, args ...interface{}) *RawCmd { + return &RawCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeGeneric, + }, + } +} + +func (cmd *RawCmd) SetVal(val []byte) { + cmd.val = val +} + +func (cmd *RawCmd) Val() []byte { + return cmd.val +} + +func (cmd *RawCmd) Result() ([]byte, error) { + return cmd.val, cmd.err +} + +func (cmd *RawCmd) Bytes() ([]byte, error) { + return cmd.val, cmd.err +} + +func (cmd *RawCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *RawCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadRawReply() + return err +} + +func (cmd *RawCmd) Clone() Cmder { + var val []byte + if cmd.val != nil { + val = make([]byte, len(cmd.val)) + copy(val, cmd.val) + } + return &RawCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +//------------------------------------------------------------------------------ + +// RawWriteToCmd streams raw RESP protocol bytes directly to an io.Writer without intermediate allocations. +type RawWriteToCmd struct { + baseCmd + w io.Writer + written int64 +} + +var _ Cmder = (*RawWriteToCmd)(nil) + +func NewRawWriteToCmd(ctx context.Context, w io.Writer, args ...interface{}) *RawWriteToCmd { + return &RawWriteToCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeGeneric, + }, + w: w, + } +} + +func (cmd *RawWriteToCmd) SetVal(written int64) { + cmd.written = written +} + +func (cmd *RawWriteToCmd) Val() int64 { + return cmd.written +} + +func (cmd *RawWriteToCmd) Result() (int64, error) { + return cmd.written, cmd.err +} + +func (cmd *RawWriteToCmd) String() string { + return cmdString(cmd, cmd.written) +} + +func (cmd *RawWriteToCmd) readReply(rd *proto.Reader) (err error) { + cmd.written, err = rd.ReadRawReplyWriteTo(cmd.w) + return err +} + +// NoRetry returns true because RawWriteToCmd writes directly to an io.Writer. +// If a retry occurs, partial data from failed attempts would be appended to +// the writer, causing data corruption. The caller must handle retries manually +// if needed, using a fresh writer for each attempt. +func (cmd *RawWriteToCmd) NoRetry() bool { + return true +} + +func (cmd *RawWriteToCmd) Clone() Cmder { + return &RawWriteToCmd{ + baseCmd: cmd.cloneBaseCmd(), + w: cmd.w, + written: cmd.written, + } +} + //------------------------------------------------------------------------------ type SliceCmd struct { @@ -502,8 +891,9 @@ var _ Cmder = (*SliceCmd)(nil) func NewSliceCmd(ctx context.Context, args ...interface{}) *SliceCmd { return &SliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeSlice, }, } } @@ -549,6 +939,18 @@ func (cmd *SliceCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *SliceCmd) Clone() Cmder { + var val []interface{} + if cmd.val != nil { + val = make([]interface{}, len(cmd.val)) + copy(val, cmd.val) + } + return &SliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StatusCmd struct { @@ -562,8 +964,9 @@ var _ Cmder = (*StatusCmd)(nil) func NewStatusCmd(ctx context.Context, args ...interface{}) *StatusCmd { return &StatusCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStatus, }, } } @@ -593,6 +996,13 @@ func (cmd *StatusCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *StatusCmd) Clone() Cmder { + return &StatusCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type IntCmd struct { @@ -606,8 +1016,9 @@ var _ Cmder = (*IntCmd)(nil) func NewIntCmd(ctx context.Context, args ...interface{}) *IntCmd { return &IntCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeInt, }, } } @@ -637,6 +1048,128 @@ func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *IntCmd) Clone() Cmder { + return &IntCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +type UintCmd struct { + baseCmd + + val uint64 +} + +var _ Cmder = (*UintCmd)(nil) + +func NewUintCmd(ctx context.Context, args ...any) *UintCmd { + return &UintCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeUint, + }, + } +} + +func (cmd *UintCmd) SetVal(val uint64) { + cmd.val = val +} + +func (cmd *UintCmd) Val() uint64 { + return cmd.val +} + +func (cmd *UintCmd) Result() (uint64, error) { + return cmd.val, cmd.err +} + +func (cmd *UintCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *UintCmd) readReply(rd *proto.Reader) (err error) { + cmd.val, err = rd.ReadUint() + return err +} + +func (cmd *UintCmd) Clone() Cmder { + return &UintCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +//------------------------------------------------------------------------------ + +// DigestCmd is a command that returns a uint64 xxh3 hash digest. +// +// This command is specifically designed for the Redis DIGEST command, +// which returns the xxh3 hash of a key's value as a hex string. +// The hex string is automatically parsed to a uint64 value. +// +// The digest can be used for optimistic locking with SetIFDEQ, SetIFDNE, +// and DelExArgs commands. +// +// For examples of client-side digest generation and usage patterns, see: +// example/digest-optimistic-locking/ +// +// Redis 8.4+. See https://redis.io/commands/digest/ +type DigestCmd struct { + baseCmd + + val uint64 +} + +var _ Cmder = (*DigestCmd)(nil) + +func NewDigestCmd(ctx context.Context, args ...interface{}) *DigestCmd { + return &DigestCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *DigestCmd) SetVal(val uint64) { + cmd.val = val +} + +func (cmd *DigestCmd) Val() uint64 { + return cmd.val +} + +func (cmd *DigestCmd) Result() (uint64, error) { + return cmd.val, cmd.err +} + +func (cmd *DigestCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *DigestCmd) Clone() Cmder { + return &DigestCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +func (cmd *DigestCmd) readReply(rd *proto.Reader) (err error) { + // Redis DIGEST command returns a hex string (e.g., "a1b2c3d4e5f67890") + // We parse it as a uint64 xxh3 hash value + var hexStr string + hexStr, err = rd.ReadString() + if err != nil { + return err + } + + // Parse hex string to uint64 + cmd.val, err = strconv.ParseUint(hexStr, 16, 64) + return err +} + //------------------------------------------------------------------------------ type IntSliceCmd struct { @@ -650,8 +1183,9 @@ var _ Cmder = (*IntSliceCmd)(nil) func NewIntSliceCmd(ctx context.Context, args ...interface{}) *IntSliceCmd { return &IntSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeIntSlice, }, } } @@ -686,6 +1220,78 @@ func (cmd *IntSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *IntSliceCmd) Clone() Cmder { + var val []int64 + if cmd.val != nil { + val = make([]int64, len(cmd.val)) + copy(val, cmd.val) + } + return &IntSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +type UintSliceCmd struct { + baseCmd + + val []uint64 +} + +var _ Cmder = (*UintSliceCmd)(nil) + +func NewUintSliceCmd(ctx context.Context, args ...any) *UintSliceCmd { + return &UintSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeUintSlice, + }, + } +} + +func (cmd *UintSliceCmd) SetVal(val []uint64) { + cmd.val = val +} + +func (cmd *UintSliceCmd) Val() []uint64 { + return cmd.val +} + +func (cmd *UintSliceCmd) Result() ([]uint64, error) { + return cmd.val, cmd.err +} + +func (cmd *UintSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *UintSliceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]uint64, n) + for i := range cmd.val { + if cmd.val[i], err = rd.ReadUint(); err != nil { + return err + } + } + return nil +} + +func (cmd *UintSliceCmd) Clone() Cmder { + var val []uint64 + if cmd.val != nil { + val = make([]uint64, len(cmd.val)) + copy(val, cmd.val) + } + return &UintSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type DurationCmd struct { @@ -700,8 +1306,9 @@ var _ Cmder = (*DurationCmd)(nil) func NewDurationCmd(ctx context.Context, precision time.Duration, args ...interface{}) *DurationCmd { return &DurationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeDuration, }, precision: precision, } @@ -739,6 +1346,14 @@ func (cmd *DurationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *DurationCmd) Clone() Cmder { + return &DurationCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + precision: cmd.precision, + } +} + //------------------------------------------------------------------------------ type TimeCmd struct { @@ -752,8 +1367,9 @@ var _ Cmder = (*TimeCmd)(nil) func NewTimeCmd(ctx context.Context, args ...interface{}) *TimeCmd { return &TimeCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTime, }, } } @@ -790,6 +1406,13 @@ func (cmd *TimeCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *TimeCmd) Clone() Cmder { + return &TimeCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type BoolCmd struct { @@ -803,8 +1426,9 @@ var _ Cmder = (*BoolCmd)(nil) func NewBoolCmd(ctx context.Context, args ...interface{}) *BoolCmd { return &BoolCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBool, }, } } @@ -837,6 +1461,13 @@ func (cmd *BoolCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *BoolCmd) Clone() Cmder { + return &BoolCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type StringCmd struct { @@ -850,8 +1481,9 @@ var _ Cmder = (*StringCmd)(nil) func NewStringCmd(ctx context.Context, args ...interface{}) *StringCmd { return &StringCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeString, }, } } @@ -883,28 +1515,28 @@ func (cmd *StringCmd) Int() (int, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.Atoi(cmd.Val()) + return strconv.Atoi(cmd.val) } func (cmd *StringCmd) Int64() (int64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseInt(cmd.Val(), 10, 64) + return strconv.ParseInt(cmd.val, 10, 64) } func (cmd *StringCmd) Uint64() (uint64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseUint(cmd.Val(), 10, 64) + return strconv.ParseUint(cmd.val, 10, 64) } func (cmd *StringCmd) Float32() (float32, error) { if cmd.err != nil { return 0, cmd.err } - f, err := strconv.ParseFloat(cmd.Val(), 32) + f, err := strconv.ParseFloat(cmd.val, 32) if err != nil { return 0, err } @@ -915,14 +1547,14 @@ func (cmd *StringCmd) Float64() (float64, error) { if cmd.err != nil { return 0, cmd.err } - return strconv.ParseFloat(cmd.Val(), 64) + return strconv.ParseFloat(cmd.val, 64) } func (cmd *StringCmd) Time() (time.Time, error) { if cmd.err != nil { return time.Time{}, cmd.err } - return time.Parse(time.RFC3339Nano, cmd.Val()) + return time.Parse(time.RFC3339Nano, cmd.val) } func (cmd *StringCmd) Scan(val interface{}) error { @@ -941,6 +1573,13 @@ func (cmd *StringCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *StringCmd) Clone() Cmder { + return &StringCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type FloatCmd struct { @@ -954,8 +1593,9 @@ var _ Cmder = (*FloatCmd)(nil) func NewFloatCmd(ctx context.Context, args ...interface{}) *FloatCmd { return &FloatCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFloat, }, } } @@ -981,6 +1621,13 @@ func (cmd *FloatCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *FloatCmd) Clone() Cmder { + return &FloatCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + //------------------------------------------------------------------------------ type FloatSliceCmd struct { @@ -994,8 +1641,9 @@ var _ Cmder = (*FloatSliceCmd)(nil) func NewFloatSliceCmd(ctx context.Context, args ...interface{}) *FloatSliceCmd { return &FloatSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFloatSlice, }, } } @@ -1036,6 +1684,18 @@ func (cmd *FloatSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *FloatSliceCmd) Clone() Cmder { + var val []float64 + if cmd.val != nil { + val = make([]float64, len(cmd.val)) + copy(val, cmd.val) + } + return &FloatSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StringSliceCmd struct { @@ -1049,8 +1709,9 @@ var _ Cmder = (*StringSliceCmd)(nil) func NewStringSliceCmd(ctx context.Context, args ...interface{}) *StringSliceCmd { return &StringSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStringSlice, }, } } @@ -1072,7 +1733,7 @@ func (cmd *StringSliceCmd) String() string { } func (cmd *StringSliceCmd) ScanSlice(container interface{}) error { - return proto.ScanSlice(cmd.Val(), container) + return proto.ScanSlice(cmd.val, container) } func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { @@ -1094,6 +1755,99 @@ func (cmd *StringSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *StringSliceCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &StringSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +//------------------------------------------------------------------------------ + +// StringSliceSliceCmd returns a slice of string slices ([][]string). +// This is used for commands like VLINKS that return an array of arrays. +type StringSliceSliceCmd struct { + baseCmd + + val [][]string +} + +var _ Cmder = (*StringSliceSliceCmd)(nil) + +func NewStringSliceSliceCmd(ctx context.Context, args ...any) *StringSliceSliceCmd { + return &StringSliceSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *StringSliceSliceCmd) SetVal(val [][]string) { + cmd.val = val +} + +func (cmd *StringSliceSliceCmd) Val() [][]string { + return cmd.val +} + +func (cmd *StringSliceSliceCmd) Result() ([][]string, error) { + return cmd.val, cmd.err +} + +func (cmd *StringSliceSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *StringSliceSliceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([][]string, n) + for i := range n { + // Read inner array + innerN, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val[i] = make([]string, innerN) + for j := range innerN { + switch s, err := rd.ReadString(); { + case err == Nil: + cmd.val[i][j] = "" + case err != nil: + return err + default: + cmd.val[i][j] = s + } + } + } + return nil +} + +func (cmd *StringSliceSliceCmd) Clone() Cmder { + var val [][]string + if cmd.val != nil { + val = make([][]string, len(cmd.val)) + for i, slice := range cmd.val { + if slice != nil { + val[i] = make([]string, len(slice)) + copy(val[i], slice) + } + } + } + return &StringSliceSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type KeyValue struct { @@ -1112,8 +1866,9 @@ var _ Cmder = (*KeyValueSliceCmd)(nil) func NewKeyValueSliceCmd(ctx context.Context, args ...interface{}) *KeyValueSliceCmd { return &KeyValueSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyValueSlice, }, } } @@ -1188,6 +1943,18 @@ func (cmd *KeyValueSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl return nil } +func (cmd *KeyValueSliceCmd) Clone() Cmder { + var val []KeyValue + if cmd.val != nil { + val = make([]KeyValue, len(cmd.val)) + copy(val, cmd.val) + } + return &KeyValueSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type BoolSliceCmd struct { @@ -1201,8 +1968,9 @@ var _ Cmder = (*BoolSliceCmd)(nil) func NewBoolSliceCmd(ctx context.Context, args ...interface{}) *BoolSliceCmd { return &BoolSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBoolSlice, }, } } @@ -1237,6 +2005,18 @@ func (cmd *BoolSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *BoolSliceCmd) Clone() Cmder { + var val []bool + if cmd.val != nil { + val = make([]bool, len(cmd.val)) + copy(val, cmd.val) + } + return &BoolSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type MapStringStringCmd struct { @@ -1250,8 +2030,9 @@ var _ Cmder = (*MapStringStringCmd)(nil) func NewMapStringStringCmd(ctx context.Context, args ...interface{}) *MapStringStringCmd { return &MapStringStringCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringString, }, } } @@ -1316,6 +2097,20 @@ func (cmd *MapStringStringCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringStringCmd) Clone() Cmder { + var val map[string]string + if cmd.val != nil { + val = make(map[string]string, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringStringCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type MapStringIntCmd struct { @@ -1329,8 +2124,9 @@ var _ Cmder = (*MapStringIntCmd)(nil) func NewMapStringIntCmd(ctx context.Context, args ...interface{}) *MapStringIntCmd { return &MapStringIntCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInt, }, } } @@ -1373,6 +2169,20 @@ func (cmd *MapStringIntCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringIntCmd) Clone() Cmder { + var val map[string]int64 + if cmd.val != nil { + val = make(map[string]int64, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringIntCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------------------------------------------ type MapStringSliceInterfaceCmd struct { baseCmd @@ -1382,8 +2192,9 @@ type MapStringSliceInterfaceCmd struct { func NewMapStringSliceInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringSliceInterfaceCmd { return &MapStringSliceInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterfaceSlice, }, } } @@ -1469,6 +2280,24 @@ func (cmd *MapStringSliceInterfaceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *MapStringSliceInterfaceCmd) Clone() Cmder { + var val map[string][]interface{} + if cmd.val != nil { + val = make(map[string][]interface{}, len(cmd.val)) + for k, v := range cmd.val { + if v != nil { + newSlice := make([]interface{}, len(v)) + copy(newSlice, v) + val[k] = newSlice + } + } + } + return &MapStringSliceInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type StringStructMapCmd struct { @@ -1482,8 +2311,9 @@ var _ Cmder = (*StringStructMapCmd)(nil) func NewStringStructMapCmd(ctx context.Context, args ...interface{}) *StringStructMapCmd { return &StringStructMapCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeStringStructMap, }, } } @@ -1521,11 +2351,28 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *StringStructMapCmd) Clone() Cmder { + var val map[string]struct{} + if cmd.val != nil { + val = maps.Clone(cmd.val) + } + return &StringStructMapCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XMessage struct { ID string Values map[string]interface{} + // MillisElapsedFromDelivery is the number of milliseconds since the entry was last delivered. + // Only populated when using XREADGROUP with CLAIM argument for claimed entries. + MillisElapsedFromDelivery int64 + // DeliveredCount is the number of times the entry was delivered. + // Only populated when using XREADGROUP with CLAIM argument for claimed entries. + DeliveredCount int64 } type XMessageSliceCmd struct { @@ -1539,8 +2386,9 @@ var _ Cmder = (*XMessageSliceCmd)(nil) func NewXMessageSliceCmd(ctx context.Context, args ...interface{}) *XMessageSliceCmd { return &XMessageSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXMessageSlice, }, } } @@ -1566,6 +2414,25 @@ func (cmd *XMessageSliceCmd) readReply(rd *proto.Reader) (err error) { return err } +func (cmd *XMessageSliceCmd) Clone() Cmder { + var val []XMessage + if cmd.val != nil { + val = make([]XMessage, len(cmd.val)) + for i, msg := range cmd.val { + val[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Values = maps.Clone(msg.Values) + } + } + } + return &XMessageSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { n, err := rd.ReadArrayLen() if err != nil { @@ -1582,10 +2449,16 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) { } func readXMessage(rd *proto.Reader) (XMessage, error) { - if err := rd.ReadFixedArrayLen(2); err != nil { + // Read array length can be 2 or 4 (with CLAIM metadata) + n, err := rd.ReadArrayLen() + if err != nil { return XMessage{}, err } + if n != 2 && n != 4 { + return XMessage{}, fmt.Errorf("redis: got %d elements in the XMessage array, expected 2 or 4", n) + } + id, err := rd.ReadString() if err != nil { return XMessage{}, err @@ -1598,10 +2471,24 @@ func readXMessage(rd *proto.Reader) (XMessage, error) { } } - return XMessage{ + msg := XMessage{ ID: id, Values: v, - }, nil + } + + if n == 4 { + msg.MillisElapsedFromDelivery, err = rd.ReadInt() + if err != nil { + return XMessage{}, err + } + + msg.DeliveredCount, err = rd.ReadInt() + if err != nil { + return XMessage{}, err + } + } + + return msg, nil } func stringInterfaceMapParser(rd *proto.Reader) (map[string]interface{}, error) { @@ -1645,8 +2532,9 @@ var _ Cmder = (*XStreamSliceCmd)(nil) func NewXStreamSliceCmd(ctx context.Context, args ...interface{}) *XStreamSliceCmd { return &XStreamSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXStreamSlice, }, } } @@ -1699,6 +2587,36 @@ func (cmd *XStreamSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XStreamSliceCmd) Clone() Cmder { + var val []XStream + if cmd.val != nil { + val = make([]XStream, len(cmd.val)) + for i, stream := range cmd.val { + val[i] = XStream{ + Stream: stream.Stream, + } + if stream.Messages != nil { + val[i].Messages = make([]XMessage, len(stream.Messages)) + for j, msg := range stream.Messages { + val[i].Messages[j] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Messages[j].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Messages[j].Values[k] = v + } + } + } + } + } + } + return &XStreamSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XPending struct { @@ -1718,8 +2636,9 @@ var _ Cmder = (*XPendingCmd)(nil) func NewXPendingCmd(ctx context.Context, args ...interface{}) *XPendingCmd { return &XPendingCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXPending, }, } } @@ -1782,6 +2701,27 @@ func (cmd *XPendingCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XPendingCmd) Clone() Cmder { + var val *XPending + if cmd.val != nil { + val = &XPending{ + Count: cmd.val.Count, + Lower: cmd.val.Lower, + Higher: cmd.val.Higher, + } + if cmd.val.Consumers != nil { + val.Consumers = make(map[string]int64, len(cmd.val.Consumers)) + for k, v := range cmd.val.Consumers { + val.Consumers[k] = v + } + } + } + return &XPendingCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XPendingExt struct { @@ -1801,8 +2741,9 @@ var _ Cmder = (*XPendingExtCmd)(nil) func NewXPendingExtCmd(ctx context.Context, args ...interface{}) *XPendingExtCmd { return &XPendingExtCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXPendingExt, }, } } @@ -1857,6 +2798,18 @@ func (cmd *XPendingExtCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XPendingExtCmd) Clone() Cmder { + var val []XPendingExt + if cmd.val != nil { + val = make([]XPendingExt, len(cmd.val)) + copy(val, cmd.val) + } + return &XPendingExtCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XAutoClaimCmd struct { @@ -1871,8 +2824,9 @@ var _ Cmder = (*XAutoClaimCmd)(nil) func NewXAutoClaimCmd(ctx context.Context, args ...interface{}) *XAutoClaimCmd { return &XAutoClaimCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXAutoClaim, }, } } @@ -1919,52 +2873,76 @@ func (cmd *XAutoClaimCmd) readReply(rd *proto.Reader) error { } if n >= 3 { - if err := rd.DiscardNext(); err != nil { - return err - } + return rd.DiscardNext() } return nil } +func (cmd *XAutoClaimCmd) Clone() Cmder { + var val []XMessage + if cmd.val != nil { + val = make([]XMessage, len(cmd.val)) + for i, msg := range cmd.val { + val[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Values[k] = v + } + } + } + } + return &XAutoClaimCmd{ + baseCmd: cmd.cloneBaseCmd(), + start: cmd.start, + val: val, + } +} + //------------------------------------------------------------------------------ -type XAutoClaimJustIDCmd struct { +type XAutoClaimWithDeletedCmd struct { baseCmd - start string - val []string + start string + val []XMessage + deletedIDs []string } -var _ Cmder = (*XAutoClaimJustIDCmd)(nil) +var _ Cmder = (*XAutoClaimWithDeletedCmd)(nil) -func NewXAutoClaimJustIDCmd(ctx context.Context, args ...interface{}) *XAutoClaimJustIDCmd { - return &XAutoClaimJustIDCmd{ +func NewXAutoClaimWithDeletedCmd(ctx context.Context, args ...interface{}) *XAutoClaimWithDeletedCmd { + return &XAutoClaimWithDeletedCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXAutoClaimWithDeleted, }, } } -func (cmd *XAutoClaimJustIDCmd) SetVal(val []string, start string) { +func (cmd *XAutoClaimWithDeletedCmd) SetVal(val []XMessage, start string, deletedIDs []string) { cmd.val = val cmd.start = start + cmd.deletedIDs = deletedIDs } -func (cmd *XAutoClaimJustIDCmd) Val() (ids []string, start string) { - return cmd.val, cmd.start +func (cmd *XAutoClaimWithDeletedCmd) Val() (messages []XMessage, start string, deletedIDs []string) { + return cmd.val, cmd.start, cmd.deletedIDs } -func (cmd *XAutoClaimJustIDCmd) Result() (ids []string, start string, err error) { - return cmd.val, cmd.start, cmd.err +func (cmd *XAutoClaimWithDeletedCmd) Result() (messages []XMessage, start string, deletedIDs []string, err error) { + return cmd.val, cmd.start, cmd.deletedIDs, cmd.err } -func (cmd *XAutoClaimJustIDCmd) String() string { +func (cmd *XAutoClaimWithDeletedCmd) String() string { return cmdString(cmd, cmd.val) } -func (cmd *XAutoClaimJustIDCmd) readReply(rd *proto.Reader) error { +func (cmd *XAutoClaimWithDeletedCmd) readReply(rd *proto.Reader) error { n, err := rd.ReadArrayLen() if err != nil { return err @@ -1975,7 +2953,7 @@ func (cmd *XAutoClaimJustIDCmd) readReply(rd *proto.Reader) error { 3: // Redis 7: // ok default: - return fmt.Errorf("redis: got %d elements in XAutoClaimJustID reply, wanted 2/3", n) + return fmt.Errorf("redis: got %d elements in XAutoClaim reply, wanted 2/3", n) } cmd.start, err = rd.ReadString() @@ -1983,54 +2961,179 @@ func (cmd *XAutoClaimJustIDCmd) readReply(rd *proto.Reader) error { return err } + cmd.val, err = readXMessageSlice(rd) + if err != nil { + return err + } + + if n < 3 { + return nil + } + nn, err := rd.ReadArrayLen() if err != nil { return err } - cmd.val = make([]string, nn) + cmd.deletedIDs = make([]string, nn) for i := 0; i < nn; i++ { - cmd.val[i], err = rd.ReadString() + cmd.deletedIDs[i], err = rd.ReadString() if err != nil { return err } } - if n >= 3 { - if err := rd.DiscardNext(); err != nil { - return err + return nil +} + +func (cmd *XAutoClaimWithDeletedCmd) Clone() Cmder { + var val []XMessage + if cmd.val != nil { + val = make([]XMessage, len(cmd.val)) + for i, msg := range cmd.val { + val[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val[i].Values[k] = v + } + } } } - - return nil + var deletedIDs []string + if cmd.deletedIDs != nil { + deletedIDs = make([]string, len(cmd.deletedIDs)) + copy(deletedIDs, cmd.deletedIDs) + } + return &XAutoClaimWithDeletedCmd{ + baseCmd: cmd.cloneBaseCmd(), + start: cmd.start, + val: val, + deletedIDs: deletedIDs, + } } //------------------------------------------------------------------------------ -type XInfoConsumersCmd struct { +type XAutoClaimJustIDCmd struct { baseCmd - val []XInfoConsumer -} -type XInfoConsumer struct { - Name string - Pending int64 - Idle time.Duration - Inactive time.Duration + start string + val []string } -var _ Cmder = (*XInfoConsumersCmd)(nil) +var _ Cmder = (*XAutoClaimJustIDCmd)(nil) -func NewXInfoConsumersCmd(ctx context.Context, stream string, group string) *XInfoConsumersCmd { - return &XInfoConsumersCmd{ +func NewXAutoClaimJustIDCmd(ctx context.Context, args ...interface{}) *XAutoClaimJustIDCmd { + return &XAutoClaimJustIDCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "consumers", stream, group}, + ctx: ctx, + args: args, + cmdType: CmdTypeXAutoClaimJustID, }, } } -func (cmd *XInfoConsumersCmd) SetVal(val []XInfoConsumer) { +func (cmd *XAutoClaimJustIDCmd) SetVal(val []string, start string) { + cmd.val = val + cmd.start = start +} + +func (cmd *XAutoClaimJustIDCmd) Val() (ids []string, start string) { + return cmd.val, cmd.start +} + +func (cmd *XAutoClaimJustIDCmd) Result() (ids []string, start string, err error) { + return cmd.val, cmd.start, cmd.err +} + +func (cmd *XAutoClaimJustIDCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *XAutoClaimJustIDCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + switch n { + case 2, // Redis 6 + 3: // Redis 7: + // ok + default: + return fmt.Errorf("redis: got %d elements in XAutoClaimJustID reply, wanted 2/3", n) + } + + cmd.start, err = rd.ReadString() + if err != nil { + return err + } + + nn, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([]string, nn) + for i := 0; i < nn; i++ { + cmd.val[i], err = rd.ReadString() + if err != nil { + return err + } + } + + if n >= 3 { + if err := rd.DiscardNext(); err != nil { + return err + } + } + + return nil +} + +func (cmd *XAutoClaimJustIDCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &XAutoClaimJustIDCmd{ + baseCmd: cmd.cloneBaseCmd(), + start: cmd.start, + val: val, + } +} + +//------------------------------------------------------------------------------ + +type XInfoConsumersCmd struct { + baseCmd + val []XInfoConsumer +} + +type XInfoConsumer struct { + Name string + Pending int64 + Idle time.Duration + Inactive time.Duration +} + +var _ Cmder = (*XInfoConsumersCmd)(nil) + +func NewXInfoConsumersCmd(ctx context.Context, stream string, group string) *XInfoConsumersCmd { + return &XInfoConsumersCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: []interface{}{"xinfo", "consumers", stream, group}, + cmdType: CmdTypeXInfoConsumers, + }, + } +} + +func (cmd *XInfoConsumersCmd) SetVal(val []XInfoConsumer) { cmd.val = val } @@ -2080,7 +3183,10 @@ func (cmd *XInfoConsumersCmd) readReply(rd *proto.Reader) error { inactive, err = rd.ReadInt() cmd.val[i].Inactive = time.Duration(inactive) * time.Millisecond default: - return fmt.Errorf("redis: unexpected content %s in XINFO CONSUMERS reply", key) + // skip unknown fields + if err = rd.DiscardNext(); err != nil { + return err + } } if err != nil { return err @@ -2091,6 +3197,18 @@ func (cmd *XInfoConsumersCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoConsumersCmd) Clone() Cmder { + var val []XInfoConsumer + if cmd.val != nil { + val = make([]XInfoConsumer, len(cmd.val)) + copy(val, cmd.val) + } + return &XInfoConsumersCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoGroupsCmd struct { @@ -2104,7 +3222,9 @@ type XInfoGroup struct { Pending int64 LastDeliveredID string EntriesRead int64 - Lag int64 + // Lag represents the number of pending messages in the stream not yet + // delivered to this consumer group. Returns -1 when the lag cannot be determined. + Lag int64 } var _ Cmder = (*XInfoGroupsCmd)(nil) @@ -2112,8 +3232,9 @@ var _ Cmder = (*XInfoGroupsCmd)(nil) func NewXInfoGroupsCmd(ctx context.Context, stream string) *XInfoGroupsCmd { return &XInfoGroupsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "groups", stream}, + ctx: ctx, + args: []interface{}{"xinfo", "groups", stream}, + cmdType: CmdTypeXInfoGroups, }, } } @@ -2187,11 +3308,17 @@ func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error { // lag: the number of entries in the stream that are still waiting to be delivered // to the group's consumers, or a NULL(Nil) when that number can't be determined. + // In that case, we return -1. if err != nil && err != Nil { return err + } else if err == Nil { + group.Lag = -1 } default: - return fmt.Errorf("redis: unexpected key %q in XINFO GROUPS reply", key) + // skip unknown fields + if err = rd.DiscardNext(); err != nil { + return err + } } } } @@ -2199,6 +3326,18 @@ func (cmd *XInfoGroupsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *XInfoGroupsCmd) Clone() Cmder { + var val []XInfoGroup + if cmd.val != nil { + val = make([]XInfoGroup, len(cmd.val)) + copy(val, cmd.val) + } + return &XInfoGroupsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoStreamCmd struct { @@ -2217,6 +3356,13 @@ type XInfoStream struct { FirstEntry XMessage LastEntry XMessage RecordedFirstEntryID string + + IDMPDuration int64 + IDMPMaxSize int64 + PIDsTracked int64 + IIDsTracked int64 + IIDsAdded int64 + IIDsDuplicates int64 } var _ Cmder = (*XInfoStreamCmd)(nil) @@ -2224,8 +3370,9 @@ var _ Cmder = (*XInfoStreamCmd)(nil) func NewXInfoStreamCmd(ctx context.Context, stream string) *XInfoStreamCmd { return &XInfoStreamCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"xinfo", "stream", stream}, + ctx: ctx, + args: []interface{}{"xinfo", "stream", stream}, + cmdType: CmdTypeXInfoStream, }, } } @@ -2309,13 +3456,85 @@ func (cmd *XInfoStreamCmd) readReply(rd *proto.Reader) error { if err != nil { return err } + case "idmp-duration": + cmd.val.IDMPDuration, err = rd.ReadInt() + if err != nil { + return err + } + case "idmp-maxsize": + cmd.val.IDMPMaxSize, err = rd.ReadInt() + if err != nil { + return err + } + case "pids-tracked": + cmd.val.PIDsTracked, err = rd.ReadInt() + if err != nil { + return err + } + case "iids-tracked": + cmd.val.IIDsTracked, err = rd.ReadInt() + if err != nil { + return err + } + case "iids-added": + cmd.val.IIDsAdded, err = rd.ReadInt() + if err != nil { + return err + } + case "iids-duplicates": + cmd.val.IIDsDuplicates, err = rd.ReadInt() + if err != nil { + return err + } default: - return fmt.Errorf("redis: unexpected key %q in XINFO STREAM reply", key) + // skip unknown fields + if err = rd.DiscardNext(); err != nil { + return err + } } } return nil } +func (cmd *XInfoStreamCmd) Clone() Cmder { + var val *XInfoStream + if cmd.val != nil { + val = &XInfoStream{ + Length: cmd.val.Length, + RadixTreeKeys: cmd.val.RadixTreeKeys, + RadixTreeNodes: cmd.val.RadixTreeNodes, + Groups: cmd.val.Groups, + LastGeneratedID: cmd.val.LastGeneratedID, + MaxDeletedEntryID: cmd.val.MaxDeletedEntryID, + EntriesAdded: cmd.val.EntriesAdded, + RecordedFirstEntryID: cmd.val.RecordedFirstEntryID, + } + // Clone XMessage fields + val.FirstEntry = XMessage{ + ID: cmd.val.FirstEntry.ID, + } + if cmd.val.FirstEntry.Values != nil { + val.FirstEntry.Values = make(map[string]interface{}, len(cmd.val.FirstEntry.Values)) + for k, v := range cmd.val.FirstEntry.Values { + val.FirstEntry.Values[k] = v + } + } + val.LastEntry = XMessage{ + ID: cmd.val.LastEntry.ID, + } + if cmd.val.LastEntry.Values != nil { + val.LastEntry.Values = make(map[string]interface{}, len(cmd.val.LastEntry.Values)) + for k, v := range cmd.val.LastEntry.Values { + val.LastEntry.Values[k] = v + } + } + } + return &XInfoStreamCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type XInfoStreamFullCmd struct { @@ -2333,6 +3552,12 @@ type XInfoStreamFull struct { Entries []XMessage Groups []XInfoStreamGroup RecordedFirstEntryID string + IDMPDuration int64 + IDMPMaxSize int64 + PIDsTracked int64 + IIDsTracked int64 + IIDsAdded int64 + IIDsDuplicates int64 } type XInfoStreamGroup struct { @@ -2341,6 +3566,7 @@ type XInfoStreamGroup struct { EntriesRead int64 Lag int64 PelCount int64 + NackedCount uint64 // redis version 8.8, number of NACK'd messages in the group Pending []XInfoStreamGroupPending Consumers []XInfoStreamConsumer } @@ -2371,8 +3597,9 @@ var _ Cmder = (*XInfoStreamFullCmd)(nil) func NewXInfoStreamFullCmd(ctx context.Context, args ...interface{}) *XInfoStreamFullCmd { return &XInfoStreamFullCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeXInfoStreamFull, }, } } @@ -2453,8 +3680,41 @@ func (cmd *XInfoStreamFullCmd) readReply(rd *proto.Reader) error { if err != nil { return err } + case "idmp-duration": + cmd.val.IDMPDuration, err = rd.ReadInt() + if err != nil { + return err + } + case "idmp-maxsize": + cmd.val.IDMPMaxSize, err = rd.ReadInt() + if err != nil { + return err + } + case "pids-tracked": + cmd.val.PIDsTracked, err = rd.ReadInt() + if err != nil { + return err + } + case "iids-tracked": + cmd.val.IIDsTracked, err = rd.ReadInt() + if err != nil { + return err + } + case "iids-added": + cmd.val.IIDsAdded, err = rd.ReadInt() + if err != nil { + return err + } + case "iids-duplicates": + cmd.val.IIDsDuplicates, err = rd.ReadInt() + if err != nil { + return err + } default: - return fmt.Errorf("redis: unexpected key %q in XINFO STREAM FULL reply", key) + // skip unknown fields + if err = rd.DiscardNext(); err != nil { + return err + } } } return nil @@ -2508,6 +3768,11 @@ func readStreamGroups(rd *proto.Reader) ([]XInfoStreamGroup, error) { if err != nil { return nil, err } + case "nacked-count": + group.NackedCount, err = rd.ReadUint() + if err != nil { + return nil, err + } case "pending": group.Pending, err = readXInfoStreamGroupPending(rd) if err != nil { @@ -2519,7 +3784,10 @@ func readStreamGroups(rd *proto.Reader) ([]XInfoStreamGroup, error) { return nil, err } default: - return nil, fmt.Errorf("redis: unexpected key %q in XINFO STREAM FULL reply", key) + // skip unknown fields + if err = rd.DiscardNext(); err != nil { + return nil, err + } } } @@ -2644,8 +3912,10 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { c.Pending = append(c.Pending, p) } default: - return nil, fmt.Errorf("redis: unexpected content %s "+ - "in XINFO STREAM FULL reply", cKey) + // skip unknown fields + if err = rd.DiscardNext(); err != nil { + return nil, err + } } if err != nil { return nil, err @@ -2657,6 +3927,45 @@ func readXInfoStreamConsumers(rd *proto.Reader) ([]XInfoStreamConsumer, error) { return consumers, nil } +func (cmd *XInfoStreamFullCmd) Clone() Cmder { + var val *XInfoStreamFull + if cmd.val != nil { + val = &XInfoStreamFull{ + Length: cmd.val.Length, + RadixTreeKeys: cmd.val.RadixTreeKeys, + RadixTreeNodes: cmd.val.RadixTreeNodes, + LastGeneratedID: cmd.val.LastGeneratedID, + MaxDeletedEntryID: cmd.val.MaxDeletedEntryID, + EntriesAdded: cmd.val.EntriesAdded, + RecordedFirstEntryID: cmd.val.RecordedFirstEntryID, + } + // Clone Entries + if cmd.val.Entries != nil { + val.Entries = make([]XMessage, len(cmd.val.Entries)) + for i, msg := range cmd.val.Entries { + val.Entries[i] = XMessage{ + ID: msg.ID, + } + if msg.Values != nil { + val.Entries[i].Values = make(map[string]interface{}, len(msg.Values)) + for k, v := range msg.Values { + val.Entries[i].Values[k] = v + } + } + } + } + // Clone Groups - simplified copy for now due to complexity + if cmd.val.Groups != nil { + val.Groups = make([]XInfoStreamGroup, len(cmd.val.Groups)) + copy(val.Groups, cmd.val.Groups) + } + } + return &XInfoStreamFullCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ZSliceCmd struct { @@ -2670,8 +3979,9 @@ var _ Cmder = (*ZSliceCmd)(nil) func NewZSliceCmd(ctx context.Context, args ...interface{}) *ZSliceCmd { return &ZSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZSlice, }, } } @@ -2735,6 +4045,18 @@ func (cmd *ZSliceCmd) readReply(rd *proto.Reader) error { // nolint:dupl return nil } +func (cmd *ZSliceCmd) Clone() Cmder { + var val []Z + if cmd.val != nil { + val = make([]Z, len(cmd.val)) + copy(val, cmd.val) + } + return &ZSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ZWithKeyCmd struct { @@ -2748,8 +4070,9 @@ var _ Cmder = (*ZWithKeyCmd)(nil) func NewZWithKeyCmd(ctx context.Context, args ...interface{}) *ZWithKeyCmd { return &ZWithKeyCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZWithKey, }, } } @@ -2789,6 +4112,23 @@ func (cmd *ZWithKeyCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ZWithKeyCmd) Clone() Cmder { + var val *ZWithKey + if cmd.val != nil { + val = &ZWithKey{ + Z: Z{ + Score: cmd.val.Score, + Member: cmd.val.Member, + }, + Key: cmd.val.Key, + } + } + return &ZWithKeyCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type ScanCmd struct { @@ -2805,8 +4145,9 @@ var _ Cmder = (*ScanCmd)(nil) func NewScanCmd(ctx context.Context, process cmdable, args ...interface{}) *ScanCmd { return &ScanCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeScan, }, process: process, } @@ -2854,6 +4195,20 @@ func (cmd *ScanCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ScanCmd) Clone() Cmder { + var page []string + if cmd.page != nil { + page = make([]string, len(cmd.page)) + copy(page, cmd.page) + } + return &ScanCmd{ + baseCmd: cmd.cloneBaseCmd(), + page: page, + cursor: cmd.cursor, + process: cmd.process, + } +} + // Iterator creates a new ScanIterator. func (cmd *ScanCmd) Iterator() *ScanIterator { return &ScanIterator{ @@ -2886,8 +4241,9 @@ var _ Cmder = (*ClusterSlotsCmd)(nil) func NewClusterSlotsCmd(ctx context.Context, args ...interface{}) *ClusterSlotsCmd { return &ClusterSlotsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterSlots, }, } } @@ -3000,6 +4356,38 @@ func (cmd *ClusterSlotsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterSlotsCmd) Clone() Cmder { + var val []ClusterSlot + if cmd.val != nil { + val = make([]ClusterSlot, len(cmd.val)) + for i, slot := range cmd.val { + val[i] = ClusterSlot{ + Start: slot.Start, + End: slot.End, + } + if slot.Nodes != nil { + val[i].Nodes = make([]ClusterNode, len(slot.Nodes)) + for j, node := range slot.Nodes { + val[i].Nodes[j] = ClusterNode{ + ID: node.ID, + Addr: node.Addr, + } + if node.NetworkingMetadata != nil { + val[i].Nodes[j].NetworkingMetadata = make(map[string]string, len(node.NetworkingMetadata)) + for k, v := range node.NetworkingMetadata { + val[i].Nodes[j].NetworkingMetadata[k] = v + } + } + } + } + } + } + return &ClusterSlotsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // GeoLocation is used with GeoAdd to add geospatial location. @@ -3039,8 +4427,9 @@ var _ Cmder = (*GeoLocationCmd)(nil) func NewGeoLocationCmd(ctx context.Context, q *GeoRadiusQuery, args ...interface{}) *GeoLocationCmd { return &GeoLocationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: geoLocationArgs(q, args...), + ctx: ctx, + args: geoLocationArgs(q, args...), + cmdType: CmdTypeGeoLocation, }, q: q, } @@ -3148,6 +4537,34 @@ func (cmd *GeoLocationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoLocationCmd) Clone() Cmder { + var q *GeoRadiusQuery + if cmd.q != nil { + q = &GeoRadiusQuery{ + Radius: cmd.q.Radius, + Unit: cmd.q.Unit, + WithCoord: cmd.q.WithCoord, + WithDist: cmd.q.WithDist, + WithGeoHash: cmd.q.WithGeoHash, + Count: cmd.q.Count, + Sort: cmd.q.Sort, + Store: cmd.q.Store, + StoreDist: cmd.q.StoreDist, + withLen: cmd.q.withLen, + } + } + var locations []GeoLocation + if cmd.locations != nil { + locations = make([]GeoLocation, len(cmd.locations)) + copy(locations, cmd.locations) + } + return &GeoLocationCmd{ + baseCmd: cmd.cloneBaseCmd(), + q: q, + locations: locations, + } +} + //------------------------------------------------------------------------------ // GeoSearchQuery is used for GEOSearch/GEOSearchStore command query. @@ -3255,8 +4672,9 @@ func NewGeoSearchLocationCmd( ) *GeoSearchLocationCmd { return &GeoSearchLocationCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: geoSearchLocationArgs(opt, args), + cmdType: CmdTypeGeoSearchLocation, }, opt: opt, } @@ -3329,6 +4747,40 @@ func (cmd *GeoSearchLocationCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoSearchLocationCmd) Clone() Cmder { + var opt *GeoSearchLocationQuery + if cmd.opt != nil { + opt = &GeoSearchLocationQuery{ + GeoSearchQuery: GeoSearchQuery{ + Member: cmd.opt.Member, + Longitude: cmd.opt.Longitude, + Latitude: cmd.opt.Latitude, + Radius: cmd.opt.Radius, + RadiusUnit: cmd.opt.RadiusUnit, + BoxWidth: cmd.opt.BoxWidth, + BoxHeight: cmd.opt.BoxHeight, + BoxUnit: cmd.opt.BoxUnit, + Sort: cmd.opt.Sort, + Count: cmd.opt.Count, + CountAny: cmd.opt.CountAny, + }, + WithCoord: cmd.opt.WithCoord, + WithDist: cmd.opt.WithDist, + WithHash: cmd.opt.WithHash, + } + } + var val []GeoLocation + if cmd.val != nil { + val = make([]GeoLocation, len(cmd.val)) + copy(val, cmd.val) + } + return &GeoSearchLocationCmd{ + baseCmd: cmd.cloneBaseCmd(), + opt: opt, + val: val, + } +} + //------------------------------------------------------------------------------ type GeoPos struct { @@ -3346,8 +4798,9 @@ var _ Cmder = (*GeoPosCmd)(nil) func NewGeoPosCmd(ctx context.Context, args ...interface{}) *GeoPosCmd { return &GeoPosCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeGeoPos, }, } } @@ -3403,17 +4856,37 @@ func (cmd *GeoPosCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *GeoPosCmd) Clone() Cmder { + var val []*GeoPos + if cmd.val != nil { + val = make([]*GeoPos, len(cmd.val)) + for i, pos := range cmd.val { + if pos != nil { + val[i] = &GeoPos{ + Longitude: pos.Longitude, + Latitude: pos.Latitude, + } + } + } + } + return &GeoPosCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type CommandInfo struct { - Name string - Arity int8 - Flags []string - ACLFlags []string - FirstKeyPos int8 - LastKeyPos int8 - StepCount int8 - ReadOnly bool + Name string + Arity int8 + Flags []string + ACLFlags []string + FirstKeyPos int8 + LastKeyPos int8 + StepCount int8 + ReadOnly bool + CommandPolicy *routing.CommandPolicy } type CommandsInfoCmd struct { @@ -3427,8 +4900,9 @@ var _ Cmder = (*CommandsInfoCmd)(nil) func NewCommandsInfoCmd(ctx context.Context, args ...interface{}) *CommandsInfoCmd { return &CommandsInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCommandsInfo, }, } } @@ -3452,7 +4926,7 @@ func (cmd *CommandsInfoCmd) String() string { func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { const numArgRedis5 = 6 const numArgRedis6 = 7 - const numArgRedis7 = 10 + const numArgRedis7 = 10 // Also matches redis 8 n, err := rd.ReadArrayLen() if err != nil { @@ -3540,9 +5014,33 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { } if nn >= numArgRedis7 { - if err := rd.DiscardNext(); err != nil { + // The 8th argument is an array of tips. + tipsLen, err := rd.ReadArrayLen() + if err != nil { return err } + + rawTips := make(map[string]string, tipsLen) + if cmdInfo.ReadOnly { + rawTips[routing.ReadOnlyCMD] = "" + } + for f := 0; f < tipsLen; f++ { + tip, err := rd.ReadString() + if err != nil { + return err + } + + k, v, ok := strings.Cut(tip, ":") + if !ok { + // Handle tips that don't have a colon (like "nondeterministic_output") + rawTips[tip] = "" + } else { + // Handle normal key:value tips + rawTips[k] = v + } + } + cmdInfo.CommandPolicy = parseCommandPolicies(rawTips, cmdInfo.FirstKeyPos) + if err := rd.DiscardNext(); err != nil { return err } @@ -3557,13 +5055,47 @@ func (cmd *CommandsInfoCmd) readReply(rd *proto.Reader) error { return nil } -//------------------------------------------------------------------------------ - -type cmdsInfoCache struct { - fn func(ctx context.Context) (map[string]*CommandInfo, error) +func (cmd *CommandsInfoCmd) Clone() Cmder { + var val map[string]*CommandInfo + if cmd.val != nil { + val = make(map[string]*CommandInfo, len(cmd.val)) + for k, v := range cmd.val { + if v != nil { + newInfo := &CommandInfo{ + Name: v.Name, + Arity: v.Arity, + FirstKeyPos: v.FirstKeyPos, + LastKeyPos: v.LastKeyPos, + StepCount: v.StepCount, + ReadOnly: v.ReadOnly, + CommandPolicy: v.CommandPolicy, // CommandPolicy can be shared as it's immutable + } + if v.Flags != nil { + newInfo.Flags = make([]string, len(v.Flags)) + copy(newInfo.Flags, v.Flags) + } + if v.ACLFlags != nil { + newInfo.ACLFlags = make([]string, len(v.ACLFlags)) + copy(newInfo.ACLFlags, v.ACLFlags) + } + val[k] = newInfo + } + } + } + return &CommandsInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +//------------------------------------------------------------------------------ + +type cmdsInfoCache struct { + fn func(ctx context.Context) (map[string]*CommandInfo, error) - once internal.Once - cmds map[string]*CommandInfo + once internal.Once + refreshLock sync.RWMutex + cmds map[string]*CommandInfo } func newCmdsInfoCache(fn func(ctx context.Context) (map[string]*CommandInfo, error)) *cmdsInfoCache { @@ -3573,26 +5105,82 @@ func newCmdsInfoCache(fn func(ctx context.Context) (map[string]*CommandInfo, err } func (c *cmdsInfoCache) Get(ctx context.Context) (map[string]*CommandInfo, error) { + c.refreshLock.Lock() + defer c.refreshLock.Unlock() + err := c.once.Do(func() error { cmds, err := c.fn(ctx) if err != nil { return err } + lowerCmds := make(map[string]*CommandInfo, len(cmds)) + // Extensions have cmd names in upper case. Convert them to lower case. for k, v := range cmds { - lower := internal.ToLower(k) - if lower != k { - cmds[lower] = v - } + lowerCmds[internal.ToLower(k)] = v } - c.cmds = cmds + c.cmds = lowerCmds return nil }) return c.cmds, err } +func (c *cmdsInfoCache) Refresh() { + c.refreshLock.Lock() + defer c.refreshLock.Unlock() + + c.once = internal.Once{} +} + +// Peek returns the cached CommandInfo map without triggering a Redis round-trip. +// Returns nil when the cache is cold; callers should fall back to other heuristics. +// Note: during the very first Get() (initial population) this call will block on +// the writer lock. After that, concurrent Peek() calls do not block each other. +// The returned map and its entries MUST NOT be mutated by the caller. +func (c *cmdsInfoCache) Peek() map[string]*CommandInfo { + if c == nil { + return nil + } + c.refreshLock.RLock() + defer c.refreshLock.RUnlock() + return c.cmds +} + +// ------------------------------------------------------------------------------ +const ( + requestPolicy = "request_policy" + responsePolicy = "response_policy" +) + +func parseCommandPolicies(commandInfoTips map[string]string, firstKeyPos int8) *routing.CommandPolicy { + req := routing.ReqDefault + resp := routing.RespDefaultKeyless + if firstKeyPos > 0 { + resp = routing.RespDefaultHashSlot + } + + tips := make(map[string]string, len(commandInfoTips)) + for k, v := range commandInfoTips { + if k == requestPolicy { + if p, err := routing.ParseRequestPolicy(v); err == nil { + req = p + } + continue + } + if k == responsePolicy { + if p, err := routing.ParseResponsePolicy(v); err == nil { + resp = p + } + continue + } + tips[k] = v + } + + return &routing.CommandPolicy{Request: req, Response: resp, Tips: tips} +} + //------------------------------------------------------------------------------ type SlowLog struct { @@ -3617,8 +5205,9 @@ var _ Cmder = (*SlowLogCmd)(nil) func NewSlowLogCmd(ctx context.Context, args ...interface{}) *SlowLogCmd { return &SlowLogCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeSlowLog, }, } } @@ -3703,6 +5292,356 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *SlowLogCmd) Clone() Cmder { + var val []SlowLog + if cmd.val != nil { + val = make([]SlowLog, len(cmd.val)) + for i, log := range cmd.val { + val[i] = SlowLog{ + ID: log.ID, + Time: log.Time, + Duration: log.Duration, + ClientAddr: log.ClientAddr, + ClientName: log.ClientName, + } + if log.Args != nil { + val[i].Args = make([]string, len(log.Args)) + copy(val[i].Args, log.Args) + } + } + } + return &SlowLogCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +//----------------------------------------------------------------------- + +type Latency struct { + Name string + Time time.Time + Latest time.Duration + Max time.Duration +} + +type LatencyCmd struct { + baseCmd + val []Latency +} + +var _ Cmder = (*LatencyCmd)(nil) + +func NewLatencyCmd(ctx context.Context, args ...interface{}) *LatencyCmd { + return &LatencyCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *LatencyCmd) SetVal(val []Latency) { + cmd.val = val +} + +func (cmd *LatencyCmd) Val() []Latency { + return cmd.val +} + +func (cmd *LatencyCmd) Result() ([]Latency, error) { + return cmd.val, cmd.err +} + +func (cmd *LatencyCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *LatencyCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + cmd.val = make([]Latency, n) + for i := 0; i < len(cmd.val); i++ { + nn, err := rd.ReadArrayLen() + if err != nil { + return err + } + if nn < 3 { + return fmt.Errorf("redis: got %d elements in latency get, expected at least 3", nn) + } + if cmd.val[i].Name, err = rd.ReadString(); err != nil { + return err + } + createdAt, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[i].Time = time.Unix(createdAt, 0) + latest, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[i].Latest = time.Duration(latest) * time.Millisecond + maximum, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val[i].Max = time.Duration(maximum) * time.Millisecond + } + return nil +} + +func (cmd *LatencyCmd) Clone() Cmder { + var val []Latency + if cmd.val != nil { + val = make([]Latency, len(cmd.val)) + copy(val, cmd.val) + } + return &LatencyCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +//----------------------------------------------------------------------- + +// HotKeysSlotRange represents a slot or slot range in the response. +// Single element slice = individual slot, two element slice = slot range [start, end]. +type HotKeysSlotRange []int64 + +// HotKeysKeyEntry represents a hot key entry with its metric value. +type HotKeysKeyEntry struct { + Key string + Value interface{} // Can be int64 or string +} + +// HotKeysResult represents the response data from HOTKEYS GET command. +// Field names match the Redis response format. +type HotKeysResult struct { + TrackingActive bool + SampleRatio uint8 + SelectedSlots []HotKeysSlotRange + SampledCommandsSelectedSlots time.Duration // Present when sample-ratio > 1 and selected-slots is not empty + AllCommandsSelectedSlots time.Duration // Present when selected-slots is not empty + AllCommandsAllSlots time.Duration + NetBytesSampledCommandsSelectedSlots int64 // Present when sample-ratio > 1 and selected-slots is not empty + NetBytesAllCommandsSelectedSlots int64 // Present when selected-slots is not empty + NetBytesAllCommandsAllSlots int64 + CollectionStartTime time.Time + CollectionDuration time.Duration + UsedCPUSys time.Duration + UsedCPUUser time.Duration + TotalNetBytes int64 + ByCPUTime []HotKeysKeyEntry + ByNetBytes []HotKeysKeyEntry +} + +type HotKeysCmd struct { + baseCmd + + val *HotKeysResult +} + +var _ Cmder = (*HotKeysCmd)(nil) + +func NewHotKeysCmd(ctx context.Context, args ...interface{}) *HotKeysCmd { + return &HotKeysCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeHotKeys, + }, + } +} + +func (cmd *HotKeysCmd) SetVal(val *HotKeysResult) { + cmd.val = val +} + +func (cmd *HotKeysCmd) Val() *HotKeysResult { + return cmd.val +} + +func (cmd *HotKeysCmd) Result() (*HotKeysResult, error) { + return cmd.val, cmd.err +} + +func (cmd *HotKeysCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *HotKeysCmd) readReply(rd *proto.Reader) error { + // HOTKEYS GET response is wrapped in an array for aggregation support + arrayLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + + if arrayLen == 0 { + // Empty array means no tracking was started or after reset + cmd.val = nil + return nil + } + + // Read the first (and typically only) element which is a map + n, err := rd.ReadMapLen() + if err != nil { + return err + } + + result := &HotKeysResult{} + data := make(map[string]interface{}, n) + + for i := 0; i < n; i++ { + k, err := rd.ReadString() + if err != nil { + return err + } + v, err := rd.ReadReply() + if err != nil { + if err == Nil { + data[k] = Nil + continue + } + if err, ok := err.(proto.RedisError); ok { + data[k] = err + continue + } + return err + } + data[k] = v + } + + if v, ok := data["tracking-active"].(int64); ok { + result.TrackingActive = v == 1 + } + if v, ok := data["sample-ratio"].(int64); ok { + result.SampleRatio = uint8(v) + } + if v, ok := data["selected-slots"].([]interface{}); ok { + result.SelectedSlots = make([]HotKeysSlotRange, 0, len(v)) + for _, slot := range v { + switch s := slot.(type) { + case int64: + // Single slot + result.SelectedSlots = append(result.SelectedSlots, HotKeysSlotRange{s}) + case []interface{}: + // Slot range + slotRange := make(HotKeysSlotRange, 0, len(s)) + for _, sr := range s { + if val, ok := sr.(int64); ok { + slotRange = append(slotRange, val) + } + } + result.SelectedSlots = append(result.SelectedSlots, slotRange) + } + } + } + if v, ok := data["sampled-commands-selected-slots-us"].(int64); ok { + result.SampledCommandsSelectedSlots = time.Duration(v) * time.Microsecond + } + if v, ok := data["all-commands-selected-slots-us"].(int64); ok { + result.AllCommandsSelectedSlots = time.Duration(v) * time.Microsecond + } + if v, ok := data["all-commands-all-slots-us"].(int64); ok { + result.AllCommandsAllSlots = time.Duration(v) * time.Microsecond + } + if v, ok := data["net-bytes-sampled-commands-selected-slots"].(int64); ok { + result.NetBytesSampledCommandsSelectedSlots = v + } + if v, ok := data["net-bytes-all-commands-selected-slots"].(int64); ok { + result.NetBytesAllCommandsSelectedSlots = v + } + if v, ok := data["net-bytes-all-commands-all-slots"].(int64); ok { + result.NetBytesAllCommandsAllSlots = v + } + if v, ok := data["collection-start-time-unix-ms"].(int64); ok { + result.CollectionStartTime = time.UnixMilli(v) + } + if v, ok := data["collection-duration-ms"].(int64); ok { + result.CollectionDuration = time.Duration(v) * time.Millisecond + } + if v, ok := data["used-cpu-sys-ms"].(int64); ok { + result.UsedCPUSys = time.Duration(v) * time.Millisecond + } + if v, ok := data["used-cpu-user-ms"].(int64); ok { + result.UsedCPUUser = time.Duration(v) * time.Millisecond + } + if v, ok := data["total-net-bytes"].(int64); ok { + result.TotalNetBytes = v + } + + if v, ok := data["by-cpu-time-us"].([]interface{}); ok { + result.ByCPUTime = parseHotKeysKeyEntries(v) + } + + if v, ok := data["by-net-bytes"].([]interface{}); ok { + result.ByNetBytes = parseHotKeysKeyEntries(v) + } + + cmd.val = result + return nil +} + +// parseHotKeysKeyEntries parses the key-value pairs from HOTKEYS GET response. +func parseHotKeysKeyEntries(v []interface{}) []HotKeysKeyEntry { + entries := make([]HotKeysKeyEntry, 0, len(v)/2) + for i := 0; i < len(v); i += 2 { + if i+1 < len(v) { + key, keyOk := v[i].(string) + if keyOk { + entries = append(entries, HotKeysKeyEntry{ + Key: key, + Value: v[i+1], // Can be int64 or string + }) + } + } + } + return entries +} + +func (cmd *HotKeysCmd) Clone() Cmder { + var val *HotKeysResult + if cmd.val != nil { + val = &HotKeysResult{ + TrackingActive: cmd.val.TrackingActive, + SampleRatio: cmd.val.SampleRatio, + SampledCommandsSelectedSlots: cmd.val.SampledCommandsSelectedSlots, + AllCommandsSelectedSlots: cmd.val.AllCommandsSelectedSlots, + AllCommandsAllSlots: cmd.val.AllCommandsAllSlots, + NetBytesSampledCommandsSelectedSlots: cmd.val.NetBytesSampledCommandsSelectedSlots, + NetBytesAllCommandsSelectedSlots: cmd.val.NetBytesAllCommandsSelectedSlots, + NetBytesAllCommandsAllSlots: cmd.val.NetBytesAllCommandsAllSlots, + CollectionStartTime: cmd.val.CollectionStartTime, + CollectionDuration: cmd.val.CollectionDuration, + UsedCPUSys: cmd.val.UsedCPUSys, + UsedCPUUser: cmd.val.UsedCPUUser, + TotalNetBytes: cmd.val.TotalNetBytes, + } + if cmd.val.SelectedSlots != nil { + val.SelectedSlots = make([]HotKeysSlotRange, len(cmd.val.SelectedSlots)) + for i, sr := range cmd.val.SelectedSlots { + val.SelectedSlots[i] = make(HotKeysSlotRange, len(sr)) + copy(val.SelectedSlots[i], sr) + } + } + if cmd.val.ByCPUTime != nil { + val.ByCPUTime = make([]HotKeysKeyEntry, len(cmd.val.ByCPUTime)) + copy(val.ByCPUTime, cmd.val.ByCPUTime) + } + if cmd.val.ByNetBytes != nil { + val.ByNetBytes = make([]HotKeysKeyEntry, len(cmd.val.ByNetBytes)) + copy(val.ByNetBytes, cmd.val.ByNetBytes) + } + } + return &HotKeysCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringInterfaceCmd struct { @@ -3716,8 +5655,9 @@ var _ Cmder = (*MapStringInterfaceCmd)(nil) func NewMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceCmd { return &MapStringInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterface, }, } } @@ -3767,6 +5707,20 @@ func (cmd *MapStringInterfaceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringInterfaceCmd) Clone() Cmder { + var val map[string]interface{} + if cmd.val != nil { + val = make(map[string]interface{}, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapStringInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringStringSliceCmd struct { @@ -3780,8 +5734,9 @@ var _ Cmder = (*MapStringStringSliceCmd)(nil) func NewMapStringStringSliceCmd(ctx context.Context, args ...interface{}) *MapStringStringSliceCmd { return &MapStringStringSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringStringSlice, }, } } @@ -3831,6 +5786,25 @@ func (cmd *MapStringStringSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringStringSliceCmd) Clone() Cmder { + var val []map[string]string + if cmd.val != nil { + val = make([]map[string]string, len(cmd.val)) + for i, m := range cmd.val { + if m != nil { + val[i] = make(map[string]string, len(m)) + for k, v := range m { + val[i][k] = v + } + } + } + } + return &MapStringStringSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ----------------------------------------------------------------------- // MapMapStringInterfaceCmd represents a command that returns a map of strings to interface{}. @@ -3842,8 +5816,9 @@ type MapMapStringInterfaceCmd struct { func NewMapMapStringInterfaceCmd(ctx context.Context, args ...interface{}) *MapMapStringInterfaceCmd { return &MapMapStringInterfaceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapMapStringInterface, }, } } @@ -3909,6 +5884,20 @@ func (cmd *MapMapStringInterfaceCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *MapMapStringInterfaceCmd) Clone() Cmder { + var val map[string]interface{} + if cmd.val != nil { + val = make(map[string]interface{}, len(cmd.val)) + for k, v := range cmd.val { + val[k] = v + } + } + return &MapMapStringInterfaceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //----------------------------------------------------------------------- type MapStringInterfaceSliceCmd struct { @@ -3922,8 +5911,9 @@ var _ Cmder = (*MapStringInterfaceSliceCmd)(nil) func NewMapStringInterfaceSliceCmd(ctx context.Context, args ...interface{}) *MapStringInterfaceSliceCmd { return &MapStringInterfaceSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeMapStringInterfaceSlice, }, } } @@ -3974,6 +5964,25 @@ func (cmd *MapStringInterfaceSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *MapStringInterfaceSliceCmd) Clone() Cmder { + var val []map[string]interface{} + if cmd.val != nil { + val = make([]map[string]interface{}, len(cmd.val)) + for i, m := range cmd.val { + if m != nil { + val[i] = make(map[string]interface{}, len(m)) + for k, v := range m { + val[i][k] = v + } + } + } + } + return &MapStringInterfaceSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ type KeyValuesCmd struct { @@ -3988,8 +5997,9 @@ var _ Cmder = (*KeyValuesCmd)(nil) func NewKeyValuesCmd(ctx context.Context, args ...interface{}) *KeyValuesCmd { return &KeyValuesCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyValues, }, } } @@ -4036,6 +6046,19 @@ func (cmd *KeyValuesCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *KeyValuesCmd) Clone() Cmder { + var val []string + if cmd.val != nil { + val = make([]string, len(cmd.val)) + copy(val, cmd.val) + } + return &KeyValuesCmd{ + baseCmd: cmd.cloneBaseCmd(), + key: cmd.key, + val: val, + } +} + //------------------------------------------------------------------------------ type ZSliceWithKeyCmd struct { @@ -4050,8 +6073,9 @@ var _ Cmder = (*ZSliceWithKeyCmd)(nil) func NewZSliceWithKeyCmd(ctx context.Context, args ...interface{}) *ZSliceWithKeyCmd { return &ZSliceWithKeyCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeZSliceWithKey, }, } } @@ -4119,6 +6143,19 @@ func (cmd *ZSliceWithKeyCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ZSliceWithKeyCmd) Clone() Cmder { + var val []Z + if cmd.val != nil { + val = make([]Z, len(cmd.val)) + copy(val, cmd.val) + } + return &ZSliceWithKeyCmd{ + baseCmd: cmd.cloneBaseCmd(), + key: cmd.key, + val: val, + } +} + type Function struct { Name string Description string @@ -4143,8 +6180,9 @@ var _ Cmder = (*FunctionListCmd)(nil) func NewFunctionListCmd(ctx context.Context, args ...interface{}) *FunctionListCmd { return &FunctionListCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFunctionList, }, } } @@ -4271,8 +6309,39 @@ func (cmd *FunctionListCmd) readFunctions(rd *proto.Reader) ([]Function, error) return functions, nil } -// FunctionStats contains information about the scripts currently executing on the server, and the available engines -// - Engines: +func (cmd *FunctionListCmd) Clone() Cmder { + var val []Library + if cmd.val != nil { + val = make([]Library, len(cmd.val)) + for i, lib := range cmd.val { + val[i] = Library{ + Name: lib.Name, + Engine: lib.Engine, + Code: lib.Code, + } + if lib.Functions != nil { + val[i].Functions = make([]Function, len(lib.Functions)) + for j, fn := range lib.Functions { + val[i].Functions[j] = Function{ + Name: fn.Name, + Description: fn.Description, + } + if fn.Flags != nil { + val[i].Functions[j].Flags = make([]string, len(fn.Flags)) + copy(val[i].Functions[j].Flags, fn.Flags) + } + } + } + } + } + return &FunctionListCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +// FunctionStats contains information about the scripts currently executing on the server, and the available engines +// - Engines: // Statistics about the engine like number of functions and number of libraries // - RunningScript: // The script currently running on the shard we're connecting to. @@ -4324,8 +6393,9 @@ var _ Cmder = (*FunctionStatsCmd)(nil) func NewFunctionStatsCmd(ctx context.Context, args ...interface{}) *FunctionStatsCmd { return &FunctionStatsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFunctionStats, }, } } @@ -4496,6 +6566,34 @@ func (cmd *FunctionStatsCmd) readRunningScripts(rd *proto.Reader) ([]RunningScri return runningScripts, len(runningScripts) > 0, nil } +func (cmd *FunctionStatsCmd) Clone() Cmder { + val := FunctionStats{ + isRunning: cmd.val.isRunning, + rs: cmd.val.rs, // RunningScript is a simple struct, can be copied directly + } + if cmd.val.Engines != nil { + val.Engines = make([]Engine, len(cmd.val.Engines)) + copy(val.Engines, cmd.val.Engines) + } + if cmd.val.allrs != nil { + val.allrs = make([]RunningScript, len(cmd.val.allrs)) + for i, rs := range cmd.val.allrs { + val.allrs[i] = RunningScript{ + Name: rs.Name, + Duration: rs.Duration, + } + if rs.Command != nil { + val.allrs[i].Command = make([]string, len(rs.Command)) + copy(val.allrs[i].Command, rs.Command) + } + } + } + return &FunctionStatsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // LCSQuery is a parameter used for the LCS command @@ -4559,8 +6657,9 @@ func NewLCSCmd(ctx context.Context, q *LCSQuery) *LCSCmd { } } cmd.baseCmd = baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeLCS, } return cmd @@ -4672,6 +6771,25 @@ func (cmd *LCSCmd) readPosition(rd *proto.Reader) (pos LCSPosition, err error) { return pos, nil } +func (cmd *LCSCmd) Clone() Cmder { + var val *LCSMatch + if cmd.val != nil { + val = &LCSMatch{ + MatchString: cmd.val.MatchString, + Len: cmd.val.Len, + } + if cmd.val.Matches != nil { + val.Matches = make([]LCSMatchedPosition, len(cmd.val.Matches)) + copy(val.Matches, cmd.val.Matches) + } + } + return &LCSCmd{ + baseCmd: cmd.cloneBaseCmd(), + readType: cmd.readType, + val: val, + } +} + // ------------------------------------------------------------------------ type KeyFlags struct { @@ -4690,8 +6808,9 @@ var _ Cmder = (*KeyFlagsCmd)(nil) func NewKeyFlagsCmd(ctx context.Context, args ...interface{}) *KeyFlagsCmd { return &KeyFlagsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeKeyFlags, }, } } @@ -4750,6 +6869,26 @@ func (cmd *KeyFlagsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *KeyFlagsCmd) Clone() Cmder { + var val []KeyFlags + if cmd.val != nil { + val = make([]KeyFlags, len(cmd.val)) + for i, kf := range cmd.val { + val[i] = KeyFlags{ + Key: kf.Key, + } + if kf.Flags != nil { + val[i].Flags = make([]string, len(kf.Flags)) + copy(val[i].Flags, kf.Flags) + } + } + } + return &KeyFlagsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // --------------------------------------------------------------------------------------------------- type ClusterLink struct { @@ -4772,8 +6911,9 @@ var _ Cmder = (*ClusterLinksCmd)(nil) func NewClusterLinksCmd(ctx context.Context, args ...interface{}) *ClusterLinksCmd { return &ClusterLinksCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterLinks, }, } } @@ -4839,6 +6979,18 @@ func (cmd *ClusterLinksCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterLinksCmd) Clone() Cmder { + var val []ClusterLink + if cmd.val != nil { + val = make([]ClusterLink, len(cmd.val)) + copy(val, cmd.val) + } + return &ClusterLinksCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------------------------------------------------------------------------------ type SlotRange struct { @@ -4874,8 +7026,9 @@ var _ Cmder = (*ClusterShardsCmd)(nil) func NewClusterShardsCmd(ctx context.Context, args ...interface{}) *ClusterShardsCmd { return &ClusterShardsCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClusterShards, }, } } @@ -4972,7 +7125,9 @@ func (cmd *ClusterShardsCmd) readReply(rd *proto.Reader) error { case "health": cmd.val[i].Nodes[k].Health, err = rd.ReadString() default: - return fmt.Errorf("redis: unexpected key %q in CLUSTER SHARDS node reply", nodeKey) + if err = rd.DiscardNext(); err != nil { + return err + } } if err != nil { @@ -4981,7 +7136,9 @@ func (cmd *ClusterShardsCmd) readReply(rd *proto.Reader) error { } } default: - return fmt.Errorf("redis: unexpected key %q in CLUSTER SHARDS reply", key) + if err = rd.DiscardNext(); err != nil { + return err + } } } } @@ -4989,6 +7146,28 @@ func (cmd *ClusterShardsCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ClusterShardsCmd) Clone() Cmder { + var val []ClusterShard + if cmd.val != nil { + val = make([]ClusterShard, len(cmd.val)) + for i, shard := range cmd.val { + val[i] = ClusterShard{} + if shard.Slots != nil { + val[i].Slots = make([]SlotRange, len(shard.Slots)) + copy(val[i].Slots, shard.Slots) + } + if shard.Nodes != nil { + val[i].Nodes = make([]Node, len(shard.Nodes)) + copy(val[i].Nodes, shard.Nodes) + } + } + } + return &ClusterShardsCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ----------------------------------------- type RankScore struct { @@ -5007,8 +7186,9 @@ var _ Cmder = (*RankWithScoreCmd)(nil) func NewRankWithScoreCmd(ctx context.Context, args ...interface{}) *RankWithScoreCmd { return &RankWithScoreCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeRankWithScore, }, } } @@ -5049,6 +7229,13 @@ func (cmd *RankWithScoreCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *RankWithScoreCmd) Clone() Cmder { + return &RankWithScoreCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // RankScore is a simple struct, can be copied directly + } +} + // -------------------------------------------------------------------------------------------------- // ClientFlags is redis-server client flags, copy from redis/src/server.h (redis 7.0) @@ -5134,6 +7321,9 @@ type ClientInfo struct { OutputListLength int // oll, output list length (replies are queued in this list when the buffer is full) OutputMemory int // omem, output buffer memory usage TotalMemory int // tot-mem, total memory consumed by this client in its various buffers + TotalNetIn int // tot-net-in, total network input + TotalNetOut int // tot-net-out, total network output + TotalCmds int // tot-cmds, total number of commands processed IoThread int // io-thread id Events string // file descriptor events (see below) LastCmd string // cmd, last command played @@ -5142,6 +7332,9 @@ type ClientInfo struct { Resp int // redis version 7.0, client RESP protocol version LibName string // redis version 7.2, client library name LibVer string // redis version 7.2, client library version + ReadEvents uint64 // redis version 8.8, number of read events processed + AvgPipelineLenSum uint64 // redis version 8.8, sum of pipeline lengths + AvgPipelineLenCnt uint64 // redis version 8.8, count of pipeline operations } type ClientInfoCmd struct { @@ -5155,8 +7348,9 @@ var _ Cmder = (*ClientInfoCmd)(nil) func NewClientInfoCmd(ctx context.Context, args ...interface{}) *ClientInfoCmd { return &ClientInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeClientInfo, }, } } @@ -5299,6 +7493,12 @@ func parseClientInfo(txt string) (info *ClientInfo, err error) { info.OutputMemory, err = strconv.Atoi(val) case "tot-mem": info.TotalMemory, err = strconv.Atoi(val) + case "tot-net-in": + info.TotalNetIn, err = strconv.Atoi(val) + case "tot-net-out": + info.TotalNetOut, err = strconv.Atoi(val) + case "tot-cmds": + info.TotalCmds, err = strconv.Atoi(val) case "events": info.Events = val case "cmd": @@ -5315,8 +7515,14 @@ func parseClientInfo(txt string) (info *ClientInfo, err error) { info.LibVer = val case "io-thread": info.IoThread, err = strconv.Atoi(val) + case "read-events": + info.ReadEvents, err = strconv.ParseUint(val, 10, 64) + case "avg-pipeline-len-sum": + info.AvgPipelineLenSum, err = strconv.ParseUint(val, 10, 64) + case "avg-pipeline-len-cnt": + info.AvgPipelineLenCnt, err = strconv.ParseUint(val, 10, 64) default: - return nil, fmt.Errorf("redis: unexpected client info key(%s)", key) + // skip unknown fields } if err != nil { @@ -5327,6 +7533,53 @@ func parseClientInfo(txt string) (info *ClientInfo, err error) { return info, nil } +func (cmd *ClientInfoCmd) Clone() Cmder { + var val *ClientInfo + if cmd.val != nil { + val = &ClientInfo{ + ID: cmd.val.ID, + Addr: cmd.val.Addr, + LAddr: cmd.val.LAddr, + FD: cmd.val.FD, + Name: cmd.val.Name, + Age: cmd.val.Age, + Idle: cmd.val.Idle, + Flags: cmd.val.Flags, + DB: cmd.val.DB, + Sub: cmd.val.Sub, + PSub: cmd.val.PSub, + SSub: cmd.val.SSub, + Multi: cmd.val.Multi, + Watch: cmd.val.Watch, + QueryBuf: cmd.val.QueryBuf, + QueryBufFree: cmd.val.QueryBufFree, + ArgvMem: cmd.val.ArgvMem, + MultiMem: cmd.val.MultiMem, + BufferSize: cmd.val.BufferSize, + BufferPeak: cmd.val.BufferPeak, + OutputBufferLength: cmd.val.OutputBufferLength, + OutputListLength: cmd.val.OutputListLength, + OutputMemory: cmd.val.OutputMemory, + TotalMemory: cmd.val.TotalMemory, + IoThread: cmd.val.IoThread, + Events: cmd.val.Events, + LastCmd: cmd.val.LastCmd, + User: cmd.val.User, + Redir: cmd.val.Redir, + Resp: cmd.val.Resp, + LibName: cmd.val.LibName, + LibVer: cmd.val.LibVer, + ReadEvents: cmd.val.ReadEvents, + AvgPipelineLenSum: cmd.val.AvgPipelineLenSum, + AvgPipelineLenCnt: cmd.val.AvgPipelineLenCnt, + } + } + return &ClientInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // ------------------------------------------- type ACLLogEntry struct { @@ -5353,8 +7606,9 @@ var _ Cmder = (*ACLLogCmd)(nil) func NewACLLogCmd(ctx context.Context, args ...interface{}) *ACLLogCmd { return &ACLLogCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeACLLog, }, } } @@ -5424,7 +7678,10 @@ func (cmd *ACLLogCmd) readReply(rd *proto.Reader) error { case "timestamp-last-updated": entry.TimestampLastUpdated, err = rd.ReadInt() default: - return fmt.Errorf("redis: unexpected key %q in ACL LOG reply", key) + // skip unknown fields + if err := rd.DiscardNext(); err != nil { + return err + } } if err != nil { @@ -5436,6 +7693,72 @@ func (cmd *ACLLogCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *ACLLogCmd) Clone() Cmder { + var val []*ACLLogEntry + if cmd.val != nil { + val = make([]*ACLLogEntry, len(cmd.val)) + for i, entry := range cmd.val { + if entry != nil { + val[i] = &ACLLogEntry{ + Count: entry.Count, + Reason: entry.Reason, + Context: entry.Context, + Object: entry.Object, + Username: entry.Username, + AgeSeconds: entry.AgeSeconds, + EntryID: entry.EntryID, + TimestampCreated: entry.TimestampCreated, + TimestampLastUpdated: entry.TimestampLastUpdated, + } + // Clone ClientInfo if present + if entry.ClientInfo != nil { + val[i].ClientInfo = &ClientInfo{ + ID: entry.ClientInfo.ID, + Addr: entry.ClientInfo.Addr, + LAddr: entry.ClientInfo.LAddr, + FD: entry.ClientInfo.FD, + Name: entry.ClientInfo.Name, + Age: entry.ClientInfo.Age, + Idle: entry.ClientInfo.Idle, + Flags: entry.ClientInfo.Flags, + DB: entry.ClientInfo.DB, + Sub: entry.ClientInfo.Sub, + PSub: entry.ClientInfo.PSub, + SSub: entry.ClientInfo.SSub, + Multi: entry.ClientInfo.Multi, + Watch: entry.ClientInfo.Watch, + QueryBuf: entry.ClientInfo.QueryBuf, + QueryBufFree: entry.ClientInfo.QueryBufFree, + ArgvMem: entry.ClientInfo.ArgvMem, + MultiMem: entry.ClientInfo.MultiMem, + BufferSize: entry.ClientInfo.BufferSize, + BufferPeak: entry.ClientInfo.BufferPeak, + OutputBufferLength: entry.ClientInfo.OutputBufferLength, + OutputListLength: entry.ClientInfo.OutputListLength, + OutputMemory: entry.ClientInfo.OutputMemory, + TotalMemory: entry.ClientInfo.TotalMemory, + IoThread: entry.ClientInfo.IoThread, + Events: entry.ClientInfo.Events, + LastCmd: entry.ClientInfo.LastCmd, + User: entry.ClientInfo.User, + Redir: entry.ClientInfo.Redir, + Resp: entry.ClientInfo.Resp, + LibName: entry.ClientInfo.LibName, + LibVer: entry.ClientInfo.LibVer, + ReadEvents: entry.ClientInfo.ReadEvents, + AvgPipelineLenSum: entry.ClientInfo.AvgPipelineLenSum, + AvgPipelineLenCnt: entry.ClientInfo.AvgPipelineLenCnt, + } + } + } + } + } + return &ACLLogCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // LibraryInfo holds the library info. type LibraryInfo struct { LibName *string @@ -5464,8 +7787,9 @@ var _ Cmder = (*InfoCmd)(nil) func NewInfoCmd(ctx context.Context, args ...interface{}) *InfoCmd { return &InfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeInfo, }, } } @@ -5531,6 +7855,25 @@ func (cmd *InfoCmd) Item(section, key string) string { } } +func (cmd *InfoCmd) Clone() Cmder { + var val map[string]map[string]string + if cmd.val != nil { + val = make(map[string]map[string]string, len(cmd.val)) + for section, sectionMap := range cmd.val { + if sectionMap != nil { + val[section] = make(map[string]string, len(sectionMap)) + for k, v := range sectionMap { + val[section][k] = v + } + } + } + } + return &InfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + type MonitorStatus int const ( @@ -5549,8 +7892,9 @@ type MonitorCmd struct { func newMonitorCmd(ctx context.Context, ch chan string) *MonitorCmd { return &MonitorCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: []interface{}{"monitor"}, + ctx: ctx, + args: []interface{}{"monitor"}, + cmdType: CmdTypeMonitor, }, ch: ch, status: monitorStatusIdle, @@ -5615,3 +7959,1158 @@ func (cmd *MonitorCmd) Stop() { defer cmd.mu.Unlock() cmd.status = monitorStatusStop } + +type VectorScoreSliceCmd struct { + baseCmd + + val []VectorScore +} + +var _ Cmder = (*VectorScoreSliceCmd)(nil) + +func NewVectorScoreSliceCmd(ctx context.Context, args ...any) *VectorScoreSliceCmd { + return &VectorScoreSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +// NewVectorInfoSliceCmd is an alias for NewVectorScoreSliceCmd kept for backwards compatibility. +func NewVectorInfoSliceCmd(ctx context.Context, args ...any) *VectorScoreSliceCmd { + return NewVectorScoreSliceCmd(ctx, args...) +} + +func (cmd *VectorScoreSliceCmd) SetVal(val []VectorScore) { + cmd.val = val +} + +func (cmd *VectorScoreSliceCmd) Val() []VectorScore { + return cmd.val +} + +func (cmd *VectorScoreSliceCmd) Result() ([]VectorScore, error) { + return cmd.val, cmd.err +} + +func (cmd *VectorScoreSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *VectorScoreSliceCmd) readReply(rd *proto.Reader) error { + typ, err := rd.PeekReplyType() + if err != nil { + return err + } + + var n int + if typ == proto.RespMap { + n, err = rd.ReadMapLen() + if err != nil { + return err + } + } else { + // RESP2 returns a flat array [name, score, name, score, ...] + n, err = rd.ReadArrayLen() + if err != nil { + return err + } + if n%2 != 0 { + return fmt.Errorf("redis: VectorScoreSliceCmd expects even number of elements, got %d", n) + } + n /= 2 + } + + cmd.val = make([]VectorScore, n) + for i := 0; i < n; i++ { + name, err := rd.ReadString() + if err != nil { + return err + } + cmd.val[i].Name = name + + score, err := rd.ReadFloat() + if err != nil { + return err + } + cmd.val[i].Score = score + } + + return nil +} + +func (cmd *VectorScoreSliceCmd) Clone() Cmder { + return &VectorScoreSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +// VectorScoreSliceSliceCmd is used for VLINKS WITHSCORES which returns an array of arrays. +// In RESP3, each inner array contains maps of element -> score. +type VectorScoreSliceSliceCmd struct { + baseCmd + + val [][]VectorScore +} + +var _ Cmder = (*VectorScoreSliceSliceCmd)(nil) + +func NewVectorScoreSliceSliceCmd(ctx context.Context, args ...any) *VectorScoreSliceSliceCmd { + return &VectorScoreSliceSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *VectorScoreSliceSliceCmd) SetVal(val [][]VectorScore) { + cmd.val = val +} + +func (cmd *VectorScoreSliceSliceCmd) Val() [][]VectorScore { + return cmd.val +} + +func (cmd *VectorScoreSliceSliceCmd) Result() ([][]VectorScore, error) { + return cmd.val, cmd.err +} + +func (cmd *VectorScoreSliceSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *VectorScoreSliceSliceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + + cmd.val = make([][]VectorScore, n) + for i := range n { + // Each level can be either a map (RESP3) or an array (RESP2) + levelTyp, err := rd.PeekReplyType() + if err != nil { + return err + } + + if levelTyp == proto.RespMap { + // RESP3 format: each level is a map {element: score, element: score, ...} + mapLen, err := rd.ReadMapLen() + if err != nil { + return err + } + + cmd.val[i] = make([]VectorScore, mapLen) + for j := range mapLen { + name, err := rd.ReadString() + if err != nil { + return err + } + score, err := rd.ReadFloat() + if err != nil { + return err + } + cmd.val[i][j] = VectorScore{Name: name, Score: score} + } + } else { + // RESP2 format: each level is an array of [element, score, element, score, ...] pairs + innerLen, err := rd.ReadArrayLen() + if err != nil { + return err + } + + if innerLen%2 != 0 { + return fmt.Errorf("redis: got %d elements in the VLINKS array, wanted a multiple of 2", innerLen) + } + + cmd.val[i] = make([]VectorScore, innerLen/2) + for j := 0; j < innerLen; j += 2 { + name, err := rd.ReadString() + if err != nil { + return err + } + score, err := rd.ReadFloat() + if err != nil { + return err + } + cmd.val[i][j/2] = VectorScore{Name: name, Score: score} + } + } + } + + return nil +} + +func (cmd *VectorScoreSliceSliceCmd) Clone() Cmder { + var val [][]VectorScore + if cmd.val != nil { + val = make([][]VectorScore, len(cmd.val)) + for i, slice := range cmd.val { + if slice != nil { + val[i] = make([]VectorScore, len(slice)) + copy(val[i], slice) + } + } + } + return &VectorScoreSliceSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + +func readVectorAttribStringOrNil(rd *proto.Reader) (*string, error) { + v, err := rd.ReadReply() + if err != nil { + if err == proto.Nil { + return nil, nil + } + return nil, err + } + s, ok := v.(string) + if !ok { + return nil, fmt.Errorf("redis: can't parse reply=%T reading string", v) + } + return &s, nil +} + +type VectorAttribSliceCmd struct { + baseCmd + + val []VectorAttrib +} + +var _ Cmder = (*VectorAttribSliceCmd)(nil) + +func NewVectorAttribSliceCmd(ctx context.Context, args ...any) *VectorAttribSliceCmd { + return &VectorAttribSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *VectorAttribSliceCmd) SetVal(val []VectorAttrib) { + cmd.val = val +} + +func (cmd *VectorAttribSliceCmd) Val() []VectorAttrib { + return cmd.val +} + +func (cmd *VectorAttribSliceCmd) Result() ([]VectorAttrib, error) { + return cmd.val, cmd.err +} + +func (cmd *VectorAttribSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *VectorAttribSliceCmd) readReply(rd *proto.Reader) error { + replyType, err := rd.PeekReplyType() + if err != nil { + return err + } + + if replyType == proto.RespMap { + n, err := rd.ReadMapLen() + if err != nil { + return err + } + cmd.val = make([]VectorAttrib, n) + for i := 0; i < n; i++ { + name, err := rd.ReadString() + if err != nil { + return err + } + attrib, err := readVectorAttribStringOrNil(rd) + if err != nil { + return err + } + cmd.val[i] = VectorAttrib{Name: name, Attribs: attrib} + } + return nil + } + + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + if n%2 != 0 { + return fmt.Errorf("redis: got %d elements in the VSIM array, wanted a multiple of 2", n) + } + cmd.val = make([]VectorAttrib, n/2) + for i := range cmd.val { + name, err := rd.ReadString() + if err != nil { + return err + } + attrib, err := readVectorAttribStringOrNil(rd) + if err != nil { + return err + } + cmd.val[i] = VectorAttrib{Name: name, Attribs: attrib} + } + return nil +} + +func (cmd *VectorAttribSliceCmd) Clone() Cmder { + return &VectorAttribSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +type VectorScoreAttribSliceCmd struct { + baseCmd + + val []VectorScoreAttrib +} + +var _ Cmder = (*VectorScoreAttribSliceCmd)(nil) + +func NewVectorScoreAttribSliceCmd(ctx context.Context, args ...any) *VectorScoreAttribSliceCmd { + return &VectorScoreAttribSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *VectorScoreAttribSliceCmd) SetVal(val []VectorScoreAttrib) { + cmd.val = val +} + +func (cmd *VectorScoreAttribSliceCmd) Val() []VectorScoreAttrib { + return cmd.val +} + +func (cmd *VectorScoreAttribSliceCmd) Result() ([]VectorScoreAttrib, error) { + return cmd.val, cmd.err +} + +func (cmd *VectorScoreAttribSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *VectorScoreAttribSliceCmd) readReply(rd *proto.Reader) error { + replyType, err := rd.PeekReplyType() + if err != nil { + return err + } + + if replyType == proto.RespMap { + n, err := rd.ReadMapLen() + if err != nil { + return err + } + cmd.val = make([]VectorScoreAttrib, n) + for i := 0; i < n; i++ { + name, err := rd.ReadString() + if err != nil { + return err + } + if err := rd.ReadFixedArrayLen(2); err != nil { + return err + } + score, err := rd.ReadFloat() + if err != nil { + return err + } + attrib, err := readVectorAttribStringOrNil(rd) + if err != nil { + return err + } + cmd.val[i] = VectorScoreAttrib{Name: name, Score: score, Attribs: attrib} + } + return nil + } + + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + if n%3 != 0 { + return fmt.Errorf("redis: got %d elements in the VSIM array, wanted a multiple of 3", n) + } + cmd.val = make([]VectorScoreAttrib, n/3) + for i := range cmd.val { + name, err := rd.ReadString() + if err != nil { + return err + } + score, err := rd.ReadFloat() + if err != nil { + return err + } + attrib, err := readVectorAttribStringOrNil(rd) + if err != nil { + return err + } + cmd.val[i] = VectorScoreAttrib{Name: name, Score: score, Attribs: attrib} + } + return nil +} + +func (cmd *VectorScoreAttribSliceCmd) Clone() Cmder { + return &VectorScoreAttribSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +func (cmd *MonitorCmd) Clone() Cmder { + // MonitorCmd cannot be safely cloned due to channels and goroutines + // Return a new MonitorCmd with the same channel + return newMonitorCmd(cmd.ctx, cmd.ch) +} + +// ExtractCommandValue extracts the value from a command result using the fast enum-based approach +func ExtractCommandValue(cmd interface{}) (interface{}, error) { + // First try to get the command type using the interface + if cmdTypeGetter, ok := cmd.(CmdTypeGetter); ok { + cmdType := cmdTypeGetter.GetCmdType() + + // Use fast type-based extraction + switch cmdType { + case CmdTypeGeneric: + if genericCmd, ok := cmd.(interface { + Val() interface{} + Err() error + }); ok { + return genericCmd.Val(), genericCmd.Err() + } + case CmdTypeString: + if stringCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return stringCmd.Val(), stringCmd.Err() + } + case CmdTypeInt: + if intCmd, ok := cmd.(interface { + Val() int64 + Err() error + }); ok { + return intCmd.Val(), intCmd.Err() + } + case CmdTypeUint: + if uintCmd, ok := cmd.(interface { + Val() uint64 + Err() error + }); ok { + return uintCmd.Val(), uintCmd.Err() + } + case CmdTypeBool: + if boolCmd, ok := cmd.(interface { + Val() bool + Err() error + }); ok { + return boolCmd.Val(), boolCmd.Err() + } + case CmdTypeFloat: + if floatCmd, ok := cmd.(interface { + Val() float64 + Err() error + }); ok { + return floatCmd.Val(), floatCmd.Err() + } + case CmdTypeStatus: + if statusCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return statusCmd.Val(), statusCmd.Err() + } + case CmdTypeDuration: + if durationCmd, ok := cmd.(interface { + Val() time.Duration + Err() error + }); ok { + return durationCmd.Val(), durationCmd.Err() + } + case CmdTypeTime: + if timeCmd, ok := cmd.(interface { + Val() time.Time + Err() error + }); ok { + return timeCmd.Val(), timeCmd.Err() + } + case CmdTypeStringStructMap: + if structMapCmd, ok := cmd.(interface { + Val() map[string]struct{} + Err() error + }); ok { + return structMapCmd.Val(), structMapCmd.Err() + } + case CmdTypeXMessageSlice: + if xMessageSliceCmd, ok := cmd.(interface { + Val() []XMessage + Err() error + }); ok { + return xMessageSliceCmd.Val(), xMessageSliceCmd.Err() + } + case CmdTypeXStreamSlice: + if xStreamSliceCmd, ok := cmd.(interface { + Val() []XStream + Err() error + }); ok { + return xStreamSliceCmd.Val(), xStreamSliceCmd.Err() + } + case CmdTypeXPending: + if xPendingCmd, ok := cmd.(interface { + Val() *XPending + Err() error + }); ok { + return xPendingCmd.Val(), xPendingCmd.Err() + } + case CmdTypeXPendingExt: + if xPendingExtCmd, ok := cmd.(interface { + Val() []XPendingExt + Err() error + }); ok { + return xPendingExtCmd.Val(), xPendingExtCmd.Err() + } + case CmdTypeXAutoClaim: + if xAutoClaimCmd, ok := cmd.(interface { + Val() ([]XMessage, string) + Err() error + }); ok { + messages, start := xAutoClaimCmd.Val() + return CmdTypeXAutoClaimValue{messages: messages, start: start}, xAutoClaimCmd.Err() + } + case CmdTypeXAutoClaimWithDeleted: + if xAutoClaimWithDeletedCmd, ok := cmd.(interface { + Val() ([]XMessage, string, []string) + Err() error + }); ok { + messages, start, deletedIDs := xAutoClaimWithDeletedCmd.Val() + return CmdTypeXAutoClaimWithDeletedValue{messages: messages, start: start, deletedIDs: deletedIDs}, xAutoClaimWithDeletedCmd.Err() + } + case CmdTypeXAutoClaimJustID: + if xAutoClaimJustIDCmd, ok := cmd.(interface { + Val() ([]string, string) + Err() error + }); ok { + ids, start := xAutoClaimJustIDCmd.Val() + return CmdTypeXAutoClaimJustIDValue{ids: ids, start: start}, xAutoClaimJustIDCmd.Err() + } + case CmdTypeXInfoConsumers: + if xInfoConsumersCmd, ok := cmd.(interface { + Val() []XInfoConsumer + Err() error + }); ok { + return xInfoConsumersCmd.Val(), xInfoConsumersCmd.Err() + } + case CmdTypeXInfoGroups: + if xInfoGroupsCmd, ok := cmd.(interface { + Val() []XInfoGroup + Err() error + }); ok { + return xInfoGroupsCmd.Val(), xInfoGroupsCmd.Err() + } + case CmdTypeXInfoStream: + if xInfoStreamCmd, ok := cmd.(interface { + Val() *XInfoStream + Err() error + }); ok { + return xInfoStreamCmd.Val(), xInfoStreamCmd.Err() + } + case CmdTypeXInfoStreamFull: + if xInfoStreamFullCmd, ok := cmd.(interface { + Val() *XInfoStreamFull + Err() error + }); ok { + return xInfoStreamFullCmd.Val(), xInfoStreamFullCmd.Err() + } + case CmdTypeZSlice: + if zSliceCmd, ok := cmd.(interface { + Val() []Z + Err() error + }); ok { + return zSliceCmd.Val(), zSliceCmd.Err() + } + case CmdTypeZWithKey: + if zWithKeyCmd, ok := cmd.(interface { + Val() *ZWithKey + Err() error + }); ok { + return zWithKeyCmd.Val(), zWithKeyCmd.Err() + } + case CmdTypeScan: + if scanCmd, ok := cmd.(interface { + Val() ([]string, uint64) + Err() error + }); ok { + keys, cursor := scanCmd.Val() + return CmdTypeScanValue{keys: keys, cursor: cursor}, scanCmd.Err() + } + case CmdTypeClusterSlots: + if clusterSlotsCmd, ok := cmd.(interface { + Val() []ClusterSlot + Err() error + }); ok { + return clusterSlotsCmd.Val(), clusterSlotsCmd.Err() + } + case CmdTypeGeoLocation: + if geoLocationCmd, ok := cmd.(interface { + Val() []GeoLocation + Err() error + }); ok { + return geoLocationCmd.Val(), geoLocationCmd.Err() + } + case CmdTypeGeoSearchLocation: + if geoSearchLocationCmd, ok := cmd.(interface { + Val() []GeoLocation + Err() error + }); ok { + return geoSearchLocationCmd.Val(), geoSearchLocationCmd.Err() + } + case CmdTypeGeoPos: + if geoPosCmd, ok := cmd.(interface { + Val() []*GeoPos + Err() error + }); ok { + return geoPosCmd.Val(), geoPosCmd.Err() + } + case CmdTypeCommandsInfo: + if commandsInfoCmd, ok := cmd.(interface { + Val() map[string]*CommandInfo + Err() error + }); ok { + return commandsInfoCmd.Val(), commandsInfoCmd.Err() + } + case CmdTypeSlowLog: + if slowLogCmd, ok := cmd.(interface { + Val() []SlowLog + Err() error + }); ok { + return slowLogCmd.Val(), slowLogCmd.Err() + } + case CmdTypeHotKeys: + if hotKeysCmd, ok := cmd.(interface { + Val() *HotKeysResult + Err() error + }); ok { + return hotKeysCmd.Val(), hotKeysCmd.Err() + } + case CmdTypeIncrEXInt: + if incrEXCmd, ok := cmd.(interface { + Val() IncrEXIntResult + Err() error + }); ok { + return incrEXCmd.Val(), incrEXCmd.Err() + } + case CmdTypeIncrEXFloat: + if incrEXCmd, ok := cmd.(interface { + Val() IncrEXFloatResult + Err() error + }); ok { + return incrEXCmd.Val(), incrEXCmd.Err() + } + case CmdTypeKeyValues: + if keyValuesCmd, ok := cmd.(interface { + Val() (string, []string) + Err() error + }); ok { + key, values := keyValuesCmd.Val() + return CmdTypeKeyValuesValue{key: key, values: values}, keyValuesCmd.Err() + } + case CmdTypeZSliceWithKey: + if zSliceWithKeyCmd, ok := cmd.(interface { + Val() (string, []Z) + Err() error + }); ok { + key, zSlice := zSliceWithKeyCmd.Val() + return CmdTypeZSliceWithKeyValue{key: key, zSlice: zSlice}, zSliceWithKeyCmd.Err() + } + case CmdTypeFunctionList: + if functionListCmd, ok := cmd.(interface { + Val() []Library + Err() error + }); ok { + return functionListCmd.Val(), functionListCmd.Err() + } + case CmdTypeFunctionStats: + if functionStatsCmd, ok := cmd.(interface { + Val() FunctionStats + Err() error + }); ok { + return functionStatsCmd.Val(), functionStatsCmd.Err() + } + case CmdTypeLCS: + if lcsCmd, ok := cmd.(interface { + Val() *LCSMatch + Err() error + }); ok { + return lcsCmd.Val(), lcsCmd.Err() + } + case CmdTypeKeyFlags: + if keyFlagsCmd, ok := cmd.(interface { + Val() []KeyFlags + Err() error + }); ok { + return keyFlagsCmd.Val(), keyFlagsCmd.Err() + } + case CmdTypeClusterLinks: + if clusterLinksCmd, ok := cmd.(interface { + Val() []ClusterLink + Err() error + }); ok { + return clusterLinksCmd.Val(), clusterLinksCmd.Err() + } + case CmdTypeClusterShards: + if clusterShardsCmd, ok := cmd.(interface { + Val() []ClusterShard + Err() error + }); ok { + return clusterShardsCmd.Val(), clusterShardsCmd.Err() + } + case CmdTypeRankWithScore: + if rankWithScoreCmd, ok := cmd.(interface { + Val() RankScore + Err() error + }); ok { + return rankWithScoreCmd.Val(), rankWithScoreCmd.Err() + } + case CmdTypeClientInfo: + if clientInfoCmd, ok := cmd.(interface { + Val() *ClientInfo + Err() error + }); ok { + return clientInfoCmd.Val(), clientInfoCmd.Err() + } + case CmdTypeACLLog: + if aclLogCmd, ok := cmd.(interface { + Val() []*ACLLogEntry + Err() error + }); ok { + return aclLogCmd.Val(), aclLogCmd.Err() + } + case CmdTypeInfo: + if infoCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return infoCmd.Val(), infoCmd.Err() + } + case CmdTypeMonitor: + if monitorCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return monitorCmd.Val(), monitorCmd.Err() + } + case CmdTypeJSON: + if jsonCmd, ok := cmd.(interface { + Val() string + Err() error + }); ok { + return jsonCmd.Val(), jsonCmd.Err() + } + case CmdTypeJSONSlice: + if jsonSliceCmd, ok := cmd.(interface { + Val() []interface{} + Err() error + }); ok { + return jsonSliceCmd.Val(), jsonSliceCmd.Err() + } + case CmdTypeIntPointerSlice: + if intPointerSliceCmd, ok := cmd.(interface { + Val() []*int64 + Err() error + }); ok { + return intPointerSliceCmd.Val(), intPointerSliceCmd.Err() + } + case CmdTypeScanDump: + if scanDumpCmd, ok := cmd.(interface { + Val() ScanDump + Err() error + }); ok { + return scanDumpCmd.Val(), scanDumpCmd.Err() + } + case CmdTypeBFInfo: + if bfInfoCmd, ok := cmd.(interface { + Val() BFInfo + Err() error + }); ok { + return bfInfoCmd.Val(), bfInfoCmd.Err() + } + case CmdTypeCFInfo: + if cfInfoCmd, ok := cmd.(interface { + Val() CFInfo + Err() error + }); ok { + return cfInfoCmd.Val(), cfInfoCmd.Err() + } + case CmdTypeCMSInfo: + if cmsInfoCmd, ok := cmd.(interface { + Val() CMSInfo + Err() error + }); ok { + return cmsInfoCmd.Val(), cmsInfoCmd.Err() + } + case CmdTypeTopKInfo: + if topKInfoCmd, ok := cmd.(interface { + Val() TopKInfo + Err() error + }); ok { + return topKInfoCmd.Val(), topKInfoCmd.Err() + } + case CmdTypeTDigestInfo: + if tDigestInfoCmd, ok := cmd.(interface { + Val() TDigestInfo + Err() error + }); ok { + return tDigestInfoCmd.Val(), tDigestInfoCmd.Err() + } + case CmdTypeFTSearch: + if ftSearchCmd, ok := cmd.(interface { + Val() FTSearchResult + Err() error + }); ok { + return ftSearchCmd.Val(), ftSearchCmd.Err() + } + case CmdTypeFTInfo: + if ftInfoCmd, ok := cmd.(interface { + Val() FTInfoResult + Err() error + }); ok { + return ftInfoCmd.Val(), ftInfoCmd.Err() + } + case CmdTypeFTSpellCheck: + if ftSpellCheckCmd, ok := cmd.(interface { + Val() []SpellCheckResult + Err() error + }); ok { + return ftSpellCheckCmd.Val(), ftSpellCheckCmd.Err() + } + case CmdTypeFTSynDump: + if ftSynDumpCmd, ok := cmd.(interface { + Val() []FTSynDumpResult + Err() error + }); ok { + return ftSynDumpCmd.Val(), ftSynDumpCmd.Err() + } + case CmdTypeAggregate: + if aggregateCmd, ok := cmd.(interface { + Val() *FTAggregateResult + Err() error + }); ok { + return aggregateCmd.Val(), aggregateCmd.Err() + } + case CmdTypeTSTimestampValue: + if tsTimestampValueCmd, ok := cmd.(interface { + Val() TSTimestampValue + Err() error + }); ok { + return tsTimestampValueCmd.Val(), tsTimestampValueCmd.Err() + } + case CmdTypeTSTimestampValueSlice: + if tsTimestampValueSliceCmd, ok := cmd.(interface { + Val() []TSTimestampValue + Err() error + }); ok { + return tsTimestampValueSliceCmd.Val(), tsTimestampValueSliceCmd.Err() + } + case CmdTypeStringSlice: + if stringSliceCmd, ok := cmd.(interface { + Val() []string + Err() error + }); ok { + return stringSliceCmd.Val(), stringSliceCmd.Err() + } + case CmdTypeIntSlice: + if intSliceCmd, ok := cmd.(interface { + Val() []int64 + Err() error + }); ok { + return intSliceCmd.Val(), intSliceCmd.Err() + } + case CmdTypeUintSlice: + if uintSliceCmd, ok := cmd.(interface { + Val() []uint64 + Err() error + }); ok { + return uintSliceCmd.Val(), uintSliceCmd.Err() + } + case CmdTypeBoolSlice: + if boolSliceCmd, ok := cmd.(interface { + Val() []bool + Err() error + }); ok { + return boolSliceCmd.Val(), boolSliceCmd.Err() + } + case CmdTypeFloatSlice: + if floatSliceCmd, ok := cmd.(interface { + Val() []float64 + Err() error + }); ok { + return floatSliceCmd.Val(), floatSliceCmd.Err() + } + case CmdTypeSlice: + if sliceCmd, ok := cmd.(interface { + Val() []interface{} + Err() error + }); ok { + return sliceCmd.Val(), sliceCmd.Err() + } + case CmdTypeKeyValueSlice: + if keyValueSliceCmd, ok := cmd.(interface { + Val() []KeyValue + Err() error + }); ok { + return keyValueSliceCmd.Val(), keyValueSliceCmd.Err() + } + case CmdTypeAREntrySlice: + if arEntrySliceCmd, ok := cmd.(interface { + Val() []AREntry + Err() error + }); ok { + return arEntrySliceCmd.Val(), arEntrySliceCmd.Err() + } + case CmdTypeMapStringString: + if mapCmd, ok := cmd.(interface { + Val() map[string]string + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() + } + case CmdTypeMapStringInt: + if mapCmd, ok := cmd.(interface { + Val() map[string]int64 + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() + } + case CmdTypeMapStringInterfaceSlice: + if mapCmd, ok := cmd.(interface { + Val() []map[string]interface{} + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() + } + case CmdTypeMapStringInterface: + if mapCmd, ok := cmd.(interface { + Val() map[string]interface{} + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() + } + case CmdTypeMapStringStringSlice: + if mapCmd, ok := cmd.(interface { + Val() []map[string]string + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() + } + case CmdTypeMapMapStringInterface: + if mapCmd, ok := cmd.(interface { + Val() map[string]interface{} + Err() error + }); ok { + return mapCmd.Val(), mapCmd.Err() + } + default: + // For unknown command types, return nil + return nil, nil + } + } + + // If we can't get the command type, return nil + return nil, nil +} + +//------------------------------------------------------------------------------ + +// IncrEXIntResult is the reply of an INCREX command issued via IncrEXInt. +// Value is the new value of the key; AppliedIncrement is the increment that +// the server actually applied (0 when an out-of-bounds operation was +// rejected, clamped when SATURATE was set). +type IncrEXIntResult struct { + Value int64 + AppliedIncrement int64 +} + +type IncrEXIntCmd struct { + baseCmd + + val IncrEXIntResult +} + +var _ Cmder = (*IncrEXIntCmd)(nil) + +func NewIncrEXIntCmd(ctx context.Context, args ...interface{}) *IncrEXIntCmd { + return &IncrEXIntCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeIncrEXInt, + }, + } +} + +func (cmd *IncrEXIntCmd) SetVal(val IncrEXIntResult) { cmd.val = val } +func (cmd *IncrEXIntCmd) Val() IncrEXIntResult { return cmd.val } +func (cmd *IncrEXIntCmd) Result() (IncrEXIntResult, error) { + return cmd.val, cmd.err +} +func (cmd *IncrEXIntCmd) String() string { return cmdString(cmd, cmd.val) } + +func (cmd *IncrEXIntCmd) readReply(rd *proto.Reader) error { + if err := rd.ReadFixedArrayLen(2); err != nil { + return err + } + value, err := rd.ReadInt() + if err != nil { + return err + } + applied, err := rd.ReadInt() + if err != nil { + return err + } + cmd.val = IncrEXIntResult{Value: value, AppliedIncrement: applied} + return nil +} + +func (cmd *IncrEXIntCmd) Clone() Cmder { + return &IncrEXIntCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +// IncrEXFloatResult is the reply of an INCREX command issued via IncrEXFloat. +type IncrEXFloatResult struct { + Value float64 + AppliedIncrement float64 +} + +type IncrEXFloatCmd struct { + baseCmd + + val IncrEXFloatResult +} + +var _ Cmder = (*IncrEXFloatCmd)(nil) + +func NewIncrEXFloatCmd(ctx context.Context, args ...interface{}) *IncrEXFloatCmd { + return &IncrEXFloatCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeIncrEXFloat, + }, + } +} + +func (cmd *IncrEXFloatCmd) SetVal(val IncrEXFloatResult) { cmd.val = val } +func (cmd *IncrEXFloatCmd) Val() IncrEXFloatResult { return cmd.val } +func (cmd *IncrEXFloatCmd) Result() (IncrEXFloatResult, error) { + return cmd.val, cmd.err +} +func (cmd *IncrEXFloatCmd) String() string { return cmdString(cmd, cmd.val) } + +func (cmd *IncrEXFloatCmd) readReply(rd *proto.Reader) error { + if err := rd.ReadFixedArrayLen(2); err != nil { + return err + } + value, err := rd.ReadFloat() + if err != nil { + return err + } + applied, err := rd.ReadFloat() + if err != nil { + return err + } + cmd.val = IncrEXFloatResult{Value: value, AppliedIncrement: applied} + return nil +} + +func (cmd *IncrEXFloatCmd) Clone() Cmder { + return &IncrEXFloatCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + } +} + +//------------------------------------------------------------------------------ + +// AREntrySliceCmd is a command that returns index-value pairs from ARSCAN or ARGREP. +type AREntrySliceCmd struct { + baseCmd + val []AREntry +} + +var _ Cmder = (*AREntrySliceCmd)(nil) + +func NewAREntrySliceCmd(ctx context.Context, args ...any) *AREntrySliceCmd { + return &AREntrySliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + cmdType: CmdTypeAREntrySlice, + }, + } +} + +func (cmd *AREntrySliceCmd) SetVal(val []AREntry) { + cmd.val = val +} + +func (cmd *AREntrySliceCmd) Val() []AREntry { + return cmd.val +} + +func (cmd *AREntrySliceCmd) Result() ([]AREntry, error) { + return cmd.val, cmd.err +} + +func (cmd *AREntrySliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *AREntrySliceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadArrayLen() + if err != nil { + return err + } + if n == 0 { + cmd.val = make([]AREntry, 0) + return nil + } + + cmd.val = make([]AREntry, n) + for i := range n { + if err = rd.ReadFixedArrayLen(2); err != nil { + return err + } + + cmd.val[i].Index, err = rd.ReadUint() + if err != nil { + return err + } + + cmd.val[i].Value, err = rd.ReadString() + if err != nil { + return err + } + } + return nil +} + +func (cmd *AREntrySliceCmd) Clone() Cmder { + var val []AREntry + if cmd.val != nil { + val = make([]AREntry, len(cmd.val)) + copy(val, cmd.val) + } + return &AREntrySliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} diff --git a/vendor/github.com/redis/go-redis/v9/command_policy_resolver.go b/vendor/github.com/redis/go-redis/v9/command_policy_resolver.go new file mode 100644 index 000000000..da8c6d314 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/command_policy_resolver.go @@ -0,0 +1,209 @@ +package redis + +import ( + "context" + "strings" + + "github.com/redis/go-redis/v9/internal/routing" +) + +type ( + module = string + commandName = string +) + +var defaultPolicies = map[module]map[commandName]*routing.CommandPolicy{ + "ft": { + "create": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "search": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "aggregate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "dictadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dictdump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "dictdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "suglen": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "cursor": { + Request: routing.ReqSpecial, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "sugadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "sugget": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "sugdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultHashSlot, + }, + "spellcheck": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "explain": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "explaincli": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "aliasadd": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "aliasdel": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "info": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "tagvals": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "syndump": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "synupdate": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "profile": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + Tips: map[string]string{ + routing.ReadOnlyCMD: "", + }, + }, + "alter": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "dropindex": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + "drop": { + Request: routing.ReqDefault, + Response: routing.RespDefaultKeyless, + }, + }, +} + +type CommandInfoResolveFunc func(ctx context.Context, cmd Cmder) *routing.CommandPolicy + +type commandInfoResolver struct { + resolveFunc CommandInfoResolveFunc + fallBackResolver *commandInfoResolver +} + +func NewCommandInfoResolver(resolveFunc CommandInfoResolveFunc) *commandInfoResolver { + return &commandInfoResolver{ + resolveFunc: resolveFunc, + } +} + +func NewDefaultCommandPolicyResolver() *commandInfoResolver { + return NewCommandInfoResolver(func(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + module := "core" + command := cmd.Name() + cmdParts := strings.Split(command, ".") + if len(cmdParts) == 2 { + module = cmdParts[0] + command = cmdParts[1] + } + + if policy, ok := defaultPolicies[module][command]; ok { + return policy + } + + return nil + }) +} + +func (r *commandInfoResolver) GetCommandPolicy(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + if r.resolveFunc == nil { + return nil + } + + policy := r.resolveFunc(ctx, cmd) + if policy != nil { + return policy + } + + if r.fallBackResolver != nil { + return r.fallBackResolver.GetCommandPolicy(ctx, cmd) + } + + return nil +} + +func (r *commandInfoResolver) SetFallbackResolver(fallbackResolver *commandInfoResolver) { + r.fallBackResolver = fallbackResolver +} diff --git a/vendor/github.com/redis/go-redis/v9/commands.go b/vendor/github.com/redis/go-redis/v9/commands.go index 271323242..d347ffeb5 100644 --- a/vendor/github.com/redis/go-redis/v9/commands.go +++ b/vendor/github.com/redis/go-redis/v9/commands.go @@ -55,6 +55,11 @@ func appendArgs(dst, src []interface{}) []interface{} { return appendArg(dst, src[0]) } + if cap(dst) < len(dst)+len(src) { + newDst := make([]interface{}, len(dst), len(dst)+len(src)) + copy(newDst, dst) + dst = newDst + } dst = append(dst, src...) return dst } @@ -193,6 +198,7 @@ type Cmdable interface { ClientID(ctx context.Context) *IntCmd ClientUnblock(ctx context.Context, id int64) *IntCmd ClientUnblockWithError(ctx context.Context, id int64) *IntCmd + ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd ConfigResetStat(ctx context.Context) *StatusCmd ConfigSet(ctx context.Context, parameter, value string) *StatusCmd @@ -209,14 +215,20 @@ type Cmdable interface { ShutdownSave(ctx context.Context) *StatusCmd ShutdownNoSave(ctx context.Context) *StatusCmd SlaveOf(ctx context.Context, host, port string) *StatusCmd + ReplicaOf(ctx context.Context, host, port string) *StatusCmd SlowLogGet(ctx context.Context, num int64) *SlowLogCmd + SlowLogLen(ctx context.Context) *IntCmd + SlowLogReset(ctx context.Context) *StatusCmd Time(ctx context.Context) *TimeCmd DebugObject(ctx context.Context, key string) *StringCmd MemoryUsage(ctx context.Context, key string, samples ...int) *IntCmd + Latency(ctx context.Context) *LatencyCmd + LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd ACLCmdable + ArrayCmdable BitMapCmdable ClusterCmdable GenericCmdable @@ -234,6 +246,7 @@ type Cmdable interface { StreamCmdable TimeseriesCmdable JSONCmdable + VectorSetCmdable } type StatefulCmdable interface { @@ -252,6 +265,7 @@ var ( _ Cmdable = (*Tx)(nil) _ Cmdable = (*Ring)(nil) _ Cmdable = (*ClusterClient)(nil) + _ Cmdable = (*Pipeline)(nil) ) type cmdable func(ctx context.Context, cmd Cmder) error @@ -436,6 +450,23 @@ func (c cmdable) Do(ctx context.Context, args ...interface{}) *Cmd { return cmd } +// DoRaw executes a command and returns the raw RESP protocol bytes without parsing. +func (c cmdable) DoRaw(ctx context.Context, args ...interface{}) *RawCmd { + cmd := NewRawCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// DoRawWriteTo executes a command and streams raw RESP bytes directly to w without intermediate allocations. +func (c cmdable) DoRawWriteTo(ctx context.Context, w io.Writer, args ...interface{}) *RawWriteToCmd { + cmd := NewRawWriteToCmd(ctx, w, args...) + _ = c(ctx, cmd) + return cmd +} + +// Quit closes the connection. +// +// Deprecated: Just close the connection instead as of Redis 7.2.0. func (c cmdable) Quit(_ context.Context) *StatusCmd { panic("not implemented") } @@ -517,6 +548,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd { return cmd } +// ClientMaintNotifications enables or disables maintenance notifications for maintenance upgrades. +// When enabled, the client will receive push notifications about Redis maintenance events. +func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd { + args := []interface{}{"client", "maint_notifications"} + if enabled { + if endpointType == "" { + endpointType = "none" + } + args = append(args, "on", "moving-endpoint-type", endpointType) + } else { + args = append(args, "off") + } + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + // ------------------------------------------------------------------------------------------------ func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd { @@ -641,18 +689,56 @@ func (c cmdable) ShutdownNoSave(ctx context.Context) *StatusCmd { return c.shutdown(ctx, "nosave") } +// SlaveOf sets a Redis server as a replica of another, or promotes it to being a master. +// +// Deprecated: Use ReplicaOf instead as of Redis 5.0.0. func (c cmdable) SlaveOf(ctx context.Context, host, port string) *StatusCmd { cmd := NewStatusCmd(ctx, "slaveof", host, port) _ = c(ctx, cmd) return cmd } +// ReplicaOf sets a Redis server as a replica of another, or promotes it to being a master. +func (c cmdable) ReplicaOf(ctx context.Context, host, port string) *StatusCmd { + cmd := NewStatusCmd(ctx, "replicaof", host, port) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) SlowLogGet(ctx context.Context, num int64) *SlowLogCmd { cmd := NewSlowLogCmd(context.Background(), "slowlog", "get", num) _ = c(ctx, cmd) return cmd } +func (c cmdable) SlowLogLen(ctx context.Context) *IntCmd { + cmd := NewIntCmd(ctx, "slowlog", "len") + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) SlowLogReset(ctx context.Context) *StatusCmd { + cmd := NewStatusCmd(ctx, "slowlog", "reset") + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) Latency(ctx context.Context) *LatencyCmd { + cmd := NewLatencyCmd(ctx, "latency", "latest") + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd { + args := make([]interface{}, 2+len(events)) + args[0] = "latency" + args[1] = "reset" + copy(args[2:], events) + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) Sync(_ context.Context) { panic("not implemented") } @@ -673,7 +759,9 @@ func (c cmdable) MemoryUsage(ctx context.Context, key string, samples ...int) *I args := []interface{}{"memory", "usage", key} if len(samples) > 0 { if len(samples) != 1 { - panic("MemoryUsage expects single sample count") + cmd := NewIntCmd(ctx) + cmd.SetErr(errors.New("MemoryUsage expects single sample count")) + return cmd } args = append(args, "SAMPLES", samples[0]) } diff --git a/vendor/github.com/redis/go-redis/v9/dial_retry_backoff.go b/vendor/github.com/redis/go-redis/v9/dial_retry_backoff.go new file mode 100644 index 000000000..bb3e8bf2a --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/dial_retry_backoff.go @@ -0,0 +1,39 @@ +package redis + +import ( + "time" + + "github.com/redis/go-redis/v9/internal" +) + +// DialRetryBackoffConstant returns a dial retry backoff function that always returns d. +// attempt is 0-based: attempt=0 is the delay after the 1st failed dial. +func DialRetryBackoffConstant(d time.Duration) func(attempt int) time.Duration { + if d < 0 { + d = 0 + } + return func(int) time.Duration { return d } +} + +// DialRetryBackoffExponential returns a dial retry backoff function that uses exponential +// backoff with jitter and a cap, using internal.RetryBackoff. +// +// attempt is 0-based: attempt=0 is the delay after the 1st failed dial. +func DialRetryBackoffExponential(minBackoff, maxBackoff time.Duration) func(attempt int) time.Duration { + if minBackoff < 0 { + minBackoff = 0 + } + if maxBackoff < 0 { + maxBackoff = 0 + } + if minBackoff > maxBackoff { + minBackoff = maxBackoff + } + return func(attempt int) time.Duration { + // internal.RetryBackoff expects retry >= 0. + if attempt < 0 { + attempt = 0 + } + return internal.RetryBackoff(attempt, minBackoff, maxBackoff) + } +} diff --git a/vendor/github.com/redis/go-redis/v9/docker-compose.yml b/vendor/github.com/redis/go-redis/v9/docker-compose.yml index 3d4347bf2..fed908bea 100644 --- a/vendor/github.com/redis/go-redis/v9/docker-compose.yml +++ b/vendor/github.com/redis/go-redis/v9/docker-compose.yml @@ -1,12 +1,16 @@ --- +x-default-image: &default-image ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.8.0} + services: redis: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:rs-7.4.0-v2} + image: *default-image platform: linux/amd64 container_name: redis-standalone environment: - TLS_ENABLED=yes + - TLS_CLIENT_CNS=testcertuser + - TLS_AUTH_CLIENTS_USER=CN - REDIS_CLUSTER=no - PORT=6379 - TLS_PORT=6666 @@ -21,9 +25,10 @@ services: - sentinel - all-stack - all + - e2e osscluster: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:rs-7.4.0-v2} + image: *default-image platform: linux/amd64 container_name: redis-osscluster environment: @@ -39,14 +44,77 @@ services: - all-stack - all + cae-resp-proxy: + image: redislabs/client-resp-proxy:latest + container_name: cae-resp-proxy + environment: + - TARGET_HOST=redis + - TARGET_PORT=6379 + - LISTEN_PORT=17000,17001,17002,17003 # 4 proxy nodes: initially show 3, swap in 4th during SMIGRATED + - LISTEN_HOST=0.0.0.0 + - API_PORT=3000 + - DEFAULT_INTERCEPTORS=cluster,hitless + ports: + - "17000:17000" # Proxy node 1 (host:container) + - "17001:17001" # Proxy node 2 (host:container) + - "17002:17002" # Proxy node 3 (host:container) + - "17003:17003" # Proxy node 4 (host:container) - hidden initially, swapped in during SMIGRATED + - "18100:3000" # HTTP API port (host:container) + depends_on: + - redis + profiles: + - e2e + - all + + proxy-fault-injector: + build: + context: . + dockerfile: maintnotifications/e2e/cmd/proxy-fi-server/Dockerfile + container_name: proxy-fault-injector + ports: + - "15000:5000" # Fault injector API port (host:container) + depends_on: + - cae-resp-proxy + environment: + - PROXY_API_URL=http://cae-resp-proxy:3000 + profiles: + - e2e + - all + + osscluster-tls: + image: *default-image + platform: linux/amd64 + container_name: redis-osscluster-tls + environment: + - NODES=6 + - PORT=6430 + - TLS_PORT=5430 + - TLS_ENABLED=yes + - TLS_CLIENT_CNS=testcertuser + - TLS_AUTH_CLIENTS_USER=CN + - REDIS_CLUSTER=yes + - REPLICAS=1 + command: "--tls-auth-clients optional --cluster-announce-ip 127.0.0.1" + ports: + - "6430-6435:6430-6435" # Regular ports + - "5430-5435:5430-5435" # TLS ports (set via TLS_PORT env var) + - "16430-16435:16430-16435" # Cluster bus ports (PORT + 10000) + volumes: + - "./dockers/osscluster-tls:/redis/work" + profiles: + - cluster-tls + - all + sentinel-cluster: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:rs-7.4.0-v2} + image: *default-image platform: linux/amd64 container_name: redis-sentinel-cluster network_mode: "host" environment: - NODES=3 - TLS_ENABLED=yes + - TLS_CLIENT_CNS=testcertuser + - TLS_AUTH_CLIENTS_USER=CN - REDIS_CLUSTER=no - PORT=9121 command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save ""} @@ -60,7 +128,7 @@ services: - all sentinel: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:rs-7.4.0-v2} + image: *default-image platform: linux/amd64 container_name: redis-sentinel depends_on: @@ -84,19 +152,21 @@ services: - all ring-cluster: - image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:rs-7.4.0-v2} + image: *default-image platform: linux/amd64 container_name: redis-ring-cluster environment: - NODES=3 - TLS_ENABLED=yes + - TLS_CLIENT_CNS=testcertuser + - TLS_AUTH_CLIENTS_USER=CN - REDIS_CLUSTER=no - PORT=6390 command: ${REDIS_EXTRA_ARGS:---enable-debug-command yes --enable-module-command yes --tls-auth-clients optional --save ""} ports: - - 6390:6390 - - 6391:6391 - - 6392:6392 + - "6390:6390" + - "6391:6391" + - "6392:6392" volumes: - "./dockers/ring:/redis/work" profiles: diff --git a/vendor/github.com/redis/go-redis/v9/error.go b/vendor/github.com/redis/go-redis/v9/error.go index 6f47f7cf2..06ecca740 100644 --- a/vendor/github.com/redis/go-redis/v9/error.go +++ b/vendor/github.com/redis/go-redis/v9/error.go @@ -15,6 +15,24 @@ import ( // ErrClosed performs any operation on the closed client will return this error. var ErrClosed = pool.ErrClosed +// ErrPoolExhausted is returned from a pool connection method +// when the maximum number of database connections in the pool has been reached. +var ErrPoolExhausted = pool.ErrPoolExhausted + +// ErrPoolTimeout timed out waiting to get a connection from the connection pool. +var ErrPoolTimeout = pool.ErrPoolTimeout + +// ErrCrossSlot is returned when keys are used in the same Redis command and +// the keys are not in the same hash slot. This error is returned by Redis +// Cluster and will be returned by the client when TxPipeline or TxPipelined +// is used on a ClusterClient with keys in different slots. +var ErrCrossSlot = proto.RedisError("CROSSSLOT Keys in request don't hash to the same slot") + +// ErrNoScript is returned when EVALSHA is requested for a script digest that +// is not available in the script cache. Note that this error text is reproduced +// literally from that used by Redis. +var ErrNoScript = proto.RedisError("NOSCRIPT No matching script. Please use EVAL.") + // HasErrorPrefix checks if the err is a Redis error and the message contains a prefix. func HasErrorPrefix(err error, prefix string) bool { var rErr Error @@ -39,34 +57,93 @@ type Error interface { var _ Error = proto.RedisError("") func isContextError(err error) bool { - switch err { - case context.Canceled, context.DeadlineExceeded: - return true - default: - return false + // Check for wrapped context errors using errors.Is + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) +} + +// isTimeoutError checks if an error is a timeout error, even if wrapped. +// Returns (isTimeout, shouldRetryOnTimeout) where: +// - isTimeout: true if the error is any kind of timeout error +// - shouldRetryOnTimeout: true if Timeout() method returns true +func isTimeoutError(err error) (isTimeout bool, hasTimeoutFlag bool) { + // Check for timeoutError interface (works with wrapped errors) + var te timeoutError + if errors.As(err, &te) { + return true, te.Timeout() } + + // Check for net.Error specifically (common case for network timeouts) + var netErr net.Error + if errors.As(err, &netErr) { + return true, netErr.Timeout() + } + + return false, false } func shouldRetry(err error, retryTimeout bool) bool { - switch err { - case io.EOF, io.ErrUnexpectedEOF: + if err == nil { + return false + } + + // Check for EOF errors (works with wrapped errors) + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + + // Dial errors mean TCP connection was never established — safe to retry even + // when wrapped inside context.DeadlineExceeded (from DialTimeout context). + // Must be checked before the context error check below. + var opErr *net.OpError + if errors.As(err, &opErr) && opErr.Op == "dial" { return true - case nil, context.Canceled, context.DeadlineExceeded: + } + + // Check for context errors (works with wrapped errors) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return false - case pool.ErrPoolTimeout: + } + + // Check for pool timeout (works with wrapped errors) + if errors.Is(err, pool.ErrPoolTimeout) { // connection pool timeout, increase retries. #3289 return true } - if v, ok := err.(timeoutError); ok { - if v.Timeout() { + // Check for timeout errors (works with wrapped errors) + if isTimeout, hasTimeoutFlag := isTimeoutError(err); isTimeout { + if hasTimeoutFlag { return retryTimeout } return true } + // Check for typed Redis errors using errors.As (works with wrapped errors) + if proto.IsMaxClientsError(err) { + return true + } + if proto.IsLoadingError(err) { + return true + } + if proto.IsReadOnlyError(err) { + return true + } + if proto.IsMasterDownError(err) { + return true + } + if proto.IsClusterDownError(err) { + return true + } + if proto.IsTryAgainError(err) { + return true + } + if proto.IsNoReplicasError(err) { + return true + } + + // Fallback to string checking for backward compatibility with plain errors s := err.Error() - if s == "ERR max number of clients reached" { + if strings.HasPrefix(s, "ERR max number of clients reached") { return true } if strings.HasPrefix(s, "LOADING ") { @@ -75,7 +152,7 @@ func shouldRetry(err error, retryTimeout bool) bool { if strings.HasPrefix(s, "READONLY ") { return true } - if strings.HasPrefix(s, "MASTERDOWN ") { + if strings.Contains(s, "-READONLY You can't write against a read only replica") { return true } if strings.HasPrefix(s, "CLUSTERDOWN ") { @@ -84,20 +161,39 @@ func shouldRetry(err error, retryTimeout bool) bool { if strings.HasPrefix(s, "TRYAGAIN ") { return true } + if strings.HasPrefix(s, "MASTERDOWN ") { + return true + } + if strings.HasPrefix(s, "NOREPLICAS ") { + return true + } return false } func isRedisError(err error) bool { - _, ok := err.(proto.RedisError) - return ok + // Check if error implements the Error interface (works with wrapped errors) + var redisErr Error + if errors.As(err, &redisErr) { + return true + } + // Also check for proto.RedisError specifically + var protoRedisErr proto.RedisError + return errors.As(err, &protoRedisErr) } func isBadConn(err error, allowTimeout bool, addr string) bool { - switch err { - case nil: + if err == nil { return false - case context.Canceled, context.DeadlineExceeded: + } + + // Check for context errors (works with wrapped errors) + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + + // Check for pool timeout errors (works with wrapped errors) + if errors.Is(err, pool.ErrConnUnusableTimeout) { return true } @@ -118,7 +214,9 @@ func isBadConn(err error, allowTimeout bool, addr string) bool { } if allowTimeout { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Check for network timeout errors (works with wrapped errors) + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { return false } } @@ -127,44 +225,151 @@ func isBadConn(err error, allowTimeout bool, addr string) bool { } func isMovedError(err error) (moved bool, ask bool, addr string) { - if !isRedisError(err) { - return + // Check for typed MovedError + if movedErr, ok := proto.IsMovedError(err); ok { + addr = movedErr.Addr() + addr = internal.GetAddr(addr) + return true, false, addr } - s := err.Error() - switch { - case strings.HasPrefix(s, "MOVED "): - moved = true - case strings.HasPrefix(s, "ASK "): - ask = true - default: - return + // Check for typed AskError + if askErr, ok := proto.IsAskError(err); ok { + addr = askErr.Addr() + addr = internal.GetAddr(addr) + return false, true, addr } - ind := strings.LastIndex(s, " ") - if ind == -1 { - return false, false, "" + // Fallback to string checking for backward compatibility + s := err.Error() + if strings.HasPrefix(s, "MOVED ") { + // Parse: MOVED 3999 127.0.0.1:6381 + parts := strings.Split(s, " ") + if len(parts) == 3 { + addr = internal.GetAddr(parts[2]) + return true, false, addr + } + } + if strings.HasPrefix(s, "ASK ") { + // Parse: ASK 3999 127.0.0.1:6381 + parts := strings.Split(s, " ") + if len(parts) == 3 { + addr = internal.GetAddr(parts[2]) + return false, true, addr + } } - addr = s[ind+1:] - addr = internal.GetAddr(addr) - return + return false, false, "" } func isLoadingError(err error) bool { - return strings.HasPrefix(err.Error(), "LOADING ") + return proto.IsLoadingError(err) } func isReadOnlyError(err error) bool { - return strings.HasPrefix(err.Error(), "READONLY ") + return proto.IsReadOnlyError(err) } func isMovedSameConnAddr(err error, addr string) bool { - redisError := err.Error() - if !strings.HasPrefix(redisError, "MOVED ") { - return false + if movedErr, ok := proto.IsMovedError(err); ok { + return strings.HasSuffix(movedErr.Addr(), addr) } - return strings.HasSuffix(redisError, " "+addr) + return false +} + +//------------------------------------------------------------------------------ + +// Typed error checking functions for public use. +// These functions work correctly even when errors are wrapped in hooks. + +// IsLoadingError checks if an error is a Redis LOADING error, even if wrapped. +// LOADING errors occur when Redis is loading the dataset in memory. +func IsLoadingError(err error) bool { + return proto.IsLoadingError(err) +} + +// IsReadOnlyError checks if an error is a Redis READONLY error, even if wrapped. +// READONLY errors occur when trying to write to a read-only replica. +func IsReadOnlyError(err error) bool { + return proto.IsReadOnlyError(err) +} + +// IsClusterDownError checks if an error is a Redis CLUSTERDOWN error, even if wrapped. +// CLUSTERDOWN errors occur when the cluster is down. +func IsClusterDownError(err error) bool { + return proto.IsClusterDownError(err) +} + +// IsTryAgainError checks if an error is a Redis TRYAGAIN error, even if wrapped. +// TRYAGAIN errors occur when a command cannot be processed and should be retried. +func IsTryAgainError(err error) bool { + return proto.IsTryAgainError(err) +} + +// IsMasterDownError checks if an error is a Redis MASTERDOWN error, even if wrapped. +// MASTERDOWN errors occur when the master is down. +func IsMasterDownError(err error) bool { + return proto.IsMasterDownError(err) +} + +// IsMaxClientsError checks if an error is a Redis max clients error, even if wrapped. +// This error occurs when the maximum number of clients has been reached. +func IsMaxClientsError(err error) bool { + return proto.IsMaxClientsError(err) +} + +// IsMovedError checks if an error is a Redis MOVED error, even if wrapped. +// MOVED errors occur in cluster mode when a key has been moved to a different node. +// Returns the address of the node where the key has been moved and a boolean indicating if it's a MOVED error. +func IsMovedError(err error) (addr string, ok bool) { + if movedErr, isMovedErr := proto.IsMovedError(err); isMovedErr { + return movedErr.Addr(), true + } + return "", false +} + +// IsAskError checks if an error is a Redis ASK error, even if wrapped. +// ASK errors occur in cluster mode when a key is being migrated and the client should ask another node. +// Returns the address of the node to ask and a boolean indicating if it's an ASK error. +func IsAskError(err error) (addr string, ok bool) { + if askErr, isAskErr := proto.IsAskError(err); isAskErr { + return askErr.Addr(), true + } + return "", false +} + +// IsAuthError checks if an error is a Redis authentication error, even if wrapped. +// Authentication errors occur when: +// - NOAUTH: Redis requires authentication but none was provided +// - WRONGPASS: Redis authentication failed due to incorrect password +// - unauthenticated: Error returned when password changed +func IsAuthError(err error) bool { + return proto.IsAuthError(err) +} + +// IsPermissionError checks if an error is a Redis permission error, even if wrapped. +// Permission errors (NOPERM) occur when a user does not have permission to execute a command. +func IsPermissionError(err error) bool { + return proto.IsPermissionError(err) +} + +// IsExecAbortError checks if an error is a Redis EXECABORT error, even if wrapped. +// EXECABORT errors occur when a transaction is aborted. +func IsExecAbortError(err error) bool { + return proto.IsExecAbortError(err) +} + +// IsOOMError checks if an error is a Redis OOM (Out Of Memory) error, even if wrapped. +// OOM errors occur when Redis is out of memory. +func IsOOMError(err error) bool { + return proto.IsOOMError(err) +} + +// IsNoReplicasError checks if an error is a Redis NOREPLICAS error, even if wrapped. +// NOREPLICAS errors occur when not enough replicas acknowledge a write operation. +// This typically happens with WAIT/WAITAOF commands or CLUSTER SETSLOT with synchronous +// replication when the required number of replicas cannot confirm the write within the timeout. +func IsNoReplicasError(err error) bool { + return proto.IsNoReplicasError(err) } //------------------------------------------------------------------------------ diff --git a/vendor/github.com/redis/go-redis/v9/generic_commands.go b/vendor/github.com/redis/go-redis/v9/generic_commands.go index dc6c3fe01..c7100222c 100644 --- a/vendor/github.com/redis/go-redis/v9/generic_commands.go +++ b/vendor/github.com/redis/go-redis/v9/generic_commands.go @@ -3,6 +3,8 @@ package redis import ( "context" "time" + + "github.com/redis/go-redis/v9/internal/hashtag" ) type GenericCmdable interface { @@ -363,6 +365,9 @@ func (c cmdable) Scan(ctx context.Context, cursor uint64, match string, count in args = append(args, "count", count) } cmd := NewScanCmd(ctx, c, args...) + if hashtag.Present(match) { + cmd.SetFirstKeyPos(3) + } _ = c(ctx, cmd) return cmd } @@ -379,6 +384,9 @@ func (c cmdable) ScanType(ctx context.Context, cursor uint64, match string, coun args = append(args, "type", keyType) } cmd := NewScanCmd(ctx, c, args...) + if hashtag.Present(match) { + cmd.SetFirstKeyPos(3) + } _ = c(ctx, cmd) return cmd } diff --git a/vendor/github.com/redis/go-redis/v9/geo_commands.go b/vendor/github.com/redis/go-redis/v9/geo_commands.go index f047b98aa..0f2742893 100644 --- a/vendor/github.com/redis/go-redis/v9/geo_commands.go +++ b/vendor/github.com/redis/go-redis/v9/geo_commands.go @@ -33,7 +33,10 @@ func (c cmdable) GeoAdd(ctx context.Context, key string, geoLocation ...*GeoLoca return cmd } -// GeoRadius is a read-only GEORADIUS_RO command. +// GeoRadius queries a geospatial index for members within a distance from a coordinate. +// This is a read-only variant that does not support Store or StoreDist options. +// +// Deprecated: Use GeoSearch with BYRADIUS argument instead as of Redis 6.2.0. func (c cmdable) GeoRadius( ctx context.Context, key string, longitude, latitude float64, query *GeoRadiusQuery, ) *GeoLocationCmd { @@ -60,7 +63,10 @@ func (c cmdable) GeoRadiusStore( return cmd } -// GeoRadiusByMember is a read-only GEORADIUSBYMEMBER_RO command. +// GeoRadiusByMember queries a geospatial index for members within a distance from a member. +// This is a read-only variant that does not support Store or StoreDist options. +// +// Deprecated: Use GeoSearch with BYRADIUS and FROMMEMBER arguments instead as of Redis 6.2.0. func (c cmdable) GeoRadiusByMember( ctx context.Context, key, member string, query *GeoRadiusQuery, ) *GeoLocationCmd { diff --git a/vendor/github.com/redis/go-redis/v9/hash_commands.go b/vendor/github.com/redis/go-redis/v9/hash_commands.go index 98a361b3e..256b8746b 100644 --- a/vendor/github.com/redis/go-redis/v9/hash_commands.go +++ b/vendor/github.com/redis/go-redis/v9/hash_commands.go @@ -3,6 +3,8 @@ package redis import ( "context" "time" + + "github.com/redis/go-redis/v9/internal/hashtag" ) type HashCmdable interface { @@ -68,6 +70,13 @@ func (c cmdable) HGet(ctx context.Context, key, field string) *StringCmd { return cmd } +// HGetAll returns a map of all fields and values stored at key. +// +// Returns an empty map when key does not exist. +// +// Time complexity: O(N) where N is the size of the hash. +// +// See https://redis.io/commands/hgetall/ func (c cmdable) HGetAll(ctx context.Context, key string) *MapStringStringCmd { cmd := NewMapStringStringCmd(ctx, "hgetall", key) _ = c(ctx, cmd) @@ -114,16 +123,16 @@ func (c cmdable) HMGet(ctx context.Context, key string, fields ...string) *Slice // HSet accepts values in following formats: // -// - HSet("myhash", "key1", "value1", "key2", "value2") +// - HSet(ctx, "myhash", "key1", "value1", "key2", "value2") // -// - HSet("myhash", []string{"key1", "value1", "key2", "value2"}) +// - HSet(ctx, "myhash", []string{"key1", "value1", "key2", "value2"}) // -// - HSet("myhash", map[string]interface{}{"key1": "value1", "key2": "value2"}) +// - HSet(ctx, "myhash", map[string]interface{}{"key1": "value1", "key2": "value2"}) // // Playing struct With "redis" tag. // type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` } // -// - HSet("myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0 +// - HSet(ctx, "myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0 // // For struct, can be a structure pointer type, we only parse the field whose tag is redis. // if you don't want the field to be read, you can use the `redis:"-"` flag to ignore it, @@ -192,6 +201,9 @@ func (c cmdable) HScan(ctx context.Context, key string, cursor uint64, match str args = append(args, "count", count) } cmd := NewScanCmd(ctx, c, args...) + if hashtag.Present(match) { + cmd.SetFirstKeyPos(4) + } _ = c(ctx, cmd) return cmd } @@ -211,6 +223,9 @@ func (c cmdable) HScanNoValues(ctx context.Context, key string, cursor uint64, m } args = append(args, "novalues") cmd := NewScanCmd(ctx, c, args...) + if hashtag.Present(match) { + cmd.SetFirstKeyPos(4) + } _ = c(ctx, cmd) return cmd } diff --git a/vendor/github.com/redis/go-redis/v9/hotkeys_commands.go b/vendor/github.com/redis/go-redis/v9/hotkeys_commands.go new file mode 100644 index 000000000..024db3ffe --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/hotkeys_commands.go @@ -0,0 +1,122 @@ +package redis + +import ( + "context" + "errors" + "strings" +) + +// HOTKEYS commands are only available on standalone *Client instances. +// They are NOT available on ClusterClient, Ring, or UniversalClient because +// HOTKEYS is a stateful command requiring session affinity - all operations +// (START, GET, STOP, RESET) must be sent to the same Redis node. +// +// If you are using UniversalClient and need HOTKEYS functionality, you must +// type assert to *Client first: +// +// if client, ok := universalClient.(*redis.Client); ok { +// result, err := client.HotKeysStart(ctx, args) +// // ... +// } + +// HotKeysMetric represents the metrics that can be tracked by the HOTKEYS command. +type HotKeysMetric string + +const ( + // HotKeysMetricCPU tracks CPU time spent on the key (in microseconds). + HotKeysMetricCPU HotKeysMetric = "CPU" + // HotKeysMetricNET tracks network bytes used by the key (ingress + egress + replication). + HotKeysMetricNET HotKeysMetric = "NET" +) + +// HotKeysStartArgs contains the arguments for the HOTKEYS START command. +// This command is only available on standalone clients due to its stateful nature +// requiring session affinity. It must NOT be used on cluster or pooled clients. +type HotKeysStartArgs struct { + // Metrics to track. At least one must be specified. + Metrics []HotKeysMetric + // Count is the number of top keys to report. + // Default: 10, Min: 10, Max: 64 + Count uint8 + // Duration is the auto-stop tracking after this many seconds. + // Default: 0 (no auto-stop) + Duration int64 + // Sample is the sample ratio - track keys with probability 1/sample. + // Default: 1 (track every key), Min: 1 + Sample int64 + // Slots specifies specific hash slots to track (0-16383). + // All specified slots must be hosted by the receiving node. + // If not specified, all slots are tracked. + Slots []uint16 +} + +// ErrHotKeysNoMetrics is returned when HotKeysStart is called without any metrics specified. +var ErrHotKeysNoMetrics = errors.New("redis: at least one metric must be specified for HOTKEYS START") + +// HotKeysStart starts collecting hotkeys data. +// At least one metric must be specified in args.Metrics. +// This command is only available on standalone clients. +func (c *Client) HotKeysStart(ctx context.Context, args *HotKeysStartArgs) *StatusCmd { + cmdArgs := make([]interface{}, 0, 16) + cmdArgs = append(cmdArgs, "hotkeys", "start") + + // Validate that at least one metric is specified + if len(args.Metrics) == 0 { + cmd := NewStatusCmd(ctx, cmdArgs...) + cmd.SetErr(ErrHotKeysNoMetrics) + return cmd + } + + cmdArgs = append(cmdArgs, "metrics", len(args.Metrics)) + for _, metric := range args.Metrics { + cmdArgs = append(cmdArgs, strings.ToLower(string(metric))) + } + + if args.Count > 0 { + cmdArgs = append(cmdArgs, "count", args.Count) + } + + if args.Duration > 0 { + cmdArgs = append(cmdArgs, "duration", args.Duration) + } + + if args.Sample > 0 { + cmdArgs = append(cmdArgs, "sample", args.Sample) + } + + if len(args.Slots) > 0 { + cmdArgs = append(cmdArgs, "slots", len(args.Slots)) + for _, slot := range args.Slots { + cmdArgs = append(cmdArgs, slot) + } + } + + cmd := NewStatusCmd(ctx, cmdArgs...) + _ = c.Process(ctx, cmd) + return cmd +} + +// HotKeysStop stops the ongoing hotkeys collection session. +// This command is only available on standalone clients. +func (c *Client) HotKeysStop(ctx context.Context) *StatusCmd { + cmd := NewStatusCmd(ctx, "hotkeys", "stop") + _ = c.Process(ctx, cmd) + return cmd +} + +// HotKeysReset discards the last hotkeys collection session results. +// Returns an error if tracking is currently active. +// This command is only available on standalone clients. +func (c *Client) HotKeysReset(ctx context.Context) *StatusCmd { + cmd := NewStatusCmd(ctx, "hotkeys", "reset") + _ = c.Process(ctx, cmd) + return cmd +} + +// HotKeysGet retrieves the results of the ongoing or last hotkeys collection session. +// This command is only available on standalone clients. +func (c *Client) HotKeysGet(ctx context.Context) *HotKeysCmd { + cmd := NewHotKeysCmd(ctx, "hotkeys", "get") + _ = c.Process(ctx, cmd) + return cmd +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/conn_reauth_credentials_listener.go b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/conn_reauth_credentials_listener.go new file mode 100644 index 000000000..22bfedf71 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/conn_reauth_credentials_listener.go @@ -0,0 +1,100 @@ +package streaming + +import ( + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// ConnReAuthCredentialsListener is a credentials listener for a specific connection +// that triggers re-authentication when credentials change. +// +// This listener implements the auth.CredentialsListener interface and is subscribed +// to a StreamingCredentialsProvider. When new credentials are received via OnNext, +// it marks the connection for re-authentication through the manager. +// +// The re-authentication is always performed asynchronously to avoid blocking the +// credentials provider and to prevent potential deadlocks with the pool semaphore. +// The actual re-auth happens when the connection is returned to the pool in an idle state. +// +// Lifecycle: +// - Created during connection initialization via Manager.Listener() +// - Subscribed to the StreamingCredentialsProvider +// - Receives credential updates via OnNext() +// - Cleaned up when connection is removed from pool via Manager.RemoveListener() +type ConnReAuthCredentialsListener struct { + // reAuth is the function to re-authenticate the connection with new credentials + reAuth func(conn *pool.Conn, credentials auth.Credentials) error + + // onErr is the function to call when re-authentication or acquisition fails + onErr func(conn *pool.Conn, err error) + + // conn is the connection this listener is associated with + conn *pool.Conn + + // manager is the streaming credentials manager for coordinating re-auth + manager *Manager +} + +// OnNext is called when new credentials are received from the StreamingCredentialsProvider. +// +// This method marks the connection for asynchronous re-authentication. The actual +// re-authentication happens in the background when the connection is returned to the +// pool and is in an idle state. +// +// Asynchronous re-auth is used to: +// - Avoid blocking the credentials provider's notification goroutine +// - Prevent deadlocks with the pool's semaphore (especially with small pool sizes) +// - Ensure re-auth happens when the connection is safe to use (not processing commands) +// +// The reAuthFn callback receives: +// - nil if the connection was successfully acquired for re-auth +// - error if acquisition timed out or failed +// +// Thread-safe: Called by the credentials provider's notification goroutine. +func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) { + if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil { + return + } + + // Always use async reauth to avoid complex pool semaphore issues + // The synchronous path can cause deadlocks in the pool's semaphore mechanism + // when called from the Subscribe goroutine, especially with small pool sizes. + // The connection pool hook will re-authenticate the connection when it is + // returned to the pool in a clean, idle state. + c.manager.MarkForReAuth(c.conn, func(err error) { + // err is from connection acquisition (timeout, etc.) + if err != nil { + // Log the error + c.OnError(err) + return + } + // err is from reauth command execution + err = c.reAuth(c.conn, credentials) + if err != nil { + // Log the error + c.OnError(err) + return + } + }) +} + +// OnError is called when an error occurs during credential streaming or re-authentication. +// +// This method can be called from: +// - The StreamingCredentialsProvider when there's an error in the credentials stream +// - The re-auth process when connection acquisition times out +// - The re-auth process when the AUTH command fails +// +// The error is delegated to the onErr callback provided during listener creation. +// +// Thread-safe: Can be called from multiple goroutines (provider, re-auth worker). +func (c *ConnReAuthCredentialsListener) OnError(err error) { + if c.onErr == nil { + return + } + + c.onErr(c.conn, err) +} + +// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface. +var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil) diff --git a/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/cred_listeners.go b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/cred_listeners.go new file mode 100644 index 000000000..66e6eafdc --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/cred_listeners.go @@ -0,0 +1,77 @@ +package streaming + +import ( + "sync" + + "github.com/redis/go-redis/v9/auth" +) + +// CredentialsListeners is a thread-safe collection of credentials listeners +// indexed by connection ID. +// +// This collection is used by the Manager to maintain a registry of listeners +// for each connection in the pool. Listeners are reused when connections are +// reinitialized (e.g., after a handoff) to avoid creating duplicate subscriptions +// to the StreamingCredentialsProvider. +// +// The collection supports concurrent access from multiple goroutines during +// connection initialization, credential updates, and connection removal. +type CredentialsListeners struct { + // listeners maps connection ID to credentials listener + listeners map[uint64]auth.CredentialsListener + + // lock protects concurrent access to the listeners map + lock sync.RWMutex +} + +// NewCredentialsListeners creates a new thread-safe credentials listeners collection. +func NewCredentialsListeners() *CredentialsListeners { + return &CredentialsListeners{ + listeners: make(map[uint64]auth.CredentialsListener), + } +} + +// Add adds or updates a credentials listener for a connection. +// +// If a listener already exists for the connection ID, it is replaced. +// This is safe because the old listener should have been unsubscribed +// before the connection was reinitialized. +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) { + c.lock.Lock() + defer c.lock.Unlock() + if c.listeners == nil { + c.listeners = make(map[uint64]auth.CredentialsListener) + } + c.listeners[connID] = listener +} + +// Get retrieves the credentials listener for a connection. +// +// Returns: +// - listener: The credentials listener for the connection, or nil if not found +// - ok: true if a listener exists for the connection ID, false otherwise +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + if len(c.listeners) == 0 { + return nil, false + } + listener, ok := c.listeners[connID] + return listener, ok +} + +// Remove removes the credentials listener for a connection. +// +// This is called when a connection is removed from the pool to prevent +// memory leaks. If no listener exists for the connection ID, this is a no-op. +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (c *CredentialsListeners) Remove(connID uint64) { + c.lock.Lock() + defer c.lock.Unlock() + delete(c.listeners, connID) +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/manager.go b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/manager.go new file mode 100644 index 000000000..f785927ee --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/manager.go @@ -0,0 +1,137 @@ +package streaming + +import ( + "errors" + "time" + + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/internal/pool" +) + +// Manager coordinates streaming credentials and re-authentication for a connection pool. +// +// The manager is responsible for: +// - Creating and managing per-connection credentials listeners +// - Providing the pool hook for re-authentication +// - Coordinating between credentials updates and pool operations +// +// When credentials change via a StreamingCredentialsProvider: +// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update +// 2. It calls MarkForReAuth on the manager +// 3. The manager delegates to the pool hook +// 4. The pool hook schedules background re-authentication +// +// The manager maintains a registry of credentials listeners indexed by connection ID, +// allowing listener reuse when connections are reinitialized (e.g., after handoff). +type Manager struct { + // credentialsListeners maps connection ID to credentials listener + credentialsListeners *CredentialsListeners + + // pool is the connection pool being managed + pool pool.Pooler + + // poolHookRef is the re-authentication pool hook + poolHookRef *ReAuthPoolHook +} + +// NewManager creates a new streaming credentials manager. +// +// Parameters: +// - pl: The connection pool to manage +// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication +// +// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that +// re-auth operations don't exhaust the connection pool. +func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager { + m := &Manager{ + pool: pl, + poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout), + credentialsListeners: NewCredentialsListeners(), + } + m.poolHookRef.manager = m + return m +} + +// PoolHook returns the pool hook for re-authentication. +// +// This hook should be registered with the connection pool to enable +// automatic re-authentication when credentials change. +func (m *Manager) PoolHook() pool.PoolHook { + return m.poolHookRef +} + +// Listener returns or creates a credentials listener for a connection. +// +// This method is called during connection initialization to set up the +// credentials listener. If a listener already exists for the connection ID +// (e.g., after a handoff), it is reused. +// +// Parameters: +// - poolCn: The connection to create/get a listener for +// - reAuth: Function to re-authenticate the connection with new credentials +// - onErr: Function to call when re-authentication fails +// +// Returns: +// - auth.CredentialsListener: The listener to subscribe to the credentials provider +// - error: Non-nil if poolCn is nil +// +// Note: The reAuth and onErr callbacks are captured once when the listener is +// created and reused for the connection's lifetime. They should not change. +// +// Thread-safe: Can be called concurrently during connection initialization. +func (m *Manager) Listener( + poolCn *pool.Conn, + reAuth func(*pool.Conn, auth.Credentials) error, + onErr func(*pool.Conn, error), +) (auth.CredentialsListener, error) { + if poolCn == nil { + return nil, errors.New("poolCn cannot be nil") + } + connID := poolCn.GetID() + // if we reconnect the underlying network connection, the streaming credentials listener will continue to work + // so we can get the old listener from the cache and use it. + // subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op + listener, ok := m.credentialsListeners.Get(connID) + if !ok || listener == nil { + // Create new listener for this connection + // Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime + newCredListener := &ConnReAuthCredentialsListener{ + conn: poolCn, + reAuth: reAuth, + onErr: onErr, + manager: m, + } + + m.credentialsListeners.Add(connID, newCredListener) + listener = newCredListener + } + return listener, nil +} + +// MarkForReAuth marks a connection for re-authentication. +// +// This method is called by the credentials listener when new credentials are +// received. It delegates to the pool hook to schedule background re-authentication. +// +// Parameters: +// - poolCn: The connection to re-authenticate +// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails +// +// Thread-safe: Called by credentials listeners when credentials change. +func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) { + connID := poolCn.GetID() + m.poolHookRef.MarkForReAuth(connID, reAuthFn) +} + +// RemoveListener removes the credentials listener for a connection. +// +// This method is called by the pool hook's OnRemove to clean up listeners +// when connections are removed from the pool. +// +// Parameters: +// - connID: The connection ID whose listener should be removed +// +// Thread-safe: Called during connection removal. +func (m *Manager) RemoveListener(connID uint64) { + m.credentialsListeners.Remove(connID) +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/pool_hook.go b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/pool_hook.go new file mode 100644 index 000000000..aaf4f6099 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/auth/streaming/pool_hook.go @@ -0,0 +1,241 @@ +package streaming + +import ( + "context" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" +) + +// ReAuthPoolHook is a pool hook that manages background re-authentication of connections +// when credentials change via a streaming credentials provider. +// +// The hook uses a semaphore-based worker pool to limit concurrent re-authentication +// operations and prevent pool exhaustion. When credentials change, connections are +// marked for re-authentication and processed asynchronously in the background. +// +// The re-authentication process: +// 1. OnPut: When a connection is returned to the pool, check if it needs re-auth +// 2. If yes, schedule it for background processing (move from shouldReAuth to scheduledReAuth) +// 3. A worker goroutine acquires the connection (waits until it's not in use) +// 4. Executes the re-auth function while holding the connection +// 5. Releases the connection back to the pool +// +// The hook ensures that: +// - Only one re-auth operation runs per connection at a time +// - Connections are not used for commands during re-authentication +// - Re-auth operations timeout if they can't acquire the connection +// - Resources are properly cleaned up on connection removal +type ReAuthPoolHook struct { + // shouldReAuth maps connection ID to re-auth function + // Connections in this map need re-authentication but haven't been scheduled yet + shouldReAuth map[uint64]func(error) + shouldReAuthLock sync.RWMutex + + // workers is a semaphore limiting concurrent re-auth operations + // Initialized with poolSize tokens to prevent pool exhaustion + // Uses FastSemaphore for better performance with eventual fairness + workers *internal.FastSemaphore + + // reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth + reAuthTimeout time.Duration + + // scheduledReAuth maps connection ID to scheduled status + // Connections in this map have a background worker attempting re-authentication + scheduledReAuth map[uint64]bool + scheduledLock sync.RWMutex + + // manager is a back-reference for cleanup operations + manager *Manager +} + +// NewReAuthPoolHook creates a new re-authentication pool hook. +// +// Parameters: +// - poolSize: Maximum number of concurrent re-auth operations (typically matches pool size) +// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication +// +// The poolSize parameter is used to initialize the worker semaphore, ensuring that +// re-auth operations don't exhaust the connection pool. +func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook { + return &ReAuthPoolHook{ + shouldReAuth: make(map[uint64]func(error)), + scheduledReAuth: make(map[uint64]bool), + workers: internal.NewFastSemaphore(int32(poolSize)), + reAuthTimeout: reAuthTimeout, + } +} + +// MarkForReAuth marks a connection for re-authentication. +// +// This method is called when credentials change and a connection needs to be +// re-authenticated. The actual re-authentication happens asynchronously when +// the connection is returned to the pool (in OnPut). +// +// Parameters: +// - connID: The connection ID to mark for re-authentication +// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails +// +// Thread-safe: Can be called concurrently from multiple goroutines. +func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) { + r.shouldReAuthLock.Lock() + defer r.shouldReAuthLock.Unlock() + r.shouldReAuth[connID] = reAuthFn +} + +// OnGet is called when a connection is retrieved from the pool. +// +// This hook checks if the connection needs re-authentication or has a scheduled +// re-auth operation. If so, it rejects the connection (returns accept=false), +// causing the pool to try another connection. +// +// Returns: +// - accept: false if connection needs re-auth, true otherwise +// - err: always nil (errors are not used in this hook) +// +// Thread-safe: Called concurrently by multiple goroutines getting connections. +func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { + connID := conn.GetID() + r.shouldReAuthLock.RLock() + _, shouldReAuth := r.shouldReAuth[connID] + r.shouldReAuthLock.RUnlock() + // This connection was marked for reauth while in the pool, + // reject the connection + if shouldReAuth { + // simply reject the connection, it will be re-authenticated in OnPut + return false, nil + } + r.scheduledLock.RLock() + _, hasScheduled := r.scheduledReAuth[connID] + r.scheduledLock.RUnlock() + // has scheduled reauth, reject the connection + if hasScheduled { + // simply reject the connection, it currently has a reauth scheduled + // and the worker is waiting for slot to execute the reauth + return false, nil + } + return true, nil +} + +// OnPut is called when a connection is returned to the pool. +// +// This hook checks if the connection needs re-authentication. If so, it schedules +// a background goroutine to perform the re-auth asynchronously. The goroutine: +// 1. Waits for a worker slot (semaphore) +// 2. Acquires the connection (waits until not in use) +// 3. Executes the re-auth function +// 4. Releases the connection and worker slot +// +// The connection is always pooled (not removed) since re-auth happens in background. +// +// Returns: +// - shouldPool: always true (connection stays in pool during background re-auth) +// - shouldRemove: always false +// - err: always nil +// +// Thread-safe: Called concurrently by multiple goroutines returning connections. +func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) { + if conn == nil { + // noop + return true, false, nil + } + connID := conn.GetID() + // Check if reauth is needed and get the function with proper locking + r.shouldReAuthLock.RLock() + reAuthFn, ok := r.shouldReAuth[connID] + r.shouldReAuthLock.RUnlock() + + if ok { + // Acquire both locks to atomically move from shouldReAuth to scheduledReAuth + // This prevents race conditions where OnGet might miss the transition + r.shouldReAuthLock.Lock() + r.scheduledLock.Lock() + r.scheduledReAuth[connID] = true + delete(r.shouldReAuth, connID) + r.scheduledLock.Unlock() + r.shouldReAuthLock.Unlock() + go func() { + r.workers.AcquireBlocking() + // safety first + if conn == nil || (conn != nil && conn.IsClosed()) { + r.workers.Release() + return + } + defer func() { + if rec := recover(); rec != nil { + // once again - safety first + internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec) + } + r.scheduledLock.Lock() + delete(r.scheduledReAuth, connID) + r.scheduledLock.Unlock() + r.workers.Release() + }() + + // Create timeout context for connection acquisition + // This prevents indefinite waiting if the connection is stuck + ctx, cancel := context.WithTimeout(context.Background(), r.reAuthTimeout) + defer cancel() + + // Try to acquire the connection for re-authentication + // We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE + // This prevents re-authentication from interfering with active commands + // Use AwaitAndTransition to wait for the connection to become IDLE + stateMachine := conn.GetStateMachine() + if stateMachine == nil { + // No state machine - should not happen, but handle gracefully + reAuthFn(pool.ErrConnUnusableTimeout) + return + } + + // Use predefined slice to avoid allocation + _, err := stateMachine.AwaitAndTransition(ctx, pool.ValidFromIdle(), pool.StateUnusable) + if err != nil { + // Timeout or other error occurred, cannot acquire connection + reAuthFn(err) + return + } + + // safety first + if !conn.IsClosed() { + // Successfully acquired the connection, perform reauth + reAuthFn(nil) + } + + // Release the connection: transition from UNUSABLE back to IDLE + stateMachine.Transition(pool.StateIdle) + }() + } + + // the reauth will happen in background, as far as the pool is concerned: + // pool the connection, don't remove it, no error + return true, false, nil +} + +// OnRemove is called when a connection is removed from the pool. +// +// This hook cleans up all state associated with the connection: +// - Removes from shouldReAuth map (pending re-auth) +// - Removes from scheduledReAuth map (active re-auth) +// - Removes credentials listener from manager +// +// This prevents memory leaks and ensures that removed connections don't have +// lingering re-auth operations or listeners. +// +// Thread-safe: Called when connections are removed due to errors, timeouts, or pool closure. +func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) { + connID := conn.GetID() + r.shouldReAuthLock.Lock() + r.scheduledLock.Lock() + delete(r.scheduledReAuth, connID) + delete(r.shouldReAuth, connID) + r.scheduledLock.Unlock() + r.shouldReAuthLock.Unlock() + if r.manager != nil { + r.manager.RemoveListener(connID) + } +} + +var _ pool.PoolHook = (*ReAuthPoolHook)(nil) diff --git a/vendor/github.com/redis/go-redis/v9/internal/hashtag/hashtag.go b/vendor/github.com/redis/go-redis/v9/internal/hashtag/hashtag.go index f13ee816d..8aa87db3d 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/hashtag/hashtag.go +++ b/vendor/github.com/redis/go-redis/v9/internal/hashtag/hashtag.go @@ -1,9 +1,8 @@ package hashtag import ( + "math/rand" "strings" - - "github.com/redis/go-redis/v9/internal/rand" ) const slotNumber = 16384 @@ -11,7 +10,7 @@ const slotNumber = 16384 // CRC16 implementation according to CCITT standards. // Copyright 2001-2010 Georges Menie (www.menie.org) // Copyright 2013 The Go Authors. All rights reserved. -// http://redis.io/topics/cluster-spec#appendix-a-crc16-reference-implementation-in-ansi-c +// https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec#appendix-a-crc16-reference-implementation-in-ansi-c. var crc16tab = [256]uint16{ 0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, 0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, @@ -56,6 +55,18 @@ func Key(key string) string { return key } +func Present(key string) bool { + if key == "" { + return false + } + if s := strings.IndexByte(key, '{'); s > -1 { + if e := strings.IndexByte(key[s+1:], '}'); e > 0 { + return true + } + } + return false +} + func RandomSlot() int { return rand.Intn(slotNumber) } diff --git a/vendor/github.com/redis/go-redis/v9/internal/hashtag/rendezvous.go b/vendor/github.com/redis/go-redis/v9/internal/hashtag/rendezvous.go new file mode 100644 index 000000000..214f7e8ea --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/hashtag/rendezvous.go @@ -0,0 +1,54 @@ +package hashtag + +import "github.com/cespare/xxhash/v2" + +// RendezvousHash implements HRW (Highest Random Weight) hashing. +type RendezvousHash struct { + nodes []node +} + +type node struct { + name string + hash uint64 +} + +// NewRendezvousHash builds a hash from shard names. +func NewRendezvousHash(shards []string) *RendezvousHash { + n := make([]node, len(shards)) + for i, s := range shards { + n[i] = node{ + name: s, + hash: xxhash.Sum64String(s), + } + } + return &RendezvousHash{nodes: n} +} + +// Get returns the shard name for the given key. +func (r *RendezvousHash) Get(key string) string { + if len(r.nodes) == 0 { + return "" + } + + kh := xxhash.Sum64String(key) + + bestIdx := 0 + bestScore := mix64(kh ^ r.nodes[0].hash) + + for i := 1; i < len(r.nodes); i++ { + if score := mix64(kh ^ r.nodes[i].hash); score > bestScore { + bestScore = score + bestIdx = i + } + } + + return r.nodes[bestIdx].name +} + +// mix64 is a xorshift-based mixing function. +func mix64(x uint64) uint64 { + x ^= x >> 12 + x ^= x << 25 + x ^= x >> 27 + return x * 2685821657736338717 +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/hscan/structmap.go b/vendor/github.com/redis/go-redis/v9/internal/hscan/structmap.go index 1a560e4a3..408ce0e4b 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/hscan/structmap.go +++ b/vendor/github.com/redis/go-redis/v9/internal/hscan/structmap.go @@ -109,6 +109,8 @@ func (s StructValue) Scan(key string, value string) error { return scan.ScanRedis(value) case encoding.TextUnmarshaler: return scan.UnmarshalText(util.StringToBytes(value)) + case encoding.BinaryUnmarshaler: + return scan.UnmarshalBinary(util.StringToBytes(value)) } } diff --git a/vendor/github.com/redis/go-redis/v9/internal/interfaces/interfaces.go b/vendor/github.com/redis/go-redis/v9/internal/interfaces/interfaces.go new file mode 100644 index 000000000..8f8569719 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/interfaces/interfaces.go @@ -0,0 +1,59 @@ +// Package interfaces provides shared interfaces used by both the main redis package +// and the maintnotifications upgrade package to avoid circular dependencies. +package interfaces + +import ( + "context" + "net" + "time" +) + +// NotificationProcessor is (most probably) a push.NotificationProcessor +// forward declaration to avoid circular imports +type NotificationProcessor interface { + RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error + UnregisterHandler(pushNotificationName string) error + GetHandler(pushNotificationName string) interface{} +} + +// ClientInterface defines the interface that clients must implement for maintnotifications upgrades. +type ClientInterface interface { + // GetOptions returns the client options. + GetOptions() OptionsInterface + + // GetPushProcessor returns the client's push notification processor. + GetPushProcessor() NotificationProcessor +} + +// OptionsInterface defines the interface for client options. +// Uses an adapter pattern to avoid circular dependencies. +type OptionsInterface interface { + // GetReadTimeout returns the read timeout. + GetReadTimeout() time.Duration + + // GetWriteTimeout returns the write timeout. + GetWriteTimeout() time.Duration + + // GetNetwork returns the network type. + GetNetwork() string + + // GetAddr returns the connection address. + GetAddr() string + + // GetNodeAddress returns the address of the Redis node as reported by the server. + // For cluster clients, this is the endpoint from CLUSTER SLOTS before any transformation. + // For standalone clients, this defaults to Addr. + GetNodeAddress() string + + // IsTLSEnabled returns true if TLS is enabled. + IsTLSEnabled() bool + + // GetProtocol returns the protocol version. + GetProtocol() int + + // GetPoolSize returns the connection pool size. + GetPoolSize() int + + // NewDialer returns a new dialer function for the connection. + NewDialer() func(context.Context) (net.Conn, error) +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/internal.go b/vendor/github.com/redis/go-redis/v9/internal/internal.go index e783d139a..403db56ca 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/internal.go +++ b/vendor/github.com/redis/go-redis/v9/internal/internal.go @@ -1,9 +1,8 @@ package internal import ( + "math/rand" "time" - - "github.com/redis/go-redis/v9/internal/rand" ) func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { diff --git a/vendor/github.com/redis/go-redis/v9/internal/log.go b/vendor/github.com/redis/go-redis/v9/internal/log.go index c8b9213de..0bfffc311 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/log.go +++ b/vendor/github.com/redis/go-redis/v9/internal/log.go @@ -7,20 +7,73 @@ import ( "os" ) +// TODO (ned): Revisit logging +// Add more standardized approach with log levels and configurability + type Logging interface { Printf(ctx context.Context, format string, v ...interface{}) } -type logger struct { +type DefaultLogger struct { log *log.Logger } -func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { +func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) { _ = l.log.Output(2, fmt.Sprintf(format, v...)) } +func NewDefaultLogger() Logging { + return &DefaultLogger{ + log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), + } +} + // Logger calls Output to print to the stderr. // Arguments are handled in the manner of fmt.Print. -var Logger Logging = &logger{ - log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), +var Logger Logging = NewDefaultLogger() + +var LogLevel LogLevelT = LogLevelError + +// LogLevelT represents the logging level +type LogLevelT int + +// Log level constants for the entire go-redis library +const ( + LogLevelError LogLevelT = iota // 0 - errors only + LogLevelWarn // 1 - warnings and errors + LogLevelInfo // 2 - info, warnings, and errors + LogLevelDebug // 3 - debug, info, warnings, and errors +) + +// String returns the string representation of the log level +func (l LogLevelT) String() string { + switch l { + case LogLevelError: + return "ERROR" + case LogLevelWarn: + return "WARN" + case LogLevelInfo: + return "INFO" + case LogLevelDebug: + return "DEBUG" + default: + return "UNKNOWN" + } +} + +// IsValid returns true if the log level is valid +func (l LogLevelT) IsValid() bool { + return l >= LogLevelError && l <= LogLevelDebug +} + +func (l LogLevelT) WarnOrAbove() bool { + return l >= LogLevelWarn +} + +func (l LogLevelT) InfoOrAbove() bool { + return l >= LogLevelInfo +} + +func (l LogLevelT) DebugOrAbove() bool { + return l >= LogLevelDebug } diff --git a/vendor/github.com/redis/go-redis/v9/internal/maintnotifications/logs/log_messages.go b/vendor/github.com/redis/go-redis/v9/internal/maintnotifications/logs/log_messages.go new file mode 100644 index 000000000..93e5bded8 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/maintnotifications/logs/log_messages.go @@ -0,0 +1,663 @@ +package logs + +import ( + "encoding/json" + "fmt" + "regexp" + + "github.com/redis/go-redis/v9/internal" +) + +// appendJSONIfDebug appends JSON data to a message only if the global log level is Debug +func appendJSONIfDebug(message string, data map[string]interface{}) string { + if internal.LogLevel.DebugOrAbove() { + jsonData, _ := json.Marshal(data) + return fmt.Sprintf("%s %s", message, string(jsonData)) + } + return message +} + +const ( + // ======================================== + // CIRCUIT_BREAKER.GO - Circuit breaker management + // ======================================== + CircuitBreakerTransitioningToHalfOpenMessage = "circuit breaker transitioning to half-open" + CircuitBreakerOpenedMessage = "circuit breaker opened" + CircuitBreakerReopenedMessage = "circuit breaker reopened" + CircuitBreakerClosedMessage = "circuit breaker closed" + CircuitBreakerCleanupMessage = "circuit breaker cleanup" + CircuitBreakerOpenMessage = "circuit breaker is open, failing fast" + + // ======================================== + // CONFIG.GO - Configuration and debug + // ======================================== + DebugLoggingEnabledMessage = "debug logging enabled" + ConfigDebugMessage = "config debug" + + // ======================================== + // ERRORS.GO - Error message constants + // ======================================== + InvalidRelaxedTimeoutErrorMessage = "relaxed timeout must be greater than 0" + InvalidHandoffTimeoutErrorMessage = "handoff timeout must be greater than 0" + InvalidHandoffWorkersErrorMessage = "MaxWorkers must be greater than or equal to 0" + InvalidHandoffQueueSizeErrorMessage = "handoff queue size must be greater than 0" + InvalidPostHandoffRelaxedDurationErrorMessage = "post-handoff relaxed duration must be greater than or equal to 0" + InvalidEndpointTypeErrorMessage = "invalid endpoint type" + InvalidMaintNotificationsErrorMessage = "invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')" + InvalidHandoffRetriesErrorMessage = "MaxHandoffRetries must be between 1 and 10" + InvalidClientErrorMessage = "invalid client type" + InvalidNotificationErrorMessage = "invalid notification format" + MaxHandoffRetriesReachedErrorMessage = "max handoff retries reached" + HandoffQueueFullErrorMessage = "handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration" + InvalidCircuitBreakerFailureThresholdErrorMessage = "circuit breaker failure threshold must be >= 1" + InvalidCircuitBreakerResetTimeoutErrorMessage = "circuit breaker reset timeout must be >= 0" + InvalidCircuitBreakerMaxRequestsErrorMessage = "circuit breaker max requests must be >= 1" + ConnectionMarkedForHandoffErrorMessage = "connection marked for handoff" + ConnectionInvalidHandoffStateErrorMessage = "connection is in invalid state for handoff" + ShutdownErrorMessage = "shutdown" + CircuitBreakerOpenErrorMessage = "circuit breaker is open, failing fast" + + // ======================================== + // EXAMPLE_HOOKS.GO - Example metrics hooks + // ======================================== + MetricsHookProcessingNotificationMessage = "metrics hook processing" + MetricsHookRecordedErrorMessage = "metrics hook recorded error" + + // ======================================== + // HANDOFF_WORKER.GO - Connection handoff processing + // ======================================== + HandoffStartedMessage = "handoff started" + HandoffFailedMessage = "handoff failed" + ConnectionNotMarkedForHandoffMessage = "is not marked for handoff and has no retries" + ConnectionNotMarkedForHandoffErrorMessage = "is not marked for handoff" + HandoffRetryAttemptMessage = "Performing handoff" + CannotQueueHandoffForRetryMessage = "can't queue handoff for retry" + HandoffQueueFullMessage = "handoff queue is full" + FailedToDialNewEndpointMessage = "failed to dial new endpoint" + ApplyingRelaxedTimeoutDueToPostHandoffMessage = "applying relaxed timeout due to post-handoff" + HandoffSuccessMessage = "handoff succeeded" + RemovingConnectionFromPoolMessage = "removing connection from pool" + NoPoolProvidedMessageCannotRemoveMessage = "no pool provided, cannot remove connection, closing it" + WorkerExitingDueToShutdownMessage = "worker exiting due to shutdown" + WorkerExitingDueToShutdownWhileProcessingMessage = "worker exiting due to shutdown while processing request" + WorkerPanicRecoveredMessage = "worker panic recovered" + WorkerExitingDueToInactivityTimeoutMessage = "worker exiting due to inactivity timeout" + ReachedMaxHandoffRetriesMessage = "reached max handoff retries" + + // ======================================== + // MANAGER.GO - Moving operation tracking and handler registration + // ======================================== + DuplicateMovingOperationMessage = "duplicate MOVING operation ignored" + TrackingMovingOperationMessage = "tracking MOVING operation" + UntrackingMovingOperationMessage = "untracking MOVING operation" + OperationNotTrackedMessage = "operation not tracked" + FailedToRegisterHandlerMessage = "failed to register handler" + + // ======================================== + // HOOKS.GO - Notification processing hooks + // ======================================== + ProcessingNotificationMessage = "processing notification started" + ProcessingNotificationFailedMessage = "proccessing notification failed" + ProcessingNotificationSucceededMessage = "processing notification succeeded" + + // ======================================== + // POOL_HOOK.GO - Pool connection management + // ======================================== + FailedToQueueHandoffMessage = "failed to queue handoff" + MarkedForHandoffMessage = "connection marked for handoff" + + // ======================================== + // PUSH_NOTIFICATION_HANDLER.GO - Push notification validation and processing + // ======================================== + InvalidNotificationFormatMessage = "invalid notification format" + InvalidNotificationTypeFormatMessage = "invalid notification type format" + InvalidSeqIDInMovingNotificationMessage = "invalid seqID in MOVING notification" + InvalidTimeSInMovingNotificationMessage = "invalid timeS in MOVING notification" + InvalidNewEndpointInMovingNotificationMessage = "invalid newEndpoint in MOVING notification" + NoConnectionInHandlerContextMessage = "no connection in handler context" + InvalidConnectionTypeInHandlerContextMessage = "invalid connection type in handler context" + SchedulingHandoffToCurrentEndpointMessage = "scheduling handoff to current endpoint" + RelaxedTimeoutDueToNotificationMessage = "applying relaxed timeout due to notification" + UnrelaxedTimeoutMessage = "clearing relaxed timeout" + ManagerNotInitializedMessage = "manager not initialized" + FailedToMarkForHandoffMessage = "failed to mark connection for handoff" + InvalidSeqIDInSMigratingNotificationMessage = "invalid SeqID in SMIGRATING notification" + InvalidSeqIDInSMigratedNotificationMessage = "invalid SeqID in SMIGRATED notification" + TriggeringClusterStateReloadMessage = "triggering cluster state reload" + + // ======================================== + // used in pool/conn + // ======================================== + UnrelaxedTimeoutAfterDeadlineMessage = "clearing relaxed timeout after deadline" +) + +func HandoffStarted(connID uint64, newEndpoint string) string { + message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffStartedMessage, newEndpoint) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": newEndpoint, + }) +} + +func HandoffFailed(connID uint64, newEndpoint string, attempt int, maxAttempts int, err error) string { + message := fmt.Sprintf("conn[%d] %s to %s (attempt %d/%d): %v", connID, HandoffFailedMessage, newEndpoint, attempt, maxAttempts, err) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": newEndpoint, + "attempt": attempt, + "maxAttempts": maxAttempts, + "error": err.Error(), + }) +} + +func HandoffSucceeded(connID uint64, newEndpoint string) string { + message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffSuccessMessage, newEndpoint) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": newEndpoint, + }) +} + +// Timeout-related log functions +func RelaxedTimeoutDueToNotification(connID uint64, notificationType string, timeout interface{}) string { + message := fmt.Sprintf("conn[%d] %s %s (%v)", connID, RelaxedTimeoutDueToNotificationMessage, notificationType, timeout) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "notificationType": notificationType, + "timeout": fmt.Sprintf("%v", timeout), + }) +} + +func UnrelaxedTimeout(connID uint64) string { + message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutMessage) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + }) +} + +func UnrelaxedTimeoutAfterDeadline(connID uint64) string { + message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutAfterDeadlineMessage) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + }) +} + +// Handoff queue and marking functions +func HandoffQueueFull(queueLen, queueCap int) string { + message := fmt.Sprintf("%s (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", HandoffQueueFullMessage, queueLen, queueCap) + return appendJSONIfDebug(message, map[string]interface{}{ + "queueLen": queueLen, + "queueCap": queueCap, + }) +} + +func FailedToQueueHandoff(connID uint64, err error) string { + message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToQueueHandoffMessage, err) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "error": err.Error(), + }) +} + +func FailedToMarkForHandoff(connID uint64, err error) string { + message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToMarkForHandoffMessage, err) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "error": err.Error(), + }) +} + +func FailedToDialNewEndpoint(connID uint64, endpoint string, err error) string { + message := fmt.Sprintf("conn[%d] %s %s: %v", connID, FailedToDialNewEndpointMessage, endpoint, err) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": endpoint, + "error": err.Error(), + }) +} + +func ReachedMaxHandoffRetries(connID uint64, endpoint string, maxRetries int) string { + message := fmt.Sprintf("conn[%d] %s to %s (max retries: %d)", connID, ReachedMaxHandoffRetriesMessage, endpoint, maxRetries) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": endpoint, + "maxRetries": maxRetries, + }) +} + +// Notification processing functions +func ProcessingNotification(connID uint64, seqID int64, notificationType string, notification interface{}) string { + message := fmt.Sprintf("conn[%d] seqID[%d] %s %s: %v", connID, seqID, ProcessingNotificationMessage, notificationType, notification) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "seqID": seqID, + "notificationType": notificationType, + "notification": fmt.Sprintf("%v", notification), + }) +} + +func ProcessingNotificationFailed(connID uint64, notificationType string, err error, notification interface{}) string { + message := fmt.Sprintf("conn[%d] %s %s: %v - %v", connID, ProcessingNotificationFailedMessage, notificationType, err, notification) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "notificationType": notificationType, + "error": err.Error(), + "notification": fmt.Sprintf("%v", notification), + }) +} + +func ProcessingNotificationSucceeded(connID uint64, notificationType string) string { + message := fmt.Sprintf("conn[%d] %s %s", connID, ProcessingNotificationSucceededMessage, notificationType) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "notificationType": notificationType, + }) +} + +// Moving operation tracking functions +func DuplicateMovingOperation(connID uint64, endpoint string, seqID int64) string { + message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, DuplicateMovingOperationMessage, endpoint, seqID) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": endpoint, + "seqID": seqID, + }) +} + +func TrackingMovingOperation(connID uint64, endpoint string, seqID int64) string { + message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, TrackingMovingOperationMessage, endpoint, seqID) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": endpoint, + "seqID": seqID, + }) +} + +func UntrackingMovingOperation(connID uint64, seqID int64) string { + message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, UntrackingMovingOperationMessage, seqID) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "seqID": seqID, + }) +} + +func OperationNotTracked(connID uint64, seqID int64) string { + message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, OperationNotTrackedMessage, seqID) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "seqID": seqID, + }) +} + +// Connection pool functions +func RemovingConnectionFromPool(connID uint64, reason error) string { + metadata := map[string]interface{}{ + "connID": connID, + "reason": "unknown", // this will be overwritten if reason is not nil + } + if reason != nil { + metadata["reason"] = reason.Error() + } + + message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason) + return appendJSONIfDebug(message, metadata) +} + +func NoPoolProvidedCannotRemove(connID uint64, reason error) string { + metadata := map[string]interface{}{ + "connID": connID, + "reason": "unknown", // this will be overwritten if reason is not nil + } + if reason != nil { + metadata["reason"] = reason.Error() + } + + message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason) + return appendJSONIfDebug(message, metadata) +} + +// Circuit breaker functions +func CircuitBreakerOpen(connID uint64, endpoint string) string { + message := fmt.Sprintf("conn[%d] %s for %s", connID, CircuitBreakerOpenMessage, endpoint) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "endpoint": endpoint, + }) +} + +// Additional handoff functions for specific cases +func ConnectionNotMarkedForHandoff(connID uint64) string { + message := fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffMessage) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + }) +} + +func ConnectionNotMarkedForHandoffError(connID uint64) string { + return fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffErrorMessage) +} + +func HandoffRetryAttempt(connID uint64, retries int, newEndpoint string, oldEndpoint string) string { + message := fmt.Sprintf("conn[%d] Retry %d: %s to %s(was %s)", connID, retries, HandoffRetryAttemptMessage, newEndpoint, oldEndpoint) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "retries": retries, + "newEndpoint": newEndpoint, + "oldEndpoint": oldEndpoint, + }) +} + +func CannotQueueHandoffForRetry(err error) string { + message := fmt.Sprintf("%s: %v", CannotQueueHandoffForRetryMessage, err) + return appendJSONIfDebug(message, map[string]interface{}{ + "error": err.Error(), + }) +} + +// Validation and error functions +func InvalidNotificationFormat(notification interface{}) string { + message := fmt.Sprintf("%s: %v", InvalidNotificationFormatMessage, notification) + return appendJSONIfDebug(message, map[string]interface{}{ + "notification": fmt.Sprintf("%v", notification), + }) +} + +func InvalidNotificationTypeFormat(notificationType interface{}) string { + message := fmt.Sprintf("%s: %v", InvalidNotificationTypeFormatMessage, notificationType) + return appendJSONIfDebug(message, map[string]interface{}{ + "notificationType": fmt.Sprintf("%v", notificationType), + }) +} + +// InvalidNotification creates a log message for invalid notifications of any type +func InvalidNotification(notificationType string, notification interface{}) string { + message := fmt.Sprintf("invalid %s notification: %v", notificationType, notification) + return appendJSONIfDebug(message, map[string]interface{}{ + "notificationType": notificationType, + "notification": fmt.Sprintf("%v", notification), + }) +} + +func InvalidSeqIDInMovingNotification(seqID interface{}) string { + message := fmt.Sprintf("%s: %v", InvalidSeqIDInMovingNotificationMessage, seqID) + return appendJSONIfDebug(message, map[string]interface{}{ + "seqID": fmt.Sprintf("%v", seqID), + }) +} + +func InvalidTimeSInMovingNotification(timeS interface{}) string { + message := fmt.Sprintf("%s: %v", InvalidTimeSInMovingNotificationMessage, timeS) + return appendJSONIfDebug(message, map[string]interface{}{ + "timeS": fmt.Sprintf("%v", timeS), + }) +} + +func InvalidNewEndpointInMovingNotification(newEndpoint interface{}) string { + message := fmt.Sprintf("%s: %v", InvalidNewEndpointInMovingNotificationMessage, newEndpoint) + return appendJSONIfDebug(message, map[string]interface{}{ + "newEndpoint": fmt.Sprintf("%v", newEndpoint), + }) +} + +func NoConnectionInHandlerContext(notificationType string) string { + message := fmt.Sprintf("%s for %s notification", NoConnectionInHandlerContextMessage, notificationType) + return appendJSONIfDebug(message, map[string]interface{}{ + "notificationType": notificationType, + }) +} + +func InvalidConnectionTypeInHandlerContext(notificationType string, conn interface{}, handlerCtx interface{}) string { + message := fmt.Sprintf("%s for %s notification - %T %#v", InvalidConnectionTypeInHandlerContextMessage, notificationType, conn, handlerCtx) + return appendJSONIfDebug(message, map[string]interface{}{ + "notificationType": notificationType, + "connType": fmt.Sprintf("%T", conn), + }) +} + +func SchedulingHandoffToCurrentEndpoint(connID uint64, seconds float64) string { + message := fmt.Sprintf("conn[%d] %s in %v seconds", connID, SchedulingHandoffToCurrentEndpointMessage, seconds) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "seconds": seconds, + }) +} + +func ManagerNotInitialized() string { + return appendJSONIfDebug(ManagerNotInitializedMessage, map[string]interface{}{}) +} + +func FailedToRegisterHandler(notificationType string, err error) string { + message := fmt.Sprintf("%s for %s: %v", FailedToRegisterHandlerMessage, notificationType, err) + return appendJSONIfDebug(message, map[string]interface{}{ + "notificationType": notificationType, + "error": err.Error(), + }) +} + +func ShutdownError() string { + return appendJSONIfDebug(ShutdownErrorMessage, map[string]interface{}{}) +} + +// Configuration validation error functions +func InvalidRelaxedTimeoutError() string { + return appendJSONIfDebug(InvalidRelaxedTimeoutErrorMessage, map[string]interface{}{}) +} + +func InvalidHandoffTimeoutError() string { + return appendJSONIfDebug(InvalidHandoffTimeoutErrorMessage, map[string]interface{}{}) +} + +func InvalidHandoffWorkersError() string { + return appendJSONIfDebug(InvalidHandoffWorkersErrorMessage, map[string]interface{}{}) +} + +func InvalidHandoffQueueSizeError() string { + return appendJSONIfDebug(InvalidHandoffQueueSizeErrorMessage, map[string]interface{}{}) +} + +func InvalidPostHandoffRelaxedDurationError() string { + return appendJSONIfDebug(InvalidPostHandoffRelaxedDurationErrorMessage, map[string]interface{}{}) +} + +func InvalidEndpointTypeError() string { + return appendJSONIfDebug(InvalidEndpointTypeErrorMessage, map[string]interface{}{}) +} + +func InvalidMaintNotificationsError() string { + return appendJSONIfDebug(InvalidMaintNotificationsErrorMessage, map[string]interface{}{}) +} + +func InvalidHandoffRetriesError() string { + return appendJSONIfDebug(InvalidHandoffRetriesErrorMessage, map[string]interface{}{}) +} + +func InvalidClientError() string { + return appendJSONIfDebug(InvalidClientErrorMessage, map[string]interface{}{}) +} + +func InvalidNotificationError() string { + return appendJSONIfDebug(InvalidNotificationErrorMessage, map[string]interface{}{}) +} + +func MaxHandoffRetriesReachedError() string { + return appendJSONIfDebug(MaxHandoffRetriesReachedErrorMessage, map[string]interface{}{}) +} + +func HandoffQueueFullError() string { + return appendJSONIfDebug(HandoffQueueFullErrorMessage, map[string]interface{}{}) +} + +func InvalidCircuitBreakerFailureThresholdError() string { + return appendJSONIfDebug(InvalidCircuitBreakerFailureThresholdErrorMessage, map[string]interface{}{}) +} + +func InvalidCircuitBreakerResetTimeoutError() string { + return appendJSONIfDebug(InvalidCircuitBreakerResetTimeoutErrorMessage, map[string]interface{}{}) +} + +func InvalidCircuitBreakerMaxRequestsError() string { + return appendJSONIfDebug(InvalidCircuitBreakerMaxRequestsErrorMessage, map[string]interface{}{}) +} + +// Configuration and debug functions +func DebugLoggingEnabled() string { + return appendJSONIfDebug(DebugLoggingEnabledMessage, map[string]interface{}{}) +} + +func ConfigDebug(config interface{}) string { + message := fmt.Sprintf("%s: %+v", ConfigDebugMessage, config) + return appendJSONIfDebug(message, map[string]interface{}{ + "config": fmt.Sprintf("%+v", config), + }) +} + +// Handoff worker functions +func WorkerExitingDueToShutdown() string { + return appendJSONIfDebug(WorkerExitingDueToShutdownMessage, map[string]interface{}{}) +} + +func WorkerExitingDueToShutdownWhileProcessing() string { + return appendJSONIfDebug(WorkerExitingDueToShutdownWhileProcessingMessage, map[string]interface{}{}) +} + +func WorkerPanicRecovered(panicValue interface{}) string { + message := fmt.Sprintf("%s: %v", WorkerPanicRecoveredMessage, panicValue) + return appendJSONIfDebug(message, map[string]interface{}{ + "panic": fmt.Sprintf("%v", panicValue), + }) +} + +func WorkerExitingDueToInactivityTimeout(timeout interface{}) string { + message := fmt.Sprintf("%s (%v)", WorkerExitingDueToInactivityTimeoutMessage, timeout) + return appendJSONIfDebug(message, map[string]interface{}{ + "timeout": fmt.Sprintf("%v", timeout), + }) +} + +func ApplyingRelaxedTimeoutDueToPostHandoff(connID uint64, timeout interface{}, until string) string { + message := fmt.Sprintf("conn[%d] %s (%v) until %s", connID, ApplyingRelaxedTimeoutDueToPostHandoffMessage, timeout, until) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + "timeout": fmt.Sprintf("%v", timeout), + "until": until, + }) +} + +// Example hooks functions +func MetricsHookProcessingNotification(notificationType string, connID uint64) string { + message := fmt.Sprintf("%s %s notification on conn[%d]", MetricsHookProcessingNotificationMessage, notificationType, connID) + return appendJSONIfDebug(message, map[string]interface{}{ + "notificationType": notificationType, + "connID": connID, + }) +} + +func MetricsHookRecordedError(notificationType string, connID uint64, err error) string { + message := fmt.Sprintf("%s for %s notification on conn[%d]: %v", MetricsHookRecordedErrorMessage, notificationType, connID, err) + return appendJSONIfDebug(message, map[string]interface{}{ + "notificationType": notificationType, + "connID": connID, + "error": err.Error(), + }) +} + +// Pool hook functions +func MarkedForHandoff(connID uint64) string { + message := fmt.Sprintf("conn[%d] %s", connID, MarkedForHandoffMessage) + return appendJSONIfDebug(message, map[string]interface{}{ + "connID": connID, + }) +} + +// Circuit breaker additional functions +func CircuitBreakerTransitioningToHalfOpen(endpoint string) string { + message := fmt.Sprintf("%s for %s", CircuitBreakerTransitioningToHalfOpenMessage, endpoint) + return appendJSONIfDebug(message, map[string]interface{}{ + "endpoint": endpoint, + }) +} + +func CircuitBreakerOpened(endpoint string, failures int64) string { + message := fmt.Sprintf("%s for endpoint %s after %d failures", CircuitBreakerOpenedMessage, endpoint, failures) + return appendJSONIfDebug(message, map[string]interface{}{ + "endpoint": endpoint, + "failures": failures, + }) +} + +func CircuitBreakerReopened(endpoint string) string { + message := fmt.Sprintf("%s for endpoint %s due to failure in half-open state", CircuitBreakerReopenedMessage, endpoint) + return appendJSONIfDebug(message, map[string]interface{}{ + "endpoint": endpoint, + }) +} + +func CircuitBreakerClosed(endpoint string, successes int64) string { + message := fmt.Sprintf("%s for endpoint %s after %d successful requests", CircuitBreakerClosedMessage, endpoint, successes) + return appendJSONIfDebug(message, map[string]interface{}{ + "endpoint": endpoint, + "successes": successes, + }) +} + +func CircuitBreakerCleanup(removed int, total int) string { + message := fmt.Sprintf("%s removed %d/%d entries", CircuitBreakerCleanupMessage, removed, total) + return appendJSONIfDebug(message, map[string]interface{}{ + "removed": removed, + "total": total, + }) +} + +// ExtractDataFromLogMessage extracts structured data from maintnotifications log messages +// Returns a map containing the parsed key-value pairs from the structured data section +// Example: "conn[123] handoff started to localhost:6379 {"connID":123,"endpoint":"localhost:6379"}" +// Returns: map[string]interface{}{"connID": 123, "endpoint": "localhost:6379"} +func ExtractDataFromLogMessage(logMessage string) map[string]interface{} { + result := make(map[string]interface{}) + + // Find the JSON data section at the end of the message + re := regexp.MustCompile(`(\{.*\})$`) + matches := re.FindStringSubmatch(logMessage) + if len(matches) < 2 { + return result + } + + jsonStr := matches[1] + if jsonStr == "" { + return result + } + + // Parse the JSON directly + var jsonResult map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &jsonResult); err == nil { + return jsonResult + } + + // If JSON parsing fails, return empty map + return result +} + +// Cluster notification functions +func InvalidSeqIDInSMigratingNotification(seqID interface{}) string { + message := fmt.Sprintf("%s: %v", InvalidSeqIDInSMigratingNotificationMessage, seqID) + return appendJSONIfDebug(message, map[string]interface{}{ + "seqID": fmt.Sprintf("%v", seqID), + }) +} + +func InvalidSeqIDInSMigratedNotification(seqID interface{}) string { + message := fmt.Sprintf("%s: %v", InvalidSeqIDInSMigratedNotificationMessage, seqID) + return appendJSONIfDebug(message, map[string]interface{}{ + "seqID": fmt.Sprintf("%v", seqID), + }) +} + +// TriggeringClusterStateReload logs when cluster state reload is triggered (deduplicated, once per seqID) +func TriggeringClusterStateReload(seqID int64, hostPort string, slotRanges []string) string { + message := fmt.Sprintf("%s seqID=%d host:port=%s slots=%v", TriggeringClusterStateReloadMessage, seqID, hostPort, slotRanges) + return appendJSONIfDebug(message, map[string]interface{}{ + "seqID": seqID, + "hostPort": hostPort, + "slotRanges": slotRanges, + }) +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/otel/metrics.go b/vendor/github.com/redis/go-redis/v9/internal/otel/metrics.go new file mode 100644 index 000000000..a3f23fffe --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/otel/metrics.go @@ -0,0 +1,298 @@ +package otel + +import ( + "context" + "crypto/rand" + "encoding/hex" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// generateUniqueID generates a short unique identifier for pool names. +func generateUniqueID() string { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + return "" + } + return hex.EncodeToString(b) +} + +// Cmder is a minimal interface for command information needed for metrics. +// This avoids circular dependencies with the main redis package. +type Cmder interface { + Name() string + FullName() string + Args() []interface{} + Err() error +} + +// Recorder is the interface for recording metrics. +type Recorder interface { + // RecordOperationDuration records the total operation duration (including all retries) + // dbIndex is the Redis database index (0-15) + RecordOperationDuration(ctx context.Context, duration time.Duration, cmd Cmder, attempts int, err error, cn *pool.Conn, dbIndex int) + + // RecordPipelineOperationDuration records the total pipeline/transaction duration. + // operationName should be "PIPELINE" for regular pipelines or "MULTI" for transactions. + // cmdCount is the number of commands in the pipeline. + // err is the error from the pipeline execution (can be nil). + // dbIndex is the Redis database index (0-15) + RecordPipelineOperationDuration(ctx context.Context, duration time.Duration, operationName string, cmdCount int, attempts int, err error, cn *pool.Conn, dbIndex int) + + // RecordConnectionCreateTime records the time it took to create a new connection + RecordConnectionCreateTime(ctx context.Context, duration time.Duration, cn *pool.Conn) + + // RecordConnectionRelaxedTimeout records when connection timeout is relaxed/unrelaxed + // delta: +1 for relaxed, -1 for unrelaxed + // poolName: name of the connection pool (e.g., "main", "pubsub") + // notificationType: the notification type that triggered the timeout relaxation (e.g., "MOVING") + RecordConnectionRelaxedTimeout(ctx context.Context, delta int, cn *pool.Conn, poolName, notificationType string) + + // RecordConnectionHandoff records when a connection is handed off to another node + // poolName: name of the connection pool (e.g., "main", "pubsub") + RecordConnectionHandoff(ctx context.Context, cn *pool.Conn, poolName string) + + // RecordError records client errors (ASK, MOVED, handshake failures, etc.) + // errorType: type of error (e.g., "ASK", "MOVED", "HANDSHAKE_FAILED") + // statusCode: Redis response status code if available (e.g., "MOVED", "ASK") + // isInternal: whether this is an internal error + // retryAttempts: number of retry attempts made + RecordError(ctx context.Context, errorType string, cn *pool.Conn, statusCode string, isInternal bool, retryAttempts int) + + // RecordMaintenanceNotification records when a maintenance notification is received + // notificationType: the type of notification (e.g., "MOVING", "MIGRATING", etc.) + RecordMaintenanceNotification(ctx context.Context, cn *pool.Conn, notificationType string) + + // RecordConnectionWaitTime records the time spent waiting for a connection from the pool + RecordConnectionWaitTime(ctx context.Context, duration time.Duration, cn *pool.Conn) + + // RecordConnectionClosed records when a connection is closed + // reason: reason for closing (e.g., "idle", "max_lifetime", "error", "pool_closed") + // err: the error that caused the close (nil for non-error closures) + RecordConnectionClosed(ctx context.Context, cn *pool.Conn, reason string, err error) + + // RecordPubSubMessage records a Pub/Sub message + // direction: "sent" or "received" + // channel: channel name (may be hidden for cardinality reduction) + // sharded: true for sharded pub/sub (SPUBLISH/SSUBSCRIBE) + RecordPubSubMessage(ctx context.Context, cn *pool.Conn, direction, channel string, sharded bool) + + // RecordStreamLag records the lag for stream consumer group processing + // lag: time difference between message creation and consumption + // streamName: name of the stream (may be hidden for cardinality reduction) + // consumerGroup: name of the consumer group + // consumerName: name of the consumer + RecordStreamLag(ctx context.Context, lag time.Duration, cn *pool.Conn, streamName, consumerGroup, consumerName string) + + // RecordConnectionCount records a change in connection count (UpDownCounter) + // delta: +1 when connection added, -1 when connection removed + // state: connection state (e.g., "idle", "used") + // isPubSub: true if this is a PubSub connection + RecordConnectionCount(ctx context.Context, delta int, cn *pool.Conn, state string, isPubSub bool) + + // RecordPendingRequests records a change in pending requests (UpDownCounter) + // delta: +1 when request starts waiting, -1 when request stops waiting + // poolName is passed explicitly because we may not have a connection yet when request starts + RecordPendingRequests(ctx context.Context, delta int, cn *pool.Conn, poolName string) +} + +type PubSubPooler interface { + Stats() *pool.PubSubStats +} + +type PoolRegistrar interface { + // RegisterPool is called when a new client is created with its connection pools. + // poolName: identifier for the pool (e.g., "main_abc123") + // pool: the connection pool + RegisterPool(poolName string, pool pool.Pooler) + // UnregisterPool is called when a client is closed to remove its pool from the registry. + // pool: the connection pool to unregister + UnregisterPool(pool pool.Pooler) + // RegisterPubSubPool is called when a new client is created with a PubSub pool. + // poolName: identifier for the pool (e.g., "main_abc123_pubsub") + // pool: the PubSub connection pool + RegisterPubSubPool(poolName string, pool PubSubPooler) + // UnregisterPubSubPool is called when a PubSub client is closed to remove its pool. + // pool: the PubSub connection pool to unregister + UnregisterPubSubPool(pool PubSubPooler) +} + +var ( + // recorderMu protects globalRecorder and operation duration callbacks + recorderMu sync.RWMutex + + // Global recorder instance (initialized by extra/redisotel-native) + globalRecorder Recorder = noopRecorder{} + + // Callbacks for operation duration metrics + operationDurationCallback func(ctx context.Context, duration time.Duration, cmd Cmder, attempts int, err error, cn *pool.Conn, dbIndex int) + pipelineOperationDurationCallback func(ctx context.Context, duration time.Duration, operationName string, cmdCount int, attempts int, err error, cn *pool.Conn, dbIndex int) +) + +// GetOperationDurationCallback returns the callback for operation duration. +func GetOperationDurationCallback() func(ctx context.Context, duration time.Duration, cmd Cmder, attempts int, err error, cn *pool.Conn, dbIndex int) { + recorderMu.RLock() + cb := operationDurationCallback + recorderMu.RUnlock() + return cb +} + +// GetPipelineOperationDurationCallback returns the callback for pipeline operation duration. +func GetPipelineOperationDurationCallback() func(ctx context.Context, duration time.Duration, operationName string, cmdCount int, attempts int, err error, cn *pool.Conn, dbIndex int) { + recorderMu.RLock() + cb := pipelineOperationDurationCallback + recorderMu.RUnlock() + return cb +} + +// getRecorder returns the current global recorder under a read lock. +func getRecorder() Recorder { + recorderMu.RLock() + r := globalRecorder + recorderMu.RUnlock() + return r +} + +// SetGlobalRecorder sets the global recorder (called by Init() in extra/redisotel-native) +func SetGlobalRecorder(r Recorder) { + recorderMu.Lock() + if r == nil { + globalRecorder = noopRecorder{} + operationDurationCallback = nil + pipelineOperationDurationCallback = nil + recorderMu.Unlock() + // Unregister all pool metric callbacks atomically + pool.SetAllMetricCallbacks(nil) + return + } + globalRecorder = r + + // Register operation duration callbacks + // These capture r directly since we want them to use the specific recorder + // that was set at this point in time + operationDurationCallback = func(ctx context.Context, duration time.Duration, cmd Cmder, attempts int, err error, cn *pool.Conn, dbIndex int) { + getRecorder().RecordOperationDuration(ctx, duration, cmd, attempts, err, cn, dbIndex) + } + pipelineOperationDurationCallback = func(ctx context.Context, duration time.Duration, operationName string, cmdCount int, attempts int, err error, cn *pool.Conn, dbIndex int) { + getRecorder().RecordPipelineOperationDuration(ctx, duration, operationName, cmdCount, attempts, err, cn, dbIndex) + } + recorderMu.Unlock() + + // Register all pool metric callbacks atomically + // These use getRecorder() to safely access the current recorder + pool.SetAllMetricCallbacks(&pool.MetricCallbacks{ + ConnectionCreateTime: func(ctx context.Context, duration time.Duration, cn *pool.Conn) { + getRecorder().RecordConnectionCreateTime(ctx, duration, cn) + }, + ConnectionRelaxedTimeout: func(ctx context.Context, delta int, cn *pool.Conn, poolName, notificationType string) { + getRecorder().RecordConnectionRelaxedTimeout(ctx, delta, cn, poolName, notificationType) + }, + ConnectionHandoff: func(ctx context.Context, cn *pool.Conn, poolName string) { + getRecorder().RecordConnectionHandoff(ctx, cn, poolName) + }, + Error: func(ctx context.Context, errorType string, cn *pool.Conn, statusCode string, isInternal bool, retryAttempts int) { + getRecorder().RecordError(ctx, errorType, cn, statusCode, isInternal, retryAttempts) + }, + MaintenanceNotification: func(ctx context.Context, cn *pool.Conn, notificationType string) { + getRecorder().RecordMaintenanceNotification(ctx, cn, notificationType) + }, + ConnectionWaitTime: func(ctx context.Context, duration time.Duration, cn *pool.Conn) { + getRecorder().RecordConnectionWaitTime(ctx, duration, cn) + }, + ConnectionClosed: func(ctx context.Context, cn *pool.Conn, reason string, err error) { + getRecorder().RecordConnectionClosed(ctx, cn, reason, err) + }, + ConnectionCount: func(ctx context.Context, delta int, cn *pool.Conn, state string, isPubSub bool) { + getRecorder().RecordConnectionCount(ctx, delta, cn, state, isPubSub) + }, + PendingRequests: func(ctx context.Context, delta int, cn *pool.Conn, poolName string) { + getRecorder().RecordPendingRequests(ctx, delta, cn, poolName) + }, + }) +} + +// RecordOperationDuration records the total operation duration. +// dbIndex is the Redis database index (0-15). +func RecordOperationDuration(ctx context.Context, duration time.Duration, cmd Cmder, attempts int, err error, cn *pool.Conn, dbIndex int) { + getRecorder().RecordOperationDuration(ctx, duration, cmd, attempts, err, cn, dbIndex) +} + +// RecordPipelineOperationDuration records the total pipeline/transaction duration. +// This is called from redis.go after pipeline/transaction execution completes. +// operationName should be "PIPELINE" for regular pipelines or "MULTI" for transactions. +// err is the error from the pipeline execution (can be nil). +// dbIndex is the Redis database index (0-15). +func RecordPipelineOperationDuration(ctx context.Context, duration time.Duration, operationName string, cmdCount int, attempts int, err error, cn *pool.Conn, dbIndex int) { + getRecorder().RecordPipelineOperationDuration(ctx, duration, operationName, cmdCount, attempts, err, cn, dbIndex) +} + +// RecordConnectionCreateTime records the time it took to create a new connection. +func RecordConnectionCreateTime(ctx context.Context, duration time.Duration, cn *pool.Conn) { + getRecorder().RecordConnectionCreateTime(ctx, duration, cn) +} + +// RecordPubSubMessage records a Pub/Sub message sent or received. +func RecordPubSubMessage(ctx context.Context, cn *pool.Conn, direction, channel string, sharded bool) { + getRecorder().RecordPubSubMessage(ctx, cn, direction, channel, sharded) +} + +// RecordStreamLag records the lag between message creation and consumption in a stream. +func RecordStreamLag(ctx context.Context, lag time.Duration, cn *pool.Conn, streamName, consumerGroup, consumerName string) { + getRecorder().RecordStreamLag(ctx, lag, cn, streamName, consumerGroup, consumerName) +} + +type noopRecorder struct{} + +func (noopRecorder) RecordOperationDuration(context.Context, time.Duration, Cmder, int, error, *pool.Conn, int) { +} +func (noopRecorder) RecordPipelineOperationDuration(context.Context, time.Duration, string, int, int, error, *pool.Conn, int) { +} +func (noopRecorder) RecordConnectionCreateTime(context.Context, time.Duration, *pool.Conn) {} +func (noopRecorder) RecordConnectionRelaxedTimeout(context.Context, int, *pool.Conn, string, string) { +} +func (noopRecorder) RecordConnectionHandoff(context.Context, *pool.Conn, string) {} +func (noopRecorder) RecordError(context.Context, string, *pool.Conn, string, bool, int) {} +func (noopRecorder) RecordMaintenanceNotification(context.Context, *pool.Conn, string) {} + +func (noopRecorder) RecordConnectionWaitTime(context.Context, time.Duration, *pool.Conn) {} +func (noopRecorder) RecordConnectionClosed(context.Context, *pool.Conn, string, error) {} + +func (noopRecorder) RecordPubSubMessage(context.Context, *pool.Conn, string, string, bool) {} + +func (noopRecorder) RecordStreamLag(context.Context, time.Duration, *pool.Conn, string, string, string) { +} +func (noopRecorder) RecordConnectionCount(context.Context, int, *pool.Conn, string, bool) {} +func (noopRecorder) RecordPendingRequests(context.Context, int, *pool.Conn, string) {} + +// RegisterPools registers connection pools with the global recorder. +func RegisterPools(connPool pool.Pooler, pubSubPool PubSubPooler, addr string) { + // Check if the global recorder implements PoolRegistrar + if registrar, ok := globalRecorder.(PoolRegistrar); ok { + // Generate a unique ID for this client's pools + uniqueID := generateUniqueID() + + if connPool != nil { + poolName := addr + "_" + uniqueID + registrar.RegisterPool(poolName, connPool) + } + if pubSubPool != nil { + poolName := addr + "_" + uniqueID + "_pubsub" + registrar.RegisterPubSubPool(poolName, pubSubPool) + } + } +} + +// UnregisterPools removes connection pools from the global recorder +func UnregisterPools(connPool pool.Pooler, pubSubPool PubSubPooler) { + // Check if the global recorder implements PoolRegistrar + if registrar, ok := globalRecorder.(PoolRegistrar); ok { + if connPool != nil { + registrar.UnregisterPool(connPool) + } + if pubSubPool != nil { + registrar.UnregisterPubSubPool(pubSubPool) + } + } +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/conn.go b/vendor/github.com/redis/go-redis/v9/internal/pool/conn.go index 7f45bc0bb..fab54654a 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/pool/conn.go +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/conn.go @@ -1,64 +1,872 @@ +// Package pool implements the pool management package pool import ( "bufio" "context" + "errors" + "fmt" "net" + "sync" "sync/atomic" "time" + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/proto" + uberatomic "go.uber.org/atomic" ) var noDeadline = time.Time{} +// Preallocated errors for hot paths to avoid allocations +var ( + errAlreadyMarkedForHandoff = errors.New("connection is already marked for handoff") + errNotMarkedForHandoff = errors.New("connection was not marked for handoff") + errHandoffStateChanged = errors.New("handoff state changed during marking") + errConnectionNotAvailable = errors.New("redis: connection not available") + errConnNotAvailableForWrite = errors.New("redis: connection not available for write operation") +) + +// getCachedTimeNs returns the current time in nanoseconds. +// This function previously used a global cache updated by a background goroutine, +// but that caused unnecessary CPU usage when the client was idle (ticker waking up +// the scheduler every 50ms). We now use time.Now() directly, which is fast enough +// on modern systems (vDSO on Linux) and only adds ~1-2% overhead in extreme +// high-concurrency benchmarks while eliminating idle CPU usage. +func getCachedTimeNs() int64 { + return time.Now().UnixNano() +} + +// GetCachedTimeNs returns the current time in nanoseconds. +// Exported for use by other packages that need fast time access. +func GetCachedTimeNs() int64 { + return getCachedTimeNs() +} + +// Global atomic counter for connection IDs +var connIDCounter uint64 + +// HandoffState represents the atomic state for connection handoffs +// This struct is stored atomically to prevent race conditions between +// checking handoff status and reading handoff parameters +type HandoffState struct { + ShouldHandoff bool // Whether connection should be handed off + Endpoint string // New endpoint for handoff + SeqID int64 // Sequence ID from MOVING notification +} + +// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value +type atomicNetConn struct { + conn net.Conn +} + +// generateConnID generates a fast unique identifier for a connection with zero allocations +func generateConnID() uint64 { + return atomic.AddUint64(&connIDCounter, 1) +} + type Conn struct { - usedAt int64 // atomic - netConn net.Conn + // Connection identifier for unique tracking + id uint64 + + usedAt atomic.Int64 + lastPutAt atomic.Int64 + dialStartNs atomic.Int64 // Time when dial started (for connection create time metric) + + // Lock-free netConn access using atomic.Value + // Contains *atomicNetConn wrapper, accessed atomically for better performance + netConnAtomic atomic.Value // stores *atomicNetConn rd *proto.Reader bw *bufio.Writer wr *proto.Writer - Inited bool + // Lightweight mutex to protect reader operations during handoff and health checks + // Used during: + // - SetNetConn (write lock for resetting reader state) + // - HasBufferedData/PeekReplyTypeSafe (read lock for safe concurrent peek operations) + readerMu sync.RWMutex + + // State machine for connection state management + // Replaces: usable, Inited, used + // Provides thread-safe state transitions with FIFO waiting queue + // States: CREATED → INITIALIZING → IDLE ⇄ IN_USE + // ↓ + // UNUSABLE (handoff/reauth) + // ↓ + // IDLE/CLOSED + stateMachine *ConnStateMachine + + // Handoff metadata - managed separately from state machine + // These are atomic for lock-free access during handoff operations + handoffStateAtomic atomic.Value // stores *HandoffState + handoffRetriesAtomic atomic.Uint32 // retry counter + pooled bool + pubsub bool createdAt time.Time + expiresAt time.Time + poolName string // Name of the pool this connection belongs to (for metrics) + + // When a goroutine closes a connection, it usually knows the reason, so closeReason is not needed. + // closeReason is only used when an in-use connection is closed by another goroutine, + // to inform the goroutine using the connection why the connection was closed. + closeReason uberatomic.String + + // maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers + + // Using atomic operations for lock-free access to avoid mutex contention + relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch + + // Counter to track multiple relaxed timeout setters if we have nested calls + // will be decremented when ClearRelaxedTimeout is called or deadline is reached + // if counter reaches 0, we clear the relaxed timeouts + relaxedCounter atomic.Int32 + + // Connection initialization function for reconnections + initConnFunc func(context.Context, *Conn) error + + onClose func() error } func NewConn(netConn net.Conn) *Conn { + return NewConnWithBufferSize(netConn, proto.DefaultBufferSize, proto.DefaultBufferSize) +} + +func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn { + now := time.Now() cn := &Conn{ - netConn: netConn, - createdAt: time.Now(), + createdAt: now, + id: generateConnID(), // Generate unique ID for this connection + stateMachine: NewConnStateMachine(), + } + + // Use specified buffer sizes, or fall back to 32KiB defaults if 0 + if readBufSize > 0 { + cn.rd = proto.NewReaderSize(netConn, readBufSize) + } else { + cn.rd = proto.NewReader(netConn) // Uses 32KiB default + } + + if writeBufSize > 0 { + cn.bw = bufio.NewWriterSize(netConn, writeBufSize) + } else { + cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize) } - cn.rd = proto.NewReader(netConn) - cn.bw = bufio.NewWriter(netConn) + + // Store netConn atomically for lock-free access using wrapper + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) + cn.wr = proto.NewWriter(cn.bw) - cn.SetUsedAt(time.Now()) + cn.SetUsedAt(now) + // Initialize handoff state atomically + initialHandoffState := &HandoffState{ + ShouldHandoff: false, + Endpoint: "", + SeqID: 0, + } + cn.handoffStateAtomic.Store(initialHandoffState) return cn } func (cn *Conn) UsedAt() time.Time { - unix := atomic.LoadInt64(&cn.usedAt) - return time.Unix(unix, 0) + return time.Unix(0, cn.usedAt.Load()) } - func (cn *Conn) SetUsedAt(tm time.Time) { - atomic.StoreInt64(&cn.usedAt, tm.Unix()) + cn.usedAt.Store(tm.UnixNano()) +} + +func (cn *Conn) UsedAtNs() int64 { + return cn.usedAt.Load() +} +func (cn *Conn) SetUsedAtNs(ns int64) { + cn.usedAt.Store(ns) +} + +func (cn *Conn) LastPutAtNs() int64 { + return cn.lastPutAt.Load() +} +func (cn *Conn) SetLastPutAtNs(ns int64) { + cn.lastPutAt.Store(ns) +} + +// GetDialStartNs returns the time when the dial started (in nanoseconds since epoch). +// This is used to calculate the full connection creation time (TCP + handshake). +func (cn *Conn) GetDialStartNs() int64 { + return cn.dialStartNs.Load() +} + +// PoolName returns the name of the pool this connection belongs to. +// This is used for metrics to identify which pool a connection is from. +func (cn *Conn) PoolName() string { + return cn.poolName +} + +// SetPoolName sets the name of the pool this connection belongs to. +// This should be called when the connection is added to a pool. +func (cn *Conn) SetPoolName(name string) { + cn.poolName = name +} + +// Backward-compatible wrapper methods for state machine +// These maintain the existing API while using the new state machine internally + +// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free). +// +// This is used by background operations (handoff, re-auth) to acquire exclusive +// access to a connection. The operation sets usable to false, preventing the pool +// from returning the connection to clients. +// +// Returns true if the swap was successful (old value matched), false otherwise. +// +// Implementation note: This is a compatibility wrapper around the state machine. +// It checks if the current state is "usable" (IDLE or IN_USE) and transitions accordingly. +// Deprecated: Use GetStateMachine().TryTransition() directly for better state management. +func (cn *Conn) CompareAndSwapUsable(old, new bool) bool { + currentState := cn.stateMachine.GetState() + + // Check if current state matches the "old" usable value + currentUsable := (currentState == StateIdle || currentState == StateInUse) + if currentUsable != old { + return false + } + + // If we're trying to set to the same value, succeed immediately + if old == new { + return true + } + + // Transition based on new value + if new { + // Trying to make usable - transition from UNUSABLE to IDLE + // This should only work from UNUSABLE or INITIALIZING states + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition( + validFromInitializingOrUnusable, + StateIdle, + ) + return err == nil + } + // Trying to make unusable - transition from IDLE to UNUSABLE + // This is typically for acquiring the connection for background operations + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition( + validFromIdle, + StateUnusable, + ) + return err == nil +} + +// IsUsable returns true if the connection is safe to use for new commands (lock-free). +// +// A connection is "usable" when it's in a stable state and can be returned to clients. +// It becomes unusable during: +// - Handoff operations (network connection replacement) +// - Re-authentication (credential updates) +// - Other background operations that need exclusive access +// +// Note: CREATED state is considered usable because new connections need to pass OnGet() hook +// before initialization. The initialization happens after OnGet() in the client code. +func (cn *Conn) IsUsable() bool { + state := cn.stateMachine.GetState() + // CREATED, IDLE, and IN_USE states are considered usable + // CREATED: new connection, not yet initialized (will be initialized by client) + // IDLE: initialized and ready to be acquired + // IN_USE: usable but currently acquired by someone + return state == StateCreated || state == StateIdle || state == StateInUse +} + +// SetUsable sets the usable flag for the connection (lock-free). +// +// Deprecated: Use GetStateMachine().Transition() directly for better state management. +// This method is kept for backwards compatibility. +// +// This should be called to mark a connection as usable after initialization or +// to release it after a background operation completes. +// +// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions. +// Deprecated: Use GetStateMachine().Transition() directly for better state management. +func (cn *Conn) SetUsable(usable bool) { + if usable { + // Transition to IDLE state (ready to be acquired) + cn.stateMachine.Transition(StateIdle) + } else { + // Transition to UNUSABLE state (for background operations) + cn.stateMachine.Transition(StateUnusable) + } +} + +// IsInited returns true if the connection has been initialized. +// This is a backward-compatible wrapper around the state machine. +func (cn *Conn) IsInited() bool { + state := cn.stateMachine.GetState() + // Connection is initialized if it's in IDLE or any post-initialization state + return state != StateCreated && state != StateInitializing && state != StateClosed +} + +// Used - State machine based implementation + +// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free). +// This method is kept for backwards compatibility. +// +// This is the preferred method for acquiring a connection from the pool, as it +// ensures that only one goroutine marks the connection as used. +// +// Implementation: Uses state machine transitions IDLE ⇄ IN_USE +// +// Returns true if the swap was successful (old value matched), false otherwise. +// Deprecated: Use GetStateMachine().TryTransition() directly for better state management. +func (cn *Conn) CompareAndSwapUsed(old, new bool) bool { + if old == new { + // No change needed + currentState := cn.stateMachine.GetState() + currentUsed := (currentState == StateInUse) + return currentUsed == old + } + + if !old && new { + // Acquiring: IDLE → IN_USE + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse) + return err == nil + } else { + // Releasing: IN_USE → IDLE + // Use predefined slice to avoid allocation + _, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle) + return err == nil + } +} + +// IsUsed returns true if the connection is currently in use (lock-free). +// +// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity. +// This method is kept for backwards compatibility. +// +// A connection is "used" when it has been retrieved from the pool and is +// actively processing a command. Background operations (like re-auth) should +// wait until the connection is not used before executing commands. +func (cn *Conn) IsUsed() bool { + return cn.stateMachine.GetState() == StateInUse +} + +// SetUsed sets the used flag for the connection (lock-free). +// +// This should be called when returning a connection to the pool (set to false) +// or when a single-connection pool retrieves its connection (set to true). +// +// Prefer CompareAndSwapUsed() when acquiring from a multi-connection pool to +// avoid race conditions. +// Deprecated: Use GetStateMachine().Transition() directly for better state management. +func (cn *Conn) SetUsed(val bool) { + if val { + cn.stateMachine.Transition(StateInUse) + } else { + cn.stateMachine.Transition(StateIdle) + } +} + +// getNetConn returns the current network connection using atomic load (lock-free). +// This is the fast path for accessing netConn without mutex overhead. +func (cn *Conn) getNetConn() net.Conn { + if v := cn.netConnAtomic.Load(); v != nil { + if wrapper, ok := v.(*atomicNetConn); ok { + return wrapper.conn + } + } + return nil +} + +// setNetConn stores the network connection atomically (lock-free). +// This is used for the fast path of connection replacement. +func (cn *Conn) setNetConn(netConn net.Conn) { + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) +} + +// Handoff state management - atomic access to handoff metadata + +// ShouldHandoff returns true if connection needs handoff (lock-free). +func (cn *Conn) ShouldHandoff() bool { + if v := cn.handoffStateAtomic.Load(); v != nil { + return v.(*HandoffState).ShouldHandoff + } + return false +} + +// GetHandoffEndpoint returns the new endpoint for handoff (lock-free). +func (cn *Conn) GetHandoffEndpoint() string { + if v := cn.handoffStateAtomic.Load(); v != nil { + return v.(*HandoffState).Endpoint + } + return "" +} + +// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free). +func (cn *Conn) GetMovingSeqID() int64 { + if v := cn.handoffStateAtomic.Load(); v != nil { + return v.(*HandoffState).SeqID + } + return 0 +} + +// GetHandoffInfo returns all handoff information atomically (lock-free). +// This method prevents race conditions by returning all handoff state in a single atomic operation. +// Returns (shouldHandoff, endpoint, seqID). +func (cn *Conn) GetHandoffInfo() (bool, string, int64) { + if v := cn.handoffStateAtomic.Load(); v != nil { + state := v.(*HandoffState) + return state.ShouldHandoff, state.Endpoint, state.SeqID + } + return false, "", 0 +} + +// HandoffRetries returns the current handoff retry count (lock-free). +func (cn *Conn) HandoffRetries() int { + return int(cn.handoffRetriesAtomic.Load()) +} + +// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free). +func (cn *Conn) IncrementAndGetHandoffRetries(n int) int { + return int(cn.handoffRetriesAtomic.Add(uint32(n))) +} + +// IsPooled returns true if the connection is managed by a pool and will be pooled on Put. +func (cn *Conn) IsPooled() bool { + return cn.pooled +} + +// IsPubSub returns true if the connection is used for PubSub. +func (cn *Conn) IsPubSub() bool { + return cn.pubsub +} + +// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades. +// These timeouts will be used for all subsequent commands until the deadline expires. +// Uses atomic operations for lock-free access. +// Note: Metrics should be recorded by the caller (notification handler) which has context about +// the notification type and pool name. +func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { + cn.relaxedCounter.Add(1) + cn.relaxedReadTimeoutNs.Store(int64(readTimeout)) + cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout)) +} + +// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. +// After the deadline, timeouts automatically revert to normal values. +// Uses atomic operations for lock-free access. +func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { + cn.SetRelaxedTimeout(readTimeout, writeTimeout) + cn.relaxedDeadlineNs.Store(deadline.UnixNano()) +} + +// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior. +// Uses atomic operations for lock-free access. +func (cn *Conn) ClearRelaxedTimeout() { + // Atomically decrement counter and check if we should clear + newCount := cn.relaxedCounter.Add(-1) + deadlineNs := cn.relaxedDeadlineNs.Load() + if newCount <= 0 && (deadlineNs == 0 || time.Now().UnixNano() >= deadlineNs) { + // Use atomic load to get current value for CAS to avoid stale value race + current := cn.relaxedCounter.Load() + if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) { + cn.clearRelaxedTimeout() + } + } +} + +func (cn *Conn) clearRelaxedTimeout() { + cn.relaxedReadTimeoutNs.Store(0) + cn.relaxedWriteTimeoutNs.Store(0) + cn.relaxedDeadlineNs.Store(0) + cn.relaxedCounter.Store(0) + + // Note: Metrics for timeout unrelaxing are not recorded here because we don't have + // context about which notification type or pool triggered the relaxation. + // In practice, relaxed timeouts expire automatically via deadline, so explicit + // unrelaxing metrics are less critical than the initial relaxation metrics. +} + +// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection. +// This checks both the timeout values and the deadline (if set). +// Uses atomic operations for lock-free access. +func (cn *Conn) HasRelaxedTimeout() bool { + // Fast path: no relaxed timeouts are set + if cn.relaxedCounter.Load() <= 0 { + return false + } + + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // If no relaxed timeouts are set, return false + if readTimeoutNs <= 0 && writeTimeoutNs <= 0 { + return false + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, relaxed timeouts are active + if deadlineNs == 0 { + return true + } + + // If deadline is set, check if it's still in the future + return time.Now().UnixNano() < deadlineNs +} + +// getEffectiveReadTimeout returns the timeout to use for read operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration { + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if readTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(readTimeoutNs) + } + + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + nowNs := getCachedTimeNs() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(readTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + newCount := cn.relaxedCounter.Add(-1) + if newCount <= 0 { + internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID())) + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + +// getEffectiveWriteTimeout returns the timeout to use for write operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration { + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if writeTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(writeTimeoutNs) + } + + // Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks) + nowNs := getCachedTimeNs() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(writeTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + newCount := cn.relaxedCounter.Add(-1) + if newCount <= 0 { + internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID())) + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + +// SetOnClose installs fn as the callback invoked exactly once when this +// connection is closed (via Conn.Close). +// +// IMPORTANT: SetOnClose OVERWRITES any previously installed callback — it +// does not compose, chain, or deduplicate. A Conn has room for a single +// onClose hook by design, because its lifecycle is bounded (a Conn is +// created, optionally re-initialized on its own net.Conn, and then closed +// once) and the pool's OnRemove hooks handle any registry-level cleanup +// that must survive the net.Conn being swapped. +// +// This has a subtle implication for per-connection subscriptions such as +// the unsubscribe function returned by StreamingCredentialsProvider +// (e.g. EntraID token rotation): if SetOnClose is called twice on the +// same Conn with DIFFERENT unsubscribe closures — for example because +// initConn ran a second time and obtained a fresh Subscribe() — +// the previous unsubscribe is dropped and will NEVER run, leaking a +// subscription on the provider. Callers must therefore ensure either: +// +// - the provider's Subscribe is idempotent for the same listener (the +// streaming credentials Manager deduplicates listeners by connection +// id, so re-Subscribe returns an equivalent unsubscribe), OR +// - the previous callback has already been invoked before SetOnClose is +// called again. +// +// Design note: unlike the client-level onCloseHooks registry (see +// redis.baseClient), there is intentionally NO named-hook dedup or +// multi-callback support on Conn. This is a deliberate trade-off to keep +// the Conn object slim — a pool can hold thousands of Conn values and +// each one is a hot allocation, so paying for a sync.Mutex plus a +// map[string]func() error per connection to support a feature that would +// only be used by at most one subsystem today (streaming credentials) is +// not worth the per-connection memory and allocation cost. For a single +// Conn there is at most one meaningful close callback at any point in +// time, and a richer registry here would not even solve the "stale +// closure" hazard described above. +func (cn *Conn) SetOnClose(fn func() error) { + cn.onClose = fn +} + +// SetInitConnFunc sets the connection initialization function to be called on reconnections. +func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) { + cn.initConnFunc = fn +} + +// ExecuteInitConn runs the stored connection initialization function if available. +func (cn *Conn) ExecuteInitConn(ctx context.Context) error { + if cn.initConnFunc != nil { + return cn.initConnFunc(ctx, cn) + } + return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID()) } func (cn *Conn) SetNetConn(netConn net.Conn) { - cn.netConn = netConn + // Store the new connection atomically first (lock-free) + cn.setNetConn(netConn) + // Protect reader reset operations to avoid data races + // Use write lock since we're modifying the reader state + cn.readerMu.Lock() cn.rd.Reset(netConn) + cn.readerMu.Unlock() + cn.bw.Reset(netConn) } +// GetNetConn safely returns the current network connection using atomic load (lock-free). +// This method is used by the pool for health checks and provides better performance. +func (cn *Conn) GetNetConn() net.Conn { + return cn.getNetConn() +} + +// SetNetConnAndInitConn replaces the underlying connection and executes the initialization. +// This method ensures only one initialization can happen at a time by using atomic state transitions. +// If another goroutine is currently initializing, this will wait for it to complete. +func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error { + // Wait for and transition to INITIALIZING state - this prevents concurrent initializations + // Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth) + // If another goroutine is initializing, we'll wait for it to finish + // if the context has a deadline, use that, otherwise use the connection read (relaxed) timeout + // which should be set during handoff. If it is not set, use a 5 second default + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(cn.getEffectiveReadTimeout(5 * time.Second)) + } + waitCtx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + // Use predefined slice to avoid allocation + finalState, err := cn.stateMachine.AwaitAndTransition( + waitCtx, + validFromCreatedIdleOrUnusable, + StateInitializing, + ) + if err != nil { + return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err) + } + + // Replace the underlying connection + cn.SetNetConn(netConn) + + // Execute initialization + // NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success + // or CLOSED on failure. We don't need to do it here. + // NOTE: Initconn returns conn in IDLE state + initErr := cn.ExecuteInitConn(ctx) + if initErr != nil { + // ExecuteInitConn already transitioned to CLOSED, just return the error + return initErr + } + + // ExecuteInitConn already transitioned to IDLE + return nil +} + +// MarkForHandoff marks the connection for handoff due to MOVING notification. +// Returns an error if the connection is already marked for handoff. +// Note: This only sets metadata - the connection state is not changed until OnPut. +// This allows the current user to finish using the connection before handoff. +func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error { + // Check if already marked for handoff + if cn.ShouldHandoff() { + return errAlreadyMarkedForHandoff + } + + // Set handoff metadata atomically + cn.handoffStateAtomic.Store(&HandoffState{ + ShouldHandoff: true, + Endpoint: newEndpoint, + SeqID: seqID, + }) + return nil +} + +// MarkQueuedForHandoff marks the connection as queued for handoff processing. +// This makes the connection unusable until handoff completes. +// This is called from OnPut hook, where the connection is typically in IN_USE state. +// The pool will preserve the UNUSABLE state and not overwrite it with IDLE. +func (cn *Conn) MarkQueuedForHandoff() error { + // Get current handoff state + currentState := cn.handoffStateAtomic.Load() + if currentState == nil { + return errNotMarkedForHandoff + } + + state := currentState.(*HandoffState) + if !state.ShouldHandoff { + return errNotMarkedForHandoff + } + + // Create new state with ShouldHandoff=false but preserve endpoint and seqID + // This prevents the connection from being queued multiple times while still + // allowing the worker to access the handoff metadata + newState := &HandoffState{ + ShouldHandoff: false, + Endpoint: state.Endpoint, // Preserve endpoint for handoff processing + SeqID: state.SeqID, // Preserve seqID for handoff processing + } + + // Atomic compare-and-swap to update state + if !cn.handoffStateAtomic.CompareAndSwap(currentState, newState) { + // State changed between load and CAS - retry or return error + return errHandoffStateChanged + } + + // Transition to UNUSABLE from IN_USE (normal flow), IDLE (edge cases), or CREATED (tests/uninitialized) + // The connection is typically in IN_USE state when OnPut is called (normal Put flow) + // But in some edge cases or tests, it might be in IDLE or CREATED state + // The pool will detect this state change and preserve it (not overwrite with IDLE) + // Use predefined slice to avoid allocation + finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable) + if err != nil { + // Check if already in UNUSABLE state (race condition or retry) + // ShouldHandoff should be false now, but check just in case + if finalState == StateUnusable && !cn.ShouldHandoff() { + // Already unusable - this is fine, keep the new handoff state + return nil + } + // Restore the original state if transition fails for other reasons + cn.handoffStateAtomic.Store(currentState) + return fmt.Errorf("failed to mark connection as unusable: %w", err) + } + return nil +} + +// GetID returns the unique identifier for this connection. +func (cn *Conn) GetID() uint64 { + return cn.id +} + +// GetStateMachine returns the connection's state machine for advanced state management. +// This is primarily used by internal packages like maintnotifications for handoff processing. +func (cn *Conn) GetStateMachine() *ConnStateMachine { + return cn.stateMachine +} + +// TryAcquire attempts to acquire the connection for use. +// This is an optimized inline method for the hot path (Get operation). +// +// It tries to transition from IDLE -> IN_USE or CREATED -> CREATED. +// Returns true if the connection was successfully acquired, false otherwise. +// The CREATED->CREATED is done so we can keep the state correct for later +// initialization of the connection in initConn. +// +// Performance: This is faster than calling GetStateMachine() + TryTransitionFast() +// +// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's +// methods. This breaks encapsulation but is necessary for performance. +// The IDLE->IN_USE and CREATED->CREATED transitions don't need +// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever +// needs to notify waiters on these transitions, update this to use TryTransitionFast(). +func (cn *Conn) TryAcquire() bool { + // The || operator short-circuits, so only 1 CAS in the common case + return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) || + cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated)) +} + +// Release releases the connection back to the pool. +// This is an optimized inline method for the hot path (Put operation). +// +// It tries to transition from IN_USE -> IDLE. +// Returns true if the connection was successfully released, false otherwise. +// +// Performance: This is faster than calling GetStateMachine() + TryTransitionFast(). +// +// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's +// methods. This breaks encapsulation but is necessary for performance. +// If the state machine ever needs to notify waiters +// on this transition, update this to use TryTransitionFast(). +func (cn *Conn) Release() bool { + // Inline the hot path - single CAS operation + return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle)) +} + +// ClearHandoffState clears the handoff state after successful handoff. +// Makes the connection usable again. +func (cn *Conn) ClearHandoffState() { + // Clear handoff metadata + cn.handoffStateAtomic.Store(&HandoffState{ + ShouldHandoff: false, + Endpoint: "", + SeqID: 0, + }) + + // Reset retry counter + cn.handoffRetriesAtomic.Store(0) + + // Mark connection as usable again + // Use state machine directly instead of deprecated SetUsable + // probably done by initConn + cn.stateMachine.Transition(StateIdle) +} + +// HasBufferedData safely checks if the connection has buffered data. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) HasBufferedData() bool { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + return cn.rd.Buffered() > 0 +} + +// PeekReplyTypeSafe safely peeks at the reply type. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) PeekReplyTypeSafe() (byte, error) { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + + if cn.rd.Buffered() <= 0 { + return 0, fmt.Errorf("redis: can't peek reply type, no data available") + } + return cn.rd.PeekReplyType() +} + func (cn *Conn) Write(b []byte) (int, error) { - return cn.netConn.Write(b) + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Write(b) + } + return 0, net.ErrClosed } func (cn *Conn) RemoteAddr() net.Addr { - if cn.netConn != nil { - return cn.netConn.RemoteAddr() + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.RemoteAddr() } return nil } @@ -67,7 +875,16 @@ func (cn *Conn) WithReader( ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveReadTimeout(timeout) + + // Get the connection directly from atomic storage + netConn := cn.getNetConn() + if netConn == nil { + return errConnectionNotAvailable + } + + if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { return err } } @@ -78,13 +895,25 @@ func (cn *Conn) WithWriter( ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { - return err + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveWriteTimeout(timeout) + + // Set write deadline on the connection + if netConn := cn.getNetConn(); netConn != nil { + if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { + return err + } + } else { + // Connection is not available - return preallocated error + return errConnNotAvailableForWrite } } + // Reset the buffered writer if needed, should not happen if cn.bw.Buffered() > 0 { - cn.bw.Reset(cn.netConn) + if netConn := cn.getNetConn(); netConn != nil { + cn.bw.Reset(netConn) + } } if err := fn(cn.wr); err != nil { @@ -94,13 +923,49 @@ func (cn *Conn) WithWriter( return cn.bw.Flush() } +func (cn *Conn) IsClosed() bool { + return cn.stateMachine.GetState() == StateClosed +} + func (cn *Conn) Close() error { - return cn.netConn.Close() + if cn.IsClosed() { + return nil + } + // Transition to CLOSED state + cn.stateMachine.Transition(StateClosed) + + if cn.onClose != nil { + // ignore error + _ = cn.onClose() + cn.onClose = nil + } + + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Close() + } + return nil +} + +// MaybeHasData tries to peek at the next byte in the socket without consuming it +// This is used to check if there are push notifications available +// Important: This will work on Linux, but not on Windows +func (cn *Conn) MaybeHasData() bool { + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return maybeHasData(netConn) + } + return false } +// deadline computes the effective deadline time based on context and timeout. +// It updates the usedAt timestamp to now. +// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation). func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { - tm := time.Now() - cn.SetUsedAt(tm) + // Use cached time for deadline calculation (called 2x per command: read + write) + nowNs := getCachedTimeNs() + cn.SetUsedAtNs(nowNs) + tm := time.Unix(0, nowNs) if timeout > 0 { tm = tm.Add(timeout) diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check.go b/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check.go index 83190d394..9e83dd833 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check.go +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check.go @@ -12,6 +12,9 @@ import ( var errUnexpectedRead = errors.New("unexpected read from socket") +// connCheck checks if the connection is still alive and if there is data in the socket +// it will try to peek at the next byte without consuming it since we may want to work with it +// later on (e.g. push notifications) func connCheck(conn net.Conn) error { // Reset previous timeout. _ = conn.SetDeadline(time.Time{}) @@ -29,7 +32,9 @@ func connCheck(conn net.Conn) error { if err := rawConn.Read(func(fd uintptr) bool { var buf [1]byte - n, err := syscall.Read(int(fd), buf[:]) + // Use MSG_PEEK to peek at data without consuming it + n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT) + switch { case n == 0 && err == nil: sysErr = io.EOF @@ -47,3 +52,8 @@ func connCheck(conn net.Conn) error { return sysErr } + +// maybeHasData checks if there is data in the socket without consuming it +func maybeHasData(conn net.Conn) bool { + return connCheck(conn) == errUnexpectedRead +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check_dummy.go b/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check_dummy.go index 295da1268..f971d94c4 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check_dummy.go +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/conn_check_dummy.go @@ -2,8 +2,19 @@ package pool -import "net" +import ( + "errors" + "net" +) -func connCheck(conn net.Conn) error { +// errUnexpectedRead is placeholder error variable for non-unix build constraints +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(_ net.Conn) error { return nil } + +// since we can't check for data on the socket, we just assume there is some +func maybeHasData(_ net.Conn) bool { + return true +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/conn_state.go b/vendor/github.com/redis/go-redis/v9/internal/pool/conn_state.go new file mode 100644 index 000000000..7dee4c49d --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/conn_state.go @@ -0,0 +1,336 @@ +package pool + +import ( + "container/list" + "context" + "errors" + "fmt" + "sync" + "sync/atomic" +) + +// ConnState represents the connection state in the state machine. +// States are designed to be lightweight and fast to check. +// +// State Transitions: +// +// CREATED → INITIALIZING → IDLE ⇄ IN_USE +// ↓ +// UNUSABLE (handoff/reauth) +// ↓ +// IDLE/CLOSED +type ConnState uint32 + +const ( + // StateCreated - Connection just created, not yet initialized + StateCreated ConnState = iota + + // StateInitializing - Connection initialization in progress + StateInitializing + + // StateIdle - Connection initialized and idle in pool, ready to be acquired + StateIdle + + // StateInUse - Connection actively processing a command (retrieved from pool) + StateInUse + + // StateUnusable - Connection temporarily unusable due to background operation + // (handoff, reauth, etc.). Cannot be acquired from pool. + StateUnusable + + // StateClosed - Connection closed + StateClosed +) + +// Predefined state slices to avoid allocations in hot paths +var ( + validFromInUse = []ConnState{StateInUse} + validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle} + validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle} + // For AwaitAndTransition calls + validFromCreatedIdleOrUnusable = []ConnState{StateCreated, StateIdle, StateUnusable} + validFromIdle = []ConnState{StateIdle} + // For CompareAndSwapUsable + validFromInitializingOrUnusable = []ConnState{StateInitializing, StateUnusable} +) + +// Accessor functions for predefined slices to avoid allocations in external packages +// These return the same slice instance, so they're zero-allocation + +// ValidFromIdle returns a predefined slice containing only StateIdle. +// Use this to avoid allocations when calling AwaitAndTransition or TryTransition. +func ValidFromIdle() []ConnState { + return validFromIdle +} + +// ValidFromCreatedIdleOrUnusable returns a predefined slice for initialization transitions. +// Use this to avoid allocations when calling AwaitAndTransition or TryTransition. +func ValidFromCreatedIdleOrUnusable() []ConnState { + return validFromCreatedIdleOrUnusable +} + +// String returns a human-readable string representation of the state. +func (s ConnState) String() string { + switch s { + case StateCreated: + return "CREATED" + case StateInitializing: + return "INITIALIZING" + case StateIdle: + return "IDLE" + case StateInUse: + return "IN_USE" + case StateUnusable: + return "UNUSABLE" + case StateClosed: + return "CLOSED" + default: + return fmt.Sprintf("UNKNOWN(%d)", s) + } +} + +var ( + // ErrInvalidStateTransition is returned when a state transition is not allowed + ErrInvalidStateTransition = errors.New("invalid state transition") + + // ErrStateMachineClosed is returned when operating on a closed state machine + ErrStateMachineClosed = errors.New("state machine is closed") + + // ErrTimeout is returned when a state transition times out + ErrTimeout = errors.New("state transition timeout") +) + +// waiter represents a goroutine waiting for a state transition. +// Designed for minimal allocations and fast processing. +type waiter struct { + validStates map[ConnState]struct{} // States we're waiting for + targetState ConnState // State to transition to + done chan error // Signaled when transition completes or times out +} + +// ConnStateMachine manages connection state transitions with FIFO waiting queue. +// Optimized for: +// - Lock-free reads (hot path) +// - Minimal allocations +// - Fast state transitions +// - FIFO fairness for waiters +// Note: Handoff metadata (endpoint, seqID, retries) is managed separately in the Conn struct. +type ConnStateMachine struct { + // Current state - atomic for lock-free reads + state atomic.Uint32 + + // FIFO queue for waiters - only locked during waiter add/remove/notify + mu sync.Mutex + waiters *list.List // List of *waiter + waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path) +} + +// NewConnStateMachine creates a new connection state machine. +// Initial state is StateCreated. +func NewConnStateMachine() *ConnStateMachine { + sm := &ConnStateMachine{ + waiters: list.New(), + } + sm.state.Store(uint32(StateCreated)) + return sm +} + +// GetState returns the current state (lock-free read). +// This is the hot path - optimized for zero allocations and minimal overhead. +// Note: Zero allocations applies to state reads; converting the returned state to a string +// (via String()) may allocate if the state is unknown. +func (sm *ConnStateMachine) GetState() ConnState { + return ConnState(sm.state.Load()) +} + +// TryTransitionFast is an optimized version for the hot path (Get/Put operations). +// It only handles simple state transitions without waiter notification. +// This is safe because: +// 1. Get/Put don't need to wait for state changes +// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match +// 3. If a background operation is in progress (state is UNUSABLE), this fails fast +// +// Returns true if transition succeeded, false otherwise. +// Use this for performance-critical paths where you don't need error details. +// +// Performance: Single CAS operation - as fast as the old atomic bool! +// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target) +// The || operator short-circuits, so only 1 CAS is executed in the common case. +func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool { + return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) +} + +// TryTransition attempts an immediate state transition without waiting. +// Returns the current state after the transition attempt and an error if the transition failed. +// The returned state is the CURRENT state (after the attempt), not the previous state. +// This is faster than AwaitAndTransition when you don't need to wait. +// Uses compare-and-swap to atomically transition, preventing concurrent transitions. +// This method does NOT wait - it fails immediately if the transition cannot be performed. +// +// Performance: Zero allocations on success path (hot path). +func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) { + // Try each valid from state with CAS + // This ensures only ONE goroutine can successfully transition at a time + for _, fromState := range validFromStates { + // Try to atomically swap from fromState to targetState + // If successful, we won the race and can proceed + if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) { + // Success! We transitioned atomically + // Hot path optimization: only check for waiters if transition succeeded + // This avoids atomic load on every Get/Put when no waiters exist + if sm.waiterCount.Load() > 0 { + sm.notifyWaiters() + } + return targetState, nil + } + } + + // All CAS attempts failed - state is not valid for this transition + // Return the current state so caller can decide what to do + // Note: This error path allocates, but it's the exceptional case + currentState := sm.GetState() + return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)", + ErrInvalidStateTransition, currentState, targetState, validFromStates) +} + +// Transition unconditionally transitions to the target state. +// Use with caution - prefer AwaitAndTransition or TryTransition for safety. +// This is useful for error paths or when you know the transition is valid. +func (sm *ConnStateMachine) Transition(targetState ConnState) { + sm.state.Store(uint32(targetState)) + sm.notifyWaiters() +} + +// AwaitAndTransition waits for the connection to reach one of the valid states, +// then atomically transitions to the target state. +// Returns the current state after the transition attempt and an error if the operation failed. +// The returned state is the CURRENT state (after the attempt), not the previous state. +// Returns error if timeout expires or context is cancelled. +// +// This method implements FIFO fairness - the first caller to wait gets priority +// when the state becomes available. +// +// Performance notes: +// - If already in a valid state, this is very fast (no allocation, no waiting) +// - If waiting is required, allocates one waiter struct and one channel +func (sm *ConnStateMachine) AwaitAndTransition( + ctx context.Context, + validFromStates []ConnState, + targetState ConnState, +) (ConnState, error) { + // Fast path: try immediate transition with CAS to prevent race conditions + // BUT: only if there are no waiters in the queue (to maintain FIFO ordering) + if sm.waiterCount.Load() == 0 { + for _, fromState := range validFromStates { + // Check if we're already in target state + if fromState == targetState && sm.GetState() == targetState { + return targetState, nil + } + + // Try to atomically swap from fromState to targetState + if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) { + // Success! We transitioned atomically + sm.notifyWaiters() + return targetState, nil + } + } + } + + // Fast path failed - check if we should wait or fail + currentState := sm.GetState() + + // Check if closed + if currentState == StateClosed { + return currentState, ErrStateMachineClosed + } + + // Slow path: need to wait for state change + // Create waiter with valid states map for fast lookup + validStatesMap := make(map[ConnState]struct{}, len(validFromStates)) + for _, s := range validFromStates { + validStatesMap[s] = struct{}{} + } + + w := &waiter{ + validStates: validStatesMap, + targetState: targetState, + done: make(chan error, 1), // Buffered to avoid goroutine leak + } + + // Add to FIFO queue + sm.mu.Lock() + elem := sm.waiters.PushBack(w) + sm.waiterCount.Add(1) + sm.mu.Unlock() + + // Wait for state change or timeout + select { + case <-ctx.Done(): + // Timeout or cancellation - remove from queue + sm.mu.Lock() + sm.waiters.Remove(elem) + sm.waiterCount.Add(-1) + sm.mu.Unlock() + return sm.GetState(), ctx.Err() + case err := <-w.done: + // Transition completed (or failed) + // Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed) + // or here (on timeout/cancellation). + return sm.GetState(), err + } +} + +// notifyWaiters checks if any waiters can proceed and notifies them in FIFO order. +// This is called after every state transition. +func (sm *ConnStateMachine) notifyWaiters() { + // Fast path: check atomic counter without acquiring lock + // This eliminates mutex overhead in the common case (no waiters) + if sm.waiterCount.Load() == 0 { + return + } + + sm.mu.Lock() + defer sm.mu.Unlock() + + // Double-check after acquiring lock (waiters might have been processed) + if sm.waiters.Len() == 0 { + return + } + + // Track state locally so we only consider transitions made within this + // call, not concurrent transitions from woken goroutines. Re-reading the + // atomic would let a fast goroutine's Transition(StateIdle) leak into our + // view, causing us to wake multiple waiters at once and breaking FIFO + // execution ordering. + currentState := sm.GetState() + + for { + processed := false + + for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() { + w := elem.Value.(*waiter) + + if _, valid := w.validStates[currentState]; valid { + sm.waiters.Remove(elem) + sm.waiterCount.Add(-1) + + if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) { + w.done <- nil + currentState = w.targetState + processed = true + break + } else { + sm.waiters.PushFront(w) + sm.waiterCount.Add(1) + currentState = sm.GetState() + processed = true + break + } + } + } + + if !processed { + break + } + } +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/hooks.go b/vendor/github.com/redis/go-redis/v9/internal/pool/hooks.go new file mode 100644 index 000000000..a26e1976d --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/hooks.go @@ -0,0 +1,165 @@ +package pool + +import ( + "context" + "sync" +) + +// PoolHook defines the interface for connection lifecycle hooks. +type PoolHook interface { + // OnGet is called when a connection is retrieved from the pool. + // It can modify the connection or return an error to prevent its use. + // The accept flag can be used to prevent the connection from being used. + // On Accept = false the connection is rejected and returned to the pool. + // The error can be used to prevent the connection from being used and returned to the pool. + // On Errors, the connection is removed from the pool. + // It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool) + // The flag can be used for gathering metrics on pool hit/miss ratio. + OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error) + + // OnPut is called when a connection is returned to the pool. + // It returns whether the connection should be pooled and whether it should be removed. + OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) + + // OnRemove is called when a connection is removed from the pool. + // This happens when: + // - Connection fails health check + // - Connection exceeds max lifetime + // - Pool is being closed + // - Connection encounters an error + // Implementations should clean up any per-connection state. + // The reason parameter indicates why the connection was removed. + OnRemove(ctx context.Context, conn *Conn, reason error) +} + +// PoolHookManager manages multiple pool hooks. +type PoolHookManager struct { + hooks []PoolHook + hooksMu sync.RWMutex +} + +// NewPoolHookManager creates a new pool hook manager. +func NewPoolHookManager() *PoolHookManager { + return &PoolHookManager{ + hooks: make([]PoolHook, 0), + } +} + +// AddHook adds a pool hook to the manager. +// Hooks are called in the order they were added. +func (phm *PoolHookManager) AddHook(hook PoolHook) { + phm.hooksMu.Lock() + defer phm.hooksMu.Unlock() + phm.hooks = append(phm.hooks, hook) +} + +// RemoveHook removes a pool hook from the manager. +func (phm *PoolHookManager) RemoveHook(hook PoolHook) { + phm.hooksMu.Lock() + defer phm.hooksMu.Unlock() + + for i, h := range phm.hooks { + if h == hook { + // Remove hook by swapping with last element and truncating + phm.hooks[i] = phm.hooks[len(phm.hooks)-1] + phm.hooks = phm.hooks[:len(phm.hooks)-1] + break + } + } +} + +// ProcessOnGet calls all OnGet hooks in order. +// If any hook returns an error, processing stops and the error is returned. +func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) { + // Copy slice reference while holding lock (fast) + phm.hooksMu.RLock() + hooks := phm.hooks + phm.hooksMu.RUnlock() + + // Call hooks without holding lock (slow operations) + for _, hook := range hooks { + acceptConn, err := hook.OnGet(ctx, conn, isNewConn) + if err != nil { + return false, err + } + + if !acceptConn { + return false, nil + } + } + return true, nil +} + +// ProcessOnPut calls all OnPut hooks in order. +// The first hook that returns shouldRemove=true or shouldPool=false will stop processing. +func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) { + // Copy slice reference while holding lock (fast) + phm.hooksMu.RLock() + hooks := phm.hooks + phm.hooksMu.RUnlock() + + shouldPool = true // Default to pooling the connection + + // Call hooks without holding lock (slow operations) + for _, hook := range hooks { + hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn) + + if hookErr != nil { + return false, true, hookErr + } + + // If any hook says to remove or not pool, respect that decision + if hookShouldRemove { + return false, true, nil + } + + if !hookShouldPool { + shouldPool = false + } + } + + return shouldPool, false, nil +} + +// ProcessOnRemove calls all OnRemove hooks in order. +func (phm *PoolHookManager) ProcessOnRemove(ctx context.Context, conn *Conn, reason error) { + // Copy slice reference while holding lock (fast) + phm.hooksMu.RLock() + hooks := phm.hooks + phm.hooksMu.RUnlock() + + // Call hooks without holding lock (slow operations) + for _, hook := range hooks { + hook.OnRemove(ctx, conn, reason) + } +} + +// GetHookCount returns the number of registered hooks (for testing). +func (phm *PoolHookManager) GetHookCount() int { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + return len(phm.hooks) +} + +// GetHooks returns a copy of all registered hooks. +func (phm *PoolHookManager) GetHooks() []PoolHook { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + hooks := make([]PoolHook, len(phm.hooks)) + copy(hooks, phm.hooks) + return hooks +} + +// Clone creates a copy of the hook manager with the same hooks. +// This is used for lock-free atomic updates of the hook manager. +func (phm *PoolHookManager) Clone() *PoolHookManager { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + newManager := &PoolHookManager{ + hooks: make([]PoolHook, len(phm.hooks)), + } + copy(newManager.hooks, phm.hooks) + return newManager +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/pool.go b/vendor/github.com/redis/go-redis/v9/internal/pool/pool.go index b69c75f4f..d551fbb17 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/pool/pool.go +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/pool.go @@ -3,12 +3,49 @@ package pool import ( "context" "errors" + "math/rand" "net" "sync" "sync/atomic" "time" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/proto" +) + +// Connection close reason constants for metrics. +// These are used as the "reason" parameter in CloseConn() calls. +const ( + // CloseReasonStale indicates the connection was closed because it exceeded + // the idle timeout or max lifetime. + CloseReasonStale = "stale" + + // CloseReasonHookError indicates the connection was closed due to an error + // in a pool hook (OnGet or OnPut). + CloseReasonHookError = "hook_error" + + // CloseReasonAuthError indicates the connection was closed due to an + // authentication error during re-authentication. + CloseReasonAuthError = "auth_error" + + // CloseReasonTest is used in tests when closing connections. + CloseReasonTest = "test" + + // CloseReasonFailover indicates the connection was closed due to a failover event. + CloseReasonFailover = "failover" +) + +// Metric state constants for connection state tracking. +// These represent the logical state of a connection from a metrics perspective, +// not the internal state machine state (ConnState). +const ( + // MetricStateIdle indicates the connection is idle in the pool, + // ready to be acquired. + MetricStateIdle = "idle" + + // MetricStateUsed indicates the connection is currently being used + // by a client operation. + MetricStateUsed = "used" ) var ( @@ -21,30 +58,266 @@ var ( // ErrPoolTimeout timed out waiting to get a connection from the connection pool. ErrPoolTimeout = errors.New("redis: connection pool timeout") + + // ErrConnUnusableTimeout is returned when a connection is not usable and we timed out trying to mark it as unusable. + ErrConnUnusableTimeout = errors.New("redis: timed out trying to mark connection as unusable") + + // errHookRequestedRemoval is returned when a hook requests connection removal. + errHookRequestedRemoval = errors.New("hook requested removal") + + // errConnNotPooled is returned when trying to return a non-pooled connection to the pool. + errConnNotPooled = errors.New("connection not pooled") + // metricCallbackMu protects all global metric callback functions for thread-safe access. + metricCallbackMu sync.RWMutex + + // Global metric callbacks for connection state changes + metricConnectionStateChangeCallback func(ctx context.Context, cn *Conn, fromState, toState string) + + // Global metric callback for connection creation time + metricConnectionCreateTimeCallback func(ctx context.Context, duration time.Duration, cn *Conn) + + // Global metric callback for connection relaxed timeout changes + // Parameters: ctx, delta (+1/-1), cn, poolName, notificationType + metricConnectionRelaxedTimeoutCallback func(ctx context.Context, delta int, cn *Conn, poolName, notificationType string) + + // Global metric callback for connection handoff + // Parameters: ctx, cn, poolName + metricConnectionHandoffCallback func(ctx context.Context, cn *Conn, poolName string) + + // Global metric callback for error tracking + // Parameters: ctx, errorType, cn, statusCode, isInternal, retryAttempts + metricErrorCallback func(ctx context.Context, errorType string, cn *Conn, statusCode string, isInternal bool, retryAttempts int) + + // Global metric callback for maintenance notifications + // Parameters: ctx, cn, notificationType + metricMaintenanceNotificationCallback func(ctx context.Context, cn *Conn, notificationType string) + + // Global metric callback for connection wait time + // Parameters: ctx, duration, cn + metricConnectionWaitTimeCallback func(ctx context.Context, duration time.Duration, cn *Conn) + + // Global metric callback for connection timeouts + // Parameters: ctx, cn, timeoutType + metricConnectionTimeoutCallback func(ctx context.Context, cn *Conn, timeoutType string) + + // Global metric callback for connection closed + // Parameters: ctx, cn, reason, err + metricConnectionClosedCallback func(ctx context.Context, cn *Conn, reason string, err error) + + // Global metric callback for connection count changes (UpDownCounter) + // Parameters: ctx, delta (+1/-1), cn, state, isPubSub + metricConnectionCountCallback func(ctx context.Context, delta int, cn *Conn, state string, isPubSub bool) + + // Global metric callback for pending requests changes (UpDownCounter) + // Parameters: ctx, delta (+1/-1), cn, poolName + // poolName is passed explicitly because we may not have a connection yet when request starts + metricPendingRequestsCallback func(ctx context.Context, delta int, cn *Conn, poolName string) + + // errPanicInDial is returned when a panic occurs in the dial function. + errPanicInQueuedNewConn = errors.New("panic in queuedNewConn") + + // popAttempts is the maximum number of attempts to find a usable connection + // when popping from the idle connection pool. This handles cases where connections + // are temporarily marked as unusable (e.g., during maintenanceNotifications upgrades or network issues). + // Value of 50 provides sufficient resilience without excessive overhead. + // This is capped by the idle connection count, so we won't loop excessively. + popAttempts = 50 + + // getAttempts is the maximum number of attempts to get a connection that passes + // hook validation (e.g., maintenanceNotifications upgrade hooks). This protects against race conditions + // where hooks might temporarily reject connections during cluster transitions. + // Value of 3 balances resilience with performance - most hook rejections resolve quickly. + getAttempts = 3 + + minTime = time.Unix(-2208988800, 0) // Jan 1, 1900 + maxTime = minTime.Add(1<<63 - 1) + noExpiration = maxTime ) -var timers = sync.Pool{ - New: func() interface{} { - t := time.NewTimer(time.Hour) - t.Stop() - return t - }, +// MetricCallbacks holds all metric callback functions. +// Use SetAllMetricCallbacks to register all callbacks atomically. +type MetricCallbacks struct { + // ConnectionCreateTime is called when a new connection is created + ConnectionCreateTime func(ctx context.Context, duration time.Duration, cn *Conn) + + // ConnectionRelaxedTimeout is called when connection timeout is relaxed/unrelaxed + // delta: +1 for relaxed, -1 for unrelaxed + ConnectionRelaxedTimeout func(ctx context.Context, delta int, cn *Conn, poolName, notificationType string) + + // ConnectionHandoff is called when a connection is handed off to another node + ConnectionHandoff func(ctx context.Context, cn *Conn, poolName string) + + // Error is called when an error occurs + Error func(ctx context.Context, errorType string, cn *Conn, statusCode string, isInternal bool, retryAttempts int) + + // MaintenanceNotification is called when a maintenance notification is received + MaintenanceNotification func(ctx context.Context, cn *Conn, notificationType string) + + // ConnectionWaitTime is called to record time spent waiting for a connection + ConnectionWaitTime func(ctx context.Context, duration time.Duration, cn *Conn) + + // ConnectionClosed is called when a connection is closed + ConnectionClosed func(ctx context.Context, cn *Conn, reason string, err error) + + // ConnectionCount is called when connection count changes (UpDownCounter) + // delta: +1 when connection added, -1 when connection removed + // state: connection state (e.g., "idle", "used") + // isPubSub: true if this is a PubSub connection + ConnectionCount func(ctx context.Context, delta int, cn *Conn, state string, isPubSub bool) + + // PendingRequests is called when pending requests count changes (UpDownCounter) + // delta: +1 when request starts waiting, -1 when request stops waiting + // poolName is passed explicitly because we may not have a connection yet when request starts + PendingRequests func(ctx context.Context, delta int, cn *Conn, poolName string) +} + +// SetAllMetricCallbacks sets all metric callbacks atomically. +// Pass nil to clear all callbacks (disable metrics). +// This ensures all callbacks are set together under a single lock, +// preventing inconsistent state during registration. +// +// Note on thread safety: After returning, there is a small window where +// concurrent getMetric* calls may return the old callback value. This is +// acceptable for metrics - at most one event may go to the old recorder +// or be missed during the transition. The callbacks themselves are immutable +// function pointers, so calling an "old" callback is safe. +func SetAllMetricCallbacks(callbacks *MetricCallbacks) { + metricCallbackMu.Lock() + defer metricCallbackMu.Unlock() + + if callbacks == nil { + metricConnectionCreateTimeCallback = nil + metricConnectionRelaxedTimeoutCallback = nil + metricConnectionHandoffCallback = nil + metricErrorCallback = nil + metricMaintenanceNotificationCallback = nil + metricConnectionWaitTimeCallback = nil + metricConnectionClosedCallback = nil + metricConnectionCountCallback = nil + metricPendingRequestsCallback = nil + return + } + + metricConnectionCreateTimeCallback = callbacks.ConnectionCreateTime + metricConnectionRelaxedTimeoutCallback = callbacks.ConnectionRelaxedTimeout + metricConnectionHandoffCallback = callbacks.ConnectionHandoff + metricErrorCallback = callbacks.Error + metricMaintenanceNotificationCallback = callbacks.MaintenanceNotification + metricConnectionWaitTimeCallback = callbacks.ConnectionWaitTime + metricConnectionClosedCallback = callbacks.ConnectionClosed + metricConnectionCountCallback = callbacks.ConnectionCount + metricPendingRequestsCallback = callbacks.PendingRequests +} + +// getMetricConnectionStateChangeCallback returns the metric callback for connection state changes. +func getMetricConnectionStateChangeCallback() func(ctx context.Context, cn *Conn, fromState, toState string) { + metricCallbackMu.RLock() + cb := metricConnectionStateChangeCallback + metricCallbackMu.RUnlock() + return cb +} + +// GetMetricConnectionCreateTimeCallback returns the metric callback for connection creation time. +func GetMetricConnectionCreateTimeCallback() func(ctx context.Context, duration time.Duration, cn *Conn) { + metricCallbackMu.RLock() + cb := metricConnectionCreateTimeCallback + metricCallbackMu.RUnlock() + return cb +} + +// GetMetricConnectionRelaxedTimeoutCallback returns the metric callback for connection relaxed timeout changes. +// This is used by maintnotifications to record relaxed timeout metrics. +func GetMetricConnectionRelaxedTimeoutCallback() func(ctx context.Context, delta int, cn *Conn, poolName, notificationType string) { + metricCallbackMu.RLock() + cb := metricConnectionRelaxedTimeoutCallback + metricCallbackMu.RUnlock() + return cb +} + +// GetMetricConnectionHandoffCallback returns the metric callback for connection handoffs. +// This is used by maintnotifications to record handoff metrics. +func GetMetricConnectionHandoffCallback() func(ctx context.Context, cn *Conn, poolName string) { + metricCallbackMu.RLock() + cb := metricConnectionHandoffCallback + metricCallbackMu.RUnlock() + return cb +} + +// GetMetricErrorCallback returns the metric callback for error tracking. +// This is used by cluster and client code to record error metrics. +func GetMetricErrorCallback() func(ctx context.Context, errorType string, cn *Conn, statusCode string, isInternal bool, retryAttempts int) { + metricCallbackMu.RLock() + cb := metricErrorCallback + metricCallbackMu.RUnlock() + return cb +} + +// GetMetricMaintenanceNotificationCallback returns the metric callback for maintenance notifications. +// This is used by maintnotifications to record notification metrics. +func GetMetricMaintenanceNotificationCallback() func(ctx context.Context, cn *Conn, notificationType string) { + metricCallbackMu.RLock() + cb := metricMaintenanceNotificationCallback + metricCallbackMu.RUnlock() + return cb +} + +func getMetricConnectionWaitTimeCallback() func(ctx context.Context, duration time.Duration, cn *Conn) { + metricCallbackMu.RLock() + cb := metricConnectionWaitTimeCallback + metricCallbackMu.RUnlock() + return cb +} + +func getMetricConnectionTimeoutCallback() func(ctx context.Context, cn *Conn, timeoutType string) { + metricCallbackMu.RLock() + cb := metricConnectionTimeoutCallback + metricCallbackMu.RUnlock() + return cb +} + +func getMetricConnectionClosedCallback() func(ctx context.Context, cn *Conn, reason string, err error) { + metricCallbackMu.RLock() + cb := metricConnectionClosedCallback + metricCallbackMu.RUnlock() + return cb +} + +// getMetricConnectionCountCallback returns the metric callback for connection count changes (UpDownCounter). +func getMetricConnectionCountCallback() func(ctx context.Context, delta int, cn *Conn, state string, isPubSub bool) { + metricCallbackMu.RLock() + cb := metricConnectionCountCallback + metricCallbackMu.RUnlock() + return cb +} + +// getMetricPendingRequestsCallback returns the metric callback for pending requests changes (UpDownCounter). +func getMetricPendingRequestsCallback() func(ctx context.Context, delta int, cn *Conn, poolName string) { + metricCallbackMu.RLock() + cb := metricPendingRequestsCallback + metricCallbackMu.RUnlock() + return cb } // Stats contains pool state information and accumulated stats. type Stats struct { - Hits uint32 // number of times free connection was found in the pool - Misses uint32 // number of times free connection was NOT found in the pool - Timeouts uint32 // number of times a wait timeout occurred - - TotalConns uint32 // number of total connections in the pool - IdleConns uint32 // number of idle connections in the pool - StaleConns uint32 // number of stale connections removed from the pool + Hits uint32 // number of times free connection was found in the pool + Misses uint32 // number of times free connection was NOT found in the pool + Timeouts uint32 // number of times a wait timeout occurred + WaitCount uint32 // number of times a connection was waited + Unusable uint32 // number of times a connection was found to be unusable + WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds + + TotalConns uint32 // number of total connections in the pool + IdleConns uint32 // number of idle connections in the pool + StaleConns uint32 // number of stale connections removed from the pool + PendingRequests uint32 // number of pending requests waiting for a connection + + PubSubStats PubSubStats } type Pooler interface { NewConn(context.Context) (*Conn, error) - CloseConn(*Conn) error + CloseConn(ctx context.Context, cn *Conn, reason string, fromState string) error Get(context.Context) (*Conn, error) Put(context.Context, *Conn) @@ -54,21 +327,55 @@ type Pooler interface { IdleLen() int Stats() *Stats + // Size returns the maximum pool size (capacity). + // This is used by the streaming credentials manager to size the re-auth worker pool. + Size() int + + AddPoolHook(hook PoolHook) + RemovePoolHook(hook PoolHook) + + // RemoveWithoutTurn removes a connection from the pool without freeing a turn. + // This should be used when removing a connection from a context that didn't acquire + // a turn via Get() (e.g., background workers, cleanup tasks). + // For normal removal after Get(), use Remove() instead. + RemoveWithoutTurn(context.Context, *Conn, error) + Close() error } type Options struct { - Dialer func(context.Context) (net.Conn, error) - - PoolFIFO bool - PoolSize int - DialTimeout time.Duration - PoolTimeout time.Duration - MinIdleConns int - MaxIdleConns int - MaxActiveConns int - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration + Dialer func(context.Context) (net.Conn, error) + ReadBufferSize int + WriteBufferSize int + + PoolFIFO bool + PoolSize int32 + MaxConcurrentDials int + DialTimeout time.Duration + PoolTimeout time.Duration + MinIdleConns int32 + MaxIdleConns int32 + MaxActiveConns int32 + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + ConnMaxLifetimeJitter time.Duration + PushNotificationsEnabled bool + + // DialerRetries is the maximum number of retry attempts when dialing fails. + // Default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // Default: 100ms + DialerRetryTimeout time.Duration + + // DialerRetryBackoff controls the delay between dial retry attempts. + // If nil, dial retry backoff is constant and equals DialerRetryTimeout (default: 100ms). + DialerRetryBackoff func(attempt int) time.Duration + + // Name is a unique identifier for this pool, used in metrics. + // Format: addr_uniqueID (e.g., "localhost:6379_a1b2c3d4") + Name string } type lastDialErrorWrap struct { @@ -81,74 +388,161 @@ type ConnPool struct { dialErrorsNum uint32 // atomic lastDialError atomic.Value - queue chan struct{} + dialsInProgress chan struct{} + dialsQueue *wantConnQueue + // Fast semaphore for connection limiting with eventual fairness + // Uses fast path optimization to avoid timer allocation when tokens are available + semaphore *internal.FastSemaphore connsMu sync.Mutex - conns []*Conn + conns map[uint64]*Conn idleConns []*Conn - poolSize int - idleConnsLen int + poolSize atomic.Int32 + idleConnsLen atomic.Int32 + idleCheckInProgress atomic.Bool + idleCheckNeeded atomic.Bool - stats Stats + stats Stats + waitDurationNs atomic.Int64 _closed uint32 // atomic + + // Pool hooks manager for flexible connection processing + // Using atomic.Pointer for lock-free reads in hot paths (Get/Put) + hookManager atomic.Pointer[PoolHookManager] } var _ Pooler = (*ConnPool)(nil) func NewConnPool(opt *Options) *ConnPool { p := &ConnPool{ - cfg: opt, - - queue: make(chan struct{}, opt.PoolSize), - conns: make([]*Conn, 0, opt.PoolSize), - idleConns: make([]*Conn, 0, opt.PoolSize), + cfg: opt, + semaphore: internal.NewFastSemaphore(opt.PoolSize), + conns: make(map[uint64]*Conn), + dialsInProgress: make(chan struct{}, opt.MaxConcurrentDials), + dialsQueue: newWantConnQueue(), + idleConns: make([]*Conn, 0, opt.PoolSize), } - p.connsMu.Lock() - p.checkMinIdleConns() - p.connsMu.Unlock() + // Only create MinIdleConns if explicitly requested (> 0) + // This avoids creating connections during pool initialization for tests + if opt.MinIdleConns > 0 { + p.connsMu.Lock() + p.checkMinIdleConns() + p.connsMu.Unlock() + } return p } +// initializeHooks sets up the pool hooks system. +func (p *ConnPool) initializeHooks() { + manager := NewPoolHookManager() + p.hookManager.Store(manager) +} + +// AddPoolHook adds a pool hook to the pool. +func (p *ConnPool) AddPoolHook(hook PoolHook) { + // Lock-free read of current manager + manager := p.hookManager.Load() + if manager == nil { + p.initializeHooks() + manager = p.hookManager.Load() + } + + // Create new manager with added hook + newManager := manager.Clone() + newManager.AddHook(hook) + + // Atomically swap to new manager + p.hookManager.Store(newManager) +} + +// RemovePoolHook removes a pool hook from the pool. +func (p *ConnPool) RemovePoolHook(hook PoolHook) { + manager := p.hookManager.Load() + if manager != nil { + // Create new manager with removed hook + newManager := manager.Clone() + newManager.RemoveHook(hook) + + // Atomically swap to new manager + p.hookManager.Store(newManager) + } +} + func (p *ConnPool) checkMinIdleConns() { + // If a check is already in progress, mark that we need another check and return + if !p.idleCheckInProgress.CompareAndSwap(false, true) { + p.idleCheckNeeded.Store(true) + return + } + if p.cfg.MinIdleConns == 0 { + p.idleCheckInProgress.Store(false) return } - for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns { - select { - case p.queue <- struct{}{}: - p.poolSize++ - p.idleConnsLen++ + // Keep checking until no more checks are needed + // This handles the case where multiple Remove() calls happen concurrently + for { + // Clear the "check needed" flag before we start + p.idleCheckNeeded.Store(false) + + // Only create idle connections if we haven't reached the total pool size limit + // MinIdleConns should be a subset of PoolSize, not additional connections + for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns { + // Try to acquire a semaphore token + if !p.semaphore.TryAcquire() { + // Semaphore is full, can't create more connections right now + // Break out of inner loop to check if we need to retry + break + } + + p.poolSize.Add(1) + p.idleConnsLen.Add(1) go func() { + defer func() { + if err := recover(); err != nil { + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) + + p.freeTurn() + internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) + } + }() + err := p.addIdleConn() if err != nil && err != ErrClosed { - p.connsMu.Lock() - p.poolSize-- - p.idleConnsLen-- - p.connsMu.Unlock() + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) } - p.freeTurn() }() - default: + } + + // If no one requested another check while we were working, we're done + if !p.idleCheckNeeded.Load() { + p.idleCheckInProgress.Store(false) return } + + // Otherwise, loop again to handle the new requests } } func (p *ConnPool) addIdleConn() error { - ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) - defer cancel() - - cn, err := p.dialConn(ctx, true) + // Do not apply DialTimeout via context here; dialConn applies DialTimeout per attempt. + cn, err := p.dialConn(context.Background(), true) if err != nil { return err } + // NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn() + // when first acquired from the pool. Do NOT transition to IDLE here - that happens + // after initialization completes. + p.connsMu.Lock() defer p.connsMu.Unlock() @@ -158,11 +552,21 @@ func (p *ConnPool) addIdleConn() error { return ErrClosed } - p.conns = append(p.conns, cn) + p.conns[cn.GetID()] = cn p.idleConns = append(p.idleConns, cn) + + // Record connection count increment (new idle connection from min-idle prewarm) + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(context.Background(), 1, cn, "idle", false) + } + return nil } +// NewConn creates a new connection and returns it to the user. +// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size. +// +// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support maintnotifications upgrades. func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { return p.newConn(ctx, false) } @@ -172,36 +576,64 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, ErrClosed } - p.connsMu.Lock() - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { - p.connsMu.Unlock() + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= p.cfg.MaxActiveConns { return nil, ErrPoolExhausted } - p.connsMu.Unlock() + // Protect against nil context due to race condition in queuedNewConn + // where the context can be set to nil after timeout/cancellation + if ctx == nil { + ctx = context.Background() + } + + // Do not apply DialTimeout via context here; dialConn applies DialTimeout per attempt. + // We still propagate ctx so callers can cancel explicitly. cn, err := p.dialConn(ctx, pooled) if err != nil { return nil, err } - p.connsMu.Lock() - defer p.connsMu.Unlock() + // NOTE: Connection is in CREATED state and will be initialized by redis.go:initConn() + // when first used. Do NOT transition to IDLE here - that happens after initialization completes. + // The state machine flow is: CREATED → INITIALIZING (in initConn) → IDLE (after init success) - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > p.cfg.MaxActiveConns { _ = cn.Close() return nil, ErrPoolExhausted } - p.conns = append(p.conns, cn) + p.connsMu.Lock() + defer p.connsMu.Unlock() + if p.closed() { + _ = cn.Close() + return nil, ErrClosed + } + // Check if pool was closed while we were waiting for the lock + if p.conns == nil { + p.conns = make(map[uint64]*Conn) + } + p.conns[cn.GetID()] = cn + if pooled { // If pool is full remove the cn on next Put. - if p.poolSize >= p.cfg.PoolSize { + currentPoolSize := p.poolSize.Load() + if currentPoolSize >= p.cfg.PoolSize { cn.pooled = false } else { - p.poolSize++ + p.poolSize.Add(1) } } + // All new connections start as "used" metrically. For the miss path in getConn, + // this is the final state. For putIdleConn (undelivered conn), a used→idle + // transition is emitted when it's added to idleConns. + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, "", MetricStateUsed) + } + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, 1, cn, "used", false) + } + return cn, nil } @@ -214,18 +646,114 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, p.getLastDialError() } - netConn, err := p.cfg.Dialer(ctx) - if err != nil { - p.setLastDialError(err) - if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { - go p.tryDial() + // Record dial start time for connection creation metric + // This will be used after handshake completes in redis.go _getConn() + // Only call time.Now() if callback is registered to avoid overhead + var dialStartNs int64 + if GetMetricConnectionCreateTimeCallback() != nil { + dialStartNs = time.Now().UnixNano() + } + + // Retry dialing with backoff + // Dial timeout is applied per attempt (so retries/backoff don't eat into the next + // attempt's dial budget), while still honoring caller cancellation via ctx. + maxRetries := p.cfg.DialerRetries + if maxRetries <= 0 { + maxRetries = 5 // Default value + } + + var lastErr error + shouldLoop := true + // when the timeout is reached, we should stop retrying + // but keep the lastErr to return to the caller + // instead of a generic context deadline exceeded error + attempt := 0 + for attempt = 0; (attempt < maxRetries) && shouldLoop; attempt++ { + attemptCtx := ctx + var cancel context.CancelFunc + if p.cfg.DialTimeout > 0 { + // Apply DialTimeout per attempt, but never extend an existing earlier deadline. + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > p.cfg.DialTimeout { + attemptCtx, cancel = context.WithTimeout(ctx, p.cfg.DialTimeout) + } } - return nil, err + + netConn, err := p.cfg.Dialer(attemptCtx) + if cancel != nil { + cancel() + } + if err != nil { + lastErr = err + // Add backoff delay for retry attempts + // (not for the first attempt, do at least one) + // Do not sleep after the last attempt. + if attempt+1 < maxRetries { + backoffDuration := p.dialRetryBackoff(attempt) + select { + case <-ctx.Done(): + shouldLoop = false + case <-time.After(backoffDuration): + // Continue with retry + } + } + continue + } + + cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize) + cn.pooled = pooled + // Store dial start time only if we recorded it + if dialStartNs > 0 { + cn.dialStartNs.Store(dialStartNs) + } + cn.expiresAt = p.calcConnExpiresAt() + // Set pool name for metrics + cn.SetPoolName(p.cfg.Name) + + return cn, nil } - cn := NewConn(netConn) - cn.pooled = pooled - return cn, nil + internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", attempt, lastErr) + // All retries failed - handle error tracking + p.setLastDialError(lastErr) + if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { + go p.tryDial() + } + return nil, lastErr +} + +func (p *ConnPool) dialRetryBackoff(attempt int) time.Duration { + if p.cfg.DialerRetryBackoff != nil { + d := p.cfg.DialerRetryBackoff(attempt) + if d < 0 { + return 0 + } + return d + } + + base := p.cfg.DialerRetryTimeout + if base <= 0 { + base = 100 * time.Millisecond + } + return base +} + +// calcConnExpiresAt calculates the expiration time for a connection. +// It applies random jitter to prevent all connections from expiring simultaneously, +// avoiding the "thundering herd" problem where all connections expire at once. +// Returns noExpiration if ConnMaxLifetime is not set. +func (p *ConnPool) calcConnExpiresAt() time.Time { + if p.cfg.ConnMaxLifetime <= 0 { + return noExpiration + } + + if p.cfg.ConnMaxLifetimeJitter <= 0 { + return time.Now().Add(p.cfg.ConnMaxLifetime) + } + + jitter := p.cfg.ConnMaxLifetimeJitter + jitterRange := jitter.Nanoseconds() * 2 + jitterNs := rand.Int63n(jitterRange) - jitter.Nanoseconds() + return time.Now().Add(p.cfg.ConnMaxLifetime + time.Duration(jitterNs)) } func (p *ConnPool) tryDial() { @@ -234,19 +762,26 @@ func (p *ConnPool) tryDial() { return } - ctx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout) + // Probe dialing even when dialErrorsNum is saturated. Apply DialTimeout per probe + // attempt so custom dialers can't hang indefinitely. + ctx := context.Background() + var cancel context.CancelFunc + if p.cfg.DialTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, p.cfg.DialTimeout) + } conn, err := p.cfg.Dialer(ctx) + if cancel != nil { + cancel() + } if err != nil { p.setLastDialError(err) time.Sleep(time.Second) - cancel() continue } atomic.StoreUint32(&p.dialErrorsNum, 0) _ = conn.Close() - cancel() return } } @@ -265,17 +800,79 @@ func (p *ConnPool) getLastDialError() error { // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { + return p.getConn(ctx) +} + +// getConn returns a connection from the pool. +func (p *ConnPool) getConn(ctx context.Context) (cn *Conn, err error) { if p.closed() { return nil, ErrClosed } - if err := p.waitTurn(ctx); err != nil { + // Track pending requests in pool stats + atomic.AddUint32(&p.stats.PendingRequests, 1) + // Record pending request increment (UpDownCounter) + // Pass pool name explicitly since we don't have a connection yet + poolName := p.cfg.Name + if cb := getMetricPendingRequestsCallback(); cb != nil { + cb(ctx, 1, nil, poolName) + } + defer func() { + if err != nil { + // Failed to get connection, decrement pending requests + atomic.AddUint32(&p.stats.PendingRequests, ^uint32(0)) // -1 + // Record pending request decrement on failure + if cb := getMetricPendingRequestsCallback(); cb != nil { + cb(ctx, -1, nil, poolName) + } + } + }() + + // Track wait time - only call time.Now() if callback is registered + var waitStart time.Time + waitTimeCallback := getMetricConnectionWaitTimeCallback() + if waitTimeCallback != nil { + waitStart = time.Now() + } + if err = p.waitTurn(ctx); err != nil { + // Record timeout if applicable + if err == ErrPoolTimeout { + if cb := getMetricConnectionTimeoutCallback(); cb != nil { + cb(ctx, nil, "pool") + } + // Record general error metric for pool timeout + if cb := GetMetricErrorCallback(); cb != nil { + cb(ctx, "POOL_TIMEOUT", nil, "POOL_TIMEOUT", true, 0) + } + } return nil, err } + var waitDuration time.Duration + if waitTimeCallback != nil { + waitDuration = time.Since(waitStart) + } + + // Use cached time for health checks (max 50ms staleness is acceptable) + nowNs := getCachedTimeNs() + + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() + + for attempts := 0; attempts < getAttempts; attempts++ { - for { p.connsMu.Lock() - cn, err := p.popIdle() + cn, err = p.popIdle() + if cn != nil { + // Emit idle→used transition inside the lock so Close() sees + // consistent state (conn removed from idleConns = "used"). + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, MetricStateIdle, MetricStateUsed) + } + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, -1, cn, "idle", false) + cb(ctx, 1, cn, "used", false) + } + } p.connsMu.Unlock() if err != nil { @@ -287,152 +884,634 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { break } - if !p.isHealthyConn(cn) { - _ = p.CloseConn(cn) + if !p.isHealthyConn(cn, nowNs) { + // Connection was already transitioned to MetricStateUsed under the lock above. + _ = p.CloseConn(ctx, cn, CloseReasonStale, MetricStateUsed) continue } + // Process connection using the hooks system + // Combine error and rejection checks to reduce branches + if hookManager != nil { + acceptConn, hookErr := hookManager.ProcessOnGet(ctx, cn, false) + if hookErr != nil || !acceptConn { + if hookErr != nil { + internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", hookErr) + // Connection was already transitioned to MetricStateUsed under the lock above. + _ = p.CloseConn(ctx, cn, CloseReasonHookError, MetricStateUsed) + } else { + internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + // Connection is already in MetricStateUsed (transitioned under the lock above). + // Return connection to pool without freeing the turn that this Get() call holds. + // putConnWithoutTurn will emit used→idle transition. + p.putConnWithoutTurn(ctx, cn) + cn = nil + } + continue + } + } + atomic.AddUint32(&p.stats.Hits, 1) + + // Record wait time (use cached callback from above) + if waitTimeCallback != nil { + waitTimeCallback(ctx, waitDuration, cn) + } + + // Decrement pending requests (connection acquired successfully) + atomic.AddUint32(&p.stats.PendingRequests, ^uint32(0)) // -1 + // Record pending request decrement (UpDownCounter) + if cb := getMetricPendingRequestsCallback(); cb != nil { + cb(ctx, -1, cn, poolName) + } + return cn, nil } atomic.AddUint32(&p.stats.Misses, 1) - newcn, err := p.newConn(ctx, true) + var newcn *Conn + newcn, err = p.queuedNewConn(ctx) if err != nil { - p.freeTurn() return nil, err } + // Process connection using the hooks system + // This includes the handshake (HELLO/AUTH) via initConn hook + if hookManager != nil { + var acceptConn bool + acceptConn, err = hookManager.ProcessOnGet(ctx, newcn, true) + // both errors and accept=false mean a hook rejected the connection + // this should not happen with a new connection, but we handle it gracefully + if err != nil || !acceptConn { + internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) + // newConn emitted +1 used; CloseConn will emit -1 used if we own the removal. + _ = p.CloseConn(ctx, newcn, CloseReasonHookError, MetricStateUsed) + return nil, err + } + + // Record connection creation time metric when hooks are used. + // When hookManager is set, ProcessOnGet initializes the connection (AUTH/HELLO), + // causing IsInited()=true. This means _getConn() in redis.go will take the + // early return path and never reach its create time recording. + // When hookManager is nil, _getConn() handles both initialization and create time recording. + if dialStartNs := newcn.GetDialStartNs(); dialStartNs > 0 { + if cb := GetMetricConnectionCreateTimeCallback(); cb != nil { + duration := time.Duration(time.Now().UnixNano() - dialStartNs) + cb(ctx, duration, newcn) + } + } + } + + // newConn already emitted +1 used, so no transition needed here. + + // Record wait time (use cached callback from above) + if waitTimeCallback != nil { + waitTimeCallback(ctx, waitDuration, newcn) + } + + // Decrement pending requests (connection acquired successfully) + atomic.AddUint32(&p.stats.PendingRequests, ^uint32(0)) // -1 + // Record pending request decrement (UpDownCounter) + if cb := getMetricPendingRequestsCallback(); cb != nil { + cb(ctx, -1, newcn, poolName) + } + return newcn, nil } -func (p *ConnPool) waitTurn(ctx context.Context) error { +func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) { select { + case p.dialsInProgress <- struct{}{}: + // Got permission, proceed to create connection case <-ctx.Done(): - return ctx.Err() - default: + p.freeTurn() + return nil, ctx.Err() } - select { - case p.queue <- struct{}{}: - return nil - default: + // Don't apply DialTimeout via context here; dialConn applies DialTimeout per attempt. + dialCtx, cancel := context.WithCancel(context.Background()) + + w := &wantConn{ + ctx: dialCtx, + cancelCtx: cancel, + result: make(chan wantConnResult, 1), } + var err error + defer func() { + if err != nil { + if cn := w.cancel(); cn != nil && p.putIdleConn(ctx, cn) { + p.freeTurn() + } + } + }() + + p.dialsQueue.discardDoneAtFront() + p.dialsQueue.enqueue(w) + + go func(w *wantConn) { + var freeTurnCalled bool + defer func() { + if err := recover(); err != nil { + w.tryDeliver(nil, errPanicInQueuedNewConn) + p.dialsQueue.discardDoneAtFront() + if !freeTurnCalled { + p.freeTurn() + } + internal.Logger.Printf(context.Background(), "queuedNewConn panic: %+v", err) + } + }() + + defer w.cancelCtx() + defer func() { <-p.dialsInProgress }() // Release connection creation permission + + dialCtx := w.getCtxForDial() + cn, cnErr := p.newConn(dialCtx, true) + if cnErr != nil { + w.tryDeliver(nil, cnErr) // deliver error to caller, notify connection creation failed + p.dialsQueue.discardDoneAtFront() + p.freeTurn() + freeTurnCalled = true + return + } - timer := timers.Get().(*time.Timer) - timer.Reset(p.cfg.PoolTimeout) + delivered := w.tryDeliver(cn, cnErr) + p.dialsQueue.discardDoneAtFront() + if !delivered && p.putIdleConn(dialCtx, cn) { + p.freeTurn() + freeTurnCalled = true + } + }(w) select { case <-ctx.Done(): - if !timer.Stop() { - <-timer.C + err = ctx.Err() + return nil, err + case result := <-w.result: + err = result.err + return result.cn, err + } +} + +// putIdleConn puts a connection back to the pool or passes it to the next waiting request. +// +// It returns true if the connection was put back to the pool, +// which means the turn needs to be freed directly by the caller, +// or false if the connection was passed to the next waiting request, +// which means the turn will be freed by the waiting goroutine after it returns. +func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) bool { + for { + w, ok := p.dialsQueue.dequeue() + if !ok { + break } - timers.Put(timer) - return ctx.Err() - case p.queue <- struct{}{}: - if !timer.Stop() { - <-timer.C + if w.tryDeliver(cn, nil) { + return false } - timers.Put(timer) + } + + p.connsMu.Lock() + defer p.connsMu.Unlock() + + if p.closed() { + // Don't close here — this connection is still in p.conns and Close() + // will handle closing it and emitting the correct metric decrements. + // We just skip adding it to idleConns. + return true + } + + p.idleConns = append(p.idleConns, cn) + p.idleConnsLen.Add(1) + + // Connection was created as "used" in newConn; transition to idle. + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, MetricStateUsed, MetricStateIdle) + } + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, -1, cn, "used", false) + cb(ctx, 1, cn, "idle", false) + } + + return true +} + +func (p *ConnPool) waitTurn(ctx context.Context) error { + // Fast path: check context first + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Fast path: try to acquire without blocking + if p.semaphore.TryAcquire() { return nil - case <-timer.C: - timers.Put(timer) + } + + // Slow path: need to wait + start := time.Now() + err := p.semaphore.Acquire(ctx, p.cfg.PoolTimeout, ErrPoolTimeout) + + switch err { + case nil: + // Successfully acquired after waiting + p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano()) + atomic.AddUint32(&p.stats.WaitCount, 1) + case ErrPoolTimeout: atomic.AddUint32(&p.stats.Timeouts, 1) - return ErrPoolTimeout } + + return err } func (p *ConnPool) freeTurn() { - <-p.queue + p.semaphore.Release() } func (p *ConnPool) popIdle() (*Conn, error) { if p.closed() { return nil, ErrClosed } + defer p.checkMinIdleConns() + n := len(p.idleConns) if n == 0 { return nil, nil } var cn *Conn - if p.cfg.PoolFIFO { - cn = p.idleConns[0] - copy(p.idleConns, p.idleConns[1:]) - p.idleConns = p.idleConns[:n-1] - } else { - idx := n - 1 - cn = p.idleConns[idx] - p.idleConns = p.idleConns[:idx] + attempts := 0 + + maxAttempts := min(popAttempts, n) + for attempts < maxAttempts { + if len(p.idleConns) == 0 { + return nil, nil + } + + if p.cfg.PoolFIFO { + cn = p.idleConns[0] + copy(p.idleConns, p.idleConns[1:]) + p.idleConns = p.idleConns[:len(p.idleConns)-1] + } else { + idx := len(p.idleConns) - 1 + cn = p.idleConns[idx] + p.idleConns = p.idleConns[:idx] + } + attempts++ + + // Hot path optimization: try IDLE → IN_USE or CREATED → IN_USE transition + // Using inline TryAcquire() method for better performance (avoids pointer dereference) + if cn.TryAcquire() { + // Successfully acquired the connection + p.idleConnsLen.Add(-1) + break + } + + // Connection is in UNUSABLE, INITIALIZING, or other state - skip it + + // Connection is not in a valid state (might be UNUSABLE for handoff/re-auth, INITIALIZING, etc.) + // Put it back in the pool and try the next one + if p.cfg.PoolFIFO { + // FIFO: put at end (will be picked up last since we pop from front) + p.idleConns = append(p.idleConns, cn) + } else { + // LIFO: put at beginning (will be picked up last since we pop from end) + p.idleConns = append([]*Conn{cn}, p.idleConns...) + } + cn = nil } - p.idleConnsLen-- - p.checkMinIdleConns() + + // If we exhausted all attempts without finding a usable connection, return nil + if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() { + internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) + return nil, nil + } + return cn, nil } func (p *ConnPool) Put(ctx context.Context, cn *Conn) { - if cn.rd.Buffered() > 0 { - internal.Logger.Printf(ctx, "Conn has unread data") - p.Remove(ctx, cn, BadConnError{}) + p.putConn(ctx, cn, true) +} + +// putConnWithoutTurn is an internal method that puts a connection back to the pool +// without freeing a turn. This is used when returning a rejected connection from +// within Get(), where the turn is still held by the Get() call. +func (p *ConnPool) putConnWithoutTurn(ctx context.Context, cn *Conn) { + p.putConn(ctx, cn, false) +} + +// putConn is the internal implementation of Put that optionally frees a turn. +func (p *ConnPool) putConn(ctx context.Context, cn *Conn, freeTurn bool) { + // Guard against nil connection + if cn == nil { + internal.Logger.Printf(ctx, "putConn called with nil connection") + if freeTurn { + p.freeTurn() + } + return + } + + // Process connection using the hooks system + shouldPool := true + shouldRemove := false + var err error + + if cn.HasBufferedData() { + // Peek at the reply type to check if it's a push notification + if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush { + // Not a push notification or error peeking, remove connection + internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it") + p.removeConnInternal(ctx, cn, err, freeTurn) + return + } + // It's a push notification, allow pooling (client will handle it) + } + + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() + + if hookManager != nil { + shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) + if err != nil { + internal.Logger.Printf(ctx, "Connection hook error: %v", err) + p.removeConnInternal(ctx, cn, err, freeTurn) + return + } + } + + // Combine all removal checks into one - reduces branches + if shouldRemove || !shouldPool { + p.removeConnInternal(ctx, cn, errHookRequestedRemoval, freeTurn) return } if !cn.pooled { - p.Remove(ctx, cn, nil) + p.removeConnInternal(ctx, cn, errConnNotPooled, freeTurn) return } var shouldCloseConn bool + var removedFromPool bool + + if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { + // Hot path optimization: try fast IN_USE → IDLE transition + // Using inline Release() method for better performance (avoids pointer dereference) + transitionedToIdle := cn.Release() + + // Handle unexpected state changes + if !transitionedToIdle { + // Fast path failed - hook might have changed state (e.g., to UNUSABLE for handoff) + // Keep the state set by the hook and pool the connection anyway + sm := cn.GetStateMachine() + if sm == nil { + // State machine is nil - connection is in an invalid state, remove it + internal.Logger.Printf(ctx, "conn[%d] has nil state machine, removing it", cn.GetID()) + p.removeConnInternal(ctx, cn, errConnNotPooled, freeTurn) + return + } + currentState := sm.GetState() + switch currentState { + case StateUnusable: + // expected state, don't log it + case StateClosed: + internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, closing it", cn.GetID(), currentState) + shouldCloseConn = true + removedFromPool = p.removeConnWithLock(cn) + default: + // Pool as-is + internal.Logger.Printf(ctx, "Unexpected conn[%d] state changed by hook to %v, pooling as-is", cn.GetID(), currentState) + } + } - p.connsMu.Lock() + // unusable conns are expected to become usable at some point (background process is reconnecting them) + // put them at the opposite end of the queue + // Optimization: if we just transitioned to IDLE, we know it's usable - skip the check + if !transitionedToIdle && !cn.IsUsable() { + p.connsMu.Lock() + // Check if Close() already removed this connection from p.conns. + // If so, skip the append and metrics — Close() already accounted for it. + if _, inPool := p.conns[cn.GetID()]; inPool { + if p.cfg.PoolFIFO { + p.idleConns = append(p.idleConns, cn) + } else { + p.idleConns = append([]*Conn{cn}, p.idleConns...) + } + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, MetricStateUsed, MetricStateIdle) + } + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, -1, cn, "used", false) + cb(ctx, 1, cn, "idle", false) + } + p.connsMu.Unlock() + p.idleConnsLen.Add(1) + } else { + shouldCloseConn = true + p.connsMu.Unlock() + } + } else if !shouldCloseConn { + p.connsMu.Lock() + if _, inPool := p.conns[cn.GetID()]; inPool { + p.idleConns = append(p.idleConns, cn) + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, MetricStateUsed, MetricStateIdle) + } + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, -1, cn, "used", false) + cb(ctx, 1, cn, "idle", false) + } + p.connsMu.Unlock() + p.idleConnsLen.Add(1) + } else { + shouldCloseConn = true + p.connsMu.Unlock() + } + } - if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns { - p.idleConns = append(p.idleConns, cn) - p.idleConnsLen++ + if shouldCloseConn { + // Connection was removed (e.g., hook set state to StateClosed). + // Only emit if we actually removed it from the map (not already taken by Close()). + if removedFromPool { + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, MetricStateUsed, "") + } + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, -1, cn, "used", false) + } + } + } } else { - p.removeConn(cn) shouldCloseConn = true - } + removedFromPool = p.removeConnWithLock(cn) - p.connsMu.Unlock() + // Only emit if we actually removed it from the map (not already taken by Close()). + if removedFromPool { + // Notify metrics: connection removed (used -> nothing) + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, MetricStateUsed, "") + } + // Record connection count decrement (connection removed while in used state) + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, -1, cn, "used", false) + } + } + } - p.freeTurn() + if freeTurn { + p.freeTurn() + } if shouldCloseConn { + // Only emit connection closed if we actually owned the removal. + // If removedFromPool is false, Close() already emitted connectionClosed for this conn. + if removedFromPool { + if cb := getMetricConnectionClosedCallback(); cb != nil { + reason := "conn_pool_close" + if r := cn.closeReason.Load(); r != "" { + reason = r + } + cb(ctx, cn, reason, nil) + } + } _ = p.closeConn(cn) } + + cn.SetLastPutAtNs(getCachedTimeNs()) +} + +func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { + p.removeConnInternal(ctx, cn, reason, true) +} + +// RemoveWithoutTurn removes a connection from the pool without freeing a turn. +// This should be used when removing a connection from a context that didn't acquire +// a turn via Get() (e.g., background workers, cleanup tasks). +// For normal removal after Get(), use Remove() instead. +func (p *ConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) { + p.removeConnInternal(ctx, cn, reason, false) } -func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) { - p.removeConnWithLock(cn) - p.freeTurn() +// removeConnInternal is the internal implementation of Remove that optionally frees a turn. +func (p *ConnPool) removeConnInternal(ctx context.Context, cn *Conn, reason error, freeTurn bool) { + // Lock-free atomic read - no mutex overhead! + hookManager := p.hookManager.Load() + + if hookManager != nil { + hookManager.ProcessOnRemove(ctx, cn, reason) + } + + removed := p.removeConnWithLock(cn) + + if freeTurn { + p.freeTurn() + } + + // Only emit metric decrements if we actually removed the connection from the map. + // If removed is false, Close() already removed it and emitted the -1 delta. + if removed { + // Notify metrics: connection removed (assume from used state) + if cb := getMetricConnectionStateChangeCallback(); cb != nil { + cb(ctx, cn, MetricStateUsed, "") + } + // Record connection count decrement (connection removed, assume from used state) + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(ctx, -1, cn, "used", false) + } + } + + // Only emit connection closed if we actually owned the removal. + // If removed is false, Close() already emitted connectionClosed for this conn. + if removed { + if cb := getMetricConnectionClosedCallback(); cb != nil { + reasonStr := "unknown" + if reason != nil { + reasonStr = reason.Error() + } + cb(ctx, cn, reasonStr, reason) + } + } + _ = p.closeConn(cn) + + // Check if we need to create new idle connections to maintain MinIdleConns + p.checkMinIdleConns() } -func (p *ConnPool) CloseConn(cn *Conn) error { - p.removeConnWithLock(cn) +// CloseConn closes a connection and records metrics. +// Parameters: +// - ctx: context for metric callbacks (enables trace-to-metric correlation) +// - cn: the connection to close +// - reason: why the connection is being closed (use CloseReason* constants) +// - fromState: the metric state the connection was in (use MetricState* constants) +func (p *ConnPool) CloseConn(ctx context.Context, cn *Conn, reason string, fromState string) error { + if hookManager := p.hookManager.Load(); hookManager != nil { + hookManager.ProcessOnRemove(ctx, cn, errors.New(reason)) + } + + removed := p.removeConnWithLock(cn) + + // Only emit UpDownCounter decrements if we actually removed the connection. + // If removed is false, Close() already removed it and emitted the -1 delta. + // Only emit connection closed if we actually owned the removal. + // If removed is false, Close() already emitted connectionClosed for this conn. + if removed { + p.recordConnectionMetrics(ctx, cn, reason, fromState) + } + return p.closeConn(cn) } -func (p *ConnPool) removeConnWithLock(cn *Conn) { +func (p *ConnPool) recordConnectionMetrics(ctx context.Context, cn *Conn, reason string, fromState string) { + // Record connection state change: connection is being removed from the specified state + if cb := getMetricConnectionStateChangeCallback(); cb != nil && fromState != "" { + cb(ctx, cn, fromState, "") + } + + // Record connection count decrement (UpDownCounter) for the state the connection was in + if cb := getMetricConnectionCountCallback(); cb != nil && fromState != "" { + cb(ctx, -1, cn, fromState, false) + } + + if cb := getMetricConnectionClosedCallback(); cb != nil { + cb(ctx, cn, reason, nil) + } +} + +// removeConnWithLock removes a connection from the pool under the connsMu lock. +// Returns true if the connection was actually present in p.conns and was removed, +// false if it was already gone (e.g., removed by Close()). Callers must use the +// return value to decide whether to emit metric decrements — this eliminates the +// shutdown race between Close() and concurrent removal paths. +func (p *ConnPool) removeConnWithLock(cn *Conn) bool { p.connsMu.Lock() defer p.connsMu.Unlock() - p.removeConn(cn) + return p.removeConn(cn) } -func (p *ConnPool) removeConn(cn *Conn) { - for i, c := range p.conns { - if c == cn { - p.conns = append(p.conns[:i], p.conns[i+1:]...) - if cn.pooled { - p.poolSize-- - p.checkMinIdleConns() +// removeConn removes a connection from the pool's internal data structures. +// Returns true if the connection was present and removed, false otherwise. +func (p *ConnPool) removeConn(cn *Conn) bool { + cid := cn.GetID() + if _, exists := p.conns[cid]; !exists { + return false + } + delete(p.conns, cid) + atomic.AddUint32(&p.stats.StaleConns, 1) + + // Decrement pool size counter when removing a connection + if cn.pooled { + p.poolSize.Add(-1) + // this can be idle conn + for idx, ic := range p.idleConns { + if ic == cn { + p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...) + p.idleConnsLen.Add(-1) + break } - break } } - atomic.AddUint32(&p.stats.StaleConns, 1) + return true } func (p *ConnPool) closeConn(cn *Conn) error { @@ -450,16 +1529,28 @@ func (p *ConnPool) Len() int { // IdleLen returns number of idle connections. func (p *ConnPool) IdleLen() int { p.connsMu.Lock() - n := p.idleConnsLen + n := p.idleConnsLen.Load() p.connsMu.Unlock() - return n + return int(n) +} + +// Size returns the maximum pool size (capacity). +// +// This is used by the streaming credentials manager to size the re-auth worker pool, +// ensuring that re-auth operations don't exhaust the connection pool. +func (p *ConnPool) Size() int { + return int(p.cfg.PoolSize) } func (p *ConnPool) Stats() *Stats { return &Stats{ - Hits: atomic.LoadUint32(&p.stats.Hits), - Misses: atomic.LoadUint32(&p.stats.Misses), - Timeouts: atomic.LoadUint32(&p.stats.Timeouts), + Hits: atomic.LoadUint32(&p.stats.Hits), + Misses: atomic.LoadUint32(&p.stats.Misses), + Timeouts: atomic.LoadUint32(&p.stats.Timeouts), + WaitCount: atomic.LoadUint32(&p.stats.WaitCount), + Unusable: atomic.LoadUint32(&p.stats.Unusable), + WaitDurationNs: p.waitDurationNs.Load(), + PendingRequests: atomic.LoadUint32(&p.stats.PendingRequests), TotalConns: uint32(p.Len()), IdleConns: uint32(p.IdleLen()), @@ -472,13 +1563,33 @@ func (p *ConnPool) closed() bool { } func (p *ConnPool) Filter(fn func(*Conn) bool) error { + ctx := context.Background() + p.connsMu.Lock() defer p.connsMu.Unlock() + idleConnSet := make(map[*Conn]struct{}, len(p.idleConns)) + for _, ic := range p.idleConns { + idleConnSet[ic] = struct{}{} + } + var firstErr error for _, cn := range p.conns { if fn(cn) { - if err := p.closeConn(cn); err != nil && firstErr == nil { + var err error + if _, isIdle := idleConnSet[cn]; isIdle { + // Idle connection - remove from pool and close. + p.removeConn(cn) + p.recordConnectionMetrics(ctx, cn, CloseReasonFailover, MetricStateIdle) + err = p.closeConn(cn) + } else { + // Used connection - set closeReason and close the connection. + // The connection remains in p.conns. When putConn() is called later, + // it will close the connection instead of pooling it. + cn.closeReason.Store(CloseReasonFailover) + err = cn.Close() + } + if err != nil && firstErr == nil { firstErr = err } } @@ -492,35 +1603,107 @@ func (p *ConnPool) Close() error { } var firstErr error + nowNs := time.Now().UnixNano() p.connsMu.Lock() + + // Emit -1 for each connection. Since all idle↔used transitions happen + // under connsMu, the idleConns slice is the source of truth for state. + cb := getMetricConnectionCountCallback() + idleSet := make(map[uint64]struct{}, len(p.idleConns)) + for _, cn := range p.idleConns { + idleSet[cn.GetID()] = struct{}{} + } + ctx := context.Background() for _, cn := range p.conns { + // Check health before closing, since closeConn invalidates the + // underlying fd and would make connCheck (inside isHealthyConn) + // always fail with EBADF. + // Only check health for idle connections to avoid data races when + // peeking at the socket/reader while another goroutine is reading from it. + // Non-idle connections are either in use or in transitional states and + // shouldn't be health-checked during shutdown. + _, isIdle := idleSet[cn.GetID()] + var healthy bool + if isIdle { + healthy = p.isHealthyConn(cn, nowNs) + } else { + healthy = true + } + if cb != nil { + if isIdle { + cb(ctx, -1, cn, "idle", false) + } else { + cb(ctx, -1, cn, "used", false) + } + } + if closedCb := getMetricConnectionClosedCallback(); closedCb != nil { + closedCb(ctx, cn, "pool_shutdown", nil) + } if err := p.closeConn(cn); err != nil && firstErr == nil { - firstErr = err + // Suppress close errors for stale connections, consistent + // with how Get() handles them (see CloseReasonStale path). + if healthy { + firstErr = err + } } } p.conns = nil - p.poolSize = 0 + p.poolSize.Store(0) p.idleConns = nil - p.idleConnsLen = 0 + p.idleConnsLen.Store(0) p.connsMu.Unlock() return firstErr } -func (p *ConnPool) isHealthyConn(cn *Conn) bool { - now := time.Now() +func (p *ConnPool) isHealthyConn(cn *Conn, nowNs int64) bool { + // Performance optimization: check conditions from cheapest to most expensive, + // and from most likely to fail to least likely to fail. - if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime { - return false + // Only fails if ConnMaxLifetime is set AND connection is old. + // Most pools don't set ConnMaxLifetime, so this rarely fails. + if p.cfg.ConnMaxLifetime > 0 { + if cn.expiresAt.UnixNano() < nowNs { + return false // Connection has exceeded max lifetime + } } - if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { - return false + + // Most pools set ConnMaxIdleTime, and idle connections are common. + // Checking this first allows us to fail fast without expensive syscalls. + if p.cfg.ConnMaxIdleTime > 0 { + if nowNs-cn.UsedAtNs() >= int64(p.cfg.ConnMaxIdleTime) { + return false // Connection has been idle too long + } } - if connCheck(cn.netConn) != nil { + // Only run this if the cheap checks passed. + if err := connCheck(cn.getNetConn()); err != nil { + // If there's unexpected data, it might be push notifications (RESP3) + if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead { + // Peek at the reply type to check if it's a push notification + if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { + // For RESP3 connections with push notifications, we allow some buffered data + // The client will process these notifications before using the connection + internal.Logger.Printf( + context.Background(), + "push: conn[%d] has buffered data, likely push notifications - will be processed by client", + cn.GetID(), + ) + + // Update timestamp for healthy connection + cn.SetUsedAtNs(nowNs) + + // Connection is healthy, client will handle notifications + return true + } + // Not a push notification - treat as unhealthy + return false + } + // Connection failed health check return false } - cn.SetUsedAt(now) + // Only update UsedAt if connection is healthy (avoids unnecessary atomic store) + cn.SetUsedAtNs(nowNs) return true } diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/pool_single.go b/vendor/github.com/redis/go-redis/v9/internal/pool/pool_single.go index 5a3fde191..682959069 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/pool/pool_single.go +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/pool_single.go @@ -1,7 +1,13 @@ package pool -import "context" +import ( + "context" + "time" +) +// SingleConnPool is a pool that always returns the same connection. +// Note: This pool is not thread-safe. +// It is intended to be used by clients that need a single connection. type SingleConnPool struct { pool Pooler cn *Conn @@ -10,6 +16,12 @@ type SingleConnPool struct { var _ Pooler = (*SingleConnPool)(nil) +// NewSingleConnPool creates a new single connection pool. +// The pool will always return the same connection. +// The pool will not: +// - Close the connection +// - Reconnect the connection +// - Track the connection in any way func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { return &SingleConnPool{ pool: pool, @@ -21,24 +33,51 @@ func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) { return p.pool.NewConn(ctx) } -func (p *SingleConnPool) CloseConn(cn *Conn) error { - return p.pool.CloseConn(cn) +func (p *SingleConnPool) CloseConn(ctx context.Context, cn *Conn, reason string, fromState string) error { + return p.pool.CloseConn(ctx, cn, reason, fromState) } -func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { +func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) { if p.stickyErr != nil { return nil, p.stickyErr } + if p.cn == nil { + return nil, ErrClosed + } + + // NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios: + // - During initialization (connection is in INITIALIZING state) + // - During re-authentication (connection is in UNUSABLE state) + // - For transactions (connection might be in various states) + // We use SetUsed() which forces the transition, rather than TryTransition() which + // would fail if the connection is not in IDLE/CREATED state. + p.cn.SetUsed(true) + p.cn.SetUsedAt(time.Now()) return p.cn, nil } -func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} +func (p *SingleConnPool) Put(_ context.Context, cn *Conn) { + if p.cn == nil { + return + } + if p.cn != cn { + return + } + p.cn.SetUsed(false) +} -func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { +func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) { + cn.SetUsed(false) p.cn = nil p.stickyErr = reason } +// RemoveWithoutTurn has the same behavior as Remove for SingleConnPool +// since SingleConnPool doesn't use a turn-based queue system. +func (p *SingleConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) { + p.Remove(ctx, cn, reason) +} + func (p *SingleConnPool) Close() error { p.cn = nil p.stickyErr = ErrClosed @@ -53,6 +92,13 @@ func (p *SingleConnPool) IdleLen() int { return 0 } +// Size returns the maximum pool size, which is always 1 for SingleConnPool. +func (p *SingleConnPool) Size() int { return 1 } + func (p *SingleConnPool) Stats() *Stats { return &Stats{} } + +func (p *SingleConnPool) AddPoolHook(_ PoolHook) {} + +func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/pool_sticky.go b/vendor/github.com/redis/go-redis/v9/internal/pool/pool_sticky.go index 3adb99bc8..6763299eb 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/pool/pool_sticky.go +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/pool_sticky.go @@ -61,8 +61,8 @@ func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { return p.pool.NewConn(ctx) } -func (p *StickyConnPool) CloseConn(cn *Conn) error { - return p.pool.CloseConn(cn) +func (p *StickyConnPool) CloseConn(ctx context.Context, cn *Conn, reason string, fromState string) error { + return p.pool.CloseConn(ctx, cn, reason, fromState) } func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { @@ -123,6 +123,12 @@ func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { p.ch <- cn } +// RemoveWithoutTurn has the same behavior as Remove for StickyConnPool +// since StickyConnPool doesn't use a turn-based queue system. +func (p *StickyConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) { + p.Remove(ctx, cn, reason) +} + func (p *StickyConnPool) Close() error { if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { return nil @@ -196,6 +202,13 @@ func (p *StickyConnPool) IdleLen() int { return len(p.ch) } +// Size returns the maximum pool size, which is always 1 for StickyConnPool. +func (p *StickyConnPool) Size() int { return 1 } + func (p *StickyConnPool) Stats() *Stats { return &Stats{} } + +func (p *StickyConnPool) AddPoolHook(hook PoolHook) {} + +func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/pubsub.go b/vendor/github.com/redis/go-redis/v9/internal/pool/pubsub.go new file mode 100644 index 000000000..8cfa86788 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/pubsub.go @@ -0,0 +1,105 @@ +package pool + +import ( + "context" + "net" + "sync" + "sync/atomic" +) + +type PubSubStats struct { + Created uint32 + Untracked uint32 + Active uint32 +} + +// PubSubPool manages a pool of PubSub connections. +type PubSubPool struct { + opt *Options + netDialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // Map to track active PubSub connections + activeConns sync.Map // map[uint64]*Conn (connID -> conn) + closed atomic.Bool + stats PubSubStats +} + +// NewPubSubPool implements a pool for PubSub connections. +// It intentionally does not implement the Pooler interface +func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { + return &PubSubPool{ + opt: opt, + netDialer: netDialer, + } +} + +func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) { + if p.closed.Load() { + return nil, ErrClosed + } + + netConn, err := p.netDialer(ctx, network, addr) + if err != nil { + return nil, err + } + cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize) + cn.pubsub = true + // Set pool name for metrics + cn.SetPoolName(p.opt.Name) + atomic.AddUint32(&p.stats.Created, 1) + return cn, nil +} + +func (p *PubSubPool) TrackConn(cn *Conn) { + atomic.AddUint32(&p.stats.Active, 1) + p.activeConns.Store(cn.GetID(), cn) + // Emit +1 used for PubSub connection + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(context.Background(), 1, cn, "used", true) + } +} + +func (p *PubSubPool) UntrackConn(cn *Conn) { + // LoadAndDelete ensures each connection is only decremented once, + // guarding against double-decrement if Close() already untracked it. + if _, loaded := p.activeConns.LoadAndDelete(cn.GetID()); !loaded { + return + } + atomic.AddUint32(&p.stats.Active, ^uint32(0)) + atomic.AddUint32(&p.stats.Untracked, 1) + // Emit -1 used for PubSub connection + if cb := getMetricConnectionCountCallback(); cb != nil { + cb(context.Background(), -1, cn, "used", true) + } +} + +func (p *PubSubPool) Close() error { + p.closed.Store(true) + cb := getMetricConnectionCountCallback() + p.activeConns.Range(func(key, value interface{}) bool { + cn := value.(*Conn) + // Use LoadAndDelete to atomically claim ownership of this entry. + // If a concurrent UntrackConn already removed it, skip to avoid double-decrement. + if _, loaded := p.activeConns.LoadAndDelete(key); !loaded { + return true + } + atomic.AddUint32(&p.stats.Active, ^uint32(0)) + atomic.AddUint32(&p.stats.Untracked, 1) + // Emit -1 used for each PubSub connection being closed + if cb != nil { + cb(context.Background(), -1, cn, "used", true) + } + _ = cn.Close() + return true + }) + return nil +} + +func (p *PubSubPool) Stats() *PubSubStats { + // load stats atomically + return &PubSubStats{ + Created: atomic.LoadUint32(&p.stats.Created), + Untracked: atomic.LoadUint32(&p.stats.Untracked), + Active: atomic.LoadUint32(&p.stats.Active), + } +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/pool/want_conn.go b/vendor/github.com/redis/go-redis/v9/internal/pool/want_conn.go new file mode 100644 index 000000000..78f86813f --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/pool/want_conn.go @@ -0,0 +1,115 @@ +package pool + +import ( + "context" + "sync" +) + +type wantConn struct { + mu sync.RWMutex // protects ctx, done and sending of the result + ctx context.Context // context for dial, cleared after delivered or canceled + cancelCtx context.CancelFunc + done bool // true after delivered or canceled + result chan wantConnResult // channel to deliver connection or error +} + +// getCtxForDial returns context for dial or nil if connection was delivered or canceled. +func (w *wantConn) getCtxForDial() context.Context { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.ctx +} + +func (w *wantConn) tryDeliver(cn *Conn, err error) bool { + w.mu.Lock() + defer w.mu.Unlock() + if w.done { + return false + } + + w.done = true + w.ctx = nil + + w.result <- wantConnResult{cn: cn, err: err} + close(w.result) + + return true +} + +func (w *wantConn) cancel() *Conn { + w.mu.Lock() + var cn *Conn + if w.done { + select { + case result := <-w.result: + cn = result.cn + default: + } + } else { + close(w.result) + } + + w.done = true + w.ctx = nil + w.mu.Unlock() + + return cn +} + +func (w *wantConn) isOngoing() bool { + w.mu.RLock() + defer w.mu.RUnlock() + return !w.done +} + +type wantConnResult struct { + cn *Conn + err error +} + +type wantConnQueue struct { + mu sync.RWMutex + items []*wantConn +} + +func newWantConnQueue() *wantConnQueue { + return &wantConnQueue{ + items: make([]*wantConn, 0), + } +} + +func (q *wantConnQueue) enqueue(w *wantConn) { + q.mu.Lock() + defer q.mu.Unlock() + q.items = append(q.items, w) +} + +func (q *wantConnQueue) dequeue() (*wantConn, bool) { + q.mu.Lock() + defer q.mu.Unlock() + + if len(q.items) == 0 { + return nil, false + } + + item := q.items[0] + q.items = q.items[1:] + return item, true +} + +func (q *wantConnQueue) discardDoneAtFront() int { + q.mu.Lock() + defer q.mu.Unlock() + count := 0 + for len(q.items) > 0 { + if q.items[0].isOngoing() { + break + } + + q.items = q.items[1:] + count++ + } + + return count +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/proto/reader.go b/vendor/github.com/redis/go-redis/v9/internal/proto/reader.go index 8d23817fe..83f28e4da 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/proto/reader.go +++ b/vendor/github.com/redis/go-redis/v9/internal/proto/reader.go @@ -12,6 +12,9 @@ import ( "github.com/redis/go-redis/v9/internal/util" ) +// DefaultBufferSize is the default size for read/write buffers (32 KiB). +const DefaultBufferSize = 32 * 1024 + // redis resp protocol data type. const ( RespStatus = '+' // +\r\n @@ -47,7 +50,8 @@ func (e RedisError) Error() string { return string(e) } func (RedisError) RedisError() {} func ParseErrorReply(line []byte) error { - return RedisError(line[1:]) + msg := string(line[1:]) + return parseTypedRedisError(msg) } //------------------------------------------------------------------------------ @@ -58,7 +62,13 @@ type Reader struct { func NewReader(rd io.Reader) *Reader { return &Reader{ - rd: bufio.NewReader(rd), + rd: bufio.NewReaderSize(rd, DefaultBufferSize), + } +} + +func NewReaderSize(rd io.Reader, size int) *Reader { + return &Reader{ + rd: bufio.NewReaderSize(rd, size), } } @@ -90,6 +100,161 @@ func (r *Reader) PeekReplyType() (byte, error) { return b[0], nil } +// PeekPushNotificationName returns the notification name of the next RESP3 +// push frame without consuming it. The caller is expected to have already +// verified that the next reply is a push notification (e.g. via PeekReplyType +// returning RespPush). +// +// To identify the name the method may block briefly reading more bytes from +// the underlying connection. That is safe: once the push marker '>' has been +// observed, the server is committed to sending the rest of the frame, so +// fetching the next few header bytes does not race with anything the caller +// could be waiting on. Blocking is preferred to a truncated peek, which would +// silently misidentify the notification and cause the caller's ReadReply to +// consume (and drop) the frame; see issue #3839. +func (r *Reader) PeekPushNotificationName() (string, error) { + c, err := r.rd.Peek(1) + if err != nil { + return "", err + } + if c[0] != RespPush { + return "", fmt.Errorf("redis: can't peek push notification name, next reply is not a push notification") + } + + // Start with a peek window that covers every Redis-defined notification + // header (MOVING, MIGRATING, FAILED_OVER, message, pmessage, smessage, + // subscribe, unsubscribe, ...). If a longer name is encountered, grow + // the window up to maxPushHeaderPeek before giving up. + const initialPeek = 36 + const maxPushHeaderPeek = 4096 + + peekSize := initialPeek + for { + buf, peekErr := r.rd.Peek(peekSize) + name, complete, parseErr := parsePushNotificationName(buf) + if parseErr != nil { + return "", parseErr + } + if complete { + return name, nil + } + // Parser ran out of bytes. Surface a failed underlying read before + // growing further; otherwise grow the peek window and retry. + if peekErr != nil { + return "", peekErr + } + if peekSize >= maxPushHeaderPeek { + return "", fmt.Errorf("redis: push notification header exceeds %d bytes", maxPushHeaderPeek) + } + peekSize *= 2 + if peekSize > maxPushHeaderPeek { + peekSize = maxPushHeaderPeek + } + } +} + +// parsePushNotificationName extracts the notification name from a buffered +// RESP3 push frame prefix. The three return values are: +// +// - (name, true, nil): the full name is in buf. +// - ("", false, nil): buf is a valid prefix but too short to determine the +// name; the caller should fetch more bytes and retry. +// - ("", _, err): buf is malformed. +// +// This split lets PeekPushNotificationName tell "incomplete header" apart +// from "corrupt frame" without ever returning a truncated string. +func parsePushNotificationName(buf []byte) (string, bool, error) { + // Need at least ">N\r" before any meaningful work. + if len(buf) < 3 { + return "", false, nil + } + if buf[0] != RespPush { + return "", false, fmt.Errorf("redis: can't parse push notification: %q", buf) + } + + // Skip the array length line ">N\r\n". + const arrayLenStart = 1 // first byte after the '>' marker + pos, ok, err := skipDigitsThenCRLF(buf, arrayLenStart) + if err != nil { + return "", false, fmt.Errorf("redis: can't parse push notification: %w", err) + } + if !ok { + return "", false, nil + } + // Reject ">\r\n": RESP requires at least one digit for the array length. + // Without this check the empty length looks like a valid prefix and the + // caller would block fetching more bytes for a frame that is already + // malformed. + if pos-2 == arrayLenStart { + return "", false, fmt.Errorf("redis: empty push notification array length") + } + + // First element type byte: '$' (bulk) or '+' (simple-string). + if pos >= len(buf) { + return "", false, nil + } + typeOfName := buf[pos] + if typeOfName != RespString && typeOfName != RespStatus { + return "", false, fmt.Errorf("redis: can't parse push notification name: %q", buf[pos:]) + } + pos++ + + if typeOfName == RespString { + // Read "$M\r\n" then the M-byte name. + lenStart := pos + next, ok, err := skipDigitsThenCRLF(buf, pos) + if err != nil { + return "", false, fmt.Errorf("redis: can't parse push notification name length: %w", err) + } + if !ok { + return "", false, nil + } + if next-2 == lenStart { + return "", false, fmt.Errorf("redis: empty push notification name length") + } + nameLen, err := util.Atoi(buf[lenStart : next-2]) + if err != nil { + return "", false, fmt.Errorf("redis: invalid push notification name length %q: %w", buf[lenStart:next-2], err) + } + if nameLen < 0 { + return "", false, fmt.Errorf("redis: negative push notification name length: %d", nameLen) + } + // Compare against the remaining bytes instead of computing + // next+nameLen: a hugely advertised length on malformed input could + // overflow int, wrap negative, slip past an "end > len(buf)" guard and + // panic the slice below. next <= len(buf) here, so the subtraction is + // safe. + if nameLen > len(buf)-next { + return "", false, nil + } + return util.BytesToString(buf[next : next+nameLen]), true, nil + } + + // RespStatus: scan for the terminating CRLF. + for i := pos; i < len(buf)-1; i++ { + if buf[i] == '\r' && buf[i+1] == '\n' { + return util.BytesToString(buf[pos:i]), true, nil + } + } + return "", false, nil +} + +// skipDigitsThenCRLF advances past zero-or-more ASCII digits and the +// terminating "\r\n" starting at offset start in buf. It returns the position +// after the "\r\n" and true on success; (pos, false, nil) if buf is too +// short; or an error if a non-digit non-CR byte is encountered before the CRLF. +func skipDigitsThenCRLF(buf []byte, start int) (int, bool, error) { + for pos := start; pos < len(buf)-1; pos++ { + if buf[pos] == '\r' && buf[pos+1] == '\n' { + return pos + 2, true, nil + } + if buf[pos] < '0' || buf[pos] > '9' { + return pos, false, fmt.Errorf("expected digit or CRLF, got %q", buf[pos]) + } + } + return len(buf), false, nil +} + // ReadLine Return a valid reply, it will check the protocol or redis error, // and discard the attribute type. func (r *Reader) ReadLine() ([]byte, error) { @@ -106,7 +271,7 @@ func (r *Reader) ReadLine() ([]byte, error) { var blobErr string blobErr, err = r.readStringReply(line) if err == nil { - err = RedisError(blobErr) + err = parseTypedRedisError(blobErr) } return nil, err case RespAttr: @@ -183,8 +348,8 @@ func (r *Reader) ReadReply() (interface{}, error) { } func (r *Reader) readFloat(line []byte) (float64, error) { - v := string(line[1:]) - switch string(line[1:]) { + v := util.BytesToString(line[1:]) + switch v { case "inf": return math.Inf(1), nil case "-inf": @@ -196,7 +361,7 @@ func (r *Reader) readFloat(line []byte) (float64, error) { } func (r *Reader) readBool(line []byte) (bool, error) { - switch string(line[1:]) { + switch util.BytesToString(line[1:]) { case "t": return true, nil case "f": @@ -207,7 +372,7 @@ func (r *Reader) readBool(line []byte) (bool, error) { func (r *Reader) readBigInt(line []byte) (*big.Int, error) { i := new(big.Int) - if i, ok := i.SetString(string(line[1:]), 10); ok { + if i, ok := i.SetString(util.BytesToString(line[1:]), 10); ok { return i, nil } return nil, fmt.Errorf("redis: can't parse bigInt reply: %q", line) @@ -357,7 +522,7 @@ func (r *Reader) ReadFloat() (float64, error) { case RespFloat: return r.readFloat(line) case RespStatus: - return strconv.ParseFloat(string(line[1:]), 64) + return strconv.ParseFloat(util.BytesToString(line[1:]), 64) case RespString: s, err := r.readStringReply(line) if err != nil { @@ -550,3 +715,193 @@ func IsNilReply(line []byte) bool { (line[0] == RespString || line[0] == RespArray) && line[1] == '-' && line[2] == '1' } + +// ReadRawReply reads the next RESP reply and returns it as raw bytes without parsing. +func (r *Reader) ReadRawReply() ([]byte, error) { + return r.readRawReplyBuf(nil) +} + +func (r *Reader) readRawReplyBuf(buf []byte) ([]byte, error) { + line, err := r.readLine() + if err != nil { + return buf, err + } + + buf = append(buf, line...) + buf = append(buf, '\r', '\n') + + switch line[0] { + case RespStatus, RespError, RespInt, RespNil, RespFloat, RespBool, RespBigInt: + return buf, nil + + case RespString, RespVerbatim, RespBlobError: + n, err := replyLen(line) + if err != nil { + if err == Nil { + return buf, nil + } + return buf, err + } + curLen := len(buf) + buf = append(buf, make([]byte, n+2)...) + _, err = io.ReadFull(r.rd, buf[curLen:]) + return buf, err + + case RespArray, RespSet, RespPush: + n, err := replyLen(line) + if err != nil { + if err == Nil { + return buf, nil + } + return buf, err + } + for i := 0; i < n; i++ { + buf, err = r.readRawReplyBuf(buf) + if err != nil { + return buf, err + } + } + return buf, nil + + case RespMap: + n, err := replyLen(line) + if err != nil { + if err == Nil { + return buf, nil + } + return buf, err + } + for i := 0; i < n*2; i++ { + buf, err = r.readRawReplyBuf(buf) + if err != nil { + return buf, err + } + } + return buf, nil + + case RespAttr: + // Per RESP3 spec, an attribute is always followed by the actual command reply. + // We need to read the attribute's key-value pairs AND the following reply. + n, err := replyLen(line) + if err != nil { + if err == Nil { + return buf, nil + } + return buf, err + } + // Read the attribute key-value pairs + for i := 0; i < n*2; i++ { + buf, err = r.readRawReplyBuf(buf) + if err != nil { + return buf, err + } + } + // Read the command reply that follows the attribute + return r.readRawReplyBuf(buf) + } + + return buf, fmt.Errorf("redis: can't read raw reply: %.100q", line) +} + +var crlf = []byte{'\r', '\n'} + +// ReadRawReplyWriteTo streams the next RESP reply directly to w without intermediate allocations. +// Returns the number of bytes written and any error encountered. +func (r *Reader) ReadRawReplyWriteTo(w io.Writer) (int64, error) { + return r.readRawReplyWriteTo(w) +} + +func (r *Reader) readRawReplyWriteTo(w io.Writer) (int64, error) { + line, err := r.readLine() + if err != nil { + return 0, err + } + + var written int64 + n, err := w.Write(line) + written += int64(n) + if err != nil { + return written, err + } + n, err = w.Write(crlf) + written += int64(n) + if err != nil { + return written, err + } + + switch line[0] { + case RespStatus, RespError, RespInt, RespNil, RespFloat, RespBool, RespBigInt: + return written, nil + + case RespString, RespVerbatim, RespBlobError: + dataLen, err := replyLen(line) + if err != nil { + if err == Nil { + return written, nil + } + return written, err + } + copied, err := io.CopyN(w, r.rd, int64(dataLen)+2) + written += copied + return written, err + + case RespArray, RespSet, RespPush: + count, err := replyLen(line) + if err != nil { + if err == Nil { + return written, nil + } + return written, err + } + for i := 0; i < count; i++ { + n, err := r.readRawReplyWriteTo(w) + written += n + if err != nil { + return written, err + } + } + return written, nil + + case RespMap: + count, err := replyLen(line) + if err != nil { + if err == Nil { + return written, nil + } + return written, err + } + for i := 0; i < count*2; i++ { + n, err := r.readRawReplyWriteTo(w) + written += n + if err != nil { + return written, err + } + } + return written, nil + + case RespAttr: + // Per RESP3 spec, an attribute is always followed by the actual command reply. + // We need to read the attribute's key-value pairs AND the following reply. + count, err := replyLen(line) + if err != nil { + if err == Nil { + return written, nil + } + return written, err + } + // Read the attribute key-value pairs + for i := 0; i < count*2; i++ { + n, err := r.readRawReplyWriteTo(w) + written += n + if err != nil { + return written, err + } + } + // Read the command reply that follows the attribute + n, err := r.readRawReplyWriteTo(w) + written += n + return written, err + } + + return written, fmt.Errorf("redis: can't read raw reply: %.100q", line) +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/proto/redis_errors.go b/vendor/github.com/redis/go-redis/v9/internal/proto/redis_errors.go new file mode 100644 index 000000000..a75370cf7 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/proto/redis_errors.go @@ -0,0 +1,539 @@ +package proto + +import ( + "errors" + "strings" +) + +// Typed Redis errors for better error handling with wrapping support. +// These errors maintain backward compatibility by keeping the same error messages. + +// LoadingError is returned when Redis is loading the dataset in memory. +type LoadingError struct { + msg string +} + +func (e *LoadingError) Error() string { + return e.msg +} + +func (e *LoadingError) RedisError() {} + +// NewLoadingError creates a new LoadingError with the given message. +func NewLoadingError(msg string) *LoadingError { + return &LoadingError{msg: msg} +} + +// ReadOnlyError is returned when trying to write to a read-only replica. +type ReadOnlyError struct { + msg string +} + +func (e *ReadOnlyError) Error() string { + return e.msg +} + +func (e *ReadOnlyError) RedisError() {} + +// NewReadOnlyError creates a new ReadOnlyError with the given message. +func NewReadOnlyError(msg string) *ReadOnlyError { + return &ReadOnlyError{msg: msg} +} + +// MovedError is returned when a key has been moved to a different node in a cluster. +type MovedError struct { + msg string + addr string +} + +func (e *MovedError) Error() string { + return e.msg +} + +func (e *MovedError) RedisError() {} + +// Addr returns the address of the node where the key has been moved. +func (e *MovedError) Addr() string { + return e.addr +} + +// NewMovedError creates a new MovedError with the given message and address. +func NewMovedError(msg string, addr string) *MovedError { + return &MovedError{msg: msg, addr: addr} +} + +// AskError is returned when a key is being migrated and the client should ask another node. +type AskError struct { + msg string + addr string +} + +func (e *AskError) Error() string { + return e.msg +} + +func (e *AskError) RedisError() {} + +// Addr returns the address of the node to ask. +func (e *AskError) Addr() string { + return e.addr +} + +// NewAskError creates a new AskError with the given message and address. +func NewAskError(msg string, addr string) *AskError { + return &AskError{msg: msg, addr: addr} +} + +// ClusterDownError is returned when the cluster is down. +type ClusterDownError struct { + msg string +} + +func (e *ClusterDownError) Error() string { + return e.msg +} + +func (e *ClusterDownError) RedisError() {} + +// NewClusterDownError creates a new ClusterDownError with the given message. +func NewClusterDownError(msg string) *ClusterDownError { + return &ClusterDownError{msg: msg} +} + +// TryAgainError is returned when a command cannot be processed and should be retried. +type TryAgainError struct { + msg string +} + +func (e *TryAgainError) Error() string { + return e.msg +} + +func (e *TryAgainError) RedisError() {} + +// NewTryAgainError creates a new TryAgainError with the given message. +func NewTryAgainError(msg string) *TryAgainError { + return &TryAgainError{msg: msg} +} + +// MasterDownError is returned when the master is down. +type MasterDownError struct { + msg string +} + +func (e *MasterDownError) Error() string { + return e.msg +} + +func (e *MasterDownError) RedisError() {} + +// NewMasterDownError creates a new MasterDownError with the given message. +func NewMasterDownError(msg string) *MasterDownError { + return &MasterDownError{msg: msg} +} + +// MaxClientsError is returned when the maximum number of clients has been reached. +type MaxClientsError struct { + msg string +} + +func (e *MaxClientsError) Error() string { + return e.msg +} + +func (e *MaxClientsError) RedisError() {} + +// NewMaxClientsError creates a new MaxClientsError with the given message. +func NewMaxClientsError(msg string) *MaxClientsError { + return &MaxClientsError{msg: msg} +} + +// AuthError is returned when authentication fails. +type AuthError struct { + msg string +} + +func (e *AuthError) Error() string { + return e.msg +} + +func (e *AuthError) RedisError() {} + +// NewAuthError creates a new AuthError with the given message. +func NewAuthError(msg string) *AuthError { + return &AuthError{msg: msg} +} + +// PermissionError is returned when a user lacks required permissions. +type PermissionError struct { + msg string +} + +func (e *PermissionError) Error() string { + return e.msg +} + +func (e *PermissionError) RedisError() {} + +// NewPermissionError creates a new PermissionError with the given message. +func NewPermissionError(msg string) *PermissionError { + return &PermissionError{msg: msg} +} + +// ExecAbortError is returned when a transaction is aborted. +type ExecAbortError struct { + msg string +} + +func (e *ExecAbortError) Error() string { + return e.msg +} + +func (e *ExecAbortError) RedisError() {} + +// NewExecAbortError creates a new ExecAbortError with the given message. +func NewExecAbortError(msg string) *ExecAbortError { + return &ExecAbortError{msg: msg} +} + +// OOMError is returned when Redis is out of memory. +type OOMError struct { + msg string +} + +func (e *OOMError) Error() string { + return e.msg +} + +func (e *OOMError) RedisError() {} + +// NewOOMError creates a new OOMError with the given message. +func NewOOMError(msg string) *OOMError { + return &OOMError{msg: msg} +} + +// NoReplicasError is returned when not enough replicas acknowledge a write. +// This error occurs when using WAIT/WAITAOF commands or CLUSTER SETSLOT with +// synchronous replication, and the required number of replicas cannot confirm +// the write within the timeout period. +type NoReplicasError struct { + msg string +} + +func (e *NoReplicasError) Error() string { + return e.msg +} + +func (e *NoReplicasError) RedisError() {} + +// NewNoReplicasError creates a new NoReplicasError with the given message. +func NewNoReplicasError(msg string) *NoReplicasError { + return &NoReplicasError{msg: msg} +} + +// parseTypedRedisError parses a Redis error message and returns a typed error if applicable. +// This function maintains backward compatibility by keeping the same error messages. +func parseTypedRedisError(msg string) error { + // Check for specific error patterns and return typed errors + switch { + case strings.HasPrefix(msg, "LOADING "): + return NewLoadingError(msg) + case strings.HasPrefix(msg, "READONLY "): + return NewReadOnlyError(msg) + case strings.HasPrefix(msg, "MOVED "): + // Extract address from "MOVED " + addr := extractAddr(msg) + return NewMovedError(msg, addr) + case strings.HasPrefix(msg, "ASK "): + // Extract address from "ASK " + addr := extractAddr(msg) + return NewAskError(msg, addr) + case strings.HasPrefix(msg, "CLUSTERDOWN "): + return NewClusterDownError(msg) + case strings.HasPrefix(msg, "TRYAGAIN "): + return NewTryAgainError(msg) + case strings.HasPrefix(msg, "MASTERDOWN "): + return NewMasterDownError(msg) + case strings.HasPrefix(msg, "NOREPLICAS "): + return NewNoReplicasError(msg) + case msg == "ERR max number of clients reached": + return NewMaxClientsError(msg) + case strings.HasPrefix(msg, "NOAUTH "), strings.HasPrefix(msg, "WRONGPASS "), strings.Contains(msg, "unauthenticated"): + return NewAuthError(msg) + case strings.HasPrefix(msg, "NOPERM "): + return NewPermissionError(msg) + case strings.HasPrefix(msg, "EXECABORT "): + return NewExecAbortError(msg) + case strings.HasPrefix(msg, "OOM "): + return NewOOMError(msg) + default: + // Return generic RedisError for unknown error types + return RedisError(msg) + } +} + +// extractAddr extracts the address from MOVED/ASK error messages. +// Format: "MOVED " or "ASK " +func extractAddr(msg string) string { + ind := strings.LastIndex(msg, " ") + if ind == -1 { + return "" + } + return msg[ind+1:] +} + +// IsLoadingError checks if an error is a LoadingError, even if wrapped. +func IsLoadingError(err error) bool { + if err == nil { + return false + } + var loadingErr *LoadingError + if errors.As(err, &loadingErr) { + return true + } + // Check if wrapped error is a RedisError with LOADING prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "LOADING ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "LOADING ") +} + +// IsReadOnlyError checks if an error is a ReadOnlyError, even if wrapped. +func IsReadOnlyError(err error) bool { + if err == nil { + return false + } + var readOnlyErr *ReadOnlyError + if errors.As(err, &readOnlyErr) { + return true + } + // Check if wrapped error is a RedisError with READONLY prefix or Lua script READONLY + var redisErr RedisError + if errors.As(err, &redisErr) { + s := redisErr.Error() + if strings.HasPrefix(s, "READONLY ") { + return true + } + // Lua script wrapped READONLY errors: + // "ERR Error running script (call to f_): @user_script:N: -READONLY You can't write against a read only replica." + if strings.Contains(s, "-READONLY You can't write against a read only replica") { + return true + } + } + // Fallback to string checking for backward compatibility + s := err.Error() + if strings.HasPrefix(s, "READONLY ") { + return true + } + return strings.Contains(s, "-READONLY You can't write against a read only replica") +} + +// IsMovedError checks if an error is a MovedError, even if wrapped. +// Returns the error and a boolean indicating if it's a MovedError. +func IsMovedError(err error) (*MovedError, bool) { + if err == nil { + return nil, false + } + var movedErr *MovedError + if errors.As(err, &movedErr) { + return movedErr, true + } + // Fallback to string checking for backward compatibility + s := err.Error() + if strings.HasPrefix(s, "MOVED ") { + // Parse: MOVED 3999 127.0.0.1:6381 + parts := strings.Split(s, " ") + if len(parts) == 3 { + return &MovedError{msg: s, addr: parts[2]}, true + } + } + return nil, false +} + +// IsAskError checks if an error is an AskError, even if wrapped. +// Returns the error and a boolean indicating if it's an AskError. +func IsAskError(err error) (*AskError, bool) { + if err == nil { + return nil, false + } + var askErr *AskError + if errors.As(err, &askErr) { + return askErr, true + } + // Fallback to string checking for backward compatibility + s := err.Error() + if strings.HasPrefix(s, "ASK ") { + // Parse: ASK 3999 127.0.0.1:6381 + parts := strings.Split(s, " ") + if len(parts) == 3 { + return &AskError{msg: s, addr: parts[2]}, true + } + } + return nil, false +} + +// IsClusterDownError checks if an error is a ClusterDownError, even if wrapped. +func IsClusterDownError(err error) bool { + if err == nil { + return false + } + var clusterDownErr *ClusterDownError + if errors.As(err, &clusterDownErr) { + return true + } + // Check if wrapped error is a RedisError with CLUSTERDOWN prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "CLUSTERDOWN ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "CLUSTERDOWN ") +} + +// IsTryAgainError checks if an error is a TryAgainError, even if wrapped. +func IsTryAgainError(err error) bool { + if err == nil { + return false + } + var tryAgainErr *TryAgainError + if errors.As(err, &tryAgainErr) { + return true + } + // Check if wrapped error is a RedisError with TRYAGAIN prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "TRYAGAIN ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "TRYAGAIN ") +} + +// IsMasterDownError checks if an error is a MasterDownError, even if wrapped. +func IsMasterDownError(err error) bool { + if err == nil { + return false + } + var masterDownErr *MasterDownError + if errors.As(err, &masterDownErr) { + return true + } + // Check if wrapped error is a RedisError with MASTERDOWN prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "MASTERDOWN ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "MASTERDOWN ") +} + +// IsMaxClientsError checks if an error is a MaxClientsError, even if wrapped. +func IsMaxClientsError(err error) bool { + if err == nil { + return false + } + var maxClientsErr *MaxClientsError + if errors.As(err, &maxClientsErr) { + return true + } + // Check if wrapped error is a RedisError with max clients prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "ERR max number of clients reached") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "ERR max number of clients reached") +} + +// IsAuthError checks if an error is an AuthError, even if wrapped. +func IsAuthError(err error) bool { + if err == nil { + return false + } + var authErr *AuthError + if errors.As(err, &authErr) { + return true + } + // Check if wrapped error is a RedisError with auth error prefix + var redisErr RedisError + if errors.As(err, &redisErr) { + s := redisErr.Error() + return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated") + } + // Fallback to string checking for backward compatibility + s := err.Error() + return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated") +} + +// IsPermissionError checks if an error is a PermissionError, even if wrapped. +func IsPermissionError(err error) bool { + if err == nil { + return false + } + var permErr *PermissionError + if errors.As(err, &permErr) { + return true + } + // Check if wrapped error is a RedisError with NOPERM prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "NOPERM ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "NOPERM ") +} + +// IsExecAbortError checks if an error is an ExecAbortError, even if wrapped. +func IsExecAbortError(err error) bool { + if err == nil { + return false + } + var execAbortErr *ExecAbortError + if errors.As(err, &execAbortErr) { + return true + } + // Check if wrapped error is a RedisError with EXECABORT prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "EXECABORT ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "EXECABORT ") +} + +// IsOOMError checks if an error is an OOMError, even if wrapped. +func IsOOMError(err error) bool { + if err == nil { + return false + } + var oomErr *OOMError + if errors.As(err, &oomErr) { + return true + } + // Check if wrapped error is a RedisError with OOM prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "OOM ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "OOM ") +} + +// IsNoReplicasError checks if an error is a NoReplicasError, even if wrapped. +func IsNoReplicasError(err error) bool { + if err == nil { + return false + } + var noReplicasErr *NoReplicasError + if errors.As(err, &noReplicasErr) { + return true + } + // Check if wrapped error is a RedisError with NOREPLICAS prefix + var redisErr RedisError + if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "NOREPLICAS ") { + return true + } + // Fallback to string checking for backward compatibility + return strings.HasPrefix(err.Error(), "NOREPLICAS ") +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/rand/rand.go b/vendor/github.com/redis/go-redis/v9/internal/rand/rand.go deleted file mode 100644 index 2edccba94..000000000 --- a/vendor/github.com/redis/go-redis/v9/internal/rand/rand.go +++ /dev/null @@ -1,50 +0,0 @@ -package rand - -import ( - "math/rand" - "sync" -) - -// Int returns a non-negative pseudo-random int. -func Int() int { return pseudo.Int() } - -// Intn returns, as an int, a non-negative pseudo-random number in [0,n). -// It panics if n <= 0. -func Intn(n int) int { return pseudo.Intn(n) } - -// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n). -// It panics if n <= 0. -func Int63n(n int64) int64 { return pseudo.Int63n(n) } - -// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n). -func Perm(n int) []int { return pseudo.Perm(n) } - -// Seed uses the provided seed value to initialize the default Source to a -// deterministic state. If Seed is not called, the generator behaves as if -// seeded by Seed(1). -func Seed(n int64) { pseudo.Seed(n) } - -var pseudo = rand.New(&source{src: rand.NewSource(1)}) - -type source struct { - src rand.Source - mu sync.Mutex -} - -func (s *source) Int63() int64 { - s.mu.Lock() - n := s.src.Int63() - s.mu.Unlock() - return n -} - -func (s *source) Seed(seed int64) { - s.mu.Lock() - s.src.Seed(seed) - s.mu.Unlock() -} - -// Shuffle pseudo-randomizes the order of elements. -// n is the number of elements. -// swap swaps the elements with indexes i and j. -func Shuffle(n int, swap func(i, j int)) { pseudo.Shuffle(n, swap) } diff --git a/vendor/github.com/redis/go-redis/v9/internal/redis.go b/vendor/github.com/redis/go-redis/v9/internal/redis.go new file mode 100644 index 000000000..190bbebea --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/redis.go @@ -0,0 +1,3 @@ +package internal + +const RedisNull = "" diff --git a/vendor/github.com/redis/go-redis/v9/internal/routing/aggregator.go b/vendor/github.com/redis/go-redis/v9/internal/routing/aggregator.go new file mode 100644 index 000000000..0d6321ec1 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/routing/aggregator.go @@ -0,0 +1,1000 @@ +package routing + +import ( + "errors" + "fmt" + "math" + "sync" + + "sync/atomic" + + "github.com/redis/go-redis/v9/internal/util" + uberAtomic "go.uber.org/atomic" +) + +var ( + ErrMaxAggregation = errors.New("redis: no valid results to aggregate for max operation") + ErrMinAggregation = errors.New("redis: no valid results to aggregate for min operation") + ErrAndAggregation = errors.New("redis: no valid results to aggregate for logical AND operation") + ErrOrAggregation = errors.New("redis: no valid results to aggregate for logical OR operation") +) + +// ResponseAggregator defines the interface for aggregating responses from multiple shards. +type ResponseAggregator interface { + // Add processes a single shard response. + Add(result interface{}, err error) error + + // AddWithKey processes a single shard response for a specific key (used by keyed aggregators). + AddWithKey(key string, result interface{}, err error) error + + BatchAdd(map[string]AggregatorResErr) error + + BatchSlice([]AggregatorResErr) error + + // Result returns the final aggregated result and any error. + Result() (interface{}, error) +} + +type AggregatorResErr struct { + Result interface{} + Err error +} + +// NewResponseAggregator creates an aggregator based on the response policy. +func NewResponseAggregator(policy ResponsePolicy, cmdName string) ResponseAggregator { + switch policy { + case RespDefaultKeyless: + return &DefaultKeylessAggregator{results: make([]interface{}, 0)} + case RespDefaultHashSlot: + return &DefaultKeyedAggregator{results: make(map[string]interface{})} + case RespAllSucceeded: + return &AllSucceededAggregator{} + case RespOneSucceeded: + return &OneSucceededAggregator{} + case RespAggSum: + return &AggSumAggregator{ + // res: + } + case RespAggMin: + return &AggMinAggregator{ + res: util.NewAtomicMin(), + } + case RespAggMax: + return &AggMaxAggregator{ + res: util.NewAtomicMax(), + } + case RespAggLogicalAnd: + andAgg := &AggLogicalAndAggregator{} + andAgg.res.Store(true) + + return andAgg + case RespAggLogicalOr: + return &AggLogicalOrAggregator{} + case RespSpecial: + return NewSpecialAggregator(cmdName) + default: + return &AllSucceededAggregator{} + } +} + +func NewDefaultAggregator(isKeyed bool) ResponseAggregator { + if isKeyed { + return &DefaultKeyedAggregator{ + results: make(map[string]interface{}), + } + } + return &DefaultKeylessAggregator{} +} + +// AllSucceededAggregator returns one non-error reply if every shard succeeded, +// propagates the first error otherwise. +type AllSucceededAggregator struct { + err atomic.Value + res atomic.Value +} + +func (a *AllSucceededAggregator) Add(result interface{}, err error) error { + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + if result != nil { + a.res.CompareAndSwap(nil, result) + } + + return nil +} + +func (a *AllSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) error { + for _, res := range results { + err := a.Add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *AllSucceededAggregator) BatchSlice(results []AggregatorResErr) error { + for _, res := range results { + err := a.Add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *AllSucceededAggregator) Result() (interface{}, error) { + var err error + res, e := a.res.Load(), a.err.Load() + if e != nil { + err = e.(error) + } + + return res, err +} + +func (a *AllSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +// OneSucceededAggregator returns the first non-error reply, +// if all shards errored, returns any one of those errors. +type OneSucceededAggregator struct { + err atomic.Value + res atomic.Value +} + +func (a *OneSucceededAggregator) Add(result interface{}, err error) error { + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + if result != nil { + a.res.CompareAndSwap(nil, result) + } + + return nil +} + +func (a *OneSucceededAggregator) BatchAdd(results map[string]AggregatorResErr) error { + for _, res := range results { + err := a.Add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err == nil { + return nil + } + } + + return nil +} + +func (a *OneSucceededAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *OneSucceededAggregator) BatchSlice(results []AggregatorResErr) error { + for _, res := range results { + err := a.Add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err == nil { + return nil + } + } + + return nil +} + +func (a *OneSucceededAggregator) Result() (interface{}, error) { + res, e := a.res.Load(), a.err.Load() + if res == nil { + return nil, e.(error) + } + + return res, nil +} + +// AggSumAggregator sums numeric replies from all shards. +type AggSumAggregator struct { + err atomic.Value + res uberAtomic.Float64 +} + +func (a *AggSumAggregator) Add(result interface{}, err error) error { + if err != nil { + a.err.CompareAndSwap(nil, err) + } + + if result != nil { + val, err := toFloat64(result) + if err != nil { + a.err.CompareAndSwap(nil, err) + return err + } + a.res.Add(val) + } + + return nil +} + +func (a *AggSumAggregator) BatchAdd(results map[string]AggregatorResErr) error { + var sum int64 + + for _, res := range results { + if res.Err != nil { + return a.Add(res.Result, res.Err) + } + + intRes, err := toInt64(res.Result) + if err != nil { + return a.Add(nil, err) + } + + sum += intRes + } + + return a.Add(sum, nil) +} + +func (a *AggSumAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggSumAggregator) BatchSlice(results []AggregatorResErr) error { + var sum int64 + + for _, res := range results { + if res.Err != nil { + return a.Add(res.Result, res.Err) + } + + intRes, err := toInt64(res.Result) + if err != nil { + return a.Add(nil, err) + } + + sum += intRes + } + + return a.Add(sum, nil) +} + +func (a *AggSumAggregator) Result() (interface{}, error) { + res, err := a.res.Load(), a.err.Load() + if err != nil { + return nil, err.(error) + } + + return res, nil +} + +// AggMinAggregator returns the minimum numeric value from all shards. +type AggMinAggregator struct { + err atomic.Value + res *util.AtomicMin +} + +func (a *AggMinAggregator) Add(result interface{}, err error) error { + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + floatVal, e := toFloat64(result) + if e != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + a.res.Value(floatVal) + + return nil +} + +func (a *AggMinAggregator) BatchAdd(results map[string]AggregatorResErr) error { + min := int64(math.MaxInt64) + + for _, res := range results { + if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + resInt, err := toInt64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + if resInt < min { + min = resInt + } + + } + + return a.Add(min, nil) +} + +func (a *AggMinAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggMinAggregator) BatchSlice(results []AggregatorResErr) error { + min := float64(math.MaxFloat64) + + for _, res := range results { + if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + floatVal, err := toFloat64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + if floatVal < min { + min = floatVal + } + + } + + return a.Add(min, nil) +} + +func (a *AggMinAggregator) Result() (interface{}, error) { + err := a.err.Load() + if err != nil { + return nil, err.(error) + } + + val, hasVal := a.res.Min() + if !hasVal { + return nil, ErrMinAggregation + } + return val, nil +} + +// AggMaxAggregator returns the maximum numeric value from all shards. +type AggMaxAggregator struct { + err atomic.Value + res *util.AtomicMax +} + +func (a *AggMaxAggregator) Add(result interface{}, err error) error { + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + floatVal, e := toFloat64(result) + if e != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + a.res.Value(floatVal) + + return nil +} + +func (a *AggMaxAggregator) BatchAdd(results map[string]AggregatorResErr) error { + max := int64(math.MinInt64) + + for _, res := range results { + if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + resInt, err := toInt64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + if resInt > max { + max = resInt + } + + } + + return a.Add(max, nil) +} + +func (a *AggMaxAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggMaxAggregator) BatchSlice(results []AggregatorResErr) error { + max := int64(math.MinInt64) + + for _, res := range results { + if res.Err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + resInt, err := toInt64(res.Result) + if err != nil { + _ = a.Add(nil, res.Err) + return nil + } + + if resInt > max { + max = resInt + } + + } + + return a.Add(max, nil) +} + +func (a *AggMaxAggregator) Result() (interface{}, error) { + err := a.err.Load() + if err != nil { + return nil, err.(error) + } + + val, hasVal := a.res.Max() + if !hasVal { + return nil, ErrMaxAggregation + } + return val, nil +} + +// AggLogicalAndAggregator performs logical AND on boolean values. +type AggLogicalAndAggregator struct { + err atomic.Value + res atomic.Bool + hasResult atomic.Bool +} + +func (a *AggLogicalAndAggregator) Add(result interface{}, err error) error { + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + val, e := toBool(result) + if e != nil { + a.err.CompareAndSwap(nil, e) + return e + } + + // Atomic AND operation: if val is false, result is always false + if !val { + a.res.Store(false) + } + + a.hasResult.Store(true) + + return nil +} + +func (a *AggLogicalAndAggregator) BatchAdd(results map[string]AggregatorResErr) error { + result := true + + for _, res := range results { + if res.Err != nil { + return a.Add(nil, res.Err) + } + + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) + } + + result = result && boolRes + } + + return a.Add(result, nil) +} + +func (a *AggLogicalAndAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggLogicalAndAggregator) BatchSlice(results []AggregatorResErr) error { + result := true + + for _, res := range results { + if res.Err != nil { + return a.Add(nil, res.Err) + } + + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) + } + + result = result && boolRes + } + + return a.Add(result, nil) +} + +func (a *AggLogicalAndAggregator) Result() (interface{}, error) { + err := a.err.Load() + if err != nil { + return nil, err.(error) + } + + if !a.hasResult.Load() { + return nil, ErrAndAggregation + } + return a.res.Load(), nil +} + +// AggLogicalOrAggregator performs logical OR on boolean values. +type AggLogicalOrAggregator struct { + err atomic.Value + res atomic.Bool + hasResult atomic.Bool +} + +func (a *AggLogicalOrAggregator) Add(result interface{}, err error) error { + if err != nil { + a.err.CompareAndSwap(nil, err) + return nil + } + + val, e := toBool(result) + if e != nil { + a.err.CompareAndSwap(nil, e) + return e + } + + // Atomic OR operation: if val is true, result is always true + if val { + a.res.Store(true) + } + + a.hasResult.Store(true) + + return nil +} + +func (a *AggLogicalOrAggregator) BatchAdd(results map[string]AggregatorResErr) error { + result := false + + for _, res := range results { + if res.Err != nil { + return a.Add(nil, res.Err) + } + + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) + } + + result = result || boolRes + } + + return a.Add(result, nil) +} + +func (a *AggLogicalOrAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *AggLogicalOrAggregator) BatchSlice(results []AggregatorResErr) error { + result := false + + for _, res := range results { + if res.Err != nil { + return a.Add(nil, res.Err) + } + + boolRes, err := toBool(res.Result) + if err != nil { + return a.Add(nil, err) + } + + result = result || boolRes + } + + return a.Add(result, nil) +} + +func (a *AggLogicalOrAggregator) Result() (interface{}, error) { + err := a.err.Load() + if err != nil { + return nil, err.(error) + } + + if !a.hasResult.Load() { + return nil, ErrOrAggregation + } + return a.res.Load(), nil +} + +func toInt64(val interface{}) (int64, error) { + if val == nil { + return 0, nil + } + switch v := val.(type) { + case int64: + return v, nil + case int: + return int64(v), nil + case int32: + return int64(v), nil + case float64: + if v != math.Trunc(v) { + return 0, fmt.Errorf("cannot convert float %f to int64", v) + } + return int64(v), nil + default: + return 0, fmt.Errorf("cannot convert %T to int64", val) + } +} + +func toFloat64(val interface{}) (float64, error) { + if val == nil { + return 0, nil + } + + switch v := val.(type) { + case float64: + return v, nil + case int: + return float64(v), nil + case int32: + return float64(v), nil + case int64: + return float64(v), nil + case float32: + return float64(v), nil + default: + return 0, fmt.Errorf("cannot convert %T to float64", val) + } +} + +func toBool(val interface{}) (bool, error) { + if val == nil { + return false, nil + } + switch v := val.(type) { + case bool: + return v, nil + case int64: + return v != 0, nil + case int: + return v != 0, nil + default: + return false, fmt.Errorf("cannot convert %T to bool", val) + } +} + +// DefaultKeylessAggregator collects all results in an array, order doesn't matter. +type DefaultKeylessAggregator struct { + mu sync.Mutex + results []interface{} + firstErr error +} + +func (a *DefaultKeylessAggregator) add(result interface{}, err error) error { + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.results = append(a.results, result) + } + return nil +} + +func (a *DefaultKeylessAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + return a.add(result, err) +} + +func (a *DefaultKeylessAggregator) BatchAdd(results map[string]AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, res := range results { + err := a.add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *DefaultKeylessAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *DefaultKeylessAggregator) BatchSlice(results []AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, res := range results { + err := a.add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *DefaultKeylessAggregator) Result() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + return a.results, nil +} + +// DefaultKeyedAggregator reassembles replies in the exact key order of the original request. +type DefaultKeyedAggregator struct { + mu sync.Mutex + results map[string]interface{} + keyOrder []string + firstErr error +} + +func NewDefaultKeyedAggregator(keyOrder []string) *DefaultKeyedAggregator { + return &DefaultKeyedAggregator{ + results: make(map[string]interface{}), + keyOrder: keyOrder, + } +} + +func (a *DefaultKeyedAggregator) add(result interface{}, err error) error { + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + // For non-keyed Add, just collect the result without ordering + if err == nil { + a.results["__default__"] = result + } + return nil +} + +func (a *DefaultKeyedAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + return a.add(result, err) +} + +func (a *DefaultKeyedAggregator) BatchAdd(results map[string]AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, res := range results { + err := a.add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *DefaultKeyedAggregator) addWithKey(key string, result interface{}, err error) error { + if err != nil && a.firstErr == nil { + a.firstErr = err + return nil + } + if err == nil { + a.results[key] = result + } + return nil +} + +func (a *DefaultKeyedAggregator) AddWithKey(key string, result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + return a.addWithKey(key, result, err) +} + +func (a *DefaultKeyedAggregator) BatchAddWithKeyOrder(results map[string]AggregatorResErr, keyOrder []string) error { + a.mu.Lock() + defer a.mu.Unlock() + + a.keyOrder = keyOrder + for key, res := range results { + err := a.addWithKey(key, res.Result, res.Err) + if err != nil { + return nil + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *DefaultKeyedAggregator) SetKeyOrder(keyOrder []string) { + a.mu.Lock() + defer a.mu.Unlock() + a.keyOrder = keyOrder +} + +func (a *DefaultKeyedAggregator) BatchSlice(results []AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, res := range results { + err := a.add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *DefaultKeyedAggregator) Result() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.firstErr != nil { + return nil, a.firstErr + } + + // If no explicit key order is set, return results in any order + if len(a.keyOrder) == 0 { + orderedResults := make([]interface{}, 0, len(a.results)) + for _, result := range a.results { + orderedResults = append(orderedResults, result) + } + return orderedResults, nil + } + + // Return results in the exact key order + orderedResults := make([]interface{}, len(a.keyOrder)) + for i, key := range a.keyOrder { + if result, exists := a.results[key]; exists { + orderedResults[i] = result + } + } + return orderedResults, nil +} + +// SpecialAggregator provides a registry for command-specific aggregation logic. +type SpecialAggregator struct { + mu sync.Mutex + aggregatorFunc func([]interface{}, []error) (interface{}, error) + results []interface{} + errors []error +} + +func (a *SpecialAggregator) add(result interface{}, err error) error { + a.results = append(a.results, result) + a.errors = append(a.errors, err) + return nil +} + +func (a *SpecialAggregator) Add(result interface{}, err error) error { + a.mu.Lock() + defer a.mu.Unlock() + + return a.add(result, err) +} + +func (a *SpecialAggregator) BatchAdd(results map[string]AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, res := range results { + err := a.add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *SpecialAggregator) AddWithKey(key string, result interface{}, err error) error { + return a.Add(result, err) +} + +func (a *SpecialAggregator) BatchSlice(results []AggregatorResErr) error { + a.mu.Lock() + defer a.mu.Unlock() + + for _, res := range results { + err := a.add(res.Result, res.Err) + if err != nil { + return err + } + + if res.Err != nil { + return nil + } + } + + return nil +} + +func (a *SpecialAggregator) Result() (interface{}, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.aggregatorFunc != nil { + return a.aggregatorFunc(a.results, a.errors) + } + // Default behavior: return first non-error result or first error + for i, err := range a.errors { + if err == nil { + return a.results[i], nil + } + } + if len(a.errors) > 0 { + return nil, a.errors[0] + } + return nil, nil +} + +// SpecialAggregatorRegistry holds custom aggregation functions for specific commands. +var SpecialAggregatorRegistry = make(map[string]func([]interface{}, []error) (interface{}, error)) + +// RegisterSpecialAggregator registers a custom aggregation function for a command. +func RegisterSpecialAggregator(cmdName string, fn func([]interface{}, []error) (interface{}, error)) { + SpecialAggregatorRegistry[cmdName] = fn +} + +// NewSpecialAggregator creates a special aggregator with command-specific logic if available. +func NewSpecialAggregator(cmdName string) *SpecialAggregator { + agg := &SpecialAggregator{} + if fn, exists := SpecialAggregatorRegistry[cmdName]; exists { + agg.aggregatorFunc = fn + } + return agg +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/routing/policy.go b/vendor/github.com/redis/go-redis/v9/internal/routing/policy.go new file mode 100644 index 000000000..7f784b506 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/routing/policy.go @@ -0,0 +1,144 @@ +package routing + +import ( + "fmt" + "strings" +) + +type RequestPolicy uint8 + +const ( + ReqDefault RequestPolicy = iota + + ReqAllNodes + + ReqAllShards + + ReqMultiShard + + ReqSpecial +) + +const ( + ReadOnlyCMD string = "readonly" +) + +func (p RequestPolicy) String() string { + switch p { + case ReqDefault: + return "default" + case ReqAllNodes: + return "all_nodes" + case ReqAllShards: + return "all_shards" + case ReqMultiShard: + return "multi_shard" + case ReqSpecial: + return "special" + default: + return fmt.Sprintf("unknown_request_policy(%d)", p) + } +} + +func ParseRequestPolicy(raw string) (RequestPolicy, error) { + switch strings.ToLower(raw) { + case "", "default", "none": + return ReqDefault, nil + case "all_nodes": + return ReqAllNodes, nil + case "all_shards": + return ReqAllShards, nil + case "multi_shard": + return ReqMultiShard, nil + case "special": + return ReqSpecial, nil + default: + return ReqDefault, fmt.Errorf("routing: unknown request_policy %q", raw) + } +} + +type ResponsePolicy uint8 + +const ( + RespDefaultKeyless ResponsePolicy = iota + RespDefaultHashSlot + RespAllSucceeded + RespOneSucceeded + RespAggSum + RespAggMin + RespAggMax + RespAggLogicalAnd + RespAggLogicalOr + RespSpecial +) + +func (p ResponsePolicy) String() string { + switch p { + case RespDefaultKeyless: + return "default(keyless)" + case RespDefaultHashSlot: + return "default(hashslot)" + case RespAllSucceeded: + return "all_succeeded" + case RespOneSucceeded: + return "one_succeeded" + case RespAggSum: + return "agg_sum" + case RespAggMin: + return "agg_min" + case RespAggMax: + return "agg_max" + case RespAggLogicalAnd: + return "agg_logical_and" + case RespAggLogicalOr: + return "agg_logical_or" + case RespSpecial: + return "special" + default: + return "all_succeeded" + } +} + +func ParseResponsePolicy(raw string) (ResponsePolicy, error) { + switch strings.ToLower(raw) { + case "default(keyless)": + return RespDefaultKeyless, nil + case "default(hashslot)": + return RespDefaultHashSlot, nil + case "all_succeeded": + return RespAllSucceeded, nil + case "one_succeeded": + return RespOneSucceeded, nil + case "agg_sum": + return RespAggSum, nil + case "agg_min": + return RespAggMin, nil + case "agg_max": + return RespAggMax, nil + case "agg_logical_and": + return RespAggLogicalAnd, nil + case "agg_logical_or": + return RespAggLogicalOr, nil + case "special": + return RespSpecial, nil + default: + return RespDefaultKeyless, fmt.Errorf("routing: unknown response_policy %q", raw) + } +} + +type CommandPolicy struct { + Request RequestPolicy + Response ResponsePolicy + // Tips that are not request_policy or response_policy + // e.g nondeterministic_output, nondeterministic_output_order. + Tips map[string]string +} + +func (p *CommandPolicy) CanBeUsedInPipeline() bool { + return p.Request != ReqAllNodes && p.Request != ReqAllShards && p.Request != ReqMultiShard +} + +func (p *CommandPolicy) IsReadOnly() bool { + _, readOnly := p.Tips[ReadOnlyCMD] + return readOnly +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/routing/shard_picker.go b/vendor/github.com/redis/go-redis/v9/internal/routing/shard_picker.go new file mode 100644 index 000000000..8e6228dd2 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/routing/shard_picker.go @@ -0,0 +1,57 @@ +package routing + +import ( + "math/rand" + "sync/atomic" +) + +// ShardPicker chooses “one arbitrary shard” when the request_policy is +// ReqDefault and the command has no keys. +type ShardPicker interface { + Next(total int) int // returns an index in [0,total) +} + +// StaticShardPicker always returns the same shard index. +type StaticShardPicker struct { + index int +} + +func NewStaticShardPicker(index int) *StaticShardPicker { + return &StaticShardPicker{index: index} +} + +func (p *StaticShardPicker) Next(total int) int { + if total == 0 || p.index >= total { + return 0 + } + return p.index +} + +/*─────────────────────────────── + Round-robin (default) +────────────────────────────────*/ + +type RoundRobinPicker struct { + cnt atomic.Uint32 +} + +func (p *RoundRobinPicker) Next(total int) int { + if total == 0 { + return 0 + } + i := p.cnt.Add(1) + return int(i-1) % total +} + +/*─────────────────────────────── + Random +────────────────────────────────*/ + +type RandomPicker struct{} + +func (RandomPicker) Next(total int) int { + if total == 0 { + return 0 + } + return rand.Intn(total) +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/semaphore.go b/vendor/github.com/redis/go-redis/v9/internal/semaphore.go new file mode 100644 index 000000000..a7f40466c --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/semaphore.go @@ -0,0 +1,193 @@ +package internal + +import ( + "context" + "sync" + "time" +) + +var semTimers = sync.Pool{ + New: func() interface{} { + t := time.NewTimer(time.Hour) + t.Stop() + return t + }, +} + +// FastSemaphore is a channel-based semaphore optimized for performance. +// It uses a fast path that avoids timer allocation when tokens are available. +// The channel is pre-filled with tokens: Acquire = receive, Release = send. +// Closing the semaphore unblocks all waiting goroutines. +// +// Performance: ~30 ns/op with zero allocations on fast path. +// Fairness: Eventual fairness (no starvation) but not strict FIFO. +type FastSemaphore struct { + tokens chan struct{} + max int32 +} + +// NewFastSemaphore creates a new fast semaphore with the given capacity. +func NewFastSemaphore(capacity int32) *FastSemaphore { + ch := make(chan struct{}, capacity) + // Pre-fill with tokens + for i := int32(0); i < capacity; i++ { + ch <- struct{}{} + } + return &FastSemaphore{ + tokens: ch, + max: capacity, + } +} + +// TryAcquire attempts to acquire a token without blocking. +// Returns true if successful, false if no tokens available. +func (s *FastSemaphore) TryAcquire() bool { + select { + case <-s.tokens: + return true + default: + return false + } +} + +// Acquire acquires a token, blocking if necessary until one is available. +// Returns an error if the context is cancelled or the timeout expires. +// Uses a fast path to avoid timer allocation when tokens are immediately available. +func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error { + // Check context first + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Try fast path first (no timer needed) + select { + case <-s.tokens: + return nil + default: + } + + // Slow path: need to wait with timeout + timer := semTimers.Get().(*time.Timer) + defer semTimers.Put(timer) + timer.Reset(timeout) + + select { + case <-s.tokens: + if !timer.Stop() { + <-timer.C + } + return nil + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ctx.Err() + case <-timer.C: + return timeoutErr + } +} + +// AcquireBlocking acquires a token, blocking indefinitely until one is available. +func (s *FastSemaphore) AcquireBlocking() { + <-s.tokens +} + +// Release releases a token back to the semaphore. +func (s *FastSemaphore) Release() { + s.tokens <- struct{}{} +} + +// Close closes the semaphore, unblocking all waiting goroutines. +// After close, all Acquire calls will receive a closed channel signal. +func (s *FastSemaphore) Close() { + close(s.tokens) +} + +// Len returns the current number of acquired tokens. +func (s *FastSemaphore) Len() int32 { + return s.max - int32(len(s.tokens)) +} + +// FIFOSemaphore is a channel-based semaphore with strict FIFO ordering. +// Unlike FastSemaphore, this guarantees that threads are served in the exact order they call Acquire(). +// The channel is pre-filled with tokens: Acquire = receive, Release = send. +// Closing the semaphore unblocks all waiting goroutines. +// +// Performance: ~115 ns/op with zero allocations (slower than FastSemaphore due to timer allocation). +// Fairness: Strict FIFO ordering guaranteed by Go runtime. +type FIFOSemaphore struct { + tokens chan struct{} + max int32 +} + +// NewFIFOSemaphore creates a new FIFO semaphore with the given capacity. +func NewFIFOSemaphore(capacity int32) *FIFOSemaphore { + ch := make(chan struct{}, capacity) + // Pre-fill with tokens + for i := int32(0); i < capacity; i++ { + ch <- struct{}{} + } + return &FIFOSemaphore{ + tokens: ch, + max: capacity, + } +} + +// TryAcquire attempts to acquire a token without blocking. +// Returns true if successful, false if no tokens available. +func (s *FIFOSemaphore) TryAcquire() bool { + select { + case <-s.tokens: + return true + default: + return false + } +} + +// Acquire acquires a token, blocking if necessary until one is available. +// Returns an error if the context is cancelled or the timeout expires. +// Always uses timer to guarantee FIFO ordering (no fast path). +func (s *FIFOSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error { + // No fast path - always use timer to guarantee FIFO + timer := semTimers.Get().(*time.Timer) + defer semTimers.Put(timer) + timer.Reset(timeout) + + select { + case <-s.tokens: + if !timer.Stop() { + <-timer.C + } + return nil + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return ctx.Err() + case <-timer.C: + return timeoutErr + } +} + +// AcquireBlocking acquires a token, blocking indefinitely until one is available. +func (s *FIFOSemaphore) AcquireBlocking() { + <-s.tokens +} + +// Release releases a token back to the semaphore. +func (s *FIFOSemaphore) Release() { + s.tokens <- struct{}{} +} + +// Close closes the semaphore, unblocking all waiting goroutines. +// After close, all Acquire calls will receive a closed channel signal. +func (s *FIFOSemaphore) Close() { + close(s.tokens) +} + +// Len returns the current number of acquired tokens. +func (s *FIFOSemaphore) Len() int32 { + return s.max - int32(len(s.tokens)) +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/util.go b/vendor/github.com/redis/go-redis/v9/internal/util.go index cc1bff24e..00516075b 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/util.go +++ b/vendor/github.com/redis/go-redis/v9/internal/util.go @@ -2,6 +2,7 @@ package internal import ( "context" + "math" "net" "strconv" "strings" @@ -10,6 +11,29 @@ import ( "github.com/redis/go-redis/v9/internal/util" ) +// String representations of special float values. +// Values are lowercase for consistency with Redis RESP2 protocol responses. +const ( + NaN = "nan" // Not a Number + Inf = "inf" // Positive infinity + NInf = "-inf" // Negative infinity +) + +// FormatFloat formats a float64 to string, normalizing special values +// (NaN, Inf) to lowercase for consistency with Redis RESP2 protocol. +func FormatFloat(f float64) string { + switch { + case math.IsNaN(f): + return NaN + case math.IsInf(f, 1): + return Inf + case math.IsInf(f, -1): + return NInf + default: + return strconv.FormatFloat(f, 'f', -1, 64) + } +} + func Sleep(ctx context.Context, dur time.Duration) error { t := time.NewTimer(dur) defer t.Stop() @@ -49,22 +73,7 @@ func isLower(s string) bool { } func ReplaceSpaces(s string) string { - // Pre-allocate a builder with the same length as s to minimize allocations. - // This is a basic optimization; adjust the initial size based on your use case. - var builder strings.Builder - builder.Grow(len(s)) - - for _, char := range s { - if char == ' ' { - // Replace space with a hyphen. - builder.WriteRune('-') - } else { - // Copy the character as-is. - builder.WriteRune(char) - } - } - - return builder.String() + return strings.ReplaceAll(s, " ", "-") } func GetAddr(addr string) string { diff --git a/vendor/github.com/redis/go-redis/v9/internal/util/atomic_max.go b/vendor/github.com/redis/go-redis/v9/internal/util/atomic_max.go new file mode 100644 index 000000000..6c621ba85 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/util/atomic_max.go @@ -0,0 +1,97 @@ +/* +© 2023–present Harald Rudell (https://haraldrudell.github.io/haraldrudell/) +ISC License + +Modified by htemelski-redis +Removed the treshold, adapted it to work with float64 +*/ + +package util + +import ( + "math" + + "go.uber.org/atomic" +) + +// AtomicMax is a thread-safe max container +// - hasValue indicator true if a value was equal to or greater than threshold +// - optional threshold for minimum accepted max value +// - if threshold is not used, initialization-free +// - — +// - wait-free CompareAndSwap mechanic +type AtomicMax struct { + + // value is current max + value atomic.Float64 + // whether [AtomicMax.Value] has been invoked + // with value equal or greater to threshold + hasValue atomic.Bool +} + +// NewAtomicMax returns a thread-safe max container +// - if threshold is not used, AtomicMax is initialization-free +func NewAtomicMax() (atomicMax *AtomicMax) { + m := AtomicMax{} + m.value.Store((-math.MaxFloat64)) + return &m +} + +// Value updates the container with a possible max value +// - isNewMax is true if: +// - — value is equal to or greater than any threshold and +// - — invocation recorded the first 0 or +// - — a new max +// - upon return, Max and Max1 are guaranteed to reflect the invocation +// - the return order of concurrent Value invocations is not guaranteed +// - Thread-safe +func (m *AtomicMax) Value(value float64) (isNewMax bool) { + // -math.MaxFloat64 as max case + var hasValue0 = m.hasValue.Load() + if value == (-math.MaxFloat64) { + if !hasValue0 { + isNewMax = m.hasValue.CompareAndSwap(false, true) + } + return // -math.MaxFloat64 as max: isNewMax true for first 0 writer + } + + // check against present value + var current = m.value.Load() + if isNewMax = value > current; !isNewMax { + return // not a new max return: isNewMax false + } + + // store the new max + for { + + // try to write value to *max + if isNewMax = m.value.CompareAndSwap(current, value); isNewMax { + if !hasValue0 { + // may be rarely written multiple times + // still faster than CompareAndSwap + m.hasValue.Store(true) + } + return // new max written return: isNewMax true + } + if current = m.value.Load(); current >= value { + return // no longer a need to write return: isNewMax false + } + } +} + +// Max returns current max and value-present flag +// - hasValue true indicates that value reflects a Value invocation +// - hasValue false: value is zero-value +// - Thread-safe +func (m *AtomicMax) Max() (value float64, hasValue bool) { + if hasValue = m.hasValue.Load(); !hasValue { + return + } + value = m.value.Load() + return +} + +// Max1 returns current maximum whether zero-value or set by Value +// - threshold is ignored +// - Thread-safe +func (m *AtomicMax) Max1() (value float64) { return m.value.Load() } diff --git a/vendor/github.com/redis/go-redis/v9/internal/util/atomic_min.go b/vendor/github.com/redis/go-redis/v9/internal/util/atomic_min.go new file mode 100644 index 000000000..e33d29cc2 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/util/atomic_min.go @@ -0,0 +1,96 @@ +package util + +/* +© 2023–present Harald Rudell (https://haraldrudell.github.io/haraldrudell/) +ISC License + +Modified by htemelski-redis +Adapted from the modified atomic_max, but with inverted logic +*/ + +import ( + "math" + + "go.uber.org/atomic" +) + +// AtomicMin is a thread-safe Min container +// - hasValue indicator true if a value was equal to or greater than threshold +// - optional threshold for minimum accepted Min value +// - — +// - wait-free CompareAndSwap mechanic +type AtomicMin struct { + + // value is current Min + value atomic.Float64 + // whether [AtomicMin.Value] has been invoked + // with value equal or greater to threshold + hasValue atomic.Bool +} + +// NewAtomicMin returns a thread-safe Min container +// - if threshold is not used, AtomicMin is initialization-free +func NewAtomicMin() (atomicMin *AtomicMin) { + m := AtomicMin{} + m.value.Store(math.MaxFloat64) + return &m +} + +// Value updates the container with a possible Min value +// - isNewMin is true if: +// - — value is equal to or greater than any threshold and +// - — invocation recorded the first 0 or +// - — a new Min +// - upon return, Min and Min1 are guaranteed to reflect the invocation +// - the return order of concurrent Value invocations is not guaranteed +// - Thread-safe +func (m *AtomicMin) Value(value float64) (isNewMin bool) { + // math.MaxFloat64 as Min case + var hasValue0 = m.hasValue.Load() + if value == math.MaxFloat64 { + if !hasValue0 { + isNewMin = m.hasValue.CompareAndSwap(false, true) + } + return // math.MaxFloat64 as Min: isNewMin true for first 0 writer + } + + // check against present value + var current = m.value.Load() + if isNewMin = value < current; !isNewMin { + return // not a new Min return: isNewMin false + } + + // store the new Min + for { + + // try to write value to *Min + if isNewMin = m.value.CompareAndSwap(current, value); isNewMin { + if !hasValue0 { + // may be rarely written multiple times + // still faster than CompareAndSwap + m.hasValue.Store(true) + } + return // new Min written return: isNewMin true + } + if current = m.value.Load(); current <= value { + return // no longer a need to write return: isNewMin false + } + } +} + +// Min returns current min and value-present flag +// - hasValue true indicates that value reflects a Value invocation +// - hasValue false: value is zero-value +// - Thread-safe +func (m *AtomicMin) Min() (value float64, hasValue bool) { + if hasValue = m.hasValue.Load(); !hasValue { + return + } + value = m.value.Load() + return +} + +// Min1 returns current Minimum whether zero-value or set by Value +// - threshold is ignored +// - Thread-safe +func (m *AtomicMin) Min1() (value float64) { return m.value.Load() } diff --git a/vendor/github.com/redis/go-redis/v9/internal/util/convert.go b/vendor/github.com/redis/go-redis/v9/internal/util/convert.go new file mode 100644 index 000000000..b743a4f0e --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/internal/util/convert.go @@ -0,0 +1,41 @@ +package util + +import ( + "fmt" + "math" + "strconv" +) + +// ParseFloat parses a Redis RESP3 float reply into a Go float64, +// handling "inf", "-inf", "nan" per Redis conventions. +func ParseStringToFloat(s string) (float64, error) { + switch s { + case "inf": + return math.Inf(1), nil + case "-inf": + return math.Inf(-1), nil + case "nan", "-nan": + return math.NaN(), nil + } + return strconv.ParseFloat(s, 64) +} + +// MustParseFloat is like ParseFloat but panics on parse errors. +func MustParseFloat(s string) float64 { + f, err := ParseStringToFloat(s) + if err != nil { + panic(fmt.Sprintf("redis: failed to parse float %q: %v", s, err)) + } + return f +} + +// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur. +func SafeIntToInt32(value int, fieldName string) (int32, error) { + if value > math.MaxInt32 { + return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32) + } + if value < math.MinInt32 { + return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32) + } + return int32(value), nil +} diff --git a/vendor/github.com/redis/go-redis/v9/internal/util/unsafe.go b/vendor/github.com/redis/go-redis/v9/internal/util/unsafe.go index cbcd2cc09..f4c3c3f33 100644 --- a/vendor/github.com/redis/go-redis/v9/internal/util/unsafe.go +++ b/vendor/github.com/redis/go-redis/v9/internal/util/unsafe.go @@ -8,15 +8,10 @@ import ( // BytesToString converts byte slice to string. func BytesToString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) + return unsafe.String(unsafe.SliceData(b), len(b)) } // StringToBytes converts string to byte slice. func StringToBytes(s string) []byte { - return *(*[]byte)(unsafe.Pointer( - &struct { - string - Cap int - }{s, len(s)}, - )) + return unsafe.Slice(unsafe.StringData(s), len(s)) } diff --git a/vendor/github.com/redis/go-redis/v9/json.go b/vendor/github.com/redis/go-redis/v9/json.go index b3cadf4b7..2bcad0b79 100644 --- a/vendor/github.com/redis/go-redis/v9/json.go +++ b/vendor/github.com/redis/go-redis/v9/json.go @@ -35,6 +35,7 @@ type JSONCmdable interface { JSONObjLen(ctx context.Context, key, path string) *IntPointerSliceCmd JSONSet(ctx context.Context, key, path string, value interface{}) *StatusCmd JSONSetMode(ctx context.Context, key, path string, value interface{}, mode string) *StatusCmd + JSONSetWithArgs(ctx context.Context, key, path string, value interface{}, options *JSONSetArgsOptions) *StatusCmd JSONStrAppend(ctx context.Context, key, path, value string) *IntPointerSliceCmd JSONStrLen(ctx context.Context, key, path string) *IntPointerSliceCmd JSONToggle(ctx context.Context, key, path string) *IntPointerSliceCmd @@ -57,6 +58,25 @@ type JSONArrTrimArgs struct { Stop *int } +// FPHAType is the floating-point type used for storing FP homogeneous arrays +// in JSON.SET (Redis 8.8+). +type FPHAType string + +const ( + FPHATypeBF16 FPHAType = "BF16" + FPHATypeFP16 FPHAType = "FP16" + FPHATypeFP32 FPHAType = "FP32" + FPHATypeFP64 FPHAType = "FP64" +) + +// JSONSetArgsOptions are the optional arguments for JSONSetWithArgs. +// Mode is "NX" or "XX" (case-insensitive). FPHA, when set, forces Redis to +// store all FP homogeneous arrays using the specified floating-point type. +type JSONSetArgsOptions struct { + Mode string + FPHA FPHAType +} + type JSONCmd struct { baseCmd val string @@ -68,8 +88,9 @@ var _ Cmder = (*JSONCmd)(nil) func newJSONCmd(ctx context.Context, args ...interface{}) *JSONCmd { return &JSONCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeJSON, }, } } @@ -82,6 +103,7 @@ func (cmd *JSONCmd) SetVal(val string) { cmd.val = val } +// Val returns the result of the JSON.GET command as a string. func (cmd *JSONCmd) Val() string { if len(cmd.val) == 0 && cmd.expanded != nil { val, err := json.Marshal(cmd.expanded) @@ -100,6 +122,7 @@ func (cmd *JSONCmd) Result() (string, error) { return cmd.Val(), cmd.Err() } +// Expanded returns the result of the JSON.GET command as unmarshalled JSON. func (cmd *JSONCmd) Expanded() (interface{}, error) { if len(cmd.val) != 0 && cmd.expanded == nil { err := json.Unmarshal([]byte(cmd.val), &cmd.expanded) @@ -113,11 +136,17 @@ func (cmd *JSONCmd) Expanded() (interface{}, error) { func (cmd *JSONCmd) readReply(rd *proto.Reader) error { // nil response from JSON.(M)GET (cmd.baseCmd.err will be "redis: nil") + // This happens when the key doesn't exist if cmd.baseCmd.Err() == Nil { cmd.val = "" return Nil } + // Handle other base command errors + if cmd.baseCmd.Err() != nil { + return cmd.baseCmd.Err() + } + if readType, err := rd.PeekReplyType(); err != nil { return err } else if readType == proto.RespArray { @@ -127,6 +156,13 @@ func (cmd *JSONCmd) readReply(rd *proto.Reader) error { return err } + // Empty array means no results found for JSON path, but key exists + // This should return "[]", not an error + if size == 0 { + cmd.val = "[]" + return nil + } + expanded := make([]interface{}, size) for i := 0; i < size; i++ { @@ -141,6 +177,7 @@ func (cmd *JSONCmd) readReply(rd *proto.Reader) error { return err } else if str == "" || err == Nil { cmd.val = "" + return Nil } else { cmd.val = str } @@ -149,6 +186,14 @@ func (cmd *JSONCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *JSONCmd) Clone() Cmder { + return &JSONCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, + expanded: cmd.expanded, // interface{} can be shared as it should be immutable after parsing + } +} + // ------------------------------------------- type JSONSliceCmd struct { @@ -159,8 +204,9 @@ type JSONSliceCmd struct { func NewJSONSliceCmd(ctx context.Context, args ...interface{}) *JSONSliceCmd { return &JSONSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeJSONSlice, }, } } @@ -217,6 +263,18 @@ func (cmd *JSONSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *JSONSliceCmd) Clone() Cmder { + var val []interface{} + if cmd.val != nil { + val = make([]interface{}, len(cmd.val)) + copy(val, cmd.val) + } + return &JSONSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + /******************************************************************************* * * IntPointerSliceCmd @@ -233,8 +291,9 @@ type IntPointerSliceCmd struct { func NewIntPointerSliceCmd(ctx context.Context, args ...interface{}) *IntPointerSliceCmd { return &IntPointerSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeIntPointerSlice, }, } } @@ -274,6 +333,18 @@ func (cmd *IntPointerSliceCmd) readReply(rd *proto.Reader) error { return nil } +func (cmd *IntPointerSliceCmd) Clone() Cmder { + var val []*int64 + if cmd.val != nil { + val = make([]*int64, len(cmd.val)) + copy(val, cmd.val) + } + return &IntPointerSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + //------------------------------------------------------------------------------ // JSONArrAppend adds the provided JSON values to the end of the array at the given path. @@ -533,6 +604,15 @@ func (c cmdable) JSONSet(ctx context.Context, key, path string, value interface{ // the argument is a string or []byte when we assume that it can be passed directly as JSON. // For more information, see https://redis.io/commands/json.set func (c cmdable) JSONSetMode(ctx context.Context, key, path string, value interface{}, mode string) *StatusCmd { + return c.JSONSetWithArgs(ctx, key, path, value, &JSONSetArgsOptions{Mode: mode}) +} + +// JSONSetWithArgs sets the JSON value at the given path in the given key with optional arguments +// for setting mode (NX/XX) and the FPHA (Floating-Point Homogeneous Array) type used for storing +// FP arrays. The value must be something that can be marshaled to JSON (using encoding/JSON) unless +// the argument is a string or []byte when we assume that it can be passed directly as JSON. +// For more information, see https://redis.io/commands/json.set +func (c cmdable) JSONSetWithArgs(ctx context.Context, key, path string, value interface{}, options *JSONSetArgsOptions) *StatusCmd { var bytes []byte var err error switch v := value.(type) { @@ -544,13 +624,17 @@ func (c cmdable) JSONSetMode(ctx context.Context, key, path string, value interf bytes, err = json.Marshal(v) } args := []interface{}{"JSON.SET", key, path, util.BytesToString(bytes)} - if mode != "" { - switch strings.ToUpper(mode) { - case "XX", "NX": - args = append(args, strings.ToUpper(mode)) - - default: - panic("redis: JSON.SET mode must be NX or XX") + if options != nil { + if options.Mode != "" { + switch strings.ToUpper(options.Mode) { + case "XX", "NX": + args = append(args, strings.ToUpper(options.Mode)) + default: + panic("redis: JSON.SET mode must be NX or XX") + } + } + if options.FPHA != "" { + args = append(args, "FPHA", string(options.FPHA)) } } cmd := NewStatusCmd(ctx, args...) diff --git a/vendor/github.com/redis/go-redis/v9/list_commands.go b/vendor/github.com/redis/go-redis/v9/list_commands.go index 24a0de081..9d9e16c65 100644 --- a/vendor/github.com/redis/go-redis/v9/list_commands.go +++ b/vendor/github.com/redis/go-redis/v9/list_commands.go @@ -77,6 +77,10 @@ func (c cmdable) BRPop(ctx context.Context, timeout time.Duration, keys ...strin return cmd } +// BRPopLPush pops an element from a list, pushes it to another list and returns it. +// Blocks until an element is available or timeout is reached. +// +// Deprecated: Use BLMove with RIGHT and LEFT arguments instead as of Redis 6.2.0. func (c cmdable) BRPopLPush(ctx context.Context, source, destination string, timeout time.Duration) *StringCmd { cmd := NewStringCmd( ctx, @@ -247,6 +251,10 @@ func (c cmdable) RPopCount(ctx context.Context, key string, count int) *StringSl return cmd } +// RPopLPush atomically returns and removes the last element of the source list, +// and pushes the element as the first element of the destination list. +// +// Deprecated: Use LMove with RIGHT and LEFT arguments instead as of Redis 6.2.0. func (c cmdable) RPopLPush(ctx context.Context, source, destination string) *StringCmd { cmd := NewStringCmd(ctx, "rpoplpush", source, destination) _ = c(ctx, cmd) diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/FEATURES.md b/vendor/github.com/redis/go-redis/v9/maintnotifications/FEATURES.md new file mode 100644 index 000000000..03bbd3918 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/FEATURES.md @@ -0,0 +1,235 @@ +# Maintenance Notifications - FEATURES + +## Overview + +The Maintenance Notifications feature enables seamless Redis connection handoffs during cluster maintenance operations without dropping active connections. This feature leverages Redis RESP3 push notifications to provide zero-downtime maintenance for Redis Enterprise and compatible Redis deployments. + +## Important + +Using Maintenance Notifications may affect the read and write timeouts by relaxing them during maintenance operations. +This is necessary to prevent false failures due to increased latency during handoffs. The relaxed timeouts are automatically applied and removed as needed. + +## Key Features + +### Seamless Connection Handoffs +- **Zero-Downtime Maintenance**: Automatically handles connection transitions during cluster operations +- **Active Operation Preservation**: Transfers in-flight operations to new connections without interruption +- **Graceful Degradation**: Falls back to standard reconnection if handoff fails + +### Push Notification Support +Supports all Redis Enterprise maintenance notification types: +- **MOVING** - Slot moving to a new node +- **MIGRATING** - Slot in migration state +- **MIGRATED** - Migration completed +- **FAILING_OVER** - Node failing over +- **FAILED_OVER** - Failover completed + +### Circuit Breaker Pattern +- **Endpoint-Specific Failure Tracking**: Prevents repeated connection attempts to failing endpoints +- **Automatic Recovery Testing**: Half-open state allows gradual recovery validation +- **Configurable Thresholds**: Customize failure thresholds and reset timeouts + +### Flexible Configuration +- **Auto-Detection Mode**: Automatically detects server support for maintenance notifications +- **Multiple Endpoint Types**: Support for internal/external IP/FQDN endpoint resolution +- **Auto-Scaling Workers**: Automatically sizes worker pool based on connection pool size +- **Timeout Management**: Separate timeouts for relaxed (during maintenance) and normal operations + +### Extensible Hook System +- **Pre/Post Processing Hooks**: Monitor and customize notification handling +- **Built-in Hooks**: Logging and metrics collection hooks included +- **Custom Hook Support**: Implement custom business logic around maintenance events + +### Comprehensive Monitoring +- **Metrics Collection**: Track notification counts, processing times, and error rates +- **Circuit Breaker Stats**: Monitor endpoint health and circuit breaker states +- **Operation Tracking**: Track active handoff operations and their lifecycle + +## Architecture Highlights + +### Event-Driven Handoff System +- **Asynchronous Processing**: Non-blocking handoff operations using worker pool pattern +- **Queue-Based Architecture**: Configurable queue size with auto-scaling support +- **Retry Mechanism**: Configurable retry attempts with exponential backoff + +### Connection Pool Integration +- **Pool Hook Interface**: Seamless integration with go-redis connection pool +- **Connection State Management**: Atomic flags for connection usability tracking +- **Graceful Shutdown**: Ensures all in-flight handoffs complete before shutdown + +### Thread-Safe Design +- **Lock-Free Operations**: Atomic operations for high-performance state tracking +- **Concurrent-Safe Maps**: sync.Map for tracking active operations +- **Minimal Lock Contention**: Read-write locks only where necessary + +## Configuration Options + +### Operation Modes +- **`ModeDisabled`**: Maintenance notifications completely disabled +- **`ModeEnabled`**: Forcefully enabled (fails if server doesn't support) +- **`ModeAuto`**: Auto-detect server support (recommended default) + +### Endpoint Types +- **`EndpointTypeAuto`**: Auto-detect based on current connection +- **`EndpointTypeInternalIP`**: Use internal IP addresses +- **`EndpointTypeInternalFQDN`**: Use internal fully qualified domain names +- **`EndpointTypeExternalIP`**: Use external IP addresses +- **`EndpointTypeExternalFQDN`**: Use external fully qualified domain names +- **`EndpointTypeNone`**: No endpoint (reconnect with current configuration) + +### Timeout Configuration +- **`RelaxedTimeout`**: Extended timeout during maintenance operations (default: 10s) +- **`HandoffTimeout`**: Maximum time for handoff completion (default: 15s) +- **`PostHandoffRelaxedDuration`**: Relaxed period after handoff (default: 2×RelaxedTimeout) + +### Worker Pool Configuration +- **`MaxWorkers`**: Maximum concurrent handoff workers (auto-calculated if 0) +- **`HandoffQueueSize`**: Handoff queue capacity (auto-calculated if 0) +- **`MaxHandoffRetries`**: Maximum retry attempts for failed handoffs (default: 3) + +### Circuit Breaker Configuration +- **`CircuitBreakerFailureThreshold`**: Failures before opening circuit (default: 5) +- **`CircuitBreakerResetTimeout`**: Time before testing recovery (default: 60s) +- **`CircuitBreakerMaxRequests`**: Max requests in half-open state (default: 3) + +## Auto-Scaling Formulas + +### Worker Pool Sizing +When `MaxWorkers = 0` (auto-calculate): +``` +MaxWorkers = min(PoolSize/2, max(10, PoolSize/3)) +``` + +### Queue Sizing +When `HandoffQueueSize = 0` (auto-calculate): +``` +QueueSize = max(20 × MaxWorkers, PoolSize) +Capped by: min(MaxActiveConns + 1, 5 × PoolSize) +``` + +### Examples +- **Pool Size 100**: 33 workers, 660 queue (capped at 500) +- **Pool Size 100 + MaxActiveConns 150**: 33 workers, 151 queue +- **Pool Size 50**: 16 workers, 320 queue (capped at 250) + +## Performance Characteristics + +### Throughput +- **Non-Blocking Handoffs**: Client operations continue during handoffs +- **Concurrent Processing**: Multiple handoffs processed in parallel +- **Minimal Overhead**: Lock-free atomic operations for state tracking + +### Latency +- **Relaxed Timeouts**: Extended timeouts during maintenance prevent false failures +- **Fast Path**: Connections not undergoing handoff have zero overhead +- **Graceful Degradation**: Failed handoffs fall back to standard reconnection + +### Resource Usage +- **Memory Efficient**: Bounded queue sizes prevent memory exhaustion +- **Worker Pool**: Fixed worker count prevents goroutine explosion +- **Connection Reuse**: Handoff reuses existing connection objects + +## Testing + +### Unit Tests +- Comprehensive unit test coverage for all components +- Mock-based testing for isolation +- Concurrent operation testing + +### Integration Tests +- Pool integration tests with real connection handoffs +- Circuit breaker behavior validation +- Hook system integration testing + +### E2E Tests +- Real Redis Enterprise cluster testing +- Multiple scenario coverage (timeouts, endpoint types, stress tests) +- Fault injection testing +- TLS configuration testing + +## Compatibility + +### Requirements +- **Redis Protocol**: RESP3 required for push notifications +- **Redis Version**: Redis Enterprise or compatible Redis with maintenance notifications +- **Go Version**: Go 1.18+ (uses generics and atomic types) + +### Client Support +#### Currently Supported +- **Standalone Client** (`redis.NewClient`) - Full support for MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER notifications +- **Cluster Client** (`redis.NewClusterClient`) - Support for SMIGRATING and SMIGRATED notifications for hitless slot migrations + +#### Will Not Support +- **Failover Client** (no planned support) +- **Ring Client** (no planned support) + +## Migration Guide + +### Enabling Maintenance Notifications (Standalone Client) + +**Before:** +```go +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 2, // RESP2 +}) +``` + +**After:** +```go +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, // RESP3 required + MaintNotificationsConfig: &maintnotifications.Config{ + Mode: maintnotifications.ModeAuto, + }, +}) +``` + +### Enabling Hitless Upgrades (Cluster Client) + +For Redis Cluster with hitless slot migration support: + +```go +client := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{"localhost:7000", "localhost:7001", "localhost:7002"}, + Protocol: 3, // RESP3 required for push notifications + MaintNotificationsConfig: &maintnotifications.Config{ + Mode: maintnotifications.ModeAuto, + RelaxedTimeout: 10 * time.Second, // Extended timeout during slot migrations + }, +}) +``` + +The cluster client automatically handles: +- **SMIGRATING**: Relaxes timeouts when slots are being migrated +- **SMIGRATED**: Triggers lazy cluster state reload when migration completes +- **SeqID Deduplication**: Same notification from multiple nodes triggers only one reload + +### Adding Monitoring + +```go +// Get the manager from the client +manager := client.GetMaintNotificationsManager() +if manager != nil { + // Add logging hook + loggingHook := maintnotifications.NewLoggingHook(2) // Info level + manager.AddNotificationHook(loggingHook) + + // Add metrics hook + metricsHook := maintnotifications.NewMetricsHook() + manager.AddNotificationHook(metricsHook) +} +``` + +## Known Limitations + +1. **RESP3 Required**: Push notifications require RESP3 protocol +2. **Server Support**: Requires Redis Enterprise or compatible Redis with maintenance notifications +3. **Single Connection Commands**: Some commands (MULTI/EXEC, WATCH) may need special handling +4. **No Failover/Ring Client Support**: Failover and Ring clients are not supported and there are no plans to add support + +## Future Enhancements + +- Enhanced metrics and observability +- TTL-based cleanup for SeqID deduplication map \ No newline at end of file diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/README.md b/vendor/github.com/redis/go-redis/v9/maintnotifications/README.md new file mode 100644 index 000000000..2f354ef6a --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/README.md @@ -0,0 +1,73 @@ +# Maintenance Notifications + +Seamless Redis connection handoffs during cluster maintenance operations without dropping connections. + +## Cluster Support + +**Cluster notifications are now supported for ClusterClient!** + +- **SMIGRATING**: `["SMIGRATING", SeqID, slot/range, ...]` - Relaxes timeouts when slots are being migrated +- **SMIGRATED**: `["SMIGRATED", SeqID, src host:port, dst host:port, slot/range, ...]` - Reloads cluster state when slot migration completes + +**Note:** Other maintenance notifications (MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER) are supported only in standalone Redis clients. Cluster clients support SMIGRATING and SMIGRATED for cluster-specific slot migration handling. + +## Quick Start + +```go +client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Protocol: 3, // RESP3 required + MaintNotificationsConfig: &maintnotifications.Config{ + Mode: maintnotifications.ModeEnabled, + }, +}) +``` + +## Modes + +- **`ModeDisabled`** - Maintenance notifications disabled +- **`ModeEnabled`** - Forcefully enabled (fails if server doesn't support) +- **`ModeAuto`** - Auto-detect server support (default) + +## Configuration + +```go +&maintnotifications.Config{ + Mode: maintnotifications.ModeAuto, + EndpointType: maintnotifications.EndpointTypeAuto, + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 15 * time.Second, + MaxHandoffRetries: 3, + MaxWorkers: 0, // Auto-calculated + HandoffQueueSize: 0, // Auto-calculated + PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout +} +``` + +### Endpoint Types + +- **`EndpointTypeAuto`** - Auto-detect based on connection (default) +- **`EndpointTypeInternalIP`** - Internal IP address +- **`EndpointTypeInternalFQDN`** - Internal FQDN +- **`EndpointTypeExternalIP`** - External IP address +- **`EndpointTypeExternalFQDN`** - External FQDN +- **`EndpointTypeNone`** - No endpoint (reconnect with current config) + +### Auto-Scaling + +**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated +**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize` + +**Examples:** +- Pool 100: 33 workers, 660 queue (capped at 500) +- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue + +## How It Works + +1. Redis sends push notifications about cluster maintenance operations +2. Client creates new connections to updated endpoints +3. Active operations transfer to new connections +4. Old connections close gracefully + + +## For more information, see [FEATURES](FEATURES.md) diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/circuit_breaker.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/circuit_breaker.go new file mode 100644 index 000000000..cb76b6447 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/circuit_breaker.go @@ -0,0 +1,353 @@ +package maintnotifications + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" +) + +// CircuitBreakerState represents the state of a circuit breaker +type CircuitBreakerState int32 + +const ( + // CircuitBreakerClosed - normal operation, requests allowed + CircuitBreakerClosed CircuitBreakerState = iota + // CircuitBreakerOpen - failing fast, requests rejected + CircuitBreakerOpen + // CircuitBreakerHalfOpen - testing if service recovered + CircuitBreakerHalfOpen +) + +func (s CircuitBreakerState) String() string { + switch s { + case CircuitBreakerClosed: + return "closed" + case CircuitBreakerOpen: + return "open" + case CircuitBreakerHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling +type CircuitBreaker struct { + // Configuration + failureThreshold int // Number of failures before opening + resetTimeout time.Duration // How long to stay open before testing + maxRequests int // Max requests allowed in half-open state + + // State tracking (atomic for lock-free access) + state atomic.Int32 // CircuitBreakerState + failures atomic.Int64 // Current failure count + successes atomic.Int64 // Success count in half-open state + requests atomic.Int64 // Request count in half-open state + lastFailureTime atomic.Int64 // Unix timestamp of last failure + lastSuccessTime atomic.Int64 // Unix timestamp of last success + + // Endpoint identification + endpoint string + config *Config +} + +// newCircuitBreaker creates a new circuit breaker for an endpoint +func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker { + // Use configuration values with sensible defaults + failureThreshold := 5 + resetTimeout := 60 * time.Second + maxRequests := 3 + + if config != nil { + failureThreshold = config.CircuitBreakerFailureThreshold + resetTimeout = config.CircuitBreakerResetTimeout + maxRequests = config.CircuitBreakerMaxRequests + } + + return &CircuitBreaker{ + failureThreshold: failureThreshold, + resetTimeout: resetTimeout, + maxRequests: maxRequests, + endpoint: endpoint, + config: config, + state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0) + } +} + +// IsOpen returns true if the circuit breaker is open (rejecting requests) +func (cb *CircuitBreaker) IsOpen() bool { + state := CircuitBreakerState(cb.state.Load()) + return state == CircuitBreakerOpen +} + +// shouldAttemptReset checks if enough time has passed to attempt reset +func (cb *CircuitBreaker) shouldAttemptReset() bool { + lastFailure := time.Unix(cb.lastFailureTime.Load(), 0) + return time.Since(lastFailure) >= cb.resetTimeout +} + +// Execute runs the given function with circuit breaker protection +func (cb *CircuitBreaker) Execute(fn func() error) error { + // Single atomic state load for consistency + state := CircuitBreakerState(cb.state.Load()) + + switch state { + case CircuitBreakerOpen: + if cb.shouldAttemptReset() { + // Attempt transition to half-open + if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) { + cb.requests.Store(0) + cb.successes.Store(0) + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint)) + } + // Fall through to half-open logic + } else { + return ErrCircuitBreakerOpen + } + } else { + return ErrCircuitBreakerOpen + } + fallthrough + case CircuitBreakerHalfOpen: + requests := cb.requests.Add(1) + if requests > int64(cb.maxRequests) { + cb.requests.Add(-1) // Revert the increment + return ErrCircuitBreakerOpen + } + } + + // Execute the function with consistent state + err := fn() + + if err != nil { + cb.recordFailure() + return err + } + + cb.recordSuccess() + return nil +} + +// recordFailure records a failure and potentially opens the circuit +func (cb *CircuitBreaker) recordFailure() { + cb.lastFailureTime.Store(time.Now().Unix()) + failures := cb.failures.Add(1) + + state := CircuitBreakerState(cb.state.Load()) + + switch state { + case CircuitBreakerClosed: + if failures >= int64(cb.failureThreshold) { + if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) { + if internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures)) + } + } + } + case CircuitBreakerHalfOpen: + // Any failure in half-open state immediately opens the circuit + if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) { + if internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint)) + } + } + } +} + +// recordSuccess records a success and potentially closes the circuit +func (cb *CircuitBreaker) recordSuccess() { + cb.lastSuccessTime.Store(time.Now().Unix()) + + state := CircuitBreakerState(cb.state.Load()) + + switch state { + case CircuitBreakerClosed: + // Reset failure count on success in closed state + cb.failures.Store(0) + case CircuitBreakerHalfOpen: + successes := cb.successes.Add(1) + + // If we've had enough successful requests, close the circuit + if successes >= int64(cb.maxRequests) { + if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) { + cb.failures.Store(0) + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes)) + } + } + } + } +} + +// GetState returns the current state of the circuit breaker +func (cb *CircuitBreaker) GetState() CircuitBreakerState { + return CircuitBreakerState(cb.state.Load()) +} + +// GetStats returns current statistics for monitoring +func (cb *CircuitBreaker) GetStats() CircuitBreakerStats { + return CircuitBreakerStats{ + Endpoint: cb.endpoint, + State: cb.GetState(), + Failures: cb.failures.Load(), + Successes: cb.successes.Load(), + Requests: cb.requests.Load(), + LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0), + LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0), + } +} + +// CircuitBreakerStats provides statistics about a circuit breaker +type CircuitBreakerStats struct { + Endpoint string + State CircuitBreakerState + Failures int64 + Successes int64 + Requests int64 + LastFailureTime time.Time + LastSuccessTime time.Time +} + +// CircuitBreakerEntry wraps a circuit breaker with access tracking +type CircuitBreakerEntry struct { + breaker *CircuitBreaker + lastAccess atomic.Int64 // Unix timestamp + created time.Time +} + +// CircuitBreakerManager manages circuit breakers for multiple endpoints +type CircuitBreakerManager struct { + breakers sync.Map // map[string]*CircuitBreakerEntry + config *Config + cleanupStop chan struct{} + cleanupMu sync.Mutex + lastCleanup atomic.Int64 // Unix timestamp +} + +// newCircuitBreakerManager creates a new circuit breaker manager +func newCircuitBreakerManager(config *Config) *CircuitBreakerManager { + cbm := &CircuitBreakerManager{ + config: config, + cleanupStop: make(chan struct{}), + } + cbm.lastCleanup.Store(time.Now().Unix()) + + // Start background cleanup goroutine + go cbm.cleanupLoop() + + return cbm +} + +// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary +func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker { + now := time.Now().Unix() + + if entry, ok := cbm.breakers.Load(endpoint); ok { + cbEntry := entry.(*CircuitBreakerEntry) + cbEntry.lastAccess.Store(now) + return cbEntry.breaker + } + + // Create new circuit breaker with metadata + newBreaker := newCircuitBreaker(endpoint, cbm.config) + newEntry := &CircuitBreakerEntry{ + breaker: newBreaker, + created: time.Now(), + } + newEntry.lastAccess.Store(now) + + actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry) + return actual.(*CircuitBreakerEntry).breaker +} + +// GetAllStats returns statistics for all circuit breakers +func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats { + var stats []CircuitBreakerStats + cbm.breakers.Range(func(key, value interface{}) bool { + entry := value.(*CircuitBreakerEntry) + stats = append(stats, entry.breaker.GetStats()) + return true + }) + return stats +} + +// cleanupLoop runs background cleanup of unused circuit breakers +func (cbm *CircuitBreakerManager) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes + defer ticker.Stop() + + for { + select { + case <-ticker.C: + cbm.cleanup() + case <-cbm.cleanupStop: + return + } + } +} + +// cleanup removes circuit breakers that haven't been accessed recently +func (cbm *CircuitBreakerManager) cleanup() { + // Prevent concurrent cleanups + if !cbm.cleanupMu.TryLock() { + return + } + defer cbm.cleanupMu.Unlock() + + now := time.Now() + cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL + + var toDelete []string + count := 0 + + cbm.breakers.Range(func(key, value interface{}) bool { + endpoint := key.(string) + entry := value.(*CircuitBreakerEntry) + + count++ + + // Remove if not accessed recently + if entry.lastAccess.Load() < cutoff { + toDelete = append(toDelete, endpoint) + } + + return true + }) + + // Delete expired entries + for _, endpoint := range toDelete { + cbm.breakers.Delete(endpoint) + } + + // Log cleanup results + if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count)) + } + + cbm.lastCleanup.Store(now.Unix()) +} + +// Shutdown stops the cleanup goroutine +func (cbm *CircuitBreakerManager) Shutdown() { + close(cbm.cleanupStop) +} + +// Reset resets all circuit breakers (useful for testing) +func (cbm *CircuitBreakerManager) Reset() { + cbm.breakers.Range(func(key, value interface{}) bool { + entry := value.(*CircuitBreakerEntry) + breaker := entry.breaker + breaker.state.Store(int32(CircuitBreakerClosed)) + breaker.failures.Store(0) + breaker.successes.Store(0) + breaker.requests.Store(0) + breaker.lastFailureTime.Store(0) + breaker.lastSuccessTime.Store(0) + return true + }) +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/config.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/config.go new file mode 100644 index 000000000..70d5acdca --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/config.go @@ -0,0 +1,502 @@ +package maintnotifications + +import ( + "context" + "net" + "runtime" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" +) + +// Mode represents the maintenance notifications mode +type Mode string + +// Constants for maintenance push notifications modes +const ( + ModeDisabled Mode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command + ModeEnabled Mode = "enabled" // Client forcefully sends command, interrupts connection on error + ModeAuto Mode = "auto" // Client tries to send command, disables feature on error +) + +// IsValid returns true if the maintenance notifications mode is valid +func (m Mode) IsValid() bool { + switch m { + case ModeDisabled, ModeEnabled, ModeAuto: + return true + default: + return false + } +} + +// String returns the string representation of the mode +func (m Mode) String() string { + return string(m) +} + +// EndpointType represents the type of endpoint to request in MOVING notifications +type EndpointType string + +// Constants for endpoint types +const ( + EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection + EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address + EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN + EndpointTypeExternalIP EndpointType = "external-ip" // External IP address + EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN + EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config) +) + +// IsValid returns true if the endpoint type is valid +func (e EndpointType) IsValid() bool { + switch e { + case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone: + return true + default: + return false + } +} + +// String returns the string representation of the endpoint type +func (e EndpointType) String() string { + return string(e) +} + +// Config provides configuration options for maintenance notifications +type Config struct { + // Mode controls how client maintenance notifications are handled. + // Valid values: ModeDisabled, ModeEnabled, ModeAuto + // Default: ModeAuto + Mode Mode + + // EndpointType specifies the type of endpoint to request in MOVING notifications. + // Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + // EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone + // Default: EndpointTypeAuto + EndpointType EndpointType + + // RelaxedTimeout is the concrete timeout value to use during + // MIGRATING/FAILING_OVER states to accommodate increased latency. + // This applies to both read and write timeouts. + // Default: 10 seconds + RelaxedTimeout time.Duration + + // HandoffTimeout is the maximum time to wait for connection handoff to complete. + // If handoff takes longer than this, the old connection will be forcibly closed. + // Default: 15 seconds (matches server-side eviction timeout) + HandoffTimeout time.Duration + + // MaxWorkers is the maximum number of worker goroutines for processing handoff requests. + // Workers are created on-demand and automatically cleaned up when idle. + // If zero, defaults to min(10, PoolSize/2) to handle bursts effectively. + // If explicitly set, enforces minimum of PoolSize/2 + // + // Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2 + MaxWorkers int + + // HandoffQueueSize is the size of the buffered channel used to queue handoff requests. + // If the queue is full, new handoff requests will be rejected. + // Scales with both worker count and pool size for better burst handling. + // + // Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize + // When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize + HandoffQueueSize int + + // PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection + // after a handoff completes. This provides additional resilience during cluster transitions. + // Default: 2 * RelaxedTimeout + PostHandoffRelaxedDuration time.Duration + + // Circuit breaker configuration for endpoint failure handling + // CircuitBreakerFailureThreshold is the number of failures before opening the circuit. + // Default: 5 + CircuitBreakerFailureThreshold int + + // CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered. + // Default: 60 seconds + CircuitBreakerResetTimeout time.Duration + + // CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state. + // Default: 3 + CircuitBreakerMaxRequests int + + // MaxHandoffRetries is the maximum number of times to retry a failed handoff. + // After this many retries, the connection will be removed from the pool. + // Default: 3 + MaxHandoffRetries int +} + +func (c *Config) IsEnabled() bool { + return c != nil && c.Mode != ModeDisabled +} + +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() *Config { + return &Config{ + Mode: ModeAuto, // Enable by default for Redis Cloud + EndpointType: EndpointTypeAuto, // Auto-detect based on connection + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 15 * time.Second, + MaxWorkers: 0, // Auto-calculated based on pool size + HandoffQueueSize: 0, // Auto-calculated based on max workers + PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout + + // Circuit breaker configuration + CircuitBreakerFailureThreshold: 5, + CircuitBreakerResetTimeout: 60 * time.Second, + CircuitBreakerMaxRequests: 3, + + // Connection Handoff Configuration + MaxHandoffRetries: 3, + } +} + +// Validate checks if the configuration is valid. +func (c *Config) Validate() error { + if c.RelaxedTimeout <= 0 { + return ErrInvalidRelaxedTimeout + } + if c.HandoffTimeout <= 0 { + return ErrInvalidHandoffTimeout + } + // Validate worker configuration + // Allow 0 for auto-calculation, but negative values are invalid + if c.MaxWorkers < 0 { + return ErrInvalidHandoffWorkers + } + // HandoffQueueSize validation - allow 0 for auto-calculation + if c.HandoffQueueSize < 0 { + return ErrInvalidHandoffQueueSize + } + if c.PostHandoffRelaxedDuration < 0 { + return ErrInvalidPostHandoffRelaxedDuration + } + + // Circuit breaker validation + if c.CircuitBreakerFailureThreshold < 1 { + return ErrInvalidCircuitBreakerFailureThreshold + } + if c.CircuitBreakerResetTimeout < 0 { + return ErrInvalidCircuitBreakerResetTimeout + } + if c.CircuitBreakerMaxRequests < 1 { + return ErrInvalidCircuitBreakerMaxRequests + } + + // Validate Mode (maintenance notifications mode) + if !c.Mode.IsValid() { + return ErrInvalidMaintNotifications + } + + // Validate EndpointType + if !c.EndpointType.IsValid() { + return ErrInvalidEndpointType + } + + // Validate configuration fields + if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 { + return ErrInvalidHandoffRetries + } + + return nil +} + +// ApplyDefaults applies default values to any zero-value fields in the configuration. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaults() *Config { + return c.ApplyDefaultsWithPoolSize(0) +} + +// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration, +// using the provided pool size to calculate worker defaults. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { + return c.ApplyDefaultsWithPoolConfig(poolSize, 0) +} + +// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration, +// using the provided pool size and max active connections to calculate worker and queue defaults. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config { + if c == nil { + return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize) + } + + defaults := DefaultConfig() + result := &Config{} + + // Apply defaults for enum fields (empty/zero means not set) + result.Mode = defaults.Mode + if c.Mode != "" { + result.Mode = c.Mode + } + + result.EndpointType = defaults.EndpointType + if c.EndpointType != "" { + result.EndpointType = c.EndpointType + } + + // Apply defaults for duration fields (zero means not set) + result.RelaxedTimeout = defaults.RelaxedTimeout + if c.RelaxedTimeout > 0 { + result.RelaxedTimeout = c.RelaxedTimeout + } + + result.HandoffTimeout = defaults.HandoffTimeout + if c.HandoffTimeout > 0 { + result.HandoffTimeout = c.HandoffTimeout + } + + // Copy worker configuration + result.MaxWorkers = c.MaxWorkers + + // Apply worker defaults based on pool size + result.applyWorkerDefaults(poolSize) + + // Apply queue size defaults with new scaling approach + // Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size + workerBasedSize := result.MaxWorkers * 20 + poolBasedSize := poolSize + result.HandoffQueueSize = max(workerBasedSize, poolBasedSize) + if c.HandoffQueueSize > 0 { + // When explicitly set: enforce minimum of 200 + result.HandoffQueueSize = max(200, c.HandoffQueueSize) + } + + // Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size + var queueCap int + if maxActiveConns > 0 { + queueCap = maxActiveConns + 1 + // Ensure queue cap is at least 2 for very small maxActiveConns + if queueCap < 2 { + queueCap = 2 + } + } else { + queueCap = poolSize * 5 + } + result.HandoffQueueSize = min(result.HandoffQueueSize, queueCap) + + // Ensure minimum queue size of 2 (fallback for very small pools) + if result.HandoffQueueSize < 2 { + result.HandoffQueueSize = 2 + } + + result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2 + if c.PostHandoffRelaxedDuration > 0 { + result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration + } + + // Apply defaults for configuration fields + result.MaxHandoffRetries = defaults.MaxHandoffRetries + if c.MaxHandoffRetries > 0 { + result.MaxHandoffRetries = c.MaxHandoffRetries + } + + // Circuit breaker configuration + result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold + if c.CircuitBreakerFailureThreshold > 0 { + result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold + } + + result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout + if c.CircuitBreakerResetTimeout > 0 { + result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout + } + + result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests + if c.CircuitBreakerMaxRequests > 0 { + result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests + } + + if internal.LogLevel.DebugOrAbove() { + internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled()) + internal.Logger.Printf(context.Background(), logs.ConfigDebug(result)) + } + return result +} + +// Clone creates a deep copy of the configuration. +func (c *Config) Clone() *Config { + if c == nil { + return DefaultConfig() + } + + return &Config{ + Mode: c.Mode, + EndpointType: c.EndpointType, + RelaxedTimeout: c.RelaxedTimeout, + HandoffTimeout: c.HandoffTimeout, + MaxWorkers: c.MaxWorkers, + HandoffQueueSize: c.HandoffQueueSize, + PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration, + + // Circuit breaker configuration + CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold, + CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout, + CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests, + + // Configuration fields + MaxHandoffRetries: c.MaxHandoffRetries, + } +} + +// applyWorkerDefaults calculates and applies worker defaults based on pool size +func (c *Config) applyWorkerDefaults(poolSize int) { + // Calculate defaults based on pool size + if poolSize <= 0 { + poolSize = 10 * runtime.GOMAXPROCS(0) + } + + // When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach + originalMaxWorkers := c.MaxWorkers + c.MaxWorkers = min(poolSize/2, max(10, poolSize/3)) + if originalMaxWorkers != 0 { + // When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers + c.MaxWorkers = max(poolSize/2, originalMaxWorkers) + } + + // Ensure minimum of 1 worker (fallback for very small pools) + if c.MaxWorkers < 1 { + c.MaxWorkers = 1 + } +} + +// endpointDetectResolveTimeout bounds the DNS lookup performed by +// DetectEndpointType so a slow or broken resolver cannot block client +// construction for the full system resolver timeout (often 5-30s). +const endpointDetectResolveTimeout = 2 * time.Second + +// cgnatNet is RFC6598 shared address space (100.64.0.0/10), used by many +// cloud/carrier NATs and not covered by net.IP.IsPrivate. +var cgnatNet = &net.IPNet{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)} + +// isPrivateIP reports whether ip belongs to a range that should be treated +// as "internal" for the purpose of endpoint type detection. It extends +// net.IP.IsPrivate (RFC1918 + RFC4193) with loopback, link-local and +// RFC6598 shared address space (CGNAT). +func isPrivateIP(ip net.IP) bool { + if ip == nil { + return false + } + if ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() { + return true + } + if v4 := ip.To4(); v4 != nil && cgnatNet.Contains(v4) { + return true + } + return false +} + +// DetectEndpointType automatically detects the appropriate endpoint type +// based on the connection address and TLS configuration. +// +// TLS behaviour: +// - If TLS is enabled: requests FQDN for proper certificate validation +// (SNI / hostname verification). +// - If TLS is disabled: always requests IP for better performance, even +// when the configured address is a hostname. In that case the hostname +// is resolved to determine whether it belongs to an internal or +// external network range. +// +// Internal vs External detection: +// - For IPs: uses private IP range detection +// - For hostnames: resolves the hostname to an IP address and uses the IP range detection +func DetectEndpointType(addr string, tlsEnabled bool) EndpointType { + // Extract host from "host:port" format + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr // Assume no port + } + + // An empty host (e.g., ":6379") conventionally means the loopback + // interface and is treated as internal. With TLS off we return an IP + // endpoint; with TLS on the caller still needs an FQDN for SNI. + if host == "" { + if tlsEnabled { + return EndpointTypeInternalFQDN + } + return EndpointTypeInternalIP + } + + // Check if the host is an IP address or hostname + ip := net.ParseIP(host) + isIPAddress := ip != nil + var endpointType EndpointType + + if isIPAddress { + // Address is an IP - determine if it's private or public + isPrivate := isPrivateIP(ip) + + if tlsEnabled { + // TLS with IP addresses - still prefer FQDN for certificate validation + if isPrivate { + endpointType = EndpointTypeInternalFQDN + } else { + endpointType = EndpointTypeExternalFQDN + } + } else { + // No TLS - can use IP addresses directly + if isPrivate { + endpointType = EndpointTypeInternalIP + } else { + endpointType = EndpointTypeExternalIP + } + } + } else { + // Address is a hostname - resolve it under a bounded timeout so a + // slow/broken DNS server cannot stall client construction. + ctx, cancel := context.WithTimeout(context.Background(), endpointDetectResolveTimeout) + defer cancel() + + isInternal, err := isInternalHostname(ctx, host) + // Will fallback to external classification if we can't determine + // whether the hostname is internal. + if err != nil && internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(ctx, "Failed to determine if hostname %q is internal: %v", host, err) + } + + if tlsEnabled { + // With TLS the server name must be preserved for certificate + // validation, so request an FQDN endpoint. + if isInternal { + endpointType = EndpointTypeInternalFQDN + } else { + endpointType = EndpointTypeExternalFQDN + } + } else { + // Without TLS we always prefer IP endpoints for performance, + // even if the configured address is a hostname. + if isInternal { + endpointType = EndpointTypeInternalIP + } else { + endpointType = EndpointTypeExternalIP + } + } + } + + return endpointType +} + +// isInternalHostname resolves the hostname (both IPv4 and IPv6) under the +// given context and reports whether every resolved address is in a +// private/internal range. If any address is public the hostname is treated +// as external. A resolution error returns (false, err). An empty result set +// returns (false, nil); callers are expected to fall back to an external +// classification when the hostname cannot be determined to be internal. +func isInternalHostname(ctx context.Context, hostname string) (bool, error) { + ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname) + if err != nil { + return false, err + } + if len(ips) == 0 { + return false, nil + } + for _, ia := range ips { + if !isPrivateIP(ia.IP) { + return false, nil + } + } + return true, nil +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/errors.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/errors.go new file mode 100644 index 000000000..049656bdd --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/errors.go @@ -0,0 +1,76 @@ +package maintnotifications + +import ( + "errors" + + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" +) + +// Configuration errors +var ( + ErrInvalidRelaxedTimeout = errors.New(logs.InvalidRelaxedTimeoutError()) + ErrInvalidHandoffTimeout = errors.New(logs.InvalidHandoffTimeoutError()) + ErrInvalidHandoffWorkers = errors.New(logs.InvalidHandoffWorkersError()) + ErrInvalidHandoffQueueSize = errors.New(logs.InvalidHandoffQueueSizeError()) + ErrInvalidPostHandoffRelaxedDuration = errors.New(logs.InvalidPostHandoffRelaxedDurationError()) + ErrInvalidEndpointType = errors.New(logs.InvalidEndpointTypeError()) + ErrInvalidMaintNotifications = errors.New(logs.InvalidMaintNotificationsError()) + ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError()) + + // Configuration validation errors + + // ErrInvalidHandoffRetries is returned when the number of handoff retries is invalid + ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError()) +) + +// Integration errors +var ( + // ErrInvalidClient is returned when the client does not support push notifications + ErrInvalidClient = errors.New(logs.InvalidClientError()) +) + +// Handoff errors +var ( + // ErrHandoffQueueFull is returned when the handoff queue is full + ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError()) +) + +// Notification errors +var ( + // ErrInvalidNotification is returned when a notification is in an invalid format + ErrInvalidNotification = errors.New(logs.InvalidNotificationError()) +) + +// connection handoff errors +var ( + // ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff + // and should not be used until the handoff is complete + ErrConnectionMarkedForHandoff = errors.New(logs.ConnectionMarkedForHandoffErrorMessage) + // ErrConnectionMarkedForHandoffWithState is returned when a connection is marked for handoff + // and should not be used until the handoff is complete + ErrConnectionMarkedForHandoffWithState = errors.New(logs.ConnectionMarkedForHandoffErrorMessage + " with state") + // ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff + ErrConnectionInvalidHandoffState = errors.New(logs.ConnectionInvalidHandoffStateErrorMessage) +) + +// shutdown errors +var ( + // ErrShutdown is returned when the maintnotifications manager is shutdown + ErrShutdown = errors.New(logs.ShutdownError()) +) + +// circuit breaker errors +var ( + // ErrCircuitBreakerOpen is returned when the circuit breaker is open + ErrCircuitBreakerOpen = errors.New(logs.CircuitBreakerOpenErrorMessage) +) + +// circuit breaker configuration errors +var ( + // ErrInvalidCircuitBreakerFailureThreshold is returned when the circuit breaker failure threshold is invalid + ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError()) + // ErrInvalidCircuitBreakerResetTimeout is returned when the circuit breaker reset timeout is invalid + ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError()) + // ErrInvalidCircuitBreakerMaxRequests is returned when the circuit breaker max requests is invalid + ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError()) +) diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/example_hooks.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/example_hooks.go new file mode 100644 index 000000000..3a3465571 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/example_hooks.go @@ -0,0 +1,101 @@ +package maintnotifications + +import ( + "context" + "fmt" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const ( + startTimeKey contextKey = "maint_notif_start_time" +) + +// MetricsHook collects metrics about notification processing. +type MetricsHook struct { + NotificationCounts map[string]int64 + ProcessingTimes map[string]time.Duration + ErrorCounts map[string]int64 + HandoffCounts int64 // Total handoffs initiated + HandoffSuccesses int64 // Successful handoffs + HandoffFailures int64 // Failed handoffs +} + +// NewMetricsHook creates a new metrics collection hook. +func NewMetricsHook() *MetricsHook { + return &MetricsHook{ + NotificationCounts: make(map[string]int64), + ProcessingTimes: make(map[string]time.Duration), + ErrorCounts: make(map[string]int64), + } +} + +// PreHook records the start time for processing metrics. +func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + mh.NotificationCounts[notificationType]++ + + // Log connection information if available + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + internal.Logger.Printf(ctx, logs.MetricsHookProcessingNotification(notificationType, conn.GetID())) + } + + // Store start time in context for duration calculation + startTime := time.Now() + _ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further + + return notification, true +} + +// PostHook records processing completion and any errors. +func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + // Calculate processing duration + if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok { + duration := time.Since(startTime) + mh.ProcessingTimes[notificationType] = duration + } + + // Record errors + if result != nil { + mh.ErrorCounts[notificationType]++ + + // Log error details with connection information + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + internal.Logger.Printf(ctx, logs.MetricsHookRecordedError(notificationType, conn.GetID(), result)) + } + } +} + +// GetMetrics returns a summary of collected metrics. +func (mh *MetricsHook) GetMetrics() map[string]interface{} { + return map[string]interface{}{ + "notification_counts": mh.NotificationCounts, + "processing_times": mh.ProcessingTimes, + "error_counts": mh.ErrorCounts, + } +} + +// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status +func ExampleCircuitBreakerMonitor(poolHook *PoolHook) { + // Get circuit breaker statistics + stats := poolHook.GetCircuitBreakerStats() + + for _, stat := range stats { + fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint) + fmt.Printf(" State: %s\n", stat.State) + fmt.Printf(" Failures: %d\n", stat.Failures) + fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime) + fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime) + + // Alert if circuit breaker is open + if stat.State.String() == "open" { + fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint) + } + } +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/handoff_worker.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/handoff_worker.go new file mode 100644 index 000000000..d66542ffc --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/handoff_worker.go @@ -0,0 +1,525 @@ +package maintnotifications + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/internal/pool" +) + +// PoolNameMain is the name used for the main connection pool in metrics. +const PoolNameMain = "main" + +// handoffWorkerManager manages background workers and queue for connection handoffs +type handoffWorkerManager struct { + // Event-driven handoff support + handoffQueue chan HandoffRequest // Queue for handoff requests + shutdown chan struct{} // Shutdown signal + shutdownOnce sync.Once // Ensure clean shutdown + workerWg sync.WaitGroup // Track worker goroutines + + // On-demand worker management + maxWorkers int + activeWorkers atomic.Int32 + workerTimeout time.Duration // How long workers wait for work before exiting + workersScaling atomic.Bool + + // Simple state tracking + pending sync.Map // map[uint64]int64 (connID -> seqID) + + // Configuration for the maintenance notifications + config *Config + + // Pool hook reference for handoff processing + poolHook *PoolHook + + // Circuit breaker manager for endpoint failure handling + circuitBreakerManager *CircuitBreakerManager +} + +// newHandoffWorkerManager creates a new handoff worker manager +func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager { + return &handoffWorkerManager{ + handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize), + shutdown: make(chan struct{}), + maxWorkers: config.MaxWorkers, + activeWorkers: atomic.Int32{}, // Start with no workers - create on demand + workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity + config: config, + poolHook: poolHook, + circuitBreakerManager: newCircuitBreakerManager(config), + } +} + +// getCurrentWorkers returns the current number of active workers (for testing) +func (hwm *handoffWorkerManager) getCurrentWorkers() int { + return int(hwm.activeWorkers.Load()) +} + +// getPendingMap returns the pending map for testing purposes +func (hwm *handoffWorkerManager) getPendingMap() *sync.Map { + return &hwm.pending +} + +// getMaxWorkers returns the max workers for testing purposes +func (hwm *handoffWorkerManager) getMaxWorkers() int { + return hwm.maxWorkers +} + +// getHandoffQueue returns the handoff queue for testing purposes +func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest { + return hwm.handoffQueue +} + +// getCircuitBreakerStats returns circuit breaker statistics for monitoring +func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats { + return hwm.circuitBreakerManager.GetAllStats() +} + +// resetCircuitBreakers resets all circuit breakers (useful for testing) +func (hwm *handoffWorkerManager) resetCircuitBreakers() { + hwm.circuitBreakerManager.Reset() +} + +// isHandoffPending returns true if the given connection has a pending handoff +func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool { + _, pending := hwm.pending.Load(conn.GetID()) + return pending +} + +// ensureWorkerAvailable ensures at least one worker is available to process requests +// Creates a new worker if needed and under the max limit +func (hwm *handoffWorkerManager) ensureWorkerAvailable() { + select { + case <-hwm.shutdown: + return + default: + if hwm.workersScaling.CompareAndSwap(false, true) { + defer hwm.workersScaling.Store(false) + // Check if we need a new worker + currentWorkers := hwm.activeWorkers.Load() + workersWas := currentWorkers + for currentWorkers < int32(hwm.maxWorkers) { + hwm.workerWg.Add(1) + go hwm.onDemandWorker() + currentWorkers++ + } + // workersWas is always <= currentWorkers + // currentWorkers will be maxWorkers, but if we have a worker that was closed + // while we were creating new workers, just add the difference between + // the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created) + hwm.activeWorkers.Add(currentWorkers - workersWas) + } + } +} + +// onDemandWorker processes handoff requests and exits when idle +func (hwm *handoffWorkerManager) onDemandWorker() { + defer func() { + // Handle panics to ensure proper cleanup + if r := recover(); r != nil { + internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r)) + } + + // Decrement active worker count when exiting + hwm.activeWorkers.Add(-1) + hwm.workerWg.Done() + }() + + // Create reusable timer to prevent timer leaks + timer := time.NewTimer(hwm.workerTimeout) + defer timer.Stop() + + for { + // Reset timer for next iteration + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + timer.Reset(hwm.workerTimeout) + + select { + case <-hwm.shutdown: + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown()) + } + return + case <-timer.C: + // Worker has been idle for too long, exit to save resources + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout)) + } + return + case request := <-hwm.handoffQueue: + // Check for shutdown before processing + select { + case <-hwm.shutdown: + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing()) + } + // Clean up the request before exiting + hwm.pending.Delete(request.ConnID) + return + default: + // Process the request + hwm.processHandoffRequest(request) + } + } + } +} + +// processHandoffRequest processes a single handoff request +func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint)) + } + + // Create a context with handoff timeout from config + handoffTimeout := 15 * time.Second // Default timeout + if hwm.config != nil && hwm.config.HandoffTimeout > 0 { + handoffTimeout = hwm.config.HandoffTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout) + defer cancel() + + // Create a context that also respects the shutdown signal + shutdownCtx, shutdownCancel := context.WithCancel(ctx) + defer shutdownCancel() + + // Monitor shutdown signal in a separate goroutine + go func() { + select { + case <-hwm.shutdown: + shutdownCancel() + case <-shutdownCtx.Done(): + } + }() + + // Perform the handoff with cancellable context + shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn) + minRetryBackoff := 500 * time.Millisecond + if err != nil { + if shouldRetry { + now := time.Now() + deadline, ok := shutdownCtx.Deadline() + thirdOfTimeout := handoffTimeout / 3 + if !ok || deadline.Before(now) { + // wait half the timeout before retrying if no deadline or deadline has passed + deadline = now.Add(thirdOfTimeout) + } + afterTime := deadline.Sub(now) + if afterTime < minRetryBackoff { + afterTime = minRetryBackoff + } + + if internal.LogLevel.InfoOrAbove() { + // Get current retry count for better logging + currentRetries := request.Conn.HandoffRetries() + maxRetries := 3 // Default fallback + if hwm.config != nil { + maxRetries = hwm.config.MaxHandoffRetries + } + internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err)) + } + // Schedule retry - keep connection in pending map until retry is queued + time.AfterFunc(afterTime, func() { + if err := hwm.queueHandoff(request.Conn); err != nil { + if internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err)) + } + // Failed to queue retry - remove from pending and close connection + hwm.pending.Delete(request.Conn.GetID()) + hwm.closeConnFromRequest(context.Background(), request, err) + } else { + // Successfully queued retry - remove from pending (will be re-added by queueHandoff) + hwm.pending.Delete(request.Conn.GetID()) + } + }) + return + } else { + // Won't retry - remove from pending and close connection + hwm.pending.Delete(request.Conn.GetID()) + go hwm.closeConnFromRequest(ctx, request, err) + } + + // Clear handoff state if not returned for retry + seqID := request.Conn.GetMovingSeqID() + connID := request.Conn.GetID() + if hwm.poolHook.operationsManager != nil { + hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID) + } + } else { + // Success - remove from pending map + hwm.pending.Delete(request.Conn.GetID()) + } +} + +// queueHandoff queues a handoff request for processing +// if err is returned, connection will be removed from pool +func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { + // Get handoff info atomically to prevent race conditions + shouldHandoff, endpoint, seqID := conn.GetHandoffInfo() + + // on retries the connection will not be marked for handoff, but it will have retries > 0 + // if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff + if !shouldHandoff && conn.HandoffRetries() == 0 { + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID())) + } + return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID())) + } + + // Create handoff request with atomically retrieved data + request := HandoffRequest{ + Conn: conn, + ConnID: conn.GetID(), + Endpoint: endpoint, + SeqID: seqID, + Pool: hwm.poolHook.pool, // Include pool for connection removal on failure + } + + select { + // priority to shutdown + case <-hwm.shutdown: + return ErrShutdown + default: + select { + case <-hwm.shutdown: + return ErrShutdown + case hwm.handoffQueue <- request: + // Store in pending map + hwm.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + hwm.ensureWorkerAvailable() + return nil + default: + select { + case <-hwm.shutdown: + return ErrShutdown + case hwm.handoffQueue <- request: + // Store in pending map + hwm.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + hwm.ensureWorkerAvailable() + return nil + case <-time.After(100 * time.Millisecond): // give workers a chance to process + // Queue is full - log and attempt scaling + queueLen := len(hwm.handoffQueue) + queueCap := cap(hwm.handoffQueue) + if internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap)) + } + } + } + } + + // Ensure we have workers available to handle the load + hwm.ensureWorkerAvailable() + return ErrHandoffQueueFull +} + +// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete +func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error { + hwm.shutdownOnce.Do(func() { + close(hwm.shutdown) + // workers will exit when they finish their current request + + // Shutdown circuit breaker manager cleanup goroutine + if hwm.circuitBreakerManager != nil { + hwm.circuitBreakerManager.Shutdown() + } + }) + + // Wait for workers to complete + done := make(chan struct{}) + go func() { + hwm.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// performConnectionHandoff performs the actual connection handoff +// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached +func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) { + // Clear handoff state after successful handoff + connID := conn.GetID() + + newEndpoint := conn.GetHandoffEndpoint() + if newEndpoint == "" { + return false, ErrConnectionInvalidHandoffState + } + + // Use circuit breaker to protect against failing endpoints + circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint) + + // Check if circuit breaker is open before attempting handoff + if circuitBreaker.IsOpen() { + internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint)) + return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open + } + + // Perform the handoff + shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID) + + // Update circuit breaker based on result + if err != nil { + // Only track dial/network errors in circuit breaker, not initialization errors + if shouldRetry { + circuitBreaker.recordFailure() + } + return shouldRetry, err + } + + // Success - record in circuit breaker + circuitBreaker.recordSuccess() + return false, nil +} + +// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration) +func (hwm *handoffWorkerManager) performHandoffInternal( + ctx context.Context, + conn *pool.Conn, + newEndpoint string, + connID uint64, +) (shouldRetry bool, err error) { + retries := conn.IncrementAndGetHandoffRetries(1) + internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) + maxRetries := 3 // Default fallback + if hwm.config != nil { + maxRetries = hwm.config.MaxHandoffRetries + } + + if retries > maxRetries { + if internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries)) + } + // won't retry on ErrMaxHandoffRetriesReached + return false, ErrMaxHandoffRetriesReached + } + + // Create endpoint-specific dialer + endpointDialer := hwm.createEndpointDialer(newEndpoint) + + // Create new connection to the new endpoint + newNetConn, err := endpointDialer(ctx) + if err != nil { + internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err)) + // will retry + // Maybe a network error - retry after a delay + return true, err + } + + // Get the old connection + oldConn := conn.GetNetConn() + + // Apply relaxed timeout to the new connection for the configured post-handoff duration + // This gives the new connection more time to handle operations during cluster transition + // Setting this here (before initing the connection) ensures that the connection is going + // to use the relaxed timeout for the first operation (auth/ACL select) + if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 { + relaxedTimeout := hwm.config.RelaxedTimeout + // Set relaxed timeout with deadline - no background goroutine needed + deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration) + conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) + + // Record relaxed timeout metric (post-handoff) + if relaxedTimeoutCallback := pool.GetMetricConnectionRelaxedTimeoutCallback(); relaxedTimeoutCallback != nil { + relaxedTimeoutCallback(ctx, 1, conn, PoolNameMain, "HANDOFF") + } + + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000"))) + } + } + + // Replace the connection and execute initialization + err = conn.SetNetConnAndInitConn(ctx, newNetConn) + if err != nil { + // won't retry + // Initialization failed - remove the connection + return false, err + } + defer func() { + if oldConn != nil { + oldConn.Close() + } + }() + + // Clear handoff state will: + // - set the connection as usable again + // - clear the handoff state (shouldHandoff, endpoint, seqID) + // - reset the handoff retries to 0 + // Note: Theoretically there may be a short window where the connection is in the pool + // and IDLE (initConn completed) but still has handoff state set. + conn.ClearHandoffState() + internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) + + // successfully completed the handoff, no retry needed and no error + // Notify metrics: connection handoff succeeded + if handoffCallback := pool.GetMetricConnectionHandoffCallback(); handoffCallback != nil { + handoffCallback(ctx, conn, PoolNameMain) + } + + return false, nil +} + +// createEndpointDialer creates a dialer function that connects to a specific endpoint +func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + // Parse endpoint to extract host and port + host, port, err := net.SplitHostPort(endpoint) + if err != nil { + // If no port specified, assume default Redis port + host = endpoint + if port == "" { + port = "6379" + } + } + + // Use the base dialer to connect to the new endpoint + return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port)) + } +} + +// closeConnFromRequest closes the connection and logs the reason +func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) { + pooler := request.Pool + conn := request.Conn + + // Clear handoff state before closing + conn.ClearHandoffState() + + if pooler != nil { + // Use RemoveWithoutTurn instead of Remove to avoid freeing a turn that we don't have. + // The handoff worker doesn't call Get(), so it doesn't have a turn to free. + // Remove() is meant to be called after Get() and frees a turn. + // RemoveWithoutTurn() removes and closes the connection without affecting the queue. + pooler.RemoveWithoutTurn(ctx, conn, err) + if internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) + } + } else { + errClose := conn.Close() // Close the connection if no pool provided + if errClose != nil { + internal.Logger.Printf(ctx, "redis: failed to close connection: %v", errClose) + } + if internal.LogLevel.WarnOrAbove() { + internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) + } + } +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/hooks.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/hooks.go new file mode 100644 index 000000000..ee3c3819c --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/hooks.go @@ -0,0 +1,60 @@ +package maintnotifications + +import ( + "context" + "slices" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// LoggingHook is an example hook implementation that logs all notifications. +type LoggingHook struct { + LogLevel int // 0=Error, 1=Warn, 2=Info, 3=Debug +} + +// PreHook logs the notification before processing and allows modification. +func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + if lh.LogLevel >= 2 { // Info level + // Log the notification type and content + connID := uint64(0) + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + connID = conn.GetID() + } + seqID := int64(0) + if slices.Contains(maintenanceNotificationTypes, notificationType) { + // seqID is the second element in the notification array + if len(notification) > 1 { + if parsedSeqID, ok := notification[1].(int64); !ok { + seqID = 0 + } else { + seqID = parsedSeqID + } + } + + } + internal.Logger.Printf(ctx, logs.ProcessingNotification(connID, seqID, notificationType, notification)) + } + return notification, true // Continue processing with unmodified notification +} + +// PostHook logs the result after processing. +func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + connID := uint64(0) + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + connID = conn.GetID() + } + if result != nil && lh.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, logs.ProcessingNotificationFailed(connID, notificationType, result, notification)) + } else if lh.LogLevel >= 3 { // Debug level + internal.Logger.Printf(ctx, logs.ProcessingNotificationSucceeded(connID, notificationType)) + } +} + +// NewLoggingHook creates a new logging hook with the specified log level. +// Log levels: 0=Error, 1=Warn, 2=Info, 3=Debug +func NewLoggingHook(logLevel int) *LoggingHook { + return &LoggingHook{LogLevel: logLevel} +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/manager.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/manager.go new file mode 100644 index 000000000..3f9478e1b --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/manager.go @@ -0,0 +1,362 @@ +package maintnotifications + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// Push notification type constants for maintenance +const ( + NotificationMoving = "MOVING" // Per-connection handoff notification + NotificationMigrating = "MIGRATING" // Per-connection migration start notification - relaxes timeouts + NotificationMigrated = "MIGRATED" // Per-connection migration complete notification - clears relaxed timeouts + NotificationFailingOver = "FAILING_OVER" // Per-connection failover start notification - relaxes timeouts + NotificationFailedOver = "FAILED_OVER" // Per-connection failover complete notification - clears relaxed timeouts + NotificationSMigrating = "SMIGRATING" // Cluster slot migrating notification - relaxes timeouts + NotificationSMigrated = "SMIGRATED" // Cluster slot migrated notification - unrelaxes timeouts and triggers cluster state reload +) + +// maintenanceNotificationTypes contains all notification types that maintenance handles +var maintenanceNotificationTypes = []string{ + NotificationMoving, + NotificationMigrating, + NotificationMigrated, + NotificationFailingOver, + NotificationFailedOver, + NotificationSMigrating, + NotificationSMigrated, +} + +// NotificationHook is called before and after notification processing +// PreHook can modify the notification and return false to skip processing +// PostHook is called after successful processing +type NotificationHook interface { + PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) + PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) +} + +// MovingOperationKey provides a unique key for tracking MOVING operations +// that combines sequence ID with connection identifier to handle duplicate +// sequence IDs across multiple connections to the same node. +type MovingOperationKey struct { + SeqID int64 // Sequence ID from MOVING notification + ConnID uint64 // Unique connection identifier +} + +// String returns a string representation of the key for debugging +func (k MovingOperationKey) String() string { + return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID) +} + +// Manager provides a simplified upgrade functionality with hooks and atomic state. +type Manager struct { + client interfaces.ClientInterface + config *Config + options interfaces.OptionsInterface + pool pool.Pooler + + // MOVING operation tracking - using sync.Map for better concurrent performance + activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation + + // SMIGRATED notification deduplication - tracks processed SeqIDs + // Multiple connections may receive the same SMIGRATED notification + processedSMigratedSeqIDs sync.Map // map[int64]bool + + // Atomic state tracking - no locks needed for state queries + activeOperationCount atomic.Int64 // Number of active operations + closed atomic.Bool // Manager closed state + + // Notification hooks for extensibility + hooks []NotificationHook + hooksMu sync.RWMutex // Protects hooks slice + poolHooksRef *PoolHook + + // Cluster state reload callback for SMIGRATED notifications + clusterStateReloadCallback ClusterStateReloadCallback +} + +// MovingOperation tracks an active MOVING operation. +type MovingOperation struct { + SeqID int64 + NewEndpoint string + StartTime time.Time + Deadline time.Time +} + +// ClusterStateReloadCallback is a callback function that triggers cluster state reload. +// This is used by node clients to notify their parent ClusterClient about SMIGRATED notifications. +// The hostPort parameter indicates the destination node (e.g., "127.0.0.1:6379"). +// The slotRanges parameter contains the migrated slots (e.g., ["1234", "5000-6000"]). +// Currently, implementations typically reload the entire cluster state, but in the future +// this could be optimized to reload only the specific slots. +type ClusterStateReloadCallback func(ctx context.Context, hostPort string, slotRanges []string) + +// NewManager creates a new simplified manager. +func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) { + if client == nil { + return nil, ErrInvalidClient + } + + hm := &Manager{ + client: client, + pool: pool, + options: client.GetOptions(), + config: config.Clone(), + hooks: make([]NotificationHook, 0), + } + + // Set up push notification handling + if err := hm.setupPushNotifications(); err != nil { + return nil, err + } + + return hm, nil +} + +// GetPoolHook creates a pool hook with a custom dialer. +func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) { + poolHook := hm.createPoolHook(baseDialer) + hm.pool.AddPoolHook(poolHook) +} + +// setupPushNotifications sets up push notification handling by registering with the client's processor. +func (hm *Manager) setupPushNotifications() error { + processor := hm.client.GetPushProcessor() + if processor == nil { + return ErrInvalidClient // Client doesn't support push notifications + } + + // Create our notification handler + handler := &NotificationHandler{manager: hm, operationsManager: hm} + + // Register handlers for all upgrade notifications with the client's processor + for _, notificationType := range maintenanceNotificationTypes { + if err := processor.RegisterHandler(notificationType, handler, true); err != nil { + return errors.New(logs.FailedToRegisterHandler(notificationType, err)) + } + } + + return nil +} + +// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID. +func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error { + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Create MOVING operation record + movingOp := &MovingOperation{ + SeqID: seqID, + NewEndpoint: newEndpoint, + StartTime: time.Now(), + Deadline: deadline, + } + + // Use LoadOrStore for atomic check-and-set operation + if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded { + // Duplicate MOVING notification, ignore + if internal.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID)) + } + return nil + } + if internal.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID)) + } + + // Increment active operation count atomically + hm.activeOperationCount.Add(1) + + return nil +} + +// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID. +func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) { + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Remove from active operations atomically + if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded { + if internal.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID)) + } + // Decrement active operation count only if operation existed + hm.activeOperationCount.Add(-1) + } else { + if internal.LogLevel.DebugOrAbove() { // Debug level + internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID)) + } + } +} + +// GetActiveMovingOperations returns active operations with composite keys. +// WARNING: This method creates a new map and copies all operations on every call. +// Use sparingly, especially in hot paths or high-frequency logging. +func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation { + result := make(map[MovingOperationKey]*MovingOperation) + + // Iterate over sync.Map to build result + hm.activeMovingOps.Range(func(key, value interface{}) bool { + k := key.(MovingOperationKey) + op := value.(*MovingOperation) + + // Create a copy to avoid sharing references + result[k] = &MovingOperation{ + SeqID: op.SeqID, + NewEndpoint: op.NewEndpoint, + StartTime: op.StartTime, + Deadline: op.Deadline, + } + return true // Continue iteration + }) + + return result +} + +// IsHandoffInProgress returns true if any handoff is in progress. +// Uses atomic counter for lock-free operation. +func (hm *Manager) IsHandoffInProgress() bool { + return hm.activeOperationCount.Load() > 0 +} + +// GetActiveOperationCount returns the number of active operations. +// Uses atomic counter for lock-free operation. +func (hm *Manager) GetActiveOperationCount() int64 { + return hm.activeOperationCount.Load() +} + +// MarkSMigratedSeqIDProcessed attempts to mark a SMIGRATED SeqID as processed. +// Returns true if this is the first time processing this SeqID (should process), +// false if it was already processed (should skip). +// This prevents duplicate processing when multiple connections receive the same notification. +func (hm *Manager) MarkSMigratedSeqIDProcessed(seqID int64) bool { + _, alreadyProcessed := hm.processedSMigratedSeqIDs.LoadOrStore(seqID, true) + return !alreadyProcessed // Return true if NOT already processed +} + +// Close closes the manager. +func (hm *Manager) Close() error { + // Use atomic operation for thread-safe close check + if !hm.closed.CompareAndSwap(false, true) { + return nil // Already closed + } + + // Shutdown the pool hook if it exists + if hm.poolHooksRef != nil { + // Use a timeout to prevent hanging indefinitely + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := hm.poolHooksRef.Shutdown(shutdownCtx) + if err != nil { + // was not able to close pool hook, keep closed state false + hm.closed.Store(false) + return err + } + // Remove the pool hook from the pool + if hm.pool != nil { + hm.pool.RemovePoolHook(hm.poolHooksRef) + } + } + + // Clear all active operations + hm.activeMovingOps.Range(func(key, value interface{}) bool { + hm.activeMovingOps.Delete(key) + return true + }) + + // Reset counter + hm.activeOperationCount.Store(0) + + return nil +} + +// GetState returns current state using atomic counter for lock-free operation. +func (hm *Manager) GetState() State { + if hm.activeOperationCount.Load() > 0 { + return StateMoving + } + return StateIdle +} + +// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing. +func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + hm.hooksMu.RLock() + defer hm.hooksMu.RUnlock() + + currentNotification := notification + + for _, hook := range hm.hooks { + modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification) + if !shouldContinue { + return modifiedNotification, false + } + currentNotification = modifiedNotification + } + + return currentNotification, true +} + +// processPostHooks calls all post-hooks with the processing result. +func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + hm.hooksMu.RLock() + defer hm.hooksMu.RUnlock() + + for _, hook := range hm.hooks { + hook.PostHook(ctx, notificationCtx, notificationType, notification, result) + } +} + +// createPoolHook creates a pool hook with this manager already set. +func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook { + if hm.poolHooksRef != nil { + return hm.poolHooksRef + } + // Get pool size from client options for better worker defaults + poolSize := 0 + if hm.options != nil { + poolSize = hm.options.GetPoolSize() + } + + hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize) + hm.poolHooksRef.SetPool(hm.pool) + + return hm.poolHooksRef +} + +func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) { + hm.hooksMu.Lock() + defer hm.hooksMu.Unlock() + hm.hooks = append(hm.hooks, notificationHook) +} + +// SetClusterStateReloadCallback sets the callback function that will be called when a SMIGRATED notification is received. +// This allows node clients to notify their parent ClusterClient to reload cluster state. +func (hm *Manager) SetClusterStateReloadCallback(callback ClusterStateReloadCallback) { + hm.clusterStateReloadCallback = callback +} + +// TriggerClusterStateReload calls the cluster state reload callback if it's set. +// This is called when a SMIGRATED notification is received. +func (hm *Manager) TriggerClusterStateReload(ctx context.Context, hostPort string, slotRanges []string) { + if hm.clusterStateReloadCallback != nil { + hm.clusterStateReloadCallback(ctx, hostPort, slotRanges) + } +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/pool_hook.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/pool_hook.go new file mode 100644 index 000000000..9ea0558bf --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/pool_hook.go @@ -0,0 +1,182 @@ +package maintnotifications + +import ( + "context" + "net" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/internal/pool" +) + +// OperationsManagerInterface defines the interface for completing handoff operations +type OperationsManagerInterface interface { + TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error + UntrackOperationWithConnID(seqID int64, connID uint64) +} + +// HandoffRequest represents a request to handoff a connection to a new endpoint +type HandoffRequest struct { + Conn *pool.Conn + ConnID uint64 // Unique connection identifier + Endpoint string + SeqID int64 + Pool pool.Pooler // Pool to remove connection from on failure +} + +// PoolHook implements pool.PoolHook for Redis-specific connection handling +// with maintenance notifications support. +type PoolHook struct { + // Base dialer for creating connections to new endpoints during handoffs + // args are network and address + baseDialer func(context.Context, string, string) (net.Conn, error) + + // Network type (e.g., "tcp", "unix") + network string + + // Worker manager for background handoff processing + workerManager *handoffWorkerManager + + // Configuration for the maintenance notifications + config *Config + + // Operations manager interface for operation completion tracking + operationsManager OperationsManagerInterface + + // Pool interface for removing connections on handoff failure + pool pool.Pooler +} + +// NewPoolHook creates a new pool hook +func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface) *PoolHook { + return NewPoolHookWithPoolSize(baseDialer, network, config, operationsManager, 0) +} + +// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults +func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface, poolSize int) *PoolHook { + // Apply defaults if config is nil or has zero values + if config == nil { + config = config.ApplyDefaultsWithPoolSize(poolSize) + } + + ph := &PoolHook{ + // baseDialer is used to create connections to new endpoints during handoffs + baseDialer: baseDialer, + network: network, + config: config, + operationsManager: operationsManager, + } + + // Create worker manager + ph.workerManager = newHandoffWorkerManager(config, ph) + + return ph +} + +// SetPool sets the pool interface for removing connections on handoff failure +func (ph *PoolHook) SetPool(pooler pool.Pooler) { + ph.pool = pooler +} + +// GetCurrentWorkers returns the current number of active workers (for testing) +func (ph *PoolHook) GetCurrentWorkers() int { + return ph.workerManager.getCurrentWorkers() +} + +// IsHandoffPending returns true if the given connection has a pending handoff +func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool { + return ph.workerManager.isHandoffPending(conn) +} + +// GetPendingMap returns the pending map for testing purposes +func (ph *PoolHook) GetPendingMap() *sync.Map { + return ph.workerManager.getPendingMap() +} + +// GetMaxWorkers returns the max workers for testing purposes +func (ph *PoolHook) GetMaxWorkers() int { + return ph.workerManager.getMaxWorkers() +} + +// GetHandoffQueue returns the handoff queue for testing purposes +func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest { + return ph.workerManager.getHandoffQueue() +} + +// GetCircuitBreakerStats returns circuit breaker statistics for monitoring +func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats { + return ph.workerManager.getCircuitBreakerStats() +} + +// ResetCircuitBreakers resets all circuit breakers (useful for testing) +func (ph *PoolHook) ResetCircuitBreakers() { + ph.workerManager.resetCircuitBreakers() +} + +// OnGet is called when a connection is retrieved from the pool +func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) { + // Check if connection is marked for handoff + // This prevents using connections that have received MOVING notifications + if conn.ShouldHandoff() { + return false, ErrConnectionMarkedForHandoffWithState + } + + // Check if connection is usable (not in UNUSABLE or CLOSED state) + // This ensures we don't return connections that are currently being handed off or re-authenticated. + if !conn.IsUsable() { + return false, ErrConnectionMarkedForHandoff + } + + return true, nil +} + +// OnPut is called when a connection is returned to the pool +func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) { + // first check if we should handoff for faster rejection + if !conn.ShouldHandoff() { + // Default behavior (no handoff): pool the connection + return true, false, nil + } + + // check pending handoff to not queue the same connection twice + if ph.workerManager.isHandoffPending(conn) { + // Default behavior (pending handoff): pool the connection + return true, false, nil + } + + if err := ph.workerManager.queueHandoff(conn); err != nil { + // Failed to queue handoff, remove the connection + internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err)) + // Don't pool, remove connection, no error to caller + return false, true, nil + } + + // Check if handoff was already processed by a worker before we can mark it as queued + if !conn.ShouldHandoff() { + // Handoff was already processed - this is normal and the connection should be pooled + return true, false, nil + } + + if err := conn.MarkQueuedForHandoff(); err != nil { + // If marking fails, check if handoff was processed in the meantime + if !conn.ShouldHandoff() { + // Handoff was processed - this is normal, pool the connection + return true, false, nil + } + // Other error - remove the connection + return false, true, nil + } + internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID())) + return true, false, nil +} + +func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) { + // Not used +} + +// Shutdown gracefully shuts down the processor, waiting for workers to complete +func (ph *PoolHook) Shutdown(ctx context.Context) error { + return ph.workerManager.shutdownWorkers(ctx) +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/push_notification_handler.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/push_notification_handler.go new file mode 100644 index 000000000..7108265b2 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/push_notification_handler.go @@ -0,0 +1,524 @@ +package maintnotifications + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// NotificationHandler handles push notifications for the simplified manager. +type NotificationHandler struct { + manager *Manager + operationsManager OperationsManagerInterface +} + +// HandlePushNotification processes push notifications with hook support. +func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) == 0 { + internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification)) + return ErrInvalidNotification + } + + notificationType, ok := notification[0].(string) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0])) + return ErrInvalidNotification + } + + // Process pre-hooks - they can modify the notification or skip processing + modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification) + if !shouldContinue { + return nil // Hooks decided to skip processing + } + + var err error + switch notificationType { + case NotificationMoving: + err = snh.handleMoving(ctx, handlerCtx, modifiedNotification) + case NotificationMigrating: + err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification) + case NotificationMigrated: + err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification) + case NotificationFailingOver: + err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification) + case NotificationFailedOver: + err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification) + case NotificationSMigrating: + err = snh.handleSMigrating(ctx, handlerCtx, modifiedNotification) + case NotificationSMigrated: + err = snh.handleSMigrated(ctx, handlerCtx, modifiedNotification) + default: + // Ignore other notification types (e.g., pub/sub messages) + err = nil + } + + // Record maintenance notification metric + if maintenanceCallback := pool.GetMetricMaintenanceNotificationCallback(); maintenanceCallback != nil { + if conn, ok := handlerCtx.Conn.(*pool.Conn); ok { + maintenanceCallback(ctx, conn, notificationType) + } + } + + // Process post-hooks with the result + snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err) + + return err +} + +// handleMoving processes MOVING notifications. +// MOVING indicates that a connection should be handed off to a new endpoint. +// This is a per-connection notification that triggers connection handoff. +// Expected format: ["MOVING", seqNum, timeS, endpoint] +func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification)) + return ErrInvalidNotification + } + seqID, ok := notification[1].(int64) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1])) + return ErrInvalidNotification + } + + // Extract timeS + timeS, ok := notification[2].(int64) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2])) + return ErrInvalidNotification + } + + newEndpoint := "" + if len(notification) > 3 { + // Extract new endpoint + newEndpoint, ok = notification[3].(string) + if !ok { + stringified := fmt.Sprintf("%v", notification[3]) + // this could be which is valid + if notification[3] == nil || stringified == internal.RedisNull { + newEndpoint = "" + } else { + internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3])) + return ErrInvalidNotification + } + } + } + + // Get the connection that received this notification + conn := handlerCtx.Conn + if conn == nil { + internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING")) + return ErrInvalidNotification + } + + // Type assert to get the underlying pool connection + var poolConn *pool.Conn + if pc, ok := conn.(*pool.Conn); ok { + poolConn = pc + } else { + internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx)) + return ErrInvalidNotification + } + + // If the connection is closed or not pooled, we can ignore the notification + // this connection won't be remembered by the pool and will be garbage collected + // Keep pubsub connections around since they are not pooled but are long-lived + // and should be allowed to handoff (the pubsub instance will reconnect and change + // the underlying *pool.Conn) + if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() { + return nil + } + + deadline := time.Now().Add(time.Duration(timeS) * time.Second) + // If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds + if newEndpoint == "" || newEndpoint == internal.RedisNull { + if internal.LogLevel.DebugOrAbove() { + internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2)) + } + // same as current endpoint + newEndpoint = snh.manager.options.GetAddr() + // delay the handoff for timeS/2 seconds to the same endpoint + // do this in a goroutine to avoid blocking the notification handler + // NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff + // and there should be no possibility of a race condition or double handoff. + time.AfterFunc(time.Duration(timeS/2)*time.Second, func() { + if poolConn == nil || poolConn.IsClosed() { + return + } + if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil { + // Log error but don't fail the goroutine - use background context since original may be cancelled + internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) + return + } + + // Queue the handoff immediately if the connection is idle in the pool. + // If the connection is in use (StateInUse), it will be queued when returned to the pool via OnPut. + // This handles the case where the connection is idle and might never be retrieved again. + if poolConn.GetStateMachine().GetState() == pool.StateIdle { + if snh.manager.poolHooksRef != nil && snh.manager.poolHooksRef.workerManager != nil { + if err := snh.manager.poolHooksRef.workerManager.queueHandoff(poolConn); err != nil { + internal.Logger.Printf(context.Background(), logs.FailedToQueueHandoff(poolConn.GetID(), err)) + } else { + // Mark the connection as queued for handoff to prevent it from being retrieved + // This transitions the connection to StateUnusable + if err := poolConn.MarkQueuedForHandoff(); err != nil { + internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) + } else { + internal.Logger.Printf(context.Background(), logs.MarkedForHandoff(poolConn.GetID())) + } + } + } + } + // If connection is StateInUse, the handoff will be queued when it's returned to the pool + }) + return nil + } + + return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline) +} + +func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error { + if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil { + internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err)) + // Connection is already marked for handoff, which is acceptable + // This can happen if multiple MOVING notifications are received for the same connection + return nil + } + // Optionally track in m + if snh.operationsManager != nil { + connID := conn.GetID() + // Track the operation (ignore errors since this is optional) + _ = snh.operationsManager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID) + } else { + return errors.New(logs.ManagerNotInitialized()) + } + return nil +} + +// handleMigrating processes MIGRATING notifications. +// MIGRATING indicates that a connection migration is starting. +// This is a per-connection notification that applies relaxed timeouts. +// Expected format: ["MIGRATING", ...] +func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 2 { + internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification)) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING")) + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx)) + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout)) + } + conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + + // Record relaxed timeout metric + if relaxedTimeoutCallback := pool.GetMetricConnectionRelaxedTimeoutCallback(); relaxedTimeoutCallback != nil { + relaxedTimeoutCallback(ctx, 1, conn, PoolNameMain, "MIGRATING") + } + + return nil +} + +// handleMigrated processes MIGRATED notifications. +// MIGRATED indicates that a connection migration has completed. +// This is a per-connection notification that clears relaxed timeouts. +// Expected format: ["MIGRATED", ...] +func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 2 { + internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification)) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED")) + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx)) + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + if internal.LogLevel.InfoOrAbove() { + connID := conn.GetID() + internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) + } + conn.ClearRelaxedTimeout() + return nil +} + +// handleFailingOver processes FAILING_OVER notifications. +// FAILING_OVER indicates that a failover is starting. +// This is a per-connection notification that applies relaxed timeouts. +// Expected format: ["FAILING_OVER", ...] +func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 2 { + internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification)) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER")) + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx)) + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + if internal.LogLevel.InfoOrAbove() { + connID := conn.GetID() + internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout)) + } + conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + + // Record relaxed timeout metric + if relaxedTimeoutCallback := pool.GetMetricConnectionRelaxedTimeoutCallback(); relaxedTimeoutCallback != nil { + relaxedTimeoutCallback(ctx, 1, conn, PoolNameMain, "FAILING_OVER") + } + + return nil +} + +// handleFailedOver processes FAILED_OVER notifications. +// FAILED_OVER indicates that a failover has completed. +// This is a per-connection notification that clears relaxed timeouts. +// Expected format: ["FAILED_OVER", ...] +func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 2 { + internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification)) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER")) + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx)) + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + if internal.LogLevel.InfoOrAbove() { + connID := conn.GetID() + internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) + } + conn.ClearRelaxedTimeout() + return nil +} + +// handleSMigrating processes SMIGRATING notifications. +// SMIGRATING indicates that a cluster slot is in the process of migrating to a different node. +// This is a per-connection notification that applies relaxed timeouts during slot migration. +// Expected format: ["SMIGRATING", SeqID, slot/range1-range2, ...] +func (snh *NotificationHandler) handleSMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATING", notification)) + return ErrInvalidNotification + } + + // Validate SeqID (position 1) + if _, ok := notification[1].(int64); !ok { + internal.Logger.Printf(ctx, logs.InvalidSeqIDInSMigratingNotification(notification[1])) + return ErrInvalidNotification + } + + if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("SMIGRATING")) + return ErrInvalidNotification + } + + conn, ok := handlerCtx.Conn.(*pool.Conn) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("SMIGRATING", handlerCtx.Conn, handlerCtx)) + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "SMIGRATING", snh.manager.config.RelaxedTimeout)) + } + conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + return nil +} + +// handleSMigrated processes SMIGRATED notifications. +// SMIGRATED indicates that a cluster slot has finished migrating to a different node. +// This is a cluster-level notification that triggers cluster state reload. +// +// Expected RESP3 format: +// +// >3 +// +SMIGRATED +// :SeqID +// * <- array of triplet arrays +// *3 <- each triplet is a 3-element array +// + <- node from which slots are migrating FROM +// + <- node to which slots are migrating TO +// + <- comma-separated slots and/or ranges (e.g., "123,789-1000") +// +// A source and target endpoint may appear in multiple triplets. +// The notification is only processed if the connection's NodeAddress matches one of the source endpoints. +// +// Note: Multiple connections may receive the same notification, so we deduplicate by SeqID before triggering reload. +// but we still process the notification on each connection to clear the relaxed timeout. +// In the case when the connection is from MOVED/ASK, the connection's original endpoint is not set, +// so we will not be able to match the source endpoint. In such case, we will trigger the reload callback with the first target endpoint. +func (snh *NotificationHandler) handleSMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // Expected: ["SMIGRATED", SeqID, [[source, target, slots], ...]] + // Minimum 3 elements: SMIGRATED, SeqID, and the array of triplets + if len(notification) < 3 { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED", notification)) + return ErrInvalidNotification + } + + // Extract SeqID (position 1) + seqID, ok := notification[1].(int64) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidSeqIDInSMigratedNotification(notification[1])) + return ErrInvalidNotification + } + + // Extract the array of triplets (position 2) + triplets, ok := notification[2].([]interface{}) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (triplets array)", notification[2])) + return ErrInvalidNotification + } + + if len(triplets) == 0 { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (empty triplets)", notification)) + return ErrInvalidNotification + } + + // Get the connection's endpoints to check if this notification is relevant + // We check against both nodeAddress (from CLUSTER SLOTS) and addr (after resolution) + // since we cannot be certain which format the notification source will use + var connectionNodeAddress string + var connectionAddr string + if snh.manager.options != nil { + connectionNodeAddress = snh.manager.options.GetNodeAddress() + connectionAddr = snh.manager.options.GetAddr() + } + + // Helper function to check if source matches either of our endpoints + // notification source can be either the node address or the addr after resolution + sourceMatchesConnection := func(source string) bool { + if source == connectionNodeAddress { + return true + } + if source == connectionAddr { + return true + } + return false + } + + // Parse triplets and check if any source matches our connection's endpoints + var matchingTriplets []struct { + source string + target string + slots string + } + var allSlotRanges []string + + for _, tripletInterface := range triplets { + // Each triplet should be a 3-element array: [source, target, slots] + triplet, ok := tripletInterface.([]interface{}) + if !ok || len(triplet) != 3 { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (triplet format)", tripletInterface)) + continue + } + + // Extract source endpoint + source, ok := triplet[0].(string) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (source)", triplet[0])) + continue + } + + // Extract target endpoint + target, ok := triplet[1].(string) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (target)", triplet[1])) + continue + } + + // Extract slots + slots, ok := triplet[2].(string) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (slots)", triplet[2])) + continue + } + + // Check if this triplet's source matches our connection's endpoints + if sourceMatchesConnection(source) { + matchingTriplets = append(matchingTriplets, struct { + source string + target string + slots string + }{source, target, slots}) + slotRanges := strings.Split(slots, ",") + allSlotRanges = append(allSlotRanges, slotRanges...) + } + } + + var connID uint64 + // Reset relaxed timeout for this specific connection + if handlerCtx.Conn != nil { + conn, ok := handlerCtx.Conn.(*pool.Conn) + if ok { + if internal.LogLevel.InfoOrAbove() { + connID = conn.GetID() + internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) + } + conn.ClearRelaxedTimeout() + } + } + + // If no matching triplets, this notification is not relevant to this connection + if len(matchingTriplets) == 0 { + return nil + } + + // Deduplicate by SeqID - multiple connections may receive the same notification + // Only trigger cluster state reload once per seqID + if snh.manager.MarkSMigratedSeqIDProcessed(seqID) { + // Use the first matching triplet + target := matchingTriplets[0].target + slotsForLog := allSlotRanges + + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(ctx, logs.TriggeringClusterStateReload(seqID, target, slotsForLog)) + } + + // Trigger cluster state reload via callback + snh.manager.TriggerClusterStateReload(ctx, target, slotsForLog) + } + + return nil +} diff --git a/vendor/github.com/redis/go-redis/v9/maintnotifications/state.go b/vendor/github.com/redis/go-redis/v9/maintnotifications/state.go new file mode 100644 index 000000000..8180bcd97 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/maintnotifications/state.go @@ -0,0 +1,24 @@ +package maintnotifications + +// State represents the current state of a maintenance operation +type State int + +const ( + // StateIdle indicates no upgrade is in progress + StateIdle State = iota + + // StateHandoff indicates a connection handoff is in progress + StateMoving +) + +// String returns a string representation of the state. +func (s State) String() string { + switch s { + case StateIdle: + return "idle" + case StateMoving: + return "moving" + default: + return "unknown" + } +} diff --git a/vendor/github.com/redis/go-redis/v9/options.go b/vendor/github.com/redis/go-redis/v9/options.go index 3ffcd07ed..ba45a0cb8 100644 --- a/vendor/github.com/redis/go-redis/v9/options.go +++ b/vendor/github.com/redis/go-redis/v9/options.go @@ -5,17 +5,34 @@ import ( "crypto/tls" "errors" "fmt" + "maps" "net" "net/url" "runtime" - "sort" + "slices" "strconv" "strings" + "sync/atomic" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/maintnotifications" + "github.com/redis/go-redis/v9/push" ) +// poolIDCounter is a global auto-increment counter for generating unique pool IDs. +var poolIDCounter atomic.Uint64 + +// generateUniqueID generates a short unique identifier for pool names using auto-increment. +// This makes it easier to identify and track pools in order of creation. +func generateUniqueID() string { + id := poolIDCounter.Add(1) + return strconv.FormatUint(id, 10) +} + // Limiter is the interface of a rate limiter or a circuit breaker. type Limiter interface { // Allow returns nil if operation is allowed or an error otherwise. @@ -29,12 +46,25 @@ type Limiter interface { // Options keeps the settings to set up redis connection. type Options struct { - // The network type, either tcp or unix. - // Default is tcp. + // Network type, either tcp or unix. + // + // default: is tcp. Network string - // host:port address. + + // Addr is the address formated as host:port Addr string + // NodeAddress is the address of the Redis node as reported by the server. + // For cluster clients, this is the exact endpoint string returned by CLUSTER SLOTS + // before any resolution or transformation (e.g., loopback replacement). + // For standalone clients, this defaults to Addr. + // + // This is used to match the source endpoint in maintenance notifications + // (e.g. SMIGRATED). + // + // Use Client.NodeAddress() to access this value. + NodeAddress string + // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. ClientName string @@ -46,17 +76,21 @@ type Options struct { OnConnect func(ctx context.Context, cn *Conn) error // Protocol 2 or 3. Use the version to negotiate RESP version with redis-server. - // Default is 3. + // + // default: 3. Protocol int - // Use the specified Username to authenticate the current connection + + // Username is used to authenticate the current connection // with one of the connections defined in the ACL list when connecting // to a Redis 6.0 instance, or greater, that is using the Redis ACL system. Username string - // Optional password. Must match the password specified in the - // requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), + + // Password is an optional password. Must match the password specified in the + // `requirepass` server configuration option (if connecting to a Redis 5.0 instance, or lower), // or the User Password when connecting to a Redis 6.0 instance, or greater, // that is using the Redis ACL system. Password string + // CredentialsProvider allows the username and password to be updated // before reconnecting. It should return the current username and password. CredentialsProvider func() (username string, password string) @@ -67,85 +101,177 @@ type Options struct { // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) - // Database to be selected after connecting to the server. + // StreamingCredentialsProvider is used to retrieve the credentials + // for the connection from an external source. Those credentials may change + // during the connection lifetime. This is useful for managed identity + // scenarios where the credentials are retrieved from an external source. + // + // Currently, this is a placeholder for the future implementation. + StreamingCredentialsProvider auth.StreamingCredentialsProvider + + // DB is the database to be selected after connecting to the server. DB int - // Maximum number of retries before giving up. - // Default is 3 retries; -1 (not 0) disables retries. + // MaxRetries is the maximum number of retries before giving up. + // -1 (not 0) disables retries. + // + // default: 3 retries MaxRetries int - // Minimum backoff between each retry. - // Default is 8 milliseconds; -1 disables backoff. + + // MinRetryBackoff is the minimum backoff between each retry. + // -1 disables backoff. + // + // default: 8 milliseconds MinRetryBackoff time.Duration - // Maximum backoff between each retry. - // Default is 512 milliseconds; -1 disables backoff. + + // MaxRetryBackoff is the maximum backoff between each retry. + // -1 disables backoff. + // default: 512 milliseconds; MaxRetryBackoff time.Duration - // Dial timeout for establishing new connections. - // Default is 5 seconds. + // DialTimeout for establishing new connections. + // + // default: 5 seconds DialTimeout time.Duration - // Timeout for socket reads. If reached, commands will fail + + // DialerRetries is the maximum number of retry attempts when dialing fails. + // + // default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // + // default: 100 milliseconds + DialerRetryTimeout time.Duration + + // DialerRetryBackoff controls the delay between dial retry attempts. + // + // attempt is 0-based: attempt=0 is the delay after the 1st failed dial (before the 2nd attempt). + // + // If nil, dial retry backoff is constant and equals DialerRetryTimeout (default: 100ms). + DialerRetryBackoff func(attempt int) time.Duration + + // ReadTimeout for socket reads. If reached, commands will fail // with a timeout instead of blocking. Supported values: - // - `0` - default timeout (3 seconds). - // - `-1` - no timeout (block indefinitely). - // - `-2` - disables SetReadDeadline calls completely. + // + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetReadDeadline calls completely. + // + // default: 3 seconds ReadTimeout time.Duration - // Timeout for socket writes. If reached, commands will fail + + // WriteTimeout for socket writes. If reached, commands will fail // with a timeout instead of blocking. Supported values: - // - `0` - default timeout (3 seconds). - // - `-1` - no timeout (block indefinitely). - // - `-2` - disables SetWriteDeadline calls completely. + // + // - `-1` - no timeout (block indefinitely). + // - `-2` - disables SetWriteDeadline calls completely. + // + // default: 3 seconds WriteTimeout time.Duration + // ContextTimeoutEnabled controls whether the client respects context timeouts and deadlines. // See https://redis.uptrace.dev/guide/go-redis-debugging.html#timeouts ContextTimeoutEnabled bool - // Type of connection pool. - // true for FIFO pool, false for LIFO pool. + // ReadBufferSize is the size of the bufio.Reader buffer for each connection. + // Larger buffers can improve performance for commands that return large responses. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + ReadBufferSize int + + // WriteBufferSize is the size of the bufio.Writer buffer for each connection. + // Larger buffers can improve performance for large pipelines and commands with many arguments. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + WriteBufferSize int + + // PoolFIFO type of connection pool. + // + // - true for FIFO pool + // - false for LIFO pool. + // // Note that FIFO has slightly higher overhead compared to LIFO, // but it helps closing idle connections faster reducing the pool size. + // default: false PoolFIFO bool - // Base number of socket connections. + + // PoolSize is the base number of socket connections. // Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. // If there is not enough connections in the pool, new connections will be allocated in excess of PoolSize, // you can limit it through MaxActiveConns + // + // default: 10 * runtime.GOMAXPROCS(0) PoolSize int - // Amount of time client waits for connection if all connections + + // MaxConcurrentDials is the maximum number of concurrent connection creation goroutines. + // If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize. + MaxConcurrentDials int + + // PoolTimeout is the amount of time client waits for connection if all connections // are busy before returning an error. - // Default is ReadTimeout + 1 second. + // + // default: ReadTimeout + 1 second PoolTimeout time.Duration - // Minimum number of idle connections which is useful when establishing - // new connection is slow. - // Default is 0. the idle connections are not closed by default. + + // MinIdleConns is the minimum number of idle connections which is useful when establishing + // new connection is slow. The idle connections are not closed by default. + // + // default: 0 MinIdleConns int - // Maximum number of idle connections. - // Default is 0. the idle connections are not closed by default. + + // MaxIdleConns is the maximum number of idle connections. + // The idle connections are not closed by default. + // + // default: 0 MaxIdleConns int - // Maximum number of connections allocated by the pool at a given time. + + // MaxActiveConns is the maximum number of connections allocated by the pool at a given time. // When zero, there is no limit on the number of connections in the pool. + // If the pool is full, the next call to Get() will block until a connection is released. + // + // default: 0 MaxActiveConns int + // ConnMaxIdleTime is the maximum amount of time a connection may be idle. // Should be less than server's timeout. // // Expired connections may be closed lazily before reuse. // If d <= 0, connections are not closed due to a connection's idle time. + // -1 disables idle timeout check. // - // Default is 30 minutes. -1 disables idle timeout check. + // default: 30 minutes ConnMaxIdleTime time.Duration + // ConnMaxLifetime is the maximum amount of time a connection may be reused. // // Expired connections may be closed lazily before reuse. // If <= 0, connections are not closed due to a connection's age. // - // Default is to not close idle connections. + // default: 0 ConnMaxLifetime time.Duration - // TLS Config to use. When set, TLS will be negotiated. + // ConnMaxLifetimeJitter is the absolute jitter duration applied to ConnMaxLifetime + // to prevent all connections from expiring simultaneously. + // + // The jitter is applied as a random offset in the range [-jitter, +jitter]. + // For example, if ConnMaxLifetime is 1 hour and ConnMaxLifetimeJitter is 6 minutes, + // connections will expire between 54 minutes and 66 minutes. + // + // If <= 0, no jitter is applied. + // If > ConnMaxLifetime, it will be capped at ConnMaxLifetime. + // + // default: 0 + ConnMaxLifetimeJitter time.Duration + + // TLSConfig to use. When set, TLS will be negotiated. TLSConfig *tls.Config // Limiter interface used to implement circuit breaker or rate limiter. Limiter Limiter - // Enables read only queries on slave/follower nodes. + // readOnly enables read only queries on slave/follower nodes. readOnly bool // DisableIndentity - Disable set-lib on connect. @@ -161,10 +287,32 @@ type Options struct { DisableIdentity bool // Add suffix to client name. Default is empty. + // IdentitySuffix - add suffix to client name. IdentitySuffix string - // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. + // Deprecated: All RediSearch commands now have stable RESP3 parsing and this + // flag is a no-op. It is kept for backwards compatibility and will be removed + // in a future release. UnstableResp3 bool + + // Push notifications are always enabled for RESP3 connections (Protocol: 3) + // and are not available for RESP2 connections. No configuration option is needed. + + // PushNotificationProcessor is the processor for handling push notifications. + // If nil, a default processor will be created for RESP3 connections. + PushNotificationProcessor push.NotificationProcessor + + // FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing. + // When a node is marked as failing, it will be avoided for this duration. + // Default is 15 seconds. + FailingTimeoutSeconds int + + // MaintNotificationsConfig provides custom configuration for maintnotifications. + // When MaintNotificationsConfig.Mode is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it. + MaintNotificationsConfig *maintnotifications.Config } func (opt *Options) init() { @@ -178,15 +326,41 @@ func (opt *Options) init() { opt.Network = "tcp" } } + // For standalone clients, default NodeAddress to Addr if not set. + // This ensures maintenance notifications (SMIGRATED, etc.) can match + // the connection's endpoint even for non-cluster clients. + if opt.NodeAddress == "" { + opt.NodeAddress = opt.Addr + } + if opt.Protocol < 2 { + opt.Protocol = 3 + } if opt.DialTimeout == 0 { opt.DialTimeout = 5 * time.Second } + if opt.DialerRetries == 0 { + opt.DialerRetries = 5 + } + if opt.DialerRetryTimeout == 0 { + opt.DialerRetryTimeout = 100 * time.Millisecond + } if opt.Dialer == nil { opt.Dialer = NewDialer(opt) } if opt.PoolSize == 0 { opt.PoolSize = 10 * runtime.GOMAXPROCS(0) } + if opt.MaxConcurrentDials <= 0 { + opt.MaxConcurrentDials = opt.PoolSize + } else if opt.MaxConcurrentDials > opt.PoolSize { + opt.MaxConcurrentDials = opt.PoolSize + } + if opt.ReadBufferSize == 0 { + opt.ReadBufferSize = proto.DefaultBufferSize + } + if opt.WriteBufferSize == 0 { + opt.WriteBufferSize = proto.DefaultBufferSize + } switch opt.ReadTimeout { case -2: opt.ReadTimeout = -1 @@ -214,6 +388,8 @@ func (opt *Options) init() { opt.ConnMaxIdleTime = 30 * time.Minute } + opt.ConnMaxLifetimeJitter = min(opt.ConnMaxLifetimeJitter, opt.ConnMaxLifetime) + switch opt.MaxRetries { case -1: opt.MaxRetries = 0 @@ -232,13 +408,40 @@ func (opt *Options) init() { case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } + + if opt.FailingTimeoutSeconds == 0 { + opt.FailingTimeoutSeconds = 15 + } + + opt.MaintNotificationsConfig = opt.MaintNotificationsConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns) + + // auto-detect endpoint type if not specified + endpointType := opt.MaintNotificationsConfig.EndpointType + if endpointType == "" || endpointType == maintnotifications.EndpointTypeAuto { + // Auto-detect endpoint type if not specified + endpointType = maintnotifications.DetectEndpointType(opt.Addr, opt.TLSConfig != nil) + } + opt.MaintNotificationsConfig.EndpointType = endpointType } func (opt *Options) clone() *Options { clone := *opt + + // Deep clone MaintNotificationsConfig to avoid sharing between clients + if opt.MaintNotificationsConfig != nil { + configClone := *opt.MaintNotificationsConfig + clone.MaintNotificationsConfig = &configClone + } + return &clone } +// NewDialer returns a function that will be used as the default dialer +// when none is specified in Options.Dialer. +func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) { + return NewDialer(opt) +} + // NewDialer returns a function that will be used as the default dialer // when none is specified in Options.Dialer. func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) { @@ -450,11 +653,8 @@ func (o *queryOptions) remaining() []string { if len(o.q) == 0 { return nil } - keys := make([]string, 0, len(o.q)) - for k := range o.q { - keys = append(keys, k) - } - sort.Strings(keys) + keys := slices.Collect(maps.Keys(o.q)) + slices.Sort(keys) return keys } @@ -485,6 +685,7 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) { o.MinIdleConns = q.int("min_idle_conns") o.MaxIdleConns = q.int("max_idle_conns") o.MaxActiveConns = q.int("max_active_conns") + o.MaxConcurrentDials = q.int("max_concurrent_dials") if q.has("conn_max_idle_time") { o.ConnMaxIdleTime = q.duration("conn_max_idle_time") } else { @@ -495,6 +696,9 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) { } else { o.ConnMaxLifetime = q.duration("max_conn_age") } + if q.has("conn_max_lifetime_jitter") { + o.ConnMaxLifetimeJitter = min(q.duration("conn_max_lifetime_jitter"), o.ConnMaxLifetime) + } if q.err != nil { return nil, q.err } @@ -524,19 +728,96 @@ func getUserPassword(u *url.URL) (string, string) { func newConnPool( opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), -) *pool.ConnPool { + poolName string, +) (*pool.ConnPool, error) { + poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize") + if err != nil { + return nil, err + } + + minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns") + if err != nil { + return nil, err + } + + maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + if err != nil { + return nil, err + } + + maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + if err != nil { + return nil, err + } + return pool.NewConnPool(&pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { return dialer(ctx, opt.Network, opt.Addr) }, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - DialTimeout: opt.DialTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, - }) + PoolFIFO: opt.PoolFIFO, + PoolSize: poolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, + PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + DialerRetryBackoff: opt.DialerRetryBackoff, + MinIdleConns: minIdleConns, + MaxIdleConns: maxIdleConns, + MaxActiveConns: maxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ConnMaxLifetimeJitter: opt.ConnMaxLifetimeJitter, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, + PushNotificationsEnabled: opt.Protocol == 3, + Name: poolName, + }), nil +} + +func newPubSubPool( + opt *Options, + dialer func(ctx context.Context, network, addr string) (net.Conn, error), + poolName string, +) (*pool.PubSubPool, error) { + poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize") + if err != nil { + return nil, err + } + + minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns") + if err != nil { + return nil, err + } + + maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + if err != nil { + return nil, err + } + + maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + if err != nil { + return nil, err + } + + return pool.NewPubSubPool(&pool.Options{ + PoolFIFO: opt.PoolFIFO, + PoolSize: poolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, + PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + DialerRetryBackoff: opt.DialerRetryBackoff, + MinIdleConns: minIdleConns, + MaxIdleConns: maxIdleConns, + MaxActiveConns: maxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ConnMaxLifetimeJitter: opt.ConnMaxLifetimeJitter, + ReadBufferSize: 32 * 1024, + WriteBufferSize: 32 * 1024, + PushNotificationsEnabled: opt.Protocol == 3, + Name: poolName, + }, dialer), nil } diff --git a/vendor/github.com/redis/go-redis/v9/osscluster.go b/vendor/github.com/redis/go-redis/v9/osscluster.go index c0278ed05..efd52960c 100644 --- a/vendor/github.com/redis/go-redis/v9/osscluster.go +++ b/vendor/github.com/redis/go-redis/v9/osscluster.go @@ -1,31 +1,43 @@ package redis import ( + "cmp" "context" "crypto/tls" + "errors" "fmt" "math" + "math/rand" "net" "net/url" "runtime" + "slices" "sort" "strings" "sync" "sync/atomic" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hashtag" + "github.com/redis/go-redis/v9/internal/otel" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" - "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/internal/routing" + "github.com/redis/go-redis/v9/maintnotifications" + "github.com/redis/go-redis/v9/push" ) const ( minLatencyMeasurementInterval = 10 * time.Second ) -var errClusterNoNodes = fmt.Errorf("redis: cluster has no nodes") +var ( + errClusterNoNodes = errors.New("redis: cluster has no nodes") + errNoWatchKeys = errors.New("redis: Watch requires at least one key") + errWatchCrosslot = errors.New("redis: Watch requires all keys to be in the same slot") +) // ClusterOptions are used to configure a cluster client and should be // passed to NewClusterClient. @@ -37,6 +49,7 @@ type ClusterOptions struct { ClientName string // NewClient creates a cluster node client with provided name and options. + // If NewClient is set by the user, the user is responsible for handling maintnotifications upgrades and push notifications. NewClient func(opt *Options) *Client // The maximum number of retries before giving up. Command is retried @@ -66,32 +79,76 @@ type ClusterOptions struct { OnConnect func(ctx context.Context, cn *Conn) error - Protocol int - Username string - Password string - CredentialsProvider func() (username string, password string) - CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) - + Protocol int + Username string + Password string + CredentialsProvider func() (username string, password string) + CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + StreamingCredentialsProvider auth.StreamingCredentialsProvider + + // MaxRetries is the maximum number of retries before giving up. + // For ClusterClient, retries are disabled by default (set to -1), + // because the cluster client handles all kinds of retries internally. + // This is intentional and differs from the standalone Options default. MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration - DialTimeout time.Duration + DialTimeout time.Duration + + // DialerRetries is the maximum number of retry attempts when dialing fails. + // + // default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // + // default: 100 milliseconds + DialerRetryTimeout time.Duration + + // DialerRetryBackoff controls the delay between dial retry attempts. + // See Options.DialerRetryBackoff for details. + DialerRetryBackoff func(attempt int) time.Duration + ReadTimeout time.Duration WriteTimeout time.Duration ContextTimeoutEnabled bool - PoolFIFO bool - PoolSize int // applies per cluster node and not for the whole cluster - PoolTimeout time.Duration - MinIdleConns int - MaxIdleConns int - MaxActiveConns int // applies per cluster node and not for the whole cluster - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration + // MaxConcurrentDials is the maximum number of concurrent connection creation goroutines. + // If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize. + MaxConcurrentDials int + + PoolFIFO bool + PoolSize int // applies per cluster node and not for the whole cluster + PoolTimeout time.Duration + MinIdleConns int + MaxIdleConns int + MaxActiveConns int // applies per cluster node and not for the whole cluster + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + ConnMaxLifetimeJitter time.Duration + + // ReadBufferSize is the size of the bufio.Reader buffer for each connection. + // Larger buffers can improve performance for commands that return large responses. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + ReadBufferSize int + + // WriteBufferSize is the size of the bufio.Writer buffer for each connection. + // Larger buffers can improve performance for large pipelines and commands with many arguments. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + WriteBufferSize int TLSConfig *tls.Config + // DisableRoutingPolicies disables the request/response policy routing system. + // When disabled, all commands use the legacy routing behavior. + // Experimental. Will be removed when shard picker is fully implemented. + DisableRoutingPolicies bool + // DisableIndentity - Disable set-lib on connect. // // default: false @@ -106,8 +163,35 @@ type ClusterOptions struct { IdentitySuffix string // Add suffix to client name. Default is empty. - // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. + // Deprecated: All RediSearch commands now have stable RESP3 parsing and this + // flag is a no-op. It is kept for backwards compatibility and will be removed + // in a future release. UnstableResp3 bool + + // PushNotificationProcessor is the processor for handling push notifications. + // If nil, a default processor will be created for RESP3 connections. + PushNotificationProcessor push.NotificationProcessor + + // FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing. + // When a node is marked as failing, it will be avoided for this duration. + // Default is 15 seconds. + FailingTimeoutSeconds int + + // MaintNotificationsConfig provides custom configuration for maintnotifications upgrades. + // When MaintNotificationsConfig.Mode is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it. + // The ClusterClient supports SMIGRATING and SMIGRATED notifications for cluster state management. + // Individual node clients handle other maintenance notifications (MOVING, MIGRATING, etc.). + MaintNotificationsConfig *maintnotifications.Config + // ShardPicker is used to pick a shard when the request_policy is + // ReqDefault and the command has no keys. + ShardPicker routing.ShardPicker + + // ClusterStateReloadInterval is the interval for reloading the cluster state. + // Default is 10 seconds. + ClusterStateReloadInterval time.Duration } func (opt *ClusterOptions) init() { @@ -122,9 +206,30 @@ func (opt *ClusterOptions) init() { opt.ReadOnly = true } + if opt.DialTimeout == 0 { + opt.DialTimeout = 5 * time.Second + } + if opt.DialerRetries == 0 { + opt.DialerRetries = 5 + } + if opt.DialerRetryTimeout == 0 { + opt.DialerRetryTimeout = 100 * time.Millisecond + } + if opt.PoolSize == 0 { opt.PoolSize = 5 * runtime.GOMAXPROCS(0) } + if opt.MaxConcurrentDials <= 0 { + opt.MaxConcurrentDials = opt.PoolSize + } else if opt.MaxConcurrentDials > opt.PoolSize { + opt.MaxConcurrentDials = opt.PoolSize + } + if opt.ReadBufferSize == 0 { + opt.ReadBufferSize = proto.DefaultBufferSize + } + if opt.WriteBufferSize == 0 { + opt.WriteBufferSize = proto.DefaultBufferSize + } switch opt.ReadTimeout { case -1: @@ -158,6 +263,18 @@ func (opt *ClusterOptions) init() { if opt.NewClient == nil { opt.NewClient = NewClient } + + if opt.FailingTimeoutSeconds == 0 { + opt.FailingTimeoutSeconds = 15 + } + + if opt.ShardPicker == nil { + opt.ShardPicker = &routing.RoundRobinPicker{} + } + + if opt.ClusterStateReloadInterval == 0 { + opt.ClusterStateReloadInterval = 10 * time.Second + } } // ParseClusterURL parses a URL into ClusterOptions that can be used to connect to Redis. @@ -252,16 +369,23 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er o.MinRetryBackoff = q.duration("min_retry_backoff") o.MaxRetryBackoff = q.duration("max_retry_backoff") o.DialTimeout = q.duration("dial_timeout") + o.DialerRetries = q.int("dialer_retries") + o.DialerRetryTimeout = q.duration("dialer_retry_timeout") o.ReadTimeout = q.duration("read_timeout") o.WriteTimeout = q.duration("write_timeout") o.PoolFIFO = q.bool("pool_fifo") o.PoolSize = q.int("pool_size") + o.MaxConcurrentDials = q.int("max_concurrent_dials") o.MinIdleConns = q.int("min_idle_conns") o.MaxIdleConns = q.int("max_idle_conns") o.MaxActiveConns = q.int("max_active_conns") o.PoolTimeout = q.duration("pool_timeout") o.ConnMaxLifetime = q.duration("conn_max_lifetime") + if q.has("conn_max_lifetime_jitter") { + o.ConnMaxLifetimeJitter = min(q.duration("conn_max_lifetime_jitter"), o.ConnMaxLifetime) + } o.ConnMaxIdleTime = q.duration("conn_max_idle_time") + o.FailingTimeoutSeconds = q.int("failing_timeout_seconds") if q.err != nil { return nil, q.err @@ -287,45 +411,64 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er } func (opt *ClusterOptions) clientOptions() *Options { + // Clone MaintNotificationsConfig to avoid sharing between cluster node clients + var maintNotificationsConfig *maintnotifications.Config + if opt.MaintNotificationsConfig != nil { + configClone := *opt.MaintNotificationsConfig + maintNotificationsConfig = &configClone + } + return &Options{ ClientName: opt.ClientName, Dialer: opt.Dialer, OnConnect: opt.OnConnect, - Protocol: opt.Protocol, - Username: opt.Username, - Password: opt.Password, - CredentialsProvider: opt.CredentialsProvider, - CredentialsProviderContext: opt.CredentialsProviderContext, + Protocol: opt.Protocol, + Username: opt.Username, + Password: opt.Password, + CredentialsProvider: opt.CredentialsProvider, + CredentialsProviderContext: opt.CredentialsProviderContext, + StreamingCredentialsProvider: opt.StreamingCredentialsProvider, MaxRetries: opt.MaxRetries, MinRetryBackoff: opt.MinRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff, - DialTimeout: opt.DialTimeout, - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + DialerRetryBackoff: opt.DialerRetryBackoff, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + ContextTimeoutEnabled: opt.ContextTimeoutEnabled, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, - DisableIdentity: opt.DisableIdentity, - DisableIndentity: opt.DisableIdentity, - IdentitySuffix: opt.IdentitySuffix, - TLSConfig: opt.TLSConfig, + PoolFIFO: opt.PoolFIFO, + PoolSize: opt.PoolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + MaxActiveConns: opt.MaxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ConnMaxLifetimeJitter: opt.ConnMaxLifetimeJitter, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, + DisableIdentity: opt.DisableIdentity, + DisableIndentity: opt.DisableIdentity, + IdentitySuffix: opt.IdentitySuffix, + FailingTimeoutSeconds: opt.FailingTimeoutSeconds, + TLSConfig: opt.TLSConfig, // If ClusterSlots is populated, then we probably have an artificial // cluster whose nodes are not in clustering mode (otherwise there isn't // much use for ClusterSlots config). This means we cannot execute the // READONLY command against that node -- setting readOnly to false in such // situations in the options below will prevent that from happening. - readOnly: opt.ReadOnly && opt.ClusterSlots == nil, - UnstableResp3: opt.UnstableResp3, + readOnly: opt.ReadOnly && opt.ClusterSlots == nil, + UnstableResp3: opt.UnstableResp3, + MaintNotificationsConfig: maintNotificationsConfig, + PushNotificationProcessor: opt.PushNotificationProcessor, } } @@ -337,15 +480,16 @@ type clusterNode struct { latency uint32 // atomic generation uint32 // atomic failing uint32 // atomic + loaded uint32 // atomic - // last time the latency measurement was performed for the node, stored in nanoseconds - // from epoch + // last time the latency measurement was performed for the node, stored in nanoseconds from epoch lastLatencyMeasurement int64 // atomic } -func newClusterNode(clOpt *ClusterOptions, addr string) *clusterNode { +func newClusterNodeWithNodeAddress(clOpt *ClusterOptions, addr, nodeAddress string) *clusterNode { opt := clOpt.clientOptions() opt.Addr = addr + opt.NodeAddress = nodeAddress node := clusterNode{ Client: clOpt.NewClient(opt), } @@ -388,7 +532,7 @@ func (n *clusterNode) updateLatency() { if successes == 0 { // If none of the pings worked, set latency to some arbitrarily high value so this node gets // least priority. - latency = float64((maximumNodeLatency) / time.Microsecond) + latency = float64(maximumNodeLatency / time.Microsecond) } else { latency = float64(dur) / float64(successes) } @@ -403,10 +547,11 @@ func (n *clusterNode) Latency() time.Duration { func (n *clusterNode) MarkAsFailing() { atomic.StoreUint32(&n.failing, uint32(time.Now().Unix())) + atomic.StoreUint32(&n.loaded, 0) } func (n *clusterNode) Failing() bool { - const timeout = 15 // 15 seconds + timeout := int64(n.Client.opt.FailingTimeoutSeconds) failing := atomic.LoadUint32(&n.failing) if failing == 0 { @@ -445,6 +590,24 @@ func (n *clusterNode) SetLastLatencyMeasurement(t time.Time) { } } +func (n *clusterNode) Loading() bool { + loaded := atomic.LoadUint32(&n.loaded) + if loaded == 1 { + return false + } + + // check if the node is loading + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := n.Client.Ping(ctx).Err() + loading := err != nil && isLoadingError(err) + if !loading { + atomic.StoreUint32(&n.loaded, 1) + } + return loading +} + //------------------------------------------------------------------------------ type clusterNodes struct { @@ -457,13 +620,12 @@ type clusterNodes struct { closed bool onNewNode []func(rdb *Client) - _generation uint32 // atomic + generation uint32 // atomic } func newClusterNodes(opt *ClusterOptions) *clusterNodes { return &clusterNodes{ - opt: opt, - + opt: opt, addrs: opt.Addrs, nodes: make(map[string]*clusterNode), } @@ -523,12 +685,11 @@ func (c *clusterNodes) Addrs() ([]string, error) { } func (c *clusterNodes) NextGeneration() uint32 { - return atomic.AddUint32(&c._generation, 1) + return atomic.AddUint32(&c.generation, 1) } // GC removes unused nodes. func (c *clusterNodes) GC(generation uint32) { - //nolint:prealloc var collected []*clusterNode c.mu.Lock() @@ -556,6 +717,10 @@ func (c *clusterNodes) GC(generation uint32) { } func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { + return c.GetOrCreateWithNodeAddress(addr, "") +} + +func (c *clusterNodes) GetOrCreateWithNodeAddress(addr, nodeAddress string) (*clusterNode, error) { node, err := c.get(addr) if err != nil { return nil, err @@ -576,28 +741,25 @@ func (c *clusterNodes) GetOrCreate(addr string) (*clusterNode, error) { return node, nil } - node = newClusterNode(c.opt, addr) + node = newClusterNodeWithNodeAddress(c.opt, addr, nodeAddress) for _, fn := range c.onNewNode { fn(node.Client) } - c.addrs = appendIfNotExists(c.addrs, addr) + c.addrs = appendIfNotExist(c.addrs, addr) c.nodes[addr] = node return node, nil } func (c *clusterNodes) get(addr string) (*clusterNode, error) { - var node *clusterNode - var err error c.mu.RLock() + defer c.mu.RUnlock() + if c.closed { - err = pool.ErrClosed - } else { - node = c.nodes[addr] + return nil, pool.ErrClosed } - c.mu.RUnlock() - return node, err + return c.nodes[addr], nil } func (c *clusterNodes) All() ([]*clusterNode, error) { @@ -628,22 +790,9 @@ func (c *clusterNodes) Random() (*clusterNode, error) { //------------------------------------------------------------------------------ type clusterSlot struct { - start, end int - nodes []*clusterNode -} - -type clusterSlotSlice []*clusterSlot - -func (p clusterSlotSlice) Len() int { - return len(p) -} - -func (p clusterSlotSlice) Less(i, j int) bool { - return p[i].start < p[j].start -} - -func (p clusterSlotSlice) Swap(i, j int) { - p[i], p[j] = p[j], p[i] + start int + end int + nodes []*clusterNode } type clusterState struct { @@ -669,18 +818,25 @@ func newClusterState( createdAt: time.Now(), } - originHost, _, _ := net.SplitHostPort(origin) + originHost, originPort, _ := net.SplitHostPort(origin) isLoopbackOrigin := isLoopback(originHost) for _, slot := range slots { var nodes []*clusterNode for i, slotNode := range slot.Nodes { - addr := slotNode.Addr + // slotNode.Addr is the node address from CLUSTER SLOTS + nodeAddress := slotNode.Addr + addr := nodeAddress if !isLoopbackOrigin { addr = replaceLoopbackHost(addr, originHost) } + // TLS-only clusters (`--port 0 --tls-port 6379`) report port 0 + // in CLUSTER SLOTS. Fall back to the origin port — by definition + // reachable, since it is the port that returned this slot map. + // See https://github.com/redis/go-redis/issues/3726. + addr = replaceZeroPort(addr, originPort) - node, err := c.nodes.GetOrCreate(addr) + node, err := c.nodes.GetOrCreateWithNodeAddress(addr, nodeAddress) if err != nil { return nil, err } @@ -689,9 +845,9 @@ func newClusterState( nodes = append(nodes, node) if i == 0 { - c.Masters = appendUniqueNode(c.Masters, node) + c.Masters = appendIfNotExist(c.Masters, node) } else { - c.Slaves = appendUniqueNode(c.Slaves, node) + c.Slaves = appendIfNotExist(c.Slaves, node) } } @@ -702,7 +858,9 @@ func newClusterState( }) } - sort.Sort(clusterSlotSlice(c.slots)) + slices.SortFunc(c.slots, func(a, b *clusterSlot) int { + return cmp.Compare(a.start, b.start) + }) time.AfterFunc(time.Minute, func() { nodes.GC(c.generation) @@ -730,12 +888,40 @@ func replaceLoopbackHost(nodeAddr, originHost string) string { return net.JoinHostPort(originHost, nodePort) } +// replaceZeroPort substitutes originPort for a node port of "0", which is +// what CLUSTER SLOTS reports for TLS-only clusters started with +// `--port 0 --tls-port `. Non-zero ports and addresses without a +// recoverable origin port are returned unchanged. +func replaceZeroPort(nodeAddr, originPort string) string { + if originPort == "" || originPort == "0" { + return nodeAddr + } + nodeHost, nodePort, err := net.SplitHostPort(nodeAddr) + if err != nil || nodePort != "0" { + return nodeAddr + } + return net.JoinHostPort(nodeHost, originPort) +} + +// isLoopback returns true if the host is a loopback address. +// For IP addresses, it uses net.IP.IsLoopback(). +// For hostnames, it recognizes well-known loopback hostnames like "localhost" +// and Docker-specific loopback patterns like "*.docker.internal". func isLoopback(host string) bool { ip := net.ParseIP(host) - if ip == nil { + if ip != nil { + return ip.IsLoopback() + } + + if strings.ToLower(host) == "localhost" { + return true + } + + if strings.HasSuffix(strings.ToLower(host), ".docker.internal") { return true } - return ip.IsLoopback() + + return false } func (c *clusterState) slotMasterNode(slot int) (*clusterNode, error) { @@ -754,7 +940,8 @@ func (c *clusterState) slotSlaveNode(slot int) (*clusterNode, error) { case 1: return nodes[0], nil case 2: - if slave := nodes[1]; !slave.Failing() { + slave := nodes[1] + if !slave.Failing() && !slave.Loading() { return slave, nil } return nodes[0], nil @@ -763,7 +950,7 @@ func (c *clusterState) slotSlaveNode(slot int) (*clusterNode, error) { for i := 0; i < 10; i++ { n := rand.Intn(len(nodes)-1) + 1 slave = nodes[n] - if !slave.Failing() { + if !slave.Failing() && !slave.Loading() { return slave, nil } } @@ -779,7 +966,7 @@ func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { return c.nodes.Random() } - var allNodesFailing = true + allNodesFailing := true var ( closestNonFailingNode *clusterNode closestNode *clusterNode @@ -833,6 +1020,29 @@ func (c *clusterState) slotRandomNode(slot int) (*clusterNode, error) { return nodes[randomNodes[0]], nil } +func (c *clusterState) slotShardPickerSlaveNode(slot int, shardPicker routing.ShardPicker) (*clusterNode, error) { + nodes := c.slotNodes(slot) + if len(nodes) == 0 { + return c.nodes.Random() + } + + // nodes[0] is master, nodes[1:] are slaves + // First, try all slave nodes for this slot using ShardPicker order + slaves := nodes[1:] + if len(slaves) > 0 { + for i := 0; i < len(slaves); i++ { + idx := shardPicker.Next(len(slaves)) + slave := slaves[idx] + if !slave.Failing() && !slave.Loading() { + return slave, nil + } + } + } + + // All slaves are failing or loading - return master + return nodes[0], nil +} + func (c *clusterState) slotNodes(slot int) []*clusterNode { i := sort.Search(len(c.slots), func(i int) bool { return c.slots[i].end >= slot @@ -852,13 +1062,16 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { type clusterStateHolder struct { load func(ctx context.Context) (*clusterState, error) - state atomic.Value - reloading uint32 // atomic + reloadInterval time.Duration + state atomic.Value + reloading uint32 // atomic + reloadPending uint32 // atomic - set to 1 when reload is requested during active reload } -func newClusterStateHolder(fn func(ctx context.Context) (*clusterState, error)) *clusterStateHolder { +func newClusterStateHolder(load func(ctx context.Context) (*clusterState, error), reloadInterval time.Duration) *clusterStateHolder { return &clusterStateHolder{ - load: fn, + load: load, + reloadInterval: reloadInterval, } } @@ -872,17 +1085,37 @@ func (c *clusterStateHolder) Reload(ctx context.Context) (*clusterState, error) } func (c *clusterStateHolder) LazyReload() { + // If already reloading, mark that another reload is pending if !atomic.CompareAndSwapUint32(&c.reloading, 0, 1) { + atomic.StoreUint32(&c.reloadPending, 1) return } + go func() { - defer atomic.StoreUint32(&c.reloading, 0) + for { + _, err := c.Reload(context.Background()) + if err != nil { + atomic.StoreUint32(&c.reloadPending, 0) + atomic.StoreUint32(&c.reloading, 0) + return + } - _, err := c.Reload(context.Background()) - if err != nil { - return + // Clear pending flag after reload completes, before cooldown + // This captures notifications that arrived during the reload + atomic.StoreUint32(&c.reloadPending, 0) + + // Wait cooldown period + time.Sleep(200 * time.Millisecond) + + // Check if another reload was requested during cooldown + if atomic.LoadUint32(&c.reloadPending) == 0 { + // No pending reload, we're done + atomic.StoreUint32(&c.reloading, 0) + return + } + + // Pending reload requested, loop to reload again } - time.Sleep(200 * time.Millisecond) }() } @@ -893,7 +1126,7 @@ func (c *clusterStateHolder) Get(ctx context.Context) (*clusterState, error) { } state := v.(*clusterState) - if time.Since(state.createdAt) > 10*time.Second { + if time.Since(state.createdAt) > c.reloadInterval { c.LazyReload() } return state, nil @@ -913,16 +1146,18 @@ func (c *clusterStateHolder) ReloadOrGet(ctx context.Context) (*clusterState, er // or more underlying connections. It's safe for concurrent use by // multiple goroutines. type ClusterClient struct { - opt *ClusterOptions - nodes *clusterNodes - state *clusterStateHolder - cmdsInfoCache *cmdsInfoCache + opt *ClusterOptions + nodes *clusterNodes + state *clusterStateHolder + cmdsInfoCache *cmdsInfoCache + cmdInfoResolver *commandInfoResolver cmdable hooksMixin } // NewClusterClient returns a Redis Cluster client as described in -// http://redis.io/topics/cluster-spec. +// https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec. +// Passing nil ClusterOptions will cause a panic. func NewClusterClient(opt *ClusterOptions) *ClusterClient { if opt == nil { panic("redis: NewClusterClient nil options") @@ -934,10 +1169,13 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { nodes: newClusterNodes(opt), } - c.state = newClusterStateHolder(c.loadState) c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) - c.cmdable = c.Process + c.state = newClusterStateHolder(c.loadState, opt.ClusterStateReloadInterval) + + c.SetCommandInfoResolver(NewDefaultCommandPolicyResolver()) + + c.cmdable = c.Process c.initHooks(hooks{ dial: nil, process: c.process, @@ -945,10 +1183,31 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { txPipeline: c.processTxPipeline, }) + // Set up SMIGRATED notification handling for cluster state reload + // When a node client receives a SMIGRATED notification, it should trigger + // cluster state reload on the parent ClusterClient + if opt.MaintNotificationsConfig != nil { + c.nodes.OnNewNode(func(nodeClient *Client) { + manager := nodeClient.GetMaintNotificationsManager() + if manager != nil { + manager.SetClusterStateReloadCallback(func(ctx context.Context, hostPort string, slotRanges []string) { + // Log the migration details for now + if internal.LogLevel.InfoOrAbove() { + internal.Logger.Printf(ctx, "cluster: slots %v migrated to %s, reloading cluster state", slotRanges, hostPort) + } + // Currently we reload the entire cluster state + // In the future, this could be optimized to reload only the specific slots + c.state.LazyReload() + }) + } + }) + } + return c } -// Options returns read-only Options that were used to create the client. +// Options returns read-only *ClusterOptions that were used to create the client. +// Any alteration of the returned *ClusterOptions may result in undefined behaviour. func (c *ClusterClient) Options() *ClusterOptions { return c.opt } @@ -967,13 +1226,6 @@ func (c *ClusterClient) Close() error { return c.nodes.Close() } -// Do create a Cmd from the args and processes the cmd. -func (c *ClusterClient) Do(ctx context.Context, args ...interface{}) *Cmd { - cmd := NewCmd(ctx, args...) - _ = c.Process(ctx, cmd) - return cmd -} - func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error { err := c.processHook(ctx, cmd) cmd.SetErr(err) @@ -981,7 +1233,7 @@ func (c *ClusterClient) Process(ctx context.Context, cmd Cmder) error { } func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { - slot := c.cmdSlot(ctx, cmd) + slot := c.cmdSlot(cmd, -1) var node *clusterNode var moved bool var ask bool @@ -997,7 +1249,11 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { if node == nil { var err error - node, err = c.cmdNode(ctx, cmd.Name(), slot) + if !c.opt.DisableRoutingPolicies && c.opt.ShardPicker != nil { + node, err = c.cmdNodeWithShardPicker(ctx, cmd.Name(), slot, c.opt.ShardPicker) + } else { + node, err = c.cmdNode(ctx, cmd.Name(), slot) + } if err != nil { return err } @@ -1005,13 +1261,16 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { if ask { ask = false - pipe := node.Client.Pipeline() _ = pipe.Process(ctx, NewCmd(ctx, "asking")) _ = pipe.Process(ctx, cmd) _, lastErr = pipe.Exec(ctx) } else { - lastErr = node.Client.Process(ctx, cmd) + if !c.opt.DisableRoutingPolicies { + lastErr = c.routeAndRun(ctx, cmd, node) + } else { + lastErr = node.Client.Process(ctx, cmd) + } } // If there is no error - we are done. @@ -1038,6 +1297,18 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { if moved || ask { c.state.LazyReload() + // Record error metrics + if errorCallback := pool.GetMetricErrorCallback(); errorCallback != nil { + errorType := "MOVED" + statusCode := "MOVED" + if ask { + errorType = "ASK" + statusCode = "ASK" + } + // MOVED/ASK are not internal errors, and this is the first attempt (retry count = 0) + errorCallback(ctx, errorType, nil, statusCode, false, 0) + } + var err error node, err = c.nodes.GetOrCreate(addr) if err != nil { @@ -1046,7 +1317,7 @@ func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error { continue } - if shouldRetry(lastErr, cmd.readTimeout() == nil) { + if shouldRetry(lastErr, cmd.readTimeout() == nil) && !cmd.NoRetry() { // First retry the same node. if attempt == 0 { continue @@ -1201,6 +1472,8 @@ func (c *ClusterClient) PoolStats() *PoolStats { acc.Hits += s.Hits acc.Misses += s.Misses acc.Timeouts += s.Timeouts + acc.WaitCount += s.WaitCount + acc.WaitDurationNs += s.WaitDurationNs acc.TotalConns += s.TotalConns acc.IdleConns += s.IdleConns @@ -1212,6 +1485,8 @@ func (c *ClusterClient) PoolStats() *PoolStats { acc.Hits += s.Hits acc.Misses += s.Misses acc.Timeouts += s.Timeouts + acc.WaitCount += s.WaitCount + acc.WaitDurationNs += s.WaitDurationNs acc.TotalConns += s.TotalConns acc.IdleConns += s.IdleConns @@ -1256,7 +1531,7 @@ func (c *ClusterClient) loadState(ctx context.Context) (*clusterState, error) { continue } - return newClusterState(c.nodes, slots, node.Client.opt.Addr) + return newClusterState(c.nodes, slots, addr) } /* @@ -1285,17 +1560,35 @@ func (c *ClusterClient) Pipelined(ctx context.Context, fn func(Pipeliner) error) } func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error { + // Only call time.Now() if pipeline operation duration callback is set to avoid overhead + var operationStart time.Time + pipelineOpDurationCallback := otel.GetPipelineOperationDurationCallback() + if pipelineOpDurationCallback != nil { + operationStart = time.Now() + } + totalAttempts := 0 + cmdsMap := newCmdsMap() if err := c.mapCmdsByNode(ctx, cmdsMap, cmds); err != nil { setCmdsErr(cmds, err) + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, "PIPELINE", len(cmds), 1, err, nil, 0) + } return err } + var lastErr error for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + totalAttempts++ if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { setCmdsErr(cmds, err) + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, "PIPELINE", len(cmds), totalAttempts, err, nil, 0) + } return err } } @@ -1316,6 +1609,17 @@ func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error break } cmdsMap = failedCmds + lastErr = cmdsFirstErr(cmds) + } + + // Record pipeline operation duration + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + finalErr := cmdsFirstErr(cmds) + if finalErr == nil { + finalErr = lastErr + } + pipelineOpDurationCallback(ctx, operationDuration, "PIPELINE", len(cmds), totalAttempts, finalErr, nil, 0) } return cmdsFirstErr(cmds) @@ -1329,10 +1633,31 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd if c.opt.ReadOnly && c.cmdsAreReadOnly(ctx, cmds) { for _, cmd := range cmds { - slot := c.cmdSlot(ctx, cmd) - node, err := c.slotReadOnlyNode(state, slot) - if err != nil { - return err + var policy *routing.CommandPolicy + if c.cmdInfoResolver != nil { + policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) + } + if policy != nil && !policy.CanBeUsedInPipeline() { + return fmt.Errorf( + "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), + ) + } + slot := c.cmdSlot(cmd, -1) + var node *clusterNode + // For keyless commands (slot == -1), use ShardPicker if routing policies are enabled + if slot == -1 && !c.opt.DisableRoutingPolicies && c.opt.ShardPicker != nil { + if len(state.Masters) == 0 { + return errClusterNoNodes + } + // For read-only keyless commands, pick from all nodes (masters + slaves) + allNodes := append(state.Masters, state.Slaves...) + idx := c.opt.ShardPicker.Next(len(allNodes)) + node = allNodes[idx] + } else { + node, err = c.slotReadOnlyNode(state, slot) + if err != nil { + return err + } } cmdsMap.Add(node, cmd) } @@ -1340,10 +1665,29 @@ func (c *ClusterClient) mapCmdsByNode(ctx context.Context, cmdsMap *cmdsMap, cmd } for _, cmd := range cmds { - slot := c.cmdSlot(ctx, cmd) - node, err := state.slotMasterNode(slot) - if err != nil { - return err + var policy *routing.CommandPolicy + if c.cmdInfoResolver != nil { + policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) + } + if policy != nil && !policy.CanBeUsedInPipeline() { + return fmt.Errorf( + "redis: cannot pipeline command %q with request policy ReqAllNodes/ReqAllShards/ReqMultiShard; Note: This behavior is subject to change in the future", cmd.Name(), + ) + } + slot := c.cmdSlot(cmd, -1) + var node *clusterNode + // For keyless commands (slot == -1), use ShardPicker if routing policies are enabled + if slot == -1 && !c.opt.DisableRoutingPolicies && c.opt.ShardPicker != nil { + if len(state.Masters) == 0 { + return errClusterNoNodes + } + idx := c.opt.ShardPicker.Next(len(state.Masters)) + node = state.Masters[idx] + } else { + node, err = state.slotMasterNode(slot) + if err != nil { + return err + } } cmdsMap.Add(node, cmd) } @@ -1393,7 +1737,7 @@ func (c *ClusterClient) processPipelineNodeConn( if isBadConn(err, false, node.Client.getAddr()) { node.MarkAsFailing() } - if shouldRetry(err, true) { + if shouldRetry(err, true) && !cmdsContainNoRetry(cmds) { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) } setCmdsErr(cmds, err) @@ -1429,7 +1773,7 @@ func (c *ClusterClient) pipelineReadCmds( } if !isRedisError(err) { - if shouldRetry(err, true) { + if shouldRetry(err, true) && !cmdsContainNoRetry(cmds) { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) } setCmdsErr(cmds[i+1:], err) @@ -1437,7 +1781,7 @@ func (c *ClusterClient) pipelineReadCmds( } } - if err := cmds[0].Err(); err != nil && shouldRetry(err, true) { + if err := cmds[0].Err(); err != nil && shouldRetry(err, true) && !cmdsContainNoRetry(cmds) { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) return err } @@ -1489,61 +1833,137 @@ func (c *ClusterClient) TxPipelined(ctx context.Context, fn func(Pipeliner) erro } func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { + // Only call time.Now() if pipeline operation duration callback is set to avoid overhead + var operationStart time.Time + pipelineOpDurationCallback := otel.GetPipelineOperationDurationCallback() + if pipelineOpDurationCallback != nil { + operationStart = time.Now() + } + totalAttempts := 0 + // Trim multi .. exec. cmds = cmds[1 : len(cmds)-1] + if len(cmds) == 0 { + return nil + } + state, err := c.state.Get(ctx) if err != nil { setCmdsErr(cmds, err) + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, "MULTI", len(cmds), 1, err, nil, 0) + } return err } - cmdsMap := c.mapCmdsBySlot(ctx, cmds) - for slot, cmds := range cmdsMap { - node, err := state.slotMasterNode(slot) - if err != nil { - setCmdsErr(cmds, err) - continue + keyedCmdsBySlot := c.slottedKeyedCommands(ctx, cmds) + slot := -1 + switch len(keyedCmdsBySlot) { + case 0: + slot = hashtag.RandomSlot() + case 1: + for sl := range keyedCmdsBySlot { + slot = sl + break } + default: + // TxPipeline does not support cross slot transaction. + setCmdsErr(cmds, ErrCrossSlot) + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, "MULTI", len(cmds), 1, ErrCrossSlot, nil, 0) + } + return ErrCrossSlot + } - cmdsMap := map[*clusterNode][]Cmder{node: cmds} - for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { - if attempt > 0 { - if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { - setCmdsErr(cmds, err) - return err + node, err := state.slotMasterNode(slot) + if err != nil { + setCmdsErr(cmds, err) + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, "MULTI", len(cmds), 1, err, nil, 0) + } + return err + } + + var lastErr error + cmdsMap := map[*clusterNode][]Cmder{node: cmds} + for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ { + totalAttempts++ + if attempt > 0 { + if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { + setCmdsErr(cmds, err) + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, "MULTI", len(cmds), totalAttempts, err, nil, 0) } + return err } + } - failedCmds := newCmdsMap() - var wg sync.WaitGroup + failedCmds := newCmdsMap() + var wg sync.WaitGroup - for node, cmds := range cmdsMap { - wg.Add(1) - go func(node *clusterNode, cmds []Cmder) { - defer wg.Done() - c.processTxPipelineNode(ctx, node, cmds, failedCmds) - }(node, cmds) - } + for node, cmds := range cmdsMap { + wg.Add(1) + go func(node *clusterNode, cmds []Cmder) { + defer wg.Done() + c.processTxPipelineNode(ctx, node, cmds, failedCmds) + }(node, cmds) + } - wg.Wait() - if len(failedCmds.m) == 0 { - break - } - cmdsMap = failedCmds.m + wg.Wait() + if len(failedCmds.m) == 0 { + break } + cmdsMap = failedCmds.m + lastErr = cmdsFirstErr(cmds) + } + + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + finalErr := cmdsFirstErr(cmds) + if finalErr == nil { + finalErr = lastErr + } + pipelineOpDurationCallback(ctx, operationDuration, "MULTI", len(cmds), totalAttempts, finalErr, nil, 0) } return cmdsFirstErr(cmds) } -func (c *ClusterClient) mapCmdsBySlot(ctx context.Context, cmds []Cmder) map[int][]Cmder { - cmdsMap := make(map[int][]Cmder) +// slottedKeyedCommands returns a map of slot to commands taking into account +// only commands that have keys. +func (c *ClusterClient) slottedKeyedCommands(ctx context.Context, cmds []Cmder) map[int][]Cmder { + cmdsSlots := map[int][]Cmder{} + + // Peek once outside the loop, one RLock for the whole batch instead of + // two per command (one for the keyless check, one inside cmdSlot). + cachedInfo := c.cmdsInfoCache.Peek() + + prefferedRandomSlot := -1 for _, cmd := range cmds { - slot := c.cmdSlot(ctx, cmd) - cmdsMap[slot] = append(cmdsMap[slot], cmd) + var info *CommandInfo + if cachedInfo != nil { + info = cachedInfo[cmd.Name()] + } + + pos := cmdFirstKeyPosWithInfo(cmd, info) + if pos == 0 { + continue + } + + slot := c.cmdSlotWithPos(cmd, pos, prefferedRandomSlot) + if prefferedRandomSlot == -1 { + prefferedRandomSlot = slot + } + + cmdsSlots[slot] = append(cmdsSlots[slot], cmd) } - return cmdsMap + + return cmdsSlots } func (c *ClusterClient) processTxPipelineNode( @@ -1574,7 +1994,7 @@ func (c *ClusterClient) processTxPipelineNodeConn( if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { - if shouldRetry(err, true) { + if shouldRetry(err, true) && !cmdsContainNoRetry(cmds) { _ = c.mapCmdsByNode(ctx, failedCmds, cmds) } setCmdsErr(cmds, err) @@ -1587,7 +2007,7 @@ func (c *ClusterClient) processTxPipelineNodeConn( trimmedCmds := cmds[1 : len(cmds)-1] if err := c.txPipelineReadQueued( - ctx, rd, statusCmd, trimmedCmds, failedCmds, + ctx, node, cn, rd, statusCmd, trimmedCmds, failedCmds, ); err != nil { setCmdsErr(cmds, err) @@ -1599,30 +2019,56 @@ func (c *ClusterClient) processTxPipelineNodeConn( return err } - return pipelineReadCmds(rd, trimmedCmds) + return node.Client.pipelineReadCmds(ctx, cn, rd, trimmedCmds) }) } func (c *ClusterClient) txPipelineReadQueued( ctx context.Context, + node *clusterNode, + cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder, failedCmds *cmdsMap, ) error { // Parse queued replies. + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + // Log the error but don't fail the command execution + // Push notification processing errors shouldn't break normal Redis operations + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } if err := statusCmd.readReply(rd); err != nil { return err } for _, cmd := range cmds { + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + // Log the error but don't fail the command execution + // Push notification processing errors shouldn't break normal Redis operations + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } err := statusCmd.readReply(rd) - if err == nil || c.checkMovedErr(ctx, cmd, err, failedCmds) || isRedisError(err) { - continue + if err != nil { + if c.checkMovedErr(ctx, cmd, err, failedCmds) { + // will be processed later + continue + } + cmd.SetErr(err) + if !isRedisError(err) { + return err + } } - return err } + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + // Log the error but don't fail the command execution + // Push notification processing errors shouldn't break normal Redis operations + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } // Parse number of replies. line, err := rd.ReadLine() if err != nil { @@ -1670,14 +2116,13 @@ func (c *ClusterClient) cmdsMoved( func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { if len(keys) == 0 { - return fmt.Errorf("redis: Watch requires at least one key") + return errNoWatchKeys } slot := hashtag.Slot(keys[0]) for _, key := range keys[1:] { if hashtag.Slot(key) != slot { - err := fmt.Errorf("redis: Watch requires all keys to be in the same slot") - return err + return errWatchCrosslot } } @@ -1693,10 +2138,18 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s } } - err = node.Client.Watch(ctx, fn, keys...) + // Track callback errors separately to avoid retrying user failures through cluster retry classification. + var fnErr error + err = node.Client.Watch(ctx, func(tx *Tx) error { + fnErr = fn(tx) + return fnErr + }, keys...) if err == nil { break } + if fnErr != nil { + return fnErr + } moved, ask, addr := isMovedError(err) if moved || ask { @@ -1728,38 +2181,64 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s return err } +// maintenance notifications won't work here for now func (c *ClusterClient) pubSub() *PubSub { var node *clusterNode pubsub := &PubSub{ opt: c.opt.clientOptions(), - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { if node != nil { panic("node != nil") } var err error + if len(channels) > 0 { slot := hashtag.Slot(channels[0]) - node, err = c.slotMasterNode(ctx, slot) + + // newConn in PubSub is only used for subscription connections, so it is safe to + // assume that a slave node can always be used when client options specify ReadOnly. + if c.opt.ReadOnly { + state, err := c.state.Get(ctx) + if err != nil { + return nil, err + } + + node, err = c.slotReadOnlyNode(state, slot) + if err != nil { + return nil, err + } + } else { + node, err = c.slotMasterNode(ctx, slot) + if err != nil { + return nil, err + } + } } else { node, err = c.nodes.Random() + if err != nil { + return nil, err + } } + cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels) if err != nil { + node = nil return nil, err } - - cn, err := node.Client.newConn(context.TODO()) + // will return nil if already initialized + err = node.Client.initConn(ctx, cn) if err != nil { + _ = cn.Close() node = nil - return nil, err } - + node.Client.pubSubPool.TrackConn(cn) return cn, nil }, closeConn: func(cn *pool.Conn) error { - err := node.Client.connPool.CloseConn(cn) + // Untrack connection from PubSubPool + node.Client.pubSubPool.UntrackConn(cn) + err := cn.Close() node = nil return err }, @@ -1820,7 +2299,6 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, for _, idx := range perm { addr := addrs[idx] - node, err := c.nodes.GetOrCreate(addr) if err != nil { if firstErr == nil { @@ -1833,6 +2311,7 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, if err == nil { return info, nil } + if firstErr == nil { firstErr = err } @@ -1844,32 +2323,64 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, return nil, firstErr } +// cmdInfo will fetch and cache the command policies after the first execution func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { - cmdsInfo, err := c.cmdsInfoCache.Get(ctx) + // Use a separate context that won't be canceled to ensure command info lookup + // doesn't fail due to original context cancellation + cmdInfoCtx := c.context(ctx) + if c.opt.ContextTimeoutEnabled && ctx != nil { + // If context timeout is enabled, still use a reasonable timeout + var cancel context.CancelFunc + cmdInfoCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + + cmdsInfo, err := c.cmdsInfoCache.Get(cmdInfoCtx) if err != nil { - internal.Logger.Printf(context.TODO(), "getting command info: %s", err) + internal.Logger.Printf(cmdInfoCtx, "getting command info: %s", err) return nil } info := cmdsInfo[name] if info == nil { - internal.Logger.Printf(context.TODO(), "info for cmd=%s not found", name) + internal.Logger.Printf(cmdInfoCtx, "info for cmd=%s not found", name) } + return info } -func (c *ClusterClient) cmdSlot(ctx context.Context, cmd Cmder) int { +// cmdInfoPeek returns the cached CommandInfo for the named command without +// triggering a round-trip to Redis. It returns nil when the cache is cold. +func (c *ClusterClient) cmdInfoPeek(name string) *CommandInfo { + if cmds := c.cmdsInfoCache.Peek(); cmds != nil { + return cmds[name] + } + return nil +} + +func (c *ClusterClient) cmdSlot(cmd Cmder, prefferedSlot int) int { + info := c.cmdInfoPeek(cmd.Name()) + return c.cmdSlotWithPos(cmd, cmdFirstKeyPosWithInfo(cmd, info), prefferedSlot) +} + +// cmdSlotWithPos computes the cluster slot for cmd given a pre-resolved first key +// position. Separating pos resolution from slot computation lets callers that +// already know pos avoid a redundant Peek() call. +func (c *ClusterClient) cmdSlotWithPos(cmd Cmder, pos int, prefferedSlot int) int { args := cmd.Args() if args[0] == "cluster" && (args[1] == "getkeysinslot" || args[1] == "countkeysinslot") { return args[2].(int) } - - return cmdSlot(cmd, cmdFirstKeyPos(cmd)) + return cmdSlot(cmd, pos, prefferedSlot) } -func cmdSlot(cmd Cmder, pos int) int { +func cmdSlot(cmd Cmder, pos int, prefferedRandomSlot int) int { if pos == 0 { - return hashtag.RandomSlot() + if prefferedRandomSlot != -1 { + return prefferedRandomSlot + } + // Return -1 for keyless commands to signal that ShardPicker should be used + return -1 } firstKey := cmd.stringArg(pos) return hashtag.Slot(firstKey) @@ -1894,6 +2405,36 @@ func (c *ClusterClient) cmdNode( return state.slotMasterNode(slot) } +func (c *ClusterClient) cmdNodeWithShardPicker( + ctx context.Context, + cmdName string, + slot int, + shardPicker routing.ShardPicker, +) (*clusterNode, error) { + state, err := c.state.Get(ctx) + if err != nil { + return nil, err + } + + // For keyless commands (slot == -1), use ShardPicker to select a shard + // This respects the user's configured ShardPicker policy + if slot == -1 { + if len(state.Masters) == 0 { + return nil, errClusterNoNodes + } + idx := shardPicker.Next(len(state.Masters)) + return state.Masters[idx], nil + } + + if c.opt.ReadOnly { + cmdInfo := c.cmdInfo(ctx, cmdName) + if cmdInfo != nil && cmdInfo.ReadOnly { + return c.slotReadOnlyNode(state, slot) + } + } + return state.slotMasterNode(slot) +} + func (c *ClusterClient) slotReadOnlyNode(state *clusterState, slot int) (*clusterNode, error) { if c.opt.RouteByLatency { return state.slotClosestNode(slot) @@ -1901,6 +2442,11 @@ func (c *ClusterClient) slotReadOnlyNode(state *clusterState, slot int) (*cluste if c.opt.RouteRandomly { return state.slotRandomNode(slot) } + + if c.opt.ShardPicker != nil { + return state.slotShardPickerSlaveNode(slot, c.opt.ShardPicker) + } + return state.slotSlaveNode(slot) } @@ -1938,7 +2484,7 @@ func (c *ClusterClient) MasterForKey(ctx context.Context, key string) (*Client, if err != nil { return nil, err } - return node.Client, err + return node.Client, nil } func (c *ClusterClient) context(ctx context.Context) context.Context { @@ -1948,26 +2494,38 @@ func (c *ClusterClient) context(ctx context.Context) context.Context { return context.Background() } -func appendUniqueNode(nodes []*clusterNode, node *clusterNode) []*clusterNode { - for _, n := range nodes { - if n == node { - return nodes - } +func (c *ClusterClient) GetResolver() *commandInfoResolver { + return c.cmdInfoResolver +} + +func (c *ClusterClient) SetCommandInfoResolver(cmdInfoResolver *commandInfoResolver) { + c.cmdInfoResolver = cmdInfoResolver +} + +// extractCommandInfo retrieves the routing policy for a command +func (c *ClusterClient) extractCommandInfo(ctx context.Context, cmd Cmder) *routing.CommandPolicy { + if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.CommandPolicy != nil { + return cmdInfo.CommandPolicy + } + + return nil +} + +// NewDynamicResolver returns a CommandInfoResolver +// that uses the underlying cmdInfo cache to resolve the policies +func (c *ClusterClient) NewDynamicResolver() *commandInfoResolver { + return &commandInfoResolver{ + resolveFunc: c.extractCommandInfo, } - return append(nodes, node) } -func appendIfNotExists(ss []string, es ...string) []string { -loop: - for _, e := range es { - for _, s := range ss { - if s == e { - continue loop - } +func appendIfNotExist[T comparable](vals []T, newVal T) []T { + for _, v := range vals { + if v == newVal { + return vals } - ss = append(ss, e) } - return ss + return append(vals, newVal) } //------------------------------------------------------------------------------ diff --git a/vendor/github.com/redis/go-redis/v9/osscluster_router.go b/vendor/github.com/redis/go-redis/v9/osscluster_router.go new file mode 100644 index 000000000..0da29530a --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/osscluster_router.go @@ -0,0 +1,1002 @@ +package redis + +import ( + "context" + "errors" + "fmt" + "reflect" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal/hashtag" + "github.com/redis/go-redis/v9/internal/routing" +) + +var ( + errInvalidCmdPointer = errors.New("redis: invalid command pointer") + errNoCmdsToAggregate = errors.New("redis: no commands to aggregate") + errNoResToAggregate = errors.New("redis: no results to aggregate") + errInvalidCursorCmdArgsCount = errors.New("redis: FT.CURSOR command requires at least 3 arguments") + errInvalidCursorIdType = errors.New("redis: invalid cursor ID type") +) + +// slotResult represents the result of executing a command on a specific slot +type slotResult struct { + cmd Cmder + keys []string + err error +} + +// routeAndRun routes a command to the appropriate cluster nodes and executes it +func (c *ClusterClient) routeAndRun(ctx context.Context, cmd Cmder, node *clusterNode) error { + var policy *routing.CommandPolicy + if c.cmdInfoResolver != nil { + policy = c.cmdInfoResolver.GetCommandPolicy(ctx, cmd) + } + + // Set stepCount from cmdInfo if not already set + if cmd.stepCount() == 0 { + if cmdInfo := c.cmdInfo(ctx, cmd.Name()); cmdInfo != nil && cmdInfo.StepCount > 0 { + cmd.SetStepCount(cmdInfo.StepCount) + } + } + + if policy == nil { + return c.executeDefault(ctx, cmd, policy, node) + } + switch policy.Request { + case routing.ReqAllNodes: + return c.executeOnAllNodes(ctx, cmd, policy) + case routing.ReqAllShards: + return c.executeOnAllShards(ctx, cmd, policy) + case routing.ReqMultiShard: + return c.executeMultiShard(ctx, cmd, policy) + case routing.ReqSpecial: + return c.executeSpecialCommand(ctx, cmd, policy, node) + default: + return c.executeDefault(ctx, cmd, policy, node) + } +} + +// executeDefault handles standard command routing based on keys +func (c *ClusterClient) executeDefault(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy, node *clusterNode) error { + if policy != nil && !c.hasKeys(cmd) { + if c.readOnlyEnabled() && policy.IsReadOnly() { + return c.executeOnArbitraryNode(ctx, cmd) + } + } + + return node.Client.Process(ctx, cmd) +} + +// executeOnArbitraryNode routes command to an arbitrary node +func (c *ClusterClient) executeOnArbitraryNode(ctx context.Context, cmd Cmder) error { + node := c.pickArbitraryNode(ctx) + if node == nil { + return errClusterNoNodes + } + return node.Client.Process(ctx, cmd) +} + +// executeOnAllNodes executes command on all nodes (masters and replicas) +func (c *ClusterClient) executeOnAllNodes(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } + + nodes := append(state.Masters, state.Slaves...) + if len(nodes) == 0 { + return errClusterNoNodes + } + + return c.executeParallel(ctx, cmd, nodes, policy) +} + +// executeOnAllShards executes command on all master shards +func (c *ClusterClient) executeOnAllShards(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + state, err := c.state.Get(ctx) + if err != nil { + return err + } + + if len(state.Masters) == 0 { + return errClusterNoNodes + } + + return c.executeParallel(ctx, cmd, state.Masters, policy) +} + +// executeMultiShard handles commands that operate on multiple keys across shards +func (c *ClusterClient) executeMultiShard(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy) error { + args := cmd.Args() + firstKeyPos := cmdFirstKeyPosWithInfo(cmd, c.cmdInfoPeek(cmd.Name())) + stepCount := int(cmd.stepCount()) + if stepCount == 0 { + stepCount = 1 // Default to 1 if not set + } + + if firstKeyPos == 0 || firstKeyPos >= len(args) { + return fmt.Errorf("redis: multi-shard command %s has no key arguments", cmd.Name()) + } + + // Group keys by slot + slotMap := make(map[int][]string) + keyOrder := make([]string, 0) + + for i := firstKeyPos; i < len(args); i += stepCount { + key, ok := args[i].(string) + if !ok { + return fmt.Errorf("redis: non-string key at position %d: %v", i, args[i]) + } + + slot := hashtag.Slot(key) + slotMap[slot] = append(slotMap[slot], key) + for j := 1; j < stepCount; j++ { + if i+j >= len(args) { + break + } + slotMap[slot] = append(slotMap[slot], args[i+j].(string)) + } + keyOrder = append(keyOrder, key) + } + + return c.executeMultiSlot(ctx, cmd, slotMap, keyOrder, policy, firstKeyPos) +} + +// executeMultiSlot executes commands across multiple slots concurrently +func (c *ClusterClient) executeMultiSlot(ctx context.Context, cmd Cmder, slotMap map[int][]string, keyOrder []string, policy *routing.CommandPolicy, firstKeyPos int) error { + results := make(chan slotResult, len(slotMap)) + var wg sync.WaitGroup + + // Execute on each slot concurrently + for slot, keys := range slotMap { + wg.Add(1) + go func(slot int, keys []string) { + defer wg.Done() + + node, err := c.cmdNodeWithShardPicker(ctx, cmd.Name(), slot, c.opt.ShardPicker) + if err != nil { + results <- slotResult{nil, keys, err} + return + } + + // Create a command for this specific slot's keys + subCmd := c.createSlotSpecificCommand(ctx, cmd, keys, firstKeyPos) + err = node.Client.Process(ctx, subCmd) + results <- slotResult{subCmd, keys, err} + }(slot, keys) + } + + go func() { + wg.Wait() + close(results) + }() + + return c.aggregateMultiSlotResults(ctx, cmd, results, keyOrder, policy) +} + +// createSlotSpecificCommand creates a new command for a specific slot's keys. +// firstKeyPos is passed in from the caller (computed once in executeMultiShard) +// so this function never independently re-peeks the cache — avoids the +// cold --> warm inconsistency the reviewer flagged. +func (c *ClusterClient) createSlotSpecificCommand(ctx context.Context, originalCmd Cmder, keys []string, firstKeyPos int) Cmder { + originalArgs := originalCmd.Args() + + // Build new args with only the specified keys + newArgs := make([]interface{}, 0, firstKeyPos+len(keys)) + + // Copy command name and arguments before the keys + newArgs = append(newArgs, originalArgs[:firstKeyPos]...) + + // Add the slot-specific keys + for _, key := range keys { + newArgs = append(newArgs, key) + } + + // Create a new command of the same type using the helper function + return createCommandByType(ctx, originalCmd.GetCmdType(), newArgs...) +} + +// createCommandByType creates a new command of the specified type with the given arguments +func createCommandByType(ctx context.Context, cmdType CmdType, args ...interface{}) Cmder { + switch cmdType { + case CmdTypeString: + return NewStringCmd(ctx, args...) + case CmdTypeInt: + return NewIntCmd(ctx, args...) + case CmdTypeBool: + return NewBoolCmd(ctx, args...) + case CmdTypeFloat: + return NewFloatCmd(ctx, args...) + case CmdTypeStringSlice: + return NewStringSliceCmd(ctx, args...) + case CmdTypeIntSlice: + return NewIntSliceCmd(ctx, args...) + case CmdTypeFloatSlice: + return NewFloatSliceCmd(ctx, args...) + case CmdTypeBoolSlice: + return NewBoolSliceCmd(ctx, args...) + case CmdTypeStatus: + return NewStatusCmd(ctx, args...) + case CmdTypeTime: + return NewTimeCmd(ctx, args...) + case CmdTypeMapStringString: + return NewMapStringStringCmd(ctx, args...) + case CmdTypeMapStringInt: + return NewMapStringIntCmd(ctx, args...) + case CmdTypeMapStringInterface: + return NewMapStringInterfaceCmd(ctx, args...) + case CmdTypeMapStringInterfaceSlice: + return NewMapStringInterfaceSliceCmd(ctx, args...) + case CmdTypeSlice: + return NewSliceCmd(ctx, args...) + case CmdTypeStringStructMap: + return NewStringStructMapCmd(ctx, args...) + case CmdTypeXMessageSlice: + return NewXMessageSliceCmd(ctx, args...) + case CmdTypeXStreamSlice: + return NewXStreamSliceCmd(ctx, args...) + case CmdTypeXPending: + return NewXPendingCmd(ctx, args...) + case CmdTypeXPendingExt: + return NewXPendingExtCmd(ctx, args...) + case CmdTypeXAutoClaim: + return NewXAutoClaimCmd(ctx, args...) + case CmdTypeXAutoClaimWithDeleted: + return NewXAutoClaimWithDeletedCmd(ctx, args...) + case CmdTypeXAutoClaimJustID: + return NewXAutoClaimJustIDCmd(ctx, args...) + case CmdTypeXInfoStreamFull: + return NewXInfoStreamFullCmd(ctx, args...) + case CmdTypeZSlice: + return NewZSliceCmd(ctx, args...) + case CmdTypeZWithKey: + return NewZWithKeyCmd(ctx, args...) + case CmdTypeClusterSlots: + return NewClusterSlotsCmd(ctx, args...) + case CmdTypeGeoPos: + return NewGeoPosCmd(ctx, args...) + case CmdTypeCommandsInfo: + return NewCommandsInfoCmd(ctx, args...) + case CmdTypeSlowLog: + return NewSlowLogCmd(ctx, args...) + case CmdTypeKeyValues: + return NewKeyValuesCmd(ctx, args...) + case CmdTypeZSliceWithKey: + return NewZSliceWithKeyCmd(ctx, args...) + case CmdTypeFunctionList: + return NewFunctionListCmd(ctx, args...) + case CmdTypeFunctionStats: + return NewFunctionStatsCmd(ctx, args...) + case CmdTypeKeyFlags: + return NewKeyFlagsCmd(ctx, args...) + case CmdTypeDuration: + return NewDurationCmd(ctx, time.Millisecond, args...) + } + return NewCmd(ctx, args...) +} + +// executeSpecialCommand handles commands with special routing requirements +func (c *ClusterClient) executeSpecialCommand(ctx context.Context, cmd Cmder, policy *routing.CommandPolicy, node *clusterNode) error { + switch cmd.Name() { + case "ft.cursor": + return c.executeCursorCommand(ctx, cmd) + default: + return c.executeDefault(ctx, cmd, policy, node) + } +} + +// executeCursorCommand handles FT.CURSOR commands with sticky routing +func (c *ClusterClient) executeCursorCommand(ctx context.Context, cmd Cmder) error { + args := cmd.Args() + if len(args) < 4 { + return errInvalidCursorCmdArgsCount + } + + cursorID, ok := args[3].(string) + if !ok { + return errInvalidCursorIdType + } + + // Route based on cursor ID to maintain stickiness + slot := hashtag.Slot(cursorID) + node, err := c.cmdNodeWithShardPicker(ctx, cmd.Name(), slot, c.opt.ShardPicker) + if err != nil { + return err + } + + return node.Client.Process(ctx, cmd) +} + +// executeParallel executes a command on multiple nodes concurrently +func (c *ClusterClient) executeParallel(ctx context.Context, cmd Cmder, nodes []*clusterNode, policy *routing.CommandPolicy) error { + if len(nodes) == 0 { + return errClusterNoNodes + } + + if len(nodes) == 1 { + return nodes[0].Client.Process(ctx, cmd) + } + + type nodeResult struct { + cmd Cmder + err error + } + + results := make(chan nodeResult, len(nodes)) + var wg sync.WaitGroup + + for _, node := range nodes { + wg.Add(1) + go func(n *clusterNode) { + defer wg.Done() + cmdCopy := cmd.Clone() + err := n.Client.Process(ctx, cmdCopy) + results <- nodeResult{cmdCopy, err} + }(node) + } + + go func() { + wg.Wait() + close(results) + }() + + // Collect results and check for errors + cmds := make([]Cmder, 0, len(nodes)) + var firstErr error + + for result := range results { + if result.err != nil && firstErr == nil { + firstErr = result.err + } + cmds = append(cmds, result.cmd) + } + + // If there was an error and no policy specified, fail fast + if firstErr != nil && (policy == nil || policy.Response == routing.RespDefaultKeyless) { + cmd.SetErr(firstErr) + return firstErr + } + + return c.aggregateResponses(cmd, cmds, policy) +} + +// aggregateMultiSlotResults aggregates results from multi-slot execution +func (c *ClusterClient) aggregateMultiSlotResults(ctx context.Context, cmd Cmder, results <-chan slotResult, keyOrder []string, policy *routing.CommandPolicy) error { + keyedResults := make(map[string]routing.AggregatorResErr) + var firstErr error + + for result := range results { + if result.err != nil && firstErr == nil { + firstErr = result.err + } + if result.cmd != nil && result.err == nil { + value, err := ExtractCommandValue(result.cmd) + + // Check if the result is a slice (e.g., from MGET) + if sliceValue, ok := value.([]interface{}); ok { + // Map each element to its corresponding key + for i, key := range result.keys { + if i < len(sliceValue) { + keyedResults[key] = routing.AggregatorResErr{Result: sliceValue[i], Err: err} + } else { + keyedResults[key] = routing.AggregatorResErr{Result: nil, Err: err} + } + } + } else { + // For non-slice results, map the entire result to each key + for _, key := range result.keys { + keyedResults[key] = routing.AggregatorResErr{Result: value, Err: err} + } + } + } + + // TODO: return multiple errors by order when we will implement multiple errors returning + if result.err != nil { + firstErr = result.err + } + } + + return c.aggregateKeyedValues(cmd, keyedResults, keyOrder, policy) +} + +// aggregateKeyedValues aggregates individual key-value pairs while preserving key order +func (c *ClusterClient) aggregateKeyedValues(cmd Cmder, keyedResults map[string]routing.AggregatorResErr, keyOrder []string, policy *routing.CommandPolicy) error { + if len(keyedResults) == 0 { + return errNoResToAggregate + } + + aggregator := c.createAggregator(policy, cmd, true) + + // Set key order for keyed aggregators + var keyedAgg *routing.DefaultKeyedAggregator + var isKeyedAgg bool + var err error + if keyedAgg, isKeyedAgg = aggregator.(*routing.DefaultKeyedAggregator); isKeyedAgg { + err = keyedAgg.BatchAddWithKeyOrder(keyedResults, keyOrder) + } else { + err = aggregator.BatchAdd(keyedResults) + } + + if err != nil { + return err + } + + return c.finishAggregation(cmd, aggregator) +} + +// aggregateResponses aggregates multiple shard responses +func (c *ClusterClient) aggregateResponses(cmd Cmder, cmds []Cmder, policy *routing.CommandPolicy) error { + if len(cmds) == 0 { + return errNoCmdsToAggregate + } + + if len(cmds) == 1 { + shardCmd := cmds[0] + if err := shardCmd.Err(); err != nil { + cmd.SetErr(err) + return err + } + value, _ := ExtractCommandValue(shardCmd) + return c.setCommandValue(cmd, value) + } + + aggregator := c.createAggregator(policy, cmd, false) + + batchWithErrs := []routing.AggregatorResErr{} + // Add all results to aggregator + for _, shardCmd := range cmds { + value, err := ExtractCommandValue(shardCmd) + batchWithErrs = append(batchWithErrs, routing.AggregatorResErr{ + Result: value, + Err: err, + }) + } + + err := aggregator.BatchSlice(batchWithErrs) + if err != nil { + return err + } + + return c.finishAggregation(cmd, aggregator) +} + +// createAggregator creates the appropriate response aggregator +func (c *ClusterClient) createAggregator(policy *routing.CommandPolicy, cmd Cmder, isKeyed bool) routing.ResponseAggregator { + if policy != nil { + return routing.NewResponseAggregator(policy.Response, cmd.Name()) + } + + if !isKeyed { + firstKeyPos := cmdFirstKeyPosWithInfo(cmd, c.cmdInfoPeek(cmd.Name())) + isKeyed = firstKeyPos > 0 + } + + return routing.NewDefaultAggregator(isKeyed) +} + +// finishAggregation completes the aggregation process and sets the result +func (c *ClusterClient) finishAggregation(cmd Cmder, aggregator routing.ResponseAggregator) error { + finalValue, finalErr := aggregator.Result() + if finalErr != nil { + cmd.SetErr(finalErr) + return finalErr + } + + return c.setCommandValue(cmd, finalValue) +} + +// pickArbitraryNode selects a master or slave shard using the configured ShardPicker +func (c *ClusterClient) pickArbitraryNode(ctx context.Context) *clusterNode { + state, err := c.state.Get(ctx) + if err != nil || len(state.Masters) == 0 { + return nil + } + + allNodes := append(state.Masters, state.Slaves...) + + idx := c.opt.ShardPicker.Next(len(allNodes)) + return allNodes[idx] +} + +// hasKeys checks if a command operates on keys +func (c *ClusterClient) hasKeys(cmd Cmder) bool { + firstKeyPos := cmdFirstKeyPosWithInfo(cmd, c.cmdInfoPeek(cmd.Name())) + return firstKeyPos > 0 +} + +func (c *ClusterClient) readOnlyEnabled() bool { + return c.opt.ReadOnly +} + +// setCommandValue sets the aggregated value on a command using the enum-based approach +func (c *ClusterClient) setCommandValue(cmd Cmder, value interface{}) error { + // If value is nil, it might mean ExtractCommandValue couldn't extract the value + // but the command might have executed successfully. In this case, don't set an error. + if value == nil { + // ExtractCommandValue returned nil - this means the command type is not supported + // in the aggregation flow. This is a programming error, not a runtime error. + if cmd.Err() != nil { + // Command already has an error, preserve it + return cmd.Err() + } + // Command executed successfully but we can't extract/set the aggregated value + // This indicates the command type needs to be added to ExtractCommandValue + return fmt.Errorf("redis: cannot aggregate command %s: unsupported command type %d", + cmd.Name(), cmd.GetCmdType()) + } + + switch cmd.GetCmdType() { + case CmdTypeGeneric: + if c, ok := cmd.(*Cmd); ok { + c.SetVal(value) + } + case CmdTypeString: + if c, ok := cmd.(*StringCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeInt: + if c, ok := cmd.(*IntCmd); ok { + if v, ok := value.(int64); ok { + c.SetVal(v) + } else if v, ok := value.(float64); ok { + c.SetVal(int64(v)) + } + } + case CmdTypeBool: + if c, ok := cmd.(*BoolCmd); ok { + if v, ok := value.(bool); ok { + c.SetVal(v) + } + } + case CmdTypeFloat: + if c, ok := cmd.(*FloatCmd); ok { + if v, ok := value.(float64); ok { + c.SetVal(v) + } + } + case CmdTypeStringSlice: + if c, ok := cmd.(*StringSliceCmd); ok { + if v, ok := value.([]string); ok { + c.SetVal(v) + } + } + case CmdTypeIntSlice: + if c, ok := cmd.(*IntSliceCmd); ok { + if v, ok := value.([]int64); ok { + c.SetVal(v) + } else if v, ok := value.([]float64); ok { + els := len(v) + intSlc := make([]int, els) + for i := range v { + intSlc[i] = int(v[i]) + } + } + } + case CmdTypeFloatSlice: + if c, ok := cmd.(*FloatSliceCmd); ok { + if v, ok := value.([]float64); ok { + c.SetVal(v) + } + } + case CmdTypeBoolSlice: + if c, ok := cmd.(*BoolSliceCmd); ok { + if v, ok := value.([]bool); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringString: + if c, ok := cmd.(*MapStringStringCmd); ok { + if v, ok := value.(map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInt: + if c, ok := cmd.(*MapStringIntCmd); ok { + if v, ok := value.(map[string]int64); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInterface: + if c, ok := cmd.(*MapStringInterfaceCmd); ok { + if v, ok := value.(map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeSlice: + if c, ok := cmd.(*SliceCmd); ok { + if v, ok := value.([]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeStatus: + if c, ok := cmd.(*StatusCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeDuration: + if c, ok := cmd.(*DurationCmd); ok { + if v, ok := value.(time.Duration); ok { + c.SetVal(v) + } + } + case CmdTypeTime: + if c, ok := cmd.(*TimeCmd); ok { + if v, ok := value.(time.Time); ok { + c.SetVal(v) + } + } + case CmdTypeKeyValueSlice: + if c, ok := cmd.(*KeyValueSliceCmd); ok { + if v, ok := value.([]KeyValue); ok { + c.SetVal(v) + } + } + case CmdTypeStringStructMap: + if c, ok := cmd.(*StringStructMapCmd); ok { + if v, ok := value.(map[string]struct{}); ok { + c.SetVal(v) + } + } + case CmdTypeXMessageSlice: + if c, ok := cmd.(*XMessageSliceCmd); ok { + if v, ok := value.([]XMessage); ok { + c.SetVal(v) + } + } + case CmdTypeXStreamSlice: + if c, ok := cmd.(*XStreamSliceCmd); ok { + if v, ok := value.([]XStream); ok { + c.SetVal(v) + } + } + case CmdTypeXPending: + if c, ok := cmd.(*XPendingCmd); ok { + if v, ok := value.(*XPending); ok { + c.SetVal(v) + } + } + case CmdTypeXPendingExt: + if c, ok := cmd.(*XPendingExtCmd); ok { + if v, ok := value.([]XPendingExt); ok { + c.SetVal(v) + } + } + case CmdTypeXAutoClaim: + if c, ok := cmd.(*XAutoClaimCmd); ok { + if v, ok := value.(CmdTypeXAutoClaimValue); ok { + c.SetVal(v.messages, v.start) + } + } + case CmdTypeXAutoClaimWithDeleted: + if c, ok := cmd.(*XAutoClaimWithDeletedCmd); ok { + if v, ok := value.(CmdTypeXAutoClaimWithDeletedValue); ok { + c.SetVal(v.messages, v.start, v.deletedIDs) + } + } + case CmdTypeXAutoClaimJustID: + if c, ok := cmd.(*XAutoClaimJustIDCmd); ok { + if v, ok := value.(CmdTypeXAutoClaimJustIDValue); ok { + c.SetVal(v.ids, v.start) + } + } + case CmdTypeXInfoConsumers: + if c, ok := cmd.(*XInfoConsumersCmd); ok { + if v, ok := value.([]XInfoConsumer); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoGroups: + if c, ok := cmd.(*XInfoGroupsCmd); ok { + if v, ok := value.([]XInfoGroup); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoStream: + if c, ok := cmd.(*XInfoStreamCmd); ok { + if v, ok := value.(*XInfoStream); ok { + c.SetVal(v) + } + } + case CmdTypeXInfoStreamFull: + if c, ok := cmd.(*XInfoStreamFullCmd); ok { + if v, ok := value.(*XInfoStreamFull); ok { + c.SetVal(v) + } + } + case CmdTypeZSlice: + if c, ok := cmd.(*ZSliceCmd); ok { + if v, ok := value.([]Z); ok { + c.SetVal(v) + } + } + case CmdTypeZWithKey: + if c, ok := cmd.(*ZWithKeyCmd); ok { + if v, ok := value.(*ZWithKey); ok { + c.SetVal(v) + } + } + case CmdTypeScan: + if c, ok := cmd.(*ScanCmd); ok { + if v, ok := value.(CmdTypeScanValue); ok { + c.SetVal(v.keys, v.cursor) + } + } + case CmdTypeClusterSlots: + if c, ok := cmd.(*ClusterSlotsCmd); ok { + if v, ok := value.([]ClusterSlot); ok { + c.SetVal(v) + } + } + case CmdTypeGeoLocation: + if c, ok := cmd.(*GeoLocationCmd); ok { + if v, ok := value.([]GeoLocation); ok { + c.SetVal(v) + } + } + case CmdTypeGeoSearchLocation: + if c, ok := cmd.(*GeoSearchLocationCmd); ok { + if v, ok := value.([]GeoLocation); ok { + c.SetVal(v) + } + } + case CmdTypeGeoPos: + if c, ok := cmd.(*GeoPosCmd); ok { + if v, ok := value.([]*GeoPos); ok { + c.SetVal(v) + } + } + case CmdTypeCommandsInfo: + if c, ok := cmd.(*CommandsInfoCmd); ok { + if v, ok := value.(map[string]*CommandInfo); ok { + c.SetVal(v) + } + } + case CmdTypeSlowLog: + if c, ok := cmd.(*SlowLogCmd); ok { + if v, ok := value.([]SlowLog); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringStringSlice: + if c, ok := cmd.(*MapStringStringSliceCmd); ok { + if v, ok := value.([]map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMapMapStringInterface: + if c, ok := cmd.(*MapMapStringInterfaceCmd); ok { + if v, ok := value.(map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeMapStringInterfaceSlice: + if c, ok := cmd.(*MapStringInterfaceSliceCmd); ok { + if v, ok := value.([]map[string]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeKeyValues: + if c, ok := cmd.(*KeyValuesCmd); ok { + // KeyValuesCmd needs a key string and values slice + if v, ok := value.(CmdTypeKeyValuesValue); ok { + c.SetVal(v.key, v.values) + } + } + case CmdTypeZSliceWithKey: + if c, ok := cmd.(*ZSliceWithKeyCmd); ok { + // ZSliceWithKeyCmd needs a key string and Z slice + if v, ok := value.(CmdTypeZSliceWithKeyValue); ok { + c.SetVal(v.key, v.zSlice) + } + } + case CmdTypeFunctionList: + if c, ok := cmd.(*FunctionListCmd); ok { + if v, ok := value.([]Library); ok { + c.SetVal(v) + } + } + case CmdTypeFunctionStats: + if c, ok := cmd.(*FunctionStatsCmd); ok { + if v, ok := value.(FunctionStats); ok { + c.SetVal(v) + } + } + case CmdTypeLCS: + if c, ok := cmd.(*LCSCmd); ok { + if v, ok := value.(*LCSMatch); ok { + c.SetVal(v) + } + } + case CmdTypeKeyFlags: + if c, ok := cmd.(*KeyFlagsCmd); ok { + if v, ok := value.([]KeyFlags); ok { + c.SetVal(v) + } + } + case CmdTypeClusterLinks: + if c, ok := cmd.(*ClusterLinksCmd); ok { + if v, ok := value.([]ClusterLink); ok { + c.SetVal(v) + } + } + case CmdTypeClusterShards: + if c, ok := cmd.(*ClusterShardsCmd); ok { + if v, ok := value.([]ClusterShard); ok { + c.SetVal(v) + } + } + case CmdTypeRankWithScore: + if c, ok := cmd.(*RankWithScoreCmd); ok { + if v, ok := value.(RankScore); ok { + c.SetVal(v) + } + } + case CmdTypeClientInfo: + if c, ok := cmd.(*ClientInfoCmd); ok { + if v, ok := value.(*ClientInfo); ok { + c.SetVal(v) + } + } + case CmdTypeACLLog: + if c, ok := cmd.(*ACLLogCmd); ok { + if v, ok := value.([]*ACLLogEntry); ok { + c.SetVal(v) + } + } + case CmdTypeInfo: + if c, ok := cmd.(*InfoCmd); ok { + if v, ok := value.(map[string]map[string]string); ok { + c.SetVal(v) + } + } + case CmdTypeMonitor: + // MonitorCmd doesn't have SetVal method + // Skip setting value for MonitorCmd + case CmdTypeJSON: + if c, ok := cmd.(*JSONCmd); ok { + if v, ok := value.(string); ok { + c.SetVal(v) + } + } + case CmdTypeJSONSlice: + if c, ok := cmd.(*JSONSliceCmd); ok { + if v, ok := value.([]interface{}); ok { + c.SetVal(v) + } + } + case CmdTypeIntPointerSlice: + if c, ok := cmd.(*IntPointerSliceCmd); ok { + if v, ok := value.([]*int64); ok { + c.SetVal(v) + } + } + case CmdTypeScanDump: + if c, ok := cmd.(*ScanDumpCmd); ok { + if v, ok := value.(ScanDump); ok { + c.SetVal(v) + } + } + case CmdTypeBFInfo: + if c, ok := cmd.(*BFInfoCmd); ok { + if v, ok := value.(BFInfo); ok { + c.SetVal(v) + } + } + case CmdTypeCFInfo: + if c, ok := cmd.(*CFInfoCmd); ok { + if v, ok := value.(CFInfo); ok { + c.SetVal(v) + } + } + case CmdTypeCMSInfo: + if c, ok := cmd.(*CMSInfoCmd); ok { + if v, ok := value.(CMSInfo); ok { + c.SetVal(v) + } + } + case CmdTypeTopKInfo: + if c, ok := cmd.(*TopKInfoCmd); ok { + if v, ok := value.(TopKInfo); ok { + c.SetVal(v) + } + } + case CmdTypeTDigestInfo: + if c, ok := cmd.(*TDigestInfoCmd); ok { + if v, ok := value.(TDigestInfo); ok { + c.SetVal(v) + } + } + case CmdTypeFTSynDump: + if c, ok := cmd.(*FTSynDumpCmd); ok { + if v, ok := value.([]FTSynDumpResult); ok { + c.SetVal(v) + } + } + case CmdTypeAggregate: + if c, ok := cmd.(*AggregateCmd); ok { + if v, ok := value.(*FTAggregateResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTInfo: + if c, ok := cmd.(*FTInfoCmd); ok { + if v, ok := value.(FTInfoResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTSpellCheck: + if c, ok := cmd.(*FTSpellCheckCmd); ok { + if v, ok := value.([]SpellCheckResult); ok { + c.SetVal(v) + } + } + case CmdTypeFTSearch: + if c, ok := cmd.(*FTSearchCmd); ok { + if v, ok := value.(FTSearchResult); ok { + c.SetVal(v) + } + } + case CmdTypeTSTimestampValue: + if c, ok := cmd.(*TSTimestampValueCmd); ok { + if v, ok := value.(TSTimestampValue); ok { + c.SetVal(v) + } + } + case CmdTypeTSTimestampValueSlice: + if c, ok := cmd.(*TSTimestampValueSliceCmd); ok { + if v, ok := value.([]TSTimestampValue); ok { + c.SetVal(v) + } + } + default: + // Fallback to reflection for unknown types + return c.setCommandValueReflection(cmd, value) + } + + return nil +} + +// setCommandValueReflection is a fallback function that uses reflection +func (c *ClusterClient) setCommandValueReflection(cmd Cmder, value interface{}) error { + cmdValue := reflect.ValueOf(cmd) + if cmdValue.Kind() != reflect.Ptr || cmdValue.IsNil() { + return errInvalidCmdPointer + } + + setValMethod := cmdValue.MethodByName("SetVal") + if !setValMethod.IsValid() { + return fmt.Errorf("redis: command %T does not have SetVal method", cmd) + } + + args := []reflect.Value{reflect.ValueOf(value)} + + switch cmd.(type) { + case *XAutoClaimCmd, *XAutoClaimJustIDCmd: + args = append(args, reflect.ValueOf("")) + case *ScanCmd: + args = append(args, reflect.ValueOf(uint64(0))) + case *KeyValuesCmd, *ZSliceWithKeyCmd: + if key, ok := value.(string); ok { + args = []reflect.Value{reflect.ValueOf(key)} + if _, ok := cmd.(*ZSliceWithKeyCmd); ok { + args = append(args, reflect.ValueOf([]Z{})) + } else { + args = append(args, reflect.ValueOf([]string{})) + } + } + } + + defer func() { + if r := recover(); r != nil { + cmd.SetErr(fmt.Errorf("redis: failed to set command value: %v", r)) + } + }() + + setValMethod.Call(args) + return nil +} diff --git a/vendor/github.com/redis/go-redis/v9/otel.go b/vendor/github.com/redis/go-redis/v9/otel.go new file mode 100644 index 000000000..1ea359364 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/otel.go @@ -0,0 +1,235 @@ +package redis + +import ( + "context" + "net" + "time" + + "github.com/redis/go-redis/v9/internal/otel" + "github.com/redis/go-redis/v9/internal/pool" +) + +// ConnInfo provides information about a Redis connection for metrics. +type ConnInfo interface { + RemoteAddr() net.Addr + PoolName() string +} + +type Pooler interface { + PoolStats() *pool.Stats +} + +type PubSubPooler interface { + Stats() *pool.PubSubStats +} + +// OTelRecorder is the interface for recording OpenTelemetry metrics. + +type OTelRecorder interface { + // RecordOperationDuration records the total operation duration (including all retries) + RecordOperationDuration(ctx context.Context, duration time.Duration, cmd Cmder, attempts int, err error, cn ConnInfo, dbIndex int) + + // RecordPipelineOperationDuration records the total pipeline/transaction duration. + // operationName should be "PIPELINE" for regular pipelines or "MULTI" for transactions. + RecordPipelineOperationDuration(ctx context.Context, duration time.Duration, operationName string, cmdCount int, attempts int, err error, cn ConnInfo, dbIndex int) + + // RecordConnectionCreateTime records the time it took to create a new connection + RecordConnectionCreateTime(ctx context.Context, duration time.Duration, cn ConnInfo) + + // RecordConnectionRelaxedTimeout records when connection timeout is relaxed/unrelaxed + // delta: +1 for relaxed, -1 for unrelaxed + // poolName: name of the connection pool (e.g., "main", "pubsub") + // notificationType: the notification type that triggered the timeout relaxation (e.g., "MOVING", "HANDOFF") + RecordConnectionRelaxedTimeout(ctx context.Context, delta int, cn ConnInfo, poolName, notificationType string) + + // RecordConnectionHandoff records when a connection is handed off to another node + // poolName: name of the connection pool (e.g., "main", "pubsub") + RecordConnectionHandoff(ctx context.Context, cn ConnInfo, poolName string) + + // RecordError records client errors (ASK, MOVED, handshake failures, etc.) + // errorType: type of error (e.g., "ASK", "MOVED", "HANDSHAKE_FAILED") + // statusCode: Redis response status code if available (e.g., "MOVED", "ASK") + // isInternal: whether this is an internal error + // retryAttempts: number of retry attempts made + RecordError(ctx context.Context, errorType string, cn ConnInfo, statusCode string, isInternal bool, retryAttempts int) + + // RecordMaintenanceNotification records when a maintenance notification is received + // notificationType: the type of notification (e.g., "MOVING", "MIGRATING", etc.) + RecordMaintenanceNotification(ctx context.Context, cn ConnInfo, notificationType string) + + // RecordConnectionWaitTime records the time spent waiting for a connection from the pool + RecordConnectionWaitTime(ctx context.Context, duration time.Duration, cn ConnInfo) + + // RecordConnectionClosed records when a connection is closed + // reason: reason for closing (e.g., "idle", "max_lifetime", "error", "pool_closed") + // err: the error that caused the close (nil for non-error closures) + RecordConnectionClosed(ctx context.Context, cn ConnInfo, reason string, err error) + + // RecordPubSubMessage records a Pub/Sub message + // direction: "sent" or "received" + // channel: channel name (may be hidden for cardinality reduction) + // sharded: true for sharded pub/sub (SPUBLISH/SSUBSCRIBE) + RecordPubSubMessage(ctx context.Context, cn ConnInfo, direction, channel string, sharded bool) + + // RecordStreamLag records the lag for stream consumer group processing + // lag: time difference between message creation and consumption + // streamName: name of the stream (may be hidden for cardinality reduction) + // consumerGroup: name of the consumer group + // consumerName: name of the consumer + RecordStreamLag(ctx context.Context, lag time.Duration, cn ConnInfo, streamName, consumerGroup, consumerName string) +} + +// OTelConnectionCounter is an optional capability interface for recording +// connection count and pending request changes via UpDownCounters. +// Implementations of OTelRecorder can optionally implement this interface +// to receive connection count and pending request delta notifications. +// This is kept separate from OTelRecorder to avoid breaking existing +// third-party implementations when new methods are added. +type OTelConnectionCounter interface { + // RecordConnectionCount records a change in connection count (UpDownCounter) + // delta: +1 when connection added, -1 when connection removed + // state: connection state (e.g., "idle", "used") + // isPubSub: true if this is a PubSub connection + RecordConnectionCount(ctx context.Context, delta int, cn ConnInfo, state string, isPubSub bool) + + // RecordPendingRequests records a change in pending requests (UpDownCounter) + // delta: +1 when request starts waiting, -1 when request stops waiting + // poolName is passed explicitly because we may not have a connection yet when request starts + RecordPendingRequests(ctx context.Context, delta int, cn ConnInfo, poolName string) +} + +// This is used for async gauge metrics that need to pull stats from pools periodically. +type OTelPoolRegistrar interface { + // RegisterPool is called when a new client is created with its main connection pool. + // poolName: unique identifier for the pool (e.g., "main_abc123") + RegisterPool(poolName string, pool Pooler) + // UnregisterPool is called when a client is closed to remove its pool from the registry. + UnregisterPool(pool Pooler) + // RegisterPubSubPool is called when a new client is created with a PubSub pool. + // poolName: unique identifier for the pool (e.g., "main_abc123_pubsub") + RegisterPubSubPool(poolName string, pool PubSubPooler) + // UnregisterPubSubPool is called when a PubSub client is closed to remove its pool. + UnregisterPubSubPool(pool PubSubPooler) +} + +// SetOTelRecorder sets the global OpenTelemetry recorder. +func SetOTelRecorder(r OTelRecorder) { + if r == nil { + otel.SetGlobalRecorder(nil) + return + } + otel.SetGlobalRecorder(&otelRecorderAdapter{r}) +} + +type otelRecorderAdapter struct { + recorder OTelRecorder +} + +// toConnInfo converts *pool.Conn to ConnInfo interface properly. +// This ensures that a nil *pool.Conn becomes a true nil interface, +// not a non-nil interface containing a nil pointer. +func toConnInfo(cn *pool.Conn) ConnInfo { + if cn == nil { + return nil + } + return cn +} + +func (a *otelRecorderAdapter) RecordOperationDuration(ctx context.Context, duration time.Duration, cmd otel.Cmder, attempts int, err error, cn *pool.Conn, dbIndex int) { + // Convert internal Cmder to public Cmder + if publicCmd, ok := cmd.(Cmder); ok { + a.recorder.RecordOperationDuration(ctx, duration, publicCmd, attempts, err, toConnInfo(cn), dbIndex) + } +} + +func (a *otelRecorderAdapter) RecordPipelineOperationDuration(ctx context.Context, duration time.Duration, operationName string, cmdCount int, attempts int, err error, cn *pool.Conn, dbIndex int) { + a.recorder.RecordPipelineOperationDuration(ctx, duration, operationName, cmdCount, attempts, err, toConnInfo(cn), dbIndex) +} + +func (a *otelRecorderAdapter) RecordConnectionCreateTime(ctx context.Context, duration time.Duration, cn *pool.Conn) { + a.recorder.RecordConnectionCreateTime(ctx, duration, toConnInfo(cn)) +} + +func (a *otelRecorderAdapter) RecordConnectionRelaxedTimeout(ctx context.Context, delta int, cn *pool.Conn, poolName, notificationType string) { + a.recorder.RecordConnectionRelaxedTimeout(ctx, delta, toConnInfo(cn), poolName, notificationType) +} + +func (a *otelRecorderAdapter) RecordConnectionHandoff(ctx context.Context, cn *pool.Conn, poolName string) { + a.recorder.RecordConnectionHandoff(ctx, toConnInfo(cn), poolName) +} + +func (a *otelRecorderAdapter) RecordError(ctx context.Context, errorType string, cn *pool.Conn, statusCode string, isInternal bool, retryAttempts int) { + a.recorder.RecordError(ctx, errorType, toConnInfo(cn), statusCode, isInternal, retryAttempts) +} + +func (a *otelRecorderAdapter) RecordMaintenanceNotification(ctx context.Context, cn *pool.Conn, notificationType string) { + a.recorder.RecordMaintenanceNotification(ctx, toConnInfo(cn), notificationType) +} + +func (a *otelRecorderAdapter) RecordConnectionWaitTime(ctx context.Context, duration time.Duration, cn *pool.Conn) { + a.recorder.RecordConnectionWaitTime(ctx, duration, toConnInfo(cn)) +} + +func (a *otelRecorderAdapter) RecordConnectionClosed(ctx context.Context, cn *pool.Conn, reason string, err error) { + a.recorder.RecordConnectionClosed(ctx, toConnInfo(cn), reason, err) +} + +func (a *otelRecorderAdapter) RecordPubSubMessage(ctx context.Context, cn *pool.Conn, direction, channel string, sharded bool) { + a.recorder.RecordPubSubMessage(ctx, toConnInfo(cn), direction, channel, sharded) +} + +func (a *otelRecorderAdapter) RecordStreamLag(ctx context.Context, lag time.Duration, cn *pool.Conn, streamName, consumerGroup, consumerName string) { + a.recorder.RecordStreamLag(ctx, lag, toConnInfo(cn), streamName, consumerGroup, consumerName) +} + +func (a *otelRecorderAdapter) RecordConnectionCount(ctx context.Context, delta int, cn *pool.Conn, state string, isPubSub bool) { + if counter, ok := a.recorder.(OTelConnectionCounter); ok { + counter.RecordConnectionCount(ctx, delta, toConnInfo(cn), state, isPubSub) + } +} + +func (a *otelRecorderAdapter) RecordPendingRequests(ctx context.Context, delta int, cn *pool.Conn, poolName string) { + if counter, ok := a.recorder.(OTelConnectionCounter); ok { + counter.RecordPendingRequests(ctx, delta, toConnInfo(cn), poolName) + } +} + +func (a *otelRecorderAdapter) RegisterPool(poolName string, p pool.Pooler) { + if registrar, ok := a.recorder.(OTelPoolRegistrar); ok { + registrar.RegisterPool(poolName, &poolerAdapter{p}) + } +} + +func (a *otelRecorderAdapter) UnregisterPool(p pool.Pooler) { + if registrar, ok := a.recorder.(OTelPoolRegistrar); ok { + registrar.UnregisterPool(&poolerAdapter{p}) + } +} + +func (a *otelRecorderAdapter) RegisterPubSubPool(poolName string, p otel.PubSubPooler) { + if registrar, ok := a.recorder.(OTelPoolRegistrar); ok { + registrar.RegisterPubSubPool(poolName, &pubSubPoolerAdapter{p}) + } +} + +func (a *otelRecorderAdapter) UnregisterPubSubPool(p otel.PubSubPooler) { + if registrar, ok := a.recorder.(OTelPoolRegistrar); ok { + registrar.UnregisterPubSubPool(&pubSubPoolerAdapter{p}) + } +} + +type poolerAdapter struct { + p pool.Pooler +} + +func (a *poolerAdapter) PoolStats() *pool.Stats { + return a.p.Stats() +} + +type pubSubPoolerAdapter struct { + p otel.PubSubPooler +} + +func (a *pubSubPoolerAdapter) Stats() *pool.PubSubStats { + return a.p.Stats() +} diff --git a/vendor/github.com/redis/go-redis/v9/pipeline.go b/vendor/github.com/redis/go-redis/v9/pipeline.go index 1c114205c..41b832213 100644 --- a/vendor/github.com/redis/go-redis/v9/pipeline.go +++ b/vendor/github.com/redis/go-redis/v9/pipeline.go @@ -7,7 +7,7 @@ import ( type pipelineExecer func(context.Context, []Cmder) error -// Pipeliner is an mechanism to realise Redis Pipeline technique. +// Pipeliner is a mechanism to realise Redis Pipeline technique. // // Pipelining is a technique to extremely speed up processing by packing // operations to batches, send them at once to Redis and read a replies in a @@ -23,27 +23,33 @@ type pipelineExecer func(context.Context, []Cmder) error type Pipeliner interface { StatefulCmdable - // Len is to obtain the number of commands in the pipeline that have not yet been executed. + // Len obtains the number of commands in the pipeline that have not yet been executed. Len() int // Do is an API for executing any command. // If a certain Redis command is not yet supported, you can use Do to execute it. Do(ctx context.Context, args ...interface{}) *Cmd - // Process is to put the commands to be executed into the pipeline buffer. + // Process queues the cmd for later execution. Process(ctx context.Context, cmd Cmder) error - // Discard is to discard all commands in the cache that have not yet been executed. + // BatchProcess adds multiple commands to be executed into the pipeline buffer. + BatchProcess(ctx context.Context, cmd ...Cmder) error + + // Discard discards all commands in the pipeline buffer that have not yet been executed. Discard() - // Exec is to send all the commands buffered in the pipeline to the redis-server. + // Exec sends all the commands buffered in the pipeline to the redis server. Exec(ctx context.Context) ([]Cmder, error) + + // Cmds returns the list of queued commands. + Cmds() []Cmder } var _ Pipeliner = (*Pipeline)(nil) // Pipeline implements pipelining as described in -// http://redis.io/topics/pipelining. +// https://redis.io/docs/latest/develop/using-commands/pipelining. // Please note: it is not safe for concurrent use by multiple goroutines. type Pipeline struct { cmdable @@ -76,7 +82,12 @@ func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd { // Process queues the cmd for later execution. func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error { - c.cmds = append(c.cmds, cmd) + return c.BatchProcess(ctx, cmd) +} + +// BatchProcess queues multiple cmds for later execution. +func (c *Pipeline) BatchProcess(ctx context.Context, cmd ...Cmder) error { + c.cmds = append(c.cmds, cmd...) return nil } @@ -119,3 +130,7 @@ func (c *Pipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([ func (c *Pipeline) TxPipeline() Pipeliner { return c } + +func (c *Pipeline) Cmds() []Cmder { + return c.cmds +} diff --git a/vendor/github.com/redis/go-redis/v9/probabilistic.go b/vendor/github.com/redis/go-redis/v9/probabilistic.go index 02ca263cb..ee67911e6 100644 --- a/vendor/github.com/redis/go-redis/v9/probabilistic.go +++ b/vendor/github.com/redis/go-redis/v9/probabilistic.go @@ -225,8 +225,9 @@ type ScanDumpCmd struct { func newScanDumpCmd(ctx context.Context, args ...interface{}) *ScanDumpCmd { return &ScanDumpCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeScanDump, }, } } @@ -270,6 +271,13 @@ func (cmd *ScanDumpCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *ScanDumpCmd) Clone() Cmder { + return &ScanDumpCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // ScanDump is a simple struct, can be copied directly + } +} + // Returns information about a Bloom filter. // For more information - https://redis.io/commands/bf.info/ func (c cmdable) BFInfo(ctx context.Context, key string) *BFInfoCmd { @@ -296,8 +304,9 @@ type BFInfoCmd struct { func NewBFInfoCmd(ctx context.Context, args ...interface{}) *BFInfoCmd { return &BFInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeBFInfo, }, } } @@ -388,6 +397,13 @@ func (cmd *BFInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *BFInfoCmd) Clone() Cmder { + return &BFInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // BFInfo is a simple struct, can be copied directly + } +} + // BFInfoCapacity returns information about the capacity of a Bloom filter. // For more information - https://redis.io/commands/bf.info/ func (c cmdable) BFInfoCapacity(ctx context.Context, key string) *BFInfoCmd { @@ -625,8 +641,9 @@ type CFInfoCmd struct { func NewCFInfoCmd(ctx context.Context, args ...interface{}) *CFInfoCmd { return &CFInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCFInfo, }, } } @@ -692,6 +709,13 @@ func (cmd *CFInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *CFInfoCmd) Clone() Cmder { + return &CFInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // CFInfo is a simple struct, can be copied directly + } +} + // CFInfo returns information about a Cuckoo filter. // For more information - https://redis.io/commands/cf.info/ func (c cmdable) CFInfo(ctx context.Context, key string) *CFInfoCmd { @@ -787,8 +811,9 @@ type CMSInfoCmd struct { func NewCMSInfoCmd(ctx context.Context, args ...interface{}) *CMSInfoCmd { return &CMSInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeCMSInfo, }, } } @@ -843,6 +868,13 @@ func (cmd *CMSInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *CMSInfoCmd) Clone() Cmder { + return &CMSInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // CMSInfo is a simple struct, can be copied directly + } +} + // CMSInfo returns information about a Count-Min Sketch filter. // For more information - https://redis.io/commands/cms.info/ func (c cmdable) CMSInfo(ctx context.Context, key string) *CMSInfoCmd { @@ -980,8 +1012,9 @@ type TopKInfoCmd struct { func NewTopKInfoCmd(ctx context.Context, args ...interface{}) *TopKInfoCmd { return &TopKInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTopKInfo, }, } } @@ -1038,6 +1071,13 @@ func (cmd *TopKInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TopKInfoCmd) Clone() Cmder { + return &TopKInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TopKInfo is a simple struct, can be copied directly + } +} + // TopKInfo returns information about a Top-K filter. // For more information - https://redis.io/commands/topk.info/ func (c cmdable) TopKInfo(ctx context.Context, key string) *TopKInfoCmd { @@ -1116,18 +1156,14 @@ func (c cmdable) TopKListWithCount(ctx context.Context, key string) *MapStringIn // Returns OK on success or an error if the operation could not be completed. // For more information - https://redis.io/commands/tdigest.add/ func (c cmdable) TDigestAdd(ctx context.Context, key string, elements ...float64) *StatusCmd { - args := make([]interface{}, 2, 2+len(elements)) + args := make([]interface{}, 2+len(elements)) args[0] = "TDIGEST.ADD" args[1] = key - // Convert floatSlice to []interface{} - interfaceSlice := make([]interface{}, len(elements)) for i, v := range elements { - interfaceSlice[i] = v + args[2+i] = v } - args = append(args, interfaceSlice...) - cmd := NewStatusCmd(ctx, args...) _ = c(ctx, cmd) return cmd @@ -1138,18 +1174,14 @@ func (c cmdable) TDigestAdd(ctx context.Context, key string, elements ...float64 // Returns an array of floats representing the values at the specified ranks or an error if the operation could not be completed. // For more information - https://redis.io/commands/tdigest.byrank/ func (c cmdable) TDigestByRank(ctx context.Context, key string, rank ...uint64) *FloatSliceCmd { - args := make([]interface{}, 2, 2+len(rank)) + args := make([]interface{}, 2+len(rank)) args[0] = "TDIGEST.BYRANK" args[1] = key - // Convert uint slice to []interface{} - interfaceSlice := make([]interface{}, len(rank)) - for i, v := range rank { - interfaceSlice[i] = v + for i, r := range rank { + args[2+i] = r } - args = append(args, interfaceSlice...) - cmd := NewFloatSliceCmd(ctx, args...) _ = c(ctx, cmd) return cmd @@ -1160,18 +1192,14 @@ func (c cmdable) TDigestByRank(ctx context.Context, key string, rank ...uint64) // Returns an array of floats representing the values at the specified ranks or an error if the operation could not be completed. // For more information - https://redis.io/commands/tdigest.byrevrank/ func (c cmdable) TDigestByRevRank(ctx context.Context, key string, rank ...uint64) *FloatSliceCmd { - args := make([]interface{}, 2, 2+len(rank)) + args := make([]interface{}, 2+len(rank)) args[0] = "TDIGEST.BYREVRANK" args[1] = key - // Convert uint slice to []interface{} - interfaceSlice := make([]interface{}, len(rank)) - for i, v := range rank { - interfaceSlice[i] = v + for i, r := range rank { + args[2+i] = r } - args = append(args, interfaceSlice...) - cmd := NewFloatSliceCmd(ctx, args...) _ = c(ctx, cmd) return cmd @@ -1182,18 +1210,14 @@ func (c cmdable) TDigestByRevRank(ctx context.Context, key string, rank ...uint6 // Returns an array of floats representing the CDF values for each element or an error if the operation could not be completed. // For more information - https://redis.io/commands/tdigest.cdf/ func (c cmdable) TDigestCDF(ctx context.Context, key string, elements ...float64) *FloatSliceCmd { - args := make([]interface{}, 2, 2+len(elements)) + args := make([]interface{}, 2+len(elements)) args[0] = "TDIGEST.CDF" args[1] = key - // Convert floatSlice to []interface{} - interfaceSlice := make([]interface{}, len(elements)) for i, v := range elements { - interfaceSlice[i] = v + args[2+i] = v } - args = append(args, interfaceSlice...) - cmd := NewFloatSliceCmd(ctx, args...) _ = c(ctx, cmd) return cmd @@ -1243,8 +1267,9 @@ type TDigestInfoCmd struct { func NewTDigestInfoCmd(ctx context.Context, args ...interface{}) *TDigestInfoCmd { return &TDigestInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTDigestInfo, }, } } @@ -1311,6 +1336,13 @@ func (cmd *TDigestInfoCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TDigestInfoCmd) Clone() Cmder { + return &TDigestInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: cmd.val, // TDigestInfo is a simple struct, can be copied directly + } +} + // TDigestInfo returns information about a t-Digest data structure. // For more information - https://redis.io/commands/tdigest.info/ func (c cmdable) TDigestInfo(ctx context.Context, key string) *TDigestInfoCmd { @@ -1376,18 +1408,14 @@ func (c cmdable) TDigestMin(ctx context.Context, key string) *FloatCmd { // Returns an array of floats representing the quantile values for each element or an error if the operation could not be completed. // For more information - https://redis.io/commands/tdigest.quantile/ func (c cmdable) TDigestQuantile(ctx context.Context, key string, elements ...float64) *FloatSliceCmd { - args := make([]interface{}, 2, 2+len(elements)) + args := make([]interface{}, 2+len(elements)) args[0] = "TDIGEST.QUANTILE" args[1] = key - // Convert floatSlice to []interface{} - interfaceSlice := make([]interface{}, len(elements)) for i, v := range elements { - interfaceSlice[i] = v + args[2+i] = v } - args = append(args, interfaceSlice...) - cmd := NewFloatSliceCmd(ctx, args...) _ = c(ctx, cmd) return cmd @@ -1398,18 +1426,14 @@ func (c cmdable) TDigestQuantile(ctx context.Context, key string, elements ...fl // Returns an array of integers representing the rank values for each element or an error if the operation could not be completed. // For more information - https://redis.io/commands/tdigest.rank/ func (c cmdable) TDigestRank(ctx context.Context, key string, values ...float64) *IntSliceCmd { - args := make([]interface{}, 2, 2+len(values)) + args := make([]interface{}, 2+len(values)) args[0] = "TDIGEST.RANK" args[1] = key - // Convert floatSlice to []interface{} - interfaceSlice := make([]interface{}, len(values)) for i, v := range values { - interfaceSlice[i] = v + args[i+2] = v } - args = append(args, interfaceSlice...) - cmd := NewIntSliceCmd(ctx, args...) _ = c(ctx, cmd) return cmd @@ -1431,18 +1455,14 @@ func (c cmdable) TDigestReset(ctx context.Context, key string) *StatusCmd { // Returns an array of integers representing the reverse rank values for each element or an error if the operation could not be completed. // For more information - https://redis.io/commands/tdigest.revrank/ func (c cmdable) TDigestRevRank(ctx context.Context, key string, values ...float64) *IntSliceCmd { - args := make([]interface{}, 2, 2+len(values)) + args := make([]interface{}, 2+len(values)) args[0] = "TDIGEST.REVRANK" args[1] = key - // Convert floatSlice to []interface{} - interfaceSlice := make([]interface{}, len(values)) for i, v := range values { - interfaceSlice[i] = v + args[2+i] = v } - args = append(args, interfaceSlice...) - cmd := NewIntSliceCmd(ctx, args...) _ = c(ctx, cmd) return cmd diff --git a/vendor/github.com/redis/go-redis/v9/pubsub.go b/vendor/github.com/redis/go-redis/v9/pubsub.go index 2a0e7a81e..9d6961059 100644 --- a/vendor/github.com/redis/go-redis/v9/pubsub.go +++ b/vendor/github.com/redis/go-redis/v9/pubsub.go @@ -3,17 +3,21 @@ package redis import ( "context" "fmt" + "maps" + "slices" "strings" "sync" "time" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/otel" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/push" ) // PubSub implements Pub/Sub commands as described in -// http://redis.io/topics/pubsub. Message receiving is NOT safe +// https://redis.io/docs/latest/develop/pubsub. Message receiving is NOT safe // for concurrent use by multiple goroutines. // // PubSub automatically reconnects to Redis Server and resubscribes @@ -21,7 +25,7 @@ import ( type PubSub struct { opt *Options - newConn func(ctx context.Context, channels []string) (*pool.Conn, error) + newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) closeConn func(*pool.Conn) error mu sync.Mutex @@ -38,6 +42,12 @@ type PubSub struct { chOnce sync.Once msgCh *channel allCh *channel + + // Push notification processor for handling generic push notifications + pushProcessor push.NotificationProcessor + + // Cleanup callback for maintenanceNotifications upgrade tracking + onClose func() } func (c *PubSub) init() { @@ -48,9 +58,9 @@ func (c *PubSub) String() string { c.mu.Lock() defer c.mu.Unlock() - channels := mapKeys(c.channels) - channels = append(channels, mapKeys(c.patterns)...) - channels = append(channels, mapKeys(c.schannels)...) + channels := slices.Collect(maps.Keys(c.channels)) + channels = append(channels, slices.Collect(maps.Keys(c.patterns))...) + channels = append(channels, slices.Collect(maps.Keys(c.schannels))...) return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) } @@ -69,10 +79,27 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er return c.cn, nil } - channels := mapKeys(c.channels) + if c.opt.Addr == "" { + // TODO(maintenanceNotifications): + // this is probably cluster client + // c.newConn will ignore the addr argument + // will be changed when we have maintenanceNotifications upgrades for cluster clients + c.opt.Addr = internal.RedisNull + } + + // Include c.schannels so reconnect-time routing of an SSubscribe-only + // PubSub picks the slot owner (channels[0] in ClusterClient.pubSub()'s + // newConn closure) instead of a random node. + // See https://github.com/redis/go-redis/issues/3806. + // c.patterns is intentionally NOT included: patterns are not slot- + // addressable, and adding them would force PSubscribe-only PubSubs to + // pin to a single node based on pattern-string hash, regressing the + // existing random-node behaviour. + channels := slices.Collect(maps.Keys(c.channels)) + channels = append(channels, slices.Collect(maps.Keys(c.schannels))...) channels = append(channels, newChannels...) - cn, err := c.newConn(ctx, channels) + cn, err := c.newConn(ctx, c.opt.Addr, channels) if err != nil { return nil, err } @@ -96,18 +123,18 @@ func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error { var firstErr error if len(c.channels) > 0 { - firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels)) + firstErr = c._subscribe(ctx, cn, "subscribe", slices.Collect(maps.Keys(c.channels))) } if len(c.patterns) > 0 { - err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns)) + err := c._subscribe(ctx, cn, "psubscribe", slices.Collect(maps.Keys(c.patterns))) if err != nil && firstErr == nil { firstErr = err } } if len(c.schannels) > 0 { - err := c._subscribe(ctx, cn, "ssubscribe", mapKeys(c.schannels)) + err := c._subscribe(ctx, cn, "ssubscribe", slices.Collect(maps.Keys(c.schannels))) if err != nil && firstErr == nil { firstErr = err } @@ -116,16 +143,6 @@ func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error { return firstErr } -func mapKeys(m map[string]struct{}) []string { - s := make([]string, len(m)) - i := 0 - for k := range m { - s[i] = k - i++ - } - return s -} - func (c *PubSub) _subscribe( ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, ) error { @@ -153,12 +170,32 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo if c.cn != cn { return } + + if !cn.IsUsable() || cn.ShouldHandoff() { + c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable")) + return + } + if isBadConn(err, allowTimeout, c.opt.Addr) { c.reconnect(ctx, err) } } func (c *PubSub) reconnect(ctx context.Context, reason error) { + if c.cn != nil && c.cn.ShouldHandoff() { + newEndpoint := c.cn.GetHandoffEndpoint() + // If new endpoint is NULL, use the original address + if newEndpoint == internal.RedisNull { + newEndpoint = c.opt.Addr + } + + if newEndpoint != "" { + // Update the address in the options + oldAddr := c.cn.RemoteAddr().String() + c.opt.Addr = newEndpoint + internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) + } + } _ = c.closeTheCn(reason) _, _ = c.conn(ctx, nil) } @@ -167,9 +204,6 @@ func (c *PubSub) closeTheCn(reason error) error { if c.cn == nil { return nil } - if !c.closed { - internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason) - } err := c.closeConn(c.cn) c.cn = nil return err @@ -185,6 +219,11 @@ func (c *PubSub) Close() error { c.closed = true close(c.exit) + // Call cleanup callback if set + if c.onClose != nil { + c.onClose() + } + return c.closeTheCn(pool.ErrClosed) } @@ -247,9 +286,7 @@ func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error { } } else { // Unsubscribe from all channels. - for channel := range c.channels { - delete(c.channels, channel) - } + clear(c.channels) } err := c.subscribe(ctx, "unsubscribe", channels...) @@ -268,9 +305,7 @@ func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { } } else { // Unsubscribe from all patterns. - for pattern := range c.patterns { - delete(c.patterns, pattern) - } + clear(c.patterns) } err := c.subscribe(ctx, "punsubscribe", patterns...) @@ -289,9 +324,7 @@ func (c *PubSub) SUnsubscribe(ctx context.Context, channels ...string) error { } } else { // Unsubscribe from all channels. - for channel := range c.schannels { - delete(c.schannels, channel) - } + clear(c.schannels) } err := c.subscribe(ctx, "sunsubscribe", channels...) @@ -329,6 +362,25 @@ func (c *PubSub) Ping(ctx context.Context, payload ...string) error { return err } +// ClientSetName assigns a namee to the PubSub connection using CLIENT SETNAME, +// The name is visible in CLIENT LIST output and is useful for debugging +// and identifying connections in a redis instance. +func (c *PubSub) ClientSetName(ctx context.Context, name string) error { + cmd := NewStatusCmd(ctx, "client", "setname", name) + + c.mu.Lock() + defer c.mu.Unlock() + + cn, err := c.conn(ctx, nil) + if err != nil { + return err + } + + err = c.writeCmd(ctx, cn, cmd) + c.releaseConn(ctx, cn, err, false) + return err +} + // Subscription received after a successful subscription to channel. type Subscription struct { // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". @@ -367,7 +419,7 @@ func (p *Pong) String() string { return "Pong" } -func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { +func (c *PubSub) newMessage(ctx context.Context, cn *pool.Conn, reply interface{}) (interface{}, error) { switch reply := reply.(type) { case string: return &Pong{ @@ -384,30 +436,42 @@ func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { Count: int(reply[2].(int64)), }, nil case "message", "smessage": + channel := reply[1].(string) + sharded := kind == "smessage" switch payload := reply[2].(type) { case string: - return &Message{ - Channel: reply[1].(string), + msg := &Message{ + Channel: channel, Payload: payload, - }, nil + } + // Record PubSub message received + otel.RecordPubSubMessage(ctx, cn, "received", channel, sharded) + return msg, nil case []interface{}: ss := make([]string, len(payload)) for i, s := range payload { ss[i] = s.(string) } - return &Message{ - Channel: reply[1].(string), + msg := &Message{ + Channel: channel, PayloadSlice: ss, - }, nil + } + // Record PubSub message received + otel.RecordPubSubMessage(ctx, cn, "received", channel, sharded) + return msg, nil default: return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload) } case "pmessage": - return &Message{ + channel := reply[2].(string) + msg := &Message{ Pattern: reply[1].(string), - Channel: reply[2].(string), + Channel: channel, Payload: reply[3].(string), - }, nil + } + // Record PubSub message received (pattern message, not sharded) + otel.RecordPubSubMessage(ctx, cn, "received", channel, false) + return msg, nil case "pong": return &Pong{ Payload: reply[1].(string), @@ -429,28 +493,38 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int } // Don't hold the lock to allow subscriptions and pings. - cn, err := c.connWithLock(ctx) if err != nil { return nil, err } err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error { + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + // Log the error but don't fail the command execution + // Push notification processing errors shouldn't break normal Redis operations + internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) + } return c.cmd.readReply(rd) }) - c.releaseConnWithLock(ctx, cn, err, timeout > 0) if err != nil { return nil, err } - return c.newMessage(c.cmd.Val()) + return c.newMessage(ctx, cn, c.cmd.Val()) } // Receive returns a message as a Subscription, Message, Pong or error. // See PubSub example for details. This is low-level API and in most cases // Channel should be used instead. +// Receive returns a message as a Subscription, Message, Pong, or an error. +// See PubSub example for details. This is a low-level API and in most cases +// Channel should be used instead. +// This method blocks until a message is received or an error occurs. +// It may return early with an error if the context is canceled, the connection fails, +// or other internal errors occur. func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { return c.ReceiveTimeout(ctx, 0) } @@ -532,6 +606,27 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac return c.allCh.allCh } +func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error { + // Only process push notifications for RESP3 connections with a processor + if c.opt.Protocol != 3 || c.pushProcessor == nil { + return nil + } + + // Create handler context with client, connection pool, and connection information + handlerCtx := c.pushNotificationHandlerContext(cn) + return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd) +} + +func (c *PubSub) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext { + // PubSub doesn't have a client or connection pool, so we pass nil for those + // PubSub connections are blocking + return push.NotificationHandlerContext{ + PubSub: c, + Conn: cn, + IsBlocking: true, + } +} + type ChannelOption func(c *channel) // WithChannelSize specifies the Go chan size that is used to buffer incoming messages. @@ -667,7 +762,7 @@ func (c *channel) initMsgChan() { } case <-timer.C: internal.Logger.Printf( - ctx, "redis: %s channel is full for %s (message is dropped)", + ctx, "redis: %v channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: @@ -721,7 +816,7 @@ func (c *channel) initAllChan() { } case <-timer.C: internal.Logger.Printf( - ctx, "redis: %s channel is full for %s (message is dropped)", + ctx, "redis: %v channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: diff --git a/vendor/github.com/redis/go-redis/v9/pubsub_commands.go b/vendor/github.com/redis/go-redis/v9/pubsub_commands.go index 28622aa6b..ccc0ed524 100644 --- a/vendor/github.com/redis/go-redis/v9/pubsub_commands.go +++ b/vendor/github.com/redis/go-redis/v9/pubsub_commands.go @@ -1,6 +1,10 @@ package redis -import "context" +import ( + "context" + + "github.com/redis/go-redis/v9/internal/otel" +) type PubSubCmdable interface { Publish(ctx context.Context, channel string, message interface{}) *IntCmd @@ -16,12 +20,20 @@ type PubSubCmdable interface { func (c cmdable) Publish(ctx context.Context, channel string, message interface{}) *IntCmd { cmd := NewIntCmd(ctx, "publish", channel, message) _ = c(ctx, cmd) + // Record PubSub message sent (if command succeeded) + if cmd.Err() == nil { + otel.RecordPubSubMessage(ctx, nil, "sent", channel, false) + } return cmd } func (c cmdable) SPublish(ctx context.Context, channel string, message interface{}) *IntCmd { cmd := NewIntCmd(ctx, "spublish", channel, message) _ = c(ctx, cmd) + // Record PubSub message sent (if command succeeded) + if cmd.Err() == nil { + otel.RecordPubSubMessage(ctx, nil, "sent", channel, true) + } return cmd } diff --git a/vendor/github.com/redis/go-redis/v9/push/errors.go b/vendor/github.com/redis/go-redis/v9/push/errors.go new file mode 100644 index 000000000..c10c98aa8 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/push/errors.go @@ -0,0 +1,176 @@ +package push + +import ( + "errors" + "fmt" +) + +// Push notification error definitions +// This file contains all error types and messages used by the push notification system + +// Error reason constants +const ( + // HandlerReasons + ReasonHandlerNil = "handler cannot be nil" + ReasonHandlerExists = "cannot overwrite existing handler" + ReasonHandlerProtected = "handler is protected" + + // ProcessorReasons + ReasonPushNotificationsDisabled = "push notifications are disabled" +) + +// ProcessorType represents the type of processor involved in the error +// defined as a custom type for better readability and easier maintenance +type ProcessorType string + +const ( + // ProcessorTypes + ProcessorTypeProcessor = ProcessorType("processor") + ProcessorTypeVoidProcessor = ProcessorType("void_processor") + ProcessorTypeCustom = ProcessorType("custom") +) + +// ProcessorOperation represents the operation being performed by the processor +// defined as a custom type for better readability and easier maintenance +type ProcessorOperation string + +const ( + // ProcessorOperations + ProcessorOperationProcess = ProcessorOperation("process") + ProcessorOperationRegister = ProcessorOperation("register") + ProcessorOperationUnregister = ProcessorOperation("unregister") + ProcessorOperationUnknown = ProcessorOperation("unknown") +) + +// Common error variables for reuse +var ( + // ErrHandlerNil is returned when attempting to register a nil handler + ErrHandlerNil = errors.New(ReasonHandlerNil) +) + +// Registry errors + +// ErrHandlerExists creates an error for when attempting to overwrite an existing handler +func ErrHandlerExists(pushNotificationName string) error { + return NewHandlerError(ProcessorOperationRegister, pushNotificationName, ReasonHandlerExists, nil) +} + +// ErrProtectedHandler creates an error for when attempting to unregister a protected handler +func ErrProtectedHandler(pushNotificationName string) error { + return NewHandlerError(ProcessorOperationUnregister, pushNotificationName, ReasonHandlerProtected, nil) +} + +// VoidProcessor errors + +// ErrVoidProcessorRegister creates an error for when attempting to register a handler on void processor +func ErrVoidProcessorRegister(pushNotificationName string) error { + return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationRegister, pushNotificationName, ReasonPushNotificationsDisabled, nil) +} + +// ErrVoidProcessorUnregister creates an error for when attempting to unregister a handler on void processor +func ErrVoidProcessorUnregister(pushNotificationName string) error { + return NewProcessorError(ProcessorTypeVoidProcessor, ProcessorOperationUnregister, pushNotificationName, ReasonPushNotificationsDisabled, nil) +} + +// Error type definitions for advanced error handling + +// HandlerError represents errors related to handler operations +type HandlerError struct { + Operation ProcessorOperation + PushNotificationName string + Reason string + Err error +} + +func (e *HandlerError) Error() string { + if e.Err != nil { + return fmt.Sprintf("handler %s failed for '%s': %s (%v)", e.Operation, e.PushNotificationName, e.Reason, e.Err) + } + return fmt.Sprintf("handler %s failed for '%s': %s", e.Operation, e.PushNotificationName, e.Reason) +} + +func (e *HandlerError) Unwrap() error { + return e.Err +} + +// NewHandlerError creates a new HandlerError +func NewHandlerError(operation ProcessorOperation, pushNotificationName, reason string, err error) *HandlerError { + return &HandlerError{ + Operation: operation, + PushNotificationName: pushNotificationName, + Reason: reason, + Err: err, + } +} + +// ProcessorError represents errors related to processor operations +type ProcessorError struct { + ProcessorType ProcessorType // "processor", "void_processor" + Operation ProcessorOperation // "process", "register", "unregister" + PushNotificationName string // Name of the push notification involved + Reason string + Err error +} + +func (e *ProcessorError) Error() string { + notifInfo := "" + if e.PushNotificationName != "" { + notifInfo = fmt.Sprintf(" for '%s'", e.PushNotificationName) + } + if e.Err != nil { + return fmt.Sprintf("%s %s failed%s: %s (%v)", e.ProcessorType, e.Operation, notifInfo, e.Reason, e.Err) + } + return fmt.Sprintf("%s %s failed%s: %s", e.ProcessorType, e.Operation, notifInfo, e.Reason) +} + +func (e *ProcessorError) Unwrap() error { + return e.Err +} + +// NewProcessorError creates a new ProcessorError +func NewProcessorError(processorType ProcessorType, operation ProcessorOperation, pushNotificationName, reason string, err error) *ProcessorError { + return &ProcessorError{ + ProcessorType: processorType, + Operation: operation, + PushNotificationName: pushNotificationName, + Reason: reason, + Err: err, + } +} + +// Helper functions for common error scenarios + +// IsHandlerNilError checks if an error is due to a nil handler +func IsHandlerNilError(err error) bool { + return errors.Is(err, ErrHandlerNil) +} + +// IsHandlerExistsError checks if an error is due to attempting to overwrite an existing handler. +// This function works correctly even when the error is wrapped. +func IsHandlerExistsError(err error) bool { + var handlerErr *HandlerError + if errors.As(err, &handlerErr) { + return handlerErr.Operation == ProcessorOperationRegister && handlerErr.Reason == ReasonHandlerExists + } + return false +} + +// IsProtectedHandlerError checks if an error is due to attempting to unregister a protected handler. +// This function works correctly even when the error is wrapped. +func IsProtectedHandlerError(err error) bool { + var handlerErr *HandlerError + if errors.As(err, &handlerErr) { + return handlerErr.Operation == ProcessorOperationUnregister && handlerErr.Reason == ReasonHandlerProtected + } + return false +} + +// IsVoidProcessorError checks if an error is due to void processor operations. +// This function works correctly even when the error is wrapped. +func IsVoidProcessorError(err error) bool { + var procErr *ProcessorError + if errors.As(err, &procErr) { + return procErr.ProcessorType == ProcessorTypeVoidProcessor && procErr.Reason == ReasonPushNotificationsDisabled + } + return false +} diff --git a/vendor/github.com/redis/go-redis/v9/push/handler.go b/vendor/github.com/redis/go-redis/v9/push/handler.go new file mode 100644 index 000000000..815edce37 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/push/handler.go @@ -0,0 +1,14 @@ +package push + +import ( + "context" +) + +// NotificationHandler defines the interface for push notification handlers. +type NotificationHandler interface { + // HandlePushNotification processes a push notification with context information. + // The handlerCtx provides information about the client, connection pool, and connection + // on which the notification was received, allowing handlers to make informed decisions. + // Returns an error if the notification could not be handled. + HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error +} diff --git a/vendor/github.com/redis/go-redis/v9/push/handler_context.go b/vendor/github.com/redis/go-redis/v9/push/handler_context.go new file mode 100644 index 000000000..c39e186b0 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/push/handler_context.go @@ -0,0 +1,44 @@ +package push + +// No imports needed for this file + +// NotificationHandlerContext provides context information about where a push notification was received. +// This struct allows handlers to make informed decisions based on the source of the notification +// with strongly typed access to different client types using concrete types. +type NotificationHandlerContext struct { + // Client is the Redis client instance that received the notification. + // It is interface to both allow for future expansion and to avoid + // circular dependencies. The developer is responsible for type assertion. + // It can be one of the following types: + // - *redis.baseClient + // - *redis.Client + // - *redis.ClusterClient + // - *redis.Conn + Client interface{} + + // ConnPool is the connection pool from which the connection was obtained. + // It is interface to both allow for future expansion and to avoid + // circular dependencies. The developer is responsible for type assertion. + // It can be one of the following types: + // - *pool.ConnPool + // - *pool.SingleConnPool + // - *pool.StickyConnPool + ConnPool interface{} + + // PubSub is the PubSub instance that received the notification. + // It is interface to both allow for future expansion and to avoid + // circular dependencies. The developer is responsible for type assertion. + // It can be one of the following types: + // - *redis.PubSub + PubSub interface{} + + // Conn is the specific connection on which the notification was received. + // It is interface to both allow for future expansion and to avoid + // circular dependencies. The developer is responsible for type assertion. + // It can be one of the following types: + // - *pool.Conn + Conn interface{} + + // IsBlocking indicates if the notification was received on a blocking connection. + IsBlocking bool +} diff --git a/vendor/github.com/redis/go-redis/v9/push/processor.go b/vendor/github.com/redis/go-redis/v9/push/processor.go new file mode 100644 index 000000000..b8112ddc8 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/push/processor.go @@ -0,0 +1,203 @@ +package push + +import ( + "context" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/proto" +) + +// NotificationProcessor defines the interface for push notification processors. +type NotificationProcessor interface { + // GetHandler returns the handler for a specific push notification name. + GetHandler(pushNotificationName string) NotificationHandler + // ProcessPendingNotifications checks for and processes any pending push notifications. + // To be used when it is known that there are notifications on the socket. + // It will try to read from the socket and if it is empty - it may block. + ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error + // RegisterHandler registers a handler for a specific push notification name. + RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error + // UnregisterHandler removes a handler for a specific push notification name. + UnregisterHandler(pushNotificationName string) error +} + +// Processor handles push notifications with a registry of handlers +type Processor struct { + registry *Registry +} + +// NewProcessor creates a new push notification processor +func NewProcessor() *Processor { + return &Processor{ + registry: NewRegistry(), + } +} + +// GetHandler returns the handler for a specific push notification name +func (p *Processor) GetHandler(pushNotificationName string) NotificationHandler { + return p.registry.GetHandler(pushNotificationName) +} + +// RegisterHandler registers a handler for a specific push notification name +func (p *Processor) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error { + return p.registry.RegisterHandler(pushNotificationName, handler, protected) +} + +// UnregisterHandler removes a handler for a specific push notification name +func (p *Processor) UnregisterHandler(pushNotificationName string) error { + return p.registry.UnregisterHandler(pushNotificationName) +} + +// ProcessPendingNotifications checks for and processes any pending push notifications +// This method should be called by the client in WithReader before reading the reply +// It will try to read from the socket and if it is empty - it may block. +func (p *Processor) ProcessPendingNotifications(ctx context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error { + if rd == nil { + return nil + } + + for { + // Check if there's data available to read + replyType, err := rd.PeekReplyType() + if err != nil { + // No more data available or error reading + // if timeout, it will be handled by the caller + break + } + + // Only process push notifications (arrays starting with >) + if replyType != proto.RespPush { + break + } + + // see if we should skip this notification + notificationName, err := rd.PeekPushNotificationName() + if err != nil { + break + } + + if willHandleNotificationInClient(notificationName) { + break + } + + // Read the push notification + reply, err := rd.ReadReply() + if err != nil { + internal.Logger.Printf(ctx, "push: error reading push notification: %v", err) + break + } + + // Convert to slice of interfaces + notification, ok := reply.([]interface{}) + if !ok { + break + } + + // Handle the notification directly + if len(notification) > 0 { + // Extract the notification type (first element) + if notificationType, ok := notification[0].(string); ok { + // Get the handler for this notification type + if handler := p.registry.GetHandler(notificationType); handler != nil { + // Handle the notification + err := handler.HandlePushNotification(ctx, handlerCtx, notification) + if err != nil { + internal.Logger.Printf(ctx, "push: error handling push notification: %v", err) + } + } + } + } + } + + return nil +} + +// VoidProcessor discards all push notifications without processing them +type VoidProcessor struct{} + +// NewVoidProcessor creates a new void push notification processor +func NewVoidProcessor() *VoidProcessor { + return &VoidProcessor{} +} + +// GetHandler returns nil for void processor since it doesn't maintain handlers +func (v *VoidProcessor) GetHandler(_ string) NotificationHandler { + return nil +} + +// RegisterHandler returns an error for void processor since it doesn't maintain handlers +func (v *VoidProcessor) RegisterHandler(pushNotificationName string, _ NotificationHandler, _ bool) error { + return ErrVoidProcessorRegister(pushNotificationName) +} + +// UnregisterHandler returns an error for void processor since it doesn't maintain handlers +func (v *VoidProcessor) UnregisterHandler(pushNotificationName string) error { + return ErrVoidProcessorUnregister(pushNotificationName) +} + +// ProcessPendingNotifications for VoidProcessor does nothing since push notifications +// are only available in RESP3 and this processor is used for RESP2 connections. +// This avoids unnecessary buffer scanning overhead. +// It does however read and discard all push notifications from the buffer to avoid +// them being interpreted as a reply. +// This method should be called by the client in WithReader before reading the reply +// to be sure there are no buffered push notifications. +// It will try to read from the socket and if it is empty - it may block. +func (v *VoidProcessor) ProcessPendingNotifications(_ context.Context, handlerCtx NotificationHandlerContext, rd *proto.Reader) error { + // read and discard all push notifications + if rd == nil { + return nil + } + + for { + // Check if there's data available to read + replyType, err := rd.PeekReplyType() + if err != nil { + // No more data available or error reading + // if timeout, it will be handled by the caller + break + } + + // Only process push notifications (arrays starting with >) + if replyType != proto.RespPush { + break + } + // see if we should skip this notification + notificationName, err := rd.PeekPushNotificationName() + if err != nil { + break + } + + if willHandleNotificationInClient(notificationName) { + break + } + + // Read the push notification + _, err = rd.ReadReply() + if err != nil { + internal.Logger.Printf(context.Background(), "push: error reading push notification: %v", err) + return nil + } + } + return nil +} + +// willHandleNotificationInClient checks if a notification type should be ignored by the push notification +// processor and handled by other specialized systems instead (pub/sub, streams, keyspace, etc.). +func willHandleNotificationInClient(notificationType string) bool { + switch notificationType { + // Pub/Sub notifications - handled by pub/sub system + case "message", // Regular pub/sub message + "pmessage", // Pattern pub/sub message + "subscribe", // Subscription confirmation + "unsubscribe", // Unsubscription confirmation + "psubscribe", // Pattern subscription confirmation + "punsubscribe", // Pattern unsubscription confirmation + "smessage", // Sharded pub/sub message (Redis 7.0+) + "ssubscribe", // Sharded subscription confirmation + "sunsubscribe": // Sharded unsubscription confirmation + return true + default: + return false + } +} diff --git a/vendor/github.com/redis/go-redis/v9/push/push.go b/vendor/github.com/redis/go-redis/v9/push/push.go new file mode 100644 index 000000000..e6adeaa45 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/push/push.go @@ -0,0 +1,7 @@ +// Package push provides push notifications for Redis. +// This is an EXPERIMENTAL API for handling push notifications from Redis. +// It is not yet stable and may change in the future. +// Although this is in a public package, in its current form public use is not advised. +// Pending push notifications should be processed before executing any readReply from the connection +// as per RESP3 specification push notifications can be sent at any time. +package push diff --git a/vendor/github.com/redis/go-redis/v9/push/registry.go b/vendor/github.com/redis/go-redis/v9/push/registry.go new file mode 100644 index 000000000..a265ae92f --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/push/registry.go @@ -0,0 +1,61 @@ +package push + +import ( + "sync" +) + +// Registry manages push notification handlers +type Registry struct { + mu sync.RWMutex + handlers map[string]NotificationHandler + protected map[string]bool +} + +// NewRegistry creates a new push notification registry +func NewRegistry() *Registry { + return &Registry{ + handlers: make(map[string]NotificationHandler), + protected: make(map[string]bool), + } +} + +// RegisterHandler registers a handler for a specific push notification name +func (r *Registry) RegisterHandler(pushNotificationName string, handler NotificationHandler, protected bool) error { + if handler == nil { + return ErrHandlerNil + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Check if handler already exists + if _, exists := r.protected[pushNotificationName]; exists { + return ErrHandlerExists(pushNotificationName) + } + + r.handlers[pushNotificationName] = handler + r.protected[pushNotificationName] = protected + return nil +} + +// GetHandler returns the handler for a specific push notification name +func (r *Registry) GetHandler(pushNotificationName string) NotificationHandler { + r.mu.RLock() + defer r.mu.RUnlock() + return r.handlers[pushNotificationName] +} + +// UnregisterHandler removes a handler for a specific push notification name +func (r *Registry) UnregisterHandler(pushNotificationName string) error { + r.mu.Lock() + defer r.mu.Unlock() + + // Check if handler is protected + if protected, exists := r.protected[pushNotificationName]; exists && protected { + return ErrProtectedHandler(pushNotificationName) + } + + delete(r.handlers, pushNotificationName) + delete(r.protected, pushNotificationName) + return nil +} diff --git a/vendor/github.com/redis/go-redis/v9/push_notifications.go b/vendor/github.com/redis/go-redis/v9/push_notifications.go new file mode 100644 index 000000000..572955fec --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/push_notifications.go @@ -0,0 +1,21 @@ +package redis + +import ( + "github.com/redis/go-redis/v9/push" +) + +// NewPushNotificationProcessor creates a new push notification processor +// This processor maintains a registry of handlers and processes push notifications +// It is used for RESP3 connections where push notifications are available +func NewPushNotificationProcessor() push.NotificationProcessor { + return push.NewProcessor() +} + +// NewVoidPushNotificationProcessor creates a new void push notification processor +// This processor does not maintain any handlers and always returns nil for all operations +// It is used for RESP2 connections where push notifications are not available +// It can also be used to disable push notifications for RESP3 connections, where +// it will discard all push notifications without processing them +func NewVoidPushNotificationProcessor() push.NotificationProcessor { + return push.NewVoidProcessor() +} diff --git a/vendor/github.com/redis/go-redis/v9/redis.go b/vendor/github.com/redis/go-redis/v9/redis.go index f50df5689..dd3451890 100644 --- a/vendor/github.com/redis/go-redis/v9/redis.go +++ b/vendor/github.com/redis/go-redis/v9/redis.go @@ -9,10 +9,15 @@ import ( "sync/atomic" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/auth/streaming" "github.com/redis/go-redis/v9/internal/hscan" + "github.com/redis/go-redis/v9/internal/otel" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/maintnotifications" + "github.com/redis/go-redis/v9/push" ) // Scanner internal/hscan.Scanner exposed interface. @@ -21,11 +26,29 @@ type Scanner = hscan.Scanner // Nil reply returned by Redis when key does not exist. const Nil = proto.Nil +// String representations of special float values. +// Values are lowercase for consistency with Redis RESP2 protocol responses. +const ( + NaN = internal.NaN // Not a Number + Inf = internal.Inf // Positive infinity + NInf = internal.NInf // Negative infinity +) + // SetLogger set custom log +// Use with VoidLogger to disable logging. +// If logger is nil, the call is ignored and the existing logger is kept. func SetLogger(logger internal.Logging) { + if logger == nil { + return + } internal.Logger = logger } +// SetLogLevel sets the log level for the library. +func SetLogLevel(logLevel internal.LogLevelT) { + internal.LogLevel = logLevel +} + //------------------------------------------------------------------------------ type Hook interface { @@ -200,20 +223,150 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e //------------------------------------------------------------------------------ +// Stable identifiers for baseClient.onClose hooks. Each component that +// registers a close callback owns a dedicated id here so the set of known +// hooks is discoverable in one place and id collisions are caught at +// compile time. New ids should be added as additional constants. +const ( + // onCloseHookIDSentinelFailover identifies the close callback installed + // by NewFailoverClient to tear down sentinel failover background work. + onCloseHookIDSentinelFailover = "sentinel-failover" +) + +// onCloseHooks is a small registry of named close callbacks attached to a +// baseClient. Each callback is identified by a stable string id; registering +// the same id twice replaces the previous callback rather than chaining onto +// it. This guarantees the registry stays bounded regardless of how often a +// hook is (re)registered and avoids the unbounded closure chain that +// motivated issue #3772. +// +// Hooks are invoked in registration order. All hooks run regardless of +// individual errors; the first non-nil error is returned. +// +// A zero-value onCloseHooks is ready to use. It is safe for concurrent use. +// Clones of a baseClient share the same *onCloseHooks so registrations and +// close semantics are preserved across WithTimeout / WithContext / etc. +type onCloseHooks struct { + mu sync.Mutex + order []string + hooks map[string]func() error +} + +// register adds or replaces the callback associated with id. Re-registering +// an existing id overwrites the previous callback in place; new ids are +// appended to the invocation order. +func (h *onCloseHooks) register(id string, fn func() error) { + h.mu.Lock() + defer h.mu.Unlock() + if h.hooks == nil { + h.hooks = make(map[string]func() error) + } + if _, exists := h.hooks[id]; !exists { + h.order = append(h.order, id) + } + h.hooks[id] = fn +} + +// unregister removes the callback associated with id, if any. It is kept +// for API symmetry with register so future callers (e.g. dynamic hook +// owners that need to detach before client Close) do not have to +// reinvent it. +// +//nolint:unused // kept for API symmetry with register; see comment above. +func (h *onCloseHooks) unregister(id string) { + h.mu.Lock() + defer h.mu.Unlock() + if _, exists := h.hooks[id]; !exists { + return + } + delete(h.hooks, id) + for i, x := range h.order { + if x == id { + h.order = append(h.order[:i], h.order[i+1:]...) + break + } + } +} + +// run invokes all registered callbacks in registration order and returns +// the first non-nil error encountered. All callbacks are executed even if +// an earlier one returns an error. +func (h *onCloseHooks) run() error { + if h == nil { + return nil + } + h.mu.Lock() + fns := make([]func() error, 0, len(h.order)) + for _, id := range h.order { + if fn := h.hooks[id]; fn != nil { + fns = append(fns, fn) + } + } + h.mu.Unlock() + + var firstErr error + for _, fn := range fns { + if err := fn(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + type baseClient struct { - opt *Options - connPool pool.Pooler + opt *Options + optLock sync.RWMutex + connPool pool.Pooler + pubSubPool *pool.PubSubPool + hooksMixin - onClose func() error // hook called when client is closed + // onClose holds named callbacks invoked when the client is closed. + // Registering a new callback never removes previously registered ones; + // only re-registering the same id replaces the existing callback. This + // lets composing components (e.g. sentinel failover) add close logic + // safely without fear of overwriting each other and without building + // unbounded closure chains on repeated registration. + onClose *onCloseHooks + + // Push notification processing + pushProcessor push.NotificationProcessor + + // Maintenance notifications manager + maintNotificationsManager *maintnotifications.Manager + maintNotificationsManagerLock sync.RWMutex + + // streamingCredentialsManager is used to manage streaming credentials + streamingCredentialsManager *streaming.Manager } func (c *baseClient) clone() *baseClient { - clone := *c - return &clone + c.maintNotificationsManagerLock.RLock() + maintNotificationsManager := c.maintNotificationsManager + c.maintNotificationsManagerLock.RUnlock() + + clone := &baseClient{ + opt: c.opt, + connPool: c.connPool, + pubSubPool: c.pubSubPool, + onClose: c.onClose, + pushProcessor: c.pushProcessor, + maintNotificationsManager: maintNotificationsManager, + streamingCredentialsManager: c.streamingCredentialsManager, + } + return clone +} + +// cloneOpt clones c.opt while holding optLock to prevent races with initConn +// which writes to MaintNotificationsConfig.Mode under the same lock. +func (c *baseClient) cloneOpt() *Options { + c.optLock.RLock() + clone := c.opt.clone() + c.optLock.RUnlock() + return clone } func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { - opt := c.opt.clone() + opt := c.cloneOpt() opt.ReadTimeout = timeout opt.WriteTimeout = timeout @@ -227,21 +380,6 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } -func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { - cn, err := c.connPool.NewConn(ctx) - if err != nil { - return nil, err - } - - err = c.initConn(ctx, cn) - if err != nil { - _ = c.connPool.CloseConn(cn) - return nil, err - } - - return cn, nil -} - func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { if c.opt.Limiter != nil { err := c.opt.Limiter.Allow() @@ -267,7 +405,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return nil, err } - if cn.Inited { + if cn.IsInited() { return cn, nil } @@ -279,40 +417,209 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return nil, err } + if dialStartNs := cn.GetDialStartNs(); dialStartNs > 0 { + if cb := pool.GetMetricConnectionCreateTimeCallback(); cb != nil { + duration := time.Duration(time.Now().UnixNano() - dialStartNs) + cb(ctx, duration, cn) + } + } + + // initConn will transition to IDLE state, so we need to acquire it + // before returning it to the user. + if !cn.TryAcquire() { + return nil, fmt.Errorf("redis: connection is not usable") + } + return cn, nil } +func (c *baseClient) reAuthConnection() func(poolCn *pool.Conn, credentials auth.Credentials) error { + return func(poolCn *pool.Conn, credentials auth.Credentials) error { + var err error + username, password := credentials.BasicAuth() + + // Use background context - timeout is handled by ReadTimeout in WithReader/WithWriter + ctx := context.Background() + + connPool := pool.NewSingleConnPool(c.connPool, poolCn) + + // Pass hooks so that reauth commands are recorded/traced + cn := newConn(c.opt, connPool, &c.hooksMixin) + + if username != "" { + err = cn.AuthACL(ctx, username, password).Err() + } else { + err = cn.Auth(ctx, password).Err() + } + + return err + } +} +func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) { + return func(poolCn *pool.Conn, err error) { + if err != nil { + if isBadConn(err, false, c.opt.Addr) { + // Close the connection to force a reconnection. + // Re-auth happens on connections that were idle in the pool (the pool hook + // waits for IDLE state before transitioning to UNUSABLE for re-auth). + // From metrics perspective, the connection was never "used" by a client. + // Note: Using context.Background() as this callback doesn't have access to caller's context. + err := c.connPool.CloseConn(context.Background(), poolCn, pool.CloseReasonAuthError, pool.MetricStateIdle) + if err != nil { + internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err) + // try to close the network connection directly + // so that no resource is leaked + err := poolCn.Close() + if err != nil { + internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err) + } + } + } + internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err) + } + } +} + func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { - if cn.Inited { + // This function is called in two scenarios: + // 1. First-time init: Connection is in CREATED state (from pool.Get()) + // - We need to transition CREATED → INITIALIZING and do the initialization + // - If another goroutine is already initializing, we WAIT for it to finish + // 2. Re-initialization: Connection is in INITIALIZING state (from SetNetConnAndInitConn()) + // - We're already in INITIALIZING, so just proceed with initialization + + currentState := cn.GetStateMachine().GetState() + + // Fast path: Check if already initialized (IDLE or IN_USE) + if currentState == pool.StateIdle || currentState == pool.StateInUse { return nil } - cn.Inited = true - var err error - username, password := c.opt.Username, c.opt.Password - if c.opt.CredentialsProviderContext != nil { - if username, password, err = c.opt.CredentialsProviderContext(ctx); err != nil { + // If in CREATED state, try to transition to INITIALIZING + if currentState == pool.StateCreated { + finalState, err := cn.GetStateMachine().TryTransition([]pool.ConnState{pool.StateCreated}, pool.StateInitializing) + if err != nil { + // Another goroutine is initializing or connection is in unexpected state + // Check what state we're in now + if finalState == pool.StateIdle || finalState == pool.StateInUse { + // Already initialized by another goroutine + return nil + } + + if finalState == pool.StateInitializing { + // Another goroutine is initializing - WAIT for it to complete + // Use a context with timeout = min(remaining command timeout, DialTimeout) + // This prevents waiting too long while respecting the caller's deadline + var waitCtx context.Context + var cancel context.CancelFunc + dialTimeout := c.opt.DialTimeout + + if cmdDeadline, hasCmdDeadline := ctx.Deadline(); hasCmdDeadline { + // Calculate remaining time until command deadline + remainingTime := time.Until(cmdDeadline) + // Use the minimum of remaining time and DialTimeout + if remainingTime < dialTimeout { + // Command deadline is sooner, use it + waitCtx = ctx + } else { + // DialTimeout is shorter, cap the wait at DialTimeout + waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) + } + } else { + // No command deadline, use DialTimeout to prevent waiting indefinitely + waitCtx, cancel = context.WithTimeout(ctx, dialTimeout) + } + if cancel != nil { + defer cancel() + } + + finalState, err := cn.GetStateMachine().AwaitAndTransition( + waitCtx, + []pool.ConnState{pool.StateIdle, pool.StateInUse}, + pool.StateIdle, // Target is IDLE (but we're already there, so this is a no-op) + ) + if err != nil { + return err + } + // Verify we're now initialized + if finalState == pool.StateIdle || finalState == pool.StateInUse { + return nil + } + // Unexpected state after waiting + return fmt.Errorf("connection in unexpected state after initialization: %s", finalState) + } + + // Unexpected state (CLOSED, UNUSABLE, etc.) return err } - } else if c.opt.CredentialsProvider != nil { - username, password = c.opt.CredentialsProvider() } + // At this point, we're in INITIALIZING state and we own the initialization + // If we fail, we must transition to CLOSED + var initErr error connPool := pool.NewSingleConnPool(c.connPool, cn) - conn := newConn(c.opt, connPool) + conn := newConn(c.opt, connPool, &c.hooksMixin) + + username, password := "", "" + if c.opt.StreamingCredentialsProvider != nil { + credListener, initErr := c.streamingCredentialsManager.Listener( + cn, + c.reAuthConnection(), + c.onAuthenticationErr(), + ) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to create credentials listener: %w", initErr) + } - var auth bool - protocol := c.opt.Protocol - // By default, use RESP3 in current version. - if protocol < 2 { - protocol = 3 + credentials, unsubscribeFromCredentialsProvider, initErr := c.opt.StreamingCredentialsProvider. + Subscribe(credListener) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to subscribe to streaming credentials: %w", initErr) + } + + // Per-connection unsubscribe is attached to the connection itself so it + // runs when this specific connection is closed. Do not register it on + // c.onClose: initConn runs for every (re)initialized connection, and + // attaching per-connection state to the shared baseClient registry would + // either leak entries (one per connection id, never trimmed) or — with + // the pre-fix wrappedOnClose approach — build an unbounded closure chain + // retaining every prior connection's unsubscribe (see issue #3772). + // + // Note: pool.Conn.SetOnClose OVERWRITES any prior callback (see the + // doc on that method). That is safe here because the streaming + // credentials Manager deduplicates listeners by connection id, so a + // second initConn on the same cn re-Subscribes the SAME listener and + // the returned unsubscribe is equivalent to the one already installed. + // Any future code path that could hand out a distinct unsubscribe on + // re-initialization must first invoke the existing one to avoid + // orphaning the old subscription on the credentials provider. + cn.SetOnClose(unsubscribeFromCredentialsProvider) + + username, password = credentials.BasicAuth() + } else if c.opt.CredentialsProviderContext != nil { + username, password, initErr = c.opt.CredentialsProviderContext(ctx) + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to get credentials from context provider: %w", initErr) + } + } else if c.opt.CredentialsProvider != nil { + username, password = c.opt.CredentialsProvider() + } else if c.opt.Username != "" || c.opt.Password != "" { + username, password = c.opt.Username, c.opt.Password } // for redis-server versions that do not support the HELLO command, // RESP2 will continue to be used. - if err = conn.Hello(ctx, protocol, username, password, c.opt.ClientName).Err(); err == nil { - auth = true - } else if !isRedisError(err) { + // helloOK tracks whether HELLO succeeded. If it did not, the connection + // falls back to RESP2 regardless of c.opt.Protocol, and features that + // require RESP3 (e.g. maintenance notifications) must be skipped. + helloOK := false + if initErr = conn.Hello(ctx, c.opt.Protocol, username, password, c.opt.ClientName).Err(); initErr == nil { + // Authentication successful with HELLO command + helloOK = true + } else if !isRedisError(initErr) { // When the server responds with the RESP protocol and the result is not a normal // execution result of the HELLO command, we consider it to be an indication that // the server does not support the HELLO command. @@ -320,18 +627,22 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { // or it could be DragonflyDB or a third-party redis-proxy. They all respond // with different error string results for unsupported commands, making it // difficult to rely on error strings to determine all results. - return err - } - - _, err = conn.Pipelined(ctx, func(pipe Pipeliner) error { - if !auth && password != "" { - if username != "" { - pipe.AuthACL(ctx, username, password) - } else { - pipe.Auth(ctx, password) - } + cn.GetStateMachine().Transition(pool.StateClosed) + return initErr + } else if password != "" { + // Try legacy AUTH command if HELLO failed + if username != "" { + initErr = conn.AuthACL(ctx, username, password).Err() + } else { + initErr = conn.Auth(ctx, password).Err() } + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to authenticate: %w", initErr) + } + } + _, initErr = conn.Pipelined(ctx, func(pipe Pipeliner) error { if c.opt.DB > 0 { pipe.Select(ctx, c.opt.DB) } @@ -346,8 +657,95 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return nil }) - if err != nil { - return err + if initErr != nil { + cn.GetStateMachine().Transition(pool.StateClosed) + return fmt.Errorf("failed to initialize connection options: %w", initErr) + } + + // Enable maintnotifications if maintnotifications are configured + c.optLock.RLock() + maintNotifEnabled := c.opt.MaintNotificationsConfig != nil && c.opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled + protocol := c.opt.Protocol + var endpointType maintnotifications.EndpointType + var maintNotifMode maintnotifications.Mode + if maintNotifEnabled { + endpointType = c.opt.MaintNotificationsConfig.EndpointType + maintNotifMode = c.opt.MaintNotificationsConfig.Mode + } + c.optLock.RUnlock() + + // Maintenance notifications require RESP3 push frames. If HELLO failed + // and the connection fell back to RESP2, there is no point in sending + // CLIENT MAINT_NOTIFICATIONS: the server either rejects it (making the + // error misleading) or accepts it silently, leaving the client unable + // to receive any notifications. Decide based on the actual negotiated + // protocol rather than the requested one. + if maintNotifEnabled && protocol == 3 && !helloOK { + if maintNotifMode == maintnotifications.ModeEnabled { + // Explicitly requested - fail fast with a clear reason. + cn.GetStateMachine().Transition(pool.StateClosed) + if errorCallback := pool.GetMetricErrorCallback(); errorCallback != nil { + errorCallback(ctx, "HANDSHAKE_FAILED", cn, "HANDSHAKE_FAILED", true, 0) + } + return fmt.Errorf("failed to enable maintnotifications: server does not support RESP3 (HELLO command failed)") + } + // auto/other modes: silently disable maintnotifications for this client. + c.optLock.Lock() + c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled + c.optLock.Unlock() + if err := c.disableMaintNotificationsUpgrades(); err != nil { + internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err) + } + maintNotifEnabled = false + } + + var maintNotifHandshakeErr error + if maintNotifEnabled && protocol == 3 { + maintNotifHandshakeErr = conn.ClientMaintNotifications( + ctx, + true, + endpointType.String(), + ).Err() + if maintNotifHandshakeErr != nil { + if !isRedisError(maintNotifHandshakeErr) { + // if not redis error, fail the connection + cn.GetStateMachine().Transition(pool.StateClosed) + return maintNotifHandshakeErr + } + c.optLock.Lock() + // handshake failed - check and modify config atomically + switch c.opt.MaintNotificationsConfig.Mode { + case maintnotifications.ModeEnabled: + // enabled mode, fail the connection + c.optLock.Unlock() + cn.GetStateMachine().Transition(pool.StateClosed) + + // Record handshake failure metric + if errorCallback := pool.GetMetricErrorCallback(); errorCallback != nil { + errorCallback(ctx, "HANDSHAKE_FAILED", cn, "HANDSHAKE_FAILED", true, 0) + } + + return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr) + default: // will handle auto and any other + // Disabling logging here as it's too noisy. + // TODO: Enable when we have a better logging solution for log levels + // internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) + c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled + c.optLock.Unlock() + // auto mode, disable maintnotifications and continue + if initErr := c.disableMaintNotificationsUpgrades(); initErr != nil { + // Log error but continue - auto mode should be resilient + internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", initErr) + } + } + } else { + // handshake was executed successfully + // to make sure that the handshake will be executed on other connections as well if it was successfully + // executed on this connection, we will force the handshake to be executed on all connections + c.optLock.Lock() + c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeEnabled + c.optLock.Unlock() + } } if !c.opt.DisableIdentity && !c.opt.DisableIndentity { @@ -361,14 +759,33 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { p.ClientSetInfo(ctx, WithLibraryVersion(libVer)) // Handle network errors (e.g. timeouts) in CLIENT SETINFO to avoid // out of order responses later on. - if _, err = p.Exec(ctx); err != nil && !isRedisError(err) { - return err + if _, initErr = p.Exec(ctx); initErr != nil && !isRedisError(initErr) { + cn.GetStateMachine().Transition(pool.StateClosed) + return initErr } } + // Set the connection initialization function for potential reconnections + // This must be set before transitioning to IDLE so that handoff/reauth can use it + cn.SetInitConnFunc(c.createInitConnFunc()) + + // Initialization succeeded - transition to IDLE state + // This marks the connection as initialized and ready for use + // NOTE: The connection is still owned by the calling goroutine at this point + // and won't be available to other goroutines until it's Put() back into the pool + cn.GetStateMachine().Transition(pool.StateIdle) + + // Call OnConnect hook if configured + // The connection is in IDLE state but still owned by this goroutine + // If OnConnect needs to send commands, it can use the connection safely if c.opt.OnConnect != nil { - return c.opt.OnConnect(ctx, conn) + if initErr = c.opt.OnConnect(ctx, conn); initErr != nil { + // OnConnect failed - transition to closed + cn.GetStateMachine().Transition(pool.StateClosed) + return initErr + } } + return nil } @@ -380,6 +797,10 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) if isBadConn(err, false, c.opt.Addr) { c.connPool.Remove(ctx, cn, err) } else { + // process any pending push notifications before returning the connection to the pool + if err := c.processPushNotifications(ctx, cn); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err) + } c.connPool.Put(ctx, cn) } } @@ -407,42 +828,143 @@ func (c *baseClient) dial(ctx context.Context, network, addr string) (net.Conn, } func (c *baseClient) process(ctx context.Context, cmd Cmder) error { + // Start measuring total operation duration (includes all retries) + // Only call time.Now() if operation duration callback is set to avoid overhead + var operationStart time.Time + opDurationCallback := otel.GetOperationDurationCallback() + if opDurationCallback != nil { + operationStart = time.Now() + } + var lastConn *pool.Conn + var lastErr error + totalAttempts := 0 for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { + totalAttempts++ attempt := attempt - retry, err := c._process(ctx, cmd, attempt) - if err == nil || !retry { + retry, cn, err := c._process(ctx, cmd, attempt) + if cn != nil { + lastConn = cn + } + // Don't retry if command explicitly disables retries (e.g., RawWriteToCmd + // which writes directly to an io.Writer and cannot undo partial writes) + if err == nil || !retry || cmd.NoRetry() { + // Record total operation duration + if opDurationCallback != nil { + operationDuration := time.Since(operationStart) + opDurationCallback(ctx, operationDuration, cmd, totalAttempts, err, lastConn, c.opt.DB) + } + + if err != nil { + if errorCallback := pool.GetMetricErrorCallback(); errorCallback != nil { + errorType, statusCode, isInternal := classifyCommandError(err) + errorCallback(ctx, errorType, lastConn, statusCode, isInternal, totalAttempts-1) + } + } return err } lastErr = err } + + // Record failed operation after all retries + if opDurationCallback != nil { + operationDuration := time.Since(operationStart) + opDurationCallback(ctx, operationDuration, cmd, totalAttempts, lastErr, lastConn, c.opt.DB) + } + + // Record error metric for exhausted retries + if errorCallback := pool.GetMetricErrorCallback(); errorCallback != nil { + errorType, statusCode, isInternal := classifyCommandError(lastErr) + errorCallback(ctx, errorType, lastConn, statusCode, isInternal, totalAttempts-1) + } + return lastErr } -func (c *baseClient) assertUnstableCommand(cmd Cmder) bool { - switch cmd.(type) { - case *AggregateCmd, *FTInfoCmd, *FTSpellCheckCmd, *FTSearchCmd, *FTSynDumpCmd: - if c.opt.UnstableResp3 { - return true - } else { - panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.") +// classifyCommandError classifies an error for metrics reporting. +// Returns: errorType, statusCode, isInternal +// - errorType: A string describing the error type (e.g., "TIMEOUT", "NETWORK", "ERR") +// - statusCode: The Redis error prefix or error category +// - isInternal: true for network/timeout errors, false for Redis server errors +func classifyCommandError(err error) (errorType, statusCode string, isInternal bool) { + if err == nil { + return "", "", false + } + + errStr := err.Error() + + // Check for timeout errors + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return "TIMEOUT", "TIMEOUT", true + } + + // Check for network errors + if _, ok := err.(net.Error); ok { + return "NETWORK", "NETWORK", true + } + + // Check for context errors + if errors.Is(err, context.Canceled) { + return "CONTEXT_CANCELED", "CONTEXT_CANCELED", true + } + if errors.Is(err, context.DeadlineExceeded) { + return "CONTEXT_TIMEOUT", "CONTEXT_TIMEOUT", true + } + + // Check for Redis errors + // Examples: "ERR ...", "WRONGTYPE ...", "CLUSTERDOWN ..." + if len(errStr) > 0 { + // Find the first space to extract the prefix + spaceIdx := 0 + for i, c := range errStr { + if c == ' ' { + spaceIdx = i + break + } + } + if spaceIdx == 0 { + spaceIdx = len(errStr) + } + prefix := errStr[:spaceIdx] + isUppercase := true + for _, c := range prefix { + if c < 'A' || c > 'Z' { + isUppercase = false + break + } + } + if isUppercase && len(prefix) > 0 { + return prefix, prefix, false } - default: - return false } + + return "UNKNOWN", "UNKNOWN", true } -func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) { +func (c *baseClient) assertUnstableCommand(cmd Cmder) (bool, error) { + // All search commands (FTSearchCmd, AggregateCmd, FTInfoCmd, FTSpellCheckCmd, FTSynDumpCmd) + // now have stable RESP3 parsing. No commands require the UnstableResp3 flag anymore. + return false, nil +} + +func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, *pool.Conn, error) { if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { - return false, err + return false, nil, err } } + var usedConn *pool.Conn retryTimeout := uint32(0) if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + usedConn = cn + // Process any pending push notifications before executing the command + if err := c.processPushNotifications(ctx, cn); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err) + } + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmd) }); err != nil { @@ -451,10 +973,22 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool } readReplyFunc := cmd.readReply // Apply unstable RESP3 search module. - if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) { - readReplyFunc = cmd.readRawReply + if c.opt.Protocol != 2 { + useRawReply, err := c.assertUnstableCommand(cmd) + if err != nil { + return err + } + if useRawReply { + readReplyFunc = cmd.readRawReply + } } - if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), readReplyFunc); err != nil { + if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error { + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } + return readReplyFunc(rd) + }); err != nil { if cmd.readTimeout() == nil { atomic.StoreUint32(&retryTimeout, 1) } else { @@ -466,10 +1000,10 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool return nil }); err != nil { retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) - return retry, err + return retry, usedConn, err } - return false, nil + return false, usedConn, nil } func (c *baseClient) retryBackoff(attempt int) time.Duration { @@ -487,19 +1021,88 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { return c.opt.ReadTimeout } +// context returns the context for the current connection. +// If the context timeout is enabled, it returns the original context. +// Otherwise, it returns a new background context. +func (c *baseClient) context(ctx context.Context) context.Context { + if c.opt.ContextTimeoutEnabled { + return ctx + } + return context.Background() +} + +// createInitConnFunc creates a connection initialization function that can be used for reconnections. +func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error { + return func(ctx context.Context, cn *pool.Conn) error { + return c.initConn(ctx, cn) + } +} + +// enableMaintNotificationsUpgrades initializes the maintnotifications upgrade manager and pool hook. +// This function is called during client initialization. +// will register push notification handlers for all maintenance upgrade events. +// will start background workers for handoff processing in the pool hook. +func (c *baseClient) enableMaintNotificationsUpgrades() error { + // Create client adapter + clientAdapterInstance := newClientAdapter(c) + + // Create maintnotifications manager directly + manager, err := maintnotifications.NewManager(clientAdapterInstance, c.connPool, c.opt.MaintNotificationsConfig) + if err != nil { + return err + } + // Set the manager reference and initialize pool hook + c.maintNotificationsManagerLock.Lock() + c.maintNotificationsManager = manager + c.maintNotificationsManagerLock.Unlock() + + // Initialize pool hook (safe to call without lock since manager is now set) + manager.InitPoolHook(c.dialHook) + return nil +} + +func (c *baseClient) disableMaintNotificationsUpgrades() error { + c.maintNotificationsManagerLock.Lock() + defer c.maintNotificationsManagerLock.Unlock() + + // Close the maintnotifications manager + if c.maintNotificationsManager != nil { + // Closing the manager will also shutdown the pool hook + // and remove it from the pool + c.maintNotificationsManager.Close() + c.maintNotificationsManager = nil + } + return nil +} + // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be // long-lived and shared between many goroutines. func (c *baseClient) Close() error { var firstErr error - if c.onClose != nil { - if err := c.onClose(); err != nil { + + // Close maintnotifications manager first + if err := c.disableMaintNotificationsUpgrades(); err != nil { + firstErr = err + } + + if err := c.onClose.run(); err != nil && firstErr == nil { + firstErr = err + } + + // Unregister pools from OTel before closing them + otel.UnregisterPools(c.connPool, c.pubSubPool) + + if c.connPool != nil { + if err := c.connPool.Close(); err != nil && firstErr == nil { firstErr = err } } - if err := c.connPool.Close(); err != nil && firstErr == nil { - firstErr = err + if c.pubSubPool != nil { + if err := c.pubSubPool.Close(); err != nil && firstErr == nil { + firstErr = err + } } return firstErr } @@ -509,14 +1112,14 @@ func (c *baseClient) getAddr() string { } func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { - if err := c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds); err != nil { + if err := c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds, "PIPELINE"); err != nil { return err } return cmdsFirstErr(cmds) } func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { - if err := c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds); err != nil { + if err := c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds, "MULTI"); err != nil { return err } return cmdsFirstErr(cmds) @@ -525,13 +1128,27 @@ func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) func (c *baseClient) generalProcessPipeline( - ctx context.Context, cmds []Cmder, p pipelineProcessor, + ctx context.Context, cmds []Cmder, p pipelineProcessor, operationName string, ) error { + // Only call time.Now() if pipeline operation duration callback is set to avoid overhead + var operationStart time.Time + pipelineOpDurationCallback := otel.GetPipelineOperationDurationCallback() + if pipelineOpDurationCallback != nil { + operationStart = time.Now() + } + var lastConn *pool.Conn + totalAttempts := 0 + var lastErr error for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { + totalAttempts++ if attempt > 0 { if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { setCmdsErr(cmds, err) + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, operationName, len(cmds), totalAttempts, err, lastConn, c.opt.DB) + } return err } } @@ -539,20 +1156,59 @@ func (c *baseClient) generalProcessPipeline( // Enable retries by default to retry dial errors returned by withConn. canRetry := true lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + lastConn = cn + // Process any pending push notifications before executing the pipeline + if err := c.processPushNotifications(ctx, cn); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before processing pipeline: %v", err) + } var err error canRetry, err = p(ctx, cn, cmds) return err }) - if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { + // Don't retry if any command in the pipeline explicitly disables retries + // (e.g., RawWriteToCmd which writes directly to an io.Writer and cannot + // undo partial writes on retry) + if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) || cmdsContainNoRetry(cmds) { + // The error should be set here only when failing to obtain the conn. + if !isRedisError(lastErr) { + setCmdsErr(cmds, lastErr) + } + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, operationName, len(cmds), totalAttempts, lastErr, lastConn, c.opt.DB) + } + + if lastErr != nil { + if errorCallback := pool.GetMetricErrorCallback(); errorCallback != nil { + errorType, statusCode, isInternal := classifyCommandError(lastErr) + errorCallback(ctx, errorType, lastConn, statusCode, isInternal, totalAttempts-1) + } + } return lastErr } } + + if pipelineOpDurationCallback != nil { + operationDuration := time.Since(operationStart) + pipelineOpDurationCallback(ctx, operationDuration, operationName, len(cmds), totalAttempts, lastErr, lastConn, c.opt.DB) + } + + if errorCallback := pool.GetMetricErrorCallback(); errorCallback != nil { + errorType, statusCode, isInternal := classifyCommandError(lastErr) + errorCallback(ctx, errorType, lastConn, statusCode, isInternal, totalAttempts-1) + } + return lastErr } func (c *baseClient) pipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { + // Process any pending push notifications before executing the pipeline + if err := c.processPushNotifications(ctx, cn); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before writing pipeline: %v", err) + } + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { @@ -561,7 +1217,8 @@ func (c *baseClient) pipelineProcessCmds( } if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { - return pipelineReadCmds(rd, cmds) + // read all replies + return c.pipelineReadCmds(ctx, cn, rd, cmds) }); err != nil { return true, err } @@ -569,8 +1226,12 @@ func (c *baseClient) pipelineProcessCmds( return false, nil } -func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { +func (c *baseClient) pipelineReadCmds(ctx context.Context, cn *pool.Conn, rd *proto.Reader, cmds []Cmder) error { for i, cmd := range cmds { + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } err := cmd.readReply(rd) cmd.SetErr(err) if err != nil && !isRedisError(err) { @@ -585,6 +1246,11 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { func (c *baseClient) txPipelineProcessCmds( ctx context.Context, cn *pool.Conn, cmds []Cmder, ) (bool, error) { + // Process any pending push notifications before executing the transaction pipeline + if err := c.processPushNotifications(ctx, cn); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err) + } + if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { @@ -597,12 +1263,13 @@ func (c *baseClient) txPipelineProcessCmds( // Trim multi and exec. trimmedCmds := cmds[1 : len(cmds)-1] - if err := txPipelineReadQueued(rd, statusCmd, trimmedCmds); err != nil { + if err := c.txPipelineReadQueued(ctx, cn, rd, statusCmd, trimmedCmds); err != nil { setCmdsErr(cmds, err) return err } - return pipelineReadCmds(rd, trimmedCmds) + // Read replies. + return c.pipelineReadCmds(ctx, cn, rd, trimmedCmds) }); err != nil { return false, err } @@ -610,19 +1277,36 @@ func (c *baseClient) txPipelineProcessCmds( return false, nil } -func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { +// txPipelineReadQueued reads queued replies from the Redis server. +// It returns an error if the server returns an error or if the number of replies does not match the number of commands. +func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } // Parse +OK. if err := statusCmd.readReply(rd); err != nil { return err } // Parse +QUEUED. - for range cmds { - if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { - return err + for _, cmd := range cmds { + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } + if err := statusCmd.readReply(rd); err != nil { + cmd.SetErr(err) + if !isRedisError(err) { + return err + } } } + // To be sure there are no buffered push notifications, we process them before reading the reply + if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { + internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + } // Parse number of replies. line, err := rd.ReadLine() if err != nil { @@ -639,13 +1323,6 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) return nil } -func (c *baseClient) context(ctx context.Context) context.Context { - if c.opt.ContextTimeoutEnabled { - return ctx - } - return context.Background() -} - //------------------------------------------------------------------------------ // Client is a Redis client representing a pool of zero or more underlying connections. @@ -656,23 +1333,79 @@ func (c *baseClient) context(ctx context.Context) context.Context { type Client struct { *baseClient cmdable - hooksMixin } // NewClient returns a client to the Redis Server specified by Options. +// Passing nil Options will cause a panic. func NewClient(opt *Options) *Client { if opt == nil { panic("redis: NewClient nil options") } + // clone to not share options with the caller + opt = opt.clone() opt.init() + // Push notifications are always enabled for RESP3 (cannot be disabled) + c := Client{ baseClient: &baseClient{ - opt: opt, + opt: opt, + onClose: &onCloseHooks{}, }, } c.init() - c.connPool = newConnPool(opt, c.dialHook) + + // Initialize push notification processor using shared helper + // Use void processor for RESP2 connections (push notifications not available) + c.pushProcessor = initializePushProcessor(opt) + // set opt push processor for child clients + c.opt.PushNotificationProcessor = c.pushProcessor + + // Generate unique pool names for metrics + uniqueID := generateUniqueID() + mainPoolName := opt.Addr + "_" + uniqueID + pubsubPoolName := opt.Addr + "_" + uniqueID + "_pubsub" + + // Create connection pools + var err error + c.connPool, err = newConnPool(opt, c.dialHook, mainPoolName) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + c.pubSubPool, err = newPubSubPool(opt, c.dialHook, pubsubPoolName) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } + + if opt.StreamingCredentialsProvider != nil { + c.streamingCredentialsManager = streaming.NewManager(c.connPool, c.opt.PoolTimeout) + c.connPool.AddPoolHook(c.streamingCredentialsManager.PoolHook()) + } + + // Initialize maintnotifications first if enabled and protocol is RESP3 + if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 { + err := c.enableMaintNotificationsUpgrades() + if err != nil { + internal.Logger.Printf(context.Background(), "failed to initialize maintnotifications: %v", err) + if opt.MaintNotificationsConfig.Mode == maintnotifications.ModeEnabled { + /* + Design decision: panic here to fail fast if maintnotifications cannot be enabled when explicitly requested. + We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect + an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced + immediately, rather than allowing the client to continue in a partially initialized or inconsistent state. + Clients relying on maintnotifications should be aware that initialization errors will cause a panic, and should + handle this accordingly (e.g., via recover or by validating configuration before calling NewClient). + This approach is only used when MaintNotificationsConfig.Mode is MaintNotificationsEnabled, indicating that maintnotifications + upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic. + */ + panic(fmt.Errorf("failed to enable maintnotifications: %w", err)) + } + } + } + + // Register pools with OTel recorder if it supports pool registration + // This allows async gauge metrics to pull stats from pools periodically + otel.RegisterPools(c.connPool, c.pubSubPool, opt.Addr) return &c } @@ -695,14 +1428,7 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client { } func (c *Client) Conn() *Conn { - return newConn(c.opt, pool.NewStickyConnPool(c.connPool)) -} - -// Do create a Cmd from the args and processes the cmd. -func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { - cmd := NewCmd(ctx, args...) - _ = c.Process(ctx, cmd) - return cmd + return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin) } func (c *Client) Process(ctx context.Context, cmd Cmder) error { @@ -711,16 +1437,67 @@ func (c *Client) Process(ctx context.Context, cmd Cmder) error { return err } -// Options returns read-only Options that were used to create the client. +// Options returns read-only *Options that were used to create the client. +// Any alteration of the returned *Options may result in undefined behaviour. func (c *Client) Options() *Options { return c.opt } +// NodeAddress returns the address of the Redis node as reported by the server. +// For cluster clients, this is the endpoint from CLUSTER SLOTS before any transformation +// (e.g., loopback replacement). For standalone clients, this defaults to Addr. +// +// This is useful for matching the source field in maintenance notifications +// (e.g. SMIGRATED). +func (c *Client) NodeAddress() string { + return c.opt.NodeAddress +} + +// GetMaintNotificationsManager returns the maintnotifications manager instance for monitoring and control. +// Returns nil if maintnotifications are not enabled. +func (c *Client) GetMaintNotificationsManager() *maintnotifications.Manager { + c.maintNotificationsManagerLock.RLock() + defer c.maintNotificationsManagerLock.RUnlock() + return c.maintNotificationsManager +} + +// initializePushProcessor initializes the push notification processor for any client type. +// This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient. +func initializePushProcessor(opt *Options) push.NotificationProcessor { + // Always use custom processor if provided + if opt.PushNotificationProcessor != nil { + return opt.PushNotificationProcessor + } + + // Push notifications are always enabled for RESP3, disabled for RESP2 + if opt.Protocol == 3 { + // Create default processor for RESP3 connections + return NewPushNotificationProcessor() + } + + // Create void processor for RESP2 connections (push notifications not available) + return NewVoidPushNotificationProcessor() +} + +// RegisterPushNotificationHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (c *Client) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error { + return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected) +} + +// GetPushNotificationHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (c *Client) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler { + return c.pushProcessor.GetHandler(pushNotificationName) +} + type PoolStats pool.Stats // PoolStats returns connection pool stats. func (c *Client) PoolStats() *PoolStats { stats := c.connPool.Stats() + stats.PubSubStats = *(c.pubSubPool.Stats()) return (*PoolStats)(stats) } @@ -755,13 +1532,31 @@ func (c *Client) TxPipeline() Pipeliner { func (c *Client) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { + cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels) + if err != nil { + return nil, err + } + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + return nil, err + } + // Track connection in PubSubPool + c.pubSubPool.TrackConn(cn) + return cn, nil + }, + closeConn: func(cn *pool.Conn) error { + // Untrack connection from PubSubPool + c.pubSubPool.UntrackConn(cn) + _ = cn.Close() + return nil }, - closeConn: c.connPool.CloseConn, + pushProcessor: c.pushProcessor, } pubsub.init() + return pubsub } @@ -828,17 +1623,28 @@ type Conn struct { baseClient cmdable statefulCmdable - hooksMixin } -func newConn(opt *Options, connPool pool.Pooler) *Conn { +// newConn is a helper func to create a new Conn instance. +// the Conn instance is not thread-safe and should not be shared between goroutines. +// the parentHooks will be cloned, no need to clone before passing it. +func newConn(opt *Options, connPool pool.Pooler, parentHooks *hooksMixin) *Conn { c := Conn{ baseClient: baseClient{ opt: opt, connPool: connPool, + onClose: &onCloseHooks{}, }, } + if parentHooks != nil { + c.hooksMixin = parentHooks.clone() + } + + // Initialize push notification processor using shared helper + // Use void processor for RESP2 connections (push notifications not available) + c.pushProcessor = initializePushProcessor(opt) + c.cmdable = c.Process c.statefulCmdable = c.Process c.initHooks(hooks{ @@ -857,6 +1663,13 @@ func (c *Conn) Process(ctx context.Context, cmd Cmder) error { return err } +// RegisterPushNotificationHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (c *Conn) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error { + return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected) +} + func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { return c.Pipeline().Pipelined(ctx, fn) } @@ -884,3 +1697,78 @@ func (c *Conn) TxPipeline() Pipeliner { pipe.init() return &pipe } + +// processPushNotifications processes all pending push notifications on a connection +// This ensures that cluster topology changes are handled immediately before the connection is used +// This method should be called by the client before using WithReader for command execution +// +// Performance optimization: Skip the expensive MaybeHasData() syscall if a health check +// was performed recently (within 5 seconds). The health check already verified the connection +// is healthy and checked for unexpected data (push notifications). +func (c *baseClient) processPushNotifications(ctx context.Context, cn *pool.Conn) error { + // Only process push notifications for RESP3 connections with a processor + if c.opt.Protocol != 3 || c.pushProcessor == nil { + return nil + } + + // Performance optimization: Skip MaybeHasData() syscall if health check was recent + // If the connection was health-checked within the last 5 seconds, we can skip the + // expensive syscall since the health check already verified no unexpected data. + // This is safe because: + // 0. lastHealthCheckNs is set in pool/conn.go:putConn() after a successful health check + // 1. Health check (connCheck) uses the same syscall (Recvfrom with MSG_PEEK) + // 2. If push notifications arrived, they would have been detected by health check + // 3. 5 seconds is short enough that connection state is still fresh + // 4. Push notifications will be processed by the next WithReader call + // used it is set on getConn, so we should use another timer (lastPutAt?) + lastHealthCheckNs := cn.LastPutAtNs() + if lastHealthCheckNs > 0 { + // Use pool's cached time to avoid expensive time.Now() syscall + nowNs := pool.GetCachedTimeNs() + if nowNs-lastHealthCheckNs < int64(5*time.Second) { + // Recent health check confirmed no unexpected data, skip the syscall + return nil + } + } + + // Check if there is any data to read before processing + // This is an optimization on UNIX systems where MaybeHasData is a syscall + // On Windows, MaybeHasData always returns true, so this check is a no-op + if !cn.MaybeHasData() { + return nil + } + + // Use WithReader to access the reader and process push notifications + // This is critical for maintnotifications to work properly + // NOTE: almost no timeouts are set for this read, so it should not block + // longer than necessary, 10us should be plenty of time to read if there are any push notifications + // on the socket. + return cn.WithReader(ctx, 10*time.Microsecond, func(rd *proto.Reader) error { + // Create handler context with client, connection pool, and connection information + handlerCtx := c.pushNotificationHandlerContext(cn) + return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd) + }) +} + +// processPendingPushNotificationWithReader processes all pending push notifications on a connection +// This method should be called by the client in WithReader before reading the reply +func (c *baseClient) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error { + // if we have the reader, we don't need to check for data on the socket, we are waiting + // for either a reply or a push notification, so we can block until we get a reply or reach the timeout + if c.opt.Protocol != 3 || c.pushProcessor == nil { + return nil + } + + // Create handler context with client, connection pool, and connection information + handlerCtx := c.pushNotificationHandlerContext(cn) + return c.pushProcessor.ProcessPendingNotifications(ctx, handlerCtx, rd) +} + +// pushNotificationHandlerContext creates a handler context for push notification processing +func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.NotificationHandlerContext { + return push.NotificationHandlerContext{ + Client: c, + ConnPool: c.connPool, + Conn: cn, // Wrap in adapter for easier interface access + } +} diff --git a/vendor/github.com/redis/go-redis/v9/result.go b/vendor/github.com/redis/go-redis/v9/result.go index cfd4cf92e..3e0d0a134 100644 --- a/vendor/github.com/redis/go-redis/v9/result.go +++ b/vendor/github.com/redis/go-redis/v9/result.go @@ -82,6 +82,14 @@ func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd { return &cmd } +// NewFloatSliceResult returns a FloatSliceCmd initialised with val and err for testing. +func NewFloatSliceResult(val []float64, err error) *FloatSliceCmd { + var cmd FloatSliceCmd + cmd.val = val + cmd.SetErr(err) + return &cmd +} + // NewMapStringStringResult returns a MapStringStringCmd initialised with val and err for testing. func NewMapStringStringResult(val map[string]string, err error) *MapStringStringCmd { var cmd MapStringStringCmd diff --git a/vendor/github.com/redis/go-redis/v9/ring.go b/vendor/github.com/redis/go-redis/v9/ring.go index 555ea2a16..b60d3eab5 100644 --- a/vendor/github.com/redis/go-redis/v9/ring.go +++ b/vendor/github.com/redis/go-redis/v9/ring.go @@ -5,39 +5,36 @@ import ( "crypto/tls" "errors" "fmt" + "math/rand" "net" "strconv" "sync" "sync/atomic" "time" - "github.com/cespare/xxhash/v2" - "github.com/dgryski/go-rendezvous" //nolint - + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hashtag" "github.com/redis/go-redis/v9/internal/pool" - "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/internal/proto" ) var errRingShardsDown = errors.New("redis: all ring shards are down") +// defaultHeartbeatFn is the default function used to check the shard liveness +var defaultHeartbeatFn = func(ctx context.Context, client *Client) bool { + err := client.Ping(ctx).Err() + return err == nil || err == pool.ErrPoolTimeout +} + //------------------------------------------------------------------------------ type ConsistentHash interface { Get(string) string } -type rendezvousWrapper struct { - *rendezvous.Rendezvous -} - -func (w rendezvousWrapper) Get(key string) string { - return w.Lookup(key) -} - func newRendezvous(shards []string) ConsistentHash { - return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)} + return hashtag.NewRendezvousHash(shards) } //------------------------------------------------------------------------------ @@ -54,10 +51,14 @@ type RingOptions struct { // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. ClientName string - // Frequency of PING commands sent to check shards availability. + // Frequency of executing HeartbeatFn to check shards availability. // Shard is considered down after 3 subsequent failed checks. HeartbeatFrequency time.Duration + // A function used to check the shard liveness + // if not set, defaults to defaultHeartbeatFn + HeartbeatFn func(ctx context.Context, client *Client) bool + // NewConsistentHash returns a consistent hash that is used // to distribute keys across the shards. // @@ -73,13 +74,45 @@ type RingOptions struct { Protocol int Username string Password string - DB int + // CredentialsProvider allows the username and password to be updated + // before reconnecting. It should return the current username and password. + CredentialsProvider func() (username string, password string) + + // CredentialsProviderContext is an enhanced parameter of CredentialsProvider, + // done to maintain API compatibility. In the future, + // there might be a merge between CredentialsProviderContext and CredentialsProvider. + // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. + CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + + // StreamingCredentialsProvider is used to retrieve the credentials + // for the connection from an external source. Those credentials may change + // during the connection lifetime. This is useful for managed identity + // scenarios where the credentials are retrieved from an external source. + // + // Currently, this is a placeholder for the future implementation. + StreamingCredentialsProvider auth.StreamingCredentialsProvider + DB int MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration - DialTimeout time.Duration + DialTimeout time.Duration + + // DialerRetries is the maximum number of retry attempts when dialing fails. + // + // default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // + // default: 100 milliseconds + DialerRetryTimeout time.Duration + + // DialerRetryBackoff controls the delay between dial retry attempts. + // See Options.DialerRetryBackoff for details. + DialerRetryBackoff func(attempt int) time.Duration + ReadTimeout time.Duration WriteTimeout time.Duration ContextTimeoutEnabled bool @@ -87,13 +120,28 @@ type RingOptions struct { // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). PoolFIFO bool - PoolSize int - PoolTimeout time.Duration - MinIdleConns int - MaxIdleConns int - MaxActiveConns int - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration + PoolSize int + PoolTimeout time.Duration + MinIdleConns int + MaxIdleConns int + MaxActiveConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + ConnMaxLifetimeJitter time.Duration + + // ReadBufferSize is the size of the bufio.Reader buffer for each connection. + // Larger buffers can improve performance for commands that return large responses. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + ReadBufferSize int + + // WriteBufferSize is the size of the bufio.Writer buffer for each connection. + // Larger buffers can improve performance for large pipelines and commands with many arguments. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + WriteBufferSize int TLSConfig *tls.Config Limiter Limiter @@ -110,7 +158,11 @@ type RingOptions struct { // default: false DisableIdentity bool IdentitySuffix string - UnstableResp3 bool + + // Deprecated: All RediSearch commands now have stable RESP3 parsing and this + // flag is a no-op. It is kept for backwards compatibility and will be removed + // in a future release. + UnstableResp3 bool } func (opt *RingOptions) init() { @@ -124,6 +176,10 @@ func (opt *RingOptions) init() { opt.HeartbeatFrequency = 500 * time.Millisecond } + if opt.HeartbeatFn == nil { + opt.HeartbeatFn = defaultHeartbeatFn + } + if opt.NewConsistentHash == nil { opt.NewConsistentHash = newRendezvous } @@ -146,6 +202,13 @@ func (opt *RingOptions) init() { case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } + + if opt.ReadBufferSize == 0 { + opt.ReadBufferSize = proto.DefaultBufferSize + } + if opt.WriteBufferSize == 0 { + opt.WriteBufferSize = proto.DefaultBufferSize + } } func (opt *RingOptions) clientOptions() *Options { @@ -154,26 +217,35 @@ func (opt *RingOptions) clientOptions() *Options { Dialer: opt.Dialer, OnConnect: opt.OnConnect, - Protocol: opt.Protocol, - Username: opt.Username, - Password: opt.Password, - DB: opt.DB, + Protocol: opt.Protocol, + Username: opt.Username, + Password: opt.Password, + CredentialsProvider: opt.CredentialsProvider, + CredentialsProviderContext: opt.CredentialsProviderContext, + StreamingCredentialsProvider: opt.StreamingCredentialsProvider, + DB: opt.DB, MaxRetries: -1, DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + DialerRetryBackoff: opt.DialerRetryBackoff, ReadTimeout: opt.ReadTimeout, WriteTimeout: opt.WriteTimeout, ContextTimeoutEnabled: opt.ContextTimeoutEnabled, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, + PoolFIFO: opt.PoolFIFO, + PoolSize: opt.PoolSize, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + MaxActiveConns: opt.MaxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ConnMaxLifetimeJitter: opt.ConnMaxLifetimeJitter, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, TLSConfig: opt.TLSConfig, Limiter: opt.Limiter, @@ -349,17 +421,16 @@ func (c *ringSharding) newRingShards( return } +// Warning: External exposure of `c.shards.list` may cause data races. +// So keep internal or implement deep copy if exposed. func (c *ringSharding) List() []*ringShard { - var list []*ringShard - c.mu.RLock() - if !c.closed { - list = make([]*ringShard, len(c.shards.list)) - copy(list, c.shards.list) - } - c.mu.RUnlock() + defer c.mu.RUnlock() - return list + if c.closed { + return nil + } + return c.shards.list } func (c *ringSharding) Hash(key string) string { @@ -406,7 +477,12 @@ func (c *ringSharding) GetByName(shardName string) (*ringShard, error) { c.mu.RLock() defer c.mu.RUnlock() - return c.shards.m[shardName], nil + shard, ok := c.shards.m[shardName] + if !ok { + return nil, errors.New("redis: the shard is not in the ring") + } + + return shard, nil } func (c *ringSharding) Random() (*ringShard, error) { @@ -423,9 +499,9 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { case <-ticker.C: var rebalance bool + // note: `c.List()` return a shadow copy of `[]*ringShard`. for _, shard := range c.List() { - err := shard.Client.Ping(ctx).Err() - isUp := err == nil || err == pool.ErrPoolTimeout + isUp := c.opt.HeartbeatFn(ctx, shard.Client) if shard.Vote(isUp) { internal.Logger.Printf(ctx, "ring shard state changed: %s", shard) rebalance = true @@ -522,6 +598,8 @@ type Ring struct { heartbeatCancelFn context.CancelFunc } +// NewRing returns a Redis Ring client to the Redis Server specified by RingOptions. +// Passing nil RingOptions will cause a panic. func NewRing(opt *RingOptions) *Ring { if opt == nil { panic("redis: NewRing nil options") @@ -558,20 +636,14 @@ func (c *Ring) SetAddrs(addrs map[string]string) { c.sharding.SetAddrs(addrs) } -// Do create a Cmd from the args and processes the cmd. -func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd { - cmd := NewCmd(ctx, args...) - _ = c.Process(ctx, cmd) - return cmd -} - func (c *Ring) Process(ctx context.Context, cmd Cmder) error { err := c.processHook(ctx, cmd) cmd.SetErr(err) return err } -// Options returns read-only Options that were used to create the client. +// Options returns read-only *RingOptions that were used to create the client. +// Any alteration of the returned *RingOptions may result in undefined behaviour. func (c *Ring) Options() *RingOptions { return c.opt } @@ -582,6 +654,7 @@ func (c *Ring) retryBackoff(attempt int) time.Duration { // PoolStats returns accumulated connection pool stats. func (c *Ring) PoolStats() *PoolStats { + // note: `c.List()` return a shadow copy of `[]*ringShard`. shards := c.sharding.List() var acc PoolStats for _, shard := range shards { @@ -651,6 +724,7 @@ func (c *Ring) ForEachShard( ctx context.Context, fn func(ctx context.Context, client *Client) error, ) error { + // note: `c.List()` return a shadow copy of `[]*ringShard`. shards := c.sharding.List() var wg sync.WaitGroup errCh := make(chan error, 1) @@ -682,6 +756,7 @@ func (c *Ring) ForEachShard( } func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { + // note: `c.List()` return a shadow copy of `[]*ringShard`. shards := c.sharding.List() var firstErr error for _, shard := range shards { @@ -699,8 +774,11 @@ func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { return nil, firstErr } -func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) { - pos := cmdFirstKeyPos(cmd) +func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) { + // TODO: populate cmdsInfoCache lazily (via cmdsInfoCache.Get) so that + // the warm-cache branch in cmdFirstKeyPosWithInfo is reachable for Ring, + // mirroring how ClusterClient.cmdInfo works. For now pass nil + pos := cmdFirstKeyPosWithInfo(cmd, nil) if pos == 0 { return c.sharding.Random() } @@ -717,13 +795,13 @@ func (c *Ring) process(ctx context.Context, cmd Cmder) error { } } - shard, err := c.cmdShard(ctx, cmd) + shard, err := c.cmdShard(cmd) if err != nil { return err } lastErr = shard.Client.Process(ctx, cmd) - if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) { + if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) || cmd.NoRetry() { return lastErr } } @@ -768,7 +846,7 @@ func (c *Ring) generalProcessPipeline( cmdsMap := make(map[string][]Cmder) for _, cmd := range cmds { - hash := cmd.stringArg(cmdFirstKeyPos(cmd)) + hash := cmd.stringArg(cmdFirstKeyPosWithInfo(cmd, nil)) if hash != "" { hash = c.sharding.Hash(hash) } @@ -776,6 +854,8 @@ func (c *Ring) generalProcessPipeline( } var wg sync.WaitGroup + errs := make(chan error, len(cmdsMap)) + for hash, cmds := range cmdsMap { wg.Add(1) go func(hash string, cmds []Cmder) { @@ -788,16 +868,24 @@ func (c *Ring) generalProcessPipeline( return } + hook := shard.Client.processPipelineHook if tx { cmds = wrapMultiExec(ctx, cmds) - _ = shard.Client.processTxPipelineHook(ctx, cmds) - } else { - _ = shard.Client.processPipelineHook(ctx, cmds) + hook = shard.Client.processTxPipelineHook + } + + if err = hook(ctx, cmds); err != nil { + errs <- err } }(hash, cmds) } wg.Wait() + close(errs) + + if err := <-errs; err != nil { + return err + } return cmdsFirstErr(cmds) } @@ -810,7 +898,7 @@ func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) er for _, key := range keys { if key != "" { - shard, err := c.sharding.GetByKey(hashtag.Key(key)) + shard, err := c.sharding.GetByKey(key) if err != nil { return err } @@ -844,3 +932,26 @@ func (c *Ring) Close() error { return c.sharding.Close() } + +// GetShardClients returns a list of all shard clients in the ring. +// This can be used to create dedicated connections (e.g., PubSub) for each shard. +func (c *Ring) GetShardClients() []*Client { + shards := c.sharding.List() + clients := make([]*Client, 0, len(shards)) + for _, shard := range shards { + if shard.IsUp() { + clients = append(clients, shard.Client) + } + } + return clients +} + +// GetShardClientForKey returns the shard client that would handle the given key. +// This can be used to determine which shard a particular key/channel would be routed to. +func (c *Ring) GetShardClientForKey(key string) (*Client, error) { + shard, err := c.sharding.GetByKey(key) + if err != nil { + return nil, err + } + return shard.Client, nil +} diff --git a/vendor/github.com/redis/go-redis/v9/script.go b/vendor/github.com/redis/go-redis/v9/script.go index 626ab03bb..92d508f9a 100644 --- a/vendor/github.com/redis/go-redis/v9/script.go +++ b/vendor/github.com/redis/go-redis/v9/script.go @@ -4,7 +4,9 @@ import ( "context" "crypto/sha1" "encoding/hex" + "errors" "io" + "sync" ) type Scripter interface { @@ -23,28 +25,69 @@ var ( ) type Script struct { - src, hash string + src string + mu sync.RWMutex + hash string + serverSHA bool // if true: do not compute SHA-1 in Go; load digest from Redis (SCRIPT LOAD) } func NewScript(src string) *Script { h := sha1.New() _, _ = io.WriteString(h, src) + + return &Script{ + src: src, + hash: hex.EncodeToString(h.Sum(nil)), + serverSHA: false, + } +} + +// NewScriptServerSHA creates a Script that avoids computing SHA-1 in Go. +// The digest is obtained from Redis via SCRIPT LOAD (server-side hashing), +// then EVALSHA/EVALSHA_RO is used. +func NewScriptServerSHA(src string) *Script { return &Script{ - src: src, - hash: hex.EncodeToString(h.Sum(nil)), + src: src, + serverSHA: true, } } func (s *Script) Hash() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.hash } func (s *Script) Load(ctx context.Context, c Scripter) *StringCmd { - return c.ScriptLoad(ctx, s.src) + cmd := c.ScriptLoad(ctx, s.src) + if err := cmd.Err(); err == nil { + s.mu.Lock() + s.hash = cmd.Val() + s.mu.Unlock() + } + return cmd } func (s *Script) Exists(ctx context.Context, c Scripter) *BoolSliceCmd { - return c.ScriptExists(ctx, s.hash) + s.mu.RLock() + hash := s.hash + serverSHA := s.serverSHA + s.mu.RUnlock() + if hash == "" && serverSHA { + // For server-side scripts, obtain digest from Redis first. + // If hash is empty, it means SCRIPT LOAD was not called yet, so we check existence of empty hash which will return false. + // This avoids unnecessary SCRIPT LOAD just to check existence. + if err := s.ensureHash(ctx, c); err != nil { + return c.ScriptExists(ctx, "") + } + s.mu.RLock() + hash = s.hash + s.mu.RUnlock() + } + if hash == "" { + return c.ScriptExists(ctx, "") + } + return c.ScriptExists(ctx, hash) } func (s *Script) Eval(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { @@ -55,19 +98,101 @@ func (s *Script) EvalRO(ctx context.Context, c Scripter, keys []string, args ... return c.EvalRO(ctx, s.src, keys, args...) } +// ensureHash ensures that s.hash is populated by using SCRIPT LOAD. +// It never calls SHA-1 in Go; Redis computes and returns the digest. +func (s *Script) ensureHash(ctx context.Context, c Scripter) error { + // Fast path: read lock, return if hash is already set. + s.mu.RLock() + if s.hash != "" { + s.mu.RUnlock() + return nil + } + s.mu.RUnlock() + + // Slow path: acquire write lock and load. + s.mu.Lock() + if s.hash != "" { + s.mu.Unlock() + return nil + } + cmd := c.ScriptLoad(ctx, s.src) + if err := cmd.Err(); err != nil { + s.mu.Unlock() + return err + } + s.hash = cmd.Val() + s.mu.Unlock() + return nil +} + func (s *Script) EvalSha(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { - return c.EvalSha(ctx, s.hash, keys, args...) + // Default behavior: use client-side SHA-1 computed in NewScript. + if !s.serverSHA { + s.mu.RLock() + hash := s.hash + s.mu.RUnlock() + return c.EvalSha(ctx, hash, keys, args...) + } + + // Server-side SHA via SCRIPT LOAD + EVALSHA. + if err := s.ensureHash(ctx, c); err != nil { + return s.Eval(ctx, c, keys, args...) + } + + s.mu.RLock() + hash := s.hash + s.mu.RUnlock() + + r := c.EvalSha(ctx, hash, keys, args...) + if HasErrorPrefix(r.Err(), "NOSCRIPT") { + // Script cache was flushed; reload and retry once. + if err := s.ensureHash(ctx, c); err != nil { + return s.Eval(ctx, c, keys, args...) + } + s.mu.RLock() + hash = s.hash + s.mu.RUnlock() + return c.EvalSha(ctx, hash, keys, args...) + } + + return r } func (s *Script) EvalShaRO(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { - return c.EvalShaRO(ctx, s.hash, keys, args...) + if !s.serverSHA { + s.mu.RLock() + hash := s.hash + s.mu.RUnlock() + return c.EvalShaRO(ctx, hash, keys, args...) + } + + if err := s.ensureHash(ctx, c); err != nil { + return s.EvalRO(ctx, c, keys, args...) + } + + s.mu.RLock() + hash := s.hash + s.mu.RUnlock() + + r := c.EvalShaRO(ctx, hash, keys, args...) + if HasErrorPrefix(r.Err(), "NOSCRIPT") { + if err := s.ensureHash(ctx, c); err != nil { + return s.EvalRO(ctx, c, keys, args...) + } + s.mu.RLock() + hash = s.hash + s.mu.RUnlock() + return c.EvalShaRO(ctx, hash, keys, args...) + } + + return r } // Run optimistically uses EVALSHA to run the script. If script does not exist // it is retried using EVAL. func (s *Script) Run(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { r := s.EvalSha(ctx, c, keys, args...) - if HasErrorPrefix(r.Err(), "NOSCRIPT") { + if errors.Is(r.Err(), ErrNoScript) { return s.Eval(ctx, c, keys, args...) } return r @@ -77,7 +202,7 @@ func (s *Script) Run(ctx context.Context, c Scripter, keys []string, args ...int // it is retried using EVAL_RO. func (s *Script) RunRO(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { r := s.EvalShaRO(ctx, c, keys, args...) - if HasErrorPrefix(r.Err(), "NOSCRIPT") { + if errors.Is(r.Err(), ErrNoScript) { return s.EvalRO(ctx, c, keys, args...) } return r diff --git a/vendor/github.com/redis/go-redis/v9/scripting_commands.go b/vendor/github.com/redis/go-redis/v9/scripting_commands.go index af9c3397b..3310b9d0a 100644 --- a/vendor/github.com/redis/go-redis/v9/scripting_commands.go +++ b/vendor/github.com/redis/go-redis/v9/scripting_commands.go @@ -60,6 +60,11 @@ func (c cmdable) eval(ctx context.Context, name, payload string, keys []string, cmd.SetFirstKeyPos(3) } _ = c(ctx, cmd) + if err := cmd.Err(); err != nil { + if HasErrorPrefix(err, "NOSCRIPT") { + cmd.SetErr(ErrNoScript) + } + } return cmd } diff --git a/vendor/github.com/redis/go-redis/v9/search_builders.go b/vendor/github.com/redis/go-redis/v9/search_builders.go new file mode 100644 index 000000000..a6c6718c3 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/search_builders.go @@ -0,0 +1,858 @@ +package redis + +import ( + "context" + "fmt" +) + +// ---------------------- +// Search Module Builders +// ---------------------- + +// SearchBuilder provides a fluent API for FT.SEARCH +// (see original FTSearchOptions for all options). +// EXPERIMENTAL: this API is subject to change, use with caution. +type SearchBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTSearchOptions +} + +// NewSearchBuilder creates a new SearchBuilder for FT.SEARCH commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewSearchBuilder(ctx context.Context, index, query string) *SearchBuilder { + b := &SearchBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSearchOptions{LimitOffset: -1}} + return b +} + +// WithScores includes WITHSCORES. +func (b *SearchBuilder) WithScores() *SearchBuilder { + b.options.WithScores = true + return b +} + +// NoContent includes NOCONTENT. +func (b *SearchBuilder) NoContent() *SearchBuilder { b.options.NoContent = true; return b } + +// Verbatim includes VERBATIM. +func (b *SearchBuilder) Verbatim() *SearchBuilder { b.options.Verbatim = true; return b } + +// NoStopWords includes NOSTOPWORDS. +func (b *SearchBuilder) NoStopWords() *SearchBuilder { b.options.NoStopWords = true; return b } + +// WithPayloads includes WITHPAYLOADS. +func (b *SearchBuilder) WithPayloads() *SearchBuilder { + b.options.WithPayloads = true + return b +} + +// WithSortKeys includes WITHSORTKEYS. +func (b *SearchBuilder) WithSortKeys() *SearchBuilder { + b.options.WithSortKeys = true + return b +} + +// Filter adds a FILTER clause: FILTER . +func (b *SearchBuilder) Filter(field string, min, max interface{}) *SearchBuilder { + b.options.Filters = append(b.options.Filters, FTSearchFilter{ + FieldName: field, + Min: min, + Max: max, + }) + return b +} + +// GeoFilter adds a GEOFILTER clause: GEOFILTER . +func (b *SearchBuilder) GeoFilter(field string, lon, lat, radius float64, unit string) *SearchBuilder { + b.options.GeoFilter = append(b.options.GeoFilter, FTSearchGeoFilter{ + FieldName: field, + Longitude: lon, + Latitude: lat, + Radius: radius, + Unit: unit, + }) + return b +} + +// InKeys restricts the search to the given keys. +func (b *SearchBuilder) InKeys(keys ...interface{}) *SearchBuilder { + b.options.InKeys = append(b.options.InKeys, keys...) + return b +} + +// InFields restricts the search to the given fields. +func (b *SearchBuilder) InFields(fields ...interface{}) *SearchBuilder { + b.options.InFields = append(b.options.InFields, fields...) + return b +} + +// ReturnFields adds simple RETURN ... +func (b *SearchBuilder) ReturnFields(fields ...string) *SearchBuilder { + for _, f := range fields { + b.options.Return = append(b.options.Return, FTSearchReturn{FieldName: f}) + } + return b +} + +// ReturnAs adds RETURN AS . +func (b *SearchBuilder) ReturnAs(field, alias string) *SearchBuilder { + b.options.Return = append(b.options.Return, FTSearchReturn{FieldName: field, As: alias}) + return b +} + +// Slop adds SLOP . +func (b *SearchBuilder) Slop(slop int) *SearchBuilder { + b.options.Slop = slop + return b +} + +// Timeout adds TIMEOUT . +func (b *SearchBuilder) Timeout(timeout int) *SearchBuilder { + b.options.Timeout = timeout + return b +} + +// InOrder includes INORDER. +func (b *SearchBuilder) InOrder() *SearchBuilder { + b.options.InOrder = true + return b +} + +// Language sets LANGUAGE . +func (b *SearchBuilder) Language(lang string) *SearchBuilder { + b.options.Language = lang + return b +} + +// Expander sets EXPANDER . +func (b *SearchBuilder) Expander(expander string) *SearchBuilder { + b.options.Expander = expander + return b +} + +// Scorer sets SCORER . +func (b *SearchBuilder) Scorer(scorer string) *SearchBuilder { + b.options.Scorer = scorer + return b +} + +// ExplainScore includes EXPLAINSCORE. +func (b *SearchBuilder) ExplainScore() *SearchBuilder { + b.options.ExplainScore = true + return b +} + +// Payload sets PAYLOAD . +func (b *SearchBuilder) Payload(payload string) *SearchBuilder { + b.options.Payload = payload + return b +} + +// SortBy adds SORTBY ASC|DESC. +func (b *SearchBuilder) SortBy(field string, asc bool) *SearchBuilder { + b.options.SortBy = append(b.options.SortBy, FTSearchSortBy{ + FieldName: field, + Asc: asc, + Desc: !asc, + }) + return b +} + +// WithSortByCount includes WITHCOUNT (when used with SortBy). +func (b *SearchBuilder) WithSortByCount() *SearchBuilder { + b.options.SortByWithCount = true + return b +} + +// Param adds a single PARAMS . +func (b *SearchBuilder) Param(key string, value interface{}) *SearchBuilder { + if b.options.Params == nil { + b.options.Params = make(map[string]interface{}, 1) + } + b.options.Params[key] = value + return b +} + +// ParamsMap adds multiple PARAMS at once. +func (b *SearchBuilder) ParamsMap(p map[string]interface{}) *SearchBuilder { + if b.options.Params == nil { + b.options.Params = make(map[string]interface{}, len(p)) + } + for k, v := range p { + b.options.Params[k] = v + } + return b +} + +// Dialect sets DIALECT . +func (b *SearchBuilder) Dialect(version int) *SearchBuilder { + b.options.DialectVersion = version + return b +} + +// Limit sets OFFSET and COUNT. CountOnly uses LIMIT 0 0. +func (b *SearchBuilder) Limit(offset, count int) *SearchBuilder { + b.options.LimitOffset = offset + b.options.Limit = count + return b +} +func (b *SearchBuilder) CountOnly() *SearchBuilder { b.options.CountOnly = true; return b } + +// Run executes FT.SEARCH and returns a typed result. +func (b *SearchBuilder) Run() (FTSearchResult, error) { + cmd := b.c.FTSearchWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// AggregateBuilder for FT.AGGREGATE +// ---------------------- + +type AggregateBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTAggregateOptions + err error +} + +// NewAggregateBuilder creates a new AggregateBuilder for FT.AGGREGATE commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewAggregateBuilder(ctx context.Context, index, query string) *AggregateBuilder { + return &AggregateBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTAggregateOptions{LimitOffset: -1}} +} + +// setErr records the first error produced while building the pipeline. +// Subsequent errors are ignored; the first error is returned from Run. +func (b *AggregateBuilder) setErr(err error) { + if b.err == nil { + b.err = err + } +} + +// Verbatim includes VERBATIM. +func (b *AggregateBuilder) Verbatim() *AggregateBuilder { b.options.Verbatim = true; return b } + +// AddScores includes ADDSCORES. +func (b *AggregateBuilder) AddScores() *AggregateBuilder { b.options.AddScores = true; return b } + +// Scorer sets SCORER . +func (b *AggregateBuilder) Scorer(s string) *AggregateBuilder { + b.options.Scorer = s + return b +} + +// LoadAll includes LOAD * (mutually exclusive with Load). +func (b *AggregateBuilder) LoadAll() *AggregateBuilder { + b.options.LoadAll = true + return b +} + +// Load adds a LOAD [AS alias] step. +// You can call it multiple times; each call becomes a separate LOAD clause +// at its position in the pipeline. +func (b *AggregateBuilder) Load(field string, alias ...string) *AggregateBuilder { + l := &FTAggregateLoad{Field: field} + if len(alias) > 0 { + l.As = alias[0] + } + b.options.Steps = append(b.options.Steps, FTAggregateStep{Load: l}) + return b +} + +// Timeout sets TIMEOUT . +func (b *AggregateBuilder) Timeout(ms int) *AggregateBuilder { + b.options.Timeout = ms + return b +} + +// Apply adds an APPLY [AS alias] step. +func (b *AggregateBuilder) Apply(field string, alias ...string) *AggregateBuilder { + a := &FTAggregateApply{Field: field} + if len(alias) > 0 { + a.As = alias[0] + } + b.options.Steps = append(b.options.Steps, FTAggregateStep{Apply: a}) + return b +} + +// GroupBy adds a new GROUPBY step. +func (b *AggregateBuilder) GroupBy(fields ...interface{}) *AggregateBuilder { + b.options.Steps = append(b.options.Steps, FTAggregateStep{ + GroupBy: &FTAggregateGroupBy{Fields: fields}, + }) + return b +} + +// Reduce adds a REDUCE [<#args> ] clause to the last step, +// which must be a GROUPBY. If it is not, Run will return an error. +func (b *AggregateBuilder) Reduce(fn SearchAggregator, args ...interface{}) *AggregateBuilder { + n := len(b.options.Steps) + if n == 0 || b.options.Steps[n-1].GroupBy == nil { + b.setErr(fmt.Errorf("FT.AGGREGATE: Reduce must follow a GroupBy step")) + return b + } + g := b.options.Steps[n-1].GroupBy + g.Reduce = append(g.Reduce, FTAggregateReducer{Reducer: fn, Args: args}) + return b +} + +// ReduceAs does the same but also sets an alias: REDUCE … AS . +// The last step must be a GROUPBY; otherwise Run will return an error. +func (b *AggregateBuilder) ReduceAs(fn SearchAggregator, alias string, args ...interface{}) *AggregateBuilder { + n := len(b.options.Steps) + if n == 0 || b.options.Steps[n-1].GroupBy == nil { + b.setErr(fmt.Errorf("FT.AGGREGATE: ReduceAs must follow a GroupBy step")) + return b + } + g := b.options.Steps[n-1].GroupBy + g.Reduce = append(g.Reduce, FTAggregateReducer{Reducer: fn, Args: args, As: alias}) + return b +} + +// SortBy adds SORTBY ASC|DESC. Consecutive SortBy calls (with no +// other step in between) are merged into a single SORTBY clause so fields +// act as tiebreakers. A SortBy call after a non-SortBy step starts a new +// SORTBY step. +// +// Note: this is a semantics change from earlier experimental versions of +// the builder, where SortBy always accumulated into a single SORTBY clause +// regardless of position in the pipeline. +func (b *AggregateBuilder) SortBy(field string, asc bool) *AggregateBuilder { + sb := FTAggregateSortBy{FieldName: field, Asc: asc, Desc: !asc} + if n := len(b.options.Steps); n > 0 && b.options.Steps[n-1].SortBy != nil { + b.options.Steps[n-1].SortBy.Fields = append(b.options.Steps[n-1].SortBy.Fields, sb) + return b + } + b.options.Steps = append(b.options.Steps, FTAggregateStep{ + SortBy: &FTAggregateSortByStep{Fields: []FTAggregateSortBy{sb}}, + }) + return b +} + +// SortByMax sets MAX on the last SORTBY step. The last step must be a +// SORTBY; otherwise Run will return an error. +func (b *AggregateBuilder) SortByMax(max int) *AggregateBuilder { + n := len(b.options.Steps) + if n == 0 || b.options.Steps[n-1].SortBy == nil { + b.setErr(fmt.Errorf("FT.AGGREGATE: SortByMax must follow a SortBy step")) + return b + } + b.options.Steps[n-1].SortBy.Max = max + return b +} + +// Filter sets FILTER . +func (b *AggregateBuilder) Filter(expr string) *AggregateBuilder { + b.options.Filter = expr + return b +} + +// WithCursor enables WITHCURSOR [COUNT ] [MAXIDLE ]. +func (b *AggregateBuilder) WithCursor(count, maxIdle int) *AggregateBuilder { + b.options.WithCursor = true + if b.options.WithCursorOptions == nil { + b.options.WithCursorOptions = &FTAggregateWithCursor{} + } + b.options.WithCursorOptions.Count = count + b.options.WithCursorOptions.MaxIdle = maxIdle + return b +} + +// Params adds PARAMS pairs. +func (b *AggregateBuilder) Params(p map[string]interface{}) *AggregateBuilder { + if b.options.Params == nil { + b.options.Params = make(map[string]interface{}, len(p)) + } + for k, v := range p { + b.options.Params[k] = v + } + return b +} + +// Dialect sets DIALECT . +func (b *AggregateBuilder) Dialect(version int) *AggregateBuilder { + b.options.DialectVersion = version + return b +} + +// Run executes FT.AGGREGATE and returns a typed result. If the builder +// recorded a validation error while constructing the pipeline (for example, +// calling SortByMax when the last step is not a SortBy), that error is +// returned without issuing the command. +func (b *AggregateBuilder) Run() (*FTAggregateResult, error) { + if b.err != nil { + return nil, b.err + } + cmd := b.c.FTAggregateWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// CreateIndexBuilder for FT.CREATE +// ---------------------- +// CreateIndexBuilder is builder for FT.CREATE +// EXPERIMENTAL: this API is subject to change, use with caution. +type CreateIndexBuilder struct { + c *Client + ctx context.Context + index string + options *FTCreateOptions + schema []*FieldSchema +} + +// NewCreateIndexBuilder creates a new CreateIndexBuilder for FT.CREATE commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewCreateIndexBuilder(ctx context.Context, index string) *CreateIndexBuilder { + return &CreateIndexBuilder{c: c, ctx: ctx, index: index, options: &FTCreateOptions{}} +} + +// OnHash sets ON HASH. +func (b *CreateIndexBuilder) OnHash() *CreateIndexBuilder { b.options.OnHash = true; return b } + +// OnJSON sets ON JSON. +func (b *CreateIndexBuilder) OnJSON() *CreateIndexBuilder { b.options.OnJSON = true; return b } + +// Prefix sets PREFIX. +func (b *CreateIndexBuilder) Prefix(prefixes ...interface{}) *CreateIndexBuilder { + b.options.Prefix = prefixes + return b +} + +// Filter sets FILTER. +func (b *CreateIndexBuilder) Filter(filter string) *CreateIndexBuilder { + b.options.Filter = filter + return b +} + +// DefaultLanguage sets LANGUAGE. +func (b *CreateIndexBuilder) DefaultLanguage(lang string) *CreateIndexBuilder { + b.options.DefaultLanguage = lang + return b +} + +// LanguageField sets LANGUAGE_FIELD. +func (b *CreateIndexBuilder) LanguageField(field string) *CreateIndexBuilder { + b.options.LanguageField = field + return b +} + +// Score sets SCORE. +func (b *CreateIndexBuilder) Score(score float64) *CreateIndexBuilder { + b.options.Score = score + return b +} + +// ScoreField sets SCORE_FIELD. +func (b *CreateIndexBuilder) ScoreField(field string) *CreateIndexBuilder { + b.options.ScoreField = field + return b +} + +// PayloadField sets PAYLOAD_FIELD. +func (b *CreateIndexBuilder) PayloadField(field string) *CreateIndexBuilder { + b.options.PayloadField = field + return b +} + +// NoOffsets includes NOOFFSETS. +func (b *CreateIndexBuilder) NoOffsets() *CreateIndexBuilder { b.options.NoOffsets = true; return b } + +// Temporary sets TEMPORARY seconds. +func (b *CreateIndexBuilder) Temporary(sec int) *CreateIndexBuilder { + b.options.Temporary = sec + return b +} + +// NoHL includes NOHL. +func (b *CreateIndexBuilder) NoHL() *CreateIndexBuilder { b.options.NoHL = true; return b } + +// NoFields includes NOFIELDS. +func (b *CreateIndexBuilder) NoFields() *CreateIndexBuilder { b.options.NoFields = true; return b } + +// NoFreqs includes NOFREQS. +func (b *CreateIndexBuilder) NoFreqs() *CreateIndexBuilder { b.options.NoFreqs = true; return b } + +// StopWords sets STOPWORDS. +func (b *CreateIndexBuilder) StopWords(words ...interface{}) *CreateIndexBuilder { + b.options.StopWords = words + return b +} + +// SkipInitialScan includes SKIPINITIALSCAN. +func (b *CreateIndexBuilder) SkipInitialScan() *CreateIndexBuilder { + b.options.SkipInitialScan = true + return b +} + +// Schema adds a FieldSchema. +func (b *CreateIndexBuilder) Schema(field *FieldSchema) *CreateIndexBuilder { + b.schema = append(b.schema, field) + return b +} + +// Run executes FT.CREATE and returns the status. +func (b *CreateIndexBuilder) Run() (string, error) { + cmd := b.c.FTCreate(b.ctx, b.index, b.options, b.schema...) + return cmd.Result() +} + +// ---------------------- +// DropIndexBuilder for FT.DROPINDEX +// ---------------------- +// DropIndexBuilder is a builder for FT.DROPINDEX +// EXPERIMENTAL: this API is subject to change, use with caution. +type DropIndexBuilder struct { + c *Client + ctx context.Context + index string + options *FTDropIndexOptions +} + +// NewDropIndexBuilder creates a new DropIndexBuilder for FT.DROPINDEX commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewDropIndexBuilder(ctx context.Context, index string) *DropIndexBuilder { + return &DropIndexBuilder{c: c, ctx: ctx, index: index} +} + +// DeleteRuncs includes DD. +func (b *DropIndexBuilder) DeleteDocs() *DropIndexBuilder { b.options.DeleteDocs = true; return b } + +// Run executes FT.DROPINDEX. +func (b *DropIndexBuilder) Run() (string, error) { + cmd := b.c.FTDropIndexWithArgs(b.ctx, b.index, b.options) + return cmd.Result() +} + +// ---------------------- +// AliasBuilder for FT.ALIAS* commands +// ---------------------- +// AliasBuilder is builder for FT.ALIAS* commands +// EXPERIMENTAL: this API is subject to change, use with caution. +type AliasBuilder struct { + c *Client + ctx context.Context + alias string + index string + action string // add|del|update +} + +// NewAliasBuilder creates a new AliasBuilder for FT.ALIAS* commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewAliasBuilder(ctx context.Context, alias string) *AliasBuilder { + return &AliasBuilder{c: c, ctx: ctx, alias: alias} +} + +// Action sets the action for the alias builder. +func (b *AliasBuilder) Action(action string) *AliasBuilder { + b.action = action + return b +} + +// Add sets the action to "add" and requires an index. +func (b *AliasBuilder) Add(index string) *AliasBuilder { + b.action = "add" + b.index = index + return b +} + +// Del sets the action to "del". +func (b *AliasBuilder) Del() *AliasBuilder { + b.action = "del" + return b +} + +// Update sets the action to "update" and requires an index. +func (b *AliasBuilder) Update(index string) *AliasBuilder { + b.action = "update" + b.index = index + return b +} + +// Run executes the configured alias command. +func (b *AliasBuilder) Run() (string, error) { + switch b.action { + case "add": + cmd := b.c.FTAliasAdd(b.ctx, b.index, b.alias) + return cmd.Result() + case "del": + cmd := b.c.FTAliasDel(b.ctx, b.alias) + return cmd.Result() + case "update": + cmd := b.c.FTAliasUpdate(b.ctx, b.index, b.alias) + return cmd.Result() + } + return "", nil +} + +// ---------------------- +// ExplainBuilder for FT.EXPLAIN +// ---------------------- +// ExplainBuilder is builder for FT.EXPLAIN +// EXPERIMENTAL: this API is subject to change, use with caution. +type ExplainBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTExplainOptions +} + +// NewExplainBuilder creates a new ExplainBuilder for FT.EXPLAIN commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewExplainBuilder(ctx context.Context, index, query string) *ExplainBuilder { + return &ExplainBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTExplainOptions{}} +} + +// Dialect sets dialect for EXPLAINCLI. +func (b *ExplainBuilder) Dialect(d string) *ExplainBuilder { b.options.Dialect = d; return b } + +// Run executes FT.EXPLAIN and returns the plan. +func (b *ExplainBuilder) Run() (string, error) { + cmd := b.c.FTExplainWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// InfoBuilder for FT.INFO +// ---------------------- + +type FTInfoBuilder struct { + c *Client + ctx context.Context + index string +} + +// NewSearchInfoBuilder creates a new FTInfoBuilder for FT.INFO commands. +func (c *Client) NewSearchInfoBuilder(ctx context.Context, index string) *FTInfoBuilder { + return &FTInfoBuilder{c: c, ctx: ctx, index: index} +} + +// Run executes FT.INFO and returns detailed info. +func (b *FTInfoBuilder) Run() (FTInfoResult, error) { + cmd := b.c.FTInfo(b.ctx, b.index) + return cmd.Result() +} + +// ---------------------- +// SpellCheckBuilder for FT.SPELLCHECK +// ---------------------- +// SpellCheckBuilder is builder for FT.SPELLCHECK +// EXPERIMENTAL: this API is subject to change, use with caution. +type SpellCheckBuilder struct { + c *Client + ctx context.Context + index string + query string + options *FTSpellCheckOptions +} + +// NewSpellCheckBuilder creates a new SpellCheckBuilder for FT.SPELLCHECK commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewSpellCheckBuilder(ctx context.Context, index, query string) *SpellCheckBuilder { + return &SpellCheckBuilder{c: c, ctx: ctx, index: index, query: query, options: &FTSpellCheckOptions{}} +} + +// Distance sets MAXDISTANCE. +func (b *SpellCheckBuilder) Distance(d int) *SpellCheckBuilder { b.options.Distance = d; return b } + +// Terms sets INCLUDE or EXCLUDE terms. +func (b *SpellCheckBuilder) Terms(include bool, dictionary string, terms ...interface{}) *SpellCheckBuilder { + if b.options.Terms == nil { + b.options.Terms = &FTSpellCheckTerms{} + } + if include { + b.options.Terms.Inclusion = "INCLUDE" + } else { + b.options.Terms.Inclusion = "EXCLUDE" + } + b.options.Terms.Dictionary = dictionary + b.options.Terms.Terms = terms + return b +} + +// Dialect sets dialect version. +func (b *SpellCheckBuilder) Dialect(d int) *SpellCheckBuilder { b.options.Dialect = d; return b } + +// Run executes FT.SPELLCHECK and returns suggestions. +func (b *SpellCheckBuilder) Run() ([]SpellCheckResult, error) { + cmd := b.c.FTSpellCheckWithArgs(b.ctx, b.index, b.query, b.options) + return cmd.Result() +} + +// ---------------------- +// DictBuilder for FT.DICT* commands +// ---------------------- +// DictBuilder is builder for FT.DICT* commands +// EXPERIMENTAL: this API is subject to change, use with caution. +type DictBuilder struct { + c *Client + ctx context.Context + dict string + terms []interface{} + action string // add|del|dump +} + +// NewDictBuilder creates a new DictBuilder for FT.DICT* commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewDictBuilder(ctx context.Context, dict string) *DictBuilder { + return &DictBuilder{c: c, ctx: ctx, dict: dict} +} + +// Action sets the action for the dictionary builder. +func (b *DictBuilder) Action(action string) *DictBuilder { + b.action = action + return b +} + +// Add sets the action to "add" and requires terms. +func (b *DictBuilder) Add(terms ...interface{}) *DictBuilder { + b.action = "add" + b.terms = terms + return b +} + +// Del sets the action to "del" and requires terms. +func (b *DictBuilder) Del(terms ...interface{}) *DictBuilder { + b.action = "del" + b.terms = terms + return b +} + +// Dump sets the action to "dump". +func (b *DictBuilder) Dump() *DictBuilder { + b.action = "dump" + return b +} + +// Run executes the configured dictionary command. +func (b *DictBuilder) Run() (interface{}, error) { + switch b.action { + case "add": + cmd := b.c.FTDictAdd(b.ctx, b.dict, b.terms...) + return cmd.Result() + case "del": + cmd := b.c.FTDictDel(b.ctx, b.dict, b.terms...) + return cmd.Result() + case "dump": + cmd := b.c.FTDictDump(b.ctx, b.dict) + return cmd.Result() + } + return nil, nil +} + +// ---------------------- +// TagValsBuilder for FT.TAGVALS +// ---------------------- +// TagValsBuilder is builder for FT.TAGVALS +// EXPERIMENTAL: this API is subject to change, use with caution. +type TagValsBuilder struct { + c *Client + ctx context.Context + index string + field string +} + +// NewTagValsBuilder creates a new TagValsBuilder for FT.TAGVALS commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewTagValsBuilder(ctx context.Context, index, field string) *TagValsBuilder { + return &TagValsBuilder{c: c, ctx: ctx, index: index, field: field} +} + +// Run executes FT.TAGVALS and returns tag values. +func (b *TagValsBuilder) Run() ([]string, error) { + cmd := b.c.FTTagVals(b.ctx, b.index, b.field) + return cmd.Result() +} + +// ---------------------- +// CursorBuilder for FT.CURSOR* +// ---------------------- +// CursorBuilder is builder for FT.CURSOR* commands +// EXPERIMENTAL: this API is subject to change, use with caution. +type CursorBuilder struct { + c *Client + ctx context.Context + index string + cursorId int64 + count int + action string // read|del +} + +// NewCursorBuilder creates a new CursorBuilder for FT.CURSOR* commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewCursorBuilder(ctx context.Context, index string, cursorId int64) *CursorBuilder { + return &CursorBuilder{c: c, ctx: ctx, index: index, cursorId: cursorId} +} + +// Action sets the action for the cursor builder. +func (b *CursorBuilder) Action(action string) *CursorBuilder { + b.action = action + return b +} + +// Read sets the action to "read". +func (b *CursorBuilder) Read() *CursorBuilder { + b.action = "read" + return b +} + +// Del sets the action to "del". +func (b *CursorBuilder) Del() *CursorBuilder { + b.action = "del" + return b +} + +// Count for READ. +func (b *CursorBuilder) Count(count int) *CursorBuilder { b.count = count; return b } + +// Run executes the cursor command. +func (b *CursorBuilder) Run() (interface{}, error) { + switch b.action { + case "read": + cmd := b.c.FTCursorRead(b.ctx, b.index, int(b.cursorId), b.count) + return cmd.Result() + case "del": + cmd := b.c.FTCursorDel(b.ctx, b.index, int(b.cursorId)) + return cmd.Result() + } + return nil, nil +} + +// ---------------------- +// SynUpdateBuilder for FT.SYNUPDATE +// ---------------------- +// SyncUpdateBuilder is builder for FT.SYNCUPDATE +// EXPERIMENTAL: this API is subject to change, use with caution. +type SynUpdateBuilder struct { + c *Client + ctx context.Context + index string + groupId interface{} + options *FTSynUpdateOptions + terms []interface{} +} + +// NewSynUpdateBuilder creates a new SynUpdateBuilder for FT.SYNUPDATE commands. +// EXPERIMENTAL: this API is subject to change, use with caution. +func (c *Client) NewSynUpdateBuilder(ctx context.Context, index string, groupId interface{}) *SynUpdateBuilder { + return &SynUpdateBuilder{c: c, ctx: ctx, index: index, groupId: groupId, options: &FTSynUpdateOptions{}} +} + +// SkipInitialScan includes SKIPINITIALSCAN. +func (b *SynUpdateBuilder) SkipInitialScan() *SynUpdateBuilder { + b.options.SkipInitialScan = true + return b +} + +// Terms adds synonyms to the group. +func (b *SynUpdateBuilder) Terms(terms ...interface{}) *SynUpdateBuilder { b.terms = terms; return b } + +// Run executes FT.SYNUPDATE. +func (b *SynUpdateBuilder) Run() (string, error) { + cmd := b.c.FTSynUpdateWithArgs(b.ctx, b.index, b.groupId, b.options, b.terms) + return cmd.Result() +} diff --git a/vendor/github.com/redis/go-redis/v9/search_commands.go b/vendor/github.com/redis/go-redis/v9/search_commands.go index b31baaa76..b13aa5bef 100644 --- a/vendor/github.com/redis/go-redis/v9/search_commands.go +++ b/vendor/github.com/redis/go-redis/v9/search_commands.go @@ -3,7 +3,10 @@ package redis import ( "context" "fmt" + "maps" + "slices" "strconv" + "strings" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/proto" @@ -29,6 +32,8 @@ type SearchCmdable interface { FTDropIndexWithArgs(ctx context.Context, index string, options *FTDropIndexOptions) *StatusCmd FTExplain(ctx context.Context, index string, query string) *StringCmd FTExplainWithArgs(ctx context.Context, index string, query string, options *FTExplainOptions) *StringCmd + FTHybrid(ctx context.Context, index string, searchExpr string, vectorField string, vectorData Vector) *FTHybridCmd + FTHybridWithArgs(ctx context.Context, index string, options *FTHybridOptions) *FTHybridCmd FTInfo(ctx context.Context, index string) *FTInfoCmd FTSpellCheck(ctx context.Context, index string, query string) *FTSpellCheckCmd FTSpellCheckWithArgs(ctx context.Context, index string, query string, options *FTSpellCheckOptions) *FTSpellCheckCmd @@ -80,8 +85,9 @@ type FieldSchema struct { } type FTVectorArgs struct { - FlatOptions *FTFlatOptions - HNSWOptions *FTHNSWOptions + FlatOptions *FTFlatOptions + HNSWOptions *FTHNSWOptions + VamanaOptions *FTVamanaOptions } type FTFlatOptions struct { @@ -103,6 +109,19 @@ type FTHNSWOptions struct { Epsilon float64 } +type FTVamanaOptions struct { + Type string + Dim int + DistanceMetric string + Compression string + ConstructionWindowSize int + GraphMaxDegree int + SearchWindowSize int + Epsilon float64 + TrainingThreshold int + ReduceDim int +} + type FTDropIndexOptions struct { DeleteDocs bool } @@ -240,22 +259,42 @@ type FTAggregateWithCursor struct { MaxIdle int } +// FTAggregateSortByStep represents a SORTBY operation with optional MAX. +// Used inside FTAggregateStep to place SORTBY at an arbitrary position in +// the aggregation pipeline. +type FTAggregateSortByStep struct { + Fields []FTAggregateSortBy + Max int // 0 means no MAX +} + +// FTAggregateStep represents a single operation in the aggregation pipeline. +// LOAD, APPLY, SORTBY and GROUPBY can all appear multiple times in any order. +// Exactly one of the fields should be set per step. +type FTAggregateStep struct { + Load *FTAggregateLoad + Apply *FTAggregateApply + GroupBy *FTAggregateGroupBy + SortBy *FTAggregateSortByStep +} + type FTAggregateOptions struct { - Verbatim bool - LoadAll bool - Load []FTAggregateLoad - Timeout int - GroupBy []FTAggregateGroupBy - SortBy []FTAggregateSortBy - SortByMax int + Verbatim bool + LoadAll bool + Timeout int // Scorer is used to set scoring function, if not set passed, a default will be used. // The default scorer depends on the Redis version: // - `BM25` for Redis >= 8 // - `TFIDF` for Redis < 8 Scorer string // AddScores is available in Redis CE 8 - AddScores bool - Apply []FTAggregateApply + AddScores bool + + // Steps is the ordered sequence of aggregation pipeline operations. + // It can contain LOAD, APPLY, GROUPBY and SORTBY in any order, multiple times. + // Steps cannot be combined with the deprecated Load, Apply, GroupBy, SortBy + // and SortByMax fields: doing so returns an error. + Steps []FTAggregateStep + LimitOffset int Limit int Filter string @@ -264,6 +303,17 @@ type FTAggregateOptions struct { Params map[string]interface{} // Dialect 1,3 and 4 are deprecated since redis 8.0 DialectVersion int + + // Deprecated: Use Steps instead. + Load []FTAggregateLoad + // Deprecated: Use Steps instead. + GroupBy []FTAggregateGroupBy + // Deprecated: Use Steps instead. + SortBy []FTAggregateSortBy + // Deprecated: Use Steps instead. + SortByMax int + // Deprecated: Use Steps instead. + Apply []FTAggregateApply } type FTSearchFilter struct { @@ -330,6 +380,100 @@ type FTSearchOptions struct { DialectVersion int } +// FTHybridCombineMethod represents the fusion method for combining search and vector results +type FTHybridCombineMethod string + +const ( + FTHybridCombineRRF FTHybridCombineMethod = "RRF" + FTHybridCombineLinear FTHybridCombineMethod = "LINEAR" + FTHybridCombineFunction FTHybridCombineMethod = "FUNCTION" +) + +// FTHybridSearchExpression represents a search expression in hybrid search +type FTHybridSearchExpression struct { + Query string + Scorer string + ScorerParams []interface{} + YieldScoreAs string +} + +type FTHybridVectorMethod = string + +const ( + KNN FTHybridCombineMethod = "KNN" + RANGE FTHybridCombineMethod = "RANGE" +) + +// FTHybridVectorExpression represents a vector expression in hybrid search +type FTHybridVectorExpression struct { + VectorField string + VectorData Vector + // VectorParamName optionally specifies the parameter name used to pass the + // vector data via the PARAMS mechanism. + // Vector data is always passed via PARAMS because inline vector blobs are no + // longer supported by Redis. When left empty, the library generates a unique + // parameter name automatically (e.g. "__vector_param_0") without mutating + // FTHybridOptions.Params and without colliding with any explicit names. + // The vector blob is passed as: VSIM @field $VectorParamName ... PARAMS ... VectorParamName + VectorParamName string + Method FTHybridVectorMethod + MethodParams []interface{} + Filter string + YieldScoreAs string +} + +// FTHybridCombineOptions represents options for result fusion +type FTHybridCombineOptions struct { + Method FTHybridCombineMethod + Count int + Window int // For RRF + Constant float64 // For RRF + Alpha float64 // For LINEAR + Beta float64 // For LINEAR + YieldScoreAs string +} + +// FTHybridGroupBy represents GROUP BY functionality +type FTHybridGroupBy struct { + Count int + Fields []string + ReduceFunc string + ReduceCount int + ReduceParams []interface{} +} + +// FTHybridApply represents APPLY functionality +type FTHybridApply struct { + Expression string + AsField string +} + +// FTHybridWithCursor represents cursor configuration for hybrid search +type FTHybridWithCursor struct { + Count int // Number of results to return per cursor read + MaxIdle int // Maximum idle time in milliseconds before cursor is automatically deleted +} + +// FTHybridOptions hold options that can be passed to the FT.HYBRID command +type FTHybridOptions struct { + CountExpressions int // Number of search/vector expressions + SearchExpressions []FTHybridSearchExpression // Multiple search expressions + VectorExpressions []FTHybridVectorExpression // Multiple vector expressions + Combine *FTHybridCombineOptions // Fusion step options + Load []string // Projected fields + GroupBy *FTHybridGroupBy // Aggregation grouping + Apply []FTHybridApply // Field transformations + SortBy []FTSearchSortBy // Reuse from FTSearch + Filter string // Post-filter expression + LimitOffset int // Result limiting + Limit int + Params map[string]interface{} // Parameter substitution + ExplainScore bool // Include score explanations + Timeout int // Runtime timeout + WithCursor bool // Enable cursor support for large result sets + WithCursorOptions *FTHybridWithCursor // Cursor configuration options +} + type FTSynDumpResult struct { Term string Synonyms []string @@ -340,9 +484,12 @@ type FTSynDumpCmd struct { val []FTSynDumpResult } +// FTAggregateResult represents the result of an aggregate operation +// NOTE: For RESP3 Total is not reliable (before Redis 8.8) type FTAggregateResult struct { - Total int - Rows []AggregateRow + Total int + Rows []AggregateRow + Warnings []string } type AggregateRow struct { @@ -409,6 +556,14 @@ type FTAttribute struct { PhoneticMatcher string CaseSensitive bool WithSuffixtrie bool + + // Vector specific attributes + Algorithm string + DataType string + Dim int + DistanceMetric string + M int + EFConstruction int } type CursorStats struct { @@ -464,8 +619,9 @@ type SpellCheckSuggestion struct { } type FTSearchResult struct { - Total int - Docs []Document + Total int + Docs []Document + Warnings []string } type Document struct { @@ -474,6 +630,7 @@ type Document struct { Payload *string SortKey *string Fields map[string]string + Error error } type AggregateQuery []interface{} @@ -498,9 +655,112 @@ func (c cmdable) FTAggregate(ctx context.Context, index string, query string) *M return cmd } -func FTAggregateQuery(query string, options *FTAggregateOptions) AggregateQuery { +// validateFTAggregateOptions validates mutually exclusive combinations of +// FTAggregateOptions fields before any command arguments are constructed. +func validateFTAggregateOptions(options *FTAggregateOptions) error { + if len(options.Steps) > 0 { + if options.Load != nil || options.Apply != nil || options.GroupBy != nil || + options.SortBy != nil || options.SortByMax != 0 { + return fmt.Errorf("FT.AGGREGATE: Steps cannot be combined with the deprecated Load, Apply, GroupBy, SortBy and SortByMax fields") + } + if options.LoadAll { + for _, step := range options.Steps { + if step.Load != nil { + return fmt.Errorf("FT.AGGREGATE: LOADALL and LOAD are mutually exclusive") + } + } + } + } + if options.LoadAll && options.Load != nil { + return fmt.Errorf("FT.AGGREGATE: LOADALL and LOAD are mutually exclusive") + } + return nil +} + +// appendFTAggregateStep appends the Redis command arguments for a single +// aggregation pipeline step. Each step must set exactly one of Load, Apply, +// GroupBy or SortBy. +func appendFTAggregateStep(args []interface{}, step FTAggregateStep) ([]interface{}, error) { + set := 0 + if step.Load != nil { + set++ + } + if step.Apply != nil { + set++ + } + if step.GroupBy != nil { + set++ + } + if step.SortBy != nil { + set++ + } + if set != 1 { + return args, fmt.Errorf("FT.AGGREGATE: each step must set exactly one of Load, Apply, GroupBy, SortBy (got %d)", set) + } + + switch { + case step.Load != nil: + args = append(args, "LOAD") + countIdx := len(args) + args = append(args, 0) + count := 0 + args = append(args, step.Load.Field) + count++ + if step.Load.As != "" { + args = append(args, "AS", step.Load.As) + count += 2 + } + args[countIdx] = count + case step.Apply != nil: + args = append(args, "APPLY", step.Apply.Field) + if step.Apply.As != "" { + args = append(args, "AS", step.Apply.As) + } + case step.GroupBy != nil: + args = append(args, "GROUPBY", len(step.GroupBy.Fields)) + args = append(args, step.GroupBy.Fields...) + for _, reducer := range step.GroupBy.Reduce { + args = append(args, "REDUCE", reducer.Reducer.String()) + if reducer.Args != nil { + args = append(args, len(reducer.Args)) + args = append(args, reducer.Args...) + } else { + args = append(args, 0) + } + if reducer.As != "" { + args = append(args, "AS", reducer.As) + } + } + case step.SortBy != nil: + args = append(args, "SORTBY") + sortByOptions := []interface{}{} + for _, sortBy := range step.SortBy.Fields { + if sortBy.Asc && sortBy.Desc { + return args, fmt.Errorf("FT.AGGREGATE: ASC and DESC are mutually exclusive") + } + sortByOptions = append(sortByOptions, sortBy.FieldName) + if sortBy.Asc { + sortByOptions = append(sortByOptions, "ASC") + } + if sortBy.Desc { + sortByOptions = append(sortByOptions, "DESC") + } + } + args = append(args, len(sortByOptions)) + args = append(args, sortByOptions...) + if step.SortBy.Max > 0 { + args = append(args, "MAX", step.SortBy.Max) + } + } + return args, nil +} + +func FTAggregateQuery(query string, options *FTAggregateOptions) (AggregateQuery, error) { queryArgs := []interface{}{query} if options != nil { + if err := validateFTAggregateOptions(options); err != nil { + return nil, err + } if options.Verbatim { queryArgs = append(queryArgs, "VERBATIM") } @@ -513,13 +773,10 @@ func FTAggregateQuery(query string, options *FTAggregateOptions) AggregateQuery queryArgs = append(queryArgs, "ADDSCORES") } - if options.LoadAll && options.Load != nil { - panic("FT.AGGREGATE: LOADALL and LOAD are mutually exclusive") - } if options.LoadAll { queryArgs = append(queryArgs, "LOAD", "*") } - if options.Load != nil { + if len(options.Steps) == 0 && options.Load != nil { queryArgs = append(queryArgs, "LOAD", len(options.Load)) index, count := len(queryArgs)-1, 0 for _, load := range options.Load { @@ -537,53 +794,63 @@ func FTAggregateQuery(query string, options *FTAggregateOptions) AggregateQuery queryArgs = append(queryArgs, "TIMEOUT", options.Timeout) } - for _, apply := range options.Apply { - queryArgs = append(queryArgs, "APPLY", apply.Field) - if apply.As != "" { - queryArgs = append(queryArgs, "AS", apply.As) + if len(options.Steps) > 0 { + for _, step := range options.Steps { + var err error + queryArgs, err = appendFTAggregateStep(queryArgs, step) + if err != nil { + return nil, err + } + } + } else { + for _, apply := range options.Apply { + queryArgs = append(queryArgs, "APPLY", apply.Field) + if apply.As != "" { + queryArgs = append(queryArgs, "AS", apply.As) + } } - } - if options.GroupBy != nil { - for _, groupBy := range options.GroupBy { - queryArgs = append(queryArgs, "GROUPBY", len(groupBy.Fields)) - queryArgs = append(queryArgs, groupBy.Fields...) - - for _, reducer := range groupBy.Reduce { - queryArgs = append(queryArgs, "REDUCE") - queryArgs = append(queryArgs, reducer.Reducer.String()) - if reducer.Args != nil { - queryArgs = append(queryArgs, len(reducer.Args)) - queryArgs = append(queryArgs, reducer.Args...) - } else { - queryArgs = append(queryArgs, 0) - } - if reducer.As != "" { - queryArgs = append(queryArgs, "AS", reducer.As) + if options.GroupBy != nil { + for _, groupBy := range options.GroupBy { + queryArgs = append(queryArgs, "GROUPBY", len(groupBy.Fields)) + queryArgs = append(queryArgs, groupBy.Fields...) + + for _, reducer := range groupBy.Reduce { + queryArgs = append(queryArgs, "REDUCE") + queryArgs = append(queryArgs, reducer.Reducer.String()) + if reducer.Args != nil { + queryArgs = append(queryArgs, len(reducer.Args)) + queryArgs = append(queryArgs, reducer.Args...) + } else { + queryArgs = append(queryArgs, 0) + } + if reducer.As != "" { + queryArgs = append(queryArgs, "AS", reducer.As) + } } } } - } - if options.SortBy != nil { - queryArgs = append(queryArgs, "SORTBY") - sortByOptions := []interface{}{} - for _, sortBy := range options.SortBy { - sortByOptions = append(sortByOptions, sortBy.FieldName) - if sortBy.Asc && sortBy.Desc { - panic("FT.AGGREGATE: ASC and DESC are mutually exclusive") - } - if sortBy.Asc { - sortByOptions = append(sortByOptions, "ASC") - } - if sortBy.Desc { - sortByOptions = append(sortByOptions, "DESC") + if options.SortBy != nil { + queryArgs = append(queryArgs, "SORTBY") + sortByOptions := []interface{}{} + for _, sortBy := range options.SortBy { + sortByOptions = append(sortByOptions, sortBy.FieldName) + if sortBy.Asc && sortBy.Desc { + return nil, fmt.Errorf("FT.AGGREGATE: ASC and DESC are mutually exclusive") + } + if sortBy.Asc { + sortByOptions = append(sortByOptions, "ASC") + } + if sortBy.Desc { + sortByOptions = append(sortByOptions, "DESC") + } } + queryArgs = append(queryArgs, len(sortByOptions)) + queryArgs = append(queryArgs, sortByOptions...) + } + if options.SortByMax > 0 { + queryArgs = append(queryArgs, "MAX", options.SortByMax) } - queryArgs = append(queryArgs, len(sortByOptions)) - queryArgs = append(queryArgs, sortByOptions...) - } - if options.SortByMax > 0 { - queryArgs = append(queryArgs, "MAX", options.SortByMax) } if options.LimitOffset >= 0 && options.Limit > 0 { queryArgs = append(queryArgs, "LIMIT", options.LimitOffset, options.Limit) @@ -615,7 +882,7 @@ func FTAggregateQuery(query string, options *FTAggregateOptions) AggregateQuery queryArgs = append(queryArgs, "DIALECT", 2) } } - return queryArgs + return queryArgs, nil } func ProcessAggregateResult(data []interface{}) (*FTAggregateResult, error) { @@ -657,8 +924,9 @@ func ProcessAggregateResult(data []interface{}) (*FTAggregateResult, error) { func NewAggregateCmd(ctx context.Context, args ...interface{}) *AggregateCmd { return &AggregateCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeAggregate, }, } } @@ -688,15 +956,158 @@ func (cmd *AggregateCmd) String() string { } func (cmd *AggregateCmd) readReply(rd *proto.Reader) (err error) { - data, err := rd.ReadSlice() + readType, err := rd.PeekReplyType() if err != nil { return err } - cmd.val, err = ProcessAggregateResult(data) + + // RESP3 returns a map, RESP2 returns an array + if readType == proto.RespMap { + // Read raw response first for backwards compatibility + cmd.rawVal, err = rd.ReadReply() + if err != nil { + return err + } + // Parse the raw response into structured result + if mapVal, ok := cmd.rawVal.(map[interface{}]interface{}); ok { + cmd.val, err = parseFTAggregateMapRESP3(mapVal) + } else { + return fmt.Errorf("unexpected RESP3 response type: %T", cmd.rawVal) + } + return err + } + + // RESP2 format or error response - use ReadReply to handle errors properly + data, err := rd.ReadReply() if err != nil { return err } - return nil + cmd.rawVal = data // Store raw value for debugging + if dataSlice, ok := data.([]interface{}); ok { + cmd.val, err = ProcessAggregateResult(dataSlice) + return err + } + return fmt.Errorf("unexpected response type: %T", data) +} + +// parseFTAggregateMapRESP3 parses the RESP3 format response from FT.AGGREGATE. +// It takes a map[interface{}]interface{} which is the raw response from ReadReply(). +// RESP3 format: +// +// %5 +// $10 attributes => *0 +// $13 total_results => :N +// $6 format => $6 STRING +// $7 results => *N (array of maps with extra_attributes, values) +// $7 warning => *N (array of strings) +func parseFTAggregateMapRESP3(data map[interface{}]interface{}) (*FTAggregateResult, error) { + result := &FTAggregateResult{ + Rows: make([]AggregateRow, 0), + } + + for k, v := range data { + key, ok := k.(string) + if !ok { + continue + } + + switch key { + case "total_results": + result.Total = internal.ToInteger(v) + case "results": + if resultsData, ok := v.([]interface{}); ok { + rows, err := parseFTAggregateResultsMapRESP3(resultsData) + if err != nil { + return nil, err + } + result.Rows = rows + } + case "warning": + if warningsData, ok := v.([]interface{}); ok { + result.Warnings = make([]string, 0, len(warningsData)) + for _, w := range warningsData { + if ws, ok := w.(string); ok { + result.Warnings = append(result.Warnings, ws) + } + } + } + // Ignore "attributes", "format", and other fields as per the spec + } + } + + return result, nil +} + +// parseFTAggregateResultsMapRESP3 parses the results array from RESP3 FT.AGGREGATE response. +func parseFTAggregateResultsMapRESP3(resultsData []interface{}) ([]AggregateRow, error) { + rows := make([]AggregateRow, 0, len(resultsData)) + for _, item := range resultsData { + if itemMap, ok := item.(map[interface{}]interface{}); ok { + row, err := parseFTAggregateRowMapRESP3(itemMap) + if err != nil { + return nil, err + } + rows = append(rows, row) + } + } + return rows, nil +} + +// parseFTAggregateRowMapRESP3 parses a single row from RESP3 FT.AGGREGATE response. +func parseFTAggregateRowMapRESP3(itemMap map[interface{}]interface{}) (AggregateRow, error) { + row := AggregateRow{ + Fields: make(map[string]interface{}), + } + + for k, v := range itemMap { + key, ok := k.(string) + if !ok { + continue + } + + switch key { + case "extra_attributes": + if extraAttrs, ok := v.(map[interface{}]interface{}); ok { + for ek, ev := range extraAttrs { + if ekStr, ok := ek.(string); ok { + row.Fields[ekStr] = ev + } + } + } + // Ignore "values" and other fields as per the spec + } + } + + return row, nil +} + +func (cmd *AggregateCmd) Clone() Cmder { + var val *FTAggregateResult + if cmd.val != nil { + val = &FTAggregateResult{ + Total: cmd.val.Total, + } + if cmd.val.Rows != nil { + val.Rows = make([]AggregateRow, len(cmd.val.Rows)) + for i, row := range cmd.val.Rows { + val.Rows[i] = AggregateRow{} + if row.Fields != nil { + val.Rows[i].Fields = make(map[string]interface{}, len(row.Fields)) + for k, v := range row.Fields { + val.Rows[i].Fields[k] = v + } + } + } + } + if cmd.val.Warnings != nil { + val.Warnings = make([]string, len(cmd.val.Warnings)) + copy(val.Warnings, cmd.val.Warnings) + } + } + return &AggregateCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } } // FTAggregateWithArgs - Performs a search query on an index and applies a series of aggregate transformations to the result. @@ -707,6 +1118,11 @@ func (cmd *AggregateCmd) readReply(rd *proto.Reader) (err error) { func (c cmdable) FTAggregateWithArgs(ctx context.Context, index string, query string, options *FTAggregateOptions) *AggregateCmd { args := []interface{}{"FT.AGGREGATE", index, query} if options != nil { + if err := validateFTAggregateOptions(options); err != nil { + cmd := NewAggregateCmd(ctx, args...) + cmd.SetErr(err) + return cmd + } if options.Verbatim { args = append(args, "VERBATIM") } @@ -716,13 +1132,10 @@ func (c cmdable) FTAggregateWithArgs(ctx context.Context, index string, query st if options.AddScores { args = append(args, "ADDSCORES") } - if options.LoadAll && options.Load != nil { - panic("FT.AGGREGATE: LOADALL and LOAD are mutually exclusive") - } if options.LoadAll { args = append(args, "LOAD", "*") } - if options.Load != nil { + if len(options.Steps) == 0 && options.Load != nil { args = append(args, "LOAD", len(options.Load)) index, count := len(args)-1, 0 for _, load := range options.Load { @@ -738,52 +1151,66 @@ func (c cmdable) FTAggregateWithArgs(ctx context.Context, index string, query st if options.Timeout > 0 { args = append(args, "TIMEOUT", options.Timeout) } - for _, apply := range options.Apply { - args = append(args, "APPLY", apply.Field) - if apply.As != "" { - args = append(args, "AS", apply.As) - } - } - if options.GroupBy != nil { - for _, groupBy := range options.GroupBy { - args = append(args, "GROUPBY", len(groupBy.Fields)) - args = append(args, groupBy.Fields...) - - for _, reducer := range groupBy.Reduce { - args = append(args, "REDUCE") - args = append(args, reducer.Reducer.String()) - if reducer.Args != nil { - args = append(args, len(reducer.Args)) - args = append(args, reducer.Args...) - } else { - args = append(args, 0) - } - if reducer.As != "" { - args = append(args, "AS", reducer.As) - } + if len(options.Steps) > 0 { + for _, step := range options.Steps { + var err error + args, err = appendFTAggregateStep(args, step) + if err != nil { + cmd := NewAggregateCmd(ctx, args...) + cmd.SetErr(err) + return cmd } } - } - if options.SortBy != nil { - args = append(args, "SORTBY") - sortByOptions := []interface{}{} - for _, sortBy := range options.SortBy { - sortByOptions = append(sortByOptions, sortBy.FieldName) - if sortBy.Asc && sortBy.Desc { - panic("FT.AGGREGATE: ASC and DESC are mutually exclusive") + } else { + for _, apply := range options.Apply { + args = append(args, "APPLY", apply.Field) + if apply.As != "" { + args = append(args, "AS", apply.As) } - if sortBy.Asc { - sortByOptions = append(sortByOptions, "ASC") + } + if options.GroupBy != nil { + for _, groupBy := range options.GroupBy { + args = append(args, "GROUPBY", len(groupBy.Fields)) + args = append(args, groupBy.Fields...) + + for _, reducer := range groupBy.Reduce { + args = append(args, "REDUCE") + args = append(args, reducer.Reducer.String()) + if reducer.Args != nil { + args = append(args, len(reducer.Args)) + args = append(args, reducer.Args...) + } else { + args = append(args, 0) + } + if reducer.As != "" { + args = append(args, "AS", reducer.As) + } + } } - if sortBy.Desc { - sortByOptions = append(sortByOptions, "DESC") + } + if options.SortBy != nil { + args = append(args, "SORTBY") + sortByOptions := []interface{}{} + for _, sortBy := range options.SortBy { + sortByOptions = append(sortByOptions, sortBy.FieldName) + if sortBy.Asc && sortBy.Desc { + cmd := NewAggregateCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.AGGREGATE: ASC and DESC are mutually exclusive")) + return cmd + } + if sortBy.Asc { + sortByOptions = append(sortByOptions, "ASC") + } + if sortBy.Desc { + sortByOptions = append(sortByOptions, "DESC") + } } + args = append(args, len(sortByOptions)) + args = append(args, sortByOptions...) + } + if options.SortByMax > 0 { + args = append(args, "MAX", options.SortByMax) } - args = append(args, len(sortByOptions)) - args = append(args, sortByOptions...) - } - if options.SortByMax > 0 { - args = append(args, "MAX", options.SortByMax) } if options.LimitOffset >= 0 && options.Limit > 0 { args = append(args, "LIMIT", options.LimitOffset, options.Limit) @@ -918,7 +1345,9 @@ func (c cmdable) FTCreate(ctx context.Context, index string, options *FTCreateOp args = append(args, "ON", "JSON") } if options.OnHash && options.OnJSON { - panic("FT.CREATE: ON HASH and ON JSON are mutually exclusive") + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: ON HASH and ON JSON are mutually exclusive")) + return cmd } if options.Prefix != nil { args = append(args, "PREFIX", len(options.Prefix)) @@ -969,12 +1398,16 @@ func (c cmdable) FTCreate(ctx context.Context, index string, options *FTCreateOp } } if schema == nil { - panic("FT.CREATE: SCHEMA is required") + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: SCHEMA is required")) + return cmd } args = append(args, "SCHEMA") for _, schema := range schema { if schema.FieldName == "" || schema.FieldType == SearchFieldTypeInvalid { - panic("FT.CREATE: SCHEMA FieldName and FieldType are required") + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: SCHEMA FieldName and FieldType are required")) + return cmd } args = append(args, schema.FieldName) if schema.As != "" { @@ -983,15 +1416,32 @@ func (c cmdable) FTCreate(ctx context.Context, index string, options *FTCreateOp args = append(args, schema.FieldType.String()) if schema.VectorArgs != nil { if schema.FieldType != SearchFieldTypeVector { - panic("FT.CREATE: SCHEMA FieldType VECTOR is required for VectorArgs") + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: SCHEMA FieldType VECTOR is required for VectorArgs")) + return cmd + } + // Check mutual exclusivity of vector options + optionCount := 0 + if schema.VectorArgs.FlatOptions != nil { + optionCount++ } - if schema.VectorArgs.FlatOptions != nil && schema.VectorArgs.HNSWOptions != nil { - panic("FT.CREATE: SCHEMA VectorArgs FlatOptions and HNSWOptions are mutually exclusive") + if schema.VectorArgs.HNSWOptions != nil { + optionCount++ + } + if schema.VectorArgs.VamanaOptions != nil { + optionCount++ + } + if optionCount != 1 { + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: SCHEMA VectorArgs must have exactly one of FlatOptions, HNSWOptions, or VamanaOptions")) + return cmd } if schema.VectorArgs.FlatOptions != nil { args = append(args, "FLAT") if schema.VectorArgs.FlatOptions.Type == "" || schema.VectorArgs.FlatOptions.Dim == 0 || schema.VectorArgs.FlatOptions.DistanceMetric == "" { - panic("FT.CREATE: Type, Dim and DistanceMetric are required for VECTOR FLAT") + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: Type, Dim and DistanceMetric are required for VECTOR FLAT")) + return cmd } flatArgs := []interface{}{ "TYPE", schema.VectorArgs.FlatOptions.Type, @@ -1010,7 +1460,9 @@ func (c cmdable) FTCreate(ctx context.Context, index string, options *FTCreateOp if schema.VectorArgs.HNSWOptions != nil { args = append(args, "HNSW") if schema.VectorArgs.HNSWOptions.Type == "" || schema.VectorArgs.HNSWOptions.Dim == 0 || schema.VectorArgs.HNSWOptions.DistanceMetric == "" { - panic("FT.CREATE: Type, Dim and DistanceMetric are required for VECTOR HNSW") + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: Type, Dim and DistanceMetric are required for VECTOR HNSW")) + return cmd } hnswArgs := []interface{}{ "TYPE", schema.VectorArgs.HNSWOptions.Type, @@ -1035,10 +1487,48 @@ func (c cmdable) FTCreate(ctx context.Context, index string, options *FTCreateOp args = append(args, len(hnswArgs)) args = append(args, hnswArgs...) } + if schema.VectorArgs.VamanaOptions != nil { + args = append(args, "SVS-VAMANA") + if schema.VectorArgs.VamanaOptions.Type == "" || schema.VectorArgs.VamanaOptions.Dim == 0 || schema.VectorArgs.VamanaOptions.DistanceMetric == "" { + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: Type, Dim and DistanceMetric are required for VECTOR VAMANA")) + return cmd + } + vamanaArgs := []interface{}{ + "TYPE", schema.VectorArgs.VamanaOptions.Type, + "DIM", schema.VectorArgs.VamanaOptions.Dim, + "DISTANCE_METRIC", schema.VectorArgs.VamanaOptions.DistanceMetric, + } + if schema.VectorArgs.VamanaOptions.Compression != "" { + vamanaArgs = append(vamanaArgs, "COMPRESSION", schema.VectorArgs.VamanaOptions.Compression) + } + if schema.VectorArgs.VamanaOptions.ConstructionWindowSize > 0 { + vamanaArgs = append(vamanaArgs, "CONSTRUCTION_WINDOW_SIZE", schema.VectorArgs.VamanaOptions.ConstructionWindowSize) + } + if schema.VectorArgs.VamanaOptions.GraphMaxDegree > 0 { + vamanaArgs = append(vamanaArgs, "GRAPH_MAX_DEGREE", schema.VectorArgs.VamanaOptions.GraphMaxDegree) + } + if schema.VectorArgs.VamanaOptions.SearchWindowSize > 0 { + vamanaArgs = append(vamanaArgs, "SEARCH_WINDOW_SIZE", schema.VectorArgs.VamanaOptions.SearchWindowSize) + } + if schema.VectorArgs.VamanaOptions.Epsilon > 0 { + vamanaArgs = append(vamanaArgs, "EPSILON", schema.VectorArgs.VamanaOptions.Epsilon) + } + if schema.VectorArgs.VamanaOptions.TrainingThreshold > 0 { + vamanaArgs = append(vamanaArgs, "TRAINING_THRESHOLD", schema.VectorArgs.VamanaOptions.TrainingThreshold) + } + if schema.VectorArgs.VamanaOptions.ReduceDim > 0 { + vamanaArgs = append(vamanaArgs, "REDUCE", schema.VectorArgs.VamanaOptions.ReduceDim) + } + args = append(args, len(vamanaArgs)) + args = append(args, vamanaArgs...) + } } if schema.GeoShapeFieldType != "" { if schema.FieldType != SearchFieldTypeGeoShape { - panic("FT.CREATE: SCHEMA FieldType GEOSHAPE is required for GeoShapeFieldType") + cmd := NewStatusCmd(ctx, args...) + cmd.SetErr(fmt.Errorf("FT.CREATE: SCHEMA FieldType GEOSHAPE is required for GeoShapeFieldType")) + return cmd } args = append(args, schema.GeoShapeFieldType) } @@ -1196,100 +1686,325 @@ func (c cmdable) FTExplainWithArgs(ctx context.Context, index string, query stri // FTExplainCli - Returns the execution plan for a complex query. [Not Implemented] // For more information, see https://redis.io/commands/ft.explaincli/ func (c cmdable) FTExplainCli(ctx context.Context, key, path string) error { - panic("not implemented") + return fmt.Errorf("FTExplainCli is not implemented") +} + +// parseFTAttributeFromMap parses an FTAttribute from a RESP3 map format +func parseFTAttributeFromMap(attrMap map[interface{}]interface{}) FTAttribute { + att := FTAttribute{} + for k, v := range attrMap { + key := internal.ToLower(internal.ToString(k)) + switch key { + case "attribute": + att.Attribute = internal.ToString(v) + case "identifier": + att.Identifier = internal.ToString(v) + case "type": + att.Type = internal.ToString(v) + case "weight": + att.Weight = internal.ToFloat(v) + case "phonetic": + att.PhoneticMatcher = internal.ToString(v) + case "algorithm": + att.Algorithm = internal.ToString(v) + case "data_type": + att.DataType = internal.ToString(v) + case "dim": + att.Dim = internal.ToInteger(v) + case "distance_metric": + att.DistanceMetric = internal.ToString(v) + case "m": + att.M = internal.ToInteger(v) + case "ef_construction": + att.EFConstruction = internal.ToInteger(v) + case "flags": + // flags is an array of strings like ["SORTABLE", "NOSTEM"] + if flags, ok := v.([]interface{}); ok { + for _, flag := range flags { + flagStr := internal.ToLower(internal.ToString(flag)) + switch flagStr { + case "nostem": + att.NoStem = true + case "sortable": + att.Sortable = true + case "noindex": + att.NoIndex = true + case "unf": + att.UNF = true + case "case_sensitive": + att.CaseSensitive = true + case "withsuffixtrie": + att.WithSuffixtrie = true + } + } + } + } + } + return att +} + +// getMapStringKey extracts a string value from a map with interface{} keys +func getMapStringKey(m map[interface{}]interface{}, key string) interface{} { + if v, ok := m[key]; ok { + return v + } + return nil +} + +// parseIndexErrorsRESP3 parses Index Errors from RESP3 map format +func parseIndexErrorsRESP3(m map[interface{}]interface{}) IndexErrors { + return IndexErrors{ + IndexingFailures: internal.ToInteger(getMapStringKey(m, "indexing failures")), + LastIndexingError: internal.ToString(getMapStringKey(m, "last indexing error")), + LastIndexingErrorKey: internal.ToString(getMapStringKey(m, "last indexing error key")), + } +} + +// parseCursorStatsRESP3 parses cursor_stats from RESP3 map format +func parseCursorStatsRESP3(m map[interface{}]interface{}) CursorStats { + return CursorStats{ + GlobalIdle: internal.ToInteger(getMapStringKey(m, "global_idle")), + GlobalTotal: internal.ToInteger(getMapStringKey(m, "global_total")), + IndexCapacity: internal.ToInteger(getMapStringKey(m, "index_capacity")), + IndexTotal: internal.ToInteger(getMapStringKey(m, "index_total")), + } +} + +// parseGCStatsRESP3 parses gc_stats from RESP3 map format +func parseGCStatsRESP3(m map[interface{}]interface{}) GCStats { + // Handle average_cycle_time_ms which can be a float64 (including NaN) or string + avgCycleTime := "" + if v := getMapStringKey(m, "average_cycle_time_ms"); v != nil { + switch val := v.(type) { + case string: + // Normalize to lowercase for consistency with RESP2 + avgCycleTime = strings.ToLower(val) + case float64: + avgCycleTime = internal.FormatFloat(val) + } + } + + return GCStats{ + BytesCollected: ftInfoNumInt(getMapStringKey(m, "bytes_collected")), + TotalMsRun: ftInfoNumInt(getMapStringKey(m, "total_ms_run")), + TotalCycles: ftInfoNumInt(getMapStringKey(m, "total_cycles")), + AverageCycleTimeMs: avgCycleTime, + LastRunTimeMs: ftInfoNumInt(getMapStringKey(m, "last_run_time_ms")), + GCNumericTreesMissed: ftInfoNumInt(getMapStringKey(m, "gc_numeric_trees_missed")), + GCBlocksDenied: ftInfoNumInt(getMapStringKey(m, "gc_blocks_denied")), + } +} + +// parseIndexDefinitionRESP3 parses index_definition from RESP3 map format +func parseIndexDefinitionRESP3(m map[interface{}]interface{}) IndexDefinition { + def := IndexDefinition{ + KeyType: internal.ToString(getMapStringKey(m, "key_type")), + DefaultScore: internal.ToFloat(getMapStringKey(m, "default_score")), + } + if prefixes, ok := getMapStringKey(m, "prefixes").([]interface{}); ok { + def.Prefixes = internal.ToStringSlice(prefixes) + } + return def +} + +// parseDialectStatsRESP3 parses dialect_stats from RESP3 map format +func parseDialectStatsRESP3(m map[interface{}]interface{}) map[string]int { + result := make(map[string]int) + for k, v := range m { + if kStr, ok := k.(string); ok { + result[kStr] = internal.ToInteger(v) + } + } + return result +} + +// ftInfoNumString stringifies a value that RediSearch emits via REPLY_KVNUM +// (RedisModule_ReplyWithDouble): a bulk string in RESP2 but a native double +// in RESP3. Used for FTInfoResult fields whose public type is string. +// Special float values (NaN, +Inf, -Inf) are normalized to lowercase to match +// the RESP2 wire format. +func ftInfoNumString(val interface{}) string { + switch v := val.(type) { + case string: + return v + case float64: + return internal.FormatFloat(v) + case float32: + return internal.FormatFloat(float64(v)) + case int64: + return strconv.FormatInt(v, 10) + case int: + return strconv.Itoa(v) + default: + return "" + } +} + +// ftInfoNumInt converts a value that RediSearch emits via REPLY_KVNUM to int. +// In RESP2 the value is a bulk string; in RESP3 it is a native double, even +// for logically-integer fields (counters, byte sizes). This helper exists so +// the internal.ToInteger helper can remain strict about float-to-int coercion +// while still letting the RediSearch parsers read those values correctly. +func ftInfoNumInt(val interface{}) int { + switch v := val.(type) { + case float64: + return int(v) + case float32: + return int(v) + default: + return internal.ToInteger(v) + } } func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) { var ftInfo FTInfoResult - // Manually parse each field from the map + + // Parse Index Errors - handle both RESP2 (array) and RESP3 (map) formats if indexErrors, ok := data["Index Errors"].([]interface{}); ok { + // RESP2 format: array with key-value pairs ftInfo.IndexErrors = IndexErrors{ IndexingFailures: internal.ToInteger(indexErrors[1]), LastIndexingError: internal.ToString(indexErrors[3]), LastIndexingErrorKey: internal.ToString(indexErrors[5]), } + } else if indexErrors, ok := data["Index Errors"].(map[interface{}]interface{}); ok { + // RESP3 format: map + ftInfo.IndexErrors = parseIndexErrorsRESP3(indexErrors) } if attributes, ok := data["attributes"].([]interface{}); ok { for _, attr := range attributes { - if attrMap, ok := attr.([]interface{}); ok { - att := FTAttribute{} - for i := 0; i < len(attrMap); i++ { - if internal.ToLower(internal.ToString(attrMap[i])) == "attribute" { - att.Attribute = internal.ToString(attrMap[i+1]) + att := FTAttribute{} + // Handle RESP2 format: attribute is []interface{} + if attrSlice, ok := attr.([]interface{}); ok { + attrLen := len(attrSlice) + for i := 0; i < attrLen; i++ { + if internal.ToLower(internal.ToString(attrSlice[i])) == "attribute" && i+1 < attrLen { + att.Attribute = internal.ToString(attrSlice[i+1]) + i++ continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "identifier" { - att.Identifier = internal.ToString(attrMap[i+1]) + if internal.ToLower(internal.ToString(attrSlice[i])) == "identifier" && i+1 < attrLen { + att.Identifier = internal.ToString(attrSlice[i+1]) + i++ continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "type" { - att.Type = internal.ToString(attrMap[i+1]) + if internal.ToLower(internal.ToString(attrSlice[i])) == "type" && i+1 < attrLen { + att.Type = internal.ToString(attrSlice[i+1]) + i++ continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "weight" { - att.Weight = internal.ToFloat(attrMap[i+1]) + if internal.ToLower(internal.ToString(attrSlice[i])) == "weight" && i+1 < attrLen { + att.Weight = internal.ToFloat(attrSlice[i+1]) + i++ continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "nostem" { + if internal.ToLower(internal.ToString(attrSlice[i])) == "nostem" { att.NoStem = true continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "sortable" { + if internal.ToLower(internal.ToString(attrSlice[i])) == "sortable" { att.Sortable = true continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "noindex" { + if internal.ToLower(internal.ToString(attrSlice[i])) == "noindex" { att.NoIndex = true continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "unf" { + if internal.ToLower(internal.ToString(attrSlice[i])) == "unf" { att.UNF = true continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "phonetic" { - att.PhoneticMatcher = internal.ToString(attrMap[i+1]) + if internal.ToLower(internal.ToString(attrSlice[i])) == "phonetic" && i+1 < attrLen { + att.PhoneticMatcher = internal.ToString(attrSlice[i+1]) continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "case_sensitive" { + if internal.ToLower(internal.ToString(attrSlice[i])) == "case_sensitive" { att.CaseSensitive = true continue } - if internal.ToLower(internal.ToString(attrMap[i])) == "withsuffixtrie" { + if internal.ToLower(internal.ToString(attrSlice[i])) == "withsuffixtrie" { att.WithSuffixtrie = true continue } + // vector specific attributes + if internal.ToLower(internal.ToString(attrSlice[i])) == "algorithm" && i+1 < attrLen { + att.Algorithm = internal.ToString(attrSlice[i+1]) + i++ + continue + } + if internal.ToLower(internal.ToString(attrSlice[i])) == "data_type" && i+1 < attrLen { + att.DataType = internal.ToString(attrSlice[i+1]) + i++ + continue + } + if internal.ToLower(internal.ToString(attrSlice[i])) == "dim" && i+1 < attrLen { + att.Dim = internal.ToInteger(attrSlice[i+1]) + i++ + continue + } + if internal.ToLower(internal.ToString(attrSlice[i])) == "distance_metric" && i+1 < attrLen { + att.DistanceMetric = internal.ToString(attrSlice[i+1]) + i++ + continue + } + if internal.ToLower(internal.ToString(attrSlice[i])) == "m" && i+1 < attrLen { + att.M = internal.ToInteger(attrSlice[i+1]) + i++ + continue + } + if internal.ToLower(internal.ToString(attrSlice[i])) == "ef_construction" && i+1 < attrLen { + att.EFConstruction = internal.ToInteger(attrSlice[i+1]) + i++ + continue + } } ftInfo.Attributes = append(ftInfo.Attributes, att) + } else if attrMap, ok := attr.(map[interface{}]interface{}); ok { + // Handle RESP3 format: attribute is map[interface{}]interface{} + att = parseFTAttributeFromMap(attrMap) + ftInfo.Attributes = append(ftInfo.Attributes, att) } } } - ftInfo.BytesPerRecordAvg = internal.ToString(data["bytes_per_record_avg"]) + ftInfo.BytesPerRecordAvg = ftInfoNumString(data["bytes_per_record_avg"]) ftInfo.Cleaning = internal.ToInteger(data["cleaning"]) + // Parse cursor_stats - handle both RESP2 (array) and RESP3 (map) formats if cursorStats, ok := data["cursor_stats"].([]interface{}); ok { + // RESP2 format ftInfo.CursorStats = CursorStats{ GlobalIdle: internal.ToInteger(cursorStats[1]), GlobalTotal: internal.ToInteger(cursorStats[3]), IndexCapacity: internal.ToInteger(cursorStats[5]), IndexTotal: internal.ToInteger(cursorStats[7]), } + } else if cursorStats, ok := data["cursor_stats"].(map[interface{}]interface{}); ok { + // RESP3 format + ftInfo.CursorStats = parseCursorStatsRESP3(cursorStats) } + // Parse dialect_stats - handle both RESP2 (array) and RESP3 (map) formats if dialectStats, ok := data["dialect_stats"].([]interface{}); ok { + // RESP2 format ftInfo.DialectStats = make(map[string]int) for i := 0; i < len(dialectStats); i += 2 { ftInfo.DialectStats[internal.ToString(dialectStats[i])] = internal.ToInteger(dialectStats[i+1]) } + } else if dialectStats, ok := data["dialect_stats"].(map[interface{}]interface{}); ok { + // RESP3 format + ftInfo.DialectStats = parseDialectStatsRESP3(dialectStats) } ftInfo.DocTableSizeMB = internal.ToFloat(data["doc_table_size_mb"]) + // Parse field statistics - handle both RESP2 and RESP3 formats if fieldStats, ok := data["field statistics"].([]interface{}); ok { for _, stat := range fieldStats { if statMap, ok := stat.([]interface{}); ok { + // RESP2 format ftInfo.FieldStatistics = append(ftInfo.FieldStatistics, FieldStatistic{ Identifier: internal.ToString(statMap[1]), Attribute: internal.ToString(statMap[3]), @@ -1299,11 +2014,23 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) { LastIndexingErrorKey: internal.ToString(statMap[5].([]interface{})[5]), }, }) + } else if statMap, ok := stat.(map[interface{}]interface{}); ok { + // RESP3 format + fs := FieldStatistic{ + Identifier: internal.ToString(getMapStringKey(statMap, "identifier")), + Attribute: internal.ToString(getMapStringKey(statMap, "attribute")), + } + if indexErrors, ok := getMapStringKey(statMap, "Index Errors").(map[interface{}]interface{}); ok { + fs.IndexErrors = parseIndexErrorsRESP3(indexErrors) + } + ftInfo.FieldStatistics = append(ftInfo.FieldStatistics, fs) } } } + // Parse gc_stats - handle both RESP2 (array) and RESP3 (map) formats if gcStats, ok := data["gc_stats"].([]interface{}); ok { + // RESP2 format ftInfo.GCStats = GCStats{} for i := 0; i < len(gcStats); i += 2 { if internal.ToLower(internal.ToString(gcStats[i])) == "bytes_collected" { @@ -1335,21 +2062,31 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) { continue } } + } else if gcStats, ok := data["gc_stats"].(map[interface{}]interface{}); ok { + // RESP3 format + ftInfo.GCStats = parseGCStatsRESP3(gcStats) } ftInfo.GeoshapesSzMB = internal.ToFloat(data["geoshapes_sz_mb"]) ftInfo.HashIndexingFailures = internal.ToInteger(data["hash_indexing_failures"]) + // Parse index_definition - handle both RESP2 (array) and RESP3 (map) formats if indexDef, ok := data["index_definition"].([]interface{}); ok { + // RESP2 format ftInfo.IndexDefinition = IndexDefinition{ KeyType: internal.ToString(indexDef[1]), Prefixes: internal.ToStringSlice(indexDef[3]), DefaultScore: internal.ToFloat(indexDef[5]), } + } else if indexDef, ok := data["index_definition"].(map[interface{}]interface{}); ok { + // RESP3 format + ftInfo.IndexDefinition = parseIndexDefinitionRESP3(indexDef) } ftInfo.IndexName = internal.ToString(data["index_name"]) - ftInfo.IndexOptions = internal.ToStringSlice(data["index_options"].([]interface{})) + if indexOptions, ok := data["index_options"].([]interface{}); ok { + ftInfo.IndexOptions = internal.ToStringSlice(indexOptions) + } ftInfo.Indexing = internal.ToInteger(data["indexing"]) ftInfo.InvertedSzMB = internal.ToFloat(data["inverted_sz_mb"]) ftInfo.KeyTableSizeMB = internal.ToFloat(data["key_table_size_mb"]) @@ -1358,16 +2095,16 @@ func parseFTInfo(data map[string]interface{}) (FTInfoResult, error) { ftInfo.NumRecords = internal.ToInteger(data["num_records"]) ftInfo.NumTerms = internal.ToInteger(data["num_terms"]) ftInfo.NumberOfUses = internal.ToInteger(data["number_of_uses"]) - ftInfo.OffsetBitsPerRecordAvg = internal.ToString(data["offset_bits_per_record_avg"]) + ftInfo.OffsetBitsPerRecordAvg = ftInfoNumString(data["offset_bits_per_record_avg"]) ftInfo.OffsetVectorsSzMB = internal.ToFloat(data["offset_vectors_sz_mb"]) - ftInfo.OffsetsPerTermAvg = internal.ToString(data["offsets_per_term_avg"]) + ftInfo.OffsetsPerTermAvg = ftInfoNumString(data["offsets_per_term_avg"]) ftInfo.PercentIndexed = internal.ToFloat(data["percent_indexed"]) - ftInfo.RecordsPerDocAvg = internal.ToString(data["records_per_doc_avg"]) + ftInfo.RecordsPerDocAvg = ftInfoNumString(data["records_per_doc_avg"]) ftInfo.SortableValuesSizeMB = internal.ToFloat(data["sortable_values_size_mb"]) ftInfo.TagOverheadSzMB = internal.ToFloat(data["tag_overhead_sz_mb"]) ftInfo.TextOverheadSzMB = internal.ToFloat(data["text_overhead_sz_mb"]) ftInfo.TotalIndexMemorySzMB = internal.ToFloat(data["total_index_memory_sz_mb"]) - ftInfo.TotalIndexingTime = internal.ToInteger(data["total_indexing_time"]) + ftInfo.TotalIndexingTime = ftInfoNumInt(data["total_indexing_time"]) ftInfo.TotalInvertedIndexBlocks = internal.ToInteger(data["total_inverted_index_blocks"]) ftInfo.VectorIndexSzMB = internal.ToFloat(data["vector_index_sz_mb"]) @@ -1382,8 +2119,9 @@ type FTInfoCmd struct { func newFTInfoCmd(ctx context.Context, args ...interface{}) *FTInfoCmd { return &FTInfoCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTInfo, }, } } @@ -1412,6 +2150,37 @@ func (cmd *FTInfoCmd) RawResult() (interface{}, error) { return cmd.rawVal, cmd.err } func (cmd *FTInfoCmd) readReply(rd *proto.Reader) (err error) { + readType, err := rd.PeekReplyType() + if err != nil { + return err + } + + // RESP3 returns a map, RESP2 returns an array + if readType == proto.RespMap { + // Read raw response first for backwards compatibility + cmd.rawVal, err = rd.ReadReply() + if err != nil { + return err + } + + // Convert map[interface{}]interface{} to map[string]interface{} + rawMap, ok := cmd.rawVal.(map[interface{}]interface{}) + if !ok { + return fmt.Errorf("unexpected RESP3 response type: %T", cmd.rawVal) + } + + data := make(map[string]interface{}, len(rawMap)) + for k, v := range rawMap { + if kStr, ok := k.(string); ok { + data[kStr] = v + } + } + + cmd.val, err = parseFTInfo(data) + return err + } + + // RESP2 format - read as map n, err := rd.ReadMapLen() if err != nil { return err @@ -1438,11 +2207,62 @@ func (cmd *FTInfoCmd) readReply(rd *proto.Reader) (err error) { data[k] = v } cmd.val, err = parseFTInfo(data) - if err != nil { - return err + return err +} + +func (cmd *FTInfoCmd) Clone() Cmder { + val := FTInfoResult{ + IndexErrors: cmd.val.IndexErrors, + BytesPerRecordAvg: cmd.val.BytesPerRecordAvg, + Cleaning: cmd.val.Cleaning, + CursorStats: cmd.val.CursorStats, + DocTableSizeMB: cmd.val.DocTableSizeMB, + GCStats: cmd.val.GCStats, + GeoshapesSzMB: cmd.val.GeoshapesSzMB, + HashIndexingFailures: cmd.val.HashIndexingFailures, + IndexDefinition: cmd.val.IndexDefinition, + IndexName: cmd.val.IndexName, + Indexing: cmd.val.Indexing, + InvertedSzMB: cmd.val.InvertedSzMB, + KeyTableSizeMB: cmd.val.KeyTableSizeMB, + MaxDocID: cmd.val.MaxDocID, + NumDocs: cmd.val.NumDocs, + NumRecords: cmd.val.NumRecords, + NumTerms: cmd.val.NumTerms, + NumberOfUses: cmd.val.NumberOfUses, + OffsetBitsPerRecordAvg: cmd.val.OffsetBitsPerRecordAvg, + OffsetVectorsSzMB: cmd.val.OffsetVectorsSzMB, + OffsetsPerTermAvg: cmd.val.OffsetsPerTermAvg, + PercentIndexed: cmd.val.PercentIndexed, + RecordsPerDocAvg: cmd.val.RecordsPerDocAvg, + SortableValuesSizeMB: cmd.val.SortableValuesSizeMB, + TagOverheadSzMB: cmd.val.TagOverheadSzMB, + TextOverheadSzMB: cmd.val.TextOverheadSzMB, + TotalIndexMemorySzMB: cmd.val.TotalIndexMemorySzMB, + TotalIndexingTime: cmd.val.TotalIndexingTime, + TotalInvertedIndexBlocks: cmd.val.TotalInvertedIndexBlocks, + VectorIndexSzMB: cmd.val.VectorIndexSzMB, + } + // Clone slices and maps + if cmd.val.Attributes != nil { + val.Attributes = slices.Clone(cmd.val.Attributes) + } + if cmd.val.DialectStats != nil { + val.DialectStats = maps.Clone(cmd.val.DialectStats) + } + if cmd.val.FieldStatistics != nil { + val.FieldStatistics = slices.Clone(cmd.val.FieldStatistics) + } + if cmd.val.IndexOptions != nil { + val.IndexOptions = slices.Clone(cmd.val.IndexOptions) + } + if cmd.val.IndexDefinition.Prefixes != nil { + val.IndexDefinition.Prefixes = slices.Clone(cmd.val.IndexDefinition.Prefixes) + } + return &FTInfoCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, } - - return nil } // FTInfo - Retrieves information about an index. @@ -1501,8 +2321,9 @@ type FTSpellCheckCmd struct { func newFTSpellCheckCmd(ctx context.Context, args ...interface{}) *FTSpellCheckCmd { return &FTSpellCheckCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSpellCheck, }, } } @@ -1532,15 +2353,117 @@ func (cmd *FTSpellCheckCmd) RawResult() (interface{}, error) { } func (cmd *FTSpellCheckCmd) readReply(rd *proto.Reader) (err error) { - data, err := rd.ReadSlice() + readType, err := rd.PeekReplyType() if err != nil { return err } - cmd.val, err = parseFTSpellCheck(data) + + // RESP3 returns a map, RESP2 returns an array + if readType == proto.RespMap { + // Read raw response first for backwards compatibility + cmd.rawVal, err = rd.ReadReply() + if err != nil { + return err + } + + // Parse the raw response into structured result + rawMap, ok := cmd.rawVal.(map[interface{}]interface{}) + if !ok { + return fmt.Errorf("unexpected RESP3 response type: %T", cmd.rawVal) + } + + cmd.val, err = parseFTSpellCheckRESP3(rawMap) + return err + } + + // RESP2 format + data, err := rd.ReadSlice() if err != nil { return err } - return nil + cmd.val, err = parseFTSpellCheck(data) + return err +} + +// parseFTSpellCheckRESP3 parses the RESP3 format response from FT.SPELLCHECK. +// RESP3 format: +// +// map{ +// "results": map{ +// "misspelled_term": [ +// map{"suggestion": score}, +// ... +// ], +// ... +// } +// } +func parseFTSpellCheckRESP3(data map[interface{}]interface{}) ([]SpellCheckResult, error) { + results := make([]SpellCheckResult, 0) + + resultsData, ok := data["results"] + if !ok { + return results, nil + } + + resultsMap, ok := resultsData.(map[interface{}]interface{}) + if !ok { + return nil, fmt.Errorf("invalid results format: expected map, got %T", resultsData) + } + + for termKey, suggestionsData := range resultsMap { + term, ok := termKey.(string) + if !ok { + continue + } + + suggestionsArray, ok := suggestionsData.([]interface{}) + if !ok { + continue + } + + suggestions := make([]SpellCheckSuggestion, 0, len(suggestionsArray)) + for _, suggestionData := range suggestionsArray { + suggestionMap, ok := suggestionData.(map[interface{}]interface{}) + if !ok { + continue + } + + for suggKey, scoreVal := range suggestionMap { + suggestion, ok := suggKey.(string) + if !ok { + continue + } + + var score float64 + switch v := scoreVal.(type) { + case float64: + score = v + case int64: + score = float64(v) + case string: + var err error + score, err = strconv.ParseFloat(v, 64) + if err != nil { + continue + } + default: + continue + } + + suggestions = append(suggestions, SpellCheckSuggestion{ + Score: score, + Suggestion: suggestion, + }) + } + } + + results = append(results, SpellCheckResult{ + Term: term, + Suggestions: suggestions, + }) + } + + return results, nil } func parseFTSpellCheck(data []interface{}) ([]SpellCheckResult, error) { @@ -1598,6 +2521,25 @@ func parseFTSpellCheck(data []interface{}) ([]SpellCheckResult, error) { return results, nil } +func (cmd *FTSpellCheckCmd) Clone() Cmder { + var val []SpellCheckResult + if cmd.val != nil { + val = make([]SpellCheckResult, len(cmd.val)) + for i, result := range cmd.val { + val[i] = SpellCheckResult{ + Term: result.Term, + } + if result.Suggestions != nil { + val[i].Suggestions = slices.Clone(result.Suggestions) + } + } + } + return &FTSpellCheckCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + func parseFTSearch(data []interface{}, noContent, withScores, withPayloads, withSortKeys bool) (FTSearchResult, error) { if len(data) < 1 { return FTSearchResult{}, fmt.Errorf("unexpected search result format") @@ -1654,7 +2596,13 @@ func parseFTSearch(data []interface{}, noContent, withScores, withPayloads, with if i < len(data) { fields, ok := data[i].([]interface{}) if !ok { - return FTSearchResult{}, fmt.Errorf("invalid document fields format") + if data[i] == proto.Nil || data[i] == nil { + doc.Error = proto.Nil + doc.Fields = map[string]string{} + fields = []interface{}{} + } else { + return FTSearchResult{}, fmt.Errorf("invalid document fields format") + } } for j := 0; j < len(fields); j += 2 { @@ -1688,8 +2636,9 @@ type FTSearchCmd struct { func newFTSearchCmd(ctx context.Context, options *FTSearchOptions, args ...interface{}) *FTSearchCmd { return &FTSearchCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSearch, }, options: options, } @@ -1720,17 +2669,530 @@ func (cmd *FTSearchCmd) RawResult() (interface{}, error) { } func (cmd *FTSearchCmd) readReply(rd *proto.Reader) (err error) { + readType, err := rd.PeekReplyType() + if err != nil { + return err + } + + // RESP3 returns a map, RESP2 returns an array + if readType == proto.RespMap { + // Read raw response first for backwards compatibility + cmd.rawVal, err = rd.ReadReply() + if err != nil { + return err + } + // Parse the raw response into structured result + if mapVal, ok := cmd.rawVal.(map[interface{}]interface{}); ok { + cmd.val, err = parseFTSearchMapRESP3(mapVal) + } else { + return fmt.Errorf("unexpected RESP3 response type: %T", cmd.rawVal) + } + return err + } + + // RESP2 format or error response - use ReadReply to handle errors properly + data, err := rd.ReadReply() + if err != nil { + return err + } + if dataSlice, ok := data.([]interface{}); ok { + cmd.val, err = parseFTSearch(dataSlice, cmd.options.NoContent, cmd.options.WithScores, cmd.options.WithPayloads, cmd.options.WithSortKeys) + return err + } + return fmt.Errorf("unexpected response type: %T", data) +} + +// parseFTSearchMapRESP3 parses the RESP3 format response from FT.SEARCH. +// It takes a map[interface{}]interface{} which is the raw response from ReadReply(). +// RESP3 format: +// +// %5 +// $10 attributes => *0 +// $13 total_results => :N +// $6 format => $6 STRING +// $7 results => *N (array of maps with id, score, extra_attributes, values) +// $7 warning => *N (array of strings) +func parseFTSearchMapRESP3(data map[interface{}]interface{}) (FTSearchResult, error) { + var result FTSearchResult + result.Docs = make([]Document, 0) + + for k, v := range data { + key, ok := k.(string) + if !ok { + continue + } + + switch key { + case "total_results": + result.Total = internal.ToInteger(v) + case "results": + if resultsData, ok := v.([]interface{}); ok { + docs, err := parseFTSearchResultsMapRESP3(resultsData) + if err != nil { + return FTSearchResult{}, err + } + result.Docs = docs + } + case "warning": + if warningsData, ok := v.([]interface{}); ok { + result.Warnings = make([]string, 0, len(warningsData)) + for _, w := range warningsData { + if ws, ok := w.(string); ok { + result.Warnings = append(result.Warnings, ws) + } + } + } + // Ignore "attributes", "format", and other fields as per the spec + } + } + + return result, nil +} + +// parseFTSearchResultsMapRESP3 parses the results array from RESP3 FT.SEARCH response. +func parseFTSearchResultsMapRESP3(resultsData []interface{}) ([]Document, error) { + docs := make([]Document, 0, len(resultsData)) + for _, item := range resultsData { + if itemMap, ok := item.(map[interface{}]interface{}); ok { + doc, err := parseFTSearchDocumentMapRESP3(itemMap) + if err != nil { + return nil, err + } + docs = append(docs, doc) + } + } + return docs, nil +} + +// parseFTSearchDocumentMapRESP3 parses a single document from RESP3 FT.SEARCH response. +func parseFTSearchDocumentMapRESP3(itemMap map[interface{}]interface{}) (Document, error) { + doc := Document{ + Fields: make(map[string]string), + } + + for k, v := range itemMap { + key, ok := k.(string) + if !ok { + continue + } + + switch key { + case "id": + if id, ok := v.(string); ok { + doc.ID = id + } + case "score": + if score, ok := v.(float64); ok { + doc.Score = &score + } + case "payload": + if payload, ok := v.(string); ok { + doc.Payload = &payload + } + case "sortkey": + if sortKey, ok := v.(string); ok { + doc.SortKey = &sortKey + } + case "extra_attributes": + if extraAttrs, ok := v.(map[interface{}]interface{}); ok { + for ek, ev := range extraAttrs { + if ekStr, ok := ek.(string); ok { + if evStr, ok := ev.(string); ok { + doc.Fields[ekStr] = evStr + } + } + } + } + // Ignore "values" and other fields as per the spec + } + } + + return doc, nil +} + +func (cmd *FTSearchCmd) Clone() Cmder { + val := FTSearchResult{ + Total: cmd.val.Total, + } + if cmd.val.Docs != nil { + val.Docs = make([]Document, len(cmd.val.Docs)) + for i, doc := range cmd.val.Docs { + val.Docs[i] = Document{ + ID: doc.ID, + Score: doc.Score, + Payload: doc.Payload, + SortKey: doc.SortKey, + } + if doc.Fields != nil { + val.Docs[i].Fields = make(map[string]string, len(doc.Fields)) + for k, v := range doc.Fields { + val.Docs[i].Fields[k] = v + } + } + } + } + if cmd.val.Warnings != nil { + val.Warnings = make([]string, len(cmd.val.Warnings)) + copy(val.Warnings, cmd.val.Warnings) + } + var options *FTSearchOptions + if cmd.options != nil { + options = &FTSearchOptions{ + NoContent: cmd.options.NoContent, + Verbatim: cmd.options.Verbatim, + NoStopWords: cmd.options.NoStopWords, + WithScores: cmd.options.WithScores, + WithPayloads: cmd.options.WithPayloads, + WithSortKeys: cmd.options.WithSortKeys, + Slop: cmd.options.Slop, + Timeout: cmd.options.Timeout, + InOrder: cmd.options.InOrder, + Language: cmd.options.Language, + Expander: cmd.options.Expander, + Scorer: cmd.options.Scorer, + ExplainScore: cmd.options.ExplainScore, + Payload: cmd.options.Payload, + SortByWithCount: cmd.options.SortByWithCount, + LimitOffset: cmd.options.LimitOffset, + Limit: cmd.options.Limit, + CountOnly: cmd.options.CountOnly, + DialectVersion: cmd.options.DialectVersion, + } + // Clone slices and maps + if cmd.options.Filters != nil { + options.Filters = slices.Clone(cmd.options.Filters) + } + if cmd.options.GeoFilter != nil { + options.GeoFilter = slices.Clone(cmd.options.GeoFilter) + } + if cmd.options.InKeys != nil { + options.InKeys = slices.Clone(cmd.options.InKeys) + } + if cmd.options.InFields != nil { + options.InFields = slices.Clone(cmd.options.InFields) + } + if cmd.options.Return != nil { + options.Return = slices.Clone(cmd.options.Return) + } + if cmd.options.SortBy != nil { + options.SortBy = slices.Clone(cmd.options.SortBy) + } + if cmd.options.Params != nil { + options.Params = maps.Clone(cmd.options.Params) + } + } + return &FTSearchCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + options: options, + } +} + +// FTHybridResult represents the result of a hybrid search operation +type FTHybridResult struct { + TotalResults int + Results []map[string]interface{} + Warnings []string + ExecutionTime float64 +} + +// FTHybridCursorResult represents cursor result for hybrid search +type FTHybridCursorResult struct { + SearchCursorID int + VsimCursorID int +} + +type FTHybridCmd struct { + baseCmd + val FTHybridResult + cursorVal *FTHybridCursorResult + options *FTHybridOptions + withCursor bool +} + +func newFTHybridCmd(ctx context.Context, options *FTHybridOptions, args ...interface{}) *FTHybridCmd { + var withCursor bool + if options != nil && options.WithCursor { + withCursor = true + } + return &FTHybridCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + options: options, + withCursor: withCursor, + } +} + +func (cmd *FTHybridCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *FTHybridCmd) SetVal(val FTHybridResult) { + cmd.val = val +} + +func (cmd *FTHybridCmd) Result() (FTHybridResult, error) { + return cmd.val, cmd.err +} + +func (cmd *FTHybridCmd) CursorResult() (*FTHybridCursorResult, error) { + return cmd.cursorVal, cmd.err +} + +func (cmd *FTHybridCmd) Val() FTHybridResult { + return cmd.val +} + +func (cmd *FTHybridCmd) CursorVal() *FTHybridCursorResult { + return cmd.cursorVal +} + +func (cmd *FTHybridCmd) RawVal() interface{} { + return cmd.rawVal +} + +func (cmd *FTHybridCmd) RawResult() (interface{}, error) { + return cmd.rawVal, cmd.err +} + +func parseFTHybrid(data []interface{}, withCursor bool) (FTHybridResult, *FTHybridCursorResult, error) { + // Convert to map + resultMap := make(map[string]interface{}) + for i := 0; i < len(data); i += 2 { + if i+1 < len(data) { + key, ok := data[i].(string) + if !ok { + return FTHybridResult{}, nil, fmt.Errorf("invalid key type at index %d", i) + } + resultMap[key] = data[i+1] + } + } + + // Handle cursor result + if withCursor { + searchCursorID, ok1 := resultMap["SEARCH"].(int64) + vsimCursorID, ok2 := resultMap["VSIM"].(int64) + if !ok1 || !ok2 { + return FTHybridResult{}, nil, fmt.Errorf("invalid cursor result format") + } + return FTHybridResult{}, &FTHybridCursorResult{ + SearchCursorID: int(searchCursorID), + VsimCursorID: int(vsimCursorID), + }, nil + } + + // Parse regular result + totalResults, ok := resultMap["total_results"].(int64) + if !ok { + return FTHybridResult{}, nil, fmt.Errorf("invalid total_results format") + } + + resultsData, ok := resultMap["results"].([]interface{}) + if !ok { + return FTHybridResult{}, nil, fmt.Errorf("invalid results format") + } + + // Parse each result item + results := make([]map[string]interface{}, 0, len(resultsData)) + for _, item := range resultsData { + // Try parsing as map[string]interface{} first (RESP3 format) + if itemMap, ok := item.(map[string]interface{}); ok { + results = append(results, itemMap) + continue + } + + // Try parsing as map[interface{}]interface{} (alternative RESP3 format) + if rawMap, ok := item.(map[interface{}]interface{}); ok { + itemMap := make(map[string]interface{}) + for k, v := range rawMap { + if keyStr, ok := k.(string); ok { + itemMap[keyStr] = v + } + } + results = append(results, itemMap) + continue + } + + // Fall back to array format (RESP2 format - key-value pairs) + itemData, ok := item.([]interface{}) + if !ok { + return FTHybridResult{}, nil, fmt.Errorf("invalid result item format") + } + + itemMap := make(map[string]interface{}) + for i := 0; i < len(itemData); i += 2 { + if i+1 < len(itemData) { + key, ok := itemData[i].(string) + if !ok { + return FTHybridResult{}, nil, fmt.Errorf("invalid item key format") + } + itemMap[key] = itemData[i+1] + } + } + results = append(results, itemMap) + } + + // Parse warnings (optional field) + var warnings []string + if warningsData, ok := resultMap["warnings"].([]interface{}); ok { + warnings = make([]string, 0, len(warningsData)) + for _, w := range warningsData { + if ws, ok := w.(string); ok { + warnings = append(warnings, ws) + } + } + } + + // Parse execution time (optional field) + var executionTime float64 + if execTimeVal, exists := resultMap["execution_time"]; exists { + switch v := execTimeVal.(type) { + case string: + var err error + executionTime, err = strconv.ParseFloat(v, 64) + if err != nil { + return FTHybridResult{}, nil, fmt.Errorf("invalid execution_time format: %v", err) + } + case float64: + executionTime = v + case int64: + executionTime = float64(v) + } + } + + return FTHybridResult{ + TotalResults: int(totalResults), + Results: results, + Warnings: warnings, + ExecutionTime: executionTime, + }, nil, nil +} + +func (cmd *FTHybridCmd) readReply(rd *proto.Reader) (err error) { data, err := rd.ReadSlice() if err != nil { return err } - cmd.val, err = parseFTSearch(data, cmd.options.NoContent, cmd.options.WithScores, cmd.options.WithPayloads, cmd.options.WithSortKeys) + + result, cursorResult, err := parseFTHybrid(data, cmd.withCursor) if err != nil { return err } + + if cmd.withCursor { + cmd.cursorVal = cursorResult + } else { + cmd.val = result + } return nil } +func (cmd *FTHybridCmd) Clone() Cmder { + val := FTHybridResult{ + TotalResults: cmd.val.TotalResults, + ExecutionTime: cmd.val.ExecutionTime, + } + if cmd.val.Results != nil { + val.Results = make([]map[string]interface{}, len(cmd.val.Results)) + for i, result := range cmd.val.Results { + val.Results[i] = make(map[string]interface{}, len(result)) + for k, v := range result { + val.Results[i][k] = v + } + } + } + if cmd.val.Warnings != nil { + val.Warnings = slices.Clone(cmd.val.Warnings) + } + + var cursorVal *FTHybridCursorResult + if cmd.cursorVal != nil { + cursorVal = &FTHybridCursorResult{ + SearchCursorID: cmd.cursorVal.SearchCursorID, + VsimCursorID: cmd.cursorVal.VsimCursorID, + } + } + + var options *FTHybridOptions + if cmd.options != nil { + options = &FTHybridOptions{ + CountExpressions: cmd.options.CountExpressions, + Load: cmd.options.Load, + Filter: cmd.options.Filter, + LimitOffset: cmd.options.LimitOffset, + Limit: cmd.options.Limit, + ExplainScore: cmd.options.ExplainScore, + Timeout: cmd.options.Timeout, + WithCursor: cmd.options.WithCursor, + } + // Clone slices and maps + if cmd.options.SearchExpressions != nil { + options.SearchExpressions = make([]FTHybridSearchExpression, len(cmd.options.SearchExpressions)) + copy(options.SearchExpressions, cmd.options.SearchExpressions) + } + if cmd.options.VectorExpressions != nil { + options.VectorExpressions = make([]FTHybridVectorExpression, len(cmd.options.VectorExpressions)) + copy(options.VectorExpressions, cmd.options.VectorExpressions) + } + if cmd.options.Combine != nil { + options.Combine = &FTHybridCombineOptions{ + Method: cmd.options.Combine.Method, + Count: cmd.options.Combine.Count, + Window: cmd.options.Combine.Window, + Constant: cmd.options.Combine.Constant, + Alpha: cmd.options.Combine.Alpha, + Beta: cmd.options.Combine.Beta, + YieldScoreAs: cmd.options.Combine.YieldScoreAs, + } + } + if cmd.options.GroupBy != nil { + options.GroupBy = &FTHybridGroupBy{ + Count: cmd.options.GroupBy.Count, + ReduceFunc: cmd.options.GroupBy.ReduceFunc, + ReduceCount: cmd.options.GroupBy.ReduceCount, + } + if cmd.options.GroupBy.Fields != nil { + options.GroupBy.Fields = make([]string, len(cmd.options.GroupBy.Fields)) + copy(options.GroupBy.Fields, cmd.options.GroupBy.Fields) + } + if cmd.options.GroupBy.ReduceParams != nil { + options.GroupBy.ReduceParams = make([]interface{}, len(cmd.options.GroupBy.ReduceParams)) + copy(options.GroupBy.ReduceParams, cmd.options.GroupBy.ReduceParams) + } + } + if cmd.options.Apply != nil { + options.Apply = make([]FTHybridApply, len(cmd.options.Apply)) + copy(options.Apply, cmd.options.Apply) + } + if cmd.options.SortBy != nil { + options.SortBy = make([]FTSearchSortBy, len(cmd.options.SortBy)) + copy(options.SortBy, cmd.options.SortBy) + } + if cmd.options.Params != nil { + options.Params = make(map[string]interface{}, len(cmd.options.Params)) + for k, v := range cmd.options.Params { + options.Params[k] = v + } + } + if cmd.options.WithCursorOptions != nil { + options.WithCursorOptions = &FTHybridWithCursor{ + MaxIdle: cmd.options.WithCursorOptions.MaxIdle, + Count: cmd.options.WithCursorOptions.Count, + } + } + } + + return &FTHybridCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + cursorVal: cursorVal, + options: options, + withCursor: cmd.withCursor, + } +} + // FTSearch - Executes a search query on an index. // The 'index' parameter specifies the index to search, and the 'query' parameter specifies the search query. // For more information, please refer to the Redis documentation about [FT.SEARCH]. @@ -1751,7 +3213,7 @@ type SearchQuery []interface{} // For more information, please refer to the Redis documentation about [FT.SEARCH]. // // [FT.SEARCH]: (https://redis.io/commands/ft.search/) -func FTSearchQuery(query string, options *FTSearchOptions) SearchQuery { +func FTSearchQuery(query string, options *FTSearchOptions) (SearchQuery, error) { queryArgs := []interface{}{query} if options != nil { if options.NoContent { @@ -1831,7 +3293,7 @@ func FTSearchQuery(query string, options *FTSearchOptions) SearchQuery { for _, sortBy := range options.SortBy { queryArgs = append(queryArgs, sortBy.FieldName) if sortBy.Asc && sortBy.Desc { - panic("FT.SEARCH: ASC and DESC are mutually exclusive") + return nil, fmt.Errorf("FT.SEARCH: ASC and DESC are mutually exclusive") } if sortBy.Asc { queryArgs = append(queryArgs, "ASC") @@ -1859,7 +3321,7 @@ func FTSearchQuery(query string, options *FTSearchOptions) SearchQuery { queryArgs = append(queryArgs, "DIALECT", 2) } } - return queryArgs + return queryArgs, nil } // FTSearchWithArgs - Executes a search query on an index with additional options. @@ -1948,7 +3410,9 @@ func (c cmdable) FTSearchWithArgs(ctx context.Context, index string, query strin for _, sortBy := range options.SortBy { args = append(args, sortBy.FieldName) if sortBy.Asc && sortBy.Desc { - panic("FT.SEARCH: ASC and DESC are mutually exclusive") + cmd := newFTSearchCmd(ctx, options, args...) + cmd.SetErr(fmt.Errorf("FT.SEARCH: ASC and DESC are mutually exclusive")) + return cmd } if sortBy.Asc { args = append(args, "ASC") @@ -1988,8 +3452,9 @@ func (c cmdable) FTSearchWithArgs(ctx context.Context, index string, query strin func NewFTSynDumpCmd(ctx context.Context, args ...interface{}) *FTSynDumpCmd { return &FTSynDumpCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeFTSynDump, }, } } @@ -2019,6 +3484,30 @@ func (cmd *FTSynDumpCmd) RawResult() (interface{}, error) { } func (cmd *FTSynDumpCmd) readReply(rd *proto.Reader) error { + readType, err := rd.PeekReplyType() + if err != nil { + return err + } + + // RESP3 returns a map, RESP2 returns an array + if readType == proto.RespMap { + // Read raw response first for backwards compatibility + cmd.rawVal, err = rd.ReadReply() + if err != nil { + return err + } + + // Parse the raw response into structured result + rawMap, ok := cmd.rawVal.(map[interface{}]interface{}) + if !ok { + return fmt.Errorf("unexpected RESP3 response type: %T", cmd.rawVal) + } + + cmd.val, err = parseFTSynDumpRESP3(rawMap) + return err + } + + // RESP2 format termSynonymPairs, err := rd.ReadSlice() if err != nil { return err @@ -2055,6 +3544,64 @@ func (cmd *FTSynDumpCmd) readReply(rd *proto.Reader) error { return nil } +// parseFTSynDumpRESP3 parses the RESP3 format response from FT.SYNDUMP. +// RESP3 format: +// +// map{ +// "term1": ["synonym_group_id1", ...], +// "term2": ["synonym_group_id2", ...], +// ... +// } +func parseFTSynDumpRESP3(data map[interface{}]interface{}) ([]FTSynDumpResult, error) { + results := make([]FTSynDumpResult, 0, len(data)) + + for termKey, synonymsData := range data { + term, ok := termKey.(string) + if !ok { + continue + } + + synonymsArray, ok := synonymsData.([]interface{}) + if !ok { + continue + } + + synonymList := make([]string, 0, len(synonymsArray)) + for _, syn := range synonymsArray { + if synonym, ok := syn.(string); ok { + synonymList = append(synonymList, synonym) + } + } + + results = append(results, FTSynDumpResult{ + Term: term, + Synonyms: synonymList, + }) + } + + return results, nil +} + +func (cmd *FTSynDumpCmd) Clone() Cmder { + var val []FTSynDumpResult + if cmd.val != nil { + val = make([]FTSynDumpResult, len(cmd.val)) + for i, result := range cmd.val { + val[i] = FTSynDumpResult{ + Term: result.Term, + } + if result.Synonyms != nil { + val[i].Synonyms = make([]string, len(result.Synonyms)) + copy(val[i].Synonyms, result.Synonyms) + } + } + } + return &FTSynDumpCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // FTSynDump - Dumps the contents of a synonym group. // The 'index' parameter specifies the index to dump. // For more information, please refer to the Redis documentation: @@ -2101,3 +3648,289 @@ func (c cmdable) FTTagVals(ctx context.Context, index string, field string) *Str _ = c(ctx, cmd) return cmd } + +// FTHybrid - Executes a hybrid search combining full-text search and vector similarity +// The 'index' parameter specifies the index to search, 'searchExpr' is the search query, +// 'vectorField' is the name of the vector field, and 'vectorData' is the vector to search with. +// FTHybrid is still experimental, the command behaviour and signature may change +func (c cmdable) FTHybrid(ctx context.Context, index string, searchExpr string, vectorField string, vectorData Vector) *FTHybridCmd { + options := &FTHybridOptions{ + CountExpressions: 2, + SearchExpressions: []FTHybridSearchExpression{ + {Query: searchExpr}, + }, + VectorExpressions: []FTHybridVectorExpression{ + {VectorField: vectorField, VectorData: vectorData}, + }, + } + return c.FTHybridWithArgs(ctx, index, options) +} + +func hybridVectorBlob(v Vector) (interface{}, error) { + if v == nil { + return nil, fmt.Errorf("FT.HYBRID: vector data is required") + } + + switch vector := v.(type) { + case *VectorFP32: + return hybridVectorBytes(vector.Val) + case *VectorFloat16: + return hybridVectorBytes(vector.Val) + case *VectorBFloat16: + return hybridVectorBytes(vector.Val) + case *VectorFloat64: + return hybridVectorBytes(vector.Val) + case *VectorInt8: + return hybridVectorBytes(vector.Val) + case *VectorUint8: + return hybridVectorBytes(vector.Val) + case *VectorValues, *VectorRef: + return nil, fmt.Errorf("FT.HYBRID: unsupported vector type %T", v) + default: + values := v.Value() + if len(values) < 2 { + return nil, fmt.Errorf("FT.HYBRID: vector Value must contain a blob at index 1") + } + return values[1], nil + } +} + +func hybridVectorBytes(blob []byte) ([]byte, error) { + if len(blob) == 0 { + return nil, fmt.Errorf("FT.HYBRID: vector blob is required") + } + return blob, nil +} + +// generateVectorParamName returns a parameter name that is not already present +// in params. It is used to pass vector data via the PARAMS mechanism when the +// caller does not provide a VectorParamName, since inline vector blobs are no +// longer supported by Redis. +func generateVectorParamName(params map[string]interface{}) string { + for i := 0; ; i++ { + name := fmt.Sprintf("__vector_param_%d", i) + if _, ok := params[name]; !ok { + return name + } + } +} + +// FTHybridWithArgs - Executes a hybrid search with advanced options +// FTHybridWithArgs is still experimental, the command behaviour and signature may change +// +// Vector data is always sent through the PARAMS mechanism, because inline vector +// blobs are no longer supported by Redis. For every vector expression whose +// VectorParamName is empty, a unique name is generated (e.g. "__vector_param_0") +// and the corresponding blob is passed via PARAMS. +// +// options.Params is never mutated: the command is built from a local copy that +// combines the caller-provided params with any generated vector parameters. This +// makes it safe to reuse the same *FTHybridOptions across multiple calls. Generated +// names are also reserved against all explicit VectorParamName values, so they never +// collide with explicit names (even those following the "__vector_param_N" pattern). +func (c cmdable) FTHybridWithArgs(ctx context.Context, index string, options *FTHybridOptions) *FTHybridCmd { + args := []interface{}{"FT.HYBRID", index} + + if options != nil { + // Add search expressions + for _, searchExpr := range options.SearchExpressions { + args = append(args, "SEARCH", searchExpr.Query) + + if searchExpr.Scorer != "" { + args = append(args, "SCORER", searchExpr.Scorer) + if len(searchExpr.ScorerParams) > 0 { + args = append(args, searchExpr.ScorerParams...) + } + } + + if searchExpr.YieldScoreAs != "" { + args = append(args, "YIELD_SCORE_AS", searchExpr.YieldScoreAs) + } + } + + // Vector data is always passed via the PARAMS mechanism (inline vector blobs + // are no longer supported by Redis). When vectors are present, build a local + // copy of the caller-provided params so options.Params is never mutated, and + // pre-reserve any explicit VectorParamName values so generated names never + // collide with them. + params := options.Params + if len(options.VectorExpressions) > 0 { + params = make(map[string]interface{}, len(options.Params)+len(options.VectorExpressions)) + for k, v := range options.Params { + params[k] = v + } + for _, vectorExpr := range options.VectorExpressions { + if vectorExpr.VectorParamName != "" { + params[vectorExpr.VectorParamName] = nil + } + } + } + // Add vector expressions + for _, vectorExpr := range options.VectorExpressions { + args = append(args, "VSIM", "@"+vectorExpr.VectorField) + + vectorBlob, err := hybridVectorBlob(vectorExpr.VectorData) + if err != nil { + cmd := newFTHybridCmd(ctx, options, args...) + cmd.SetErr(err) + return cmd + } + + // When VectorParamName is not provided, generate a unique name. Generated + // names are tracked only in the local params map, never written back to + // options.Params. + paramName := vectorExpr.VectorParamName + if paramName == "" { + paramName = generateVectorParamName(params) + } + args = append(args, "$"+paramName) + params[paramName] = vectorBlob + + if vectorExpr.Method != "" { + args = append(args, vectorExpr.Method) + if len(vectorExpr.MethodParams) > 0 { + // MethodParams should be key-value pairs, count them + args = append(args, len(vectorExpr.MethodParams)) + args = append(args, vectorExpr.MethodParams...) + } + } + + if vectorExpr.Filter != "" { + args = append(args, "FILTER", vectorExpr.Filter) + } + + if vectorExpr.YieldScoreAs != "" { + args = append(args, "YIELD_SCORE_AS", vectorExpr.YieldScoreAs) + } + } + + // Add combine/fusion options + if options.Combine != nil { + // Build combine parameters + combineParams := []interface{}{} + + switch options.Combine.Method { + case FTHybridCombineRRF: + if options.Combine.Window > 0 { + combineParams = append(combineParams, "WINDOW", options.Combine.Window) + } + if options.Combine.Constant > 0 { + combineParams = append(combineParams, "CONSTANT", options.Combine.Constant) + } + case FTHybridCombineLinear: + if options.Combine.Alpha > 0 { + combineParams = append(combineParams, "ALPHA", options.Combine.Alpha) + } + if options.Combine.Beta > 0 { + combineParams = append(combineParams, "BETA", options.Combine.Beta) + } + } + + if options.Combine.YieldScoreAs != "" { + combineParams = append(combineParams, "YIELD_SCORE_AS", options.Combine.YieldScoreAs) + } + + // Add COMBINE with method and parameter count + args = append(args, "COMBINE", string(options.Combine.Method)) + if len(combineParams) > 0 { + args = append(args, len(combineParams)) + args = append(args, combineParams...) + } + } + + // Add LOAD (projected fields) + if len(options.Load) > 0 { + args = append(args, "LOAD", len(options.Load)) + for _, field := range options.Load { + args = append(args, field) + } + } + + // Add GROUPBY + if options.GroupBy != nil { + args = append(args, "GROUPBY", options.GroupBy.Count) + for _, field := range options.GroupBy.Fields { + args = append(args, field) + } + if options.GroupBy.ReduceFunc != "" { + args = append(args, "REDUCE", options.GroupBy.ReduceFunc, options.GroupBy.ReduceCount) + args = append(args, options.GroupBy.ReduceParams...) + } + } + + // Add APPLY transformations + for _, apply := range options.Apply { + args = append(args, "APPLY", apply.Expression, "AS", apply.AsField) + } + + // Add SORTBY + if len(options.SortBy) > 0 { + sortByOptions := []interface{}{} + for _, sortBy := range options.SortBy { + sortByOptions = append(sortByOptions, sortBy.FieldName) + if sortBy.Asc && sortBy.Desc { + cmd := newFTHybridCmd(ctx, options, args...) + cmd.SetErr(fmt.Errorf("FT.HYBRID: ASC and DESC are mutually exclusive")) + return cmd + } + if sortBy.Asc { + sortByOptions = append(sortByOptions, "ASC") + } + if sortBy.Desc { + sortByOptions = append(sortByOptions, "DESC") + } + } + args = append(args, "SORTBY", len(sortByOptions)) + args = append(args, sortByOptions...) + } + + // Add FILTER (post-filter) + if options.Filter != "" { + args = append(args, "FILTER", options.Filter) + } + + // Add LIMIT + if options.LimitOffset >= 0 && options.Limit > 0 || options.LimitOffset > 0 && options.Limit == 0 { + args = append(args, "LIMIT", options.LimitOffset, options.Limit) + } + + // Add PARAMS + // Emit from the local params map, which contains the caller-provided params + // plus any generated vector parameter names. options.Params is left untouched. + if len(params) > 0 { + args = append(args, "PARAMS", len(params)*2) + for key, value := range params { + // PARAMS entries are passed without a '$' prefix; they are referenced in + // the query and clauses using "$". + args = append(args, key, value) + } + } + + // Add EXPLAINSCORE + if options.ExplainScore { + args = append(args, "EXPLAINSCORE") + } + + // Add TIMEOUT + if options.Timeout > 0 { + args = append(args, "TIMEOUT", options.Timeout) + } + + // Add WITHCURSOR support + if options.WithCursor { + args = append(args, "WITHCURSOR") + if options.WithCursorOptions != nil { + if options.WithCursorOptions.Count > 0 { + args = append(args, "COUNT", options.WithCursorOptions.Count) + } + if options.WithCursorOptions.MaxIdle > 0 { + args = append(args, "MAXIDLE", options.WithCursorOptions.MaxIdle) + } + } + } + } + + cmd := newFTHybridCmd(ctx, options, args...) + _ = c(ctx, cmd) + return cmd +} diff --git a/vendor/github.com/redis/go-redis/v9/sentinel.go b/vendor/github.com/redis/go-redis/v9/sentinel.go index cfc848cf0..055b3101f 100644 --- a/vendor/github.com/redis/go-redis/v9/sentinel.go +++ b/vendor/github.com/redis/go-redis/v9/sentinel.go @@ -5,14 +5,20 @@ import ( "crypto/tls" "errors" "fmt" + "math/rand" "net" + "net/url" + "slices" + "strconv" "strings" "sync" "time" + "github.com/redis/go-redis/v9/auth" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" - "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/maintnotifications" + "github.com/redis/go-redis/v9/push" ) //------------------------------------------------------------------------------ @@ -58,26 +64,80 @@ type FailoverOptions struct { Protocol int Username string Password string - DB int + + // Push notifications are always enabled for RESP3 connections + // CredentialsProvider allows the username and password to be updated + // before reconnecting. It should return the current username and password. + CredentialsProvider func() (username string, password string) + + // CredentialsProviderContext is an enhanced parameter of CredentialsProvider, + // done to maintain API compatibility. In the future, + // there might be a merge between CredentialsProviderContext and CredentialsProvider. + // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. + CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + + // StreamingCredentialsProvider is used to retrieve the credentials + // for the connection from an external source. Those credentials may change + // during the connection lifetime. This is useful for managed identity + // scenarios where the credentials are retrieved from an external source. + // + // Currently, this is a placeholder for the future implementation. + StreamingCredentialsProvider auth.StreamingCredentialsProvider + DB int MaxRetries int MinRetryBackoff time.Duration MaxRetryBackoff time.Duration - DialTimeout time.Duration + DialTimeout time.Duration + + // DialerRetries is the maximum number of retry attempts when dialing fails. + // + // default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // + // default: 100 milliseconds + DialerRetryTimeout time.Duration + + // DialerRetryBackoff controls the delay between dial retry attempts. + // See Options.DialerRetryBackoff for details. + DialerRetryBackoff func(attempt int) time.Duration + ReadTimeout time.Duration WriteTimeout time.Duration ContextTimeoutEnabled bool + // ReadBufferSize is the size of the bufio.Reader buffer for each connection. + // Larger buffers can improve performance for commands that return large responses. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + ReadBufferSize int + + // WriteBufferSize is the size of the bufio.Writer buffer for each connection. + // Larger buffers can improve performance for large pipelines and commands with many arguments. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + WriteBufferSize int + PoolFIFO bool - PoolSize int - PoolTimeout time.Duration - MinIdleConns int - MaxIdleConns int - MaxActiveConns int - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration + PoolSize int + + // MaxConcurrentDials is the maximum number of concurrent connection creation goroutines. + // If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize. + MaxConcurrentDials int + + PoolTimeout time.Duration + MinIdleConns int + MaxIdleConns int + MaxActiveConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + ConnMaxLifetimeJitter time.Duration TLSConfig *tls.Config @@ -94,7 +154,29 @@ type FailoverOptions struct { DisableIdentity bool IdentitySuffix string - UnstableResp3 bool + + // FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing. + // When a node is marked as failing, it will be avoided for this duration. + // Only applies to failover cluster clients. Default is 15 seconds. + FailingTimeoutSeconds int + + // Deprecated: All RediSearch commands now have stable RESP3 parsing and this + // flag is a no-op. It is kept for backwards compatibility and will be removed + // in a future release. + UnstableResp3 bool + + // PushNotificationProcessor is the processor for handling push notifications. + // If nil, a default processor will be created for RESP3 connections. + PushNotificationProcessor push.NotificationProcessor + + // MaintNotificationsConfig is not supported for FailoverClients at the moment + // MaintNotificationsConfig provides custom configuration for maintnotifications upgrades. + // When MaintNotificationsConfig.Mode is not "disabled", the client will handle + // upgrade notifications gracefully and manage connection/pool state transitions + // seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, maintnotifications upgrades are disabled. + // (however if Mode is nil, it defaults to "auto" - enable if server supports it) + //MaintNotificationsConfig *maintnotifications.Config } func (opt *FailoverOptions) clientOptions() *Options { @@ -105,36 +187,53 @@ func (opt *FailoverOptions) clientOptions() *Options { Dialer: opt.Dialer, OnConnect: opt.OnConnect, - DB: opt.DB, - Protocol: opt.Protocol, - Username: opt.Username, - Password: opt.Password, + DB: opt.DB, + Protocol: opt.Protocol, + Username: opt.Username, + Password: opt.Password, + CredentialsProvider: opt.CredentialsProvider, + CredentialsProviderContext: opt.CredentialsProviderContext, + StreamingCredentialsProvider: opt.StreamingCredentialsProvider, MaxRetries: opt.MaxRetries, MinRetryBackoff: opt.MinRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff, - DialTimeout: opt.DialTimeout, - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, + + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + DialerRetryBackoff: opt.DialerRetryBackoff, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + ContextTimeoutEnabled: opt.ContextTimeoutEnabled, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, + PoolFIFO: opt.PoolFIFO, + PoolSize: opt.PoolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + MaxActiveConns: opt.MaxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ConnMaxLifetimeJitter: opt.ConnMaxLifetimeJitter, TLSConfig: opt.TLSConfig, DisableIdentity: opt.DisableIdentity, DisableIndentity: opt.DisableIndentity, - IdentitySuffix: opt.IdentitySuffix, - UnstableResp3: opt.UnstableResp3, + IdentitySuffix: opt.IdentitySuffix, + UnstableResp3: opt.UnstableResp3, + PushNotificationProcessor: opt.PushNotificationProcessor, + + MaintNotificationsConfig: &maintnotifications.Config{ + Mode: maintnotifications.ModeDisabled, + }, } } @@ -154,27 +253,42 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options { MinRetryBackoff: opt.MinRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff, - DialTimeout: opt.DialTimeout, - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, + // The sentinel client uses a 4KiB read/write buffer size. + ReadBufferSize: 4096, + WriteBufferSize: 4096, + + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + DialerRetryBackoff: opt.DialerRetryBackoff, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + ContextTimeoutEnabled: opt.ContextTimeoutEnabled, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, + PoolFIFO: opt.PoolFIFO, + PoolSize: opt.PoolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + MaxActiveConns: opt.MaxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ConnMaxLifetimeJitter: opt.ConnMaxLifetimeJitter, TLSConfig: opt.TLSConfig, DisableIdentity: opt.DisableIdentity, DisableIndentity: opt.DisableIndentity, - IdentitySuffix: opt.IdentitySuffix, - UnstableResp3: opt.UnstableResp3, + IdentitySuffix: opt.IdentitySuffix, + UnstableResp3: opt.UnstableResp3, + PushNotificationProcessor: opt.PushNotificationProcessor, + + MaintNotificationsConfig: &maintnotifications.Config{ + Mode: maintnotifications.ModeDisabled, + }, } } @@ -185,44 +299,213 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { Dialer: opt.Dialer, OnConnect: opt.OnConnect, - Protocol: opt.Protocol, - Username: opt.Username, - Password: opt.Password, + Protocol: opt.Protocol, + Username: opt.Username, + Password: opt.Password, + CredentialsProvider: opt.CredentialsProvider, + CredentialsProviderContext: opt.CredentialsProviderContext, + StreamingCredentialsProvider: opt.StreamingCredentialsProvider, MaxRedirects: opt.MaxRetries, + ReadOnly: opt.ReplicaOnly, RouteByLatency: opt.RouteByLatency, RouteRandomly: opt.RouteRandomly, MinRetryBackoff: opt.MinRetryBackoff, MaxRetryBackoff: opt.MaxRetryBackoff, - DialTimeout: opt.DialTimeout, - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, + + DialTimeout: opt.DialTimeout, + DialerRetries: opt.DialerRetries, + DialerRetryTimeout: opt.DialerRetryTimeout, + DialerRetryBackoff: opt.DialerRetryBackoff, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + ContextTimeoutEnabled: opt.ContextTimeoutEnabled, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, + PoolFIFO: opt.PoolFIFO, + PoolSize: opt.PoolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + MaxActiveConns: opt.MaxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, TLSConfig: opt.TLSConfig, - DisableIdentity: opt.DisableIdentity, - DisableIndentity: opt.DisableIndentity, + DisableIdentity: opt.DisableIdentity, + DisableIndentity: opt.DisableIndentity, + IdentitySuffix: opt.IdentitySuffix, + FailingTimeoutSeconds: opt.FailingTimeoutSeconds, + PushNotificationProcessor: opt.PushNotificationProcessor, + + MaintNotificationsConfig: &maintnotifications.Config{ + Mode: maintnotifications.ModeDisabled, + }, + } +} + +// ParseFailoverURL parses a URL into FailoverOptions that can be used to connect to Redis. +// The URL must be in the form: +// +// redis://:@:/ +// or +// rediss://:@:/ +// +// To add additional addresses, specify the query parameter, "addr" one or more times. e.g: +// +// redis://:@:/?addr=:&addr=: +// or +// rediss://:@:/?addr=:&addr=: +// +// Most Option fields can be set using query parameters, with the following restrictions: +// - field names are mapped using snake-case conversion: to set MaxRetries, use max_retries +// - only scalar type fields are supported (bool, int, time.Duration) +// - for time.Duration fields, values must be a valid input for time.ParseDuration(); +// additionally a plain integer as value (i.e. without unit) is interpreted as seconds +// - to disable a duration field, use value less than or equal to 0; to use the default +// value, leave the value blank or remove the parameter +// - only the last value is interpreted if a parameter is given multiple times +// - fields "network", "addr", "sentinel_username" and "sentinel_password" can only be set using other +// URL attributes (scheme, host, userinfo, resp.), query parameters using these +// names will be treated as unknown parameters +// - unknown parameter names will result in an error +// - use "skip_verify=true" to ignore TLS certificate validation +// +// Example: +// +// redis://user:password@localhost:6789?master_name=mymaster&dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791 +// is equivalent to: +// &FailoverOptions{ +// MasterName: "mymaster", +// Addr: ["localhost:6789", "localhost:6790", "localhost:6791"] +// DialTimeout: 3 * time.Second, // no time unit = seconds +// ReadTimeout: 6 * time.Second, +// } +func ParseFailoverURL(redisURL string) (*FailoverOptions, error) { + u, err := url.Parse(redisURL) + if err != nil { + return nil, err + } + return setupFailoverConn(u) +} + +func setupFailoverConn(u *url.URL) (*FailoverOptions, error) { + o := &FailoverOptions{} + + o.SentinelUsername, o.SentinelPassword = getUserPassword(u) + + h, p := getHostPortWithDefaults(u) + o.SentinelAddrs = append(o.SentinelAddrs, net.JoinHostPort(h, p)) - IdentitySuffix: opt.IdentitySuffix, + switch u.Scheme { + case "rediss": + o.TLSConfig = &tls.Config{ServerName: h, MinVersion: tls.VersionTLS12} + case "redis": + o.TLSConfig = nil + default: + return nil, fmt.Errorf("redis: invalid URL scheme: %s", u.Scheme) } + + f := strings.FieldsFunc(u.Path, func(r rune) bool { + return r == '/' + }) + switch len(f) { + case 0: + o.DB = 0 + case 1: + var err error + if o.DB, err = strconv.Atoi(f[0]); err != nil { + return nil, fmt.Errorf("redis: invalid database number: %q", f[0]) + } + default: + return nil, fmt.Errorf("redis: invalid URL path: %s", u.Path) + } + + return setupFailoverConnParams(u, o) +} + +func setupFailoverConnParams(u *url.URL, o *FailoverOptions) (*FailoverOptions, error) { + q := queryOptions{q: u.Query()} + + o.MasterName = q.string("master_name") + o.ClientName = q.string("client_name") + o.RouteByLatency = q.bool("route_by_latency") + o.RouteRandomly = q.bool("route_randomly") + o.ReplicaOnly = q.bool("replica_only") + o.UseDisconnectedReplicas = q.bool("use_disconnected_replicas") + o.Protocol = q.int("protocol") + o.Username = q.string("username") + o.Password = q.string("password") + o.MaxRetries = q.int("max_retries") + o.MinRetryBackoff = q.duration("min_retry_backoff") + o.MaxRetryBackoff = q.duration("max_retry_backoff") + o.DialTimeout = q.duration("dial_timeout") + o.DialerRetries = q.int("dialer_retries") + o.DialerRetryTimeout = q.duration("dialer_retry_timeout") + o.ReadTimeout = q.duration("read_timeout") + o.WriteTimeout = q.duration("write_timeout") + o.ContextTimeoutEnabled = q.bool("context_timeout_enabled") + o.PoolFIFO = q.bool("pool_fifo") + o.PoolSize = q.int("pool_size") + o.MaxConcurrentDials = q.int("max_concurrent_dials") + o.MinIdleConns = q.int("min_idle_conns") + o.MaxIdleConns = q.int("max_idle_conns") + o.MaxActiveConns = q.int("max_active_conns") + o.ConnMaxLifetime = q.duration("conn_max_lifetime") + if q.has("conn_max_lifetime_jitter") { + o.ConnMaxLifetimeJitter = min(q.duration("conn_max_lifetime_jitter"), o.ConnMaxLifetime) + } + o.ConnMaxIdleTime = q.duration("conn_max_idle_time") + o.PoolTimeout = q.duration("pool_timeout") + o.DisableIdentity = q.bool("disableIdentity") + o.IdentitySuffix = q.string("identitySuffix") + o.UnstableResp3 = q.bool("unstable_resp3") + + if q.err != nil { + return nil, q.err + } + + if tmp := q.string("db"); tmp != "" { + db, err := strconv.Atoi(tmp) + if err != nil { + return nil, fmt.Errorf("redis: invalid database number: %w", err) + } + o.DB = db + } + + addrs := q.strings("addr") + for _, addr := range addrs { + h, p, err := net.SplitHostPort(addr) + if err != nil || h == "" || p == "" { + return nil, fmt.Errorf("redis: unable to parse addr param: %s", addr) + } + + o.SentinelAddrs = append(o.SentinelAddrs, net.JoinHostPort(h, p)) + } + + if o.TLSConfig != nil && q.has("skip_verify") { + o.TLSConfig.InsecureSkipVerify = q.bool("skip_verify") + } + + // any parameters left? + if r := q.remaining(); len(r) > 0 { + return nil, fmt.Errorf("redis: unexpected option: %s", strings.Join(r, ", ")) + } + + return o, nil } // NewFailoverClient returns a Redis client that uses Redis Sentinel // for automatic failover. It's safe for concurrent use by multiple // goroutines. +// Passing nil FailoverOptions will cause a panic. func NewFailoverClient(failoverOpt *FailoverOptions) *Client { if failoverOpt == nil { panic("redis: NewFailoverClient nil options") @@ -251,24 +534,42 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt.Dialer = masterReplicaDialer(failover) opt.init() - var connPool *pool.ConnPool - rdb := &Client{ baseClient: &baseClient{ - opt: opt, + opt: opt, + onClose: &onCloseHooks{}, }, } rdb.init() - connPool = newConnPool(opt, rdb.dialHook) - rdb.connPool = connPool - rdb.onClose = failover.Close + // Initialize push notification processor using shared helper + // Use void processor by default for RESP2 connections + rdb.pushProcessor = initializePushProcessor(opt) + + // Generate unique pool names for metrics + uniqueID := generateUniqueID() + mainPoolName := opt.Addr + "_" + uniqueID + pubsubPoolName := opt.Addr + "_" + uniqueID + "_pubsub" + + var err error + rdb.connPool, err = newConnPool(opt, rdb.dialHook, mainPoolName) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook, pubsubPoolName) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } + + rdb.onClose.register(onCloseHookIDSentinelFailover, failover.Close) failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { - _ = connPool.Filter(func(cn *pool.Conn) bool { - return cn.RemoteAddr().String() != addr - }) + if connPool, ok := rdb.connPool.(*pool.ConnPool); ok { + _ = connPool.Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != addr + }) + } } failover.mu.Unlock() @@ -313,9 +614,10 @@ func masterReplicaDialer( // SentinelClient is a client for a Redis Sentinel. type SentinelClient struct { *baseClient - hooksMixin } +// NewSentinelClient returns a Redis Sentinel client. +// Passing nil Options will cause a panic. func NewSentinelClient(opt *Options) *SentinelClient { if opt == nil { panic("redis: NewSentinelClient nil options") @@ -323,19 +625,51 @@ func NewSentinelClient(opt *Options) *SentinelClient { opt.init() c := &SentinelClient{ baseClient: &baseClient{ - opt: opt, + opt: opt, + onClose: &onCloseHooks{}, }, } + // Initialize push notification processor using shared helper + // Use void processor for Sentinel clients + c.pushProcessor = NewVoidPushNotificationProcessor() + c.initHooks(hooks{ dial: c.baseClient.dial, process: c.baseClient.process, }) - c.connPool = newConnPool(opt, c.dialHook) + + // Generate unique pool names for metrics + uniqueID := generateUniqueID() + mainPoolName := opt.Addr + "_" + uniqueID + pubsubPoolName := opt.Addr + "_" + uniqueID + "_pubsub" + + var err error + c.connPool, err = newConnPool(opt, c.dialHook, mainPoolName) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + c.pubSubPool, err = newPubSubPool(opt, c.dialHook, pubsubPoolName) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } return c } +// GetPushNotificationHandler returns the handler for a specific push notification name. +// Returns nil if no handler is registered for the given name. +func (c *SentinelClient) GetPushNotificationHandler(pushNotificationName string) push.NotificationHandler { + return c.pushProcessor.GetHandler(pushNotificationName) +} + +// RegisterPushNotificationHandler registers a handler for a specific push notification name. +// Returns an error if a handler is already registered for this push notification name. +// If protected is true, the handler cannot be unregistered. +func (c *SentinelClient) RegisterPushNotificationHandler(pushNotificationName string, handler push.NotificationHandler, protected bool) error { + return c.pushProcessor.RegisterHandler(pushNotificationName, handler, protected) +} + func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { err := c.processHook(ctx, cmd) cmd.SetErr(err) @@ -345,13 +679,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { func (c *SentinelClient) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { + cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels) + if err != nil { + return nil, err + } + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + return nil, err + } + // Track connection in PubSubPool + c.pubSubPool.TrackConn(cn) + return cn, nil }, - closeConn: c.connPool.CloseConn, + closeConn: func(cn *pool.Conn) error { + // Untrack connection from PubSubPool + c.pubSubPool.UntrackConn(cn) + _ = cn.Close() + return nil + }, + pushProcessor: c.pushProcessor, } pubsub.init() + return pubsub } @@ -486,10 +838,10 @@ type sentinelFailover struct { onFailover func(ctx context.Context, addr string) onUpdate func(ctx context.Context) - mu sync.RWMutex - _masterAddr string - sentinel *SentinelClient - pubsub *PubSub + mu sync.RWMutex + masterAddr string + sentinel *SentinelClient + pubsub *PubSub } func (c *sentinelFailover) Close() error { @@ -545,7 +897,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { if sentinel != nil { addr, err := c.getMasterAddr(ctx, sentinel) if err != nil { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if isContextError(ctx.Err()) { return "", err } // Continue on other errors @@ -563,7 +915,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { addr, err := c.getMasterAddr(ctx, c.sentinel) if err != nil { _ = c.closeSentinel() - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if isContextError(ctx.Err()) { return "", err } // Continue on other errors @@ -574,6 +926,11 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { } } + // short circuit if no sentinels configured + if len(c.sentinelAddrs) == 0 { + return "", errors.New("redis: no sentinels configured") + } + var ( masterAddr string wg sync.WaitGroup @@ -597,6 +954,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { errCh <- err return } + once.Do(func() { masterAddr = net.JoinHostPort(addrVal[0], addrVal[1]) // Push working sentinel to the top @@ -605,6 +963,10 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { internal.Logger.Printf(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) cancel() }) + + if sentinelCli != c.sentinel { + _ = sentinelCli.Close() + } }(i, sentinelAddr) } @@ -628,7 +990,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo if sentinel != nil { addrs, err := c.getReplicaAddrs(ctx, sentinel) if err != nil { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if isContextError(ctx.Err()) { return nil, err } // Continue on other errors @@ -646,7 +1008,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo addrs, err := c.getReplicaAddrs(ctx, c.sentinel) if err != nil { _ = c.closeSentinel() - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if isContextError(ctx.Err()) { return nil, err } // Continue on other errors @@ -654,8 +1016,16 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo c.opt.MasterName, err) } else if len(addrs) > 0 { return addrs, nil + } else if !useDisconnected { + // No error and no replicas — valid steady state for master-only setups. + // Preserve the sentinel connection for master discovery and failover + // pub/sub monitoring. Only return early when useDisconnected is false; + // when true, fall through to the discovery loop which passes + // useDisconnected to parseReplicaAddrs (getReplicaAddrs hardcodes false). + return []string{}, nil } else { - // No error and no replicas. + // useDisconnected=true: close sentinel so the discovery loop can call + // setSentinel if it finds disconnected replicas. _ = c.closeSentinel() } } @@ -668,7 +1038,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo replicas, err := sentinel.Replicas(ctx, c.opt.MasterName).Result() if err != nil { _ = sentinel.Close() - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if isContextError(ctx.Err()) { return nil, err } internal.Logger.Printf(ctx, "sentinel: Replicas master=%q failed: %s", @@ -688,9 +1058,9 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo } if sentinelReachable { - return []string{}, nil + return nil, nil } - return []string{}, errors.New("redis: all sentinels specified in configuration are unreachable") + return nil, errors.New("redis: all sentinels specified in configuration are unreachable") } func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *SentinelClient) (string, error) { @@ -737,7 +1107,7 @@ func parseReplicaAddrs(addrs []map[string]string, keepDisconnected bool) []strin func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { c.mu.RLock() - currentAddr := c._masterAddr //nolint:ifshort + currentAddr := c.masterAddr //nolint:ifshort c.mu.RUnlock() if addr == currentAddr { @@ -747,10 +1117,10 @@ func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { c.mu.Lock() defer c.mu.Unlock() - if addr == c._masterAddr { + if addr == c.masterAddr { return } - c._masterAddr = addr + c.masterAddr = addr internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q", c.opt.MasterName, addr) @@ -787,7 +1157,7 @@ func (c *sentinelFailover) discoverSentinels(ctx context.Context) { } if ip != "" && port != "" { sentinelAddr := net.JoinHostPort(ip, port) - if !contains(c.sentinelAddrs, sentinelAddr) { + if !slices.Contains(c.sentinelAddrs, sentinelAddr) { internal.Logger.Printf(ctx, "sentinel: discovered new sentinel=%q for master=%q", sentinelAddr, c.opt.MasterName) c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr) @@ -821,19 +1191,11 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { } } -func contains(slice []string, str string) bool { - for _, s := range slice { - if s == str { - return true - } - } - return false -} - //------------------------------------------------------------------------------ // NewFailoverClusterClient returns a client that supports routing read-only commands // to a replica node. +// Passing nil FailoverOptions will cause a panic. func NewFailoverClusterClient(failoverOpt *FailoverOptions) *ClusterClient { if failoverOpt == nil { panic("redis: NewFailoverClusterClient nil options") diff --git a/vendor/github.com/redis/go-redis/v9/set_commands.go b/vendor/github.com/redis/go-redis/v9/set_commands.go index cef8ad6d8..2a465728b 100644 --- a/vendor/github.com/redis/go-redis/v9/set_commands.go +++ b/vendor/github.com/redis/go-redis/v9/set_commands.go @@ -1,7 +1,13 @@ package redis -import "context" +import ( + "context" + "github.com/redis/go-redis/v9/internal/hashtag" +) + +// SetCmdable is an interface for Redis set commands. +// Sets are unordered collections of unique strings. type SetCmdable interface { SAdd(ctx context.Context, key string, members ...interface{}) *IntCmd SCard(ctx context.Context, key string) *IntCmd @@ -25,8 +31,12 @@ type SetCmdable interface { SUnionStore(ctx context.Context, destination string, keys ...string) *IntCmd } -//------------------------------------------------------------------------------ - +// Returns the number of elements that were added to the set, not including all +// the elements already present in the set. +// +// For more information about the command please refer to [SADD]. +// +// [SADD]: (https://redis.io/docs/latest/commands/sadd/) func (c cmdable) SAdd(ctx context.Context, key string, members ...interface{}) *IntCmd { args := make([]interface{}, 2, 2+len(members)) args[0] = "sadd" @@ -37,12 +47,25 @@ func (c cmdable) SAdd(ctx context.Context, key string, members ...interface{}) * return cmd } +// Returns the set cardinality (number of elements) of the set stored at key. +// Returns 0 if key does not exist. +// +// For more information about the command please refer to [SCARD]. +// +// [SCARD]: (https://redis.io/docs/latest/commands/scard/) func (c cmdable) SCard(ctx context.Context, key string) *IntCmd { cmd := NewIntCmd(ctx, "scard", key) _ = c(ctx, cmd) return cmd } +// Returns the members of the set resulting from the difference between the first set +// and all the successive sets. +// Keys that do not exist are considered to be empty sets. +// +// For more information about the command please refer to [SDIFF]. +// +// [SDIFF]: (https://redis.io/docs/latest/commands/sdiff/) func (c cmdable) SDiff(ctx context.Context, keys ...string) *StringSliceCmd { args := make([]interface{}, 1+len(keys)) args[0] = "sdiff" @@ -54,6 +77,13 @@ func (c cmdable) SDiff(ctx context.Context, keys ...string) *StringSliceCmd { return cmd } +// Stores the members of the set resulting from the difference between the first set +// and all the successive sets into destination. +// If destination already exists, it is overwritten. +// +// For more information about the command please refer to [SDIFFSTORE]. +// +// [SDIFFSTORE]: (https://redis.io/docs/latest/commands/sdiffstore/) func (c cmdable) SDiffStore(ctx context.Context, destination string, keys ...string) *IntCmd { args := make([]interface{}, 2+len(keys)) args[0] = "sdiffstore" @@ -66,6 +96,13 @@ func (c cmdable) SDiffStore(ctx context.Context, destination string, keys ...str return cmd } +// Returns the members of the set resulting from the intersection of all the given sets. +// Keys that do not exist are considered to be empty sets. +// With one of the keys being an empty set, the resulting set is also empty. +// +// For more information about the command please refer to [SINTER]. +// +// [SINTER]: (https://redis.io/docs/latest/commands/sinter/) func (c cmdable) SInter(ctx context.Context, keys ...string) *StringSliceCmd { args := make([]interface{}, 1+len(keys)) args[0] = "sinter" @@ -77,22 +114,38 @@ func (c cmdable) SInter(ctx context.Context, keys ...string) *StringSliceCmd { return cmd } +// Returns the cardinality of the set resulting from the intersection of all the given sets. +// Keys that do not exist are considered to be empty sets. +// With one of the keys being an empty set, the resulting set is also empty. +// +// The limit parameter sets an upper bound on the number of results returned. +// If limit is 0, no limit is applied. +// +// For more information about the command please refer to [SINTERCARD]. +// +// [SINTERCARD]: (https://redis.io/docs/latest/commands/sintercard/) func (c cmdable) SInterCard(ctx context.Context, limit int64, keys ...string) *IntCmd { - args := make([]interface{}, 4+len(keys)) + numKeys := len(keys) + args := make([]interface{}, 4+numKeys) args[0] = "sintercard" - numkeys := int64(0) + args[1] = numKeys for i, key := range keys { args[2+i] = key - numkeys++ } - args[1] = numkeys - args[2+numkeys] = "limit" - args[3+numkeys] = limit + args[2+numKeys] = "limit" + args[3+numKeys] = limit cmd := NewIntCmd(ctx, args...) _ = c(ctx, cmd) return cmd } +// Stores the members of the set resulting from the intersection of all the given sets +// into destination. +// If destination already exists, it is overwritten. +// +// For more information about the command please refer to [SINTERSTORE]. +// +// [SINTERSTORE]: (https://redis.io/docs/latest/commands/sinterstore/) func (c cmdable) SInterStore(ctx context.Context, destination string, keys ...string) *IntCmd { args := make([]interface{}, 2+len(keys)) args[0] = "sinterstore" @@ -105,13 +158,26 @@ func (c cmdable) SInterStore(ctx context.Context, destination string, keys ...st return cmd } +// Returns if member is a member of the set stored at key. +// Returns true if the element is a member of the set, false if it is not a member +// or if key does not exist. +// +// For more information about the command please refer to [SISMEMBER]. +// +// [SISMEMBER]: (https://redis.io/docs/latest/commands/sismember/) func (c cmdable) SIsMember(ctx context.Context, key string, member interface{}) *BoolCmd { cmd := NewBoolCmd(ctx, "sismember", key, member) _ = c(ctx, cmd) return cmd } -// SMIsMember Redis `SMISMEMBER key member [member ...]` command. +// Returns whether each member is a member of the set stored at key. +// For each member, returns true if the element is a member of the set, false if it is not +// a member or if key does not exist. +// +// For more information about the command please refer to [SMISMEMBER]. +// +// [SMISMEMBER]: (https://redis.io/docs/latest/commands/smismember/) func (c cmdable) SMIsMember(ctx context.Context, key string, members ...interface{}) *BoolSliceCmd { args := make([]interface{}, 2, 2+len(members)) args[0] = "smismember" @@ -122,54 +188,100 @@ func (c cmdable) SMIsMember(ctx context.Context, key string, members ...interfac return cmd } -// SMembers Redis `SMEMBERS key` command output as a slice. +// Returns all the members of the set value stored at key. +// Returns an empty slice if key does not exist. +// +// For more information about the command please refer to [SMEMBERS]. +// +// [SMEMBERS]: (https://redis.io/docs/latest/commands/smembers/) func (c cmdable) SMembers(ctx context.Context, key string) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "smembers", key) _ = c(ctx, cmd) return cmd } -// SMembersMap Redis `SMEMBERS key` command output as a map. +// Returns all the members of the set value stored at key as a map. +// Returns an empty map if key does not exist. +// +// For more information about the command please refer to [SMEMBERS]. +// +// [SMEMBERS]: (https://redis.io/docs/latest/commands/smembers/) func (c cmdable) SMembersMap(ctx context.Context, key string) *StringStructMapCmd { cmd := NewStringStructMapCmd(ctx, "smembers", key) _ = c(ctx, cmd) return cmd } +// Moves member from the set at source to the set at destination. +// This operation is atomic. In every given moment the element will appear to be a member +// of source or destination for other clients. +// +// For more information about the command please refer to [SMOVE]. +// +// [SMOVE]: (https://redis.io/docs/latest/commands/smove/) func (c cmdable) SMove(ctx context.Context, source, destination string, member interface{}) *BoolCmd { cmd := NewBoolCmd(ctx, "smove", source, destination, member) _ = c(ctx, cmd) return cmd } -// SPop Redis `SPOP key` command. +// Removes and returns one or more random members from the set value stored at key. +// This version returns a single random member. +// +// For more information about the command please refer to [SPOP]. +// +// [SPOP]: (https://redis.io/docs/latest/commands/spop/) func (c cmdable) SPop(ctx context.Context, key string) *StringCmd { cmd := NewStringCmd(ctx, "spop", key) _ = c(ctx, cmd) return cmd } -// SPopN Redis `SPOP key count` command. +// Removes and returns one or more random members from the set value stored at key. +// This version returns up to count random members. +// +// For more information about the command please refer to [SPOP]. +// +// [SPOP]: (https://redis.io/docs/latest/commands/spop/) func (c cmdable) SPopN(ctx context.Context, key string, count int64) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "spop", key, count) _ = c(ctx, cmd) return cmd } -// SRandMember Redis `SRANDMEMBER key` command. +// Returns a random member from the set value stored at key. +// This version returns a single random member without removing it. +// +// For more information about the command please refer to [SRANDMEMBER]. +// +// [SRANDMEMBER]: (https://redis.io/docs/latest/commands/srandmember/) func (c cmdable) SRandMember(ctx context.Context, key string) *StringCmd { cmd := NewStringCmd(ctx, "srandmember", key) _ = c(ctx, cmd) return cmd } -// SRandMemberN Redis `SRANDMEMBER key count` command. +// Returns an array of random members from the set value stored at key. +// This version returns up to count random members without removing them. +// When called with a positive count, returns distinct elements. +// When called with a negative count, allows for repeated elements. +// +// For more information about the command please refer to [SRANDMEMBER]. +// +// [SRANDMEMBER]: (https://redis.io/docs/latest/commands/srandmember/) func (c cmdable) SRandMemberN(ctx context.Context, key string, count int64) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "srandmember", key, count) _ = c(ctx, cmd) return cmd } +// Removes the specified members from the set stored at key. +// Specified members that are not a member of this set are ignored. +// If key does not exist, it is treated as an empty set and this command returns 0. +// +// For more information about the command please refer to [SREM]. +// +// [SREM]: (https://redis.io/docs/latest/commands/srem/) func (c cmdable) SRem(ctx context.Context, key string, members ...interface{}) *IntCmd { args := make([]interface{}, 2, 2+len(members)) args[0] = "srem" @@ -180,6 +292,12 @@ func (c cmdable) SRem(ctx context.Context, key string, members ...interface{}) * return cmd } +// Returns the members of the set resulting from the union of all the given sets. +// Keys that do not exist are considered to be empty sets. +// +// For more information about the command please refer to [SUNION]. +// +// [SUNION]: (https://redis.io/docs/latest/commands/sunion/) func (c cmdable) SUnion(ctx context.Context, keys ...string) *StringSliceCmd { args := make([]interface{}, 1+len(keys)) args[0] = "sunion" @@ -191,6 +309,13 @@ func (c cmdable) SUnion(ctx context.Context, keys ...string) *StringSliceCmd { return cmd } +// Stores the members of the set resulting from the union of all the given sets +// into destination. +// If destination already exists, it is overwritten. +// +// For more information about the command please refer to [SUNIONSTORE]. +// +// [SUNIONSTORE]: (https://redis.io/docs/latest/commands/sunionstore/) func (c cmdable) SUnionStore(ctx context.Context, destination string, keys ...string) *IntCmd { args := make([]interface{}, 2+len(keys)) args[0] = "sunionstore" @@ -203,6 +328,17 @@ func (c cmdable) SUnionStore(ctx context.Context, destination string, keys ...st return cmd } +// Incrementally iterates the set elements stored at key. +// This is a cursor-based iterator that allows scanning large sets efficiently. +// +// Parameters: +// - cursor: The cursor value for the iteration (use 0 to start a new scan) +// - match: Optional pattern to match elements (empty string means no pattern) +// - count: Optional hint about how many elements to return per iteration +// +// For more information about the command please refer to [SSCAN]. +// +// [SSCAN]: (https://redis.io/docs/latest/commands/sscan/) func (c cmdable) SScan(ctx context.Context, key string, cursor uint64, match string, count int64) *ScanCmd { args := []interface{}{"sscan", key, cursor} if match != "" { @@ -212,6 +348,9 @@ func (c cmdable) SScan(ctx context.Context, key string, cursor uint64, match str args = append(args, "count", count) } cmd := NewScanCmd(ctx, c, args...) + if hashtag.Present(match) { + cmd.SetFirstKeyPos(4) + } _ = c(ctx, cmd) return cmd } diff --git a/vendor/github.com/redis/go-redis/v9/sortedset_commands.go b/vendor/github.com/redis/go-redis/v9/sortedset_commands.go index 670140270..b171d7ac1 100644 --- a/vendor/github.com/redis/go-redis/v9/sortedset_commands.go +++ b/vendor/github.com/redis/go-redis/v9/sortedset_commands.go @@ -2,8 +2,11 @@ package redis import ( "context" + "errors" "strings" "time" + + "github.com/redis/go-redis/v9/internal/hashtag" ) type SortedSetCmdable interface { @@ -257,16 +260,15 @@ func (c cmdable) ZInterWithScores(ctx context.Context, store *ZStore) *ZSliceCmd } func (c cmdable) ZInterCard(ctx context.Context, limit int64, keys ...string) *IntCmd { - args := make([]interface{}, 4+len(keys)) + numKeys := len(keys) + args := make([]interface{}, 4+numKeys) args[0] = "zintercard" - numkeys := int64(0) + args[1] = numKeys for i, key := range keys { args[2+i] = key - numkeys++ } - args[1] = numkeys - args[2+numkeys] = "limit" - args[3+numkeys] = limit + args[2+numKeys] = "limit" + args[3+numKeys] = limit cmd := NewIntCmd(ctx, args...) _ = c(ctx, cmd) return cmd @@ -312,7 +314,9 @@ func (c cmdable) ZPopMax(ctx context.Context, key string, count ...int64) *ZSlic case 1: args = append(args, count[0]) default: - panic("too many arguments") + cmd := NewZSliceCmd(ctx) + cmd.SetErr(errors.New("too many arguments")) + return cmd } cmd := NewZSliceCmd(ctx, args...) @@ -332,7 +336,9 @@ func (c cmdable) ZPopMin(ctx context.Context, key string, count ...int64) *ZSlic case 1: args = append(args, count[0]) default: - panic("too many arguments") + cmd := NewZSliceCmd(ctx) + cmd.SetErr(errors.New("too many arguments")) + return cmd } cmd := NewZSliceCmd(ctx, args...) @@ -367,6 +373,17 @@ type ZRangeArgs struct { // } // cmd: "ZRange example-key (3 8 ByScore" (3 < score <= 8). // + // When the Rev option is also provided, should be the higher score value and + // should be the lower score value (i.e. reversed order): + // ZRangeArgs{ + // Key: "example-key", + // Start: 8, + // Stop: "(3", + // ByScore: true, + // Rev: true, + // } + // cmd: "ZRange example-key 8 (3 ByScore Rev" (8 >= score > 3, in reverse order). + // // For the ByLex option, it is similar to the deprecated(6.2.0+) ZRangeByLex command. // You can set the and options as follows: // ZRangeArgs{ @@ -377,6 +394,17 @@ type ZRangeArgs struct { // } // cmd: "ZRange example-key [abc (def ByLex" // + // When the Rev option is also provided, should be the lexicographically higher + // value and should be the lower value: + // ZRangeArgs{ + // Key: "example-key", + // Start: "(def", + // Stop: "[abc", + // ByLex: true, + // Rev: true, + // } + // cmd: "ZRange example-key (def [abc ByLex Rev" + // // For normal cases (ByScore==false && ByLex==false), and should be set to the index range (int). // You can read the documentation for more information: https://redis.io/commands/zrange Start interface{} @@ -394,12 +422,7 @@ type ZRangeArgs struct { } func (z ZRangeArgs) appendArgs(args []interface{}) []interface{} { - // For Rev+ByScore/ByLex, we need to adjust the position of and . - if z.Rev && (z.ByScore || z.ByLex) { - args = append(args, z.Key, z.Stop, z.Start) - } else { - args = append(args, z.Key, z.Start, z.Stop) - } + args = append(args, z.Key, z.Start, z.Stop) if z.ByScore { args = append(args, "byscore") @@ -473,10 +496,16 @@ func (c cmdable) zRangeBy(ctx context.Context, zcmd, key string, opt *ZRangeBy, return cmd } +// ZRangeByScore returns members in a sorted set within a range of scores. +// +// Deprecated: Use ZRangeArgs with ByScore option instead as of Redis 6.2.0. func (c cmdable) ZRangeByScore(ctx context.Context, key string, opt *ZRangeBy) *StringSliceCmd { return c.zRangeBy(ctx, "zrangebyscore", key, opt, false) } +// ZRangeByLex returns members in a sorted set within a lexicographical range. +// +// Deprecated: Use ZRangeArgs with ByLex option instead as of Redis 6.2.0. func (c cmdable) ZRangeByLex(ctx context.Context, key string, opt *ZRangeBy) *StringSliceCmd { return c.zRangeBy(ctx, "zrangebylex", key, opt, false) } @@ -553,6 +582,9 @@ func (c cmdable) ZRemRangeByLex(ctx context.Context, key, min, max string) *IntC return cmd } +// ZRevRange returns members in a sorted set within a range of indexes in reverse order. +// +// Deprecated: Use ZRangeArgs with Rev option instead as of Redis 6.2.0. func (c cmdable) ZRevRange(ctx context.Context, key string, start, stop int64) *StringSliceCmd { cmd := NewStringSliceCmd(ctx, "zrevrange", key, start, stop) _ = c(ctx, cmd) @@ -582,10 +614,16 @@ func (c cmdable) zRevRangeBy(ctx context.Context, zcmd, key string, opt *ZRangeB return cmd } +// ZRevRangeByScore returns members in a sorted set within a range of scores in reverse order. +// +// Deprecated: Use ZRangeArgs with Rev and ByScore options instead as of Redis 6.2.0. func (c cmdable) ZRevRangeByScore(ctx context.Context, key string, opt *ZRangeBy) *StringSliceCmd { return c.zRevRangeBy(ctx, "zrevrangebyscore", key, opt) } +// ZRevRangeByLex returns members in a sorted set within a lexicographical range in reverse order. +// +// Deprecated: Use ZRangeArgs with Rev and ByLex options instead as of Redis 6.2.0. func (c cmdable) ZRevRangeByLex(ctx context.Context, key string, opt *ZRangeBy) *StringSliceCmd { return c.zRevRangeBy(ctx, "zrevrangebylex", key, opt) } @@ -720,6 +758,9 @@ func (c cmdable) ZScan(ctx context.Context, key string, cursor uint64, match str args = append(args, "count", count) } cmd := NewScanCmd(ctx, c, args...) + if hashtag.Present(match) { + cmd.SetFirstKeyPos(4) + } _ = c(ctx, cmd) return cmd } @@ -740,7 +781,7 @@ type ZWithKey struct { type ZStore struct { Keys []string Weights []float64 - // Can be SUM, MIN or MAX. + // Can be SUM, MIN, MAX or COUNT. Aggregate string } diff --git a/vendor/github.com/redis/go-redis/v9/stream_commands.go b/vendor/github.com/redis/go-redis/v9/stream_commands.go index 6d7b22922..71191aec4 100644 --- a/vendor/github.com/redis/go-redis/v9/stream_commands.go +++ b/vendor/github.com/redis/go-redis/v9/stream_commands.go @@ -2,12 +2,18 @@ package redis import ( "context" + "strconv" + "strings" "time" + + "github.com/redis/go-redis/v9/internal/otel" ) type StreamCmdable interface { XAdd(ctx context.Context, a *XAddArgs) *StringCmd + XAckDel(ctx context.Context, stream string, group string, mode string, ids ...string) *SliceCmd XDel(ctx context.Context, stream string, ids ...string) *IntCmd + XDelEx(ctx context.Context, stream string, mode string, ids ...string) *SliceCmd XLen(ctx context.Context, stream string) *IntCmd XRange(ctx context.Context, stream, start, stop string) *XMessageSliceCmd XRangeN(ctx context.Context, stream, start, stop string, count int64) *XMessageSliceCmd @@ -23,20 +29,27 @@ type StreamCmdable interface { XGroupDelConsumer(ctx context.Context, stream, group, consumer string) *IntCmd XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSliceCmd XAck(ctx context.Context, stream, group string, ids ...string) *IntCmd + XNack(ctx context.Context, a *XNackArgs) *IntCmd XPending(ctx context.Context, stream, group string) *XPendingCmd XPendingExt(ctx context.Context, a *XPendingExtArgs) *XPendingExtCmd XClaim(ctx context.Context, a *XClaimArgs) *XMessageSliceCmd XClaimJustID(ctx context.Context, a *XClaimArgs) *StringSliceCmd XAutoClaim(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimCmd + XAutoClaimWithDeleted(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimWithDeletedCmd XAutoClaimJustID(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimJustIDCmd XTrimMaxLen(ctx context.Context, key string, maxLen int64) *IntCmd XTrimMaxLenApprox(ctx context.Context, key string, maxLen, limit int64) *IntCmd + XTrimMaxLenMode(ctx context.Context, key string, maxLen int64, mode string) *IntCmd + XTrimMaxLenApproxMode(ctx context.Context, key string, maxLen, limit int64, mode string) *IntCmd XTrimMinID(ctx context.Context, key string, minID string) *IntCmd XTrimMinIDApprox(ctx context.Context, key string, minID string, limit int64) *IntCmd + XTrimMinIDMode(ctx context.Context, key string, minID string, mode string) *IntCmd + XTrimMinIDApproxMode(ctx context.Context, key string, minID string, limit int64, mode string) *IntCmd XInfoGroups(ctx context.Context, key string) *XInfoGroupsCmd XInfoStream(ctx context.Context, key string) *XInfoStreamCmd XInfoStreamFull(ctx context.Context, key string, count int) *XInfoStreamFullCmd XInfoConsumers(ctx context.Context, key string, group string) *XInfoConsumersCmd + XCfgSet(ctx context.Context, a *XCfgSetArgs) *StatusCmd } // XAddArgs accepts values in the following formats: @@ -46,41 +59,69 @@ type StreamCmdable interface { // // Note that map will not preserve the order of key-value pairs. // MaxLen/MaxLenApprox and MinID are in conflict, only one of them can be used. +// +// For idempotent production (at-most-once production): +// - ProducerID: A unique identifier for the producer (required for both IDMP and IDMPAUTO) +// - IdempotentID: A unique identifier for the message (used with IDMP) +// - IdempotentAuto: If true, Redis will auto-generate an idempotent ID based on message content (IDMPAUTO) +// +// ProducerID and IdempotentID are mutually exclusive with IdempotentAuto. +// When using idempotent production, ID must be "*" or empty. type XAddArgs struct { Stream string NoMkStream bool MaxLen int64 // MAXLEN N MinID string // Approx causes MaxLen and MinID to use "~" matcher (instead of "="). - Approx bool - Limit int64 - ID string - Values interface{} + Approx bool + Limit int64 + Mode string + ID string + Values interface{} + ProducerID string // Producer ID for idempotent production (IDMP or IDMPAUTO) + IdempotentID string // Idempotent ID for IDMP + IdempotentAuto bool // Use IDMPAUTO to auto-generate idempotent ID based on content } func (c cmdable) XAdd(ctx context.Context, a *XAddArgs) *StringCmd { - args := make([]interface{}, 0, 11) + args := make([]interface{}, 0, 15) args = append(args, "xadd", a.Stream) if a.NoMkStream { args = append(args, "nomkstream") } + + if a.Mode != "" { + args = append(args, a.Mode) + } + + if a.ProducerID != "" { + if a.IdempotentAuto { + // IDMPAUTO pid + args = append(args, "idmpauto", a.ProducerID) + } else if a.IdempotentID != "" { + // IDMP pid iid + args = append(args, "idmp", a.ProducerID, a.IdempotentID) + } + } + switch { case a.MaxLen > 0: if a.Approx { args = append(args, "maxlen", "~", a.MaxLen) } else { - args = append(args, "maxlen", a.MaxLen) + args = append(args, "maxlen", "=", a.MaxLen) } case a.MinID != "": if a.Approx { args = append(args, "minid", "~", a.MinID) } else { - args = append(args, "minid", a.MinID) + args = append(args, "minid", "=", a.MinID) } } if a.Limit > 0 { args = append(args, "limit", a.Limit) } + if a.ID != "" { args = append(args, a.ID) } else { @@ -93,6 +134,16 @@ func (c cmdable) XAdd(ctx context.Context, a *XAddArgs) *StringCmd { return cmd } +func (c cmdable) XAckDel(ctx context.Context, stream string, group string, mode string, ids ...string) *SliceCmd { + args := []interface{}{"xackdel", stream, group, mode, "ids", len(ids)} + for _, id := range ids { + args = append(args, id) + } + cmd := NewSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) XDel(ctx context.Context, stream string, ids ...string) *IntCmd { args := []interface{}{"xdel", stream} for _, id := range ids { @@ -103,6 +154,16 @@ func (c cmdable) XDel(ctx context.Context, stream string, ids ...string) *IntCmd return cmd } +func (c cmdable) XDelEx(ctx context.Context, stream string, mode string, ids ...string) *SliceCmd { + args := []interface{}{"xdelex", stream, mode, "ids", len(ids)} + for _, id := range ids { + args = append(args, id) + } + cmd := NewSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) XLen(ctx context.Context, stream string) *IntCmd { cmd := NewIntCmd(ctx, "xlen", stream) _ = c(ctx, cmd) @@ -231,6 +292,7 @@ type XReadGroupArgs struct { Count int64 Block time.Duration NoAck bool + Claim time.Duration // Claim idle pending entries older than this duration } func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSliceCmd { @@ -250,6 +312,10 @@ func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSlic args = append(args, "noack") keyPos++ } + if a.Claim > 0 { + args = append(args, "claim", int64(a.Claim/time.Millisecond)) + keyPos += 2 + } args = append(args, "streams") keyPos++ for _, s := range a.Streams { @@ -262,6 +328,26 @@ func (c cmdable) XReadGroup(ctx context.Context, a *XReadGroupArgs) *XStreamSlic } cmd.SetFirstKeyPos(keyPos) _ = c(ctx, cmd) + + // Record stream lag for each message (if command succeeded) + if cmd.Err() == nil { + streams := cmd.Val() + for _, stream := range streams { + for _, msg := range stream.Messages { + // Parse message ID to extract timestamp (format: "millisecondsTime-sequenceNumber") + if parts := strings.SplitN(msg.ID, "-", 2); len(parts) == 2 { + if timestampMs, err := strconv.ParseInt(parts[0], 10, 64); err == nil { + // Calculate lag (time since message was created) + messageTime := time.Unix(0, timestampMs*int64(time.Millisecond)) + lag := time.Since(messageTime) + // Record lag metric + otel.RecordStreamLag(ctx, lag, nil, stream.Stream, a.Group, a.Consumer) + } + } + } + } + } + return cmd } @@ -275,6 +361,71 @@ func (c cmdable) XAck(ctx context.Context, stream, group string, ids ...string) return cmd } +// XNACK modes. See [XNackArgs.Mode]. +const ( + XNackModeSilent = "SILENT" + XNackModeFail = "FAIL" + XNackModeFatal = "FATAL" +) + +// XNackArgs represents the arguments for the XNACK command (Redis >= 8.8). +// +// XNACK negatively acknowledges one or more messages in a consumer group's +// Pending Entries List (PEL), releasing them back to the group so they can be +// redelivered to another consumer via XREADGROUP. +type XNackArgs struct { + Stream string + Group string + + // Mode controls how the delivery counter is adjusted for each NACKed entry. + // Must be one of [XNackModeSilent], [XNackModeFail], or [XNackModeFatal]: + // - SILENT: the consumer is shutting down or experiencing internal errors + // unrelated to the message. The delivery counter is decremented by 1, + // undoing the increment that happened when the message was delivered. + // - FAIL: the consumer could not process the message (e.g. insufficient + // memory), but another consumer might succeed. The delivery counter is + // left unchanged. + // - FATAL: the message is invalid or suspected malicious. The delivery + // counter is set to MAXINT, which will immediately move the message to + // the Dead Letter Queue (DLQ) if one is configured for the group. + Mode string + + // IDs is the list of message IDs to NACK. All IDs must already be in the + // group's PEL (i.e. previously delivered via XREADGROUP), unless Force is set. + IDs []string + + // RetryCount sets the delivery counter to an explicit value, overriding the + // counter adjustment that would otherwise be applied by Mode. + // Leave nil to let Mode control the counter (the common case). + RetryCount *uint64 + + // Force allows NACKing message IDs that are not yet in the group's PEL, + // creating new unowned NACKed PEL entries for them directly. + // This is analogous to the FORCE flag in XCLAIM. + // Primarily used internally by Redis during AOF rewrite to reconstruct + // NACKed entries, but can also be used to manually inject entries. + Force bool +} + +// XNack executes the XNACK command. See [XNackArgs] for the full argument documentation. +// Requires Redis >= 8.8. +func (c cmdable) XNack(ctx context.Context, a *XNackArgs) *IntCmd { + args := make([]interface{}, 0, 9+len(a.IDs)) + args = append(args, "xnack", a.Stream, a.Group, a.Mode, "ids", len(a.IDs)) + for _, id := range a.IDs { + args = append(args, id) + } + if a.RetryCount != nil { + args = append(args, "retrycount", *a.RetryCount) + } + if a.Force { + args = append(args, "force") + } + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) XPending(ctx context.Context, stream, group string) *XPendingCmd { cmd := NewXPendingCmd(ctx, "xpending", stream, group) _ = c(ctx, cmd) @@ -322,6 +473,13 @@ func (c cmdable) XAutoClaim(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimC return cmd } +func (c cmdable) XAutoClaimWithDeleted(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimWithDeletedCmd { + args := xAutoClaimArgs(ctx, a) + cmd := NewXAutoClaimWithDeletedCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) XAutoClaimJustID(ctx context.Context, a *XAutoClaimArgs) *XAutoClaimJustIDCmd { args := xAutoClaimArgs(ctx, a) args = append(args, "justid") @@ -375,6 +533,8 @@ func xClaimArgs(a *XClaimArgs) []interface{} { return args } +// TODO: refactor xTrim, xTrimMode and the wrappers over the functions + // xTrim If approx is true, add the "~" parameter, otherwise it is the default "=" (redis default). // example: // @@ -390,6 +550,8 @@ func (c cmdable) xTrim( args = append(args, "xtrim", key, strategy) if approx { args = append(args, "~") + } else { + args = append(args, "=") } args = append(args, threshold) if limit > 0 { @@ -418,6 +580,44 @@ func (c cmdable) XTrimMinIDApprox(ctx context.Context, key string, minID string, return c.xTrim(ctx, key, "minid", true, minID, limit) } +func (c cmdable) xTrimMode( + ctx context.Context, key, strategy string, + approx bool, threshold interface{}, limit int64, + mode string, +) *IntCmd { + args := make([]interface{}, 0, 7) + args = append(args, "xtrim", key, strategy) + if approx { + args = append(args, "~") + } else { + args = append(args, "=") + } + args = append(args, threshold) + if limit > 0 { + args = append(args, "limit", limit) + } + args = append(args, mode) + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +func (c cmdable) XTrimMaxLenMode(ctx context.Context, key string, maxLen int64, mode string) *IntCmd { + return c.xTrimMode(ctx, key, "maxlen", false, maxLen, 0, mode) +} + +func (c cmdable) XTrimMaxLenApproxMode(ctx context.Context, key string, maxLen, limit int64, mode string) *IntCmd { + return c.xTrimMode(ctx, key, "maxlen", true, maxLen, limit, mode) +} + +func (c cmdable) XTrimMinIDMode(ctx context.Context, key string, minID string, mode string) *IntCmd { + return c.xTrimMode(ctx, key, "minid", false, minID, 0, mode) +} + +func (c cmdable) XTrimMinIDApproxMode(ctx context.Context, key string, minID string, limit int64, mode string) *IntCmd { + return c.xTrimMode(ctx, key, "minid", true, minID, limit, mode) +} + func (c cmdable) XInfoConsumers(ctx context.Context, key string, group string) *XInfoConsumersCmd { cmd := NewXInfoConsumersCmd(ctx, key, group) _ = c(ctx, cmd) @@ -448,3 +648,28 @@ func (c cmdable) XInfoStreamFull(ctx context.Context, key string, count int) *XI _ = c(ctx, cmd) return cmd } + +// XCfgSetArgs represents the arguments for the XCFGSET command. +// Duration is the duration, in seconds, that Redis keeps each idempotent ID. +// MaxSize is the maximum number of most recent idempotent IDs that Redis keeps for each producer ID. +type XCfgSetArgs struct { + Stream string + Duration int64 + MaxSize int64 +} + +// XCfgSet sets the idempotent production configuration for a stream. +// XCFGSET key [IDMP-DURATION duration] [IDMP-MAXSIZE maxsize] +func (c cmdable) XCfgSet(ctx context.Context, a *XCfgSetArgs) *StatusCmd { + args := make([]interface{}, 0, 6) + args = append(args, "xcfgset", a.Stream) + if a.Duration > 0 { + args = append(args, "idmp-duration", a.Duration) + } + if a.MaxSize > 0 { + args = append(args, "idmp-maxsize", a.MaxSize) + } + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} diff --git a/vendor/github.com/redis/go-redis/v9/string_commands.go b/vendor/github.com/redis/go-redis/v9/string_commands.go index eff5880dc..88a80844e 100644 --- a/vendor/github.com/redis/go-redis/v9/string_commands.go +++ b/vendor/github.com/redis/go-redis/v9/string_commands.go @@ -2,6 +2,8 @@ package redis import ( "context" + "fmt" + "strings" "time" ) @@ -9,6 +11,8 @@ type StringCmdable interface { Append(ctx context.Context, key, value string) *IntCmd Decr(ctx context.Context, key string) *IntCmd DecrBy(ctx context.Context, key string, decrement int64) *IntCmd + DelExArgs(ctx context.Context, key string, a DelExArgs) *IntCmd + Digest(ctx context.Context, key string) *DigestCmd Get(ctx context.Context, key string) *StringCmd GetRange(ctx context.Context, key string, start, end int64) *StringCmd GetSet(ctx context.Context, key string, value interface{}) *StringCmd @@ -17,13 +21,24 @@ type StringCmdable interface { Incr(ctx context.Context, key string) *IntCmd IncrBy(ctx context.Context, key string, value int64) *IntCmd IncrByFloat(ctx context.Context, key string, value float64) *FloatCmd + IncrEXInt(ctx context.Context, key string, args IncrEXIntArgs) *IncrEXIntCmd + IncrEXFloat(ctx context.Context, key string, args IncrEXFloatArgs) *IncrEXFloatCmd LCS(ctx context.Context, q *LCSQuery) *LCSCmd MGet(ctx context.Context, keys ...string) *SliceCmd MSet(ctx context.Context, values ...interface{}) *StatusCmd MSetNX(ctx context.Context, values ...interface{}) *BoolCmd + MSetEX(ctx context.Context, args MSetEXArgs, values ...interface{}) *IntCmd Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd SetArgs(ctx context.Context, key string, value interface{}, a SetArgs) *StatusCmd SetEx(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd + SetIFEQ(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd + SetIFEQGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd + SetIFNE(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd + SetIFNEGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd + SetIFDEQ(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd + SetIFDEQGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd + SetIFDNE(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd + SetIFDNEGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd SetXX(ctx context.Context, key string, value interface{}, expiration time.Duration) *BoolCmd SetRange(ctx context.Context, key string, offset int64, value string) *IntCmd @@ -48,6 +63,76 @@ func (c cmdable) DecrBy(ctx context.Context, key string, decrement int64) *IntCm return cmd } +// DelExArgs provides arguments for the DelExArgs function. +type DelExArgs struct { + // Mode can be `IFEQ`, `IFNE`, `IFDEQ`, or `IFDNE`. + Mode string + + // MatchValue is used with IFEQ/IFNE modes for compare-and-delete operations. + // - IFEQ: only delete if current value equals MatchValue + // - IFNE: only delete if current value does not equal MatchValue + MatchValue interface{} + + // MatchDigest is used with IFDEQ/IFDNE modes for digest-based compare-and-delete. + // - IFDEQ: only delete if current value's digest equals MatchDigest + // - IFDNE: only delete if current value's digest does not equal MatchDigest + // + // The digest is a uint64 xxh3 hash value. + // + // For examples of client-side digest generation, see: + // example/digest-optimistic-locking/ + MatchDigest uint64 +} + +// DelExArgs Redis `DELEX key [IFEQ|IFNE|IFDEQ|IFDNE] match-value` command. +// Compare-and-delete with flexible conditions. +// +// Returns the number of keys that were removed (0 or 1). +// +// NOTE DelExArgs is still experimental +// it's signature and behaviour may change +func (c cmdable) DelExArgs(ctx context.Context, key string, a DelExArgs) *IntCmd { + args := []interface{}{"delex", key} + + if a.Mode != "" { + args = append(args, a.Mode) + + // Add match value/digest based on mode + switch a.Mode { + case "ifeq", "IFEQ", "ifne", "IFNE": + if a.MatchValue != nil { + args = append(args, a.MatchValue) + } + case "ifdeq", "IFDEQ", "ifdne", "IFDNE": + if a.MatchDigest != 0 { + args = append(args, fmt.Sprintf("%016x", a.MatchDigest)) + } + } + } + + cmd := NewIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// Digest returns the xxh3 hash (uint64) of the specified key's value. +// +// The digest is a 64-bit xxh3 hash that can be used for optimistic locking +// with SetIFDEQ, SetIFDNE, and DelExArgs commands. +// +// For examples of client-side digest generation and usage patterns, see: +// example/digest-optimistic-locking/ +// +// Redis 8.4+. See https://redis.io/commands/digest/ +// +// NOTE Digest is still experimental +// it's signature and behaviour may change +func (c cmdable) Digest(ctx context.Context, key string) *DigestCmd { + cmd := NewDigestCmd(ctx, "digest", key) + _ = c(ctx, cmd) + return cmd +} + // Get Redis `GET key` command. It returns redis.Nil error when key does not exist. func (c cmdable) Get(ctx context.Context, key string) *StringCmd { cmd := NewStringCmd(ctx, "get", key) @@ -61,6 +146,9 @@ func (c cmdable) GetRange(ctx context.Context, key string, start, end int64) *St return cmd } +// GetSet returns the old value stored at key and sets it to the new value. +// +// Deprecated: Use SetArgs with Get option instead as of Redis 6.2.0. func (c cmdable) GetSet(ctx context.Context, key string, value interface{}) *StringCmd { cmd := NewStringCmd(ctx, "getset", key, value) _ = c(ctx, cmd) @@ -112,6 +200,160 @@ func (c cmdable) IncrByFloat(ctx context.Context, key string, value float64) *Fl return cmd } +// IncrEXIntArgs are the arguments to IncrEXInt (the BYINT variant of INCREX). +// +// If By is zero and HasBy is false, the server increments by 1. +// HasLBound/HasUBound gate the optional LBOUND/UBOUND clauses so that 0 is a +// valid bound. Expiration is shared with the SET command via ExpirationOption. +type IncrEXIntArgs struct { + By int64 + HasBy bool + + LBound, UBound int64 + HasLBound, HasUBound bool + + // Saturate clamps the result to LBOUND/UBOUND (or LLONG_MAX/MIN when no + // explicit bound is given) when the increment would exceed it. Without + // this flag, out-of-bounds operations are rejected: the key and TTL are + // left unchanged and the reply is [current_value, 0]. + Saturate bool + + // Expiration sets the TTL semantics: EX, PX, EXAT, PXAT, or PERSIST. + Expiration *ExpirationOption + + // ENX applies the expiration only when the key does not already have an + // expiration. Requires Expiration to set one of EX/PX/EXAT/PXAT. + ENX bool +} + +// IncrEXFloatArgs are the arguments to IncrEXFloat (the BYFLOAT variant of +// INCREX). BYFLOAT is always sent — even when By is zero — to keep the +// operation in float mode on the server side; omitting BYFLOAT would cause +// the server to treat the call as an integer increment by 1. +// HasLBound/HasUBound gate the optional LBOUND/UBOUND clauses so that 0 is +// a valid bound. +type IncrEXFloatArgs struct { + By float64 + + LBound, UBound float64 + HasLBound, HasUBound bool + + // Saturate clamps the result to LBOUND/UBOUND (or ±LDBL_MAX when no + // explicit bound is given) when the increment would exceed it. Without + // this flag, out-of-bounds operations are rejected: the key and TTL are + // left unchanged and the reply is [current_value, 0]. + Saturate bool + + Expiration *ExpirationOption + + ENX bool +} + +// IncrEXInt Redis `INCREX key [BYINT amount] [LBOUND value] [UBOUND value] +// [SATURATE] [EX seconds | PX ms | EXAT ts | PXAT ts | PERSIST] [ENX]` +// command. +// +// Atomically increments the integer value stored at key, optionally +// constraining the result to a range and applying expiration semantics. +// Returns the new value and the increment that was actually applied. When +// the increment would exceed LBOUND/UBOUND and SATURATE is not set, the key +// and TTL are left unchanged and the reply is [current_value, 0]. +// +// Available since Redis 8.8. +// For more information, see https://redis.io/commands/increx +func (c cmdable) IncrEXInt(ctx context.Context, key string, a IncrEXIntArgs) *IncrEXIntCmd { + args := make([]interface{}, 0, 14) + args = append(args, "increx", key) + if a.HasBy { + args = append(args, "byint", a.By) + } + if a.HasLBound { + args = append(args, "lbound", a.LBound) + } + if a.HasUBound { + args = append(args, "ubound", a.UBound) + } + if a.Saturate { + args = append(args, "saturate") + } + args = appendIncrEXTail(args, a.Expiration, a.ENX) + + cmd := NewIncrEXIntCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// IncrEXFloat Redis `INCREX key [BYFLOAT amount] [LBOUND value] [UBOUND value] +// [SATURATE] [EX seconds | PX ms | EXAT ts | PXAT ts | PERSIST] [ENX]` +// command. +// +// Available since Redis 8.8. +// For more information, see https://redis.io/commands/increx +func (c cmdable) IncrEXFloat(ctx context.Context, key string, a IncrEXFloatArgs) *IncrEXFloatCmd { + args := make([]interface{}, 0, 14) + args = append(args, "increx", key, "byfloat", a.By) + if a.HasLBound { + args = append(args, "lbound", a.LBound) + } + if a.HasUBound { + args = append(args, "ubound", a.UBound) + } + if a.Saturate { + args = append(args, "saturate") + } + args = appendIncrEXTail(args, a.Expiration, a.ENX) + + cmd := NewIncrEXFloatCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +func appendIncrEXTail(args []interface{}, exp *ExpirationOption, enx bool) []interface{} { + if exp != nil { + switch exp.Mode { + case EX, PX, EXAT, PXAT: + args = append(args, strings.ToLower(string(exp.Mode)), exp.Value) + case PERSIST: + args = append(args, "persist") + } + } + if enx { + args = append(args, "enx") + } + return args +} + +type SetCondition string + +const ( + // NX only set the keys and their expiration if none exist + NX SetCondition = "NX" + // XX only set the keys and their expiration if all already exist + XX SetCondition = "XX" +) + +type ExpirationMode string + +const ( + // EX sets expiration in seconds + EX ExpirationMode = "EX" + // PX sets expiration in milliseconds + PX ExpirationMode = "PX" + // EXAT sets expiration as Unix timestamp in seconds + EXAT ExpirationMode = "EXAT" + // PXAT sets expiration as Unix timestamp in milliseconds + PXAT ExpirationMode = "PXAT" + // KEEPTTL keeps the existing TTL + KEEPTTL ExpirationMode = "KEEPTTL" + // PERSIST removes the existing TTL. Used by INCREX. + PERSIST ExpirationMode = "PERSIST" +) + +type ExpirationOption struct { + Mode ExpirationMode + Value int64 +} + func (c cmdable) LCS(ctx context.Context, q *LCSQuery) *LCSCmd { cmd := NewLCSCmd(ctx, q) _ = c(ctx, cmd) @@ -157,6 +399,49 @@ func (c cmdable) MSetNX(ctx context.Context, values ...interface{}) *BoolCmd { return cmd } +type MSetEXArgs struct { + Condition SetCondition + Expiration *ExpirationOption +} + +// MSetEX sets the given keys to their respective values. +// This command is an extension of the MSETNX that adds expiration and XX options. +// Available since Redis 8.4 +// Important: When this method is used with Cluster clients, all keys +// must be in the same hash slot, otherwise CROSSSLOT error will be returned. +// For more information, see https://redis.io/commands/msetex +func (c cmdable) MSetEX(ctx context.Context, args MSetEXArgs, values ...interface{}) *IntCmd { + expandedArgs := appendArgs([]interface{}{}, values) + numkeys := len(expandedArgs) / 2 + + cmdArgs := make([]interface{}, 0, 2+len(expandedArgs)+3) + cmdArgs = append(cmdArgs, "msetex", numkeys) + cmdArgs = append(cmdArgs, expandedArgs...) + + if args.Condition != "" { + cmdArgs = append(cmdArgs, string(args.Condition)) + } + + if args.Expiration != nil { + switch args.Expiration.Mode { + case EX: + cmdArgs = append(cmdArgs, "ex", args.Expiration.Value) + case PX: + cmdArgs = append(cmdArgs, "px", args.Expiration.Value) + case EXAT: + cmdArgs = append(cmdArgs, "exat", args.Expiration.Value) + case PXAT: + cmdArgs = append(cmdArgs, "pxat", args.Expiration.Value) + case KEEPTTL: + cmdArgs = append(cmdArgs, "keepttl") + } + } + + cmd := NewIntCmd(ctx, cmdArgs...) + _ = c(ctx, cmd) + return cmd +} + // Set Redis `SET key value [expiration]` command. // Use expiration for `SETEx`-like behavior. // @@ -185,9 +470,24 @@ func (c cmdable) Set(ctx context.Context, key string, value interface{}, expirat // SetArgs provides arguments for the SetArgs function. type SetArgs struct { - // Mode can be `NX` or `XX` or empty. + // Mode can be `NX`, `XX`, `IFEQ`, `IFNE`, `IFDEQ`, `IFDNE` or empty. Mode string + // MatchValue is used with IFEQ/IFNE modes for compare-and-set operations. + // - IFEQ: only set if current value equals MatchValue + // - IFNE: only set if current value does not equal MatchValue + MatchValue interface{} + + // MatchDigest is used with IFDEQ/IFDNE modes for digest-based compare-and-set. + // - IFDEQ: only set if current value's digest equals MatchDigest + // - IFDNE: only set if current value's digest does not equal MatchDigest + // + // The digest is a uint64 xxh3 hash value. + // + // For examples of client-side digest generation, see: + // example/digest-optimistic-locking/ + MatchDigest uint64 + // Zero `TTL` or `Expiration` means that the key has no expiration time. TTL time.Duration ExpireAt time.Time @@ -223,6 +523,18 @@ func (c cmdable) SetArgs(ctx context.Context, key string, value interface{}, a S if a.Mode != "" { args = append(args, a.Mode) + + // Add match value/digest for CAS modes + switch a.Mode { + case "ifeq", "IFEQ", "ifne", "IFNE": + if a.MatchValue != nil { + args = append(args, a.MatchValue) + } + case "ifdeq", "IFDEQ", "ifdne", "IFDNE": + if a.MatchDigest != 0 { + args = append(args, fmt.Sprintf("%016x", a.MatchDigest)) + } + } } if a.Get { @@ -234,14 +546,16 @@ func (c cmdable) SetArgs(ctx context.Context, key string, value interface{}, a S return cmd } -// SetEx Redis `SETEx key expiration value` command. +// SetEx sets the value and expiration of a key. +// +// Deprecated: Use Set with expiration instead as of Redis 2.6.12. func (c cmdable) SetEx(ctx context.Context, key string, value interface{}, expiration time.Duration) *StatusCmd { cmd := NewStatusCmd(ctx, "setex", key, formatSec(ctx, expiration), value) _ = c(ctx, cmd) return cmd } -// SetNX Redis `SET key value [expiration] NX` command. +// SetNX sets the value of a key only if the key does not exist. // // Zero expiration means the key has no expiration time. // KeepTTL is a Redis KEEPTTL option to keep existing TTL, it requires your redis-server version >= 6.0, @@ -250,8 +564,7 @@ func (c cmdable) SetNX(ctx context.Context, key string, value interface{}, expir var cmd *BoolCmd switch expiration { case 0: - // Use old `SETNX` to support old Redis versions. - cmd = NewBoolCmd(ctx, "setnx", key, value) + cmd = NewBoolCmd(ctx, "set", key, value, "nx") case KeepTTL: cmd = NewBoolCmd(ctx, "set", key, value, "keepttl", "nx") default: @@ -290,6 +603,270 @@ func (c cmdable) SetXX(ctx context.Context, key string, value interface{}, expir return cmd } +// SetIFEQ Redis `SET key value [expiration] IFEQ match-value` command. +// Compare-and-set: only sets the value if the current value equals matchValue. +// +// Returns "OK" on success. +// Returns nil if the operation was aborted due to condition not matching. +// Zero expiration means the key has no expiration time. +// +// NOTE SetIFEQ is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFEQ(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifeq", matchValue) + + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// SetIFEQGet Redis `SET key value [expiration] IFEQ match-value GET` command. +// Compare-and-set with GET: only sets the value if the current value equals matchValue, +// and returns the previous value. +// +// Returns the previous value on success. +// Returns nil if the operation was aborted due to condition not matching. +// Zero expiration means the key has no expiration time. +// +// NOTE SetIFEQGet is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFEQGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifeq", matchValue, "get") + + cmd := NewStringCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// SetIFNE Redis `SET key value [expiration] IFNE match-value` command. +// Compare-and-set: only sets the value if the current value does not equal matchValue. +// +// Returns "OK" on success. +// Returns nil if the operation was aborted due to condition not matching. +// Zero expiration means the key has no expiration time. +// +// NOTE SetIFNE is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFNE(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StatusCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifne", matchValue) + + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// SetIFNEGet Redis `SET key value [expiration] IFNE match-value GET` command. +// Compare-and-set with GET: only sets the value if the current value does not equal matchValue, +// and returns the previous value. +// +// Returns the previous value on success. +// Returns nil if the operation was aborted due to condition not matching. +// Zero expiration means the key has no expiration time. +// +// NOTE SetIFNEGet is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFNEGet(ctx context.Context, key string, value interface{}, matchValue interface{}, expiration time.Duration) *StringCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifne", matchValue, "get") + + cmd := NewStringCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// SetIFDEQ sets the value only if the current value's digest equals matchDigest. +// +// This is a compare-and-set operation using xxh3 digest for optimistic locking. +// The matchDigest parameter is a uint64 xxh3 hash value. +// +// Returns "OK" on success. +// Returns redis.Nil if the digest doesn't match (value was modified). +// Zero expiration means the key has no expiration time. +// +// For examples of client-side digest generation and usage patterns, see: +// example/digest-optimistic-locking/ +// +// Redis 8.4+. See https://redis.io/commands/set/ +// +// NOTE SetIFNEQ is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFDEQ(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifdeq", fmt.Sprintf("%016x", matchDigest)) + + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// SetIFDEQGet sets the value only if the current value's digest equals matchDigest, +// and returns the previous value. +// +// This is a compare-and-set operation using xxh3 digest for optimistic locking. +// The matchDigest parameter is a uint64 xxh3 hash value. +// +// Returns the previous value on success. +// Returns redis.Nil if the digest doesn't match (value was modified). +// Zero expiration means the key has no expiration time. +// +// For examples of client-side digest generation and usage patterns, see: +// example/digest-optimistic-locking/ +// +// Redis 8.4+. See https://redis.io/commands/set/ +// +// NOTE SetIFNEQGet is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFDEQGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifdeq", fmt.Sprintf("%016x", matchDigest), "get") + + cmd := NewStringCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// SetIFDNE sets the value only if the current value's digest does NOT equal matchDigest. +// +// This is a compare-and-set operation using xxh3 digest for optimistic locking. +// The matchDigest parameter is a uint64 xxh3 hash value. +// +// Returns "OK" on success (digest didn't match, value was set). +// Returns redis.Nil if the digest matches (value was not modified). +// Zero expiration means the key has no expiration time. +// +// For examples of client-side digest generation and usage patterns, see: +// example/digest-optimistic-locking/ +// +// Redis 8.4+. See https://redis.io/commands/set/ +// +// NOTE SetIFDNE is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFDNE(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StatusCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifdne", fmt.Sprintf("%016x", matchDigest)) + + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// SetIFDNEGet sets the value only if the current value's digest does NOT equal matchDigest, +// and returns the previous value. +// +// This is a compare-and-set operation using xxh3 digest for optimistic locking. +// The matchDigest parameter is a uint64 xxh3 hash value. +// +// Returns the previous value on success (digest didn't match, value was set). +// Returns redis.Nil if the digest matches (value was not modified). +// Zero expiration means the key has no expiration time. +// +// For examples of client-side digest generation and usage patterns, see: +// example/digest-optimistic-locking/ +// +// Redis 8.4+. See https://redis.io/commands/set/ +// +// NOTE SetIFDNEGet is still experimental +// it's signature and behaviour may change +func (c cmdable) SetIFDNEGet(ctx context.Context, key string, value interface{}, matchDigest uint64, expiration time.Duration) *StringCmd { + args := []interface{}{"set", key, value} + + if expiration > 0 { + if usePrecise(expiration) { + args = append(args, "px", formatMs(ctx, expiration)) + } else { + args = append(args, "ex", formatSec(ctx, expiration)) + } + } else if expiration == KeepTTL { + args = append(args, "keepttl") + } + + args = append(args, "ifdne", fmt.Sprintf("%016x", matchDigest), "get") + + cmd := NewStringCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + func (c cmdable) SetRange(ctx context.Context, key string, offset int64, value string) *IntCmd { cmd := NewIntCmd(ctx, "setrange", key, offset, value) _ = c(ctx, cmd) diff --git a/vendor/github.com/redis/go-redis/v9/timeseries_commands.go b/vendor/github.com/redis/go-redis/v9/timeseries_commands.go index 82d8cdfcf..db00db80f 100644 --- a/vendor/github.com/redis/go-redis/v9/timeseries_commands.go +++ b/vendor/github.com/redis/go-redis/v9/timeseries_commands.go @@ -2,9 +2,12 @@ package redis import ( "context" - "strconv" + "errors" + "fmt" + "strings" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/util" ) type TimeseriesCmdable interface { @@ -96,6 +99,8 @@ const ( VarP VarS Twa + CountNaN + CountAll ) func (a Aggregator) String() string { @@ -128,44 +133,97 @@ func (a Aggregator) String() string { return "VAR.S" case Twa: return "TWA" + case CountNaN: + return "COUNTNAN" + case CountAll: + return "COUNTALL" default: return "" } } +var ( + errTSMultiAggregationGroupBy = errors.New("redis: GROUPBY is not allowed when multiple aggregators are specified") + errTSAggregationConflict = errors.New("redis: setting both Aggregator and Aggregators is not allowed; use Aggregators instead because Aggregator is deprecated") +) + +func formatAggregationArgs(aggregator Aggregator, aggregators []Aggregator) (string, int, error) { + if aggregator != Invalid && len(aggregators) > 0 { + return "", 0, errTSAggregationConflict + } + if len(aggregators) == 0 { + if aggregator == Invalid { + return "", 0, nil + } + aggregationArg, err := formatAggregatorArg(aggregator) + if err != nil { + return "", 0, err + } + return aggregationArg, 1, nil + } + + parts := make([]string, len(aggregators)) + for i, agg := range aggregators { + if agg == Invalid { + return "", 0, fmt.Errorf("redis: invalid timeseries aggregator at index %d: Invalid (%d)", i, agg) + } + aggregationArg, err := formatAggregatorArg(agg) + if err != nil { + return "", 0, fmt.Errorf("redis: invalid timeseries aggregator at index %d: %d", i, agg) + } + parts[i] = aggregationArg + } + + return strings.Join(parts, ","), len(parts), nil +} + +func formatAggregatorArg(aggregator Aggregator) (string, error) { + aggregationArg := aggregator.String() + if aggregationArg == "" { + return "", fmt.Errorf("redis: invalid timeseries aggregator: %d", aggregator) + } + return aggregationArg, nil +} + type TSRangeOptions struct { - Latest bool - FilterByTS []int - FilterByValue []int - Count int - Align interface{} + Latest bool + FilterByTS []int + FilterByValue []int + Count int + Align interface{} + // Deprecated: use Aggregators instead. Aggregator Aggregator + Aggregators []Aggregator BucketDuration int BucketTimestamp interface{} Empty bool } type TSRevRangeOptions struct { - Latest bool - FilterByTS []int - FilterByValue []int - Count int - Align interface{} + Latest bool + FilterByTS []int + FilterByValue []int + Count int + Align interface{} + // Deprecated: use Aggregators instead. Aggregator Aggregator + Aggregators []Aggregator BucketDuration int BucketTimestamp interface{} Empty bool } type TSMRangeOptions struct { - Latest bool - FilterByTS []int - FilterByValue []int - WithLabels bool - SelectedLabels []interface{} - Count int - Align interface{} + Latest bool + FilterByTS []int + FilterByValue []int + WithLabels bool + SelectedLabels []interface{} + Count int + Align interface{} + // Deprecated: use Aggregators instead. Aggregator Aggregator + Aggregators []Aggregator BucketDuration int BucketTimestamp interface{} Empty bool @@ -174,14 +232,16 @@ type TSMRangeOptions struct { } type TSMRevRangeOptions struct { - Latest bool - FilterByTS []int - FilterByValue []int - WithLabels bool - SelectedLabels []interface{} - Count int - Align interface{} + Latest bool + FilterByTS []int + FilterByValue []int + WithLabels bool + SelectedLabels []interface{} + Count int + Align interface{} + // Deprecated: use Aggregators instead. Aggregator Aggregator + Aggregators []Aggregator BucketDuration int BucketTimestamp interface{} Empty bool @@ -477,7 +537,16 @@ func (c cmdable) TSGet(ctx context.Context, key string) *TSTimestampValueCmd { type TSTimestampValue struct { Timestamp int64 Value float64 + Values []float64 } + +func (tv TSTimestampValue) String() string { + if len(tv.Values) > 0 { + return fmt.Sprintf("{%d %v}", tv.Timestamp, tv.Values) + } + return fmt.Sprintf("{%d %v}", tv.Timestamp, tv.Value) +} + type TSTimestampValueCmd struct { baseCmd val TSTimestampValue @@ -486,8 +555,9 @@ type TSTimestampValueCmd struct { func newTSTimestampValueCmd(ctx context.Context, args ...interface{}) *TSTimestampValueCmd { return &TSTimestampValueCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTSTimestampValue, }, } } @@ -524,7 +594,7 @@ func (cmd *TSTimestampValueCmd) readReply(rd *proto.Reader) (err error) { return err } cmd.val.Timestamp = timestamp - cmd.val.Value, err = strconv.ParseFloat(value, 64) + cmd.val.Value, err = util.ParseStringToFloat(value) if err != nil { return err } @@ -533,6 +603,18 @@ func (cmd *TSTimestampValueCmd) readReply(rd *proto.Reader) (err error) { return nil } +func (cmd *TSTimestampValueCmd) Clone() Cmder { + val := cmd.val + if cmd.val.Values != nil { + val.Values = make([]float64, len(cmd.val.Values)) + copy(val.Values, cmd.val.Values) + } + return &TSTimestampValueCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // TSInfo - Returns information about a time-series key. // For more information - https://redis.io/commands/ts.info/ func (c cmdable) TSInfo(ctx context.Context, key string) *MapStringInterfaceCmd { @@ -622,8 +704,14 @@ func (c cmdable) TSRevRangeWithArgs(ctx context.Context, key string, fromTimesta if options.Align != nil { args = append(args, "ALIGN", options.Align) } - if options.Aggregator != 0 { - args = append(args, "AGGREGATION", options.Aggregator.String()) + aggregationArg, _, err := formatAggregationArgs(options.Aggregator, options.Aggregators) + if err != nil { + cmd := newTSTimestampValueSliceCmd(ctx, args...) + cmd.SetErr(err) + return cmd + } + if aggregationArg != "" { + args = append(args, "AGGREGATION", aggregationArg) } if options.BucketDuration != 0 { args = append(args, options.BucketDuration) @@ -678,8 +766,14 @@ func (c cmdable) TSRangeWithArgs(ctx context.Context, key string, fromTimestamp if options.Align != nil { args = append(args, "ALIGN", options.Align) } - if options.Aggregator != 0 { - args = append(args, "AGGREGATION", options.Aggregator.String()) + aggregationArg, _, err := formatAggregationArgs(options.Aggregator, options.Aggregators) + if err != nil { + cmd := newTSTimestampValueSliceCmd(ctx, args...) + cmd.SetErr(err) + return cmd + } + if aggregationArg != "" { + args = append(args, "AGGREGATION", aggregationArg) } if options.BucketDuration != 0 { args = append(args, options.BucketDuration) @@ -704,8 +798,9 @@ type TSTimestampValueSliceCmd struct { func newTSTimestampValueSliceCmd(ctx context.Context, args ...interface{}) *TSTimestampValueSliceCmd { return &TSTimestampValueSliceCmd{ baseCmd: baseCmd{ - ctx: ctx, - args: args, + ctx: ctx, + args: args, + cmdType: CmdTypeTSTimestampValueSlice, }, } } @@ -733,25 +828,62 @@ func (cmd *TSTimestampValueSliceCmd) readReply(rd *proto.Reader) (err error) { } cmd.val = make([]TSTimestampValue, n) for i := 0; i < n; i++ { - _, _ = rd.ReadArrayLen() - timestamp, err := rd.ReadInt() + itemLen, err := rd.ReadArrayLen() if err != nil { return err } - value, err := rd.ReadString() + + timestamp, err := rd.ReadInt() if err != nil { return err } cmd.val[i].Timestamp = timestamp - cmd.val[i].Value, err = strconv.ParseFloat(value, 64) - if err != nil { - return err + if itemLen == 2 { + value, err := rd.ReadString() + if err != nil { + return err + } + cmd.val[i].Value, err = util.ParseStringToFloat(value) + if err != nil { + return err + } + continue + } + + cmd.val[i].Values = make([]float64, itemLen-1) + for j := 0; j < itemLen-1; j++ { + value, err := rd.ReadString() + if err != nil { + return err + } + cmd.val[i].Values[j], err = util.ParseStringToFloat(value) + if err != nil { + return err + } } } return nil } +func (cmd *TSTimestampValueSliceCmd) Clone() Cmder { + var val []TSTimestampValue + if cmd.val != nil { + val = make([]TSTimestampValue, len(cmd.val)) + copy(val, cmd.val) + for i := range cmd.val { + if cmd.val[i].Values != nil { + val[i].Values = make([]float64, len(cmd.val[i].Values)) + copy(val[i].Values, cmd.val[i].Values) + } + } + } + return &TSTimestampValueSliceCmd{ + baseCmd: cmd.cloneBaseCmd(), + val: val, + } +} + // TSMRange - Returns a range of samples from multiple time-series keys. // For more information - https://redis.io/commands/ts.mrange/ func (c cmdable) TSMRange(ctx context.Context, fromTimestamp int, toTimestamp int, filterExpr []string) *MapStringSliceInterfaceCmd { @@ -772,6 +904,7 @@ func (c cmdable) TSMRange(ctx context.Context, fromTimestamp int, toTimestamp in // For more information - https://redis.io/commands/ts.mrange/ func (c cmdable) TSMRangeWithArgs(ctx context.Context, fromTimestamp int, toTimestamp int, filterExpr []string, options *TSMRangeOptions) *MapStringSliceInterfaceCmd { args := []interface{}{"TS.MRANGE", fromTimestamp, toTimestamp} + multiAggregationCount := 0 if options != nil { if options.Latest { args = append(args, "LATEST") @@ -801,8 +934,15 @@ func (c cmdable) TSMRangeWithArgs(ctx context.Context, fromTimestamp int, toTime if options.Align != nil { args = append(args, "ALIGN", options.Align) } - if options.Aggregator != 0 { - args = append(args, "AGGREGATION", options.Aggregator.String()) + aggregationArg, count, err := formatAggregationArgs(options.Aggregator, options.Aggregators) + if err != nil { + cmd := NewMapStringSliceInterfaceCmd(ctx, args...) + cmd.SetErr(err) + return cmd + } + multiAggregationCount = count + if aggregationArg != "" { + args = append(args, "AGGREGATION", aggregationArg) } if options.BucketDuration != 0 { args = append(args, options.BucketDuration) @@ -819,6 +959,11 @@ func (c cmdable) TSMRangeWithArgs(ctx context.Context, fromTimestamp int, toTime args = append(args, f) } if options != nil { + if multiAggregationCount > 1 && (options.GroupByLabel != nil || options.Reducer != nil) { + cmd := NewMapStringSliceInterfaceCmd(ctx, args...) + cmd.SetErr(errTSMultiAggregationGroupBy) + return cmd + } if options.GroupByLabel != nil { args = append(args, "GROUPBY", options.GroupByLabel) } @@ -851,6 +996,7 @@ func (c cmdable) TSMRevRange(ctx context.Context, fromTimestamp int, toTimestamp // For more information - https://redis.io/commands/ts.mrevrange/ func (c cmdable) TSMRevRangeWithArgs(ctx context.Context, fromTimestamp int, toTimestamp int, filterExpr []string, options *TSMRevRangeOptions) *MapStringSliceInterfaceCmd { args := []interface{}{"TS.MREVRANGE", fromTimestamp, toTimestamp} + multiAggregationCount := 0 if options != nil { if options.Latest { args = append(args, "LATEST") @@ -880,8 +1026,15 @@ func (c cmdable) TSMRevRangeWithArgs(ctx context.Context, fromTimestamp int, toT if options.Align != nil { args = append(args, "ALIGN", options.Align) } - if options.Aggregator != 0 { - args = append(args, "AGGREGATION", options.Aggregator.String()) + aggregationArg, count, err := formatAggregationArgs(options.Aggregator, options.Aggregators) + if err != nil { + cmd := NewMapStringSliceInterfaceCmd(ctx, args...) + cmd.SetErr(err) + return cmd + } + multiAggregationCount = count + if aggregationArg != "" { + args = append(args, "AGGREGATION", aggregationArg) } if options.BucketDuration != 0 { args = append(args, options.BucketDuration) @@ -898,6 +1051,11 @@ func (c cmdable) TSMRevRangeWithArgs(ctx context.Context, fromTimestamp int, toT args = append(args, f) } if options != nil { + if multiAggregationCount > 1 && (options.GroupByLabel != nil || options.Reducer != nil) { + cmd := NewMapStringSliceInterfaceCmd(ctx, args...) + cmd.SetErr(errTSMultiAggregationGroupBy) + return cmd + } if options.GroupByLabel != nil { args = append(args, "GROUPBY", options.GroupByLabel) } diff --git a/vendor/github.com/redis/go-redis/v9/tx.go b/vendor/github.com/redis/go-redis/v9/tx.go index 039eaf351..b433b4024 100644 --- a/vendor/github.com/redis/go-redis/v9/tx.go +++ b/vendor/github.com/redis/go-redis/v9/tx.go @@ -11,7 +11,7 @@ import ( const TxFailedErr = proto.RedisError("redis: transaction failed") // Tx implements Redis transactions as described in -// http://redis.io/topics/transactions. It's NOT safe for concurrent use +// https://redis.io/docs/latest/develop/using-commands/transactions. It's NOT safe for concurrent use // by multiple goroutines, because Exec resets list of watched keys. // // If you don't need WATCH, use Pipeline instead. @@ -19,16 +19,17 @@ type Tx struct { baseClient cmdable statefulCmdable - hooksMixin } func (c *Client) newTx() *Tx { tx := Tx{ baseClient: baseClient{ - opt: c.opt, - connPool: pool.NewStickyConnPool(c.connPool), + opt: c.cloneOpt(), // Clone options under optLock to avoid race with initConn + connPool: pool.NewStickyConnPool(c.connPool), + hooksMixin: c.hooksMixin.clone(), + pushProcessor: c.pushProcessor, // Copy push processor from parent client + onClose: &onCloseHooks{}, }, - hooksMixin: c.hooksMixin.clone(), } tx.init() return &tx diff --git a/vendor/github.com/redis/go-redis/v9/universal.go b/vendor/github.com/redis/go-redis/v9/universal.go index a1ce17bac..b623460cb 100644 --- a/vendor/github.com/redis/go-redis/v9/universal.go +++ b/vendor/github.com/redis/go-redis/v9/universal.go @@ -5,6 +5,10 @@ import ( "crypto/tls" "net" "time" + + "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/maintnotifications" + "github.com/redis/go-redis/v9/push" ) // UniversalOptions information is required by UniversalClient to establish @@ -26,9 +30,27 @@ type UniversalOptions struct { Dialer func(ctx context.Context, network, addr string) (net.Conn, error) OnConnect func(ctx context.Context, cn *Conn) error - Protocol int - Username string - Password string + Protocol int + Username string + Password string + // CredentialsProvider allows the username and password to be updated + // before reconnecting. It should return the current username and password. + CredentialsProvider func() (username string, password string) + + // CredentialsProviderContext is an enhanced parameter of CredentialsProvider, + // done to maintain API compatibility. In the future, + // there might be a merge between CredentialsProviderContext and CredentialsProvider. + // There will be a conflict between them; if CredentialsProviderContext exists, we will ignore CredentialsProvider. + CredentialsProviderContext func(ctx context.Context) (username string, password string, err error) + + // StreamingCredentialsProvider is used to retrieve the credentials + // for the connection from an external source. Those credentials may change + // during the connection lifetime. This is useful for managed identity + // scenarios where the credentials are retrieved from an external source. + // + // Currently, this is a placeholder for the future implementation. + StreamingCredentialsProvider auth.StreamingCredentialsProvider + SentinelUsername string SentinelPassword string @@ -36,21 +58,52 @@ type UniversalOptions struct { MinRetryBackoff time.Duration MaxRetryBackoff time.Duration - DialTimeout time.Duration + DialTimeout time.Duration + + // DialerRetries is the maximum number of retry attempts when dialing fails. + // + // default: 5 + DialerRetries int + + // DialerRetryTimeout is the backoff duration between retry attempts. + // + // default: 100 milliseconds + DialerRetryTimeout time.Duration + ReadTimeout time.Duration WriteTimeout time.Duration ContextTimeoutEnabled bool + // ReadBufferSize is the size of the bufio.Reader buffer for each connection. + // Larger buffers can improve performance for commands that return large responses. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + ReadBufferSize int + + // WriteBufferSize is the size of the bufio.Writer buffer for each connection. + // Larger buffers can improve performance for large pipelines and commands with many arguments. + // Smaller buffers can improve memory usage for larger pools. + // + // default: 32KiB (32768 bytes) + WriteBufferSize int + // PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). PoolFIFO bool - PoolSize int - PoolTimeout time.Duration - MinIdleConns int - MaxIdleConns int - MaxActiveConns int - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration + PoolSize int + + // MaxConcurrentDials is the maximum number of concurrent connection creation goroutines. + // If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize. + MaxConcurrentDials int + + PoolTimeout time.Duration + MinIdleConns int + MaxIdleConns int + MaxActiveConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + ConnMaxLifetimeJitter time.Duration TLSConfig *tls.Config @@ -78,10 +131,26 @@ type UniversalOptions struct { DisableIdentity bool IdentitySuffix string - UnstableResp3 bool + + // FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing. + // When a node is marked as failing, it will be avoided for this duration. + // Only applies to cluster clients. Default is 15 seconds. + FailingTimeoutSeconds int + + // Deprecated: All RediSearch commands now have stable RESP3 parsing and this + // flag is a no-op. It is kept for backwards compatibility and will be removed + // in a future release. + UnstableResp3 bool + + // PushNotificationProcessor is the processor for handling push notifications. + // If nil, a default processor will be created for RESP3 connections. + PushNotificationProcessor push.NotificationProcessor // IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint). IsClusterMode bool + + // MaintNotificationsConfig provides configuration for maintnotifications upgrades. + MaintNotificationsConfig *maintnotifications.Config } // Cluster returns cluster options created from the universal options. @@ -96,9 +165,12 @@ func (o *UniversalOptions) Cluster() *ClusterOptions { Dialer: o.Dialer, OnConnect: o.OnConnect, - Protocol: o.Protocol, - Username: o.Username, - Password: o.Password, + Protocol: o.Protocol, + Username: o.Username, + Password: o.Password, + CredentialsProvider: o.CredentialsProvider, + CredentialsProviderContext: o.CredentialsProviderContext, + StreamingCredentialsProvider: o.StreamingCredentialsProvider, MaxRedirects: o.MaxRedirects, ReadOnly: o.ReadOnly, @@ -109,27 +181,37 @@ func (o *UniversalOptions) Cluster() *ClusterOptions { MinRetryBackoff: o.MinRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff, - DialTimeout: o.DialTimeout, - ReadTimeout: o.ReadTimeout, - WriteTimeout: o.WriteTimeout, + DialTimeout: o.DialTimeout, + DialerRetries: o.DialerRetries, + DialerRetryTimeout: o.DialerRetryTimeout, + ReadTimeout: o.ReadTimeout, + WriteTimeout: o.WriteTimeout, + ContextTimeoutEnabled: o.ContextTimeoutEnabled, - PoolFIFO: o.PoolFIFO, + ReadBufferSize: o.ReadBufferSize, + WriteBufferSize: o.WriteBufferSize, - PoolSize: o.PoolSize, - PoolTimeout: o.PoolTimeout, - MinIdleConns: o.MinIdleConns, - MaxIdleConns: o.MaxIdleConns, - MaxActiveConns: o.MaxActiveConns, - ConnMaxIdleTime: o.ConnMaxIdleTime, - ConnMaxLifetime: o.ConnMaxLifetime, + PoolFIFO: o.PoolFIFO, + PoolSize: o.PoolSize, + MaxConcurrentDials: o.MaxConcurrentDials, + PoolTimeout: o.PoolTimeout, + MinIdleConns: o.MinIdleConns, + MaxIdleConns: o.MaxIdleConns, + MaxActiveConns: o.MaxActiveConns, + ConnMaxIdleTime: o.ConnMaxIdleTime, + ConnMaxLifetime: o.ConnMaxLifetime, + ConnMaxLifetimeJitter: o.ConnMaxLifetimeJitter, TLSConfig: o.TLSConfig, - DisableIdentity: o.DisableIdentity, - DisableIndentity: o.DisableIndentity, - IdentitySuffix: o.IdentitySuffix, - UnstableResp3: o.UnstableResp3, + DisableIdentity: o.DisableIdentity, + DisableIndentity: o.DisableIndentity, + IdentitySuffix: o.IdentitySuffix, + FailingTimeoutSeconds: o.FailingTimeoutSeconds, + UnstableResp3: o.UnstableResp3, + PushNotificationProcessor: o.PushNotificationProcessor, + MaintNotificationsConfig: o.MaintNotificationsConfig, } } @@ -147,10 +229,14 @@ func (o *UniversalOptions) Failover() *FailoverOptions { Dialer: o.Dialer, OnConnect: o.OnConnect, - DB: o.DB, - Protocol: o.Protocol, - Username: o.Username, - Password: o.Password, + DB: o.DB, + Protocol: o.Protocol, + Username: o.Username, + Password: o.Password, + CredentialsProvider: o.CredentialsProvider, + CredentialsProviderContext: o.CredentialsProviderContext, + StreamingCredentialsProvider: o.StreamingCredentialsProvider, + SentinelUsername: o.SentinelUsername, SentinelPassword: o.SentinelPassword, @@ -161,28 +247,38 @@ func (o *UniversalOptions) Failover() *FailoverOptions { MinRetryBackoff: o.MinRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff, - DialTimeout: o.DialTimeout, - ReadTimeout: o.ReadTimeout, - WriteTimeout: o.WriteTimeout, + DialTimeout: o.DialTimeout, + DialerRetries: o.DialerRetries, + DialerRetryTimeout: o.DialerRetryTimeout, + ReadTimeout: o.ReadTimeout, + WriteTimeout: o.WriteTimeout, + ContextTimeoutEnabled: o.ContextTimeoutEnabled, - PoolFIFO: o.PoolFIFO, - PoolSize: o.PoolSize, - PoolTimeout: o.PoolTimeout, - MinIdleConns: o.MinIdleConns, - MaxIdleConns: o.MaxIdleConns, - MaxActiveConns: o.MaxActiveConns, - ConnMaxIdleTime: o.ConnMaxIdleTime, - ConnMaxLifetime: o.ConnMaxLifetime, + ReadBufferSize: o.ReadBufferSize, + WriteBufferSize: o.WriteBufferSize, + + PoolFIFO: o.PoolFIFO, + PoolSize: o.PoolSize, + MaxConcurrentDials: o.MaxConcurrentDials, + PoolTimeout: o.PoolTimeout, + MinIdleConns: o.MinIdleConns, + MaxIdleConns: o.MaxIdleConns, + MaxActiveConns: o.MaxActiveConns, + ConnMaxIdleTime: o.ConnMaxIdleTime, + ConnMaxLifetime: o.ConnMaxLifetime, + ConnMaxLifetimeJitter: o.ConnMaxLifetimeJitter, TLSConfig: o.TLSConfig, ReplicaOnly: o.ReadOnly, - DisableIdentity: o.DisableIdentity, - DisableIndentity: o.DisableIndentity, - IdentitySuffix: o.IdentitySuffix, - UnstableResp3: o.UnstableResp3, + DisableIdentity: o.DisableIdentity, + DisableIndentity: o.DisableIndentity, + IdentitySuffix: o.IdentitySuffix, + UnstableResp3: o.UnstableResp3, + PushNotificationProcessor: o.PushNotificationProcessor, + // Note: MaintNotificationsConfig not supported for FailoverOptions } } @@ -199,35 +295,48 @@ func (o *UniversalOptions) Simple() *Options { Dialer: o.Dialer, OnConnect: o.OnConnect, - DB: o.DB, - Protocol: o.Protocol, - Username: o.Username, - Password: o.Password, + DB: o.DB, + Protocol: o.Protocol, + Username: o.Username, + Password: o.Password, + CredentialsProvider: o.CredentialsProvider, + CredentialsProviderContext: o.CredentialsProviderContext, + StreamingCredentialsProvider: o.StreamingCredentialsProvider, MaxRetries: o.MaxRetries, MinRetryBackoff: o.MinRetryBackoff, MaxRetryBackoff: o.MaxRetryBackoff, - DialTimeout: o.DialTimeout, - ReadTimeout: o.ReadTimeout, - WriteTimeout: o.WriteTimeout, + DialTimeout: o.DialTimeout, + DialerRetries: o.DialerRetries, + DialerRetryTimeout: o.DialerRetryTimeout, + ReadTimeout: o.ReadTimeout, + WriteTimeout: o.WriteTimeout, + ContextTimeoutEnabled: o.ContextTimeoutEnabled, - PoolFIFO: o.PoolFIFO, - PoolSize: o.PoolSize, - PoolTimeout: o.PoolTimeout, - MinIdleConns: o.MinIdleConns, - MaxIdleConns: o.MaxIdleConns, - MaxActiveConns: o.MaxActiveConns, - ConnMaxIdleTime: o.ConnMaxIdleTime, - ConnMaxLifetime: o.ConnMaxLifetime, + ReadBufferSize: o.ReadBufferSize, + WriteBufferSize: o.WriteBufferSize, + + PoolFIFO: o.PoolFIFO, + PoolSize: o.PoolSize, + MaxConcurrentDials: o.MaxConcurrentDials, + PoolTimeout: o.PoolTimeout, + MinIdleConns: o.MinIdleConns, + MaxIdleConns: o.MaxIdleConns, + MaxActiveConns: o.MaxActiveConns, + ConnMaxIdleTime: o.ConnMaxIdleTime, + ConnMaxLifetime: o.ConnMaxLifetime, + ConnMaxLifetimeJitter: o.ConnMaxLifetimeJitter, TLSConfig: o.TLSConfig, - DisableIdentity: o.DisableIdentity, - DisableIndentity: o.DisableIndentity, - IdentitySuffix: o.IdentitySuffix, - UnstableResp3: o.UnstableResp3, + DisableIdentity: o.DisableIdentity, + DisableIndentity: o.DisableIndentity, + IdentitySuffix: o.IdentitySuffix, + UnstableResp3: o.UnstableResp3, + PushNotificationProcessor: o.PushNotificationProcessor, + MaintNotificationsConfig: o.MaintNotificationsConfig, } } @@ -266,6 +375,8 @@ var ( // 3. If the number of Addrs is two or more, or IsClusterMode option is specified, // a ClusterClient is returned. // 4. Otherwise, a single-node Client is returned. +// +// Passing nil UniversalOptions will cause a panic. func NewUniversalClient(opts *UniversalOptions) UniversalClient { if opts == nil { panic("redis: NewUniversalClient nil options") diff --git a/vendor/github.com/redis/go-redis/v9/vectorset_commands.go b/vendor/github.com/redis/go-redis/v9/vectorset_commands.go new file mode 100644 index 000000000..a88bb1cd9 --- /dev/null +++ b/vendor/github.com/redis/go-redis/v9/vectorset_commands.go @@ -0,0 +1,480 @@ +package redis + +import ( + "context" + "encoding/json" + "strconv" +) + +// note: the APIs is experimental and may be subject to change. +type VectorSetCmdable interface { + VAdd(ctx context.Context, key, element string, val Vector) *BoolCmd + VAddWithArgs(ctx context.Context, key, element string, val Vector, addArgs *VAddArgs) *BoolCmd + VCard(ctx context.Context, key string) *IntCmd + VDim(ctx context.Context, key string) *IntCmd + VEmb(ctx context.Context, key, element string, raw bool) *SliceCmd + VGetAttr(ctx context.Context, key, element string) *StringCmd + VInfo(ctx context.Context, key string) *MapStringInterfaceCmd + VLinks(ctx context.Context, key, element string) *StringSliceSliceCmd + VLinksWithScores(ctx context.Context, key, element string) *VectorScoreSliceSliceCmd + VRandMember(ctx context.Context, key string) *StringCmd + VRandMemberCount(ctx context.Context, key string, count int) *StringSliceCmd + VRem(ctx context.Context, key, element string) *BoolCmd + VSetAttr(ctx context.Context, key, element string, attr interface{}) *BoolCmd + VClearAttributes(ctx context.Context, key, element string) *BoolCmd + VSim(ctx context.Context, key string, val Vector) *StringSliceCmd + VSimWithScores(ctx context.Context, key string, val Vector) *VectorScoreSliceCmd + VSimWithArgs(ctx context.Context, key string, val Vector, args *VSimArgs) *StringSliceCmd + VSimWithArgsWithScores(ctx context.Context, key string, val Vector, args *VSimArgs) *VectorScoreSliceCmd + VSimWithArgsWithAttribs(ctx context.Context, key string, val Vector, args *VSimArgs) *VectorAttribSliceCmd + VSimWithArgsWithScoresWithAttribs(ctx context.Context, key string, val Vector, args *VSimArgs) *VectorScoreAttribSliceCmd + VRange(ctx context.Context, key, start, end string, count int64) *StringSliceCmd + VIsMember(ctx context.Context, key, element string) *BoolCmd +} + +type Vector interface { + Value() []any +} + +const ( + vectorFormatFP32 string = "FP32" + vectorFormatValues string = "Values" + vectorFormatF16 string = "FLOAT16" + vectorFormatBF16 string = "BFLOAT16" + vectorFormatF64 string = "FLOAT64" + vectorFormatI8 string = "INT8" + vectorFormatU8 string = "UINT8" +) + +type VectorFP32 struct { + Val []byte +} + +func (v *VectorFP32) Value() []any { + return []any{vectorFormatFP32, v.Val} +} + +var _ Vector = (*VectorFP32)(nil) + +// VectorFloat16 represents a FLOAT16-encoded vector blob. +// note: intended for search/index query commands such as FT.HYBRID. +type VectorFloat16 struct { + Val []byte +} + +func (v *VectorFloat16) Value() []any { + return []any{vectorFormatF16, v.Val} +} + +var _ Vector = (*VectorFloat16)(nil) + +// VectorBFloat16 represents a BFLOAT16-encoded vector blob. +// note: intended for search/index query commands such as FT.HYBRID. +type VectorBFloat16 struct { + Val []byte +} + +func (v *VectorBFloat16) Value() []any { + return []any{vectorFormatBF16, v.Val} +} + +var _ Vector = (*VectorBFloat16)(nil) + +// VectorFloat64 represents a FLOAT64-encoded vector blob. +// note: intended for search/index query commands such as FT.HYBRID. +type VectorFloat64 struct { + Val []byte +} + +func (v *VectorFloat64) Value() []any { + return []any{vectorFormatF64, v.Val} +} + +var _ Vector = (*VectorFloat64)(nil) + +// VectorInt8 represents an INT8-encoded vector blob. +// note: intended for search/index query commands such as FT.HYBRID. +type VectorInt8 struct { + Val []byte +} + +func (v *VectorInt8) Value() []any { + return []any{vectorFormatI8, v.Val} +} + +var _ Vector = (*VectorInt8)(nil) + +// VectorUint8 represents a UINT8-encoded vector blob. +// note: intended for search/index query commands such as FT.HYBRID. +type VectorUint8 struct { + Val []byte +} + +func (v *VectorUint8) Value() []any { + return []any{vectorFormatU8, v.Val} +} + +var _ Vector = (*VectorUint8)(nil) + +type VectorValues struct { + Val []float64 +} + +func (v *VectorValues) Value() []any { + res := make([]any, 2+len(v.Val)) + res[0] = vectorFormatValues + res[1] = len(v.Val) + for i, v := range v.Val { + res[2+i] = v + } + return res +} + +var _ Vector = (*VectorValues)(nil) + +type VectorRef struct { + Name string // the name of the referent vector +} + +func (v *VectorRef) Value() []any { + return []any{"ele", v.Name} +} + +var _ Vector = (*VectorRef)(nil) + +type VectorScore struct { + Name string + Score float64 +} + +type VectorAttrib struct { + Name string + Attribs *string +} + +type VectorScoreAttrib struct { + Name string + Score float64 + Attribs *string +} + +// `VADD key (FP32 | VALUES num) vector element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VAdd(ctx context.Context, key, element string, val Vector) *BoolCmd { + return c.VAddWithArgs(ctx, key, element, val, &VAddArgs{}) +} + +type VAddArgs struct { + // the REDUCE option must be passed immediately after the key + Reduce int64 + Cas bool + + // The NoQuant, Q8 and Bin options are mutually exclusive. + NoQuant bool + Q8 bool + Bin bool + + EF int64 + SetAttr string + M int64 +} + +func (v VAddArgs) reduce() int64 { + return v.Reduce +} + +func (v VAddArgs) appendArgs(args []any) []any { + if v.Cas { + args = append(args, "cas") + } + + if v.NoQuant { + args = append(args, "noquant") + } else if v.Q8 { + args = append(args, "q8") + } else if v.Bin { + args = append(args, "bin") + } + + if v.EF > 0 { + args = append(args, "ef", strconv.FormatInt(v.EF, 10)) + } + if len(v.SetAttr) > 0 { + args = append(args, "setattr", v.SetAttr) + } + if v.M > 0 { + args = append(args, "m", strconv.FormatInt(v.M, 10)) + } + return args +} + +// `VADD key [REDUCE dim] (FP32 | VALUES num) vector element [CAS] [NOQUANT | Q8 | BIN] [EF build-exploration-factor] [SETATTR attributes] [M numlinks]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VAddWithArgs(ctx context.Context, key, element string, val Vector, addArgs *VAddArgs) *BoolCmd { + if addArgs == nil { + addArgs = &VAddArgs{} + } + args := []any{"vadd", key} + if addArgs.reduce() > 0 { + args = append(args, "reduce", addArgs.reduce()) + } + args = append(args, val.Value()...) + args = append(args, element) + args = addArgs.appendArgs(args) + cmd := NewBoolCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VCARD key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VCard(ctx context.Context, key string) *IntCmd { + cmd := NewIntCmd(ctx, "vcard", key) + _ = c(ctx, cmd) + return cmd +} + +// `VDIM key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VDim(ctx context.Context, key string) *IntCmd { + cmd := NewIntCmd(ctx, "vdim", key) + _ = c(ctx, cmd) + return cmd +} + +// `VEMB key element [RAW]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VEmb(ctx context.Context, key, element string, raw bool) *SliceCmd { + args := []any{"vemb", key, element} + if raw { + args = append(args, "raw") + } + cmd := NewSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VGETATTR key element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VGetAttr(ctx context.Context, key, element string) *StringCmd { + cmd := NewStringCmd(ctx, "vgetattr", key, element) + _ = c(ctx, cmd) + return cmd +} + +// `VINFO key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VInfo(ctx context.Context, key string) *MapStringInterfaceCmd { + cmd := NewMapStringInterfaceCmd(ctx, "vinfo", key) + _ = c(ctx, cmd) + return cmd +} + +// `VLINKS key element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VLinks(ctx context.Context, key, element string) *StringSliceSliceCmd { + cmd := NewStringSliceSliceCmd(ctx, "vlinks", key, element) + _ = c(ctx, cmd) + return cmd +} + +// `VLINKS key element WITHSCORES` +// note: the API is experimental and may be subject to change. +func (c cmdable) VLinksWithScores(ctx context.Context, key, element string) *VectorScoreSliceSliceCmd { + cmd := NewVectorScoreSliceSliceCmd(ctx, "vlinks", key, element, "withscores") + _ = c(ctx, cmd) + return cmd +} + +// `VRANDMEMBER key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VRandMember(ctx context.Context, key string) *StringCmd { + cmd := NewStringCmd(ctx, "vrandmember", key) + _ = c(ctx, cmd) + return cmd +} + +// `VRANDMEMBER key [count]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VRandMemberCount(ctx context.Context, key string, count int) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "vrandmember", key, count) + _ = c(ctx, cmd) + return cmd +} + +// `VREM key element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VRem(ctx context.Context, key, element string) *BoolCmd { + cmd := NewBoolCmd(ctx, "vrem", key, element) + _ = c(ctx, cmd) + return cmd +} + +// `VSETATTR key element "{ JSON obj }"` +// The `attr` must be something that can be marshaled to JSON (using encoding/JSON) unless +// the argument is a string or []byte when we assume that it can be passed directly as JSON. +// +// note: the API is experimental and may be subject to change. +func (c cmdable) VSetAttr(ctx context.Context, key, element string, attr interface{}) *BoolCmd { + var attrStr string + var err error + switch v := attr.(type) { + case string: + attrStr = v + case []byte: + attrStr = string(v) + default: + var bytes []byte + bytes, err = json.Marshal(v) + if err != nil { + // If marshalling fails, create the command and set the error; this command won't be executed. + cmd := NewBoolCmd(ctx, "vsetattr", key, element, "") + cmd.SetErr(err) + return cmd + } + attrStr = string(bytes) + } + cmd := NewBoolCmd(ctx, "vsetattr", key, element, attrStr) + _ = c(ctx, cmd) + return cmd +} + +// `VClearAttributes` clear attributes on a vector set element. +// The implementation of `VClearAttributes` is execute command `VSETATTR key element ""`. +// note: the API is experimental and may be subject to change. +func (c cmdable) VClearAttributes(ctx context.Context, key, element string) *BoolCmd { + cmd := NewBoolCmd(ctx, "vsetattr", key, element, "") + _ = c(ctx, cmd) + return cmd +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element)` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSim(ctx context.Context, key string, val Vector) *StringSliceCmd { + return c.VSimWithArgs(ctx, key, val, &VSimArgs{}) +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) WITHSCORES` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithScores(ctx context.Context, key string, val Vector) *VectorScoreSliceCmd { + return c.VSimWithArgsWithScores(ctx, key, val, &VSimArgs{}) +} + +type VSimArgs struct { + Count int64 + EF int64 + Filter string + FilterEF int64 + Truth bool + NoThread bool + Epsilon float64 +} + +func (v VSimArgs) appendArgs(args []any) []any { + if v.Count > 0 { + args = append(args, "count", v.Count) + } + if v.EF > 0 { + args = append(args, "ef", v.EF) + } + if len(v.Filter) > 0 { + args = append(args, "filter", v.Filter) + } + if v.FilterEF > 0 { + args = append(args, "filter-ef", v.FilterEF) + } + if v.Truth { + args = append(args, "truth") + } + if v.NoThread { + args = append(args, "nothread") + } + if v.Epsilon > 0 { + args = append(args, "epsilon", v.Epsilon) + } + return args +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) [COUNT num] [EPSILON delta] +// [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithArgs(ctx context.Context, key string, val Vector, simArgs *VSimArgs) *StringSliceCmd { + if simArgs == nil { + simArgs = &VSimArgs{} + } + args := []any{"vsim", key} + args = append(args, val.Value()...) + args = simArgs.appendArgs(args) + cmd := NewStringSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) [WITHSCORES] [COUNT num] [EPSILON delta] +// [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithArgsWithScores(ctx context.Context, key string, val Vector, simArgs *VSimArgs) *VectorScoreSliceCmd { + if simArgs == nil { + simArgs = &VSimArgs{} + } + args := []any{"vsim", key} + args = append(args, val.Value()...) + args = append(args, "withscores") + args = simArgs.appendArgs(args) + cmd := NewVectorInfoSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) [WITHATTRIBS] [COUNT num] [EPSILON delta] +// [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD]` +// WITHATTRIBS is only available in Redis v8.2.0+ +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithArgsWithAttribs(ctx context.Context, key string, val Vector, simArgs *VSimArgs) *VectorAttribSliceCmd { + if simArgs == nil { + simArgs = &VSimArgs{} + } + args := []any{"vsim", key} + args = append(args, val.Value()...) + args = append(args, "withattribs") + args = simArgs.appendArgs(args) + cmd := NewVectorAttribSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON delta] +// [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD]` +// WITHATTRIBS is only available in Redis v8.2.0+ +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithArgsWithScoresWithAttribs(ctx context.Context, key string, val Vector, simArgs *VSimArgs) *VectorScoreAttribSliceCmd { + if simArgs == nil { + simArgs = &VSimArgs{} + } + args := []any{"vsim", key} + args = append(args, val.Value()...) + args = append(args, "withscores", "withattribs") + args = simArgs.appendArgs(args) + cmd := NewVectorScoreAttribSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VRANGE key start end count` +// a negative count means to return all the elements in the vector set. +// note: the API is experimental and may be subject to change. +func (c cmdable) VRange(ctx context.Context, key, start, end string, count int64) *StringSliceCmd { + args := []any{"vrange", key, start, end, count} + cmd := NewStringSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VISMEMBER key element` +// Check if an element exists in a vector set. +// note: the API is experimental and may be subject to change. +func (c cmdable) VIsMember(ctx context.Context, key, element string) *BoolCmd { + cmd := NewBoolCmd(ctx, "vismember", key, element) + _ = c(ctx, cmd) + return cmd +} diff --git a/vendor/github.com/redis/go-redis/v9/version.go b/vendor/github.com/redis/go-redis/v9/version.go index c56e04ff1..d59381170 100644 --- a/vendor/github.com/redis/go-redis/v9/version.go +++ b/vendor/github.com/redis/go-redis/v9/version.go @@ -2,5 +2,5 @@ package redis // Version is the current release version. func Version() string { - return "9.8.0" + return "9.20.1" } diff --git a/vendor/go.uber.org/atomic/.codecov.yml b/vendor/go.uber.org/atomic/.codecov.yml new file mode 100644 index 000000000..571116cc3 --- /dev/null +++ b/vendor/go.uber.org/atomic/.codecov.yml @@ -0,0 +1,19 @@ +coverage: + range: 80..100 + round: down + precision: 2 + + status: + project: # measuring the overall project coverage + default: # context, you can create multiple ones with custom titles + enabled: yes # must be yes|true to enable this status + target: 100 # specify the target coverage for each commit status + # option: "auto" (must increase from parent commit or pull request base) + # option: "X%" a static target percentage to hit + if_not_found: success # if parent is not found report status as success, error, or failure + if_ci_failed: error # if ci fails report status as success, error, or failure + +# Also update COVER_IGNORE_PKGS in the Makefile. +ignore: + - /internal/gen-atomicint/ + - /internal/gen-valuewrapper/ diff --git a/vendor/go.uber.org/atomic/.gitignore b/vendor/go.uber.org/atomic/.gitignore new file mode 100644 index 000000000..2e337a0ed --- /dev/null +++ b/vendor/go.uber.org/atomic/.gitignore @@ -0,0 +1,15 @@ +/bin +.DS_Store +/vendor +cover.html +cover.out +lint.log + +# Binaries +*.test + +# Profiling output +*.prof + +# Output of fossa analyzer +/fossa diff --git a/vendor/go.uber.org/atomic/CHANGELOG.md b/vendor/go.uber.org/atomic/CHANGELOG.md new file mode 100644 index 000000000..6f87f33fa --- /dev/null +++ b/vendor/go.uber.org/atomic/CHANGELOG.md @@ -0,0 +1,127 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.11.0] - 2023-05-02 +### Fixed +- Fix initialization of `Value` wrappers. + +### Added +- Add `String` method to `atomic.Pointer[T]` type allowing users to safely print +underlying values of pointers. + +[1.11.0]: https://github.com/uber-go/atomic/compare/v1.10.0...v1.11.0 + +## [1.10.0] - 2022-08-11 +### Added +- Add `atomic.Float32` type for atomic operations on `float32`. +- Add `CompareAndSwap` and `Swap` methods to `atomic.String`, `atomic.Error`, + and `atomic.Value`. +- Add generic `atomic.Pointer[T]` type for atomic operations on pointers of any + type. This is present only for Go 1.18 or higher, and is a drop-in for + replacement for the standard library's `sync/atomic.Pointer` type. + +### Changed +- Deprecate `CAS` methods on all types in favor of corresponding + `CompareAndSwap` methods. + +Thanks to @eNV25 and @icpd for their contributions to this release. + +[1.10.0]: https://github.com/uber-go/atomic/compare/v1.9.0...v1.10.0 + +## [1.9.0] - 2021-07-15 +### Added +- Add `Float64.Swap` to match int atomic operations. +- Add `atomic.Time` type for atomic operations on `time.Time` values. + +[1.9.0]: https://github.com/uber-go/atomic/compare/v1.8.0...v1.9.0 + +## [1.8.0] - 2021-06-09 +### Added +- Add `atomic.Uintptr` type for atomic operations on `uintptr` values. +- Add `atomic.UnsafePointer` type for atomic operations on `unsafe.Pointer` values. + +[1.8.0]: https://github.com/uber-go/atomic/compare/v1.7.0...v1.8.0 + +## [1.7.0] - 2020-09-14 +### Added +- Support JSON serialization and deserialization of primitive atomic types. +- Support Text marshalling and unmarshalling for string atomics. + +### Changed +- Disallow incorrect comparison of atomic values in a non-atomic way. + +### Removed +- Remove dependency on `golang.org/x/{lint, tools}`. + +[1.7.0]: https://github.com/uber-go/atomic/compare/v1.6.0...v1.7.0 + +## [1.6.0] - 2020-02-24 +### Changed +- Drop library dependency on `golang.org/x/{lint, tools}`. + +[1.6.0]: https://github.com/uber-go/atomic/compare/v1.5.1...v1.6.0 + +## [1.5.1] - 2019-11-19 +- Fix bug where `Bool.CAS` and `Bool.Toggle` do work correctly together + causing `CAS` to fail even though the old value matches. + +[1.5.1]: https://github.com/uber-go/atomic/compare/v1.5.0...v1.5.1 + +## [1.5.0] - 2019-10-29 +### Changed +- With Go modules, only the `go.uber.org/atomic` import path is supported now. + If you need to use the old import path, please add a `replace` directive to + your `go.mod`. + +[1.5.0]: https://github.com/uber-go/atomic/compare/v1.4.0...v1.5.0 + +## [1.4.0] - 2019-05-01 +### Added + - Add `atomic.Error` type for atomic operations on `error` values. + +[1.4.0]: https://github.com/uber-go/atomic/compare/v1.3.2...v1.4.0 + +## [1.3.2] - 2018-05-02 +### Added +- Add `atomic.Duration` type for atomic operations on `time.Duration` values. + +[1.3.2]: https://github.com/uber-go/atomic/compare/v1.3.1...v1.3.2 + +## [1.3.1] - 2017-11-14 +### Fixed +- Revert optimization for `atomic.String.Store("")` which caused data races. + +[1.3.1]: https://github.com/uber-go/atomic/compare/v1.3.0...v1.3.1 + +## [1.3.0] - 2017-11-13 +### Added +- Add `atomic.Bool.CAS` for compare-and-swap semantics on bools. + +### Changed +- Optimize `atomic.String.Store("")` by avoiding an allocation. + +[1.3.0]: https://github.com/uber-go/atomic/compare/v1.2.0...v1.3.0 + +## [1.2.0] - 2017-04-12 +### Added +- Shadow `atomic.Value` from `sync/atomic`. + +[1.2.0]: https://github.com/uber-go/atomic/compare/v1.1.0...v1.2.0 + +## [1.1.0] - 2017-03-10 +### Added +- Add atomic `Float64` type. + +### Changed +- Support new `go.uber.org/atomic` import path. + +[1.1.0]: https://github.com/uber-go/atomic/compare/v1.0.0...v1.1.0 + +## [1.0.0] - 2016-07-18 + +- Initial release. + +[1.0.0]: https://github.com/uber-go/atomic/releases/tag/v1.0.0 diff --git a/vendor/github.com/dgryski/go-rendezvous/LICENSE b/vendor/go.uber.org/atomic/LICENSE.txt similarity index 92% rename from vendor/github.com/dgryski/go-rendezvous/LICENSE rename to vendor/go.uber.org/atomic/LICENSE.txt index 22080f736..8765c9fbc 100644 --- a/vendor/github.com/dgryski/go-rendezvous/LICENSE +++ b/vendor/go.uber.org/atomic/LICENSE.txt @@ -1,6 +1,4 @@ -The MIT License (MIT) - -Copyright (c) 2017-2020 Damian Gryski +Copyright (c) 2016 Uber Technologies, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/vendor/go.uber.org/atomic/Makefile b/vendor/go.uber.org/atomic/Makefile new file mode 100644 index 000000000..46c945b32 --- /dev/null +++ b/vendor/go.uber.org/atomic/Makefile @@ -0,0 +1,79 @@ +# Directory to place `go install`ed binaries into. +export GOBIN ?= $(shell pwd)/bin + +GOLINT = $(GOBIN)/golint +GEN_ATOMICINT = $(GOBIN)/gen-atomicint +GEN_ATOMICWRAPPER = $(GOBIN)/gen-atomicwrapper +STATICCHECK = $(GOBIN)/staticcheck + +GO_FILES ?= $(shell find . '(' -path .git -o -path vendor ')' -prune -o -name '*.go' -print) + +# Also update ignore section in .codecov.yml. +COVER_IGNORE_PKGS = \ + go.uber.org/atomic/internal/gen-atomicint \ + go.uber.org/atomic/internal/gen-atomicwrapper + +.PHONY: build +build: + go build ./... + +.PHONY: test +test: + go test -race ./... + +.PHONY: gofmt +gofmt: + $(eval FMT_LOG := $(shell mktemp -t gofmt.XXXXX)) + gofmt -e -s -l $(GO_FILES) > $(FMT_LOG) || true + @[ ! -s "$(FMT_LOG)" ] || (echo "gofmt failed:" && cat $(FMT_LOG) && false) + +$(GOLINT): + cd tools && go install golang.org/x/lint/golint + +$(STATICCHECK): + cd tools && go install honnef.co/go/tools/cmd/staticcheck + +$(GEN_ATOMICWRAPPER): $(wildcard ./internal/gen-atomicwrapper/*) + go build -o $@ ./internal/gen-atomicwrapper + +$(GEN_ATOMICINT): $(wildcard ./internal/gen-atomicint/*) + go build -o $@ ./internal/gen-atomicint + +.PHONY: golint +golint: $(GOLINT) + $(GOLINT) ./... + +.PHONY: staticcheck +staticcheck: $(STATICCHECK) + $(STATICCHECK) ./... + +.PHONY: lint +lint: gofmt golint staticcheck generatenodirty + +# comma separated list of packages to consider for code coverage. +COVER_PKG = $(shell \ + go list -find ./... | \ + grep -v $(foreach pkg,$(COVER_IGNORE_PKGS),-e "^$(pkg)$$") | \ + paste -sd, -) + +.PHONY: cover +cover: + go test -coverprofile=cover.out -coverpkg $(COVER_PKG) -v ./... + go tool cover -html=cover.out -o cover.html + +.PHONY: generate +generate: $(GEN_ATOMICINT) $(GEN_ATOMICWRAPPER) + go generate ./... + +.PHONY: generatenodirty +generatenodirty: + @[ -z "$$(git status --porcelain)" ] || ( \ + echo "Working tree is dirty. Commit your changes first."; \ + git status; \ + exit 1 ) + @make generate + @status=$$(git status --porcelain); \ + [ -z "$$status" ] || ( \ + echo "Working tree is dirty after `make generate`:"; \ + echo "$$status"; \ + echo "Please ensure that the generated code is up-to-date." ) diff --git a/vendor/go.uber.org/atomic/README.md b/vendor/go.uber.org/atomic/README.md new file mode 100644 index 000000000..96b47a1f1 --- /dev/null +++ b/vendor/go.uber.org/atomic/README.md @@ -0,0 +1,63 @@ +# atomic [![GoDoc][doc-img]][doc] [![Build Status][ci-img]][ci] [![Coverage Status][cov-img]][cov] [![Go Report Card][reportcard-img]][reportcard] + +Simple wrappers for primitive types to enforce atomic access. + +## Installation + +```shell +$ go get -u go.uber.org/atomic@v1 +``` + +### Legacy Import Path + +As of v1.5.0, the import path `go.uber.org/atomic` is the only supported way +of using this package. If you are using Go modules, this package will fail to +compile with the legacy import path path `github.com/uber-go/atomic`. + +We recommend migrating your code to the new import path but if you're unable +to do so, or if your dependencies are still using the old import path, you +will have to add a `replace` directive to your `go.mod` file downgrading the +legacy import path to an older version. + +``` +replace github.com/uber-go/atomic => github.com/uber-go/atomic v1.4.0 +``` + +You can do so automatically by running the following command. + +```shell +$ go mod edit -replace github.com/uber-go/atomic=github.com/uber-go/atomic@v1.4.0 +``` + +## Usage + +The standard library's `sync/atomic` is powerful, but it's easy to forget which +variables must be accessed atomically. `go.uber.org/atomic` preserves all the +functionality of the standard library, but wraps the primitive types to +provide a safer, more convenient API. + +```go +var atom atomic.Uint32 +atom.Store(42) +atom.Sub(2) +atom.CAS(40, 11) +``` + +See the [documentation][doc] for a complete API specification. + +## Development Status + +Stable. + +--- + +Released under the [MIT License](LICENSE.txt). + +[doc-img]: https://godoc.org/github.com/uber-go/atomic?status.svg +[doc]: https://godoc.org/go.uber.org/atomic +[ci-img]: https://github.com/uber-go/atomic/actions/workflows/go.yml/badge.svg +[ci]: https://github.com/uber-go/atomic/actions/workflows/go.yml +[cov-img]: https://codecov.io/gh/uber-go/atomic/branch/master/graph/badge.svg +[cov]: https://codecov.io/gh/uber-go/atomic +[reportcard-img]: https://goreportcard.com/badge/go.uber.org/atomic +[reportcard]: https://goreportcard.com/report/go.uber.org/atomic diff --git a/vendor/go.uber.org/atomic/bool.go b/vendor/go.uber.org/atomic/bool.go new file mode 100644 index 000000000..f0a2ddd14 --- /dev/null +++ b/vendor/go.uber.org/atomic/bool.go @@ -0,0 +1,88 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" +) + +// Bool is an atomic type-safe wrapper for bool values. +type Bool struct { + _ nocmp // disallow non-atomic comparison + + v Uint32 +} + +var _zeroBool bool + +// NewBool creates a new Bool. +func NewBool(val bool) *Bool { + x := &Bool{} + if val != _zeroBool { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped bool. +func (x *Bool) Load() bool { + return truthy(x.v.Load()) +} + +// Store atomically stores the passed bool. +func (x *Bool) Store(val bool) { + x.v.Store(boolToInt(val)) +} + +// CAS is an atomic compare-and-swap for bool values. +// +// Deprecated: Use CompareAndSwap. +func (x *Bool) CAS(old, new bool) (swapped bool) { + return x.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for bool values. +func (x *Bool) CompareAndSwap(old, new bool) (swapped bool) { + return x.v.CompareAndSwap(boolToInt(old), boolToInt(new)) +} + +// Swap atomically stores the given bool and returns the old +// value. +func (x *Bool) Swap(val bool) (old bool) { + return truthy(x.v.Swap(boolToInt(val))) +} + +// MarshalJSON encodes the wrapped bool into JSON. +func (x *Bool) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a bool from JSON. +func (x *Bool) UnmarshalJSON(b []byte) error { + var v bool + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/vendor/go.uber.org/atomic/bool_ext.go b/vendor/go.uber.org/atomic/bool_ext.go new file mode 100644 index 000000000..a2e60e987 --- /dev/null +++ b/vendor/go.uber.org/atomic/bool_ext.go @@ -0,0 +1,53 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "strconv" +) + +//go:generate bin/gen-atomicwrapper -name=Bool -type=bool -wrapped=Uint32 -pack=boolToInt -unpack=truthy -cas -swap -json -file=bool.go + +func truthy(n uint32) bool { + return n == 1 +} + +func boolToInt(b bool) uint32 { + if b { + return 1 + } + return 0 +} + +// Toggle atomically negates the Boolean and returns the previous value. +func (b *Bool) Toggle() (old bool) { + for { + old := b.Load() + if b.CAS(old, !old) { + return old + } + } +} + +// String encodes the wrapped value as a string. +func (b *Bool) String() string { + return strconv.FormatBool(b.Load()) +} diff --git a/vendor/go.uber.org/atomic/doc.go b/vendor/go.uber.org/atomic/doc.go new file mode 100644 index 000000000..ae7390ee6 --- /dev/null +++ b/vendor/go.uber.org/atomic/doc.go @@ -0,0 +1,23 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Package atomic provides simple wrappers around numerics to enforce atomic +// access. +package atomic diff --git a/vendor/go.uber.org/atomic/duration.go b/vendor/go.uber.org/atomic/duration.go new file mode 100644 index 000000000..7c23868fc --- /dev/null +++ b/vendor/go.uber.org/atomic/duration.go @@ -0,0 +1,89 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "time" +) + +// Duration is an atomic type-safe wrapper for time.Duration values. +type Duration struct { + _ nocmp // disallow non-atomic comparison + + v Int64 +} + +var _zeroDuration time.Duration + +// NewDuration creates a new Duration. +func NewDuration(val time.Duration) *Duration { + x := &Duration{} + if val != _zeroDuration { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped time.Duration. +func (x *Duration) Load() time.Duration { + return time.Duration(x.v.Load()) +} + +// Store atomically stores the passed time.Duration. +func (x *Duration) Store(val time.Duration) { + x.v.Store(int64(val)) +} + +// CAS is an atomic compare-and-swap for time.Duration values. +// +// Deprecated: Use CompareAndSwap. +func (x *Duration) CAS(old, new time.Duration) (swapped bool) { + return x.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for time.Duration values. +func (x *Duration) CompareAndSwap(old, new time.Duration) (swapped bool) { + return x.v.CompareAndSwap(int64(old), int64(new)) +} + +// Swap atomically stores the given time.Duration and returns the old +// value. +func (x *Duration) Swap(val time.Duration) (old time.Duration) { + return time.Duration(x.v.Swap(int64(val))) +} + +// MarshalJSON encodes the wrapped time.Duration into JSON. +func (x *Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a time.Duration from JSON. +func (x *Duration) UnmarshalJSON(b []byte) error { + var v time.Duration + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/vendor/go.uber.org/atomic/duration_ext.go b/vendor/go.uber.org/atomic/duration_ext.go new file mode 100644 index 000000000..4c18b0a9e --- /dev/null +++ b/vendor/go.uber.org/atomic/duration_ext.go @@ -0,0 +1,40 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import "time" + +//go:generate bin/gen-atomicwrapper -name=Duration -type=time.Duration -wrapped=Int64 -pack=int64 -unpack=time.Duration -cas -swap -json -imports time -file=duration.go + +// Add atomically adds to the wrapped time.Duration and returns the new value. +func (d *Duration) Add(delta time.Duration) time.Duration { + return time.Duration(d.v.Add(int64(delta))) +} + +// Sub atomically subtracts from the wrapped time.Duration and returns the new value. +func (d *Duration) Sub(delta time.Duration) time.Duration { + return time.Duration(d.v.Sub(int64(delta))) +} + +// String encodes the wrapped value as a string. +func (d *Duration) String() string { + return d.Load().String() +} diff --git a/vendor/go.uber.org/atomic/error.go b/vendor/go.uber.org/atomic/error.go new file mode 100644 index 000000000..b7e3f1291 --- /dev/null +++ b/vendor/go.uber.org/atomic/error.go @@ -0,0 +1,72 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// Error is an atomic type-safe wrapper for error values. +type Error struct { + _ nocmp // disallow non-atomic comparison + + v Value +} + +var _zeroError error + +// NewError creates a new Error. +func NewError(val error) *Error { + x := &Error{} + if val != _zeroError { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped error. +func (x *Error) Load() error { + return unpackError(x.v.Load()) +} + +// Store atomically stores the passed error. +func (x *Error) Store(val error) { + x.v.Store(packError(val)) +} + +// CompareAndSwap is an atomic compare-and-swap for error values. +func (x *Error) CompareAndSwap(old, new error) (swapped bool) { + if x.v.CompareAndSwap(packError(old), packError(new)) { + return true + } + + if old == _zeroError { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, packError(new)) + } + + return false +} + +// Swap atomically stores the given error and returns the old +// value. +func (x *Error) Swap(val error) (old error) { + return unpackError(x.v.Swap(packError(val))) +} diff --git a/vendor/go.uber.org/atomic/error_ext.go b/vendor/go.uber.org/atomic/error_ext.go new file mode 100644 index 000000000..d31fb633b --- /dev/null +++ b/vendor/go.uber.org/atomic/error_ext.go @@ -0,0 +1,39 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// atomic.Value panics on nil inputs, or if the underlying type changes. +// Stabilize by always storing a custom struct that we control. + +//go:generate bin/gen-atomicwrapper -name=Error -type=error -wrapped=Value -pack=packError -unpack=unpackError -compareandswap -swap -file=error.go + +type packedError struct{ Value error } + +func packError(v error) interface{} { + return packedError{v} +} + +func unpackError(v interface{}) error { + if err, ok := v.(packedError); ok { + return err.Value + } + return nil +} diff --git a/vendor/go.uber.org/atomic/float32.go b/vendor/go.uber.org/atomic/float32.go new file mode 100644 index 000000000..62c36334f --- /dev/null +++ b/vendor/go.uber.org/atomic/float32.go @@ -0,0 +1,77 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" +) + +// Float32 is an atomic type-safe wrapper for float32 values. +type Float32 struct { + _ nocmp // disallow non-atomic comparison + + v Uint32 +} + +var _zeroFloat32 float32 + +// NewFloat32 creates a new Float32. +func NewFloat32(val float32) *Float32 { + x := &Float32{} + if val != _zeroFloat32 { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped float32. +func (x *Float32) Load() float32 { + return math.Float32frombits(x.v.Load()) +} + +// Store atomically stores the passed float32. +func (x *Float32) Store(val float32) { + x.v.Store(math.Float32bits(val)) +} + +// Swap atomically stores the given float32 and returns the old +// value. +func (x *Float32) Swap(val float32) (old float32) { + return math.Float32frombits(x.v.Swap(math.Float32bits(val))) +} + +// MarshalJSON encodes the wrapped float32 into JSON. +func (x *Float32) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a float32 from JSON. +func (x *Float32) UnmarshalJSON(b []byte) error { + var v float32 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/vendor/go.uber.org/atomic/float32_ext.go b/vendor/go.uber.org/atomic/float32_ext.go new file mode 100644 index 000000000..b0cd8d9c8 --- /dev/null +++ b/vendor/go.uber.org/atomic/float32_ext.go @@ -0,0 +1,76 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "math" + "strconv" +) + +//go:generate bin/gen-atomicwrapper -name=Float32 -type=float32 -wrapped=Uint32 -pack=math.Float32bits -unpack=math.Float32frombits -swap -json -imports math -file=float32.go + +// Add atomically adds to the wrapped float32 and returns the new value. +func (f *Float32) Add(delta float32) float32 { + for { + old := f.Load() + new := old + delta + if f.CAS(old, new) { + return new + } + } +} + +// Sub atomically subtracts from the wrapped float32 and returns the new value. +func (f *Float32) Sub(delta float32) float32 { + return f.Add(-delta) +} + +// CAS is an atomic compare-and-swap for float32 values. +// +// Deprecated: Use CompareAndSwap +func (f *Float32) CAS(old, new float32) (swapped bool) { + return f.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for float32 values. +// +// Note: CompareAndSwap handles NaN incorrectly. NaN != NaN using Go's inbuilt operators +// but CompareAndSwap allows a stored NaN to compare equal to a passed in NaN. +// This avoids typical CompareAndSwap loops from blocking forever, e.g., +// +// for { +// old := atom.Load() +// new = f(old) +// if atom.CompareAndSwap(old, new) { +// break +// } +// } +// +// If CompareAndSwap did not match NaN to match, then the above would loop forever. +func (f *Float32) CompareAndSwap(old, new float32) (swapped bool) { + return f.v.CompareAndSwap(math.Float32bits(old), math.Float32bits(new)) +} + +// String encodes the wrapped value as a string. +func (f *Float32) String() string { + // 'g' is the behavior for floats with %v. + return strconv.FormatFloat(float64(f.Load()), 'g', -1, 32) +} diff --git a/vendor/go.uber.org/atomic/float64.go b/vendor/go.uber.org/atomic/float64.go new file mode 100644 index 000000000..5bc11caab --- /dev/null +++ b/vendor/go.uber.org/atomic/float64.go @@ -0,0 +1,77 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" +) + +// Float64 is an atomic type-safe wrapper for float64 values. +type Float64 struct { + _ nocmp // disallow non-atomic comparison + + v Uint64 +} + +var _zeroFloat64 float64 + +// NewFloat64 creates a new Float64. +func NewFloat64(val float64) *Float64 { + x := &Float64{} + if val != _zeroFloat64 { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped float64. +func (x *Float64) Load() float64 { + return math.Float64frombits(x.v.Load()) +} + +// Store atomically stores the passed float64. +func (x *Float64) Store(val float64) { + x.v.Store(math.Float64bits(val)) +} + +// Swap atomically stores the given float64 and returns the old +// value. +func (x *Float64) Swap(val float64) (old float64) { + return math.Float64frombits(x.v.Swap(math.Float64bits(val))) +} + +// MarshalJSON encodes the wrapped float64 into JSON. +func (x *Float64) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a float64 from JSON. +func (x *Float64) UnmarshalJSON(b []byte) error { + var v float64 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/vendor/go.uber.org/atomic/float64_ext.go b/vendor/go.uber.org/atomic/float64_ext.go new file mode 100644 index 000000000..48c52b0ab --- /dev/null +++ b/vendor/go.uber.org/atomic/float64_ext.go @@ -0,0 +1,76 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "math" + "strconv" +) + +//go:generate bin/gen-atomicwrapper -name=Float64 -type=float64 -wrapped=Uint64 -pack=math.Float64bits -unpack=math.Float64frombits -swap -json -imports math -file=float64.go + +// Add atomically adds to the wrapped float64 and returns the new value. +func (f *Float64) Add(delta float64) float64 { + for { + old := f.Load() + new := old + delta + if f.CAS(old, new) { + return new + } + } +} + +// Sub atomically subtracts from the wrapped float64 and returns the new value. +func (f *Float64) Sub(delta float64) float64 { + return f.Add(-delta) +} + +// CAS is an atomic compare-and-swap for float64 values. +// +// Deprecated: Use CompareAndSwap +func (f *Float64) CAS(old, new float64) (swapped bool) { + return f.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for float64 values. +// +// Note: CompareAndSwap handles NaN incorrectly. NaN != NaN using Go's inbuilt operators +// but CompareAndSwap allows a stored NaN to compare equal to a passed in NaN. +// This avoids typical CompareAndSwap loops from blocking forever, e.g., +// +// for { +// old := atom.Load() +// new = f(old) +// if atom.CompareAndSwap(old, new) { +// break +// } +// } +// +// If CompareAndSwap did not match NaN to match, then the above would loop forever. +func (f *Float64) CompareAndSwap(old, new float64) (swapped bool) { + return f.v.CompareAndSwap(math.Float64bits(old), math.Float64bits(new)) +} + +// String encodes the wrapped value as a string. +func (f *Float64) String() string { + // 'g' is the behavior for floats with %v. + return strconv.FormatFloat(f.Load(), 'g', -1, 64) +} diff --git a/vendor/go.uber.org/atomic/gen.go b/vendor/go.uber.org/atomic/gen.go new file mode 100644 index 000000000..1e9ef4f87 --- /dev/null +++ b/vendor/go.uber.org/atomic/gen.go @@ -0,0 +1,27 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +//go:generate bin/gen-atomicint -name=Int32 -wrapped=int32 -file=int32.go +//go:generate bin/gen-atomicint -name=Int64 -wrapped=int64 -file=int64.go +//go:generate bin/gen-atomicint -name=Uint32 -wrapped=uint32 -unsigned -file=uint32.go +//go:generate bin/gen-atomicint -name=Uint64 -wrapped=uint64 -unsigned -file=uint64.go +//go:generate bin/gen-atomicint -name=Uintptr -wrapped=uintptr -unsigned -file=uintptr.go diff --git a/vendor/go.uber.org/atomic/int32.go b/vendor/go.uber.org/atomic/int32.go new file mode 100644 index 000000000..5320eac10 --- /dev/null +++ b/vendor/go.uber.org/atomic/int32.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Int32 is an atomic wrapper around int32. +type Int32 struct { + _ nocmp // disallow non-atomic comparison + + v int32 +} + +// NewInt32 creates a new Int32. +func NewInt32(val int32) *Int32 { + return &Int32{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Int32) Load() int32 { + return atomic.LoadInt32(&i.v) +} + +// Add atomically adds to the wrapped int32 and returns the new value. +func (i *Int32) Add(delta int32) int32 { + return atomic.AddInt32(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped int32 and returns the new value. +func (i *Int32) Sub(delta int32) int32 { + return atomic.AddInt32(&i.v, -delta) +} + +// Inc atomically increments the wrapped int32 and returns the new value. +func (i *Int32) Inc() int32 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped int32 and returns the new value. +func (i *Int32) Dec() int32 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Int32) CAS(old, new int32) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Int32) CompareAndSwap(old, new int32) (swapped bool) { + return atomic.CompareAndSwapInt32(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Int32) Store(val int32) { + atomic.StoreInt32(&i.v, val) +} + +// Swap atomically swaps the wrapped int32 and returns the old value. +func (i *Int32) Swap(val int32) (old int32) { + return atomic.SwapInt32(&i.v, val) +} + +// MarshalJSON encodes the wrapped int32 into JSON. +func (i *Int32) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped int32. +func (i *Int32) UnmarshalJSON(b []byte) error { + var v int32 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Int32) String() string { + v := i.Load() + return strconv.FormatInt(int64(v), 10) +} diff --git a/vendor/go.uber.org/atomic/int64.go b/vendor/go.uber.org/atomic/int64.go new file mode 100644 index 000000000..460821d00 --- /dev/null +++ b/vendor/go.uber.org/atomic/int64.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Int64 is an atomic wrapper around int64. +type Int64 struct { + _ nocmp // disallow non-atomic comparison + + v int64 +} + +// NewInt64 creates a new Int64. +func NewInt64(val int64) *Int64 { + return &Int64{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Int64) Load() int64 { + return atomic.LoadInt64(&i.v) +} + +// Add atomically adds to the wrapped int64 and returns the new value. +func (i *Int64) Add(delta int64) int64 { + return atomic.AddInt64(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped int64 and returns the new value. +func (i *Int64) Sub(delta int64) int64 { + return atomic.AddInt64(&i.v, -delta) +} + +// Inc atomically increments the wrapped int64 and returns the new value. +func (i *Int64) Inc() int64 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped int64 and returns the new value. +func (i *Int64) Dec() int64 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Int64) CAS(old, new int64) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Int64) CompareAndSwap(old, new int64) (swapped bool) { + return atomic.CompareAndSwapInt64(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Int64) Store(val int64) { + atomic.StoreInt64(&i.v, val) +} + +// Swap atomically swaps the wrapped int64 and returns the old value. +func (i *Int64) Swap(val int64) (old int64) { + return atomic.SwapInt64(&i.v, val) +} + +// MarshalJSON encodes the wrapped int64 into JSON. +func (i *Int64) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped int64. +func (i *Int64) UnmarshalJSON(b []byte) error { + var v int64 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Int64) String() string { + v := i.Load() + return strconv.FormatInt(int64(v), 10) +} diff --git a/vendor/go.uber.org/atomic/nocmp.go b/vendor/go.uber.org/atomic/nocmp.go new file mode 100644 index 000000000..54b74174a --- /dev/null +++ b/vendor/go.uber.org/atomic/nocmp.go @@ -0,0 +1,35 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// nocmp is an uncomparable struct. Embed this inside another struct to make +// it uncomparable. +// +// type Foo struct { +// nocmp +// // ... +// } +// +// This DOES NOT: +// +// - Disallow shallow copies of structs +// - Disallow comparison of pointers to uncomparable structs +type nocmp [0]func() diff --git a/vendor/go.uber.org/atomic/pointer_go118.go b/vendor/go.uber.org/atomic/pointer_go118.go new file mode 100644 index 000000000..1fb6c03b2 --- /dev/null +++ b/vendor/go.uber.org/atomic/pointer_go118.go @@ -0,0 +1,31 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build go1.18 +// +build go1.18 + +package atomic + +import "fmt" + +// String returns a human readable representation of a Pointer's underlying value. +func (p *Pointer[T]) String() string { + return fmt.Sprint(p.Load()) +} diff --git a/vendor/go.uber.org/atomic/pointer_go118_pre119.go b/vendor/go.uber.org/atomic/pointer_go118_pre119.go new file mode 100644 index 000000000..e0f47dba4 --- /dev/null +++ b/vendor/go.uber.org/atomic/pointer_go118_pre119.go @@ -0,0 +1,60 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build go1.18 && !go1.19 +// +build go1.18,!go1.19 + +package atomic + +import "unsafe" + +type Pointer[T any] struct { + _ nocmp // disallow non-atomic comparison + p UnsafePointer +} + +// NewPointer creates a new Pointer. +func NewPointer[T any](v *T) *Pointer[T] { + var p Pointer[T] + if v != nil { + p.p.Store(unsafe.Pointer(v)) + } + return &p +} + +// Load atomically loads the wrapped value. +func (p *Pointer[T]) Load() *T { + return (*T)(p.p.Load()) +} + +// Store atomically stores the passed value. +func (p *Pointer[T]) Store(val *T) { + p.p.Store(unsafe.Pointer(val)) +} + +// Swap atomically swaps the wrapped pointer and returns the old value. +func (p *Pointer[T]) Swap(val *T) (old *T) { + return (*T)(p.p.Swap(unsafe.Pointer(val))) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (p *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) { + return p.p.CompareAndSwap(unsafe.Pointer(old), unsafe.Pointer(new)) +} diff --git a/vendor/go.uber.org/atomic/pointer_go119.go b/vendor/go.uber.org/atomic/pointer_go119.go new file mode 100644 index 000000000..6726f17ad --- /dev/null +++ b/vendor/go.uber.org/atomic/pointer_go119.go @@ -0,0 +1,61 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build go1.19 +// +build go1.19 + +package atomic + +import "sync/atomic" + +// Pointer is an atomic pointer of type *T. +type Pointer[T any] struct { + _ nocmp // disallow non-atomic comparison + p atomic.Pointer[T] +} + +// NewPointer creates a new Pointer. +func NewPointer[T any](v *T) *Pointer[T] { + var p Pointer[T] + if v != nil { + p.p.Store(v) + } + return &p +} + +// Load atomically loads the wrapped value. +func (p *Pointer[T]) Load() *T { + return p.p.Load() +} + +// Store atomically stores the passed value. +func (p *Pointer[T]) Store(val *T) { + p.p.Store(val) +} + +// Swap atomically swaps the wrapped pointer and returns the old value. +func (p *Pointer[T]) Swap(val *T) (old *T) { + return p.p.Swap(val) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (p *Pointer[T]) CompareAndSwap(old, new *T) (swapped bool) { + return p.p.CompareAndSwap(old, new) +} diff --git a/vendor/go.uber.org/atomic/string.go b/vendor/go.uber.org/atomic/string.go new file mode 100644 index 000000000..061466c5b --- /dev/null +++ b/vendor/go.uber.org/atomic/string.go @@ -0,0 +1,72 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// String is an atomic type-safe wrapper for string values. +type String struct { + _ nocmp // disallow non-atomic comparison + + v Value +} + +var _zeroString string + +// NewString creates a new String. +func NewString(val string) *String { + x := &String{} + if val != _zeroString { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped string. +func (x *String) Load() string { + return unpackString(x.v.Load()) +} + +// Store atomically stores the passed string. +func (x *String) Store(val string) { + x.v.Store(packString(val)) +} + +// CompareAndSwap is an atomic compare-and-swap for string values. +func (x *String) CompareAndSwap(old, new string) (swapped bool) { + if x.v.CompareAndSwap(packString(old), packString(new)) { + return true + } + + if old == _zeroString { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, packString(new)) + } + + return false +} + +// Swap atomically stores the given string and returns the old +// value. +func (x *String) Swap(val string) (old string) { + return unpackString(x.v.Swap(packString(val))) +} diff --git a/vendor/go.uber.org/atomic/string_ext.go b/vendor/go.uber.org/atomic/string_ext.go new file mode 100644 index 000000000..019109c86 --- /dev/null +++ b/vendor/go.uber.org/atomic/string_ext.go @@ -0,0 +1,54 @@ +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +//go:generate bin/gen-atomicwrapper -name=String -type=string -wrapped Value -pack packString -unpack unpackString -compareandswap -swap -file=string.go + +func packString(s string) interface{} { + return s +} + +func unpackString(v interface{}) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +// String returns the wrapped value. +func (s *String) String() string { + return s.Load() +} + +// MarshalText encodes the wrapped string into a textual form. +// +// This makes it encodable as JSON, YAML, XML, and more. +func (s *String) MarshalText() ([]byte, error) { + return []byte(s.Load()), nil +} + +// UnmarshalText decodes text and replaces the wrapped string with it. +// +// This makes it decodable from JSON, YAML, XML, and more. +func (s *String) UnmarshalText(b []byte) error { + s.Store(string(b)) + return nil +} diff --git a/vendor/go.uber.org/atomic/time.go b/vendor/go.uber.org/atomic/time.go new file mode 100644 index 000000000..cc2a230c0 --- /dev/null +++ b/vendor/go.uber.org/atomic/time.go @@ -0,0 +1,55 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "time" +) + +// Time is an atomic type-safe wrapper for time.Time values. +type Time struct { + _ nocmp // disallow non-atomic comparison + + v Value +} + +var _zeroTime time.Time + +// NewTime creates a new Time. +func NewTime(val time.Time) *Time { + x := &Time{} + if val != _zeroTime { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped time.Time. +func (x *Time) Load() time.Time { + return unpackTime(x.v.Load()) +} + +// Store atomically stores the passed time.Time. +func (x *Time) Store(val time.Time) { + x.v.Store(packTime(val)) +} diff --git a/vendor/go.uber.org/atomic/time_ext.go b/vendor/go.uber.org/atomic/time_ext.go new file mode 100644 index 000000000..1e3dc978a --- /dev/null +++ b/vendor/go.uber.org/atomic/time_ext.go @@ -0,0 +1,36 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import "time" + +//go:generate bin/gen-atomicwrapper -name=Time -type=time.Time -wrapped=Value -pack=packTime -unpack=unpackTime -imports time -file=time.go + +func packTime(t time.Time) interface{} { + return t +} + +func unpackTime(v interface{}) time.Time { + if t, ok := v.(time.Time); ok { + return t + } + return time.Time{} +} diff --git a/vendor/go.uber.org/atomic/uint32.go b/vendor/go.uber.org/atomic/uint32.go new file mode 100644 index 000000000..4adc294ac --- /dev/null +++ b/vendor/go.uber.org/atomic/uint32.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Uint32 is an atomic wrapper around uint32. +type Uint32 struct { + _ nocmp // disallow non-atomic comparison + + v uint32 +} + +// NewUint32 creates a new Uint32. +func NewUint32(val uint32) *Uint32 { + return &Uint32{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Uint32) Load() uint32 { + return atomic.LoadUint32(&i.v) +} + +// Add atomically adds to the wrapped uint32 and returns the new value. +func (i *Uint32) Add(delta uint32) uint32 { + return atomic.AddUint32(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped uint32 and returns the new value. +func (i *Uint32) Sub(delta uint32) uint32 { + return atomic.AddUint32(&i.v, ^(delta - 1)) +} + +// Inc atomically increments the wrapped uint32 and returns the new value. +func (i *Uint32) Inc() uint32 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped uint32 and returns the new value. +func (i *Uint32) Dec() uint32 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Uint32) CAS(old, new uint32) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Uint32) CompareAndSwap(old, new uint32) (swapped bool) { + return atomic.CompareAndSwapUint32(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Uint32) Store(val uint32) { + atomic.StoreUint32(&i.v, val) +} + +// Swap atomically swaps the wrapped uint32 and returns the old value. +func (i *Uint32) Swap(val uint32) (old uint32) { + return atomic.SwapUint32(&i.v, val) +} + +// MarshalJSON encodes the wrapped uint32 into JSON. +func (i *Uint32) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped uint32. +func (i *Uint32) UnmarshalJSON(b []byte) error { + var v uint32 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Uint32) String() string { + v := i.Load() + return strconv.FormatUint(uint64(v), 10) +} diff --git a/vendor/go.uber.org/atomic/uint64.go b/vendor/go.uber.org/atomic/uint64.go new file mode 100644 index 000000000..0e2eddb30 --- /dev/null +++ b/vendor/go.uber.org/atomic/uint64.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Uint64 is an atomic wrapper around uint64. +type Uint64 struct { + _ nocmp // disallow non-atomic comparison + + v uint64 +} + +// NewUint64 creates a new Uint64. +func NewUint64(val uint64) *Uint64 { + return &Uint64{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Uint64) Load() uint64 { + return atomic.LoadUint64(&i.v) +} + +// Add atomically adds to the wrapped uint64 and returns the new value. +func (i *Uint64) Add(delta uint64) uint64 { + return atomic.AddUint64(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped uint64 and returns the new value. +func (i *Uint64) Sub(delta uint64) uint64 { + return atomic.AddUint64(&i.v, ^(delta - 1)) +} + +// Inc atomically increments the wrapped uint64 and returns the new value. +func (i *Uint64) Inc() uint64 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped uint64 and returns the new value. +func (i *Uint64) Dec() uint64 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Uint64) CAS(old, new uint64) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Uint64) CompareAndSwap(old, new uint64) (swapped bool) { + return atomic.CompareAndSwapUint64(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Uint64) Store(val uint64) { + atomic.StoreUint64(&i.v, val) +} + +// Swap atomically swaps the wrapped uint64 and returns the old value. +func (i *Uint64) Swap(val uint64) (old uint64) { + return atomic.SwapUint64(&i.v, val) +} + +// MarshalJSON encodes the wrapped uint64 into JSON. +func (i *Uint64) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped uint64. +func (i *Uint64) UnmarshalJSON(b []byte) error { + var v uint64 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Uint64) String() string { + v := i.Load() + return strconv.FormatUint(uint64(v), 10) +} diff --git a/vendor/go.uber.org/atomic/uintptr.go b/vendor/go.uber.org/atomic/uintptr.go new file mode 100644 index 000000000..7d5b000d6 --- /dev/null +++ b/vendor/go.uber.org/atomic/uintptr.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Uintptr is an atomic wrapper around uintptr. +type Uintptr struct { + _ nocmp // disallow non-atomic comparison + + v uintptr +} + +// NewUintptr creates a new Uintptr. +func NewUintptr(val uintptr) *Uintptr { + return &Uintptr{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Uintptr) Load() uintptr { + return atomic.LoadUintptr(&i.v) +} + +// Add atomically adds to the wrapped uintptr and returns the new value. +func (i *Uintptr) Add(delta uintptr) uintptr { + return atomic.AddUintptr(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped uintptr and returns the new value. +func (i *Uintptr) Sub(delta uintptr) uintptr { + return atomic.AddUintptr(&i.v, ^(delta - 1)) +} + +// Inc atomically increments the wrapped uintptr and returns the new value. +func (i *Uintptr) Inc() uintptr { + return i.Add(1) +} + +// Dec atomically decrements the wrapped uintptr and returns the new value. +func (i *Uintptr) Dec() uintptr { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Uintptr) CAS(old, new uintptr) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Uintptr) CompareAndSwap(old, new uintptr) (swapped bool) { + return atomic.CompareAndSwapUintptr(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Uintptr) Store(val uintptr) { + atomic.StoreUintptr(&i.v, val) +} + +// Swap atomically swaps the wrapped uintptr and returns the old value. +func (i *Uintptr) Swap(val uintptr) (old uintptr) { + return atomic.SwapUintptr(&i.v, val) +} + +// MarshalJSON encodes the wrapped uintptr into JSON. +func (i *Uintptr) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped uintptr. +func (i *Uintptr) UnmarshalJSON(b []byte) error { + var v uintptr + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Uintptr) String() string { + v := i.Load() + return strconv.FormatUint(uint64(v), 10) +} diff --git a/vendor/go.uber.org/atomic/unsafe_pointer.go b/vendor/go.uber.org/atomic/unsafe_pointer.go new file mode 100644 index 000000000..34868baf6 --- /dev/null +++ b/vendor/go.uber.org/atomic/unsafe_pointer.go @@ -0,0 +1,65 @@ +// Copyright (c) 2021-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "sync/atomic" + "unsafe" +) + +// UnsafePointer is an atomic wrapper around unsafe.Pointer. +type UnsafePointer struct { + _ nocmp // disallow non-atomic comparison + + v unsafe.Pointer +} + +// NewUnsafePointer creates a new UnsafePointer. +func NewUnsafePointer(val unsafe.Pointer) *UnsafePointer { + return &UnsafePointer{v: val} +} + +// Load atomically loads the wrapped value. +func (p *UnsafePointer) Load() unsafe.Pointer { + return atomic.LoadPointer(&p.v) +} + +// Store atomically stores the passed value. +func (p *UnsafePointer) Store(val unsafe.Pointer) { + atomic.StorePointer(&p.v, val) +} + +// Swap atomically swaps the wrapped unsafe.Pointer and returns the old value. +func (p *UnsafePointer) Swap(val unsafe.Pointer) (old unsafe.Pointer) { + return atomic.SwapPointer(&p.v, val) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap +func (p *UnsafePointer) CAS(old, new unsafe.Pointer) (swapped bool) { + return p.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (p *UnsafePointer) CompareAndSwap(old, new unsafe.Pointer) (swapped bool) { + return atomic.CompareAndSwapPointer(&p.v, old, new) +} diff --git a/vendor/go.uber.org/atomic/value.go b/vendor/go.uber.org/atomic/value.go new file mode 100644 index 000000000..52caedb9a --- /dev/null +++ b/vendor/go.uber.org/atomic/value.go @@ -0,0 +1,31 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import "sync/atomic" + +// Value shadows the type of the same name from sync/atomic +// https://godoc.org/sync/atomic#Value +type Value struct { + _ nocmp // disallow non-atomic comparison + + atomic.Value +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 78d21eef9..33f70192d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -272,9 +272,6 @@ github.com/davecgh/go-spew/spew # github.com/davidmz/go-pageant v1.0.2 ## explicit; go 1.16 github.com/davidmz/go-pageant -# github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f -## explicit -github.com/dgryski/go-rendezvous # github.com/distribution/reference v0.6.0 ## explicit; go 1.20 github.com/distribution/reference @@ -735,16 +732,23 @@ github.com/prometheus/common/model github.com/prometheus/procfs github.com/prometheus/procfs/internal/fs github.com/prometheus/procfs/internal/util -# github.com/redis/go-redis/v9 v9.8.0 -## explicit; go 1.18 +# github.com/redis/go-redis/v9 v9.20.1 +## explicit; go 1.24 github.com/redis/go-redis/v9 +github.com/redis/go-redis/v9/auth github.com/redis/go-redis/v9/internal +github.com/redis/go-redis/v9/internal/auth/streaming github.com/redis/go-redis/v9/internal/hashtag github.com/redis/go-redis/v9/internal/hscan +github.com/redis/go-redis/v9/internal/interfaces +github.com/redis/go-redis/v9/internal/maintnotifications/logs +github.com/redis/go-redis/v9/internal/otel github.com/redis/go-redis/v9/internal/pool github.com/redis/go-redis/v9/internal/proto -github.com/redis/go-redis/v9/internal/rand +github.com/redis/go-redis/v9/internal/routing github.com/redis/go-redis/v9/internal/util +github.com/redis/go-redis/v9/maintnotifications +github.com/redis/go-redis/v9/push # github.com/robfig/cron/v3 v3.0.2-0.20210106135023-bc59245fe10e ## explicit; go 1.12 github.com/robfig/cron/v3 @@ -820,6 +824,9 @@ go.opentelemetry.io/otel/semconv/v1.40.0 go.opentelemetry.io/otel/trace go.opentelemetry.io/otel/trace/embedded go.opentelemetry.io/otel/trace/internal/telemetry +# go.uber.org/atomic v1.11.0 +## explicit; go 1.18 +go.uber.org/atomic # go.uber.org/mock v0.6.0 ## explicit; go 1.23.0 go.uber.org/mock/gomock