diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 5fcde1f..ee37b4b 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,5 +1,5 @@ ARG GO_VERSION=1.26.3 -ARG ALPINE_VERSION=3.22 +ARG ALPINE_VERSION=3.23 FROM golang:${GO_VERSION}-alpine${ALPINE_VERSION} diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 76ffd21..dc1be9d 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.9' services: - switcher-gitops: + switcher-client-go: build: context: . dockerfile: Dockerfile diff --git a/Makefile b/Makefile index 4c45cb3..40bb773 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test cover cover-html lint lint-install +.PHONY: test fmt cover cover-html lint lint-install GOLANGCI_LINT_VERSION=v2.12.2 @@ -8,6 +8,9 @@ test-clean: test: go test -p 1 -v ./... +fmt: + gofmt -s -w . + lint-install: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@$(GOLANGCI_LINT_VERSION) diff --git a/client.go b/client.go index dc49c43..32072e7 100644 --- a/client.go +++ b/client.go @@ -135,6 +135,42 @@ func (c *Client) SnapshotVersion() int { return c.snapshot.Domain.Version } +func CheckSnapshot() (bool, error) { + return defaultClient().CheckSnapshot() +} + +func (c *Client) CheckSnapshot() (bool, error) { + token, err := c.ensureToken() + if err != nil { + return false, err + } + + if err := missingTokenError(token); err != nil { + return false, err + } + + upToDate, err := c.checkSnapshotVersion(token, c.SnapshotVersion()) + if err != nil { + return false, err + } + + if upToDate { + return false, nil + } + + snapshot, err := c.resolveSnapshot(token) + if err != nil { + return false, err + } + + if err := saveSnapshotToFile(c.Context(), snapshot); err != nil { + return false, err + } + + c.setSnapshot(snapshot) + return true, nil +} + func (c *Client) snapshotState() *Snapshot { c.mu.RLock() defer c.mu.RUnlock() @@ -154,6 +190,21 @@ func (c *Client) stopBackgroundTasks() { c.UnwatchSnapshot() } +func (c *Client) shouldCheckSnapshot(fetchRemote bool) bool { + ctx := c.Context() + return c.SnapshotVersion() == 0 && (fetchRemote || !ctx.Options.Local) +} + +func (c *Client) loadSnapshotFromCurrentFile() (*Snapshot, error) { + snapshot, err := loadSnapshotFromFile(c.Context()) + if err != nil { + return nil, err + } + + c.setSnapshot(snapshot) + return snapshot, nil +} + func defaultClient() *Client { if client := globalClient.Load(); client != nil { return client diff --git a/local_test.go b/local_test.go index 8391885..cfe3fcb 100644 --- a/local_test.go +++ b/local_test.go @@ -1,199 +1,12 @@ package client import ( - "encoding/json" - "net/http" - "os" "path/filepath" "testing" - "time" "github.com/stretchr/testify/assert" ) -func TestSnapshotLoading(t *testing.T) { - t.Run("should load snapshot version from a local file", func(t *testing.T) { - BuildContext(Context{ - Domain: "My Domain", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotFixtureDir(), - }, - }) - - assert.Equal(t, 0, SnapshotVersion()) - - version, err := LoadSnapshot(nil) - - assert.NoError(t, err) - assert.Equal(t, 1, version) - assert.Equal(t, 1, SnapshotVersion()) - }) - - t.Run("should return an error when the snapshot file is malformed", func(t *testing.T) { - BuildContext(Context{ - Domain: "My Domain", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotFixtureDir(), - }, - Environment: "default_malformed", - }) - - version, err := LoadSnapshot(nil) - - assert.Error(t, err) - assert.Zero(t, version) - }) - - t.Run("should return an error when the snapshot file is not accessible", func(t *testing.T) { - snapshotLocation := filepath.Join(t.TempDir(), "snapshot-location-file") - writeErr := os.WriteFile(snapshotLocation, []byte("not-a-directory"), 0o644) - assert.NoError(t, writeErr) - - BuildContext(Context{ - Domain: "My Domain", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotLocation, - }, - Environment: "default", - }) - - version, err := LoadSnapshot(nil) - - assert.Error(t, err) - assert.Zero(t, version) - }) - - t.Run("should return an error when the snapshot file is not accessible while saving remote updates", func(t *testing.T) { - snapshotDir := t.TempDir() - writeSnapshotFixture(t, snapshotDir, "default_load_1", "default_load_1") - - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, - snapshotCheckBody: map[string]any{"status": false}, - resolveStatus: http.StatusOK, - resolveDomain: loadSnapshotFixture(t, "default_load_2"), - }) - defer server.Close() - - BuildContext(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - Environment: "default_load_1", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, - }, - }) - - version, loadErr := LoadSnapshot(nil) - assert.NoError(t, loadErr) - assert.Equal(t, 1588557288040, version) - - removeErr := os.RemoveAll(snapshotDir) - assert.NoError(t, removeErr) - blockErr := os.WriteFile(snapshotDir, []byte("not-a-directory"), 0o644) - assert.NoError(t, blockErr) - - updated, err := CheckSnapshot() - - assert.Error(t, err) - assert.False(t, updated) - }) - - t.Run("should return an error when the snapshot file path cannot be created", func(t *testing.T) { - BuildContext(Context{ - Domain: "My Domain", - Options: ContextOptions{ - Local: true, - SnapshotLocation: t.TempDir(), - }, - Environment: filepath.Join("nested", "missing"), - }) - - version, err := LoadSnapshot(nil) - - assert.Error(t, err) - assert.Zero(t, version) - }) - - t.Run("should return an error when check snapshot fails during load snapshot", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusInternalServerError, - snapshotCheckBody: map[string]any{"status": false}, - }) - defer server.Close() - - BuildContext(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - Options: ContextOptions{ - Local: false, - }, - }) - - version, err := LoadSnapshot(nil) - - assert.Error(t, err) - assert.Zero(t, version) - assert.EqualError(t, err, "[check_snapshot_version] failed with status: 500") - }) - - t.Run("should return an error when watch snapshot fails during load snapshot", func(t *testing.T) { - BuildContext(Context{ - Domain: "My Domain", - Options: ContextOptions{ - Local: true, - }, - }) - - version, err := LoadSnapshot(&LoadSnapshotOptions{ - WatchSnapshot: true, - }) - - assert.Error(t, err) - assert.Zero(t, version) - assert.EqualError(t, err, "snapshot location is not defined in the context options") - }) - - t.Run("should create a clean snapshot when no file exists", func(t *testing.T) { - snapshotDir := t.TempDir() - - BuildContext(Context{ - Domain: "My Domain", - Environment: "generated-clean", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, - }, - }) - - version, err := LoadSnapshot(nil) - - assert.NoError(t, err) - assert.Equal(t, 0, version) - assert.Equal(t, 0, SnapshotVersion()) - - content, readErr := os.ReadFile(filepath.Join(snapshotDir, "generated-clean.json")) - assert.NoError(t, readErr) - - var snapshot Snapshot - unmarshalErr := json.Unmarshal(content, &snapshot) - assert.NoError(t, unmarshalErr) - assert.Equal(t, 0, snapshot.Domain.Version) - }) -} - func TestSwitcherLocalEvaluation(t *testing.T) { t.Run("should use local snapshot to evaluate a switcher without strategies", func(t *testing.T) { useLocalSnapshotFixture(t, "default") diff --git a/remote.go b/remote.go index c68041e..449878c 100644 --- a/remote.go +++ b/remote.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "encoding/json" "errors" + "fmt" "io" "net" "net/http" @@ -31,6 +32,25 @@ type criteriaResponse struct { Metadata map[string]any `json:"metadata"` } +type snapshotCheckResponse struct { + Status bool `json:"status"` +} + +type resolveSnapshotResponse struct { + Data struct { + Domain SnapshotDomain `json:"domain"` + } `json:"data"` +} + +const contentTypeJSON = "application/json" + +func (c *Client) authHeaders(token string) map[string]string { + return map[string]string{ + "Authorization": "Bearer " + token, + "Content-Type": contentTypeJSON, + } +} + func (c *Client) ensureToken() (string, error) { c.authMu.Lock() defer c.authMu.Unlock() @@ -50,10 +70,7 @@ func (c *Client) ensureToken() (string, error) { "component": ctx.Component, "environment": ctx.Environment, }, - map[string]string{ - "switcher-api-key": ctx.APIKey, - "Content-Type": "application/json", - }, + c.authHeaders(""), ) if err != nil { return "", newRemoteAuthError("[auth] remote unavailable") @@ -103,10 +120,7 @@ func (c *Client) checkCriteria(token string, switcher *Switcher, showDetails boo map[string]any{ "entry": entries, }, - map[string]string{ - "Authorization": "Bearer " + token, - "Content-Type": "application/json", - }, + c.authHeaders(token), ) if err != nil { return ResultDetail{}, newRemoteCriteriaError("[check_criteria] remote unavailable") @@ -131,6 +145,78 @@ func (c *Client) checkCriteria(token string, switcher *Switcher, showDetails boo return ResultDetail(payload), nil } +func (c *Client) checkSnapshotVersion(token string, snapshotVersion int) (bool, error) { + ctx := c.Context() + endpoint := fmt.Sprintf("%s/criteria/snapshot_check/%d", strings.TrimRight(ctx.URL, "/"), snapshotVersion) + + response, err := c.doJSONRequest( + http.MethodGet, + endpoint, + nil, + c.authHeaders(token), + ) + if err != nil { + return false, newRemoteSnapshotError("[check_snapshot_version] remote unavailable") + } + defer func() { + _ = response.Body.Close() + }() + + if response.StatusCode != http.StatusOK { + return false, newRemoteSnapshotError("[check_snapshot_version] failed with status: %d", response.StatusCode) + } + + var payload snapshotCheckResponse + if err := json.NewDecoder(response.Body).Decode(&payload); err != nil { + return false, newRemoteSnapshotError("[check_snapshot_version] failed to decode response: %v", err) + } + + return payload.Status, nil +} + +func (c *Client) resolveSnapshot(token string) (*Snapshot, error) { + ctx := c.Context() + endpoint := strings.TrimRight(ctx.URL, "/") + "/graphql" + + response, err := c.doJSONRequest( + http.MethodPost, + endpoint, + map[string]string{ + "query": fmt.Sprintf(` + query domain { + domain(name: %q, environment: %q, _component: %q) { + name version activated + group { name activated + config { key activated + strategies { strategy activated operation values } + relay { type activated } + } + } + } + } + `, ctx.Domain, ctx.Environment, ctx.Component), + }, + c.authHeaders(token), + ) + if err != nil { + return nil, newRemoteSnapshotError("[resolve_snapshot] remote unavailable") + } + defer func() { + _ = response.Body.Close() + }() + + if response.StatusCode != http.StatusOK { + return nil, newRemoteSnapshotError("[resolve_snapshot] failed with status: %d", response.StatusCode) + } + + var payload resolveSnapshotResponse + if err := json.NewDecoder(response.Body).Decode(&payload); err != nil { + return nil, newRemoteSnapshotError("[resolve_snapshot] failed to decode response: %v", err) + } + + return &Snapshot{Domain: payload.Data.Domain}, nil +} + func (c *Client) doJSONRequest(method, endpoint string, payload any, headers map[string]string) (*http.Response, error) { var bodyReader io.Reader if payload != nil { diff --git a/snapshot_auto_updater.go b/snapshot_auto_updater.go new file mode 100644 index 0000000..2c7b351 --- /dev/null +++ b/snapshot_auto_updater.go @@ -0,0 +1,107 @@ +package client + +import ( + "sync" + "time" +) + +type snapshotAutoUpdater struct { + mu sync.Mutex + stop chan struct{} + done chan struct{} +} + +func newSnapshotAutoUpdater() *snapshotAutoUpdater { + return &snapshotAutoUpdater{} +} + +func ScheduleSnapshotAutoUpdate(interval time.Duration, callback func(error, bool)) { + defaultClient().ScheduleSnapshotAutoUpdate(interval, callback) +} + +func (c *Client) ScheduleSnapshotAutoUpdate(interval time.Duration, callback func(error, bool)) { + if interval > 0 { + c.mu.Lock() + c.context.Options.SnapshotAutoUpdateInterval = interval + c.mu.Unlock() + } + + effectiveInterval := interval + if effectiveInterval <= 0 { + effectiveInterval = c.Context().Options.SnapshotAutoUpdateInterval + } + + if effectiveInterval <= 0 || c.snapshotAutoUpdater == nil { + return + } + + c.snapshotAutoUpdater.Start(c, effectiveInterval, callback) +} + +func TerminateSnapshotAutoUpdate() { + defaultClient().TerminateSnapshotAutoUpdate() +} + +func (c *Client) TerminateSnapshotAutoUpdate() { + if c.snapshotAutoUpdater != nil { + c.snapshotAutoUpdater.Stop() + } +} + +func (u *snapshotAutoUpdater) Start(client *Client, interval time.Duration, callback func(error, bool)) { + u.Stop() + + stop := make(chan struct{}) + done := make(chan struct{}) + + u.mu.Lock() + u.stop = stop + u.done = done + u.mu.Unlock() + + go func() { + defer close(done) + + timer := time.NewTimer(interval) + defer timer.Stop() + + select { + case <-stop: + return + case <-timer.C: + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + updated, err := client.CheckSnapshot() + if callback != nil { + callback(err, updated) + } + + select { + case <-stop: + return + case <-ticker.C: + } + } + }() +} + +func (u *snapshotAutoUpdater) Stop() { + u.mu.Lock() + stop := u.stop + done := u.done + u.stop = nil + u.done = nil + u.mu.Unlock() + + if stop != nil { + close(stop) + } + + if done != nil { + <-done + } +} diff --git a/snapshot_auto_updater_test.go b/snapshot_auto_updater_test.go new file mode 100644 index 0000000..004843f --- /dev/null +++ b/snapshot_auto_updater_test.go @@ -0,0 +1,195 @@ +package client + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSnapshotAutoUpdater(t *testing.T) { + t.Run("should load the snapshot from remote when fetch remote is enabled", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckBody: map[string]any{"status": false}, + resolveStatus: http.StatusOK, + resolveDomain: loadSnapshotFixture(t, "default_load_1"), + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Environment: "default_load_1", + Options: ContextOptions{ + Local: true, + }, + }) + + version, err := client.LoadSnapshot(&LoadSnapshotOptions{FetchRemote: true}) + enabled, enabledErr := client.GetSwitcher("FF2FOR2030").IsOn() + + assert.NoError(t, err) + assert.NoError(t, enabledErr) + assert.Equal(t, 1588557288040, version) + assert.Equal(t, 1588557288040, client.SnapshotVersion()) + assert.True(t, enabled) + }) + + t.Run("should auto update the snapshot on schedule", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotChecks: []snapshotCheckStep{ + {status: http.StatusOK, body: map[string]any{"status": false}}, + {status: http.StatusOK, body: map[string]any{"status": false}}, + }, + resolveSteps: []resolveSnapshotStep{ + {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_1")}, + {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_2")}, + }, + }) + defer server.Close() + + snapshotDir := t.TempDir() + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Environment: "generated-auto-update", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + t.Cleanup(client.TerminateSnapshotAutoUpdate) + + version, err := client.LoadSnapshot(&LoadSnapshotOptions{FetchRemote: true}) + require.NoError(t, err) + require.Equal(t, 1588557288040, version) + + callbacks := make(chan struct { + err error + updated bool + }, 1) + client.ScheduleSnapshotAutoUpdate(50*time.Millisecond, func(err error, updated bool) { + select { + case callbacks <- struct { + err error + updated bool + }{err: err, updated: updated}: + default: + } + }) + + select { + case callback := <-callbacks: + assert.NoError(t, callback.err) + assert.True(t, callback.updated) + case <-time.After(5 * time.Second): + t.Fatal("expected scheduled snapshot update callback") + } + + assert.Eventually(t, func() bool { + got, gotErr := client.GetSwitcher("FF2FOR2030").IsOn() + return gotErr == nil && !got && client.SnapshotVersion() == 1588557288041 + }, 5*time.Second, 100*time.Millisecond) + }) + + t.Run("should schedule snapshot auto update using package function", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotChecks: []snapshotCheckStep{ + {status: http.StatusOK, body: map[string]any{"status": false}}, + {status: http.StatusOK, body: map[string]any{"status": false}}, + }, + resolveSteps: []resolveSnapshotStep{ + {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_1")}, + {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_2")}, + }, + }) + defer server.Close() + + BuildContext(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Environment: "generated-auto-update-global", + Options: ContextOptions{ + Local: true, + SnapshotLocation: t.TempDir(), + }, + }) + t.Cleanup(TerminateSnapshotAutoUpdate) + + version, err := LoadSnapshot(&LoadSnapshotOptions{FetchRemote: true}) + require.NoError(t, err) + require.Equal(t, 1588557288040, version) + + callbacks := make(chan struct { + err error + updated bool + }, 1) + ScheduleSnapshotAutoUpdate(50*time.Millisecond, func(err error, updated bool) { + select { + case callbacks <- struct { + err error + updated bool + }{err: err, updated: updated}: + default: + } + }) + + select { + case callback := <-callbacks: + assert.NoError(t, callback.err) + assert.True(t, callback.updated) + case <-time.After(5 * time.Second): + t.Fatal("expected scheduled snapshot update callback") + } + + assert.Eventually(t, func() bool { + got, gotErr := GetSwitcher("FF2FOR2030").IsOn() + return gotErr == nil && !got && SnapshotVersion() == 1588557288041 + }, 5*time.Second, 100*time.Millisecond) + }) + + t.Run("should terminate snapshot auto update using package function", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", + Environment: "default", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotFixtureDir(), + }, + }) + t.Cleanup(TerminateSnapshotAutoUpdate) + + _, err := LoadSnapshot(nil) + require.NoError(t, err) + + callbackCh := make(chan struct{}, 1) + ScheduleSnapshotAutoUpdate(200*time.Millisecond, func(_ error, _ bool) { + select { + case callbackCh <- struct{}{}: + default: + } + }) + TerminateSnapshotAutoUpdate() + + select { + case <-callbackCh: + t.Fatal("did not expect auto update callback after terminate") + case <-time.After(400 * time.Millisecond): + } + }) +} diff --git a/snapshot_checker_test.go b/snapshot_checker_test.go new file mode 100644 index 0000000..d656e90 --- /dev/null +++ b/snapshot_checker_test.go @@ -0,0 +1,301 @@ +package client + +import ( + "net/http" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSnapshotChecker(t *testing.T) { + t.Run("should return an error when the snapshot file is not accessible while saving remote updates", func(t *testing.T) { + snapshotDir := t.TempDir() + writeSnapshotFixture(t, snapshotDir, "default_load_1", "default_load_1") + + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckBody: map[string]any{"status": false}, + resolveStatus: http.StatusOK, + resolveDomain: loadSnapshotFixture(t, "default_load_2"), + }) + defer server.Close() + + BuildContext(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Environment: "default_load_1", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + + version, loadErr := LoadSnapshot(nil) + assert.NoError(t, loadErr) + assert.Equal(t, 1588557288040, version) + + removeErr := os.RemoveAll(snapshotDir) + assert.NoError(t, removeErr) + blockErr := os.WriteFile(snapshotDir, []byte("not-a-directory"), 0o644) + assert.NoError(t, blockErr) + + updated, err := CheckSnapshot() + + assert.Error(t, err) + assert.False(t, updated) + }) + + t.Run("should not update the snapshot when the current version is still valid", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckBody: map[string]any{"status": true}, + }) + defer server.Close() + + snapshotDir := t.TempDir() + writeSnapshotFixture(t, snapshotDir, "default_load_1", "default_load_1") + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Environment: "default_load_1", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + + version, loadErr := client.LoadSnapshot(nil) + updated, err := client.CheckSnapshot() + + assert.NoError(t, loadErr) + assert.NoError(t, err) + assert.Equal(t, 1588557288040, version) + assert.Equal(t, 1588557288040, client.SnapshotVersion()) + assert.False(t, updated) + }) + + t.Run("should return an error when the snapshot version check fails", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusInternalServerError, + snapshotCheckBody: map[string]any{"status": false}, + }) + defer server.Close() + + snapshotDir := t.TempDir() + writeSnapshotFixture(t, snapshotDir, "default_load_1", "default_load_1") + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + Environment: "default_load_1", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + + _, loadErr := client.LoadSnapshot(nil) + updated, err := client.CheckSnapshot() + + assert.NoError(t, loadErr) + assert.False(t, updated) + assert.Error(t, err) + var snapshotErr *RemoteSnapshotError + assert.ErrorAs(t, err, &snapshotErr) + assert.EqualError(t, err, "[check_snapshot_version] failed with status: 500") + }) + + t.Run("should return an error when snapshot version response cannot be decoded", func(t *testing.T) { + rawBody := "{invalid-json" + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckRawBody: &rawBody, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + }) + + updated, err := client.CheckSnapshot() + + assert.False(t, updated) + assert.Error(t, err) + var snapshotErr *RemoteSnapshotError + assert.ErrorAs(t, err, &snapshotErr) + assert.Contains(t, err.Error(), "[check_snapshot_version] failed to decode response") + }) + + t.Run("should return an error when snapshot version endpoint is unavailable", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckBody: map[string]any{"status": true}, + }) + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + }) + + _, tokenErr := client.ensureToken() + assert.NoError(t, tokenErr) + + server.Close() + + updated, err := client.CheckSnapshot() + + assert.False(t, updated) + assert.Error(t, err) + var snapshotErr *RemoteSnapshotError + assert.ErrorAs(t, err, &snapshotErr) + assert.EqualError(t, err, "[check_snapshot_version] remote unavailable") + }) + + t.Run("should return an error when ensure token fails during snapshot check", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusUnauthorized, + authBody: map[string]any{}, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + }) + + updated, err := client.CheckSnapshot() + + assert.False(t, updated) + assert.Error(t, err) + var authErr *RemoteAuthError + assert.ErrorAs(t, err, &authErr) + assert.EqualError(t, err, "invalid API key") + }) + + t.Run("should return an error when auth response does not include a token", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": nil, "exp": time.Now().Add(time.Hour).Unix()}, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + }) + + updated, err := client.CheckSnapshot() + + assert.False(t, updated) + assert.EqualError(t, err, "something went wrong: missing token field") + }) + + t.Run("should return an error when resolving snapshot fails", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckBody: map[string]any{"status": false}, + resolveStatus: http.StatusInternalServerError, + resolveDomain: map[string]any{}, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + }) + + updated, err := client.CheckSnapshot() + + assert.False(t, updated) + assert.Error(t, err) + var snapshotErr *RemoteSnapshotError + assert.ErrorAs(t, err, &snapshotErr) + assert.EqualError(t, err, "[resolve_snapshot] failed with status: 500") + }) + + t.Run("should return an error when snapshot resolve response cannot be decoded", func(t *testing.T) { + rawBody := "{invalid-json" + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckBody: map[string]any{"status": false}, + resolveStatus: http.StatusOK, + resolveRawBody: &rawBody, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + }) + + updated, err := client.CheckSnapshot() + + assert.False(t, updated) + assert.Error(t, err) + var snapshotErr *RemoteSnapshotError + assert.ErrorAs(t, err, &snapshotErr) + assert.Contains(t, err.Error(), "[resolve_snapshot] failed to decode response") + }) + + t.Run("should return an error when snapshot resolve endpoint is unavailable", func(t *testing.T) { + server := newSnapshotTestServer(t, snapshotRemoteHandlers{ + authStatus: http.StatusOK, + authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, + snapshotCheckStatus: http.StatusOK, + snapshotCheckBody: map[string]any{"status": false}, + resolveUnavailable: true, + }) + defer server.Close() + + client := NewClient(Context{ + Domain: "My Domain", + URL: server.URL, + APIKey: "[YOUR_API_KEY]", + Component: "MyApp", + }) + + updated, err := client.CheckSnapshot() + + assert.False(t, updated) + assert.Error(t, err) + var snapshotErr *RemoteSnapshotError + assert.ErrorAs(t, err, &snapshotErr) + assert.EqualError(t, err, "[resolve_snapshot] remote unavailable") + }) +} diff --git a/snapshot_lifecycle.go b/snapshot_lifecycle.go deleted file mode 100644 index 702e91b..0000000 --- a/snapshot_lifecycle.go +++ /dev/null @@ -1,377 +0,0 @@ -package client - -import ( - "encoding/json" - "fmt" - "net/http" - "os" - "strings" - "sync" - "time" -) - -const snapshotWatcherPollInterval = 100 * time.Millisecond - -type WatchSnapshotCallback struct { - Success func() - Reject func(error) -} - -type snapshotCheckResponse struct { - Status bool `json:"status"` -} - -type resolveSnapshotResponse struct { - Data struct { - Domain SnapshotDomain `json:"domain"` - } `json:"data"` -} - -type snapshotWatcher struct { - mu sync.Mutex - stop chan struct{} - done chan struct{} -} - -type snapshotAutoUpdater struct { - mu sync.Mutex - stop chan struct{} - done chan struct{} -} - -func newSnapshotWatcher() *snapshotWatcher { - return &snapshotWatcher{} -} - -func newSnapshotAutoUpdater() *snapshotAutoUpdater { - return &snapshotAutoUpdater{} -} - -func CheckSnapshot() (bool, error) { - return defaultClient().CheckSnapshot() -} - -func (c *Client) CheckSnapshot() (bool, error) { - token, err := c.ensureToken() - if err != nil { - return false, err - } - - if err := missingTokenError(token); err != nil { - return false, err - } - - upToDate, err := c.checkSnapshotVersion(token, c.SnapshotVersion()) - if err != nil { - return false, err - } - - if upToDate { - return false, nil - } - - snapshot, err := c.resolveSnapshot(token) - if err != nil { - return false, err - } - - if err := saveSnapshotToFile(c.Context(), snapshot); err != nil { - return false, err - } - - c.setSnapshot(snapshot) - return true, nil -} - -func WatchSnapshot(callback WatchSnapshotCallback) error { - return defaultClient().WatchSnapshot(callback) -} - -func (c *Client) WatchSnapshot(callback WatchSnapshotCallback) error { - return c.snapshotWatcher.Start(c, callback) -} - -func UnwatchSnapshot() { - defaultClient().UnwatchSnapshot() -} - -func (c *Client) UnwatchSnapshot() { - if c.snapshotWatcher != nil { - c.snapshotWatcher.Stop() - } -} - -func ScheduleSnapshotAutoUpdate(interval time.Duration, callback func(error, bool)) { - defaultClient().ScheduleSnapshotAutoUpdate(interval, callback) -} - -func (c *Client) ScheduleSnapshotAutoUpdate(interval time.Duration, callback func(error, bool)) { - if interval > 0 { - c.mu.Lock() - c.context.Options.SnapshotAutoUpdateInterval = interval - c.mu.Unlock() - } - - effectiveInterval := interval - if effectiveInterval <= 0 { - effectiveInterval = c.Context().Options.SnapshotAutoUpdateInterval - } - - if effectiveInterval <= 0 || c.snapshotAutoUpdater == nil { - return - } - - c.snapshotAutoUpdater.Start(c, effectiveInterval, callback) -} - -func TerminateSnapshotAutoUpdate() { - defaultClient().TerminateSnapshotAutoUpdate() -} - -func (c *Client) TerminateSnapshotAutoUpdate() { - if c.snapshotAutoUpdater != nil { - c.snapshotAutoUpdater.Stop() - } -} - -func (c *Client) shouldCheckSnapshot(fetchRemote bool) bool { - ctx := c.Context() - return c.SnapshotVersion() == 0 && (fetchRemote || !ctx.Options.Local) -} - -func (c *Client) loadSnapshotFromCurrentFile() (*Snapshot, error) { - snapshot, err := loadSnapshotFromFile(c.Context()) - if err != nil { - return nil, err - } - - c.setSnapshot(snapshot) - return snapshot, nil -} - -func (c *Client) checkSnapshotVersion(token string, snapshotVersion int) (bool, error) { - ctx := c.Context() - endpoint := fmt.Sprintf("%s/criteria/snapshot_check/%d", strings.TrimRight(ctx.URL, "/"), snapshotVersion) - - response, err := c.doJSONRequest( - http.MethodGet, - endpoint, - nil, - map[string]string{ - "Authorization": "Bearer " + token, - "Content-Type": "application/json", - }, - ) - if err != nil { - return false, newRemoteSnapshotError("[check_snapshot_version] remote unavailable") - } - defer func() { - _ = response.Body.Close() - }() - - if response.StatusCode != http.StatusOK { - return false, newRemoteSnapshotError("[check_snapshot_version] failed with status: %d", response.StatusCode) - } - - var payload snapshotCheckResponse - if err := json.NewDecoder(response.Body).Decode(&payload); err != nil { - return false, newRemoteSnapshotError("[check_snapshot_version] failed to decode response: %v", err) - } - - return payload.Status, nil -} - -func (c *Client) resolveSnapshot(token string) (*Snapshot, error) { - ctx := c.Context() - endpoint := strings.TrimRight(ctx.URL, "/") + "/graphql" - - response, err := c.doJSONRequest( - http.MethodPost, - endpoint, - map[string]string{ - "query": fmt.Sprintf(` - query domain { - domain(name: %q, environment: %q, _component: %q) { - name version activated - group { name activated - config { key activated - strategies { strategy activated operation values } - relay { type activated } - } - } - } - } - `, ctx.Domain, ctx.Environment, ctx.Component), - }, - map[string]string{ - "Authorization": "Bearer " + token, - "Content-Type": "application/json", - }, - ) - if err != nil { - return nil, newRemoteSnapshotError("[resolve_snapshot] remote unavailable") - } - defer func() { - _ = response.Body.Close() - }() - - if response.StatusCode != http.StatusOK { - return nil, newRemoteSnapshotError("[resolve_snapshot] failed with status: %d", response.StatusCode) - } - - var payload resolveSnapshotResponse - if err := json.NewDecoder(response.Body).Decode(&payload); err != nil { - return nil, newRemoteSnapshotError("[resolve_snapshot] failed to decode response: %v", err) - } - - return &Snapshot{Domain: payload.Data.Domain}, nil -} - -func (w *snapshotWatcher) Start(client *Client, callback WatchSnapshotCallback) error { - snapshotLocation := strings.TrimSpace(client.Context().Options.SnapshotLocation) - if snapshotLocation == "" { - return fmt.Errorf("snapshot location is not defined in the context options") - } - - snapshotFile := snapshotFilePath(client.Context()) - info, err := os.Stat(snapshotFile) - if err != nil { - return err - } - - w.Stop() - - stop := make(chan struct{}) - done := make(chan struct{}) - - w.mu.Lock() - w.stop = stop - w.done = done - w.mu.Unlock() - - go func() { - defer close(done) - - ticker := time.NewTicker(snapshotWatcherPollInterval) - defer ticker.Stop() - - lastModified := info.ModTime() - lastSize := info.Size() - - for { - select { - case <-stop: - return - case <-ticker.C: - currentInfo, statErr := os.Stat(snapshotFile) - if statErr != nil { - invokeWatchReject(callback, statErr) - continue - } - - if currentInfo.ModTime().Equal(lastModified) && currentInfo.Size() == lastSize { - continue - } - - lastModified = currentInfo.ModTime() - lastSize = currentInfo.Size() - - if _, loadErr := client.loadSnapshotFromCurrentFile(); loadErr != nil { - invokeWatchReject(callback, loadErr) - continue - } - - invokeWatchSuccess(callback) - } - } - }() - - return nil -} - -func (w *snapshotWatcher) Stop() { - w.mu.Lock() - stop := w.stop - done := w.done - w.stop = nil - w.done = nil - w.mu.Unlock() - - if stop != nil { - close(stop) - } - - if done != nil { - <-done - } -} - -func (u *snapshotAutoUpdater) Start(client *Client, interval time.Duration, callback func(error, bool)) { - u.Stop() - - stop := make(chan struct{}) - done := make(chan struct{}) - - u.mu.Lock() - u.stop = stop - u.done = done - u.mu.Unlock() - - go func() { - defer close(done) - - timer := time.NewTimer(interval) - defer timer.Stop() - - select { - case <-stop: - return - case <-timer.C: - } - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - updated, err := client.CheckSnapshot() - if callback != nil { - callback(err, updated) - } - - select { - case <-stop: - return - case <-ticker.C: - } - } - }() -} - -func (u *snapshotAutoUpdater) Stop() { - u.mu.Lock() - stop := u.stop - done := u.done - u.stop = nil - u.done = nil - u.mu.Unlock() - - if stop != nil { - close(stop) - } - - if done != nil { - <-done - } -} - -func invokeWatchSuccess(callback WatchSnapshotCallback) { - if callback.Success != nil { - callback.Success() - } -} - -func invokeWatchReject(callback WatchSnapshotCallback, err error) { - if callback.Reject != nil { - callback.Reject(err) - } -} diff --git a/snapshot_test.go b/snapshot_test.go index a68016a..c2b3b77 100644 --- a/snapshot_test.go +++ b/snapshot_test.go @@ -2,7 +2,6 @@ package client import ( "encoding/json" - "errors" "net/http" "net/http/httptest" "os" @@ -15,663 +14,128 @@ import ( "github.com/stretchr/testify/require" ) -func TestSnapshotLifecycle(t *testing.T) { - t.Run("should load the snapshot from remote when fetch remote is enabled", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, - snapshotCheckBody: map[string]any{"status": false}, - resolveStatus: http.StatusOK, - resolveDomain: loadSnapshotFixture(t, "default_load_1"), - }) - defer server.Close() - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - Environment: "default_load_1", +func TestSnapshotLoading(t *testing.T) { + t.Run("should load snapshot version from a local file", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", Options: ContextOptions{ - Local: true, + Local: true, + SnapshotLocation: snapshotFixtureDir(), }, }) - version, err := client.LoadSnapshot(&LoadSnapshotOptions{FetchRemote: true}) - enabled, enabledErr := client.GetSwitcher("FF2FOR2030").IsOn() + assert.Equal(t, 0, SnapshotVersion()) + + version, err := LoadSnapshot(nil) assert.NoError(t, err) - assert.NoError(t, enabledErr) - assert.Equal(t, 1588557288040, version) - assert.Equal(t, 1588557288040, client.SnapshotVersion()) - assert.True(t, enabled) + assert.Equal(t, 1, version) + assert.Equal(t, 1, SnapshotVersion()) }) - t.Run("should not update the snapshot when the current version is still valid", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, - snapshotCheckBody: map[string]any{"status": true}, - }) - defer server.Close() - - snapshotDir := t.TempDir() - writeSnapshotFixture(t, snapshotDir, "default_load_1", "default_load_1") - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - Environment: "default_load_1", + t.Run("should return an error when the snapshot file is malformed", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", Options: ContextOptions{ Local: true, - SnapshotLocation: snapshotDir, + SnapshotLocation: snapshotFixtureDir(), }, + Environment: "default_malformed", }) - version, loadErr := client.LoadSnapshot(nil) - updated, err := client.CheckSnapshot() + version, err := LoadSnapshot(nil) - assert.NoError(t, loadErr) - assert.NoError(t, err) - assert.Equal(t, 1588557288040, version) - assert.Equal(t, 1588557288040, client.SnapshotVersion()) - assert.False(t, updated) + assert.Error(t, err) + assert.Zero(t, version) }) - t.Run("should return an error when the snapshot version check fails", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusInternalServerError, - snapshotCheckBody: map[string]any{"status": false}, - }) - defer server.Close() - - snapshotDir := t.TempDir() - writeSnapshotFixture(t, snapshotDir, "default_load_1", "default_load_1") + t.Run("should return an error when the snapshot file is not accessible", func(t *testing.T) { + snapshotLocation := filepath.Join(t.TempDir(), "snapshot-location-file") + writeErr := os.WriteFile(snapshotLocation, []byte("not-a-directory"), 0o644) + assert.NoError(t, writeErr) - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - Environment: "default_load_1", + BuildContext(Context{ + Domain: "My Domain", Options: ContextOptions{ Local: true, - SnapshotLocation: snapshotDir, + SnapshotLocation: snapshotLocation, }, + Environment: "default", }) - _, loadErr := client.LoadSnapshot(nil) - updated, err := client.CheckSnapshot() - - assert.NoError(t, loadErr) - assert.False(t, updated) - assert.Error(t, err) - var snapshotErr *RemoteSnapshotError - assert.ErrorAs(t, err, &snapshotErr) - assert.EqualError(t, err, "[check_snapshot_version] failed with status: 500") - }) - - t.Run("should return an error when snapshot version response cannot be decoded", func(t *testing.T) { - rawBody := "{invalid-json" - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, - snapshotCheckRawBody: &rawBody, - }) - defer server.Close() - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - }) - - updated, err := client.CheckSnapshot() - - assert.False(t, updated) - assert.Error(t, err) - var snapshotErr *RemoteSnapshotError - assert.ErrorAs(t, err, &snapshotErr) - assert.Contains(t, err.Error(), "[check_snapshot_version] failed to decode response") - }) - - t.Run("should return an error when snapshot version endpoint is unavailable", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, - snapshotCheckBody: map[string]any{"status": true}, - }) - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - }) - - _, tokenErr := client.ensureToken() - assert.NoError(t, tokenErr) - - server.Close() - - updated, err := client.CheckSnapshot() - - assert.False(t, updated) - assert.Error(t, err) - var snapshotErr *RemoteSnapshotError - assert.ErrorAs(t, err, &snapshotErr) - assert.EqualError(t, err, "[check_snapshot_version] remote unavailable") - }) - - t.Run("should return an error when ensure token fails during snapshot check", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusUnauthorized, - authBody: map[string]any{}, - }) - defer server.Close() - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - }) - - updated, err := client.CheckSnapshot() - - assert.False(t, updated) - assert.Error(t, err) - var authErr *RemoteAuthError - assert.ErrorAs(t, err, &authErr) - assert.EqualError(t, err, "invalid API key") - }) - - t.Run("should return an error when auth response does not include a token", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": nil, "exp": time.Now().Add(time.Hour).Unix()}, - }) - defer server.Close() - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - }) - - updated, err := client.CheckSnapshot() - - assert.False(t, updated) - assert.EqualError(t, err, "something went wrong: missing token field") - }) - - t.Run("should return an error when resolving snapshot fails", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, - snapshotCheckBody: map[string]any{"status": false}, - resolveStatus: http.StatusInternalServerError, - resolveDomain: map[string]any{}, - }) - defer server.Close() - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - }) - - updated, err := client.CheckSnapshot() + version, err := LoadSnapshot(nil) - assert.False(t, updated) assert.Error(t, err) - var snapshotErr *RemoteSnapshotError - assert.ErrorAs(t, err, &snapshotErr) - assert.EqualError(t, err, "[resolve_snapshot] failed with status: 500") + assert.Zero(t, version) }) - t.Run("should return an error when snapshot resolve response cannot be decoded", func(t *testing.T) { - rawBody := "{invalid-json" - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, - snapshotCheckBody: map[string]any{"status": false}, - resolveStatus: http.StatusOK, - resolveRawBody: &rawBody, - }) - defer server.Close() - - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", + t.Run("should return an error when the snapshot file path cannot be created", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", + Options: ContextOptions{ + Local: true, + SnapshotLocation: t.TempDir(), + }, + Environment: filepath.Join("nested", "missing"), }) - updated, err := client.CheckSnapshot() + version, err := LoadSnapshot(nil) - assert.False(t, updated) assert.Error(t, err) - var snapshotErr *RemoteSnapshotError - assert.ErrorAs(t, err, &snapshotErr) - assert.Contains(t, err.Error(), "[resolve_snapshot] failed to decode response") + assert.Zero(t, version) }) - t.Run("should return an error when snapshot resolve endpoint is unavailable", func(t *testing.T) { + t.Run("should return an error when check snapshot fails during load snapshot", func(t *testing.T) { server := newSnapshotTestServer(t, snapshotRemoteHandlers{ authStatus: http.StatusOK, authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotCheckStatus: http.StatusOK, + snapshotCheckStatus: http.StatusInternalServerError, snapshotCheckBody: map[string]any{"status": false}, - resolveUnavailable: true, }) defer server.Close() - client := NewClient(Context{ + BuildContext(Context{ Domain: "My Domain", URL: server.URL, APIKey: "[YOUR_API_KEY]", Component: "MyApp", - }) - - updated, err := client.CheckSnapshot() - - assert.False(t, updated) - assert.Error(t, err) - var snapshotErr *RemoteSnapshotError - assert.ErrorAs(t, err, &snapshotErr) - assert.EqualError(t, err, "[resolve_snapshot] remote unavailable") - }) - - t.Run("should return an error when watch snapshot cannot stat the file at startup", func(t *testing.T) { - snapshotDir := t.TempDir() - - client := NewClient(Context{ - Domain: "My Domain", - Environment: "missing-watch-file", Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, + Local: false, }, }) - t.Cleanup(client.UnwatchSnapshot) - err := client.WatchSnapshot(WatchSnapshotCallback{}) + version, err := LoadSnapshot(nil) assert.Error(t, err) - assert.ErrorIs(t, err, os.ErrNotExist) - }) - - t.Run("should watch the snapshot file when load snapshot enables watch mode", func(t *testing.T) { - snapshotDir := t.TempDir() - writeSnapshotFixture(t, snapshotDir, "watched", "default_load_1") - - client := NewClient(Context{ - Domain: "My Domain", - Environment: "watched", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, - }, - }) - t.Cleanup(client.UnwatchSnapshot) - - version, err := client.LoadSnapshot(&LoadSnapshotOptions{WatchSnapshot: true}) - require.NoError(t, err) - require.Equal(t, 1588557288040, version) - - enabled, enabledErr := client.GetSwitcher("FF2FOR2030").IsOn() - require.NoError(t, enabledErr) - require.True(t, enabled) - - writeSnapshotFixture(t, snapshotDir, "watched", "default_load_2") - - require.Eventually(t, func() bool { - got, gotErr := client.GetSwitcher("FF2FOR2030").IsOn() - return gotErr == nil && !got && client.SnapshotVersion() == 1588557288041 - }, 5*time.Second, 100*time.Millisecond) - }) - - t.Run("should reject watch updates when the modified snapshot is malformed", func(t *testing.T) { - snapshotDir := t.TempDir() - writeSnapshotFixture(t, snapshotDir, "watched", "default_load_1") - - client := NewClient(Context{ - Domain: "My Domain", - Environment: "watched", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, - }, - }) - t.Cleanup(client.UnwatchSnapshot) - - _, err := client.LoadSnapshot(nil) - require.NoError(t, err) - - rejectCh := make(chan error, 1) - watchErr := client.WatchSnapshot(WatchSnapshotCallback{ - Reject: func(err error) { - select { - case rejectCh <- err: - default: - } - }, - }) - require.NoError(t, watchErr) - - content, readErr := os.ReadFile(filepath.Join(snapshotFixtureDir(), "default_malformed.json")) - require.NoError(t, readErr) - writeErr := os.WriteFile(filepath.Join(snapshotDir, "watched.json"), content, 0o644) - require.NoError(t, writeErr) - - select { - case rejectErr := <-rejectCh: - assert.Error(t, rejectErr) - case <-time.After(5 * time.Second): - t.Fatal("expected malformed snapshot watch callback") - } - }) - - t.Run("should reject watch updates when watched snapshot file becomes unavailable", func(t *testing.T) { - snapshotDir := t.TempDir() - environment := "watched-missing-runtime" - writeSnapshotFixture(t, snapshotDir, environment, "default_load_1") - - client := NewClient(Context{ - Domain: "My Domain", - Environment: environment, - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, - }, - }) - t.Cleanup(client.UnwatchSnapshot) - - _, err := client.LoadSnapshot(nil) - require.NoError(t, err) - - rejectCh := make(chan error, 1) - watchErr := client.WatchSnapshot(WatchSnapshotCallback{ - Reject: func(err error) { - select { - case rejectCh <- err: - default: - } - }, - }) - require.NoError(t, watchErr) - - removeErr := os.Remove(filepath.Join(snapshotDir, environment+".json")) - require.NoError(t, removeErr) - - select { - case rejectErr := <-rejectCh: - assert.Error(t, rejectErr) - assert.True(t, errors.Is(rejectErr, os.ErrNotExist)) - case <-time.After(5 * time.Second): - t.Fatal("expected stat error callback when watched snapshot file is removed") - } - }) - - t.Run("should auto update the snapshot on schedule", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotChecks: []snapshotCheckStep{ - {status: http.StatusOK, body: map[string]any{"status": false}}, - {status: http.StatusOK, body: map[string]any{"status": false}}, - }, - resolveSteps: []resolveSnapshotStep{ - {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_1")}, - {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_2")}, - }, - }) - defer server.Close() - - snapshotDir := t.TempDir() - client := NewClient(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - Environment: "generated-auto-update", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, - }, - }) - t.Cleanup(client.TerminateSnapshotAutoUpdate) - - version, err := client.LoadSnapshot(&LoadSnapshotOptions{FetchRemote: true}) - require.NoError(t, err) - require.Equal(t, 1588557288040, version) - - callbacks := make(chan struct { - err error - updated bool - }, 1) - client.ScheduleSnapshotAutoUpdate(50*time.Millisecond, func(err error, updated bool) { - select { - case callbacks <- struct { - err error - updated bool - }{err: err, updated: updated}: - default: - } - }) - - select { - case callback := <-callbacks: - assert.NoError(t, callback.err) - assert.True(t, callback.updated) - case <-time.After(5 * time.Second): - t.Fatal("expected scheduled snapshot update callback") - } - - assert.Eventually(t, func() bool { - got, gotErr := client.GetSwitcher("FF2FOR2030").IsOn() - return gotErr == nil && !got && client.SnapshotVersion() == 1588557288041 - }, 5*time.Second, 100*time.Millisecond) - }) -} - -func TestSnapshotLifecycleGlobalFunctions(t *testing.T) { - t.Run("should return an error when watch snapshot has no snapshot location", func(t *testing.T) { - BuildContext(Context{ - Domain: "My Domain", - }) - t.Cleanup(UnwatchSnapshot) - - err := WatchSnapshot(WatchSnapshotCallback{}) - - assert.EqualError(t, err, "snapshot location is not defined in the context options") + assert.Zero(t, version) + assert.EqualError(t, err, "[check_snapshot_version] failed with status: 500") }) - t.Run("should watch snapshot using package function", func(t *testing.T) { + t.Run("should create a clean snapshot when no file exists", func(t *testing.T) { snapshotDir := t.TempDir() - writeSnapshotFixture(t, snapshotDir, "watched-global", "default_load_1") BuildContext(Context{ Domain: "My Domain", - Environment: "watched-global", + Environment: "generated-clean", Options: ContextOptions{ Local: true, SnapshotLocation: snapshotDir, }, }) - t.Cleanup(UnwatchSnapshot) version, err := LoadSnapshot(nil) - require.NoError(t, err) - require.Equal(t, 1588557288040, version) - - successCh := make(chan struct{}, 1) - watchErr := WatchSnapshot(WatchSnapshotCallback{ - Success: func() { - select { - case successCh <- struct{}{}: - default: - } - }, - }) - require.NoError(t, watchErr) - - writeSnapshotFixture(t, snapshotDir, "watched-global", "default_load_2") - - select { - case <-successCh: - case <-time.After(5 * time.Second): - t.Fatal("expected watch snapshot callback") - } - - assert.Eventually(t, func() bool { - got, gotErr := GetSwitcher("FF2FOR2030").IsOn() - return gotErr == nil && !got && SnapshotVersion() == 1588557288041 - }, 5*time.Second, 100*time.Millisecond) - }) - - t.Run("should stop watching snapshot using package unwatch function", func(t *testing.T) { - snapshotDir := t.TempDir() - writeSnapshotFixture(t, snapshotDir, "watched-global-stop", "default_load_1") - - BuildContext(Context{ - Domain: "My Domain", - Environment: "watched-global-stop", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotDir, - }, - }) - t.Cleanup(UnwatchSnapshot) - - _, err := LoadSnapshot(nil) - require.NoError(t, err) - - callbackCh := make(chan struct{}, 1) - watchErr := WatchSnapshot(WatchSnapshotCallback{ - Success: func() { - select { - case callbackCh <- struct{}{}: - default: - } - }, - }) - require.NoError(t, watchErr) - - UnwatchSnapshot() - writeSnapshotFixture(t, snapshotDir, "watched-global-stop", "default_load_2") - - select { - case <-callbackCh: - t.Fatal("did not expect watch callback after unwatch") - case <-time.After(400 * time.Millisecond): - } - - assert.Equal(t, 1588557288040, SnapshotVersion()) - }) - t.Run("should schedule snapshot auto update using package function", func(t *testing.T) { - server := newSnapshotTestServer(t, snapshotRemoteHandlers{ - authStatus: http.StatusOK, - authBody: map[string]any{"token": "[token]", "exp": time.Now().Add(time.Hour).Unix()}, - snapshotChecks: []snapshotCheckStep{ - {status: http.StatusOK, body: map[string]any{"status": false}}, - {status: http.StatusOK, body: map[string]any{"status": false}}, - }, - resolveSteps: []resolveSnapshotStep{ - {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_1")}, - {status: http.StatusOK, domain: loadSnapshotFixture(t, "default_load_2")}, - }, - }) - defer server.Close() - - BuildContext(Context{ - Domain: "My Domain", - URL: server.URL, - APIKey: "[YOUR_API_KEY]", - Component: "MyApp", - Environment: "generated-auto-update-global", - Options: ContextOptions{ - Local: true, - SnapshotLocation: t.TempDir(), - }, - }) - t.Cleanup(TerminateSnapshotAutoUpdate) - - version, err := LoadSnapshot(&LoadSnapshotOptions{FetchRemote: true}) - require.NoError(t, err) - require.Equal(t, 1588557288040, version) - - callbacks := make(chan struct { - err error - updated bool - }, 1) - ScheduleSnapshotAutoUpdate(50*time.Millisecond, func(err error, updated bool) { - select { - case callbacks <- struct { - err error - updated bool - }{err: err, updated: updated}: - default: - } - }) - - select { - case callback := <-callbacks: - assert.NoError(t, callback.err) - assert.True(t, callback.updated) - case <-time.After(5 * time.Second): - t.Fatal("expected scheduled snapshot update callback") - } - - assert.Eventually(t, func() bool { - got, gotErr := GetSwitcher("FF2FOR2030").IsOn() - return gotErr == nil && !got && SnapshotVersion() == 1588557288041 - }, 5*time.Second, 100*time.Millisecond) - }) - - t.Run("should terminate snapshot auto update using package function", func(t *testing.T) { - BuildContext(Context{ - Domain: "My Domain", - Environment: "default", - Options: ContextOptions{ - Local: true, - SnapshotLocation: snapshotFixtureDir(), - }, - }) - t.Cleanup(TerminateSnapshotAutoUpdate) - - _, err := LoadSnapshot(nil) - require.NoError(t, err) + assert.NoError(t, err) + assert.Equal(t, 0, version) + assert.Equal(t, 0, SnapshotVersion()) - callbackCh := make(chan struct{}, 1) - ScheduleSnapshotAutoUpdate(200*time.Millisecond, func(_ error, _ bool) { - select { - case callbackCh <- struct{}{}: - default: - } - }) - TerminateSnapshotAutoUpdate() + content, readErr := os.ReadFile(filepath.Join(snapshotDir, "generated-clean.json")) + assert.NoError(t, readErr) - select { - case <-callbackCh: - t.Fatal("did not expect auto update callback after terminate") - case <-time.After(400 * time.Millisecond): - } + var snapshot Snapshot + unmarshalErr := json.Unmarshal(content, &snapshot) + assert.NoError(t, unmarshalErr) + assert.Equal(t, 0, snapshot.Domain.Version) }) } diff --git a/snapshot_watcher.go b/snapshot_watcher.go new file mode 100644 index 0000000..00e587f --- /dev/null +++ b/snapshot_watcher.go @@ -0,0 +1,135 @@ +package client + +import ( + "fmt" + "os" + "strings" + "sync" + "time" +) + +const snapshotWatcherPollInterval = 100 * time.Millisecond + +type WatchSnapshotCallback struct { + Success func() + Reject func(error) +} + +type snapshotWatcher struct { + mu sync.Mutex + stop chan struct{} + done chan struct{} +} + +func newSnapshotWatcher() *snapshotWatcher { + return &snapshotWatcher{} +} + +func WatchSnapshot(callback WatchSnapshotCallback) error { + return defaultClient().WatchSnapshot(callback) +} + +func (c *Client) WatchSnapshot(callback WatchSnapshotCallback) error { + return c.snapshotWatcher.Start(c, callback) +} + +func UnwatchSnapshot() { + defaultClient().UnwatchSnapshot() +} + +func (c *Client) UnwatchSnapshot() { + if c.snapshotWatcher != nil { + c.snapshotWatcher.Stop() + } +} + +func (w *snapshotWatcher) Start(client *Client, callback WatchSnapshotCallback) error { + snapshotLocation := strings.TrimSpace(client.Context().Options.SnapshotLocation) + if snapshotLocation == "" { + return fmt.Errorf("snapshot location is not defined in the context options") + } + + snapshotFile := snapshotFilePath(client.Context()) + info, err := os.Stat(snapshotFile) + if err != nil { + return err + } + + w.Stop() + + stop := make(chan struct{}) + done := make(chan struct{}) + + w.mu.Lock() + w.stop = stop + w.done = done + w.mu.Unlock() + + go func() { + defer close(done) + + ticker := time.NewTicker(snapshotWatcherPollInterval) + defer ticker.Stop() + + lastModified := info.ModTime() + lastSize := info.Size() + + for { + select { + case <-stop: + return + case <-ticker.C: + currentInfo, statErr := os.Stat(snapshotFile) + if statErr != nil { + invokeWatchReject(callback, statErr) + continue + } + + if currentInfo.ModTime().Equal(lastModified) && currentInfo.Size() == lastSize { + continue + } + + lastModified = currentInfo.ModTime() + lastSize = currentInfo.Size() + + if _, loadErr := client.loadSnapshotFromCurrentFile(); loadErr != nil { + invokeWatchReject(callback, loadErr) + continue + } + + invokeWatchSuccess(callback) + } + } + }() + + return nil +} + +func (w *snapshotWatcher) Stop() { + w.mu.Lock() + stop := w.stop + done := w.done + w.stop = nil + w.done = nil + w.mu.Unlock() + + if stop != nil { + close(stop) + } + + if done != nil { + <-done + } +} + +func invokeWatchSuccess(callback WatchSnapshotCallback) { + if callback.Success != nil { + callback.Success() + } +} + +func invokeWatchReject(callback WatchSnapshotCallback, err error) { + if callback.Reject != nil { + callback.Reject(err) + } +} diff --git a/snapshot_watcher_test.go b/snapshot_watcher_test.go new file mode 100644 index 0000000..03eecbe --- /dev/null +++ b/snapshot_watcher_test.go @@ -0,0 +1,259 @@ +package client + +import ( + "errors" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSnapshotWatcher(t *testing.T) { + t.Run("should return an error when watch snapshot fails during load snapshot", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", + Options: ContextOptions{ + Local: true, + }, + }) + + version, err := LoadSnapshot(&LoadSnapshotOptions{ + WatchSnapshot: true, + }) + + assert.Error(t, err) + assert.Zero(t, version) + assert.EqualError(t, err, "snapshot location is not defined in the context options") + }) + + t.Run("should return an error when watch snapshot has no snapshot location", func(t *testing.T) { + BuildContext(Context{ + Domain: "My Domain", + }) + t.Cleanup(UnwatchSnapshot) + + err := WatchSnapshot(WatchSnapshotCallback{}) + + assert.EqualError(t, err, "snapshot location is not defined in the context options") + }) + + t.Run("should watch snapshot using package function", func(t *testing.T) { + snapshotDir := t.TempDir() + writeSnapshotFixture(t, snapshotDir, "watched-global", "default_load_1") + + BuildContext(Context{ + Domain: "My Domain", + Environment: "watched-global", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + t.Cleanup(UnwatchSnapshot) + + version, err := LoadSnapshot(nil) + require.NoError(t, err) + require.Equal(t, 1588557288040, version) + + successCh := make(chan struct{}, 1) + watchErr := WatchSnapshot(WatchSnapshotCallback{ + Success: func() { + select { + case successCh <- struct{}{}: + default: + } + }, + }) + require.NoError(t, watchErr) + + writeSnapshotFixture(t, snapshotDir, "watched-global", "default_load_2") + + select { + case <-successCh: + case <-time.After(5 * time.Second): + t.Fatal("expected watch snapshot callback") + } + + assert.Eventually(t, func() bool { + got, gotErr := GetSwitcher("FF2FOR2030").IsOn() + return gotErr == nil && !got && SnapshotVersion() == 1588557288041 + }, 5*time.Second, 100*time.Millisecond) + }) + + t.Run("should stop watching snapshot using package unwatch function", func(t *testing.T) { + snapshotDir := t.TempDir() + writeSnapshotFixture(t, snapshotDir, "watched-global-stop", "default_load_1") + + BuildContext(Context{ + Domain: "My Domain", + Environment: "watched-global-stop", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + t.Cleanup(UnwatchSnapshot) + + _, err := LoadSnapshot(nil) + require.NoError(t, err) + + callbackCh := make(chan struct{}, 1) + watchErr := WatchSnapshot(WatchSnapshotCallback{ + Success: func() { + select { + case callbackCh <- struct{}{}: + default: + } + }, + }) + require.NoError(t, watchErr) + + UnwatchSnapshot() + writeSnapshotFixture(t, snapshotDir, "watched-global-stop", "default_load_2") + + select { + case <-callbackCh: + t.Fatal("did not expect watch callback after unwatch") + case <-time.After(400 * time.Millisecond): + } + + assert.Equal(t, 1588557288040, SnapshotVersion()) + }) + + t.Run("should return an error when watch snapshot cannot stat the file at startup", func(t *testing.T) { + snapshotDir := t.TempDir() + + client := NewClient(Context{ + Domain: "My Domain", + Environment: "missing-watch-file", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + t.Cleanup(client.UnwatchSnapshot) + + err := client.WatchSnapshot(WatchSnapshotCallback{}) + + assert.Error(t, err) + assert.ErrorIs(t, err, os.ErrNotExist) + }) + + t.Run("should watch the snapshot file when load snapshot enables watch mode", func(t *testing.T) { + snapshotDir := t.TempDir() + writeSnapshotFixture(t, snapshotDir, "watched", "default_load_1") + + client := NewClient(Context{ + Domain: "My Domain", + Environment: "watched", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + t.Cleanup(client.UnwatchSnapshot) + + version, err := client.LoadSnapshot(&LoadSnapshotOptions{WatchSnapshot: true}) + require.NoError(t, err) + require.Equal(t, 1588557288040, version) + + enabled, enabledErr := client.GetSwitcher("FF2FOR2030").IsOn() + require.NoError(t, enabledErr) + require.True(t, enabled) + + // delay to ensure the watcher goroutine is running before the file is updated + time.Sleep(100 * time.Millisecond) + + writeSnapshotFixture(t, snapshotDir, "watched", "default_load_2") + + require.Eventually(t, func() bool { + got, gotErr := client.GetSwitcher("FF2FOR2030").IsOn() + return gotErr == nil && !got && client.SnapshotVersion() == 1588557288041 + }, 5*time.Second, 100*time.Millisecond) + }) + + t.Run("should reject watch updates when the modified snapshot is malformed", func(t *testing.T) { + snapshotDir := t.TempDir() + writeSnapshotFixture(t, snapshotDir, "watched", "default_load_1") + + client := NewClient(Context{ + Domain: "My Domain", + Environment: "watched", + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + t.Cleanup(client.UnwatchSnapshot) + + _, err := client.LoadSnapshot(nil) + require.NoError(t, err) + + rejectCh := make(chan error, 1) + watchErr := client.WatchSnapshot(WatchSnapshotCallback{ + Reject: func(err error) { + select { + case rejectCh <- err: + default: + } + }, + }) + require.NoError(t, watchErr) + + content, readErr := os.ReadFile(filepath.Join(snapshotFixtureDir(), "default_malformed.json")) + require.NoError(t, readErr) + writeErr := os.WriteFile(filepath.Join(snapshotDir, "watched.json"), content, 0o644) + require.NoError(t, writeErr) + + select { + case rejectErr := <-rejectCh: + assert.Error(t, rejectErr) + case <-time.After(5 * time.Second): + t.Fatal("expected malformed snapshot watch callback") + } + }) + + t.Run("should reject watch updates when watched snapshot file becomes unavailable", func(t *testing.T) { + snapshotDir := t.TempDir() + environment := "watched-missing-runtime" + writeSnapshotFixture(t, snapshotDir, environment, "default_load_1") + + client := NewClient(Context{ + Domain: "My Domain", + Environment: environment, + Options: ContextOptions{ + Local: true, + SnapshotLocation: snapshotDir, + }, + }) + t.Cleanup(client.UnwatchSnapshot) + + _, err := client.LoadSnapshot(nil) + require.NoError(t, err) + + rejectCh := make(chan error, 1) + watchErr := client.WatchSnapshot(WatchSnapshotCallback{ + Reject: func(err error) { + select { + case rejectCh <- err: + default: + } + }, + }) + require.NoError(t, watchErr) + + removeErr := os.Remove(filepath.Join(snapshotDir, environment+".json")) + require.NoError(t, removeErr) + + select { + case rejectErr := <-rejectCh: + assert.Error(t, rejectErr) + assert.True(t, errors.Is(rejectErr, os.ErrNotExist)) + case <-time.After(5 * time.Second): + t.Fatal("expected stat error callback when watched snapshot file is removed") + } + }) +}