diff --git a/apiendpoint/api_endpoint.go b/apiendpoint/api_endpoint.go index d952cf0..fa3b33f 100644 --- a/apiendpoint/api_endpoint.go +++ b/apiendpoint/api_endpoint.go @@ -17,6 +17,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/riverqueue/apiframe/apierror" + "github.com/riverqueue/apiframe/apimiddleware" "github.com/riverqueue/apiframe/internal/validate" ) @@ -87,18 +88,40 @@ func (m *EndpointMeta) validate() { } } +type MountOpts struct { + Logger *slog.Logger + // MiddlewareStack is a stack of middleware that will be mounted in front of + // the API endpoint handler. If not specified, no middleware will be used. + MiddlewareStack *apimiddleware.MiddlewareStack +} + // Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log // information about endpoint execution. -func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointExecuteInterface[TReq, TResp]) EndpointInterface { +func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteInterface[TReq, TResp], opts *MountOpts) EndpointInterface { + if opts == nil { + opts = &MountOpts{} + } + + logger := opts.Logger + if logger == nil { + logger = slog.Default() + } + apiEndpoint.SetLogger(logger) meta := apiEndpoint.Meta() meta.validate() // panic on problem apiEndpoint.SetMeta(meta) - mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) { - executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute) - }) + innerHandler := func(w http.ResponseWriter, r *http.Request) { + executeAPIEndpoint(w, r, opts.Logger, meta, apiEndpoint.Execute) + } + + if opts.MiddlewareStack != nil { + mux.Handle(meta.Pattern, opts.MiddlewareStack.Mount(http.HandlerFunc(innerHandler))) + } else { + mux.HandleFunc(meta.Pattern, innerHandler) + } return apiEndpoint } diff --git a/apiendpoint/api_endpoint_test.go b/apiendpoint/api_endpoint_test.go index 793a405..b9b667c 100644 --- a/apiendpoint/api_endpoint_test.go +++ b/apiendpoint/api_endpoint_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "net/http" "net/http/httptest" "testing" @@ -25,6 +26,7 @@ func TestMountAndServe(t *testing.T) { ctx := context.Background() type testBundle struct { + logger *slog.Logger recorder *httptest.ResponseRecorder } @@ -34,12 +36,14 @@ func TestMountAndServe(t *testing.T) { var ( logger = riversharedtest.Logger(t) mux = http.NewServeMux() + opts = &MountOpts{Logger: logger} ) - Mount(mux, logger, &getEndpoint{}) - Mount(mux, logger, &postEndpoint{}) + Mount(mux, &getEndpoint{}, opts) + Mount(mux, &postEndpoint{}, opts) return mux, &testBundle{ + logger: logger, recorder: httptest.NewRecorder(), } } @@ -79,6 +83,36 @@ func TestMountAndServe(t *testing.T) { requireStatusAndResponse(t, http.StatusMethodNotAllowed, "Method Not Allowed\n", bundle.recorder) }) + t.Run("NilOptions", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + + mux := http.NewServeMux() + Mount(mux, &postEndpoint{}, nil) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) + }) + + t.Run("OptionsWithCustomLogger", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + + mux := http.NewServeMux() + Mount(mux, &postEndpoint{}, &MountOpts{Logger: bundle.logger}) + + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + mux.ServeHTTP(bundle.recorder, req) + + requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) + }) + t.Run("PostEndpoint", func(t *testing.T) { t.Parallel()