diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..a37af17 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,47 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + +jobs: + build_and_test: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + check-latest: true + go-version-file: "go.mod" + + - name: Display Go version + run: go version + + - name: Test + run: make test + + golangci-lint: + runs-on: ubuntu-latest + env: + GOLANGCI_LINT_VERSION: v1.64.6 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + check-latest: true + go-version-file: "go.mod" + + - name: Lint + uses: golangci/golangci-lint-action@v6 + with: + version: ${{ env.GOLANGCI_LINT_VERSION }} diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..b1e7e9a --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,96 @@ +issues: + exclude: + - 'Error return value of .(\w+\.Rollback(.*)). is not checked' + +linters: + presets: + - bugs + - comment + - format + - performance + - style + - test + - unused + + disable: + # disabled, but which we should enable with discussion + - wrapcheck # checks that errors are wrapped; currently not done anywhere + + # disabled because we're not compliant, but which we should think about + - exhaustruct # checks that properties in structs are exhaustively defined; may be a good idea + - testpackage # requires tests in test packages like `river_test` + + # disabled because they're annoying/bad + - interfacebloat # we do in fact want >10 methods on the Adapter interface or wherever we see fit. + - godox # bans TODO statements; total non-starter at the moment + - err113 # wants all errors to be defined as variables at the package level; quite obnoxious + - mnd # detects "magic numbers", which it defines as any number; annoying + - ireturn # bans returning interfaces; questionable as is, but also buggy as hell; very, very annoying + - lll # restricts maximum line length; annoying + - nlreturn # requires a blank line before returns; annoying + - wsl # a bunch of style/whitespace stuff; annoying + +linters-settings: + depguard: + rules: + all: + files: ["$all"] + deny: + - desc: "Use `github.com/google/uuid` package for UUIDs instead." + pkg: "github.com/xtgo/uuid" + + forbidigo: + forbid: + - msg: "Use `require` variants instead." + p: '^assert\.' + - msg: "Use `Func` suffix for function variables instead." + p: 'Fn\b' + - msg: "Use built-in `max` function instead." + p: '\bmath\.Max\b' + - msg: "Use built-in `min` function instead." + p: '\bmath\.Min\b' + + gci: + sections: + - Standard + - Default + - Prefix(github.com/riverqueue) + + gomoddirectives: + replace-local: true + + gosec: + excludes: + - G404 # use of non-crypto random; overly broad for our use case + + revive: + rules: + - name: unused-parameter + disabled: true + + tagliatelle: + case: + rules: + json: snake + + testifylint: + enable-all: true + disable: + - go-require + + varnamelen: + ignore-names: + - db + - eg + - f + - i + - id + - j + - mu + - r + - sb # common convention for string builder + - t + - tt # common convention for table tests + - tx + - w + - wg diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..0ac62ef --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,7 @@ +Copyright 2025 Blake Gentry, Brandur Leach + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c257ff2 --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +.DEFAULT_GOAL := help + +# Looks at comments using ## on targets and uses them to produce a help output. +.PHONY: help +help: ALIGN=22 +help: ## Print this message + @awk -F '::? .*## ' -- "/^[^':]+::? .*## /"' { printf "'$$(tput bold)'%-$(ALIGN)s'$$(tput sgr0)' %s\n", $$1, $$2 }' $(MAKEFILE_LIST) + +.PHONY: lint +lint:: ## Run linter + golangci-lint run --fix + +.PHONY: test +test:: ## Run test suite + go test ./... + +.PHONY: test/race +test/race:: ## Run test suite with race detector + go test ./... -race + +.PHONY: tidy +tidy:: ## Run `go mod tidy` + go mod tidy \ No newline at end of file diff --git a/apiendpoint/api_endpoint.go b/apiendpoint/api_endpoint.go new file mode 100644 index 0000000..ee02a00 --- /dev/null +++ b/apiendpoint/api_endpoint.go @@ -0,0 +1,235 @@ +// Package apiendpoint provides a lightweight API framework extracted from its +// original use in River projects. It lets API endpoints be defined, then +// mounted into an http.ServeMux. +package apiendpoint + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/riverqueue/riverapiframe/apierror" + "github.com/riverqueue/riverapiframe/internal/validate" +) + +// Endpoint is a struct that should be embedded on an API endpoint, and which +// provides a partial implementation for EndpointInterface. +type Endpoint[TReq any, TResp any] struct { + // Logger used to log information about endpoint execution. + logger *slog.Logger + + // Metadata about the endpoint. This is not available until SetMeta is + // invoked on the endpoint, which is usually done in Mount. + meta *EndpointMeta +} + +func (e *Endpoint[TReq, TResp]) SetLogger(logger *slog.Logger) { e.logger = logger } +func (e *Endpoint[TReq, TResp]) SetMeta(meta *EndpointMeta) { e.meta = meta } + +type EndpointInterface interface { + // Meta returns metadata about an API endpoint, like the path it should be + // mounted at, and the status code it returns on success. + // + // This should be implemented by each specific API endpoint. + Meta() *EndpointMeta + + // SetLogger sets a logger on the endpoint. + // + // Implementation inherited from an embedded Endpoint struct. + SetLogger(logger *slog.Logger) + + // SetMeta sets metadata on an Endpoint struct after it's extracted from a + // call to an endpoint's Meta function. + // + // Implementation inherited from an embedded Endpoint struct. + SetMeta(meta *EndpointMeta) +} + +// EndpointExecuteInterface is an interface to an API endpoint. Some of it is +// implemented by an embedded Endpoint struct, and some of it should be +// implemented by the endpoint itself. +type EndpointExecuteInterface[TReq any, TResp any] interface { + EndpointInterface + + // Execute executes the API endpoint. + // + // This should be implemented by each specific API endpoint. + Execute(ctx context.Context, req *TReq) (*TResp, error) +} + +// EndpointMeta is metadata about an API endpoint. +type EndpointMeta struct { + // Pattern is the API endpoint's HTTP method and path where it should be + // mounted, which is passed to http.ServeMux by Mount. It should start with + // a verb like `GET` or `POST`, and may contain Go 1.22 path variables like + // `{name}`, whose values should be extracted by an endpoint request + // struct's custom ExtractRaw implementation. + Pattern string + + // StatusCode is the status code to be set on a successful response. + StatusCode int +} + +func (m *EndpointMeta) validate() { + if m.Pattern == "" { + panic("Endpoint.Path is required") + } + if m.StatusCode == 0 { + panic("Endpoint.StatusCode is required") + } +} + +// Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log +// information about endpoint execution. +func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointExecuteInterface[TReq, TResp]) EndpointInterface { + apiEndpoint.SetLogger(logger) + + meta := apiEndpoint.Meta() + meta.validate() // panic on problem + apiEndpoint.SetMeta(meta) + + mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) { + executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute) + }) + + return apiEndpoint +} + +func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Run as much code as we can in a sub-function that can return an error. + // This is more convenient to write, but is also safer because unlike when + // writing errors to ResponseWriter, there's no danger of a missing return. + err := func() error { + var req TReq + if r.Method != http.MethodGet { + reqData, err := io.ReadAll(r.Body) + if err != nil { + return fmt.Errorf("error reading request body: %w", err) + } + + if len(reqData) > 0 { + if err := json.Unmarshal(reqData, &req); err != nil { + return apierror.NewBadRequestf("Error unmarshaling request body: %s.", err) + } + } + } + + if rawExtractor, ok := any(&req).(RawExtractor); ok { + if err := rawExtractor.ExtractRaw(r); err != nil { + return err + } + } + + if err := validate.StructCtx(ctx, &req); err != nil { + return apierror.NewBadRequest(validate.PublicFacingMessage(err)) + } + + resp, err := execute(ctx, &req) + if err != nil { + return err + } + + if rawExtractor, ok := any(resp).(RawResponder); ok { + return rawExtractor.RespondRaw(w) + } + + respData, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("error marshaling response JSON: %w", err) + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(meta.StatusCode) + + if _, err := w.Write(respData); err != nil { + return fmt.Errorf("error writing response: %w", err) + } + + return nil + }() + if err != nil { + // Convert certain types of Postgres errors into something more + // user-friendly than an internal server error. + err = maybeInterpretInternalError(err) + + var apiErr apierror.Interface + if errors.As(err, &apiErr) { + logAttrs := []any{ + slog.String("error", apiErr.Error()), + } + + if internalErr := apiErr.GetInternalError(); internalErr != nil { + logAttrs = append(logAttrs, slog.String("internal_error", internalErr.Error())) + } + + // Logged at info level because API errors are normal. + logger.InfoContext(ctx, "API error response", logAttrs...) + + apiErr.Write(ctx, logger, w) + return + } + + if errors.Is(err, context.DeadlineExceeded) { + logger.ErrorContext(ctx, "request timeout", slog.String("error", err.Error())) + apierror.NewServiceUnavailable("Request timed out. Retrying the request might work.").Write(ctx, logger, w) + return + } + + // Internal server error. The error goes to logs but should not be + // included in the response in case there's something sensitive in + // the error string. + logger.ErrorContext(ctx, "error running API route", slog.String("error", err.Error())) + apierror.NewInternalServerError("Internal server error. Check logs for more information.").Write(ctx, logger, w) + } +} + +// RawExtractor is an interface that can be implemented by request structs that +// allows them to extract information from a raw request, like path values. +type RawExtractor interface { + ExtractRaw(r *http.Request) error +} + +// RawResponder is an interface that can be implemented by response structs that +// allow them to respond directly to a ResponseWriter instead of emitting the +// normal JSON format. +type RawResponder interface { + RespondRaw(w http.ResponseWriter) error +} + +// Make some broad categories of internal error back into something public +// facing because in some cases they can be a vast help for debugging. +func maybeInterpretInternalError(err error) error { + var ( + apiErr apierror.Interface + connectErr *pgconn.ConnectError + pgErr *pgconn.PgError + ) + + switch { + case errors.As(err, &connectErr): + apiErr = apierror.NewBadRequest("There was a problem connecting to the configured database. Check logs for details.") + + case errors.As(err, &pgErr): + if pgErr.Code == pgerrcode.InsufficientPrivilege { + apiErr = apierror.NewBadRequest("Insufficient database privilege to perform this operation.") + } else { + return err + } + + default: + return err + } + + return apierror.WithInternalError(apiErr, err) +} diff --git a/apiendpoint/api_endpoint_test.go b/apiendpoint/api_endpoint_test.go new file mode 100644 index 0000000..2717e5e --- /dev/null +++ b/apiendpoint/api_endpoint_test.go @@ -0,0 +1,316 @@ +package apiendpoint + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/rivershared/riversharedtest" + "github.com/riverqueue/riverapiframe/apierror" +) + +func TestMountAndServe(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + recorder *httptest.ResponseRecorder + } + + setup := func(t *testing.T) (*http.ServeMux, *testBundle) { + t.Helper() + + var ( + logger = riversharedtest.Logger(t) + mux = http.NewServeMux() + ) + + Mount(mux, logger, &getEndpoint{}) + Mount(mux, logger, &postEndpoint{}) + + return mux, &testBundle{ + recorder: httptest.NewRecorder(), + } + } + + t.Run("GetEndpointAndExtractRaw", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", nil) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) + }) + + t.Run("BodyIgnoredOnGet", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", + bytes.NewBuffer(mustMarshalJSON(t, &getRequest{IgnoredJSONMessage: "Ignored hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) + }) + + t.Run("MethodNotAllowed", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint/Hello.", nil) + mux.ServeHTTP(bundle.recorder, req) + + // This error comes from net/http. + requireStatusAndResponse(t, http.StatusMethodNotAllowed, "Method Not Allowed\n", bundle.recorder) + }) + + t.Run("PostEndpoint", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) + }) + + t.Run("ValidationError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Field `message` is required."}, bundle.recorder) + }) + + t.Run("APIError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeAPIError: true, Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Bad request."}, bundle.recorder) + }) + + t.Run("InterpretedError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true, Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Insufficient database privilege to perform this operation."}, bundle.recorder) + }) + + t.Run("Timeout", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + ctx, cancel := context.WithDeadline(ctx, time.Now()) + t.Cleanup(cancel) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + require.NoError(t, err) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusServiceUnavailable, &apierror.APIError{Message: "Request timed out. Retrying the request might work."}, bundle.recorder) + }) + + t.Run("InternalServerError", func(t *testing.T) { + t.Parallel() + + mux, bundle := setup(t) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true, Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) + }) +} + +func TestMaybeInterpretInternalError(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("ConnectError", func(t *testing.T) { + t.Parallel() + + _, err := pgconn.Connect(ctx, "postgres://user@127.0.0.1:37283/does_not_exist") + + require.Equal(t, apierror.WithInternalError(apierror.NewBadRequest("There was a problem connecting to the configured database. Check logs for details."), err), maybeInterpretInternalError(err)) + }) + + t.Run("ConnectError", func(t *testing.T) { + t.Parallel() + + err := &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege} + + require.Equal(t, apierror.WithInternalError(apierror.NewBadRequest("Insufficient database privilege to perform this operation."), err), maybeInterpretInternalError(err)) + }) + + t.Run("OtherPGError", func(t *testing.T) { + t.Parallel() + + err := &pgconn.PgError{Code: pgerrcode.CardinalityViolation} + + require.Equal(t, err, maybeInterpretInternalError(err)) + }) + + t.Run("ConnectError", func(t *testing.T) { + t.Parallel() + + err := errors.New("other error") + + require.Equal(t, err, maybeInterpretInternalError(err)) + }) +} + +func mustMarshalJSON(t *testing.T, v any) []byte { + t.Helper() + + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} + +func mustUnmarshalJSON[T any](t *testing.T, data []byte) *T { + t.Helper() + + var val T + err := json.Unmarshal(data, &val) + require.NoError(t, err) + return &val +} + +// Shortcut for requiring an HTTP status code and a JSON-marshaled response +// equivalent to expectedResp. The important thing that is does is that in the +// event of a failure on status code, it prints the response body as additional +// context to help debug the problem. +func requireStatusAndJSONResponse[T any](t *testing.T, expectedStatusCode int, expectedResp *T, recorder *httptest.ResponseRecorder) { + t.Helper() + + require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) + require.Equal(t, expectedResp, mustUnmarshalJSON[T](t, recorder.Body.Bytes())) + require.Equal(t, "application/json; charset=utf-8", recorder.Header().Get("Content-Type")) +} + +// Same as the above, but for a non-JSON response. +func requireStatusAndResponse(t *testing.T, expectedStatusCode int, expectedResp string, recorder *httptest.ResponseRecorder) { + t.Helper() + + require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) + require.Equal(t, expectedResp, recorder.Body.String()) +} + +// +// getEndpoint +// + +type getEndpoint struct { + Endpoint[getRequest, getResponse] +} + +func (*getEndpoint) Meta() *EndpointMeta { + return &EndpointMeta{ + Pattern: "GET /api/get-endpoint/{message}", + StatusCode: http.StatusOK, + } +} + +type getRequest struct { + IgnoredJSONMessage string `json:"ignored_json" validate:"-"` + Message string `json:"-" validate:"required"` +} + +func (req *getRequest) ExtractRaw(r *http.Request) error { + req.Message = r.PathValue("message") + return nil +} + +type getResponse struct { + Message string `json:"message" validate:"required"` +} + +func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, error) { + // This branch never gets taken because request bodies are ignored on GET. + if req.IgnoredJSONMessage != "" { + return &getResponse{Message: req.IgnoredJSONMessage}, nil + } + + return &getResponse{Message: req.Message}, nil +} + +// +// postEndpoint +// + +type postEndpoint struct { + Endpoint[postRequest, postResponse] +} + +func (*postEndpoint) Meta() *EndpointMeta { + return &EndpointMeta{ + Pattern: "POST /api/post-endpoint", + StatusCode: http.StatusCreated, + } +} + +type postRequest struct { + MakeAPIError bool `json:"make_api_error" validate:"-"` + MakeInternalError bool `json:"make_internal_error" validate:"-"` + MakePostgresError bool `json:"make_postgres_error" validate:"-"` + Message string `json:"message" validate:"required"` +} + +type postResponse struct { + Message string `json:"message"` +} + +func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResponse, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + if req.MakeAPIError { + return nil, apierror.NewBadRequest("Bad request.") + } + + if req.MakeInternalError { + return nil, errors.New("an internal error occurred") + } + + if req.MakePostgresError { + // Wrap the error to make it more realistic. + return nil, fmt.Errorf("error running Postgres query: %w", &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege}) + } + + return &postResponse{Message: req.Message}, nil +} diff --git a/apierror/api_error.go b/apierror/api_error.go new file mode 100644 index 0000000..4923dde --- /dev/null +++ b/apierror/api_error.go @@ -0,0 +1,153 @@ +// Package apierror contains a variety of marshalable API errors that adhere to +// a unified error response convention. +package apierror + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" +) + +// APIError is a struct that's embedded on a more specific API error struct (as +// seen below), and which provides a JSON serialization and a wait to +// conveniently write itself to an HTTP response. +// +// APIErrorInterface should be used with errors.As instead of this struct. +type APIError struct { + // InternalError is an additional error that might be associated with the + // API error. It's not returned in the API error response, but is logged in + // API endpoint execution to provide extra information for operators. + InternalError error `json:"-"` + + // Message is a descriptive, human-friendly message indicating what went + // wrong. Try to make error messages as actionable as possible to help the + // caller easily fix what went wrong. + Message string `json:"message"` + + // StatusCode is the API error's HTTP status code. It's not marshaled to + // JSON, but determines how the error is written to a response. + StatusCode int `json:"-"` +} + +func (e *APIError) Error() string { return e.Message } +func (e *APIError) GetInternalError() error { return e.InternalError } +func (e *APIError) SetInternalError(internalErr error) { e.InternalError = internalErr } + +// Write writes the API error to an HTTP response, writing to the given logger +// in case of a problem. +func (e *APIError) Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(e.StatusCode) + + respData, err := json.Marshal(e) + if err != nil { + logger.ErrorContext(ctx, "error marshaling API error", slog.String("error", err.Error())) + } + + if _, err := w.Write(respData); err != nil { + logger.ErrorContext(ctx, "error writing API error", slog.String("error", err.Error())) + } +} + +// Interface is an interface to an API error. This is needed for use with +// errors.As because APIError itself is embedded on another error struct, and +// won't be usable as an errors.As target. +type Interface interface { + Error() string + GetInternalError() error + SetInternalError(internalErr error) + Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) +} + +// WithInternalError is a convenience function for assigning an internal error +// to the given API error and returning it. +func WithInternalError[TAPIError Interface](apiErr TAPIError, internalErr error) TAPIError { + apiErr.SetInternalError(internalErr) + return apiErr +} + +// +// BadRequest +// + +type BadRequest struct { //nolint:errname + APIError +} + +func NewBadRequest(message string) *BadRequest { + return &BadRequest{ + APIError: APIError{ + Message: message, + StatusCode: http.StatusBadRequest, + }, + } +} + +func NewBadRequestf(format string, a ...any) *BadRequest { + return NewBadRequest(fmt.Sprintf(format, a...)) +} + +// +// InternalServerError +// + +type InternalServerError struct { + APIError +} + +func NewInternalServerError(message string) *InternalServerError { + return &InternalServerError{ + APIError: APIError{ + Message: message, + StatusCode: http.StatusInternalServerError, + }, + } +} + +func NewInternalServerErrorf(format string, a ...any) *InternalServerError { + return NewInternalServerError(fmt.Sprintf(format, a...)) +} + +// +// NotFound +// + +type NotFound struct { //nolint:errname + APIError +} + +func NewNotFound(message string) *NotFound { + return &NotFound{ + APIError: APIError{ + Message: message, + StatusCode: http.StatusNotFound, + }, + } +} + +func NewNotFoundf(format string, a ...any) *NotFound { + return NewNotFound(fmt.Sprintf(format, a...)) +} + +// +// ServiceUnavailable +// + +type ServiceUnavailable struct { //nolint:errname + APIError +} + +func NewServiceUnavailable(message string) *ServiceUnavailable { + return &ServiceUnavailable{ + APIError: APIError{ + Message: message, + StatusCode: http.StatusServiceUnavailable, + }, + } +} + +func NewServiceUnavailablef(format string, a ...any) *ServiceUnavailable { + return NewServiceUnavailable(fmt.Sprintf(format, a...)) +} diff --git a/apierror/api_error_test.go b/apierror/api_error_test.go new file mode 100644 index 0000000..d2fa195 --- /dev/null +++ b/apierror/api_error_test.go @@ -0,0 +1,74 @@ +package apierror + +import ( + "context" + "encoding/json" + "errors" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/rivershared/riversharedtest" +) + +func TestAPIError(t *testing.T) { + t.Parallel() + + var ( + anErr = errors.New("an error") + apiErr = NewBadRequest("Bad request.") + ) + + apiErr.SetInternalError(anErr) + require.Equal(t, anErr, apiErr.GetInternalError()) +} + +func TestAPIErrorJSON(t *testing.T) { + t.Parallel() + + require.JSONEq(t, + `{"message":"Bad request. Try sending JSON next time."}`, + string(mustMarshalJSON( + t, NewBadRequest("Bad request. Try sending JSON next time."))), + ) +} + +func TestAPIErrorWrite(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + logger = riversharedtest.Logger(t) + recorder = httptest.NewRecorder() + ) + + NewBadRequest("Bad request. Try sending JSON next time.").Write(ctx, logger, recorder) + + require.Equal(t, 400, recorder.Result().StatusCode) + require.JSONEq(t, + `{"message":"Bad request. Try sending JSON next time."}`, + recorder.Body.String(), + ) + require.Equal(t, "application/json; charset=utf-8", recorder.Header().Get("Content-Type")) +} + +func TestWithInternalError(t *testing.T) { + t.Parallel() + + var ( + anErr = errors.New("an error") + apiErr = NewBadRequest("Bad request.") + ) + + apiErr = WithInternalError(apiErr, anErr) + require.Equal(t, anErr, apiErr.InternalError) +} + +func mustMarshalJSON(t *testing.T, v any) []byte { + t.Helper() + + data, err := json.Marshal(v) + require.NoError(t, err) + return data +} diff --git a/apimiddleware/api_middleware.go b/apimiddleware/api_middleware.go new file mode 100644 index 0000000..9149d8f --- /dev/null +++ b/apimiddleware/api_middleware.go @@ -0,0 +1,71 @@ +package apimiddleware + +import ( + "net/http" +) + +// middlewareInterface is an interface to be implemented by middleware. +type middlewareInterface interface { + Middleware(next http.Handler) http.Handler +} + +// MiddlewareFunc allows a simple middleware to be defined as only a function. +type MiddlewareFunc func(next http.Handler) http.Handler + +// Middleware allows MiddlewareFunc to implement middlewareInterface. +func (f MiddlewareFunc) Middleware(next http.Handler) http.Handler { + return f(next) +} + +// MiddlewareStack builds a stack of middleware that a request will be +// dispatched through before sending it to the underlying handler. Middlewares +// are added to it before getting a handler with a call to Mount on another +// handler like a ServeMux. +// used like: +// +// middlewares := &MiddlewareStack +// middlewares.Use(middleware1) +// middlewares.Use(middleware2) +// ... +// handler := middlewares.Mount(mux) +// +// Besides some slight syntactic nicety, the entire reason this type exists is +// because it will mount middlewares in a more human-friendly/intuitive order. +// When mounting middlewares (not using MiddlewareStack) like: +// +// handler := mux +// handler = middleware1.Wrapper(handler) +// handler = middleware2.Wrapper(handler) +// ... +// +// One must be very careful because the middlewares will be "backwards" +// according to the list in that when a request enters the stack, the middleware +// that was mounted first will be called _last_ because it's nested the deepest +// down. +// +// MiddlewareStack fixes this problem by enabling any number of middlewares to +// be specified, and then mounting them in inverted order when Mount is called. +type MiddlewareStack struct { + middlewares []middlewareInterface +} + +// NewMiddlewareStack is a helper that can act as a shortcut to initialize a +// middleware stack by passing a series of middlewares as variadic args. +func NewMiddlewareStack(middlewares ...middlewareInterface) *MiddlewareStack { + stack := &MiddlewareStack{} + for _, mw := range middlewares { + stack.Use(mw) + } + return stack +} + +func (s *MiddlewareStack) Mount(handler http.Handler) http.Handler { + for i := len(s.middlewares) - 1; i >= 0; i-- { + handler = s.middlewares[i].Middleware(handler) + } + return handler +} + +func (s *MiddlewareStack) Use(middleware middlewareInterface) { + s.middlewares = append(s.middlewares, middleware) +} diff --git a/apimiddleware/api_middleware_test.go b/apimiddleware/api_middleware_test.go new file mode 100644 index 0000000..a0c323e --- /dev/null +++ b/apimiddleware/api_middleware_test.go @@ -0,0 +1,84 @@ +package apimiddleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +// Verify MiddlewareFunc complies with middlewareInterface. +var _ middlewareInterface = MiddlewareFunc(func(next http.Handler) http.Handler { + return next +}) + +type contextTrailContextKey struct{} + +// Adds the configured segment to trail in context. +type contextTrailMiddleware struct { + segment string +} + +func newContextTrailMiddleware(segment string) *contextTrailMiddleware { + return &contextTrailMiddleware{segment: segment} +} + +func (m *contextTrailMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Extract + var contextTrail []string + if existingTrail, ok := ctx.Value(contextTrailContextKey{}).([]string); ok { + contextTrail = existingTrail + } + contextTrail = append(contextTrail, m.segment) + + next.ServeHTTP(w, r.WithContext(context.WithValue(ctx, contextTrailContextKey{}, contextTrail))) + }) +} + +func TestMiddlewareStack(t *testing.T) { + t.Parallel() + + makeRequestAndExtractTrail := func(stack *MiddlewareStack) []string { + var contextTrail []string + + handler := stack.Mount(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + contextTrail = r.Context().Value(contextTrailContextKey{}).([]string) //nolint:forcetypeassert + })) + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) + handler.ServeHTTP(recorder, req) + + return contextTrail + } + + t.Run("NewMiddlewareStack", func(t *testing.T) { + t.Parallel() + + stack := NewMiddlewareStack( + newContextTrailMiddleware("1st"), + newContextTrailMiddleware("2nd"), + newContextTrailMiddleware("3rd"), + ) + + contextTrail := makeRequestAndExtractTrail(stack) + require.Equal(t, []string{"1st", "2nd", "3rd"}, contextTrail) + }) + + t.Run("Use", func(t *testing.T) { + t.Parallel() + + stack := &MiddlewareStack{} + stack.Use(newContextTrailMiddleware("1st")) + stack.Use(newContextTrailMiddleware("2nd")) + stack.Use(newContextTrailMiddleware("3rd")) + + contextTrail := makeRequestAndExtractTrail(stack) + require.Equal(t, []string{"1st", "2nd", "3rd"}, contextTrail) + }) +} diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..1d69fe4 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,3 @@ +# riverapiframe [![Build Status](https://github.com/riverqueue/riverapiframe/actions/workflows/ci.yaml/badge.svg?branch=master)](https://github.com/riverqueue/riverapiframe/actions) + +A minimal API framework for various River projects. diff --git a/docs/development.md b/docs/development.md new file mode 100644 index 0000000..20fcd96 --- /dev/null +++ b/docs/development.md @@ -0,0 +1,14 @@ +# River API framework development + +## Run tests + +```sh +$ go test ./... +``` + +## Run lint + +```sh +$ brew install golangci-lint +$ golangci-lint run --fix +``` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8d2747d --- /dev/null +++ b/go.mod @@ -0,0 +1,30 @@ +module github.com/riverqueue/riverapiframe + +go 1.23.0 + +require ( + github.com/go-playground/validator/v10 v10.25.0 + github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 + github.com/jackc/pgx/v5 v5.7.2 + github.com/riverqueue/river/rivershared v0.18.0 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gabriel-vasile/mimetype v1.4.8 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/riverqueue/river/rivertype v0.18.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect + go.uber.org/goleak v1.3.0 // indirect + golang.org/x/crypto v0.32.0 // indirect + golang.org/x/net v0.34.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/text v0.22.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7203c18 --- /dev/null +++ b/go.sum @@ -0,0 +1,56 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= +github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.25.0 h1:5Dh7cjvzR7BRZadnsVOzPhWsrwUr0nmsZJxEAnFLNO8= +github.com/go-playground/validator/v10 v10.25.0/go.mod h1:GGzBIJMuE98Ic/kJsBXbz1x/7cByt++cQ+YOuDM5wus= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI= +github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/river/rivershared v0.18.0 h1:hBfyaoTAvogs7lSw4vr6A2ZdZmmtTlew10P4MRXuaDg= +github.com/riverqueue/river/rivershared v0.18.0/go.mod h1:wyJw90ILEYNcYCoXr4B6iPHnSyRH0WKGQuPzjdEwou8= +github.com/riverqueue/river/rivertype v0.18.0 h1:YsXR5NbLAzniurGO0+zcISWMKq7Y71xkIe2oi86OAsE= +github.com/riverqueue/river/rivertype v0.18.0/go.mod h1:DETcejveWlq6bAb8tHkbgJqmXWVLiFhTiEm8j7co1bE= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/validate/validate.go b/internal/validate/validate.go new file mode 100644 index 0000000..fd8cee5 --- /dev/null +++ b/internal/validate/validate.go @@ -0,0 +1,118 @@ +// Package validate internalizes Go Playground's Validator framework, setting +// some common options that we use everywhere, providing some useful helpers, +// and exporting a simplified API. +package validate + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-playground/validator/v10" +) + +// WithRequiredStructEnabled can be removed once validator/v11 is released. +var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals + +func init() { //nolint:gochecknoinits + validate.RegisterTagNameFunc(preferPublicName) +} + +// PublicFacingMessage builds a complete error message from a validator error +// that's suitable for public-facing consumption. +// +// I only added a few possible validations to start. We'll probably need to add +// more as we go and expand our usage. +func PublicFacingMessage(validatorErr error) string { + var message string + + //nolint:errorlint + if validationErrs, ok := validatorErr.(validator.ValidationErrors); ok { + for _, fieldErr := range validationErrs { + switch fieldErr.Tag() { + case "lte": + fallthrough // lte and max are synonyms + case "max": + kind := fieldErr.Kind() + if kind == reflect.Ptr { + kind = fieldErr.Type().Elem().Kind() + } + + switch kind { //nolint:exhaustive + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: + message += fmt.Sprintf(" Field `%s` must be less than or equal to %s.", + fieldErr.Field(), fieldErr.Param()) + + case reflect.Slice, reflect.Map: + message += fmt.Sprintf(" Field `%s` must contain at most %s element(s).", + fieldErr.Field(), fieldErr.Param()) + + case reflect.String: + message += fmt.Sprintf(" Field `%s` must be at most %s character(s) long.", + fieldErr.Field(), fieldErr.Param()) + + default: + message += fieldErr.Error() + } + + case "gte": + fallthrough // gte and min are synonyms + case "min": + kind := fieldErr.Kind() + if kind == reflect.Ptr { + kind = fieldErr.Type().Elem().Kind() + } + + switch kind { //nolint:exhaustive + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: + message += fmt.Sprintf(" Field `%s` must be greater or equal to %s.", + fieldErr.Field(), fieldErr.Param()) + + case reflect.Slice, reflect.Map: + message += fmt.Sprintf(" Field `%s` must contain at least %s element(s).", + fieldErr.Field(), fieldErr.Param()) + + case reflect.String: + message += fmt.Sprintf(" Field `%s` must be at least %s character(s) long.", + fieldErr.Field(), fieldErr.Param()) + + default: + message += fieldErr.Error() + } + + case "oneof": + message += fmt.Sprintf(" Field `%s` should be one of the following values: %s.", + fieldErr.Field(), fieldErr.Param()) + + case "required": + message += fmt.Sprintf(" Field `%s` is required.", fieldErr.Field()) + + default: + message += fmt.Sprintf(" Validation on field `%s` failed on the `%s` tag.", fieldErr.Field(), fieldErr.Tag()) + } + } + } + + return strings.TrimSpace(message) +} + +// StructCtx validates a structs exposed fields, and automatically validates +// nested structs, unless otherwise specified and also allows passing of +// context.Context for contextual validation information. +func StructCtx(ctx context.Context, s any) error { + return validate.StructCtx(ctx, s) +} + +// preferPublicName is a validator tag naming function that uses public names +// like a field's JSON tag instead of actual field names in structs. +// This is important because we sent these back as user-facing errors (and the +// users submitted them as JSON/path parameters). +func preferPublicName(fld reflect.StructField) string { + name, _, _ := strings.Cut(fld.Tag.Get("json"), ",") + if name != "" && name != "-" { + return name + } + + return fld.Name +} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go new file mode 100644 index 0000000..81e2d69 --- /dev/null +++ b/internal/validate/validate_test.go @@ -0,0 +1,135 @@ +package validate + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFromValidator(t *testing.T) { + t.Parallel() + + // Fields have JSON tags so we can verify those are used over the + // property name. + type TestStruct struct { + MinInt int `json:"min_int" validate:"min=1"` + MinSlice []string `json:"min_slice" validate:"min=1"` + MinString string `json:"min_string" validate:"min=1"` + MaxInt int `json:"max_int" validate:"max=0"` + MaxSlice []string `json:"max_slice" validate:"max=0"` + MaxString string `json:"max_string" validate:"max=0"` + OneOf string `json:"one_of" validate:"oneof=blue green"` + Required string `json:"required" validate:"required"` + Unsupported string `json:"unsupported" validate:"e164"` + } + + validTestStruct := func() *TestStruct { + return &TestStruct{ + MinInt: 1, + MinSlice: []string{"1"}, + MinString: "value", + MaxInt: 0, + MaxSlice: []string{}, + MaxString: "", + OneOf: "blue", + Required: "value", + Unsupported: "+1123456789", + } + } + + t.Run("MaxInt", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxInt = 1 + require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MaxSlice", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxSlice = []string{"1"} + require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MaxString", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MaxString = "value" + require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinInt", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinInt = 0 + require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinSlice", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinSlice = nil + require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MinString", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinString = "" + require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("OneOf", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.OneOf = "red" + require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("Required", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.Required = "" + require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("Unsupported", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.Unsupported = "abc" + require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct))) + }) + + t.Run("MultipleErrors", func(t *testing.T) { + t.Parallel() + + testStruct := validTestStruct() + testStruct.MinInt = 0 + testStruct.Required = "" + require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) + }) +} + +func TestPreferPublicNames(t *testing.T) { + t.Parallel() + + type testStruct struct { + JSONNameField string `json:"json_name"` + StructNameField string `apiquery:"-"` + } + + require.Equal(t, "json_name", + preferPublicName(reflect.TypeOf(testStruct{}).Field(0))) + require.Equal(t, "StructNameField", + preferPublicName(reflect.TypeOf(testStruct{}).Field(1))) +}