From 460a28f6c0f9fbe8e8b3a8459788e6306d330377 Mon Sep 17 00:00:00 2001 From: Brandur Date: Sun, 9 Mar 2025 11:03:26 -0700 Subject: [PATCH] Move API framework to its own repo for reuse Here, move the API framework, which is currently duplicated in two projects, over to its own repository so that we can reuse it between them. It's new home will be `riverapiframe` [1]. [1] https://github.com/riverqueue/riverapiframe/pull/1 --- .github/workflows/ci.yaml | 2 +- cmd/riverui/main.go | 2 +- dist/.gitkeep | 0 docs/development.md | 12 +- go.mod | 9 +- go.sum | 6 +- handler.go | 5 +- handler_api_endpoint.go | 44 ++- handler_api_endpoint_test.go | 20 +- internal/apiendpoint/api_endpoint.go | 234 ------------- internal/apiendpoint/api_endpoint_test.go | 316 ------------------ internal/apierror/api_error.go | 141 -------- internal/apierror/api_error_test.go | 74 ---- internal/apimiddleware/api_middleware.go | 71 ---- internal/apimiddleware/api_middleware_test.go | 84 ----- internal/validate/validate.go | 118 ------- internal/validate/validate_test.go | 135 -------- 17 files changed, 55 insertions(+), 1218 deletions(-) delete mode 100644 dist/.gitkeep delete mode 100644 internal/apiendpoint/api_endpoint.go delete mode 100644 internal/apiendpoint/api_endpoint_test.go delete mode 100644 internal/apierror/api_error.go delete mode 100644 internal/apierror/api_error_test.go delete mode 100644 internal/apimiddleware/api_middleware.go delete mode 100644 internal/apimiddleware/api_middleware_test.go delete mode 100644 internal/validate/validate.go delete mode 100644 internal/validate/validate_test.go diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7bfaf979..0e7824ef 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -81,7 +81,7 @@ jobs: name: Go lint runs-on: ubuntu-latest env: - GOLANGCI_LINT_VERSION: v1.63.4 + GOLANGCI_LINT_VERSION: v1.64.6 permissions: contents: read # allow read access to pull request. Use with `only-new-issues` option. diff --git a/cmd/riverui/main.go b/cmd/riverui/main.go index bf0cc51a..0ef10a01 100644 --- a/cmd/riverui/main.go +++ b/cmd/riverui/main.go @@ -16,11 +16,11 @@ import ( "github.com/rs/cors" sloghttp "github.com/samber/slog-http" + "github.com/riverqueue/apiframe/apimiddleware" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver/riverpgxv5" "riverqueue.com/riverui" - "riverqueue.com/riverui/internal/apimiddleware" ) func main() { diff --git a/dist/.gitkeep b/dist/.gitkeep deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/development.md b/docs/development.md index 2c80864e..5ec1213e 100644 --- a/docs/development.md +++ b/docs/development.md @@ -4,12 +4,6 @@ River UI consists of two apps: a Go backend API, and a TypeScript UI frontend. ## Environment -The project uses a combination of direnv and a `.env` file (to suit Vite conventions). Copy the example and edit as necessary: - -```sh -cp .env.example .env.local -``` - ## Install dependencies ```sh @@ -17,10 +11,12 @@ go get ./... npm install ``` -## Install Reflex - This project uses [Reflex](https://github.com/cespare/reflex) for local dev. Install it. +``` sh +go install github.com/cespare/reflex@latest +``` + ## Running the UI and API together ```sh diff --git a/go.mod b/go.mod index 3eb02f07..bb2dffa5 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,13 @@ module riverqueue.com/riverui -go 1.22.0 +go 1.23.0 -toolchain go1.23.5 +toolchain go1.24.1 require ( - github.com/go-playground/validator/v10 v10.25.0 github.com/google/uuid v1.6.0 - github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgx/v5 v5.7.2 + github.com/riverqueue/apiframe v0.0.0-20250310051455-203f3fd8260f github.com/riverqueue/river v0.18.0 github.com/riverqueue/river/riverdriver v0.18.0 github.com/riverqueue/river/riverdriver/riverpgxv5 v0.18.0 @@ -24,6 +23,8 @@ require ( github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.25.0 // indirect + github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect diff --git a/go.sum b/go.sum index c42a94b7..bbc2ee8f 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/riverqueue/apiframe v0.0.0-20250310051455-203f3fd8260f h1:Pr3+ERes1GkCJnw/gpi9yVx9tN7VCcPjuZCFN/OBaZg= +github.com/riverqueue/apiframe v0.0.0-20250310051455-203f3fd8260f/go.mod h1:ko/9b4SeomWrHTr4WU0i21peq90Qk2Mm8MgOqPrTcHA= github.com/riverqueue/river v0.18.0 h1:sGHeTOL9MR8+pMIVHRm59fzet8Ron/xjF3Yq/PSGb78= github.com/riverqueue/river v0.18.0/go.mod h1:oapX5xb/L2YnkE801QubDZ0COHxVxEGVY37icPzghhU= github.com/riverqueue/river/riverdriver v0.18.0 h1:a2haR5I0MQLHjLCSVFpUEeJALCLemRl5zCztucysm1E= @@ -49,8 +51,8 @@ github.com/riverqueue/river/rivertype v0.18.0 h1:YsXR5NbLAzniurGO0+zcISWMKq7Y71x github.com/riverqueue/river/rivertype v0.18.0/go.mod h1:DETcejveWlq6bAb8tHkbgJqmXWVLiFhTiEm8j7co1bE= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/samber/slog-http v1.5.1 h1:z5Ty/u5LKJbWjmjLDr6OgtwHXrPPrH2ZPoQE/47n9sU= diff --git a/handler.go b/handler.go index 53c22275..55d6357b 100644 --- a/handler.go +++ b/handler.go @@ -17,13 +17,12 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/riverqueue/apiframe/apiendpoint" + "github.com/riverqueue/apiframe/apimiddleware" "github.com/riverqueue/river" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/util/valutil" - - "riverqueue.com/riverui/internal/apiendpoint" - "riverqueue.com/riverui/internal/apimiddleware" ) // DB is the interface for a pgx database connection. diff --git a/handler_api_endpoint.go b/handler_api_endpoint.go index 313e3515..a4358c0d 100644 --- a/handler_api_endpoint.go +++ b/handler_api_endpoint.go @@ -13,6 +13,8 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/riverqueue/apiframe/apiendpoint" + "github.com/riverqueue/apiframe/apierror" "github.com/riverqueue/river" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/startstop" @@ -20,8 +22,6 @@ import ( "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" - "riverqueue.com/riverui/internal/apiendpoint" - "riverqueue.com/riverui/internal/apierror" "riverqueue.com/riverui/internal/dbsqlc" "riverqueue.com/riverui/internal/querycacher" "riverqueue.com/riverui/internal/util/pgxutil" @@ -104,7 +104,7 @@ func (a *healthCheckGetEndpoint) Execute(ctx context.Context, req *healthCheckGe // fall through to OK status response below default: - return nil, apierror.NewNotFound("Health check %q not found. Use either `complete` or `minimal`.", req.Name) + return nil, apierror.NewNotFoundf("Health check %q not found. Use either `complete` or `minimal`.", req.Name) } return statusResponseOK, nil @@ -142,7 +142,7 @@ func (a *jobCancelEndpoint) Execute(ctx context.Context, req *jobCancelRequest) job, err := a.client.JobCancelTx(ctx, tx, jobID) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(jobID) + return nil, NewNotFoundJob(jobID) } return nil, err } @@ -185,10 +185,10 @@ func (a *jobDeleteEndpoint) Execute(ctx context.Context, req *jobDeleteRequest) _, err := a.client.JobDeleteTx(ctx, tx, jobID) if err != nil { if errors.Is(err, rivertype.ErrJobRunning) { - return nil, apierror.NewBadRequest("Job %d is running and can't be deleted until it finishes.", jobID) + return nil, apierror.NewBadRequestf("Job %d is running and can't be deleted until it finishes.", jobID) } if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(jobID) + return nil, NewNotFoundJob(jobID) } return nil, err } @@ -227,7 +227,7 @@ func (req *jobGetRequest) ExtractRaw(r *http.Request) error { jobID, err := strconv.ParseInt(idString, 10, 64) if err != nil { - return apierror.NewBadRequest("Couldn't convert job ID to int64: %s.", err) + return apierror.NewBadRequestf("Couldn't convert job ID to int64: %s.", err) } req.JobID = jobID @@ -239,7 +239,7 @@ func (a *jobGetEndpoint) Execute(ctx context.Context, req *jobGetRequest) (*Rive job, err := a.client.JobGetTx(ctx, tx, req.JobID) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(req.JobID) + return nil, NewNotFoundJob(req.JobID) } return nil, fmt.Errorf("error getting job: %w", err) } @@ -276,7 +276,7 @@ func (req *jobListRequest) ExtractRaw(r *http.Request) error { if limitStr := r.URL.Query().Get("limit"); limitStr != "" { limit, err := strconv.Atoi(limitStr) if err != nil { - return apierror.NewBadRequest("Couldn't convert `limit` to integer: %s.", err) + return apierror.NewBadRequestf("Couldn't convert `limit` to integer: %s.", err) } req.Limit = &limit @@ -344,7 +344,7 @@ func (a *jobRetryEndpoint) Execute(ctx context.Context, req *jobRetryRequest) (* _, err := a.client.JobRetryTx(ctx, tx, jobID) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundJob(jobID) + return nil, NewNotFoundJob(jobID) } return nil, err } @@ -388,7 +388,7 @@ func (a *queueGetEndpoint) Execute(ctx context.Context, req *queueGetRequest) (* queue, err := a.client.QueueGetTx(ctx, tx, req.Name) if err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundQueue(req.Name) + return nil, NewNotFoundQueue(req.Name) } return nil, fmt.Errorf("error getting queue: %w", err) } @@ -430,7 +430,7 @@ func (req *queueListRequest) ExtractRaw(r *http.Request) error { if limitStr := r.URL.Query().Get("limit"); limitStr != "" { limit, err := strconv.Atoi(limitStr) if err != nil { - return apierror.NewBadRequest("Couldn't convert `limit` to integer: %s.", err) + return apierror.NewBadRequestf("Couldn't convert `limit` to integer: %s.", err) } req.Limit = &limit @@ -490,7 +490,7 @@ func (a *queuePauseEndpoint) Execute(ctx context.Context, req *queuePauseRequest return pgxutil.WithTxV(ctx, a.dbPool, func(ctx context.Context, tx pgx.Tx) (*statusResponse, error) { if err := a.client.QueuePauseTx(ctx, tx, req.Name, nil); err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundQueue(req.Name) + return nil, NewNotFoundQueue(req.Name) } return nil, fmt.Errorf("error pausing queue: %w", err) } @@ -532,7 +532,7 @@ func (a *queueResumeEndpoint) Execute(ctx context.Context, req *queueResumeReque return pgxutil.WithTxV(ctx, a.dbPool, func(ctx context.Context, tx pgx.Tx) (*statusResponse, error) { if err := a.client.QueueResumeTx(ctx, tx, req.Name, nil); err != nil { if errors.Is(err, river.ErrNotFound) { - return nil, apierror.NewNotFoundQueue(req.Name) + return nil, NewNotFoundQueue(req.Name) } return nil, fmt.Errorf("error resuming queue: %w", err) } @@ -670,7 +670,7 @@ func (a *workflowGetEndpoint) Execute(ctx context.Context, req *workflowGetReque } if len(jobs) < 1 { - return nil, apierror.NewNotFoundWorkflow(req.ID) + return nil, NewNotFoundWorkflow(req.ID) } return &workflowGetResponse{ @@ -712,7 +712,7 @@ func (req *workflowListRequest) ExtractRaw(r *http.Request) error { if limitStr := r.URL.Query().Get("limit"); limitStr != "" { limit, err := strconv.Atoi(limitStr) if err != nil { - return apierror.NewBadRequest("Couldn't convert `limit` to integer: %s.", err) + return apierror.NewBadRequestf("Couldn't convert `limit` to integer: %s.", err) } req.Limit = &limit @@ -758,6 +758,18 @@ func (a *workflowListEndpoint) Execute(ctx context.Context, req *workflowListReq } } +func NewNotFoundJob(jobID int64) *apierror.NotFound { + return apierror.NewNotFoundf("Job not found: %d.", jobID) +} + +func NewNotFoundQueue(name string) *apierror.NotFound { + return apierror.NewNotFoundf("Queue not found: %s.", name) +} + +func NewNotFoundWorkflow(id string) *apierror.NotFound { + return apierror.NewNotFoundf("Workflow not found: %s.", id) +} + type RiverJob struct { ID int64 `json:"id"` Args json.RawMessage `json:"args"` diff --git a/handler_api_endpoint_test.go b/handler_api_endpoint_test.go index ebc70d60..5649cf99 100644 --- a/handler_api_endpoint_test.go +++ b/handler_api_endpoint_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" + "github.com/riverqueue/apiframe/apierror" "github.com/riverqueue/river" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/riversharedtest" @@ -17,7 +18,6 @@ import ( "github.com/riverqueue/river/rivershared/util/ptrutil" "github.com/riverqueue/river/rivertype" - "riverqueue.com/riverui/internal/apierror" "riverqueue.com/riverui/internal/riverinternaltest" "riverqueue.com/riverui/internal/riverinternaltest/testfactory" ) @@ -104,7 +104,7 @@ func TestHandlerHealthCheckGetEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newHealthCheckGetEndpoint) _, err := endpoint.Execute(ctx, &healthCheckGetRequest{Name: "other"}) - requireAPIError(t, apierror.NewNotFound("Health check %q not found. Use either `complete` or `minimal`.", "other"), err) + requireAPIError(t, apierror.NewNotFoundf("Health check %q not found. Use either `complete` or `minimal`.", "other"), err) }) } @@ -140,7 +140,7 @@ func TestJobCancelEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobCancelEndpoint) _, err := endpoint.Execute(ctx, &jobCancelRequest{JobIDs: []int64String{123}}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -174,7 +174,7 @@ func TestJobDeleteEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobDeleteEndpoint) _, err := endpoint.Execute(ctx, &jobDeleteRequest{JobIDs: []int64String{123}}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -201,7 +201,7 @@ func TestJobGetEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobGetEndpoint) _, err := endpoint.Execute(ctx, &jobGetRequest{JobID: 123}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -322,7 +322,7 @@ func TestJobRetryEndpoint(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newJobRetryEndpoint) _, err := endpoint.Execute(ctx, &jobRetryRequest{JobIDs: []int64String{123}}) - requireAPIError(t, apierror.NewNotFoundJob(123), err) + requireAPIError(t, NewNotFoundJob(123), err) }) } @@ -353,7 +353,7 @@ func TestAPIHandlerQueueGet(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newQueueGetEndpoint) _, err := endpoint.Execute(ctx, &queueGetRequest{Name: "does_not_exist"}) - requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) + requireAPIError(t, NewNotFoundQueue("does_not_exist"), err) }) } @@ -420,7 +420,7 @@ func TestAPIHandlerQueuePause(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newQueuePauseEndpoint) _, err := endpoint.Execute(ctx, &queuePauseRequest{Name: "does_not_exist"}) - requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) + requireAPIError(t, NewNotFoundQueue("does_not_exist"), err) }) } @@ -449,7 +449,7 @@ func TestAPIHandlerQueueResume(t *testing.T) { endpoint, _ := setupEndpoint(ctx, t, newQueueResumeEndpoint) _, err := endpoint.Execute(ctx, &queueResumeRequest{Name: "does_not_exist"}) - requireAPIError(t, apierror.NewNotFoundQueue("does_not_exist"), err) + requireAPIError(t, NewNotFoundQueue("does_not_exist"), err) }) } @@ -577,7 +577,7 @@ func TestAPIHandlerWorkflowGet(t *testing.T) { workflowID := uuid.New() _, err := endpoint.Execute(ctx, &workflowGetRequest{ID: workflowID.String()}) - requireAPIError(t, apierror.NewNotFoundWorkflow(workflowID.String()), err) + requireAPIError(t, NewNotFoundWorkflow(workflowID.String()), err) }) } diff --git a/internal/apiendpoint/api_endpoint.go b/internal/apiendpoint/api_endpoint.go deleted file mode 100644 index ea3b81e9..00000000 --- a/internal/apiendpoint/api_endpoint.go +++ /dev/null @@ -1,234 +0,0 @@ -// Package apiendpoint provides a lightweight API framework for use with River -// UI. It lets API endpoints be defined, then mounted into an http.ServeMux. -package apiendpoint - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "time" - - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5/pgconn" - - "riverqueue.com/riverui/internal/apierror" - "riverqueue.com/riverui/internal/validate" -) - -// Endpoint is a struct that should be embedded on an API endpoint, and which -// provides a partial implementation for EndpointInterface. -type Endpoint[TReq any, TResp any] struct { - // Logger used to log information about endpoint execution. - logger *slog.Logger - - // Metadata about the endpoint. This is not available until SetMeta is - // invoked on the endpoint, which is usually done in Mount. - meta *EndpointMeta -} - -func (e *Endpoint[TReq, TResp]) SetLogger(logger *slog.Logger) { e.logger = logger } -func (e *Endpoint[TReq, TResp]) SetMeta(meta *EndpointMeta) { e.meta = meta } - -type EndpointInterface interface { - // Meta returns metadata about an API endpoint, like the path it should be - // mounted at, and the status code it returns on success. - // - // This should be implemented by each specific API endpoint. - Meta() *EndpointMeta - - // SetLogger sets a logger on the endpoint. - // - // Implementation inherited from an embedded Endpoint struct. - SetLogger(logger *slog.Logger) - - // SetMeta sets metadata on an Endpoint struct after it's extracted from a - // call to an endpoint's Meta function. - // - // Implementation inherited from an embedded Endpoint struct. - SetMeta(meta *EndpointMeta) -} - -// EndpointExecuteInterface is an interface to an API endpoint. Some of it is -// implemented by an embedded Endpoint struct, and some of it should be -// implemented by the endpoint itself. -type EndpointExecuteInterface[TReq any, TResp any] interface { - EndpointInterface - - // Execute executes the API endpoint. - // - // This should be implemented by each specific API endpoint. - Execute(ctx context.Context, req *TReq) (*TResp, error) -} - -// EndpointMeta is metadata about an API endpoint. -type EndpointMeta struct { - // Pattern is the API endpoint's HTTP method and path where it should be - // mounted, which is passed to http.ServeMux by Mount. It should start with - // a verb like `GET` or `POST`, and may contain Go 1.22 path variables like - // `{name}`, whose values should be extracted by an endpoint request - // struct's custom ExtractRaw implementation. - Pattern string - - // StatusCode is the status code to be set on a successful response. - StatusCode int -} - -func (m *EndpointMeta) validate() { - if m.Pattern == "" { - panic("Endpoint.Path is required") - } - if m.StatusCode == 0 { - panic("Endpoint.StatusCode is required") - } -} - -// Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log -// information about endpoint execution. -func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointExecuteInterface[TReq, TResp]) EndpointInterface { - apiEndpoint.SetLogger(logger) - - meta := apiEndpoint.Meta() - meta.validate() // panic on problem - apiEndpoint.SetMeta(meta) - - mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) { - executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute) - }) - - return apiEndpoint -} - -func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) { - ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) - defer cancel() - - // Run as much code as we can in a sub-function that can return an error. - // This is more convenient to write, but is also safer because unlike when - // writing errors to ResponseWriter, there's no danger of a missing return. - err := func() error { - var req TReq - if r.Method != http.MethodGet { - reqData, err := io.ReadAll(r.Body) - if err != nil { - return fmt.Errorf("error reading request body: %w", err) - } - - if len(reqData) > 0 { - if err := json.Unmarshal(reqData, &req); err != nil { - return apierror.NewBadRequest("Error unmarshaling request body: %s.", err) - } - } - } - - if rawExtractor, ok := any(&req).(RawExtractor); ok { - if err := rawExtractor.ExtractRaw(r); err != nil { - return err - } - } - - if err := validate.StructCtx(ctx, &req); err != nil { - return apierror.NewBadRequest(validate.PublicFacingMessage(err)) //nolint:govet - } - - resp, err := execute(ctx, &req) - if err != nil { - return err - } - - if rawExtractor, ok := any(resp).(RawResponder); ok { - return rawExtractor.RespondRaw(w) - } - - respData, err := json.Marshal(resp) - if err != nil { - return fmt.Errorf("error marshaling response JSON: %w", err) - } - - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(meta.StatusCode) - - if _, err := w.Write(respData); err != nil { - return fmt.Errorf("error writing response: %w", err) - } - - return nil - }() - if err != nil { - // Convert certain types of Postgres errors into something more - // user-friendly than an internal server error. - err = maybeInterpretInternalError(err) - - var apiErr apierror.Interface - if errors.As(err, &apiErr) { - logAttrs := []any{ - slog.String("error", apiErr.Error()), - } - - if internalErr := apiErr.GetInternalError(); internalErr != nil { - logAttrs = append(logAttrs, slog.String("internal_error", internalErr.Error())) - } - - // Logged at info level because API errors are normal. - logger.InfoContext(ctx, "API error response", logAttrs...) - - apiErr.Write(ctx, logger, w) - return - } - - if errors.Is(err, context.DeadlineExceeded) { - logger.ErrorContext(ctx, "request timeout", slog.String("error", err.Error())) - apierror.NewServiceUnavailable("Request timed out. Retrying the request might work.").Write(ctx, logger, w) - return - } - - // Internal server error. The error goes to logs but should not be - // included in the response in case there's something sensitive in - // the error string. - logger.ErrorContext(ctx, "error running API route", slog.String("error", err.Error())) - apierror.NewInternalServerError("Internal server error. Check logs for more information.").Write(ctx, logger, w) - } -} - -// RawExtractor is an interface that can be implemented by request structs that -// allows them to extract information from a raw request, like path values. -type RawExtractor interface { - ExtractRaw(r *http.Request) error -} - -// RawResponder is an interface that can be implemented by response structs that -// allow them to respond directly to a ResponseWriter instead of emitting the -// normal JSON format. -type RawResponder interface { - RespondRaw(w http.ResponseWriter) error -} - -// Make some broad categories of internal error back into something public -// facing because in some cases they can be a vast help for debugging. -func maybeInterpretInternalError(err error) error { - var ( - apiErr apierror.Interface - connectErr *pgconn.ConnectError - pgErr *pgconn.PgError - ) - - switch { - case errors.As(err, &connectErr): - apiErr = apierror.NewBadRequest("There was a problem connecting to the configured database. Check logs for details.") - - case errors.As(err, &pgErr): - if pgErr.Code == pgerrcode.InsufficientPrivilege { - apiErr = apierror.NewBadRequest("Insufficient database privilege to perform this operation.") - } else { - return err - } - - default: - return err - } - - return apierror.WithInternalError(apiErr, err) -} diff --git a/internal/apiendpoint/api_endpoint_test.go b/internal/apiendpoint/api_endpoint_test.go deleted file mode 100644 index d7a5fce8..00000000 --- a/internal/apiendpoint/api_endpoint_test.go +++ /dev/null @@ -1,316 +0,0 @@ -package apiendpoint - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5/pgconn" - "github.com/stretchr/testify/require" - - "riverqueue.com/riverui/internal/apierror" - "riverqueue.com/riverui/internal/riverinternaltest" -) - -func TestMountAndServe(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - type testBundle struct { - recorder *httptest.ResponseRecorder - } - - setup := func(t *testing.T) (*http.ServeMux, *testBundle) { - t.Helper() - - var ( - logger = riverinternaltest.Logger(t) - mux = http.NewServeMux() - ) - - Mount(mux, logger, &getEndpoint{}) - Mount(mux, logger, &postEndpoint{}) - - return mux, &testBundle{ - recorder: httptest.NewRecorder(), - } - } - - t.Run("GetEndpointAndExtractRaw", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", nil) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) - }) - - t.Run("BodyIgnoredOnGet", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", - bytes.NewBuffer(mustMarshalJSON(t, &getRequest{IgnoredJSONMessage: "Ignored hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) - }) - - t.Run("MethodNotAllowed", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint/Hello.", nil) - mux.ServeHTTP(bundle.recorder, req) - - // This error comes from net/http. - requireStatusAndResponse(t, http.StatusMethodNotAllowed, "Method Not Allowed\n", bundle.recorder) - }) - - t.Run("PostEndpoint", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) - }) - - t.Run("ValidationError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Field `message` is required."}, bundle.recorder) - }) - - t.Run("APIError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeAPIError: true, Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Bad request."}, bundle.recorder) - }) - - t.Run("InterpretedError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true, Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Insufficient database privilege to perform this operation."}, bundle.recorder) - }) - - t.Run("Timeout", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - ctx, cancel := context.WithDeadline(ctx, time.Now()) - t.Cleanup(cancel) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) - require.NoError(t, err) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusServiceUnavailable, &apierror.APIError{Message: "Request timed out. Retrying the request might work."}, bundle.recorder) - }) - - t.Run("InternalServerError", func(t *testing.T) { - t.Parallel() - - mux, bundle := setup(t) - - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true, Message: "Hello."}))) - mux.ServeHTTP(bundle.recorder, req) - - requireStatusAndJSONResponse(t, http.StatusInternalServerError, &apierror.APIError{Message: "Internal server error. Check logs for more information."}, bundle.recorder) - }) -} - -func TestMaybeInterpretInternalError(t *testing.T) { - t.Parallel() - - ctx := context.Background() - - t.Run("ConnectError", func(t *testing.T) { - t.Parallel() - - _, err := pgconn.Connect(ctx, "postgres://user@127.0.0.1:37283/does_not_exist") - - require.Equal(t, apierror.WithInternalError(apierror.NewBadRequest("There was a problem connecting to the configured database. Check logs for details."), err), maybeInterpretInternalError(err)) - }) - - t.Run("ConnectError", func(t *testing.T) { - t.Parallel() - - err := &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege} - - require.Equal(t, apierror.WithInternalError(apierror.NewBadRequest("Insufficient database privilege to perform this operation."), err), maybeInterpretInternalError(err)) - }) - - t.Run("OtherPGError", func(t *testing.T) { - t.Parallel() - - err := &pgconn.PgError{Code: pgerrcode.CardinalityViolation} - - require.Equal(t, err, maybeInterpretInternalError(err)) - }) - - t.Run("ConnectError", func(t *testing.T) { - t.Parallel() - - err := errors.New("other error") - - require.Equal(t, err, maybeInterpretInternalError(err)) - }) -} - -func mustMarshalJSON(t *testing.T, v any) []byte { - t.Helper() - - data, err := json.Marshal(v) - require.NoError(t, err) - return data -} - -func mustUnmarshalJSON[T any](t *testing.T, data []byte) *T { - t.Helper() - - var val T - err := json.Unmarshal(data, &val) - require.NoError(t, err) - return &val -} - -// Shortcut for requiring an HTTP status code and a JSON-marshaled response -// equivalent to expectedResp. The important thing that is does is that in the -// event of a failure on status code, it prints the response body as additional -// context to help debug the problem. -func requireStatusAndJSONResponse[T any](t *testing.T, expectedStatusCode int, expectedResp *T, recorder *httptest.ResponseRecorder) { - t.Helper() - - require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) - require.Equal(t, expectedResp, mustUnmarshalJSON[T](t, recorder.Body.Bytes())) - require.Equal(t, "application/json; charset=utf-8", recorder.Header().Get("Content-Type")) -} - -// Same as the above, but for a non-JSON response. -func requireStatusAndResponse(t *testing.T, expectedStatusCode int, expectedResp string, recorder *httptest.ResponseRecorder) { - t.Helper() - - require.Equal(t, expectedStatusCode, recorder.Result().StatusCode, "Unexpected status code; response body: %s", recorder.Body.String()) - require.Equal(t, expectedResp, recorder.Body.String()) -} - -// -// getEndpoint -// - -type getEndpoint struct { - Endpoint[getRequest, getResponse] -} - -func (*getEndpoint) Meta() *EndpointMeta { - return &EndpointMeta{ - Pattern: "GET /api/get-endpoint/{message}", - StatusCode: http.StatusOK, - } -} - -type getRequest struct { - IgnoredJSONMessage string `json:"ignored_json" validate:"-"` - Message string `json:"-" validate:"required"` -} - -func (req *getRequest) ExtractRaw(r *http.Request) error { - req.Message = r.PathValue("message") - return nil -} - -type getResponse struct { - Message string `json:"message" validate:"required"` -} - -func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, error) { - // This branch never gets taken because request bodies are ignored on GET. - if req.IgnoredJSONMessage != "" { - return &getResponse{Message: req.IgnoredJSONMessage}, nil - } - - return &getResponse{Message: req.Message}, nil -} - -// -// postEndpoint -// - -type postEndpoint struct { - Endpoint[postRequest, postResponse] -} - -func (*postEndpoint) Meta() *EndpointMeta { - return &EndpointMeta{ - Pattern: "POST /api/post-endpoint", - StatusCode: http.StatusCreated, - } -} - -type postRequest struct { - MakeAPIError bool `json:"make_api_error" validate:"-"` - MakeInternalError bool `json:"make_internal_error" validate:"-"` - MakePostgresError bool `json:"make_postgres_error" validate:"-"` - Message string `json:"message" validate:"required"` -} - -type postResponse struct { - Message string `json:"message"` -} - -func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResponse, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - - if req.MakeAPIError { - return nil, apierror.NewBadRequest("Bad request.") - } - - if req.MakeInternalError { - return nil, errors.New("an internal error occurred") - } - - if req.MakePostgresError { - // Wrap the error to make it more realistic. - return nil, fmt.Errorf("error running Postgres query: %w", &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege}) - } - - return &postResponse{Message: req.Message}, nil -} diff --git a/internal/apierror/api_error.go b/internal/apierror/api_error.go deleted file mode 100644 index e553d6cf..00000000 --- a/internal/apierror/api_error.go +++ /dev/null @@ -1,141 +0,0 @@ -// Package apierror contains a variety of marshalable API errors that adhere to -// a unified error response convention. -package apierror - -import ( - "context" - "encoding/json" - "fmt" - "log/slog" - "net/http" -) - -// APIError is a struct that's embedded on a more specific API error struct (as -// seen below), and which provides a JSON serialization and a wait to -// conveniently write itself to an HTTP response. -// -// APIErrorInterface should be used with errors.As instead of this struct. -type APIError struct { - // InternalError is an additional error that might be associated with the - // API error. It's not returned in the API error response, but is logged in - // API endpoint execution to provide extra information for operators. - InternalError error `json:"-"` - - // Message is a descriptive, human-friendly message indicating what went - // wrong. Try to make error messages as actionable as possible to help the - // caller easily fix what went wrong. - Message string `json:"message"` - - // StatusCode is the API error's HTTP status code. It's not marshaled to - // JSON, but determines how the error is written to a response. - StatusCode int `json:"-"` -} - -func (e *APIError) Error() string { return e.Message } -func (e *APIError) GetInternalError() error { return e.InternalError } -func (e *APIError) SetInternalError(internalErr error) { e.InternalError = internalErr } - -// Write writes the API error to an HTTP response, writing to the given logger -// in case of a problem. -func (e *APIError) Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.WriteHeader(e.StatusCode) - - respData, err := json.Marshal(e) - if err != nil { - logger.ErrorContext(ctx, "error marshaling API error", slog.String("error", err.Error())) - } - - if _, err := w.Write(respData); err != nil { - logger.ErrorContext(ctx, "error writing API error", slog.String("error", err.Error())) - } -} - -// Interface is an interface to an API error. This is needed for use with -// errors.As because APIError itself is embedded on another error struct, and -// won't be usable as an errors.As target. -type Interface interface { - Error() string - GetInternalError() error - SetInternalError(internalErr error) - Write(ctx context.Context, logger *slog.Logger, w http.ResponseWriter) -} - -// WithInternalError is a convenience function for assigning an internal error -// to the given API error and returning it. -func WithInternalError[TAPIError Interface](apiErr TAPIError, internalErr error) TAPIError { - apiErr.SetInternalError(internalErr) - return apiErr -} - -// -// BadRequest -// - -type BadRequest struct { //nolint:errname - APIError -} - -func NewBadRequest(format string, a ...any) *BadRequest { - return &BadRequest{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusBadRequest, - }, - } -} - -// -// InternalServerError -// - -type InternalServerError struct { - APIError -} - -func NewInternalServerError(format string, a ...any) *InternalServerError { - return &InternalServerError{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusInternalServerError, - }, - } -} - -// -// NotFound -// - -type NotFound struct { //nolint:errname - APIError -} - -func NewNotFound(format string, a ...any) *NotFound { - return &NotFound{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusNotFound, - }, - } -} - -func NewNotFoundJob(jobID int64) *NotFound { return NewNotFound("Job not found: %d.", jobID) } -func NewNotFoundQueue(name string) *NotFound { return NewNotFound("Queue not found: %s.", name) } -func NewNotFoundWorkflow(id string) *NotFound { return NewNotFound("Workflow not found: %s.", id) } - -// -// ServiceUnavailable -// - -type ServiceUnavailable struct { //nolint:errname - APIError -} - -func NewServiceUnavailable(format string, a ...any) *ServiceUnavailable { - return &ServiceUnavailable{ - APIError: APIError{ - Message: fmt.Sprintf(format, a...), - StatusCode: http.StatusServiceUnavailable, - }, - } -} diff --git a/internal/apierror/api_error_test.go b/internal/apierror/api_error_test.go deleted file mode 100644 index 0a6178cf..00000000 --- a/internal/apierror/api_error_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package apierror - -import ( - "context" - "encoding/json" - "errors" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" - - "riverqueue.com/riverui/internal/riverinternaltest" -) - -func TestAPIError(t *testing.T) { - t.Parallel() - - var ( - anErr = errors.New("an error") - apiErr = NewBadRequest("Bad request.") - ) - - apiErr.SetInternalError(anErr) - require.Equal(t, anErr, apiErr.GetInternalError()) -} - -func TestAPIErrorJSON(t *testing.T) { - t.Parallel() - - require.JSONEq(t, - `{"message":"Bad request. Try sending JSON next time."}`, - string(mustMarshalJSON( - t, NewBadRequest("Bad request. Try sending JSON next time."))), - ) -} - -func TestAPIErrorWrite(t *testing.T) { - t.Parallel() - - var ( - ctx = context.Background() - logger = riverinternaltest.Logger(t) - recorder = httptest.NewRecorder() - ) - - NewBadRequest("Bad request. Try sending JSON next time.").Write(ctx, logger, recorder) - - require.Equal(t, 400, recorder.Result().StatusCode) - require.JSONEq(t, - `{"message":"Bad request. Try sending JSON next time."}`, - recorder.Body.String(), - ) - require.Equal(t, "application/json; charset=utf-8", recorder.Header().Get("Content-Type")) -} - -func TestWithInternalError(t *testing.T) { - t.Parallel() - - var ( - anErr = errors.New("an error") - apiErr = NewBadRequest("Bad request.") - ) - - apiErr = WithInternalError(apiErr, anErr) - require.Equal(t, anErr, apiErr.InternalError) -} - -func mustMarshalJSON(t *testing.T, v any) []byte { - t.Helper() - - data, err := json.Marshal(v) - require.NoError(t, err) - return data -} diff --git a/internal/apimiddleware/api_middleware.go b/internal/apimiddleware/api_middleware.go deleted file mode 100644 index 9149d8f4..00000000 --- a/internal/apimiddleware/api_middleware.go +++ /dev/null @@ -1,71 +0,0 @@ -package apimiddleware - -import ( - "net/http" -) - -// middlewareInterface is an interface to be implemented by middleware. -type middlewareInterface interface { - Middleware(next http.Handler) http.Handler -} - -// MiddlewareFunc allows a simple middleware to be defined as only a function. -type MiddlewareFunc func(next http.Handler) http.Handler - -// Middleware allows MiddlewareFunc to implement middlewareInterface. -func (f MiddlewareFunc) Middleware(next http.Handler) http.Handler { - return f(next) -} - -// MiddlewareStack builds a stack of middleware that a request will be -// dispatched through before sending it to the underlying handler. Middlewares -// are added to it before getting a handler with a call to Mount on another -// handler like a ServeMux. -// used like: -// -// middlewares := &MiddlewareStack -// middlewares.Use(middleware1) -// middlewares.Use(middleware2) -// ... -// handler := middlewares.Mount(mux) -// -// Besides some slight syntactic nicety, the entire reason this type exists is -// because it will mount middlewares in a more human-friendly/intuitive order. -// When mounting middlewares (not using MiddlewareStack) like: -// -// handler := mux -// handler = middleware1.Wrapper(handler) -// handler = middleware2.Wrapper(handler) -// ... -// -// One must be very careful because the middlewares will be "backwards" -// according to the list in that when a request enters the stack, the middleware -// that was mounted first will be called _last_ because it's nested the deepest -// down. -// -// MiddlewareStack fixes this problem by enabling any number of middlewares to -// be specified, and then mounting them in inverted order when Mount is called. -type MiddlewareStack struct { - middlewares []middlewareInterface -} - -// NewMiddlewareStack is a helper that can act as a shortcut to initialize a -// middleware stack by passing a series of middlewares as variadic args. -func NewMiddlewareStack(middlewares ...middlewareInterface) *MiddlewareStack { - stack := &MiddlewareStack{} - for _, mw := range middlewares { - stack.Use(mw) - } - return stack -} - -func (s *MiddlewareStack) Mount(handler http.Handler) http.Handler { - for i := len(s.middlewares) - 1; i >= 0; i-- { - handler = s.middlewares[i].Middleware(handler) - } - return handler -} - -func (s *MiddlewareStack) Use(middleware middlewareInterface) { - s.middlewares = append(s.middlewares, middleware) -} diff --git a/internal/apimiddleware/api_middleware_test.go b/internal/apimiddleware/api_middleware_test.go deleted file mode 100644 index a0c323e7..00000000 --- a/internal/apimiddleware/api_middleware_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package apimiddleware - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/require" -) - -// Verify MiddlewareFunc complies with middlewareInterface. -var _ middlewareInterface = MiddlewareFunc(func(next http.Handler) http.Handler { - return next -}) - -type contextTrailContextKey struct{} - -// Adds the configured segment to trail in context. -type contextTrailMiddleware struct { - segment string -} - -func newContextTrailMiddleware(segment string) *contextTrailMiddleware { - return &contextTrailMiddleware{segment: segment} -} - -func (m *contextTrailMiddleware) Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Extract - var contextTrail []string - if existingTrail, ok := ctx.Value(contextTrailContextKey{}).([]string); ok { - contextTrail = existingTrail - } - contextTrail = append(contextTrail, m.segment) - - next.ServeHTTP(w, r.WithContext(context.WithValue(ctx, contextTrailContextKey{}, contextTrail))) - }) -} - -func TestMiddlewareStack(t *testing.T) { - t.Parallel() - - makeRequestAndExtractTrail := func(stack *MiddlewareStack) []string { - var contextTrail []string - - handler := stack.Mount(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - contextTrail = r.Context().Value(contextTrailContextKey{}).([]string) //nolint:forcetypeassert - })) - - recorder := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "https://example.com", nil) - handler.ServeHTTP(recorder, req) - - return contextTrail - } - - t.Run("NewMiddlewareStack", func(t *testing.T) { - t.Parallel() - - stack := NewMiddlewareStack( - newContextTrailMiddleware("1st"), - newContextTrailMiddleware("2nd"), - newContextTrailMiddleware("3rd"), - ) - - contextTrail := makeRequestAndExtractTrail(stack) - require.Equal(t, []string{"1st", "2nd", "3rd"}, contextTrail) - }) - - t.Run("Use", func(t *testing.T) { - t.Parallel() - - stack := &MiddlewareStack{} - stack.Use(newContextTrailMiddleware("1st")) - stack.Use(newContextTrailMiddleware("2nd")) - stack.Use(newContextTrailMiddleware("3rd")) - - contextTrail := makeRequestAndExtractTrail(stack) - require.Equal(t, []string{"1st", "2nd", "3rd"}, contextTrail) - }) -} diff --git a/internal/validate/validate.go b/internal/validate/validate.go deleted file mode 100644 index 184318d2..00000000 --- a/internal/validate/validate.go +++ /dev/null @@ -1,118 +0,0 @@ -// Package validate internalizes Go Playground's Validator framework, setting -// some common options that we use everywhere, providing some useful helpers, -// and exporting a simplified API. -package validate - -import ( - "context" - "fmt" - "reflect" - "strings" - - "github.com/go-playground/validator/v10" -) - -// WithRequiredStructEnabled can be removed once validator/v11 is released. -var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals - -func init() { //nolint:gochecknoinits - validate.RegisterTagNameFunc(preferPublicName) -} - -// PublicFacingMessage builds a complete error message from a validator error -// that's suitable for public-facing consumption. -// -// I only added a few possible validations to start. We'll probably need to add -// more as we go and expand our usage. -func PublicFacingMessage(validatorErr error) string { - var message string - - //nolint:errorlint - if validationErrs, ok := validatorErr.(validator.ValidationErrors); ok { - for _, fieldErr := range validationErrs { - switch fieldErr.Tag() { - case "lte": - fallthrough // lte and max are synonyms - case "max": - kind := fieldErr.Kind() - if kind == reflect.Ptr { - kind = fieldErr.Type().Elem().Kind() - } - - switch kind { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: - message += fmt.Sprintf(" Field `%s` must be less than or equal to %s.", - fieldErr.Field(), fieldErr.Param()) - - case reflect.Slice, reflect.Map: - message += fmt.Sprintf(" Field `%s` must contain at most %s element(s).", - fieldErr.Field(), fieldErr.Param()) - - case reflect.String: - message += fmt.Sprintf(" Field `%s` must be at most %s character(s) long.", - fieldErr.Field(), fieldErr.Param()) - - default: - message += fieldErr.Error() - } - - case "gte": - fallthrough // gte and min are synonyms - case "min": - kind := fieldErr.Kind() - if kind == reflect.Ptr { - kind = fieldErr.Type().Elem().Kind() - } - - switch kind { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int32, reflect.Int64: - message += fmt.Sprintf(" Field `%s` must be greater or equal to %s.", - fieldErr.Field(), fieldErr.Param()) - - case reflect.Slice, reflect.Map: - message += fmt.Sprintf(" Field `%s` must contain at least %s element(s).", - fieldErr.Field(), fieldErr.Param()) - - case reflect.String: - message += fmt.Sprintf(" Field `%s` must be at least %s character(s) long.", - fieldErr.Field(), fieldErr.Param()) - - default: - message += fieldErr.Error() - } - - case "oneof": - message += fmt.Sprintf(" Field `%s` should be one of the following values: %s.", - fieldErr.Field(), fieldErr.Param()) - - case "required": - message += fmt.Sprintf(" Field `%s` is required.", fieldErr.Field()) - - default: - message += fmt.Sprintf(" Validation on field `%s` failed on the `%s` tag.", fieldErr.Field(), fieldErr.Tag()) - } - } - } - - return strings.TrimSpace(message) -} - -// StructCtx validates a structs exposed fields, and automatically validates -// nested structs, unless otherwise specified and also allows passing of -// context.Context for contextual validation information. -func StructCtx(ctx context.Context, s any) error { - return validate.StructCtx(ctx, s) -} - -// preferPublicName is a validator tag naming function that uses public names -// like a field's JSON tag instead of actual field names in structs. -// This is important because we sent these back as user-facing errors (and the -// users submitted them as JSON/path parameters). -func preferPublicName(fld reflect.StructField) string { - name, _, _ := strings.Cut(fld.Tag.Get("json"), ",") - if name != "" && name != "-" { - return name - } - - return fld.Name -} diff --git a/internal/validate/validate_test.go b/internal/validate/validate_test.go deleted file mode 100644 index 81e2d69c..00000000 --- a/internal/validate/validate_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package validate - -import ( - "reflect" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestFromValidator(t *testing.T) { - t.Parallel() - - // Fields have JSON tags so we can verify those are used over the - // property name. - type TestStruct struct { - MinInt int `json:"min_int" validate:"min=1"` - MinSlice []string `json:"min_slice" validate:"min=1"` - MinString string `json:"min_string" validate:"min=1"` - MaxInt int `json:"max_int" validate:"max=0"` - MaxSlice []string `json:"max_slice" validate:"max=0"` - MaxString string `json:"max_string" validate:"max=0"` - OneOf string `json:"one_of" validate:"oneof=blue green"` - Required string `json:"required" validate:"required"` - Unsupported string `json:"unsupported" validate:"e164"` - } - - validTestStruct := func() *TestStruct { - return &TestStruct{ - MinInt: 1, - MinSlice: []string{"1"}, - MinString: "value", - MaxInt: 0, - MaxSlice: []string{}, - MaxString: "", - OneOf: "blue", - Required: "value", - Unsupported: "+1123456789", - } - } - - t.Run("MaxInt", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MaxInt = 1 - require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MaxSlice", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MaxSlice = []string{"1"} - require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MaxString", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MaxString = "value" - require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MinInt", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinInt = 0 - require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MinSlice", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinSlice = nil - require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MinString", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinString = "" - require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("OneOf", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.OneOf = "red" - require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("Required", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.Required = "" - require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("Unsupported", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.Unsupported = "abc" - require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct))) - }) - - t.Run("MultipleErrors", func(t *testing.T) { - t.Parallel() - - testStruct := validTestStruct() - testStruct.MinInt = 0 - testStruct.Required = "" - require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct))) - }) -} - -func TestPreferPublicNames(t *testing.T) { - t.Parallel() - - type testStruct struct { - JSONNameField string `json:"json_name"` - StructNameField string `apiquery:"-"` - } - - require.Equal(t, "json_name", - preferPublicName(reflect.TypeOf(testStruct{}).Field(0))) - require.Equal(t, "StructNameField", - preferPublicName(reflect.TypeOf(testStruct{}).Field(1))) -}