Skip to content
Open
6 changes: 1 addition & 5 deletions frontend/src/components/layout/layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ export const Layout = () => {
setIgnoreDomainWarning(true);
}, [setIgnoreDomainWarning]);

if (
!ignoreDomainWarning &&
ui.warningsEnabled &&
!app.trustedDomains.includes(currentUrl)
) {
if (!ignoreDomainWarning && ui.warningsEnabled && currentUrl !== app.appUrl) {
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return (
<BaseLayout>
<DomainWarning
Expand Down
62 changes: 58 additions & 4 deletions frontend/src/lib/hooks/redirect-uri.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,27 @@ type IuseRedirectUri = {
export const useRedirectUri = (
redirect_uri: string | undefined,
cookieDomain: string,
appUrl: string,
subdomainsEnabled: boolean,
): IuseRedirectUri => {
let isValid = false;
let isTrusted = false;
let isAllowedProto = false;
let isHttpsDowngrade = false;

let appUrlObj: URL;

try {
appUrlObj = new URL(appUrl);
} catch {
return {
valid: isValid,
trusted: isTrusted,
allowedProto: isAllowedProto,
httpsDowngrade: isHttpsDowngrade,
};
}

if (!redirect_uri) {
return {
valid: isValid,
Expand All @@ -39,10 +54,7 @@ export const useRedirectUri = (

isValid = true;

if (
url.hostname == cookieDomain ||
url.hostname.endsWith(`.${cookieDomain}`)
) {
if (isTrustedDomain(url, appUrlObj, cookieDomain, subdomainsEnabled)) {
isTrusted = true;
}

Expand All @@ -62,3 +74,45 @@ export const useRedirectUri = (
httpsDowngrade: isHttpsDowngrade,
};
};

// ported from internal/controller/oauth_controller.go
const getEffectivePort = (url: URL): string => {
if (url.port) {
return url.port;
}

if (url.protocol == "https:") {
return "443";
}

return "80";
};

const isTrustedDomain = (
url: URL,
appUrl: URL,
cookieDomain: string,
subdomainsEnabled: boolean,
): boolean => {
if (url.protocol != appUrl.protocol) {
return false;
}

if (getEffectivePort(url) != getEffectivePort(appUrl)) {
return false;
}

if (url.hostname == appUrl.hostname) {
return true;
}

if (!subdomainsEnabled) {
return false;
}

if (url.hostname.endsWith("." + cookieDomain.toLowerCase())) {
return true;
}

return false;
};
8 changes: 7 additions & 1 deletion frontend/src/pages/continue-page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ export const ContinuePage = () => {
const { url, valid, trusted, allowedProto, httpsDowngrade } = useRedirectUri(
redirectUri,
app.cookieDomain,
app.appUrl,
app.subdomainsEnabled,
);

const urlHref = url?.href;
Expand Down Expand Up @@ -108,7 +110,11 @@ export const ContinuePage = () => {
components={{
code: <code />,
}}
values={{ cookieDomain: app.cookieDomain }}
values={{
cookieDomain: app.subdomainsEnabled
? `.${app.cookieDomain}`
: app.cookieDomain,
}}
shouldUnescape={true}
/>
</CardDescription>
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/schemas/app-context-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const uiSchema = z.object({
const appSchema = z.object({
appUrl: z.string(),
cookieDomain: z.string(),
trustedDomains: z.array(z.string()),
subdomainsEnabled: z.boolean(),
});

export const appContextSchema = z.object({
Expand Down
101 changes: 61 additions & 40 deletions internal/bootstrap/app_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,17 @@ type Services struct {
}

type BootstrapApp struct {
config model.Config
runtime model.RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries repository.Store
router *gin.Engine
db *sql.DB
ding *ding.Ding
listeners []Listener
dig *dig.Container
config model.Config
runtime model.RuntimeConfig
services Services
log *logger.Logger
ctx context.Context
cancel context.CancelFunc
queries repository.Store
router *gin.Engine
db *sql.DB
ding *ding.Ding
dig *dig.Container
}

func NewBootstrapApp(config model.Config) *BootstrapApp {
Expand Down Expand Up @@ -98,8 +97,7 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to parse app url: %w", err)
}

app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, app.runtime.AppURL)
app.runtime.AppURL = strings.ToLower(appUrl.Scheme + "://" + appUrl.Host)

// validate session config
if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry {
Expand Down Expand Up @@ -144,34 +142,23 @@ func (app *BootstrapApp) Setup() error {
provider.ClientSecret = secret
provider.ClientSecretFile = ""

if provider.RedirectURL == "" {
provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id
}

app.runtime.OAuthProviders[id] = provider
}

// set presets for built-in providers
for id, provider := range app.runtime.OAuthProviders {
if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok {
provider.Name = name
} else {
provider.Name = utils.Capitalize(id)
}
}

app.runtime.OAuthProviders[id] = provider
}

// cookie domain
cookieDomainResolver := utils.GetCookieDomain

if !app.config.Auth.SubdomainsEnabled {
app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains")
cookieDomainResolver = utils.GetStandaloneCookieDomain
app.log.App.Warn().Msg("Subdomains are disabled, cookies will be set for the current domain only")
}

cookieDomain, err := cookieDomainResolver(app.runtime.AppURL)
cookieDomain, err := utils.GetCookieDomain(app.runtime.AppURL, app.config.Auth.SubdomainsEnabled)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
Expand Down Expand Up @@ -286,9 +273,43 @@ func (app *BootstrapApp) Setup() error {

app.runtime.ConfiguredProviders = configuredProviders

// throw in tailscale if it's configured just before setting up the controllers
if app.services.tailscaleService != nil {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
// if tailscale is enabled and listening, replace the app url with the tailscale hostname
if app.services.tailscaleService != nil && app.config.Tailscale.Listen {
tailscaleUrl := "https://" + app.services.tailscaleService.GetHostname()

// if the tailscale url is different from the app url, replace it
if tailscaleUrl != app.runtime.AppURL {
app.log.App.Info().Msg("Listening on tailscale, replacing app url with tailscale hostname")

app.runtime.AppURL = tailscaleUrl

// also update cookie domain
cookieDomain, err := utils.GetCookieDomain(tailscaleUrl, app.config.Auth.SubdomainsEnabled)

if err != nil {
return fmt.Errorf("failed to get cookie domain: %w", err)
}

app.runtime.CookieDomain = cookieDomain
}
}

// force an update of the redirect urls for all oauth providers, if they are empty
services := app.services.oauthBrokerService.GetConfiguredServices()

for _, service := range services {
oauthService, ok := app.services.oauthBrokerService.GetService(service)

if !ok {
return fmt.Errorf("failed to get oauth service for provider %s", service)
}

providerConfig := oauthService.GetConfig()

if providerConfig.RedirectURL == "" {
providerConfig.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + service
oauthService.UpdateConfig(providerConfig)
}
}

// setup router
Expand All @@ -308,19 +329,19 @@ func (app *BootstrapApp) Setup() error {
app.ding.Go(app.heartbeatRoutine, ding.RingMinor)
}

// setup listeners
app.listeners = app.calculateListenerPolicy()
// get listener
listenerFunc, err := app.getListenerFunc()

if app.config.Server.ConcurrentListenersEnabled {
app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners")
if err != nil {
return fmt.Errorf("failed to get listener function: %w", err)
}

// run listeners
lec, err := app.runListeners()
// run listener
lec := make(chan error, 1)

if err != nil {
return fmt.Errorf("failed to run listeners: %w", err)
}
app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)

// monitor cancellation and server errors
for {
Expand Down
83 changes: 12 additions & 71 deletions internal/bootstrap/router_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"os"
"time"

"github.com/steveiliop56/ding"
"github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
Expand All @@ -18,14 +17,6 @@ import (
"github.com/gin-gonic/gin"
)

type Listener int

const (
ListenerHTTP Listener = iota
ListenerUnix
ListenerTailscale
)

func (app *BootstrapApp) setupRouter() error {
// we don't want gin debug mode
gin.SetMode(gin.ReleaseMode)
Expand Down Expand Up @@ -134,79 +125,29 @@ func (app *BootstrapApp) setupRouter() error {
return nil
}

func (app *BootstrapApp) runListeners() (chan error, error) {
// lec -> listener error channel
lec := make(chan error, len(app.listeners))

for _, listenerType := range app.listeners {
listenerFunc, err := app.listenerFromType(listenerType)

if err != nil {
return nil, fmt.Errorf("failed to get listener function: %w", err)
// Top down
// 1. Tailscale (if tailscale.listen)
// 2. Unix socket (if server.socketPath)
// 3. HTTP - default
func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, error) {
if app.config.Tailscale.Listen {
if app.services.tailscaleService == nil {
return nil, fmt.Errorf("tailscale.listen is enabled but tailscale service is not initialized")
}

app.ding.Go(func(ctx context.Context) {
lec <- listenerFunc(ctx)
}, ding.RingNormal)
}

return lec, nil
}

// The way we calculate listeners is as follows:
// If concurrent listeners are disabled, we pick the first available listener, so:
// 1. If tailscale is enabled, we use tailscale
// 2. If socket path is configured, we use unix socket
// 3. Finally if none is configured we use http
// If concurrent listeners are enabled, we add all available listeners in the following order
func (app *BootstrapApp) calculateListenerPolicy() []Listener {
l := []Listener{}

if !app.config.Server.ConcurrentListenersEnabled {
if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
return l
}

if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
return l
}

l = append(l, ListenerHTTP)
return l
return app.serveTailscale, nil
}

if app.config.Server.SocketPath != "" {
l = append(l, ListenerUnix)
}

if app.services.tailscaleService != nil {
l = append(l, ListenerTailscale)
}

l = append(l, ListenerHTTP)

return l
}

func (app *BootstrapApp) listenerFromType(listenerType Listener) (func(ctx context.Context) error, error) {
switch listenerType {
case ListenerHTTP:
return app.serveHTTP, nil
case ListenerUnix:
return app.serveUnix, nil
case ListenerTailscale:
return app.serveTailscale, nil
default:
return nil, fmt.Errorf("invalid listener type: %d", listenerType)
}

return app.serveHTTP, nil
}

func (app *BootstrapApp) serveHTTP(ctx context.Context) error {
address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port)

app.log.App.Info().Msgf("Starting server on %s", address)
app.log.App.Info().Msgf("Starting server on http://%s", address)

listener, err := net.Listen("tcp", address)

Expand Down
Loading