diff --git a/README.md b/README.md index c9f2531..d10bd8d 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,29 @@ sql: emit_async_querier: true ``` +### Configuration Options + +These are the supported `options` for the `py` plugin. Add them under the `codegen[].options` section of your `sqlc.yaml`. + +- package: Module path used for imports in generated query files (e.g., `from import models`). +- emit_sync_querier: Emit a synchronous `Querier` class using `sqlalchemy.engine.Connection`. +- emit_async_querier: Emit an asynchronous `AsyncQuerier` class using `sqlalchemy.ext.asyncio.AsyncConnection`. +- emit_pydantic_models: Emit Pydantic models instead of `dataclasses` for models.py. See the section below. +- emit_str_enum: Emit enums as `enum.StrEnum` (Python >=3.11). When false, emit `(str, enum.Enum)`. See the section below. +- emit_schema_name_prefix: When true, prefix non-default schema to generated types to avoid name collisions. Examples: + - false (default): `Book`, `BookStatus` + - true: `MySchemaBook`, `MySchemaBookStatus` when the objects live in schema `my_schema`. +- emit_exact_table_names: When true, do not singularize table names for model class names. +- query_parameter_limit: Integer controlling when query params are grouped into a single struct argument. + - If the number of parameters exceeds this value, a single `Params` struct is emitted. + - Set to 0 to always emit a struct; omit or set to a large value to keep separate parameters. +- inflection_exclude_table_names: A list of table names to exclude from singularization when `emit_exact_table_names` is false. +- overrides: Column type overrides; see the section below. + +Notes +- out: Controlled by `codegen[].out` at the sqlc level. The plugin’s `out` option is not used; prefer the top-level `out` value. + + ### Emit Pydantic Models instead of `dataclasses` Option: `emit_pydantic_models` diff --git a/internal/config.go b/internal/config.go index 1a8a565..8fc6a0a 100644 --- a/internal/config.go +++ b/internal/config.go @@ -1,13 +1,21 @@ package python +type OverrideColumn struct { + Column string `json:"column"` + PyType string `json:"py_type"` + PyImport string `json:"py_import"` +} + type Config struct { - EmitExactTableNames bool `json:"emit_exact_table_names"` - EmitSyncQuerier bool `json:"emit_sync_querier"` - EmitAsyncQuerier bool `json:"emit_async_querier"` - Package string `json:"package"` - Out string `json:"out"` - EmitPydanticModels bool `json:"emit_pydantic_models"` - EmitStrEnum bool `json:"emit_str_enum"` - QueryParameterLimit *int32 `json:"query_parameter_limit"` - InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + EmitExactTableNames bool `json:"emit_exact_table_names"` + EmitSyncQuerier bool `json:"emit_sync_querier"` + EmitAsyncQuerier bool `json:"emit_async_querier"` + Package string `json:"package"` + Out string `json:"out"` + EmitPydanticModels bool `json:"emit_pydantic_models"` + EmitStrEnum bool `json:"emit_str_enum"` + EmitSchemaNamePrefix bool `json:"emit_schema_name_prefix"` + QueryParameterLimit *int32 `json:"query_parameter_limit"` + InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"` + Overrides []OverrideColumn `json:"overrides"` } diff --git a/internal/gen.go b/internal/gen.go index ca18c21..0c145ea 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -181,7 +181,41 @@ func (q Query) ArgDictNode() *pyast.Node { } func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { - typ := pyInnerType(req, col) + // Parse the configuration + var conf Config + if len(req.PluginOptions) > 0 { + if err := json.Unmarshal(req.PluginOptions, &conf); err != nil { + log.Printf("failed to parse plugin options: %s", err) + } + } + + // Check for overrides + if len(conf.Overrides) > 0 && col.Table != nil { + tableName := col.Table.Name + if col.Table.Schema != "" && col.Table.Schema != req.Catalog.DefaultSchema { + tableName = col.Table.Schema + "." + tableName + } + + // Look for a matching override + for _, override := range conf.Overrides { + overrideKey := tableName + "." + col.Name + if override.Column == overrideKey { + // Found a match, use the override + typeStr := override.PyType + if override.PyImport != "" && !strings.Contains(typeStr, ".") { + typeStr = override.PyImport + "." + override.PyType + } + return pyType{ + InnerType: typeStr, + IsArray: col.IsArray, + IsNull: !col.NotNull, + } + } + } + } + + // No override found, use the standard type mapping + typ := pyInnerType(conf, req, col) return pyType{ InnerType: typ, IsArray: col.IsArray, @@ -189,10 +223,10 @@ func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { } } -func pyInnerType(req *plugin.GenerateRequest, col *plugin.Column) string { +func pyInnerType(conf Config, req *plugin.GenerateRequest, col *plugin.Column) string { switch req.Settings.Engine { case "postgresql": - return postgresType(req, col) + return postgresType(conf, req, col) default: log.Println("unsupported engine type") return "Any" @@ -226,7 +260,7 @@ func pyEnumValueName(value string) string { return strings.ToUpper(id) } -func buildEnums(req *plugin.GenerateRequest) []Enum { +func buildEnums(conf Config, req *plugin.GenerateRequest) []Enum { var enums []Enum for _, schema := range req.Catalog.Schemas { if schema.Name == "pg_catalog" || schema.Name == "information_schema" { @@ -234,10 +268,10 @@ func buildEnums(req *plugin.GenerateRequest) []Enum { } for _, enum := range schema.Enums { var enumName string - if schema.Name == req.Catalog.DefaultSchema { - enumName = enum.Name - } else { + if conf.EmitSchemaNamePrefix && schema.Name != req.Catalog.DefaultSchema { enumName = schema.Name + "_" + enum.Name + } else { + enumName = enum.Name } e := Enum{ Name: modelName(enumName, req.Settings), @@ -267,10 +301,10 @@ func buildModels(conf Config, req *plugin.GenerateRequest) []Struct { } for _, table := range schema.Tables { var tableName string - if schema.Name == req.Catalog.DefaultSchema { - tableName = table.Rel.Name - } else { + if conf.EmitSchemaNamePrefix && schema.Name != req.Catalog.DefaultSchema { tableName = schema.Name + "_" + table.Rel.Name + } else { + tableName = table.Rel.Name } structName := tableName if !conf.EmitExactTableNames { @@ -1185,7 +1219,7 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR } } - enums := buildEnums(req) + enums := buildEnums(conf, req) models := buildModels(conf, req) queries, err := buildQueries(conf, req, models) if err != nil { diff --git a/internal/postgresql_type.go b/internal/postgresql_type.go index 3d0891b..dfc0800 100644 --- a/internal/postgresql_type.go +++ b/internal/postgresql_type.go @@ -7,8 +7,8 @@ import ( "github.com/sqlc-dev/plugin-sdk-go/sdk" ) -func postgresType(req *plugin.GenerateRequest, col *plugin.Column) string { - columnType := sdk.DataType(col.Type) +func postgresType(conf Config, req *plugin.GenerateRequest, col *plugin.Column) string { + columnType := sdk.DataType(col.Type) switch columnType { case "serial", "serial4", "pg_catalog.serial4", "bigserial", "serial8", "pg_catalog.serial8", "smallserial", "serial2", "pg_catalog.serial2", "integer", "int", "int4", "pg_catalog.int4", "bigint", "int8", "pg_catalog.int8", "smallint", "int2", "pg_catalog.int2": @@ -42,21 +42,23 @@ func postgresType(req *plugin.GenerateRequest, col *plugin.Column) string { return "str" case "ltree", "lquery", "ltxtquery": return "str" - default: - for _, schema := range req.Catalog.Schemas { - if schema.Name == "pg_catalog" || schema.Name == "information_schema" { - continue - } - for _, enum := range schema.Enums { - if columnType == enum.Name { - if schema.Name == req.Catalog.DefaultSchema { - return "models." + modelName(enum.Name, req.Settings) - } - return "models." + modelName(schema.Name+"_"+enum.Name, req.Settings) - } - } - } - log.Printf("unknown PostgreSQL type: %s\n", columnType) - return "Any" - } + default: + for _, schema := range req.Catalog.Schemas { + if schema.Name == "pg_catalog" || schema.Name == "information_schema" { + continue + } + for _, enum := range schema.Enums { + // Match both unqualified and schema-qualified enum type names + if columnType == enum.Name || columnType == schema.Name+"."+enum.Name { + name := enum.Name + if conf.EmitSchemaNamePrefix && schema.Name != req.Catalog.DefaultSchema { + name = schema.Name + "_" + enum.Name + } + return "models." + modelName(name, req.Settings) + } + } + } + log.Printf("unknown PostgreSQL type: %s\n", columnType) + return "Any" + } }