diff --git a/.env.example b/.env.example index 100b0e9df..4f1c3c31b 100644 --- a/.env.example +++ b/.env.example @@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS= TINYAUTH_LDAP_BINDDN= # Bind password for LDAP authentication. TINYAUTH_LDAP_BINDPASSWORD= +# Path to the Bind password. +TINYAUTH_LDAP_BINDPASSWORDFILE= # Base DN for LDAP searches. TINYAUTH_LDAP_BASEDN= # Allow insecure LDAP connections. diff --git a/frontend/src/lib/hooks/redirect-uri.ts b/frontend/src/lib/hooks/redirect-uri.ts index 38e8b5c53..aeeae0c54 100644 --- a/frontend/src/lib/hooks/redirect-uri.ts +++ b/frontend/src/lib/hooks/redirect-uri.ts @@ -15,7 +15,7 @@ export const useRedirectUri = ( let isAllowedProto = false; let isHttpsDowngrade = false; - if (!redirect_uri) { + if (redirect_uri === undefined) { return { valid: isValid, trusted: isTrusted, diff --git a/gen/sqlc-wrapper/sqlc_wrapper.go b/gen/sqlc-wrapper/sqlc_wrapper.go index a7a75eb46..79576ec61 100644 --- a/gen/sqlc-wrapper/sqlc_wrapper.go +++ b/gen/sqlc-wrapper/sqlc_wrapper.go @@ -67,15 +67,24 @@ func run() error { Overlay: map[string][]byte{outPath: stub}, } - driverTypePkg, err := loadOnePkg(cfg, *driverPkg) + repoPkgPath := parentPkg(*driverPkg) + + pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath) + if err != nil { - return fmt.Errorf("load driver package: %w", err) + return fmt.Errorf("load packages: %w", err) } - repoPkgPath := parentPkg(*driverPkg) - repoTypePkg, err := loadOnePkg(cfg, repoPkgPath) - if err != nil { - return fmt.Errorf("load repo package: %w", err) + driverTypePkg, ok := pkgs[*driverPkg] + + if !ok { + return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg) + } + + repoTypePkg, ok := pkgs[repoPkgPath] + + if !ok { + return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath) } if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { @@ -106,25 +115,25 @@ func run() error { return nil } -// loadOnePkg loads a single package via cfg and returns its *types.Package, -// or an error if the package fails to load or has type errors. -func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) { - pkgs, err := packages.Load(cfg, importPath) +// loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package, +// or an error if any package fails to load or has type errors. +func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) { + pkgs, err := packages.Load(cfg, importPaths...) if err != nil { - return nil, fmt.Errorf("load %s: %w", importPath, err) - } - if len(pkgs) != 1 { - return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs)) - } - pkg := pkgs[0] - if len(pkg.Errors) > 0 { - msgs := make([]string, len(pkg.Errors)) - for i, e := range pkg.Errors { - msgs[i] = e.Error() + return nil, fmt.Errorf("load %v: %w", importPaths, err) + } + out := make(map[string]*types.Package) + for _, pkg := range pkgs { + if len(pkg.Errors) > 0 { + msgs := make([]string, len(pkg.Errors)) + for i, e := range pkg.Errors { + msgs[i] = e.Error() + } + return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n ")) } - return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n ")) + out[pkg.PkgPath] = pkg.Types } - return pkg.Types, nil + return out, nil } // parentPkg returns the parent import path (everything before the last /). diff --git a/internal/assets/migrations/postgres/000003_oidc_consent.up.sql b/internal/assets/migrations/postgres/000003_oidc_consent.up.sql new file mode 100644 index 000000000..6aa84ed58 --- /dev/null +++ b/internal/assets/migrations/postgres/000003_oidc_consent.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS "oidc_consent" ( + "uuid" TEXT NOT NULL UNIQUE PRIMARY KEY, + "client_id" TEXT NOT NULL, + "scopes" TEXT NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/internal/assets/migrations/postgres/000003_oidc_consnet.down.sql b/internal/assets/migrations/postgres/000003_oidc_consnet.down.sql new file mode 100644 index 000000000..2dae02be6 --- /dev/null +++ b/internal/assets/migrations/postgres/000003_oidc_consnet.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS "oidc_consent"; diff --git a/internal/assets/migrations/sqlite/000011_oidc_consent.down.sql b/internal/assets/migrations/sqlite/000011_oidc_consent.down.sql new file mode 100644 index 000000000..2dae02be6 --- /dev/null +++ b/internal/assets/migrations/sqlite/000011_oidc_consent.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS "oidc_consent"; diff --git a/internal/assets/migrations/sqlite/000011_oidc_consent.up.sql b/internal/assets/migrations/sqlite/000011_oidc_consent.up.sql new file mode 100644 index 000000000..0fc41cf88 --- /dev/null +++ b/internal/assets/migrations/sqlite/000011_oidc_consent.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS "oidc_consent" ( + "uuid" TEXT NOT NULL UNIQUE PRIMARY KEY, + "client_id" TEXT NOT NULL, + "scopes" TEXT NOT NULL, + "created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 50aa49ef2..2d102af36 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -48,6 +48,7 @@ type Services struct { type BootstrapApp struct { config model.Config runtime model.RuntimeConfig + helpers model.RuntimeHelpers services Services log *logger.Logger ctx context.Context @@ -185,9 +186,8 @@ func (app *BootstrapApp) Setup() error { cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) - app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) - app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId) // database store, err := app.SetupStore() @@ -291,6 +291,17 @@ func (app *BootstrapApp) Setup() error { app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) } + // runtime helpers + app.helpers.GetCookieDomain = app.getCookieDomain + + err = app.dig.Provide(func() *model.RuntimeHelpers { + return &app.helpers + }) + + if err != nil { + return fmt.Errorf("failed to provide runtime helpers to container: %w", err) + } + // setup router err = app.setupRouter() diff --git a/internal/bootstrap/app_helpers.go b/internal/bootstrap/app_helpers.go new file mode 100644 index 000000000..4be947530 --- /dev/null +++ b/internal/bootstrap/app_helpers.go @@ -0,0 +1,55 @@ +package bootstrap + +import ( + "context" + "errors" + "fmt" + + "github.com/tinyauthapp/tinyauth/internal/utils" +) + +// Not really the best place for the helpers to be but it works because bootstrap app provides +// them with everything they need + +func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) { + cookieDomain := app.runtime.CookieDomain + + if app.isTailscaleRequest(ctx, ip) { + if app.services.tailscaleService == nil { + return "", errors.New("tailscale service is not configured") + } + + tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname())) + + if err != nil { + return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err) + } + + cookieDomain = tsCookieDomain + } + + if app.config.Auth.SubdomainsEnabled { + cookieDomain = "." + cookieDomain + } + + return cookieDomain, nil +} + +func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool { + if app.services.tailscaleService == nil { + return false + } + + whois, err := app.services.tailscaleService.Whois(ctx, ip) + + if err != nil { + app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err) + return false + } + + if whois == nil { + return false + } + + return true +} diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index e01bf480f..941fca118 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -28,6 +28,7 @@ type OAuthController struct { config *model.Config runtime *model.RuntimeConfig auth *service.AuthService + helpers *model.RuntimeHelpers } type OAuthControllerInput struct { @@ -36,6 +37,7 @@ type OAuthControllerInput struct { Log *logger.Logger Config *model.Config RuntimeConfig *model.RuntimeConfig + Helpers *model.RuntimeHelpers RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` AuthService *service.AuthService } @@ -46,6 +48,7 @@ func NewOAuthController(i OAuthControllerInput) *OAuthController { config: i.Config, runtime: i.RuntimeConfig, auth: i.AuthService, + helpers: i.Helpers, } oauthGroup := i.RouterGroup.Group("/oauth") @@ -110,7 +113,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) + cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP()) + + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain") + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true) c.JSON(200, gin.H{ "status": 200, @@ -140,7 +154,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) + cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP()) + + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + + c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) @@ -257,7 +279,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { controller.log.App.Debug().Msg("Creating session cookie for user") - cookie, err := controller.auth.CreateSession(c, sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to create session cookie") diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index c40499532..0da8d66da 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -1,6 +1,7 @@ package controller import ( + "database/sql" "encoding/json" "errors" "fmt" @@ -34,6 +35,8 @@ type OIDCController struct { log *logger.Logger oidc *service.OIDCService runtime *model.RuntimeConfig + helpers *model.RuntimeHelpers + config *model.Config } type AuthorizeCallback struct { @@ -90,6 +93,8 @@ type OIDCControllerInput struct { RuntimeConfig *model.RuntimeConfig RouterGroup *gin.RouterGroup `name:"apiRouterGroup"` MainRouter *gin.RouterGroup `name:"mainRouterGroup"` + Helpers *model.RuntimeHelpers + Config *model.Config } func NewOIDCController(i OIDCControllerInput) *OIDCController { @@ -97,6 +102,8 @@ func NewOIDCController(i OIDCControllerInput) *OIDCController { log: i.Log, oidc: i.OIDCService, runtime: i.RuntimeConfig, + helpers: i.Helpers, + config: i.Config, } i.MainRouter.POST("/authorize", controller.authorize) @@ -219,6 +226,25 @@ func (controller *OIDCController) authorize(c *gin.Context) { values.OIDCPrompt = service.OIDCPromptNone } + // If no prompt is already set, we can check if we can/should skip it based on the cookie + if values.OIDCPrompt == "" { + consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName) + + if err == nil { + consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie) + + if err == nil && consentEntry != nil { + if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope { + values.OIDCPrompt = service.OIDCPromptNone + } + } else { + if !errors.Is(err, sql.ErrNoRows) { + controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie") + } + } + } + } + if req.MaxAge != "" && userContext != nil { maxAge, err := strconv.Atoi(req.MaxAge) if err != nil { @@ -361,6 +387,33 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) { return } + // Just before returning let's set the consent cookie + consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope) + + // If we fail to create the consent entry, we don't want to block the authorization flow, + // but we log the error and move on without setting the cookie + if err == nil { + cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP()) + + if err == nil { + cookie := &http.Cookie{ + Name: controller.runtime.ConsentCookieName, + Value: consnetUUID, + Path: "/", + Domain: cookieDomain, + Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year + Secure: controller.config.Auth.SecureCookie, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + http.SetCookie(c.Writer, cookie) + } else { + controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie") + } + } else { + controller.log.App.Error().Err(err).Msg("Failed to create consent entry") + } + c.JSON(200, gin.H{ "status": 200, "redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()), diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index b22ddc547..b7e8370a7 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -29,6 +29,8 @@ func TestOIDCController(t *testing.T) { cfg, runtime := test.CreateTestConfigs(t) + helpers := test.CreateTestHelpers() + ctx := context.TODO() dg := ding.New(ctx) @@ -862,6 +864,8 @@ func TestOIDCController(t *testing.T) { RuntimeConfig: &runtime, RouterGroup: group, MainRouter: &router.RouterGroup, + Helpers: helpers, + Config: &cfg, }) recorder := httptest.NewRecorder() diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index faa9934b4..9dc5a8e41 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -26,6 +26,8 @@ func TestProxyController(t *testing.T) { cfg, runtime := test.CreateTestConfigs(t) + helpers := test.CreateTestHelpers() + const browserUserAgent = ` Mozilla/5.0 (Linux; Android 8.0.0; SM-G955U Build/R16NW) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Mobile Safari/537.36` @@ -719,6 +721,7 @@ func TestProxyController(t *testing.T) { OAuthBroker: broker, Tailscale: nil, PolicyEngine: policyEngine, + Helpers: helpers, }) for _, test := range tests { diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index f17b7d795..670dc8f34 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -155,7 +155,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { Email: email, Provider: "local", TotpPending: true, - }) + }, c.RemoteIP()) if err != nil { controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") @@ -200,7 +200,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { } } - cookie, err := controller.auth.CreateSession(c, sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) if err != nil { controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") @@ -251,7 +251,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) { return } - cookie, err := controller.auth.DeleteSession(c, uuid) + cookie, err := controller.auth.DeleteSession(c, uuid, c.RemoteIP()) if err != nil { controller.log.App.Error().Err(err).Msg("Error deleting session on logout") @@ -355,7 +355,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { uuid, err := c.Cookie(controller.runtime.SessionCookieName) if err == nil { - _, err = controller.auth.DeleteSession(c, uuid) + _, err = controller.auth.DeleteSession(c, uuid, c.RemoteIP()) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") } @@ -379,7 +379,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { sessionCookie.Email = user.Attributes.Email } - cookie, err := controller.auth.CreateSession(c, sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) if err != nil { controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") @@ -429,7 +429,7 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) { Provider: "tailscale", } - cookie, err := controller.auth.CreateSession(c, sessionCookie) + cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP()) if err != nil { controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful Tailscale login") diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 4f081b9b0..82f779ae3 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -28,6 +28,8 @@ func TestUserController(t *testing.T) { cfg, runtime := test.CreateTestConfigs(t) + helpers := test.CreateTestHelpers() + totpCtx := func(c *gin.Context) { c.Set("context", &model.UserContext{ Authenticated: false, @@ -553,6 +555,7 @@ func TestUserController(t *testing.T) { OAuthBroker: broker, Tailscale: nil, PolicyEngine: policyEngine, + Helpers: helpers, }) beforeEach := func() { diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 62c64b90d..cb0b78af6 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -211,12 +211,12 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri } if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) { - m.auth.DeleteSession(ctx, uuid) + m.auth.DeleteSession(ctx, uuid, ip) return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) } } - cookie, err := m.auth.RefreshSession(ctx, uuid) + cookie, err := m.auth.RefreshSession(ctx, uuid, ip) if err != nil { return nil, nil, fmt.Errorf("error refreshing session: %w", err) diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 8a89e4190..2f841b489 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -26,6 +26,8 @@ func TestContextMiddleware(t *testing.T) { cfg, runtime := test.CreateTestConfigs(t) + helpers := test.CreateTestHelpers() + basicAuthHeader := func(username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) } @@ -275,6 +277,7 @@ func TestContextMiddleware(t *testing.T) { OAuthBroker: broker, Tailscale: nil, PolicyEngine: policyEngine, + Helpers: helpers, }) contextMiddleware := NewContextMiddleware(ContextMiddlewareInput{ diff --git a/internal/model/constants.go b/internal/model/constants.go index d5885dcf7..35ce2813e 100644 --- a/internal/model/constants.go +++ b/internal/model/constants.go @@ -18,8 +18,7 @@ var OverrideProviders = map[string]string{ } const SessionCookieName = "tinyauth-session" -const CSRFCookieName = "tinyauth-csrf" -const RedirectCookieName = "tinyauth-redirect" const OAuthSessionCookieName = "tinyauth-oauth" +const ConsentCookieName = "tinyauth-consent" const GracefulShutdownTimeout = 5 // seconds diff --git a/internal/model/runtime.go b/internal/model/runtime.go index 0df999015..ca717b0ee 100644 --- a/internal/model/runtime.go +++ b/internal/model/runtime.go @@ -1,13 +1,14 @@ package model +import "context" + type RuntimeConfig struct { AppURL string UUID string CookieDomain string SessionCookieName string - CSRFCookieName string - RedirectCookieName string OAuthSessionCookieName string + ConsentCookieName string LocalUsers []LocalUser OAuthProviders map[string]OAuthServiceConfig OAuthWhitelist []string @@ -15,6 +16,10 @@ type RuntimeConfig struct { TrustedDomains []string } +type RuntimeHelpers struct { + GetCookieDomain func(ctx context.Context, ip string) (string, error) +} + type Provider struct { Name string `json:"name"` ID string `json:"id"` diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go index 558ed234f..373d68f5e 100644 --- a/internal/repository/memory/memory_test.go +++ b/internal/repository/memory/memory_test.go @@ -277,6 +277,78 @@ func TestMemoryStore(t *testing.T) { assert.NoError(t, err) }, }, + { + description: "Create and get OIDC consent", + run: func(t *testing.T, s repository.Store) { + consent, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{ + UUID: "uuid-1", + ClientID: "client-1", + Scopes: "openid profile", + }) + require.NoError(t, err) + assert.Equal(t, "uuid-1", consent.UUID) + assert.Equal(t, "client-1", consent.ClientID) + assert.Equal(t, "openid profile", consent.Scopes) + + got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, consent, got) + }, + }, + { + description: "Get OIDC consent by UUID not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.GetOIDCConsentByUUID(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Create OIDC consent unique UUID constraint", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"}) + require.NoError(t, err) + + _, err = s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-2", Scopes: "profile"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_consent.uuid") + }, + }, + { + description: "Update OIDC consent", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"}) + require.NoError(t, err) + + updated, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{ + UUID: "uuid-1", + Scopes: "profile email", + }) + require.NoError(t, err) + assert.Equal(t, "profile email", updated.Scopes) + + got, err := s.GetOIDCConsentByUUID(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, updated, got) + }, + }, + { + description: "Update OIDC consent not found", + run: func(t *testing.T, s repository.Store) { + _, err := s.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{UUID: "missing"}) + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, + { + description: "Delete OIDC consent by UUID", + run: func(t *testing.T, s repository.Store) { + _, err := s.CreateOIDCConsent(ctx, repository.CreateOIDCConsentParams{UUID: "uuid-1", ClientID: "client-1", Scopes: "openid"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOIDCConsentByUUID(ctx, "uuid-1")) + + _, err = s.GetOIDCConsentByUUID(ctx, "uuid-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + }, + }, } for _, test := range tests { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go index 1ee81c8bf..70728978e 100644 --- a/internal/repository/memory/oidc_queries.go +++ b/internal/repository/memory/oidc_queries.go @@ -94,3 +94,47 @@ func (s *Store) DeleteExpiredOIDCSessions(_ context.Context, arg repository.Dele } return nil } + +func (s *Store) CreateOIDCConsent(_ context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.oidcConsent[arg.UUID]; ok { + return repository.OidcConsent{}, fmt.Errorf("UNIQUE constraint failed: oidc_consent.uuid") + } + consent := repository.OidcConsent{ + UUID: arg.UUID, + ClientID: arg.ClientID, + Scopes: arg.Scopes, + } + s.oidcConsent[arg.UUID] = consent + return consent, nil +} + +func (s *Store) GetOIDCConsentByUUID(_ context.Context, uuid string) (repository.OidcConsent, error) { + s.mu.RLock() + defer s.mu.RUnlock() + consent, ok := s.oidcConsent[uuid] + if !ok { + return repository.OidcConsent{}, repository.ErrNotFound + } + return consent, nil +} + +func (s *Store) UpdateOIDCConsent(_ context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) { + s.mu.Lock() + defer s.mu.Unlock() + consent, ok := s.oidcConsent[arg.UUID] + if !ok { + return repository.OidcConsent{}, repository.ErrNotFound + } + consent.Scopes = arg.Scopes + s.oidcConsent[arg.UUID] = consent + return consent, nil +} + +func (s *Store) DeleteOIDCConsentByUUID(_ context.Context, uuid string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcConsent, uuid) + return nil +} diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go index 684ddeb3e..ec7fd8db8 100644 --- a/internal/repository/memory/store.go +++ b/internal/repository/memory/store.go @@ -12,6 +12,7 @@ type Store struct { mu sync.RWMutex sessions map[string]repository.Session oidcSessions map[string]repository.OidcSession + oidcConsent map[string]repository.OidcConsent } // New returns a new empty in-memory Store. @@ -19,5 +20,6 @@ func New() repository.Store { return &Store{ sessions: make(map[string]repository.Session), oidcSessions: make(map[string]repository.OidcSession), + oidcConsent: make(map[string]repository.OidcConsent), } } diff --git a/internal/repository/models.go b/internal/repository/models.go index 39538a000..1d77a5fe8 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -1,8 +1,18 @@ package repository +import "time" + // Shared model and parameter types for all storage drivers. // sqlc-generated driver packages use these via the conversion layer in their store.go. +type OidcConsent struct { + UUID string + ClientID string + Scopes string + CreatedAt time.Time + UpdatedAt time.Time +} + type Session struct { UUID string Username string @@ -84,3 +94,14 @@ type DeleteExpiredOIDCSessionsParams struct { TokenExpiresAt int64 RefreshTokenExpiresAt int64 } + +type CreateOIDCConsentParams struct { + UUID string + ClientID string + Scopes string +} + +type UpdateOIDCConsentParams struct { + Scopes string + UUID string +} diff --git a/internal/repository/postgres/models.go b/internal/repository/postgres/models.go index f957e1fde..a214908d0 100644 --- a/internal/repository/postgres/models.go +++ b/internal/repository/postgres/models.go @@ -4,6 +4,18 @@ package postgres +import ( + "time" +) + +type OidcConsent struct { + UUID string + ClientID string + Scopes string + CreatedAt time.Time + UpdatedAt time.Time +} + type OidcSession struct { Sub string AccessTokenHash string diff --git a/internal/repository/postgres/oidc_queries.sql.go b/internal/repository/postgres/oidc_queries.sql.go index b5b9789c9..363dacb25 100644 --- a/internal/repository/postgres/oidc_queries.sql.go +++ b/internal/repository/postgres/oidc_queries.sql.go @@ -9,6 +9,36 @@ import ( "context" ) +const createOIDCConsent = `-- name: CreateOIDCConsent :one +INSERT INTO "oidc_consent" ( + "uuid", + "client_id", + "scopes" +) VALUES ( + $1, $2, $3 +) +RETURNING uuid, client_id, scopes, created_at, updated_at +` + +type CreateOIDCConsentParams struct { + UUID string + ClientID string + Scopes string +} + +func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) { + row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes) + var i OidcConsent + err := row.Scan( + &i.UUID, + &i.ClientID, + &i.Scopes, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const createOIDCSession = `-- name: CreateOIDCSession :one INSERT INTO "oidc_sessions" ( "sub", @@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir return err } +const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec +DELETE FROM "oidc_consent" +WHERE "uuid" = $1 +` + +func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error { + _, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid) + return err +} + const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec DELETE FROM "oidc_sessions" WHERE "sub" = $1 @@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error return err } +const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one +SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent" +WHERE "uuid" = $1 +` + +func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) { + row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid) + var i OidcConsent + err := row.Scan( + &i.UUID, + &i.ClientID, + &i.Scopes, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "access_token_hash" = $1 @@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess return i, err } +const updateOIDCConsent = `-- name: UpdateOIDCConsent :one +UPDATE "oidc_consent" SET + "scopes" = $1, + "updated_at" = CURRENT_TIMESTAMP +WHERE "uuid" = $2 +RETURNING uuid, client_id, scopes, created_at, updated_at +` + +type UpdateOIDCConsentParams struct { + Scopes string + UUID string +} + +func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) { + row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID) + var i OidcConsent + err := row.Scan( + &i.UUID, + &i.ClientID, + &i.Scopes, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const updateOIDCSession = `-- name: UpdateOIDCSession :one UPDATE "oidc_sessions" SET "access_token_hash" = $1, diff --git a/internal/repository/postgres/store.go b/internal/repository/postgres/store.go index b3e79c803..719b3118d 100644 --- a/internal/repository/postgres/store.go +++ b/internal/repository/postgres/store.go @@ -32,6 +32,14 @@ func mapErr(err error) error { return err } +func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) { + r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg)) + if err != nil { + return repository.OidcConsent{}, mapErr(err) + } + return repository.OidcConsent(r), nil +} + func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) if err != nil { @@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } +func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error { + return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid)) +} + func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) } @@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error { return mapErr(s.q.DeleteSession(ctx, uuid)) } +func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) { + r, err := s.q.GetOIDCConsentByUUID(ctx, uuid) + if err != nil { + return repository.OidcConsent{}, mapErr(err) + } + return repository.OidcConsent(r), nil +} + func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) if err != nil { @@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session return repository.Session(r), nil } +func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) { + r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg)) + if err != nil { + return repository.OidcConsent{}, mapErr(err) + } + return repository.OidcConsent(r), nil +} + func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) if err != nil { diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go index 2ced8a2b3..ca4a524cb 100644 --- a/internal/repository/sqlite/models.go +++ b/internal/repository/sqlite/models.go @@ -4,6 +4,18 @@ package sqlite +import ( + "time" +) + +type OidcConsent struct { + UUID string + ClientID string + Scopes string + CreatedAt time.Time + UpdatedAt time.Time +} + type OidcSession struct { Sub string AccessTokenHash string diff --git a/internal/repository/sqlite/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go index a5aa08a8f..7f7c267bb 100644 --- a/internal/repository/sqlite/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -9,6 +9,36 @@ import ( "context" ) +const createOIDCConsent = `-- name: CreateOIDCConsent :one +INSERT INTO "oidc_consent" ( + "uuid", + "client_id", + "scopes" +) VALUES ( + ?, ?, ? +) +RETURNING uuid, client_id, scopes, created_at, updated_at +` + +type CreateOIDCConsentParams struct { + UUID string + ClientID string + Scopes string +} + +func (q *Queries) CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) { + row := q.db.QueryRowContext(ctx, createOIDCConsent, arg.UUID, arg.ClientID, arg.Scopes) + var i OidcConsent + err := row.Scan( + &i.UUID, + &i.ClientID, + &i.Scopes, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const createOIDCSession = `-- name: CreateOIDCSession :one INSERT INTO "oidc_sessions" ( "sub", @@ -80,6 +110,16 @@ func (q *Queries) DeleteExpiredOIDCSessions(ctx context.Context, arg DeleteExpir return err } +const deleteOIDCConsentByUUID = `-- name: DeleteOIDCConsentByUUID :exec +DELETE FROM "oidc_consent" +WHERE "uuid" = ? +` + +func (q *Queries) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error { + _, err := q.db.ExecContext(ctx, deleteOIDCConsentByUUID, uuid) + return err +} + const deleteOIDCSessionBySub = `-- name: DeleteOIDCSessionBySub :exec DELETE FROM "oidc_sessions" WHERE "sub" = ? @@ -90,6 +130,24 @@ func (q *Queries) DeleteOIDCSessionBySub(ctx context.Context, sub string) error return err } +const getOIDCConsentByUUID = `-- name: GetOIDCConsentByUUID :one +SELECT uuid, client_id, scopes, created_at, updated_at FROM "oidc_consent" +WHERE "uuid" = ? +` + +func (q *Queries) GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) { + row := q.db.QueryRowContext(ctx, getOIDCConsentByUUID, uuid) + var i OidcConsent + err := row.Scan( + &i.UUID, + &i.ClientID, + &i.Scopes, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getOIDCSessionByAccessTokenHash = `-- name: GetOIDCSessionByAccessTokenHash :one SELECT sub, access_token_hash, refresh_token_hash, scope, client_id, token_expires_at, refresh_token_expires_at, nonce, userinfo_json FROM "oidc_sessions" WHERE "access_token_hash" = ? @@ -156,6 +214,32 @@ func (q *Queries) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSess return i, err } +const updateOIDCConsent = `-- name: UpdateOIDCConsent :one +UPDATE "oidc_consent" SET + "scopes" = ?, + "updated_at" = CURRENT_TIMESTAMP +WHERE "uuid" = ? +RETURNING uuid, client_id, scopes, created_at, updated_at +` + +type UpdateOIDCConsentParams struct { + Scopes string + UUID string +} + +func (q *Queries) UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) { + row := q.db.QueryRowContext(ctx, updateOIDCConsent, arg.Scopes, arg.UUID) + var i OidcConsent + err := row.Scan( + &i.UUID, + &i.ClientID, + &i.Scopes, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const updateOIDCSession = `-- name: UpdateOIDCSession :one UPDATE "oidc_sessions" SET "access_token_hash" = ?, diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index a567c8718..a5ba21421 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -32,6 +32,14 @@ func mapErr(err error) error { return err } +func (s *Store) CreateOIDCConsent(ctx context.Context, arg repository.CreateOIDCConsentParams) (repository.OidcConsent, error) { + r, err := s.q.CreateOIDCConsent(ctx, CreateOIDCConsentParams(arg)) + if err != nil { + return repository.OidcConsent{}, mapErr(err) + } + return repository.OidcConsent(r), nil +} + func (s *Store) CreateOIDCSession(ctx context.Context, arg repository.CreateOIDCSessionParams) (repository.OidcSession, error) { r, err := s.q.CreateOIDCSession(ctx, CreateOIDCSessionParams(arg)) if err != nil { @@ -56,6 +64,10 @@ func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } +func (s *Store) DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error { + return mapErr(s.q.DeleteOIDCConsentByUUID(ctx, uuid)) +} + func (s *Store) DeleteOIDCSessionBySub(ctx context.Context, sub string) error { return mapErr(s.q.DeleteOIDCSessionBySub(ctx, sub)) } @@ -64,6 +76,14 @@ func (s *Store) DeleteSession(ctx context.Context, uuid string) error { return mapErr(s.q.DeleteSession(ctx, uuid)) } +func (s *Store) GetOIDCConsentByUUID(ctx context.Context, uuid string) (repository.OidcConsent, error) { + r, err := s.q.GetOIDCConsentByUUID(ctx, uuid) + if err != nil { + return repository.OidcConsent{}, mapErr(err) + } + return repository.OidcConsent(r), nil +} + func (s *Store) GetOIDCSessionByAccessTokenHash(ctx context.Context, accessTokenHash string) (repository.OidcSession, error) { r, err := s.q.GetOIDCSessionByAccessTokenHash(ctx, accessTokenHash) if err != nil { @@ -96,6 +116,14 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session return repository.Session(r), nil } +func (s *Store) UpdateOIDCConsent(ctx context.Context, arg repository.UpdateOIDCConsentParams) (repository.OidcConsent, error) { + r, err := s.q.UpdateOIDCConsent(ctx, UpdateOIDCConsentParams(arg)) + if err != nil { + return repository.OidcConsent{}, mapErr(err) + } + return repository.OidcConsent(r), nil +} + func (s *Store) UpdateOIDCSession(ctx context.Context, arg repository.UpdateOIDCSessionParams) (repository.OidcSession, error) { r, err := s.q.UpdateOIDCSession(ctx, UpdateOIDCSessionParams(arg)) if err != nil { diff --git a/internal/repository/store.go b/internal/repository/store.go index abd70bd34..a36f12ee9 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -27,4 +27,10 @@ type Store interface { GetOIDCSessionByRefreshTokenHash(ctx context.Context, refreshTokenHash string) (OidcSession, error) GetOIDCSessionBySub(ctx context.Context, sub string) (OidcSession, error) UpdateOIDCSession(ctx context.Context, arg UpdateOIDCSessionParams) (OidcSession, error) + + // OIDC consents + CreateOIDCConsent(ctx context.Context, arg CreateOIDCConsentParams) (OidcConsent, error) + DeleteOIDCConsentByUUID(ctx context.Context, uuid string) error + GetOIDCConsentByUUID(ctx context.Context, uuid string) (OidcConsent, error) + UpdateOIDCConsent(ctx context.Context, arg UpdateOIDCConsentParams) (OidcConsent, error) } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 5e79ff752..c6bb94373 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -62,6 +62,7 @@ type AuthService struct { config *model.Config runtime *model.RuntimeConfig ctx context.Context + helpers *model.RuntimeHelpers ldap *LdapService queries repository.Store @@ -99,6 +100,7 @@ type AuthServiceInput struct { OAuthBroker *OAuthBrokerService Tailscale *TailscaleService `optional:"true"` PolicyEngine *PolicyEngine + Helpers *model.RuntimeHelpers } func NewAuthService(i AuthServiceInput) *AuthService { @@ -112,6 +114,7 @@ func NewAuthService(i AuthServiceInput) *AuthService { oauthBroker: i.OAuthBroker, tailscale: i.Tailscale, policyEngine: i.PolicyEngine, + helpers: i.Helpers, } // get the max login limits based on the number of users and the configured max retries @@ -339,7 +342,7 @@ func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool }) } -func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { +func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session, ip string) (*http.Cookie, error) { if data.Provider == "tailscale" && auth.tailscale == nil { return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user") } @@ -380,33 +383,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess return nil, fmt.Errorf("failed to create session entry: %w", err) } - if data.Provider == "tailscale" { - auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname") + cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip) - tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname())) - - if err != nil { - return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err) - } - - return &http.Cookie{ - Name: auth.runtime.SessionCookieName, - Value: session.UUID, - Path: "/", - Domain: fmt.Sprintf(".%s", tsCookieDomain), - Expires: expiresAt, - MaxAge: int(time.Until(expiresAt).Seconds()), - Secure: auth.config.Auth.SecureCookie, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - }, nil + if err != nil { + return nil, fmt.Errorf("failed to determine cookie domain: %w", err) } return &http.Cookie{ Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Domain: cookieDomain, Expires: expiresAt, MaxAge: int(time.Until(expiresAt).Seconds()), Secure: auth.config.Auth.SecureCookie, @@ -415,13 +402,17 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess }, nil } -func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { +func (auth *AuthService) RefreshSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) { session, err := auth.queries.GetSession(ctx, uuid) if err != nil { return nil, fmt.Errorf("failed to retrieve session: %w", err) } + if session.Provider == "tailscale" && auth.tailscale == nil { + return nil, fmt.Errorf("tailscale service not configured, cannot create session for tailscale user") + } + currentTime := time.Now().Unix() var refreshThreshold int64 @@ -455,11 +446,17 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http return nil, fmt.Errorf("failed to update session expiry: %w", err) } + cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip) + + if err != nil { + return nil, fmt.Errorf("failed to determine cookie domain: %w", err) + } + return &http.Cookie{ Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Domain: cookieDomain, Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), MaxAge: int(newExpiry - currentTime), Secure: auth.config.Auth.SecureCookie, @@ -469,18 +466,24 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http } -func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { +func (auth *AuthService) DeleteSession(ctx context.Context, uuid string, ip string) (*http.Cookie, error) { err := auth.queries.DeleteSession(ctx, uuid) if err != nil { auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") } + cookieDomain, err := auth.helpers.GetCookieDomain(ctx, ip) + + if err != nil { + return nil, fmt.Errorf("failed to determine cookie domain: %w", err) + } + return &http.Cookie{ Name: auth.runtime.SessionCookieName, Value: "", Path: "/", - Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Domain: cookieDomain, Expires: time.Now(), MaxAge: -1, Secure: auth.config.Auth.SecureCookie, diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index a3a02400b..9c88c7ec4 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -22,6 +22,7 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "github.com/steveiliop56/ding" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" @@ -969,3 +970,47 @@ func (service *OIDCService) GetPrompt(prompt string) []OIDCPrompt { return parsedPromps } + +func (service *OIDCService) CreateConsentEntry(ctx context.Context, clientId string, scope string) (string, error) { + u := uuid.New() + + entry := repository.CreateOIDCConsentParams{ + UUID: u.String(), + ClientID: clientId, + Scopes: scope, + } + + _, err := service.queries.CreateOIDCConsent(ctx, entry) + + if err != nil { + return "", err + } + + return entry.UUID, nil +} + +func (service *OIDCService) GetConsentEntry(ctx context.Context, uuid string) (*repository.OidcConsent, error) { + entry, err := service.queries.GetOIDCConsentByUUID(ctx, uuid) + + if err != nil { + if errors.Is(err, repository.ErrNotFound) { + return nil, nil + } + return nil, err + } + + return &entry, nil +} + +func (service *OIDCService) DeleteConsentEntry(ctx context.Context, uuid string) error { + return service.queries.DeleteOIDCConsentByUUID(ctx, uuid) +} + +func (service *OIDCService) UpdateConsentEntry(ctx context.Context, uuid string, scopes string) error { + _, err := service.queries.UpdateOIDCConsent(ctx, repository.UpdateOIDCConsentParams{ + UUID: uuid, + Scopes: scopes, + }) + + return err +} diff --git a/internal/test/test.go b/internal/test/test.go index 676501a49..d7ee6d6e3 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,6 +1,7 @@ package test import ( + "context" "path/filepath" "testing" @@ -173,3 +174,11 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { return config, runtime } + +func CreateTestHelpers() *model.RuntimeHelpers { + return &model.RuntimeHelpers{ + GetCookieDomain: func(ctx context.Context, ip string) (string, error) { + return "example.com", nil + }, + } +} diff --git a/sql/postgres/oidc_queries.sql b/sql/postgres/oidc_queries.sql index 3cd5ff99f..4442ef337 100644 --- a/sql/postgres/oidc_queries.sql +++ b/sql/postgres/oidc_queries.sql @@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET "userinfo_json" = $8 WHERE "sub" = $9 RETURNING *; + +-- name: CreateOIDCConsent :one +INSERT INTO "oidc_consent" ( + "uuid", + "client_id", + "scopes" +) VALUES ( + $1, $2, $3 +) +RETURNING *; + +-- name: GetOIDCConsentByUUID :one +SELECT * FROM "oidc_consent" +WHERE "uuid" = $1; + +-- name: UpdateOIDCConsent :one +UPDATE "oidc_consent" SET + "scopes" = $1, + "updated_at" = CURRENT_TIMESTAMP +WHERE "uuid" = $2 +RETURNING *; + +-- name: DeleteOIDCConsentByUUID :exec +DELETE FROM "oidc_consent" +WHERE "uuid" = $1; diff --git a/sql/postgres/oidc_schemas.sql b/sql/postgres/oidc_schemas.sql index 2376c1d4a..622650235 100644 --- a/sql/postgres/oidc_schemas.sql +++ b/sql/postgres/oidc_schemas.sql @@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" ( "nonce" TEXT NOT NULL DEFAULT '', "userinfo_json" TEXT NOT NULL ); + +CREATE TABLE IF NOT EXISTS "oidc_consent" ( + "uuid" TEXT NOT NULL UNIQUE PRIMARY KEY, + "client_id" TEXT NOT NULL, + "scopes" TEXT NOT NULL, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/sql/sqlite/oidc_queries.sql b/sql/sqlite/oidc_queries.sql index 49b33cff7..5ec9ea449 100644 --- a/sql/sqlite/oidc_queries.sql +++ b/sql/sqlite/oidc_queries.sql @@ -46,3 +46,28 @@ UPDATE "oidc_sessions" SET "userinfo_json" = ? WHERE "sub" = ? RETURNING *; + +-- name: CreateOIDCConsent :one +INSERT INTO "oidc_consent" ( + "uuid", + "client_id", + "scopes" +) VALUES ( + ?, ?, ? +) +RETURNING *; + +-- name: GetOIDCConsentByUUID :one +SELECT * FROM "oidc_consent" +WHERE "uuid" = ?; + +-- name: UpdateOIDCConsent :one +UPDATE "oidc_consent" SET + "scopes" = ?, + "updated_at" = CURRENT_TIMESTAMP +WHERE "uuid" = ? +RETURNING *; + +-- name: DeleteOIDCConsentByUUID :exec +DELETE FROM "oidc_consent" +WHERE "uuid" = ?; diff --git a/sql/sqlite/oidc_schemas.sql b/sql/sqlite/oidc_schemas.sql index 5a851033a..e5d3a0d34 100644 --- a/sql/sqlite/oidc_schemas.sql +++ b/sql/sqlite/oidc_schemas.sql @@ -9,3 +9,11 @@ CREATE TABLE IF NOT EXISTS "oidc_sessions" ( "nonce" TEXT NOT NULL DEFAULT "", "userinfo_json" TEXT NOT NULL ); + +CREATE TABLE IF NOT EXISTS "oidc_consent" ( + "uuid" TEXT NOT NULL UNIQUE PRIMARY KEY, + "client_id" TEXT NOT NULL, + "scopes" TEXT NOT NULL, + "created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +);