diff --git a/pkg/rain/ddl.go b/pkg/rain/ddl.go index b0057a4..8554a25 100644 --- a/pkg/rain/ddl.go +++ b/pkg/rain/ddl.go @@ -552,6 +552,7 @@ func createViewSQL(d dialect.Dialect, table *schema.TableDef) (string, error) { } ctx := newCompileContext(d) + defer releaseCompileContext(ctx) ctx.useLiterals = true ctx.writeString("CREATE VIEW ") ctx.writeQuotedIdentifier(table.Name) diff --git a/pkg/rain/query_common_internal_test.go b/pkg/rain/query_common_internal_test.go index cdec8f6..c50ae72 100644 --- a/pkg/rain/query_common_internal_test.go +++ b/pkg/rain/query_common_internal_test.go @@ -26,39 +26,51 @@ func TestQueryCommonHelpers(t *testing.T) { t.Fatalf("expected non-table select source to return nil, got %#v", got) } - ctx := newCompileContext(dialectForTest(t, "postgres")) - if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: tableDefSource{table: users.TableDef()}}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") { - t.Fatalf("expected empty alias error, got %v", err) - } - - ctx = newCompileContext(dialectForTest(t, "postgres")) - if err := (subqueryTableSource{alias: "u", query: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") { - t.Fatalf("expected nil query error, got %v", err) - } - - ctx = newCompileContext(dialectForTest(t, "postgres")) - err := (subqueryTableSource{ - alias: "u", - query: &SelectQuery{ - dialect: ctx.dialect, - table: tableDefSource{table: users.TableDef()}, - cols: []schema.Expression{users.ID}, - }, - }).writeSQL(ctx) - if err != nil { - t.Fatalf("subqueryTableSource.writeSQL returned error: %v", err) - } - if !strings.Contains(ctx.String(), `AS "u"`) { - t.Fatalf("expected compiled subquery alias, got %q", ctx.String()) - } - - ctx = newCompileContext(dialectForTest(t, "postgres")) - if err := (subqueryTableSource{ - alias: "broken", - query: &SelectQuery{dialect: ctx.dialect}, - }).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "requires a table") { - t.Fatalf("expected nested query error, got %v", err) - } + t.Run("SubqueryAliasValidation", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: tableDefSource{table: users.TableDef()}}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") { + t.Fatalf("expected empty alias error, got %v", err) + } + }) + + t.Run("SubqueryNilQueryValidation", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := (subqueryTableSource{alias: "u", query: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") { + t.Fatalf("expected nil query error, got %v", err) + } + }) + + t.Run("SubqueryWriteSQL", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + err := (subqueryTableSource{ + alias: "u", + query: &SelectQuery{ + dialect: ctx.dialect, + table: tableDefSource{table: users.TableDef()}, + cols: []schema.Expression{users.ID}, + }, + }).writeSQL(ctx) + if err != nil { + t.Fatalf("subqueryTableSource.writeSQL returned error: %v", err) + } + if !strings.Contains(ctx.String(), `AS "u"`) { + t.Fatalf("expected compiled subquery alias, got %q", ctx.String()) + } + }) + + t.Run("NestedQueryError", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := (subqueryTableSource{ + alias: "broken", + query: &SelectQuery{dialect: ctx.dialect}, + }).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "requires a table") { + t.Fatalf("expected nested query error, got %v", err) + } + }) } func TestCloseRows(t *testing.T) { diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index 5409d60..5047f5a 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strings" + "sync" "github.com/hyperlocalise/rain-orm/pkg/dialect" "github.com/hyperlocalise/rain-orm/pkg/schema" @@ -84,11 +85,34 @@ type compileContext struct { useLiterals bool } +var compileContextPool = sync.Pool{ + New: func() any { + return &compileContext{ + argPlan: make([]compiledArg, 0, 8), + } + }, +} + func newCompileContext(d dialect.Dialect) *compileContext { - return &compileContext{ - dialect: d, - argPlan: make([]compiledArg, 0, 8), - } + ctx := compileContextPool.Get().(*compileContext) + ctx.reset(d) + return ctx +} + +func releaseCompileContext(ctx *compileContext) { + compileContextPool.Put(ctx) +} + +func (c *compileContext) reset(d dialect.Dialect) { + c.builder.Reset() + c.dialect = d + // Clear the argPlan slice to ensure any values it contains can be + // garbage collected before we reset its length for reuse. + clear(c.argPlan) + c.argPlan = c.argPlan[:0] + c.err = nil + c.skipCTEs = false + c.useLiterals = false } func (c *compileContext) String() string { @@ -96,11 +120,15 @@ func (c *compileContext) String() string { } func (c *compileContext) compiledQuery() compiledQuery { + // OPTIMIZATION: Explicitly copy the argPlan slice so that the compileContext + // and its underlying array can be safely returned to the sync.Pool without + // causing data corruption for the caller of compiledQuery. + argPlan := append([]compiledArg(nil), c.argPlan...) + compiled := compiledQuery{ sql: c.String(), - argPlan: make([]compiledArg, len(c.argPlan)), + argPlan: argPlan, } - copy(compiled.argPlan, c.argPlan) for _, arg := range compiled.argPlan { if arg.kind == compiledArgNamedPlaceholder { compiled.hasNames = true diff --git a/pkg/rain/query_compile_internal_test.go b/pkg/rain/query_compile_internal_test.go index a56e34a..f7e08f2 100644 --- a/pkg/rain/query_compile_internal_test.go +++ b/pkg/rain/query_compile_internal_test.go @@ -119,21 +119,27 @@ func TestCompileContextAndAssignmentsHelpers(t *testing.T) { users, posts := defineInternalQueryTables() - ctx := newCompileContext(dialectForTest(t, "postgres")) - if err := ctx.writeRaw(schema.Raw("NOW()")); err != nil { - t.Fatalf("writeRaw without args failed: %v", err) - } - if ctx.String() != "NOW()" { - t.Fatalf("unexpected raw SQL: %s", ctx.String()) - } + t.Run("WriteRawWithoutArgs", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeRaw(schema.Raw("NOW()")); err != nil { + t.Fatalf("writeRaw without args failed: %v", err) + } + if ctx.String() != "NOW()" { + t.Fatalf("unexpected raw SQL: %s", ctx.String()) + } + }) - ctx = newCompileContext(dialectForTest(t, "postgres")) - if err := ctx.writeRaw(schema.Raw("? + ?", 1, 2)); err != nil { - t.Fatalf("writeRaw placeholders failed: %v", err) - } - if ctx.String() != "$1 + $2" { - t.Fatalf("unexpected placeholder SQL: %s", ctx.String()) - } + t.Run("WriteRawWithPlaceholders", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeRaw(schema.Raw("? + ?", 1, 2)); err != nil { + t.Fatalf("writeRaw placeholders failed: %v", err) + } + if ctx.String() != "$1 + $2" { + t.Fatalf("unexpected placeholder SQL: %s", ctx.String()) + } + }) for _, tc := range []struct { name string @@ -146,6 +152,7 @@ func TestCompileContextAndAssignmentsHelpers(t *testing.T) { } { t.Run("named placeholder "+tc.name, func(t *testing.T) { ctx := newCompileContext(dialectForTest(t, tc.dialect)) + defer releaseCompileContext(ctx) expr := schema.And( users.Email.EqExpr(schema.Placeholder("email")), users.Active.EqExpr(schema.Placeholder("active")), @@ -177,24 +184,53 @@ func TestCompileContextAndAssignmentsHelpers(t *testing.T) { }) } - if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("?", 1, 2)); err == nil || !strings.Contains(err.Error(), "unused args") { - t.Fatalf("expected raw unused args error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeRaw(schema.Raw("? ?", 1)); err == nil || !strings.Contains(err.Error(), "placeholder count") { - t.Fatalf("expected raw placeholder mismatch error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(schema.CoalesceExpr{Exprs: []schema.Expression{users.Email}}); err == nil || !strings.Contains(err.Error(), "at least two expressions") { - t.Fatalf("expected COALESCE arity error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(schema.CaseExpr{}); err == nil || !strings.Contains(err.Error(), "CASE expression requires at least one WHEN clause") { - t.Fatalf("expected CASE arity error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(users.ID.In()); err == nil || !strings.Contains(err.Error(), "requires at least one value") { - t.Fatalf("expected empty IN error, got %v", err) - } - if err := newCompileContext(dialectForTest(t, "postgres")).writeExpression(nil); err == nil || !strings.Contains(err.Error(), "unsupported expression type") { - t.Fatalf("expected unsupported expression error, got %v", err) - } + t.Run("RawUnusedArgsError", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeRaw(schema.Raw("?", 1, 2)); err == nil || !strings.Contains(err.Error(), "unused args") { + t.Fatalf("expected raw unused args error, got %v", err) + } + }) + + t.Run("RawPlaceholderMismatchError", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeRaw(schema.Raw("? ?", 1)); err == nil || !strings.Contains(err.Error(), "placeholder count") { + t.Fatalf("expected raw placeholder mismatch error, got %v", err) + } + }) + + t.Run("CoalesceArityError", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeExpression(schema.CoalesceExpr{Exprs: []schema.Expression{users.Email}}); err == nil || !strings.Contains(err.Error(), "at least two expressions") { + t.Fatalf("expected COALESCE arity error, got %v", err) + } + }) + + t.Run("CaseArityError", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeExpression(schema.CaseExpr{}); err == nil || !strings.Contains(err.Error(), "CASE expression requires at least one WHEN clause") { + t.Fatalf("expected CASE arity error, got %v", err) + } + }) + + t.Run("EmptyInError", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeExpression(users.ID.In()); err == nil || !strings.Contains(err.Error(), "requires at least one value") { + t.Fatalf("expected empty IN error, got %v", err) + } + }) + + t.Run("UnsupportedExpressionError", func(t *testing.T) { + ctx := newCompileContext(dialectForTest(t, "postgres")) + defer releaseCompileContext(ctx) + if err := ctx.writeExpression(nil); err == nil || !strings.Contains(err.Error(), "unsupported expression type") { + t.Fatalf("expected unsupported expression error, got %v", err) + } + }) merged, err := mergeAssignments(users.TableDef(), []assignment{ @@ -426,6 +462,7 @@ func TestNewOperatorsSQL(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { ctx := newCompileContext(d) + defer releaseCompileContext(ctx) if err := ctx.writeExpression(tc.expr); err != nil { t.Fatalf("writeExpression failed: %v", err) } diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go index 75681b3..67ddf10 100644 --- a/pkg/rain/query_delete.go +++ b/pkg/rain/query_delete.go @@ -57,6 +57,7 @@ func (q *DeleteQuery) ToSQL() (string, []any, error) { } ctx := newCompileContext(q.dialect) + defer releaseCompileContext(ctx) ctx.writeString("DELETE FROM ") ctx.writeTableName(q.table) if len(q.where) > 0 { diff --git a/pkg/rain/query_insert.go b/pkg/rain/query_insert.go index 3b00254..ee0facd 100644 --- a/pkg/rain/query_insert.go +++ b/pkg/rain/query_insert.go @@ -134,6 +134,7 @@ func (q *InsertQuery) ToSQL() (string, []any, error) { } ctx := newCompileContext(q.dialect) + defer releaseCompileContext(ctx) ctx.writeString("INSERT INTO ") ctx.writeTableName(q.table) ctx.writeString(" (") @@ -190,6 +191,7 @@ func (q *InsertQuery) toSelectSQL() (string, []any, error) { } ctx := newCompileContext(q.dialect) + defer releaseCompileContext(ctx) ctx.writeString("INSERT INTO ") ctx.writeTableName(q.table) diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index 6fbcb86..23ab5a7 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -916,6 +916,7 @@ func (q *SelectQuery) compile() (compiledQuery, error) { } ctx := newCompileContext(q.dialect) + defer releaseCompileContext(ctx) if err := q.writeSQL(ctx); err != nil { return compiledQuery{}, err } @@ -940,6 +941,7 @@ func (q *SelectQuery) compileAggregate(selection string) (compiledQuery, error) } ctx := newCompileContext(q.dialect) + defer releaseCompileContext(ctx) ctx.writeString("SELECT ") ctx.writeString(selection) ctx.writeString(" FROM ") diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go index f1b260a..358ad6f 100644 --- a/pkg/rain/query_update.go +++ b/pkg/rain/query_update.go @@ -74,6 +74,7 @@ func (q *UpdateQuery) ToSQL() (string, []any, error) { } ctx := newCompileContext(q.dialect) + defer releaseCompileContext(ctx) ctx.writeString("UPDATE ") ctx.writeTableName(q.table) ctx.writeString(" SET ")