From 50f8a4e14c431af0f433b36399ee2c9056971de1 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Mon, 1 Jun 2026 22:51:24 +0000 Subject: [PATCH 1/4] feat(rain): support advanced UPSERT features (ON CONSTRAINT, WHERE clauses) - Updated InsertQuery.OnConflict to accept schema.Expression targets. - Added OnConstraint, TargetWhere, Where, and Set methods to InsertConflictBuilder. - Implemented dialect-aware rendering for EXCLUDED (Postgres/SQLite) and VALUES (MySQL). - Supported optional targets for ON CONFLICT DO NOTHING on Postgres/SQLite. - Added comprehensive tests for new UPSERT capabilities. Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/rain/query_compile.go | 16 ++- pkg/rain/query_insert.go | 150 ++++++++++++++++----- pkg/rain/upsert_advanced_test.go | 215 +++++++++++++++++++++++++++++++ 3 files changed, 348 insertions(+), 33 deletions(-) create mode 100644 pkg/rain/upsert_advanced_test.go diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index 270b5d1..c0d8b11 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -215,8 +215,17 @@ func (c *compileContext) writeSelectExpression(expr schema.Expression) error { func (c *compileContext) writeExpressionInContext(expr schema.Expression, context expressionContext) error { switch value := expr.(type) { + case excludedColumn: + if c.dialect.Name() == "mysql" { + c.writeString("VALUES(") + c.writeQuotedIdentifier(value.column.ColumnDef().Name) + c.writeByte(')') + } else { + c.writeString("EXCLUDED.") + c.writeQuotedIdentifier(value.column.ColumnDef().Name) + } case schema.ColumnReference: - c.writeColumn(value) + c.writeQualifiedColumn(value) case schema.ValueExpr: if c.useLiterals { sql, err := literalDDLSQL(c.dialect, value.Value) @@ -445,6 +454,11 @@ func (c *compileContext) writeRaw(raw schema.RawExpr) error { } func (c *compileContext) writeColumn(column schema.ColumnReference) { + def := column.ColumnDef() + c.writeQuotedIdentifier(def.Name) +} + +func (c *compileContext) writeQualifiedColumn(column schema.ColumnReference) { def := column.ColumnDef() table := def.Table qualifier := table.Name diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go index d3ff0e0..9bd1fa5 100644 --- a/pkg/rain/query_insert.go +++ b/pkg/rain/query_insert.go @@ -35,9 +35,17 @@ const ( ) type insertConflictClause struct { - columns []schema.ColumnReference - action insertConflictAction - updates []schema.ColumnReference + targets []schema.Expression + targetWhere schema.Predicate + constraint string + action insertConflictAction + updates []assignment + where schema.Predicate +} + +type excludedColumn struct { + schema.ExpressionMarker + column schema.ColumnReference } // InsertConflictBuilder configures conflict behavior for INSERT statements. @@ -98,21 +106,63 @@ func (q *InsertQuery) Values(rows ...map[schema.ColumnReference]any) *InsertQuer } // OnConflict starts an upsert clause for PostgreSQL and SQLite dialects. -func (q *InsertQuery) OnConflict(columns ...schema.ColumnReference) *InsertConflictBuilder { - q.conflict = &insertConflictClause{columns: columns} +func (q *InsertQuery) OnConflict(targets ...schema.Expression) *InsertConflictBuilder { + q.conflict = &insertConflictClause{targets: targets} return &InsertConflictBuilder{query: q} } +// OnConstraint configures ON CONFLICT ON CONSTRAINT. +func (b *InsertConflictBuilder) OnConstraint(name string) *InsertConflictBuilder { + b.query.conflict.constraint = name + return b +} + +// TargetWhere configures the filter condition for the conflict target (e.g. for partial indexes). +func (b *InsertConflictBuilder) TargetWhere(p schema.Predicate) *InsertConflictBuilder { + b.query.conflict.targetWhere = p + return b +} + // DoNothing configures ON CONFLICT ... DO NOTHING. func (b *InsertConflictBuilder) DoNothing() *InsertQuery { b.query.conflict.action = insertConflictActionDoNothing return b.query } -// DoUpdateSet configures ON CONFLICT ... DO UPDATE SET using EXCLUDED values (PostgreSQL/SQLite) or VALUES() references for MySQL VALUES inserts. +// DoUpdateSet configures ON CONFLICT ... DO UPDATE SET using EXCLUDED values (PostgreSQL/SQLite) or VALUES() references for MySQL. func (b *InsertConflictBuilder) DoUpdateSet(columns ...schema.ColumnReference) *InsertQuery { b.query.conflict.action = insertConflictActionDoUpdateSet - b.query.conflict.updates = columns + for _, col := range columns { + b.query.conflict.updates = append(b.query.conflict.updates, assignment{ + column: col, + value: excludedColumn{column: col}, + }) + } + return b.query +} + +// Set adds an explicit assignment to the DO UPDATE SET clause. +func (b *InsertConflictBuilder) Set(column schema.ColumnReference, value any) *InsertConflictBuilder { + var expr schema.Expression + if e, ok := value.(schema.Expression); ok { + expr = e + } else { + expr = schema.ValueExpr{Value: value} + } + b.query.conflict.updates = append(b.query.conflict.updates, assignment{column: column, value: expr}) + return b +} + +// Where adds a filter condition to the DO UPDATE SET clause. +func (b *InsertConflictBuilder) Where(p schema.Predicate) *InsertConflictBuilder { + b.query.conflict.where = p + return b +} + +// DoUpdate configures ON CONFLICT ... DO UPDATE SET. +// Use this after calling Set() to finish the builder. +func (b *InsertConflictBuilder) DoUpdate() *InsertQuery { + b.query.conflict.action = insertConflictActionDoUpdateSet return b.query } @@ -495,7 +545,7 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return nil } if q.conflict.action == insertConflictActionNone { - return errors.New("rain: conflict action is required; call DoNothing() or DoUpdateSet(...)") + return errors.New("rain: conflict action is required; call DoNothing(), DoUpdateSet(...), or DoUpdate()") } if q.dialect.Name() != "postgres" && q.dialect.Name() != "sqlite" && q.dialect.Name() != "mysql" { @@ -503,9 +553,19 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { } if q.dialect.Name() == "mysql" { - if len(q.conflict.columns) > 0 { + if len(q.conflict.targets) > 0 { return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE cannot target specific conflict columns; call OnConflict() without columns") } + if q.conflict.constraint != "" { + return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE does not support ON CONSTRAINT") + } + if q.conflict.targetWhere != nil { + return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE does not support target WHERE clause") + } + if q.conflict.where != nil { + return errors.New("rain: MySQL ON DUPLICATE KEY UPDATE does not support DO UPDATE WHERE clause") + } + if q.conflict.action == insertConflictActionDoNothing { noopColumn, err := mysqlConflictNoopColumn(q.table) if err != nil { @@ -525,37 +585,55 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return errors.New("rain: MySQL conflict DO UPDATE is not supported for INSERT ... SELECT") } ctx.writeString(" ON DUPLICATE KEY UPDATE ") - for idx, col := range q.conflict.updates { - if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { + for idx, item := range q.conflict.updates { + if err := validateAssignmentTarget(q.table, item); err != nil { return err } if idx > 0 { ctx.writeString(", ") } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - ctx.writeString(" = VALUES(") - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - ctx.writeByte(')') + ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) + ctx.writeString(" = ") + if err := ctx.writeExpression(item.value); err != nil { + return err + } } } return nil } - if len(q.conflict.columns) == 0 { - return errors.New("rain: conflict clause requires at least one target column") - } - - ctx.writeString(" ON CONFLICT (") - for idx, col := range q.conflict.columns { - if err := validateColumnBelongsToTable(q.table, col.ColumnDef()); err != nil { - return err + if q.conflict.constraint != "" { + if len(q.conflict.targets) > 0 { + return errors.New("rain: ON CONFLICT cannot specify both targets and ON CONSTRAINT") } - if idx > 0 { - ctx.writeString(", ") + ctx.writeString(" ON CONFLICT ON CONSTRAINT ") + ctx.writeQuotedIdentifier(q.conflict.constraint) + } else { + ctx.writeString(" ON CONFLICT") + if len(q.conflict.targets) > 0 { + ctx.writeString(" (") + for idx, target := range q.conflict.targets { + if idx > 0 { + ctx.writeString(", ") + } + if col, ok := target.(schema.ColumnReference); ok { + ctx.writeColumn(col) + } else if err := ctx.writeExpression(target); err != nil { + return err + } + } + ctx.writeByte(')') + } else if q.conflict.action == insertConflictActionDoUpdateSet { + return errors.New("rain: conflict DO UPDATE requires at least one target column or constraint") + } + + if q.conflict.targetWhere != nil { + ctx.writeString(" WHERE ") + if err := ctx.writePredicate(q.conflict.targetWhere); err != nil { + return err + } } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) } - ctx.writeByte(')') switch q.conflict.action { case insertConflictActionDoNothing: @@ -565,16 +643,24 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return errors.New("rain: conflict DO UPDATE requires at least one update column") } ctx.writeString(" DO UPDATE SET ") - for idx, col := range q.conflict.updates { - if err := validateAssignmentTarget(q.table, assignment{column: col}); err != nil { + for idx, item := range q.conflict.updates { + if err := validateAssignmentTarget(q.table, item); err != nil { return err } if idx > 0 { ctx.writeString(", ") } - ctx.writeQuotedIdentifier(col.ColumnDef().Name) - ctx.writeString(" = EXCLUDED.") - ctx.writeQuotedIdentifier(col.ColumnDef().Name) + ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) + ctx.writeString(" = ") + if err := ctx.writeExpression(item.value); err != nil { + return err + } + } + if q.conflict.where != nil { + ctx.writeString(" WHERE ") + if err := ctx.writePredicate(q.conflict.where); err != nil { + return err + } } } diff --git a/pkg/rain/upsert_advanced_test.go b/pkg/rain/upsert_advanced_test.go new file mode 100644 index 0000000..ebc9912 --- /dev/null +++ b/pkg/rain/upsert_advanced_test.go @@ -0,0 +1,215 @@ +package rain_test + +import ( + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/rain" +) + +func TestInsertOnConflictAdvancedPostgres(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("do nothing without targets", func(t *testing.T) { + sqlText, _, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + }) + + t.Run("on constraint", func(t *testing.T) { + sqlText, _, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(). + OnConstraint("users_email_key"). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ON CONSTRAINT "users_email_key" DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + }) + + t.Run("target where", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(users.Email). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") WHERE "users"."active" = $2 DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 2 || args[1] != true { + t.Fatalf("unexpected args: %#v", args) + } + }) + + t.Run("do update set with custom values and where", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + OnConflict(users.Email). + Set(users.Active, true). + Where(users.Active.Eq(false)). + DoUpdateSet(users.Name). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2) ON CONFLICT ("email") DO UPDATE SET "active" = $3, "name" = EXCLUDED."name" WHERE "users"."active" = $4` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 4 || args[2] != true || args[3] != false { + t.Fatalf("unexpected args: %#v", args) + } + }) + + t.Run("do update with only custom sets", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(users.Email). + Set(users.Name, "Conflicted"). + DoUpdate(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") DO UPDATE SET "name" = $2` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 2 || args[1] != "Conflicted" { + t.Fatalf("unexpected args: %#v", args) + } + }) +} + +func TestInsertOnConflictAdvancedSQLite(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("sqlite") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("target where", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + OnConflict(users.Email). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES (?) ON CONFLICT ("email") WHERE "users"."active" = ? DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 2 || args[1] != true { + t.Fatalf("unexpected args: %#v", args) + } + }) +} + +func TestInsertOnConflictAdvancedMySQL(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("mysql") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("do update set with custom values", func(t *testing.T) { + sqlText, args, err := db.Insert(). + Table(users). + Set(users.Email, "alice@example.com"). + Set(users.Name, "Alice"). + OnConflict(). + Set(users.Active, true). + DoUpdateSet(users.Name). + ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := "INSERT INTO `users` (`email`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `active` = ?, `name` = VALUES(`name`)" + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + if len(args) != 3 || args[2] != true { + t.Fatalf("unexpected args: %#v", args) + } + }) + + t.Run("rejects postgres-only features", func(t *testing.T) { + _, _, err := db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + OnConstraint("foo"). + DoNothing(). + ToSQL() + if err == nil { + t.Fatal("expected error for ON CONSTRAINT on MySQL") + } + + _, _, err = db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + if err == nil { + t.Fatal("expected error for TargetWhere on MySQL") + } + + _, _, err = db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + Where(users.Active.Eq(true)). + DoUpdateSet(users.Name). + ToSQL() + if err == nil { + t.Fatal("expected error for Where on MySQL") + } + }) +} From 8069556d080f9a898e7b2e4cc401474848591516 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 03:37:14 +0000 Subject: [PATCH 2/4] feat(rain): enhance UPSERT support (ON CONSTRAINT, WHERE clauses, custom assignments) - Added support for named constraints via `OnConstraint` - Added support for partial index targets via `TargetWhere` - Added support for conditional updates via `Where` in conflict builder - Added `Set` and `DoUpdate` for explicit assignments in the update clause - Implemented dialect-aware rendering for excluded values (`EXCLUDED.` vs `VALUES()`) - Fixed table qualification in conflict target lists for PostgreSQL - Added comprehensive tests for all new capabilities across dialects Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/rain/query_compile.go | 41 ++++++++++++++++++-------------- pkg/rain/query_insert.go | 34 +++++++++++++++++--------- pkg/rain/upsert_advanced_test.go | 12 +++++----- 3 files changed, 52 insertions(+), 35 deletions(-) diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index c0d8b11..c4cf65d 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -201,8 +201,9 @@ func (c *compileContext) writePredicate(predicate schema.Predicate) error { } type expressionContext struct { - allowAlias bool - noParens bool + allowAlias bool + noParens bool + unqualified bool } func (c *compileContext) writeExpression(expr schema.Expression) error { @@ -225,7 +226,11 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.writeQuotedIdentifier(value.column.ColumnDef().Name) } case schema.ColumnReference: - c.writeQualifiedColumn(value) + if context.unqualified { + c.writeColumn(value) + } else { + c.writeQualifiedColumn(value) + } case schema.ValueExpr: if c.useLiterals { sql, err := literalDDLSQL(c.dialect, value.Value) @@ -246,20 +251,20 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.argPlan = append(c.argPlan, compiledArg{kind: compiledArgNamedPlaceholder, name: value.Name}) c.writeString(c.dialect.Placeholder(index)) case schema.ComparisonExpr: - if err := c.writeExpression(value.Left); err != nil { + if err := c.writeExpressionInContext(value.Left, context); err != nil { return err } c.writeByte(' ') c.writeString(value.Operator) c.writeByte(' ') - if err := c.writeExpression(value.Right); err != nil { + if err := c.writeExpressionInContext(value.Right, context); err != nil { return err } case schema.InExpr: if len(value.Values) == 0 { return errors.New("rain: IN predicate requires at least one value") } - if err := c.writeExpression(value.Left); err != nil { + if err := c.writeExpressionInContext(value.Left, context); err != nil { return err } if value.Negated { @@ -271,7 +276,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex if idx > 0 { c.writeString(", ") } - ctx := expressionContext{} + ctx := expressionContext{unqualified: context.unqualified} if len(value.Values) == 1 { ctx.noParens = true } @@ -281,7 +286,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex } c.writeByte(')') case schema.BetweenExpr: - if err := c.writeExpression(value.Left); err != nil { + if err := c.writeExpressionInContext(value.Left, context); err != nil { return err } if value.Negated { @@ -289,16 +294,16 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex } else { c.writeString(" BETWEEN ") } - if err := c.writeExpression(value.Start); err != nil { + if err := c.writeExpressionInContext(value.Start, context); err != nil { return err } c.writeString(" AND ") - if err := c.writeExpression(value.End); err != nil { + if err := c.writeExpressionInContext(value.End, context); err != nil { return err } case schema.NotExpr: c.writeString("NOT (") - if err := c.writePredicate(value.Expr); err != nil { + if err := c.writeExpressionInContext(value.Expr, context); err != nil { return err } c.writeByte(')') @@ -338,7 +343,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.writeString(value.Operator) c.writeByte(' ') } - if err := c.writePredicate(part); err != nil { + if err := c.writeExpressionInContext(part, context); err != nil { return err } } @@ -350,23 +355,23 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.writeString("CASE") if value.ValueExpression != nil { c.writeByte(' ') - if err := c.writeExpression(value.ValueExpression); err != nil { + if err := c.writeExpressionInContext(value.ValueExpression, context); err != nil { return err } } for _, pair := range value.WhenThenPairs { c.writeString(" WHEN ") - if err := c.writeExpression(pair.When); err != nil { + if err := c.writeExpressionInContext(pair.When, context); err != nil { return err } c.writeString(" THEN ") - if err := c.writeExpression(pair.Then); err != nil { + if err := c.writeExpressionInContext(pair.Then, context); err != nil { return err } } if value.ElseExpression != nil { c.writeString(" ELSE ") - if err := c.writeExpression(value.ElseExpression); err != nil { + if err := c.writeExpressionInContext(value.ElseExpression, context); err != nil { return err } } @@ -387,7 +392,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex case value.Star: c.writeByte('*') case value.Expr != nil: - if err := c.writeExpression(value.Expr); err != nil { + if err := c.writeExpressionInContext(value.Expr, context); err != nil { return err } default: @@ -406,7 +411,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex if idx > 0 { c.writeString(", ") } - if err := c.writeExpression(part); err != nil { + if err := c.writeExpressionInContext(part, context); err != nil { return err } } diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go index 9bd1fa5..2ca308d 100644 --- a/pkg/rain/query_insert.go +++ b/pkg/rain/query_insert.go @@ -585,10 +585,11 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return errors.New("rain: MySQL conflict DO UPDATE is not supported for INSERT ... SELECT") } ctx.writeString(" ON DUPLICATE KEY UPDATE ") - for idx, item := range q.conflict.updates { - if err := validateAssignmentTarget(q.table, item); err != nil { - return err - } + updates, err := mergeAssignments(q.table, nil, q.conflict.updates) + if err != nil { + return err + } + for idx, item := range updates { if idx > 0 { ctx.writeString(", ") } @@ -603,6 +604,9 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { } if q.conflict.constraint != "" { + if q.dialect.Name() != "postgres" { + return fmt.Errorf("rain: ON CONSTRAINT is not supported by %s dialect", q.dialect.Name()) + } if len(q.conflict.targets) > 0 { return errors.New("rain: ON CONFLICT cannot specify both targets and ON CONSTRAINT") } @@ -617,8 +621,11 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { ctx.writeString(", ") } if col, ok := target.(schema.ColumnReference); ok { - ctx.writeColumn(col) - } else if err := ctx.writeExpression(target); err != nil { + if err := validateColumnBelongsToTable(q.table, col.ColumnDef()); err != nil { + return err + } + } + if err := ctx.writeExpressionInContext(target, expressionContext{unqualified: true}); err != nil { return err } } @@ -629,7 +636,11 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { if q.conflict.targetWhere != nil { ctx.writeString(" WHERE ") - if err := ctx.writePredicate(q.conflict.targetWhere); err != nil { + oldUseLiterals := ctx.useLiterals + ctx.useLiterals = true + err := ctx.writeExpressionInContext(q.conflict.targetWhere, expressionContext{unqualified: true}) + ctx.useLiterals = oldUseLiterals + if err != nil { return err } } @@ -643,10 +654,11 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { return errors.New("rain: conflict DO UPDATE requires at least one update column") } ctx.writeString(" DO UPDATE SET ") - for idx, item := range q.conflict.updates { - if err := validateAssignmentTarget(q.table, item); err != nil { - return err - } + updates, err := mergeAssignments(q.table, nil, q.conflict.updates) + if err != nil { + return err + } + for idx, item := range updates { if idx > 0 { ctx.writeString(", ") } diff --git a/pkg/rain/upsert_advanced_test.go b/pkg/rain/upsert_advanced_test.go index ebc9912..f96bec3 100644 --- a/pkg/rain/upsert_advanced_test.go +++ b/pkg/rain/upsert_advanced_test.go @@ -62,11 +62,11 @@ func TestInsertOnConflictAdvancedPostgres(t *testing.T) { t.Fatalf("ToSQL returned error: %v", err) } - wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") WHERE "users"."active" = $2 DO NOTHING` + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") WHERE "active" = TRUE DO NOTHING` if sqlText != wantSQL { t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) } - if len(args) != 2 || args[1] != true { + if len(args) != 1 { t.Fatalf("unexpected args: %#v", args) } }) @@ -85,7 +85,7 @@ func TestInsertOnConflictAdvancedPostgres(t *testing.T) { t.Fatalf("ToSQL returned error: %v", err) } - wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2) ON CONFLICT ("email") DO UPDATE SET "active" = $3, "name" = EXCLUDED."name" WHERE "users"."active" = $4` + wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2) ON CONFLICT ("email") DO UPDATE SET "name" = EXCLUDED."name", "active" = $3 WHERE "users"."active" = $4` if sqlText != wantSQL { t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) } @@ -137,11 +137,11 @@ func TestInsertOnConflictAdvancedSQLite(t *testing.T) { t.Fatalf("ToSQL returned error: %v", err) } - wantSQL := `INSERT INTO "users" ("email") VALUES (?) ON CONFLICT ("email") WHERE "users"."active" = ? DO NOTHING` + wantSQL := `INSERT INTO "users" ("email") VALUES (?) ON CONFLICT ("email") WHERE "active" = 1 DO NOTHING` if sqlText != wantSQL { t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) } - if len(args) != 2 || args[1] != true { + if len(args) != 1 { t.Fatalf("unexpected args: %#v", args) } }) @@ -169,7 +169,7 @@ func TestInsertOnConflictAdvancedMySQL(t *testing.T) { t.Fatalf("ToSQL returned error: %v", err) } - wantSQL := "INSERT INTO `users` (`email`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `active` = ?, `name` = VALUES(`name`)" + wantSQL := "INSERT INTO `users` (`email`, `name`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `name` = VALUES(`name`), `active` = ?" if sqlText != wantSQL { t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) } From 19c02ec817efb9164263474827ac8b4472fb6aeb Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 23:57:34 +0000 Subject: [PATCH 3/4] fix(rain): address PR feedback on advanced UPSERT - Propagate `unqualified` context in `NullCheckExpr` for `TargetWhere` clauses. - Add validation to reject `TargetWhere` without conflict target columns. - Update tests to cover these edge cases. Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/rain/query_compile.go | 2 +- pkg/rain/query_insert.go | 3 +++ pkg/rain/upsert_advanced_test.go | 40 ++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index c4cf65d..ed5d75a 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -327,7 +327,7 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.writeByte(')') } case schema.NullCheckExpr: - if err := c.writeExpression(value.Expr); err != nil { + if err := c.writeExpressionInContext(value.Expr, context); err != nil { return err } if value.Negated { diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go index 2ca308d..d264016 100644 --- a/pkg/rain/query_insert.go +++ b/pkg/rain/query_insert.go @@ -635,6 +635,9 @@ func (q *InsertQuery) writeConflictClause(ctx *compileContext) error { } if q.conflict.targetWhere != nil { + if len(q.conflict.targets) == 0 { + return errors.New("rain: conflict targetWhere requires at least one conflict target column") + } ctx.writeString(" WHERE ") oldUseLiterals := ctx.useLiterals ctx.useLiterals = true diff --git a/pkg/rain/upsert_advanced_test.go b/pkg/rain/upsert_advanced_test.go index f96bec3..d114a98 100644 --- a/pkg/rain/upsert_advanced_test.go +++ b/pkg/rain/upsert_advanced_test.go @@ -116,6 +116,46 @@ func TestInsertOnConflictAdvancedPostgres(t *testing.T) { }) } +func TestInsertOnConflictValidation(t *testing.T) { + t.Parallel() + + db, _ := rain.OpenDialect("postgres") + users, _ := defineTables() + + t.Run("targetWhere without targets returns error", func(t *testing.T) { + _, _, err := db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(). + TargetWhere(users.Active.Eq(true)). + DoNothing(). + ToSQL() + + if err == nil || err.Error() != "rain: conflict targetWhere requires at least one conflict target column" { + t.Fatalf("expected targetWhere validation error, got %v", err) + } + }) + + t.Run("NullCheckExpr in targetWhere is unqualified", func(t *testing.T) { + sqlText, _, err := db.Insert(). + Table(users). + Set(users.Email, "a"). + OnConflict(users.Email). + TargetWhere(users.Name.IsNull()). + DoNothing(). + ToSQL() + + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email") VALUES ($1) ON CONFLICT ("email") WHERE "name" IS NULL DO NOTHING` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } + }) +} + func TestInsertOnConflictAdvancedSQLite(t *testing.T) { t.Parallel() From 9846b417c7f8b8b0525561388bc912d93a4fef4c Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:03:21 +0000 Subject: [PATCH 4/4] chore: run make fmt to fix CI failure Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/rain/upsert_advanced_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/rain/upsert_advanced_test.go b/pkg/rain/upsert_advanced_test.go index d114a98..2d1248e 100644 --- a/pkg/rain/upsert_advanced_test.go +++ b/pkg/rain/upsert_advanced_test.go @@ -144,7 +144,6 @@ func TestInsertOnConflictValidation(t *testing.T) { TargetWhere(users.Name.IsNull()). DoNothing(). ToSQL() - if err != nil { t.Fatalf("ToSQL returned error: %v", err) }