diff --git a/pkg/migrator/diff.go b/pkg/migrator/diff.go index 3d7aa62..83c4ebc 100644 --- a/pkg/migrator/diff.go +++ b/pkg/migrator/diff.go @@ -82,6 +82,21 @@ func planCreateAll(snapshot Snapshot) Plan { } func diffTable(previous, current TableSnapshot, dialectName string) ([]string, error) { + previousView := isViewSnapshot(previous) + currentView := isViewSnapshot(current) + if previousView || currentView { + if previousView != currentView { + return nil, fmt.Errorf("migrator: changing %q between view and table is not supported", current.Name) + } + if normalizeSQL(previous.CreateTableSQL) == normalizeSQL(current.CreateTableSQL) { + return nil, nil + } + return []string{ + dropViewSQL(dialectName, current.Name), + current.CreateTableSQL, + }, nil + } + var statements []string previousColumns := make(map[string]ColumnSnapshot, len(previous.Columns)) @@ -243,6 +258,17 @@ func normalizeSQL(sql string) string { return strings.Join(strings.Fields(sql), " ") } +func isViewSnapshot(table TableSnapshot) bool { + if table.IsView { + return true + } + return strings.HasPrefix(strings.ToUpper(strings.TrimSpace(table.CreateTableSQL)), "CREATE VIEW ") +} + +func dropViewSQL(dialectName, name string) string { + return "DROP VIEW " + quoteIdentifier(dialectName, name) +} + func constraintSupportError(dialectName, kind, tableName, name string) error { switch dialectName { case "postgres", "postgresql", "mysql": diff --git a/pkg/migrator/migrator_test.go b/pkg/migrator/migrator_test.go index 0e9f45b..9bce2b6 100644 --- a/pkg/migrator/migrator_test.go +++ b/pkg/migrator/migrator_test.go @@ -11,6 +11,7 @@ import ( "time" exampleregistry "github.com/hyperlocalise/rain-orm/examples/schema/registry" + "github.com/hyperlocalise/rain-orm/pkg/rain" "github.com/hyperlocalise/rain-orm/pkg/schema" _ "modernc.org/sqlite" ) @@ -174,6 +175,81 @@ func TestDiffSnapshotsRejectAddConstraintOnSQLite(t *testing.T) { } } +func TestDiffSnapshotsRecreateChangedView(t *testing.T) { + t.Parallel() + + ddl, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect(postgres): %v", err) + } + users := schema.Define("users", func(t *usersTable) { + t.ID = t.BigSerial("id").PrimaryKey() + t.Email = t.Text("email").NotNull() + }) + baseQuery := ddl.Select().Table(users).Column(users.Email) + viewBefore := schema.DefineView("user_emails", baseQuery, func(v *userEmailsView) { + v.Email = v.VarChar("email", 255) + }) + filteredQuery := ddl.Select().Table(users).Column(users.Email).Where(schema.Raw("? = ?", users.Email, "alice@example.com")) + viewAfter := schema.DefineView("user_emails", filteredQuery, func(v *userEmailsView) { + v.Email = v.VarChar("email", 255) + }) + + before := mustBuildSnapshot(t, "postgres", []schema.TableReference{users, viewBefore}) + after := mustBuildSnapshot(t, "postgres", []schema.TableReference{users, viewAfter}) + beforeView, ok := tableSnapshotByName(before, "user_emails") + if !ok || !beforeView.IsView { + t.Fatalf("expected view snapshot to set IsView, got %#v", beforeView) + } + afterView, ok := tableSnapshotByName(after, "user_emails") + if !ok || !afterView.IsView { + t.Fatalf("expected view snapshot to set IsView, got %#v", afterView) + } + + plan, err := DiffSnapshots(&before, after) + if err != nil { + t.Fatalf("DiffSnapshots returned error: %v", err) + } + if len(plan.Statements) != 2 { + t.Fatalf("expected drop and create view statements, got %d: %v", len(plan.Statements), plan.Statements) + } + if plan.Statements[0] != `DROP VIEW "user_emails"` { + t.Fatalf("expected DROP VIEW statement, got %q", plan.Statements[0]) + } + if !strings.HasPrefix(plan.Statements[1], `CREATE VIEW "user_emails" AS `) { + t.Fatalf("expected CREATE VIEW statement, got %q", plan.Statements[1]) + } + if !strings.Contains(plan.Statements[1], `'alice@example.com'`) { + t.Fatalf("expected updated view definition, got %q", plan.Statements[1]) + } +} + +func TestDiffSnapshotsUnchangedView(t *testing.T) { + t.Parallel() + + ddl, err := rain.OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect(sqlite): %v", err) + } + users := schema.Define("users", func(t *usersTable) { + t.ID = t.BigSerial("id").PrimaryKey() + t.Email = t.Text("email").NotNull() + }) + query := ddl.Select().Table(users).Column(users.Email) + view := schema.DefineView("user_emails", query, func(v *userEmailsView) { + v.Email = v.Text("email") + }) + + snapshot := mustBuildSnapshot(t, "sqlite", []schema.TableReference{users, view}) + plan, err := DiffSnapshots(&snapshot, snapshot) + if err != nil { + t.Fatalf("DiffSnapshots returned error: %v", err) + } + if !plan.Empty() { + t.Fatalf("expected no statements for unchanged view, got %v", plan.Statements) + } +} + func TestDiffSnapshotsRejectAddForeignKeyOnSQLite(t *testing.T) { t.Parallel() @@ -597,6 +673,15 @@ func TestLockNameColumnDDL(t *testing.T) { } } +func tableSnapshotByName(snapshot Snapshot, name string) (TableSnapshot, bool) { + for _, table := range snapshot.Tables { + if table.Name == name { + return table, true + } + } + return TableSnapshot{}, false +} + func mustBuildSnapshot(t *testing.T, dialectName string, tables []schema.TableReference) Snapshot { t.Helper() @@ -622,6 +707,11 @@ type postsTable struct { UserID *schema.Column[int64] } +type userEmailsView struct { + schema.TableModel + Email *schema.Column[string] +} + func usersTableWithoutNickname() schema.TableReference { return schema.Define("users", func(t *usersTable) { t.ID = t.BigSerial("id").PrimaryKey() diff --git a/pkg/migrator/snapshot.go b/pkg/migrator/snapshot.go index 12a69e8..f75dbc8 100644 --- a/pkg/migrator/snapshot.go +++ b/pkg/migrator/snapshot.go @@ -21,6 +21,7 @@ type Snapshot struct { // TableSnapshot stores a portable, deterministic representation of one table. type TableSnapshot struct { Name string `json:"name"` + IsView bool `json:"is_view,omitempty"` CreateTableSQL string `json:"create_table_sql"` Columns []ColumnSnapshot `json:"columns"` Constraints []ConstraintSnapshot `json:"constraints"` @@ -167,6 +168,7 @@ func BuildSnapshot(dialectName string, tables []schema.TableReference) (Snapshot tableSnapshots = append(tableSnapshots, TableSnapshot{ Name: tableDef.Name, + IsView: tableDef.IsView, CreateTableSQL: createTableSQL, Columns: columnSnapshots, Constraints: constraintSnapshots, diff --git a/pkg/rain/coverage_target_internal_test.go b/pkg/rain/coverage_target_internal_test.go index 0aaf275..dc47a9b 100644 --- a/pkg/rain/coverage_target_internal_test.go +++ b/pkg/rain/coverage_target_internal_test.go @@ -707,7 +707,7 @@ func TestCoverageDDLMethodsAndHelpers(t *testing.T) { if _, err := columnDefinitionSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "broken_default", Type: schema.ColumnType{DataType: schema.TypeText}, HasDefault: true, Default: struct{}{}}, false); err == nil { t.Fatalf("expected columnDefinitionSQL default error") } - if got := columnTypeSQL(sqlite, users.CreatedAt.ColumnDef()); got != "TEXT" { + if got := ddlColumnTypeSQL(sqlite, users.CreatedAt.ColumnDef()); got != "TEXT" { t.Fatalf("unexpected sqlite timestamp type: %q", got) } if shouldEmitAutoIncrementKeyword(pg, &schema.ColumnDef{Name: "id", Type: schema.ColumnType{DataType: schema.TypeBigSerial}}, true) { diff --git a/pkg/rain/ddl.go b/pkg/rain/ddl.go index b17b0fb..a3c13e2 100644 --- a/pkg/rain/ddl.go +++ b/pkg/rain/ddl.go @@ -20,6 +20,10 @@ func (db *DB) CreateTableSQL(table schema.TableReference) (string, error) { return "", errors.New("rain: create table requires a non-nil table") } + if table.TableDef().IsView { + return createViewSQL(db.dialect, table.TableDef()) + } + return createTableSQL(db.dialect, table.TableDef()) } @@ -32,6 +36,10 @@ func (db *DB) CreateIndexesSQL(table schema.TableReference) ([]string, error) { return nil, errors.New("rain: create indexes requires a non-nil table") } + if table.TableDef().IsView { + return nil, nil + } + return createIndexesSQL(db.dialect, table.TableDef()) } @@ -50,6 +58,10 @@ func (db *DB) ColumnDefinitionSQL(table schema.TableReference, columnName string return "", fmt.Errorf("rain: table %q has no column %q", tableDef.Name, columnName) } + if tableDef.IsView { + return db.dialect.QuoteIdentifier(column.Name) + " " + ddlColumnTypeSQL(db.dialect, column), nil + } + inlinePrimaryKey := false tablePrimaryKey, err := tablePrimaryKeyConstraint(tableDef) if err != nil { @@ -73,6 +85,10 @@ func (db *DB) AddConstraintSQL(table schema.TableReference, constraintName strin } tableDef := table.TableDef() + if tableDef.IsView { + return "", fmt.Errorf("rain: view %q does not support constraints", tableDef.Name) + } + for _, constraint := range tableDef.Constraints { if constraint.Name != constraintName { continue @@ -97,6 +113,10 @@ func (db *DB) AddForeignKeySQL(table schema.TableReference, foreignKeyName strin } tableDef := table.TableDef() + if tableDef.IsView { + return "", fmt.Errorf("rain: view %q does not support foreign keys", tableDef.Name) + } + for _, foreignKey := range tableDef.ForeignKeys { if foreignKey.Name != foreignKeyName { continue @@ -132,6 +152,35 @@ func (db *DB) ColumnDefaultSQL(table schema.TableReference, columnName string) ( return columnDefaultSQL(db.dialect, column) } +func createViewSQL(d dialect.Dialect, table *schema.TableDef) (string, error) { + if d == nil { + return "", errors.New("rain: create view requires a configured dialect") + } + if table == nil { + return "", errors.New("rain: create view requires a non-nil table") + } + if !table.IsView { + return "", fmt.Errorf("rain: table %q is not a view", table.Name) + } + if table.ViewQuery == nil { + return "", fmt.Errorf("rain: view %q requires a defining query", table.Name) + } + + ctx := newCompileContext(d) + ctx.useLiterals = true + if err := ctx.writeExpressionInContext(table.ViewQuery, expressionContext{noParens: true}); err != nil { + return "", err + } + + var builder strings.Builder + builder.WriteString("CREATE VIEW ") + builder.WriteString(d.QuoteIdentifier(table.Name)) + builder.WriteString(" AS ") + builder.WriteString(ctx.String()) + + return builder.String(), nil +} + func createTableSQL(d dialect.Dialect, table *schema.TableDef) (string, error) { if d == nil { return "", errors.New("rain: create table requires a configured dialect") @@ -297,7 +346,7 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche var parts []string parts = append(parts, d.QuoteIdentifier(column.Name)) - typeSQL := columnTypeSQL(d, column) + typeSQL := ddlColumnTypeSQL(d, column) parts = append(parts, typeSQL) if inlinePrimaryKey { @@ -338,7 +387,7 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche return strings.Join(parts, " "), nil } -func columnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string { +func ddlColumnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string { typeSQL := d.DataType(column.Type) if column.Type.DataType == schema.TypeVarChar && column.Type.Size > 0 && strings.EqualFold(typeSQL, "VARCHAR") { @@ -363,13 +412,9 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef, if !inlinePrimaryKey { return false } - if column.Type.DataType != schema.TypeBigSerial { - return true - } - switch d.Name() { case "postgres": - return false + return !isPostgresSerialType(column.Type.DataType) case "sqlite": return true default: @@ -377,6 +422,15 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef, } } +func isPostgresSerialType(dataType schema.DataType) bool { + switch dataType { + case schema.TypeBigSerial, schema.TypeSerial, schema.TypeSmallSerial: + return true + default: + return false + } +} + func columnDefaultSQL(d dialect.Dialect, column *schema.ColumnDef) (string, error) { if column.DefaultSQL != "" { return column.DefaultSQL, nil diff --git a/pkg/rain/ddl_test.go b/pkg/rain/ddl_test.go index 144ddb6..51cd2c3 100644 --- a/pkg/rain/ddl_test.go +++ b/pkg/rain/ddl_test.go @@ -40,6 +40,21 @@ type ddlMembershipsTable struct { Active *schema.Column[bool] } +type ddlSerialTable struct { + schema.TableModel + ID *schema.Column[int32] +} + +type ddlSmallSerialTable struct { + schema.TableModel + ID *schema.Column[int16] +} + +type ddlUserEmailView struct { + schema.TableModel + Email *schema.Column[string] +} + func defineDDLTables() (*ddlUsersTable, *ddlPostsTable, *ddlMembershipsTable) { users := schema.Define("users", func(t *ddlUsersTable) { t.ID = t.BigSerial("id").PrimaryKey() @@ -78,6 +93,97 @@ func defineDDLTables() (*ddlUsersTable, *ddlPostsTable, *ddlMembershipsTable) { return users, posts, memberships } +func TestCreateViewSQLRawExprUsesLiterals(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect(postgres): %v", err) + } + users, _, _ := defineDDLTables() + query := db.Select(). + Table(users). + Column(users.Email). + Where(schema.Raw("? = ?", users.Email, "alice@example.com")) + view := schema.DefineView("user_email_view", query, func(v *ddlUserEmailView) { + v.Email = v.VarChar("email", 255) + }) + + sql, err := db.CreateTableSQL(view) + if err != nil { + t.Fatalf("CreateTableSQL(view): %v", err) + } + if strings.Contains(sql, "$1") || strings.Contains(sql, "$2") { + t.Fatalf("expected view DDL to inline raw args, got:\n%s", sql) + } + if !strings.Contains(sql, `"users"."email" = 'alice@example.com'`) { + t.Fatalf("expected view DDL to include literalized raw predicate, got:\n%s", sql) + } +} + +func TestAliasViewWithSelectQueryDoesNotPanic(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect(postgres): %v", err) + } + users, _, _ := defineDDLTables() + query := db.Select().Table(users).Column(users.Email) + view := schema.DefineView("user_email_view_alias_source", query, func(v *ddlUserEmailView) { + v.Email = v.VarChar("email", 255) + }) + + aliased := schema.Alias(view, "uev") + sql, args, err := db.Select().Table(aliased).Column(aliased.Email).ToSQL() + if err != nil { + t.Fatalf("Select aliased view: %v", err) + } + if len(args) != 0 { + t.Fatalf("expected no args, got %#v", args) + } + if !strings.Contains(sql, `FROM "user_email_view_alias_source" AS "uev"`) { + t.Fatalf("expected aliased view table source, got:\n%s", sql) + } +} + +func TestCreateTableSQLPostgresSerialPrimaryKeysDoNotRepeatSerialKeyword(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect(postgres): %v", err) + } + serialTable := schema.Define("serial_ids", func(t *ddlSerialTable) { + t.ID = t.Serial("id").PrimaryKey() + }) + smallSerialTable := schema.Define("small_serial_ids", func(t *ddlSmallSerialTable) { + t.ID = t.SmallSerial("id").PrimaryKey() + }) + + for _, tc := range []struct { + name string + table schema.TableReference + want string + }{ + {name: "serial", table: serialTable, want: `"id" SERIAL PRIMARY KEY`}, + {name: "smallserial", table: smallSerialTable, want: `"id" SMALLSERIAL PRIMARY KEY`}, + } { + t.Run(tc.name, func(t *testing.T) { + sql, err := db.CreateTableSQL(tc.table) + if err != nil { + t.Fatalf("CreateTableSQL: %v", err) + } + if !strings.Contains(sql, tc.want) { + t.Fatalf("expected SQL to contain %q, got:\n%s", tc.want, sql) + } + if strings.Contains(sql, "PRIMARY KEY SERIAL") || strings.Contains(sql, "PRIMARY KEY SMALLSERIAL") { + t.Fatalf("expected SQL not to repeat serial keyword, got:\n%s", sql) + } + }) + } +} + func TestCreateTableSQLAcrossDialects(t *testing.T) { t.Parallel() diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index 173f7f3..ccbab3b 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -76,11 +76,12 @@ func (q compiledQuery) bind(args PreparedArgs) ([]any, error) { } type compileContext struct { - builder strings.Builder - dialect dialect.Dialect - argPlan []compiledArg - err error - skipCTEs bool + builder strings.Builder + dialect dialect.Dialect + argPlan []compiledArg + err error + skipCTEs bool + useLiterals bool } func newCompileContext(d dialect.Dialect) *compileContext { @@ -180,6 +181,14 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex case schema.ColumnReference: c.writeColumn(value) case schema.ValueExpr: + if c.useLiterals { + literal, err := literalDDLSQL(c.dialect, value.Value) + if err != nil { + return err + } + c.writeString(literal) + return nil + } index := c.nextPlaceholderIndex() c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: value.Value}) c.writeString(c.dialect.Placeholder(index)) @@ -386,6 +395,21 @@ func (c *compileContext) writeRaw(raw schema.RawExpr) error { if argIndex >= len(raw.Args) { return errors.New("rain: raw SQL placeholder count does not match args") } + if c.useLiterals { + if expr, ok := raw.Args[argIndex].(schema.Expression); ok { + if err := c.writeExpression(expr); err != nil { + return err + } + } else { + literal, err := literalDDLSQL(c.dialect, raw.Args[argIndex]) + if err != nil { + return err + } + c.writeString(literal) + } + argIndex++ + continue + } index := c.nextPlaceholderIndex() c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgLiteral, value: raw.Args[argIndex]}) c.writeString(c.dialect.Placeholder(index)) diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index 3b31b5f..e53dad6 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -252,6 +252,12 @@ func (q *SelectQuery) clone() *SelectQuery { return &newQ } +// CloneExpressionForTable preserves SELECT subqueries when schema metadata is +// cloned for an alias. The query's own table sources remain unchanged. +func (q *SelectQuery) CloneExpressionForTable(*schema.TableDef) schema.Expression { + return q +} + func (q *SelectQuery) withSQLiteInsertSelectConflictWhere() *SelectQuery { rewritten, _ := q.withSQLiteInsertSelectConflictWhereChanged() return rewritten diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index b6637c2..fc2dc70 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -20,6 +20,8 @@ type TimestampKind string // Supported schema data types. const ( TypeBigSerial DataType = "BIGSERIAL" + TypeSerial DataType = "SERIAL" + TypeSmallSerial DataType = "SMALLSERIAL" TypeSmallInt DataType = "SMALLINT" TypeInteger DataType = "INTEGER" TypeBigInt DataType = "BIGINT" @@ -112,6 +114,8 @@ type TableDef struct { Name string Alias string Columns []*ColumnDef + IsView bool + ViewQuery Expression Indexes []IndexDef Constraints []ConstraintDef ForeignKeys []ForeignKeyDef @@ -263,6 +267,16 @@ func (t *TableModel) BigSerial(name string) *Column[int64] { return addColumn[int64](t.def, name, ColumnType{DataType: TypeBigSerial}, false, true) } +// Serial adds a SERIAL column intended for 32-bit auto-incrementing integers. +func (t *TableModel) Serial(name string) *Column[int32] { + return addColumn[int32](t.def, name, ColumnType{DataType: TypeSerial}, false, true) +} + +// SmallSerial adds a SMALLSERIAL column intended for 16-bit auto-incrementing integers. +func (t *TableModel) SmallSerial(name string) *Column[int16] { + return addColumn[int16](t.def, name, ColumnType{DataType: TypeSmallSerial}, false, true) +} + // BigInt adds a BIGINT column. func (t *TableModel) BigInt(name string) *Column[int64] { return addColumn[int64](t.def, name, ColumnType{DataType: TypeBigInt}, true, false) @@ -483,6 +497,27 @@ func Define[T any](name string, fn func(*T)) *T { return handle } +// DefineView creates a typed view handle backed by schema metadata and a defining query. +func DefineView[T any](name string, query Expression, fn func(*T)) *T { + if query == nil { + panic("schema: DefineView requires a non-nil query") + } + + handle := new(T) + def := &TableDef{ + Name: name, + IsView: true, + ViewQuery: query, + Columns: make([]*ColumnDef, 0, 8), + columnsByName: make(map[string]*ColumnDef, 8), + relationsByName: make(map[string]RelationDef, 4), + } + bindTableModel(handle, def) + fn(handle) + + return handle +} + // Alias clones a typed table handle with a SQL alias. func Alias[T any](src *T, alias string) *T { clone := new(T) @@ -1313,6 +1348,10 @@ type tableCloner interface { cloneForTable(*TableDef) any } +type expressionCloner interface { + CloneExpressionForTable(*TableDef) Expression +} + func (c *AnyColumn) cloneForTable(table *TableDef) any { clonedMeta, ok := table.columnsByName[c.def.Name] if !ok { @@ -1375,6 +1414,7 @@ func cloneTableDef(src *TableDef, alias string) *TableDef { cloned := &TableDef{ Name: src.Name, Alias: alias, + IsView: src.IsView, Columns: make([]*ColumnDef, 0, len(src.Columns)), Indexes: make([]IndexDef, len(src.Indexes)), Constraints: make([]ConstraintDef, len(src.Constraints)), @@ -1384,6 +1424,10 @@ func cloneTableDef(src *TableDef, alias string) *TableDef { relationsByName: make(map[string]RelationDef, len(src.Relations)), } + if src.ViewQuery != nil { + cloned.ViewQuery = cloneExpressionForTable(src.ViewQuery, cloned) + } + for _, column := range src.Columns { copyColumn := *column copyColumn.Type.EnumValues = append([]string(nil), column.Type.EnumValues...) @@ -1477,6 +1521,8 @@ func cloneExpressionForTable(expr Expression, table *TableDef) Expression { panic(fmt.Sprintf("schema: cloned expression %T is not an expression", value)) } return cloned + case expressionCloner: + return value.CloneExpressionForTable(table) case ValueExpr: return value case PlaceholderExpr: