diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7bfaf979..0e7824ef 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -81,7 +81,7 @@ jobs: name: Go lint runs-on: ubuntu-latest env: - GOLANGCI_LINT_VERSION: v1.63.4 + GOLANGCI_LINT_VERSION: v1.64.6 permissions: contents: read # allow read access to pull request. Use with `only-new-issues` option. diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index bf0cc51a..0ef10a01 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -16,11 +16,11 @@ import ( "github.com/rs/cors" sloghttp "github.com/samber/slog-http" + "github.com/riverqueue/apiframe/apimiddleware" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver/riverpgxv5" "riverqueue.com/riverui" - "riverqueue.com/riverui/internal/apimiddleware" ) func main() { diff --git a/dist/.gitkeep b/dist/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/development.md b/docs/development.md index 2c80864e..5ec1213e 100644 --- a/docs/development.md +++ b/docs/development.md @@ -4,12 +4,6 @@ River UI consists of two apps: a Go backend API, and a TypeScript UI frontend. ## Environment -The project uses a combination of direnv and a `.env` file (to suit Vite conventions). Copy the example and edit as necessary: - -```sh -cp .env.example .env.local -``` - ## Install dependencies ```sh @@ -17,10 +11,12 @@ go get ./... npm install ``` -## Install Reflex - This project uses [Reflex](https://github.com/cespare/reflex) for local dev. Install it. +``` sh +go install github.com/cespare/reflex@latest +``` + ## Running the UI and API together ```sh diff --git a/go.mod b/go.mod index 3eb02f07..bb2dffa5 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,13 @@ module riverqueue.com/riverui -go 1.22.0 +go 1.23.0 -toolchain go1.23.5 +toolchain go1.24.1 require ( - github.com/go-playground/validator/v10 v10.25.0 github.com/google/uuid v1.6.0 - github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgx/v5 v5.7.2 + github.com/riverqueue/apiframe v0.0.0-20250310051455-203f3fd8260f github.com/riverqueue/river v0.18.0 github.com/riverqueue/river/riverdriver v0.18.0 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.18.0 @@ -24,6 +23,8 @@ require ( github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.25.0 // indirect + github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect diff --git a/go.sum b/go.sum index c42a94b7..bbc2ee8f 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/apiframe v0.0.0-20250310051455-203f3fd8260f h1:Pr3+ERes1GkCJnw/gpi9yVx9tN7VCcPjuZCFN/OBaZg= +github.com/riverqueue/apiframe v0.0.0-20250310051455-203f3fd8260f/go.mod h1:ko/9b4SeomWrHTr4WU0i21peq90Qk2Mm8MgOqPrTcHA= github.com/riverqueue/river v0.18.0 h1:sGHeTOL9MR8+pMIVHRm59fzet8Ron/xjF3Yq/PSGb78= github.com/riverqueue/river v0.18.0/go.mod h1:oapX5xb/L2YnkE801QubDZ0COHxVxEGVY37icPzghhU= github.com/riverqueue/river/riverdriver v0.18.0 h1:a2haR5I0MQLHjLCSVFpUEeJALCLemRl5zCztucysm1E= @@ -49,8 +51,8 @@ github.com/riverqueue/river/rivertype v0.18.0 h1:YsXR5NbLAzniurGO0+zcISWMKq7Y71x github.com/riverqueue/river/rivertype v0.18.0/go.mod h1:DETcejveWlq6bAb8tHkbgJqmXWVLiFhTiEm8j7co1bE= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/samber/slog-http v1.5.1 h1:z5Ty/u5LKJbWjmjLDr6OgtwHXrPPrH2ZPoQE/47n9sU= diff --git a/handler.go b/handler.go index 53c22275..55d6357b 100644 --- a/handler.go +++ b/handler.go @@ -17,13 +17,12 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/riverqueue/apiframe/apiendpoint" + "github.com/riverqueue/apiframe/apimiddleware" "github.com/riverqueue/river" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/util/valutil" - - "riverqueue.com/riverui/internal/apiendpoint" - "riverqueue.com/riverui/internal/apimiddleware" ) // DB is the interface for a pgx database connection. diff --git a/handler_api_endpoint.go b/handler_api_endpoint.go index 313e3515..a4358c0d 100644 --- a/handler_api_endpoint.go +++ b/handler_api_endpoint.go @@ -13,6 +13,8 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/riverqueue/apiframe/apiendpoint" + "github.com/riverqueue/apiframe/apierror" "github.com/riverqueue/river" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/startstop" @@ -20,8 +22,6 @@ import ( "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" - "riverqueue.com/riverui/internal/apiendpoint" - "riverqueue.com/riverui/internal/apierror" "riverqueue.com/riverui/internal/dbsqlc" "riverqueue.com/riverui/internal/querycacher" "riverqueue.com/riverui/internal/util/pgxutil" @@ -104,7 +104,7 @@ func (a *healthCheckGetEndpoint) Execute(ctx context.Context, req *healthCheckGe // fall through to OK status response below default: - return nil, apierror.NewNotFound("Health check %q not found. Use either `complete` or `minimal`.", req.Name) + return nil, apierror.NewNotFoundf("Health check %q not found. Use either `complete` or `minimal`.", req.Name) } return statusResponseOK, nil @@ -142,7 +142,7 @@ func (a *jobCancelEndpoint) Execute(ctx context.Context, req *jobCancelRequest) job, err := a.client.JobCancelTx(ctx, tx, jobID) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(jobID) + return nil, NewNotFoundJob(jobID) } return nil, err } @@ -185,10 +185,10 @@ func (a *jobDeleteEndpoint) Execute(ctx context.Context, req *jobDeleteRequest) _, err := a.client.JobDeleteTx(ctx, tx, jobID) if err != nil { if errors.Is(err, rivertype.ErrJobRunning) { - return nil, apierror.NewBadRequest("Job %d is running and can't be deleted until it finishes.", jobID) + return nil, apierror.NewBadRequestf("Job %d is running and can't be deleted until it finishes.", jobID) } if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(jobID) + return nil, NewNotFoundJob(jobID) } return nil, err } @@ -227,7 +227,7 @@ func (req *jobGetRequest) ExtractRaw(r *http.Request) error { jobID, err := strconv.ParseInt(idString, 10, 64) if err != nil { - return apierror.NewBadRequest("Couldn't convert job ID to int64: %s.", err) + return apierror.NewBadRequestf("Couldn't convert job ID to int64: %s.", err) } req.JobID = jobID @@ -239,7 +239,7 @@ func (a *jobGetEndpoint) Execute(ctx context.Context, req *jobGetRequest) (*Rive job, err := a.client.JobGetTx(ctx, tx, req.JobID) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(req.JobID) + return nil, NewNotFoundJob(req.JobID) } return nil, fmt.Errorf("error getting job: %w", err) } @@ -276,7 +276,7 @@ func (req *jobListRequest) ExtractRaw(r *http.Request) error { if limitStr := r.URL.Query().Get("limit"); limitStr != "" { limit, err := strconv.Atoi(limitStr) if err != nil { - return apierror.NewBadRequest("Couldn't convert `limit` to integer: %s.", err) + return apierror.NewBadRequestf("Couldn't convert `limit` to integer: %s.", err) } req.Limit = &limit @@ -344,7 +344,7 @@ func (a *jobRetryEndpoint) Execute(ctx context.Context, req *jobRetryRequest) (* _, err := a.client.JobRetryTx(ctx, tx, jobID) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(jobID) + return nil, NewNotFoundJob(jobID) } return nil, err } @@ -388,7 +388,7 @@ func (a *queueGetEndpoint) Execute(ctx context.Context, req *queueGetRequest) (* queue, err := a.client.QueueGetTx(ctx, tx, req.Name) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundQueue(req.Name) + return nil, NewNotFoundQueue(req.Name) } return nil, fmt.Errorf("error getting queue: %w", err) } @@ -430,7 +430,7 @@ func (req *queueListRequest) ExtractRaw(r *http.Request) error { if limitStr := r.URL.Query().Get("limit"); limitStr != "" { limit, err := strconv.Atoi(limitStr) if err != nil { - return apierror.NewBadRequest("Couldn't convert `limit` to integer: %s.", err) + return apierror.NewBadRequestf("Couldn't convert `limit` to integer: %s.", err) } req.Limit = &limit @@ -490,7 +490,7 @@ func (a *queuePauseEndpoint) Execute(ctx context.Context, req *queuePauseRequest return pgxutil.WithTxV(ctx, a.dbPool, func(ctx context.Context, tx pgx.Tx) (*statusResponse, error) { if err := a.client.QueuePauseTx(ctx, tx, req.Name, nil); err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundQueue(req.Name) + return nil, NewNotFoundQueue(req.Name) } return nil, fmt.Errorf("error pausing queue: %w", err) } @@ -532,7 +532,7 @@ func (a *queueResumeEndpoint) Execute(ctx context.Context, req *queueResumeReque return pgxutil.WithTxV(ctx, a.dbPool, func(ctx context.Context, tx pgx.Tx) (*statusResponse, error) { if err := a.client.QueueResumeTx(ctx, tx, req.Name, nil); err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundQueue(req.Name) + return nil, NewNotFoundQueue(req.Name) } return nil, fmt.Errorf("error resuming queue: %w", err) } @@ -670,7 +670,7 @@ func (a *workflowGetEndpoint) Execute(ctx context.Context, req *workflowGetReque } if len(jobs) < 1 { - return nil, apierror.NewNotFoundWorkflow(req.ID) + return nil, NewNotFoundWorkflow(req.ID) } return &workflowGetResponse{ @@ -712,7 +712,7 @@ func (req *workflowListRequest) ExtractRaw(r *http.Request) error { if limitStr := r.URL.Query().Get("limit"); limitStr != "" { limit, err := strconv.Atoi(limitStr) if err != nil { - return apierror.NewBadRequest("Couldn't convert `limit` to integer: %s.", err) + return apierror.NewBadRequestf("Couldn't convert `limit` to integer: %s.", err) } req.Limit = &limit @@ -758,6 +758,18 @@ func (a *workflowListEndpoint) Execute(ctx context.Context, req *workflowListReq } } +func NewNotFoundJob(jobID int64) *apierror.NotFound { + return apierror.NewNotFoundf("Job not found: %d.", jobID) +} + +func NewNotFoundQueue(name string) *apierror.NotFound { + return apierror.NewNotFoundf("Queue not found: %s.", name) +} + +func NewNotFoundWorkflow(id string) *apierror.NotFound { + return apierror.NewNotFoundf("Workflow not found: %s.", id) +} + type RiverJob struct { ID int64 `json:"id"` Args json.RawMessage `json:"args"` diff --git a/handler_api_endpoint_test.go b/handler_api_endpoint_test.go index ebc70d60..5649cf99 100644 --- a/handler_api_endpoint_test.go +++ b/handler_api_endpoint_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" + "github.com/riverqueue/apiframe/apierror" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -17,7 +18,6 @@ import ( "github.com/riverqueue/river/rivershared/util/ptrutil" "github.com/riverqueue/river/rivertype" - "riverqueue.com/riverui/internal/apierror" "riverqueue.com/riverui/internal/riverinternaltest" "riverqueue.com/riverui/internal/riverinternaltest/testfactory" ) @@ -104,7 +104,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) _, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: "other"}) - requireAPIError(t, apierror.NewNotFound("Health check %q not found. Use either `complete` or `minimal`.", "other"), err) + requireAPIError(t, apierror.NewNotFoundf("Health check %q not found. Use either `complete` or `minimal`.", "other"), err) }) } @@ -140,7 +140,7 @@ func TestJobCancelEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobCancelEndpoint) _, err := endpoint.Execute(ctx, &jobCancelRequest{JobIDs: []int64String{123}}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -174,7 +174,7 @@ func TestJobDeleteEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobDeleteEndpoint) _, err := endpoint.Execute(ctx, &jobDeleteRequest{JobIDs: []int64String{123}}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -201,7 +201,7 @@ func TestJobGetEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobGetEndpoint) _, err := endpoint.Execute(ctx, &jobGetRequest{JobID: 123}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -322,7 +322,7 @@ func TestJobRetryEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobRetryEndpoint) _, err := endpoint.Execute(ctx, &jobRetryRequest{JobIDs: []int64String{123}}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -353,7 +353,7 @@ func TestAPIHandlerQueueGet(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newQueueGetEndpoint) _, err := endpoint.Execute(ctx, &queueGetRequest{Name: "does_not_exist"}) - requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) + requireAPIError(t, NewNotFoundQueue("does_not_exist"), err) }) } @@ -420,7 +420,7 @@ func TestAPIHandlerQueuePause(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newQueuePauseEndpoint) _, err := endpoint.Execute(ctx, &queuePauseRequest{Name: "does_not_exist"}) - requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) + requireAPIError(t, NewNotFoundQueue("does_not_exist"), err) }) } @@ -449,7 +449,7 @@ func TestAPIHandlerQueueResume(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newQueueResumeEndpoint) _, err := endpoint.Execute(ctx, &queueResumeRequest{Name: "does_not_exist"}) - requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) + requireAPIError(t, NewNotFoundQueue("does_not_exist"), err) }) } @@ -577,7 +577,7 @@ func TestAPIHandlerWorkflowGet(t *testing.T) { workflowID := uuid.New() _, err := endpoint.Execute(ctx, &workflowGetRequest{ID: workflowID.String()}) - requireAPIError(t, apierror.NewNotFoundWorkflow(workflowID.String()), err) + requireAPIError(t, NewNotFoundWorkflow(workflowID.String()), err) }) } diff --git a/internal/apiendpoint/api_endpoint.go b/internal/apiendpoint/api_endpoint.go deleted file mode 100644 index ea3b81e9..00000000 --- a/internal/apiendpoint/api_endpoint.go +++ /dev/null @@ -1,234 +0,0 @@ -// Package apiendpoint provides a lightweight API framework for use with River -// UI. It lets API endpoints be defined, then mounted into an http.ServeMux. -package apiendpoint - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "time" - - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5/pgconn" - - "riverqueue.com/riverui/internal/apierror" - "riverqueue.com/riverui/internal/validate" -) - -// Endpoint is a struct that should be embedded on an API endpoint, and which -// provides a partial implementation for EndpointInterface. -type Endpoint[TReq any, TResp any] struct { - // Logger used to log information about endpoint execution. - logger *slog.Logger - - // Metadata about the endpoint. This is not available until SetMeta is - // invoked on the endpoint, which is usually done in Mount. - meta *EndpointMeta -} - -func (e *Endpoint[TReq, TResp]) SetLogger(logger *slog.Logger) { e.logger = logger } -func (e *Endpoint[TReq, TResp]) SetMeta(meta *EndpointMeta) { e.meta = meta } - -type EndpointInterface interface { - // Meta returns metadata about an API endpoint, like the path it should be - // mounted at, and the status code it returns on success. - // - // This should be implemented by each specific API endpoint. - Meta() *EndpointMeta - - // SetLogger sets a logger on the endpoint. - // - // Implementation inherited from an embedded Endpoint struct. - SetLogger(logger *slog.Logger) - - // SetMeta sets metadata on an Endpoint struct after it's extracted from a - // call to an endpoint's Meta function. - // - // Implementation inherited from an embedded Endpoint struct. - SetMeta(meta *EndpointMeta) -} - -// EndpointExecuteInterface is an interface to an API endpoint. Some of it is -// implemented by an embedded Endpoint struct, and some of it should be -// implemented by the endpoint itself. -type EndpointExecuteInterface[TReq any, TResp any] interface { - EndpointInterface - - // Execute executes the API endpoint. - // - // This should be implemented by each specific API endpoint. - Execute(ctx context.Context, req *TReq) (*TResp, error) -} - -// EndpointMeta is metadata about an API endpoint. -type EndpointMeta struct { - // Pattern is the API endpoint's HTTP method and path where it should be - // mounted, which is passed to http.ServeMux by Mount. It should start with - // a verb like `GET` or `POST`, and may contain Go 1.22 path variables like - // `{name}`, whose values should be extracted by an endpoint request - // struct's custom ExtractRaw implementation. - Pattern string - - // StatusCode is the status code to be set on a successful response. - StatusCode int -} - -func (m *EndpointMeta) validate() { - if m.Pattern == "" { - panic("Endpoint.Path is required") - } - if m.StatusCode == 0 { - panic("Endpoint.StatusCode is required") - } -} - -// Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log -// information about endpoint execution. -func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointExecuteInterface[TReq, TResp]) EndpointInterface { - apiEndpoint.SetLogger(logger) - - meta := apiEndpoint.Meta() - meta.validate() // panic on problem - apiEndpoint.SetMeta(meta) - - mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) { - executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute) - }) - - return apiEndpoint -} - -func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) { - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) - defer cancel() - - // Run as much code as we can in a sub-function that can return an error. - // This is more convenient to write, but is also safer because unlike when - // writing errors to ResponseWriter, there's no danger of a missing return. - err := func() error { - var req TReq - if r.Method != http.MethodGet { - reqData, err := io.ReadAll(r.Body) - if err != nil { - return fmt.Errorf("error reading request body: %w", err) - } - - if len(reqData) > 0 { - if err := json.Unmarshal(reqData, &req); err != nil { - return apierror.NewBadRequest("Error unmarshaling request body: %s.", err) - } - } - } - - if rawExtractor, ok := any(&req).(RawExtractor); ok { - if err := rawExtractor.ExtractRaw(r); err != nil { - return err - } - } - - if err := validate.StructCtx(ctx, &req); err != nil { - return apierror.NewBadRequest(validate.PublicFacingMessage(err)) //nolint:govet - } - - resp, err := execute(ctx, &req) - if err != nil { - return err - } - - if rawExtractor, ok := any(resp).(RawResponder); ok { - return rawExtractor.RespondRaw(w) - } - - respData, err := json.Marshal(resp) - if err != nil { - return fmt.Errorf("error marshaling response JSON: %w", err) - } - - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(meta.StatusCode) - - if _, err := w.Write(respData); err != nil { - return fmt.Errorf("error writing response: %w", err) - } - - return nil - }() - if err != nil { - // Convert certain types of Postgres errors into something more - // user-friendly than an internal server error. - err = maybeInterpretInternalError(err) - - var apiErr apierror.Interface - if errors.As(err, &apiErr) { - logAttrs := []any{ - slog.String("error", apiErr.Error()), - } - - if internalErr := apiErr.GetInternalError(); internalErr != nil { - logAttrs = append(logAttrs, slog.String("internal_error", internalErr.Error())) - } - - // Logged at info level because API errors are normal. - logger.InfoContext(ctx, "API error response", logAttrs...) - - apiErr.Write(ctx, logger, w) - return - } - - if errors.Is(err, context.DeadlineExceeded) { - logger.ErrorContext(ctx, "request timeout", slog.String("error", err.Error())) - apierror.NewServiceUnavailable("Request timed out. Retrying the request might work.").Write(ctx, logger, w) - return - } - - // Internal server error. The error goes to logs but should not be - // included in the response in case there's something sensitive in - // the error string. - logger.ErrorContext(ctx, "error running API route", slog.String("error", err.Error())) - apierror.NewInternalServerError("Internal server error. Check logs for more information.").Write(ctx, logger, w) - } -} - -// RawExtractor is an interface that can be implemented by request structs that -// allows them to extract information from a raw request, like path values. -type RawExtractor interface { - ExtractRaw(r *http.Request) error -} - -// RawResponder is an interface that can be implemented by response structs that -// allow them to respond directly to a ResponseWriter instead of emitting the -// normal JSON format. -type RawResponder interface { - RespondRaw(w http.ResponseWriter) error -} - -// Make some broad categories of internal error back into something public -// facing because in some cases they can be a vast help for debugging. -func maybeInterpretInternalError(err error) error { - var ( - apiErr apierror.Interface - connectErr *pgconn.ConnectError - pgErr *pgconn.PgError - ) - - switch { - case errors.As(err, &connectErr): - apiErr = apierror.NewBadRequest("There was a problem connecting to the configured database. Check logs for details.") - - case errors.As(err, &pgErr): - if pgErr.Code == pgerrcode.InsufficientPrivilege { - apiErr = apierror.NewBadRequest("Insufficient database privilege to perform this operation.") - } else { - return err - } - - default: - return err - } - - return apierror.WithInternalError(apiErr, err) -} diff --git a/internal/apiendpoint/api_endpoint_test.go b/internal/apiendpoint/api_endpoint_test.go deleted file mode 100644 index d7a5fce8..00000000 --- a/internal/apiendpoint/api_endpoint_test.go +++ /dev/null @@ -1,316 +0,0 @@ -package apiendpoint - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5/pgconn" - "github.com/stretchr/testify/require" - - "riverqueue.com/riverui/internal/apierror" - "riverqueue.com/riverui/internal/riverinternaltest" -) - -func TestMountAndServe(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - type testBundle struct { - recorder *httptest.ResponseRecorder - } - - setup := func(t *testing.T) (*http.ServeMux, *testBundle) { - t.Helper() - - var ( - logger = riverinternaltest.Logger(t) - mux = http.NewServeMux() - ) - - Mount(mux, logger, &getEndpoint{}) - Mount(mux, logger, &postEndpoint{}) - - return mux, &testBundle{ - recorder: httptest.NewRecorder(), - } - } - - t.Run("GetEndpointAndExtractRaw", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", nil) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) - }) - - t.Run("BodyIgnoredOnGet", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", - bytes.NewBuffer(mustMarshalJSON(t, &getRequest{IgnoredJSONMessage: "Ignored hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) - }) - - t.Run("MethodNotAllowed", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint/Hello.", nil) - mux.ServeHTTP(bundle.recorder, req) - - // This error comes from net/http. - requireStatusAndResponse(t, http.StatusMethodNotAllowed, "Method Not Allowed\n", bundle.recorder) - }) - - t.Run("PostEndpoint", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) - }) - - t.Run("ValidationError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Field `message` is required."}, bundle.recorder) - }) - - t.Run("APIError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeAPIError: true, Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Bad request."}, bundle.recorder) - }) - - t.Run("InterpretedError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true, Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Insufficient database privilege to perform this operation."}, bundle.recorder) - }) - - t.Run("Timeout", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - ctx, cancel := context.WithDeadline(ctx, time.Now()) - t.Cleanup(cancel) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) - require.NoError(t, err) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusServiceUnavailable, &apierror.APIError{Message: "Request timed out. Retrying the request might work."}, bundle.recorder) - }) - - t.Run("InternalServerError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true, Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) - }) -} - -func TestMaybeInterpretInternalError(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - t.Run("ConnectError", func(t *testing.T) { - t.Parallel() - - _, err := pgconn.Connect(ctx, "postgres://user@127.0.0.1:37283/does_not_exist") - - require.Equal(t, apierror.WithInternalError(apierror.NewBadRequest("There was a problem connecting to the configured database. Check logs for details."), err), maybeInterpretInternalError(err)) - }) - - t.Run("ConnectError", func(t *testing.T) { - t.Parallel() - - err := &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege} - - require.Equal(t, apierror.WithInternalError(apierror.NewBadRequest("Insufficient database privilege to perform this operation."), err), maybeInterpretInternalError(err)) - }) - - t.Run("OtherPGError", func(t *testing.T) { - t.Parallel() - - err := &pgconn.PgError{Code: pgerrcode.CardinalityViolation} - - require.Equal(t, err, maybeInterpretInternalError(err)) - }) - - t.Run("ConnectError", func(t *testing.T) { - t.Parallel() - - err := errors.New("other error") - - require.Equal(t, err, maybeInterpretInternalError(err)) - }) -} - -func mustMarshalJSON(t *testing.T, v any) []byte { - t.Helper() - - data, err := json.Marshal(v) - require.NoError(t, err) - return data -} - -func mustUnmarshalJSON[T any](t *testing.T, data []byte) *T { - t.Helper() - - var val T - err := json.Unmarshal(data, &val) - require.NoError(t, err) - return &val -} - -// Shortcut for requiring an HTTP status code and a JSON-marshaled response -// equivalent to expectedResp. The important thing that is does is that in the -// event of a failure on status code, it prints the response body as additional -// context to help debug the problem. -func requireStatusAndJSONResponse[T any](t *testing.T, expectedStatusCode int, expectedResp *T, recorder *httptest.ResponseRecorder) { - t.Helper() - - require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) - require.Equal(t, expectedResp, mustUnmarshalJSON[T](t, recorder.Body.Bytes())) - require.Equal(t, "application/json; charset=utf-8", recorder.Header().Get("Content-Type")) -} - -// Same as the above, but for a non-JSON response. -func requireStatusAndResponse(t *testing.T, expectedStatusCode int, expectedResp string, recorder *httptest.ResponseRecorder) { - t.Helper() - - require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) - require.Equal(t, expectedResp, recorder.Body.String()) -} - -// -// getEndpoint -// - -type getEndpoint struct { - Endpoint[getRequest, getResponse] -} - -func (*getEndpoint) Meta() *EndpointMeta { - return &EndpointMeta{ - Pattern: "GET /api/get-endpoint/{message}", - StatusCode: http.StatusOK, - } -} - -type getRequest struct { - IgnoredJSONMessage string `json:"ignored_json" validate:"-"` - Message string `json:"-" validate:"required"` -} - -func (req *getRequest) ExtractRaw(r *http.Request) error { - req.Message = r.PathValue("message") - return nil -} - -type getResponse struct { - Message string `json:"message" validate:"required"` -} - -func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, error) { - // This branch never gets taken because request bodies are ignored on GET. - if req.IgnoredJSONMessage != "" { - return &getResponse{Message: req.IgnoredJSONMessage}, nil - } - - return &getResponse{Message: req.Message}, nil -} - -// -// postEndpoint -// - -type postEndpoint struct { - Endpoint[postRequest, postResponse] -} - -func (*postEndpoint) Meta() *EndpointMeta { - return &EndpointMeta{ - Pattern: "POST /api/post-endpoint", - StatusCode: http.StatusCreated, - } -} - -type postRequest struct { - MakeAPIError bool `json:"make_api_error" validate:"-"` - MakeInternalError bool `json:"make_internal_error" validate:"-"` - MakePostgresError bool `json:"make_postgres_error" validate:"-"` - Message string `json:"message" validate:"required"` -} - -type postResponse struct { - Message string `json:"message"` -} - -func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResponse, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - - if req.MakeAPIError { - return nil, apierror.NewBadRequest("Bad request.") - } - - if req.MakeInternalError { - return nil, errors.New("an internal error occurred") - } - - if req.MakePostgresError { - // Wrap the error to make it more realistic. - return nil, fmt.Errorf("error running Postgres query: %w", &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege}) - } - - return &postResponse{Message: req.Message}, nil -} diff --git a/internal/apierror/api_error.go b/internal/apierror/api_error.go deleted file mode 100644 index e553d6cf..00000000 --- a/internal/apierror/api_error.go +++ /dev/null @@ -1,141 +0,0 @@ -// Package apierror contains a variety of marshalable API errors that adhere to -// a unified error response convention. -package apierror - -import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "net/http" -) - -// APIError is a struct that's embedded on a more specific API error struct (as -// seen below), and which provides a JSON serialization and a wait to -// conveniently write itself to an HTTP response. -// -// APIErrorInterface should be used with errors.As instead of this struct. -type APIError struct { - // InternalError is an additional error that might be associated with the - // API error. It's not returned in the API error response, but is logged in - // API endpoint execution to provide extra information for operators. - InternalError error `json:"-"` - - // Message is a descriptive, human-friendly message indicating what went - // wrong. Try to make error messages as actionable as possible to help the - // caller easily fix what went wrong. - Message string `json:"message"` - - // StatusCode is the API error's HTTP status code. It's not marshaled to - // JSON, but determines how the error is written to a response. - StatusCode int `json:"-"` -} - -func (e *APIError) Error() string { return e.Message } -func (e *APIError) GetInternalError() error { return e.InternalError } -func (e *APIError) SetInternalError(internalErr error) { e.InternalError = internalErr } - -// Write writes the API error to an HTTP response, writing to the given logger -// in case of a problem. -func (e *APIError) Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(e.StatusCode) - - respData, err := json.Marshal(e) - if err != nil { - logger.ErrorContext(ctx, "error marshaling API error", slog.String("error", err.Error())) - } - - if _, err := w.Write(respData); err != nil { - logger.ErrorContext(ctx, "error writing API error", slog.String("error", err.Error())) - } -} - -// Interface is an interface to an API error. This is needed for use with -// errors.As because APIError itself is embedded on another error struct, and -// won't be usable as an errors.As target. -type Interface interface { - Error() string - GetInternalError() error - SetInternalError(internalErr error) - Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) -} - -// WithInternalError is a convenience function for assigning an internal error -// to the given API error and returning it. -func WithInternalError[TAPIError Interface](apiErr TAPIError, internalErr error) TAPIError { - apiErr.SetInternalError(internalErr) - return apiErr -} - -// -// BadRequest -// - -type BadRequest struct { //nolint:errname - APIError -} - -func NewBadRequest(format string, a ...any) *BadRequest { - return &BadRequest{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusBadRequest, - }, - } -} - -// -// InternalServerError -// - -type InternalServerError struct { - APIError -} - -func NewInternalServerError(format string, a ...any) *InternalServerError { - return &InternalServerError{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusInternalServerError, - }, - } -} - -// -// NotFound -// - -type NotFound struct { //nolint:errname - APIError -} - -func NewNotFound(format string, a ...any) *NotFound { - return &NotFound{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusNotFound, - }, - } -} - -func NewNotFoundJob(jobID int64) *NotFound { return NewNotFound("Job not found: %d.", jobID) } -func NewNotFoundQueue(name string) *NotFound { return NewNotFound("Queue not found: %s.", name) } -func NewNotFoundWorkflow(id string) *NotFound { return NewNotFound("Workflow not found: %s.", id) } - -// -// ServiceUnavailable -// - -type ServiceUnavailable struct { //nolint:errname - APIError -} - -func NewServiceUnavailable(format string, a ...any) *ServiceUnavailable { - return &ServiceUnavailable{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusServiceUnavailable, - }, - } -} diff --git a/internal/apierror/api_error_test.go b/internal/apierror/api_error_test.go deleted file mode 100644 index 0a6178cf..00000000 --- a/internal/apierror/api_error_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package apierror - -import ( - "context" - "encoding/json" - "errors" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" - - "riverqueue.com/riverui/internal/riverinternaltest" -) - -func TestAPIError(t *testing.T) { - t.Parallel() - - var ( - anErr = errors.New("an error") - apiErr = NewBadRequest("Bad request.") - ) - - apiErr.SetInternalError(anErr) - require.Equal(t, anErr, apiErr.GetInternalError()) -} - -func TestAPIErrorJSON(t *testing.T) { - t.Parallel() - - require.JSONEq(t, - `{"message":"Bad request. Try sending JSON next time."}`, - string(mustMarshalJSON( - t, NewBadRequest("Bad request. Try sending JSON next time."))), - ) -} - -func TestAPIErrorWrite(t *testing.T) { - t.Parallel() - - var ( - ctx = context.Background() - logger = riverinternaltest.Logger(t) - recorder = httptest.NewRecorder() - ) - - NewBadRequest("Bad request. Try sending JSON next time.").Write(ctx, logger, recorder) - - require.Equal(t, 400, recorder.Result().StatusCode) - require.JSONEq(t, - `{"message":"Bad request. Try sending JSON next time."}`, - recorder.Body.String(), - ) - require.Equal(t, "application/json; charset=utf-8", recorder.Header().Get("Content-Type")) -} - -func TestWithInternalError(t *testing.T) { - t.Parallel() - - var ( - anErr = errors.New("an error") - apiErr = NewBadRequest("Bad request.") - ) - - apiErr = WithInternalError(apiErr, anErr) - require.Equal(t, anErr, apiErr.InternalError) -} - -func mustMarshalJSON(t *testing.T, v any) []byte { - t.Helper() - - data, err := json.Marshal(v) - require.NoError(t, err) - return data -} diff --git a/internal/apimiddleware/api_middleware.go b/internal/apimiddleware/api_middleware.go deleted file mode 100644 index 9149d8f4..00000000 --- a/internal/apimiddleware/api_middleware.go +++ /dev/null @@ -1,71 +0,0 @@ -package apimiddleware - -import ( - "net/http" -) - -// middlewareInterface is an interface to be implemented by middleware. -type middlewareInterface interface { - Middleware(next http.Handler) http.Handler -} - -// MiddlewareFunc allows a simple middleware to be defined as only a function. -type MiddlewareFunc func(next http.Handler) http.Handler - -// Middleware allows MiddlewareFunc to implement middlewareInterface. -func (f MiddlewareFunc) Middleware(next http.Handler) http.Handler { - return f(next) -} - -// MiddlewareStack builds a stack of middleware that a request will be -// dispatched through before sending it to the underlying handler. Middlewares -// are added to it before getting a handler with a call to Mount on another -// handler like a ServeMux. -// used like: -// -// middlewares := &MiddlewareStack -// middlewares.Use(middleware1) -// middlewares.Use(middleware2) -// ... -// handler := middlewares.Mount(mux) -// -// Besides some slight syntactic nicety, the entire reason this type exists is -// because it will mount middlewares in a more human-friendly/intuitive order. -// When mounting middlewares (not using MiddlewareStack) like: -// -// handler := mux -// handler = middleware1.Wrapper(handler) -// handler = middleware2.Wrapper(handler) -// ... -// -// One must be very careful because the middlewares will be "backwards" -// according to the list in that when a request enters the stack, the middleware -// that was mounted first will be called _last_ because it's nested the deepest -// down. -// -// MiddlewareStack fixes this problem by enabling any number of middlewares to -// be specified, and then mounting them in inverted order when Mount is called. -type MiddlewareStack struct { - middlewares []middlewareInterface -} - -// NewMiddlewareStack is a helper that can act as a shortcut to initialize a -// middleware stack by passing a series of middlewares as variadic args. -func NewMiddlewareStack(middlewares ...middlewareInterface) *MiddlewareStack { - stack := &MiddlewareStack{} - for _, mw := range middlewares { - stack.Use(mw) - } - return stack -} - -func (s *MiddlewareStack) Mount(handler http.Handler) http.Handler { - for i := len(s.middlewares) - 1; i >= 0; i-- { - handler = s.middlewares[i].Middleware(handler) - } - return handler -} - -func (s *MiddlewareStack) Use(middleware middlewareInterface) { - s.middlewares = append(s.middlewares, middleware) -} diff --git a/internal/apimiddleware/api_middleware_test.go b/internal/apimiddleware/api_middleware_test.go deleted file mode 100644 index a0c323e7..00000000 --- a/internal/apimiddleware/api_middleware_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package apimiddleware - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" -) - -// Verify MiddlewareFunc complies with middlewareInterface. -var _ middlewareInterface = MiddlewareFunc(func(next http.Handler) http.Handler { - return next -}) - -type contextTrailContextKey struct{} - -// Adds the configured segment to trail in context. -type contextTrailMiddleware struct { - segment string -} - -func newContextTrailMiddleware(segment string) *contextTrailMiddleware { - return &contextTrailMiddleware{segment: segment} -} - -func (m *contextTrailMiddleware) Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Extract - var contextTrail []string - if existingTrail, ok := ctx.Value(contextTrailContextKey{}).([]string); ok { - contextTrail = existingTrail - } - contextTrail = append(contextTrail, m.segment) - - next.ServeHTTP(w, r.WithContext(context.WithValue(ctx, contextTrailContextKey{}, contextTrail))) - }) -} - -func TestMiddlewareStack(t *testing.T) { - t.Parallel() - - makeRequestAndExtractTrail := func(stack *MiddlewareStack) []string { - var contextTrail []string - - handler := stack.Mount(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - contextTrail = r.Context().Value(contextTrailContextKey{}).([]string) //nolint:forcetypeassert - })) - - recorder := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) - handler.ServeHTTP(recorder, req) - - return contextTrail - } - - t.Run("NewMiddlewareStack", func(t *testing.T) { - t.Parallel() - - stack := NewMiddlewareStack( - newContextTrailMiddleware("1st"), - newContextTrailMiddleware("2nd"), - newContextTrailMiddleware("3rd"), - ) - - contextTrail := makeRequestAndExtractTrail(stack) - require.Equal(t, []string{"1st", "2nd", "3rd"}, contextTrail) - }) - - t.Run("Use", func(t *testing.T) { - t.Parallel() - - stack := &MiddlewareStack{} - stack.Use(newContextTrailMiddleware("1st")) - stack.Use(newContextTrailMiddleware("2nd")) - stack.Use(newContextTrailMiddleware("3rd")) - - contextTrail := makeRequestAndExtractTrail(stack) - require.Equal(t, []string{"1st", "2nd", "3rd"}, contextTrail) - }) -} diff --git a/internal/validate/validate.go b/internal/validate/validate.go deleted file mode 100644 index 184318d2..00000000 --- a/internal/validate/validate.go +++ /dev/null @@ -1,118 +0,0 @@ -// Package validate internalizes Go Playground's Validator framework, setting -// some common options that we use everywhere, providing some useful helpers, -// and exporting a simplified API. -package validate - -import ( - "context" - "fmt" - "reflect" - "strings" - - "github.com/go-playground/validator/v10" -) - -// WithRequiredStructEnabled can be removed once validator/v11 is released. -var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals - -func init() { //nolint:gochecknoinits - validate.RegisterTagNameFunc(preferPublicName) -} - -// PublicFacingMessage builds a complete error message from a validator error -// that's suitable for public-facing consumption. -// -// I only added a few possible validations to start. We'll probably need to add -// more as we go and expand our usage. -func PublicFacingMessage(validatorErr error) string { - var message string - - //nolint:errorlint - if validationErrs, ok := validatorErr.(validator.ValidationErrors); ok { - for _, fieldErr := range validationErrs { - switch fieldErr.Tag() { - case "lte": - fallthrough // lte and max are synonyms - case "max": - kind := fieldErr.Kind() - if kind == reflect.Ptr { - kind = fieldErr.Type().Elem().Kind() - } - - switch kind { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: - message += fmt.Sprintf(" Field `%s` must be less than or equal to %s.", - fieldErr.Field(), fieldErr.Param()) - - case reflect.Slice, reflect.Map: - message += fmt.Sprintf(" Field `%s` must contain at most %s element(s).", - fieldErr.Field(), fieldErr.Param()) - - case reflect.String: - message += fmt.Sprintf(" Field `%s` must be at most %s character(s) long.", - fieldErr.Field(), fieldErr.Param()) - - default: - message += fieldErr.Error() - } - - case "gte": - fallthrough // gte and min are synonyms - case "min": - kind := fieldErr.Kind() - if kind == reflect.Ptr { - kind = fieldErr.Type().Elem().Kind() - } - - switch kind { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: - message += fmt.Sprintf(" Field `%s` must be greater or equal to %s.", - fieldErr.Field(), fieldErr.Param()) - - case reflect.Slice, reflect.Map: - message += fmt.Sprintf(" Field `%s` must contain at least %s element(s).", - fieldErr.Field(), fieldErr.Param()) - - case reflect.String: - message += fmt.Sprintf(" Field `%s` must be at least %s character(s) long.", - fieldErr.Field(), fieldErr.Param()) - - default: - message += fieldErr.Error() - } - - case "oneof": - message += fmt.Sprintf(" Field `%s` should be one of the following values: %s.", - fieldErr.Field(), fieldErr.Param()) - - case "required": - message += fmt.Sprintf(" Field `%s` is required.", fieldErr.Field()) - - default: - message += fmt.Sprintf(" Validation on field `%s` failed on the `%s` tag.", fieldErr.Field(), fieldErr.Tag()) - } - } - } - - return strings.TrimSpace(message) -} - -// StructCtx validates a structs exposed fields, and automatically validates -// nested structs, unless otherwise specified and also allows passing of -// context.Context for contextual validation information. -func StructCtx(ctx context.Context, s any) error { - return validate.StructCtx(ctx, s) -} - -// preferPublicName is a validator tag naming function that uses public names -// like a field's JSON tag instead of actual field names in structs. -// This is important because we sent these back as user-facing errors (and the -// users submitted them as JSON/path parameters). -func preferPublicName(fld reflect.StructField) string { - name, _, _ := strings.Cut(fld.Tag.Get("json"), ",") - if name != "" && name != "-" { - return name - } - - return fld.Name -} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go deleted file mode 100644 index 81e2d69c..00000000 --- a/internal/validate/validate_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package validate - -import ( - "reflect" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestFromValidator(t *testing.T) { - t.Parallel() - - // Fields have JSON tags so we can verify those are used over the - // property name. - type TestStruct struct { - MinInt int `json:"min_int" validate:"min=1"` - MinSlice []string `json:"min_slice" validate:"min=1"` - MinString string `json:"min_string" validate:"min=1"` - MaxInt int `json:"max_int" validate:"max=0"` - MaxSlice []string `json:"max_slice" validate:"max=0"` - MaxString string `json:"max_string" validate:"max=0"` - OneOf string `json:"one_of" validate:"oneof=blue green"` - Required string `json:"required" validate:"required"` - Unsupported string `json:"unsupported" validate:"e164"` - } - - validTestStruct := func() *TestStruct { - return &TestStruct{ - MinInt: 1, - MinSlice: []string{"1"}, - MinString: "value", - MaxInt: 0, - MaxSlice: []string{}, - MaxString: "", - OneOf: "blue", - Required: "value", - Unsupported: "+1123456789", - } - } - - t.Run("MaxInt", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MaxInt = 1 - require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MaxSlice", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MaxSlice = []string{"1"} - require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MaxString", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MaxString = "value" - require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MinInt", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinInt = 0 - require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MinSlice", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinSlice = nil - require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MinString", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinString = "" - require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("OneOf", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.OneOf = "red" - require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("Required", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.Required = "" - require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("Unsupported", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.Unsupported = "abc" - require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MultipleErrors", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinInt = 0 - testStruct.Required = "" - require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) - }) -} - -func TestPreferPublicNames(t *testing.T) { - t.Parallel() - - type testStruct struct { - JSONNameField string `json:"json_name"` - StructNameField string `apiquery:"-"` - } - - require.Equal(t, "json_name", - preferPublicName(reflect.TypeOf(testStruct{}).Field(0))) - require.Equal(t, "StructNameField", - preferPublicName(reflect.TypeOf(testStruct{}).Field(1))) -}