diff --git a/apiendpoint/api_endpoint.go b/apiendpoint/api_endpoint.go index fa3b33f..3787c32 100644 --- a/apiendpoint/api_endpoint.go +++ b/apiendpoint/api_endpoint.go @@ -4,6 +4,7 @@ package apiendpoint import ( + "bytes" "context" "encoding/json" "errors" @@ -146,6 +147,8 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ return apierror.NewBadRequestf("Error unmarshaling request body: %s.", err) } } + + r.Body = io.NopCloser(bytes.NewReader(reqData)) } if rawExtractor, ok := any(&req).(RawExtractor); ok { diff --git a/apiendpoint/api_endpoint_test.go b/apiendpoint/api_endpoint_test.go index b9b667c..036896a 100644 --- a/apiendpoint/api_endpoint_test.go +++ b/apiendpoint/api_endpoint_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "log/slog" "net/http" "net/http/httptest" @@ -48,15 +49,15 @@ func TestMountAndServe(t *testing.T) { } } - t.Run("GetEndpointAndExtractRaw", func(t *testing.T) { + t.Run("GetEndpoint", func(t *testing.T) { t.Parallel() mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", nil) + req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint", nil) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusOK, &getResponse{Message: "Hello."}, bundle.recorder) }) t.Run("BodyIgnoredOnGet", func(t *testing.T) { @@ -64,11 +65,11 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint/Hello.", + req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint", bytes.NewBuffer(mustMarshalJSON(t, &getRequest{IgnoredJSONMessage: "Ignored hello."}))) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusOK, &postResponse{Message: "Hello."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusOK, &getResponse{Message: "Hello."}, bundle.recorder) }) t.Run("MethodNotAllowed", func(t *testing.T) { @@ -76,7 +77,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint/Hello.", nil) + req := httptest.NewRequest(http.MethodPost, "/api/get-endpoint", nil) mux.ServeHTTP(bundle.recorder, req) // This error comes from net/http. @@ -91,11 +92,11 @@ func TestMountAndServe(t *testing.T) { mux := http.NewServeMux() Mount(mux, &postEndpoint{}, nil) - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + reqPayload := mustMarshalJSON(t, &postRequest{Message: "Hello."}) + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(reqPayload)) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{ID: "123", Message: "Hello.", RawPayload: reqPayload}, bundle.recorder) }) t.Run("OptionsWithCustomLogger", func(t *testing.T) { @@ -104,25 +105,24 @@ func TestMountAndServe(t *testing.T) { _, bundle := setup(t) mux := http.NewServeMux() - Mount(mux, &postEndpoint{}, &MountOpts{Logger: bundle.logger}) + Mount(mux, &getEndpoint{}, &MountOpts{Logger: bundle.logger}) - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + req := httptest.NewRequest(http.MethodGet, "/api/get-endpoint", nil) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusOK, &getResponse{Message: "Hello."}, bundle.recorder) }) - t.Run("PostEndpoint", func(t *testing.T) { + t.Run("PostEndpointAndExtractRaw", func(t *testing.T) { t.Parallel() mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", - bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) + reqPayload := mustMarshalJSON(t, &postRequest{Message: "Hello."}) + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(reqPayload)) mux.ServeHTTP(bundle.recorder, req) - requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{Message: "Hello."}, bundle.recorder) + requireStatusAndJSONResponse(t, http.StatusCreated, &postResponse{ID: "123", Message: "Hello.", RawPayload: reqPayload}, bundle.recorder) }) t.Run("ValidationError", func(t *testing.T) { @@ -130,7 +130,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", nil) + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", nil) mux.ServeHTTP(bundle.recorder, req) requireStatusAndJSONResponse(t, http.StatusBadRequest, &apierror.APIError{Message: "Field `message` is required."}, bundle.recorder) @@ -141,7 +141,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeAPIError: true, Message: "Hello."}))) mux.ServeHTTP(bundle.recorder, req) @@ -153,7 +153,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakePostgresError: true, Message: "Hello."}))) mux.ServeHTTP(bundle.recorder, req) @@ -168,7 +168,7 @@ func TestMountAndServe(t *testing.T) { ctx, cancel := context.WithDeadline(ctx, time.Now()) t.Cleanup(cancel) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint", + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(mustMarshalJSON(t, &postRequest{Message: "Hello."}))) require.NoError(t, err) mux.ServeHTTP(bundle.recorder, req) @@ -181,7 +181,7 @@ func TestMountAndServe(t *testing.T) { mux, bundle := setup(t) - req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint", + req := httptest.NewRequest(http.MethodPost, "/api/post-endpoint/123", bytes.NewBuffer(mustMarshalJSON(t, &postRequest{MakeInternalError: true, Message: "Hello."}))) mux.ServeHTTP(bundle.recorder, req) @@ -274,19 +274,13 @@ type getEndpoint struct { func (*getEndpoint) Meta() *EndpointMeta { return &EndpointMeta{ - Pattern: "GET /api/get-endpoint/{message}", + Pattern: "GET /api/get-endpoint", StatusCode: http.StatusOK, } } type getRequest struct { IgnoredJSONMessage string `json:"ignored_json" validate:"-"` - Message string `json:"-" validate:"required"` -} - -func (req *getRequest) ExtractRaw(r *http.Request) error { - req.Message = r.PathValue("message") - return nil } type getResponse struct { @@ -299,7 +293,7 @@ func (a *getEndpoint) Execute(_ context.Context, req *getRequest) (*getResponse, return &getResponse{Message: req.IgnoredJSONMessage}, nil } - return &getResponse{Message: req.Message}, nil + return &getResponse{Message: "Hello."}, nil } // @@ -312,20 +306,34 @@ type postEndpoint struct { func (*postEndpoint) Meta() *EndpointMeta { return &EndpointMeta{ - Pattern: "POST /api/post-endpoint", + Pattern: "POST /api/post-endpoint/{id}", StatusCode: http.StatusCreated, } } type postRequest struct { + ID string `json:"-" validate:"-"` MakeAPIError bool `json:"make_api_error" validate:"-"` MakeInternalError bool `json:"make_internal_error" validate:"-"` MakePostgresError bool `json:"make_postgres_error" validate:"-"` Message string `json:"message" validate:"required"` + RawPayload []byte `json:"-" validate:"-"` +} + +func (req *postRequest) ExtractRaw(r *http.Request) error { + var err error + if req.RawPayload, err = io.ReadAll(r.Body); err != nil { + return err + } + + req.ID = r.PathValue("id") + return nil } type postResponse struct { - Message string `json:"message"` + ID string `json:"id"` + Message string `json:"message"` + RawPayload json.RawMessage `json:"raw_payload"` } func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResponse, error) { @@ -346,5 +354,5 @@ func (a *postEndpoint) Execute(ctx context.Context, req *postRequest) (*postResp return nil, fmt.Errorf("error running Postgres query: %w", &pgconn.PgError{Code: pgerrcode.InsufficientPrivilege}) } - return &postResponse{Message: req.Message}, nil + return &postResponse{ID: req.ID, Message: req.Message, RawPayload: req.RawPayload}, nil }