diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index 270b5d1..ed5d75a 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -201,8 +201,9 @@ func (c *compileContext) writePredicate(predicate schema.Predicate) error { } type expressionContext struct { - allowAlias bool - noParens bool + allowAlias bool + noParens bool + unqualified bool } func (c *compileContext) writeExpression(expr schema.Expression) error { @@ -215,8 +216,21 @@ func (c *compileContext) writeSelectExpression(expr schema.Expression) error { func (c *compileContext) writeExpressionInContext(expr schema.Expression, context expressionContext) error { switch value := expr.(type) { + case excludedColumn: + if c.dialect.Name() == "mysql" { + c.writeString("VALUES(") + c.writeQuotedIdentifier(value.column.ColumnDef().Name) + c.writeByte(')') + } else { + c.writeString("EXCLUDED.") + c.writeQuotedIdentifier(value.column.ColumnDef().Name) + } case schema.ColumnReference: - c.writeColumn(value) + if context.unqualified { + c.writeColumn(value) + } else { + c.writeQualifiedColumn(value) + } case schema.ValueExpr: if c.useLiterals { sql, err := literalDDLSQL(c.dialect, value.Value) @@ -237,20 +251,20 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgNamedPlaceholder, name: value.Name}) c.writeString(c.dialect.Placeholder(index)) case schema.ComparisonExpr: - if err := c.writeExpression(value.Left); err != nil { + if err := c.writeExpressionInContext(value.Left, context); err != nil { return err } c.writeByte(' ') c.writeString(value.Operator) c.writeByte(' ') - if err := c.writeExpression(value.Right); err != nil { + if err := c.writeExpressionInContext(value.Right, context); err != nil { return err } case schema.InExpr: if len(value.Values) == 0 { return errors.New("rain: IN predicate requires at least one value") } - if err := c.writeExpression(value.Left); err != nil { + if err := c.writeExpressionInContext(value.Left, context); err != nil { return err } if value.Negated { @@ -262,7 +276,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex if idx > 0 { c.writeString(", ") } - ctx := expressionContext{} + ctx := expressionContext{unqualified: context.unqualified} if len(value.Values) == 1 { ctx.noParens = true } @@ -272,7 +286,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex } c.writeByte(')') case schema.BetweenExpr: - if err := c.writeExpression(value.Left); err != nil { + if err := c.writeExpressionInContext(value.Left, context); err != nil { return err } if value.Negated { @@ -280,16 +294,16 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex } else { c.writeString(" BETWEEN ") } - if err := c.writeExpression(value.Start); err != nil { + if err := c.writeExpressionInContext(value.Start, context); err != nil { return err } c.writeString(" AND ") - if err := c.writeExpression(value.End); err != nil { + if err := c.writeExpressionInContext(value.End, context); err != nil { return err } case schema.NotExpr: c.writeString("NOT (") - if err := c.writePredicate(value.Expr); err != nil { + if err := c.writeExpressionInContext(value.Expr, context); err != nil { return err } c.writeByte(')') @@ -313,7 +327,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.writeByte(')') } case schema.NullCheckExpr: - if err := c.writeExpression(value.Expr); err != nil { + if err := c.writeExpressionInContext(value.Expr, context); err != nil { return err } if value.Negated { @@ -329,7 +343,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.writeString(value.Operator) c.writeByte(' ') } - if err := c.writePredicate(part); err != nil { + if err := c.writeExpressionInContext(part, context); err != nil { return err } } @@ -341,23 +355,23 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.writeString("CASE") if value.ValueExpression != nil { c.writeByte(' ') - if err := c.writeExpression(value.ValueExpression); err != nil { + if err := c.writeExpressionInContext(value.ValueExpression, context); err != nil { return err } } for _, pair := range value.WhenThenPairs { c.writeString(" WHEN ") - if err := c.writeExpression(pair.When); err != nil { + if err := c.writeExpressionInContext(pair.When, context); err != nil { return err } c.writeString(" THEN ") - if err := c.writeExpression(pair.Then); err != nil { + if err := c.writeExpressionInContext(pair.Then, context); err != nil { return err } } if value.ElseExpression != nil { c.writeString(" ELSE ") - if err := c.writeExpression(value.ElseExpression); err != nil { + if err := c.writeExpressionInContext(value.ElseExpression, context); err != nil { return err } } @@ -378,7 +392,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex case value.Star: c.writeByte('*') case value.Expr != nil: - if err := c.writeExpression(value.Expr); err != nil { + if err := c.writeExpressionInContext(value.Expr, context); err != nil { return err } default: @@ -397,7 +411,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex if idx > 0 { c.writeString(", ") } - if err := c.writeExpression(part); err != nil { + if err := c.writeExpressionInContext(part, context); err != nil { return err } } @@ -445,6 +459,11 @@ func (c *compileContext) writeRaw(raw schema.RawExpr) error { } func (c *compileContext) writeColumn(column schema.ColumnReference) { + def := column.ColumnDef() + c.writeQuotedIdentifier(def.Name) +} + +func (c *compileContext) writeQualifiedColumn(column schema.ColumnReference) { def := column.ColumnDef() table := def.Table qualifier := table.Name diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go index d3ff0e0..d264016 100644 --- a/pkg/rain/query_insert.go +++ b/pkg/rain/query_insert.go @@ -35,9 +35,17 @@ const ( ) type insertConflictClause struct { - columns []schema.ColumnReference - action insertConflictAction - updates []schema.ColumnReference + targets []schema.Expression + targetWhere schema.Predicate + constraint string + action insertConflictAction + updates []assignment + where schema.Predicate +} + +type excludedColumn struct { + schema.ExpressionMarker + column schema.ColumnReference } // InsertConflictBuilder configures conflict behavior for INSERT statements. @@ -98,21 +106,63 @@ func (q *InsertQuery) Values(rows ...map[schema.ColumnReference]any) *InsertQuer } // OnConflict starts an upsert clause for PostgreSQL and SQLite dialects. -func (q *InsertQuery) OnConflict(columns ...schema.ColumnReference) *InsertConflictBuilder { - q.conflict = &insertConflictClause{columns: columns} +func (q *InsertQuery) OnConflict(targets ...schema.Expression) *InsertConflictBuilder { + q.conflict = &insertConflictClause{targets: targets} return &InsertConflictBuilder{query: q} } +// OnConstraint configures ON CONFLICT ON CONSTRAINT. +func (b *InsertConflictBuilder) OnConstraint(name string) *InsertConflictBuilder { + b.query.conflict.constraint = name + return b +} + +// TargetWhere configures the filter condition for the conflict target (e.g. for partial indexes). +func (b *InsertConflictBuilder) TargetWhere(p schema.Predicate) *InsertConflictBuilder { + b.query.conflict.targetWhere = p + return b +} + // DoNothing configures ON CONFLICT ... DO NOTHING. func (b *InsertConflictBuilder) DoNothing() *InsertQuery { b.query.conflict.action = insertConflictActionDoNothing return b.query } -// DoUpdateSet configures ON CONFLICT ... DO UPDATE SET using EXCLUDED values (PostgreSQL/SQLite) or VALUES() references for MySQL VALUES inserts. +// DoUpdateSet configures ON CONFLICT ... DO UPDATE SET using EXCLUDED values (PostgreSQL/SQLite) or VALUES() references for MySQL. func (b *InsertConflictBuilder) DoUpdateSet(columns ...schema.ColumnReference) *InsertQuery { b.query.conflict.action = insertConflictActionDoUpdateSet - b.query.conflict.updates = columns + for _, col := range columns { + b.query.conflict.updates = append(b.query.conflict.updates, assignment{ + column: col, + value: excludedColumn{column: col}, + }) + } + return b.query +} + +// Set adds an explicit assignment to the DO UPDATE SET clause. +func (b *InsertConflictBuilder) Set(column schema.ColumnReference, value any) *InsertConflictBuilder { + var expr schema.Expression + if e, ok := value.(schema.Expression); ok { + expr = e + } else { + expr = schema.ValueExpr{Value: value} + } + b.query.conflict.updates = append(b.query.conflict.updates, assignment{column: column, value: expr}) + return b +} + +// Where adds a filter condition to the DO UPDATE SET clause. +func (b *InsertConflictBuilder) Where(p schema.Predicate) *InsertConflictBuilder { + b.query.conflict.where = p + return b +} + +// DoUpdate configures ON CONFLICT ... DO UPDATE SET. +// Use this after calling Set() to finish the builder. +func (b *InsertConflictBuilder) DoUpdate() *InsertQuery { + b.query.conflict.action = insertConflictActionDoUpdateSet return b.query } @@ -495,7 +545,7 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return nil } if q.conflict.action == insertConflictActionNone { - return errors.New("rain: conflict action is required; call DoNothing() or DoUpdateSet(...)") + return errors.New("rain: conflict action is required; call DoNothing(), DoUpdateSet(...), or DoUpdate()") } if q.dialect.Name() != "postgres" && q.dialect.Name() != "sqlite" && q.dialect.Name() != "mysql" { @@ -503,9 +553,19 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { } if q.dialect.Name() == "mysql" { - if len(q.conflict.columns) > 0 { + if len(q.conflict.targets) > 0 { return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE cannot target specific conflict columns; call OnConflict() without columns") } + if q.conflict.constraint != "" { + return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE does not support ON CONSTRAINT") + } + if q.conflict.targetWhere != nil { + return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE does not support target WHERE clause") + } + if q.conflict.where != nil { + return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE does not support DO UPDATE WHERE clause") + } + if q.conflict.action == insertConflictActionDoNothing { noopColumn, err := mysqlConflictNoopColumn(q.table) if err != nil { @@ -525,37 +585,69 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return errors.New("rain: MySQL conflict DO UPDATE is not supported for INSERT ... SELECT") } ctx.writeString(" ON DUPLICATE KEY UPDATE ") - for idx, col := range q.conflict.updates { - if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { - return err - } + updates, err := mergeAssignments(q.table, nil, q.conflict.updates) + if err != nil { + return err + } + for idx, item := range updates { if idx > 0 { ctx.writeString(", ") } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - ctx.writeString(" = VALUES(") - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - ctx.writeByte(')') + ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) + ctx.writeString(" = ") + if err := ctx.writeExpression(item.value); err != nil { + return err + } } } return nil } - if len(q.conflict.columns) == 0 { - return errors.New("rain: conflict clause requires at least one target column") - } - - ctx.writeString(" ON CONFLICT (") - for idx, col := range q.conflict.columns { - if err := validateColumnBelongsToTable(q.table, col.ColumnDef()); err != nil { - return err + if q.conflict.constraint != "" { + if q.dialect.Name() != "postgres" { + return fmt.Errorf("rain: ON CONSTRAINT is not supported by %s dialect", q.dialect.Name()) } - if idx > 0 { - ctx.writeString(", ") + if len(q.conflict.targets) > 0 { + return errors.New("rain: ON CONFLICT cannot specify both targets and ON CONSTRAINT") + } + ctx.writeString(" ON CONFLICT ON CONSTRAINT ") + ctx.writeQuotedIdentifier(q.conflict.constraint) + } else { + ctx.writeString(" ON CONFLICT") + if len(q.conflict.targets) > 0 { + ctx.writeString(" (") + for idx, target := range q.conflict.targets { + if idx > 0 { + ctx.writeString(", ") + } + if col, ok := target.(schema.ColumnReference); ok { + if err := validateColumnBelongsToTable(q.table, col.ColumnDef()); err != nil { + return err + } + } + if err := ctx.writeExpressionInContext(target, expressionContext{unqualified: true}); err != nil { + return err + } + } + ctx.writeByte(')') + } else if q.conflict.action == insertConflictActionDoUpdateSet { + return errors.New("rain: conflict DO UPDATE requires at least one target column or constraint") + } + + if q.conflict.targetWhere != nil { + if len(q.conflict.targets) == 0 { + return errors.New("rain: conflict targetWhere requires at least one conflict target column") + } + ctx.writeString(" WHERE ") + oldUseLiterals := ctx.useLiterals + ctx.useLiterals = true + err := ctx.writeExpressionInContext(q.conflict.targetWhere, expressionContext{unqualified: true}) + ctx.useLiterals = oldUseLiterals + if err != nil { + return err + } } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) } - ctx.writeByte(')') switch q.conflict.action { case insertConflictActionDoNothing: @@ -565,16 +657,25 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return errors.New("rain: conflict DO UPDATE requires at least one update column") } ctx.writeString(" DO UPDATE SET ") - for idx, col := range q.conflict.updates { - if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { - return err - } + updates, err := mergeAssignments(q.table, nil, q.conflict.updates) + if err != nil { + return err + } + for idx, item := range updates { if idx > 0 { ctx.writeString(", ") } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - ctx.writeString(" = EXCLUDED.") - ctx.writeQuotedIdentifier(col.ColumnDef().Name) + ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) + ctx.writeString(" = ") + if err := ctx.writeExpression(item.value); err != nil { + return err + } + } + if q.conflict.where != nil { + ctx.writeString(" WHERE ") + if err := ctx.writePredicate(q.conflict.where); err != nil { + return err + } } } diff --git a/pkg/rain/upsert_advanced_test.go b/pkg/rain/upsert_advanced_test.go new file mode 100644 index 0000000..2d1248e --- /dev/null +++ b/pkg/rain/upsert_advanced_test.go @@ -0,0 +1,254 @@ +package rain_test + +import ( + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/rain" +) + +func TestInsertOnConflictAdvancedPostgres(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("do nothing without targets", func(t *testing.T) { + sqlText, _, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + }) + + t.Run("on constraint", func(t *testing.T) { + sqlText, _, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(). + OnConstraint("users_email_key"). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ON CONSTRAINT "users_email_key" DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + }) + + t.Run("target where", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(users.Email). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") WHERE "active" = TRUE DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 1 { + t.Fatalf("unexpected args: %#v", args) + } + }) + + t.Run("do update set with custom values and where", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + OnConflict(users.Email). + Set(users.Active, true). + Where(users.Active.Eq(false)). + DoUpdateSet(users.Name). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2) ON CONFLICT ("email") DO UPDATE SET "name" = EXCLUDED."name", "active" = $3 WHERE "users"."active" = $4` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 4 || args[2] != true || args[3] != false { + t.Fatalf("unexpected args: %#v", args) + } + }) + + t.Run("do update with only custom sets", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(users.Email). + Set(users.Name, "Conflicted"). + DoUpdate(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "name" = $2` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 2 || args[1] != "Conflicted" { + t.Fatalf("unexpected args: %#v", args) + } + }) +} + +func TestInsertOnConflictValidation(t *testing.T) { + t.Parallel() + + db, _ := rain.OpenDialect("postgres") + users, _ := defineTables() + + t.Run("targetWhere without targets returns error", func(t *testing.T) { + _, _, err := db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + + if err == nil || err.Error() != "rain: conflict targetWhere requires at least one conflict target column" { + t.Fatalf("expected targetWhere validation error, got %v", err) + } + }) + + t.Run("NullCheckExpr in targetWhere is unqualified", func(t *testing.T) { + sqlText, _, err := db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(users.Email). + TargetWhere(users.Name.IsNull()). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") WHERE "name" IS NULL DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + }) +} + +func TestInsertOnConflictAdvancedSQLite(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("target where", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(users.Email). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES (?) ON CONFLICT ("email") WHERE "active" = 1 DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 1 { + t.Fatalf("unexpected args: %#v", args) + } + }) +} + +func TestInsertOnConflictAdvancedMySQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("mysql") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("do update set with custom values", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + OnConflict(). + Set(users.Active, true). + DoUpdateSet(users.Name). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := "INSERT INTO `users` (`email`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `active` = ?" + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 3 || args[2] != true { + t.Fatalf("unexpected args: %#v", args) + } + }) + + t.Run("rejects postgres-only features", func(t *testing.T) { + _, _, err := db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + OnConstraint("foo"). + DoNothing(). + ToSQL() + if err == nil { + t.Fatal("expected error for ON CONSTRAINT on MySQL") + } + + _, _, err = db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + if err == nil { + t.Fatal("expected error for TargetWhere on MySQL") + } + + _, _, err = db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + Where(users.Active.Eq(true)). + DoUpdateSet(users.Name). + ToSQL() + if err == nil { + t.Fatal("expected error for Where on MySQL") + } + }) +}