Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -100,6 +101,7 @@ type Echo struct {

// formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm)
formParseMaxMemory int64
AutoHead bool
}

// JSONSerializer is the interface that encodes and decodes JSON to and from interfaces.
Expand Down Expand Up @@ -288,6 +290,11 @@ type Config struct {
// FormParseMaxMemory is default value for memory limit that is used
// when parsing multipart forms (See (*http.Request).ParseMultipartForm)
FormParseMaxMemory int64

// AutoHead enables automatic registration of HEAD routes for GET routes.
// When enabled, a HEAD request to a GET-only path will be handled automatically
// using the same handler as GET, with the response body suppressed.
AutoHead bool
}

// NewWithConfig creates an instance of Echo with given configuration.
Expand Down Expand Up @@ -326,6 +333,9 @@ func NewWithConfig(config Config) *Echo {
if config.FormParseMaxMemory > 0 {
e.formParseMaxMemory = config.FormParseMaxMemory
}
if config.AutoHead {
e.AutoHead = config.AutoHead
}
return e
}

Expand Down Expand Up @@ -421,6 +431,67 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler {
}
}

// headResponseWriter wraps an http.ResponseWriter and suppresses the response body
// while preserving headers and status code. Used for automatic HEAD route handling.
// It counts the bytes that would have been written so we can set Content-Length accurately.
type headResponseWriter struct {
http.ResponseWriter
bytesWritten int64
statusCode int
wroteHeader bool
}

// Write intercepts writes to the response body and counts bytes without actually writing them.
func (hw *headResponseWriter) Write(b []byte) (int, error) {
if !hw.wroteHeader {
hw.statusCode = http.StatusOK
hw.wroteHeader = true
}
hw.bytesWritten += int64(len(b))
// Return success without actually writing the body for HEAD requests
return len(b), nil
}

// WriteHeader intercepts the status code but still writes it to the underlying ResponseWriter.
func (hw *headResponseWriter) WriteHeader(statusCode int) {
if !hw.wroteHeader {
hw.statusCode = statusCode
hw.wroteHeader = true
hw.ResponseWriter.WriteHeader(statusCode)
}
}

// Unwrap returns the underlying http.ResponseWriter for compatibility with echo.Response unwrapping.
func (hw *headResponseWriter) Unwrap() http.ResponseWriter {
return hw.ResponseWriter
}

func wrapHeadHandler(handler HandlerFunc) HandlerFunc {
return func(c *Context) error {
if c.Request().Method != http.MethodHead {
return handler(c)
}
originalWriter := c.Response()
headWriter := &headResponseWriter{ResponseWriter: originalWriter}

c.SetResponse(headWriter)
defer func() {
c.SetResponse(originalWriter)
}()
err := handler(c)

if headWriter.bytesWritten > 0 {
originalWriter.Header().Set("Content-Length", strconv.FormatInt(headWriter.bytesWritten, 10))
}

if !headWriter.wroteHeader && headWriter.statusCode > 0 {
originalWriter.WriteHeader(headWriter.statusCode)
}

return err
}
}

// Pre adds middleware to the chain which is run before router tries to find matching route.
// Meaning middleware is executed even for 404 (not found) cases.
func (e *Echo) Pre(middleware ...MiddlewareFunc) {
Expand Down Expand Up @@ -634,6 +705,20 @@ func (e *Echo) add(route Route) (RouteInfo, error) {
if paramsCount > e.contextPathParamAllocSize.Load() {
e.contextPathParamAllocSize.Store(paramsCount)
}

// Auto-register HEAD route for GET if AutoHead is enabled
if e.AutoHead && route.Method == http.MethodGet {
headRoute := Route{
Method: http.MethodHead,
Path: route.Path,
Handler: wrapHeadHandler(route.Handler),
Middlewares: route.Middlewares,
Name: route.Name,
}
// Attempt to add HEAD route, but ignore errors if an explicit HEAD route already exists
_, _ = e.router.Add(headRoute)
}

return ri, nil
}

Expand All @@ -642,6 +727,7 @@ func (e *Echo) add(route Route) (RouteInfo, error) {
func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo {
ri, err := e.add(
Route{

Method: method,
Path: path,
Handler: handler,
Expand Down
173 changes: 173 additions & 0 deletions echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,159 @@ func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Code)
}

func TestAutoHeadRoute(t *testing.T) {
tests := []struct {
name string
autoHead bool
method string
wantBody bool
wantCode int
wantCLen bool // expect Content-Length header
}{
{
name: "AutoHead disabled - HEAD returns 405",
autoHead: false,
method: http.MethodHead,
wantCode: http.StatusMethodNotAllowed,
wantBody: false,
},
{
name: "AutoHead enabled - HEAD returns 200 with Content-Length",
autoHead: true,
method: http.MethodHead,
wantCode: http.StatusOK,
wantBody: false,
wantCLen: true,
},
{
name: "GET request works normally with AutoHead enabled",
autoHead: true,
method: http.MethodGet,
wantCode: http.StatusOK,
wantBody: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create Echo instance with AutoHead configuration
e := New()
e.AutoHead = tt.autoHead

// Register a simple GET route
testBody := "Hello, World!"
e.GET("/hello", func(c *Context) error {
return c.String(http.StatusOK, testBody)
})

// Create request and response
req := httptest.NewRequest(tt.method, "/hello", nil)
rec := httptest.NewRecorder()

// Serve the request
e.ServeHTTP(rec, req)

// Verify status code
if rec.Code != tt.wantCode {
t.Errorf("expected status %d, got %d", tt.wantCode, rec.Code)
}

// Verify response body
if tt.wantBody {
if rec.Body.String() != testBody {
t.Errorf("expected body %q, got %q", testBody, rec.Body.String())
}
} else {
if rec.Body.String() != "" {
t.Errorf("expected empty body for HEAD, got %q", rec.Body.String())
}
}

// Verify Content-Length header for HEAD
if tt.wantCLen && tt.method == http.MethodHead {
clen := rec.Header().Get("Content-Length")
if clen == "" {
t.Error("expected Content-Length header for HEAD request")
}
}
})
}
}

func TestAutoHeadExplicitHeadTakesPrecedence(t *testing.T) {
e := New()
e.AutoHead = true

// Register explicit HEAD route FIRST with custom behavior
e.HEAD("/api/users", func(c *Context) error {
c.Response().Header().Set("X-Custom-Header", "explicit-head")
return c.NoContent(http.StatusOK)
})

// Then register GET route - AutoHead will try to add a HEAD route but fail silently
// since one already exists
e.GET("/api/users", func(c *Context) error {
return c.JSON(http.StatusOK, map[string]string{"name": "John"})
})

// Test that the explicit HEAD route behavior is preserved
req := httptest.NewRequest(http.MethodHead, "/api/users", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rec.Code)
}

if rec.Header().Get("X-Custom-Header") != "explicit-head" {
t.Error("expected explicit HEAD route to be used")
}

// Verify body is empty
if rec.Body.String() != "" {
t.Errorf("expected empty body for HEAD, got %q", rec.Body.String())
}
}

func TestAutoHeadWithMiddleware(t *testing.T) {
e := New()
e.AutoHead = true

// Add request logger middleware
middlewareExecuted := false
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c *Context) error {
middlewareExecuted = true
c.Response().Header().Set("X-Middleware", "executed")
return next(c)
}
})

// Register GET route
e.GET("/test", func(c *Context) error {
return c.String(http.StatusOK, "test response")
})

// Test HEAD request goes through middleware
req := httptest.NewRequest(http.MethodHead, "/test", nil)
rec := httptest.NewRecorder()

middlewareExecuted = false
e.ServeHTTP(rec, req)

if !middlewareExecuted {
t.Error("middleware should execute for automatic HEAD route")
}

if rec.Header().Get("X-Middleware") != "executed" {
t.Error("middleware header not set")
}

if rec.Body.String() != "" {
t.Errorf("expected empty body for HEAD, got %q", rec.Body.String())
}
}

func benchmarkEchoRoutes(b *testing.B, routes []testRoute) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
Expand Down Expand Up @@ -1278,3 +1431,23 @@ func BenchmarkEchoGitHubAPIMisses(b *testing.B) {
func BenchmarkEchoParseAPI(b *testing.B) {
benchmarkEchoRoutes(b, parseAPI)
}

func BenchmarkAutoHeadRoute(b *testing.B) {
e := New()
e.AutoHead = true

e.GET("/bench", func(c *Context) error {
return c.String(http.StatusOK, "benchmark response body")
})

req := httptest.NewRequest(http.MethodHead, "/bench", nil)
rec := httptest.NewRecorder()

b.ReportAllocs()
b.ResetTimer()

for i := 0; i < b.N; i++ {
rec.Body.Reset()
e.ServeHTTP(rec, req)
}
}