Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
name: CI

env:
# A suitable URL for the test database.
DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_dev?sslmode=disable
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't needed anywhere so I took it out so that the env is cleaner for tests where we want to work with it (i.e. we don't have to unset DATABASE_URL before testing something).


# Test database.
TEST_DATABASE_URL: postgres://postgres:postgres@127.0.0.1:5432/river_test?sslmode=disable

Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Allow `PG*` env vars as an alternative to `DATABASE_URL`. [PR #256](https://github.com/riverqueue/riverui/pull/256).

## [0.7.0] - 2024-12-16

### Added
Expand Down
34 changes: 0 additions & 34 deletions cmd/riverui/logger.go

This file was deleted.

147 changes: 85 additions & 62 deletions cmd/riverui/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"flag"
"fmt"
"log"
"log/slog"
"net/http"
"os"
Expand All @@ -26,69 +25,103 @@ import (

func main() {
ctx := context.Background()
initLogger()
os.Exit(initAndServe(ctx))
}

func initAndServe(ctx context.Context) int {
var (
devMode bool
liveFS bool
pathPrefix string
)
_, liveFS = os.LookupEnv("LIVE_FS")
_, devMode = os.LookupEnv("DEV")
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: getLogLevel(),
}))

var pathPrefix string
flag.StringVar(&pathPrefix, "prefix", "/", "path prefix to use for the API and UI HTTP requests")
flag.Parse()

initRes, err := initServer(ctx, logger, pathPrefix)
if err != nil {
logger.ErrorContext(ctx, "Error initializing server", slog.String("error", err.Error()))
os.Exit(1)
}

if err := startAndListen(ctx, logger, initRes); err != nil {
logger.ErrorContext(ctx, "Error starting server", slog.String("error", err.Error()))
os.Exit(1)
}
}

// Translates either a "1" or "true" from env to a Go boolean.
func envBooleanTrue(val string) bool {
return val == "1" || val == "true"
}

func getLogLevel() slog.Level {
if envBooleanTrue(os.Getenv("RIVER_DEBUG")) {
return slog.LevelDebug
}

switch strings.ToLower(os.Getenv("RIVER_LOG_LEVEL")) {
case "debug":
return slog.LevelDebug
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
return slog.LevelInfo
}
}

type initServerResult struct {
dbPool *pgxpool.Pool // database pool; close must be deferred by caller!
httpServer *http.Server // HTTP server wrapping the UI server
logger *slog.Logger // application logger (also internalized in UI server)
uiServer *riverui.Server // River UI server
}

func initServer(ctx context.Context, logger *slog.Logger, pathPrefix string) (*initServerResult, error) {
if !strings.HasPrefix(pathPrefix, "/") || pathPrefix == "" {
logger.ErrorContext(ctx, "invalid path prefix", slog.String("prefix", pathPrefix))
return 1
return nil, fmt.Errorf("invalid path prefix: %s", pathPrefix)
}
pathPrefix = riverui.NormalizePathPrefix(pathPrefix)

var (
basicAuthUsername = os.Getenv("RIVER_BASIC_AUTH_USER")
basicAuthPassword = os.Getenv("RIVER_BASIC_AUTH_PASS")
corsOrigins = strings.Split(os.Getenv("CORS_ORIGINS"), ",")
dbURL = mustEnv("DATABASE_URL")
databaseURL = os.Getenv("DATABASE_URL")
devMode = envBooleanTrue(os.Getenv("DEV"))
host = os.Getenv("RIVER_HOST") // may be left empty to bind to all local interfaces
otelEnabled = os.Getenv("OTEL_ENABLED") == "true"
liveFS = envBooleanTrue(os.Getenv("LIVE_FS"))
otelEnabled = envBooleanTrue(os.Getenv("OTEL_ENABLED"))
port = cmp.Or(os.Getenv("PORT"), "8080")
)

dbPool, err := getDBPool(ctx, dbURL)
if databaseURL == "" && os.Getenv("PGDATABASE") == "" {
return nil, errors.New("expect to have DATABASE_URL or database configuration in standard PG* env vars like PGDATABASE/PGHOST/PGPORT/PGUSER/PGPASSWORD")
}

poolConfig, err := pgxpool.ParseConfig(databaseURL)
if err != nil {
return nil, fmt.Errorf("error parsing db config: %w", err)
}

dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
logger.ErrorContext(ctx, "error connecting to db", slog.String("error", err.Error()))
return 1
return nil, fmt.Errorf("error connecting to db: %w", err)
}
defer dbPool.Close()

client, err := river.NewClient(riverpgxv5.New(dbPool), &river.Config{})
if err != nil {
logger.ErrorContext(ctx, "error creating river client", slog.String("error", err.Error()))
return 1
return nil, err
}

handlerOpts := &riverui.ServerOpts{
uiServer, err := riverui.NewServer(&riverui.ServerOpts{
Client: client,
DB: dbPool,
DevMode: devMode,
LiveFS: liveFS,
Logger: logger,
Prefix: pathPrefix,
}

server, err := riverui.NewServer(handlerOpts)
})
if err != nil {
logger.ErrorContext(ctx, "error creating handler", slog.String("error", err.Error()))
return 1
}

if err = server.Start(ctx); err != nil {
logger.ErrorContext(ctx, "error starting UI server", slog.String("error", err.Error()))
return 1
return nil, err
}

corsHandler := cors.New(cors.Options{
Expand All @@ -109,40 +142,30 @@ func initAndServe(ctx context.Context) int {
middlewareStack.Use(&authMiddleware{username: basicAuthUsername, password: basicAuthPassword})
}

srv := &http.Server{
Addr: host + ":" + port,
Handler: middlewareStack.Mount(server),
ReadHeaderTimeout: 5 * time.Second,
}
return &initServerResult{
dbPool: dbPool,
httpServer: &http.Server{
Addr: host + ":" + port,
Handler: middlewareStack.Mount(uiServer),
ReadHeaderTimeout: 5 * time.Second,
},
logger: logger,
uiServer: uiServer,
}, nil
}

log.Printf("starting server on %s", srv.Addr)
func startAndListen(ctx context.Context, logger *slog.Logger, initRes *initServerResult) error {
defer initRes.dbPool.Close()

if err = srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.ErrorContext(ctx, "error from ListenAndServe", slog.String("error", err.Error()))
return 1
if err := initRes.uiServer.Start(ctx); err != nil {
return err
}

return 0
}
logger.InfoContext(ctx, "Starting server", slog.String("addr", initRes.httpServer.Addr))

func getDBPool(ctx context.Context, dbURL string) (*pgxpool.Pool, error) {
poolConfig, err := pgxpool.ParseConfig(dbURL)
if err != nil {
return nil, fmt.Errorf("error parsing db config: %w", err)
if err := initRes.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}

dbPool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("error connecting to db: %w", err)
}
return dbPool, nil
}

func mustEnv(name string) string {
val := os.Getenv(name)
if val == "" {
logger.Error("missing required env var", slog.String("name", name))
os.Exit(1)
}
return val
return nil
}
65 changes: 65 additions & 0 deletions cmd/riverui/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package main

import (
"cmp"
"context"
"net/url"
"os"
"testing"

"github.com/stretchr/testify/require"

"github.com/riverqueue/river/rivershared/riversharedtest"
)

func TestInitServer(t *testing.T) {
var (
ctx = context.Background()
databaseURL = cmp.Or(os.Getenv("TEST_DATABASE_URL"), "postgres://localhost/river_test")
)

t.Setenv("DEV", "true")

type testBundle struct{}

setup := func(t *testing.T) (*initServerResult, *testBundle) {
t.Helper()

initRes, err := initServer(ctx, riversharedtest.Logger(t), "/")
require.NoError(t, err)
t.Cleanup(initRes.dbPool.Close)

return initRes, &testBundle{}
}

t.Run("WithDatabaseURL", func(t *testing.T) {
t.Setenv("DATABASE_URL", databaseURL)

initRes, _ := setup(t)

_, err := initRes.dbPool.Exec(ctx, "SELECT 1")
require.NoError(t, err)
})

t.Run("WithPGEnvVars", func(t *testing.T) {
// Verify that DATABASE_URL is indeed not set to be sure we're taking
// the configuration branch we expect to be taking.
require.Empty(t, os.Getenv("DATABASE_URL"))

parsedURL, err := url.Parse(databaseURL)
require.NoError(t, err)

t.Setenv("PGDATABASE", parsedURL.Path[1:])
t.Setenv("PGHOST", parsedURL.Hostname())
pass, _ := parsedURL.User.Password()
t.Setenv("PGPASSWORD", pass)
t.Setenv("PGPORT", cmp.Or(parsedURL.Port(), "5432"))
t.Setenv("PGSSLMODE", parsedURL.Query().Get("sslmode"))
t.Setenv("PGUSER", parsedURL.User.Username())

initRes, _ := setup(t)

_, err = initRes.dbPool.Exec(ctx, "SELECT 1")
require.NoError(t, err)
})
}
Loading