Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions pkg/rain/prepared_exec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package rain

import (
"context"
"database/sql"
"sync"

"github.com/hyperlocalise/rain-orm/pkg/schema"
)

// PreparedInsertQuery is a prepared INSERT query with reusable named argument binding.
type PreparedInsertQuery struct {
table *schema.TableDef
compiled compiledQuery
stmt *sql.Stmt
closeOnce sync.Once
closeErr error
}

// Exec executes the prepared INSERT query.
func (p *PreparedInsertQuery) Exec(ctx context.Context, args PreparedArgs) (sql.Result, error) {
bound, err := p.compiled.bind(args)
if err != nil {
return nil, err
}

return p.stmt.ExecContext(ctx, bound...)
}

// Scan executes the prepared INSERT ... RETURNING query and scans results into dest.
func (p *PreparedInsertQuery) Scan(ctx context.Context, args PreparedArgs, dest any) (err error) {
bound, err := p.compiled.bind(args)
if err != nil {
return err
}

rows, err := p.stmt.QueryContext(ctx, bound...)
if err != nil {
return err
}
defer closeRows(rows, &err)

return scanRowsAgainstTable(rows, dest, p.table)
}
Comment on lines +31 to +44
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing RETURNING guard in prepared Scan methods

The non-prepared InsertQuery.Scan(), UpdateQuery.Scan(), and DeleteQuery.Scan() all guard against being called without a RETURNING clause by checking len(q.returning) == 0 and returning a clear "rain: insert scan requires RETURNING" error. The three prepared Scan methods have no such check. If a user prepares an INSERT/UPDATE/DELETE without .Returning(...) and then calls Scan(), the statement executes but the database returns no columns; scanRowsAgainstTable then fails with a confusing low-level error (e.g., a scan plan mismatch) rather than the clear sentinel the caller would expect.

Fix: add a hasReturning bool field to each prepared type, set it to len(q.returning) > 0 in the respective Prepare() method, and check it at the top of each Scan(). The same gap exists on lines 76–89 (PreparedUpdateQuery.Scan) and lines 121–134 (PreparedDeleteQuery.Scan).

Prompt To Fix With AI
This is a comment left during a code review.
Path: pkg/rain/prepared_exec.go
Line: 31-44

Comment:
**Missing RETURNING guard in prepared Scan methods**

The non-prepared `InsertQuery.Scan()`, `UpdateQuery.Scan()`, and `DeleteQuery.Scan()` all guard against being called without a `RETURNING` clause by checking `len(q.returning) == 0` and returning a clear `"rain: insert scan requires RETURNING"` error. The three prepared `Scan` methods have no such check. If a user prepares an `INSERT`/`UPDATE`/`DELETE` without `.Returning(...)` and then calls `Scan()`, the statement executes but the database returns no columns; `scanRowsAgainstTable` then fails with a confusing low-level error (e.g., a scan plan mismatch) rather than the clear sentinel the caller would expect.

Fix: add a `hasReturning bool` field to each prepared type, set it to `len(q.returning) > 0` in the respective `Prepare()` method, and check it at the top of each `Scan()`. The same gap exists on lines 76–89 (`PreparedUpdateQuery.Scan`) and lines 121–134 (`PreparedDeleteQuery.Scan`).

How can I resolve this? If you propose a fix, please make it concise.

Fix in Codex


// Close closes the prepared statement.
func (p *PreparedInsertQuery) Close() error {
p.closeOnce.Do(func() {
if p.stmt != nil {
p.closeErr = p.stmt.Close()
}
})
return p.closeErr
}

// PreparedUpdateQuery is a prepared UPDATE query with reusable named argument binding.
type PreparedUpdateQuery struct {
table *schema.TableDef
compiled compiledQuery
stmt *sql.Stmt
closeOnce sync.Once
closeErr error
}

// Exec executes the prepared UPDATE query.
func (p *PreparedUpdateQuery) Exec(ctx context.Context, args PreparedArgs) (sql.Result, error) {
bound, err := p.compiled.bind(args)
if err != nil {
return nil, err
}

return p.stmt.ExecContext(ctx, bound...)
}

// Scan executes the prepared UPDATE ... RETURNING query and scans results into dest.
func (p *PreparedUpdateQuery) Scan(ctx context.Context, args PreparedArgs, dest any) (err error) {
bound, err := p.compiled.bind(args)
if err != nil {
return err
}

rows, err := p.stmt.QueryContext(ctx, bound...)
if err != nil {
return err
}
defer closeRows(rows, &err)

return scanRowsAgainstTable(rows, dest, p.table)
}

// Close closes the prepared statement.
func (p *PreparedUpdateQuery) Close() error {
p.closeOnce.Do(func() {
if p.stmt != nil {
p.closeErr = p.stmt.Close()
}
})
return p.closeErr
}

// PreparedDeleteQuery is a prepared DELETE query with reusable named argument binding.
type PreparedDeleteQuery struct {
table *schema.TableDef
compiled compiledQuery
stmt *sql.Stmt
closeOnce sync.Once
closeErr error
}

// Exec executes the prepared DELETE query.
func (p *PreparedDeleteQuery) Exec(ctx context.Context, args PreparedArgs) (sql.Result, error) {
bound, err := p.compiled.bind(args)
if err != nil {
return nil, err
}

return p.stmt.ExecContext(ctx, bound...)
}

// Scan executes the prepared DELETE ... RETURNING query and scans results into dest.
func (p *PreparedDeleteQuery) Scan(ctx context.Context, args PreparedArgs, dest any) (err error) {
bound, err := p.compiled.bind(args)
if err != nil {
return err
}

rows, err := p.stmt.QueryContext(ctx, bound...)
if err != nil {
return err
}
defer closeRows(rows, &err)

return scanRowsAgainstTable(rows, dest, p.table)
}

// Close closes the prepared statement.
func (p *PreparedDeleteQuery) Close() error {
p.closeOnce.Do(func() {
if p.stmt != nil {
p.closeErr = p.stmt.Close()
}
})
return p.closeErr
}
120 changes: 120 additions & 0 deletions pkg/rain/prepared_exec_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package rain

import (
"testing"

"github.com/hyperlocalise/rain-orm/pkg/schema"
)

func TestPreparedInsertCompile(t *testing.T) {
t.Parallel()

db, _ := OpenDialect("postgres")
users := schema.Define("users", func(t *struct {
schema.TableModel
ID *schema.Column[int64]
Email *schema.Column[string]
Name *schema.Column[string]
},
) {
t.ID = t.BigSerial("id").PrimaryKey()
t.Email = t.VarChar("email", 255)
t.Name = t.Text("name")
})

q := db.Insert().
Table(users).
Set(users.Email, schema.Placeholder("email")).
Set(users.Name, schema.Placeholder("name"))

compiled, err := q.compile()
if err != nil {
t.Fatalf("compile failed: %v", err)
}

wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2)`
if compiled.sql != wantSQL {
t.Errorf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, compiled.sql)
}
if !compiled.hasNames {
t.Errorf("expected compiled query to have names")
}
}

func TestPreparedUpdateCompile(t *testing.T) {
t.Parallel()

db, _ := OpenDialect("postgres")
users := schema.Define("users", func(t *struct {
schema.TableModel
ID *schema.Column[int64]
Name *schema.Column[string]
},
) {
t.ID = t.BigSerial("id").PrimaryKey()
t.Name = t.Text("name")
})

q := db.Update().
Table(users).
Set(users.Name, schema.Placeholder("new_name")).
Where(users.ID.EqExpr(schema.Placeholder("id")))

compiled, err := q.compile()
if err != nil {
t.Fatalf("compile failed: %v", err)
}

wantSQL := `UPDATE "users" SET "name" = $1 WHERE "users"."id" = $2`
if compiled.sql != wantSQL {
t.Errorf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, compiled.sql)
}
}

func TestPreparedDeleteCompile(t *testing.T) {
t.Parallel()

db, _ := OpenDialect("postgres")
users := schema.Define("users", func(t *struct {
schema.TableModel
ID *schema.Column[int64]
},
) {
t.ID = t.BigSerial("id").PrimaryKey()
})

q := db.Delete().
Table(users).
Where(users.ID.EqExpr(schema.Placeholder("id")))

compiled, err := q.compile()
if err != nil {
t.Fatalf("compile failed: %v", err)
}

wantSQL := `DELETE FROM "users" WHERE "users"."id" = $1`
if compiled.sql != wantSQL {
t.Errorf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, compiled.sql)
}
}

func TestPreparedInsertScanInternal(t *testing.T) {
// This test ensures that PreparedInsertQuery has access to the correct table metadata.
users := schema.Define("users", func(t *struct {
schema.TableModel
ID *schema.Column[int64]
Email *schema.Column[string]
},
) {
t.ID = t.BigSerial("id").PrimaryKey()
t.Email = t.VarChar("email", 255)
})

prepared := &PreparedInsertQuery{
table: users.TableDef(),
}

if prepared.table.Name != "users" {
t.Errorf("expected table name users, got %s", prepared.table.Name)
}
}
49 changes: 49 additions & 0 deletions pkg/rain/prepared_exec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package rain_test

import (
"testing"

"github.com/hyperlocalise/rain-orm/pkg/rain"
"github.com/hyperlocalise/rain-orm/pkg/schema"
)

func TestPreparedToSQLReturnsErrorWithPlaceholders(t *testing.T) {
t.Parallel()

db, err := rain.OpenDialect("postgres")
if err != nil {
t.Fatalf("OpenDialect returned error: %v", err)
}
users, _ := defineTables()

t.Run("insert", func(t *testing.T) {
_, _, err := db.Insert().
Table(users).
Set(users.Email, schema.Placeholder("email")).
ToSQL()
if err != rain.ErrPreparedArgsRequired {
t.Errorf("expected ErrPreparedArgsRequired, got %v", err)
}
})

t.Run("update", func(t *testing.T) {
_, _, err := db.Update().
Table(users).
Set(users.Name, schema.Placeholder("name")).
Where(users.ID.Eq(int64(1))).
ToSQL()
if err != rain.ErrPreparedArgsRequired {
t.Errorf("expected ErrPreparedArgsRequired, got %v", err)
}
})

t.Run("delete", func(t *testing.T) {
_, _, err := db.Delete().
Table(users).
Where(users.ID.EqExpr(schema.Placeholder("id"))).
ToSQL()
if err != rain.ErrPreparedArgsRequired {
t.Errorf("expected ErrPreparedArgsRequired, got %v", err)
}
})
}
Loading