diff --git a/frontend/src/components/layout/layout.tsx b/frontend/src/components/layout/layout.tsx index e129092e..f9a01d88 100644 --- a/frontend/src/components/layout/layout.tsx +++ b/frontend/src/components/layout/layout.tsx @@ -40,11 +40,7 @@ export const Layout = () => { setIgnoreDomainWarning(true); }, [setIgnoreDomainWarning]); - if ( - !ignoreDomainWarning && - ui.warningsEnabled && - !app.trustedDomains.includes(currentUrl) - ) { + if (!ignoreDomainWarning && ui.warningsEnabled && currentUrl !== app.appUrl) { return ( { let isValid = false; let isTrusted = false; let isAllowedProto = false; let isHttpsDowngrade = false; + let appUrlObj: URL; + + try { + appUrlObj = new URL(appUrl); + } catch { + return { + valid: isValid, + trusted: isTrusted, + allowedProto: isAllowedProto, + httpsDowngrade: isHttpsDowngrade, + }; + } + if (!redirect_uri) { return { valid: isValid, @@ -39,10 +54,7 @@ export const useRedirectUri = ( isValid = true; - if ( - url.hostname == cookieDomain || - url.hostname.endsWith(`.${cookieDomain}`) - ) { + if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) { isTrusted = true; } @@ -62,3 +74,45 @@ export const useRedirectUri = ( httpsDowngrade: isHttpsDowngrade, }; }; + +// ported from internal/controller/oauth_controller.go +const getEffectivePort = (url: URL): string => { + if (url.port) { + return url.port; + } + + if (url.protocol == "https:") { + return "443"; + } + + return "80"; +}; + +const isTrustedDomain = ( + url: URL, + appUrl: URL, + cookieDomain: string, + subdomainsEnabled: boolean, +): boolean => { + if (url.protocol != appUrl.protocol) { + return false; + } + + if (getEffectivePort(url) != getEffectivePort(appUrl)) { + return false; + } + + if (url.hostname == appUrl.hostname) { + return true; + } + + if (!subdomainsEnabled) { + return false; + } + + if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) { + return true; + } + + return false; +}; diff --git a/frontend/src/pages/continue-page.tsx b/frontend/src/pages/continue-page.tsx index 3220ac99..a4e34fc5 100644 --- a/frontend/src/pages/continue-page.tsx +++ b/frontend/src/pages/continue-page.tsx @@ -37,6 +37,8 @@ export const ContinuePage = () => { const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri( redirectUri, app.cookieDomain, + app.appUrl, + app.subdomainsEnabled, ); const urlHref = url?.href; @@ -108,7 +110,11 @@ export const ContinuePage = () => { components={{ code: , }} - values={{ cookieDomain: app.cookieDomain }} + values={{ + cookieDomain: app.subdomainsEnabled + ? `.${app.cookieDomain}` + : app.cookieDomain, + }} shouldUnescape={true} /> diff --git a/frontend/src/schemas/app-context-schema.ts b/frontend/src/schemas/app-context-schema.ts index a91dda77..f8740a70 100644 --- a/frontend/src/schemas/app-context-schema.ts +++ b/frontend/src/schemas/app-context-schema.ts @@ -24,7 +24,7 @@ const uiSchema = z.object({ const appSchema = z.object({ appUrl: z.string(), cookieDomain: z.string(), - trustedDomains: z.array(z.string()), + subdomainsEnabled: z.boolean(), }); export const appContextSchema = z.object({ diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 50aa49ef..c24638f5 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -46,18 +46,17 @@ type Services struct { } type BootstrapApp struct { - config model.Config - runtime model.RuntimeConfig - services Services - log *logger.Logger - ctx context.Context - cancel context.CancelFunc - queries repository.Store - router *gin.Engine - db *sql.DB - ding *ding.Ding - listeners []Listener - dig *dig.Container + config model.Config + runtime model.RuntimeConfig + services Services + log *logger.Logger + ctx context.Context + cancel context.CancelFunc + queries repository.Store + router *gin.Engine + db *sql.DB + ding *ding.Ding + dig *dig.Container } func NewBootstrapApp(config model.Config) *BootstrapApp { @@ -98,8 +97,7 @@ func (app *BootstrapApp) Setup() error { return fmt.Errorf("failed to parse app url: %w", err) } - app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host - app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL) + app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host) // validate session config if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { @@ -144,15 +142,6 @@ func (app *BootstrapApp) Setup() error { provider.ClientSecret = secret provider.ClientSecretFile = "" - if provider.RedirectURL == "" { - provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id - } - - app.runtime.OAuthProviders[id] = provider - } - - // set presets for built-in providers - for id, provider := range app.runtime.OAuthProviders { if provider.Name == "" { if name, ok := model.OverrideProviders[id]; ok { provider.Name = name @@ -160,18 +149,16 @@ func (app *BootstrapApp) Setup() error { provider.Name = utils.Capitalize(id) } } + app.runtime.OAuthProviders[id] = provider } // cookie domain - cookieDomainResolver := utils.GetCookieDomain - if !app.config.Auth.SubdomainsEnabled { - app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") - cookieDomainResolver = utils.GetStandaloneCookieDomain + app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only") } - cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) + cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled) if err != nil { return fmt.Errorf("failed to get cookie domain: %w", err) @@ -286,9 +273,43 @@ func (app *BootstrapApp) Setup() error { app.runtime.ConfiguredProviders = configuredProviders - // throw in tailscale if it's configured just before setting up the controllers - if app.services.tailscaleService != nil { - app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname()) + // if tailscale is enabled and listening, replace the app url with the tailscale hostname + if app.services.tailscaleService != nil && app.config.Tailscale.Listen { + tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname() + + // if the tailscale url is different from the app url, replace it + if tailscaleUrl != app.runtime.AppURL { + app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname") + + app.runtime.AppURL = tailscaleUrl + + // also update cookie domain + cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled) + + if err != nil { + return fmt.Errorf("failed to get cookie domain: %w", err) + } + + app.runtime.CookieDomain = cookieDomain + } + } + + // force an update of the redirect urls for all oauth providers, if they are empty + services := app.services.oauthBrokerService.GetConfiguredServices() + + for _, service := range services { + oauthService, ok := app.services.oauthBrokerService.GetService(service) + + if !ok { + return fmt.Errorf("failed to get oauth service for provider %s", service) + } + + providerConfig := oauthService.GetConfig() + + if providerConfig.RedirectURL == "" { + providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service + oauthService.UpdateConfig(providerConfig) + } } // setup router @@ -308,19 +329,19 @@ func (app *BootstrapApp) Setup() error { app.ding.Go(app.heartbeatRoutine, ding.RingMinor) } - // setup listeners - app.listeners = app.calculateListenerPolicy() + // get listener + listenerFunc, err := app.getListenerFunc() - if app.config.Server.ConcurrentListenersEnabled { - app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners") + if err != nil { + return fmt.Errorf("failed to get listener function: %w", err) } - // run listeners - lec, err := app.runListeners() + // run listener + lec := make(chan error, 1) - if err != nil { - return fmt.Errorf("failed to run listeners: %w", err) - } + app.ding.Go(func(ctx context.Context) { + lec <- listenerFunc(ctx) + }, ding.RingNormal) // monitor cancellation and server errors for { diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 636840d6..703d0442 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -9,7 +9,6 @@ import ( "os" "time" - "github.com/steveiliop56/ding" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/model" @@ -18,14 +17,6 @@ import ( "github.com/gin-gonic/gin" ) -type Listener int - -const ( - ListenerHTTP Listener = iota - ListenerUnix - ListenerTailscale -) - func (app *BootstrapApp) setupRouter() error { // we don't want gin debug mode gin.SetMode(gin.ReleaseMode) @@ -134,79 +125,29 @@ func (app *BootstrapApp) setupRouter() error { return nil } -func (app *BootstrapApp) runListeners() (chan error, error) { - // lec -> listener error channel - lec := make(chan error, len(app.listeners)) - - for _, listenerType := range app.listeners { - listenerFunc, err := app.listenerFromType(listenerType) - - if err != nil { - return nil, fmt.Errorf("failed to get listener function: %w", err) +// Top down +// 1. Tailscale (if tailscale.listen) +// 2. Unix socket (if server.socketPath) +// 3. HTTP - default +func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) { + if app.config.Tailscale.Listen { + if app.services.tailscaleService == nil { + return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized") } - - app.ding.Go(func(ctx context.Context) { - lec <- listenerFunc(ctx) - }, ding.RingNormal) - } - - return lec, nil -} - -// The way we calculate listeners is as follows: -// If concurrent listeners are disabled, we pick the first available listener, so: -// 1. If tailscale is enabled, we use tailscale -// 2. If socket path is configured, we use unix socket -// 3. Finally if none is configured we use http -// If concurrent listeners are enabled, we add all available listeners in the following order -func (app *BootstrapApp) calculateListenerPolicy() []Listener { - l := []Listener{} - - if !app.config.Server.ConcurrentListenersEnabled { - if app.services.tailscaleService != nil { - l = append(l, ListenerTailscale) - return l - } - - if app.config.Server.SocketPath != "" { - l = append(l, ListenerUnix) - return l - } - - l = append(l, ListenerHTTP) - return l + return app.serveTailscale, nil } if app.config.Server.SocketPath != "" { - l = append(l, ListenerUnix) - } - - if app.services.tailscaleService != nil { - l = append(l, ListenerTailscale) - } - - l = append(l, ListenerHTTP) - - return l -} - -func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) { - switch listenerType { - case ListenerHTTP: - return app.serveHTTP, nil - case ListenerUnix: return app.serveUnix, nil - case ListenerTailscale: - return app.serveTailscale, nil - default: - return nil, fmt.Errorf("invalid listener type: %d", listenerType) } + + return app.serveHTTP, nil } func (app *BootstrapApp) serveHTTP(ctx context.Context) error { address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) - app.log.App.Info().Msgf("Starting server on %s", address) + app.log.App.Info().Msgf("Starting server on http://%s", address) listener, err := net.Listen("tcp", address) diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 32574c99..abfabaad 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -1,6 +1,8 @@ package controller import ( + "errors" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" "go.uber.org/dig" @@ -58,9 +60,9 @@ type ACRUI struct { } type ACRApp struct { - AppURL string `json:"appUrl"` - CookieDomain string `json:"cookieDomain"` - TrustedDomains []string `json:"trustedDomains"` + AppURL string `json:"appUrl"` + CookieDomain string `json:"cookieDomain"` + SubdomainsEnabled bool `json:"subdomainsEnabled"` } type AppContextResponse struct { @@ -109,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to create user context from request") + if !errors.Is(err, model.ErrUserContextNotFound) { + controller.log.App.Error().Err(err).Msg("Failed to create user context from request") + } c.JSON(200, UserContextResponse{ Status: 401, Message: "Unauthorized", @@ -160,9 +164,9 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { WarningsEnabled: controller.config.UI.WarningsEnabled, }, App: ACRApp{ - AppURL: controller.runtime.AppURL, - CookieDomain: controller.runtime.CookieDomain, - TrustedDomains: controller.runtime.TrustedDomains, + AppURL: controller.runtime.AppURL, + CookieDomain: controller.runtime.CookieDomain, + SubdomainsEnabled: controller.config.Auth.SubdomainsEnabled, }, }) } diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 708824c1..2a3bc545 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -48,9 +48,9 @@ func TestContextController(t *testing.T) { WarningsEnabled: cfg.UI.WarningsEnabled, }, App: ACRApp{ - AppURL: runtime.AppURL, - CookieDomain: runtime.CookieDomain, - TrustedDomains: runtime.TrustedDomains, + AppURL: runtime.AppURL, + CookieDomain: runtime.CookieDomain, + SubdomainsEnabled: cfg.Auth.SubdomainsEnabled, }, } bytes, err := json.Marshal(expectedAppContextResponse) diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index e01bf480..27fca206 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -12,7 +12,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" - "github.com/weppos/publicsuffix-go/publicsuffix" "go.uber.org/dig" "github.com/gin-gonic/gin" @@ -305,8 +304,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthCallbackPar } func (controller *OAuthController) getCookieDomain() string { - if controller.config.Auth.SubdomainsEnabled { - return "." + controller.runtime.CookieDomain + if !controller.config.Auth.SubdomainsEnabled { + return "" } return controller.runtime.CookieDomain } @@ -314,51 +313,53 @@ func (controller *OAuthController) getCookieDomain() string { func (controller *OAuthController) isRedirectSafe(redirectURI string) bool { u, err := url.Parse(redirectURI) - if err != nil || u.Host == "" || u.Scheme == "" { + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to parse redirect URI") return false } - for _, allowed := range controller.runtime.TrustedDomains { - tu, err := url.Parse(allowed) - if err != nil { - controller.log.App.Error().Err(err).Str("allowed", allowed).Msg("Failed to parse trusted domain") - continue - } + if u.Scheme == "" || u.Host == "" { + controller.log.App.Warn().Msg("Redirect URI has invalid scheme or host") + return false + } - if tu.Scheme != u.Scheme { - continue - } + au, err := url.Parse(controller.runtime.AppURL) - // exact match - if strings.EqualFold(u.Host, tu.Host) { - return true - } + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to parse app URL") + return false + } - // if subdomains are disabled, end here - if !controller.config.Auth.SubdomainsEnabled { - continue - } + if u.Scheme != au.Scheme { + controller.log.App.Warn().Msg("Redirect URI scheme does not match app URL scheme") + return false + } - // get the root domain (e.g. tinyauth.example.com -> example.com or - // tinyauth.sub.example.com -> sub.example.com) - _, root, ok := strings.Cut(tu.Host, ".") - if !ok { - continue + getEffectivePort := func(u *url.URL) string { + if u.Port() != "" { + return u.Port() + } + if u.Scheme == "https" { + return "443" } + return "80" + } - root = strings.ToLower(root) + if getEffectivePort(u) != getEffectivePort(au) { + controller.log.App.Warn().Msg("Redirect URI port does not match app URL port") + return false + } - // check if the root domain is in the psl - _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, root, nil) + if strings.EqualFold(u.Hostname(), au.Hostname()) { + return true + } - if err != nil { - continue - } + if !controller.config.Auth.SubdomainsEnabled { + return false + } - // subdomain match - if strings.HasSuffix(strings.ToLower(u.Host), "."+root) { - return true - } + if strings.HasSuffix(strings.ToLower(u.Hostname()), "."+strings.ToLower(controller.runtime.CookieDomain)) { + return true } return false diff --git a/internal/controller/oauth_controller_test.go b/internal/controller/oauth_controller_test.go index fb9db6db..1e3b8aec 100644 --- a/internal/controller/oauth_controller_test.go +++ b/internal/controller/oauth_controller_test.go @@ -9,7 +9,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -func TestOAuthController(t *testing.T) { +func TestOAuthControllerIsRedirectSafe(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() @@ -17,145 +17,171 @@ func TestOAuthController(t *testing.T) { type testCase struct { description string - run func(ctrl *OAuthController) - trustedDomains []string + appURL string + cookieDomain string subdomainsEnabled bool + redirectURI string + expected bool } tests := []testCase{ { - description: "Test exact match of redirect URI", - trustedDomains: []string{"https://tinyauth.example.com"}, + description: "Exact host match returns true", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://tinyauth.example.com" - assert.True(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https://tinyauth.example.com", + expected: true, }, { - description: "Test subdomain match of redirect URI", - trustedDomains: []string{"https://tinyauth.example.com"}, + description: "Exact host match is case insensitive", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://sub.example.com" - assert.True(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https://TinyAuth.Example.com", + expected: true, }, { - description: "Test different trusted domain", - trustedDomains: []string{"https://tinyauth.example.com", "https://tinyauth.foo.com"}, + description: "Exact host match with subdomains disabled returns true", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", + subdomainsEnabled: false, + redirectURI: "https://tinyauth.example.com", + expected: true, + }, + { + description: "Subdomain of cookie domain returns true when subdomains enabled", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://app.foo.com" - assert.True(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https://sub.example.com", + expected: true, }, { - description: "Test invalid redirect URI", - run: func(ctrl *OAuthController) { - redirectUri := "https:/malicious" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + description: "Subdomain of cookie domain is case insensitive", + appURL: "https://tinyauth.example.com", + cookieDomain: "Example.COM", + subdomainsEnabled: true, + redirectURI: "https://SUB.example.com", + expected: true, }, { - description: "Test empty redirect URI", - run: func(ctrl *OAuthController) { - redirectUri := "" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + description: "Subdomain not matching cookie domain returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", + subdomainsEnabled: true, + redirectURI: "https://sub.evil.com", + expected: false, }, { - description: "Test redirect URI with different scheme", - trustedDomains: []string{"https://tinyauth.example.com"}, + description: "Subdomain returns false when subdomains disabled", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", + subdomainsEnabled: false, + redirectURI: "https://sub.example.com", + expected: false, + }, + { + description: "Cookie domain itself is not a subdomain match", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "http://tinyauth.example.com" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https://example.com", + expected: false, }, { - description: "Test redirect URI with different port", - trustedDomains: []string{"https://tinyauth.example.com"}, + description: "Different scheme returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://tinyauth.example.com:8080" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "http://tinyauth.example.com", + expected: false, }, { - // weird case, subdomains enabled and domain without subdomain can't happen - description: "Test with trusted domain that's in PSL when split", - trustedDomains: []string{"https://example.com"}, // will become .com which we - // obviously don't want to allow + description: "Different port returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://sub.example.com" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https://tinyauth.example.com:8080", + expected: false, }, { - description: "Test subdomain redirect URI when subdomains are disabled", - trustedDomains: []string{"https://tinyauth.example.com"}, - subdomainsEnabled: false, - run: func(ctrl *OAuthController) { - redirectUri := "https://sub.tinyauth.example.com" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + description: "Empty redirect URI returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", + subdomainsEnabled: true, + redirectURI: "", + expected: false, }, { - description: "Test domain like the .co.uk", - trustedDomains: []string{"https://example.co.uk"}, + description: "Redirect URI without host returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://sub.example.co.uk" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https:/malicious", + expected: false, }, { - description: "Test domain like the .co.uk with subdomains disabled", - trustedDomains: []string{"https://example.co.uk"}, - subdomainsEnabled: false, - run: func(ctrl *OAuthController) { - redirectUri := "https://example.co.uk" - assert.True(t, ctrl.isRedirectSafe(redirectUri)) - }, + description: "Redirect URI without scheme returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", + subdomainsEnabled: true, + redirectURI: "tinyauth.example.com", + expected: false, + }, + { + description: "Relative redirect URI returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", + subdomainsEnabled: true, + redirectURI: "/some/path", + expected: false, + }, + { + description: "Userinfo trick with malicious host returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", + subdomainsEnabled: true, + redirectURI: "https://malicious.example.com@evil.com", + expected: false, }, { - description: "Test caps domain", - trustedDomains: []string{"https://TINYAUTH.ExAmpLe.com"}, + description: "Unparseable redirect URI returns false", + appURL: "https://tinyauth.example.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://sUb.ExAmPle.com" - assert.True(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https://exa\x7fmple.com", + expected: false, }, { - description: "Test edge case with @", - trustedDomains: []string{"https://tinyauth.example.com"}, + description: "Unparseable app URL returns false", + appURL: "https://tinyauth.\x7fexample.com", + cookieDomain: "example.com", subdomainsEnabled: true, - run: func(ctrl *OAuthController) { - redirectUri := "https://malicious.example.com@evil.com" - assert.False(t, ctrl.isRedirectSafe(redirectUri)) - }, + redirectURI: "https://tinyauth.example.com", + expected: false, }, } - // TODO: add auth service for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { router := gin.Default() group := router.Group("/api") gin.SetMode(gin.TestMode) - // overwrite the trusted domains and subdomain setting for each test case - runtime.TrustedDomains = tc.trustedDomains + + // Overwrite the app URL, cookie domain and subdomain setting for each test case + runtime.AppURL = tc.appURL + runtime.CookieDomain = tc.cookieDomain cfg.Auth.SubdomainsEnabled = tc.subdomainsEnabled + ctrl := NewOAuthController(OAuthControllerInput{ Log: log, Config: &cfg, RuntimeConfig: &runtime, RouterGroup: group, }) - tc.run(ctrl) + + assert.Equal(t, tc.expected, ctrl.isRedirectSafe(tc.redirectURI)) }) } } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index f17b7d79..ae6c23bf 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -295,6 +295,14 @@ func (controller *UserController) totpHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { + if errors.Is(err, model.ErrUserContextNotFound) { + controller.log.App.Warn().Msg("TOTP verification attempt without user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") c.JSON(500, gin.H{ "status": 500, @@ -405,6 +413,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { + if errors.Is(err, model.ErrUserContextNotFound) { + controller.log.App.Warn().Msg("Tailscale login attempt without user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } controller.log.App.Error().Err(err).Msg("Failed to create user context from request") c.JSON(401, gin.H{ "status": 401, diff --git a/internal/model/config.go b/internal/model/config.go index 2de389a0..23648794 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -15,9 +15,8 @@ func NewDefaultConfiguration() *Config { Path: "./resources", }, Server: ServerConfig{ - Port: 3000, - Address: "0.0.0.0", - ConcurrentListenersEnabled: false, + Port: 3000, + Address: "0.0.0.0", }, Auth: AuthConfig{ SubdomainsEnabled: true, @@ -104,10 +103,9 @@ type ResourcesConfig struct { } type ServerConfig struct { - Port int `description:"The port on which the server listens." yaml:"port"` - Address string `description:"The address on which the server listens." yaml:"address"` - SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` - ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"` + Port int `description:"The port on which the server listens." yaml:"port"` + Address string `description:"The address on which the server listens." yaml:"address"` + SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` } type AuthConfig struct { @@ -218,6 +216,8 @@ type TailscaleConfig struct { Hostname string `description:"Tailscale hostname." yaml:"hostname"` AuthKey string `description:"Tailscale auth key." yaml:"authKey"` Ephemeral bool `description:"Use ephemeral Tailscale node." yaml:"ephemeral"` + Funnel bool `description:"Enable Tailscale Funnel." yaml:"funnel"` + Listen bool `description:"Listen on the Tailscale address instead of standard address." yaml:"listen"` } // OAuth/OIDC config diff --git a/internal/model/runtime.go b/internal/model/runtime.go index 0df99901..e1c034d3 100644 --- a/internal/model/runtime.go +++ b/internal/model/runtime.go @@ -12,7 +12,6 @@ type RuntimeConfig struct { OAuthProviders map[string]OAuthServiceConfig OAuthWhitelist []string ConfiguredProviders []Provider - TrustedDomains []string } type Provider struct { diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 5e79ff75..eeb5c8e1 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -46,7 +46,7 @@ type OAuthPendingSession struct { State string Verifier string Token *oauth2.Token - Service *OAuthServiceImpl + Service IOAuthService ExpiresAt time.Time CallbackParams OAuthCallbackParams } @@ -380,33 +380,11 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess return nil, fmt.Errorf("failed to create session entry: %w", err) } - if data.Provider == "tailscale" { - auth.log.App.Trace().Str("url", fmt.Sprintf("https://%s", auth.tailscale.GetHostname())).Msg("Extracting root domain from Tailscale hostname") - - tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", auth.tailscale.GetHostname())) - - if err != nil { - return nil, fmt.Errorf("failed to get cookie domain for tailscale user: %w", err) - } - - return &http.Cookie{ - Name: auth.runtime.SessionCookieName, - Value: session.UUID, - Path: "/", - Domain: fmt.Sprintf(".%s", tsCookieDomain), - Expires: expiresAt, - MaxAge: int(time.Until(expiresAt).Seconds()), - Secure: auth.config.Auth.SecureCookie, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - }, nil - } - return &http.Cookie{ Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Domain: auth.getCookieDomain(), Expires: expiresAt, MaxAge: int(time.Until(expiresAt).Seconds()), Secure: auth.config.Auth.SecureCookie, @@ -459,7 +437,7 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Domain: auth.getCookieDomain(), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), MaxAge: int(newExpiry - currentTime), Secure: auth.config.Auth.SecureCookie, @@ -480,7 +458,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. Name: auth.runtime.SessionCookieName, Value: "", Path: "/", - Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), + Domain: auth.getCookieDomain(), Expires: time.Now(), MaxAge: -1, Secure: auth.config.Auth.SecureCookie, @@ -549,7 +527,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac session := OAuthPendingSession{ State: state, Verifier: verifier, - Service: &service, + Service: service, ExpiresAt: time.Now().Add(1 * time.Hour), CallbackParams: params, } @@ -566,7 +544,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { return "", err } - return (*session.Service).GetAuthURL(session.State, session.Verifier), nil + return session.Service.GetAuthURL(session.State, session.Verifier), nil } func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { @@ -576,7 +554,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T return nil, fmt.Errorf("oauth session not found: %s", sessionId) } - token, err := (*session.Service).GetToken(code, session.Verifier) + token, err := session.Service.GetToken(code, session.Verifier) if err != nil { return nil, fmt.Errorf("failed to exchange code for token: %w", err) @@ -605,7 +583,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) } - userinfo, err := (*session.Service).GetUserinfo(session.Token) + userinfo, err := session.Service.GetUserinfo(session.Token) if err != nil { return nil, fmt.Errorf("failed to get userinfo: %w", err) @@ -614,14 +592,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro return userinfo, nil } -func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) { +func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) { session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { return nil, err } - return *session.Service, nil + return session.Service, nil } func (auth *AuthService) EndOAuthSession(sessionId string) { @@ -726,3 +704,10 @@ func (auth *AuthService) calculateLockdownLimit() int { return limit } + +func (auth *AuthService) getCookieDomain() string { + if !auth.config.Auth.SubdomainsEnabled { + return "" + } + return auth.runtime.CookieDomain +} diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 63503abc..4df0e825 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -12,19 +12,21 @@ import ( "golang.org/x/oauth2" ) -type OAuthServiceImpl interface { +type IOAuthService interface { Name() string ID() string NewRandom() string - GetAuthURL(state string, verifier string) string - GetToken(code string, verifier string) (*oauth2.Token, error) + GetAuthURL(state, verifier string) string + GetToken(code, verifier string) (*oauth2.Token, error) GetUserinfo(token *oauth2.Token) (*model.Claims, error) + GetConfig() model.OAuthServiceConfig + UpdateConfig(config model.OAuthServiceConfig) } type OAuthBrokerService struct { log *logger.Logger - services map[string]OAuthServiceImpl + services map[string]IOAuthService configs map[string]model.OAuthServiceConfig } @@ -44,7 +46,7 @@ type OAuthBrokerServiceInput struct { func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService { service := &OAuthBrokerService{ log: i.Log, - services: make(map[string]OAuthServiceImpl), + services: make(map[string]IOAuthService), configs: i.Runtime.OAuthProviders, } @@ -70,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string { return services } -func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) { +func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) { service, exists := broker.services[name] return service, exists } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 07d0e1cc..888614ec 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string { return random } -func (s *OAuthService) GetAuthURL(state string, verifier string) string { +func (s *OAuthService) GetAuthURL(state, verifier string) string { return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier)) } @@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) { client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) } + +func (s *OAuthService) GetConfig() model.OAuthServiceConfig { + return s.serviceCfg +} + +func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) { + s.serviceCfg = config + s.config.ClientID = config.ClientID + s.config.ClientSecret = config.ClientSecret + s.config.Scopes = config.Scopes + s.config.Endpoint.AuthURL = config.AuthURL + s.config.Endpoint.TokenURL = config.TokenURL + s.config.RedirectURL = config.RedirectURL +} diff --git a/internal/service/tailscale_service.go b/internal/service/tailscale_service.go index 7a1be1e0..183f6f27 100644 --- a/internal/service/tailscale_service.go +++ b/internal/service/tailscale_service.go @@ -94,6 +94,10 @@ func NewTailscaleService(i TailscaleServiceInput) (*TailscaleService, error) { i.Ding.Go(service.watchAndClose, ding.RingMajor) + if i.Config.Tailscale.Funnel && !i.Config.Tailscale.Listen { + service.log.App.Warn().Msg("Tailscale Funnel is enabled but listen is disabled. Funnel will not work without listen enabled.") + } + return service, nil } @@ -148,6 +152,16 @@ func (ts *TailscaleService) CreateListener() (net.Listener, error) { if ts.ln != nil { return *ts.ln, nil } + + if ts.config.Tailscale.Funnel { + ln, err := ts.srv.ListenFunnel("tcp", ":443") + if err != nil { + return nil, err + } + ts.ln = &ln + return ln, nil + } + ln, err := ts.srv.ListenTLS("tcp", ":443") if err != nil { return nil, err diff --git a/internal/test/test.go b/internal/test/test.go index 676501a4..a3f07ca0 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -43,6 +43,7 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { ACLs: model.ACLsConfig{ Policy: "allow", }, + SubdomainsEnabled: true, }, Database: model.DatabaseConfig{ Path: filepath.Join(tempDir, "test.db"), @@ -165,10 +166,6 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { CookieDomain: "example.com", AppURL: "https://tinyauth.example.com", SessionCookieName: "tinyauth-session", - TrustedDomains: []string{ - "https://tinyauth.example.com", - "https://tinyauth.foo.com", - }, } return config, runtime diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index 777e380d..00adf246 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -1,7 +1,7 @@ package utils import ( - "errors" + "fmt" "net" "net/url" "strings" @@ -10,26 +10,33 @@ import ( ) // Get cookie domain parses a hostname and returns the upper domain (e.g. sub1.sub2.domain.com -> sub2.domain.com) -func GetCookieDomain(u string) (string, error) { - parsed, err := url.Parse(u) +func GetCookieDomain(appUrl string, subdomainsEnabled bool) (string, error) { + u, err := url.Parse(appUrl) + if err != nil { - return "", err + return "", fmt.Errorf("invalid app url: %w", err) } - host := parsed.Hostname() + hostname := strings.ToLower(u.Hostname()) - if netIP := net.ParseIP(host); netIP != nil { - return "", errors.New("ip addresses not allowed") + if netIP := net.ParseIP(hostname); netIP != nil { + return "", fmt.Errorf("ip addresses not allowed") } - parts := strings.Split(host, ".") + parts := strings.Split(hostname, ".") - if len(parts) == 2 { - return host, nil + if len(parts) < 2 { + return "", fmt.Errorf("invalid app url, must be in format subdomain.domain.tld or domain.tld") } - if len(parts) < 3 { - return "", errors.New("invalid app url, must be at least second level domain") + if !subdomainsEnabled || len(parts) == 2 { + _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, hostname, nil) + + if err != nil { + return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err) + } + + return hostname, nil } domain := strings.Join(parts[1:], ".") @@ -37,33 +44,12 @@ func GetCookieDomain(u string) (string, error) { _, err = publicsuffix.DomainFromListWithOptions(publicsuffix.DefaultList, domain, nil) if err != nil { - return "", errors.New("domain in public suffix list, cannot set cookies") + return "", fmt.Errorf("domain in public suffix list, cannot set cookies: %w", err) } return domain, nil } -func GetStandaloneCookieDomain(u string) (string, error) { - parsed, err := url.Parse(u) - if err != nil { - return "", err - } - - host := parsed.Hostname() - - if netIP := net.ParseIP(host); netIP != nil { - return "", errors.New("ip addresses not allowed") - } - - parts := strings.Split(host, ".") - - if len(parts) < 2 { - return "", errors.New("invalid app url") - } - - return host, nil -} - func ParseFileToLine(content string) string { lines := strings.Split(content, "\n") users := make([]string, 0) diff --git a/internal/utils/app_utils_test.go b/internal/utils/app_utils_test.go index f0c3625c..e4525335 100644 --- a/internal/utils/app_utils_test.go +++ b/internal/utils/app_utils_test.go @@ -11,50 +11,71 @@ func TestGetRootDomain(t *testing.T) { // Normal case domain := "http://sub.tinyauth.app" expected := "tinyauth.app" - result, err := utils.GetCookieDomain(domain) + result, err := utils.GetCookieDomain(domain, true) assert.NoError(t, err) assert.Equal(t, expected, result) // Domain with multiple subdomains domain = "http://b.c.tinyauth.app" expected = "c.tinyauth.app" - result, err = utils.GetCookieDomain(domain) + result, err = utils.GetCookieDomain(domain, true) assert.NoError(t, err) assert.Equal(t, expected, result) // Invalid domain (only TLD) domain = "com" - _, err = utils.GetCookieDomain(domain) - assert.ErrorContains(t, err, "invalid app url, must be at least second level domain") + _, err = utils.GetCookieDomain(domain, true) + assert.EqualError(t, err, "invalid app url, must be in format subdomain.domain.tld or domain.tld") // IP address domain = "http://10.10.10.10" - _, err = utils.GetCookieDomain(domain) + _, err = utils.GetCookieDomain(domain, true) assert.ErrorContains(t, err, "ip addresses not allowed") // Invalid URL domain = "http://[::1]:namedport" - _, err = utils.GetCookieDomain(domain) + _, err = utils.GetCookieDomain(domain, true) assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") // URL with scheme and path domain = "https://sub.tinyauth.app/path" expected = "tinyauth.app" - result, err = utils.GetCookieDomain(domain) + result, err = utils.GetCookieDomain(domain, true) assert.NoError(t, err) assert.Equal(t, expected, result) // URL with port domain = "http://sub.tinyauth.app:8080" expected = "tinyauth.app" - result, err = utils.GetCookieDomain(domain) + result, err = utils.GetCookieDomain(domain, true) assert.NoError(t, err) assert.Equal(t, expected, result) // Domain managed by ICANN domain = "http://example.co.uk" - _, err = utils.GetCookieDomain(domain) + _, err = utils.GetCookieDomain(domain, true) assert.Error(t, err, "domain in public suffix list, cannot set cookies") + + // Domain without subdomain + domain = "http://tinyauth.app" + expected = "tinyauth.app" + result, err = utils.GetCookieDomain(domain, true) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // Case insensitivity + domain = "http://Sub.Tinyauth.App" + expected = "tinyauth.app" + result, err = utils.GetCookieDomain(domain, true) + assert.NoError(t, err) + assert.Equal(t, expected, result) + + // Subdomains disabled + domain = "http://sub.tinyauth.app" + expected = "sub.tinyauth.app" + result, err = utils.GetCookieDomain(domain, false) + assert.NoError(t, err) + assert.Equal(t, expected, result) } func TestParseFileToLine(t *testing.T) { @@ -125,48 +146,3 @@ func TestFilter(t *testing.T) { resultStr := utils.Filter(sliceStr, testFuncStr) assert.Equal(t, expectedStr, resultStr) } - -func TestGetStandaloneCookieDomain(t *testing.T) { - // Normal case - domain := "http://tinyauth.app" - expected := "tinyauth.app" - result, err := utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // URL with subdomain (full hostname is returned, no subdomain stripping) - domain = "http://sub.tinyauth.app" - expected = "sub.tinyauth.app" - result, err = utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // URL with port (port should be stripped) - domain = "http://tinyauth.app:8080" - expected = "tinyauth.app" - result, err = utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // URL with path - domain = "https://tinyauth.app/some/path" - expected = "tinyauth.app" - result, err = utils.GetStandaloneCookieDomain(domain) - assert.NoError(t, err) - assert.Equal(t, expected, result) - - // IP address - domain = "http://10.10.10.10" - _, err = utils.GetStandaloneCookieDomain(domain) - assert.ErrorContains(t, err, "ip addresses not allowed") - - // Invalid domain (only TLD) - domain = "com" - _, err = utils.GetStandaloneCookieDomain(domain) - assert.ErrorContains(t, err, "invalid app url") - - // Invalid URL - domain = "http://[::1]:namedport" - _, err = utils.GetStandaloneCookieDomain(domain) - assert.ErrorContains(t, err, "parse \"http://[::1]:namedport\": invalid port \":namedport\" after host") -}