Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/rain/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ func createViewSQL(d dialect.Dialect, table *schema.TableDef) (string, error) {
}

ctx := newCompileContext(d)
defer releaseCompileContext(ctx)
ctx.useLiterals = true
ctx.writeString("CREATE VIEW ")
ctx.writeQuotedIdentifier(table.Name)
Expand Down
78 changes: 45 additions & 33 deletions pkg/rain/query_common_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,51 @@ func TestQueryCommonHelpers(t *testing.T) {
t.Fatalf("expected non-table select source to return nil, got %#v", got)
}

ctx := newCompileContext(dialectForTest(t, "postgres"))
if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: tableDefSource{table: users.TableDef()}}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") {
t.Fatalf("expected empty alias error, got %v", err)
}

ctx = newCompileContext(dialectForTest(t, "postgres"))
if err := (subqueryTableSource{alias: "u", query: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") {
t.Fatalf("expected nil query error, got %v", err)
}

ctx = newCompileContext(dialectForTest(t, "postgres"))
err := (subqueryTableSource{
alias: "u",
query: &SelectQuery{
dialect: ctx.dialect,
table: tableDefSource{table: users.TableDef()},
cols: []schema.Expression{users.ID},
},
}).writeSQL(ctx)
if err != nil {
t.Fatalf("subqueryTableSource.writeSQL returned error: %v", err)
}
if !strings.Contains(ctx.String(), `AS "u"`) {
t.Fatalf("expected compiled subquery alias, got %q", ctx.String())
}

ctx = newCompileContext(dialectForTest(t, "postgres"))
if err := (subqueryTableSource{
alias: "broken",
query: &SelectQuery{dialect: ctx.dialect},
}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "requires a table") {
t.Fatalf("expected nested query error, got %v", err)
}
t.Run("SubqueryAliasValidation", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: tableDefSource{table: users.TableDef()}}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") {
t.Fatalf("expected empty alias error, got %v", err)
}
})

t.Run("SubqueryNilQueryValidation", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := (subqueryTableSource{alias: "u", query: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") {
t.Fatalf("expected nil query error, got %v", err)
}
})

t.Run("SubqueryWriteSQL", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
err := (subqueryTableSource{
alias: "u",
query: &SelectQuery{
dialect: ctx.dialect,
table: tableDefSource{table: users.TableDef()},
cols: []schema.Expression{users.ID},
},
}).writeSQL(ctx)
if err != nil {
t.Fatalf("subqueryTableSource.writeSQL returned error: %v", err)
}
if !strings.Contains(ctx.String(), `AS "u"`) {
t.Fatalf("expected compiled subquery alias, got %q", ctx.String())
}
})

t.Run("NestedQueryError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := (subqueryTableSource{
alias: "broken",
query: &SelectQuery{dialect: ctx.dialect},
}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "requires a table") {
t.Fatalf("expected nested query error, got %v", err)
}
})
}

func TestCloseRows(t *testing.T) {
Expand Down
40 changes: 34 additions & 6 deletions pkg/rain/query_compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"strings"
"sync"

"github.com/hyperlocalise/rain-orm/pkg/dialect"
"github.com/hyperlocalise/rain-orm/pkg/schema"
Expand Down Expand Up @@ -84,23 +85,50 @@ type compileContext struct {
useLiterals bool
}

var compileContextPool = sync.Pool{
New: func() any {
return &compileContext{
argPlan: make([]compiledArg, 0, 8),
}
},
}

func newCompileContext(d dialect.Dialect) *compileContext {
return &compileContext{
dialect: d,
argPlan: make([]compiledArg, 0, 8),
}
ctx := compileContextPool.Get().(*compileContext)
ctx.reset(d)
return ctx
}

func releaseCompileContext(ctx *compileContext) {
compileContextPool.Put(ctx)
}

func (c *compileContext) reset(d dialect.Dialect) {
c.builder.Reset()
c.dialect = d
// Clear the argPlan slice to ensure any values it contains can be
// garbage collected before we reset its length for reuse.
clear(c.argPlan)
c.argPlan = c.argPlan[:0]
c.err = nil
c.skipCTEs = false
c.useLiterals = false
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.

func (c *compileContext) String() string {
return c.builder.String()
}

func (c *compileContext) compiledQuery() compiledQuery {
// OPTIMIZATION: Explicitly copy the argPlan slice so that the compileContext
// and its underlying array can be safely returned to the sync.Pool without
// causing data corruption for the caller of compiledQuery.
argPlan := append([]compiledArg(nil), c.argPlan...)

compiled := compiledQuery{
sql: c.String(),
argPlan: make([]compiledArg, len(c.argPlan)),
argPlan: argPlan,
}
copy(compiled.argPlan, c.argPlan)
for _, arg := range compiled.argPlan {
if arg.kind == compiledArgNamedPlaceholder {
compiled.hasNames = true
Expand Down
101 changes: 69 additions & 32 deletions pkg/rain/query_compile_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,27 @@ func TestCompileContextAndAssignmentsHelpers(t *testing.T) {

users, posts := defineInternalQueryTables()

ctx := newCompileContext(dialectForTest(t, "postgres"))
if err := ctx.writeRaw(schema.Raw("NOW()")); err != nil {
t.Fatalf("writeRaw without args failed: %v", err)
}
if ctx.String() != "NOW()" {
t.Fatalf("unexpected raw SQL: %s", ctx.String())
}
t.Run("WriteRawWithoutArgs", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeRaw(schema.Raw("NOW()")); err != nil {
t.Fatalf("writeRaw without args failed: %v", err)
}
if ctx.String() != "NOW()" {
t.Fatalf("unexpected raw SQL: %s", ctx.String())
}
})

ctx = newCompileContext(dialectForTest(t, "postgres"))
if err := ctx.writeRaw(schema.Raw("? + ?", 1, 2)); err != nil {
t.Fatalf("writeRaw placeholders failed: %v", err)
}
if ctx.String() != "$1 + $2" {
t.Fatalf("unexpected placeholder SQL: %s", ctx.String())
}
t.Run("WriteRawWithPlaceholders", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeRaw(schema.Raw("? + ?", 1, 2)); err != nil {
t.Fatalf("writeRaw placeholders failed: %v", err)
}
if ctx.String() != "$1 + $2" {
t.Fatalf("unexpected placeholder SQL: %s", ctx.String())
}
})

for _, tc := range []struct {
name string
Expand All @@ -146,6 +152,7 @@ func TestCompileContextAndAssignmentsHelpers(t *testing.T) {
} {
t.Run("named placeholder "+tc.name, func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, tc.dialect))
defer releaseCompileContext(ctx)
expr := schema.And(
users.Email.EqExpr(schema.Placeholder("email")),
users.Active.EqExpr(schema.Placeholder("active")),
Expand Down Expand Up @@ -177,24 +184,53 @@ func TestCompileContextAndAssignmentsHelpers(t *testing.T) {
})
}

if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("?", 1, 2)); err == nil || !strings.Contains(err.Error(), "unused args") {
t.Fatalf("expected raw unused args error, got %v", err)
}
if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("? ?", 1)); err == nil || !strings.Contains(err.Error(), "placeholder count") {
t.Fatalf("expected raw placeholder mismatch error, got %v", err)
}
if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(schema.CoalesceExpr{Exprs: []schema.Expression{users.Email}}); err == nil || !strings.Contains(err.Error(), "at least two expressions") {
t.Fatalf("expected COALESCE arity error, got %v", err)
}
if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(schema.CaseExpr{}); err == nil || !strings.Contains(err.Error(), "CASE expression requires at least one WHEN clause") {
t.Fatalf("expected CASE arity error, got %v", err)
}
if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(users.ID.In()); err == nil || !strings.Contains(err.Error(), "requires at least one value") {
t.Fatalf("expected empty IN error, got %v", err)
}
if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(nil); err == nil || !strings.Contains(err.Error(), "unsupported expression type") {
t.Fatalf("expected unsupported expression error, got %v", err)
}
t.Run("RawUnusedArgsError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeRaw(schema.Raw("?", 1, 2)); err == nil || !strings.Contains(err.Error(), "unused args") {
t.Fatalf("expected raw unused args error, got %v", err)
}
})

t.Run("RawPlaceholderMismatchError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeRaw(schema.Raw("? ?", 1)); err == nil || !strings.Contains(err.Error(), "placeholder count") {
t.Fatalf("expected raw placeholder mismatch error, got %v", err)
}
})

t.Run("CoalesceArityError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeExpression(schema.CoalesceExpr{Exprs: []schema.Expression{users.Email}}); err == nil || !strings.Contains(err.Error(), "at least two expressions") {
t.Fatalf("expected COALESCE arity error, got %v", err)
}
})

t.Run("CaseArityError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeExpression(schema.CaseExpr{}); err == nil || !strings.Contains(err.Error(), "CASE expression requires at least one WHEN clause") {
t.Fatalf("expected CASE arity error, got %v", err)
}
})

t.Run("EmptyInError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeExpression(users.ID.In()); err == nil || !strings.Contains(err.Error(), "requires at least one value") {
t.Fatalf("expected empty IN error, got %v", err)
}
})

t.Run("UnsupportedExpressionError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := ctx.writeExpression(nil); err == nil || !strings.Contains(err.Error(), "unsupported expression type") {
t.Fatalf("expected unsupported expression error, got %v", err)
}
})

merged, err := mergeAssignments(users.TableDef(),
[]assignment{
Expand Down Expand Up @@ -426,6 +462,7 @@ func TestNewOperatorsSQL(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
ctx := newCompileContext(d)
defer releaseCompileContext(ctx)
if err := ctx.writeExpression(tc.expr); err != nil {
t.Fatalf("writeExpression failed: %v", err)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/rain/query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) {
}

ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)
ctx.writeString("DELETE FROM ")
ctx.writeTableName(q.table)
if len(q.where) > 0 {
Expand Down
2 changes: 2 additions & 0 deletions pkg/rain/query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func (q *InsertQuery) ToSQL() (string, []any, error) {
}

ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)
ctx.writeString("INSERT INTO ")
ctx.writeTableName(q.table)
ctx.writeString(" (")
Expand Down Expand Up @@ -190,6 +191,7 @@ func (q *InsertQuery) toSelectSQL() (string, []any, error) {
}

ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)
ctx.writeString("INSERT INTO ")
ctx.writeTableName(q.table)

Expand Down
2 changes: 2 additions & 0 deletions pkg/rain/query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ func (q *SelectQuery) compile() (compiledQuery, error) {
}

ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)
if err := q.writeSQL(ctx); err != nil {
return compiledQuery{}, err
}
Expand All @@ -940,6 +941,7 @@ func (q *SelectQuery) compileAggregate(selection string) (compiledQuery, error)
}

ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)
ctx.writeString("SELECT ")
ctx.writeString(selection)
ctx.writeString(" FROM ")
Expand Down
1 change: 1 addition & 0 deletions pkg/rain/query_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) {
}

ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)
ctx.writeString("UPDATE ")
ctx.writeTableName(q.table)
ctx.writeString(" SET ")
Expand Down