From af8889df0719e57f9f89e7d770869f696b6511a2 Mon Sep 17 00:00:00 2001 From: petruki <31597636+petruki@users.noreply.github.com> Date: Mon, 11 May 2026 20:19:45 -0700 Subject: [PATCH] fix: simplified remote APIs, timeout settings, replaced tests --- README.md | 8 +- client_test.go | 16 ++++ context.go | 20 +--- context_test.go | 26 ++--- remote.go | 35 ++----- remote_test.go | 241 ++++++++++++++++++++++++++++------------------- switcher.go | 30 +++--- switcher_test.go | 15 --- 8 files changed, 200 insertions(+), 191 deletions(-) diff --git a/README.md b/README.md index 2c55409..df88759 100644 --- a/README.md +++ b/README.md @@ -189,9 +189,7 @@ func main() { Remote: client.RemoteOptions{ CertPath: "./certs/ca.pem", ConnectTimeout: 300 * time.Millisecond, - ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, - PoolTimeout: 5 * time.Second, + Timeout: 5 * time.Second, }, }, }) @@ -223,9 +221,7 @@ func main() { |--------|------|-------------|---------| | `CertPath` | `string` | Path to custom certificate for secure API connections | `""` | | `ConnectTimeout` | `time.Duration` | Max time to establish a remote connection before failing fast | `300ms` | -| `ReadTimeout` | `time.Duration` | Max time to wait for remote response data | `5s` | -| `WriteTimeout` | `time.Duration` | Max time to send remote request data | `5s` | -| `PoolTimeout` | `time.Duration` | Max time to wait for a pooled HTTP connection | `5s` | +| `Timeout` | `time.Duration` | Max time for remote request/response and idle connection reuse | `5s` | **Under development:** transport errors are normalized into typed SDK errors, and silent mode uses the configured remote timeouts to fail fast and switch back to local evaluation. diff --git a/client_test.go b/client_test.go index eb16392..37eb7b6 100644 --- a/client_test.go +++ b/client_test.go @@ -22,6 +22,22 @@ func TestClientGetSwitcher(t *testing.T) { assert.NotSame(t, switcher1, switcher3, "expected different instances for different keys") }) + t.Run("should replace the cached switchers after rebuilding the default context", func(t *testing.T) { + BuildContext(Context{ + Domain: "First Domain", + }) + + first := GetSwitcher("switcher1") + + BuildContext(Context{ + Domain: "Second Domain", + }) + + second := GetSwitcher("switcher1") + + assert.NotSame(t, first, second, "expected a new cached switcher after rebuilding the default context") + }) + t.Run("should return the cached instance after concurrent insert", func(t *testing.T) { client := NewClient(Context{Domain: "My Domain"}) diff --git a/context.go b/context.go index 46a1551..e9426e0 100644 --- a/context.go +++ b/context.go @@ -11,17 +11,13 @@ const ( DefaultRegexMaxBlacklist = 100 DefaultRegexMaxTimeLimit = 3 * time.Second DefaultRemoteConnectTimeout = 300 * time.Millisecond - DefaultRemoteReadTimeout = 5 * time.Second - DefaultRemoteWriteTimeout = 5 * time.Second - DefaultRemotePoolTimeout = 5 * time.Second + DefaultRemoteTimeout = 5 * time.Second ) type RemoteOptions struct { CertPath string ConnectTimeout time.Duration - ReadTimeout time.Duration - WriteTimeout time.Duration - PoolTimeout time.Duration + Timeout time.Duration } type ContextOptions struct { @@ -78,16 +74,8 @@ func (o RemoteOptions) withDefaults() RemoteOptions { o.ConnectTimeout = DefaultRemoteConnectTimeout } - if o.ReadTimeout == 0 { - o.ReadTimeout = DefaultRemoteReadTimeout - } - - if o.WriteTimeout == 0 { - o.WriteTimeout = DefaultRemoteWriteTimeout - } - - if o.PoolTimeout == 0 { - o.PoolTimeout = DefaultRemotePoolTimeout + if o.Timeout == 0 { + o.Timeout = DefaultRemoteTimeout } return o diff --git a/context_test.go b/context_test.go index a147b90..997c940 100644 --- a/context_test.go +++ b/context_test.go @@ -8,7 +8,7 @@ import ( func TestBuildContext(t *testing.T) { t.Run("should preserve optional context options", func(t *testing.T) { - BuildContext(Context{ + client := NewClient(Context{ Domain: "My Domain", Options: ContextOptions{ Local: true, @@ -16,41 +16,41 @@ func TestBuildContext(t *testing.T) { }, }) - options := defaultClient().context.Options + options := client.Context().Options assert.True(t, options.Local, "expected Local option to be true") assert.Equal(t, "./tests/snapshots", options.SnapshotLocation, "expected SnapshotLocation to be './tests/snapshots'") }) t.Run("should create fresh default options on rebuild", func(t *testing.T) { - BuildContext(Context{Domain: "First Domain"}) + BuildContext(Context{ + Domain: "First Domain", + Options: ContextOptions{ + Local: true, + }, + }) firstClient := defaultClient() - firstClient.mu.Lock() - firstClient.context.Options.Local = true - firstClient.mu.Unlock() - BuildContext(Context{Domain: "Second Domain"}) secondClient := defaultClient() assert.NotSame(t, firstClient, secondClient, "expected different clients for different contexts") - assert.False(t, secondClient.context.Options.Local, "expected Local option to be false for the new client") + assert.True(t, firstClient.Context().Options.Local, "expected Local option to remain true for the first client") + assert.False(t, secondClient.Context().Options.Local, "expected Local option to be false for the new client") }) t.Run("should apply default values when omitted", func(t *testing.T) { - BuildContext(Context{ + client := NewClient(Context{ Domain: "My Domain", }) - ctx := defaultClient().context + ctx := client.Context() assert.Equal(t, DefaultEnvironment, ctx.Environment) assert.True(t, ctx.Options.RestrictRelay) assert.Equal(t, DefaultRegexMaxBlacklist, ctx.Options.RegexMaxBlacklist) assert.Equal(t, DefaultRegexMaxTimeLimit, ctx.Options.RegexMaxTimeLimit) assert.Equal(t, DefaultRemoteConnectTimeout, ctx.Options.Remote.ConnectTimeout) - assert.Equal(t, DefaultRemoteReadTimeout, ctx.Options.Remote.ReadTimeout) - assert.Equal(t, DefaultRemoteWriteTimeout, ctx.Options.Remote.WriteTimeout) - assert.Equal(t, DefaultRemotePoolTimeout, ctx.Options.Remote.PoolTimeout) + assert.Equal(t, DefaultRemoteTimeout, ctx.Options.Remote.Timeout) }) } diff --git a/remote.go b/remote.go index b2c86c5..335afc9 100644 --- a/remote.go +++ b/remote.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "cmp" "context" "crypto/tls" "encoding/json" @@ -134,15 +135,8 @@ func (c *Client) checkCriteria(token string, switcher *Switcher, showDetails boo } func (c *Client) doJSONRequest(method, endpoint string, payload any, headers map[string]string) (*http.Response, error) { - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - request, err := http.NewRequestWithContext(context.Background(), method, endpoint, bytes.NewReader(body)) - if err != nil { - return nil, err - } + body, _ := json.Marshal(payload) + request, _ := http.NewRequestWithContext(context.Background(), method, endpoint, bytes.NewReader(body)) for key, value := range headers { request.Header.Set(key, value) @@ -165,10 +159,8 @@ func (c *Client) httpClient() *http.Client { } transport := &http.Transport{ - DialContext: dialer.DialContext, - ResponseHeaderTimeout: ctx.Options.Remote.ReadTimeout, - TLSHandshakeTimeout: ctx.Options.Remote.ConnectTimeout, - IdleConnTimeout: ctx.Options.Remote.PoolTimeout, + DialContext: dialer.DialContext, + TLSHandshakeTimeout: ctx.Options.Remote.ConnectTimeout, TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, }, @@ -176,21 +168,12 @@ func (c *Client) httpClient() *http.Client { c.httpClient_ = &http.Client{ Transport: transport, - Timeout: requestTimeout(ctx.Options.Remote), + Timeout: cmp.Or(ctx.Options.Remote.Timeout, DefaultRemoteTimeout), } return c.httpClient_ } -func requestTimeout(options RemoteOptions) time.Duration { - timeout := options.ConnectTimeout + options.ReadTimeout + options.WriteTimeout - if timeout <= 0 { - return DefaultRemoteConnectTimeout + DefaultRemoteReadTimeout + DefaultRemoteWriteTimeout - } - - return timeout -} - func missingTokenError(token string) error { if strings.TrimSpace(token) != "" { return nil @@ -200,11 +183,7 @@ func missingTokenError(token string) error { } func parseTokenExpiration(value json.Number) int64 { - parsed, err := value.Int64() - if err != nil { - return 0 - } - + parsed, _ := value.Int64() return parsed } diff --git a/remote_test.go b/remote_test.go index e4a6db1..7233ac2 100644 --- a/remote_test.go +++ b/remote_test.go @@ -12,11 +12,15 @@ import ( func TestSwitcherRemoteEvaluation(t *testing.T) { t.Run("should call the remote API with success", func(t *testing.T) { + var captured map[string]any server := newRemoteTestServer(t, remoteTestHandlers{ authStatus: http.StatusOK, authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, criteriaStatus: http.StatusOK, criteriaBody: map[string]any{"result": true}, + onCriteriaRequest: func(body map[string]any, _ *http.Request) { + captured = body + }, }) defer server.Close() @@ -26,6 +30,9 @@ func TestSwitcherRemoteEvaluation(t *testing.T) { assert.NoError(t, err) assert.True(t, got) + assert.Equal(t, map[string]any{ + "entry": []any{}, + }, captured) }) t.Run("should send input parameters to the remote criteria endpoint", func(t *testing.T) { @@ -57,6 +64,28 @@ func TestSwitcherRemoteEvaluation(t *testing.T) { }, captured) }) + t.Run("should return response from the remote API without requesting details", func(t *testing.T) { + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{ + "result": true, + }, + onCriteriaRequest: func(_ map[string]any, request *http.Request) { + assert.Equal(t, "false", request.URL.Query().Get("showReason")) + }, + }) + defer server.Close() + + client := newRemoteTestClient(server.URL) + + got, err := client.GetSwitcher("MY_SWITCHER").IsOn() + + assert.NoError(t, err) + assert.True(t, got) + }) + t.Run("should return detailed response from the remote API", func(t *testing.T) { server := newRemoteTestServer(t, remoteTestHandlers{ authStatus: http.StatusOK, @@ -88,6 +117,31 @@ func TestSwitcherRemoteEvaluation(t *testing.T) { }, got.ToMap()) }) + t.Run("should request details only for the detailed call on the same switcher", func(t *testing.T) { + showReasonValues := make([]string, 0, 2) + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true, "reason": "Success"}, + onCriteriaRequest: func(_ map[string]any, request *http.Request) { + showReasonValues = append(showReasonValues, request.URL.Query().Get("showReason")) + }, + }) + defer server.Close() + + client := newRemoteTestClient(server.URL) + switcher := client.GetSwitcher("MY_SWITCHER") + + _, detailErr := switcher.IsOnWithDetails() + got, err := switcher.IsOn() + + assert.NoError(t, detailErr) + assert.NoError(t, err) + assert.True(t, got) + assert.Equal(t, []string{"true", "false"}, showReasonValues) + }) + t.Run("should authenticate during prepare and reuse the prepared key", func(t *testing.T) { server := newRemoteTestServer(t, remoteTestHandlers{ authStatus: http.StatusOK, @@ -111,6 +165,94 @@ func TestSwitcherRemoteEvaluation(t *testing.T) { assert.True(t, got) }) + t.Run("should keep only the latest value check input", func(t *testing.T) { + var captured map[string]any + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true}, + onCriteriaRequest: func(body map[string]any, _ *http.Request) { + captured = body + }, + }) + defer server.Close() + + client := newRemoteTestClient(server.URL) + + got, err := client.GetSwitcher("MY_SWITCHER").CheckValue("first").CheckValue("second").IsOn() + + assert.NoError(t, err) + assert.True(t, got) + assert.Equal(t, map[string]any{ + "entry": []any{ + map[string]any{ + "strategy": StrategyValue, + "input": "second", + }, + }, + }, captured) + }) + + t.Run("should reuse the token while a millisecond expiration is still valid", func(t *testing.T) { + authRequests := 0 + server := newRemoteTestServer(t, remoteTestHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).UnixMilli()}, + criteriaStatus: http.StatusOK, + criteriaBody: map[string]any{"result": true}, + onAuthRequest: func(_ *http.Request) { + authRequests++ + }, + }) + defer server.Close() + + client := newRemoteTestClient(server.URL) + switcher := client.GetSwitcher("MY_SWITCHER") + + first, firstErr := switcher.IsOn() + second, secondErr := switcher.IsOn() + + assert.NoError(t, firstErr) + assert.NoError(t, secondErr) + assert.True(t, first) + assert.True(t, second) + assert.Equal(t, 1, authRequests) + }) + + t.Run("should renew the token when the cached expiration is zero", func(t *testing.T) { + authRequests := 0 + mux := http.NewServeMux() + mux.HandleFunc("/criteria/auth", func(writer http.ResponseWriter, request *http.Request) { + authRequests++ + + payload := map[string]any{"token": "[new_token]", "exp": time.Now().Add(time.Hour).Unix()} + if authRequests == 1 { + payload = map[string]any{"token": "[expired_token]", "exp": 0} + } + + writeJSONResponse(t, writer, http.StatusOK, payload) + }) + mux.HandleFunc("/criteria", func(writer http.ResponseWriter, request *http.Request) { + writeJSONResponse(t, writer, http.StatusOK, map[string]any{"result": true}) + }) + + server := httptest.NewServer(mux) + defer server.Close() + + client := newRemoteTestClient(server.URL) + switcher := client.GetSwitcher("MY_SWITCHER") + + first, firstErr := switcher.IsOn() + second, secondErr := switcher.IsOn() + + assert.NoError(t, firstErr) + assert.NoError(t, secondErr) + assert.True(t, first) + assert.True(t, second) + assert.Equal(t, 2, authRequests) + }) + t.Run("should return an auth error when the API key is invalid", func(t *testing.T) { server := newRemoteTestServer(t, remoteTestHandlers{ authStatus: http.StatusUnauthorized, @@ -237,101 +379,6 @@ func TestSwitcherRemoteEvaluation(t *testing.T) { }) } -func TestClientDoJSONRequest(t *testing.T) { - t.Run("should return an error when the payload cannot be marshaled", func(t *testing.T) { - client := newRemoteTestClient("https://api.switcherapi.com") - - response, err := client.doJSONRequest( - http.MethodPost, - "https://api.switcherapi.com/criteria/auth", - map[string]any{ - "invalid": func() {}, - }, - map[string]string{ - "Content-Type": "application/json", - }, - ) - - assert.Nil(t, response) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unsupported type") - }) - - t.Run("should return an error when the request cannot be created", func(t *testing.T) { - client := newRemoteTestClient("https://api.switcherapi.com") - - response, err := client.doJSONRequest( - http.MethodPost, - "://bad-url", - map[string]any{ - "domain": "My Domain", - }, - map[string]string{ - "Content-Type": "application/json", - }, - ) - - assert.Nil(t, response) - assert.Error(t, err) - assert.Contains(t, err.Error(), "missing protocol scheme") - }) -} - -func TestRequestTimeout(t *testing.T) { - t.Run("should return the default combined timeout when configured timeout is zero", func(t *testing.T) { - got := requestTimeout(RemoteOptions{}) - - assert.Equal( - t, - DefaultRemoteConnectTimeout+DefaultRemoteReadTimeout+DefaultRemoteWriteTimeout, - got, - ) - }) - - t.Run("should return the default combined timeout when configured timeout is negative", func(t *testing.T) { - got := requestTimeout(RemoteOptions{ - ConnectTimeout: -time.Second, - }) - - assert.Equal( - t, - DefaultRemoteConnectTimeout+DefaultRemoteReadTimeout+DefaultRemoteWriteTimeout, - got, - ) - }) -} - -func TestParseTokenExpiration(t *testing.T) { - t.Run("should parse the expiration from a json number", func(t *testing.T) { - got := parseTokenExpiration(json.Number("1700000000")) - - assert.Equal(t, int64(1700000000), got) - }) - - t.Run("should return zero when the json number is invalid", func(t *testing.T) { - got := parseTokenExpiration(json.Number("invalid")) - - assert.Zero(t, got) - }) -} - -func TestTokenExpired(t *testing.T) { - t.Run("should treat zero expiration as expired", func(t *testing.T) { - assert.True(t, tokenExpired(0)) - }) - - t.Run("should compare millisecond expirations against the current time", func(t *testing.T) { - assert.False(t, tokenExpired(time.Now().Add(time.Minute).UnixMilli())) - assert.True(t, tokenExpired(time.Now().Add(-time.Minute).UnixMilli())) - }) -} - -func TestStrconvFormatBool(t *testing.T) { - t.Run("should return false when the input is false", func(t *testing.T) { - assert.Equal(t, "false", strconvFormatBool(false)) - }) -} - type remoteTestHandlers struct { authStatus int authBody map[string]any @@ -339,6 +386,7 @@ type remoteTestHandlers struct { criteriaStatus int criteriaBody map[string]any criteriaRawBody *string + onAuthRequest func(request *http.Request) onCriteriaRequest func(body map[string]any, request *http.Request) } @@ -357,6 +405,9 @@ func newRemoteTestServer(t *testing.T, handlers remoteTestHandlers) *httptest.Se mux := http.NewServeMux() mux.HandleFunc("/criteria/auth", func(writer http.ResponseWriter, request *http.Request) { assert.Equal(t, http.MethodPost, request.Method) + if handlers.onAuthRequest != nil { + handlers.onAuthRequest(request) + } if handlers.authRawBody != nil { writer.Header().Set("Content-Type", "application/json") writer.WriteHeader(handlers.authStatus) diff --git a/switcher.go b/switcher.go index efc247e..48a1b73 100644 --- a/switcher.go +++ b/switcher.go @@ -42,11 +42,12 @@ func (s *Switcher) Validate() error { } func (s *Switcher) CheckValue(input string) *Switcher { - s.entries = appendFilteredEntries(s.entries, StrategyValue) - s.entries = append(s.entries, criteriaEntry{ - Strategy: StrategyValue, - Input: input, - }) + s.entries = []criteriaEntry{ + { + Strategy: StrategyValue, + Input: input, + }, + } return s } @@ -69,7 +70,7 @@ func (s *Switcher) Prepare(key string) error { } func (s *Switcher) IsOn() (bool, error) { - result, err := s.IsOnWithDetails() + result, err := s.submit(false) if err != nil { return false, err } @@ -78,6 +79,10 @@ func (s *Switcher) IsOn() (bool, error) { } func (s *Switcher) IsOnWithDetails() (ResultDetail, error) { + return s.submit(true) +} + +func (s *Switcher) submit(showDetails bool) (ResultDetail, error) { if err := s.Validate(); err != nil { return ResultDetail{}, err } @@ -91,16 +96,5 @@ func (s *Switcher) IsOnWithDetails() (ResultDetail, error) { return ResultDetail{}, err } - return s.client.checkCriteria(token, s, true) -} - -func appendFilteredEntries(entries []criteriaEntry, strategy string) []criteriaEntry { - filtered := entries[:0] - for _, entry := range entries { - if entry.Strategy != strategy { - filtered = append(filtered, entry) - } - } - - return filtered + return s.client.checkCriteria(token, s, showDetails) } diff --git a/switcher_test.go b/switcher_test.go index b6f763c..fda3b4e 100644 --- a/switcher_test.go +++ b/switcher_test.go @@ -88,18 +88,3 @@ func TestSwitcherIsOnWithDetails(t *testing.T) { assert.EqualError(t, err, "something went wrong: missing or empty required fields (url, component, api_key)") }) } - -func TestAppendFilteredEntries(t *testing.T) { - t.Run("should preserve entries whose strategy does not match the filtered strategy", func(t *testing.T) { - entries := []criteriaEntry{ - {Strategy: StrategyValue, Input: "user_id"}, - {Strategy: "NETWORK_VALIDATION", Input: "127.0.0.1"}, - } - - got := appendFilteredEntries(entries, StrategyValue) - - assert.Equal(t, []criteriaEntry{ - {Strategy: "NETWORK_VALIDATION", Input: "127.0.0.1"}, - }, got) - }) -}