From cc3c26c8b73912ff76874187253307354a3bbb5b Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sat, 30 May 2026 22:39:21 +0000 Subject: [PATCH 1/2] feat(rain): support CTEs, ORDER BY, and LIMIT in UPDATE/DELETE This change enhances UpdateQuery and DeleteQuery to support advanced SQL clauses, bringing them closer to parity with Drizzle ORM. - Added new dialect features for Update/Delete Order and Limit. - Enabled CTE support for MySQL and SQLite. - Refactored shared SQL rendering logic into pkg/rain/query_common.go. - Implemented With(), OrderBy(), and Limit() on UpdateQuery and DeleteQuery. - Added comprehensive unit tests for all new functionality. Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/dialect/dialect_test.go | 4 +- pkg/dialect/feature.go | 4 ++ pkg/dialect/mysql.go | 10 ++- pkg/dialect/sqlite.go | 7 +- pkg/rain/query_common.go | 69 +++++++++++++++++++ pkg/rain/query_delete.go | 32 +++++++++ pkg/rain/query_delete_test.go | 117 ++++++++++++++++++++++++++++++++ pkg/rain/query_select.go | 62 ++--------------- pkg/rain/query_select_test.go | 4 +- pkg/rain/query_update.go | 32 +++++++++ pkg/rain/query_update_test.go | 123 ++++++++++++++++++++++++++++++++++ pkg/rain/query_write_test.go | 23 +++++-- 12 files changed, 417 insertions(+), 70 deletions(-) create mode 100644 pkg/rain/query_delete_test.go create mode 100644 pkg/rain/query_update_test.go diff --git a/pkg/dialect/dialect_test.go b/pkg/dialect/dialect_test.go index 843d96e..8e7997b 100644 --- a/pkg/dialect/dialect_test.go +++ b/pkg/dialect/dialect_test.go @@ -224,7 +224,7 @@ func TestMySQLDialect(t *testing.T) { if got := d.Name(); got != "mysql" { t.Fatalf("unexpected name: %q", got) } - if got := d.Features(); got != FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureSelectLocking { + if got := d.Features(); got != FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureSelectLocking|FeatureCTE|FeatureUpdateOrder|FeatureUpdateLimit|FeatureDeleteOrder|FeatureDeleteLimit { t.Fatalf("unexpected features: %b", got) } if got := d.QuoteIdentifier("user`name"); got != "`user``name`" { @@ -314,7 +314,7 @@ func TestSQLiteDialect(t *testing.T) { if got := d.Name(); got != "sqlite" { t.Fatalf("unexpected name: %q", got) } - if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureNullsOrder { + if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureNullsOrder|FeatureCTE|FeatureUpdateOrder|FeatureUpdateLimit|FeatureDeleteOrder|FeatureDeleteLimit { t.Fatalf("unexpected features: %b", got) } if got := d.QuoteIdentifier(`user"name`); got != `"user""name"` { diff --git a/pkg/dialect/feature.go b/pkg/dialect/feature.go index d11cdb8..b9dea8f 100644 --- a/pkg/dialect/feature.go +++ b/pkg/dialect/feature.go @@ -15,6 +15,10 @@ const ( FeatureSelectLocking FeatureNullsOrder FeatureSelectDistinctOn + FeatureUpdateOrder + FeatureUpdateLimit + FeatureDeleteOrder + FeatureDeleteLimit ) // HasFeature reports whether a feature set includes the requested capability. diff --git a/pkg/dialect/mysql.go b/pkg/dialect/mysql.go index 6a9bb83..5ab4294 100644 --- a/pkg/dialect/mysql.go +++ b/pkg/dialect/mysql.go @@ -19,7 +19,15 @@ func (d *MySQLDialect) Name() string { // Features returns MySQL capabilities supported by Rain. func (d *MySQLDialect) Features() Feature { - return FeatureOffset | FeatureUpsert | FeatureSavepoint | FeatureSelectLocking + return FeatureOffset | + FeatureUpsert | + FeatureSavepoint | + FeatureSelectLocking | + FeatureCTE | + FeatureUpdateOrder | + FeatureUpdateLimit | + FeatureDeleteOrder | + FeatureDeleteLimit } // QuoteIdentifier quotes identifiers with backticks. diff --git a/pkg/dialect/sqlite.go b/pkg/dialect/sqlite.go index 612440e..e99729c 100644 --- a/pkg/dialect/sqlite.go +++ b/pkg/dialect/sqlite.go @@ -25,7 +25,12 @@ func (d *SQLiteDialect) Features() Feature { FeatureOffset | FeatureUpsert | FeatureSavepoint | - FeatureNullsOrder + FeatureNullsOrder | + FeatureCTE | + FeatureUpdateOrder | + FeatureUpdateLimit | + FeatureDeleteOrder | + FeatureDeleteLimit } // QuoteIdentifier quotes identifiers with double quotes. diff --git a/pkg/rain/query_common.go b/pkg/rain/query_common.go index 9759773..d95ea1d 100644 --- a/pkg/rain/query_common.go +++ b/pkg/rain/query_common.go @@ -88,3 +88,72 @@ func closeRows(rows *sql.Rows, errp *error) { *errp = err } } + +func writeCTEs(ctx *compileContext, ctes []cteDefinition, label string) error { + if len(ctes) == 0 || ctx.skipCTEs { + return nil + } + if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureCTE) { + return fmt.Errorf("rain: %s queries do not support CTEs for %s dialect", label, ctx.dialect.Name()) + } + ctx.writeString("WITH ") + for idx, cte := range ctes { + if idx > 0 { + ctx.writeString(", ") + } + if strings.TrimSpace(cte.name) == "" { + return errors.New("rain: CTE name cannot be empty") + } + if cte.query == nil { + return fmt.Errorf("rain: CTE %q requires a query", cte.name) + } + if len(cte.query.ctes) > 0 { + return fmt.Errorf("rain: CTE %q body cannot itself contain CTEs", cte.name) + } + ctx.writeQuotedIdentifier(cte.name) + ctx.writeString(" AS (") + if err := cte.query.writeSQL(ctx); err != nil { + return err + } + ctx.writeByte(')') + } + ctx.writeByte(' ') + return nil +} + +func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit int, offset int, featureOrder, featureLimit dialect.Feature) error { + if len(order) > 0 { + if featureOrder != 0 && !dialect.HasFeature(ctx.dialect.Features(), featureOrder) { + return fmt.Errorf("rain: ORDER BY is not supported for this query type in %s dialect", ctx.dialect.Name()) + } + ctx.writeString(" ORDER BY ") + for idx, item := range order { + if idx > 0 { + ctx.writeString(", ") + } + if err := ctx.writeExpression(item.Expr); err != nil { + return err + } + ctx.writeByte(' ') + ctx.writeString(string(item.Direction)) + if item.NullsOrder != "" { + if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureNullsOrder) { + return fmt.Errorf("rain: NULLS FIRST/LAST is not supported by %s dialect", ctx.dialect.Name()) + } + ctx.writeByte(' ') + ctx.writeString(string(item.NullsOrder)) + } + } + } + + if limit > 0 || offset > 0 { + if featureLimit != 0 && !dialect.HasFeature(ctx.dialect.Features(), featureLimit) { + return fmt.Errorf("rain: LIMIT/OFFSET is not supported for this query type in %s dialect", ctx.dialect.Name()) + } + if clause := ctx.dialect.LimitOffset(limit, offset); clause != "" { + ctx.writeByte(' ') + ctx.writeString(clause) + } + } + return nil +} diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go index 67ddf10..90cdada 100644 --- a/pkg/rain/query_delete.go +++ b/pkg/rain/query_delete.go @@ -16,6 +16,9 @@ type DeleteQuery struct { dialect dialect.Dialect table *schema.TableDef where []schema.Predicate + order []schema.OrderExpr + limit int + ctes []cteDefinition returning []schema.Expression unbounded bool } @@ -38,6 +41,26 @@ func (q *DeleteQuery) Returning(exprs ...schema.Expression) *DeleteQuery { return q } +// With appends a common table expression definition. +func (q *DeleteQuery) With(name string, query *SelectQuery) *DeleteQuery { + q.ctes = append(q.ctes, cteDefinition{name: name, query: query}) + return q +} + +// OrderBy appends ORDER BY expressions. +// Supported by MySQL and SQLite. +func (q *DeleteQuery) OrderBy(order ...schema.OrderExpr) *DeleteQuery { + q.order = append(q.order, order...) + return q +} + +// Limit sets the LIMIT clause. +// Supported by MySQL and SQLite. +func (q *DeleteQuery) Limit(limit int) *DeleteQuery { + q.limit = limit + return q +} + // Unbounded allows DELETE without a WHERE clause. func (q *DeleteQuery) Unbounded() *DeleteQuery { q.unbounded = true @@ -58,6 +81,11 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) { ctx := newCompileContext(q.dialect) defer releaseCompileContext(ctx) + + if err := writeCTEs(ctx, q.ctes, "delete"); err != nil { + return "", nil, err + } + ctx.writeString("DELETE FROM ") ctx.writeTableName(q.table) if len(q.where) > 0 { @@ -67,6 +95,10 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) { } } + if err := writeOrderLimit(ctx, q.order, q.limit, 0, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil { + return "", nil, err + } + if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { return "", nil, err } diff --git a/pkg/rain/query_delete_test.go b/pkg/rain/query_delete_test.go new file mode 100644 index 0000000..305dff6 --- /dev/null +++ b/pkg/rain/query_delete_test.go @@ -0,0 +1,117 @@ +package rain_test + +import ( + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/rain" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestDeleteOrderLimitToSQL(t *testing.T) { + users, _ := defineTables() + + tests := []struct { + name string + dialect string + setup func(q *rain.DeleteQuery) + wantSQL string + wantErr string + }{ + { + name: "sqlite order and limit", + dialect: "sqlite", + setup: func(q *rain.DeleteQuery) { + q.Where(users.Active.Eq(false)). + OrderBy(users.ID.Asc()). + Limit(10) + }, + wantSQL: `DELETE FROM "users" WHERE "users"."active" = ? ORDER BY "users"."id" ASC LIMIT 10`, + }, + { + name: "mysql order and limit", + dialect: "mysql", + setup: func(q *rain.DeleteQuery) { + q.Where(users.Active.Eq(false)). + OrderBy(users.ID.Asc()). + Limit(10) + }, + wantSQL: "DELETE FROM `users` WHERE `users`.`active` = ? ORDER BY `users`.`id` ASC LIMIT 10", + }, + { + name: "postgres order error", + dialect: "postgres", + setup: func(q *rain.DeleteQuery) { + q.Where(users.Active.Eq(false)). + OrderBy(users.ID.Asc()) + }, + wantErr: "rain: ORDER BY is not supported for this query type in postgres dialect", + }, + { + name: "postgres limit error", + dialect: "postgres", + setup: func(q *rain.DeleteQuery) { + q.Where(users.Active.Eq(false)). + Limit(10) + }, + wantErr: "rain: LIMIT/OFFSET is not supported for this query type in postgres dialect", + }, + { + name: "sqlite with cte", + dialect: "sqlite", + setup: func(q *rain.DeleteQuery) { + db, _ := rain.OpenDialect("sqlite") + sub := db.Select(). + Table(users). + Column(users.ID). + Where(users.Active.Eq(false)) + + q.With("inactive_users", sub). + Where(users.ID.InSubquery(schema.Raw(`SELECT id FROM inactive_users`))) + }, + wantSQL: `WITH "inactive_users" AS (SELECT "users"."id" FROM "users" WHERE "users"."active" = ?) DELETE FROM "users" WHERE "users"."id" IN (SELECT id FROM inactive_users)`, + }, + { + name: "mysql with cte", + dialect: "mysql", + setup: func(q *rain.DeleteQuery) { + db, _ := rain.OpenDialect("mysql") + sub := db.Select(). + Table(users). + Column(users.ID). + Where(users.Active.Eq(false)) + + q.With("inactive_users", sub). + Where(users.ID.InSubquery(schema.Raw(`SELECT id FROM inactive_users`))) + }, + wantSQL: "WITH `inactive_users` AS (SELECT `users`.`id` FROM `users` WHERE `users`.`active` = ?) DELETE FROM `users` WHERE `users`.`id` IN (SELECT id FROM inactive_users)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := rain.OpenDialect(tt.dialect) + if err != nil { + t.Fatal(err) + } + + q := db.Delete().Table(users) + tt.setup(q) + + gotSQL, _, err := q.ToSQL() + if tt.wantErr != "" { + if err == nil || err.Error() != tt.wantErr { + t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + if err != nil { + t.Errorf("ToSQL() unexpected error: %v", err) + return + } + + if gotSQL != tt.wantSQL { + t.Errorf("ToSQL() gotSQL = %q, want %q", gotSQL, tt.wantSQL) + } + }) + } +} diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index 23ab5a7..2e4d5cb 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "strings" "github.com/hyperlocalise/rain-orm/pkg/dialect" "github.com/hyperlocalise/rain-orm/pkg/schema" @@ -385,32 +384,8 @@ func (q *SelectQuery) ToSQL() (string, []any, error) { } func (q *SelectQuery) writeSQL(ctx *compileContext) error { - if len(q.ctes) > 0 && !ctx.skipCTEs { - if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureCTE) { - return fmt.Errorf("rain: select queries do not support CTEs for %s dialect", ctx.dialect.Name()) - } - ctx.writeString("WITH ") - for idx, cte := range q.ctes { - if idx > 0 { - ctx.writeString(", ") - } - if strings.TrimSpace(cte.name) == "" { - return errors.New("rain: CTE name cannot be empty") - } - if cte.query == nil { - return fmt.Errorf("rain: CTE %q requires a query", cte.name) - } - if len(cte.query.ctes) > 0 { - return fmt.Errorf("rain: CTE %q body cannot itself contain CTEs", cte.name) - } - ctx.writeQuotedIdentifier(cte.name) - ctx.writeString(" AS (") - if err := cte.query.writeSQL(ctx); err != nil { - return err - } - ctx.writeByte(')') - } - ctx.writeByte(' ') + if err := writeCTEs(ctx, q.ctes, "select"); err != nil { + return err } if q.firstOperand != nil { @@ -428,7 +403,7 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error { return err } } - if err := q.writeOrderLimit(ctx); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, 0, dialect.FeatureOffset); err != nil { return err } return q.writeLocking(ctx) @@ -504,7 +479,7 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error { } } - if err := q.writeOrderLimit(ctx); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, 0, dialect.FeatureOffset); err != nil { return err } @@ -606,35 +581,6 @@ func (q *SelectQuery) writeJoins(ctx *compileContext) error { return nil } -func (q *SelectQuery) writeOrderLimit(ctx *compileContext) error { - if len(q.order) > 0 { - ctx.writeString(" ORDER BY ") - for idx, item := range q.order { - if idx > 0 { - ctx.writeString(", ") - } - if err := ctx.writeExpression(item.Expr); err != nil { - return err - } - ctx.writeByte(' ') - ctx.writeString(string(item.Direction)) - if item.NullsOrder != "" { - if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureNullsOrder) { - return fmt.Errorf("rain: NULLS FIRST/LAST is not supported by %s dialect", ctx.dialect.Name()) - } - ctx.writeByte(' ') - ctx.writeString(string(item.NullsOrder)) - } - } - } - - if clause := q.dialect.LimitOffset(q.limit, q.offset); clause != "" { - ctx.writeByte(' ') - ctx.writeString(clause) - } - return nil -} - // Scan executes the SELECT query and scans results into dest. func (q *SelectQuery) Scan(ctx context.Context, dest any) error { if q.runner == nil { diff --git a/pkg/rain/query_select_test.go b/pkg/rain/query_select_test.go index eec76a1..4b1b86e 100644 --- a/pkg/rain/query_select_test.go +++ b/pkg/rain/query_select_test.go @@ -842,13 +842,13 @@ func TestSelectAdvancedComposition(t *testing.T) { wantErr: "requires a non-nil query", }, { - name: "cte unsupported on mysql", + name: "cte supported on mysql", dialect: "mysql", build: func(db *rain.DB) *rain.SelectQuery { base := db.Select().Table(users) return db.Select().With("u", base).Table(users) }, - wantErr: "do not support CTEs", + wantSQL: "WITH `u` AS (SELECT * FROM `users`) SELECT * FROM `users`", }, { name: "nested cte body is invalid", diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go index 358ad6f..c138177 100644 --- a/pkg/rain/query_update.go +++ b/pkg/rain/query_update.go @@ -17,6 +17,9 @@ type UpdateQuery struct { table *schema.TableDef values []assignment where []schema.Predicate + order []schema.OrderExpr + limit int + ctes []cteDefinition returning []schema.Expression unbounded bool } @@ -52,6 +55,26 @@ func (q *UpdateQuery) Returning(exprs ...schema.Expression) *UpdateQuery { return q } +// With appends a common table expression definition. +func (q *UpdateQuery) With(name string, query *SelectQuery) *UpdateQuery { + q.ctes = append(q.ctes, cteDefinition{name: name, query: query}) + return q +} + +// OrderBy appends ORDER BY expressions. +// Supported by MySQL and SQLite. +func (q *UpdateQuery) OrderBy(order ...schema.OrderExpr) *UpdateQuery { + q.order = append(q.order, order...) + return q +} + +// Limit sets the LIMIT clause. +// Supported by MySQL and SQLite. +func (q *UpdateQuery) Limit(limit int) *UpdateQuery { + q.limit = limit + return q +} + // Unbounded allows UPDATE without a WHERE clause. func (q *UpdateQuery) Unbounded() *UpdateQuery { q.unbounded = true @@ -75,6 +98,11 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) { ctx := newCompileContext(q.dialect) defer releaseCompileContext(ctx) + + if err := writeCTEs(ctx, q.ctes, "update"); err != nil { + return "", nil, err + } + ctx.writeString("UPDATE ") ctx.writeTableName(q.table) ctx.writeString(" SET ") @@ -99,6 +127,10 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) { } } + if err := writeOrderLimit(ctx, q.order, q.limit, 0, dialect.FeatureUpdateOrder, dialect.FeatureUpdateLimit); err != nil { + return "", nil, err + } + if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { return "", nil, err } diff --git a/pkg/rain/query_update_test.go b/pkg/rain/query_update_test.go new file mode 100644 index 0000000..ea6bf5d --- /dev/null +++ b/pkg/rain/query_update_test.go @@ -0,0 +1,123 @@ +package rain_test + +import ( + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/rain" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestUpdateOrderLimitToSQL(t *testing.T) { + users, _ := defineTables() + + tests := []struct { + name string + dialect string + setup func(q *rain.UpdateQuery) + wantSQL string + wantErr string + }{ + { + name: "sqlite order and limit", + dialect: "sqlite", + setup: func(q *rain.UpdateQuery) { + q.Set(users.Name, "Alice"). + Where(users.Active.Eq(true)). + OrderBy(users.ID.Asc()). + Limit(10) + }, + wantSQL: `UPDATE "users" SET "name" = ? WHERE "users"."active" = ? ORDER BY "users"."id" ASC LIMIT 10`, + }, + { + name: "mysql order and limit", + dialect: "mysql", + setup: func(q *rain.UpdateQuery) { + q.Set(users.Name, "Alice"). + Where(users.Active.Eq(true)). + OrderBy(users.ID.Asc()). + Limit(10) + }, + wantSQL: "UPDATE `users` SET `name` = ? WHERE `users`.`active` = ? ORDER BY `users`.`id` ASC LIMIT 10", + }, + { + name: "postgres order error", + dialect: "postgres", + setup: func(q *rain.UpdateQuery) { + q.Set(users.Name, "Alice"). + Where(users.Active.Eq(true)). + OrderBy(users.ID.Asc()) + }, + wantErr: "rain: ORDER BY is not supported for this query type in postgres dialect", + }, + { + name: "postgres limit error", + dialect: "postgres", + setup: func(q *rain.UpdateQuery) { + q.Set(users.Name, "Alice"). + Where(users.Active.Eq(true)). + Limit(10) + }, + wantErr: "rain: LIMIT/OFFSET is not supported for this query type in postgres dialect", + }, + { + name: "sqlite with cte", + dialect: "sqlite", + setup: func(q *rain.UpdateQuery) { + db, _ := rain.OpenDialect("sqlite") + sub := db.Select(). + Table(users). + Column(users.ID). + Where(users.Active.Eq(false)) + + q.With("inactive_users", sub). + Set(users.Active, true). + Where(users.ID.InSubquery(schema.Raw(`SELECT id FROM inactive_users`))) + }, + wantSQL: `WITH "inactive_users" AS (SELECT "users"."id" FROM "users" WHERE "users"."active" = ?) UPDATE "users" SET "active" = ? WHERE "users"."id" IN (SELECT id FROM inactive_users)`, + }, + { + name: "mysql with cte", + dialect: "mysql", + setup: func(q *rain.UpdateQuery) { + db, _ := rain.OpenDialect("mysql") + sub := db.Select(). + Table(users). + Column(users.ID). + Where(users.Active.Eq(false)) + + q.With("inactive_users", sub). + Set(users.Active, true). + Where(users.ID.InSubquery(schema.Raw(`SELECT id FROM inactive_users`))) + }, + wantSQL: "WITH `inactive_users` AS (SELECT `users`.`id` FROM `users` WHERE `users`.`active` = ?) UPDATE `users` SET `active` = ? WHERE `users`.`id` IN (SELECT id FROM inactive_users)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := rain.OpenDialect(tt.dialect) + if err != nil { + t.Fatal(err) + } + + q := db.Update().Table(users) + tt.setup(q) + + gotSQL, _, err := q.ToSQL() + if tt.wantErr != "" { + if err == nil || err.Error() != tt.wantErr { + t.Errorf("ToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + return + } + if err != nil { + t.Errorf("ToSQL() unexpected error: %v", err) + return + } + + if gotSQL != tt.wantSQL { + t.Errorf("ToSQL() gotSQL = %q, want %q", gotSQL, tt.wantSQL) + } + }) + } +} diff --git a/pkg/rain/query_write_test.go b/pkg/rain/query_write_test.go index 6b4f28c..bb0c1c0 100644 --- a/pkg/rain/query_write_test.go +++ b/pkg/rain/query_write_test.go @@ -166,14 +166,21 @@ func TestDialectFeatures(t *testing.T) { dialect.FeatureSelectDistinctOn, }, { - name: "mysql", - dialect: "mysql", - features: dialect.FeatureOffset | dialect.FeatureUpsert | dialect.FeatureSavepoint | dialect.FeatureSelectLocking, + name: "mysql", + dialect: "mysql", + features: dialect.FeatureOffset | + dialect.FeatureUpsert | + dialect.FeatureSavepoint | + dialect.FeatureSelectLocking | + dialect.FeatureCTE | + dialect.FeatureUpdateOrder | + dialect.FeatureUpdateLimit | + dialect.FeatureDeleteOrder | + dialect.FeatureDeleteLimit, missing: []dialect.Feature{ dialect.FeatureInsertReturning, dialect.FeatureUpdateReturning, dialect.FeatureDeleteReturning, - dialect.FeatureCTE, dialect.FeatureDefaultPlaceholder, }, }, @@ -186,9 +193,13 @@ func TestDialectFeatures(t *testing.T) { dialect.FeatureOffset | dialect.FeatureUpsert | dialect.FeatureSavepoint | - dialect.FeatureNullsOrder, + dialect.FeatureNullsOrder | + dialect.FeatureCTE | + dialect.FeatureUpdateOrder | + dialect.FeatureUpdateLimit | + dialect.FeatureDeleteOrder | + dialect.FeatureDeleteLimit, missing: []dialect.Feature{ - dialect.FeatureCTE, dialect.FeatureDefaultPlaceholder, }, }, From f6c8355f6c44b8ef445df0d1c935d92988058f4e Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 31 May 2026 06:26:30 +0000 Subject: [PATCH 2/2] refactor(rain): improve ORDER BY/LIMIT safety and handle LIMIT 0 - Treat zero-value dialect features as "deny" in writeOrderLimit for safer defaults. - Introduce FeatureUnlimited for queries that always allow Order/Limit (SELECT). - Change limit and offset to *int to support explicit Limit(0). - Return error for negative limit/offset values. - Update PostgreSQL, MySQL, and SQLite dialects to enable FeatureUnlimited. - Ensure consistent LIMIT syntax across dialects for offset-only and zero-limit cases. Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/dialect/dialect_test.go | 18 +++++++++--------- pkg/dialect/feature.go | 1 + pkg/dialect/mysql.go | 5 +++-- pkg/dialect/postgres.go | 11 ++++++----- pkg/dialect/sqlite.go | 5 +++-- pkg/rain/query_common.go | 24 +++++++++++++++++++----- pkg/rain/query_delete.go | 6 +++--- pkg/rain/query_select.go | 24 ++++++++++++++++-------- pkg/rain/query_update.go | 6 +++--- pkg/rain/query_write_test.go | 9 ++++++--- 10 files changed, 69 insertions(+), 40 deletions(-) diff --git a/pkg/dialect/dialect_test.go b/pkg/dialect/dialect_test.go index 8e7997b..88613c3 100644 --- a/pkg/dialect/dialect_test.go +++ b/pkg/dialect/dialect_test.go @@ -134,7 +134,7 @@ func TestPostgresDialect(t *testing.T) { if got := d.Name(); got != "postgres" { t.Fatalf("unexpected name: %q", got) } - if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureCTE|FeatureDefaultPlaceholder|FeatureSavepoint|FeatureSelectLocking|FeatureNullsOrder|FeatureSelectDistinctOn { + if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureCTE|FeatureDefaultPlaceholder|FeatureSavepoint|FeatureSelectLocking|FeatureNullsOrder|FeatureSelectDistinctOn|FeatureUnlimited { t.Fatalf("unexpected features: %b", got) } if got := d.QuoteIdentifier(`user"name`); got != `"user""name"` { @@ -193,10 +193,10 @@ func TestPostgresDialect(t *testing.T) { if got := d.LimitOffset(10, 0); got != "LIMIT 10" { t.Fatalf("unexpected limit only clause: %q", got) } - if got := d.LimitOffset(0, 20); got != "OFFSET 20" { + if got := d.LimitOffset(0, 20); got != "LIMIT 0 OFFSET 20" { t.Fatalf("unexpected offset only clause: %q", got) } - if got := d.LimitOffset(0, 0); got != "" { + if got := d.LimitOffset(0, 0); got != "LIMIT 0" { t.Fatalf("unexpected empty limit clause: %q", got) } if got := d.UpsertClause("users", []string{"email"}, []string{"name"}); got != "ON CONFLICT DO UPDATE" { @@ -224,7 +224,7 @@ func TestMySQLDialect(t *testing.T) { if got := d.Name(); got != "mysql" { t.Fatalf("unexpected name: %q", got) } - if got := d.Features(); got != FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureSelectLocking|FeatureCTE|FeatureUpdateOrder|FeatureUpdateLimit|FeatureDeleteOrder|FeatureDeleteLimit { + if got := d.Features(); got != FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureSelectLocking|FeatureCTE|FeatureUpdateOrder|FeatureUpdateLimit|FeatureDeleteOrder|FeatureDeleteLimit|FeatureUnlimited { t.Fatalf("unexpected features: %b", got) } if got := d.QuoteIdentifier("user`name"); got != "`user``name`" { @@ -283,10 +283,10 @@ func TestMySQLDialect(t *testing.T) { if got := d.LimitOffset(10, 0); got != "LIMIT 10" { t.Fatalf("unexpected limit only clause: %q", got) } - if got := d.LimitOffset(0, 20); got != "LIMIT 18446744073709551615 OFFSET 20" { + if got := d.LimitOffset(0, 20); got != "LIMIT 20, 0" { t.Fatalf("unexpected offset only clause: %q", got) } - if got := d.LimitOffset(0, 0); got != "" { + if got := d.LimitOffset(0, 0); got != "LIMIT 0" { t.Fatalf("unexpected empty limit clause: %q", got) } if got := d.UpsertClause("users", []string{"email"}, []string{"name"}); got != "ON DUPLICATE KEY UPDATE" { @@ -314,7 +314,7 @@ func TestSQLiteDialect(t *testing.T) { if got := d.Name(); got != "sqlite" { t.Fatalf("unexpected name: %q", got) } - if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureNullsOrder|FeatureCTE|FeatureUpdateOrder|FeatureUpdateLimit|FeatureDeleteOrder|FeatureDeleteLimit { + if got := d.Features(); got != FeatureInsertReturning|FeatureUpdateReturning|FeatureDeleteReturning|FeatureOffset|FeatureUpsert|FeatureSavepoint|FeatureNullsOrder|FeatureCTE|FeatureUpdateOrder|FeatureUpdateLimit|FeatureDeleteOrder|FeatureDeleteLimit|FeatureUnlimited { t.Fatalf("unexpected features: %b", got) } if got := d.QuoteIdentifier(`user"name`); got != `"user""name"` { @@ -368,10 +368,10 @@ func TestSQLiteDialect(t *testing.T) { if got := d.LimitOffset(10, 0); got != "LIMIT 10" { t.Fatalf("unexpected limit only clause: %q", got) } - if got := d.LimitOffset(0, 20); got != "LIMIT -1 OFFSET 20" { + if got := d.LimitOffset(0, 20); got != "LIMIT 0 OFFSET 20" { t.Fatalf("unexpected offset only clause: %q", got) } - if got := d.LimitOffset(0, 0); got != "" { + if got := d.LimitOffset(0, 0); got != "LIMIT 0" { t.Fatalf("unexpected empty limit clause: %q", got) } if got := d.UpsertClause("users", []string{"email"}, []string{"name"}); got != "ON CONFLICT DO UPDATE" { diff --git a/pkg/dialect/feature.go b/pkg/dialect/feature.go index b9dea8f..e9973b1 100644 --- a/pkg/dialect/feature.go +++ b/pkg/dialect/feature.go @@ -19,6 +19,7 @@ const ( FeatureUpdateLimit FeatureDeleteOrder FeatureDeleteLimit + FeatureUnlimited ) // HasFeature reports whether a feature set includes the requested capability. diff --git a/pkg/dialect/mysql.go b/pkg/dialect/mysql.go index 5ab4294..9cc1cc9 100644 --- a/pkg/dialect/mysql.go +++ b/pkg/dialect/mysql.go @@ -27,7 +27,8 @@ func (d *MySQLDialect) Features() Feature { FeatureUpdateOrder | FeatureUpdateLimit | FeatureDeleteOrder | - FeatureDeleteLimit + FeatureDeleteLimit | + FeatureUnlimited } // QuoteIdentifier quotes identifiers with backticks. @@ -101,7 +102,7 @@ func (d *MySQLDialect) AutoIncrementKeyword() string { // LimitOffset returns MySQL LIMIT/OFFSET syntax. func (d *MySQLDialect) LimitOffset(limit, offset int) string { - if limit > 0 { + if limit >= 0 { if offset > 0 { return "LIMIT " + strconv.Itoa(offset) + ", " + strconv.Itoa(limit) } diff --git a/pkg/dialect/postgres.go b/pkg/dialect/postgres.go index e84129a..4371b26 100644 --- a/pkg/dialect/postgres.go +++ b/pkg/dialect/postgres.go @@ -30,7 +30,8 @@ func (d *PostgresDialect) Features() Feature { FeatureSavepoint | FeatureSelectLocking | FeatureNullsOrder | - FeatureSelectDistinctOn + FeatureSelectDistinctOn | + FeatureUnlimited } // QuoteIdentifier quotes identifiers with double quotes. @@ -104,10 +105,10 @@ func (d *PostgresDialect) AutoIncrementKeyword() string { // LimitOffset returns PostgreSQL LIMIT/OFFSET syntax. func (d *PostgresDialect) LimitOffset(limit, offset int) string { - if limit > 0 && offset > 0 { - return "LIMIT " + strconv.Itoa(limit) + " OFFSET " + strconv.Itoa(offset) - } - if limit > 0 { + if limit >= 0 { + if offset > 0 { + return "LIMIT " + strconv.Itoa(limit) + " OFFSET " + strconv.Itoa(offset) + } return "LIMIT " + strconv.Itoa(limit) } if offset > 0 { diff --git a/pkg/dialect/sqlite.go b/pkg/dialect/sqlite.go index e99729c..7e452d4 100644 --- a/pkg/dialect/sqlite.go +++ b/pkg/dialect/sqlite.go @@ -30,7 +30,8 @@ func (d *SQLiteDialect) Features() Feature { FeatureUpdateOrder | FeatureUpdateLimit | FeatureDeleteOrder | - FeatureDeleteLimit + FeatureDeleteLimit | + FeatureUnlimited } // QuoteIdentifier quotes identifiers with double quotes. @@ -81,7 +82,7 @@ func (d *SQLiteDialect) AutoIncrementKeyword() string { // LimitOffset returns SQLite LIMIT/OFFSET syntax. func (d *SQLiteDialect) LimitOffset(limit, offset int) string { - if limit > 0 { + if limit >= 0 { if offset > 0 { return "LIMIT " + strconv.Itoa(limit) + " OFFSET " + strconv.Itoa(offset) } diff --git a/pkg/rain/query_common.go b/pkg/rain/query_common.go index d95ea1d..4dec51b 100644 --- a/pkg/rain/query_common.go +++ b/pkg/rain/query_common.go @@ -121,9 +121,9 @@ func writeCTEs(ctx *compileContext, ctes []cteDefinition, label string) error { return nil } -func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit int, offset int, featureOrder, featureLimit dialect.Feature) error { +func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit *int, offset *int, featureOrder, featureLimit dialect.Feature) error { if len(order) > 0 { - if featureOrder != 0 && !dialect.HasFeature(ctx.dialect.Features(), featureOrder) { + if featureOrder != dialect.FeatureUnlimited && !dialect.HasFeature(ctx.dialect.Features(), featureOrder) { return fmt.Errorf("rain: ORDER BY is not supported for this query type in %s dialect", ctx.dialect.Name()) } ctx.writeString(" ORDER BY ") @@ -146,11 +146,25 @@ func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit int, o } } - if limit > 0 || offset > 0 { - if featureLimit != 0 && !dialect.HasFeature(ctx.dialect.Features(), featureLimit) { + if limit != nil || (offset != nil && *offset > 0) { + if featureLimit != dialect.FeatureUnlimited && !dialect.HasFeature(ctx.dialect.Features(), featureLimit) { return fmt.Errorf("rain: LIMIT/OFFSET is not supported for this query type in %s dialect", ctx.dialect.Name()) } - if clause := ctx.dialect.LimitOffset(limit, offset); clause != "" { + l := -1 + if limit != nil { + l = *limit + if l < 0 { + return errors.New("rain: LIMIT must be non-negative") + } + } + o := 0 + if offset != nil { + o = *offset + if o < 0 { + return errors.New("rain: OFFSET must be non-negative") + } + } + if clause := ctx.dialect.LimitOffset(l, o); clause != "" { ctx.writeByte(' ') ctx.writeString(clause) } diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go index 90cdada..91ab74d 100644 --- a/pkg/rain/query_delete.go +++ b/pkg/rain/query_delete.go @@ -17,7 +17,7 @@ type DeleteQuery struct { table *schema.TableDef where []schema.Predicate order []schema.OrderExpr - limit int + limit *int ctes []cteDefinition returning []schema.Expression unbounded bool @@ -57,7 +57,7 @@ func (q *DeleteQuery) OrderBy(order ...schema.OrderExpr) *DeleteQuery { // Limit sets the LIMIT clause. // Supported by MySQL and SQLite. func (q *DeleteQuery) Limit(limit int) *DeleteQuery { - q.limit = limit + q.limit = &limit return q } @@ -95,7 +95,7 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) { } } - if err := writeOrderLimit(ctx, q.order, q.limit, 0, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, nil, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil { return "", nil, err } diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index 2e4d5cb..34be220 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -28,8 +28,8 @@ type SelectQuery struct { setOps []setOperation distinct bool distinctOn []schema.Expression - limit int - offset int + limit *int + offset *int relationNames []string cacheOptions *queryCacheOptions locking *selectLocking @@ -158,13 +158,13 @@ func (q *SelectQuery) OrderBy(order ...schema.OrderExpr) *SelectQuery { // Limit sets the LIMIT clause. func (q *SelectQuery) Limit(limit int) *SelectQuery { - q.limit = limit + q.limit = &limit return q } // Offset sets the OFFSET clause. func (q *SelectQuery) Offset(offset int) *SelectQuery { - q.offset = offset + q.offset = &offset return q } @@ -285,6 +285,14 @@ func (q *SelectQuery) clone() *SelectQuery { newQ.setOps = append([]setOperation(nil), q.setOps...) newQ.distinctOn = append([]schema.Expression(nil), q.distinctOn...) newQ.relationNames = append([]string(nil), q.relationNames...) + if q.limit != nil { + l := *q.limit + newQ.limit = &l + } + if q.offset != nil { + o := *q.offset + newQ.offset = &o + } if q.locking != nil { copyLocking := *q.locking copyLocking.of = append([]schema.TableReference(nil), q.locking.of...) @@ -336,7 +344,7 @@ func (q *SelectQuery) withSQLiteInsertSelectConflictWhereChanged() (*SelectQuery func (q *SelectQuery) isBareCompound() bool { return q.firstOperand != nil && - len(q.order) == 0 && q.limit == 0 && q.offset == 0 && + len(q.order) == 0 && q.limit == nil && q.offset == nil && !q.distinct && len(q.distinctOn) == 0 && len(q.cols) == 0 && q.table == nil && len(q.where) == 0 && len(q.joins) == 0 && len(q.groupBy) == 0 && len(q.having) == 0 && @@ -403,7 +411,7 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error { return err } } - if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, 0, dialect.FeatureOffset); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, dialect.FeatureUnlimited, dialect.FeatureUnlimited); err != nil { return err } return q.writeLocking(ctx) @@ -479,7 +487,7 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error { } } - if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, 0, dialect.FeatureOffset); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, dialect.FeatureUnlimited, dialect.FeatureUnlimited); err != nil { return err } @@ -539,7 +547,7 @@ func (q *SelectQuery) writeCompoundOperandSQL(ctx *compileContext) error { } // Use parentheses if the operand has its own ORDER BY, LIMIT, locking, or is itself a compound query. // Flattening is handled during builder chaining in wrapSetOp. - useParens := len(q.order) > 0 || q.limit > 0 || q.offset > 0 || q.locking != nil || q.firstOperand != nil + useParens := len(q.order) > 0 || q.limit != nil || q.offset != nil || q.locking != nil || q.firstOperand != nil if useParens { ctx.writeByte('(') } diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go index c138177..d24ccb2 100644 --- a/pkg/rain/query_update.go +++ b/pkg/rain/query_update.go @@ -18,7 +18,7 @@ type UpdateQuery struct { values []assignment where []schema.Predicate order []schema.OrderExpr - limit int + limit *int ctes []cteDefinition returning []schema.Expression unbounded bool @@ -71,7 +71,7 @@ func (q *UpdateQuery) OrderBy(order ...schema.OrderExpr) *UpdateQuery { // Limit sets the LIMIT clause. // Supported by MySQL and SQLite. func (q *UpdateQuery) Limit(limit int) *UpdateQuery { - q.limit = limit + q.limit = &limit return q } @@ -127,7 +127,7 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) { } } - if err := writeOrderLimit(ctx, q.order, q.limit, 0, dialect.FeatureUpdateOrder, dialect.FeatureUpdateLimit); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, nil, dialect.FeatureUpdateOrder, dialect.FeatureUpdateLimit); err != nil { return "", nil, err } diff --git a/pkg/rain/query_write_test.go b/pkg/rain/query_write_test.go index bb0c1c0..99ce552 100644 --- a/pkg/rain/query_write_test.go +++ b/pkg/rain/query_write_test.go @@ -163,7 +163,8 @@ func TestDialectFeatures(t *testing.T) { dialect.FeatureSavepoint | dialect.FeatureSelectLocking | dialect.FeatureNullsOrder | - dialect.FeatureSelectDistinctOn, + dialect.FeatureSelectDistinctOn | + dialect.FeatureUnlimited, }, { name: "mysql", @@ -176,7 +177,8 @@ func TestDialectFeatures(t *testing.T) { dialect.FeatureUpdateOrder | dialect.FeatureUpdateLimit | dialect.FeatureDeleteOrder | - dialect.FeatureDeleteLimit, + dialect.FeatureDeleteLimit | + dialect.FeatureUnlimited, missing: []dialect.Feature{ dialect.FeatureInsertReturning, dialect.FeatureUpdateReturning, @@ -198,7 +200,8 @@ func TestDialectFeatures(t *testing.T) { dialect.FeatureUpdateOrder | dialect.FeatureUpdateLimit | dialect.FeatureDeleteOrder | - dialect.FeatureDeleteLimit, + dialect.FeatureDeleteLimit | + dialect.FeatureUnlimited, missing: []dialect.Feature{ dialect.FeatureDefaultPlaceholder, },