Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ TINYAUTH_LDAP_ADDRESS=
TINYAUTH_LDAP_BINDDN=
# Bind password for LDAP authentication.
TINYAUTH_LDAP_BINDPASSWORD=
# Path to the Bind password.
TINYAUTH_LDAP_BINDPASSWORDFILE=
# Base DN for LDAP searches.
TINYAUTH_LDAP_BASEDN=
# Allow insecure LDAP connections.
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/lib/hooks/redirect-uri.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export const useRedirectUri = (
let isAllowedProto = false;
let isHttpsDowngrade = false;

if (!redirect_uri) {
if (redirect_uri === undefined) {
return {
valid: isValid,
trusted: isTrusted,
Expand Down
53 changes: 31 additions & 22 deletions gen/sqlc-wrapper/sqlc_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,24 @@ func run() error {
Overlay: map[string][]byte{outPath: stub},
}

driverTypePkg, err := loadOnePkg(cfg, *driverPkg)
repoPkgPath := parentPkg(*driverPkg)

pkgs, err := loadMultiplePkgs(cfg, *driverPkg, repoPkgPath)

if err != nil {
return fmt.Errorf("load driver package: %w", err)
return fmt.Errorf("load packages: %w", err)
}

repoPkgPath := parentPkg(*driverPkg)
repoTypePkg, err := loadOnePkg(cfg, repoPkgPath)
if err != nil {
return fmt.Errorf("load repo package: %w", err)
driverTypePkg, ok := pkgs[*driverPkg]

if !ok {
return fmt.Errorf("driver package %s not found in loaded packages", *driverPkg)
}

repoTypePkg, ok := pkgs[repoPkgPath]

if !ok {
return fmt.Errorf("repository package %s not found in loaded packages", repoPkgPath)
}

if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil {
Expand Down Expand Up @@ -106,25 +115,25 @@ func run() error {
return nil
}

// loadOnePkg loads a single package via cfg and returns its *types.Package,
// or an error if the package fails to load or has type errors.
func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) {
pkgs, err := packages.Load(cfg, importPath)
// loadMultiplePkgs loads multiple packages via cfg and returns a map of import path → *types.Package,
// or an error if any package fails to load or has type errors.
func loadMultiplePkgs(cfg *packages.Config, importPaths ...string) (map[string]*types.Package, error) {
pkgs, err := packages.Load(cfg, importPaths...)
if err != nil {
return nil, fmt.Errorf("load %s: %w", importPath, err)
}
if len(pkgs) != 1 {
return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs))
}
pkg := pkgs[0]
if len(pkg.Errors) > 0 {
msgs := make([]string, len(pkg.Errors))
for i, e := range pkg.Errors {
msgs[i] = e.Error()
return nil, fmt.Errorf("load %v: %w", importPaths, err)
}
out := make(map[string]*types.Package)
for _, pkg := range pkgs {
if len(pkg.Errors) > 0 {
msgs := make([]string, len(pkg.Errors))
for i, e := range pkg.Errors {
msgs[i] = e.Error()
}
return nil, fmt.Errorf("package %s has errors:\n %s", pkg.PkgPath, strings.Join(msgs, "\n "))
}
return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n "))
out[pkg.PkgPath] = pkg.Types
}
return pkg.Types, nil
return out, nil
}

// parentPkg returns the parent import path (everything before the last /).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS "oidc_consent";
7 changes: 7 additions & 0 deletions internal/assets/migrations/sqlite/000011_oidc_consent.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
CREATE TABLE IF NOT EXISTS "oidc_consent" (
"uuid" TEXT NOT NULL UNIQUE PRIMARY KEY,
"client_id" TEXT NOT NULL,
"scopes" TEXT NOT NULL,
"created_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
15 changes: 13 additions & 2 deletions internal/bootstrap/app_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Services struct {
type BootstrapApp struct {
config model.Config
runtime model.RuntimeConfig
helpers model.RuntimeHelpers
services Services
log *logger.Logger
ctx context.Context
Expand Down Expand Up @@ -185,9 +186,8 @@ func (app *BootstrapApp) Setup() error {
cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough

app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId)
app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId)
app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId)
app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId)
app.runtime.ConsentCookieName = fmt.Sprintf("%s-%s", model.ConsentCookieName, cookieId)

// database
store, err := app.SetupStore()
Expand Down Expand Up @@ -291,6 +291,17 @@ func (app *BootstrapApp) Setup() error {
app.runtime.TrustedDomains = append(app.runtime.TrustedDomains, "https://"+app.services.tailscaleService.GetHostname())
}

// runtime helpers
app.helpers.GetCookieDomain = app.getCookieDomain

err = app.dig.Provide(func() *model.RuntimeHelpers {
return &app.helpers
})

if err != nil {
return fmt.Errorf("failed to provide runtime helpers to container: %w", err)
}

// setup router
err = app.setupRouter()

Expand Down
55 changes: 55 additions & 0 deletions internal/bootstrap/app_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package bootstrap

import (
"context"
"errors"
"fmt"

"github.com/tinyauthapp/tinyauth/internal/utils"
)

// Not really the best place for the helpers to be but it works because bootstrap app provides
// them with everything they need

func (app *BootstrapApp) getCookieDomain(ctx context.Context, ip string) (string, error) {
cookieDomain := app.runtime.CookieDomain

if app.isTailscaleRequest(ctx, ip) {
if app.services.tailscaleService == nil {
return "", errors.New("tailscale service is not configured")
}

tsCookieDomain, err := utils.GetCookieDomain(fmt.Sprintf("https://%s", app.services.tailscaleService.GetHostname()))

if err != nil {
return "", fmt.Errorf("failed to get cookie domain for tailscale user: %w", err)
}

cookieDomain = tsCookieDomain
}

if app.config.Auth.SubdomainsEnabled {
cookieDomain = "." + cookieDomain
}

return cookieDomain, nil
}

func (app *BootstrapApp) isTailscaleRequest(ctx context.Context, ip string) bool {
if app.services.tailscaleService == nil {
return false
}

whois, err := app.services.tailscaleService.Whois(ctx, ip)

if err != nil {
app.log.App.Error().Err(err).Msgf("Error performing Tailscale whois for IP %s: %v", ip, err)
return false
}

if whois == nil {
return false
}

return true
}
28 changes: 25 additions & 3 deletions internal/controller/oauth_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type OAuthController struct {
config *model.Config
runtime *model.RuntimeConfig
auth *service.AuthService
helpers *model.RuntimeHelpers
}

type OAuthControllerInput struct {
Expand All @@ -36,6 +37,7 @@ type OAuthControllerInput struct {
Log *logger.Logger
Config *model.Config
RuntimeConfig *model.RuntimeConfig
Helpers *model.RuntimeHelpers
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
AuthService *service.AuthService
}
Expand All @@ -46,6 +48,7 @@ func NewOAuthController(i OAuthControllerInput) *OAuthController {
config: i.Config,
runtime: i.RuntimeConfig,
auth: i.AuthService,
helpers: i.Helpers,
}

oauthGroup := i.RouterGroup.Group("/oauth")
Expand Down Expand Up @@ -110,7 +113,18 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
return
}

c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())

if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}

c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", cookieDomain, controller.config.Auth.SecureCookie, true)

c.JSON(200, gin.H{
"status": 200,
Expand Down Expand Up @@ -140,7 +154,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}

c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true)
cookieDomain, err := controller.helpers.GetCookieDomain(c, c.RemoteIP())

if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}

c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", cookieDomain, controller.config.Auth.SecureCookie, true)

oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)

Expand Down Expand Up @@ -257,7 +279,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {

controller.log.App.Debug().Msg("Creating session cookie for user")

cookie, err := controller.auth.CreateSession(c, sessionCookie)
cookie, err := controller.auth.CreateSession(c, sessionCookie, c.RemoteIP())

if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
Expand Down
53 changes: 53 additions & 0 deletions internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package controller

import (
"database/sql"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -34,6 +35,8 @@ type OIDCController struct {
log *logger.Logger
oidc *service.OIDCService
runtime *model.RuntimeConfig
helpers *model.RuntimeHelpers
config *model.Config
}

type AuthorizeCallback struct {
Expand Down Expand Up @@ -90,13 +93,17 @@ type OIDCControllerInput struct {
RuntimeConfig *model.RuntimeConfig
RouterGroup *gin.RouterGroup `name:"apiRouterGroup"`
MainRouter *gin.RouterGroup `name:"mainRouterGroup"`
Helpers *model.RuntimeHelpers
Config *model.Config
}

func NewOIDCController(i OIDCControllerInput) *OIDCController {
controller := &OIDCController{
log: i.Log,
oidc: i.OIDCService,
runtime: i.RuntimeConfig,
helpers: i.Helpers,
config: i.Config,
}

i.MainRouter.POST("/authorize", controller.authorize)
Expand Down Expand Up @@ -219,6 +226,25 @@ func (controller *OIDCController) authorize(c *gin.Context) {
values.OIDCPrompt = service.OIDCPromptNone
}

// If no prompt is already set, we can check if we can/should skip it based on the cookie
if values.OIDCPrompt == "" {
consnetCookie, err := c.Cookie(controller.runtime.ConsentCookieName)

if err == nil {
consentEntry, err := controller.oidc.GetConsentEntry(c, consnetCookie)

if err == nil && consentEntry != nil {
if consentEntry.ClientID == req.ClientID && consentEntry.Scopes == req.Scope {
values.OIDCPrompt = service.OIDCPromptNone
}
} else {
if !errors.Is(err, sql.ErrNoRows) {
controller.log.App.Error().Err(err).Msg("Failed to get consent entry for consent cookie")
}
}
}
}

if req.MaxAge != "" && userContext != nil {
maxAge, err := strconv.Atoi(req.MaxAge)
if err != nil {
Expand Down Expand Up @@ -361,6 +387,33 @@ func (controller *OIDCController) authorizeComplete(c *gin.Context) {
return
}

// Just before returning let's set the consent cookie
consnetUUID, err := controller.oidc.CreateConsentEntry(c, authorizeReq.ClientID, authorizeReq.Scope)

// If we fail to create the consent entry, we don't want to block the authorization flow,
// but we log the error and move on without setting the cookie
if err == nil {
cookieDomain, err := controller.helpers.GetCookieDomain(c.Request.Context(), c.RemoteIP())

if err == nil {
cookie := &http.Cookie{
Name: controller.runtime.ConsentCookieName,
Value: consnetUUID,
Path: "/",
Domain: cookieDomain,
Expires: time.Now().Add(365 * 24 * time.Hour), // set consent cookie for 1 year
Secure: controller.config.Auth.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(c.Writer, cookie)
} else {
controller.log.App.Error().Err(err).Msg("Failed to determine cookie domain for consent cookie")
}
} else {
controller.log.App.Error().Err(err).Msg("Failed to create consent entry")
}

c.JSON(200, gin.H{
"status": 200,
"redirect_uri": fmt.Sprintf("%s?%s", authorizeReq.RedirectURI, queries.Encode()),
Expand Down
4 changes: 4 additions & 0 deletions internal/controller/oidc_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func TestOIDCController(t *testing.T) {

cfg, runtime := test.CreateTestConfigs(t)

helpers := test.CreateTestHelpers()

ctx := context.TODO()
dg := ding.New(ctx)

Expand Down Expand Up @@ -862,6 +864,8 @@ func TestOIDCController(t *testing.T) {
RuntimeConfig: &runtime,
RouterGroup: group,
MainRouter: &router.RouterGroup,
Helpers: helpers,
Config: &cfg,
})

recorder := httptest.NewRecorder()
Expand Down
Loading