diff --git a/json_test.go b/json_test.go new file mode 100644 index 0000000..1acf7d2 --- /dev/null +++ b/json_test.go @@ -0,0 +1,52 @@ +package version + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVersionJSONSerialization(t *testing.T) { + v := MustParse("v20.1.2-alpha.3-cloudonly.4") + + blob, err := json.Marshal(v) + require.NoError(t, err) + + var parsed Version + err = json.Unmarshal(blob, &parsed) + require.NoError(t, err) + + require.Equal(t, v, parsed) +} + +func TestNullVersionJSONSerialization(t *testing.T) { + t.Run("valid", func(t *testing.T) { + v := MustParse("v20.1.2-alpha.3-cloudonly.4") + nv := NewNullVersion(v) + + blob, err := json.Marshal(nv) + require.NoError(t, err) + + var parsed NullVersion + err = json.Unmarshal(blob, &parsed) + require.NoError(t, err) + + require.True(t, parsed.Valid) + require.Equal(t, v, parsed.Version) + }) + + t.Run("invalid", func(t *testing.T) { + v := NullVersion{Valid: false} + + blob, err := json.Marshal(v) + require.NoError(t, err) + + var parsed NullVersion + err = json.Unmarshal(blob, &parsed) + require.NoError(t, err) + + require.False(t, parsed.Valid) + require.Equal(t, Version{}, parsed.Version) + }) +} diff --git a/null_version.go b/null_version.go new file mode 100644 index 0000000..f158ec9 --- /dev/null +++ b/null_version.go @@ -0,0 +1,74 @@ +package version + +import ( + "database/sql/driver" + "encoding/json" + + "github.com/cockroachdb/errors" +) + +// Represents a NULLable version when stored in the database. The zero +// value of NullVersion serializes as database NULL (and vice-versa). +type NullVersion struct { + Valid bool + Version Version +} + +func NewNullVersion(v Version) NullVersion { + return NullVersion{ + Valid: !v.Empty(), + Version: v, + } +} + +// Value is used when serializing a NullVersion for storage in the db. +func (n NullVersion) Value() (driver.Value, error) { + if n.Valid { + return n.Version.String(), nil + } else { + return nil, nil + } +} + +// Scan implements sql.Scanner, and is used when deserializing a NullVersion from the db. +func (n *NullVersion) Scan(value interface{}) error { + if value == nil { + *n = NullVersion{Valid: false, Version: Version{}} + return nil + } + err := n.Version.Scan(value) + if err != nil { + return err + } + n.Valid = true + return nil +} + +// We must implement json.Unmarshaler, because the invalid NullVersion stores an empty +// string in the version field, and we don't want to make Version unmarshal successfully +// from empty string (it should and does maintain the same behavior as Parse). +func (n *NullVersion) UnmarshalJSON(data []byte) error { + var rawMap map[string]interface{} + if err := json.Unmarshal(data, &rawMap); err != nil { + return err + } + if valid, ok := rawMap["Valid"].(bool); ok && !valid { + n.Valid = false + n.Version = Version{} + return nil + } else if ok && valid { + // then Version is a map like {"$raw": "vX.Y.Z"} + if versionMap, ok := rawMap["Version"].(map[string]interface{}); ok { + if rawVersion, ok := versionMap["$raw"].(string); ok { + parsed, err := Parse(rawVersion) + if err != nil { + return err + } + n.Valid = true + n.Version = parsed + return nil + } + } + } + return errors.Newf("cannot parse '%s' as NullVersion", data) +} diff --git a/sql_test.go b/sql_test.go new file mode 100644 index 0000000..d79d55c --- /dev/null +++ b/sql_test.go @@ -0,0 +1,53 @@ +package version + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVersionScan(t *testing.T) { + t.Run("valid", func(t *testing.T) { + v := MustParse("v20.1.2-alpha.3-cloudonly.4") + + var scanned Version + err := scanned.Scan(v.String()) + require.NoError(t, err) + require.Equal(t, v, scanned) + }) + + t.Run("empty", func(t *testing.T) { + var scanned Version + err := scanned.Scan("") + require.NoError(t, err) + require.True(t, scanned.Empty()) + }) +} + +func TestNullVersionScan(t *testing.T) { + t.Run("valid", func(t *testing.T) { + v := MustParse("v20.1.2-alpha.3-cloudonly.4") + nv := NewNullVersion(v) + + var scanned NullVersion + err := scanned.Scan(v.String()) + require.NoError(t, err) + require.Equal(t, nv, scanned) + }) + + t.Run("null", func(t *testing.T) { + var scanned NullVersion + err := scanned.Scan(nil) + require.NoError(t, err) + require.False(t, scanned.Valid) + require.Equal(t, Version{}, scanned.Version) + }) + + t.Run("empty", func(t *testing.T) { + var scanned NullVersion + err := scanned.Scan("") + require.NoError(t, err) + require.True(t, scanned.Valid) + require.True(t, scanned.Version.Empty()) + }) +}