From 56a9a7f1ce2101a7f42f71cd426c5cba0f7584cc Mon Sep 17 00:00:00 2001 From: kubrickcode Date: Sat, 29 Nov 2025 06:15:35 +0000 Subject: [PATCH] refactor: prevent internal error exposure in API responses Resolved security issue where internal implementation details (JWT parsing, Redis connection, etc.) were exposed to clients via direct err.Error() return - Separate internal error logging from external responses with WriteErrorWithLog - Map 20 sentinel errors to user-friendly messages - Apply to all 3 API handlers (callback, refresh, verify) - Add error response tests for each API handler fix #57 --- server/api/callback/index.go | 18 +- server/api/callback/index_test.go | 143 +++++++++++++++ server/api/refresh/index.go | 24 +-- server/api/refresh/index_test.go | 155 ++++++++++++++++ server/api/verify/index.go | 8 +- server/api/verify/index_test.go | 155 ++++++++++++++++ server/pkg/httputil/errors.go | 67 +++++++ server/pkg/httputil/errors_test.go | 273 +++++++++++++++++++++++++++++ 8 files changed, 818 insertions(+), 25 deletions(-) create mode 100644 server/api/callback/index_test.go create mode 100644 server/api/refresh/index_test.go create mode 100644 server/api/verify/index_test.go create mode 100644 server/pkg/httputil/errors.go create mode 100644 server/pkg/httputil/errors_test.go diff --git a/server/api/callback/index.go b/server/api/callback/index.go index fca5d6c..62d54ca 100644 --- a/server/api/callback/index.go +++ b/server/api/callback/index.go @@ -42,53 +42,53 @@ func Handler(w http.ResponseWriter, r *http.Request) { oauthClient, err := oauth.GetClient() if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "OAuth configuration missing") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "OAuth service unavailable") return } token, err := oauthClient.ExchangeCode(code) if err != nil { - httputil.WriteError(w, http.StatusBadRequest, "exchange_failed", "Failed to exchange authorization code") + httputil.WriteErrorWithLog(w, err, http.StatusBadRequest, "exchange_failed", "Failed to exchange authorization code") return } redisClient, err := redis.GetClient() if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Redis connection failed") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Storage service unavailable") return } sessionID, err := crypto.GenerateSessionID() if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to generate session ID") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to create session") return } if err := redisClient.Set(redis.SessionKeyPrefix+sessionID, token.AccessToken, redis.SessionTTL); err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to store session") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to store session") return } refreshTokenID, err := crypto.GenerateRefreshTokenID() if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to generate refresh token ID") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to create session") return } if err := redisClient.Set(redis.RefreshTokenKeyPrefix+refreshTokenID, sessionID, redis.RefreshTokenTTL); err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to store refresh token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to store session") return } accessToken, err := jwt.GenerateAccessToken(sessionID) if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to generate access token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to create access token") return } refreshToken, err := jwt.GenerateRefreshToken(refreshTokenID, sessionID) if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to generate refresh token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to create refresh token") return } diff --git a/server/api/callback/index_test.go b/server/api/callback/index_test.go new file mode 100644 index 0000000..573bfb4 --- /dev/null +++ b/server/api/callback/index_test.go @@ -0,0 +1,143 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github-project-status-viewer-server/pkg/httputil" +) + +func TestHandler_ErrorResponseSanitization(t *testing.T) { + tests := []struct { + expectedCode string + expectedDescription string + name string + queryParams string + shouldNotContain []string + }{ + { + name: "missing code parameter should return sanitized error", + queryParams: "?state=abc123", + expectedCode: "missing_code", + expectedDescription: "Authorization code is required", + shouldNotContain: []string{"internal", "stack", "nil"}, + }, + { + name: "missing state parameter should return sanitized error", + queryParams: "?code=test_code", + expectedCode: "missing_state", + expectedDescription: "State parameter is required for CSRF protection", + shouldNotContain: []string{"internal", "stack", "nil"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/api/callback"+tt.queryParams, nil) + w := httptest.NewRecorder() + + Handler(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Status code = %v, want %v", w.Code, http.StatusBadRequest) + } + + var apiError httputil.APIError + if err := json.NewDecoder(w.Body).Decode(&apiError); err != nil { + t.Fatalf("Failed to decode error response: %v", err) + } + + if apiError.Code != tt.expectedCode { + t.Errorf("Error code = %v, want %v", apiError.Code, tt.expectedCode) + } + + if apiError.Description != tt.expectedDescription { + t.Errorf("Error description = %v, want %v", apiError.Description, tt.expectedDescription) + } + + responseBody := w.Body.String() + for _, forbidden := range tt.shouldNotContain { + if containsString(responseBody, forbidden) { + t.Errorf("Response should not contain '%s' but body contains: %s", forbidden, responseBody) + } + } + }) + } +} + +func TestHandler_MethodValidation(t *testing.T) { + tests := []struct { + method string + name string + wantStatus int + }{ + { + name: "GET method should be accepted", + method: http.MethodGet, + wantStatus: http.StatusBadRequest, + }, + { + name: "POST method should be rejected", + method: http.MethodPost, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "PUT method should be rejected", + method: http.MethodPut, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "DELETE method should be rejected", + method: http.MethodDelete, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "OPTIONS method should be accepted for CORS", + method: http.MethodOptions, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/api/callback?code=test&state=test", nil) + w := httptest.NewRecorder() + + Handler(w, req) + + if w.Code != tt.wantStatus { + t.Errorf("Status code = %v, want %v", w.Code, tt.wantStatus) + } + }) + } +} + +func TestHandler_CORSHeaders(t *testing.T) { + t.Setenv("CHROME_EXTENSION_ID", "test-extension-id") + + req := httptest.NewRequest(http.MethodGet, "/api/callback?code=test&state=test", nil) + w := httptest.NewRecorder() + + Handler(w, req) + + corsHeader := w.Header().Get("Access-Control-Allow-Origin") + expectedOrigin := "chrome-extension://test-extension-id" + if corsHeader != expectedOrigin { + t.Errorf("Expected CORS header to be %s, got %s", expectedOrigin, corsHeader) + } +} + +func containsString(s, substr string) bool { + return len(s) >= len(substr) && stringContains(s, substr) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/server/api/refresh/index.go b/server/api/refresh/index.go index 0a8634f..b17d9a8 100644 --- a/server/api/refresh/index.go +++ b/server/api/refresh/index.go @@ -33,67 +33,67 @@ func Handler(w http.ResponseWriter, r *http.Request) { refreshToken := tokenString[7:] claims, err := jwt.ValidateRefreshToken(refreshToken) if err != nil { - httputil.WriteError(w, http.StatusUnauthorized, "invalid_refresh_token", err.Error()) + httputil.WriteErrorWithLog(w, err, http.StatusUnauthorized, "invalid_refresh_token", "Invalid or expired refresh token") return } redisClient, err := redis.GetClient() if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Redis connection failed") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Storage service unavailable") return } storedSessionID, err := redisClient.Get(redis.RefreshTokenKeyPrefix + claims.RefreshTokenID) if err != nil { if errors.Is(err, pkgerrors.ErrKeyNotFound) { - httputil.WriteError(w, http.StatusUnauthorized, "refresh_token_revoked", "Refresh token has been revoked or expired") + httputil.WriteErrorWithLog(w, pkgerrors.ErrRefreshTokenRevoked, http.StatusUnauthorized, "refresh_token_revoked", "Refresh token has been revoked or expired") } else { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to verify refresh token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to verify refresh token") } return } if storedSessionID != claims.SessionID { - httputil.WriteError(w, http.StatusUnauthorized, "session_mismatch", "Session mismatch detected") + httputil.WriteErrorWithLog(w, pkgerrors.ErrSessionMismatch, http.StatusUnauthorized, "session_mismatch", "Session mismatch detected") return } exists, err := redisClient.Exists(redis.SessionKeyPrefix + claims.SessionID) if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to check session") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to verify session") return } if !exists { - httputil.WriteError(w, http.StatusUnauthorized, "session_not_found", "Session expired or invalid") + httputil.WriteErrorWithLog(w, pkgerrors.ErrSessionNotFound, http.StatusUnauthorized, "session_not_found", "Session expired or invalid") return } if err := redisClient.Delete(redis.RefreshTokenKeyPrefix + claims.RefreshTokenID); err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to revoke old refresh token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to revoke old refresh token") return } newRefreshTokenID, err := crypto.GenerateRefreshTokenID() if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to generate refresh token ID") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to create refresh token") return } if err := redisClient.Set(redis.RefreshTokenKeyPrefix+newRefreshTokenID, claims.SessionID, redis.RefreshTokenTTL); err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to store new refresh token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to store refresh token") return } newAccessToken, err := jwt.GenerateAccessToken(claims.SessionID) if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to generate access token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to create access token") return } newRefreshToken, err := jwt.GenerateRefreshToken(newRefreshTokenID, claims.SessionID) if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Failed to generate refresh token") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to create refresh token") return } diff --git a/server/api/refresh/index_test.go b/server/api/refresh/index_test.go new file mode 100644 index 0000000..ee17dbf --- /dev/null +++ b/server/api/refresh/index_test.go @@ -0,0 +1,155 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github-project-status-viewer-server/pkg/httputil" +) + +func TestHandler_ErrorResponseSanitization(t *testing.T) { + tests := []struct { + authHeader string + expectedCode string + expectedDescription string + name string + shouldNotContain []string + }{ + { + name: "missing bearer token should return sanitized error", + authHeader: "", + expectedCode: "invalid_token", + expectedDescription: "Bearer token required", + shouldNotContain: []string{"internal", "stack", "nil"}, + }, + { + name: "invalid bearer format should return sanitized error", + authHeader: "InvalidFormat", + expectedCode: "invalid_token", + expectedDescription: "Bearer token required", + shouldNotContain: []string{"internal", "stack", "nil"}, + }, + { + name: "invalid token should not expose internal error details", + authHeader: "Bearer invalid_token_format", + expectedCode: "invalid_refresh_token", + expectedDescription: "Invalid or expired refresh token", + shouldNotContain: []string{"jwt", "parse", "crypto", "signature", "algorithm"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/refresh", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + w := httptest.NewRecorder() + + Handler(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Status code = %v, want %v", w.Code, http.StatusUnauthorized) + } + + var apiError httputil.APIError + if err := json.NewDecoder(w.Body).Decode(&apiError); err != nil { + t.Fatalf("Failed to decode error response: %v", err) + } + + if apiError.Code != tt.expectedCode { + t.Errorf("Error code = %v, want %v", apiError.Code, tt.expectedCode) + } + + if apiError.Description != tt.expectedDescription { + t.Errorf("Error description = %v, want %v", apiError.Description, tt.expectedDescription) + } + + responseBody := w.Body.String() + for _, forbidden := range tt.shouldNotContain { + if containsString(responseBody, forbidden) { + t.Errorf("Response should not contain '%s' but body contains: %s", forbidden, responseBody) + } + } + }) + } +} + +func TestHandler_MethodValidation(t *testing.T) { + tests := []struct { + method string + name string + wantStatus int + }{ + { + name: "POST method should be accepted", + method: http.MethodPost, + wantStatus: http.StatusUnauthorized, + }, + { + name: "GET method should be rejected", + method: http.MethodGet, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "PUT method should be rejected", + method: http.MethodPut, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "DELETE method should be rejected", + method: http.MethodDelete, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "OPTIONS method should be accepted for CORS", + method: http.MethodOptions, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/api/refresh", nil) + req.Header.Set("Authorization", "Bearer test_token") + w := httptest.NewRecorder() + + Handler(w, req) + + if w.Code != tt.wantStatus { + t.Errorf("Status code = %v, want %v", w.Code, tt.wantStatus) + } + }) + } +} + +func TestHandler_CORSHeaders(t *testing.T) { + t.Setenv("CHROME_EXTENSION_ID", "test-extension-id") + + req := httptest.NewRequest(http.MethodPost, "/api/refresh", nil) + req.Header.Set("Authorization", "Bearer test_token") + w := httptest.NewRecorder() + + Handler(w, req) + + corsHeader := w.Header().Get("Access-Control-Allow-Origin") + expectedOrigin := "chrome-extension://test-extension-id" + if corsHeader != expectedOrigin { + t.Errorf("Expected CORS header to be %s, got %s", expectedOrigin, corsHeader) + } +} + +func containsString(s, substr string) bool { + return len(s) >= len(substr) && stringContains(s, substr) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/server/api/verify/index.go b/server/api/verify/index.go index 405db11..444ec54 100644 --- a/server/api/verify/index.go +++ b/server/api/verify/index.go @@ -31,22 +31,22 @@ func Handler(w http.ResponseWriter, r *http.Request) { accessToken := tokenString[7:] claims, err := jwt.ValidateAccessToken(accessToken) if err != nil { - httputil.WriteError(w, http.StatusUnauthorized, "invalid_access_token", err.Error()) + httputil.WriteErrorWithLog(w, err, http.StatusUnauthorized, "invalid_access_token", "Invalid or expired access token") return } redisClient, err := redis.GetClient() if err != nil { - httputil.WriteError(w, http.StatusInternalServerError, "server_error", "Redis connection failed") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Storage service unavailable") return } githubAccessToken, err := redisClient.Get(redis.SessionKeyPrefix + claims.SessionID) if err != nil { if errors.Is(err, pkgerrors.ErrKeyNotFound) { - httputil.WriteError(w, http.StatusUnauthorized, "session_not_found", "Session expired or invalid") + httputil.WriteErrorWithLog(w, pkgerrors.ErrSessionNotFound, http.StatusUnauthorized, "session_not_found", "Session expired or invalid") } else { - httputil.WriteError(w, http.StatusInternalServerError, "redis_error", "Failed to retrieve session") + httputil.WriteErrorWithLog(w, err, http.StatusInternalServerError, "server_error", "Failed to retrieve session") } return } diff --git a/server/api/verify/index_test.go b/server/api/verify/index_test.go new file mode 100644 index 0000000..0ec8ee8 --- /dev/null +++ b/server/api/verify/index_test.go @@ -0,0 +1,155 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github-project-status-viewer-server/pkg/httputil" +) + +func TestHandler_ErrorResponseSanitization(t *testing.T) { + tests := []struct { + authHeader string + expectedCode string + expectedDescription string + name string + shouldNotContain []string + }{ + { + name: "missing bearer token should return sanitized error", + authHeader: "", + expectedCode: "invalid_token", + expectedDescription: "Bearer token required", + shouldNotContain: []string{"internal", "stack", "nil"}, + }, + { + name: "invalid bearer format should return sanitized error", + authHeader: "InvalidFormat", + expectedCode: "invalid_token", + expectedDescription: "Bearer token required", + shouldNotContain: []string{"internal", "stack", "nil"}, + }, + { + name: "invalid token should not expose internal error details", + authHeader: "Bearer invalid_token_format", + expectedCode: "invalid_access_token", + expectedDescription: "Invalid or expired access token", + shouldNotContain: []string{"jwt", "parse", "crypto", "signature", "algorithm"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/api/verify", nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + w := httptest.NewRecorder() + + Handler(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Status code = %v, want %v", w.Code, http.StatusUnauthorized) + } + + var apiError httputil.APIError + if err := json.NewDecoder(w.Body).Decode(&apiError); err != nil { + t.Fatalf("Failed to decode error response: %v", err) + } + + if apiError.Code != tt.expectedCode { + t.Errorf("Error code = %v, want %v", apiError.Code, tt.expectedCode) + } + + if apiError.Description != tt.expectedDescription { + t.Errorf("Error description = %v, want %v", apiError.Description, tt.expectedDescription) + } + + responseBody := w.Body.String() + for _, forbidden := range tt.shouldNotContain { + if containsString(responseBody, forbidden) { + t.Errorf("Response should not contain '%s' but body contains: %s", forbidden, responseBody) + } + } + }) + } +} + +func TestHandler_MethodValidation(t *testing.T) { + tests := []struct { + method string + name string + wantStatus int + }{ + { + name: "POST method should be accepted", + method: http.MethodPost, + wantStatus: http.StatusUnauthorized, + }, + { + name: "GET method should be rejected", + method: http.MethodGet, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "PUT method should be rejected", + method: http.MethodPut, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "DELETE method should be rejected", + method: http.MethodDelete, + wantStatus: http.StatusMethodNotAllowed, + }, + { + name: "OPTIONS method should be accepted for CORS", + method: http.MethodOptions, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, "/api/verify", nil) + req.Header.Set("Authorization", "Bearer test_token") + w := httptest.NewRecorder() + + Handler(w, req) + + if w.Code != tt.wantStatus { + t.Errorf("Status code = %v, want %v", w.Code, tt.wantStatus) + } + }) + } +} + +func TestHandler_CORSHeaders(t *testing.T) { + t.Setenv("CHROME_EXTENSION_ID", "test-extension-id") + + req := httptest.NewRequest(http.MethodPost, "/api/verify", nil) + req.Header.Set("Authorization", "Bearer test_token") + w := httptest.NewRecorder() + + Handler(w, req) + + corsHeader := w.Header().Get("Access-Control-Allow-Origin") + expectedOrigin := "chrome-extension://test-extension-id" + if corsHeader != expectedOrigin { + t.Errorf("Expected CORS header to be %s, got %s", expectedOrigin, corsHeader) + } +} + +func containsString(s, substr string) bool { + return len(s) >= len(substr) && stringContains(s, substr) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/server/pkg/httputil/errors.go b/server/pkg/httputil/errors.go new file mode 100644 index 0000000..9ca8587 --- /dev/null +++ b/server/pkg/httputil/errors.go @@ -0,0 +1,67 @@ +package httputil + +import ( + "errors" + "log/slog" + "net/http" + + pkgerrors "github-project-status-viewer-server/pkg/errors" +) + +type ErrorResponse struct { + Code string + Description string + StatusCode int +} + +var errorResponses = map[error]ErrorResponse{ + pkgerrors.ErrBearerTokenRequired: {StatusCode: http.StatusUnauthorized, Code: "invalid_token", Description: "Bearer token required"}, + pkgerrors.ErrInvalidAccessTokenClaims: {StatusCode: http.StatusUnauthorized, Code: "invalid_access_token", Description: "Invalid authentication token"}, + pkgerrors.ErrInvalidAuthHeader: {StatusCode: http.StatusUnauthorized, Code: "invalid_token", Description: "Invalid authorization header format"}, + pkgerrors.ErrInvalidRefreshTokenClaims: {StatusCode: http.StatusUnauthorized, Code: "invalid_refresh_token", Description: "Invalid refresh token"}, + pkgerrors.ErrInvalidSigningMethod: {StatusCode: http.StatusUnauthorized, Code: "invalid_token", Description: "Invalid token signature"}, + pkgerrors.ErrInvalidTokenFormat: {StatusCode: http.StatusUnauthorized, Code: "invalid_token", Description: "Invalid token format"}, + pkgerrors.ErrJWTSecretMissing: {StatusCode: http.StatusInternalServerError, Code: "server_error", Description: "Service configuration error"}, + pkgerrors.ErrKeyNotFound: {StatusCode: http.StatusUnauthorized, Code: "session_not_found", Description: "Session expired or invalid"}, + pkgerrors.ErrMethodNotAllowed: {StatusCode: http.StatusMethodNotAllowed, Code: "method_not_allowed", Description: "HTTP method not allowed"}, + pkgerrors.ErrMissingAuthCode: {StatusCode: http.StatusBadRequest, Code: "missing_code", Description: "Authorization code is required"}, + pkgerrors.ErrMissingStateParam: {StatusCode: http.StatusBadRequest, Code: "missing_state", Description: "State parameter is required for CSRF protection"}, + pkgerrors.ErrOAuthConfigMissing: {StatusCode: http.StatusInternalServerError, Code: "server_error", Description: "OAuth configuration error"}, + pkgerrors.ErrOAuthExchangeFailed: {StatusCode: http.StatusBadRequest, Code: "exchange_failed", Description: "Failed to exchange authorization code"}, + pkgerrors.ErrOAuthRequestFailed: {StatusCode: http.StatusBadGateway, Code: "oauth_error", Description: "OAuth service unavailable"}, + pkgerrors.ErrRedisConfigMissing: {StatusCode: http.StatusInternalServerError, Code: "server_error", Description: "Storage configuration error"}, + pkgerrors.ErrRedisRequestFailed: {StatusCode: http.StatusInternalServerError, Code: "server_error", Description: "Storage service error"}, + pkgerrors.ErrRefreshTokenRevoked: {StatusCode: http.StatusUnauthorized, Code: "refresh_token_revoked", Description: "Refresh token has been revoked or expired"}, + pkgerrors.ErrSessionExpired: {StatusCode: http.StatusUnauthorized, Code: "session_expired", Description: "Session expired or invalid"}, + pkgerrors.ErrSessionMismatch: {StatusCode: http.StatusUnauthorized, Code: "session_mismatch", Description: "Session mismatch detected"}, + pkgerrors.ErrSessionNotFound: {StatusCode: http.StatusUnauthorized, Code: "session_not_found", Description: "Session not found"}, + pkgerrors.ErrTokenExpired: {StatusCode: http.StatusUnauthorized, Code: "token_expired", Description: "Token has expired"}, + pkgerrors.ErrUnexpectedResponse: {StatusCode: http.StatusInternalServerError, Code: "server_error", Description: "Unexpected response from storage"}, +} + +func WriteErrorWithLog(w http.ResponseWriter, internalErr error, fallbackStatus int, fallbackCode, fallbackDescription string) { + response := getErrorResponse(internalErr, fallbackStatus, fallbackCode, fallbackDescription) + + slog.Error("API error occurred", + "status", response.StatusCode, + "code", response.Code, + "description", response.Description, + "internal_error", internalErr, + ) + + WriteError(w, response.StatusCode, response.Code, response.Description) +} + +func getErrorResponse(err error, fallbackStatus int, fallbackCode, fallbackDescription string) ErrorResponse { + for sentinelErr, response := range errorResponses { + if errors.Is(err, sentinelErr) { + return response + } + } + + return ErrorResponse{ + Code: fallbackCode, + Description: fallbackDescription, + StatusCode: fallbackStatus, + } +} diff --git a/server/pkg/httputil/errors_test.go b/server/pkg/httputil/errors_test.go new file mode 100644 index 0000000..b6735d7 --- /dev/null +++ b/server/pkg/httputil/errors_test.go @@ -0,0 +1,273 @@ +package httputil + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + pkgerrors "github-project-status-viewer-server/pkg/errors" +) + +func TestWriteErrorWithLog(t *testing.T) { + tests := []struct { + err error + fallbackCode string + fallbackDescription string + fallbackStatus int + name string + wantCode string + wantDescription string + wantStatus int + }{ + { + name: "should sanitize JWT secret missing error", + err: pkgerrors.ErrJWTSecretMissing, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusInternalServerError, + wantCode: "server_error", + wantDescription: "Service configuration error", + }, + { + name: "should sanitize invalid signing method error", + err: pkgerrors.ErrInvalidSigningMethod, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusUnauthorized, + wantCode: "invalid_token", + wantDescription: "Invalid token signature", + }, + { + name: "should sanitize token expired error", + err: pkgerrors.ErrTokenExpired, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusUnauthorized, + wantCode: "token_expired", + wantDescription: "Token has expired", + }, + { + name: "should sanitize session not found error", + err: pkgerrors.ErrSessionNotFound, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusUnauthorized, + wantCode: "session_not_found", + wantDescription: "Session not found", + }, + { + name: "should sanitize Redis config missing error", + err: pkgerrors.ErrRedisConfigMissing, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusInternalServerError, + wantCode: "server_error", + wantDescription: "Storage configuration error", + }, + { + name: "should sanitize OAuth exchange failed error", + err: pkgerrors.ErrOAuthExchangeFailed, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusBadRequest, + wantCode: "exchange_failed", + wantDescription: "Failed to exchange authorization code", + }, + { + name: "should use fallback for unknown error", + err: errors.New("internal database constraint violation"), + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "server_error", + fallbackDescription: "Internal server error", + wantStatus: http.StatusInternalServerError, + wantCode: "server_error", + wantDescription: "Internal server error", + }, + { + name: "should sanitize wrapped known error", + err: errors.Join(pkgerrors.ErrInvalidTokenFormat, errors.New("jwt: malformed token")), + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusUnauthorized, + wantCode: "invalid_token", + wantDescription: "Invalid token format", + }, + { + name: "should sanitize refresh token revoked error", + err: pkgerrors.ErrRefreshTokenRevoked, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusUnauthorized, + wantCode: "refresh_token_revoked", + wantDescription: "Refresh token has been revoked or expired", + }, + { + name: "should sanitize session mismatch error", + err: pkgerrors.ErrSessionMismatch, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "unknown_error", + fallbackDescription: "An error occurred", + wantStatus: http.StatusUnauthorized, + wantCode: "session_mismatch", + wantDescription: "Session mismatch detected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + + WriteErrorWithLog(w, tt.err, tt.fallbackStatus, tt.fallbackCode, tt.fallbackDescription) + + if w.Code != tt.wantStatus { + t.Errorf("Status code = %v, want %v", w.Code, tt.wantStatus) + } + + var apiError APIError + if err := json.NewDecoder(w.Body).Decode(&apiError); err != nil { + t.Fatalf("Failed to decode error response: %v", err) + } + + if apiError.Code != tt.wantCode { + t.Errorf("Error code = %v, want %v", apiError.Code, tt.wantCode) + } + + if apiError.Description != tt.wantDescription { + t.Errorf("Error description = %v, want %v", apiError.Description, tt.wantDescription) + } + }) + } +} + +func TestWriteErrorWithLog_DoesNotExposeInternalDetails(t *testing.T) { + tests := []struct { + err error + name string + shouldNotContain []string + }{ + { + name: "JWT validation error should not expose stack trace", + err: errors.New("failed to parse token: jwt: token signature is invalid: crypto/rsa: verification error"), + shouldNotContain: []string{"crypto/rsa", "verification error", "stack trace", "parse"}, + }, + { + name: "Redis error should not expose connection details", + err: errors.New("redis connection failed: dial tcp 127.0.0.1:6379: connect: connection refused"), + shouldNotContain: []string{"127.0.0.1", "6379", "dial tcp", "connection refused"}, + }, + { + name: "OAuth error should not expose API keys", + err: errors.New("oauth exchange failed: invalid_client: client_id abc123xyz does not match"), + shouldNotContain: []string{"abc123xyz", "client_id", "invalid_client"}, + }, + { + name: "SQL error should not expose query details", + err: errors.New("database error: duplicate key value violates unique constraint \"users_email_key\""), + shouldNotContain: []string{"duplicate key", "unique constraint", "users_email_key"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + + WriteErrorWithLog(w, tt.err, http.StatusInternalServerError, "server_error", "An unexpected error occurred") + + var apiError APIError + if err := json.NewDecoder(w.Body).Decode(&apiError); err != nil { + t.Fatalf("Failed to decode error response: %v", err) + } + + responseText := apiError.Code + " " + apiError.Description + + for _, forbidden := range tt.shouldNotContain { + if containsSubstring(responseText, forbidden) { + t.Errorf("Response contains internal detail '%s': %s", forbidden, responseText) + } + } + + if apiError.Code != "server_error" { + t.Errorf("Expected generic error code 'server_error', got '%s'", apiError.Code) + } + + if apiError.Description != "An unexpected error occurred" { + t.Errorf("Expected generic description, got '%s'", apiError.Description) + } + }) + } +} + +func TestGetErrorResponse(t *testing.T) { + tests := []struct { + err error + fallbackCode string + fallbackDescription string + fallbackStatus int + name string + wantCode string + wantDescription string + wantStatus int + }{ + { + name: "known error returns mapped response", + err: pkgerrors.ErrInvalidAccessTokenClaims, + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "fallback", + fallbackDescription: "Fallback message", + wantStatus: http.StatusUnauthorized, + wantCode: "invalid_access_token", + wantDescription: "Invalid authentication token", + }, + { + name: "unknown error returns fallback response", + err: errors.New("completely unknown error"), + fallbackStatus: http.StatusInternalServerError, + fallbackCode: "server_error", + fallbackDescription: "Something went wrong", + wantStatus: http.StatusInternalServerError, + wantCode: "server_error", + wantDescription: "Something went wrong", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + response := getErrorResponse(tt.err, tt.fallbackStatus, tt.fallbackCode, tt.fallbackDescription) + + if response.StatusCode != tt.wantStatus { + t.Errorf("StatusCode = %v, want %v", response.StatusCode, tt.wantStatus) + } + + if response.Code != tt.wantCode { + t.Errorf("Code = %v, want %v", response.Code, tt.wantCode) + } + + if response.Description != tt.wantDescription { + t.Errorf("Description = %v, want %v", response.Description, tt.wantDescription) + } + }) + } +} + +func containsSubstring(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && stringContains(s, substr)) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}