Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions pkg/migrator/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ func planCreateAll(snapshot Snapshot) Plan {
}

func diffTable(previous, current TableSnapshot, dialectName string) ([]string, error) {
previousView := isViewSnapshot(previous)
currentView := isViewSnapshot(current)
if previousView || currentView {
if previousView != currentView {
return nil, fmt.Errorf("migrator: changing %q between view and table is not supported", current.Name)
}
if normalizeSQL(previous.CreateTableSQL) == normalizeSQL(current.CreateTableSQL) {
return nil, nil
}
return []string{
dropViewSQL(dialectName, current.Name),
current.CreateTableSQL,
}, nil
}

var statements []string

previousColumns := make(map[string]ColumnSnapshot, len(previous.Columns))
Expand Down Expand Up @@ -243,6 +258,17 @@ func normalizeSQL(sql string) string {
return strings.Join(strings.Fields(sql), " ")
}

func isViewSnapshot(table TableSnapshot) bool {
if table.IsView {
return true
}
return strings.HasPrefix(strings.ToUpper(strings.TrimSpace(table.CreateTableSQL)), "CREATE VIEW ")
}

func dropViewSQL(dialectName, name string) string {
return "DROP VIEW " + quoteIdentifier(dialectName, name)
}

func constraintSupportError(dialectName, kind, tableName, name string) error {
switch dialectName {
case "postgres", "postgresql", "mysql":
Expand Down
90 changes: 90 additions & 0 deletions pkg/migrator/migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

exampleregistry "github.com/hyperlocalise/rain-orm/examples/schema/registry"
"github.com/hyperlocalise/rain-orm/pkg/rain"
"github.com/hyperlocalise/rain-orm/pkg/schema"
_ "modernc.org/sqlite"
)
Expand Down Expand Up @@ -174,6 +175,81 @@ func TestDiffSnapshotsRejectAddConstraintOnSQLite(t *testing.T) {
}
}

func TestDiffSnapshotsRecreateChangedView(t *testing.T) {
t.Parallel()

ddl, err := rain.OpenDialect("postgres")
if err != nil {
t.Fatalf("OpenDialect(postgres): %v", err)
}
users := schema.Define("users", func(t *usersTable) {
t.ID = t.BigSerial("id").PrimaryKey()
t.Email = t.Text("email").NotNull()
})
baseQuery := ddl.Select().Table(users).Column(users.Email)
viewBefore := schema.DefineView("user_emails", baseQuery, func(v *userEmailsView) {
v.Email = v.VarChar("email", 255)
})
filteredQuery := ddl.Select().Table(users).Column(users.Email).Where(schema.Raw("? = ?", users.Email, "alice@example.com"))
viewAfter := schema.DefineView("user_emails", filteredQuery, func(v *userEmailsView) {
v.Email = v.VarChar("email", 255)
})

before := mustBuildSnapshot(t, "postgres", []schema.TableReference{users, viewBefore})
after := mustBuildSnapshot(t, "postgres", []schema.TableReference{users, viewAfter})
beforeView, ok := tableSnapshotByName(before, "user_emails")
if !ok || !beforeView.IsView {
t.Fatalf("expected view snapshot to set IsView, got %#v", beforeView)
}
afterView, ok := tableSnapshotByName(after, "user_emails")
if !ok || !afterView.IsView {
t.Fatalf("expected view snapshot to set IsView, got %#v", afterView)
}

plan, err := DiffSnapshots(&before, after)
if err != nil {
t.Fatalf("DiffSnapshots returned error: %v", err)
}
if len(plan.Statements) != 2 {
t.Fatalf("expected drop and create view statements, got %d: %v", len(plan.Statements), plan.Statements)
}
if plan.Statements[0] != `DROP VIEW "user_emails"` {
t.Fatalf("expected DROP VIEW statement, got %q", plan.Statements[0])
}
if !strings.HasPrefix(plan.Statements[1], `CREATE VIEW "user_emails" AS `) {
t.Fatalf("expected CREATE VIEW statement, got %q", plan.Statements[1])
}
if !strings.Contains(plan.Statements[1], `'alice@example.com'`) {
t.Fatalf("expected updated view definition, got %q", plan.Statements[1])
}
}

func TestDiffSnapshotsUnchangedView(t *testing.T) {
t.Parallel()

ddl, err := rain.OpenDialect("sqlite")
if err != nil {
t.Fatalf("OpenDialect(sqlite): %v", err)
}
users := schema.Define("users", func(t *usersTable) {
t.ID = t.BigSerial("id").PrimaryKey()
t.Email = t.Text("email").NotNull()
})
query := ddl.Select().Table(users).Column(users.Email)
view := schema.DefineView("user_emails", query, func(v *userEmailsView) {
v.Email = v.Text("email")
})

snapshot := mustBuildSnapshot(t, "sqlite", []schema.TableReference{users, view})
plan, err := DiffSnapshots(&snapshot, snapshot)
if err != nil {
t.Fatalf("DiffSnapshots returned error: %v", err)
}
if !plan.Empty() {
t.Fatalf("expected no statements for unchanged view, got %v", plan.Statements)
}
}

func TestDiffSnapshotsRejectAddForeignKeyOnSQLite(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -597,6 +673,15 @@ func TestLockNameColumnDDL(t *testing.T) {
}
}

func tableSnapshotByName(snapshot Snapshot, name string) (TableSnapshot, bool) {
for _, table := range snapshot.Tables {
if table.Name == name {
return table, true
}
}
return TableSnapshot{}, false
}

func mustBuildSnapshot(t *testing.T, dialectName string, tables []schema.TableReference) Snapshot {
t.Helper()

Expand All @@ -622,6 +707,11 @@ type postsTable struct {
UserID *schema.Column[int64]
}

type userEmailsView struct {
schema.TableModel
Email *schema.Column[string]
}

func usersTableWithoutNickname() schema.TableReference {
return schema.Define("users", func(t *usersTable) {
t.ID = t.BigSerial("id").PrimaryKey()
Expand Down
2 changes: 2 additions & 0 deletions pkg/migrator/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Snapshot struct {
// TableSnapshot stores a portable, deterministic representation of one table.
type TableSnapshot struct {
Name string `json:"name"`
IsView bool `json:"is_view,omitempty"`
CreateTableSQL string `json:"create_table_sql"`
Columns []ColumnSnapshot `json:"columns"`
Constraints []ConstraintSnapshot `json:"constraints"`
Expand Down Expand Up @@ -167,6 +168,7 @@ func BuildSnapshot(dialectName string, tables []schema.TableReference) (Snapshot

tableSnapshots = append(tableSnapshots, TableSnapshot{
Name: tableDef.Name,
IsView: tableDef.IsView,
CreateTableSQL: createTableSQL,
Columns: columnSnapshots,
Constraints: constraintSnapshots,
Expand Down
2 changes: 1 addition & 1 deletion pkg/rain/coverage_target_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ func TestCoverageDDLMethodsAndHelpers(t *testing.T) {
if _, err := columnDefinitionSQL(pg, users.TableDef(), &schema.ColumnDef{Name: "broken_default", Type: schema.ColumnType{DataType: schema.TypeText}, HasDefault: true, Default: struct{}{}}, false); err == nil {
t.Fatalf("expected columnDefinitionSQL default error")
}
if got := columnTypeSQL(sqlite, users.CreatedAt.ColumnDef()); got != "TEXT" {
if got := ddlColumnTypeSQL(sqlite, users.CreatedAt.ColumnDef()); got != "TEXT" {
t.Fatalf("unexpected sqlite timestamp type: %q", got)
}
if shouldEmitAutoIncrementKeyword(pg, &schema.ColumnDef{Name: "id", Type: schema.ColumnType{DataType: schema.TypeBigSerial}}, true) {
Expand Down
68 changes: 61 additions & 7 deletions pkg/rain/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ func (db *DB) CreateTableSQL(table schema.TableReference) (string, error) {
return "", errors.New("rain: create table requires a non-nil table")
}

if table.TableDef().IsView {
return createViewSQL(db.dialect, table.TableDef())
}

return createTableSQL(db.dialect, table.TableDef())
}

Expand All @@ -32,6 +36,10 @@ func (db *DB) CreateIndexesSQL(table schema.TableReference) ([]string, error) {
return nil, errors.New("rain: create indexes requires a non-nil table")
}

if table.TableDef().IsView {
return nil, nil
}

return createIndexesSQL(db.dialect, table.TableDef())
}

Expand All @@ -50,6 +58,10 @@ func (db *DB) ColumnDefinitionSQL(table schema.TableReference, columnName string
return "", fmt.Errorf("rain: table %q has no column %q", tableDef.Name, columnName)
}

if tableDef.IsView {
return db.dialect.QuoteIdentifier(column.Name) + " " + ddlColumnTypeSQL(db.dialect, column), nil
}

inlinePrimaryKey := false
tablePrimaryKey, err := tablePrimaryKeyConstraint(tableDef)
if err != nil {
Expand All @@ -73,6 +85,10 @@ func (db *DB) AddConstraintSQL(table schema.TableReference, constraintName strin
}

tableDef := table.TableDef()
if tableDef.IsView {
return "", fmt.Errorf("rain: view %q does not support constraints", tableDef.Name)
}

for _, constraint := range tableDef.Constraints {
if constraint.Name != constraintName {
continue
Expand All @@ -97,6 +113,10 @@ func (db *DB) AddForeignKeySQL(table schema.TableReference, foreignKeyName strin
}

tableDef := table.TableDef()
if tableDef.IsView {
return "", fmt.Errorf("rain: view %q does not support foreign keys", tableDef.Name)
}

for _, foreignKey := range tableDef.ForeignKeys {
if foreignKey.Name != foreignKeyName {
continue
Expand Down Expand Up @@ -132,6 +152,35 @@ func (db *DB) ColumnDefaultSQL(table schema.TableReference, columnName string) (
return columnDefaultSQL(db.dialect, column)
}

func createViewSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
if d == nil {
return "", errors.New("rain: create view requires a configured dialect")
}
if table == nil {
return "", errors.New("rain: create view requires a non-nil table")
}
if !table.IsView {
return "", fmt.Errorf("rain: table %q is not a view", table.Name)
}
if table.ViewQuery == nil {
return "", fmt.Errorf("rain: view %q requires a defining query", table.Name)
}

ctx := newCompileContext(d)
ctx.useLiterals = true
if err := ctx.writeExpressionInContext(table.ViewQuery, expressionContext{noParens: true}); err != nil {
return "", err
}

var builder strings.Builder
builder.WriteString("CREATE VIEW ")
builder.WriteString(d.QuoteIdentifier(table.Name))
builder.WriteString(" AS ")
builder.WriteString(ctx.String())

return builder.String(), nil
}

func createTableSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
if d == nil {
return "", errors.New("rain: create table requires a configured dialect")
Expand Down Expand Up @@ -297,7 +346,7 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche
var parts []string
parts = append(parts, d.QuoteIdentifier(column.Name))

typeSQL := columnTypeSQL(d, column)
typeSQL := ddlColumnTypeSQL(d, column)
parts = append(parts, typeSQL)

if inlinePrimaryKey {
Expand Down Expand Up @@ -338,7 +387,7 @@ func columnDefinitionSQL(d dialect.Dialect, table *schema.TableDef, column *sche
return strings.Join(parts, " "), nil
}

func columnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string {
func ddlColumnTypeSQL(d dialect.Dialect, column *schema.ColumnDef) string {
typeSQL := d.DataType(column.Type)

if column.Type.DataType == schema.TypeVarChar && column.Type.Size > 0 && strings.EqualFold(typeSQL, "VARCHAR") {
Expand All @@ -363,20 +412,25 @@ func shouldEmitAutoIncrementKeyword(d dialect.Dialect, column *schema.ColumnDef,
if !inlinePrimaryKey {
return false
}
if column.Type.DataType != schema.TypeBigSerial {
return true
}

switch d.Name() {
case "postgres":
return false
return !isPostgresSerialType(column.Type.DataType)
case "sqlite":
return true
default:
return true
}
}

func isPostgresSerialType(dataType schema.DataType) bool {
switch dataType {
case schema.TypeBigSerial, schema.TypeSerial, schema.TypeSmallSerial:
return true
default:
return false
}
}

func columnDefaultSQL(d dialect.Dialect, column *schema.ColumnDef) (string, error) {
if column.DefaultSQL != "" {
return column.DefaultSQL, nil
Expand Down
Loading