diff --git a/pkg/rain/coverage_target_internal_test.go b/pkg/rain/coverage_target_internal_test.go index 0aaf275..ac990f1 100644 --- a/pkg/rain/coverage_target_internal_test.go +++ b/pkg/rain/coverage_target_internal_test.go @@ -748,17 +748,17 @@ func TestCoverageDDLMethodsAndHelpers(t *testing.T) { {value: time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC), want: "'2026-01-02T03:04:05Z'"}, {value: []byte("abc"), want: "'abc'"}, } { - if got, err := columnDefaultSQL(pg, &schema.ColumnDef{Name: "x", Default: tc.value}); err != nil || got != tc.want { + if got, err := columnDefaultSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "x", Default: tc.value}); err != nil || got != tc.want { t.Fatalf("unexpected columnDefaultSQL for %#v: %q err=%v", tc.value, got, err) } if got, err := literalDDLSQL(pg, tc.value); err != nil || got != tc.want { t.Fatalf("unexpected literalDDLSQL for %#v: %q err=%v", tc.value, got, err) } } - if got, err := columnDefaultSQL(pg, &schema.ColumnDef{Name: "x", DefaultSQL: "NOW()"}); err != nil || got != "NOW()" { + if got, err := columnDefaultSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "x", DefaultSQL: "NOW()"}); err != nil || got != "NOW()" { t.Fatalf("unexpected DefaultSQL passthrough: %q err=%v", got, err) } - if _, err := columnDefaultSQL(pg, &schema.ColumnDef{Name: "x", Default: struct{}{}}); err == nil { + if _, err := columnDefaultSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "x", Default: struct{}{}}); err == nil { t.Fatalf("expected unsupported default type to fail") } if _, err := literalDDLSQL(pg, struct{}{}); err == nil { diff --git a/pkg/rain/ddl.go b/pkg/rain/ddl.go index 3062f76..b0057a4 100644 --- a/pkg/rain/ddl.go +++ b/pkg/rain/ddl.go @@ -140,11 +140,11 @@ func (db *DB) ColumnDefaultSQL(table schema.TableReference, columnName string) ( if !ok { return "", fmt.Errorf("rain: table %q has no column %q", tableDef.Name, columnName) } - if !column.HasDefault && column.DefaultSQL == "" { + if !column.HasDefault && column.DefaultSQL == "" && column.DefaultExpr == nil { return "", nil } - return columnDefaultSQL(db.dialect, column) + return columnDefaultSQL(db.dialect, tableDef, column) } func createTableSQL(d dialect.Dialect, table *schema.TableDef) (string, error) { @@ -342,8 +342,8 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche if column.Unique { parts = append(parts, "UNIQUE") } - if column.HasDefault || column.DefaultSQL != "" { - defaultSQL, err := columnDefaultSQL(d, column) + if column.HasDefault || column.DefaultSQL != "" || column.DefaultExpr != nil { + defaultSQL, err := columnDefaultSQL(d, table, column) if err != nil { return "", err } @@ -396,7 +396,11 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef, } } -func columnDefaultSQL(d dialect.Dialect, column *schema.ColumnDef) (string, error) { +func columnDefaultSQL(d dialect.Dialect, table *schema.TableDef, column *schema.ColumnDef) (string, error) { + if column.DefaultExpr != nil { + return expressionDDLSQL(d, table, column.DefaultExpr) + } + if column.DefaultSQL != "" { return column.DefaultSQL, nil } diff --git a/pkg/rain/ddl_test.go b/pkg/rain/ddl_test.go index 144ddb6..1be8c22 100644 --- a/pkg/rain/ddl_test.go +++ b/pkg/rain/ddl_test.go @@ -78,6 +78,77 @@ func defineDDLTables() (*ddlUsersTable, *ddlPostsTable, *ddlMembershipsTable) { return users, posts, memberships } +type ddlDefaultRawTable struct { + schema.TableModel + ID *schema.Column[int64] + CreatedAt *schema.Column[time.Time] + Random *schema.Column[float64] +} + +func TestCreateTableSQLWithDefaultRaw(t *testing.T) { + t.Parallel() + + table := schema.Define("default_raw_test", func(t *ddlDefaultRawTable) { + t.ID = t.BigSerial("id").PrimaryKey() + t.CreatedAt = t.TimestampTZ("created_at").NotNull().DefaultRaw(schema.Raw("now()")) + t.Random = t.Double("random").NotNull().DefaultRaw(schema.Raw("random()")) + }) + + cases := []struct { + name string + dialect string + fragments []string + }{ + { + name: "postgres default raw", + dialect: "postgres", + fragments: []string{ + `"created_at" TIMESTAMPTZ NOT NULL DEFAULT now()`, + `"random" DOUBLE PRECISION NOT NULL DEFAULT random()`, + }, + }, + { + name: "mysql default raw", + dialect: "mysql", + fragments: []string{ + "`created_at` DATETIME NOT NULL DEFAULT now()", + "`random` DOUBLE NOT NULL DEFAULT random()", + }, + }, + { + name: "sqlite default raw", + dialect: "sqlite", + fragments: []string{ + `"created_at" TEXT NOT NULL DEFAULT now()`, + `"random" REAL NOT NULL DEFAULT random()`, + }, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect(tc.dialect) + if err != nil { + t.Fatalf("OpenDialect(%q): %v", tc.dialect, err) + } + + sql, err := db.CreateTableSQL(table) + if err != nil { + t.Fatalf("CreateTableSQL: %v", err) + } + + for _, fragment := range tc.fragments { + if !strings.Contains(sql, fragment) { + t.Fatalf("expected SQL to contain %q, got:\n%s", fragment, sql) + } + } + }) + } +} + func TestCreateTableSQLAcrossDialects(t *testing.T) { t.Parallel() diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index c5da9dc..831c506 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -134,6 +134,7 @@ type ColumnDef struct { Default any HasDefault bool DefaultSQL string + DefaultExpr Expression PrimaryKey bool AutoIncrement bool Unique bool @@ -628,6 +629,8 @@ func (c *Column[T]) Nullable() *Column[T] { func (c *Column[T]) Default(value T) *Column[T] { c.def.HasDefault = true c.def.Default = value + c.def.DefaultSQL = "" + c.def.DefaultExpr = nil return c } @@ -635,6 +638,20 @@ func (c *Column[T]) Default(value T) *Column[T] { func (c *Column[T]) DefaultNow() *Column[T] { c.def.HasDefault = true c.def.DefaultSQL = "CURRENT_TIMESTAMP" + c.def.Default = nil + c.def.DefaultExpr = nil + return c +} + +// DefaultRaw sets a raw SQL expression as the default value. +func (c *Column[T]) DefaultRaw(expr Expression) *Column[T] { + if expr == nil { + panic("schema: DefaultRaw requires a non-nil expression") + } + c.def.HasDefault = true + c.def.DefaultExpr = expr + c.def.Default = nil + c.def.DefaultSQL = "" return c }