Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,46 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if err != nil {
return err
}
if allow {
return next(c)
}

// Fallback to legacy token based CSRF protection

// Get or generate token (needed even for "safe" requests to render forms)
token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
token = randomString(config.TokenLength)
} else {
token = k.Value // Reuse token
}

// If request is deemed safe by Sec-Fetch-Site, skip token validation
// but still set token in context and cookie for form rendering
if allow {
cookie := new(http.Cookie)
cookie.Name = config.CookieName
cookie.Value = token
if config.CookiePath != "" {
cookie.Path = config.CookiePath
}
if config.CookieDomain != "" {
cookie.Domain = config.CookieDomain
}
if config.CookieSameSite != http.SameSiteDefaultMode {
cookie.SameSite = config.CookieSameSite
}
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.Secure = config.CookieSecure
cookie.HttpOnly = config.CookieHTTPOnly
c.SetCookie(cookie)

// Store token in context for handlers
c.Set(config.ContextKey, token)

// Protect clients from caching the response
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)

return next(c)
}

// Fallback to legacy token based CSRF protection

switch c.Request().Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
Expand Down
90 changes: 90 additions & 0 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,93 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
})
}
}

func TestCSRF_SecFetchSite_SetsTokenInContext(t *testing.T) {
// Test for issue #2874: CSRF middleware should set token in context
// even when Sec-Fetch-Site validation passes
var testCases = []struct {
name string
whenMethod string
whenSecFetchSite string
expectTokenInCtx bool
expectCookie bool
}{
{
name: "ok, GET with Sec-Fetch-Site: none sets token in context",
whenMethod: http.MethodGet,
whenSecFetchSite: "none",
expectTokenInCtx: true,
expectCookie: true,
},
{
name: "ok, GET with Sec-Fetch-Site: same-origin sets token in context",
whenMethod: http.MethodGet,
whenSecFetchSite: "same-origin",
expectTokenInCtx: true,
expectCookie: true,
},
{
name: "ok, POST with Sec-Fetch-Site: none sets token in context",
whenMethod: http.MethodPost,
whenSecFetchSite: "none",
expectTokenInCtx: true,
expectCookie: true,
},
{
name: "ok, POST with Sec-Fetch-Site: same-origin sets token in context",
whenMethod: http.MethodPost,
whenSecFetchSite: "same-origin",
expectTokenInCtx: true,
expectCookie: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(tc.whenMethod, "/", nil)
req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

var tokenInContext string
csrf := CSRFWithConfig(CSRFConfig{
TokenLookup: "form:csrf",
})

h := csrf(func(c echo.Context) error {
// Handler expects CSRF token in context for form rendering
token, ok := c.Get("csrf").(string)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, "CSRF token not found")
}
tokenInContext = token
return c.String(http.StatusOK, "test")
})

err := h(c)
assert.NoError(t, err)

if tc.expectTokenInCtx {
assert.NotEmpty(t, tokenInContext, "token should be set in context")
} else {
assert.Empty(t, tokenInContext, "token should not be set in context")
}

if tc.expectCookie {
cookies := rec.Result().Cookies()
assert.NotEmpty(t, cookies, "CSRF cookie should be set")
var csrfCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == "_csrf" {
csrfCookie = cookie
break
}
}
assert.NotNil(t, csrfCookie, "CSRF cookie should exist")
assert.NotEmpty(t, csrfCookie.Value, "CSRF cookie value should not be empty")
assert.Equal(t, tokenInContext, csrfCookie.Value, "token in context should match cookie value")
}
})
}
}