diff --git a/pkg/rain/prepared_exec.go b/pkg/rain/prepared_exec.go new file mode 100644 index 0000000..ff35f02 --- /dev/null +++ b/pkg/rain/prepared_exec.go @@ -0,0 +1,144 @@ +package rain + +import ( + "context" + "database/sql" + "sync" + + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +// PreparedInsertQuery is a prepared INSERT query with reusable named argument binding. +type PreparedInsertQuery struct { + table *schema.TableDef + compiled compiledQuery + stmt *sql.Stmt + closeOnce sync.Once + closeErr error +} + +// Exec executes the prepared INSERT query. +func (p *PreparedInsertQuery) Exec(ctx context.Context, args PreparedArgs) (sql.Result, error) { + bound, err := p.compiled.bind(args) + if err != nil { + return nil, err + } + + return p.stmt.ExecContext(ctx, bound...) +} + +// Scan executes the prepared INSERT ... RETURNING query and scans results into dest. +func (p *PreparedInsertQuery) Scan(ctx context.Context, args PreparedArgs, dest any) (err error) { + bound, err := p.compiled.bind(args) + if err != nil { + return err + } + + rows, err := p.stmt.QueryContext(ctx, bound...) + if err != nil { + return err + } + defer closeRows(rows, &err) + + return scanRowsAgainstTable(rows, dest, p.table) +} + +// Close closes the prepared statement. +func (p *PreparedInsertQuery) Close() error { + p.closeOnce.Do(func() { + if p.stmt != nil { + p.closeErr = p.stmt.Close() + } + }) + return p.closeErr +} + +// PreparedUpdateQuery is a prepared UPDATE query with reusable named argument binding. +type PreparedUpdateQuery struct { + table *schema.TableDef + compiled compiledQuery + stmt *sql.Stmt + closeOnce sync.Once + closeErr error +} + +// Exec executes the prepared UPDATE query. +func (p *PreparedUpdateQuery) Exec(ctx context.Context, args PreparedArgs) (sql.Result, error) { + bound, err := p.compiled.bind(args) + if err != nil { + return nil, err + } + + return p.stmt.ExecContext(ctx, bound...) +} + +// Scan executes the prepared UPDATE ... RETURNING query and scans results into dest. +func (p *PreparedUpdateQuery) Scan(ctx context.Context, args PreparedArgs, dest any) (err error) { + bound, err := p.compiled.bind(args) + if err != nil { + return err + } + + rows, err := p.stmt.QueryContext(ctx, bound...) + if err != nil { + return err + } + defer closeRows(rows, &err) + + return scanRowsAgainstTable(rows, dest, p.table) +} + +// Close closes the prepared statement. +func (p *PreparedUpdateQuery) Close() error { + p.closeOnce.Do(func() { + if p.stmt != nil { + p.closeErr = p.stmt.Close() + } + }) + return p.closeErr +} + +// PreparedDeleteQuery is a prepared DELETE query with reusable named argument binding. +type PreparedDeleteQuery struct { + table *schema.TableDef + compiled compiledQuery + stmt *sql.Stmt + closeOnce sync.Once + closeErr error +} + +// Exec executes the prepared DELETE query. +func (p *PreparedDeleteQuery) Exec(ctx context.Context, args PreparedArgs) (sql.Result, error) { + bound, err := p.compiled.bind(args) + if err != nil { + return nil, err + } + + return p.stmt.ExecContext(ctx, bound...) +} + +// Scan executes the prepared DELETE ... RETURNING query and scans results into dest. +func (p *PreparedDeleteQuery) Scan(ctx context.Context, args PreparedArgs, dest any) (err error) { + bound, err := p.compiled.bind(args) + if err != nil { + return err + } + + rows, err := p.stmt.QueryContext(ctx, bound...) + if err != nil { + return err + } + defer closeRows(rows, &err) + + return scanRowsAgainstTable(rows, dest, p.table) +} + +// Close closes the prepared statement. +func (p *PreparedDeleteQuery) Close() error { + p.closeOnce.Do(func() { + if p.stmt != nil { + p.closeErr = p.stmt.Close() + } + }) + return p.closeErr +} diff --git a/pkg/rain/prepared_exec_internal_test.go b/pkg/rain/prepared_exec_internal_test.go new file mode 100644 index 0000000..c56eebb --- /dev/null +++ b/pkg/rain/prepared_exec_internal_test.go @@ -0,0 +1,120 @@ +package rain + +import ( + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestPreparedInsertCompile(t *testing.T) { + t.Parallel() + + db, _ := OpenDialect("postgres") + users := schema.Define("users", func(t *struct { + schema.TableModel + ID *schema.Column[int64] + Email *schema.Column[string] + Name *schema.Column[string] + }, + ) { + t.ID = t.BigSerial("id").PrimaryKey() + t.Email = t.VarChar("email", 255) + t.Name = t.Text("name") + }) + + q := db.Insert(). + Table(users). + Set(users.Email, schema.Placeholder("email")). + Set(users.Name, schema.Placeholder("name")) + + compiled, err := q.compile() + if err != nil { + t.Fatalf("compile failed: %v", err) + } + + wantSQL := `INSERT INTO "users" ("email", "name") VALUES ($1, $2)` + if compiled.sql != wantSQL { + t.Errorf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, compiled.sql) + } + if !compiled.hasNames { + t.Errorf("expected compiled query to have names") + } +} + +func TestPreparedUpdateCompile(t *testing.T) { + t.Parallel() + + db, _ := OpenDialect("postgres") + users := schema.Define("users", func(t *struct { + schema.TableModel + ID *schema.Column[int64] + Name *schema.Column[string] + }, + ) { + t.ID = t.BigSerial("id").PrimaryKey() + t.Name = t.Text("name") + }) + + q := db.Update(). + Table(users). + Set(users.Name, schema.Placeholder("new_name")). + Where(users.ID.EqExpr(schema.Placeholder("id"))) + + compiled, err := q.compile() + if err != nil { + t.Fatalf("compile failed: %v", err) + } + + wantSQL := `UPDATE "users" SET "name" = $1 WHERE "users"."id" = $2` + if compiled.sql != wantSQL { + t.Errorf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, compiled.sql) + } +} + +func TestPreparedDeleteCompile(t *testing.T) { + t.Parallel() + + db, _ := OpenDialect("postgres") + users := schema.Define("users", func(t *struct { + schema.TableModel + ID *schema.Column[int64] + }, + ) { + t.ID = t.BigSerial("id").PrimaryKey() + }) + + q := db.Delete(). + Table(users). + Where(users.ID.EqExpr(schema.Placeholder("id"))) + + compiled, err := q.compile() + if err != nil { + t.Fatalf("compile failed: %v", err) + } + + wantSQL := `DELETE FROM "users" WHERE "users"."id" = $1` + if compiled.sql != wantSQL { + t.Errorf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, compiled.sql) + } +} + +func TestPreparedInsertScanInternal(t *testing.T) { + // This test ensures that PreparedInsertQuery has access to the correct table metadata. + users := schema.Define("users", func(t *struct { + schema.TableModel + ID *schema.Column[int64] + Email *schema.Column[string] + }, + ) { + t.ID = t.BigSerial("id").PrimaryKey() + t.Email = t.VarChar("email", 255) + }) + + prepared := &PreparedInsertQuery{ + table: users.TableDef(), + } + + if prepared.table.Name != "users" { + t.Errorf("expected table name users, got %s", prepared.table.Name) + } +} diff --git a/pkg/rain/prepared_exec_test.go b/pkg/rain/prepared_exec_test.go new file mode 100644 index 0000000..481d9c6 --- /dev/null +++ b/pkg/rain/prepared_exec_test.go @@ -0,0 +1,49 @@ +package rain_test + +import ( + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/rain" + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestPreparedToSQLReturnsErrorWithPlaceholders(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + t.Run("insert", func(t *testing.T) { + _, _, err := db.Insert(). + Table(users). + Set(users.Email, schema.Placeholder("email")). + ToSQL() + if err != rain.ErrPreparedArgsRequired { + t.Errorf("expected ErrPreparedArgsRequired, got %v", err) + } + }) + + t.Run("update", func(t *testing.T) { + _, _, err := db.Update(). + Table(users). + Set(users.Name, schema.Placeholder("name")). + Where(users.ID.Eq(int64(1))). + ToSQL() + if err != rain.ErrPreparedArgsRequired { + t.Errorf("expected ErrPreparedArgsRequired, got %v", err) + } + }) + + t.Run("delete", func(t *testing.T) { + _, _, err := db.Delete(). + Table(users). + Where(users.ID.EqExpr(schema.Placeholder("id"))). + ToSQL() + if err != rain.ErrPreparedArgsRequired { + t.Errorf("expected ErrPreparedArgsRequired, got %v", err) + } + }) +} diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go index 91ab74d..3754de3 100644 --- a/pkg/rain/query_delete.go +++ b/pkg/rain/query_delete.go @@ -67,23 +67,71 @@ func (q *DeleteQuery) Unbounded() *DeleteQuery { return q } +// Prepare compiles and prepares the DELETE query. +func (q *DeleteQuery) Prepare(ctx context.Context) (*PreparedDeleteQuery, error) { + if q.runner == nil { + return nil, ErrNoConnection + } + + runner, ok := q.runner.(preparingQueryRunner) + if !ok { + return nil, ErrPrepareNotSupported + } + + compiled, err := q.compile() + if err != nil { + return nil, err + } + + stmt, err := runner.prepareContext(ctx, compiled.sql) + if err != nil { + return nil, err + } + + return &PreparedDeleteQuery{ + table: q.table, + compiled: compiled, + stmt: stmt, + }, nil +} + // ToSQL compiles the delete into SQL and args. func (q *DeleteQuery) ToSQL() (string, []any, error) { + compiled, err := q.compile() + if err != nil { + return "", nil, err + } + args, err := compiled.literalArgs() + if err != nil { + return "", nil, err + } + return compiled.sql, args, nil +} + +func (q *DeleteQuery) compile() (compiledQuery, error) { if q.table == nil { - return "", nil, errors.New("rain: delete query requires a table") + return compiledQuery{}, errors.New("rain: delete query requires a table") } if q.table.IsView { - return "", nil, fmt.Errorf("rain: cannot delete from view %q", q.table.Name) + return compiledQuery{}, fmt.Errorf("rain: cannot delete from view %q", q.table.Name) } if len(q.where) == 0 && !q.unbounded { - return "", nil, errors.New("rain: delete query requires at least one WHERE predicate; call Unbounded() to allow all rows") + return compiledQuery{}, errors.New("rain: delete query requires at least one WHERE predicate; call Unbounded() to allow all rows") } ctx := newCompileContext(q.dialect) defer releaseCompileContext(ctx) + if err := q.writeSQL(ctx); err != nil { + return compiledQuery{}, err + } + + return ctx.compiledQuery(), ctx.err +} + +func (q *DeleteQuery) writeSQL(ctx *compileContext) error { if err := writeCTEs(ctx, q.ctes, "delete"); err != nil { - return "", nil, err + return err } ctx.writeString("DELETE FROM ") @@ -91,24 +139,15 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) { if len(q.where) > 0 { ctx.writeString(" WHERE ") if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { - return "", nil, err + return err } } if err := writeOrderLimit(ctx, q.order, q.limit, nil, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil { - return "", nil, err - } - - if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { - return "", nil, err + return err } - compiled := ctx.compiledQuery() - args, err := compiled.literalArgs() - if err != nil { - return "", nil, err - } - return compiled.sql, args, ctx.err + return ctx.writeReturning(q.returning, q.returningClause()) } func (q *DeleteQuery) returningClause() returningClause { diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go index ee0facd..d3ff0e0 100644 --- a/pkg/rain/query_insert.go +++ b/pkg/rain/query_insert.go @@ -122,19 +122,70 @@ func (q *InsertQuery) Returning(exprs ...schema.Expression) *InsertQuery { return q } +// Prepare compiles and prepares the INSERT query. +func (q *InsertQuery) Prepare(ctx context.Context) (*PreparedInsertQuery, error) { + if q.runner == nil { + return nil, ErrNoConnection + } + + runner, ok := q.runner.(preparingQueryRunner) + if !ok { + return nil, ErrPrepareNotSupported + } + + compiled, err := q.compile() + if err != nil { + return nil, err + } + + stmt, err := runner.prepareContext(ctx, compiled.sql) + if err != nil { + return nil, err + } + + return &PreparedInsertQuery{ + table: q.table, + compiled: compiled, + stmt: stmt, + }, nil +} + // ToSQL compiles the insert into SQL and args. func (q *InsertQuery) ToSQL() (string, []any, error) { - if q.selectQuery != nil { - return q.toSelectSQL() + compiled, err := q.compile() + if err != nil { + return "", nil, err } - - rows, err := q.insertAssignments() + args, err := compiled.literalArgs() if err != nil { return "", nil, err } + return compiled.sql, args, nil +} +func (q *InsertQuery) compile() (compiledQuery, error) { ctx := newCompileContext(q.dialect) defer releaseCompileContext(ctx) + + if q.selectQuery != nil { + if err := q.writeSelectSQL(ctx); err != nil { + return compiledQuery{}, err + } + } else { + if err := q.writeValuesSQL(ctx); err != nil { + return compiledQuery{}, err + } + } + + return ctx.compiledQuery(), ctx.err +} + +func (q *InsertQuery) writeValuesSQL(ctx *compileContext) error { + rows, err := q.insertAssignments() + if err != nil { + return err + } + ctx.writeString("INSERT INTO ") ctx.writeTableName(q.table) ctx.writeString(" (") @@ -155,34 +206,25 @@ func (q *InsertQuery) ToSQL() (string, []any, error) { ctx.writeString(", ") } if err := ctx.writeExpression(item.value); err != nil { - return "", nil, err + return err } } ctx.writeByte(')') } if err := q.writeConflictClause(ctx); err != nil { - return "", nil, err - } - - if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { - return "", nil, err + return err } - compiled := ctx.compiledQuery() - args, err := compiled.literalArgs() - if err != nil { - return "", nil, err - } - return compiled.sql, args, ctx.err + return ctx.writeReturning(q.returning, q.returningClause()) } -func (q *InsertQuery) toSelectSQL() (string, []any, error) { +func (q *InsertQuery) writeSelectSQL(ctx *compileContext) error { if err := q.validateSources(); err != nil { - return "", nil, err + return err } if err := q.validateInsertSelectColumns(); err != nil { - return "", nil, err + return err } selectQuery := q.selectQuery @@ -190,8 +232,6 @@ func (q *InsertQuery) toSelectSQL() (string, []any, error) { selectQuery = selectQuery.withSQLiteInsertSelectConflictWhere() } - ctx := newCompileContext(q.dialect) - defer releaseCompileContext(ctx) ctx.writeString("INSERT INTO ") ctx.writeTableName(q.table) @@ -208,23 +248,14 @@ func (q *InsertQuery) toSelectSQL() (string, []any, error) { ctx.writeByte(' ') if err := selectQuery.writeSQL(ctx); err != nil { - return "", nil, err + return err } if err := q.writeConflictClause(ctx); err != nil { - return "", nil, err - } - - if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { - return "", nil, err + return err } - compiled := ctx.compiledQuery() - args, err := compiled.literalArgs() - if err != nil { - return "", nil, err - } - return compiled.sql, args, ctx.err + return ctx.writeReturning(q.returning, q.returningClause()) } func (q *InsertQuery) validateInsertSelectColumns() error { diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go index d24ccb2..aea64f8 100644 --- a/pkg/rain/query_update.go +++ b/pkg/rain/query_update.go @@ -81,26 +81,74 @@ func (q *UpdateQuery) Unbounded() *UpdateQuery { return q } +// Prepare compiles and prepares the UPDATE query. +func (q *UpdateQuery) Prepare(ctx context.Context) (*PreparedUpdateQuery, error) { + if q.runner == nil { + return nil, ErrNoConnection + } + + runner, ok := q.runner.(preparingQueryRunner) + if !ok { + return nil, ErrPrepareNotSupported + } + + compiled, err := q.compile() + if err != nil { + return nil, err + } + + stmt, err := runner.prepareContext(ctx, compiled.sql) + if err != nil { + return nil, err + } + + return &PreparedUpdateQuery{ + table: q.table, + compiled: compiled, + stmt: stmt, + }, nil +} + // ToSQL compiles the update into SQL and args. func (q *UpdateQuery) ToSQL() (string, []any, error) { + compiled, err := q.compile() + if err != nil { + return "", nil, err + } + args, err := compiled.literalArgs() + if err != nil { + return "", nil, err + } + return compiled.sql, args, nil +} + +func (q *UpdateQuery) compile() (compiledQuery, error) { if q.table == nil { - return "", nil, errors.New("rain: update query requires a table") + return compiledQuery{}, errors.New("rain: update query requires a table") } if q.table.IsView { - return "", nil, fmt.Errorf("rain: cannot update view %q", q.table.Name) + return compiledQuery{}, fmt.Errorf("rain: cannot update view %q", q.table.Name) } if len(q.values) == 0 { - return "", nil, errors.New("rain: update query requires at least one assignment") + return compiledQuery{}, errors.New("rain: update query requires at least one assignment") } if len(q.where) == 0 && !q.unbounded { - return "", nil, errors.New("rain: update query requires at least one WHERE predicate; call Unbounded() to allow all rows") + return compiledQuery{}, errors.New("rain: update query requires at least one WHERE predicate; call Unbounded() to allow all rows") } ctx := newCompileContext(q.dialect) defer releaseCompileContext(ctx) + if err := q.writeSQL(ctx); err != nil { + return compiledQuery{}, err + } + + return ctx.compiledQuery(), ctx.err +} + +func (q *UpdateQuery) writeSQL(ctx *compileContext) error { if err := writeCTEs(ctx, q.ctes, "update"); err != nil { - return "", nil, err + return err } ctx.writeString("UPDATE ") @@ -108,7 +156,7 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) { ctx.writeString(" SET ") for idx, item := range q.values { if err := validateAssignmentTarget(q.table, item); err != nil { - return "", nil, err + return err } if idx > 0 { ctx.writeString(", ") @@ -116,31 +164,22 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) { ctx.writeQuotedIdentifier(item.column.ColumnDef().Name) ctx.writeString(" = ") if err := ctx.writeExpression(item.value); err != nil { - return "", nil, err + return err } } if len(q.where) > 0 { ctx.writeString(" WHERE ") if err := ctx.writePredicate(joinPredicates(q.where)); err != nil { - return "", nil, err + return err } } if err := writeOrderLimit(ctx, q.order, q.limit, nil, dialect.FeatureUpdateOrder, dialect.FeatureUpdateLimit); err != nil { - return "", nil, err - } - - if err := ctx.writeReturning(q.returning, q.returningClause()); err != nil { - return "", nil, err + return err } - compiled := ctx.compiledQuery() - args, err := compiled.literalArgs() - if err != nil { - return "", nil, err - } - return compiled.sql, args, ctx.err + return ctx.writeReturning(q.returning, q.returningClause()) } func (q *UpdateQuery) returningClause() returningClause { diff --git a/pkg/rain/sqlite_integration_test.go b/pkg/rain/sqlite_integration_test.go index a981cd7..33f7261 100644 --- a/pkg/rain/sqlite_integration_test.go +++ b/pkg/rain/sqlite_integration_test.go @@ -818,6 +818,106 @@ func TestSQLiteIntegrationRichAdvancedSelectsAndPreparedQueries(t *testing.T) { } } +func TestSQLiteIntegrationPreparedExecQueries(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openSQLiteTestDB(t) + users, _, _ := defineSQLiteTables() + createSQLiteSchema(t, ctx, db) + + // Prepare INSERT + prepInsert, err := db.Insert(). + Table(users). + Set(users.Email, schema.Placeholder("email")). + Set(users.Name, schema.Placeholder("name")). + Prepare(ctx) + if err != nil { + t.Fatalf("prepare insert failed: %v", err) + } + defer func() { _ = prepInsert.Close() }() + + res, err := prepInsert.Exec(ctx, rain.PreparedArgs{ + "email": "prep@example.com", + "name": "Prepared", + }) + if err != nil { + t.Fatalf("exec prepared insert failed: %v", err) + } + id, _ := res.LastInsertId() + + // Prepare UPDATE + prepUpdate, err := db.Update(). + Table(users). + Set(users.Name, schema.Placeholder("name")). + Where(users.ID.EqExpr(schema.Placeholder("id"))). + Prepare(ctx) + if err != nil { + t.Fatalf("prepare update failed: %v", err) + } + defer func() { _ = prepUpdate.Close() }() + + _, err = prepUpdate.Exec(ctx, rain.PreparedArgs{ + "id": id, + "name": "Updated Prepared", + }) + if err != nil { + t.Fatalf("exec prepared update failed: %v", err) + } + + // Verify update + var row sqliteUserRow + if err := db.Select().Table(users).Where(users.ID.Eq(id)).Scan(ctx, &row); err != nil { + t.Fatalf("select row failed: %v", err) + } + if row.Name != "Updated Prepared" { + t.Fatalf("expected Updated Prepared, got %q", row.Name) + } + + // Prepare DELETE + prepDelete, err := db.Delete(). + Table(users). + Where(users.ID.EqExpr(schema.Placeholder("id"))). + Prepare(ctx) + if err != nil { + t.Fatalf("prepare delete failed: %v", err) + } + defer func() { _ = prepDelete.Close() }() + + _, err = prepDelete.Exec(ctx, rain.PreparedArgs{"id": id}) + if err != nil { + t.Fatalf("exec prepared delete failed: %v", err) + } + + // Verify delete + exists, err := db.Select().Table(users).Where(users.ID.Eq(id)).Exists(ctx) + if err != nil { + t.Fatalf("check exists failed: %v", err) + } + if exists { + t.Fatalf("expected row to be deleted") + } + + // Test SCAN (RETURNING) + prepInsertScan, err := db.Insert(). + Table(users). + Set(users.Email, schema.Placeholder("email")). + Returning(users.ID, users.Email). + Prepare(ctx) + if err != nil { + t.Fatalf("prepare insert scan failed: %v", err) + } + defer func() { _ = prepInsertScan.Close() }() + + var inserted sqliteUserRow + if err := prepInsertScan.Scan(ctx, rain.PreparedArgs{"email": "scan@example.com"}, &inserted); err != nil { + t.Fatalf("scan prepared insert failed: %v", err) + } + if inserted.Email != "scan@example.com" || inserted.ID == 0 { + t.Fatalf("unexpected scan result: %+v", inserted) + } +} + func TestSQLiteIntegrationHasOneRelation(t *testing.T) { t.Parallel()