diff --git a/middleware/csrf.go b/middleware/csrf.go index f9d3293b0..63a2c2900 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -162,12 +162,8 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if err != nil { return err } - if allow { - return next(c) - } - - // Fallback to legacy token based CSRF protection + // Get or generate token (needed even for "safe" requests to render forms) token := "" if k, err := c.Cookie(config.CookieName); err != nil { token = randomString(config.TokenLength) @@ -175,6 +171,37 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { token = k.Value // Reuse token } + // If request is deemed safe by Sec-Fetch-Site, skip token validation + // but still set token in context and cookie for form rendering + if allow { + cookie := new(http.Cookie) + cookie.Name = config.CookieName + cookie.Value = token + if config.CookiePath != "" { + cookie.Path = config.CookiePath + } + if config.CookieDomain != "" { + cookie.Domain = config.CookieDomain + } + if config.CookieSameSite != http.SameSiteDefaultMode { + cookie.SameSite = config.CookieSameSite + } + cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) + cookie.Secure = config.CookieSecure + cookie.HttpOnly = config.CookieHTTPOnly + c.SetCookie(cookie) + + // Store token in context for handlers + c.Set(config.ContextKey, token) + + // Protect clients from caching the response + c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) + + return next(c) + } + + // Fallback to legacy token based CSRF protection + switch c.Request().Method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: default: diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 85b7f1077..f79e0d8ba 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -850,3 +850,93 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { }) } } + +func TestCSRF_SecFetchSite_SetsTokenInContext(t *testing.T) { + // Test for issue #2874: CSRF middleware should set token in context + // even when Sec-Fetch-Site validation passes + var testCases = []struct { + name string + whenMethod string + whenSecFetchSite string + expectTokenInCtx bool + expectCookie bool + }{ + { + name: "ok, GET with Sec-Fetch-Site: none sets token in context", + whenMethod: http.MethodGet, + whenSecFetchSite: "none", + expectTokenInCtx: true, + expectCookie: true, + }, + { + name: "ok, GET with Sec-Fetch-Site: same-origin sets token in context", + whenMethod: http.MethodGet, + whenSecFetchSite: "same-origin", + expectTokenInCtx: true, + expectCookie: true, + }, + { + name: "ok, POST with Sec-Fetch-Site: none sets token in context", + whenMethod: http.MethodPost, + whenSecFetchSite: "none", + expectTokenInCtx: true, + expectCookie: true, + }, + { + name: "ok, POST with Sec-Fetch-Site: same-origin sets token in context", + whenMethod: http.MethodPost, + whenSecFetchSite: "same-origin", + expectTokenInCtx: true, + expectCookie: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(tc.whenMethod, "/", nil) + req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var tokenInContext string + csrf := CSRFWithConfig(CSRFConfig{ + TokenLookup: "form:csrf", + }) + + h := csrf(func(c echo.Context) error { + // Handler expects CSRF token in context for form rendering + token, ok := c.Get("csrf").(string) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, "CSRF token not found") + } + tokenInContext = token + return c.String(http.StatusOK, "test") + }) + + err := h(c) + assert.NoError(t, err) + + if tc.expectTokenInCtx { + assert.NotEmpty(t, tokenInContext, "token should be set in context") + } else { + assert.Empty(t, tokenInContext, "token should not be set in context") + } + + if tc.expectCookie { + cookies := rec.Result().Cookies() + assert.NotEmpty(t, cookies, "CSRF cookie should be set") + var csrfCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == "_csrf" { + csrfCookie = cookie + break + } + } + assert.NotNil(t, csrfCookie, "CSRF cookie should exist") + assert.NotEmpty(t, csrfCookie.Value, "CSRF cookie value should not be empty") + assert.Equal(t, tokenInContext, csrfCookie.Value, "token in context should match cookie value") + } + }) + } +}