From 63e17d01328d9444520e570cb1a7c28245c6e497 Mon Sep 17 00:00:00 2001 From: Cosmin Staicu Date: Mon, 15 Jun 2026 12:24:53 +0300 Subject: [PATCH] redis: add credentials provider option and built-in Entra ID auth --- go.mod | 7 ++++ go.sum | 15 ++++++++ redis/redis.go | 69 ++++++++++++++++++++++++++++------ redis/redis_test.go | 92 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 12 deletions(-) create mode 100644 redis/redis_test.go diff --git a/go.mod b/go.mod index 7aad1df98..8083be557 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/prometheus/client_golang v1.23.2 github.com/prometheus/procfs v0.20.1 github.com/puzpuzpuz/xsync/v4 v4.5.0 + github.com/redis/go-redis-entraid v1.0.7 github.com/redis/go-redis/v9 v9.20.0 github.com/stretchr/testify v1.11.1 github.com/twitchtv/twirp v8.1.3+incompatible @@ -52,6 +53,10 @@ require ( buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20260415201107-50325440f8f2.1 // indirect buf.build/go/protovalidate v1.2.0 // indirect cel.dev/expr v0.25.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect @@ -64,6 +69,7 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/klauspost/compress v1.18.6 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nats-io/nats.go v1.52.0 // indirect github.com/nats-io/nkeys v0.4.16 // indirect @@ -81,6 +87,7 @@ require ( github.com/pion/stun/v3 v3.1.4 // indirect github.com/pion/transport/v4 v4.0.2 // indirect github.com/pion/turn/v5 v5.0.8 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.68.1 // indirect diff --git a/go.sum b/go.sum index 23714658a..8d5724577 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,14 @@ buf.build/go/protoyaml v0.7.0 h1:z4oVoFicbpPefhT7WAykxUdfp0yEQlhMQ2mCZOY5V38= buf.build/go/protoyaml v0.7.0/go.mod h1:+a0cavd0uMvirb87xdu2ZMMmjlIQoiH/N2Ich5MGSQ0= cel.dev/expr v0.25.2 h1:K6j46C81hXtZQfuX60cVWQFBJahKSE2gfRbNuvr5bFs= cel.dev/expr v0.25.2/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.0 h1:j8BorDEigD8UFOSZQiSqAMOOleyQOOQPnUAwV+Ls1gA= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.0/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= @@ -83,6 +91,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/lithammer/shortuuid/v4 v4.2.0 h1:LMFOzVB3996a7b8aBuEXxqOBflbfPQAiVzkIcHO0h8c= github.com/lithammer/shortuuid/v4 v4.2.0/go.mod h1:D5noHZ2oFw/YaKCfGy0YxyE7M0wMbezmMjPdhyEFe6Y= github.com/livekit/mageutil v0.0.0-20250511045019-0f1ff63f7731 h1:9x+U2HGLrSw5ATTo469PQPkqzdoU7be46ryiCDO3boc= @@ -159,6 +169,8 @@ github.com/pion/turn/v5 v5.0.8 h1:pZUCtmwWCMkrRKqh/8pL3WoGADXBe0/lOPkN7oqFjK8= github.com/pion/turn/v5 v5.0.8/go.mod h1:1VwvxElZaOdJU0liJ/WUSm/Tsh+n2OxS5ISSDxgOWxU= github.com/pion/webrtc/v4 v4.2.11 h1:QUX1QZKlNIn4O7U5JxLPGP0sV5RTncZkzu9SPR3jVNU= github.com/pion/webrtc/v4 v4.2.11/go.mod h1:s/rAiyy77GyRFrZMx+Ls6aua26dIBPudH8/ZHYbIRWY= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= @@ -173,6 +185,8 @@ github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEy github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo= github.com/puzpuzpuz/xsync/v4 v4.5.0 h1:vOSWu6b57/emh+L/Cw0BeQfvxa/cogFywXHeGUxQxAg= github.com/puzpuzpuz/xsync/v4 v4.5.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo= +github.com/redis/go-redis-entraid v1.0.7 h1:cU+XXeCyZ8w1AcRSlCFHXMNJNmB4TA3hKwMe/HQQLTM= +github.com/redis/go-redis-entraid v1.0.7/go.mod h1:OS6s3V1DdSRzOJEIjpK38/w4chZpl/Sy+1pzby+6nEk= github.com/redis/go-redis/v9 v9.20.0 h1:WnQYxLkgO2xiXTCJY0ldIiI8dNqCDlQAG+AtaH7a2a0= github.com/redis/go-redis/v9 v9.20.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA= github.com/rodaine/protogofakeit v0.1.1 h1:ZKouljuRM3A+TArppfBqnH8tGZHOwM/pjvtXe9DaXH8= @@ -237,6 +251,7 @@ golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= diff --git a/redis/redis.go b/redis/redis.go index 6881ba1d0..b1da0189e 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -21,7 +21,9 @@ import ( "fmt" "time" + entraid "github.com/redis/go-redis-entraid" "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/auth" "github.com/livekit/protocol/xtls" @@ -50,6 +52,7 @@ type RedisConfig struct { MaxRedirects *int `yaml:"max_redirects,omitempty"` PoolTimeout time.Duration `yaml:"pool_timeout,omitempty"` PoolSize int `yaml:"pool_size,omitempty"` + AzureEntra bool `yaml:"azure_entra,omitempty"` } func (r *RedisConfig) IsConfigured() bool { @@ -72,19 +75,26 @@ func (r *RedisConfig) GetMaxRedirects() int { return 2 } -func GetRedisClient(conf *RedisConfig) (redis.UniversalClient, error) { - if conf == nil { - return nil, nil - } +type clientOptions struct { + streamingCredentialsProvider auth.StreamingCredentialsProvider +} - if !conf.IsConfigured() { - return nil, ErrNotConfigured +type Option func(*clientOptions) + +func WithStreamingCredentialsProvider(p auth.StreamingCredentialsProvider) Option { + return func(o *clientOptions) { + o.streamingCredentialsProvider = p } +} - var rcOptions *redis.UniversalOptions - var rc redis.UniversalClient - var tlsConfig *tls.Config +var azureEntraProviderFactory = newAzureEntraCredentialsProvider + +func newAzureEntraCredentialsProvider() (auth.StreamingCredentialsProvider, error) { + return entraid.NewDefaultAzureCredentialsProvider(entraid.DefaultAzureCredentialsProviderOptions{}) +} +func buildRedisOptions(conf *RedisConfig, co clientOptions) (*redis.UniversalOptions, error) { + var tlsConfig *tls.Config if conf.TLS != nil && conf.TLS.Enabled { var err error tlsConfig, err = conf.TLS.ClientTLSConfig() @@ -97,6 +107,7 @@ func GetRedisClient(conf *RedisConfig) (redis.UniversalClient, error) { } } + var rcOptions *redis.UniversalOptions if len(conf.SentinelAddresses) > 0 { logger.Infow("connecting to redis", "sentinel", true, "addr", conf.SentinelAddresses, "masterName", conf.MasterName) @@ -153,12 +164,46 @@ func GetRedisClient(conf *RedisConfig) (redis.UniversalClient, error) { PoolSize: conf.PoolSize, } } - rc = redis.NewUniversalClient(rcOptions) - if err := rc.Ping(context.Background()).Err(); err != nil { - err = fmt.Errorf("unable to connect to redis: %w", err) + provider := co.streamingCredentialsProvider + if provider == nil && conf.AzureEntra { + p, err := azureEntraProviderFactory() + if err != nil { + return nil, fmt.Errorf("unable to create Azure Entra credentials provider: %w", err) + } + provider = p + } + if provider != nil { + rcOptions.StreamingCredentialsProvider = provider + } + + return rcOptions, nil +} + +func GetRedisClient(conf *RedisConfig, opts ...Option) (redis.UniversalClient, error) { + if conf == nil { + return nil, nil + } + + if !conf.IsConfigured() { + return nil, ErrNotConfigured + } + + var co clientOptions + for _, opt := range opts { + opt(&co) + } + + rcOptions, err := buildRedisOptions(conf, co) + if err != nil { return nil, err } + rc := redis.NewUniversalClient(rcOptions) + + if err := rc.Ping(context.Background()).Err(); err != nil { + return nil, fmt.Errorf("unable to connect to redis: %w", err) + } + return rc, nil } diff --git a/redis/redis_test.go b/redis/redis_test.go new file mode 100644 index 000000000..24442c9e2 --- /dev/null +++ b/redis/redis_test.go @@ -0,0 +1,92 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package redis + +import ( + "testing" + + "github.com/redis/go-redis/v9/auth" + "github.com/stretchr/testify/require" +) + +type fakeStreamingProvider struct{} + +func (f *fakeStreamingProvider) Subscribe(_ auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) { + return nil, nil, nil +} + +func TestBuildRedisOptions_ClusterWithoutProvider(t *testing.T) { + opts, err := buildRedisOptions(&RedisConfig{ + ClusterAddresses: []string{"host:10000"}, + Username: "user", + Password: "pass", + UseTLS: true, + }, clientOptions{}) + require.NoError(t, err) + require.True(t, opts.IsClusterMode) + require.Nil(t, opts.StreamingCredentialsProvider) + require.Equal(t, "user", opts.Username) + require.Equal(t, "pass", opts.Password) + require.NotNil(t, opts.TLSConfig) +} + +func TestBuildRedisOptions_WithStreamingCredentialsProvider(t *testing.T) { + fake := &fakeStreamingProvider{} + var co clientOptions + WithStreamingCredentialsProvider(fake)(&co) + + opts, err := buildRedisOptions(&RedisConfig{ + ClusterAddresses: []string{"host:10000"}, + }, co) + require.NoError(t, err) + require.Same(t, fake, opts.StreamingCredentialsProvider) +} + +func TestBuildRedisOptions_AzureEntraFlag(t *testing.T) { + fake := &fakeStreamingProvider{} + orig := azureEntraProviderFactory + azureEntraProviderFactory = func() (auth.StreamingCredentialsProvider, error) { + return fake, nil + } + t.Cleanup(func() { azureEntraProviderFactory = orig }) + + opts, err := buildRedisOptions(&RedisConfig{ + ClusterAddresses: []string{"host:10000"}, + UseTLS: true, + AzureEntra: true, + }, clientOptions{}) + require.NoError(t, err) + require.Same(t, fake, opts.StreamingCredentialsProvider) +} + +func TestBuildRedisOptions_ExplicitProviderBeatsAzureEntraFlag(t *testing.T) { + explicit := &fakeStreamingProvider{} + azureFromFactory := &fakeStreamingProvider{} + orig := azureEntraProviderFactory + azureEntraProviderFactory = func() (auth.StreamingCredentialsProvider, error) { + return azureFromFactory, nil + } + t.Cleanup(func() { azureEntraProviderFactory = orig }) + + var co clientOptions + WithStreamingCredentialsProvider(explicit)(&co) + + opts, err := buildRedisOptions(&RedisConfig{ + ClusterAddresses: []string{"host:10000"}, + AzureEntra: true, + }, co) + require.NoError(t, err) + require.Same(t, explicit, opts.StreamingCredentialsProvider) +}