From 7cce9513516b9901b83f55d338715c36b2deaa54 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Feb 2026 14:04:44 +0200 Subject: [PATCH 01/57] basic prepared parametrized queryset implementation --- tortoise/expressions.py | 24 +++++++++++---- tortoise/parameter.py | 20 ++++++++++++ tortoise/queryset.py | 68 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 tortoise/parameter.py diff --git a/tortoise/expressions.py b/tortoise/expressions.py index ff51e2cb9..8d833229e 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -24,6 +24,7 @@ from tortoise.fields.data import JSONField from tortoise.fields.relational import RelationalField from tortoise.filters import FilterInfoDict +from tortoise.parameter import Parameter from tortoise.query_utils import ( QueryModifier, TableCriterionTuple, @@ -386,6 +387,9 @@ def _process_filter_kwarg( else: filter_info = model._meta.get_filter(key) + if isinstance(value, Parameter): + value.model = model + field_object = None if "table" in filter_info: # join the table @@ -395,15 +399,23 @@ def _process_filter_kwarg( == filter_info["table"][filter_info["backward_key"]], ) if "value_encoder" in filter_info: - value = filter_info["value_encoder"](value, model) + if isinstance(value, Parameter): + value.value_encoder = filter_info["value_encoder"] + else: + value = filter_info["value_encoder"](value, model) table = filter_info["table"] elif not isinstance(value, Term): field_object = model._meta.fields_map[filter_info["field"]] - value = ( - filter_info["value_encoder"](value, model, field_object) - if "value_encoder" in filter_info - else field_object.to_db_value(value, model) - ) + value_encoder = filter_info["value_encoder"] if "value_encoder" in filter_info else None + if isinstance(value, Parameter): + value.field_object = field_object + value.value_encoder = value_encoder + else: + value = ( + value_encoder(value, model, field_object) + if value_encoder is not None + else field_object.to_db_value(value, model) + ) op = filter_info["operator"] term: Term = table[filter_info.get("source_field", filter_info["field"])] if field_object is not None: diff --git a/tortoise/parameter.py b/tortoise/parameter.py new file mode 100644 index 000000000..dcda2f5e4 --- /dev/null +++ b/tortoise/parameter.py @@ -0,0 +1,20 @@ +from tortoise.fields import Field + + +class Parameter: + __slots__ = ("name", "model", "value_encoder", "field_object",) + + def __init__(self, name: str) -> None: + self.name = name + self.model = None + self.value_encoder = None + self.field_object: Field | None = None + + def encode_value(self, value: ...) -> ...: + if self.value_encoder: + if self.field_object is not None: + return self.value_encoder(value, self.model, self.field_object) + return self.value_encoder(value, self.model) + if self.field_object is not None: + return self.field_object.to_db_value(value, self.model) + return self.value_encoder \ No newline at end of file diff --git a/tortoise/queryset.py b/tortoise/queryset.py index bb7eb69bc..5dd66536c 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -27,6 +27,7 @@ RelationalField, ) from tortoise.filters import FilterInfoDict +from tortoise.parameter import Parameter from tortoise.query_utils import ( Prefetch, QueryModifier, @@ -1251,6 +1252,73 @@ async def _execute(self) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list + def prepare(self) -> PreparedQuery[MODEL]: + if self._db is None: + self._db = self._choose_db(self._select_for_update) + self._make_query() + sql, params = self.query.get_parameterized_sql() + return PreparedQuery( + sql=sql, + params=params, + db=self._db, + model=self.model, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, + custom_fields=list(self._annotations.keys()), + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + ) + + +class PreparedQuery(AwaitableQuery[MODEL]): + def __init__( + self, sql: str, params: list[Any], db: BaseDBAsyncClient, model: MODEL, prefetch_map: ..., prefetch_queries: ..., + select_related_idx: ..., custom_fields: list[...], single: bool, raise_does_not_exist: bool, + ) -> None: + super().__init__(model) + + self._sql = sql + self._params = params + self._need_params = { + param.name: (param, idx) + for idx, param in enumerate(params) + if isinstance(param, Parameter) + } + self._executor = db.executor_class( + model=model, + db=db, + prefetch_map=prefetch_map, + prefetch_queries=prefetch_queries, + select_related_idx=select_related_idx, + ) + self._model = model + self._custom_fields = custom_fields + self._single = single + self._raise_does_not_exist = raise_does_not_exist + + async def execute(self, **params) -> list[MODEL]: + if self._need_params.keys() != params.keys(): + raise ValueError("One of more parameters does not match prepared parameters") + + filled_params = self._params.copy() + for name, (param, idx) in self._need_params.items(): + filled_params[idx] = param.encode_value(params[name]) + + instance_list = await self._executor.execute_select( + self._sql, filled_params, + custom_fields=self._custom_fields, + ) + if self._single: + if len(instance_list) == 1: + return instance_list[0] + if not instance_list: + if self._raise_does_not_exist: + raise DoesNotExist(self.model) + return None # type: ignore + raise MultipleObjectsReturned(self.model) + return instance_list + class UpdateQuery(AwaitableQuery): __slots__ = ( From bcef40d4fa2eaf3010a1834f6480f74175c40892 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Feb 2026 14:29:07 +0200 Subject: [PATCH 02/57] add support for string filters --- tortoise/filters.py | 62 ++++++++++++++++++++++++++++++++----------- tortoise/parameter.py | 20 +++++++++----- 2 files changed, 61 insertions(+), 21 deletions(-) diff --git a/tortoise/filters.py b/tortoise/filters.py index 9065c6bc8..49da27b86 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -16,11 +16,13 @@ Equality, Term, ValueWrapper, + Function, ) from tortoise.contrib.postgres.fields import ArrayField, TSVectorField from tortoise.fields import Field, JSONField from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance +from tortoise.parameter import Parameter if sys.version_info >= (3, 11): # pragma:nocoverage from typing import NotRequired @@ -142,8 +144,11 @@ def not_null(field: Term, value: Any) -> Criterion: return field.isnull() -def contains(field: Term, value: str) -> Criterion: - return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"%{escape_like(value)}%")) +def contains(field: Term, value: str | Parameter) -> Criterion: + return Like( + Cast(field, SqlTypes.VARCHAR), + field.wrap_constant(_format_str_or_parameter(field, value, True, True)) + ) def search(field: Term, value: str) -> Any: @@ -165,33 +170,60 @@ def insensitive_posix_regex(field: Term, value: str): ) -def starts_with(field: Term, value: str) -> Criterion: - return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"{escape_like(value)}%")) +def _format_str_or_parameter( + field: Term, value: str | Parameter, like_start: bool = False, like_end: bool = False, + escape_func: Callable[[Any], str] = escape_like, +) -> Term: + if isinstance(value, Parameter): + value.encode = escape_func + wrapped = ValueWrapper(value) + if not like_start and not like_end: + return wrapped + args = [] + if like_start: + args.append("%") + args.append(wrapped) + if like_end: + args.append("%") + return Function("Concat", *args) + else: + return field.wrap_constant( + f"{'%' if like_start else ''}" + f"{escape_func(value)}" + f"{'%' if like_end else ''}" + ) + +def starts_with(field: Term, value: str | Parameter) -> Criterion: + return Like(Cast(field, SqlTypes.VARCHAR), _format_str_or_parameter(field, value, False, True)) -def ends_with(field: Term, value: str) -> Criterion: - return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"%{escape_like(value)}")) +def ends_with(field: Term, value: str | Parameter) -> Criterion: + return Like(Cast(field, SqlTypes.VARCHAR), _format_str_or_parameter(field, value, True, False)) -def insensitive_exact(field: Term, value: str) -> Criterion: - return Upper(Cast(field, SqlTypes.VARCHAR)).eq(Upper(str(value))) +def insensitive_exact(field: Term, value: str | Parameter) -> Criterion: + return Upper(Cast(field, SqlTypes.VARCHAR)).eq(Upper(_format_str_or_parameter(field, value, escape_func=str))) -def insensitive_contains(field: Term, value: str) -> Criterion: + +def insensitive_contains(field: Term, value: str | Parameter) -> Criterion: return Like( - Upper(Cast(field, SqlTypes.VARCHAR)), field.wrap_constant(Upper(f"%{escape_like(value)}%")) + Upper(Cast(field, SqlTypes.VARCHAR)), + field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, True))) ) -def insensitive_starts_with(field: Term, value: str) -> Criterion: +def insensitive_starts_with(field: Term, value: str | Parameter) -> Criterion: return Like( - Upper(Cast(field, SqlTypes.VARCHAR)), field.wrap_constant(Upper(f"{escape_like(value)}%")) + Upper(Cast(field, SqlTypes.VARCHAR)), + field.wrap_constant(Upper(_format_str_or_parameter(field, value, False, True))) ) -def insensitive_ends_with(field: Term, value: str) -> Criterion: +def insensitive_ends_with(field: Term, value: str | Parameter) -> Criterion: return Like( - Upper(Cast(field, SqlTypes.VARCHAR)), field.wrap_constant(Upper(f"%{escape_like(value)}")) + Upper(Cast(field, SqlTypes.VARCHAR)), + field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, False))) ) @@ -240,7 +272,7 @@ def json_contained_by(field: Term, value: str) -> Criterion: def json_filter(field: Term, value: dict) -> Criterion: - raise NotImplementedError("must be overridden in each xecutor") + raise NotImplementedError("must be overridden in each executor") def array_contains(field: Term, value: Any | Sequence[Any]) -> Criterion: diff --git a/tortoise/parameter.py b/tortoise/parameter.py index dcda2f5e4..fab6eec7d 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -2,19 +2,27 @@ class Parameter: - __slots__ = ("name", "model", "value_encoder", "field_object",) + __slots__ = ("name", "model", "value_encoder", "field_object", "encode",) def __init__(self, name: str) -> None: self.name = name self.model = None self.value_encoder = None self.field_object: Field | None = None + self.encode = None def encode_value(self, value: ...) -> ...: + encoded = value + if self.value_encoder: if self.field_object is not None: - return self.value_encoder(value, self.model, self.field_object) - return self.value_encoder(value, self.model) - if self.field_object is not None: - return self.field_object.to_db_value(value, self.model) - return self.value_encoder \ No newline at end of file + encoded = self.value_encoder(value, self.model, self.field_object) + else: + encoded = self.value_encoder(value, self.model) + elif self.field_object is not None: + encoded = self.field_object.to_db_value(value, self.model) + + if self.encode: + encoded = self.encode(encoded) + + return encoded \ No newline at end of file From b27f5859da9547944adaa44c5c770aa661aecb6b Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 8 Feb 2026 18:01:23 +0200 Subject: [PATCH 03/57] trying to cache prepared queries for params-collections (i.e. tuple/list/etc.) with different sizes --- tortoise/parameter.py | 16 ++++++++-- tortoise/queryset.py | 71 +++++++++++++++++++++++++++++++------------ 2 files changed, 66 insertions(+), 21 deletions(-) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index fab6eec7d..f44bd2401 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -1,8 +1,9 @@ from tortoise.fields import Field +from pypika_tortoise import SqlContext class Parameter: - __slots__ = ("name", "model", "value_encoder", "field_object", "encode",) + __slots__ = ("name", "model", "value_encoder", "field_object", "encode", "container_size",) def __init__(self, name: str) -> None: self.name = name @@ -10,6 +11,7 @@ def __init__(self, name: str) -> None: self.value_encoder = None self.field_object: Field | None = None self.encode = None + self.container_size = None def encode_value(self, value: ...) -> ...: encoded = value @@ -25,4 +27,14 @@ def encode_value(self, value: ...) -> ...: if self.encode: encoded = self.encode(encoded) - return encoded \ No newline at end of file + return encoded + + def get_sql(self, ctx: SqlContext) -> str: + if self.container_size is None: + if ctx.parameterizer is not None: + ctx.parameterizer.create_param(self) + return "?" + else: + if ctx.parameterizer is not None: + ctx.parameterizer.create_param(self) + return f"({','.join(['?' for _ in range(self.container_size)])})" \ No newline at end of file diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 5dd66536c..81e46e95b 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1256,10 +1256,8 @@ def prepare(self) -> PreparedQuery[MODEL]: if self._db is None: self._db = self._choose_db(self._select_for_update) self._make_query() - sql, params = self.query.get_parameterized_sql() return PreparedQuery( - sql=sql, - params=params, + query=self.query, db=self._db, model=self.model, prefetch_map=self._prefetch_map, @@ -1271,20 +1269,37 @@ def prepare(self) -> PreparedQuery[MODEL]: ) +class CachedSql: + def __init__(self, sql: str, params: list[Parameter | Any]) -> None: + self.sql = sql + self.params = params + self.need_params = { + param.name: (param, idx) + for idx, param in enumerate(params) + if isinstance(param, Parameter) + } + + def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + if self.need_params.keys() != params.keys(): + raise ValueError("One of more parameters does not match prepared parameters") + + filled_params = self.params.copy() + for name, (param, idx) in self.need_params.items(): + filled_params[idx] = param.encode_value(params[name]) + + return filled_params + + class PreparedQuery(AwaitableQuery[MODEL]): def __init__( - self, sql: str, params: list[Any], db: BaseDBAsyncClient, model: MODEL, prefetch_map: ..., prefetch_queries: ..., - select_related_idx: ..., custom_fields: list[...], single: bool, raise_does_not_exist: bool, + self, query: QueryBuilder, db: BaseDBAsyncClient, model: type[MODEL], prefetch_map: ..., + prefetch_queries: ..., select_related_idx: ..., custom_fields: list[...], single: bool, + raise_does_not_exist: bool, ) -> None: super().__init__(model) - self._sql = sql - self._params = params - self._need_params = { - param.name: (param, idx) - for idx, param in enumerate(params) - if isinstance(param, Parameter) - } + self._query = query + self._cached_sql = {} self._executor = db.executor_class( model=model, db=db, @@ -1297,16 +1312,34 @@ def __init__( self._single = single self._raise_does_not_exist = raise_does_not_exist - async def execute(self, **params) -> list[MODEL]: - if self._need_params.keys() != params.keys(): - raise ValueError("One of more parameters does not match prepared parameters") + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + _, sql_params = self._query.get_parameterized_sql() + need_params = { + param.name: param + for param in sql_params + if isinstance(param, Parameter) + } - filled_params = self._params.copy() - for name, (param, idx) in self._need_params.items(): - filled_params[idx] = param.encode_value(params[name]) + cache_key = "query" + for param, value in params.items(): + if param not in need_params: + continue + if isinstance(value, (tuple, list, set)): + cache_key += f"-{param}{len(value)}" + need_params[param].container_size = len(value) + + if cache_key not in self._cached_sql: + sql, params = self._query.get_parameterized_sql() + self._cached_sql[cache_key] = CachedSql(sql, params) + + return self._cached_sql[cache_key] + + async def execute(self, **params) -> list[MODEL]: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) instance_list = await self._executor.execute_select( - self._sql, filled_params, + cached_query.sql, filled_params, custom_fields=self._custom_fields, ) if self._single: From 543020c116da6cdbd561446d92bbd58a10316925 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 9 Feb 2026 01:18:33 +0200 Subject: [PATCH 04/57] cache prepared queries for params-collections with different sizes --- tortoise/parameter.py | 30 ++++++++++++++++++++++++++++-- tortoise/queryset.py | 20 ++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index f44bd2401..4848a6316 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -1,18 +1,33 @@ +from typing import Self + from tortoise.fields import Field from pypika_tortoise import SqlContext class Parameter: - __slots__ = ("name", "model", "value_encoder", "field_object", "encode", "container_size",) + __slots__ = ("name", "model", "value_encoder", "field_object", "encode", "container_size", "container_encoder",) def __init__(self, name: str) -> None: self.name = name self.model = None self.value_encoder = None + self.container_encoder = None self.field_object: Field | None = None self.encode = None self.container_size = None + def clone(self) -> Self: + new = self.__new__(self.__class__) + new.name = self.name + new.model = self.model + new.container_encoder = self.container_encoder + new.value_encoder = self.value_encoder + new.field_object = self.field_object + new.encode = self.encode + new.container_size = self.container_size + + return new + def encode_value(self, value: ...) -> ...: encoded = value @@ -29,6 +44,12 @@ def encode_value(self, value: ...) -> ...: return encoded + def encode_container(self, value: ...) -> ...: + if self.field_object is not None: + return self.container_encoder(value, self.model, self.field_object) + else: + return self.container_encoder(value, self.model) + def get_sql(self, ctx: SqlContext) -> str: if self.container_size is None: if ctx.parameterizer is not None: @@ -36,5 +57,10 @@ def get_sql(self, ctx: SqlContext) -> str: return "?" else: if ctx.parameterizer is not None: - ctx.parameterizer.create_param(self) + for idx in range(self.container_size): + new_param = self.clone() + new_param.name += f"[{idx}]" + new_param.container_encoder = new_param.value_encoder + new_param.value_encoder = None + ctx.parameterizer.create_param(new_param) return f"({','.join(['?' for _ in range(self.container_size)])})" \ No newline at end of file diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 81e46e95b..1548bc09a 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1280,6 +1280,20 @@ def __init__(self, sql: str, params: list[Parameter | Any]) -> None: } def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + add_params = [] + del_params = [] + for param_name, value in params.items(): + if not isinstance(value, (list, tuple, set)): + continue + for idx, item in enumerate(self.need_params[f"{param_name}[0]"][0].encode_container(value)): + add_params.append((f"{param_name}[{idx}]", item)) + del_params.append(param_name) + + for param_name, value in add_params: + params[param_name] = value + for param_name in del_params: + del params[param_name] + if self.need_params.keys() != params.keys(): raise ValueError("One of more parameters does not match prepared parameters") @@ -1320,6 +1334,8 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: if isinstance(param, Parameter) } + reset_params = [] + cache_key = "query" for param, value in params.items(): if param not in need_params: @@ -1327,11 +1343,15 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: if isinstance(value, (tuple, list, set)): cache_key += f"-{param}{len(value)}" need_params[param].container_size = len(value) + reset_params.append(need_params[param]) if cache_key not in self._cached_sql: sql, params = self._query.get_parameterized_sql() self._cached_sql[cache_key] = CachedSql(sql, params) + for param in reset_params: + param.container_size = None + return self._cached_sql[cache_key] async def execute(self, **params) -> list[MODEL]: From 175f0da2fd2db448f816fe392f0b5c8343c66f41 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 10 Feb 2026 11:38:42 +0200 Subject: [PATCH 05/57] rewrite collections params processing in CachedSql --- tortoise/filters.py | 6 +++++- tortoise/parameter.py | 39 +++++++++++++++++++++++------------ tortoise/queryset.py | 47 ++++++++++++++++++++++--------------------- 3 files changed, 55 insertions(+), 37 deletions(-) diff --git a/tortoise/filters.py b/tortoise/filters.py index 49da27b86..8e38845c8 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -22,7 +22,7 @@ from tortoise.contrib.postgres.fields import ArrayField, TSVectorField from tortoise.fields import Field, JSONField from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance -from tortoise.parameter import Parameter +from tortoise.parameter import Parameter, CollectionParameter if sys.version_info >= (3, 11): # pragma:nocoverage from typing import NotRequired @@ -104,6 +104,8 @@ def array_encoder(value: Any | Sequence[Any], instance: Model, field: Field) -> def is_in(field: Term, value: Any) -> Criterion: if value: + if isinstance(value, Parameter): + value = CollectionParameter.from_simple_param(value) return field.isin(value) # SQL has no False, so we return 1=0 return BasicCriterion( @@ -115,6 +117,8 @@ def is_in(field: Term, value: Any) -> Criterion: def not_in(field: Term, value: Any) -> Criterion: if value: + if isinstance(value, Parameter): + value = CollectionParameter.from_simple_param(value) return field.notin(value) | field.isnull() # SQL has no True, so we return 1=1 return BasicCriterion( diff --git a/tortoise/parameter.py b/tortoise/parameter.py index 4848a6316..9e27193a5 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -5,26 +5,22 @@ class Parameter: - __slots__ = ("name", "model", "value_encoder", "field_object", "encode", "container_size", "container_encoder",) + __slots__ = ("name", "model", "value_encoder", "field_object", "encode",) def __init__(self, name: str) -> None: self.name = name self.model = None self.value_encoder = None - self.container_encoder = None self.field_object: Field | None = None self.encode = None - self.container_size = None def clone(self) -> Self: new = self.__new__(self.__class__) new.name = self.name new.model = self.model - new.container_encoder = self.container_encoder new.value_encoder = self.value_encoder new.field_object = self.field_object new.encode = self.encode - new.container_size = self.container_size return new @@ -44,23 +40,40 @@ def encode_value(self, value: ...) -> ...: return encoded - def encode_container(self, value: ...) -> ...: + +class CollectionParameter(Parameter): + __slots__ = ("collection_size", "collection_encoder",) + + def __init__(self, name: str) -> None: + super().__init__(name) + self.collection_size = None + self.collection_encoder = None + + @classmethod + def from_simple_param(cls, param: Parameter) -> Self: + new_param = cls(param.name) + new_param.model = param.model + new_param.value_encoder = param.value_encoder + new_param.field_object = param.field_object + new_param.encode = param.encode + return new_param + + def encode_collection(self, value: ...) -> ...: if self.field_object is not None: - return self.container_encoder(value, self.model, self.field_object) + return self.collection_encoder(value, self.model, self.field_object) else: - return self.container_encoder(value, self.model) + return self.collection_encoder(value, self.model) def get_sql(self, ctx: SqlContext) -> str: - if self.container_size is None: + if self.collection_size is None: if ctx.parameterizer is not None: ctx.parameterizer.create_param(self) return "?" else: if ctx.parameterizer is not None: - for idx in range(self.container_size): + for idx in range(self.collection_size): new_param = self.clone() - new_param.name += f"[{idx}]" - new_param.container_encoder = new_param.value_encoder + new_param.collection_encoder = new_param.value_encoder new_param.value_encoder = None ctx.parameterizer.create_param(new_param) - return f"({','.join(['?' for _ in range(self.container_size)])})" \ No newline at end of file + return f"({','.join(['?' for _ in range(self.collection_size)])})" \ No newline at end of file diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 1548bc09a..c4e7e2bc2 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -27,7 +27,7 @@ RelationalField, ) from tortoise.filters import FilterInfoDict -from tortoise.parameter import Parameter +from tortoise.parameter import Parameter, CollectionParameter from tortoise.query_utils import ( Prefetch, QueryModifier, @@ -1273,34 +1273,35 @@ class CachedSql: def __init__(self, sql: str, params: list[Parameter | Any]) -> None: self.sql = sql self.params = params - self.need_params = { - param.name: (param, idx) - for idx, param in enumerate(params) - if isinstance(param, Parameter) - } - - def make_filled_params(self, params: dict[str, Any]) -> list[Any]: - add_params = [] - del_params = [] - for param_name, value in params.items(): - if not isinstance(value, (list, tuple, set)): + self.param_by_name: dict[str, Parameter] = {} + self.need_params: dict[str, int] = {} + self.need_collection_params: dict[str, list[int]] = defaultdict(list) + for idx, param in enumerate(params): + if not isinstance(param, Parameter): continue - for idx, item in enumerate(self.need_params[f"{param_name}[0]"][0].encode_container(value)): - add_params.append((f"{param_name}[{idx}]", item)) - del_params.append(param_name) - for param_name, value in add_params: - params[param_name] = value - for param_name in del_params: - del params[param_name] + if param.name not in self.param_by_name: + self.param_by_name[param.name] = param - if self.need_params.keys() != params.keys(): - raise ValueError("One of more parameters does not match prepared parameters") + if isinstance(param, CollectionParameter): + self.need_collection_params[param.name].append(idx) + else: + self.need_params[param.name] = idx + + def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + # TODO: check for parameters mismatch filled_params = self.params.copy() for name, (param, idx) in self.need_params.items(): filled_params[idx] = param.encode_value(params[name]) + for name, indexes in self.need_collection_params.items(): + param = self.param_by_name[name] + collection = param.encode_collection(params[name]) + # TODO: check that len(value) mathes len(indexes) + for idx, value in zip(indexes, collection): + filled_params[idx] = param.encode_value(value) + return filled_params @@ -1342,7 +1343,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: continue if isinstance(value, (tuple, list, set)): cache_key += f"-{param}{len(value)}" - need_params[param].container_size = len(value) + need_params[param].collection_size = len(value) reset_params.append(need_params[param]) if cache_key not in self._cached_sql: @@ -1350,7 +1351,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: self._cached_sql[cache_key] = CachedSql(sql, params) for param in reset_params: - param.container_size = None + param.collection_size = None return self._cached_sql[cache_key] From 39f2c3cf86ce73121bcc210e39710f6f100890d4 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 10 Feb 2026 21:20:54 +0200 Subject: [PATCH 06/57] check collection length in make_filled_params --- idk_test.py | 115 +++++++++++++++++++++++++++++++++++++++++++ tortoise/queryset.py | 8 ++- 2 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 idk_test.py diff --git a/idk_test.py b/idk_test.py new file mode 100644 index 000000000..1c07ac176 --- /dev/null +++ b/idk_test.py @@ -0,0 +1,115 @@ +from tortoise import fields, run_async +from tortoise.contrib.test import init_memory_sqlite +from tortoise.models import Model +from tortoise.parameter import Parameter + + +CHECK_ACTUAL = True + + +class SomeModel(Model): + id: int = fields.BigIntField(pk=True) + name: str = fields.TextField() + + +async def t0_sanity_check(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + idk = await SomeModel.filter(id=some2.id) + print(idk) + + +async def t1_simple_gte(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + query = SomeModel.filter(id__gte=Parameter("idk")) + prepared = query.prepare() + actual = await prepared.execute(idk=some2.id) + print(actual) + + if CHECK_ACTUAL: + expected = await SomeModel.filter(id__gte=some2.id) + print(expected) + print(actual == expected) + + +async def t2_simple_string_param(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + query = SomeModel.filter(name=Parameter("idk")) + prepared = query.prepare() + actual1 = await prepared.execute(idk=some2.id) + print(actual1) + actual2 = await prepared.execute(idk=some2.name) + print(actual2) + + if CHECK_ACTUAL: + expected1 = await SomeModel.filter(name=some2.id) + expected2 = await SomeModel.filter(name=some2.name) + print(expected1) + print(expected2) + print(actual1 == expected1) + print(actual2 == expected2) + + +async def t3_startswith(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + query = SomeModel.filter(name__startswith=Parameter("idk")) + prepared = query.prepare() + actual1 = await prepared.execute(idk=some2.id) + print(actual1) + actual2 = await prepared.execute(idk=some2.name) + print(actual2) + actual3 = await prepared.execute(idk="asd") + print(actual3) + actual4 = await prepared.execute(idk="qwe") + print(actual4) + + if CHECK_ACTUAL: + expected1 = await SomeModel.filter(name__startswith=some2.id) + expected2 = await SomeModel.filter(name__startswith=some2.name) + expected3 = await SomeModel.filter(name__startswith="asd") + expected4 = await SomeModel.filter(name__startswith="qwe") + print(expected1) + print(expected2) + print(expected3) + print(expected4) + print(actual1 == expected1) + print(actual2 == expected2) + print(actual3 == expected3) + print(actual4 == expected4) + + +async def t4_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + query = SomeModel.filter(id__in=Parameter("idk")) + prepared = query.prepare() + actual1 = await prepared.execute(idk=[some2.id, some1.id]) + print(actual1) + actual2 = await prepared.execute(idk=[some3.id, some3.id * 2, some3.id * 10]) + print(actual2) + + if CHECK_ACTUAL: + expected1 = await SomeModel.filter(id__in=[some2.id, some1.id]) + expected2 = await SomeModel.filter(id__in=[some3.id, some3.id * 2, some3.id * 10]) + print(expected1) + print(expected2) + print(actual1 == expected1) + print(actual2 == expected2) + + +TESTS = [ + # t0_sanity_check, + # t1_simple_gte, + # t2_simple_string_param, + # t3_startswith, + t4_in, +] + + +@init_memory_sqlite +async def run() -> None: + some1 = await SomeModel.create(name="asdqwe") + some2 = await SomeModel.create(name="asdqweasd") + some3 = await SomeModel.create(name="asdqweasd123") + + for test_func in TESTS: + print(f"Running {test_func.__name__} ...") + await test_func(some1, some2, some3) + print("=" * 32) + + +if __name__ == "__main__": + run_async(run()) \ No newline at end of file diff --git a/tortoise/queryset.py b/tortoise/queryset.py index c4e7e2bc2..aa70da69b 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1298,7 +1298,12 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: for name, indexes in self.need_collection_params.items(): param = self.param_by_name[name] collection = param.encode_collection(params[name]) - # TODO: check that len(value) mathes len(indexes) + if len(collection) != len(indexes): + raise ValueError( + f"Provided value length (len(collection)) " + f"for parameter {name!r} does not match " + f"parameter indexes length ({len(indexes)})" + ) for idx, value in zip(indexes, collection): filled_params[idx] = param.encode_value(value) @@ -1328,6 +1333,7 @@ def __init__( self._raise_does_not_exist = raise_does_not_exist def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + # TODO: cache this _, sql_params = self._query.get_parameterized_sql() need_params = { param.name: param From 39ae78bb34fb59669fcaeb6c55b74877cc342581 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Fri, 13 Feb 2026 22:47:08 +0200 Subject: [PATCH 07/57] add comparison of prepared/not-prepared queries --- idk_test.py | 51 +++++++++++++++++++++++++++++++++++++++++++- tortoise/queryset.py | 32 +++++++++++++-------------- 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/idk_test.py b/idk_test.py index 1c07ac176..d243e473b 100644 --- a/idk_test.py +++ b/idk_test.py @@ -1,5 +1,10 @@ +import random +import time + from tortoise import fields, run_async from tortoise.contrib.test import init_memory_sqlite +from tortoise.expressions import Q +from tortoise.functions import Min, Max from tortoise.models import Model from tortoise.parameter import Parameter @@ -90,12 +95,56 @@ async def t4_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: print(actual2 == expected2) +async def t5_compare_prepared_non_prepared(*_) -> None: + ITERS = 1000 + + prefix = f"{time.time()}-" + await SomeModel.bulk_create([ + SomeModel(name=f"{prefix}{num}") + for num in range(1000) + ]) + + min_id, max_id = await SomeModel.filter(name__startswith=prefix).annotate(max_id=Max("id"), min_id=Min("id")).first().values_list("min_id", "max_id") + random_id = random.randint(min_id, max_id) + + random_ids = await SomeModel.filter(name__startswith=prefix).values_list("id", flat=True) + random.shuffle(random_ids) + random_ids = random_ids[:2] + + start_time = time.perf_counter() + for _ in range(ITERS): + await SomeModel.filter(Q(id__lte=random_id * 2, id__in=random_ids, join_type=Q.OR), id__gte=random_id) + end_time = time.perf_counter() + non_prepared_millis = (end_time - start_time) * 1000 + print(f"Non-prepared: {non_prepared_millis:.2f}ms") + + start_time = time.perf_counter() + query = SomeModel.filter(Q(id__lte=Parameter("id_lte"), id__in=Parameter("id_in"), join_type=Q.OR), id__gte=Parameter("id_gte")).prepare() + for _ in range(ITERS): + await query.execute(id_lte=random_id * 2, id_gte=random_id, id_in=random_ids) + end_time = time.perf_counter() + prepared_millis = (end_time - start_time) * 1000 + print(f"Prepared: {prepared_millis:.2f}ms") + + if non_prepared_millis > prepared_millis: + ratio = non_prepared_millis / prepared_millis + result = "faster" + else: + ratio = prepared_millis / non_prepared_millis + result = "slower" + + print(f"Prepared is {(ratio - 1) * 100:.2f}% {result} than non-prepared") + + await SomeModel.filter(name__startswith=prefix).delete() + + TESTS = [ # t0_sanity_check, # t1_simple_gte, # t2_simple_string_param, # t3_startswith, - t4_in, + # t4_in, + t5_compare_prepared_non_prepared, ] diff --git a/tortoise/queryset.py b/tortoise/queryset.py index aa70da69b..438c845ca 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1292,15 +1292,16 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: # TODO: check for parameters mismatch filled_params = self.params.copy() - for name, (param, idx) in self.need_params.items(): + for name, idx in self.need_params.items(): + param = self.param_by_name[name] filled_params[idx] = param.encode_value(params[name]) for name, indexes in self.need_collection_params.items(): - param = self.param_by_name[name] + param = cast(CollectionParameter, self.param_by_name[name]) collection = param.encode_collection(params[name]) if len(collection) != len(indexes): raise ValueError( - f"Provided value length (len(collection)) " + f"Provided value length ({len(collection)}) " f"for parameter {name!r} does not match " f"parameter indexes length ({len(indexes)})" ) @@ -1331,26 +1332,25 @@ def __init__( self._custom_fields = custom_fields self._single = single self._raise_does_not_exist = raise_does_not_exist - - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - # TODO: cache this - _, sql_params = self._query.get_parameterized_sql() - need_params = { + _, params = self._query.get_parameterized_sql() + self._dynamic_params = { param.name: param - for param in sql_params - if isinstance(param, Parameter) + for param in params + if isinstance(param, CollectionParameter) } + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: reset_params = [] cache_key = "query" - for param, value in params.items(): - if param not in need_params: + for name, value in params.items(): + if name not in self._dynamic_params or not isinstance(value, (tuple, list, set)): continue - if isinstance(value, (tuple, list, set)): - cache_key += f"-{param}{len(value)}" - need_params[param].collection_size = len(value) - reset_params.append(need_params[param]) + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) if cache_key not in self._cached_sql: sql, params = self._query.get_parameterized_sql() From 6f85c4e321d95a5453b79cb37596499ae586e9c4 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sat, 14 Feb 2026 15:29:40 +0200 Subject: [PATCH 08/57] fix params in subqueries --- idk_test.py | 54 ++++++++++++++++++++++++++++++++++++++----- tortoise/parameter.py | 40 ++++++++++++++++++++++++++++---- tortoise/queryset.py | 6 +++-- 3 files changed, 87 insertions(+), 13 deletions(-) diff --git a/idk_test.py b/idk_test.py index d243e473b..8233a82c0 100644 --- a/idk_test.py +++ b/idk_test.py @@ -3,7 +3,7 @@ from tortoise import fields, run_async from tortoise.contrib.test import init_memory_sqlite -from tortoise.expressions import Q +from tortoise.expressions import Q, Subquery from tortoise.functions import Min, Max from tortoise.models import Model from tortoise.parameter import Parameter @@ -17,6 +17,12 @@ class SomeModel(Model): name: str = fields.TextField() +class SomeForeignKeyModel(Model): + id: int = fields.BigIntField(pk=True) + info: str = fields.CharField(max_length=128, default="") + some: SomeModel = fields.ForeignKeyField("models.SomeModel") + + async def t0_sanity_check(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: idk = await SomeModel.filter(id=some2.id) print(idk) @@ -138,13 +144,49 @@ async def t5_compare_prepared_non_prepared(*_) -> None: await SomeModel.filter(name__startswith=prefix).delete() +async def t6_subqueries(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + query = SomeModel.filter(id__in=Subquery(SomeModel.filter(Q(id=Parameter("idk1")) | Q(id=Parameter("idk2"))).values("id"))) + prepared = query.prepare() + actual1 = await prepared.execute(idk1=some2.id, idk2=some1.id) + print(actual1) + actual2 = await prepared.execute(idk1=some3.id, idk2=some3.id * 2) + print(actual2) + + if CHECK_ACTUAL: + expected1 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(Q(id=some2.id) | Q(id=some1.id)).values("id"))) + expected2 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(Q(id=some3.id) | Q(id=some3.id * 2)).values("id"))) + print(expected1) + print(expected2) + print(actual1 == expected1) + print(actual2 == expected2) + + +async def t7_subqueries_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + query = SomeModel.filter(id__in=Subquery(SomeModel.filter(id__in=Parameter("idk")).values("id"))) + prepared = query.prepare() + actual1 = await prepared.execute(idk=[some2.id, some1.id]) + print(actual1) + actual2 = await prepared.execute(idk=[some3.id, some3.id * 2, some3.id * 10]) + print(actual2) + + if CHECK_ACTUAL: + expected1 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(id__in=[some2.id, some1.id]).values("id"))) + expected2 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(id__in=[some3.id, some3.id * 2, some3.id * 10]).values("id"))) + print(expected1) + print(expected2) + print(actual1 == expected1) + print(actual2 == expected2) + + TESTS = [ - # t0_sanity_check, - # t1_simple_gte, - # t2_simple_string_param, - # t3_startswith, - # t4_in, + t0_sanity_check, + t1_simple_gte, + t2_simple_string_param, + t3_startswith, + t4_in, t5_compare_prepared_non_prepared, + t6_subqueries, + t7_subqueries_in, ] diff --git a/tortoise/parameter.py b/tortoise/parameter.py index 9e27193a5..e78c85ee0 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -1,9 +1,35 @@ +from __future__ import annotations + +from dataclasses import dataclass from typing import Self from tortoise.fields import Field from pypika_tortoise import SqlContext +@dataclass(frozen=True) +class TortoiseSqlContext(SqlContext): + dynamic_params: dict[str, CollectionParameter] | None = None + + def copy(self: SqlContext, **kwargs) -> SqlContext: + existing_dynamic_params = self.dynamic_params if isinstance(self, TortoiseSqlContext) else None + return TortoiseSqlContext( + quote_char=kwargs.get("quote_char", self.quote_char), + secondary_quote_char=kwargs.get("secondary_quote_char", self.secondary_quote_char), + alias_quote_char=kwargs.get("alias_quote_char", self.alias_quote_char), + dialect=kwargs.get("dialect", self.dialect), + as_keyword=kwargs.get("as_keyword", self.as_keyword), + subquery=kwargs.get("subquery", self.subquery), + with_alias=kwargs.get("with_alias", self.with_alias), + with_namespace=kwargs.get("with_namespace", self.with_namespace), + subcriterion=kwargs.get("subcriterion", self.subcriterion), + parameterizer=kwargs.get("parameterizer", self.parameterizer), + groupby_alias=kwargs.get("groupby_alias", self.groupby_alias), + orderby_alias=kwargs.get("orderby_alias", self.orderby_alias), + dynamic_params=kwargs.get("dynamic_params", existing_dynamic_params), + ) + + class Parameter: __slots__ = ("name", "model", "value_encoder", "field_object", "encode",) @@ -65,15 +91,19 @@ def encode_collection(self, value: ...) -> ...: return self.collection_encoder(value, self.model) def get_sql(self, ctx: SqlContext) -> str: - if self.collection_size is None: + param = self + if isinstance(ctx, TortoiseSqlContext): + param = ctx.dynamic_params.get(self.name, self) + + if param.collection_size is None: if ctx.parameterizer is not None: - ctx.parameterizer.create_param(self) + ctx.parameterizer.create_param(param) return "?" else: if ctx.parameterizer is not None: - for idx in range(self.collection_size): - new_param = self.clone() + for idx in range(param.collection_size): + new_param = param.clone() new_param.collection_encoder = new_param.value_encoder new_param.value_encoder = None ctx.parameterizer.create_param(new_param) - return f"({','.join(['?' for _ in range(self.collection_size)])})" \ No newline at end of file + return f"({','.join(['?' for _ in range(param.collection_size)])})" \ No newline at end of file diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 438c845ca..fbb8598bc 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -27,7 +27,7 @@ RelationalField, ) from tortoise.filters import FilterInfoDict -from tortoise.parameter import Parameter, CollectionParameter +from tortoise.parameter import Parameter, CollectionParameter, TortoiseSqlContext from tortoise.query_utils import ( Prefetch, QueryModifier, @@ -1353,7 +1353,9 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: reset_params.append(param) if cache_key not in self._cached_sql: - sql, params = self._query.get_parameterized_sql() + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self._query.get_parameterized_sql(ctx) self._cached_sql[cache_key] = CachedSql(sql, params) for param in reset_params: From 99650c73b9eda8d3a0d7a4f42b6dc4cfe8e43588 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 14:19:41 +0200 Subject: [PATCH 09/57] add global prepared queryset cache --- idk_test.py | 20 +-- tortoise/models.py | 7 + tortoise/queryset.py | 395 ++++++++++++++++++++++++++++++++++--------- 3 files changed, 332 insertions(+), 90 deletions(-) diff --git a/idk_test.py b/idk_test.py index 8233a82c0..655d1e3e8 100644 --- a/idk_test.py +++ b/idk_test.py @@ -29,8 +29,7 @@ async def t0_sanity_check(some1: SomeModel, some2: SomeModel, some3: SomeModel) async def t1_simple_gte(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - query = SomeModel.filter(id__gte=Parameter("idk")) - prepared = query.prepare() + prepared = SomeModel.prepare_sql("some_query1").filter(id__gte=Parameter("idk")).prepared() actual = await prepared.execute(idk=some2.id) print(actual) @@ -41,8 +40,7 @@ async def t1_simple_gte(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> async def t2_simple_string_param(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - query = SomeModel.filter(name=Parameter("idk")) - prepared = query.prepare() + prepared = SomeModel.prepare_sql("some_query2").filter(name=Parameter("idk")).prepared() actual1 = await prepared.execute(idk=some2.id) print(actual1) actual2 = await prepared.execute(idk=some2.name) @@ -58,8 +56,7 @@ async def t2_simple_string_param(some1: SomeModel, some2: SomeModel, some3: Some async def t3_startswith(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - query = SomeModel.filter(name__startswith=Parameter("idk")) - prepared = query.prepare() + prepared = SomeModel.prepare_sql("some_query3").filter(name__startswith=Parameter("idk")).prepared() actual1 = await prepared.execute(idk=some2.id) print(actual1) actual2 = await prepared.execute(idk=some2.name) @@ -85,8 +82,7 @@ async def t3_startswith(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> async def t4_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - query = SomeModel.filter(id__in=Parameter("idk")) - prepared = query.prepare() + prepared = SomeModel.prepare_sql("some_query4").filter(id__in=Parameter("idk")).prepared() actual1 = await prepared.execute(idk=[some2.id, some1.id]) print(actual1) actual2 = await prepared.execute(idk=[some3.id, some3.id * 2, some3.id * 10]) @@ -125,7 +121,7 @@ async def t5_compare_prepared_non_prepared(*_) -> None: print(f"Non-prepared: {non_prepared_millis:.2f}ms") start_time = time.perf_counter() - query = SomeModel.filter(Q(id__lte=Parameter("id_lte"), id__in=Parameter("id_in"), join_type=Q.OR), id__gte=Parameter("id_gte")).prepare() + query = SomeModel.prepare_sql("some_query5").filter(Q(id__lte=Parameter("id_lte"), id__in=Parameter("id_in"), join_type=Q.OR), id__gte=Parameter("id_gte")).prepared() for _ in range(ITERS): await query.execute(id_lte=random_id * 2, id_gte=random_id, id_in=random_ids) end_time = time.perf_counter() @@ -145,8 +141,7 @@ async def t5_compare_prepared_non_prepared(*_) -> None: async def t6_subqueries(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - query = SomeModel.filter(id__in=Subquery(SomeModel.filter(Q(id=Parameter("idk1")) | Q(id=Parameter("idk2"))).values("id"))) - prepared = query.prepare() + prepared = SomeModel.prepare_sql("some_query6").filter(id__in=Subquery(SomeModel.filter(Q(id=Parameter("idk1")) | Q(id=Parameter("idk2"))).values("id"))).prepared() actual1 = await prepared.execute(idk1=some2.id, idk2=some1.id) print(actual1) actual2 = await prepared.execute(idk1=some3.id, idk2=some3.id * 2) @@ -162,8 +157,7 @@ async def t6_subqueries(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> async def t7_subqueries_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - query = SomeModel.filter(id__in=Subquery(SomeModel.filter(id__in=Parameter("idk")).values("id"))) - prepared = query.prepare() + prepared = SomeModel.prepare_sql("some_query7").filter(id__in=Subquery(SomeModel.filter(id__in=Parameter("idk")).values("id"))).prepared() actual1 = await prepared.execute(idk=[some2.id, some1.id]) print(actual1) actual2 = await prepared.execute(idk=[some3.id, some3.id * 2, some3.id * 10]) diff --git a/tortoise/models.py b/tortoise/models.py index c536a4fda..8d8b96caa 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -49,6 +49,7 @@ QuerySet, QuerySetSingle, RawSQLQuery, + PreparedQuerySet, ) from tortoise.router import router from tortoise.signals import Signals @@ -213,6 +214,7 @@ class MetaInfo: "db_complex_fields", "_default_ordering", "_ordering_validated", + "query_cache", ) def __init__(self, meta: Model.Meta) -> None: @@ -252,6 +254,7 @@ def __init__(self, meta: Model.Meta) -> None: self.db_native_fields: list[tuple[str, str, Field]] = [] self.db_default_fields: list[tuple[str, str, Field]] = [] self.db_complex_fields: list[tuple[str, str, Field]] = [] + self.query_cache: dict[str, PreparedQuerySet] = {} @property def full_name(self) -> str: @@ -1486,6 +1489,10 @@ async def fetch_for_list( db = using_db or cls._choose_db() await db.executor_class(model=cls, db=db).fetch_for_list(instance_list, *args) + @classmethod + def prepare_sql(cls, key: str) -> PreparedQuerySet[MODEL]: + return cls._meta.manager.get_queryset().prepare_sql(key) + @classmethod def _check(cls) -> None: """ diff --git a/tortoise/queryset.py b/tortoise/queryset.py index fbb8598bc..ffb19a8c3 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1,10 +1,11 @@ from __future__ import annotations +import functools import types from collections import defaultdict from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable from copy import copy -from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count @@ -1252,99 +1253,127 @@ async def _execute(self) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list - def prepare(self) -> PreparedQuery[MODEL]: - if self._db is None: - self._db = self._choose_db(self._select_for_update) - self._make_query() - return PreparedQuery( - query=self.query, - db=self._db, - model=self.model, - prefetch_map=self._prefetch_map, - prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, - custom_fields=list(self._annotations.keys()), - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - ) + def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: + """ + Cache generated sql of this query set. + If query set is already in cache, return cached version with already generated sql. + """ + if key in self.model._meta.query_cache: + return self.model._meta.query_cache[key] + # TODO: add some arg to _clone to override class? + # to be able to to something like self._clone(PreparedQuerySet) + queryset = PreparedQuerySet(self.model) + queryset.fields = self.fields + queryset.model = self.model + queryset.query = self.query + queryset.capabilities = self.capabilities + queryset._prefetch_map = copy(self._prefetch_map) + queryset._prefetch_queries = copy(self._prefetch_queries) + queryset._single = self._single + queryset._raise_does_not_exist = self._raise_does_not_exist + queryset._db = self._db + queryset._limit = self._limit + queryset._offset = self._offset + queryset._fields_for_select = self._fields_for_select + queryset._filter_kwargs = copy(self._filter_kwargs) + queryset._orderings = copy(self._orderings) + queryset._joined_tables = copy(self._joined_tables) + queryset._q_objects = copy(self._q_objects) + queryset._distinct = self._distinct + queryset._annotations = copy(self._annotations) + queryset._having = copy(self._having) + queryset._custom_filters = copy(self._custom_filters) + queryset._group_bys = copy(self._group_bys) + queryset._select_for_update = self._select_for_update + queryset._select_for_update_nowait = self._select_for_update_nowait + queryset._select_for_update_skip_locked = self._select_for_update_skip_locked + queryset._select_for_update_of = self._select_for_update_of + queryset._select_for_update_no_key = self._select_for_update_no_key + queryset._select_related = self._select_related + queryset._select_related_idx = self._select_related_idx + queryset._force_indexes = self._force_indexes + queryset._use_indexes = self._use_indexes + queryset._cache_key = key -class CachedSql: - def __init__(self, sql: str, params: list[Parameter | Any]) -> None: - self.sql = sql - self.params = params - self.param_by_name: dict[str, Parameter] = {} - self.need_params: dict[str, int] = {} - self.need_collection_params: dict[str, list[int]] = defaultdict(list) - for idx, param in enumerate(params): - if not isinstance(param, Parameter): - continue + return queryset - if param.name not in self.param_by_name: - self.param_by_name[param.name] = param - if isinstance(param, CollectionParameter): - self.need_collection_params[param.name].append(idx) - else: - self.need_params[param.name] = idx +class PreparedQuerySet(QuerySet[MODEL]): + __slots__ = ( + "_cache_key", + "_prepared", + "_custom_fields", + "_sql_cache", + "_executor", + "_dynamic_params", + "_dynamic_params_names", + ) - def make_filled_params(self, params: dict[str, Any]) -> list[Any]: - # TODO: check for parameters mismatch + def __init__(self, model: type[MODEL]) -> None: + super().__init__(model) + self._cache_key: str | None = None + self._prepared: bool = False - filled_params = self.params.copy() - for name, idx in self.need_params.items(): - param = self.param_by_name[name] - filled_params[idx] = param.encode_value(params[name]) + self._custom_fields = None + self._sql_cache = None + self._executor = None + self._dynamic_params = None + self._dynamic_params_names = None - for name, indexes in self.need_collection_params.items(): - param = cast(CollectionParameter, self.param_by_name[name]) - collection = param.encode_collection(params[name]) - if len(collection) != len(indexes): - raise ValueError( - f"Provided value length ({len(collection)}) " - f"for parameter {name!r} does not match " - f"parameter indexes length ({len(indexes)})" - ) - for idx, value in zip(indexes, collection): - filled_params[idx] = param.encode_value(value) + def _clone(self) -> PreparedQuerySet[MODEL]: + queryset = super()._clone() + queryset._cache_key = self._cache_key + queryset._prepared = self._prepared + return cast(PreparedQuerySet, queryset) - return filled_params + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError + def prepared(self) -> PreparedQuerySet[MODEL]: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") -class PreparedQuery(AwaitableQuery[MODEL]): - def __init__( - self, query: QueryBuilder, db: BaseDBAsyncClient, model: type[MODEL], prefetch_map: ..., - prefetch_queries: ..., select_related_idx: ..., custom_fields: list[...], single: bool, - raise_does_not_exist: bool, - ) -> None: - super().__init__(model) + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] - self._query = query - self._cached_sql = {} - self._executor = db.executor_class( - model=model, - db=db, - prefetch_map=prefetch_map, - prefetch_queries=prefetch_queries, - select_related_idx=select_related_idx, + queryset = self._clone() + + queryset._choose_db_if_not_chosen(queryset._select_for_update) + queryset._make_query() + + queryset._custom_fields = list(self._annotations.keys()) + queryset._sql_cache = {} + queryset._executor = queryset._db.executor_class( + model=queryset.model, + db=queryset._db, + prefetch_map=queryset._prefetch_map, + prefetch_queries=queryset._prefetch_queries, + select_related_idx=queryset._select_related_idx, ) - self._model = model - self._custom_fields = custom_fields - self._single = single - self._raise_does_not_exist = raise_does_not_exist - _, params = self._query.get_parameterized_sql() - self._dynamic_params = { + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { param.name: param for param in params if isinstance(param, CollectionParameter) } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: reset_params = [] + # TODO: cache also by database dialect cache_key = "query" - for name, value in params.items(): - if name not in self._dynamic_params or not isinstance(value, (tuple, list, set)): + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? continue param = self._dynamic_params[name] @@ -1352,21 +1381,22 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: param.collection_size = len(value) reset_params.append(param) - if cache_key not in self._cached_sql: + if cache_key not in self._sql_cache: # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self._query.get_parameterized_sql(ctx) - self._cached_sql[cache_key] = CachedSql(sql, params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) for param in reset_params: param.collection_size = None - return self._cached_sql[cache_key] + return self._sql_cache[cache_key] async def execute(self, **params) -> list[MODEL]: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) + # TODO: re-create executor when database changes instance_list = await self._executor.execute_select( cached_query.sql, filled_params, custom_fields=self._custom_fields, @@ -1381,6 +1411,217 @@ async def execute(self, **params) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list + def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call filter on already prepared queryset.") + return cast(PreparedQuerySet, super().filter(*args, **kwargs)) + + def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call exclude on already prepared queryset.") + return cast(PreparedQuerySet, super().exclude(*args, **kwargs)) + + def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call order_by on already prepared queryset.") + return cast(PreparedQuerySet, super().order_by(*orderings)) + + def latest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: + if self._prepared: + raise ValueError("Cannot call latest on already prepared queryset.") + # TODO: fix typing + return cast(PreparedQuerySet, super().latest(*orderings)) + + def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: + if self._prepared: + raise ValueError("Cannot call earliest on already prepared queryset.") + # TODO: fix typing + return cast(PreparedQuerySet, super().earliest(*orderings)) + + def limit(self, limit: int) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call limit on already prepared queryset.") + return cast(PreparedQuerySet, super().limit(limit)) + + def offset(self, offset: int) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call offset on already prepared queryset.") + return cast(PreparedQuerySet, super().offset(offset)) + + def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call __getitem__ on already prepared queryset.") + return cast(PreparedQuerySet, super().__getitem__(key)) + + def distinct(self) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call distinct on already prepared queryset.") + return cast(PreparedQuerySet, super().distinct()) + + def select_for_update( + self, + nowait: bool = False, + skip_locked: bool = False, + of: tuple[str, ...] = (), + no_key: bool = False, + ) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call select_for_update on already prepared queryset.") + return cast(PreparedQuerySet, super().select_for_update( + nowait, skip_locked, of, no_key + )) + + def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call annotate on already prepared queryset.") + return cast(PreparedQuerySet, super().annotate(*kwargs)) + + def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call group_by on already prepared queryset.") + return cast(PreparedQuerySet, super().group_by(*fields)) + + def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Literal[False]]: + # TODO: implementation for PreparedQuerySet.delete() + raise NotImplementedError + + def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: + # TODO: implementation for PreparedQuerySet.delete() + raise NotImplementedError + + def delete(self) -> DeleteQuery: + # TODO: implementation for PreparedQuerySet.delete() + raise NotImplementedError + + def update(self, **kwargs: Any) -> UpdateQuery: + # TODO: implementation for PreparedQuerySet.update() + raise NotImplementedError + + def count(self) -> CountQuery: + # TODO: implementation for PreparedQuerySet.count() + raise NotImplementedError + + def exists(self) -> ExistsQuery: + # TODO: implementation for PreparedQuerySet.exists() + raise NotImplementedError + + def all(self) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call all on already prepared queryset.") + return cast(PreparedQuerySet, super().all()) + + def first(self) -> QuerySetSingle[MODEL | None]: + if self._prepared: + raise ValueError("Cannot call first on already prepared queryset.") + # TODO: fix typing + return cast(PreparedQuerySet, super().first()) + + def last(self) -> QuerySetSingle[MODEL | None]: + if self._prepared: + raise ValueError("Cannot call last on already prepared queryset.") + # TODO: fix typing + return cast(PreparedQuerySet, super().last()) + + def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]: + if self._prepared: + raise ValueError("Cannot call get on already prepared queryset.") + # TODO: fix typing + return cast(PreparedQuerySet, super().get(*args, **kwargs)) + + async def in_bulk(self, id_list: Iterable[str | int], field_name: str) -> dict[str, MODEL]: + raise NotImplementedError("Prepared queries don't support in_bulk.") + + def bulk_create( + self, + objects: Iterable[MODEL], + batch_size: int | None = None, + ignore_conflicts: bool = False, + update_fields: Iterable[str] | None = None, + on_conflict: Iterable[str] | None = None, + ) -> BulkCreateQuery[MODEL]: + raise NotImplementedError("Prepared queries don't support bulk_create.") + + def bulk_update( + self, + objects: Iterable[MODEL], + fields: Iterable[str], + batch_size: int | None = None, + ) -> BulkUpdateQuery[MODEL]: + raise NotImplementedError("Prepared queries don't support bulk_update.") + + def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL | None]: + if self._prepared: + raise ValueError("Cannot call get_or_none on already prepared queryset.") + # TODO: fix typing + return cast(PreparedQuerySet, super().get_or_none(*args, **kwargs)) + + def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call only on already prepared queryset.") + return cast(PreparedQuerySet, super().only(*fields_for_select)) + + def select_related(self, *fields: str) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call select_related on already prepared queryset.") + return cast(PreparedQuerySet, super().select_related(*fields)) + + def force_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call force_index on already prepared queryset.") + return cast(PreparedQuerySet, super().force_index(*index_names)) + + def use_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call use_index on already prepared queryset.") + return cast(PreparedQuerySet, super().use_index(*index_names)) + + def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: + if self._prepared: + raise ValueError("Cannot call prefetch_related on already prepared queryset.") + return cast(PreparedQuerySet, super().prefetch_related(*args)) + + +class CachedSql: + def __init__(self, sql: str, params: list[Parameter | Any]) -> None: + self.sql = sql + self.params = params + self.param_by_name: dict[str, Parameter] = {} + self.need_params: dict[str, int] = {} + self.need_collection_params: dict[str, list[int]] = defaultdict(list) + for idx, param in enumerate(params): + if not isinstance(param, Parameter): + continue + + if param.name not in self.param_by_name: + self.param_by_name[param.name] = param + + if isinstance(param, CollectionParameter): + self.need_collection_params[param.name].append(idx) + else: + self.need_params[param.name] = idx + + def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + # TODO: check for parameters mismatch + + filled_params = self.params.copy() + for name, idx in self.need_params.items(): + param = self.param_by_name[name] + filled_params[idx] = param.encode_value(params[name]) + + for name, indexes in self.need_collection_params.items(): + param = cast(CollectionParameter, self.param_by_name[name]) + collection = param.encode_collection(params[name]) + if len(collection) != len(indexes): + raise ValueError( + f"Provided value length ({len(collection)}) " + f"for parameter {name!r} does not match " + f"parameter indexes length ({len(indexes)})" + ) + for idx, value in zip(indexes, collection): + filled_params[idx] = param.encode_value(value) + + return filled_params + class UpdateQuery(AwaitableQuery): __slots__ = ( From f67ee856a66a90b1050c13667a5463aeb2737589 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 14:27:51 +0200 Subject: [PATCH 10/57] use decorator to disallow queryset methods (.filter, .get, etc.) on PreparedQuerySet that is prepared (._prepared = True) --- idk_test.py | 10 +++++ tortoise/queryset.py | 88 ++++++++++++++++++++++---------------------- 2 files changed, 53 insertions(+), 45 deletions(-) diff --git a/idk_test.py b/idk_test.py index 655d1e3e8..ac495d399 100644 --- a/idk_test.py +++ b/idk_test.py @@ -27,6 +27,16 @@ async def t0_sanity_check(some1: SomeModel, some2: SomeModel, some3: SomeModel) idk = await SomeModel.filter(id=some2.id) print(idk) + cache_key = "_some_query" + prepared = SomeModel.prepare_sql(cache_key).filter(id=Parameter("idk")).prepared() + assert SomeModel.prepare_sql(cache_key) is prepared + try: + prepared.filter(id=1) + except ValueError: + ... + else: + raise RuntimeError("PreparedQuerySet.filter on prepared query should raise") + async def t1_simple_gte(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: prepared = SomeModel.prepare_sql("some_query1").filter(id__gte=Parameter("idk")).prepared() diff --git a/tortoise/queryset.py b/tortoise/queryset.py index ffb19a8c3..de6022986 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable from copy import copy -from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn, ParamSpec from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count @@ -1299,6 +1299,20 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: return queryset +P = ParamSpec("P") +T = TypeVar("T") + + +def _disallow_queryset_methods_on_prepared_query(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) + def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: + if self._prepared: + raise ValueError(f"Cannot call \"{func.__name__}\" on already prepared queryset.") + return func(self, *args, **kwargs) + + return decorated + + class PreparedQuerySet(QuerySet[MODEL]): __slots__ = ( "_cache_key", @@ -1411,53 +1425,45 @@ async def execute(self, **params) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list + @_disallow_queryset_methods_on_prepared_query def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call filter on already prepared queryset.") return cast(PreparedQuerySet, super().filter(*args, **kwargs)) + @_disallow_queryset_methods_on_prepared_query def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call exclude on already prepared queryset.") return cast(PreparedQuerySet, super().exclude(*args, **kwargs)) + @_disallow_queryset_methods_on_prepared_query def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call order_by on already prepared queryset.") return cast(PreparedQuerySet, super().order_by(*orderings)) + @_disallow_queryset_methods_on_prepared_query def latest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: - if self._prepared: - raise ValueError("Cannot call latest on already prepared queryset.") # TODO: fix typing return cast(PreparedQuerySet, super().latest(*orderings)) + @_disallow_queryset_methods_on_prepared_query def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: - if self._prepared: - raise ValueError("Cannot call earliest on already prepared queryset.") # TODO: fix typing return cast(PreparedQuerySet, super().earliest(*orderings)) + @_disallow_queryset_methods_on_prepared_query def limit(self, limit: int) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call limit on already prepared queryset.") return cast(PreparedQuerySet, super().limit(limit)) + @_disallow_queryset_methods_on_prepared_query def offset(self, offset: int) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call offset on already prepared queryset.") return cast(PreparedQuerySet, super().offset(offset)) + @_disallow_queryset_methods_on_prepared_query def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call __getitem__ on already prepared queryset.") return cast(PreparedQuerySet, super().__getitem__(key)) + @_disallow_queryset_methods_on_prepared_query def distinct(self) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call distinct on already prepared queryset.") return cast(PreparedQuerySet, super().distinct()) + @_disallow_queryset_methods_on_prepared_query def select_for_update( self, nowait: bool = False, @@ -1465,66 +1471,64 @@ def select_for_update( of: tuple[str, ...] = (), no_key: bool = False, ) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call select_for_update on already prepared queryset.") return cast(PreparedQuerySet, super().select_for_update( nowait, skip_locked, of, no_key )) + @_disallow_queryset_methods_on_prepared_query def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call annotate on already prepared queryset.") return cast(PreparedQuerySet, super().annotate(*kwargs)) + @_disallow_queryset_methods_on_prepared_query def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call group_by on already prepared queryset.") return cast(PreparedQuerySet, super().group_by(*fields)) + @_disallow_queryset_methods_on_prepared_query def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Literal[False]]: # TODO: implementation for PreparedQuerySet.delete() raise NotImplementedError + @_disallow_queryset_methods_on_prepared_query def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: # TODO: implementation for PreparedQuerySet.delete() raise NotImplementedError + @_disallow_queryset_methods_on_prepared_query def delete(self) -> DeleteQuery: # TODO: implementation for PreparedQuerySet.delete() raise NotImplementedError + @_disallow_queryset_methods_on_prepared_query def update(self, **kwargs: Any) -> UpdateQuery: # TODO: implementation for PreparedQuerySet.update() raise NotImplementedError + @_disallow_queryset_methods_on_prepared_query def count(self) -> CountQuery: # TODO: implementation for PreparedQuerySet.count() raise NotImplementedError + @_disallow_queryset_methods_on_prepared_query def exists(self) -> ExistsQuery: # TODO: implementation for PreparedQuerySet.exists() raise NotImplementedError + @_disallow_queryset_methods_on_prepared_query def all(self) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call all on already prepared queryset.") return cast(PreparedQuerySet, super().all()) + @_disallow_queryset_methods_on_prepared_query def first(self) -> QuerySetSingle[MODEL | None]: - if self._prepared: - raise ValueError("Cannot call first on already prepared queryset.") # TODO: fix typing return cast(PreparedQuerySet, super().first()) + @_disallow_queryset_methods_on_prepared_query def last(self) -> QuerySetSingle[MODEL | None]: - if self._prepared: - raise ValueError("Cannot call last on already prepared queryset.") # TODO: fix typing return cast(PreparedQuerySet, super().last()) + @_disallow_queryset_methods_on_prepared_query def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]: - if self._prepared: - raise ValueError("Cannot call get on already prepared queryset.") # TODO: fix typing return cast(PreparedQuerySet, super().get(*args, **kwargs)) @@ -1549,35 +1553,29 @@ def bulk_update( ) -> BulkUpdateQuery[MODEL]: raise NotImplementedError("Prepared queries don't support bulk_update.") + @_disallow_queryset_methods_on_prepared_query def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL | None]: - if self._prepared: - raise ValueError("Cannot call get_or_none on already prepared queryset.") # TODO: fix typing return cast(PreparedQuerySet, super().get_or_none(*args, **kwargs)) + @_disallow_queryset_methods_on_prepared_query def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call only on already prepared queryset.") return cast(PreparedQuerySet, super().only(*fields_for_select)) + @_disallow_queryset_methods_on_prepared_query def select_related(self, *fields: str) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call select_related on already prepared queryset.") return cast(PreparedQuerySet, super().select_related(*fields)) + @_disallow_queryset_methods_on_prepared_query def force_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call force_index on already prepared queryset.") return cast(PreparedQuerySet, super().force_index(*index_names)) + @_disallow_queryset_methods_on_prepared_query def use_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call use_index on already prepared queryset.") return cast(PreparedQuerySet, super().use_index(*index_names)) + @_disallow_queryset_methods_on_prepared_query def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: - if self._prepared: - raise ValueError("Cannot call prefetch_related on already prepared queryset.") return cast(PreparedQuerySet, super().prefetch_related(*args)) From 4511f988ce609f8c3d9d259a62c7b7ad9dea78f7 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 20:54:09 +0200 Subject: [PATCH 11/57] add prepared UpdateQuery implementation --- idk_test.py | 21 +++++++ tortoise/queryset.py | 137 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 154 insertions(+), 4 deletions(-) diff --git a/idk_test.py b/idk_test.py index ac495d399..2697b1615 100644 --- a/idk_test.py +++ b/idk_test.py @@ -182,6 +182,26 @@ async def t7_subqueries_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) print(actual2 == expected2) +async def t8_update(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: + original_name = some1.name + + prepared = SomeModel.prepare_sql("some_query8").filter(id=Parameter("search_id")).update(name=Parameter("replace_name")).prepared() + await prepared.execute(search_id=some1.id, replace_name=some1.name + "_test") + await some1.refresh_from_db(["name"]) + print(f"{original_name!r} -> {some1.name!r}") + await prepared.execute(search_id=some1.id, replace_name=original_name) + await some1.refresh_from_db(["name"]) + print(f"back to {original_name!r}: {some1.name!r}") + + if CHECK_ACTUAL: + await SomeModel.filter(id=some1.id).update(name=some1.name + "_test") + await some1.refresh_from_db(["name"]) + print(f"{original_name!r} -> {some1.name!r}") + await SomeModel.filter(id=some1.id).update(name=original_name) + await some1.refresh_from_db(["name"]) + print(f"back to {original_name!r}: {some1.name!r}") + + TESTS = [ t0_sanity_check, t1_simple_gte, @@ -191,6 +211,7 @@ async def t7_subqueries_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) t5_compare_prepared_non_prepared, t6_subqueries, t7_subqueries_in, + t8_update, ] diff --git a/tortoise/queryset.py b/tortoise/queryset.py index de6022986..aba26cd46 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1500,8 +1500,17 @@ def delete(self) -> DeleteQuery: @_disallow_queryset_methods_on_prepared_query def update(self, **kwargs: Any) -> UpdateQuery: - # TODO: implementation for PreparedQuerySet.update() - raise NotImplementedError + return PreparedUpdateQuery( + model=self.model, + update_kwargs=kwargs, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) @_disallow_queryset_methods_on_prepared_query def count(self) -> CountQuery: @@ -1626,7 +1635,7 @@ class UpdateQuery(AwaitableQuery): "update_kwargs", "_orderings", "_limit", - "values", + "values", # TODO: unused? ) def __init__( @@ -1669,6 +1678,7 @@ def _make_query(self) -> None: self.model._validate_relation_type(key, value) fk_field: str = field_object.source_field # type: ignore db_field = self.model._meta.fields_map[fk_field].source_field + # TODO: support Parameters in here value = self.model._meta.fields_map[fk_field].to_db_value( getattr(value, field_object.to_field_instance.model_field_name), None, @@ -1689,7 +1699,11 @@ def _make_query(self) -> None: ) ).term else: - value = self.model._meta.fields_map[key].to_db_value(value, None) + field_object = self.model._meta.fields_map[key] + if isinstance(value, Parameter): + value.field_object = field_object + else: + value = field_object.to_db_value(value, None) self.query = self.query.set(db_field, value) @@ -1702,6 +1716,121 @@ async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] +class PreparedUpdateQuery(UpdateQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_executor", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + update_kwargs: dict[str, Any], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + orderings: list[tuple[str, str]], + cache_key: str, + ) -> None: + super().__init__( + model, update_kwargs, db, q_objects, annotations, custom_filters, limit, orderings, + ) + + self._cache_key: str | None = cache_key + self._prepared: bool = False + + self._sql_cache = None + self._executor = None + self._dynamic_params = None + self._dynamic_params_names = None + + def _clone(self) -> PreparedUpdateQuery[MODEL]: + query = self.__class__( + model=self.model, + update_kwargs=self.update_kwargs, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError + + # TODO: big part of this method is duplicated with PreparedQuerySet, de-duplicate it + def prepared(self) -> PreparedUpdateQuery[MODEL]: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] + + queryset = self._clone() + + queryset._choose_db_if_not_chosen(True) + queryset._make_query() + + queryset._sql_cache = {} + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { + param.name: param + for param in params + if isinstance(param, CollectionParameter) + } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset + + # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, + # move into separate class maybe + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + reset_params = [] + + # TODO: cache also by database dialect + cache_key = "query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + class DeleteQuery(AwaitableQuery): __slots__ = ( "_annotations", From 5c0ca0fb61abe787ef72a355589b9ec5151be7c3 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 21:28:19 +0200 Subject: [PATCH 12/57] create test file for prepared querysets --- idk_test.py | 231 -------------------------------- tests/test_queryset_prepared.py | 135 +++++++++++++++++++ 2 files changed, 135 insertions(+), 231 deletions(-) delete mode 100644 idk_test.py create mode 100644 tests/test_queryset_prepared.py diff --git a/idk_test.py b/idk_test.py deleted file mode 100644 index 2697b1615..000000000 --- a/idk_test.py +++ /dev/null @@ -1,231 +0,0 @@ -import random -import time - -from tortoise import fields, run_async -from tortoise.contrib.test import init_memory_sqlite -from tortoise.expressions import Q, Subquery -from tortoise.functions import Min, Max -from tortoise.models import Model -from tortoise.parameter import Parameter - - -CHECK_ACTUAL = True - - -class SomeModel(Model): - id: int = fields.BigIntField(pk=True) - name: str = fields.TextField() - - -class SomeForeignKeyModel(Model): - id: int = fields.BigIntField(pk=True) - info: str = fields.CharField(max_length=128, default="") - some: SomeModel = fields.ForeignKeyField("models.SomeModel") - - -async def t0_sanity_check(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - idk = await SomeModel.filter(id=some2.id) - print(idk) - - cache_key = "_some_query" - prepared = SomeModel.prepare_sql(cache_key).filter(id=Parameter("idk")).prepared() - assert SomeModel.prepare_sql(cache_key) is prepared - try: - prepared.filter(id=1) - except ValueError: - ... - else: - raise RuntimeError("PreparedQuerySet.filter on prepared query should raise") - - -async def t1_simple_gte(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - prepared = SomeModel.prepare_sql("some_query1").filter(id__gte=Parameter("idk")).prepared() - actual = await prepared.execute(idk=some2.id) - print(actual) - - if CHECK_ACTUAL: - expected = await SomeModel.filter(id__gte=some2.id) - print(expected) - print(actual == expected) - - -async def t2_simple_string_param(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - prepared = SomeModel.prepare_sql("some_query2").filter(name=Parameter("idk")).prepared() - actual1 = await prepared.execute(idk=some2.id) - print(actual1) - actual2 = await prepared.execute(idk=some2.name) - print(actual2) - - if CHECK_ACTUAL: - expected1 = await SomeModel.filter(name=some2.id) - expected2 = await SomeModel.filter(name=some2.name) - print(expected1) - print(expected2) - print(actual1 == expected1) - print(actual2 == expected2) - - -async def t3_startswith(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - prepared = SomeModel.prepare_sql("some_query3").filter(name__startswith=Parameter("idk")).prepared() - actual1 = await prepared.execute(idk=some2.id) - print(actual1) - actual2 = await prepared.execute(idk=some2.name) - print(actual2) - actual3 = await prepared.execute(idk="asd") - print(actual3) - actual4 = await prepared.execute(idk="qwe") - print(actual4) - - if CHECK_ACTUAL: - expected1 = await SomeModel.filter(name__startswith=some2.id) - expected2 = await SomeModel.filter(name__startswith=some2.name) - expected3 = await SomeModel.filter(name__startswith="asd") - expected4 = await SomeModel.filter(name__startswith="qwe") - print(expected1) - print(expected2) - print(expected3) - print(expected4) - print(actual1 == expected1) - print(actual2 == expected2) - print(actual3 == expected3) - print(actual4 == expected4) - - -async def t4_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - prepared = SomeModel.prepare_sql("some_query4").filter(id__in=Parameter("idk")).prepared() - actual1 = await prepared.execute(idk=[some2.id, some1.id]) - print(actual1) - actual2 = await prepared.execute(idk=[some3.id, some3.id * 2, some3.id * 10]) - print(actual2) - - if CHECK_ACTUAL: - expected1 = await SomeModel.filter(id__in=[some2.id, some1.id]) - expected2 = await SomeModel.filter(id__in=[some3.id, some3.id * 2, some3.id * 10]) - print(expected1) - print(expected2) - print(actual1 == expected1) - print(actual2 == expected2) - - -async def t5_compare_prepared_non_prepared(*_) -> None: - ITERS = 1000 - - prefix = f"{time.time()}-" - await SomeModel.bulk_create([ - SomeModel(name=f"{prefix}{num}") - for num in range(1000) - ]) - - min_id, max_id = await SomeModel.filter(name__startswith=prefix).annotate(max_id=Max("id"), min_id=Min("id")).first().values_list("min_id", "max_id") - random_id = random.randint(min_id, max_id) - - random_ids = await SomeModel.filter(name__startswith=prefix).values_list("id", flat=True) - random.shuffle(random_ids) - random_ids = random_ids[:2] - - start_time = time.perf_counter() - for _ in range(ITERS): - await SomeModel.filter(Q(id__lte=random_id * 2, id__in=random_ids, join_type=Q.OR), id__gte=random_id) - end_time = time.perf_counter() - non_prepared_millis = (end_time - start_time) * 1000 - print(f"Non-prepared: {non_prepared_millis:.2f}ms") - - start_time = time.perf_counter() - query = SomeModel.prepare_sql("some_query5").filter(Q(id__lte=Parameter("id_lte"), id__in=Parameter("id_in"), join_type=Q.OR), id__gte=Parameter("id_gte")).prepared() - for _ in range(ITERS): - await query.execute(id_lte=random_id * 2, id_gte=random_id, id_in=random_ids) - end_time = time.perf_counter() - prepared_millis = (end_time - start_time) * 1000 - print(f"Prepared: {prepared_millis:.2f}ms") - - if non_prepared_millis > prepared_millis: - ratio = non_prepared_millis / prepared_millis - result = "faster" - else: - ratio = prepared_millis / non_prepared_millis - result = "slower" - - print(f"Prepared is {(ratio - 1) * 100:.2f}% {result} than non-prepared") - - await SomeModel.filter(name__startswith=prefix).delete() - - -async def t6_subqueries(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - prepared = SomeModel.prepare_sql("some_query6").filter(id__in=Subquery(SomeModel.filter(Q(id=Parameter("idk1")) | Q(id=Parameter("idk2"))).values("id"))).prepared() - actual1 = await prepared.execute(idk1=some2.id, idk2=some1.id) - print(actual1) - actual2 = await prepared.execute(idk1=some3.id, idk2=some3.id * 2) - print(actual2) - - if CHECK_ACTUAL: - expected1 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(Q(id=some2.id) | Q(id=some1.id)).values("id"))) - expected2 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(Q(id=some3.id) | Q(id=some3.id * 2)).values("id"))) - print(expected1) - print(expected2) - print(actual1 == expected1) - print(actual2 == expected2) - - -async def t7_subqueries_in(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - prepared = SomeModel.prepare_sql("some_query7").filter(id__in=Subquery(SomeModel.filter(id__in=Parameter("idk")).values("id"))).prepared() - actual1 = await prepared.execute(idk=[some2.id, some1.id]) - print(actual1) - actual2 = await prepared.execute(idk=[some3.id, some3.id * 2, some3.id * 10]) - print(actual2) - - if CHECK_ACTUAL: - expected1 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(id__in=[some2.id, some1.id]).values("id"))) - expected2 = await SomeModel.filter(id__in=Subquery(SomeModel.filter(id__in=[some3.id, some3.id * 2, some3.id * 10]).values("id"))) - print(expected1) - print(expected2) - print(actual1 == expected1) - print(actual2 == expected2) - - -async def t8_update(some1: SomeModel, some2: SomeModel, some3: SomeModel) -> None: - original_name = some1.name - - prepared = SomeModel.prepare_sql("some_query8").filter(id=Parameter("search_id")).update(name=Parameter("replace_name")).prepared() - await prepared.execute(search_id=some1.id, replace_name=some1.name + "_test") - await some1.refresh_from_db(["name"]) - print(f"{original_name!r} -> {some1.name!r}") - await prepared.execute(search_id=some1.id, replace_name=original_name) - await some1.refresh_from_db(["name"]) - print(f"back to {original_name!r}: {some1.name!r}") - - if CHECK_ACTUAL: - await SomeModel.filter(id=some1.id).update(name=some1.name + "_test") - await some1.refresh_from_db(["name"]) - print(f"{original_name!r} -> {some1.name!r}") - await SomeModel.filter(id=some1.id).update(name=original_name) - await some1.refresh_from_db(["name"]) - print(f"back to {original_name!r}: {some1.name!r}") - - -TESTS = [ - t0_sanity_check, - t1_simple_gte, - t2_simple_string_param, - t3_startswith, - t4_in, - t5_compare_prepared_non_prepared, - t6_subqueries, - t7_subqueries_in, - t8_update, -] - - -@init_memory_sqlite -async def run() -> None: - some1 = await SomeModel.create(name="asdqwe") - some2 = await SomeModel.create(name="asdqweasd") - some3 = await SomeModel.create(name="asdqweasd123") - - for test_func in TESTS: - print(f"Running {test_func.__name__} ...") - await test_func(some1, some2, some3) - print("=" * 32) - - -if __name__ == "__main__": - run_async(run()) \ No newline at end of file diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py new file mode 100644 index 000000000..e781328d3 --- /dev/null +++ b/tests/test_queryset_prepared.py @@ -0,0 +1,135 @@ +from tests.testmodels import ( + Author, +) +from tortoise.contrib import test +from tortoise.expressions import Subquery, Q +from tortoise.parameter import Parameter + + +class TestQuerysetPrepared(test.TestCase): + def test_prepared_queryset_always_same(self): + cache_key = "test_prepared_queryset_always_same" + prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() + assert Author.prepare_sql(cache_key) is prepared + + def test_disallow_filtering_on_prepared_queryset(self): + cache_key = "test_disallow_filtering_on_prepared_queryset" + prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() + + with self.assertRaises(ValueError): + prepared.filter(id=1) + + async def test_gte_filter(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + expected = await Author.filter(id__gte=author2.pk).order_by("id") + + prepared = Author.prepare_sql("test_gte_filter").filter(id__gte=Parameter("idgte")).order_by("id").prepared() + actual = await prepared.execute(idgte=author2.pk) + self.assertEqual(len(actual), 2) + self.assertEqual(actual[0].id, author2.pk) + self.assertEqual(actual[1].id, author3.pk) + self.assertEqual(expected, actual) + + async def test_string_param(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + expected = await Author.filter(name=author2.name) + + prepared = Author.prepare_sql("test_string_param").filter(name=Parameter("name")).prepared() + actual = await prepared.execute(name=author2.name) + self.assertEqual(len(actual), 1) + self.assertEqual(actual[0].id, author2.pk) + self.assertEqual(expected, actual) + + async def test_startswith_filter(self): + author1 = await Author.create(name="test") + author2 = await Author.create(name="testqwe") + author3 = await Author.create(name="qwetest") + + prepared = Author.prepare_sql("test_startswith_filter").filter(name__startswith=Parameter("name")).prepared() + + for test_name in (author2.pk, author1.name, author3.name, "asd"): + expected = await Author.filter(name__startswith=test_name) + actual = await prepared.execute(name=test_name) + self.assertEqual(expected, actual) + + async def test_in_filter(self): + author1 = await Author.create(name="test") + author2 = await Author.create(name="testqwe") + author3 = await Author.create(name="qwetest") + + prepared = Author.prepare_sql("test_in_filter").filter(id__in=Parameter("ids")).prepared() + + for test_ids in ( + [author2.pk, author1.pk], + [author3.pk, author3.pk * 2, author3.pk * 10] + ): + expected = await Author.filter(id__in=test_ids) + actual = await prepared.execute(ids=test_ids) + self.assertEqual(expected, actual) + + async def test_subqueries(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.prepare_sql("test_subqueries").filter(id__in=Subquery( + Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id") + )).prepared() + + for id1, id2 in ( + (author2.pk, author1.pk), + (author3.pk, author3.pk * 2), + ): + expected = await Author.filter(id__in=Subquery( + Author.filter(Q(id=id1) | Q(id=id2)).values("id") + )) + actual = await prepared.execute(id1=id1, id2=id2) + self.assertEqual(expected, actual) + + async def test_subqueries_in_filter(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.prepare_sql("test_subqueries_in_filter").filter(id__in=Subquery( + Author.filter(id__in=Parameter("ids")).values("id") + )).prepared() + + for test_ids in ( + [author2.pk, author1.pk], + [author3.pk, author3.pk * 2, author3.pk * 10] + ): + expected = await Author.filter(id__in=Subquery( + Author.filter(id__in=test_ids).values("id") + )) + actual = await prepared.execute(ids=test_ids) + self.assertEqual(expected, actual) + + async def test_update(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + original_name1 = author1.name + original_name2 = author2.name + new_name1 = f"{author1.name}_test" + + prepared = Author.prepare_sql("test_update").filter( + id=Parameter("search_id") + ).update(name=Parameter("replace_name")).prepared() + + await prepared.execute(search_id=author1.pk, replace_name=new_name1) + await author1.refresh_from_db(["name"]) + await author2.refresh_from_db(["name"]) + self.assertEqual(author1.name, new_name1) + self.assertEqual(author2.name, original_name2) + + await prepared.execute(search_id=author1.pk, replace_name=original_name1) + await author1.refresh_from_db(["name"]) + self.assertEqual(author1.name, original_name1) From ffdde8a9fd24d2686bda1b1d544483d6ca4cfa2e Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 22:14:39 +0200 Subject: [PATCH 13/57] implement prepared DeleteQuery class --- tortoise/queryset.py | 130 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 6 deletions(-) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index aba26cd46..3f2709218 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1485,18 +1485,26 @@ def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: @_disallow_queryset_methods_on_prepared_query def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Literal[False]]: - # TODO: implementation for PreparedQuerySet.delete() + # TODO: implementation for PreparedQuerySet.values_list() raise NotImplementedError @_disallow_queryset_methods_on_prepared_query def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: - # TODO: implementation for PreparedQuerySet.delete() + # TODO: implementation for PreparedQuerySet.values() raise NotImplementedError @_disallow_queryset_methods_on_prepared_query def delete(self) -> DeleteQuery: - # TODO: implementation for PreparedQuerySet.delete() - raise NotImplementedError + return PreparedDeleteQuery( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) @_disallow_queryset_methods_on_prepared_query def update(self, **kwargs: Any) -> UpdateQuery: @@ -1721,7 +1729,6 @@ class PreparedUpdateQuery(UpdateQuery): "_cache_key", "_prepared", "_sql_cache", - "_executor", "_dynamic_params", "_dynamic_params_names", ) @@ -1746,7 +1753,6 @@ def __init__( self._prepared: bool = False self._sql_cache = None - self._executor = None self._dynamic_params = None self._dynamic_params_names = None @@ -1880,6 +1886,118 @@ async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] +class PreparedDeleteQuery(DeleteQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + orderings: list[tuple[str, str]], + cache_key: str, + ) -> None: + super().__init__( + model, db, q_objects, annotations, custom_filters, limit, orderings, + ) + + self._cache_key: str | None = cache_key + self._prepared: bool = False + + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None + + def _clone(self) -> PreparedDeleteQuery[MODEL]: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError + + # TODO: big part of this method is duplicated with PreparedQuerySet + # (and literally 1:1 with PreparedUpdateQuery), de-duplicate it + def prepared(self) -> PreparedDeleteQuery[MODEL]: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] + + queryset = self._clone() + + queryset._choose_db_if_not_chosen(True) + queryset._make_query() + + queryset._sql_cache = {} + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { + param.name: param + for param in params + if isinstance(param, CollectionParameter) + } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset + + # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, + # move into separate class maybe + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + reset_params = [] + + # TODO: cache also by database dialect + cache_key = "query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + class ExistsQuery(AwaitableQuery): __slots__ = ( "_force_indexes", From 73f1431bba4a4c84f43c41491cf98ff80666de28 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 22:46:10 +0200 Subject: [PATCH 14/57] implement prepared ExistsQuery class --- tests/test_queryset_prepared.py | 27 +++++++ tortoise/queryset.py | 129 +++++++++++++++++++++++++++++++- 2 files changed, 152 insertions(+), 4 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index e781328d3..44bf8fbaa 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -133,3 +133,30 @@ async def test_update(self): await prepared.execute(search_id=author1.pk, replace_name=original_name1) await author1.refresh_from_db(["name"]) self.assertEqual(author1.name, original_name1) + + async def test_delete(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.prepare_sql("test_delete").filter( + id__in=Parameter("ids"), + ).delete().prepared() + + affected = await prepared.execute(ids=[author1.pk]) + self.assertEqual(affected, 1) + self.assertEqual(await Author.all().count(), 2) + existing = await Author.all().values_list("id", flat=True) + self.assertEqual(set(existing), {author2.pk, author3.pk}) + + async def test_exists(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.prepare_sql("test_exists").filter( + id__in=Parameter("ids"), + ).exists().prepared() + + self.assertTrue(await prepared.execute(ids=[author1.pk])) + self.assertFalse(await prepared.execute(ids=[author3.pk * 2])) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 3f2709218..ce8a173da 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1507,7 +1507,7 @@ def delete(self) -> DeleteQuery: ) @_disallow_queryset_methods_on_prepared_query - def update(self, **kwargs: Any) -> UpdateQuery: + def update(self, **kwargs: Any) -> PreparedUpdateQuery: return PreparedUpdateQuery( model=self.model, update_kwargs=kwargs, @@ -1526,9 +1526,17 @@ def count(self) -> CountQuery: raise NotImplementedError @_disallow_queryset_methods_on_prepared_query - def exists(self) -> ExistsQuery: - # TODO: implementation for PreparedQuerySet.exists() - raise NotImplementedError + def exists(self) -> PreparedExistsQuery: + return PreparedExistsQuery( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) @_disallow_queryset_methods_on_prepared_query def all(self) -> PreparedQuerySet[MODEL]: @@ -2047,6 +2055,119 @@ async def _execute( return bool(result) +class PreparedExistsQuery(ExistsQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super().__init__( + model, db, q_objects, annotations, custom_filters, force_indexes, use_indexes, + ) + + self._cache_key: str = cache_key + self._prepared: bool = False + + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None + + def _clone(self) -> PreparedExistsQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError + + # TODO: big part of this method is duplicated with PreparedQuerySet + # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it + def prepared(self) -> PreparedExistsQuery: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] + + queryset = self._clone() + + queryset._choose_db_if_not_chosen(False) + queryset._make_query() + + queryset._sql_cache = {} + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { + param.name: param + for param in params + if isinstance(param, CollectionParameter) + } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset + + # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, + # move into separate class maybe + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + reset_params = [] + + # TODO: cache also by database dialect + cache_key = "query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + result, _ = await self._db.execute_query(cached_query.sql, filled_params) + return bool(result) + + class CountQuery(AwaitableQuery): __slots__ = ( "_limit", From 7c353c33a8e7b473a8e1251a73eca9ffe1d6d0b6 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 23:01:04 +0200 Subject: [PATCH 15/57] implement prepared CountQuery class; add support for parameters in PreparedQuerySet.limit and PreparedQuerySet.offset --- tests/test_queryset_prepared.py | 54 +++++++++- tortoise/queryset.py | 173 ++++++++++++++++++++++++++++++-- 2 files changed, 216 insertions(+), 11 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 44bf8fbaa..478995f3b 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -2,6 +2,7 @@ Author, ) from tortoise.contrib import test +from tortoise.exceptions import ParamsError from tortoise.expressions import Subquery, Q from tortoise.parameter import Parameter @@ -150,13 +151,56 @@ async def test_delete(self): self.assertEqual(set(existing), {author2.pk, author3.pk}) async def test_exists(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") + author = await Author.create(name="1") prepared = Author.prepare_sql("test_exists").filter( id__in=Parameter("ids"), ).exists().prepared() - self.assertTrue(await prepared.execute(ids=[author1.pk])) - self.assertFalse(await prepared.execute(ids=[author3.pk * 2])) + self.assertTrue(await prepared.execute(ids=[author.pk])) + self.assertFalse(await prepared.execute(ids=[author.pk * 2])) + + async def test_count(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.prepare_sql("test_count").filter( + id__gte=Parameter("idgte"), + ).count().prepared() + + self.assertEqual(await prepared.execute(idgte=author1.pk), 3) + self.assertEqual(await prepared.execute(idgte=author2.pk), 2) + self.assertEqual(await prepared.execute(idgte=author3.pk), 1) + self.assertEqual(await prepared.execute(idgte=author3.pk * 2), 0) + + async def test_parameter_in_limit(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.prepare_sql("test_parameter_in_limit").all().limit(Parameter("lim")).order_by("id").prepared() + + self.assertEqual(len(await prepared.execute(lim=1)), 1) + self.assertEqual(len(await prepared.execute(lim=2)), 2) + self.assertEqual(len(await prepared.execute(lim=3)), 3) + self.assertEqual(len(await prepared.execute(lim=4)), 3) + + with self.assertRaises(ParamsError): + await prepared.execute(lim=-1) + + async def test_parameter_in_offset(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.prepare_sql("test_parameter_in_offset").all().offset(Parameter("off")).order_by("id").prepared() + + self.assertEqual(len(await prepared.execute(off=1)), 2) + self.assertEqual(len(await prepared.execute(off=2)), 1) + self.assertEqual(len(await prepared.execute(off=3)), 0) + self.assertEqual(len(await prepared.execute(off=4)), 0) + + with self.assertRaises(ParamsError): + await prepared.execute(off=-1) + diff --git a/tortoise/queryset.py b/tortoise/queryset.py index ce8a173da..98886e31c 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1447,13 +1447,41 @@ def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: # TODO: fix typing return cast(PreparedQuerySet, super().earliest(*orderings)) + @staticmethod + def _validate_limit(value: int) -> int: + if value < 0: + raise ParamsError("Limit should be non-negative number") + return value + @_disallow_queryset_methods_on_prepared_query - def limit(self, limit: int) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().limit(limit)) + def limit(self, limit: int | Parameter) -> PreparedQuerySet[MODEL]: + if isinstance(limit, int) and limit < 0: + raise ParamsError("Limit should be non-negative number") + elif isinstance(limit, Parameter): + limit.encode = self._validate_limit + + queryset = self._clone() + queryset._limit = limit # type: ignore + return queryset + + @staticmethod + def _validate_offset(value: int) -> int: + if value < 0: + raise ParamsError("Offset should be non-negative number") + return value @_disallow_queryset_methods_on_prepared_query def offset(self, offset: int) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().offset(offset)) + if isinstance(offset, int) and offset < 0: + raise ParamsError("Offset should be non-negative number") + elif isinstance(offset, Parameter): + offset.encode = self._validate_offset + + queryset = self._clone() + queryset._offset = offset + if self.capabilities.requires_limit and queryset._limit is None: + queryset._limit = 1000000 + return queryset @_disallow_queryset_methods_on_prepared_query def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: @@ -1521,9 +1549,19 @@ def update(self, **kwargs: Any) -> PreparedUpdateQuery: ) @_disallow_queryset_methods_on_prepared_query - def count(self) -> CountQuery: - # TODO: implementation for PreparedQuerySet.count() - raise NotImplementedError + def count(self) -> PreparedCountQuery: + return PreparedCountQuery( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + offset=self._offset, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) @_disallow_queryset_methods_on_prepared_query def exists(self) -> PreparedExistsQuery: @@ -2231,6 +2269,129 @@ async def _execute(self) -> int: return count +class PreparedCountQuery(CountQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + offset: int | None, + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super().__init__( + model, db, q_objects, annotations, custom_filters, limit, offset, force_indexes, use_indexes, + ) + + self._cache_key: str = cache_key + self._prepared: bool = False + + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None + + def _clone(self) -> PreparedCountQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + offset=self._offset, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError + + # TODO: big part of this method is duplicated with PreparedQuerySet + # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it + def prepared(self) -> PreparedCountQuery: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] + + queryset = self._clone() + + queryset._choose_db_if_not_chosen(False) + queryset._make_query() + + queryset._sql_cache = {} + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { + param.name: param + for param in params + if isinstance(param, CollectionParameter) + } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset + + # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, + # move into separate class maybe + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + reset_params = [] + + # TODO: cache also by database dialect + cache_key = "query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + _, result = await self._db.execute_query(cached_query.sql, filled_params) + if not result: + return 0 + count = list(dict(result[0]).values())[0] - self._offset + if self._limit and count > self._limit: + return self._limit + return count + + class FieldSelectQuery(AwaitableQuery): # pylint: disable=W0223 From 1824428ed300bf721449c9f094f974c4f42cb3c2 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 16 Feb 2026 23:31:31 +0200 Subject: [PATCH 16/57] implement prepared ValuesQuery class; implement prepared ValuesListQuery class; --- tests/test_queryset_prepared.py | 60 +++++ tortoise/queryset.py | 381 +++++++++++++++++++++++++++++--- 2 files changed, 410 insertions(+), 31 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 478995f3b..b71eaf9dd 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -204,3 +204,63 @@ async def test_parameter_in_offset(self): with self.assertRaises(ParamsError): await prepared.execute(off=-1) + async def test_values(self): + author = await Author.create(name="1") + + prepared = Author.prepare_sql("test_values").filter( + id=Parameter("id"), + ).values().prepared() + + self.assertEqual( + await prepared.execute(id=author.pk), + [{"id": author.pk, "name": author.name}], + ) + self.assertEqual( + await prepared.execute(id=author.pk * 2), + [], + ) + + async def test_values_list_all_fields(self): + author = await Author.create(name="1") + + prepared_all = Author.prepare_sql("test_values_list_all_fields").filter( + id=Parameter("id"), + ).values_list().prepared() + self.assertEqual( + await prepared_all.execute(id=author.pk), + [(author.pk, author.name)], + ) + self.assertEqual( + await prepared_all.execute(id=author.pk * 2), + [], + ) + + async def test_values_list_only_id_field(self): + author = await Author.create(name="1") + + prepared_ids = Author.prepare_sql("test_values_list_only_id_field").filter( + id=Parameter("id"), + ).values_list("id").prepared() + self.assertEqual( + await prepared_ids.execute(id=author.pk), + [(author.pk,)], + ) + self.assertEqual( + await prepared_ids.execute(id=author.pk * 2), + [], + ) + + async def test_values_list_only_id_field_flat(self): + author = await Author.create(name="1") + + prepared_ids_flat = Author.prepare_sql("test_values_list_only_id_field_flat").filter( + id=Parameter("id"), + ).values_list("id", flat=True).prepared() + self.assertEqual( + await prepared_ids_flat.execute(id=author.pk), + [author.pk], + ) + self.assertEqual( + await prepared_ids_flat.execute(id=author.pk * 2), + [], + ) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 98886e31c..22b61feb5 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -5,7 +5,8 @@ from collections import defaultdict from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable from copy import copy -from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn, ParamSpec +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn, ParamSpec, \ + Sequence from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count @@ -639,6 +640,14 @@ def group_by(self, *fields: str) -> QuerySet[MODEL]: queryset._group_bys = fields return queryset + def _get_fields_list_for_select(self, *fields_: str) -> list[str]: + if self._fields_for_select: + raise ValueError(".values_list() cannot be used with .only()") + + return fields_ or [ + field for field in self.model._meta.fields_map if field in self.model._meta.db_fields + ] + list(self._annotations.keys()) + def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Literal[False]]: """ Make QuerySet returns list of tuples for given args instead of objects. @@ -650,12 +659,8 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite If no arguments are passed it will default to a tuple containing all fields in order of declaration. """ - if self._fields_for_select: - raise ValueError(".values_list() cannot be used with .only()") + fields_for_select_list = self._get_fields_list_for_select(*fields_) - fields_for_select_list = fields_ or [ - field for field in self.model._meta.fields_map if field in self.model._meta.db_fields - ] + list(self._annotations.keys()) return ValuesListQuery( db=self._db, model=self.model, @@ -675,20 +680,7 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite use_indexes=self._use_indexes, ) - def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: - """ - Make QuerySet return dicts instead of objects. - - If called after `.get()`, `.get_or_none()` or `.first()`, returns a dict instead of an object. - - You can specify which fields to include by: - - Passing field names as positional arguments - - Using kwargs in the format `field_name='name_in_dict'` to customize the keys in the resulting dict - - If no arguments are passed, it will default to a dict containing all fields. - - :raises FieldError: If duplicate key has been provided. - """ + def _get_fields_for_select(self, *args: str, **kwargs: str) -> dict[str, str]: if self._fields_for_select: raise ValueError(".values() cannot be used with .only()") @@ -712,6 +704,24 @@ def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: fields_for_select = {field: field for field in _fields} + return fields_for_select + + def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: + """ + Make QuerySet return dicts instead of objects. + + If called after `.get()`, `.get_or_none()` or `.first()`, returns a dict instead of an object. + + You can specify which fields to include by: + - Passing field names as positional arguments + - Using kwargs in the format `field_name='name_in_dict'` to customize the keys in the resulting dict + + If no arguments are passed, it will default to a dict containing all fields. + + :raises FieldError: If duplicate key has been provided. + """ + fields_for_select = self._get_fields_for_select(*args, **kwargs) + return ValuesQuery( db=self._db, model=self.model, @@ -1478,7 +1488,7 @@ def offset(self, offset: int) -> PreparedQuerySet[MODEL]: offset.encode = self._validate_offset queryset = self._clone() - queryset._offset = offset + queryset._offset = offset # type: ignore if self.capabilities.requires_limit and queryset._limit is None: queryset._limit = 1000000 return queryset @@ -1512,14 +1522,51 @@ def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: return cast(PreparedQuerySet, super().group_by(*fields)) @_disallow_queryset_methods_on_prepared_query - def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Literal[False]]: - # TODO: implementation for PreparedQuerySet.values_list() - raise NotImplementedError + def values_list(self, *fields_: str, flat: bool = False) -> PreparedValuesListQuery[Literal[False]]: + fields_for_select_list = self._get_fields_list_for_select(*fields_) + + return PreparedValuesListQuery( + db=self._db, + model=self.model, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + flat=flat, + fields_for_select_list=fields_for_select_list, + distinct=self._distinct, + limit=self._limit, + offset=self._offset, + orderings=self._orderings, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) @_disallow_queryset_methods_on_prepared_query - def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: - # TODO: implementation for PreparedQuerySet.values() - raise NotImplementedError + def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: + fields_for_select = self._get_fields_for_select(*args, **kwargs) + + return PreparedValuesQuery( + db=self._db, + model=self.model, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select=fields_for_select, + distinct=self._distinct, + limit=self._limit, + offset=self._offset, + orderings=self._orderings, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) @_disallow_queryset_methods_on_prepared_query def delete(self) -> DeleteQuery: @@ -2615,8 +2662,7 @@ async def __aiter__(self: ValuesListQuery[Any]) -> AsyncIterator[Any]: for val in await self: yield val - async def _execute(self) -> list[Any] | tuple: - _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) + def _process_results(self, result: Sequence[dict]) -> list[Any] | tuple: columns = [ (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() @@ -2639,6 +2685,144 @@ async def _execute(self) -> list[Any] | tuple: raise MultipleObjectsReturned(self.model) return lst_values + async def _execute(self) -> list[Any] | tuple: + _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) + return self._process_results(result) + + +class PreparedValuesListQuery(ValuesListQuery[SINGLE]): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + single: bool, + raise_does_not_exist: bool, + fields_for_select_list: tuple[str, ...] | list[str], + limit: int | None, + offset: int | None, + distinct: bool, + orderings: list[tuple[str, str]], + flat: bool, + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + group_bys: tuple[str, ...], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super().__init__( + model, db, q_objects, single, raise_does_not_exist, fields_for_select_list, limit, + offset, distinct, orderings, flat, annotations, custom_filters, group_bys, + force_indexes, use_indexes + ) + + self._cache_key: str = cache_key + self._prepared: bool = False + + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None + + def _clone(self) -> PreparedValuesListQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select_list=self._fields_for_select_list, + limit=self._limit, + offset=self._offset, + distinct=self._distinct, + orderings=self._orderings, + flat=self._flat, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError + + # TODO: big part of this method is duplicated with PreparedQuerySet + # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it + def prepared(self) -> PreparedValuesListQuery: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] + + queryset = self._clone() + + queryset._choose_db_if_not_chosen(False) + queryset._make_query() + + queryset._sql_cache = {} + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { + param.name: param + for param in params + if isinstance(param, CollectionParameter) + } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset + + # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, + # move into separate class maybe + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + reset_params = [] + + # TODO: cache also by database dialect + cache_key = "query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + async def execute(self, **params) -> list[Any] | tuple: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + _, result = await self._db.execute_query(cached_query.sql, filled_params) + return self._process_results(result) + class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( @@ -2745,8 +2929,7 @@ async def __aiter__(self: ValuesQuery[Any]) -> AsyncIterator[dict[str, Any]]: for val in await self: yield val - async def _execute(self) -> list[dict] | dict: - result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) + def _process_results(self, result: list[dict]) -> list[dict] | dict: columns = [ val for val in [ @@ -2771,6 +2954,142 @@ async def _execute(self) -> list[dict] | dict: raise MultipleObjectsReturned(self.model) return result + async def _execute(self) -> list[dict] | dict: + result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) + return self._process_results(result) + + +class PreparedValuesQuery(ValuesQuery[SINGLE]): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + single: bool, + raise_does_not_exist: bool, + fields_for_select: dict[str, str], + limit: int | None, + offset: int | None, + distinct: bool, + orderings: list[tuple[str, str]], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + group_bys: tuple[str, ...], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super().__init__( + model, db, q_objects, single, raise_does_not_exist, fields_for_select, limit, + offset, distinct, orderings, annotations, custom_filters, group_bys, + force_indexes, use_indexes, + ) + + self._cache_key: str = cache_key + self._prepared: bool = False + + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None + + def _clone(self) -> PreparedValuesQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select=self._fields_for_select, + limit=self._limit, + offset=self._offset, + distinct=self._distinct, + orderings=self._orderings, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError + + # TODO: big part of this method is duplicated with PreparedQuerySet + # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it + def prepared(self) -> PreparedValuesQuery: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] + + queryset = self._clone() + + queryset._choose_db_if_not_chosen(False) + queryset._make_query() + + queryset._sql_cache = {} + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { + param.name: param + for param in params + if isinstance(param, CollectionParameter) + } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset + + # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, + # move into separate class maybe + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + reset_params = [] + + # TODO: cache also by database dialect + cache_key = "query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + async def execute(self, **params) -> list[Any] | tuple: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + result = await self._db.execute_query_dict(cached_query.sql, filled_params) + return self._process_results(result) + class RawSQLQuery(AwaitableQuery): __slots__ = ("_sql", "_db") From f70be998b00668a8ff50186fafb4199889aef7f2 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 17 Feb 2026 14:27:11 +0200 Subject: [PATCH 17/57] fix typing of PreparedQuerySet methods that return QuerySetSingle --- tortoise/queryset.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 22b61feb5..9a27c5e1a 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -81,6 +81,14 @@ def values( ) -> ValuesQuery[Literal[True]]: ... # pragma: nocoverage +class PreparedQuerySetSingle(QuerySetSingle[T_co]): + def prepared(self) -> PreparedQuerySet[MODEL]: + ... + + async def execute(self, **params) -> list[MODEL]: + ... + + class AwaitableQuery(Generic[MODEL]): __slots__ = ( "query", @@ -1448,14 +1456,12 @@ def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: return cast(PreparedQuerySet, super().order_by(*orderings)) @_disallow_queryset_methods_on_prepared_query - def latest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: - # TODO: fix typing - return cast(PreparedQuerySet, super().latest(*orderings)) + def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().latest(*orderings)) @_disallow_queryset_methods_on_prepared_query - def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: - # TODO: fix typing - return cast(PreparedQuerySet, super().earliest(*orderings)) + def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().earliest(*orderings)) @staticmethod def _validate_limit(value: int) -> int: @@ -1628,19 +1634,16 @@ def all(self) -> PreparedQuerySet[MODEL]: return cast(PreparedQuerySet, super().all()) @_disallow_queryset_methods_on_prepared_query - def first(self) -> QuerySetSingle[MODEL | None]: - # TODO: fix typing - return cast(PreparedQuerySet, super().first()) + def first(self) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().first()) @_disallow_queryset_methods_on_prepared_query - def last(self) -> QuerySetSingle[MODEL | None]: - # TODO: fix typing - return cast(PreparedQuerySet, super().last()) + def last(self) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().last()) @_disallow_queryset_methods_on_prepared_query - def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]: - # TODO: fix typing - return cast(PreparedQuerySet, super().get(*args, **kwargs)) + def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: + return cast(PreparedQuerySetSingle, super().get(*args, **kwargs)) async def in_bulk(self, id_list: Iterable[str | int], field_name: str) -> dict[str, MODEL]: raise NotImplementedError("Prepared queries don't support in_bulk.") @@ -1664,9 +1667,8 @@ def bulk_update( raise NotImplementedError("Prepared queries don't support bulk_update.") @_disallow_queryset_methods_on_prepared_query - def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL | None]: - # TODO: fix typing - return cast(PreparedQuerySet, super().get_or_none(*args, **kwargs)) + def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().get_or_none(*args, **kwargs)) @_disallow_queryset_methods_on_prepared_query def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: From 85543391e5e7ec0d413d05aca0363724a8acadb8 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 19 Feb 2026 12:49:02 +0200 Subject: [PATCH 18/57] move all prepared-query duplicated code into separate class --- tortoise/queryset.py | 564 ++++++++----------------------------------- 1 file changed, 102 insertions(+), 462 deletions(-) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 9a27c5e1a..113421ed8 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -6,7 +6,7 @@ from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable from copy import copy from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn, ParamSpec, \ - Sequence + Sequence, Self from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count @@ -1281,7 +1281,7 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: # TODO: add some arg to _clone to override class? # to be able to to something like self._clone(PreparedQuerySet) - queryset = PreparedQuerySet(self.model) + queryset = PreparedQuerySet(self.model, key) queryset.fields = self.fields queryset.model = self.model queryset.query = self.query @@ -1317,52 +1317,32 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: return queryset -P = ParamSpec("P") -T = TypeVar("T") - - -def _disallow_queryset_methods_on_prepared_query(func: Callable[P, T]) -> Callable[P, T]: - @functools.wraps(func) - def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: - if self._prepared: - raise ValueError(f"Cannot call \"{func.__name__}\" on already prepared queryset.") - return func(self, *args, **kwargs) - - return decorated - - -class PreparedQuerySet(QuerySet[MODEL]): - __slots__ = ( - "_cache_key", - "_prepared", - "_custom_fields", - "_sql_cache", - "_executor", - "_dynamic_params", - "_dynamic_params_names", - ) +class _PreparedQuery: + # __slots__ = ( + # "_cache_key", + # "_prepared", + # "_sql_cache", + # "_dynamic_params", + # "_dynamic_params_names", + # "_db_for_write" + # ) - def __init__(self, model: type[MODEL]) -> None: - super().__init__(model) - self._cache_key: str | None = None + def __init__(self, cache_key: str) -> None: + self._cache_key: str = cache_key self._prepared: bool = False - self._custom_fields = None self._sql_cache = None - self._executor = None self._dynamic_params = None self._dynamic_params_names = None + self._db_for_write = False - def _clone(self) -> PreparedQuerySet[MODEL]: - queryset = super()._clone() - queryset._cache_key = self._cache_key - queryset._prepared = self._prepared - return cast(PreparedQuerySet, queryset) + def _clone(self) -> Self: + raise NotImplementedError def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError + raise NotImplementedError("Querysets must only be prepared once") - def prepared(self) -> PreparedQuerySet[MODEL]: + def prepared(self) -> Self: if self._cache_key is None: raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") @@ -1371,18 +1351,10 @@ def prepared(self) -> PreparedQuerySet[MODEL]: queryset = self._clone() - queryset._choose_db_if_not_chosen(queryset._select_for_update) + queryset._choose_db_if_not_chosen(self._db_for_write) queryset._make_query() - queryset._custom_fields = list(self._annotations.keys()) queryset._sql_cache = {} - queryset._executor = queryset._db.executor_class( - model=queryset.model, - db=queryset._db, - prefetch_map=queryset._prefetch_map, - prefetch_queries=queryset._prefetch_queries, - select_related_idx=queryset._select_related_idx, - ) _, params = queryset.query.get_parameterized_sql() queryset._dynamic_params = { param.name: param @@ -1424,6 +1396,65 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: return self._sql_cache[cache_key] + async def execute(self, **params) -> int: + raise NotImplementedError + + +P = ParamSpec("P") +T = TypeVar("T") + + +def _disallow_queryset_methods_on_prepared_query(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) + def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: + if self._prepared: + raise ValueError(f"Cannot call \"{func.__name__}\" on already prepared queryset.") + return func(self, *args, **kwargs) + + return decorated + + +class PreparedQuerySet(QuerySet[MODEL], _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_custom_fields", + "_sql_cache", + "_executor", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__(self, model: type[MODEL], cache_key: str) -> None: + super(PreparedQuerySet, self).__init__(model) + super(AwaitableQuery, self).__init__(cache_key) + + self._db_for_write = self._select_for_update + + self._custom_fields = None + self._executor = None + + def _clone(self) -> PreparedQuerySet[MODEL]: + queryset = super()._clone() + queryset._cache_key = self._cache_key + queryset._prepared = self._prepared + queryset._db_for_write = self._select_for_update + return cast(PreparedQuerySet, queryset) + + def prepared(self) -> PreparedQuerySet[MODEL]: + queryset = cast(PreparedQuerySet[MODEL], super().prepared()) + + queryset._custom_fields = list(self._annotations.keys()) + queryset._executor = queryset._db.executor_class( + model=queryset.model, + db=queryset._db, + prefetch_map=queryset._prefetch_map, + prefetch_queries=queryset._prefetch_queries, + select_related_idx=queryset._select_related_idx, + ) + + return queryset + async def execute(self, **params) -> list[MODEL]: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) @@ -1819,7 +1850,7 @@ async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] -class PreparedUpdateQuery(UpdateQuery): +class PreparedUpdateQuery(UpdateQuery, _PreparedQuery): __slots__ = ( "_cache_key", "_prepared", @@ -1840,16 +1871,12 @@ def __init__( orderings: list[tuple[str, str]], cache_key: str, ) -> None: - super().__init__( + super(PreparedUpdateQuery, self).__init__( model, update_kwargs, db, q_objects, annotations, custom_filters, limit, orderings, ) + super(AwaitableQuery, self).__init__(cache_key) - self._cache_key: str | None = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None + self._db_for_write = True def _clone(self) -> PreparedUpdateQuery[MODEL]: query = self.__class__( @@ -1866,66 +1893,6 @@ def _clone(self) -> PreparedUpdateQuery[MODEL]: query._prepared = self._prepared return query - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError - - # TODO: big part of this method is duplicated with PreparedQuerySet, de-duplicate it - def prepared(self) -> PreparedUpdateQuery[MODEL]: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") - - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone() - - queryset._choose_db_if_not_chosen(True) - queryset._make_query() - - queryset._sql_cache = {} - _, params = queryset.query.get_parameterized_sql() - queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - - queryset._prepared = True - - self.model._meta.query_cache[self._cache_key] = queryset - - return queryset - - # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, - # move into separate class maybe - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - # TODO: cache also by database dialect - cache_key = "query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - async def execute(self, **params) -> int: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) @@ -1981,7 +1948,7 @@ async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] -class PreparedDeleteQuery(DeleteQuery): +class PreparedDeleteQuery(DeleteQuery, _PreparedQuery): __slots__ = ( "_cache_key", "_prepared", @@ -2001,16 +1968,12 @@ def __init__( orderings: list[tuple[str, str]], cache_key: str, ) -> None: - super().__init__( + super(PreparedDeleteQuery, self).__init__( model, db, q_objects, annotations, custom_filters, limit, orderings, ) + super(AwaitableQuery, self).__init__(cache_key) - self._cache_key: str | None = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None + self._db_for_write = True def _clone(self) -> PreparedDeleteQuery[MODEL]: query = self.__class__( @@ -2026,67 +1989,6 @@ def _clone(self) -> PreparedDeleteQuery[MODEL]: query._prepared = self._prepared return query - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError - - # TODO: big part of this method is duplicated with PreparedQuerySet - # (and literally 1:1 with PreparedUpdateQuery), de-duplicate it - def prepared(self) -> PreparedDeleteQuery[MODEL]: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") - - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone() - - queryset._choose_db_if_not_chosen(True) - queryset._make_query() - - queryset._sql_cache = {} - _, params = queryset.query.get_parameterized_sql() - queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - - queryset._prepared = True - - self.model._meta.query_cache[self._cache_key] = queryset - - return queryset - - # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, - # move into separate class maybe - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - # TODO: cache also by database dialect - cache_key = "query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - async def execute(self, **params) -> int: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) @@ -2142,7 +2044,7 @@ async def _execute( return bool(result) -class PreparedExistsQuery(ExistsQuery): +class PreparedExistsQuery(ExistsQuery, _PreparedQuery): __slots__ = ( "_cache_key", "_prepared", @@ -2162,16 +2064,12 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super().__init__( + super(PreparedExistsQuery, self).__init__( model, db, q_objects, annotations, custom_filters, force_indexes, use_indexes, ) + super(AwaitableQuery, self).__init__(cache_key) - self._cache_key: str = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None + self._db_for_write = False def _clone(self) -> PreparedExistsQuery: query = self.__class__( @@ -2187,67 +2085,6 @@ def _clone(self) -> PreparedExistsQuery: query._prepared = self._prepared return query - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError - - # TODO: big part of this method is duplicated with PreparedQuerySet - # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it - def prepared(self) -> PreparedExistsQuery: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") - - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone() - - queryset._choose_db_if_not_chosen(False) - queryset._make_query() - - queryset._sql_cache = {} - _, params = queryset.query.get_parameterized_sql() - queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - - queryset._prepared = True - - self.model._meta.query_cache[self._cache_key] = queryset - - return queryset - - # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, - # move into separate class maybe - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - # TODO: cache also by database dialect - cache_key = "query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - async def execute(self, **params) -> int: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) @@ -2318,7 +2155,7 @@ async def _execute(self) -> int: return count -class PreparedCountQuery(CountQuery): +class PreparedCountQuery(CountQuery, _PreparedQuery): __slots__ = ( "_cache_key", "_prepared", @@ -2340,16 +2177,12 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super().__init__( + super(PreparedCountQuery, self).__init__( model, db, q_objects, annotations, custom_filters, limit, offset, force_indexes, use_indexes, ) + super(AwaitableQuery, self).__init__(cache_key) - self._cache_key: str = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None + self._db_for_write = False def _clone(self) -> PreparedCountQuery: query = self.__class__( @@ -2367,67 +2200,6 @@ def _clone(self) -> PreparedCountQuery: query._prepared = self._prepared return query - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError - - # TODO: big part of this method is duplicated with PreparedQuerySet - # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it - def prepared(self) -> PreparedCountQuery: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") - - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone() - - queryset._choose_db_if_not_chosen(False) - queryset._make_query() - - queryset._sql_cache = {} - _, params = queryset.query.get_parameterized_sql() - queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - - queryset._prepared = True - - self.model._meta.query_cache[self._cache_key] = queryset - - return queryset - - # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, - # move into separate class maybe - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - # TODO: cache also by database dialect - cache_key = "query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - async def execute(self, **params) -> int: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) @@ -2692,7 +2464,7 @@ async def _execute(self) -> list[Any] | tuple: return self._process_results(result) -class PreparedValuesListQuery(ValuesListQuery[SINGLE]): +class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQuery): __slots__ = ( "_cache_key", "_prepared", @@ -2721,18 +2493,13 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super().__init__( + super(PreparedValuesListQuery, self).__init__( model, db, q_objects, single, raise_does_not_exist, fields_for_select_list, limit, offset, distinct, orderings, flat, annotations, custom_filters, group_bys, force_indexes, use_indexes ) - - self._cache_key: str = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None + super(AwaitableQuery, self).__init__(cache_key) + self._db_for_write = False def _clone(self) -> PreparedValuesListQuery: query = self.__class__( @@ -2757,67 +2524,6 @@ def _clone(self) -> PreparedValuesListQuery: query._prepared = self._prepared return query - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError - - # TODO: big part of this method is duplicated with PreparedQuerySet - # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it - def prepared(self) -> PreparedValuesListQuery: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") - - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone() - - queryset._choose_db_if_not_chosen(False) - queryset._make_query() - - queryset._sql_cache = {} - _, params = queryset.query.get_parameterized_sql() - queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - - queryset._prepared = True - - self.model._meta.query_cache[self._cache_key] = queryset - - return queryset - - # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, - # move into separate class maybe - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - # TODO: cache also by database dialect - cache_key = "query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - async def execute(self, **params) -> list[Any] | tuple: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) @@ -2961,7 +2667,7 @@ async def _execute(self) -> list[dict] | dict: return self._process_results(result) -class PreparedValuesQuery(ValuesQuery[SINGLE]): +class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQuery): __slots__ = ( "_cache_key", "_prepared", @@ -2989,18 +2695,13 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super().__init__( + super(PreparedValuesQuery, self).__init__( model, db, q_objects, single, raise_does_not_exist, fields_for_select, limit, offset, distinct, orderings, annotations, custom_filters, group_bys, force_indexes, use_indexes, ) - - self._cache_key: str = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None + super(AwaitableQuery, self).__init__(cache_key) + self._db_for_write = False def _clone(self) -> PreparedValuesQuery: query = self.__class__( @@ -3024,67 +2725,6 @@ def _clone(self) -> PreparedValuesQuery: query._prepared = self._prepared return query - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError - - # TODO: big part of this method is duplicated with PreparedQuerySet - # (and almost 1:1 with PreparedUpdateQuery), de-duplicate it - def prepared(self) -> PreparedValuesQuery: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") - - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone() - - queryset._choose_db_if_not_chosen(False) - queryset._make_query() - - queryset._sql_cache = {} - _, params = queryset.query.get_parameterized_sql() - queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - - queryset._prepared = True - - self.model._meta.query_cache[self._cache_key] = queryset - - return queryset - - # TODO: this is a copy of PreparedQuerySet._get_or_create_cached_sql, - # move into separate class maybe - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - # TODO: cache also by database dialect - cache_key = "query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - async def execute(self, **params) -> list[Any] | tuple: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) From e4d142ba9ce4ee74d3888a9f3b55dacc40acbf26 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 11:57:55 +0200 Subject: [PATCH 19/57] support updating foreign keys in PreparedUpdateQuery --- tests/test_queryset_prepared.py | 35 +- tortoise/models.py | 2 +- tortoise/parameter.py | 14 +- tortoise/queryset.py | 818 +------------------------------ tortoise/queryset_prepared.py | 834 ++++++++++++++++++++++++++++++++ 5 files changed, 897 insertions(+), 806 deletions(-) create mode 100644 tortoise/queryset_prepared.py diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index b71eaf9dd..2a9e80ccf 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -1,8 +1,8 @@ from tests.testmodels import ( - Author, + Author, Book, ) from tortoise.contrib import test -from tortoise.exceptions import ParamsError +from tortoise.exceptions import ParamsError, ValidationError from tortoise.expressions import Subquery, Q from tortoise.parameter import Parameter @@ -264,3 +264,34 @@ async def test_values_list_only_id_field_flat(self): await prepared_ids_flat.execute(id=author.pk * 2), [], ) + + async def test_update_fk(self): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + + book = await Book.create(name="test", author=author1, rating=5) + + prepared = Book.prepare_sql("test_update_fk").filter( + id=Parameter("search_id") + ).update(author=Parameter("replace_author")).prepared() + + await prepared.execute(search_id=book.pk, replace_author=author2) + book = await Book.get(id=book.pk).select_related("author") + # await book.refresh_from_db(["author_id"]) + self.assertEqual(book.author, author2) + + await prepared.execute(search_id=book.pk, replace_author=author1) + book = await Book.get(id=book.pk).select_related("author") + # await book.refresh_from_db(["author_id"]) + self.assertEqual(book.author, author1) + + async def test_update_pk_invalid_obj(self): + author = await Author.create(name="1") + book = await Book.create(name="test", author=author, rating=5) + + prepared = Book.prepare_sql("test_update_pk_invalid_obj").filter( + id=Parameter("search_id") + ).update(author=Parameter("replace_author")).prepared() + + with self.assertRaises(ValidationError): + await prepared.execute(search_id=book.pk, replace_author="not an Author object") diff --git a/tortoise/models.py b/tortoise/models.py index 8d8b96caa..dd7080767 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -49,8 +49,8 @@ QuerySet, QuerySetSingle, RawSQLQuery, - PreparedQuerySet, ) +from tortoise.queryset_prepared import PreparedQuerySet from tortoise.router import router from tortoise.signals import Signals from tortoise.transactions import in_transaction diff --git a/tortoise/parameter.py b/tortoise/parameter.py index e78c85ee0..872b5aac2 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Self +from typing import Self, Callable, Any from tortoise.fields import Field from pypika_tortoise import SqlContext @@ -31,7 +31,7 @@ def copy(self: SqlContext, **kwargs) -> SqlContext: class Parameter: - __slots__ = ("name", "model", "value_encoder", "field_object", "encode",) + __slots__ = ("name", "model", "value_encoder", "field_object", "encode", "value_getter", "value_validator",) def __init__(self, name: str) -> None: self.name = name @@ -39,6 +39,8 @@ def __init__(self, name: str) -> None: self.value_encoder = None self.field_object: Field | None = None self.encode = None + self.value_getter: Callable[[Any], Any] | None = None + self.value_validator: Callable[[Any], Any] | None = None def clone(self) -> Self: new = self.__new__(self.__class__) @@ -47,10 +49,18 @@ def clone(self) -> Self: new.value_encoder = self.value_encoder new.field_object = self.field_object new.encode = self.encode + new.value_getter = self.value_getter + new.value_validator = self.value_validator return new def encode_value(self, value: ...) -> ...: + if self.value_validator is not None: + self.value_validator(value) + + if self.value_getter is not None: + value = self.value_getter(value) + encoded = value if self.value_encoder: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 113421ed8..41bb89563 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1,12 +1,11 @@ from __future__ import annotations -import functools import types from collections import defaultdict from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable from copy import copy -from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn, ParamSpec, \ - Sequence, Self +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, Sequence from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count @@ -29,7 +28,7 @@ RelationalField, ) from tortoise.filters import FilterInfoDict -from tortoise.parameter import Parameter, CollectionParameter, TortoiseSqlContext +from tortoise.parameter import Parameter from tortoise.query_utils import ( Prefetch, QueryModifier, @@ -46,6 +45,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model + from tortoise.queryset_prepared import PreparedQuerySet MODEL = TypeVar("MODEL", bound="Model") T_co = TypeVar("T_co", covariant=True) @@ -81,14 +81,6 @@ def values( ) -> ValuesQuery[Literal[True]]: ... # pragma: nocoverage -class PreparedQuerySetSingle(QuerySetSingle[T_co]): - def prepared(self) -> PreparedQuerySet[MODEL]: - ... - - async def execute(self, **params) -> list[MODEL]: - ... - - class AwaitableQuery(Generic[MODEL]): __slots__ = ( "query", @@ -1279,6 +1271,8 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: if key in self.model._meta.query_cache: return self.model._meta.query_cache[key] + from tortoise.queryset_prepared import PreparedQuerySet + # TODO: add some arg to _clone to override class? # to be able to to something like self._clone(PreparedQuerySet) queryset = PreparedQuerySet(self.model, key) @@ -1317,453 +1311,6 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: return queryset -class _PreparedQuery: - # __slots__ = ( - # "_cache_key", - # "_prepared", - # "_sql_cache", - # "_dynamic_params", - # "_dynamic_params_names", - # "_db_for_write" - # ) - - def __init__(self, cache_key: str) -> None: - self._cache_key: str = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = False - - def _clone(self) -> Self: - raise NotImplementedError - - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError("Querysets must only be prepared once") - - def prepared(self) -> Self: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") - - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone() - - queryset._choose_db_if_not_chosen(self._db_for_write) - queryset._make_query() - - queryset._sql_cache = {} - _, params = queryset.query.get_parameterized_sql() - queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - - queryset._prepared = True - - self.model._meta.query_cache[self._cache_key] = queryset - - return queryset - - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - # TODO: cache also by database dialect - cache_key = "query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - - async def execute(self, **params) -> int: - raise NotImplementedError - - -P = ParamSpec("P") -T = TypeVar("T") - - -def _disallow_queryset_methods_on_prepared_query(func: Callable[P, T]) -> Callable[P, T]: - @functools.wraps(func) - def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: - if self._prepared: - raise ValueError(f"Cannot call \"{func.__name__}\" on already prepared queryset.") - return func(self, *args, **kwargs) - - return decorated - - -class PreparedQuerySet(QuerySet[MODEL], _PreparedQuery): - __slots__ = ( - "_cache_key", - "_prepared", - "_custom_fields", - "_sql_cache", - "_executor", - "_dynamic_params", - "_dynamic_params_names", - ) - - def __init__(self, model: type[MODEL], cache_key: str) -> None: - super(PreparedQuerySet, self).__init__(model) - super(AwaitableQuery, self).__init__(cache_key) - - self._db_for_write = self._select_for_update - - self._custom_fields = None - self._executor = None - - def _clone(self) -> PreparedQuerySet[MODEL]: - queryset = super()._clone() - queryset._cache_key = self._cache_key - queryset._prepared = self._prepared - queryset._db_for_write = self._select_for_update - return cast(PreparedQuerySet, queryset) - - def prepared(self) -> PreparedQuerySet[MODEL]: - queryset = cast(PreparedQuerySet[MODEL], super().prepared()) - - queryset._custom_fields = list(self._annotations.keys()) - queryset._executor = queryset._db.executor_class( - model=queryset.model, - db=queryset._db, - prefetch_map=queryset._prefetch_map, - prefetch_queries=queryset._prefetch_queries, - select_related_idx=queryset._select_related_idx, - ) - - return queryset - - async def execute(self, **params) -> list[MODEL]: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - # TODO: re-create executor when database changes - instance_list = await self._executor.execute_select( - cached_query.sql, filled_params, - custom_fields=self._custom_fields, - ) - if self._single: - if len(instance_list) == 1: - return instance_list[0] - if not instance_list: - if self._raise_does_not_exist: - raise DoesNotExist(self.model) - return None # type: ignore - raise MultipleObjectsReturned(self.model) - return instance_list - - @_disallow_queryset_methods_on_prepared_query - def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().filter(*args, **kwargs)) - - @_disallow_queryset_methods_on_prepared_query - def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().exclude(*args, **kwargs)) - - @_disallow_queryset_methods_on_prepared_query - def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().order_by(*orderings)) - - @_disallow_queryset_methods_on_prepared_query - def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().latest(*orderings)) - - @_disallow_queryset_methods_on_prepared_query - def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().earliest(*orderings)) - - @staticmethod - def _validate_limit(value: int) -> int: - if value < 0: - raise ParamsError("Limit should be non-negative number") - return value - - @_disallow_queryset_methods_on_prepared_query - def limit(self, limit: int | Parameter) -> PreparedQuerySet[MODEL]: - if isinstance(limit, int) and limit < 0: - raise ParamsError("Limit should be non-negative number") - elif isinstance(limit, Parameter): - limit.encode = self._validate_limit - - queryset = self._clone() - queryset._limit = limit # type: ignore - return queryset - - @staticmethod - def _validate_offset(value: int) -> int: - if value < 0: - raise ParamsError("Offset should be non-negative number") - return value - - @_disallow_queryset_methods_on_prepared_query - def offset(self, offset: int) -> PreparedQuerySet[MODEL]: - if isinstance(offset, int) and offset < 0: - raise ParamsError("Offset should be non-negative number") - elif isinstance(offset, Parameter): - offset.encode = self._validate_offset - - queryset = self._clone() - queryset._offset = offset # type: ignore - if self.capabilities.requires_limit and queryset._limit is None: - queryset._limit = 1000000 - return queryset - - @_disallow_queryset_methods_on_prepared_query - def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().__getitem__(key)) - - @_disallow_queryset_methods_on_prepared_query - def distinct(self) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().distinct()) - - @_disallow_queryset_methods_on_prepared_query - def select_for_update( - self, - nowait: bool = False, - skip_locked: bool = False, - of: tuple[str, ...] = (), - no_key: bool = False, - ) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().select_for_update( - nowait, skip_locked, of, no_key - )) - - @_disallow_queryset_methods_on_prepared_query - def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().annotate(*kwargs)) - - @_disallow_queryset_methods_on_prepared_query - def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().group_by(*fields)) - - @_disallow_queryset_methods_on_prepared_query - def values_list(self, *fields_: str, flat: bool = False) -> PreparedValuesListQuery[Literal[False]]: - fields_for_select_list = self._get_fields_list_for_select(*fields_) - - return PreparedValuesListQuery( - db=self._db, - model=self.model, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - flat=flat, - fields_for_select_list=fields_for_select_list, - distinct=self._distinct, - limit=self._limit, - offset=self._offset, - orderings=self._orderings, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - - @_disallow_queryset_methods_on_prepared_query - def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: - fields_for_select = self._get_fields_for_select(*args, **kwargs) - - return PreparedValuesQuery( - db=self._db, - model=self.model, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - fields_for_select=fields_for_select, - distinct=self._distinct, - limit=self._limit, - offset=self._offset, - orderings=self._orderings, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - - @_disallow_queryset_methods_on_prepared_query - def delete(self) -> DeleteQuery: - return PreparedDeleteQuery( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - cache_key=self._cache_key, - ) - - @_disallow_queryset_methods_on_prepared_query - def update(self, **kwargs: Any) -> PreparedUpdateQuery: - return PreparedUpdateQuery( - model=self.model, - update_kwargs=kwargs, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - cache_key=self._cache_key, - ) - - @_disallow_queryset_methods_on_prepared_query - def count(self) -> PreparedCountQuery: - return PreparedCountQuery( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - offset=self._offset, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - - @_disallow_queryset_methods_on_prepared_query - def exists(self) -> PreparedExistsQuery: - return PreparedExistsQuery( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - - @_disallow_queryset_methods_on_prepared_query - def all(self) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().all()) - - @_disallow_queryset_methods_on_prepared_query - def first(self) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().first()) - - @_disallow_queryset_methods_on_prepared_query - def last(self) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().last()) - - @_disallow_queryset_methods_on_prepared_query - def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: - return cast(PreparedQuerySetSingle, super().get(*args, **kwargs)) - - async def in_bulk(self, id_list: Iterable[str | int], field_name: str) -> dict[str, MODEL]: - raise NotImplementedError("Prepared queries don't support in_bulk.") - - def bulk_create( - self, - objects: Iterable[MODEL], - batch_size: int | None = None, - ignore_conflicts: bool = False, - update_fields: Iterable[str] | None = None, - on_conflict: Iterable[str] | None = None, - ) -> BulkCreateQuery[MODEL]: - raise NotImplementedError("Prepared queries don't support bulk_create.") - - def bulk_update( - self, - objects: Iterable[MODEL], - fields: Iterable[str], - batch_size: int | None = None, - ) -> BulkUpdateQuery[MODEL]: - raise NotImplementedError("Prepared queries don't support bulk_update.") - - @_disallow_queryset_methods_on_prepared_query - def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().get_or_none(*args, **kwargs)) - - @_disallow_queryset_methods_on_prepared_query - def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().only(*fields_for_select)) - - @_disallow_queryset_methods_on_prepared_query - def select_related(self, *fields: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().select_related(*fields)) - - @_disallow_queryset_methods_on_prepared_query - def force_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().force_index(*index_names)) - - @_disallow_queryset_methods_on_prepared_query - def use_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().use_index(*index_names)) - - @_disallow_queryset_methods_on_prepared_query - def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().prefetch_related(*args)) - - -class CachedSql: - def __init__(self, sql: str, params: list[Parameter | Any]) -> None: - self.sql = sql - self.params = params - self.param_by_name: dict[str, Parameter] = {} - self.need_params: dict[str, int] = {} - self.need_collection_params: dict[str, list[int]] = defaultdict(list) - for idx, param in enumerate(params): - if not isinstance(param, Parameter): - continue - - if param.name not in self.param_by_name: - self.param_by_name[param.name] = param - - if isinstance(param, CollectionParameter): - self.need_collection_params[param.name].append(idx) - else: - self.need_params[param.name] = idx - - def make_filled_params(self, params: dict[str, Any]) -> list[Any]: - # TODO: check for parameters mismatch - - filled_params = self.params.copy() - for name, idx in self.need_params.items(): - param = self.param_by_name[name] - filled_params[idx] = param.encode_value(params[name]) - - for name, indexes in self.need_collection_params.items(): - param = cast(CollectionParameter, self.param_by_name[name]) - collection = param.encode_collection(params[name]) - if len(collection) != len(indexes): - raise ValueError( - f"Provided value length ({len(collection)}) " - f"for parameter {name!r} does not match " - f"parameter indexes length ({len(indexes)})" - ) - for idx, value in zip(indexes, collection): - filled_params[idx] = param.encode_value(value) - - return filled_params - - class UpdateQuery(AwaitableQuery): __slots__ = ( "update_kwargs", @@ -1809,14 +1356,19 @@ def _make_query(self) -> None: if field_object.generated: raise IntegrityError(f"Field {key} is generated and can not be updated") if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)): - self.model._validate_relation_type(key, value) fk_field: str = field_object.source_field # type: ignore db_field = self.model._meta.fields_map[fk_field].source_field - # TODO: support Parameters in here - value = self.model._meta.fields_map[fk_field].to_db_value( - getattr(value, field_object.to_field_instance.model_field_name), - None, - ) + + if isinstance(value, Parameter): + value.field_object = self.model._meta.fields_map[fk_field] + value.value_getter = attrgetter(field_object.to_field_instance.model_field_name) + value.value_validator = lambda val: self.model._validate_relation_type(key, val) + else: + self.model._validate_relation_type(key, value) + value = self.model._meta.fields_map[fk_field].to_db_value( + getattr(value, field_object.to_field_instance.model_field_name), + None, + ) else: try: db_field = self.model._meta.fields_db_projection[key] @@ -1850,55 +1402,6 @@ async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] -class PreparedUpdateQuery(UpdateQuery, _PreparedQuery): - __slots__ = ( - "_cache_key", - "_prepared", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - ) - - def __init__( - self, - model: type[MODEL], - update_kwargs: dict[str, Any], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - orderings: list[tuple[str, str]], - cache_key: str, - ) -> None: - super(PreparedUpdateQuery, self).__init__( - model, update_kwargs, db, q_objects, annotations, custom_filters, limit, orderings, - ) - super(AwaitableQuery, self).__init__(cache_key) - - self._db_for_write = True - - def _clone(self) -> PreparedUpdateQuery[MODEL]: - query = self.__class__( - model=self.model, - update_kwargs=self.update_kwargs, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - cache_key=self._cache_key, - ) - query._prepared = self._prepared - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - return (await self._db.execute_query(cached_query.sql, filled_params))[0] - - class DeleteQuery(AwaitableQuery): __slots__ = ( "_annotations", @@ -1948,53 +1451,6 @@ async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] -class PreparedDeleteQuery(DeleteQuery, _PreparedQuery): - __slots__ = ( - "_cache_key", - "_prepared", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - orderings: list[tuple[str, str]], - cache_key: str, - ) -> None: - super(PreparedDeleteQuery, self).__init__( - model, db, q_objects, annotations, custom_filters, limit, orderings, - ) - super(AwaitableQuery, self).__init__(cache_key) - - self._db_for_write = True - - def _clone(self) -> PreparedDeleteQuery[MODEL]: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - cache_key=self._cache_key, - ) - query._prepared = self._prepared - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - return (await self._db.execute_query(cached_query.sql, filled_params))[0] - - class ExistsQuery(AwaitableQuery): __slots__ = ( "_force_indexes", @@ -2044,54 +1500,6 @@ async def _execute( return bool(result) -class PreparedExistsQuery(ExistsQuery, _PreparedQuery): - __slots__ = ( - "_cache_key", - "_prepared", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, - ) -> None: - super(PreparedExistsQuery, self).__init__( - model, db, q_objects, annotations, custom_filters, force_indexes, use_indexes, - ) - super(AwaitableQuery, self).__init__(cache_key) - - self._db_for_write = False - - def _clone(self) -> PreparedExistsQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - result, _ = await self._db.execute_query(cached_query.sql, filled_params) - return bool(result) - - class CountQuery(AwaitableQuery): __slots__ = ( "_limit", @@ -2155,64 +1563,6 @@ async def _execute(self) -> int: return count -class PreparedCountQuery(CountQuery, _PreparedQuery): - __slots__ = ( - "_cache_key", - "_prepared", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - offset: int | None, - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, - ) -> None: - super(PreparedCountQuery, self).__init__( - model, db, q_objects, annotations, custom_filters, limit, offset, force_indexes, use_indexes, - ) - super(AwaitableQuery, self).__init__(cache_key) - - self._db_for_write = False - - def _clone(self) -> PreparedCountQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - offset=self._offset, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - _, result = await self._db.execute_query(cached_query.sql, filled_params) - if not result: - return 0 - count = list(dict(result[0]).values())[0] - self._offset - if self._limit and count > self._limit: - return self._limit - return count - - class FieldSelectQuery(AwaitableQuery): # pylint: disable=W0223 @@ -2464,74 +1814,6 @@ async def _execute(self) -> list[Any] | tuple: return self._process_results(result) -class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQuery): - __slots__ = ( - "_cache_key", - "_prepared", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - single: bool, - raise_does_not_exist: bool, - fields_for_select_list: tuple[str, ...] | list[str], - limit: int | None, - offset: int | None, - distinct: bool, - orderings: list[tuple[str, str]], - flat: bool, - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - group_bys: tuple[str, ...], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, - ) -> None: - super(PreparedValuesListQuery, self).__init__( - model, db, q_objects, single, raise_does_not_exist, fields_for_select_list, limit, - offset, distinct, orderings, flat, annotations, custom_filters, group_bys, - force_indexes, use_indexes - ) - super(AwaitableQuery, self).__init__(cache_key) - self._db_for_write = False - - def _clone(self) -> PreparedValuesListQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - fields_for_select_list=self._fields_for_select_list, - limit=self._limit, - offset=self._offset, - distinct=self._distinct, - orderings=self._orderings, - flat=self._flat, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared - return query - - async def execute(self, **params) -> list[Any] | tuple: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - _, result = await self._db.execute_query(cached_query.sql, filled_params) - return self._process_results(result) - - class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( "_fields_for_select", @@ -2667,72 +1949,6 @@ async def _execute(self) -> list[dict] | dict: return self._process_results(result) -class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQuery): - __slots__ = ( - "_cache_key", - "_prepared", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - single: bool, - raise_does_not_exist: bool, - fields_for_select: dict[str, str], - limit: int | None, - offset: int | None, - distinct: bool, - orderings: list[tuple[str, str]], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - group_bys: tuple[str, ...], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, - ) -> None: - super(PreparedValuesQuery, self).__init__( - model, db, q_objects, single, raise_does_not_exist, fields_for_select, limit, - offset, distinct, orderings, annotations, custom_filters, group_bys, - force_indexes, use_indexes, - ) - super(AwaitableQuery, self).__init__(cache_key) - self._db_for_write = False - - def _clone(self) -> PreparedValuesQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - fields_for_select=self._fields_for_select, - limit=self._limit, - offset=self._offset, - distinct=self._distinct, - orderings=self._orderings, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared - return query - - async def execute(self, **params) -> list[Any] | tuple: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - result = await self._db.execute_query_dict(cached_query.sql, filled_params) - return self._process_results(result) - - class RawSQLQuery(AwaitableQuery): __slots__ = ("_sql", "_db") diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py new file mode 100644 index 000000000..cca7dab77 --- /dev/null +++ b/tortoise/queryset_prepared.py @@ -0,0 +1,834 @@ +from __future__ import annotations + +import functools +import types +from collections import defaultdict +from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable +from copy import copy +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn, ParamSpec, \ + Sequence, Self + +from pypika_tortoise import JoinType, Order, Table +from pypika_tortoise.analytics import Count +from pypika_tortoise.functions import Cast +from pypika_tortoise.queries import QueryBuilder +from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper + +from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities +from tortoise.exceptions import ( + DoesNotExist, + FieldError, + IntegrityError, + MultipleObjectsReturned, + ParamsError, +) +from tortoise.expressions import Expression, Q, RawSQL, ResolveContext, ResolveResult +from tortoise.fields.relational import ( + ForeignKeyFieldInstance, + OneToOneFieldInstance, + RelationalField, +) +from tortoise.filters import FilterInfoDict +from tortoise.parameter import Parameter, CollectionParameter, TortoiseSqlContext +from tortoise.query_utils import ( + Prefetch, + QueryModifier, + TableCriterionTuple, + expand_lookup_expression, + get_joins_for_related_field, +) +from tortoise.queryset import QuerySet, MODEL, AwaitableQuery, QuerySetSingle, T_co, DeleteQuery, BulkCreateQuery, \ + BulkUpdateQuery, UpdateQuery, ExistsQuery, CountQuery, ValuesListQuery, ValuesQuery, SINGLE +from tortoise.router import router +from tortoise.utils import chunk + + +class PreparedQuerySetSingle(QuerySetSingle[T_co]): + def prepared(self) -> PreparedQuerySet[MODEL]: + ... + + async def execute(self, **params) -> list[MODEL]: + ... + + +class _PreparedQuery: + # __slots__ = ( + # "_cache_key", + # "_prepared", + # "_sql_cache", + # "_dynamic_params", + # "_dynamic_params_names", + # "_db_for_write" + # ) + + def __init__(self, cache_key: str) -> None: + self._cache_key: str = cache_key + self._prepared: bool = False + + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None + self._db_for_write = False + + def _clone(self) -> Self: + raise NotImplementedError + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError("Querysets must only be prepared once") + + def prepared(self) -> Self: + if self._cache_key is None: + raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + + if self._cache_key in self.model._meta.query_cache: + return self.model._meta.query_cache[self._cache_key] + + queryset = self._clone() + + queryset._choose_db_if_not_chosen(self._db_for_write) + queryset._make_query() + + queryset._sql_cache = {} + _, params = queryset.query.get_parameterized_sql() + queryset._dynamic_params = { + param.name: param + for param in params + if isinstance(param, CollectionParameter) + } + queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) + + queryset._prepared = True + + self.model._meta.query_cache[self._cache_key] = queryset + + return queryset + + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + reset_params = [] + + # TODO: cache also by database dialect + cache_key = "query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + sql, params = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + async def execute(self, **params) -> int: + raise NotImplementedError + + +P = ParamSpec("P") +T = TypeVar("T") + + +def _disallow_queryset_methods_on_prepared_query(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) + def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: + if self._prepared: + raise ValueError(f"Cannot call \"{func.__name__}\" on already prepared queryset.") + return func(self, *args, **kwargs) + + return decorated + + +class PreparedQuerySet(QuerySet[MODEL], _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_custom_fields", + "_sql_cache", + "_executor", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__(self, model: type[MODEL], cache_key: str) -> None: + super(PreparedQuerySet, self).__init__(model) + super(AwaitableQuery, self).__init__(cache_key) + + self._db_for_write = self._select_for_update + + self._custom_fields = None + self._executor = None + + def _clone(self) -> PreparedQuerySet[MODEL]: + queryset = super()._clone() + queryset._cache_key = self._cache_key + queryset._prepared = self._prepared + queryset._db_for_write = self._select_for_update + return cast(PreparedQuerySet, queryset) + + def prepared(self) -> PreparedQuerySet[MODEL]: + queryset = cast(PreparedQuerySet[MODEL], super().prepared()) + + queryset._custom_fields = list(self._annotations.keys()) + queryset._executor = queryset._db.executor_class( + model=queryset.model, + db=queryset._db, + prefetch_map=queryset._prefetch_map, + prefetch_queries=queryset._prefetch_queries, + select_related_idx=queryset._select_related_idx, + ) + + return queryset + + async def execute(self, **params) -> list[MODEL]: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + # TODO: re-create executor when database changes + instance_list = await self._executor.execute_select( + cached_query.sql, filled_params, + custom_fields=self._custom_fields, + ) + if self._single: + if len(instance_list) == 1: + return instance_list[0] + if not instance_list: + if self._raise_does_not_exist: + raise DoesNotExist(self.model) + return None # type: ignore + raise MultipleObjectsReturned(self.model) + return instance_list + + @_disallow_queryset_methods_on_prepared_query + def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().filter(*args, **kwargs)) + + @_disallow_queryset_methods_on_prepared_query + def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().exclude(*args, **kwargs)) + + @_disallow_queryset_methods_on_prepared_query + def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().order_by(*orderings)) + + @_disallow_queryset_methods_on_prepared_query + def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().latest(*orderings)) + + @_disallow_queryset_methods_on_prepared_query + def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().earliest(*orderings)) + + @staticmethod + def _validate_limit(value: int) -> int: + if value < 0: + raise ParamsError("Limit should be non-negative number") + return value + + @_disallow_queryset_methods_on_prepared_query + def limit(self, limit: int | Parameter) -> PreparedQuerySet[MODEL]: + if isinstance(limit, int) and limit < 0: + raise ParamsError("Limit should be non-negative number") + elif isinstance(limit, Parameter): + limit.encode = self._validate_limit + + queryset = self._clone() + queryset._limit = limit # type: ignore + return queryset + + @staticmethod + def _validate_offset(value: int) -> int: + if value < 0: + raise ParamsError("Offset should be non-negative number") + return value + + @_disallow_queryset_methods_on_prepared_query + def offset(self, offset: int) -> PreparedQuerySet[MODEL]: + if isinstance(offset, int) and offset < 0: + raise ParamsError("Offset should be non-negative number") + elif isinstance(offset, Parameter): + offset.encode = self._validate_offset + + queryset = self._clone() + queryset._offset = offset # type: ignore + if self.capabilities.requires_limit and queryset._limit is None: + queryset._limit = 1000000 + return queryset + + @_disallow_queryset_methods_on_prepared_query + def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().__getitem__(key)) + + @_disallow_queryset_methods_on_prepared_query + def distinct(self) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().distinct()) + + @_disallow_queryset_methods_on_prepared_query + def select_for_update( + self, + nowait: bool = False, + skip_locked: bool = False, + of: tuple[str, ...] = (), + no_key: bool = False, + ) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().select_for_update( + nowait, skip_locked, of, no_key + )) + + @_disallow_queryset_methods_on_prepared_query + def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().annotate(**kwargs)) + + @_disallow_queryset_methods_on_prepared_query + def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().group_by(*fields)) + + @_disallow_queryset_methods_on_prepared_query + def values_list(self, *fields_: str, flat: bool = False) -> PreparedValuesListQuery[Literal[False]]: + fields_for_select_list = self._get_fields_list_for_select(*fields_) + + return PreparedValuesListQuery( + db=self._db, + model=self.model, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + flat=flat, + fields_for_select_list=fields_for_select_list, + distinct=self._distinct, + limit=self._limit, + offset=self._offset, + orderings=self._orderings, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + + @_disallow_queryset_methods_on_prepared_query + def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: + fields_for_select = self._get_fields_for_select(*args, **kwargs) + + return PreparedValuesQuery( + db=self._db, + model=self.model, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select=fields_for_select, + distinct=self._distinct, + limit=self._limit, + offset=self._offset, + orderings=self._orderings, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + + @_disallow_queryset_methods_on_prepared_query + def delete(self) -> DeleteQuery: + return PreparedDeleteQuery( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) + + @_disallow_queryset_methods_on_prepared_query + def update(self, **kwargs: Any) -> PreparedUpdateQuery: + return PreparedUpdateQuery( + model=self.model, + update_kwargs=kwargs, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) + + @_disallow_queryset_methods_on_prepared_query + def count(self) -> PreparedCountQuery: + return PreparedCountQuery( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + offset=self._offset, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + + @_disallow_queryset_methods_on_prepared_query + def exists(self) -> PreparedExistsQuery: + return PreparedExistsQuery( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + + @_disallow_queryset_methods_on_prepared_query + def all(self) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().all()) + + @_disallow_queryset_methods_on_prepared_query + def first(self) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().first()) + + @_disallow_queryset_methods_on_prepared_query + def last(self) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().last()) + + @_disallow_queryset_methods_on_prepared_query + def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: + return cast(PreparedQuerySetSingle, super().get(*args, **kwargs)) + + async def in_bulk(self, id_list: Iterable[str | int], field_name: str) -> dict[str, MODEL]: + raise NotImplementedError("Prepared queries don't support in_bulk.") + + def bulk_create( + self, + objects: Iterable[MODEL], + batch_size: int | None = None, + ignore_conflicts: bool = False, + update_fields: Iterable[str] | None = None, + on_conflict: Iterable[str] | None = None, + ) -> BulkCreateQuery[MODEL]: + raise NotImplementedError("Prepared queries don't support bulk_create.") + + def bulk_update( + self, + objects: Iterable[MODEL], + fields: Iterable[str], + batch_size: int | None = None, + ) -> BulkUpdateQuery[MODEL]: + raise NotImplementedError("Prepared queries don't support bulk_update.") + + @_disallow_queryset_methods_on_prepared_query + def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, super().get_or_none(*args, **kwargs)) + + @_disallow_queryset_methods_on_prepared_query + def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().only(*fields_for_select)) + + @_disallow_queryset_methods_on_prepared_query + def select_related(self, *fields: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().select_related(*fields)) + + @_disallow_queryset_methods_on_prepared_query + def force_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().force_index(*index_names)) + + @_disallow_queryset_methods_on_prepared_query + def use_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().use_index(*index_names)) + + @_disallow_queryset_methods_on_prepared_query + def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, super().prefetch_related(*args)) + + +class CachedSql: + def __init__(self, sql: str, params: list[Parameter | Any]) -> None: + self.sql = sql + self.params = params + self.param_by_name: dict[str, Parameter] = {} + self.need_params: dict[str, int] = {} + self.need_collection_params: dict[str, list[int]] = defaultdict(list) + for idx, param in enumerate(params): + if not isinstance(param, Parameter): + continue + + if param.name not in self.param_by_name: + self.param_by_name[param.name] = param + + if isinstance(param, CollectionParameter): + self.need_collection_params[param.name].append(idx) + else: + self.need_params[param.name] = idx + + def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + # TODO: check for parameters mismatch + + filled_params = self.params.copy() + for name, idx in self.need_params.items(): + param = self.param_by_name[name] + filled_params[idx] = param.encode_value(params[name]) + + for name, indexes in self.need_collection_params.items(): + param = cast(CollectionParameter, self.param_by_name[name]) + collection = param.encode_collection(params[name]) + if len(collection) != len(indexes): + raise ValueError( + f"Provided value length ({len(collection)}) " + f"for parameter {name!r} does not match " + f"parameter indexes length ({len(indexes)})" + ) + for idx, value in zip(indexes, collection): + filled_params[idx] = param.encode_value(value) + + return filled_params + + +class PreparedUpdateQuery(UpdateQuery, _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + update_kwargs: dict[str, Any], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + orderings: list[tuple[str, str]], + cache_key: str, + ) -> None: + super(PreparedUpdateQuery, self).__init__( + model, update_kwargs, db, q_objects, annotations, custom_filters, limit, orderings, + ) + super(AwaitableQuery, self).__init__(cache_key) + + self._db_for_write = True + + def _clone(self) -> PreparedUpdateQuery[MODEL]: + query = self.__class__( + model=self.model, + update_kwargs=self.update_kwargs, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + +class PreparedDeleteQuery(DeleteQuery, _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + orderings: list[tuple[str, str]], + cache_key: str, + ) -> None: + super(PreparedDeleteQuery, self).__init__( + model, db, q_objects, annotations, custom_filters, limit, orderings, + ) + super(AwaitableQuery, self).__init__(cache_key) + + self._db_for_write = True + + def _clone(self) -> PreparedDeleteQuery[MODEL]: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + orderings=self._orderings, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + +class PreparedExistsQuery(ExistsQuery, _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super(PreparedExistsQuery, self).__init__( + model, db, q_objects, annotations, custom_filters, force_indexes, use_indexes, + ) + super(AwaitableQuery, self).__init__(cache_key) + + self._db_for_write = False + + def _clone(self) -> PreparedExistsQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + result, _ = await self._db.execute_query(cached_query.sql, filled_params) + return bool(result) + + +class PreparedCountQuery(CountQuery, _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + offset: int | None, + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super(PreparedCountQuery, self).__init__( + model, db, q_objects, annotations, custom_filters, limit, offset, force_indexes, use_indexes, + ) + super(AwaitableQuery, self).__init__(cache_key) + + self._db_for_write = False + + def _clone(self) -> PreparedCountQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + annotations=self._annotations, + custom_filters=self._custom_filters, + limit=self._limit, + offset=self._offset, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + async def execute(self, **params) -> int: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + _, result = await self._db.execute_query(cached_query.sql, filled_params) + if not result: + return 0 + count = list(dict(result[0]).values())[0] - self._offset + if self._limit and count > self._limit: + return self._limit + return count + + +class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + single: bool, + raise_does_not_exist: bool, + fields_for_select_list: tuple[str, ...] | list[str], + limit: int | None, + offset: int | None, + distinct: bool, + orderings: list[tuple[str, str]], + flat: bool, + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + group_bys: tuple[str, ...], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super(PreparedValuesListQuery, self).__init__( + model, db, q_objects, single, raise_does_not_exist, fields_for_select_list, limit, + offset, distinct, orderings, flat, annotations, custom_filters, group_bys, + force_indexes, use_indexes + ) + super(AwaitableQuery, self).__init__(cache_key) + self._db_for_write = False + + def _clone(self) -> PreparedValuesListQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select_list=self._fields_for_select_list, + limit=self._limit, + offset=self._offset, + distinct=self._distinct, + orderings=self._orderings, + flat=self._flat, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + async def execute(self, **params) -> list[Any] | tuple: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + _, result = await self._db.execute_query(cached_query.sql, filled_params) + return self._process_results(result) + + +class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQuery): + __slots__ = ( + "_cache_key", + "_prepared", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + single: bool, + raise_does_not_exist: bool, + fields_for_select: dict[str, str], + limit: int | None, + offset: int | None, + distinct: bool, + orderings: list[tuple[str, str]], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + group_bys: tuple[str, ...], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, + ) -> None: + super(PreparedValuesQuery, self).__init__( + model, db, q_objects, single, raise_does_not_exist, fields_for_select, limit, + offset, distinct, orderings, annotations, custom_filters, group_bys, + force_indexes, use_indexes, + ) + super(AwaitableQuery, self).__init__(cache_key) + self._db_for_write = False + + def _clone(self) -> PreparedValuesQuery: + query = self.__class__( + model=self.model, + db=self._db, + q_objects=self._q_objects, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select=self._fields_for_select, + limit=self._limit, + offset=self._offset, + distinct=self._distinct, + orderings=self._orderings, + annotations=self._annotations, + custom_filters=self._custom_filters, + group_bys=self._group_bys, + force_indexes=self._force_indexes, + use_indexes=self._use_indexes, + cache_key=self._cache_key, + ) + query._prepared = self._prepared + return query + + async def execute(self, **params) -> list[Any] | tuple: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + result = await self._db.execute_query_dict(cached_query.sql, filled_params) + return self._process_results(result) From f192c7a227aaf094f39ecff3de7124b98e146228 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 12:48:09 +0200 Subject: [PATCH 20/57] move prepared queryset fields initialisation from mixin to prepared-query classes --- tortoise/queryset.py | 47 +++------ tortoise/queryset_prepared.py | 174 +++++++++++++++++----------------- 2 files changed, 97 insertions(+), 124 deletions(-) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 41bb89563..3c79c6395 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -362,8 +362,10 @@ def __init__(self, model: type[MODEL]) -> None: self._force_indexes: set[str] = set() self._use_indexes: set[str] = set() - def _clone(self) -> QuerySet[MODEL]: - queryset = self.__class__.__new__(self.__class__) + def _clone(self, _new_cls: type[QuerySet] | None = None) -> QuerySet[MODEL]: + if _new_cls is None: + _new_cls = self.__class__ + queryset = _new_cls.__new__(_new_cls) queryset.fields = self.fields queryset.model = self.model queryset.query = self.query @@ -1273,42 +1275,15 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: from tortoise.queryset_prepared import PreparedQuerySet - # TODO: add some arg to _clone to override class? - # to be able to to something like self._clone(PreparedQuerySet) - queryset = PreparedQuerySet(self.model, key) - queryset.fields = self.fields - queryset.model = self.model - queryset.query = self.query - queryset.capabilities = self.capabilities - queryset._prefetch_map = copy(self._prefetch_map) - queryset._prefetch_queries = copy(self._prefetch_queries) - queryset._single = self._single - queryset._raise_does_not_exist = self._raise_does_not_exist - queryset._db = self._db - queryset._limit = self._limit - queryset._offset = self._offset - queryset._fields_for_select = self._fields_for_select - queryset._filter_kwargs = copy(self._filter_kwargs) - queryset._orderings = copy(self._orderings) - queryset._joined_tables = copy(self._joined_tables) - queryset._q_objects = copy(self._q_objects) - queryset._distinct = self._distinct - queryset._annotations = copy(self._annotations) - queryset._having = copy(self._having) - queryset._custom_filters = copy(self._custom_filters) - queryset._group_bys = copy(self._group_bys) - queryset._select_for_update = self._select_for_update - queryset._select_for_update_nowait = self._select_for_update_nowait - queryset._select_for_update_skip_locked = self._select_for_update_skip_locked - queryset._select_for_update_of = self._select_for_update_of - queryset._select_for_update_no_key = self._select_for_update_no_key - queryset._select_related = self._select_related - queryset._select_related_idx = self._select_related_idx - queryset._force_indexes = self._force_indexes - queryset._use_indexes = self._use_indexes + queryset = self._clone(PreparedQuerySet) queryset._cache_key = key + queryset._prepared = False + queryset._sql_cache = None + queryset._dynamic_params = None + queryset._dynamic_params_names = None + queryset._db_for_write = self._select_for_update - return queryset + return cast(PreparedQuerySet[MODEL], queryset) class UpdateQuery(AwaitableQuery): diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index cca7dab77..ff18d0f48 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -1,49 +1,24 @@ from __future__ import annotations import functools -import types +from abc import ABC from collections import defaultdict -from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable -from copy import copy -from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, NoReturn, ParamSpec, \ - Sequence, Self - -from pypika_tortoise import JoinType, Order, Table -from pypika_tortoise.analytics import Count -from pypika_tortoise.functions import Cast -from pypika_tortoise.queries import QueryBuilder -from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper - -from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities -from tortoise.exceptions import ( - DoesNotExist, - FieldError, - IntegrityError, - MultipleObjectsReturned, - ParamsError, -) -from tortoise.expressions import Expression, Q, RawSQL, ResolveContext, ResolveResult -from tortoise.fields.relational import ( - ForeignKeyFieldInstance, - OneToOneFieldInstance, - RelationalField, -) +from collections.abc import Callable, Iterable +from typing import Any, Literal, TypeVar, cast, NoReturn, ParamSpec, Self, Protocol + +from pypika_tortoise.terms import Term + +from tortoise.backends.base.client import BaseDBAsyncClient +from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned, ParamsError +from tortoise.expressions import Expression, Q from tortoise.filters import FilterInfoDict from tortoise.parameter import Parameter, CollectionParameter, TortoiseSqlContext -from tortoise.query_utils import ( - Prefetch, - QueryModifier, - TableCriterionTuple, - expand_lookup_expression, - get_joins_for_related_field, -) -from tortoise.queryset import QuerySet, MODEL, AwaitableQuery, QuerySetSingle, T_co, DeleteQuery, BulkCreateQuery, \ - BulkUpdateQuery, UpdateQuery, ExistsQuery, CountQuery, ValuesListQuery, ValuesQuery, SINGLE -from tortoise.router import router -from tortoise.utils import chunk - - -class PreparedQuerySetSingle(QuerySetSingle[T_co]): +from tortoise.query_utils import Prefetch +from tortoise.queryset import QuerySet, MODEL, QuerySetSingle, T_co, DeleteQuery, BulkCreateQuery, \ + BulkUpdateQuery, UpdateQuery, ExistsQuery, CountQuery, ValuesListQuery, ValuesQuery, SINGLE, AwaitableQuery + + +class PreparedQuerySetSingle(QuerySetSingle[T_co], Protocol): def prepared(self) -> PreparedQuerySet[MODEL]: ... @@ -51,30 +26,21 @@ async def execute(self, **params) -> list[MODEL]: ... -class _PreparedQuery: - # __slots__ = ( - # "_cache_key", - # "_prepared", - # "_sql_cache", - # "_dynamic_params", - # "_dynamic_params_names", - # "_db_for_write" - # ) +class _PreparedQueryMixin(AwaitableQuery, ABC): + _cache_key: str + _prepared: bool + _sql_cache: dict[str, CachedSql] | None + _dynamic_params: dict[str, CollectionParameter] | None + _dynamic_params_names: list[str] | None + _db_for_write: bool - def __init__(self, cache_key: str) -> None: - self._cache_key: str = cache_key - self._prepared: bool = False - - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = False + __slots__ = () def _clone(self) -> Self: raise NotImplementedError def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError("Querysets must only be prepared once") + raise NotImplementedError("QuerySets must only be prepared once") def prepared(self) -> Self: if self._cache_key is None: @@ -148,7 +114,7 @@ def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: return decorated -class PreparedQuerySet(QuerySet[MODEL], _PreparedQuery): +class PreparedQuerySet(QuerySet[MODEL], _PreparedQueryMixin): __slots__ = ( "_cache_key", "_prepared", @@ -157,22 +123,28 @@ class PreparedQuerySet(QuerySet[MODEL], _PreparedQuery): "_executor", "_dynamic_params", "_dynamic_params_names", + "_db_for_write", ) def __init__(self, model: type[MODEL], cache_key: str) -> None: - super(PreparedQuerySet, self).__init__(model) - super(AwaitableQuery, self).__init__(cache_key) - + super().__init__(model) + self._cache_key: str = cache_key + self._prepared: bool = False + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None self._db_for_write = self._select_for_update - self._custom_fields = None self._executor = None - def _clone(self) -> PreparedQuerySet[MODEL]: - queryset = super()._clone() + def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparedQuerySet[MODEL]: + queryset = super()._clone(_new_cls) queryset._cache_key = self._cache_key queryset._prepared = self._prepared - queryset._db_for_write = self._select_for_update + queryset._sql_cache = self._sql_cache + queryset._dynamic_params = self._dynamic_params + queryset._dynamic_params_names = self._dynamic_params_names + queryset._db_for_write = self._db_for_write return cast(PreparedQuerySet, queryset) def prepared(self) -> PreparedQuerySet[MODEL]: @@ -498,13 +470,14 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: return filled_params -class PreparedUpdateQuery(UpdateQuery, _PreparedQuery): +class PreparedUpdateQuery(UpdateQuery, _PreparedQueryMixin): __slots__ = ( "_cache_key", "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", + "_db_for_write", ) def __init__( @@ -519,11 +492,14 @@ def __init__( orderings: list[tuple[str, str]], cache_key: str, ) -> None: - super(PreparedUpdateQuery, self).__init__( + super().__init__( model, update_kwargs, db, q_objects, annotations, custom_filters, limit, orderings, ) - super(AwaitableQuery, self).__init__(cache_key) - + self._cache_key: str = cache_key + self._prepared: bool = False + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None self._db_for_write = True def _clone(self) -> PreparedUpdateQuery[MODEL]: @@ -547,13 +523,14 @@ async def execute(self, **params) -> int: return (await self._db.execute_query(cached_query.sql, filled_params))[0] -class PreparedDeleteQuery(DeleteQuery, _PreparedQuery): +class PreparedDeleteQuery(DeleteQuery, _PreparedQueryMixin): __slots__ = ( "_cache_key", "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", + "_db_for_write", ) def __init__( @@ -567,11 +544,14 @@ def __init__( orderings: list[tuple[str, str]], cache_key: str, ) -> None: - super(PreparedDeleteQuery, self).__init__( + super().__init__( model, db, q_objects, annotations, custom_filters, limit, orderings, ) - super(AwaitableQuery, self).__init__(cache_key) - + self._cache_key: str = cache_key + self._prepared: bool = False + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None self._db_for_write = True def _clone(self) -> PreparedDeleteQuery[MODEL]: @@ -594,13 +574,14 @@ async def execute(self, **params) -> int: return (await self._db.execute_query(cached_query.sql, filled_params))[0] -class PreparedExistsQuery(ExistsQuery, _PreparedQuery): +class PreparedExistsQuery(ExistsQuery, _PreparedQueryMixin): __slots__ = ( "_cache_key", "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", + "_db_for_write", ) def __init__( @@ -614,11 +595,14 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super(PreparedExistsQuery, self).__init__( + super().__init__( model, db, q_objects, annotations, custom_filters, force_indexes, use_indexes, ) - super(AwaitableQuery, self).__init__(cache_key) - + self._cache_key: str = cache_key + self._prepared: bool = False + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None self._db_for_write = False def _clone(self) -> PreparedExistsQuery: @@ -642,13 +626,14 @@ async def execute(self, **params) -> int: return bool(result) -class PreparedCountQuery(CountQuery, _PreparedQuery): +class PreparedCountQuery(CountQuery, _PreparedQueryMixin): __slots__ = ( "_cache_key", "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", + "_db_for_write", ) def __init__( @@ -664,11 +649,14 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super(PreparedCountQuery, self).__init__( + super().__init__( model, db, q_objects, annotations, custom_filters, limit, offset, force_indexes, use_indexes, ) - super(AwaitableQuery, self).__init__(cache_key) - + self._cache_key: str = cache_key + self._prepared: bool = False + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None self._db_for_write = False def _clone(self) -> PreparedCountQuery: @@ -700,13 +688,14 @@ async def execute(self, **params) -> int: return count -class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQuery): +class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): __slots__ = ( "_cache_key", "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", + "_db_for_write", ) def __init__( @@ -729,12 +718,16 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super(PreparedValuesListQuery, self).__init__( + super().__init__( model, db, q_objects, single, raise_does_not_exist, fields_for_select_list, limit, offset, distinct, orderings, flat, annotations, custom_filters, group_bys, force_indexes, use_indexes ) - super(AwaitableQuery, self).__init__(cache_key) + self._cache_key: str = cache_key + self._prepared: bool = False + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None self._db_for_write = False def _clone(self) -> PreparedValuesListQuery: @@ -768,13 +761,14 @@ async def execute(self, **params) -> list[Any] | tuple: return self._process_results(result) -class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQuery): +class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQueryMixin): __slots__ = ( "_cache_key", "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", + "_db_for_write" ) def __init__( @@ -796,12 +790,16 @@ def __init__( use_indexes: set[str], cache_key: str, ) -> None: - super(PreparedValuesQuery, self).__init__( + super().__init__( model, db, q_objects, single, raise_does_not_exist, fields_for_select, limit, offset, distinct, orderings, annotations, custom_filters, group_bys, force_indexes, use_indexes, ) - super(AwaitableQuery, self).__init__(cache_key) + self._cache_key: str = cache_key + self._prepared: bool = False + self._sql_cache = None + self._dynamic_params = None + self._dynamic_params_names = None self._db_for_write = False def _clone(self) -> PreparedValuesQuery: From 2a515478e73ca91fa1cb9af2d4a386153c2cfbd0 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 13:14:17 +0200 Subject: [PATCH 21/57] fix most of the typing errors --- tortoise/parameter.py | 22 ++-- tortoise/queryset_prepared.py | 184 ++++++++++++++++++---------------- 2 files changed, 109 insertions(+), 97 deletions(-) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index 872b5aac2..27bad60c6 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -1,11 +1,17 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Self, Callable, Any +from typing import Self, Callable, Any, TYPE_CHECKING, Sequence, TypeVar from tortoise.fields import Field from pypika_tortoise import SqlContext +if TYPE_CHECKING: + from tortoise import Model + +T_out = TypeVar("T_out") +FieldEncoder = Callable[[Any, "Model"], T_out] | Callable[[Any, "Model", Field | None], T_out] + @dataclass(frozen=True) class TortoiseSqlContext(SqlContext): @@ -35,10 +41,10 @@ class Parameter: def __init__(self, name: str) -> None: self.name = name - self.model = None - self.value_encoder = None + self.model: Model | None = None + self.value_encoder: FieldEncoder[Any] | None = None self.field_object: Field | None = None - self.encode = None + self.encode: Callable[[Any], Any] | None = None self.value_getter: Callable[[Any], Any] | None = None self.value_validator: Callable[[Any], Any] | None = None @@ -54,7 +60,7 @@ def clone(self) -> Self: return new - def encode_value(self, value: ...) -> ...: + def encode_value(self, value: Any) -> Any: if self.value_validator is not None: self.value_validator(value) @@ -82,8 +88,8 @@ class CollectionParameter(Parameter): def __init__(self, name: str) -> None: super().__init__(name) - self.collection_size = None - self.collection_encoder = None + self.collection_size: int | None = None + self.collection_encoder: FieldEncoder[Sequence[Any]] | None = None @classmethod def from_simple_param(cls, param: Parameter) -> Self: @@ -94,7 +100,7 @@ def from_simple_param(cls, param: Parameter) -> Self: new_param.encode = param.encode return new_param - def encode_collection(self, value: ...) -> ...: + def encode_collection(self, value: Any) -> Sequence[Any]: if self.field_object is not None: return self.collection_encoder(value, self.model, self.field_object) else: diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index ff18d0f48..4a8ad9e6b 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -1,14 +1,15 @@ -from __future__ import annotations +from __future__ import annotations as _ import functools -from abc import ABC +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable, Iterable -from typing import Any, Literal, TypeVar, cast, NoReturn, ParamSpec, Self, Protocol +from typing import Any, Literal, TypeVar, cast, NoReturn, ParamSpec, Self, Protocol, Concatenate from pypika_tortoise.terms import Term from tortoise.backends.base.client import BaseDBAsyncClient +from tortoise.backends.base.executor import BaseExecutor from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned, ParamsError from tortoise.expressions import Expression, Q from tortoise.filters import FilterInfoDict @@ -26,6 +27,48 @@ async def execute(self, **params) -> list[MODEL]: ... +class CachedSql: + def __init__(self, sql: str, params: list[Parameter | Any]) -> None: + self.sql = sql + self.params = params + self.param_by_name: dict[str, Parameter] = {} + self.need_params: dict[str, int] = {} + self.need_collection_params: dict[str, list[int]] = defaultdict(list) + for idx, param in enumerate(params): + if not isinstance(param, Parameter): + continue + + if param.name not in self.param_by_name: + self.param_by_name[param.name] = param + + if isinstance(param, CollectionParameter): + self.need_collection_params[param.name].append(idx) + else: + self.need_params[param.name] = idx + + def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + # TODO: check for parameters mismatch + + filled_params = self.params.copy() + for name, idx in self.need_params.items(): + param = self.param_by_name[name] + filled_params[idx] = param.encode_value(params[name]) + + for name, indexes in self.need_collection_params.items(): + param = cast(CollectionParameter, self.param_by_name[name]) + collection = param.encode_collection(params[name]) + if len(collection) != len(indexes): + raise ValueError( + f"Provided value length ({len(collection)}) " + f"for parameter {name!r} does not match " + f"parameter indexes length ({len(indexes)})" + ) + for idx, value in zip(indexes, collection): + filled_params[idx] = param.encode_value(value) + + return filled_params + + class _PreparedQueryMixin(AwaitableQuery, ABC): _cache_key: str _prepared: bool @@ -36,11 +79,9 @@ class _PreparedQueryMixin(AwaitableQuery, ABC): __slots__ = () + @abstractmethod def _clone(self) -> Self: - raise NotImplementedError - - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError("QuerySets must only be prepared once") + ... def prepared(self) -> Self: if self._cache_key is None: @@ -88,23 +129,26 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: if cache_key not in self._sql_cache: # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) - sql, params = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params) + sql, params_ = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params_) for param in reset_params: param.collection_size = None return self._sql_cache[cache_key] - async def execute(self, **params) -> int: - raise NotImplementedError + @abstractmethod + async def execute(self, **params) -> Any: + ... P = ParamSpec("P") T = TypeVar("T") -def _disallow_queryset_methods_on_prepared_query(func: Callable[P, T]) -> Callable[P, T]: +def _disallow_queryset_methods_on_prepared_query( + func: Callable[Concatenate[PreparedQuerySet, P], T], +) -> Callable[Concatenate[PreparedQuerySet, P], T]: @functools.wraps(func) def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: if self._prepared: @@ -130,15 +174,15 @@ def __init__(self, model: type[MODEL], cache_key: str) -> None: super().__init__(model) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None + self._sql_cache: dict[str, CachedSql] | None = None + self._dynamic_params: dict[str, CollectionParameter] | None = None + self._dynamic_params_names: list[str] | None = None self._db_for_write = self._select_for_update - self._custom_fields = None - self._executor = None + self._custom_fields: list[str] | None = None + self._executor: BaseExecutor | None = None def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparedQuerySet[MODEL]: - queryset = super()._clone(_new_cls) + queryset = cast(Self, super()._clone(_new_cls)) queryset._cache_key = self._cache_key queryset._prepared = self._prepared queryset._sql_cache = self._sql_cache @@ -147,8 +191,11 @@ def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparedQuery queryset._db_for_write = self._db_for_write return cast(PreparedQuerySet, queryset) + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError("QuerySets must only be prepared once") + def prepared(self) -> PreparedQuerySet[MODEL]: - queryset = cast(PreparedQuerySet[MODEL], super().prepared()) + queryset = cast(Self, super().prepared()) queryset._custom_fields = list(self._annotations.keys()) queryset._executor = queryset._db.executor_class( @@ -224,7 +271,7 @@ def _validate_offset(value: int) -> int: return value @_disallow_queryset_methods_on_prepared_query - def offset(self, offset: int) -> PreparedQuerySet[MODEL]: + def offset(self, offset: int | Parameter) -> PreparedQuerySet[MODEL]: if isinstance(offset, int) and offset < 0: raise ParamsError("Offset should be non-negative number") elif isinstance(offset, Parameter): @@ -428,48 +475,6 @@ def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: return cast(PreparedQuerySet, super().prefetch_related(*args)) -class CachedSql: - def __init__(self, sql: str, params: list[Parameter | Any]) -> None: - self.sql = sql - self.params = params - self.param_by_name: dict[str, Parameter] = {} - self.need_params: dict[str, int] = {} - self.need_collection_params: dict[str, list[int]] = defaultdict(list) - for idx, param in enumerate(params): - if not isinstance(param, Parameter): - continue - - if param.name not in self.param_by_name: - self.param_by_name[param.name] = param - - if isinstance(param, CollectionParameter): - self.need_collection_params[param.name].append(idx) - else: - self.need_params[param.name] = idx - - def make_filled_params(self, params: dict[str, Any]) -> list[Any]: - # TODO: check for parameters mismatch - - filled_params = self.params.copy() - for name, idx in self.need_params.items(): - param = self.param_by_name[name] - filled_params[idx] = param.encode_value(params[name]) - - for name, indexes in self.need_collection_params.items(): - param = cast(CollectionParameter, self.param_by_name[name]) - collection = param.encode_collection(params[name]) - if len(collection) != len(indexes): - raise ValueError( - f"Provided value length ({len(collection)}) " - f"for parameter {name!r} does not match " - f"parameter indexes length ({len(indexes)})" - ) - for idx, value in zip(indexes, collection): - filled_params[idx] = param.encode_value(value) - - return filled_params - - class PreparedUpdateQuery(UpdateQuery, _PreparedQueryMixin): __slots__ = ( "_cache_key", @@ -495,14 +500,15 @@ def __init__( super().__init__( model, update_kwargs, db, q_objects, annotations, custom_filters, limit, orderings, ) + self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = True + self._sql_cache: dict[str, CachedSql] | None = None + self._dynamic_params: dict[str, CollectionParameter] | None = None + self._dynamic_params_names: list[str] | None = None + self._db_for_write: bool = True - def _clone(self) -> PreparedUpdateQuery[MODEL]: + def _clone(self) -> PreparedUpdateQuery: query = self.__class__( model=self.model, update_kwargs=self.update_kwargs, @@ -549,12 +555,12 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = True + self._sql_cache: dict[str, CachedSql] | None = None + self._dynamic_params: dict[str, CollectionParameter] | None = None + self._dynamic_params_names: list[str] | None = None + self._db_for_write: bool = True - def _clone(self) -> PreparedDeleteQuery[MODEL]: + def _clone(self) -> PreparedDeleteQuery: query = self.__class__( model=self.model, db=self._db, @@ -600,10 +606,10 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = False + self._sql_cache: dict[str, CachedSql] | None = None + self._dynamic_params: dict[str, CollectionParameter] | None = None + self._dynamic_params_names: list[str] | None = None + self._db_for_write: bool = False def _clone(self) -> PreparedExistsQuery: query = self.__class__( @@ -654,10 +660,10 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = False + self._sql_cache: dict[str, CachedSql] | None = None + self._dynamic_params: dict[str, CollectionParameter] | None = None + self._dynamic_params_names: list[str] | None = None + self._db_for_write: bool = False def _clone(self) -> PreparedCountQuery: query = self.__class__( @@ -725,10 +731,10 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = False + self._sql_cache: dict[str, CachedSql] | None = None + self._dynamic_params: dict[str, CollectionParameter] | None = None + self._dynamic_params_names: list[str] | None = None + self._db_for_write: bool = False def _clone(self) -> PreparedValuesListQuery: query = self.__class__( @@ -797,10 +803,10 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache = None - self._dynamic_params = None - self._dynamic_params_names = None - self._db_for_write = False + self._sql_cache: dict[str, CachedSql] | None = None + self._dynamic_params: dict[str, CollectionParameter] | None = None + self._dynamic_params_names: list[str] | None = None + self._db_for_write: bool = False def _clone(self) -> PreparedValuesQuery: query = self.__class__( @@ -824,7 +830,7 @@ def _clone(self) -> PreparedValuesQuery: query._prepared = self._prepared return query - async def execute(self, **params) -> list[Any] | tuple: + async def execute(self, **params) -> list[dict] | dict: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) From 03397c21deff57ee67d901f50e87bee8a12a9875 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 13:18:16 +0200 Subject: [PATCH 22/57] dont cache database executor in PreparedQuerySet because it is anyway cached in EXECUTOR_CACHE --- tortoise/queryset_prepared.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 4a8ad9e6b..96fd93b01 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -164,7 +164,6 @@ class PreparedQuerySet(QuerySet[MODEL], _PreparedQueryMixin): "_prepared", "_custom_fields", "_sql_cache", - "_executor", "_dynamic_params", "_dynamic_params_names", "_db_for_write", @@ -179,7 +178,6 @@ def __init__(self, model: type[MODEL], cache_key: str) -> None: self._dynamic_params_names: list[str] | None = None self._db_for_write = self._select_for_update self._custom_fields: list[str] | None = None - self._executor: BaseExecutor | None = None def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparedQuerySet[MODEL]: queryset = cast(Self, super()._clone(_new_cls)) @@ -196,24 +194,20 @@ def prepare_sql(self, key: str) -> NoReturn: def prepared(self) -> PreparedQuerySet[MODEL]: queryset = cast(Self, super().prepared()) - queryset._custom_fields = list(self._annotations.keys()) - queryset._executor = queryset._db.executor_class( - model=queryset.model, - db=queryset._db, - prefetch_map=queryset._prefetch_map, - prefetch_queries=queryset._prefetch_queries, - select_related_idx=queryset._select_related_idx, - ) - return queryset async def execute(self, **params) -> list[MODEL]: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) - # TODO: re-create executor when database changes - instance_list = await self._executor.execute_select( + instance_list = await self._db.executor_class( + model=self.model, + db=self._db, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, + ).execute_select( cached_query.sql, filled_params, custom_fields=self._custom_fields, ) From bf981ca882a14238b169c1606b8920c3c2241e1e Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 13:20:41 +0200 Subject: [PATCH 23/57] cache prepared sql queries by dialect --- tortoise/queryset_prepared.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 96fd93b01..feb44649d 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -9,7 +9,6 @@ from pypika_tortoise.terms import Term from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.backends.base.executor import BaseExecutor from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned, ParamsError from tortoise.expressions import Expression, Q from tortoise.filters import FilterInfoDict @@ -113,8 +112,7 @@ def prepared(self) -> Self: def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: reset_params = [] - # TODO: cache also by database dialect - cache_key = "query" + cache_key = f"{self._db.capabilities.dialect}-query" for name in self._dynamic_params_names: value = params[name] if not isinstance(value, (tuple, list, set)): @@ -126,6 +124,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: param.collection_size = len(value) reset_params.append(param) + # TODO: add ability to limit cache, use lru? if cache_key not in self._sql_cache: # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) From 6fdb888bcd12327a39c75db8ec98a90795001cf8 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 13:21:23 +0200 Subject: [PATCH 24/57] add __slots__ to CachedSql --- tortoise/queryset_prepared.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index feb44649d..bb87104e4 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -27,6 +27,8 @@ async def execute(self, **params) -> list[MODEL]: class CachedSql: + __slots__ = ("sql", "params", "param_by_name", "need_params", "need_collection_params",) + def __init__(self, sql: str, params: list[Parameter | Any]) -> None: self.sql = sql self.params = params From 5b7ff632ffb5c14cc5c6d581bc8b4e29ad38247a Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 13:25:54 +0200 Subject: [PATCH 25/57] fix style issues --- tests/test_queryset_prepared.py | 212 +++++++++++++++-------- tortoise/filters.py | 27 +-- tortoise/parameter.py | 27 ++- tortoise/queryset.py | 4 +- tortoise/queryset_prepared.py | 289 ++++++++++++++++++++------------ 5 files changed, 360 insertions(+), 199 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 2a9e80ccf..157ad8fae 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -1,9 +1,10 @@ from tests.testmodels import ( - Author, Book, + Author, + Book, ) from tortoise.contrib import test from tortoise.exceptions import ParamsError, ValidationError -from tortoise.expressions import Subquery, Q +from tortoise.expressions import Q, Subquery from tortoise.parameter import Parameter @@ -21,13 +22,17 @@ def test_disallow_filtering_on_prepared_queryset(self): prepared.filter(id=1) async def test_gte_filter(self): - author1 = await Author.create(name="1") author2 = await Author.create(name="2") author3 = await Author.create(name="3") expected = await Author.filter(id__gte=author2.pk).order_by("id") - prepared = Author.prepare_sql("test_gte_filter").filter(id__gte=Parameter("idgte")).order_by("id").prepared() + prepared = ( + Author.prepare_sql("test_gte_filter") + .filter(id__gte=Parameter("idgte")) + .order_by("id") + .prepared() + ) actual = await prepared.execute(idgte=author2.pk) self.assertEqual(len(actual), 2) self.assertEqual(actual[0].id, author2.pk) @@ -35,9 +40,9 @@ async def test_gte_filter(self): self.assertEqual(expected, actual) async def test_string_param(self): - author1 = await Author.create(name="1") + await Author.create(name="1") author2 = await Author.create(name="2") - author3 = await Author.create(name="3") + await Author.create(name="3") expected = await Author.filter(name=author2.name) @@ -52,7 +57,11 @@ async def test_startswith_filter(self): author2 = await Author.create(name="testqwe") author3 = await Author.create(name="qwetest") - prepared = Author.prepare_sql("test_startswith_filter").filter(name__startswith=Parameter("name")).prepared() + prepared = ( + Author.prepare_sql("test_startswith_filter") + .filter(name__startswith=Parameter("name")) + .prepared() + ) for test_name in (author2.pk, author1.name, author3.name, "asd"): expected = await Author.filter(name__startswith=test_name) @@ -66,10 +75,7 @@ async def test_in_filter(self): prepared = Author.prepare_sql("test_in_filter").filter(id__in=Parameter("ids")).prepared() - for test_ids in ( - [author2.pk, author1.pk], - [author3.pk, author3.pk * 2, author3.pk * 10] - ): + for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): expected = await Author.filter(id__in=test_ids) actual = await prepared.execute(ids=test_ids) self.assertEqual(expected, actual) @@ -79,17 +85,23 @@ async def test_subqueries(self): author2 = await Author.create(name="2") author3 = await Author.create(name="3") - prepared = Author.prepare_sql("test_subqueries").filter(id__in=Subquery( - Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id") - )).prepared() + prepared = ( + Author.prepare_sql("test_subqueries") + .filter( + id__in=Subquery( + Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id") + ) + ) + .prepared() + ) for id1, id2 in ( - (author2.pk, author1.pk), - (author3.pk, author3.pk * 2), + (author2.pk, author1.pk), + (author3.pk, author3.pk * 2), ): - expected = await Author.filter(id__in=Subquery( - Author.filter(Q(id=id1) | Q(id=id2)).values("id") - )) + expected = await Author.filter( + id__in=Subquery(Author.filter(Q(id=id1) | Q(id=id2)).values("id")) + ) actual = await prepared.execute(id1=id1, id2=id2) self.assertEqual(expected, actual) @@ -98,32 +110,33 @@ async def test_subqueries_in_filter(self): author2 = await Author.create(name="2") author3 = await Author.create(name="3") - prepared = Author.prepare_sql("test_subqueries_in_filter").filter(id__in=Subquery( - Author.filter(id__in=Parameter("ids")).values("id") - )).prepared() + prepared = ( + Author.prepare_sql("test_subqueries_in_filter") + .filter(id__in=Subquery(Author.filter(id__in=Parameter("ids")).values("id"))) + .prepared() + ) - for test_ids in ( - [author2.pk, author1.pk], - [author3.pk, author3.pk * 2, author3.pk * 10] - ): - expected = await Author.filter(id__in=Subquery( - Author.filter(id__in=test_ids).values("id") - )) + for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): + expected = await Author.filter( + id__in=Subquery(Author.filter(id__in=test_ids).values("id")) + ) actual = await prepared.execute(ids=test_ids) self.assertEqual(expected, actual) async def test_update(self): author1 = await Author.create(name="1") author2 = await Author.create(name="2") - author3 = await Author.create(name="3") original_name1 = author1.name original_name2 = author2.name - new_name1 = f"{author1.name}_test" + new_name1 = f"{author1.name}_test" - prepared = Author.prepare_sql("test_update").filter( - id=Parameter("search_id") - ).update(name=Parameter("replace_name")).prepared() + prepared = ( + Author.prepare_sql("test_update") + .filter(id=Parameter("search_id")) + .update(name=Parameter("replace_name")) + .prepared() + ) await prepared.execute(search_id=author1.pk, replace_name=new_name1) await author1.refresh_from_db(["name"]) @@ -140,9 +153,14 @@ async def test_delete(self): author2 = await Author.create(name="2") author3 = await Author.create(name="3") - prepared = Author.prepare_sql("test_delete").filter( - id__in=Parameter("ids"), - ).delete().prepared() + prepared = ( + Author.prepare_sql("test_delete") + .filter( + id__in=Parameter("ids"), + ) + .delete() + .prepared() + ) affected = await prepared.execute(ids=[author1.pk]) self.assertEqual(affected, 1) @@ -153,9 +171,14 @@ async def test_delete(self): async def test_exists(self): author = await Author.create(name="1") - prepared = Author.prepare_sql("test_exists").filter( - id__in=Parameter("ids"), - ).exists().prepared() + prepared = ( + Author.prepare_sql("test_exists") + .filter( + id__in=Parameter("ids"), + ) + .exists() + .prepared() + ) self.assertTrue(await prepared.execute(ids=[author.pk])) self.assertFalse(await prepared.execute(ids=[author.pk * 2])) @@ -165,9 +188,14 @@ async def test_count(self): author2 = await Author.create(name="2") author3 = await Author.create(name="3") - prepared = Author.prepare_sql("test_count").filter( - id__gte=Parameter("idgte"), - ).count().prepared() + prepared = ( + Author.prepare_sql("test_count") + .filter( + id__gte=Parameter("idgte"), + ) + .count() + .prepared() + ) self.assertEqual(await prepared.execute(idgte=author1.pk), 3) self.assertEqual(await prepared.execute(idgte=author2.pk), 2) @@ -175,11 +203,21 @@ async def test_count(self): self.assertEqual(await prepared.execute(idgte=author3.pk * 2), 0) async def test_parameter_in_limit(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") + await Author.bulk_create( + [ + Author(name="1"), + Author(name="2"), + Author(name="3"), + ] + ) - prepared = Author.prepare_sql("test_parameter_in_limit").all().limit(Parameter("lim")).order_by("id").prepared() + prepared = ( + Author.prepare_sql("test_parameter_in_limit") + .all() + .limit(Parameter("lim")) + .order_by("id") + .prepared() + ) self.assertEqual(len(await prepared.execute(lim=1)), 1) self.assertEqual(len(await prepared.execute(lim=2)), 2) @@ -190,11 +228,21 @@ async def test_parameter_in_limit(self): await prepared.execute(lim=-1) async def test_parameter_in_offset(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") + await Author.bulk_create( + [ + Author(name="1"), + Author(name="2"), + Author(name="3"), + ] + ) - prepared = Author.prepare_sql("test_parameter_in_offset").all().offset(Parameter("off")).order_by("id").prepared() + prepared = ( + Author.prepare_sql("test_parameter_in_offset") + .all() + .offset(Parameter("off")) + .order_by("id") + .prepared() + ) self.assertEqual(len(await prepared.execute(off=1)), 2) self.assertEqual(len(await prepared.execute(off=2)), 1) @@ -207,9 +255,14 @@ async def test_parameter_in_offset(self): async def test_values(self): author = await Author.create(name="1") - prepared = Author.prepare_sql("test_values").filter( - id=Parameter("id"), - ).values().prepared() + prepared = ( + Author.prepare_sql("test_values") + .filter( + id=Parameter("id"), + ) + .values() + .prepared() + ) self.assertEqual( await prepared.execute(id=author.pk), @@ -223,9 +276,14 @@ async def test_values(self): async def test_values_list_all_fields(self): author = await Author.create(name="1") - prepared_all = Author.prepare_sql("test_values_list_all_fields").filter( - id=Parameter("id"), - ).values_list().prepared() + prepared_all = ( + Author.prepare_sql("test_values_list_all_fields") + .filter( + id=Parameter("id"), + ) + .values_list() + .prepared() + ) self.assertEqual( await prepared_all.execute(id=author.pk), [(author.pk, author.name)], @@ -238,9 +296,14 @@ async def test_values_list_all_fields(self): async def test_values_list_only_id_field(self): author = await Author.create(name="1") - prepared_ids = Author.prepare_sql("test_values_list_only_id_field").filter( - id=Parameter("id"), - ).values_list("id").prepared() + prepared_ids = ( + Author.prepare_sql("test_values_list_only_id_field") + .filter( + id=Parameter("id"), + ) + .values_list("id") + .prepared() + ) self.assertEqual( await prepared_ids.execute(id=author.pk), [(author.pk,)], @@ -253,9 +316,14 @@ async def test_values_list_only_id_field(self): async def test_values_list_only_id_field_flat(self): author = await Author.create(name="1") - prepared_ids_flat = Author.prepare_sql("test_values_list_only_id_field_flat").filter( - id=Parameter("id"), - ).values_list("id", flat=True).prepared() + prepared_ids_flat = ( + Author.prepare_sql("test_values_list_only_id_field_flat") + .filter( + id=Parameter("id"), + ) + .values_list("id", flat=True) + .prepared() + ) self.assertEqual( await prepared_ids_flat.execute(id=author.pk), [author.pk], @@ -271,9 +339,12 @@ async def test_update_fk(self): book = await Book.create(name="test", author=author1, rating=5) - prepared = Book.prepare_sql("test_update_fk").filter( - id=Parameter("search_id") - ).update(author=Parameter("replace_author")).prepared() + prepared = ( + Book.prepare_sql("test_update_fk") + .filter(id=Parameter("search_id")) + .update(author=Parameter("replace_author")) + .prepared() + ) await prepared.execute(search_id=book.pk, replace_author=author2) book = await Book.get(id=book.pk).select_related("author") @@ -289,9 +360,12 @@ async def test_update_pk_invalid_obj(self): author = await Author.create(name="1") book = await Book.create(name="test", author=author, rating=5) - prepared = Book.prepare_sql("test_update_pk_invalid_obj").filter( - id=Parameter("search_id") - ).update(author=Parameter("replace_author")).prepared() + prepared = ( + Book.prepare_sql("test_update_pk_invalid_obj") + .filter(id=Parameter("search_id")) + .update(author=Parameter("replace_author")) + .prepared() + ) with self.assertRaises(ValidationError): await prepared.execute(search_id=book.pk, replace_author="not an Author object") diff --git a/tortoise/filters.py b/tortoise/filters.py index 8e38845c8..a9329f78d 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -14,15 +14,15 @@ BasicCriterion, Criterion, Equality, + Function, Term, ValueWrapper, - Function, ) from tortoise.contrib.postgres.fields import ArrayField, TSVectorField from tortoise.fields import Field, JSONField from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance -from tortoise.parameter import Parameter, CollectionParameter +from tortoise.parameter import CollectionParameter, Parameter if sys.version_info >= (3, 11): # pragma:nocoverage from typing import NotRequired @@ -151,7 +151,7 @@ def not_null(field: Term, value: Any) -> Criterion: def contains(field: Term, value: str | Parameter) -> Criterion: return Like( Cast(field, SqlTypes.VARCHAR), - field.wrap_constant(_format_str_or_parameter(field, value, True, True)) + field.wrap_constant(_format_str_or_parameter(field, value, True, True)), ) @@ -175,8 +175,11 @@ def insensitive_posix_regex(field: Term, value: str): def _format_str_or_parameter( - field: Term, value: str | Parameter, like_start: bool = False, like_end: bool = False, - escape_func: Callable[[Any], str] = escape_like, + field: Term, + value: str | Parameter, + like_start: bool = False, + like_end: bool = False, + escape_func: Callable[[Any], str] = escape_like, ) -> Term: if isinstance(value, Parameter): value.encode = escape_func @@ -192,9 +195,7 @@ def _format_str_or_parameter( return Function("Concat", *args) else: return field.wrap_constant( - f"{'%' if like_start else ''}" - f"{escape_func(value)}" - f"{'%' if like_end else ''}" + f"{'%' if like_start else ''}{escape_func(value)}{'%' if like_end else ''}" ) @@ -207,27 +208,29 @@ def ends_with(field: Term, value: str | Parameter) -> Criterion: def insensitive_exact(field: Term, value: str | Parameter) -> Criterion: - return Upper(Cast(field, SqlTypes.VARCHAR)).eq(Upper(_format_str_or_parameter(field, value, escape_func=str))) + return Upper(Cast(field, SqlTypes.VARCHAR)).eq( + Upper(_format_str_or_parameter(field, value, escape_func=str)) + ) def insensitive_contains(field: Term, value: str | Parameter) -> Criterion: return Like( Upper(Cast(field, SqlTypes.VARCHAR)), - field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, True))) + field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, True))), ) def insensitive_starts_with(field: Term, value: str | Parameter) -> Criterion: return Like( Upper(Cast(field, SqlTypes.VARCHAR)), - field.wrap_constant(Upper(_format_str_or_parameter(field, value, False, True))) + field.wrap_constant(Upper(_format_str_or_parameter(field, value, False, True))), ) def insensitive_ends_with(field: Term, value: str | Parameter) -> Criterion: return Like( Upper(Cast(field, SqlTypes.VARCHAR)), - field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, False))) + field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, False))), ) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index 27bad60c6..e3ddc32f6 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -1,11 +1,13 @@ from __future__ import annotations +from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Self, Callable, Any, TYPE_CHECKING, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, Self, TypeVar -from tortoise.fields import Field from pypika_tortoise import SqlContext +from tortoise.fields import Field + if TYPE_CHECKING: from tortoise import Model @@ -18,7 +20,9 @@ class TortoiseSqlContext(SqlContext): dynamic_params: dict[str, CollectionParameter] | None = None def copy(self: SqlContext, **kwargs) -> SqlContext: - existing_dynamic_params = self.dynamic_params if isinstance(self, TortoiseSqlContext) else None + existing_dynamic_params = ( + self.dynamic_params if isinstance(self, TortoiseSqlContext) else None + ) return TortoiseSqlContext( quote_char=kwargs.get("quote_char", self.quote_char), secondary_quote_char=kwargs.get("secondary_quote_char", self.secondary_quote_char), @@ -37,7 +41,15 @@ def copy(self: SqlContext, **kwargs) -> SqlContext: class Parameter: - __slots__ = ("name", "model", "value_encoder", "field_object", "encode", "value_getter", "value_validator",) + __slots__ = ( + "name", + "model", + "value_encoder", + "field_object", + "encode", + "value_getter", + "value_validator", + ) def __init__(self, name: str) -> None: self.name = name @@ -84,7 +96,10 @@ def encode_value(self, value: Any) -> Any: class CollectionParameter(Parameter): - __slots__ = ("collection_size", "collection_encoder",) + __slots__ = ( + "collection_size", + "collection_encoder", + ) def __init__(self, name: str) -> None: super().__init__(name) @@ -122,4 +137,4 @@ def get_sql(self, ctx: SqlContext) -> str: new_param.collection_encoder = new_param.value_encoder new_param.value_encoder = None ctx.parameterizer.create_param(new_param) - return f"({','.join(['?' for _ in range(param.collection_size)])})" \ No newline at end of file + return f"({','.join(['?' for _ in range(param.collection_size)])})" diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 3c79c6395..61f0e65f2 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -2,10 +2,10 @@ import types from collections import defaultdict -from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable +from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable, Sequence from copy import copy from operator import attrgetter -from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload, Sequence +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index bb87104e4..f1f736bdd 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable, Iterable -from typing import Any, Literal, TypeVar, cast, NoReturn, ParamSpec, Self, Protocol, Concatenate +from typing import Any, Concatenate, Literal, NoReturn, ParamSpec, Protocol, Self, TypeVar, cast from pypika_tortoise.terms import Term @@ -12,22 +12,40 @@ from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned, ParamsError from tortoise.expressions import Expression, Q from tortoise.filters import FilterInfoDict -from tortoise.parameter import Parameter, CollectionParameter, TortoiseSqlContext +from tortoise.parameter import CollectionParameter, Parameter, TortoiseSqlContext from tortoise.query_utils import Prefetch -from tortoise.queryset import QuerySet, MODEL, QuerySetSingle, T_co, DeleteQuery, BulkCreateQuery, \ - BulkUpdateQuery, UpdateQuery, ExistsQuery, CountQuery, ValuesListQuery, ValuesQuery, SINGLE, AwaitableQuery +from tortoise.queryset import ( + MODEL, + SINGLE, + AwaitableQuery, + BulkCreateQuery, + BulkUpdateQuery, + CountQuery, + DeleteQuery, + ExistsQuery, + QuerySet, + QuerySetSingle, + T_co, + UpdateQuery, + ValuesListQuery, + ValuesQuery, +) class PreparedQuerySetSingle(QuerySetSingle[T_co], Protocol): - def prepared(self) -> PreparedQuerySet[MODEL]: - ... + def prepared(self) -> PreparedQuerySet[MODEL]: ... - async def execute(self, **params) -> list[MODEL]: - ... + async def execute(self, **params) -> list[MODEL]: ... class CachedSql: - __slots__ = ("sql", "params", "param_by_name", "need_params", "need_collection_params",) + __slots__ = ( + "sql", + "params", + "param_by_name", + "need_params", + "need_collection_params", + ) def __init__(self, sql: str, params: list[Parameter | Any]) -> None: self.sql = sql @@ -81,8 +99,7 @@ class _PreparedQueryMixin(AwaitableQuery, ABC): __slots__ = () @abstractmethod - def _clone(self) -> Self: - ... + def _clone(self) -> Self: ... def prepared(self) -> Self: if self._cache_key is None: @@ -99,9 +116,7 @@ def prepared(self) -> Self: queryset._sql_cache = {} _, params = queryset.query.get_parameterized_sql() queryset._dynamic_params = { - param.name: param - for param in params - if isinstance(param, CollectionParameter) + param.name: param for param in params if isinstance(param, CollectionParameter) } queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) @@ -129,7 +144,9 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: # TODO: add ability to limit cache, use lru? if cache_key not in self._sql_cache: # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy(self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params) + ctx = TortoiseSqlContext.copy( + self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params + ) sql, params_ = self.query.get_parameterized_sql(ctx) self._sql_cache[cache_key] = CachedSql(sql, params_) @@ -139,8 +156,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: return self._sql_cache[cache_key] @abstractmethod - async def execute(self, **params) -> Any: - ... + async def execute(self, **params) -> Any: ... P = ParamSpec("P") @@ -148,12 +164,12 @@ async def execute(self, **params) -> Any: def _disallow_queryset_methods_on_prepared_query( - func: Callable[Concatenate[PreparedQuerySet, P], T], + func: Callable[Concatenate[PreparedQuerySet, P], T], ) -> Callable[Concatenate[PreparedQuerySet, P], T]: @functools.wraps(func) def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: if self._prepared: - raise ValueError(f"Cannot call \"{func.__name__}\" on already prepared queryset.") + raise ValueError(f'Cannot call "{func.__name__}" on already prepared queryset.') return func(self, *args, **kwargs) return decorated @@ -209,7 +225,8 @@ async def execute(self, **params) -> list[MODEL]: prefetch_queries=self._prefetch_queries, select_related_idx=self._select_related_idx, ).execute_select( - cached_query.sql, filled_params, + cached_query.sql, + filled_params, custom_fields=self._custom_fields, ) if self._single: @@ -256,7 +273,7 @@ def limit(self, limit: int | Parameter) -> PreparedQuerySet[MODEL]: limit.encode = self._validate_limit queryset = self._clone() - queryset._limit = limit # type: ignore + queryset._limit = limit # type: ignore return queryset @staticmethod @@ -294,9 +311,7 @@ def select_for_update( of: tuple[str, ...] = (), no_key: bool = False, ) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().select_for_update( - nowait, skip_locked, of, no_key - )) + return cast(PreparedQuerySet, super().select_for_update(nowait, skip_locked, of, no_key)) @_disallow_queryset_methods_on_prepared_query def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: @@ -307,7 +322,9 @@ def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: return cast(PreparedQuerySet, super().group_by(*fields)) @_disallow_queryset_methods_on_prepared_query - def values_list(self, *fields_: str, flat: bool = False) -> PreparedValuesListQuery[Literal[False]]: + def values_list( + self, *fields_: str, flat: bool = False + ) -> PreparedValuesListQuery[Literal[False]]: fields_for_select_list = self._get_fields_list_for_select(*fields_) return PreparedValuesListQuery( @@ -481,19 +498,26 @@ class PreparedUpdateQuery(UpdateQuery, _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - update_kwargs: dict[str, Any], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - orderings: list[tuple[str, str]], - cache_key: str, + self, + model: type[MODEL], + update_kwargs: dict[str, Any], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + orderings: list[tuple[str, str]], + cache_key: str, ) -> None: super().__init__( - model, update_kwargs, db, q_objects, annotations, custom_filters, limit, orderings, + model, + update_kwargs, + db, + q_objects, + annotations, + custom_filters, + limit, + orderings, ) self._cache_key: str = cache_key @@ -535,18 +559,24 @@ class PreparedDeleteQuery(DeleteQuery, _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - orderings: list[tuple[str, str]], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + orderings: list[tuple[str, str]], + cache_key: str, ) -> None: super().__init__( - model, db, q_objects, annotations, custom_filters, limit, orderings, + model, + db, + q_objects, + annotations, + custom_filters, + limit, + orderings, ) self._cache_key: str = cache_key self._prepared: bool = False @@ -586,18 +616,24 @@ class PreparedExistsQuery(ExistsQuery, _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, ) -> None: super().__init__( - model, db, q_objects, annotations, custom_filters, force_indexes, use_indexes, + model, + db, + q_objects, + annotations, + custom_filters, + force_indexes, + use_indexes, ) self._cache_key: str = cache_key self._prepared: bool = False @@ -638,20 +674,28 @@ class PreparedCountQuery(CountQuery, _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - offset: int | None, - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + limit: int | None, + offset: int | None, + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, ) -> None: super().__init__( - model, db, q_objects, annotations, custom_filters, limit, offset, force_indexes, use_indexes, + model, + db, + q_objects, + annotations, + custom_filters, + limit, + offset, + force_indexes, + use_indexes, ) self._cache_key: str = cache_key self._prepared: bool = False @@ -700,29 +744,42 @@ class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - single: bool, - raise_does_not_exist: bool, - fields_for_select_list: tuple[str, ...] | list[str], - limit: int | None, - offset: int | None, - distinct: bool, - orderings: list[tuple[str, str]], - flat: bool, - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - group_bys: tuple[str, ...], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + single: bool, + raise_does_not_exist: bool, + fields_for_select_list: tuple[str, ...] | list[str], + limit: int | None, + offset: int | None, + distinct: bool, + orderings: list[tuple[str, str]], + flat: bool, + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + group_bys: tuple[str, ...], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, ) -> None: super().__init__( - model, db, q_objects, single, raise_does_not_exist, fields_for_select_list, limit, - offset, distinct, orderings, flat, annotations, custom_filters, group_bys, - force_indexes, use_indexes + model, + db, + q_objects, + single, + raise_does_not_exist, + fields_for_select_list, + limit, + offset, + distinct, + orderings, + flat, + annotations, + custom_filters, + group_bys, + force_indexes, + use_indexes, ) self._cache_key: str = cache_key self._prepared: bool = False @@ -769,32 +826,44 @@ class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQueryMixin): "_sql_cache", "_dynamic_params", "_dynamic_params_names", - "_db_for_write" + "_db_for_write", ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - single: bool, - raise_does_not_exist: bool, - fields_for_select: dict[str, str], - limit: int | None, - offset: int | None, - distinct: bool, - orderings: list[tuple[str, str]], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - group_bys: tuple[str, ...], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + q_objects: list[Q], + single: bool, + raise_does_not_exist: bool, + fields_for_select: dict[str, str], + limit: int | None, + offset: int | None, + distinct: bool, + orderings: list[tuple[str, str]], + annotations: dict[str, Any], + custom_filters: dict[str, FilterInfoDict], + group_bys: tuple[str, ...], + force_indexes: set[str], + use_indexes: set[str], + cache_key: str, ) -> None: super().__init__( - model, db, q_objects, single, raise_does_not_exist, fields_for_select, limit, - offset, distinct, orderings, annotations, custom_filters, group_bys, - force_indexes, use_indexes, + model, + db, + q_objects, + single, + raise_does_not_exist, + fields_for_select, + limit, + offset, + distinct, + orderings, + annotations, + custom_filters, + group_bys, + force_indexes, + use_indexes, ) self._cache_key: str = cache_key self._prepared: bool = False From 6ea62f4c1bd6b7c3632cba926f9c18e31d100cae Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 13:47:49 +0200 Subject: [PATCH 26/57] fix rest of the typing errors --- tortoise/filters.py | 2 +- tortoise/parameter.py | 23 ++++++++++---- tortoise/queryset.py | 12 ++++---- tortoise/queryset_prepared.py | 58 +++++++++++++++++------------------ 4 files changed, 52 insertions(+), 43 deletions(-) diff --git a/tortoise/filters.py b/tortoise/filters.py index a9329f78d..7a7e0b818 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -186,7 +186,7 @@ def _format_str_or_parameter( wrapped = ValueWrapper(value) if not like_start and not like_end: return wrapped - args = [] + args: list[str | ValueWrapper] = [] if like_start: args.append("%") args.append(wrapped) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index e3ddc32f6..cc83b5e10 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Self, TypeVar +from typing import TYPE_CHECKING, Any, Protocol, Self, TypeVar, cast from pypika_tortoise import SqlContext @@ -11,8 +11,16 @@ if TYPE_CHECKING: from tortoise import Model -T_out = TypeVar("T_out") -FieldEncoder = Callable[[Any, "Model"], T_out] | Callable[[Any, "Model", Field | None], T_out] +T_out = TypeVar("T_out", covariant=True) + + +class FieldEncoder(Protocol[T_out]): + def __call__( + self, + value: Any, + model: type[Model] | None, + field: Field | None = None, + ) -> T_out: ... @dataclass(frozen=True) @@ -53,7 +61,7 @@ class Parameter: def __init__(self, name: str) -> None: self.name = name - self.model: Model | None = None + self.model: type[Model] | None = None self.value_encoder: FieldEncoder[Any] | None = None self.field_object: Field | None = None self.encode: Callable[[Any], Any] | None = None @@ -87,7 +95,7 @@ def encode_value(self, value: Any) -> Any: else: encoded = self.value_encoder(value, self.model) elif self.field_object is not None: - encoded = self.field_object.to_db_value(value, self.model) + encoded = self.field_object.to_db_value(value, cast(type[Model], self.model)) if self.encode: encoded = self.encode(encoded) @@ -116,6 +124,9 @@ def from_simple_param(cls, param: Parameter) -> Self: return new_param def encode_collection(self, value: Any) -> Sequence[Any]: + if self.collection_encoder is None: + return value # TODO: probably raise exception + if self.field_object is not None: return self.collection_encoder(value, self.model, self.field_object) else: @@ -123,7 +134,7 @@ def encode_collection(self, value: Any) -> Sequence[Any]: def get_sql(self, ctx: SqlContext) -> str: param = self - if isinstance(ctx, TortoiseSqlContext): + if isinstance(ctx, TortoiseSqlContext) and ctx.dynamic_params is not None: param = ctx.dynamic_params.get(self.name, self) if param.collection_size is None: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 61f0e65f2..b81ee8a5b 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -642,7 +642,7 @@ def group_by(self, *fields: str) -> QuerySet[MODEL]: queryset._group_bys = fields return queryset - def _get_fields_list_for_select(self, *fields_: str) -> list[str]: + def _get_fields_list_for_select(self, *fields_: str) -> tuple[str, ...] | list[str]: if self._fields_for_select: raise ValueError(".values_list() cannot be used with .only()") @@ -1275,15 +1275,15 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: from tortoise.queryset_prepared import PreparedQuerySet - queryset = self._clone(PreparedQuerySet) + queryset = cast(PreparedQuerySet[MODEL], self._clone(PreparedQuerySet)) queryset._cache_key = key queryset._prepared = False - queryset._sql_cache = None - queryset._dynamic_params = None - queryset._dynamic_params_names = None + queryset._sql_cache = {} + queryset._dynamic_params = {} + queryset._dynamic_params_names = [] queryset._db_for_write = self._select_for_update - return cast(PreparedQuerySet[MODEL], queryset) + return queryset class UpdateQuery(AwaitableQuery): diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index f1f736bdd..b64e5e6f7 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -91,13 +91,11 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: class _PreparedQueryMixin(AwaitableQuery, ABC): _cache_key: str _prepared: bool - _sql_cache: dict[str, CachedSql] | None - _dynamic_params: dict[str, CollectionParameter] | None - _dynamic_params_names: list[str] | None + _sql_cache: dict[str, CachedSql] + _dynamic_params: dict[str, CollectionParameter] + _dynamic_params_names: list[str] _db_for_write: bool - __slots__ = () - @abstractmethod def _clone(self) -> Self: ... @@ -113,13 +111,12 @@ def prepared(self) -> Self: queryset._choose_db_if_not_chosen(self._db_for_write) queryset._make_query() - queryset._sql_cache = {} _, params = queryset.query.get_parameterized_sql() + queryset._sql_cache = {} queryset._dynamic_params = { param.name: param for param in params if isinstance(param, CollectionParameter) } queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - queryset._prepared = True self.model._meta.query_cache[self._cache_key] = queryset @@ -145,7 +142,8 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: if cache_key not in self._sql_cache: # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy( - self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params + self.query.QUERY_CLS.SQL_CONTEXT, + dynamic_params=self._dynamic_params, ) sql, params_ = self.query.get_parameterized_sql(ctx) self._sql_cache[cache_key] = CachedSql(sql, params_) @@ -190,9 +188,9 @@ def __init__(self, model: type[MODEL], cache_key: str) -> None: super().__init__(model) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] | None = None - self._dynamic_params: dict[str, CollectionParameter] | None = None - self._dynamic_params_names: list[str] | None = None + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] self._db_for_write = self._select_for_update self._custom_fields: list[str] | None = None @@ -223,7 +221,7 @@ async def execute(self, **params) -> list[MODEL]: db=self._db, prefetch_map=self._prefetch_map, prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, + select_related_idx=self._select_related_idx, # type: ignore ).execute_select( cached_query.sql, filled_params, @@ -522,9 +520,9 @@ def __init__( self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] | None = None - self._dynamic_params: dict[str, CollectionParameter] | None = None - self._dynamic_params_names: list[str] | None = None + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True def _clone(self) -> PreparedUpdateQuery: @@ -580,9 +578,9 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] | None = None - self._dynamic_params: dict[str, CollectionParameter] | None = None - self._dynamic_params_names: list[str] | None = None + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True def _clone(self) -> PreparedDeleteQuery: @@ -637,9 +635,9 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] | None = None - self._dynamic_params: dict[str, CollectionParameter] | None = None - self._dynamic_params_names: list[str] | None = None + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False def _clone(self) -> PreparedExistsQuery: @@ -699,9 +697,9 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] | None = None - self._dynamic_params: dict[str, CollectionParameter] | None = None - self._dynamic_params_names: list[str] | None = None + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False def _clone(self) -> PreparedCountQuery: @@ -783,9 +781,9 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] | None = None - self._dynamic_params: dict[str, CollectionParameter] | None = None - self._dynamic_params_names: list[str] | None = None + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False def _clone(self) -> PreparedValuesListQuery: @@ -867,9 +865,9 @@ def __init__( ) self._cache_key: str = cache_key self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] | None = None - self._dynamic_params: dict[str, CollectionParameter] | None = None - self._dynamic_params_names: list[str] | None = None + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False def _clone(self) -> PreparedValuesQuery: From 3c56e3fbae10f86a3f2c7b92e525e3cd9e568767 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 13:54:46 +0200 Subject: [PATCH 27/57] fix typing.Self import failing on python3.10 --- tortoise/parameter.py | 10 ++++++++-- tortoise/queryset_prepared.py | 8 +++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index cc83b5e10..dbbf2ddf7 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -1,13 +1,19 @@ from __future__ import annotations +import sys from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol, Self, TypeVar, cast +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast from pypika_tortoise import SqlContext from tortoise.fields import Field +if sys.version_info >= (3, 11): # pragma: nocoverage + from typing import Self +else: + from typing_extensions import Self + if TYPE_CHECKING: from tortoise import Model @@ -95,7 +101,7 @@ def encode_value(self, value: Any) -> Any: else: encoded = self.value_encoder(value, self.model) elif self.field_object is not None: - encoded = self.field_object.to_db_value(value, cast(type[Model], self.model)) + encoded = self.field_object.to_db_value(value, cast(type["Model"], self.model)) if self.encode: encoded = self.encode(encoded) diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index b64e5e6f7..54a2f70c8 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -1,10 +1,11 @@ from __future__ import annotations as _ import functools +import sys from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable, Iterable -from typing import Any, Concatenate, Literal, NoReturn, ParamSpec, Protocol, Self, TypeVar, cast +from typing import Any, Concatenate, Literal, NoReturn, ParamSpec, Protocol, TypeVar, cast from pypika_tortoise.terms import Term @@ -31,6 +32,11 @@ ValuesQuery, ) +if sys.version_info >= (3, 11): # pragma: nocoverage + from typing import Self +else: + from typing_extensions import Self + class PreparedQuerySetSingle(QuerySetSingle[T_co], Protocol): def prepared(self) -> PreparedQuerySet[MODEL]: ... From 576d87115dce6b40e059330f7b8837200d5bfccb Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 14:23:14 +0200 Subject: [PATCH 28/57] dont use concat in startswith/endswith/contains/etc. --- tortoise/filters.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tortoise/filters.py b/tortoise/filters.py index 7a7e0b818..c6e0972a8 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -14,7 +14,6 @@ BasicCriterion, Criterion, Equality, - Function, Term, ValueWrapper, ) @@ -181,22 +180,17 @@ def _format_str_or_parameter( like_end: bool = False, escape_func: Callable[[Any], str] = escape_like, ) -> Term: + like_at_start = "%" if like_start else "" + like_at_end = "%" if like_end else "" + if isinstance(value, Parameter): value.encode = escape_func wrapped = ValueWrapper(value) - if not like_start and not like_end: - return wrapped - args: list[str | ValueWrapper] = [] - if like_start: - args.append("%") - args.append(wrapped) - if like_end: - args.append("%") - return Function("Concat", *args) + if like_start or like_end: + value.encode = lambda val: f"{like_at_start}{escape_func(val)}{like_at_end}" + return wrapped else: - return field.wrap_constant( - f"{'%' if like_start else ''}{escape_func(value)}{'%' if like_end else ''}" - ) + return field.wrap_constant(f"{like_at_start}{escape_func(value)}{like_at_end}") def starts_with(field: Term, value: str | Parameter) -> Criterion: From bed1c7ec2bd834bd2b9b11e98e0f2d676d86c2ef Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 14:53:33 +0200 Subject: [PATCH 29/57] return correct parameter placeholders in CollectionParameter.get_sql --- tortoise/parameter.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index dbbf2ddf7..a3cd38288 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -144,14 +144,17 @@ def get_sql(self, ctx: SqlContext) -> str: param = ctx.dynamic_params.get(self.name, self) if param.collection_size is None: - if ctx.parameterizer is not None: - ctx.parameterizer.create_param(param) - return "?" + if ctx.parameterizer is None: + raise ValueError("Parametrization must be enabled when using tortoise.Parameter.") + return ctx.parameterizer.create_param(param).get_sql(ctx) else: - if ctx.parameterizer is not None: - for idx in range(param.collection_size): - new_param = param.clone() - new_param.collection_encoder = new_param.value_encoder - new_param.value_encoder = None - ctx.parameterizer.create_param(new_param) - return f"({','.join(['?' for _ in range(param.collection_size)])})" + if ctx.parameterizer is None: + raise ValueError("Parametrization must be enabled when using tortoise.Parameter.") + placeholders = [] + for idx in range(param.collection_size): + new_param = param.clone() + new_param.collection_encoder = new_param.value_encoder + new_param.value_encoder = None + pypika_param = ctx.parameterizer.create_param(new_param) + placeholders.append(pypika_param.get_sql(ctx)) + return f"({','.join(placeholders)})" From 4d0ae8901abf142a8e0abce4656008fa689c7c42 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 22 Feb 2026 16:14:04 +0200 Subject: [PATCH 30/57] add Parameter support in mysql string filters --- tests/test_queryset_prepared.py | 3 ++ tortoise/backends/mysql/executor.py | 64 +++++++++++++++++++++++++---- tortoise/parameter.py | 7 ++-- 3 files changed, 62 insertions(+), 12 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 157ad8fae..29fe81c18 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -63,6 +63,9 @@ async def test_startswith_filter(self): .prepared() ) + # print(Author.filter(name__startswith="asd").sql()) + # print(prepared.sql()) + for test_name in (author2.pk, author1.name, author3.name, "asd"): expected = await Author.filter(name__startswith=test_name) actual = await prepared.execute(name=test_name) diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 6a19be073..373f0f97e 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -1,4 +1,6 @@ import enum +from collections.abc import Callable +from typing import Any from pypika_tortoise import SqlContext, functions from pypika_tortoise.enums import SqlTypes @@ -32,6 +34,7 @@ search, starts_with, ) +from tortoise.parameter import Parameter class MySQLRegexpComparators(enum.Enum): @@ -49,10 +52,47 @@ def get_value_sql(self, ctx: SqlContext) -> str: return format_quotes(value, quote_char) +# TODO: maybe there is better way to do this? +class StrParamWrapper(ValueWrapper): + value: Parameter + + def get_value_sql(self, ctx: SqlContext) -> str: + quote_char = ctx.secondary_quote_char or "" + real_encoder = self.value.encode + + def _encoder(val: str) -> str: + if real_encoder is not None: + val = real_encoder(val) + val = val.replace(quote_char, quote_char * 2) + return format_quotes(val, quote_char) + + new_param = self.value.clone() + new_param.encode = _encoder + return new_param # type: ignore[return-value] + + def escape_like(val: str) -> str: return val.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") +def _format_str_or_parameter( + value: str | Parameter, + like_start: bool = False, + like_end: bool = False, + escape_func: Callable[[Any], str] = escape_like, +) -> Term: + like_at_start = "%" if like_start else "" + like_at_end = "%" if like_end else "" + + if isinstance(value, Parameter): + value.encode = escape_func + if like_start or like_end: + value.encode = lambda val: f"{like_at_start}{escape_func(val)}{like_at_end}" + return StrParamWrapper(value) + else: + return StrWrapper(f"{like_at_start}{escape_func(value)}{like_at_end}") + + def mysql_contains(field: Term, value: str) -> Criterion: return Like( functions.Cast(field, SqlTypes.CHAR), StrWrapper(f"%{escape_like(value)}%"), escape="" @@ -61,24 +101,30 @@ def mysql_contains(field: Term, value: str) -> Criterion: def mysql_starts_with(field: Term, value: str) -> Criterion: return Like( - functions.Cast(field, SqlTypes.CHAR), StrWrapper(f"{escape_like(value)}%"), escape="" + functions.Cast(field, SqlTypes.CHAR), + _format_str_or_parameter(value, False, True, escape_like), + escape="", ) def mysql_ends_with(field: Term, value: str) -> Criterion: return Like( - functions.Cast(field, SqlTypes.CHAR), StrWrapper(f"%{escape_like(value)}"), escape="" + functions.Cast(field, SqlTypes.CHAR), + _format_str_or_parameter(value, True, False, escape_like), + escape="", ) def mysql_insensitive_exact(field: Term, value: str) -> Criterion: - return functions.Upper(functions.Cast(field, SqlTypes.CHAR)).eq(functions.Upper(str(value))) + return functions.Upper(functions.Cast(field, SqlTypes.CHAR)).eq( + functions.Upper(_format_str_or_parameter(value, escape_func=str)) + ) def mysql_insensitive_contains(field: Term, value: str) -> Criterion: return Like( functions.Upper(functions.Cast(field, SqlTypes.CHAR)), - functions.Upper(StrWrapper(f"%{escape_like(value)}%")), + functions.Upper(_format_str_or_parameter(value, True, True, escape_like)), escape="", ) @@ -86,7 +132,7 @@ def mysql_insensitive_contains(field: Term, value: str) -> Criterion: def mysql_insensitive_starts_with(field: Term, value: str) -> Criterion: return Like( functions.Upper(functions.Cast(field, SqlTypes.CHAR)), - functions.Upper(StrWrapper(f"{escape_like(value)}%")), + functions.Upper(_format_str_or_parameter(value, False, True, escape_like)), escape="", ) @@ -94,18 +140,20 @@ def mysql_insensitive_starts_with(field: Term, value: str) -> Criterion: def mysql_insensitive_ends_with(field: Term, value: str) -> Criterion: return Like( functions.Upper(functions.Cast(field, SqlTypes.CHAR)), - functions.Upper(StrWrapper(f"%{escape_like(value)}")), + functions.Upper(_format_str_or_parameter(value, True, False, escape_like)), escape="", ) def mysql_search(field: Term, value: str) -> SearchCriterion: - return SearchCriterion(field, expr=StrWrapper(value)) + return SearchCriterion(field, expr=_format_str_or_parameter(value, escape_func=lambda x: x)) def mysql_posix_regex(field: Term, value: str) -> BasicCriterion: return BasicCriterion( - MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.CHAR)), StrWrapper(value) + MySQLRegexpComparators.REGEXP, + Coalesce(Cast(field, SqlTypes.CHAR)), + _format_str_or_parameter(value, escape_func=lambda x: x), ) diff --git a/tortoise/parameter.py b/tortoise/parameter.py index a3cd38288..405460123 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -139,17 +139,16 @@ def encode_collection(self, value: Any) -> Sequence[Any]: return self.collection_encoder(value, self.model) def get_sql(self, ctx: SqlContext) -> str: + if ctx.parameterizer is None: + raise ValueError("Parametrization must be enabled when using tortoise.Parameter.") + param = self if isinstance(ctx, TortoiseSqlContext) and ctx.dynamic_params is not None: param = ctx.dynamic_params.get(self.name, self) if param.collection_size is None: - if ctx.parameterizer is None: - raise ValueError("Parametrization must be enabled when using tortoise.Parameter.") return ctx.parameterizer.create_param(param).get_sql(ctx) else: - if ctx.parameterizer is None: - raise ValueError("Parametrization must be enabled when using tortoise.Parameter.") placeholders = [] for idx in range(param.collection_size): new_param = param.clone() From 88870e6841fd873a7c4eedfa14ae2a5850578b9c Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 23 Feb 2026 18:18:49 +0200 Subject: [PATCH 31/57] fix tests after upgrading to upstream --- tests/test_queryset_prepared.py | 660 ++++++++++++++++---------------- tortoise/queryset_prepared.py | 5 +- 2 files changed, 338 insertions(+), 327 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 29fe81c18..9834941ea 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -1,374 +1,382 @@ -from tests.testmodels import ( - Author, - Book, -) -from tortoise.contrib import test +import pytest + +from tests.testmodels import Author, Book from tortoise.exceptions import ParamsError, ValidationError from tortoise.expressions import Q, Subquery from tortoise.parameter import Parameter -class TestQuerysetPrepared(test.TestCase): - def test_prepared_queryset_always_same(self): - cache_key = "test_prepared_queryset_always_same" - prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - assert Author.prepare_sql(cache_key) is prepared +def test_prepared_queryset_always_same(db): + cache_key = "test_prepared_queryset_always_same" + prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() + assert Author.prepare_sql(cache_key) is prepared - def test_disallow_filtering_on_prepared_queryset(self): - cache_key = "test_disallow_filtering_on_prepared_queryset" - prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - with self.assertRaises(ValueError): - prepared.filter(id=1) +def test_disallow_filtering_on_prepared_queryset(db): + cache_key = "test_disallow_filtering_on_prepared_queryset" + prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - async def test_gte_filter(self): - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") + with pytest.raises(ValueError): + prepared.filter(id=1) - expected = await Author.filter(id__gte=author2.pk).order_by("id") - prepared = ( - Author.prepare_sql("test_gte_filter") - .filter(id__gte=Parameter("idgte")) - .order_by("id") - .prepared() - ) - actual = await prepared.execute(idgte=author2.pk) - self.assertEqual(len(actual), 2) - self.assertEqual(actual[0].id, author2.pk) - self.assertEqual(actual[1].id, author3.pk) - self.assertEqual(expected, actual) - - async def test_string_param(self): - await Author.create(name="1") - author2 = await Author.create(name="2") - await Author.create(name="3") - - expected = await Author.filter(name=author2.name) - - prepared = Author.prepare_sql("test_string_param").filter(name=Parameter("name")).prepared() - actual = await prepared.execute(name=author2.name) - self.assertEqual(len(actual), 1) - self.assertEqual(actual[0].id, author2.pk) - self.assertEqual(expected, actual) - - async def test_startswith_filter(self): - author1 = await Author.create(name="test") - author2 = await Author.create(name="testqwe") - author3 = await Author.create(name="qwetest") - - prepared = ( - Author.prepare_sql("test_startswith_filter") - .filter(name__startswith=Parameter("name")) - .prepared() - ) +@pytest.mark.asyncio +async def test_gte_filter(db): + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") - # print(Author.filter(name__startswith="asd").sql()) - # print(prepared.sql()) - - for test_name in (author2.pk, author1.name, author3.name, "asd"): - expected = await Author.filter(name__startswith=test_name) - actual = await prepared.execute(name=test_name) - self.assertEqual(expected, actual) - - async def test_in_filter(self): - author1 = await Author.create(name="test") - author2 = await Author.create(name="testqwe") - author3 = await Author.create(name="qwetest") - - prepared = Author.prepare_sql("test_in_filter").filter(id__in=Parameter("ids")).prepared() - - for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): - expected = await Author.filter(id__in=test_ids) - actual = await prepared.execute(ids=test_ids) - self.assertEqual(expected, actual) - - async def test_subqueries(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") - - prepared = ( - Author.prepare_sql("test_subqueries") - .filter( - id__in=Subquery( - Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id") - ) - ) - .prepared() - ) + expected = await Author.filter(id__gte=author2.pk).order_by("id") - for id1, id2 in ( - (author2.pk, author1.pk), - (author3.pk, author3.pk * 2), - ): - expected = await Author.filter( - id__in=Subquery(Author.filter(Q(id=id1) | Q(id=id2)).values("id")) - ) - actual = await prepared.execute(id1=id1, id2=id2) - self.assertEqual(expected, actual) - - async def test_subqueries_in_filter(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") - - prepared = ( - Author.prepare_sql("test_subqueries_in_filter") - .filter(id__in=Subquery(Author.filter(id__in=Parameter("ids")).values("id"))) - .prepared() - ) + prepared = ( + Author.prepare_sql("test_gte_filter") + .filter(id__gte=Parameter("idgte")) + .order_by("id") + .prepared() + ) + actual = await prepared.execute(idgte=author2.pk) + assert len(actual) == 2 + assert actual[0].id == author2.pk + assert actual[1].id == author3.pk + assert expected == actual - for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): - expected = await Author.filter( - id__in=Subquery(Author.filter(id__in=test_ids).values("id")) - ) - actual = await prepared.execute(ids=test_ids) - self.assertEqual(expected, actual) - - async def test_update(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - - original_name1 = author1.name - original_name2 = author2.name - new_name1 = f"{author1.name}_test" - - prepared = ( - Author.prepare_sql("test_update") - .filter(id=Parameter("search_id")) - .update(name=Parameter("replace_name")) - .prepared() - ) - await prepared.execute(search_id=author1.pk, replace_name=new_name1) - await author1.refresh_from_db(["name"]) - await author2.refresh_from_db(["name"]) - self.assertEqual(author1.name, new_name1) - self.assertEqual(author2.name, original_name2) - - await prepared.execute(search_id=author1.pk, replace_name=original_name1) - await author1.refresh_from_db(["name"]) - self.assertEqual(author1.name, original_name1) - - async def test_delete(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") - - prepared = ( - Author.prepare_sql("test_delete") - .filter( - id__in=Parameter("ids"), - ) - .delete() - .prepared() - ) +@pytest.mark.asyncio +async def test_string_param(db): + await Author.create(name="1") + author2 = await Author.create(name="2") + await Author.create(name="3") - affected = await prepared.execute(ids=[author1.pk]) - self.assertEqual(affected, 1) - self.assertEqual(await Author.all().count(), 2) - existing = await Author.all().values_list("id", flat=True) - self.assertEqual(set(existing), {author2.pk, author3.pk}) + expected = await Author.filter(name=author2.name) - async def test_exists(self): - author = await Author.create(name="1") + prepared = Author.prepare_sql("test_string_param").filter(name=Parameter("name")).prepared() + actual = await prepared.execute(name=author2.name) + assert len(actual) == 1 + assert actual[0].id == author2.pk + assert expected == actual - prepared = ( - Author.prepare_sql("test_exists") - .filter( - id__in=Parameter("ids"), - ) - .exists() - .prepared() - ) - self.assertTrue(await prepared.execute(ids=[author.pk])) - self.assertFalse(await prepared.execute(ids=[author.pk * 2])) +@pytest.mark.asyncio +async def test_startswith_filter(db): + author1 = await Author.create(name="test") + author2 = await Author.create(name="testqwe") + author3 = await Author.create(name="qwetest") - async def test_count(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") - author3 = await Author.create(name="3") + prepared = ( + Author.prepare_sql("test_startswith_filter") + .filter(name__startswith=Parameter("name")) + .prepared() + ) - prepared = ( - Author.prepare_sql("test_count") - .filter( - id__gte=Parameter("idgte"), - ) - .count() - .prepared() - ) + # print(Author.filter(name__startswith="asd").sql()) + # print(prepared.sql()) - self.assertEqual(await prepared.execute(idgte=author1.pk), 3) - self.assertEqual(await prepared.execute(idgte=author2.pk), 2) - self.assertEqual(await prepared.execute(idgte=author3.pk), 1) - self.assertEqual(await prepared.execute(idgte=author3.pk * 2), 0) - - async def test_parameter_in_limit(self): - await Author.bulk_create( - [ - Author(name="1"), - Author(name="2"), - Author(name="3"), - ] - ) + for test_name in (author2.pk, author1.name, author3.name, "asd"): + expected = await Author.filter(name__startswith=test_name) + actual = await prepared.execute(name=test_name) + assert expected == actual - prepared = ( - Author.prepare_sql("test_parameter_in_limit") - .all() - .limit(Parameter("lim")) - .order_by("id") - .prepared() - ) - self.assertEqual(len(await prepared.execute(lim=1)), 1) - self.assertEqual(len(await prepared.execute(lim=2)), 2) - self.assertEqual(len(await prepared.execute(lim=3)), 3) - self.assertEqual(len(await prepared.execute(lim=4)), 3) - - with self.assertRaises(ParamsError): - await prepared.execute(lim=-1) - - async def test_parameter_in_offset(self): - await Author.bulk_create( - [ - Author(name="1"), - Author(name="2"), - Author(name="3"), - ] - ) +@pytest.mark.asyncio +async def test_in_filter(db): + author1 = await Author.create(name="test") + author2 = await Author.create(name="testqwe") + author3 = await Author.create(name="qwetest") - prepared = ( - Author.prepare_sql("test_parameter_in_offset") - .all() - .offset(Parameter("off")) - .order_by("id") - .prepared() - ) + prepared = Author.prepare_sql("test_in_filter").filter(id__in=Parameter("ids")).prepared() - self.assertEqual(len(await prepared.execute(off=1)), 2) - self.assertEqual(len(await prepared.execute(off=2)), 1) - self.assertEqual(len(await prepared.execute(off=3)), 0) - self.assertEqual(len(await prepared.execute(off=4)), 0) + for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): + expected = await Author.filter(id__in=test_ids) + actual = await prepared.execute(ids=test_ids) + assert expected == actual - with self.assertRaises(ParamsError): - await prepared.execute(off=-1) - async def test_values(self): - author = await Author.create(name="1") +@pytest.mark.asyncio +async def test_subqueries(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") - prepared = ( - Author.prepare_sql("test_values") - .filter( - id=Parameter("id"), + prepared = ( + Author.prepare_sql("test_subqueries") + .filter( + id__in=Subquery( + Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id") ) - .values() - .prepared() ) - - self.assertEqual( - await prepared.execute(id=author.pk), - [{"id": author.pk, "name": author.name}], + .prepared() + ) + + for id1, id2 in ( + (author2.pk, author1.pk), + (author3.pk, author3.pk * 2), + ): + expected = await Author.filter( + id__in=Subquery(Author.filter(Q(id=id1) | Q(id=id2)).values("id")) ) - self.assertEqual( - await prepared.execute(id=author.pk * 2), - [], + actual = await prepared.execute(id1=id1, id2=id2) + assert expected == actual + + +@pytest.mark.asyncio +async def test_subqueries_in_filter(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = ( + Author.prepare_sql("test_subqueries_in_filter") + .filter(id__in=Subquery(Author.filter(id__in=Parameter("ids")).values("id"))) + .prepared() + ) + + for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): + expected = await Author.filter(id__in=Subquery(Author.filter(id__in=test_ids).values("id"))) + actual = await prepared.execute(ids=test_ids) + assert expected == actual + + +@pytest.mark.asyncio +async def test_update(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + + original_name1 = author1.name + original_name2 = author2.name + new_name1 = f"{author1.name}_test" + + prepared = ( + Author.prepare_sql("test_update") + .filter(id=Parameter("search_id")) + .update(name=Parameter("replace_name")) + .prepared() + ) + + await prepared.execute(search_id=author1.pk, replace_name=new_name1) + await author1.refresh_from_db(["name"]) + await author2.refresh_from_db(["name"]) + assert author1.name == new_name1 + assert author2.name == original_name2 + + await prepared.execute(search_id=author1.pk, replace_name=original_name1) + await author1.refresh_from_db(["name"]) + assert author1.name == original_name1 + + +@pytest.mark.asyncio +async def test_delete(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = ( + Author.prepare_sql("test_delete") + .filter( + id__in=Parameter("ids"), ) + .delete() + .prepared() + ) - async def test_values_list_all_fields(self): - author = await Author.create(name="1") + affected = await prepared.execute(ids=[author1.pk]) + assert affected == 1 + assert await Author.all().count() == 2 + existing = await Author.all().values_list("id", flat=True) + assert set(existing) == {author2.pk, author3.pk} - prepared_all = ( - Author.prepare_sql("test_values_list_all_fields") - .filter( - id=Parameter("id"), - ) - .values_list() - .prepared() - ) - self.assertEqual( - await prepared_all.execute(id=author.pk), - [(author.pk, author.name)], - ) - self.assertEqual( - await prepared_all.execute(id=author.pk * 2), - [], - ) - async def test_values_list_only_id_field(self): - author = await Author.create(name="1") +@pytest.mark.asyncio +async def test_exists(db): + author = await Author.create(name="1") - prepared_ids = ( - Author.prepare_sql("test_values_list_only_id_field") - .filter( - id=Parameter("id"), - ) - .values_list("id") - .prepared() - ) - self.assertEqual( - await prepared_ids.execute(id=author.pk), - [(author.pk,)], - ) - self.assertEqual( - await prepared_ids.execute(id=author.pk * 2), - [], + prepared = ( + Author.prepare_sql("test_exists") + .filter( + id__in=Parameter("ids"), ) + .exists() + .prepared() + ) - async def test_values_list_only_id_field_flat(self): - author = await Author.create(name="1") + assert await prepared.execute(ids=[author.pk]) + assert not await prepared.execute(ids=[author.pk * 2]) - prepared_ids_flat = ( - Author.prepare_sql("test_values_list_only_id_field_flat") - .filter( - id=Parameter("id"), - ) - .values_list("id", flat=True) - .prepared() - ) - self.assertEqual( - await prepared_ids_flat.execute(id=author.pk), - [author.pk], + +@pytest.mark.asyncio +async def test_count(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = ( + Author.prepare_sql("test_count") + .filter( + id__gte=Parameter("idgte"), ) - self.assertEqual( - await prepared_ids_flat.execute(id=author.pk * 2), - [], + .count() + .prepared() + ) + + assert await prepared.execute(idgte=author1.pk) == 3 + assert await prepared.execute(idgte=author2.pk) == 2 + assert await prepared.execute(idgte=author3.pk) == 1 + assert await prepared.execute(idgte=author3.pk * 2) == 0 + + +@pytest.mark.asyncio +async def test_parameter_in_limit(db): + await Author.bulk_create( + [ + Author(name="1"), + Author(name="2"), + Author(name="3"), + ] + ) + + prepared = ( + Author.prepare_sql("test_parameter_in_limit") + .all() + .limit(Parameter("lim")) + .order_by("id") + .prepared() + ) + + assert len(await prepared.execute(lim=1)) == 1 + assert len(await prepared.execute(lim=2)) == 2 + assert len(await prepared.execute(lim=3)) == 3 + assert len(await prepared.execute(lim=4)) == 3 + + with pytest.raises(ParamsError): + await prepared.execute(lim=-1) + + +@pytest.mark.asyncio +async def test_parameter_in_offset(db): + await Author.bulk_create( + [ + Author(name="1"), + Author(name="2"), + Author(name="3"), + ] + ) + + prepared = ( + Author.prepare_sql("test_parameter_in_offset") + .all() + .offset(Parameter("off")) + .order_by("id") + .prepared() + ) + + assert len(await prepared.execute(off=1)) == 2 + assert len(await prepared.execute(off=2)) == 1 + assert len(await prepared.execute(off=3)) == 0 + assert len(await prepared.execute(off=4)) == 0 + + with pytest.raises(ParamsError): + await prepared.execute(off=-1) + + +@pytest.mark.asyncio +async def test_values(db): + author = await Author.create(name="1") + + prepared = ( + Author.prepare_sql("test_values") + .filter( + id=Parameter("id"), ) + .values() + .prepared() + ) - async def test_update_fk(self): - author1 = await Author.create(name="1") - author2 = await Author.create(name="2") + assert await prepared.execute(id=author.pk) == [{"id": author.pk, "name": author.name}] + assert await prepared.execute(id=author.pk * 2) == [] - book = await Book.create(name="test", author=author1, rating=5) - prepared = ( - Book.prepare_sql("test_update_fk") - .filter(id=Parameter("search_id")) - .update(author=Parameter("replace_author")) - .prepared() +@pytest.mark.asyncio +async def test_values_list_all_fields(db): + author = await Author.create(name="1") + + prepared_all = ( + Author.prepare_sql("test_values_list_all_fields") + .filter( + id=Parameter("id"), ) + .values_list() + .prepared() + ) + assert await prepared_all.execute(id=author.pk) == [(author.pk, author.name)] + assert await prepared_all.execute(id=author.pk * 2) == [] + - await prepared.execute(search_id=book.pk, replace_author=author2) - book = await Book.get(id=book.pk).select_related("author") - # await book.refresh_from_db(["author_id"]) - self.assertEqual(book.author, author2) - - await prepared.execute(search_id=book.pk, replace_author=author1) - book = await Book.get(id=book.pk).select_related("author") - # await book.refresh_from_db(["author_id"]) - self.assertEqual(book.author, author1) - - async def test_update_pk_invalid_obj(self): - author = await Author.create(name="1") - book = await Book.create(name="test", author=author, rating=5) - - prepared = ( - Book.prepare_sql("test_update_pk_invalid_obj") - .filter(id=Parameter("search_id")) - .update(author=Parameter("replace_author")) - .prepared() +@pytest.mark.asyncio +async def test_values_list_only_id_field(db): + author = await Author.create(name="1") + + prepared_ids = ( + Author.prepare_sql("test_values_list_only_id_field") + .filter( + id=Parameter("id"), ) + .values_list("id") + .prepared() + ) + assert await prepared_ids.execute(id=author.pk) == [(author.pk,)] + assert await prepared_ids.execute(id=author.pk * 2) == [] + - with self.assertRaises(ValidationError): - await prepared.execute(search_id=book.pk, replace_author="not an Author object") +@pytest.mark.asyncio +async def test_values_list_only_id_field_flat(db): + author = await Author.create(name="1") + + prepared_ids_flat = ( + Author.prepare_sql("test_values_list_only_id_field_flat") + .filter( + id=Parameter("id"), + ) + .values_list("id", flat=True) + .prepared() + ) + assert await prepared_ids_flat.execute(id=author.pk) == [author.pk] + assert await prepared_ids_flat.execute(id=author.pk * 2) == [] + + +@pytest.mark.asyncio +async def test_update_fk(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + + book = await Book.create(name="test", author=author1, rating=5) + + prepared = ( + Book.prepare_sql("test_update_fk") + .filter(id=Parameter("search_id")) + .update(author=Parameter("replace_author")) + .prepared() + ) + + await prepared.execute(search_id=book.pk, replace_author=author2) + book = await Book.get(id=book.pk).select_related("author") + # await book.refresh_from_db(["author_id"]) + assert book.author == author2 + + await prepared.execute(search_id=book.pk, replace_author=author1) + book = await Book.get(id=book.pk).select_related("author") + # await book.refresh_from_db(["author_id"]) + assert book.author == author1 + + +@pytest.mark.asyncio +async def test_update_pk_invalid_obj(db): + author = await Author.create(name="1") + book = await Book.create(name="test", author=author, rating=5) + + prepared = ( + Book.prepare_sql("test_update_pk_invalid_obj") + .filter(id=Parameter("search_id")) + .update(author=Parameter("replace_author")) + .prepared() + ) + + with pytest.raises(ValidationError): + await prepared.execute(search_id=book.pk, replace_author="not an Author object") diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 54a2f70c8..91eb37c46 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -17,6 +17,7 @@ from tortoise.query_utils import Prefetch from tortoise.queryset import ( MODEL, + PRIMARY_KEY, SINGLE, AwaitableQuery, BulkCreateQuery, @@ -445,7 +446,9 @@ def last(self) -> PreparedQuerySetSingle[MODEL | None]: def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: return cast(PreparedQuerySetSingle, super().get(*args, **kwargs)) - async def in_bulk(self, id_list: Iterable[str | int], field_name: str) -> dict[str, MODEL]: + async def in_bulk( + self, id_list: Iterable[PRIMARY_KEY], field_name: str + ) -> dict[PRIMARY_KEY, MODEL]: raise NotImplementedError("Prepared queries don't support in_bulk.") def bulk_create( From 62fccf04d0e8c294b0487d492b0b38d4fd8dd619 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 24 Feb 2026 12:06:22 +0200 Subject: [PATCH 32/57] reset queryset prepared db when pulling query from cache --- tortoise/queryset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 5097b6e2a..a04e58efc 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1284,7 +1284,10 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: If query set is already in cache, return cached version with already generated sql. """ if key in self.model._meta.query_cache: - return self.model._meta.query_cache[key] + queryset = self.model._meta.query_cache[key]._clone() + queryset._db = None + queryset._db = queryset._choose_db(queryset._db_for_write) + return queryset from tortoise.queryset_prepared import PreparedQuerySet @@ -1295,6 +1298,8 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: queryset._dynamic_params = {} queryset._dynamic_params_names = [] queryset._db_for_write = self._select_for_update + queryset._db = None + queryset._db = queryset._choose_db(queryset._db_for_write) return queryset From 7f3440a40ca36b044142633325b7029deed017bb Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 24 Feb 2026 18:41:43 +0200 Subject: [PATCH 33/57] rewrite prepared query classes --- tests/test_queryset_prepared.py | 8 +- tortoise/models.py | 4 +- tortoise/queryset.py | 8 +- tortoise/queryset_prepared.py | 872 ++++++++++++++++++-------------- 4 files changed, 496 insertions(+), 396 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 9834941ea..43120c751 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -6,18 +6,16 @@ from tortoise.parameter import Parameter -def test_prepared_queryset_always_same(db): +def test_prepared_queryset_query_always_same(db): cache_key = "test_prepared_queryset_always_same" prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - assert Author.prepare_sql(cache_key) is prepared + assert Author.prepare_sql(cache_key).query is prepared.query def test_disallow_filtering_on_prepared_queryset(db): cache_key = "test_disallow_filtering_on_prepared_queryset" prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - - with pytest.raises(ValueError): - prepared.filter(id=1) + assert prepared is prepared.filter(id=1) @pytest.mark.asyncio diff --git a/tortoise/models.py b/tortoise/models.py index 74022ea2c..2e5159ef3 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -50,7 +50,7 @@ QuerySetSingle, RawSQLQuery, ) -from tortoise.queryset_prepared import PreparedQuerySet +from tortoise.queryset_prepared import PreparingQuerySet from tortoise.router import router from tortoise.signals import Signals from tortoise.transactions import in_transaction @@ -1614,7 +1614,7 @@ async def fetch_for_list( await db.executor_class(model=cls, db=db).fetch_for_list(instance_list, *args) @classmethod - def prepare_sql(cls, key: str) -> PreparedQuerySet[MODEL]: + def prepare_sql(cls, key: str) -> PreparingQuerySet[MODEL]: return cls._meta.manager.get_queryset().prepare_sql(key) @classmethod diff --git a/tortoise/queryset.py b/tortoise/queryset.py index a04e58efc..82b7d63d0 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model - from tortoise.queryset_prepared import PreparedQuerySet + from tortoise.queryset_prepared import PreparingQuerySet MODEL = TypeVar("MODEL", bound="Model") PRIMARY_KEY = TypeVar("PRIMARY_KEY") @@ -1278,7 +1278,7 @@ async def _execute(self) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list - def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: + def prepare_sql(self, key: str) -> PreparingQuerySet[MODEL]: """ Cache generated sql of this query set. If query set is already in cache, return cached version with already generated sql. @@ -1289,9 +1289,9 @@ def prepare_sql(self, key: str) -> PreparedQuerySet[MODEL]: queryset._db = queryset._choose_db(queryset._db_for_write) return queryset - from tortoise.queryset_prepared import PreparedQuerySet + from tortoise.queryset_prepared import PreparingQuerySet - queryset = cast(PreparedQuerySet[MODEL], self._clone(PreparedQuerySet)) + queryset = cast(PreparingQuerySet[MODEL], self._clone(PreparingQuerySet)) queryset._cache_key = key queryset._prepared = False queryset._sql_cache = {} diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 91eb37c46..5b63c27d5 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -1,18 +1,17 @@ from __future__ import annotations as _ -import functools import sys from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Callable, Iterable -from typing import Any, Concatenate, Literal, NoReturn, ParamSpec, Protocol, TypeVar, cast +from collections.abc import Iterable +from typing import Any, Literal, NoReturn, Protocol, cast +from pypika_tortoise.queries import QueryBuilder from pypika_tortoise.terms import Term from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned, ParamsError from tortoise.expressions import Expression, Q -from tortoise.filters import FilterInfoDict from tortoise.parameter import CollectionParameter, Parameter, TortoiseSqlContext from tortoise.query_utils import Prefetch from tortoise.queryset import ( @@ -95,40 +94,51 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: return filled_params -class _PreparedQueryMixin(AwaitableQuery, ABC): +class _PreparingQueryMixin(AwaitableQuery[MODEL], ABC): _cache_key: str - _prepared: bool - _sql_cache: dict[str, CachedSql] - _dynamic_params: dict[str, CollectionParameter] - _dynamic_params_names: list[str] _db_for_write: bool @abstractmethod - def _clone(self) -> Self: ... + def _clone(self, new_cls: type[AwaitableQuery[MODEL]] | None = None) -> AwaitableQuery[MODEL]: ... - def prepared(self) -> Self: - if self._cache_key is None: - raise ValueError("QuerySet.prepare_sql() must be called before QuerySet.prepared()") + @abstractmethod + def _clone_prepared(self) -> _PreparedQueryMixin[MODEL]: + ... + def prepared(self) -> Self: if self._cache_key in self.model._meta.query_cache: return self.model._meta.query_cache[self._cache_key] - queryset = self._clone() + queryset = self._clone_prepared() queryset._choose_db_if_not_chosen(self._db_for_write) queryset._make_query() + queryset._init_prepared() - _, params = queryset.query.get_parameterized_sql() - queryset._sql_cache = {} - queryset._dynamic_params = { - param.name: param for param in params if isinstance(param, CollectionParameter) - } - queryset._dynamic_params_names = sorted(queryset._dynamic_params.keys()) - queryset._prepared = True + return queryset - self.model._meta.query_cache[self._cache_key] = queryset - return queryset +class _PreparedQueryMixin(AwaitableQuery[MODEL], ABC): + _cache_key: str + _sql_cache: dict[str, CachedSql] + _dynamic_params: dict[str, CollectionParameter] + _dynamic_params_names: list[str] + _db_for_write: bool + + def prepare_sql(self, key: str) -> NoReturn: + raise NotImplementedError("QuerySets must be prepared only once") + + def prepared(self) -> PreparedQuerySet[MODEL]: + return self + + def _init_prepared(self) -> None: + _, params = self.query.get_parameterized_sql() + self._sql_cache = {} + self._dynamic_params = { + param.name: param for param in params if isinstance(param, CollectionParameter) + } + self._dynamic_params_names = sorted(self._dynamic_params.keys()) + self.model._meta.query_cache[self._cache_key] = self def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: reset_params = [] @@ -163,104 +173,162 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: @abstractmethod async def execute(self, **params) -> Any: ... + def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, self) + + def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, self) + + def limit(self, limit: int | Parameter) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def offset(self, offset: int | Parameter) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def distinct(self) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def select_for_update( + self, + nowait: bool = False, + skip_locked: bool = False, + of: tuple[str, ...] = (), + no_key: bool = False, + ) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def values_list( + self, *fields_: str, flat: bool = False + ) -> PreparedValuesListQuery[Literal[False]]: + return cast(PreparedValuesListQuery, self) + + def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: + return cast(PreparedValuesQuery, self) + + def delete(self) -> DeleteQuery: + return cast(DeleteQuery, self) + + def update(self, **kwargs: Any) -> PreparedUpdateQuery: + return cast(PreparedUpdateQuery, self) + + def count(self) -> PreparedCountQuery: + return cast(PreparedCountQuery, self) + + def exists(self) -> PreparedExistsQuery: + return cast(PreparedExistsQuery, self) -P = ParamSpec("P") -T = TypeVar("T") + def all(self) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + def first(self) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, self) -def _disallow_queryset_methods_on_prepared_query( - func: Callable[Concatenate[PreparedQuerySet, P], T], -) -> Callable[Concatenate[PreparedQuerySet, P], T]: - @functools.wraps(func) - def decorated(self: PreparedQuerySet, *args: P.args, **kwargs: P.kwargs) -> T: - if self._prepared: - raise ValueError(f'Cannot call "{func.__name__}" on already prepared queryset.') - return func(self, *args, **kwargs) + def last(self) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, self) - return decorated + def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: + return cast(PreparedQuerySetSingle, self) + async def in_bulk( + self, id_list: Iterable[PRIMARY_KEY], field_name: str + ) -> dict[PRIMARY_KEY, MODEL]: + raise NotImplementedError("Prepared queries don't support in_bulk.") + + def bulk_create( + self, + objects: Iterable[MODEL], + batch_size: int | None = None, + ignore_conflicts: bool = False, + update_fields: Iterable[str] | None = None, + on_conflict: Iterable[str] | None = None, + ) -> BulkCreateQuery[MODEL]: + raise NotImplementedError("Prepared queries don't support bulk_create.") + + def bulk_update( + self, + objects: Iterable[MODEL], + fields: Iterable[str], + batch_size: int | None = None, + ) -> BulkUpdateQuery[MODEL]: + raise NotImplementedError("Prepared queries don't support bulk_update.") + + def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: + return cast(PreparedQuerySetSingle, self) -class PreparedQuerySet(QuerySet[MODEL], _PreparedQueryMixin): + def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def select_related(self, *fields: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def force_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def use_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: + return cast(PreparedQuerySet, self) + + +class PreparingQuerySet(QuerySet[MODEL], _PreparingQueryMixin): __slots__ = ( "_cache_key", - "_prepared", - "_custom_fields", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", "_db_for_write", ) def __init__(self, model: type[MODEL], cache_key: str) -> None: super().__init__(model) self._cache_key: str = cache_key - self._prepared: bool = False - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] self._db_for_write = self._select_for_update - self._custom_fields: list[str] | None = None - def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparedQuerySet[MODEL]: + def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparingQuerySet[MODEL]: queryset = cast(Self, super()._clone(_new_cls)) queryset._cache_key = self._cache_key - queryset._prepared = self._prepared - queryset._sql_cache = self._sql_cache - queryset._dynamic_params = self._dynamic_params - queryset._dynamic_params_names = self._dynamic_params_names - queryset._db_for_write = self._db_for_write - return cast(PreparedQuerySet, queryset) + queryset._db_for_write = self._select_for_update + return cast(PreparingQuerySet, queryset) + + def _clone_prepared(self) -> _PreparedQueryMixin[MODEL]: + return self._clone(PreparedQuerySet) def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError("QuerySets must only be prepared once") + raise NotImplementedError("QuerySets must be prepared only once") def prepared(self) -> PreparedQuerySet[MODEL]: queryset = cast(Self, super().prepared()) queryset._custom_fields = list(self._annotations.keys()) return queryset - async def execute(self, **params) -> list[MODEL]: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) + def filter(self, *args: Q, **kwargs: Any) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().filter(*args, **kwargs)) - instance_list = await self._db.executor_class( - model=self.model, - db=self._db, - prefetch_map=self._prefetch_map, - prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, # type: ignore - ).execute_select( - cached_query.sql, - filled_params, - custom_fields=self._custom_fields, - ) - if self._single: - if len(instance_list) == 1: - return instance_list[0] - if not instance_list: - if self._raise_does_not_exist: - raise DoesNotExist(self.model) - return None # type: ignore - raise MultipleObjectsReturned(self.model) - return instance_list - - @_disallow_queryset_methods_on_prepared_query - def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().filter(*args, **kwargs)) - - @_disallow_queryset_methods_on_prepared_query - def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().exclude(*args, **kwargs)) + def exclude(self, *args: Q, **kwargs: Any) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().exclude(*args, **kwargs)) - @_disallow_queryset_methods_on_prepared_query - def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().order_by(*orderings)) + def order_by(self, *orderings: str) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().order_by(*orderings)) - @_disallow_queryset_methods_on_prepared_query def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, super().latest(*orderings)) - @_disallow_queryset_methods_on_prepared_query def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, super().earliest(*orderings)) @@ -270,8 +338,7 @@ def _validate_limit(value: int) -> int: raise ParamsError("Limit should be non-negative number") return value - @_disallow_queryset_methods_on_prepared_query - def limit(self, limit: int | Parameter) -> PreparedQuerySet[MODEL]: + def limit(self, limit: int | Parameter) -> PreparingQuerySet[MODEL]: if isinstance(limit, int) and limit < 0: raise ParamsError("Limit should be non-negative number") elif isinstance(limit, Parameter): @@ -287,8 +354,7 @@ def _validate_offset(value: int) -> int: raise ParamsError("Offset should be non-negative number") return value - @_disallow_queryset_methods_on_prepared_query - def offset(self, offset: int | Parameter) -> PreparedQuerySet[MODEL]: + def offset(self, offset: int | Parameter) -> PreparingQuerySet[MODEL]: if isinstance(offset, int) and offset < 0: raise ParamsError("Offset should be non-negative number") elif isinstance(offset, Parameter): @@ -300,39 +366,32 @@ def offset(self, offset: int | Parameter) -> PreparedQuerySet[MODEL]: queryset._limit = 1000000 return queryset - @_disallow_queryset_methods_on_prepared_query - def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().__getitem__(key)) + def __getitem__(self, key: slice) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().__getitem__(key)) - @_disallow_queryset_methods_on_prepared_query - def distinct(self) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().distinct()) + def distinct(self) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().distinct()) - @_disallow_queryset_methods_on_prepared_query def select_for_update( self, nowait: bool = False, skip_locked: bool = False, of: tuple[str, ...] = (), no_key: bool = False, - ) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().select_for_update(nowait, skip_locked, of, no_key)) + ) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().select_for_update(nowait, skip_locked, of, no_key)) - @_disallow_queryset_methods_on_prepared_query - def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().annotate(**kwargs)) + def annotate(self, **kwargs: Expression | Term) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().annotate(**kwargs)) - @_disallow_queryset_methods_on_prepared_query - def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().group_by(*fields)) + def group_by(self, *fields: str) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().group_by(*fields)) - @_disallow_queryset_methods_on_prepared_query def values_list( self, *fields_: str, flat: bool = False ) -> PreparedValuesListQuery[Literal[False]]: fields_for_select_list = self._get_fields_list_for_select(*fields_) - - return PreparedValuesListQuery( + query = ValuesListQuery( db=self._db, model=self.model, q_objects=self._q_objects, @@ -349,14 +408,27 @@ def values_list( group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + ) + query._db = query._choose_db(True) + query._make_query() + + prepared = PreparedValuesListQuery( + db=self._db, + model=self.model, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + flat=flat, + fields_for_select_list=fields_for_select_list, + annotations=self._annotations, + query=query.query, cache_key=self._cache_key, ) + prepared._init_prepared() + return prepared - @_disallow_queryset_methods_on_prepared_query def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: fields_for_select = self._get_fields_for_select(*args, **kwargs) - - return PreparedValuesQuery( + query = ValuesQuery( db=self._db, model=self.model, q_objects=self._q_objects, @@ -372,12 +444,25 @@ def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False group_bys=self._group_bys, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + ) + query._db = query._choose_db(True) + query._make_query() + + prepared = PreparedValuesQuery( + db=self._db, + model=self.model, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select=fields_for_select, + annotations=self._annotations, + query=query.query, cache_key=self._cache_key, ) + prepared._init_prepared() + return prepared - @_disallow_queryset_methods_on_prepared_query - def delete(self) -> DeleteQuery: - return PreparedDeleteQuery( + def delete(self) -> PreparedDeleteQuery: + query = DeleteQuery( model=self.model, db=self._db, q_objects=self._q_objects, @@ -385,12 +470,21 @@ def delete(self) -> DeleteQuery: custom_filters=self._custom_filters, limit=self._limit, orderings=self._orderings, + ) + query._db = query._choose_db(True) + query._make_query() + + prepared = PreparedDeleteQuery( + model=self.model, + db=self._db, + query=query.query, cache_key=self._cache_key, ) + prepared._init_prepared() + return prepared - @_disallow_queryset_methods_on_prepared_query def update(self, **kwargs: Any) -> PreparedUpdateQuery: - return PreparedUpdateQuery( + query = UpdateQuery( model=self.model, update_kwargs=kwargs, db=self._db, @@ -399,12 +493,21 @@ def update(self, **kwargs: Any) -> PreparedUpdateQuery: custom_filters=self._custom_filters, limit=self._limit, orderings=self._orderings, + ) + query._db = query._choose_db(True) + query._make_query() + + prepared = PreparedUpdateQuery( + model=self.model, + db=self._db, + query=query.query, cache_key=self._cache_key, ) + prepared._init_prepared() + return prepared - @_disallow_queryset_methods_on_prepared_query def count(self) -> PreparedCountQuery: - return PreparedCountQuery( + query = CountQuery( model=self.model, db=self._db, q_objects=self._q_objects, @@ -414,12 +517,23 @@ def count(self) -> PreparedCountQuery: offset=self._offset, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + ) + query._db = query._choose_db(True) + query._make_query() + + prepared = PreparedCountQuery( + model=self.model, + db=self._db, + query=query.query, + limit=self._limit, + offset=self._offset, cache_key=self._cache_key, ) + prepared._init_prepared() + return prepared - @_disallow_queryset_methods_on_prepared_query def exists(self) -> PreparedExistsQuery: - return PreparedExistsQuery( + query = ExistsQuery( model=self.model, db=self._db, q_objects=self._q_objects, @@ -427,22 +541,28 @@ def exists(self) -> PreparedExistsQuery: custom_filters=self._custom_filters, force_indexes=self._force_indexes, use_indexes=self._use_indexes, + ) + query._db = query._choose_db(True) + query._make_query() + + prepared = PreparedExistsQuery( + model=self.model, + db=self._db, + query=query.query, cache_key=self._cache_key, ) + prepared._init_prepared() + return prepared - @_disallow_queryset_methods_on_prepared_query - def all(self) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().all()) + def all(self) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().all()) - @_disallow_queryset_methods_on_prepared_query def first(self) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, super().first()) - @_disallow_queryset_methods_on_prepared_query def last(self) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, super().last()) - @_disallow_queryset_methods_on_prepared_query def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: return cast(PreparedQuerySetSingle, super().get(*args, **kwargs)) @@ -469,35 +589,82 @@ def bulk_update( ) -> BulkUpdateQuery[MODEL]: raise NotImplementedError("Prepared queries don't support bulk_update.") - @_disallow_queryset_methods_on_prepared_query def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, super().get_or_none(*args, **kwargs)) - @_disallow_queryset_methods_on_prepared_query - def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().only(*fields_for_select)) + def only(self, *fields_for_select: str) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().only(*fields_for_select)) - @_disallow_queryset_methods_on_prepared_query - def select_related(self, *fields: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().select_related(*fields)) + def select_related(self, *fields: str) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().select_related(*fields)) - @_disallow_queryset_methods_on_prepared_query - def force_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().force_index(*index_names)) + def force_index(self, *index_names: str) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().force_index(*index_names)) - @_disallow_queryset_methods_on_prepared_query - def use_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().use_index(*index_names)) + def use_index(self, *index_names: str) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().use_index(*index_names)) - @_disallow_queryset_methods_on_prepared_query - def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: - return cast(PreparedQuerySet, super().prefetch_related(*args)) + def prefetch_related(self, *args: str | Prefetch) -> PreparingQuerySet[MODEL]: + return cast(PreparingQuerySet, super().prefetch_related(*args)) -class PreparedUpdateQuery(UpdateQuery, _PreparedQueryMixin): +class PreparedQuerySet(_PreparedQueryMixin, QuerySet[MODEL]): + __slots__ = ( + "_cache_key", + "_custom_fields", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + "_db_for_write", + ) + + def __init__(self, model: type[MODEL], cache_key: str) -> None: + super().__init__(model) + self._cache_key: str = cache_key + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] + self._db_for_write = self._select_for_update + self._custom_fields: list[str] | None = None + + def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparedQuerySet[MODEL]: + queryset = cast(Self, super()._clone(_new_cls)) + queryset._cache_key = self._cache_key + queryset._sql_cache = self._sql_cache + queryset._dynamic_params = self._dynamic_params + queryset._dynamic_params_names = self._dynamic_params_names + queryset._db_for_write = self._db_for_write + return cast(PreparedQuerySet, queryset) + + async def execute(self, **params) -> list[MODEL]: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + instance_list = await self._db.executor_class( + model=self.model, + db=self._db, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, # type: ignore + ).execute_select( + cached_query.sql, + filled_params, + custom_fields=self._custom_fields, + ) + if self._single: + if len(instance_list) == 1: + return instance_list[0] + if not instance_list: + if self._raise_does_not_exist: + raise DoesNotExist(self.model) + return None # type: ignore + raise MultipleObjectsReturned(self.model) + return instance_list + + +class PreparedUpdateQuery(_PreparedQueryMixin): __slots__ = ( "_cache_key", - "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", @@ -507,46 +674,31 @@ class PreparedUpdateQuery(UpdateQuery, _PreparedQueryMixin): def __init__( self, model: type[MODEL], - update_kwargs: dict[str, Any], db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - orderings: list[tuple[str, str]], + query: QueryBuilder, cache_key: str, ) -> None: - super().__init__( - model, - update_kwargs, - db, - q_objects, - annotations, - custom_filters, - limit, - orderings, - ) + super().__init__(model) + self._db = db + self.query = query self._cache_key: str = cache_key - self._prepared: bool = False self._sql_cache: dict[str, CachedSql] = {} self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True - def _clone(self) -> PreparedUpdateQuery: - query = self.__class__( - model=self.model, - update_kwargs=self.update_kwargs, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - cache_key=self._cache_key, - ) - query._prepared = self._prepared + def _clone(self) -> PreparedUpdateQuery[MODEL]: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = self._db + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + return query async def execute(self, **params) -> int: @@ -555,10 +707,9 @@ async def execute(self, **params) -> int: return (await self._db.execute_query(cached_query.sql, filled_params))[0] -class PreparedDeleteQuery(DeleteQuery, _PreparedQueryMixin): +class PreparedDeleteQuery(_PreparedQueryMixin): __slots__ = ( "_cache_key", - "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", @@ -566,44 +717,33 @@ class PreparedDeleteQuery(DeleteQuery, _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - orderings: list[tuple[str, str]], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + query: QueryBuilder, + cache_key: str, ) -> None: - super().__init__( - model, - db, - q_objects, - annotations, - custom_filters, - limit, - orderings, - ) + super().__init__(model) + self._db = db + self.query = query + self._cache_key: str = cache_key - self._prepared: bool = False self._sql_cache: dict[str, CachedSql] = {} self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True - def _clone(self) -> PreparedDeleteQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - cache_key=self._cache_key, - ) - query._prepared = self._prepared + def _clone(self) -> PreparedDeleteQuery[MODEL]: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = self._db + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + return query async def execute(self, **params) -> int: @@ -612,10 +752,9 @@ async def execute(self, **params) -> int: return (await self._db.execute_query(cached_query.sql, filled_params))[0] -class PreparedExistsQuery(ExistsQuery, _PreparedQueryMixin): +class PreparedExistsQuery(_PreparedQueryMixin): __slots__ = ( "_cache_key", - "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", @@ -623,44 +762,33 @@ class PreparedExistsQuery(ExistsQuery, _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + query: QueryBuilder, + cache_key: str, ) -> None: - super().__init__( - model, - db, - q_objects, - annotations, - custom_filters, - force_indexes, - use_indexes, - ) + super().__init__(model) + self._db = db + self.query = query + self._cache_key: str = cache_key - self._prepared: bool = False self._sql_cache: dict[str, CachedSql] = {} self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self) -> PreparedExistsQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared + def _clone(self) -> PreparedExistsQuery[MODEL]: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = self._db + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + return query async def execute(self, **params) -> int: @@ -670,10 +798,11 @@ async def execute(self, **params) -> int: return bool(result) -class PreparedCountQuery(CountQuery, _PreparedQueryMixin): +class PreparedCountQuery(_PreparedQueryMixin): __slots__ = ( + "_limit", + "_offset", "_cache_key", - "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", @@ -681,50 +810,39 @@ class PreparedCountQuery(CountQuery, _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - limit: int | None, - offset: int | None, - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + query: QueryBuilder, + limit: int | None, + offset: int | None, + cache_key: str, ) -> None: - super().__init__( - model, - db, - q_objects, - annotations, - custom_filters, - limit, - offset, - force_indexes, - use_indexes, - ) + super().__init__(model) + self._db = db + self.query = query + self._limit = limit or 0 + self._offset = offset or 0 + self._cache_key: str = cache_key - self._prepared: bool = False self._sql_cache: dict[str, CachedSql] = {} self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self) -> PreparedCountQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - offset=self._offset, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared + def _clone(self) -> PreparedCountQuery[MODEL]: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = self._db + query._limit = self._limit + query._offset = self._offset + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + return query async def execute(self, **params) -> int: @@ -743,7 +861,6 @@ async def execute(self, **params) -> int: class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): __slots__ = ( "_cache_key", - "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", @@ -751,71 +868,57 @@ class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - q_objects: list[Q], - single: bool, - raise_does_not_exist: bool, - fields_for_select_list: tuple[str, ...] | list[str], - limit: int | None, - offset: int | None, - distinct: bool, - orderings: list[tuple[str, str]], - flat: bool, - annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - group_bys: tuple[str, ...], - force_indexes: set[str], - use_indexes: set[str], - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + single: bool, + raise_does_not_exist: bool, + fields_for_select_list: tuple[str, ...] | list[str], + flat: bool, + annotations: dict[str, Any], + query: QueryBuilder, + cache_key: str, ) -> None: super().__init__( - model, - db, - q_objects, - single, - raise_does_not_exist, - fields_for_select_list, - limit, - offset, - distinct, - orderings, - flat, - annotations, - custom_filters, - group_bys, - force_indexes, - use_indexes, + model=model, + db=db, + q_objects=[], + single=single, + raise_does_not_exist=raise_does_not_exist, + fields_for_select_list=fields_for_select_list, + limit=None, + offset=None, + distinct=False, + orderings=[], + flat=flat, + annotations=annotations, + custom_filters={}, + group_bys=(), + force_indexes=set(), + use_indexes=set(), ) + self.query = query + self._cache_key: str = cache_key - self._prepared: bool = False self._sql_cache: dict[str, CachedSql] = {} self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self) -> PreparedValuesListQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - fields_for_select_list=self._fields_for_select_list, - limit=self._limit, - offset=self._offset, - distinct=self._distinct, - orderings=self._orderings, - flat=self._flat, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared + def _clone(self) -> PreparedValuesListQuery[MODEL]: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = self._db + + # TODO: clone rest of the fields + + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + return query async def execute(self, **params) -> list[Any] | tuple: @@ -829,7 +932,6 @@ async def execute(self, **params) -> list[Any] | tuple: class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQueryMixin): __slots__ = ( "_cache_key", - "_prepared", "_sql_cache", "_dynamic_params", "_dynamic_params_names", @@ -840,65 +942,65 @@ def __init__( self, model: type[MODEL], db: BaseDBAsyncClient, - q_objects: list[Q], single: bool, raise_does_not_exist: bool, fields_for_select: dict[str, str], - limit: int | None, - offset: int | None, - distinct: bool, - orderings: list[tuple[str, str]], annotations: dict[str, Any], - custom_filters: dict[str, FilterInfoDict], - group_bys: tuple[str, ...], - force_indexes: set[str], - use_indexes: set[str], + query: QueryBuilder, cache_key: str, ) -> None: super().__init__( - model, - db, - q_objects, - single, - raise_does_not_exist, - fields_for_select, - limit, - offset, - distinct, - orderings, - annotations, - custom_filters, - group_bys, - force_indexes, - use_indexes, + model=model, + db=db, + q_objects=[], + single=single, + raise_does_not_exist=raise_does_not_exist, + fields_for_select=fields_for_select, + limit=None, + offset=None, + distinct=False, + orderings=[], + annotations=annotations, + custom_filters={}, + group_bys=(), + force_indexes=set(), + use_indexes=set(), ) + + self.query = query self._cache_key: str = cache_key - self._prepared: bool = False self._sql_cache: dict[str, CachedSql] = {} self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self) -> PreparedValuesQuery: - query = self.__class__( - model=self.model, - db=self._db, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - fields_for_select=self._fields_for_select, - limit=self._limit, - offset=self._offset, - distinct=self._distinct, - orderings=self._orderings, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - cache_key=self._cache_key, - ) - query._prepared = self._prepared + def _clone(self, _new_cls: type[ValuesQuery[MODEL]] | None = None) -> ValuesQuery[MODEL]: + if _new_cls is None: + _new_cls = self.__class__ + + # TODO: rewrite ._clone() + + query = _new_cls.__new__(_new_cls) + query.model = self.model + query._fields_for_select = self._fields_for_select + query._limit = self._limit + query._offset = self._offset + query._distinct = self._distinct + query._orderings = self._orderings + query._custom_filters = self._custom_filters + query._q_objects = self._q_objects + query._single = self._single + query._raise_does_not_exist = self._raise_does_not_exist + query._db = self._db + query._group_bys = self._group_bys + query._force_indexes = self._force_indexes + query._use_indexes = self._use_indexes + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + return query async def execute(self, **params) -> list[dict] | dict: From dc4b90c92c5bb52afa48470c9bc177c4f4d379d8 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 24 Feb 2026 20:23:34 +0200 Subject: [PATCH 34/57] clone all fields in PreparedValuesListQuery._clone and PreparedValuesQuery._clone; only store query and used query info in PreparedQuerySet; fix style and typing issues; --- tortoise/models.py | 6 +- tortoise/queryset.py | 26 ++- tortoise/queryset_prepared.py | 297 ++++++++++++++++++++-------------- 3 files changed, 189 insertions(+), 140 deletions(-) diff --git a/tortoise/models.py b/tortoise/models.py index 2e5159ef3..64ba87df9 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -50,7 +50,7 @@ QuerySetSingle, RawSQLQuery, ) -from tortoise.queryset_prepared import PreparingQuerySet +from tortoise.queryset_prepared import PreparingQuerySet, _PreparedQueryMixin from tortoise.router import router from tortoise.signals import Signals from tortoise.transactions import in_transaction @@ -257,7 +257,7 @@ def __init__(self, meta: Model.Meta) -> None: self.db_native_fields: list[tuple[str, str, Field]] = [] self.db_default_fields: list[tuple[str, str, Field]] = [] self.db_complex_fields: list[tuple[str, str, Field]] = [] - self.query_cache: dict[str, PreparedQuerySet] = {} + self.query_cache: dict[str, _PreparedQueryMixin] = {} @property def full_name(self) -> str: @@ -1614,7 +1614,7 @@ async def fetch_for_list( await db.executor_class(model=cls, db=db).fetch_for_list(instance_list, *args) @classmethod - def prepare_sql(cls, key: str) -> PreparingQuerySet[MODEL]: + def prepare_sql(cls, key: str) -> PreparingQuerySet[MODEL] | _PreparedQueryMixin: return cls._meta.manager.get_queryset().prepare_sql(key) @classmethod diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 82b7d63d0..b824d45ac 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model - from tortoise.queryset_prepared import PreparingQuerySet + from tortoise.queryset_prepared import PreparingQuerySet, _PreparedQueryMixin MODEL = TypeVar("MODEL", bound="Model") PRIMARY_KEY = TypeVar("PRIMARY_KEY") @@ -1278,30 +1278,24 @@ async def _execute(self) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list - def prepare_sql(self, key: str) -> PreparingQuerySet[MODEL]: + def prepare_sql(self, key: str) -> PreparingQuerySet[MODEL] | _PreparedQueryMixin: """ Cache generated sql of this query set. If query set is already in cache, return cached version with already generated sql. """ + if key in self.model._meta.query_cache: - queryset = self.model._meta.query_cache[key]._clone() - queryset._db = None - queryset._db = queryset._choose_db(queryset._db_for_write) - return queryset + prepared_queryset = self.model._meta.query_cache[key]._clone() + prepared_queryset._db = None # type: ignore + prepared_queryset._db = prepared_queryset._choose_db(prepared_queryset._db_for_write) + return prepared_queryset from tortoise.queryset_prepared import PreparingQuerySet - queryset = cast(PreparingQuerySet[MODEL], self._clone(PreparingQuerySet)) - queryset._cache_key = key - queryset._prepared = False - queryset._sql_cache = {} - queryset._dynamic_params = {} - queryset._dynamic_params_names = [] - queryset._db_for_write = self._select_for_update - queryset._db = None - queryset._db = queryset._choose_db(queryset._db_for_write) + preparing_queryset = cast(PreparingQuerySet[MODEL], self._clone(PreparingQuerySet)) + preparing_queryset._cache_key = key - return queryset + return preparing_queryset class UpdateQuery(AwaitableQuery): diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 5b63c27d5..50b7020bc 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -4,9 +4,9 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterable -from typing import Any, Literal, NoReturn, Protocol, cast +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Protocol, cast -from pypika_tortoise.queries import QueryBuilder +from pypika_tortoise.queries import QueryBuilder, Table from pypika_tortoise.terms import Term from tortoise.backends.base.client import BaseDBAsyncClient @@ -37,9 +37,12 @@ else: from typing_extensions import Self +if TYPE_CHECKING: + from tortoise import Model + class PreparedQuerySetSingle(QuerySetSingle[T_co], Protocol): - def prepared(self) -> PreparedQuerySet[MODEL]: ... + def prepared(self) -> PreparedQuerySet: ... async def execute(self, **params) -> list[MODEL]: ... @@ -94,30 +97,6 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: return filled_params -class _PreparingQueryMixin(AwaitableQuery[MODEL], ABC): - _cache_key: str - _db_for_write: bool - - @abstractmethod - def _clone(self, new_cls: type[AwaitableQuery[MODEL]] | None = None) -> AwaitableQuery[MODEL]: ... - - @abstractmethod - def _clone_prepared(self) -> _PreparedQueryMixin[MODEL]: - ... - - def prepared(self) -> Self: - if self._cache_key in self.model._meta.query_cache: - return self.model._meta.query_cache[self._cache_key] - - queryset = self._clone_prepared() - - queryset._choose_db_if_not_chosen(self._db_for_write) - queryset._make_query() - queryset._init_prepared() - - return queryset - - class _PreparedQueryMixin(AwaitableQuery[MODEL], ABC): _cache_key: str _sql_cache: dict[str, CachedSql] @@ -125,11 +104,14 @@ class _PreparedQueryMixin(AwaitableQuery[MODEL], ABC): _dynamic_params_names: list[str] _db_for_write: bool + @abstractmethod + def _clone(self) -> Self: ... + def prepare_sql(self, key: str) -> NoReturn: raise NotImplementedError("QuerySets must be prepared only once") - def prepared(self) -> PreparedQuerySet[MODEL]: - return self + def prepared(self) -> PreparedQuerySet: + return cast(PreparedQuerySet, self) def _init_prepared(self) -> None: _, params = self.query.get_parameterized_sql() @@ -173,13 +155,13 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: @abstractmethod async def execute(self, **params) -> Any: ... - def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet[MODEL]: + def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def order_by(self, *orderings: str) -> PreparedQuerySet[MODEL]: + def order_by(self, *orderings: str) -> PreparedQuerySet: return cast(PreparedQuerySet, self) def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: @@ -188,16 +170,16 @@ def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, self) - def limit(self, limit: int | Parameter) -> PreparedQuerySet[MODEL]: + def limit(self, limit: int | Parameter) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def offset(self, offset: int | Parameter) -> PreparedQuerySet[MODEL]: + def offset(self, offset: int | Parameter) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def __getitem__(self, key: slice) -> PreparedQuerySet[MODEL]: + def __getitem__(self, key: slice) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def distinct(self) -> PreparedQuerySet[MODEL]: + def distinct(self) -> PreparedQuerySet: return cast(PreparedQuerySet, self) def select_for_update( @@ -206,13 +188,13 @@ def select_for_update( skip_locked: bool = False, of: tuple[str, ...] = (), no_key: bool = False, - ) -> PreparedQuerySet[MODEL]: + ) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet[MODEL]: + def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def group_by(self, *fields: str) -> PreparedQuerySet[MODEL]: + def group_by(self, *fields: str) -> PreparedQuerySet: return cast(PreparedQuerySet, self) def values_list( @@ -235,7 +217,7 @@ def count(self) -> PreparedCountQuery: def exists(self) -> PreparedExistsQuery: return cast(PreparedExistsQuery, self) - def all(self) -> PreparedQuerySet[MODEL]: + def all(self) -> PreparedQuerySet: return cast(PreparedQuerySet, self) def first(self) -> PreparedQuerySetSingle[MODEL | None]: @@ -273,50 +255,60 @@ def bulk_update( def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, self) - def only(self, *fields_for_select: str) -> PreparedQuerySet[MODEL]: + def only(self, *fields_for_select: str) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def select_related(self, *fields: str) -> PreparedQuerySet[MODEL]: + def select_related(self, *fields: str) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def force_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + def force_index(self, *index_names: str) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def use_index(self, *index_names: str) -> PreparedQuerySet[MODEL]: + def use_index(self, *index_names: str) -> PreparedQuerySet: return cast(PreparedQuerySet, self) - def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet[MODEL]: + def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet: return cast(PreparedQuerySet, self) -class PreparingQuerySet(QuerySet[MODEL], _PreparingQueryMixin): - __slots__ = ( - "_cache_key", - "_db_for_write", - ) +class PreparingQuerySet(QuerySet[MODEL]): + __slots__ = ("_cache_key",) def __init__(self, model: type[MODEL], cache_key: str) -> None: super().__init__(model) self._cache_key: str = cache_key - self._db_for_write = self._select_for_update def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparingQuerySet[MODEL]: queryset = cast(Self, super()._clone(_new_cls)) queryset._cache_key = self._cache_key - queryset._db_for_write = self._select_for_update return cast(PreparingQuerySet, queryset) - def _clone_prepared(self) -> _PreparedQueryMixin[MODEL]: - return self._clone(PreparedQuerySet) + def prepared(self) -> PreparedQuerySet: + if self._cache_key in self.model._meta.query_cache: + return cast(PreparedQuerySet, self.model._meta.query_cache[self._cache_key]) + + self._db = self._choose_db(self._select_for_update) + self._make_query() + + prepared = PreparedQuerySet( + model=self.model, + db=self._db, + query=self.query, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + select_for_update=self._select_for_update, + custom_fields=list(self._annotations.keys()), + cache_key=self._cache_key, + ) + prepared._init_prepared() + return prepared def prepare_sql(self, key: str) -> NoReturn: raise NotImplementedError("QuerySets must be prepared only once") - def prepared(self) -> PreparedQuerySet[MODEL]: - queryset = cast(Self, super().prepared()) - queryset._custom_fields = list(self._annotations.keys()) - return queryset - def filter(self, *args: Q, **kwargs: Any) -> PreparingQuerySet[MODEL]: return cast(PreparingQuerySet, super().filter(*args, **kwargs)) @@ -391,7 +383,7 @@ def values_list( self, *fields_: str, flat: bool = False ) -> PreparedValuesListQuery[Literal[False]]: fields_for_select_list = self._get_fields_list_for_select(*fields_) - query = ValuesListQuery( + query: ValuesListQuery = ValuesListQuery( db=self._db, model=self.model, q_objects=self._q_objects, @@ -412,8 +404,8 @@ def values_list( query._db = query._choose_db(True) query._make_query() - prepared = PreparedValuesListQuery( - db=self._db, + prepared: PreparedValuesListQuery = PreparedValuesListQuery( + db=query._db, model=self.model, single=self._single, raise_does_not_exist=self._raise_does_not_exist, @@ -428,7 +420,7 @@ def values_list( def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: fields_for_select = self._get_fields_for_select(*args, **kwargs) - query = ValuesQuery( + query: ValuesQuery = ValuesQuery( db=self._db, model=self.model, q_objects=self._q_objects, @@ -448,8 +440,8 @@ def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False query._db = query._choose_db(True) query._make_query() - prepared = PreparedValuesQuery( - db=self._db, + prepared: PreparedValuesQuery = PreparedValuesQuery( + db=query._db, model=self.model, single=self._single, raise_does_not_exist=self._raise_does_not_exist, @@ -461,7 +453,7 @@ def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False prepared._init_prepared() return prepared - def delete(self) -> PreparedDeleteQuery: + def delete(self) -> PreparedDeleteQuery: # type: ignore query = DeleteQuery( model=self.model, db=self._db, @@ -476,14 +468,14 @@ def delete(self) -> PreparedDeleteQuery: prepared = PreparedDeleteQuery( model=self.model, - db=self._db, + db=query._db, query=query.query, cache_key=self._cache_key, ) prepared._init_prepared() return prepared - def update(self, **kwargs: Any) -> PreparedUpdateQuery: + def update(self, **kwargs: Any) -> PreparedUpdateQuery: # type: ignore query = UpdateQuery( model=self.model, update_kwargs=kwargs, @@ -499,14 +491,14 @@ def update(self, **kwargs: Any) -> PreparedUpdateQuery: prepared = PreparedUpdateQuery( model=self.model, - db=self._db, + db=query._db, query=query.query, cache_key=self._cache_key, ) prepared._init_prepared() return prepared - def count(self) -> PreparedCountQuery: + def count(self) -> PreparedCountQuery: # type: ignore query = CountQuery( model=self.model, db=self._db, @@ -523,7 +515,7 @@ def count(self) -> PreparedCountQuery: prepared = PreparedCountQuery( model=self.model, - db=self._db, + db=query._db, query=query.query, limit=self._limit, offset=self._offset, @@ -532,7 +524,7 @@ def count(self) -> PreparedCountQuery: prepared._init_prepared() return prepared - def exists(self) -> PreparedExistsQuery: + def exists(self) -> PreparedExistsQuery: # type: ignore query = ExistsQuery( model=self.model, db=self._db, @@ -547,7 +539,7 @@ def exists(self) -> PreparedExistsQuery: prepared = PreparedExistsQuery( model=self.model, - db=self._db, + db=query._db, query=query.query, cache_key=self._cache_key, ) @@ -608,7 +600,7 @@ def prefetch_related(self, *args: str | Prefetch) -> PreparingQuerySet[MODEL]: return cast(PreparingQuerySet, super().prefetch_related(*args)) -class PreparedQuerySet(_PreparedQueryMixin, QuerySet[MODEL]): +class PreparedQuerySet(_PreparedQueryMixin): __slots__ = ( "_cache_key", "_custom_fields", @@ -618,28 +610,65 @@ class PreparedQuerySet(_PreparedQueryMixin, QuerySet[MODEL]): "_db_for_write", ) - def __init__(self, model: type[MODEL], cache_key: str) -> None: + def __init__( + self, + model: type[MODEL], + query: QueryBuilder, + db: BaseDBAsyncClient, + prefetch_map: dict[str, set[str | Prefetch]], + prefetch_queries: dict[str, list[tuple[str | None, QuerySet]]], + select_related_idx: list[ + tuple[type[Model], int, Table | str, type[Model], Iterable[str | None]] + ], + single: bool, + raise_does_not_exist: bool, + select_for_update: bool, + custom_fields: list[str] | None, + cache_key: str, + ) -> None: super().__init__(model) + self._db = db + self._prefetch_map = prefetch_map + self._prefetch_queries = prefetch_queries + self._select_related_idx = select_related_idx + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._db_for_write = select_for_update + self._custom_fields: list[str] | None = custom_fields + + self.query = query self._cache_key: str = cache_key self._sql_cache: dict[str, CachedSql] = {} self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] - self._db_for_write = self._select_for_update - self._custom_fields: list[str] | None = None - def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparedQuerySet[MODEL]: - queryset = cast(Self, super()._clone(_new_cls)) + def _clone(self) -> PreparedQuerySet: + queryset = self.__class__.__new__(self.__class__) + queryset.model = self.model + queryset.query = self.query + queryset._capabilities = self._capabilities + queryset._annotations = self._annotations + + queryset._db = self._db + queryset._prefetch_map = self._prefetch_map + queryset._prefetch_queries = self._prefetch_queries + queryset._select_related_idx = self._select_related_idx + queryset._single = self._single + queryset._raise_does_not_exist = self._raise_does_not_exist + queryset._db_for_write = self._db_for_write + queryset._custom_fields = self._custom_fields + queryset._cache_key = self._cache_key queryset._sql_cache = self._sql_cache queryset._dynamic_params = self._dynamic_params queryset._dynamic_params_names = self._dynamic_params_names - queryset._db_for_write = self._db_for_write - return cast(PreparedQuerySet, queryset) + return queryset async def execute(self, **params) -> list[MODEL]: cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) + self._choose_db_if_not_chosen(self._db_for_write) instance_list = await self._db.executor_class( model=self.model, db=self._db, @@ -688,11 +717,14 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True - def _clone(self) -> PreparedUpdateQuery[MODEL]: + def _clone(self) -> PreparedUpdateQuery: query = self.__class__.__new__(self.__class__) query.model = self.model query.query = self.query query._db = self._db + query._capabilities = self._capabilities + query._annotations = self._annotations + query._cache_key = self._cache_key query._db_for_write = self._db_for_write query._sql_cache = self._sql_cache @@ -717,11 +749,11 @@ class PreparedDeleteQuery(_PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - query: QueryBuilder, - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + query: QueryBuilder, + cache_key: str, ) -> None: super().__init__(model) self._db = db @@ -733,10 +765,12 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True - def _clone(self) -> PreparedDeleteQuery[MODEL]: + def _clone(self) -> PreparedDeleteQuery: query = self.__class__.__new__(self.__class__) query.model = self.model query.query = self.query + query._capabilities = self._capabilities + query._annotations = self._annotations query._db = self._db query._cache_key = self._cache_key query._db_for_write = self._db_for_write @@ -762,11 +796,11 @@ class PreparedExistsQuery(_PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - query: QueryBuilder, - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + query: QueryBuilder, + cache_key: str, ) -> None: super().__init__(model) self._db = db @@ -778,10 +812,12 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self) -> PreparedExistsQuery[MODEL]: + def _clone(self) -> PreparedExistsQuery: query = self.__class__.__new__(self.__class__) query.model = self.model query.query = self.query + query._capabilities = self._capabilities + query._annotations = self._annotations query._db = self._db query._cache_key = self._cache_key query._db_for_write = self._db_for_write @@ -810,13 +846,13 @@ class PreparedCountQuery(_PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - query: QueryBuilder, - limit: int | None, - offset: int | None, - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + query: QueryBuilder, + limit: int | None, + offset: int | None, + cache_key: str, ) -> None: super().__init__(model) self._db = db @@ -830,10 +866,12 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self) -> PreparedCountQuery[MODEL]: + def _clone(self) -> PreparedCountQuery: query = self.__class__.__new__(self.__class__) query.model = self.model query.query = self.query + query._capabilities = self._capabilities + query._annotations = self._annotations query._db = self._db query._limit = self._limit query._offset = self._offset @@ -868,16 +906,16 @@ class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): ) def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - single: bool, - raise_does_not_exist: bool, - fields_for_select_list: tuple[str, ...] | list[str], - flat: bool, - annotations: dict[str, Any], - query: QueryBuilder, - cache_key: str, + self, + model: type[MODEL], + db: BaseDBAsyncClient, + single: bool, + raise_does_not_exist: bool, + fields_for_select_list: tuple[str, ...] | list[str], + flat: bool, + annotations: dict[str, Any], + query: QueryBuilder, + cache_key: str, ) -> None: super().__init__( model=model, @@ -905,13 +943,29 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self) -> PreparedValuesListQuery[MODEL]: + def _clone(self) -> PreparedValuesListQuery[SINGLE]: query = self.__class__.__new__(self.__class__) query.model = self.model query.query = self.query query._db = self._db + query._capabilities = self._capabilities - # TODO: clone rest of the fields + query.fields = self.fields + query._limit = self._limit + query._offset = self._offset + query._distinct = self._distinct + query._orderings = self._orderings + query._custom_filters = self._custom_filters + query._q_objects = self._q_objects + query._single = self._single + query._raise_does_not_exist = self._raise_does_not_exist + query._fields_for_select_list = self._fields_for_select_list + query._flat = self._flat + query._group_bys = self._group_bys + query._force_indexes = self._force_indexes + query._use_indexes = self._use_indexes + query._fields_to_select_sql = self._fields_to_select_sql + query._annotations = self._annotations query._cache_key = self._cache_key query._db_for_write = self._db_for_write @@ -974,14 +1028,13 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False - def _clone(self, _new_cls: type[ValuesQuery[MODEL]] | None = None) -> ValuesQuery[MODEL]: - if _new_cls is None: - _new_cls = self.__class__ - - # TODO: rewrite ._clone() - - query = _new_cls.__new__(_new_cls) + def _clone(self) -> PreparedValuesQuery[SINGLE]: + query = self.__class__.__new__(self.__class__) query.model = self.model + query.query = self.query + query._db = self._db + query._capabilities = self._capabilities + query._fields_for_select = self._fields_for_select query._limit = self._limit query._offset = self._offset @@ -995,6 +1048,8 @@ def _clone(self, _new_cls: type[ValuesQuery[MODEL]] | None = None) -> ValuesQuer query._group_bys = self._group_bys query._force_indexes = self._force_indexes query._use_indexes = self._use_indexes + query._annotations = self._annotations + query._cache_key = self._cache_key query._db_for_write = self._db_for_write query._sql_cache = self._sql_cache From 68c00544e90cfa59f7fa31988f4b039aca662995 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 1 Mar 2026 12:53:20 +0200 Subject: [PATCH 35/57] fix model type within prepared query not resolving --- tortoise/models.py | 2 +- tortoise/queryset.py | 2 +- tortoise/queryset_prepared.py | 86 ++++++++++++++++++++++++++--------- 3 files changed, 66 insertions(+), 24 deletions(-) diff --git a/tortoise/models.py b/tortoise/models.py index 64ba87df9..8235e2460 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -1614,7 +1614,7 @@ async def fetch_for_list( await db.executor_class(model=cls, db=db).fetch_for_list(instance_list, *args) @classmethod - def prepare_sql(cls, key: str) -> PreparingQuerySet[MODEL] | _PreparedQueryMixin: + def prepare_sql(cls, key: str) -> PreparingQuerySet[Self] | _PreparedQueryMixin[Self]: return cls._meta.manager.get_queryset().prepare_sql(key) @classmethod diff --git a/tortoise/queryset.py b/tortoise/queryset.py index b824d45ac..8586b8846 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1303,7 +1303,7 @@ class UpdateQuery(AwaitableQuery): "update_kwargs", "_orderings", "_limit", - "values", # TODO: unused? + "values", ) def __init__( diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 50b7020bc..a77af32f5 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -42,9 +42,29 @@ class PreparedQuerySetSingle(QuerySetSingle[T_co], Protocol): - def prepared(self) -> PreparedQuerySet: ... + def prefetch_related( + self, *args: str | Prefetch + ) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage - async def execute(self, **params) -> list[MODEL]: ... + def select_related(self, *args: str) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage + + def annotate( + self, **kwargs: Expression | Term + ) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage + + def only(self, *fields_for_select: str) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage + + def values_list( + self, *fields_: str, flat: bool = False + ) -> PreparedValuesListQuery[Literal[True]]: ... # pragma: nocoverage + + def values( + self, *args: str, **kwargs: str + ) -> PreparedValuesQuery[Literal[True]]: ... # pragma: nocoverage + + def prepared(self) -> PreparedQuerySet[T_co]: ... + + async def execute(self, **params) -> list[T_co]: ... class CachedSql: @@ -110,8 +130,8 @@ def _clone(self) -> Self: ... def prepare_sql(self, key: str) -> NoReturn: raise NotImplementedError("QuerySets must be prepared only once") - def prepared(self) -> PreparedQuerySet: - return cast(PreparedQuerySet, self) + def prepared(self) -> Self: + return self def _init_prepared(self) -> None: _, params = self.query.get_parameterized_sql() @@ -155,13 +175,13 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: @abstractmethod async def execute(self, **params) -> Any: ... - def filter(self, *args: Q, **kwargs: Any) -> PreparedQuerySet: + def filter(self, *args: Q, **kwargs: Any) -> Self: return cast(PreparedQuerySet, self) - def exclude(self, *args: Q, **kwargs: Any) -> PreparedQuerySet: + def exclude(self, *args: Q, **kwargs: Any) -> Self: return cast(PreparedQuerySet, self) - def order_by(self, *orderings: str) -> PreparedQuerySet: + def order_by(self, *orderings: str) -> Self: return cast(PreparedQuerySet, self) def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: @@ -170,16 +190,16 @@ def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, self) - def limit(self, limit: int | Parameter) -> PreparedQuerySet: + def limit(self, limit: int | Parameter) -> Self: return cast(PreparedQuerySet, self) - def offset(self, offset: int | Parameter) -> PreparedQuerySet: + def offset(self, offset: int | Parameter) -> Self: return cast(PreparedQuerySet, self) - def __getitem__(self, key: slice) -> PreparedQuerySet: + def __getitem__(self, key: slice) -> Self: return cast(PreparedQuerySet, self) - def distinct(self) -> PreparedQuerySet: + def distinct(self) -> Self: return cast(PreparedQuerySet, self) def select_for_update( @@ -188,13 +208,13 @@ def select_for_update( skip_locked: bool = False, of: tuple[str, ...] = (), no_key: bool = False, - ) -> PreparedQuerySet: + ) -> Self: return cast(PreparedQuerySet, self) - def annotate(self, **kwargs: Expression | Term) -> PreparedQuerySet: + def annotate(self, **kwargs: Expression | Term) -> Self: return cast(PreparedQuerySet, self) - def group_by(self, *fields: str) -> PreparedQuerySet: + def group_by(self, *fields: str) -> Self: return cast(PreparedQuerySet, self) def values_list( @@ -217,7 +237,7 @@ def count(self) -> PreparedCountQuery: def exists(self) -> PreparedExistsQuery: return cast(PreparedExistsQuery, self) - def all(self) -> PreparedQuerySet: + def all(self) -> Self: return cast(PreparedQuerySet, self) def first(self) -> PreparedQuerySetSingle[MODEL | None]: @@ -255,19 +275,19 @@ def bulk_update( def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, self) - def only(self, *fields_for_select: str) -> PreparedQuerySet: + def only(self, *fields_for_select: str) -> Self: return cast(PreparedQuerySet, self) - def select_related(self, *fields: str) -> PreparedQuerySet: + def select_related(self, *fields: str) -> Self: return cast(PreparedQuerySet, self) - def force_index(self, *index_names: str) -> PreparedQuerySet: + def force_index(self, *index_names: str) -> Self: return cast(PreparedQuerySet, self) - def use_index(self, *index_names: str) -> PreparedQuerySet: + def use_index(self, *index_names: str) -> Self: return cast(PreparedQuerySet, self) - def prefetch_related(self, *args: str | Prefetch) -> PreparedQuerySet: + def prefetch_related(self, *args: str | Prefetch) -> Self: return cast(PreparedQuerySet, self) @@ -600,7 +620,8 @@ def prefetch_related(self, *args: str | Prefetch) -> PreparingQuerySet[MODEL]: return cast(PreparingQuerySet, super().prefetch_related(*args)) -class PreparedQuerySet(_PreparedQueryMixin): +# TODO: make it generic +class PreparedQuerySet(_PreparedQueryMixin[MODEL]): __slots__ = ( "_cache_key", "_custom_fields", @@ -642,7 +663,10 @@ def __init__( self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] - def _clone(self) -> PreparedQuerySet: + def prepared(self) -> PreparedQuerySet[MODEL]: + return self + + def _clone(self) -> PreparedQuerySet[MODEL]: queryset = self.__class__.__new__(self.__class__) queryset.model = self.model queryset.query = self.query @@ -717,6 +741,9 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True + def prepared(self) -> PreparedUpdateQuery: + return self + def _clone(self) -> PreparedUpdateQuery: query = self.__class__.__new__(self.__class__) query.model = self.model @@ -765,6 +792,9 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = True + def prepared(self) -> PreparedDeleteQuery: + return self + def _clone(self) -> PreparedDeleteQuery: query = self.__class__.__new__(self.__class__) query.model = self.model @@ -812,6 +842,9 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False + def prepared(self) -> PreparedExistsQuery: + return self + def _clone(self) -> PreparedExistsQuery: query = self.__class__.__new__(self.__class__) query.model = self.model @@ -866,6 +899,9 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False + def prepared(self) -> PreparedCountQuery: + return self + def _clone(self) -> PreparedCountQuery: query = self.__class__.__new__(self.__class__) query.model = self.model @@ -943,6 +979,9 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False + def prepared(self) -> PreparedValuesListQuery[SINGLE]: + return self + def _clone(self) -> PreparedValuesListQuery[SINGLE]: query = self.__class__.__new__(self.__class__) query.model = self.model @@ -1028,6 +1067,9 @@ def __init__( self._dynamic_params_names: list[str] = [] self._db_for_write: bool = False + def prepared(self) -> PreparedValuesQuery[SINGLE]: + return self + def _clone(self) -> PreparedValuesQuery[SINGLE]: query = self.__class__.__new__(self.__class__) query.model = self.model From 165522d99c68413e35465f2c124ffc4aace8276a Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 1 Mar 2026 13:13:53 +0200 Subject: [PATCH 36/57] fix style and typing issues --- tests/test_queryset_prepared.py | 2 +- tortoise/queryset_prepared.py | 52 +++++++++++++++++---------------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 43120c751..a6834352d 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -2,7 +2,7 @@ from tests.testmodels import Author, Book from tortoise.exceptions import ParamsError, ValidationError -from tortoise.expressions import Q, Subquery +from tortoise.expressions import Q, Subquery, F from tortoise.parameter import Parameter diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index a77af32f5..6b0218647 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Literal, NoReturn, Protocol, cast +from typing import TYPE_CHECKING, Any, Literal, NoReturn, Protocol, TypeVar, cast from pypika_tortoise.queries import QueryBuilder, Table from pypika_tortoise.terms import Term @@ -26,7 +26,6 @@ ExistsQuery, QuerySet, QuerySetSingle, - T_co, UpdateQuery, ValuesListQuery, ValuesQuery, @@ -41,18 +40,21 @@ from tortoise import Model -class PreparedQuerySetSingle(QuerySetSingle[T_co], Protocol): +T = TypeVar("T") + + +class PreparedQuerySetSingle(QuerySetSingle[T], Protocol[T]): def prefetch_related( self, *args: str | Prefetch - ) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage + ) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage - def select_related(self, *args: str) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage + def select_related(self, *args: str) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage def annotate( self, **kwargs: Expression | Term - ) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage + ) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage - def only(self, *fields_for_select: str) -> PreparedQuerySetSingle[T_co]: ... # pragma: nocoverage + def only(self, *fields_for_select: str) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage def values_list( self, *fields_: str, flat: bool = False @@ -62,9 +64,9 @@ def values( self, *args: str, **kwargs: str ) -> PreparedValuesQuery[Literal[True]]: ... # pragma: nocoverage - def prepared(self) -> PreparedQuerySet[T_co]: ... + def prepared(self) -> PreparedQuerySetSingle[T]: ... - async def execute(self, **params) -> list[T_co]: ... + async def execute(self, **params) -> list[T]: ... class CachedSql: @@ -176,13 +178,13 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: async def execute(self, **params) -> Any: ... def filter(self, *args: Q, **kwargs: Any) -> Self: - return cast(PreparedQuerySet, self) + return self def exclude(self, *args: Q, **kwargs: Any) -> Self: - return cast(PreparedQuerySet, self) + return self def order_by(self, *orderings: str) -> Self: - return cast(PreparedQuerySet, self) + return self def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, self) @@ -191,16 +193,16 @@ def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, self) def limit(self, limit: int | Parameter) -> Self: - return cast(PreparedQuerySet, self) + return self def offset(self, offset: int | Parameter) -> Self: - return cast(PreparedQuerySet, self) + return self def __getitem__(self, key: slice) -> Self: - return cast(PreparedQuerySet, self) + return self def distinct(self) -> Self: - return cast(PreparedQuerySet, self) + return self def select_for_update( self, @@ -209,13 +211,13 @@ def select_for_update( of: tuple[str, ...] = (), no_key: bool = False, ) -> Self: - return cast(PreparedQuerySet, self) + return self def annotate(self, **kwargs: Expression | Term) -> Self: - return cast(PreparedQuerySet, self) + return self def group_by(self, *fields: str) -> Self: - return cast(PreparedQuerySet, self) + return self def values_list( self, *fields_: str, flat: bool = False @@ -238,7 +240,7 @@ def exists(self) -> PreparedExistsQuery: return cast(PreparedExistsQuery, self) def all(self) -> Self: - return cast(PreparedQuerySet, self) + return self def first(self) -> PreparedQuerySetSingle[MODEL | None]: return cast(PreparedQuerySetSingle, self) @@ -276,19 +278,19 @@ def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | return cast(PreparedQuerySetSingle, self) def only(self, *fields_for_select: str) -> Self: - return cast(PreparedQuerySet, self) + return self def select_related(self, *fields: str) -> Self: - return cast(PreparedQuerySet, self) + return self def force_index(self, *index_names: str) -> Self: - return cast(PreparedQuerySet, self) + return self def use_index(self, *index_names: str) -> Self: - return cast(PreparedQuerySet, self) + return self def prefetch_related(self, *args: str | Prefetch) -> Self: - return cast(PreparedQuerySet, self) + return self class PreparingQuerySet(QuerySet[MODEL]): From bf411ee206bfa1262d523c11ae314ec55b3a8eef Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Sun, 1 Mar 2026 15:50:28 +0200 Subject: [PATCH 37/57] fix return type of PreparedQuerySetSingle.execute --- tortoise/queryset_prepared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 6b0218647..eb7149e8b 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -66,7 +66,7 @@ def values( def prepared(self) -> PreparedQuerySetSingle[T]: ... - async def execute(self, **params) -> list[T]: ... + async def execute(self, **params) -> T: ... class CachedSql: From b678f76523e170fbc09ea34e1e9e5378ce489df3 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Mar 2026 13:36:39 +0200 Subject: [PATCH 38/57] remove unused "F" import --- tests/test_queryset_prepared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index a6834352d..43120c751 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -2,7 +2,7 @@ from tests.testmodels import Author, Book from tortoise.exceptions import ParamsError, ValidationError -from tortoise.expressions import Q, Subquery, F +from tortoise.expressions import Q, Subquery from tortoise.parameter import Parameter From 6c08eb0789763400f2f68a106f261b0921f78760 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Mar 2026 13:48:42 +0200 Subject: [PATCH 39/57] use QuerySet methods for creating queries in .values_list, .values, .update, .delete, .count, .exists --- tortoise/queryset_prepared.py | 83 +++-------------------------------- 1 file changed, 6 insertions(+), 77 deletions(-) diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index eb7149e8b..c3dca37a4 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -21,12 +21,9 @@ AwaitableQuery, BulkCreateQuery, BulkUpdateQuery, - CountQuery, DeleteQuery, - ExistsQuery, QuerySet, QuerySetSingle, - UpdateQuery, ValuesListQuery, ValuesQuery, ) @@ -405,24 +402,7 @@ def values_list( self, *fields_: str, flat: bool = False ) -> PreparedValuesListQuery[Literal[False]]: fields_for_select_list = self._get_fields_list_for_select(*fields_) - query: ValuesListQuery = ValuesListQuery( - db=self._db, - model=self.model, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - flat=flat, - fields_for_select_list=fields_for_select_list, - distinct=self._distinct, - limit=self._limit, - offset=self._offset, - orderings=self._orderings, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - ) + query = super().values_list(*fields_, flat=flat) query._db = query._choose_db(True) query._make_query() @@ -442,23 +422,7 @@ def values_list( def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: fields_for_select = self._get_fields_for_select(*args, **kwargs) - query: ValuesQuery = ValuesQuery( - db=self._db, - model=self.model, - q_objects=self._q_objects, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - fields_for_select=fields_for_select, - distinct=self._distinct, - limit=self._limit, - offset=self._offset, - orderings=self._orderings, - annotations=self._annotations, - custom_filters=self._custom_filters, - group_bys=self._group_bys, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - ) + query = super().values(*args, **kwargs) query._db = query._choose_db(True) query._make_query() @@ -476,15 +440,7 @@ def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False return prepared def delete(self) -> PreparedDeleteQuery: # type: ignore - query = DeleteQuery( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - ) + query = super().delete() query._db = query._choose_db(True) query._make_query() @@ -498,16 +454,7 @@ def delete(self) -> PreparedDeleteQuery: # type: ignore return prepared def update(self, **kwargs: Any) -> PreparedUpdateQuery: # type: ignore - query = UpdateQuery( - model=self.model, - update_kwargs=kwargs, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - orderings=self._orderings, - ) + query = super().update(**kwargs) query._db = query._choose_db(True) query._make_query() @@ -521,17 +468,7 @@ def update(self, **kwargs: Any) -> PreparedUpdateQuery: # type: ignore return prepared def count(self) -> PreparedCountQuery: # type: ignore - query = CountQuery( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - limit=self._limit, - offset=self._offset, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - ) + query = super().count() query._db = query._choose_db(True) query._make_query() @@ -547,15 +484,7 @@ def count(self) -> PreparedCountQuery: # type: ignore return prepared def exists(self) -> PreparedExistsQuery: # type: ignore - query = ExistsQuery( - model=self.model, - db=self._db, - q_objects=self._q_objects, - annotations=self._annotations, - custom_filters=self._custom_filters, - force_indexes=self._force_indexes, - use_indexes=self._use_indexes, - ) + query = super().exists() query._db = query._choose_db(True) query._make_query() From 5cd827aa1a2173c3248b84d0350d9228043a2a86 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Mar 2026 16:27:16 +0200 Subject: [PATCH 40/57] add ability to remove prepared query from model's cache --- tests/test_queryset_prepared.py | 9 +++++++++ tortoise/models.py | 4 ++++ tortoise/queryset.py | 1 + tortoise/queryset_prepared.py | 1 - 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 43120c751..8b9084b9f 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -378,3 +378,12 @@ async def test_update_pk_invalid_obj(db): with pytest.raises(ValidationError): await prepared.execute(search_id=book.pk, replace_author="not an Author object") + + +def test_remove_prepared_queryset_from_cache(db): + cache_key = "test_remove_query_from_cache" + prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() + assert Author.prepare_sql(cache_key).query is prepared.query + Author.remove_prepared_query(cache_key) + assert Author.prepare_sql(cache_key).query is not prepared.query + diff --git a/tortoise/models.py b/tortoise/models.py index 8235e2460..ee68666b1 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -1617,6 +1617,10 @@ async def fetch_for_list( def prepare_sql(cls, key: str) -> PreparingQuerySet[Self] | _PreparedQueryMixin[Self]: return cls._meta.manager.get_queryset().prepare_sql(key) + @classmethod + def remove_prepared_query(cls, key: str) -> None: + cls._meta.query_cache.pop(key, None) + @classmethod def _check(cls) -> None: """ diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 8586b8846..ba4fbbba6 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1286,6 +1286,7 @@ def prepare_sql(self, key: str) -> PreparingQuerySet[MODEL] | _PreparedQueryMixi if key in self.model._meta.query_cache: prepared_queryset = self.model._meta.query_cache[key]._clone() + # TODO: select db in .prepared, not in here prepared_queryset._db = None # type: ignore prepared_queryset._db = prepared_queryset._choose_db(prepared_queryset._db_for_write) return prepared_queryset diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index c3dca37a4..c6c38478b 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -551,7 +551,6 @@ def prefetch_related(self, *args: str | Prefetch) -> PreparingQuerySet[MODEL]: return cast(PreparingQuerySet, super().prefetch_related(*args)) -# TODO: make it generic class PreparedQuerySet(_PreparedQueryMixin[MODEL]): __slots__ = ( "_cache_key", From 488a01ea6c56f2f3bbfe57c01889cb1be6f9d6f3 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Mar 2026 16:37:29 +0200 Subject: [PATCH 41/57] add .sql method implementation to prepared queries --- tests/test_queryset_prepared.py | 22 +++++++++++++++++++++- tortoise/queryset_prepared.py | 4 ++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 8b9084b9f..445382f65 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -1,6 +1,6 @@ import pytest -from tests.testmodels import Author, Book +from tests.testmodels import Author, Book, CharPkModel from tortoise.exceptions import ParamsError, ValidationError from tortoise.expressions import Q, Subquery from tortoise.parameter import Parameter @@ -387,3 +387,23 @@ def test_remove_prepared_queryset_from_cache(db): Author.remove_prepared_query(cache_key) assert Author.prepare_sql(cache_key).query is not prepared.query + +@pytest.mark.parametrize( + ("filter_kwargs", "cache_key_suffix",), + [ + ({"id": "123"}, "1"), + ({"id__gte": "321"}, "2"), + ({"id__in": ["321", 123, 987]}, "3"), + ], +) +def test_prepared_query_get_sql(db, filter_kwargs: dict[str, ...], cache_key_suffix: str): + expected_sql = CharPkModel.all().filter(**filter_kwargs).limit(10).offset(0).sql() + actual_sql = CharPkModel.prepare_sql( + f"test_prepared_query_get_sql-{cache_key_suffix}" + ).all().filter(**{ + key: Parameter(key) + for key in filter_kwargs + }).limit(10).offset(0).prepared().sql(**filter_kwargs) + + assert expected_sql == actual_sql + diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index c6c38478b..d48fe8517 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -174,6 +174,10 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: @abstractmethod async def execute(self, **params) -> Any: ... + def sql(self, **params) -> str: + cached_query = self._get_or_create_cached_sql(params) + return cached_query.sql + def filter(self, *args: Q, **kwargs: Any) -> Self: return self From 69311e12a415b29bcbc58ea901479f272800ec0a Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Mar 2026 16:39:42 +0200 Subject: [PATCH 42/57] fix style issue in test_prepared_query_get_sql; fix typing issue in _PreparedQueryMixin.sql --- tests/test_queryset_prepared.py | 25 ++++++++++++++++--------- tortoise/queryset_prepared.py | 2 +- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_prepared.py index 445382f65..cbd083b3a 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_prepared.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from tests.testmodels import Author, Book, CharPkModel @@ -389,21 +391,26 @@ def test_remove_prepared_queryset_from_cache(db): @pytest.mark.parametrize( - ("filter_kwargs", "cache_key_suffix",), + ( + "filter_kwargs", + "cache_key_suffix", + ), [ ({"id": "123"}, "1"), ({"id__gte": "321"}, "2"), ({"id__in": ["321", 123, 987]}, "3"), ], ) -def test_prepared_query_get_sql(db, filter_kwargs: dict[str, ...], cache_key_suffix: str): +def test_prepared_query_get_sql(db, filter_kwargs: dict[str, Any], cache_key_suffix: str): expected_sql = CharPkModel.all().filter(**filter_kwargs).limit(10).offset(0).sql() - actual_sql = CharPkModel.prepare_sql( - f"test_prepared_query_get_sql-{cache_key_suffix}" - ).all().filter(**{ - key: Parameter(key) - for key in filter_kwargs - }).limit(10).offset(0).prepared().sql(**filter_kwargs) + actual_sql = ( + CharPkModel.prepare_sql(f"test_prepared_query_get_sql-{cache_key_suffix}") + .all() + .filter(**{key: Parameter(key) for key in filter_kwargs}) + .limit(10) + .offset(0) + .prepared() + .sql(**filter_kwargs) + ) assert expected_sql == actual_sql - diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index d48fe8517..409b87cdd 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -174,7 +174,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: @abstractmethod async def execute(self, **params) -> Any: ... - def sql(self, **params) -> str: + def sql(self, params_inline=False, **params) -> str: cached_query = self._get_or_create_cached_sql(params) return cached_query.sql From 0da28ce91f2748cd70db5d143ff07884f64209aa Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Wed, 4 Mar 2026 20:53:15 +0200 Subject: [PATCH 43/57] return PreparedDeleteQuery from _PreparedQueryMixin.delete instead of DeleteQuery --- tortoise/queryset_prepared.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py index 409b87cdd..957a67b7c 100644 --- a/tortoise/queryset_prepared.py +++ b/tortoise/queryset_prepared.py @@ -21,7 +21,6 @@ AwaitableQuery, BulkCreateQuery, BulkUpdateQuery, - DeleteQuery, QuerySet, QuerySetSingle, ValuesListQuery, @@ -228,8 +227,8 @@ def values_list( def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: return cast(PreparedValuesQuery, self) - def delete(self) -> DeleteQuery: - return cast(DeleteQuery, self) + def delete(self) -> PreparedDeleteQuery: + return cast(PreparedDeleteQuery, self) def update(self, **kwargs: Any) -> PreparedUpdateQuery: return cast(PreparedUpdateQuery, self) From 8abf5c0502989dbf288694d4f4c74d7c03cc586c Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 08:57:35 +0200 Subject: [PATCH 44/57] rewrite compiled query classes without using QuerySet as a base class; add ability to only compile queries after queryset was constructed --- ..._prepared.py => test_queryset_compiled.py} | 100 ++-- tortoise/models.py | 10 +- tortoise/queryset.py | 180 ++++++- tortoise/queryset_compiled.py | 465 ++++++++++++++++++ 4 files changed, 681 insertions(+), 74 deletions(-) rename tests/{test_queryset_prepared.py => test_queryset_compiled.py} (81%) create mode 100644 tortoise/queryset_compiled.py diff --git a/tests/test_queryset_prepared.py b/tests/test_queryset_compiled.py similarity index 81% rename from tests/test_queryset_prepared.py rename to tests/test_queryset_compiled.py index cbd083b3a..c13401182 100644 --- a/tests/test_queryset_prepared.py +++ b/tests/test_queryset_compiled.py @@ -10,14 +10,8 @@ def test_prepared_queryset_query_always_same(db): cache_key = "test_prepared_queryset_always_same" - prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - assert Author.prepare_sql(cache_key).query is prepared.query - - -def test_disallow_filtering_on_prepared_queryset(db): - cache_key = "test_disallow_filtering_on_prepared_queryset" - prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - assert prepared is prepared.filter(id=1) + prepared = Author.filter(id=Parameter("some_param")).compile(cache_key) + assert Author.all().compile(cache_key).query is prepared.query @pytest.mark.asyncio @@ -28,15 +22,16 @@ async def test_gte_filter(db): expected = await Author.filter(id__gte=author2.pk).order_by("id") prepared = ( - Author.prepare_sql("test_gte_filter") + Author .filter(id__gte=Parameter("idgte")) .order_by("id") - .prepared() + .compile("test_gte_filter") ) + print(prepared.sql(idgte=author2.pk)) actual = await prepared.execute(idgte=author2.pk) assert len(actual) == 2 - assert actual[0].id == author2.pk - assert actual[1].id == author3.pk + assert actual[0].pk == author2.pk + assert actual[1].pk == author3.pk assert expected == actual @@ -48,10 +43,10 @@ async def test_string_param(db): expected = await Author.filter(name=author2.name) - prepared = Author.prepare_sql("test_string_param").filter(name=Parameter("name")).prepared() + prepared = Author.filter(name=Parameter("name")).compile("test_string_param") actual = await prepared.execute(name=author2.name) assert len(actual) == 1 - assert actual[0].id == author2.pk + assert actual[0].pk == author2.pk assert expected == actual @@ -62,14 +57,11 @@ async def test_startswith_filter(db): author3 = await Author.create(name="qwetest") prepared = ( - Author.prepare_sql("test_startswith_filter") + Author .filter(name__startswith=Parameter("name")) - .prepared() + .compile("test_startswith_filter") ) - # print(Author.filter(name__startswith="asd").sql()) - # print(prepared.sql()) - for test_name in (author2.pk, author1.name, author3.name, "asd"): expected = await Author.filter(name__startswith=test_name) actual = await prepared.execute(name=test_name) @@ -82,7 +74,7 @@ async def test_in_filter(db): author2 = await Author.create(name="testqwe") author3 = await Author.create(name="qwetest") - prepared = Author.prepare_sql("test_in_filter").filter(id__in=Parameter("ids")).prepared() + prepared = Author.filter(id__in=Parameter("ids")).compile("test_in_filter") for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): expected = await Author.filter(id__in=test_ids) @@ -97,13 +89,13 @@ async def test_subqueries(db): author3 = await Author.create(name="3") prepared = ( - Author.prepare_sql("test_subqueries") + Author .filter( id__in=Subquery( Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id") ) ) - .prepared() + .compile("test_subqueries") ) for id1, id2 in ( @@ -124,9 +116,9 @@ async def test_subqueries_in_filter(db): author3 = await Author.create(name="3") prepared = ( - Author.prepare_sql("test_subqueries_in_filter") + Author .filter(id__in=Subquery(Author.filter(id__in=Parameter("ids")).values("id"))) - .prepared() + .compile("test_subqueries_in_filter") ) for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): @@ -145,10 +137,10 @@ async def test_update(db): new_name1 = f"{author1.name}_test" prepared = ( - Author.prepare_sql("test_update") + Author .filter(id=Parameter("search_id")) .update(name=Parameter("replace_name")) - .prepared() + .compile("test_update") ) await prepared.execute(search_id=author1.pk, replace_name=new_name1) @@ -169,12 +161,12 @@ async def test_delete(db): author3 = await Author.create(name="3") prepared = ( - Author.prepare_sql("test_delete") + Author .filter( id__in=Parameter("ids"), ) .delete() - .prepared() + .compile("test_delete") ) affected = await prepared.execute(ids=[author1.pk]) @@ -189,12 +181,12 @@ async def test_exists(db): author = await Author.create(name="1") prepared = ( - Author.prepare_sql("test_exists") + Author .filter( id__in=Parameter("ids"), ) .exists() - .prepared() + .compile("test_exists") ) assert await prepared.execute(ids=[author.pk]) @@ -208,12 +200,12 @@ async def test_count(db): author3 = await Author.create(name="3") prepared = ( - Author.prepare_sql("test_count") + Author .filter( id__gte=Parameter("idgte"), ) .count() - .prepared() + .compile("test_count") ) assert await prepared.execute(idgte=author1.pk) == 3 @@ -233,11 +225,11 @@ async def test_parameter_in_limit(db): ) prepared = ( - Author.prepare_sql("test_parameter_in_limit") + Author .all() .limit(Parameter("lim")) .order_by("id") - .prepared() + .compile("test_parameter_in_limit") ) assert len(await prepared.execute(lim=1)) == 1 @@ -260,11 +252,11 @@ async def test_parameter_in_offset(db): ) prepared = ( - Author.prepare_sql("test_parameter_in_offset") + Author .all() .offset(Parameter("off")) .order_by("id") - .prepared() + .compile("test_parameter_in_offset") ) assert len(await prepared.execute(off=1)) == 2 @@ -281,12 +273,12 @@ async def test_values(db): author = await Author.create(name="1") prepared = ( - Author.prepare_sql("test_values") + Author .filter( id=Parameter("id"), ) .values() - .prepared() + .compile("test_values") ) assert await prepared.execute(id=author.pk) == [{"id": author.pk, "name": author.name}] @@ -298,12 +290,12 @@ async def test_values_list_all_fields(db): author = await Author.create(name="1") prepared_all = ( - Author.prepare_sql("test_values_list_all_fields") + Author .filter( id=Parameter("id"), ) .values_list() - .prepared() + .compile("test_values_list_all_fields") ) assert await prepared_all.execute(id=author.pk) == [(author.pk, author.name)] assert await prepared_all.execute(id=author.pk * 2) == [] @@ -314,12 +306,12 @@ async def test_values_list_only_id_field(db): author = await Author.create(name="1") prepared_ids = ( - Author.prepare_sql("test_values_list_only_id_field") + Author .filter( id=Parameter("id"), ) .values_list("id") - .prepared() + .compile("test_values_list_only_id_field") ) assert await prepared_ids.execute(id=author.pk) == [(author.pk,)] assert await prepared_ids.execute(id=author.pk * 2) == [] @@ -330,12 +322,12 @@ async def test_values_list_only_id_field_flat(db): author = await Author.create(name="1") prepared_ids_flat = ( - Author.prepare_sql("test_values_list_only_id_field_flat") + Author .filter( id=Parameter("id"), ) .values_list("id", flat=True) - .prepared() + .compile("test_values_list_only_id_field_flat") ) assert await prepared_ids_flat.execute(id=author.pk) == [author.pk] assert await prepared_ids_flat.execute(id=author.pk * 2) == [] @@ -349,10 +341,10 @@ async def test_update_fk(db): book = await Book.create(name="test", author=author1, rating=5) prepared = ( - Book.prepare_sql("test_update_fk") + Book .filter(id=Parameter("search_id")) .update(author=Parameter("replace_author")) - .prepared() + .compile("test_update_fk") ) await prepared.execute(search_id=book.pk, replace_author=author2) @@ -372,10 +364,10 @@ async def test_update_pk_invalid_obj(db): book = await Book.create(name="test", author=author, rating=5) prepared = ( - Book.prepare_sql("test_update_pk_invalid_obj") + Book .filter(id=Parameter("search_id")) .update(author=Parameter("replace_author")) - .prepared() + .compile("test_update_pk_invalid_obj") ) with pytest.raises(ValidationError): @@ -384,10 +376,10 @@ async def test_update_pk_invalid_obj(db): def test_remove_prepared_queryset_from_cache(db): cache_key = "test_remove_query_from_cache" - prepared = Author.prepare_sql(cache_key).filter(id=Parameter("some_param")).prepared() - assert Author.prepare_sql(cache_key).query is prepared.query - Author.remove_prepared_query(cache_key) - assert Author.prepare_sql(cache_key).query is not prepared.query + prepared = Author.filter(id=Parameter("some_param")).compile(cache_key) + assert Author.all().compile(cache_key).query is prepared.query + Author.remove_compiled_query(cache_key) + assert Author.all().compile(cache_key).query is not prepared.query @pytest.mark.parametrize( @@ -404,12 +396,12 @@ def test_remove_prepared_queryset_from_cache(db): def test_prepared_query_get_sql(db, filter_kwargs: dict[str, Any], cache_key_suffix: str): expected_sql = CharPkModel.all().filter(**filter_kwargs).limit(10).offset(0).sql() actual_sql = ( - CharPkModel.prepare_sql(f"test_prepared_query_get_sql-{cache_key_suffix}") + CharPkModel .all() .filter(**{key: Parameter(key) for key in filter_kwargs}) .limit(10) .offset(0) - .prepared() + .compile(f"test_prepared_query_get_sql-{cache_key_suffix}") .sql(**filter_kwargs) ) diff --git a/tortoise/models.py b/tortoise/models.py index ee68666b1..db485f3c1 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -50,7 +50,7 @@ QuerySetSingle, RawSQLQuery, ) -from tortoise.queryset_prepared import PreparingQuerySet, _PreparedQueryMixin +from tortoise.queryset_compiled import BaseCompiledQuery from tortoise.router import router from tortoise.signals import Signals from tortoise.transactions import in_transaction @@ -257,7 +257,7 @@ def __init__(self, meta: Model.Meta) -> None: self.db_native_fields: list[tuple[str, str, Field]] = [] self.db_default_fields: list[tuple[str, str, Field]] = [] self.db_complex_fields: list[tuple[str, str, Field]] = [] - self.query_cache: dict[str, _PreparedQueryMixin] = {} + self.query_cache: dict[str, BaseCompiledQuery[Self]] = {} @property def full_name(self) -> str: @@ -1614,11 +1614,7 @@ async def fetch_for_list( await db.executor_class(model=cls, db=db).fetch_for_list(instance_list, *args) @classmethod - def prepare_sql(cls, key: str) -> PreparingQuerySet[Self] | _PreparedQueryMixin[Self]: - return cls._meta.manager.get_queryset().prepare_sql(key) - - @classmethod - def remove_prepared_query(cls, key: str) -> None: + def remove_compiled_query(cls, key: str) -> None: cls._meta.query_cache.pop(key, None) @classmethod diff --git a/tortoise/queryset.py b/tortoise/queryset.py index ba4fbbba6..06a662da1 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -45,7 +45,8 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model - from tortoise.queryset_prepared import PreparingQuerySet, _PreparedQueryMixin + from tortoise.queryset_compiled import CompiledQuerySet, CompiledUpdateQuery, CompiledDeleteQuery, \ + CompiledExistsQuery, CompiledCountQuery MODEL = TypeVar("MODEL", bound="Model") PRIMARY_KEY = TypeVar("PRIMARY_KEY") @@ -373,6 +374,7 @@ def __init__(self, model: type[MODEL]) -> None: self._force_indexes: set[str] = set() self._use_indexes: set[str] = set() + # TODO: remove _new_cls def _clone(self, _new_cls: type[QuerySet] | None = None) -> QuerySet[MODEL]: if _new_cls is None: _new_cls = self.__class__ @@ -517,6 +519,7 @@ def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: queryset._orderings = self._parse_orderings(orderings) return queryset._as_single() + # TODO: support Parameter arguments def limit(self, limit: int) -> QuerySet[MODEL]: """ Limits QuerySet to given length. @@ -530,6 +533,7 @@ def limit(self, limit: int) -> QuerySet[MODEL]: queryset._limit = limit return queryset + # TODO: support Parameter arguments def offset(self, offset: int) -> QuerySet[MODEL]: """ Query offset for QuerySet. @@ -1278,25 +1282,38 @@ async def _execute(self) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list - def prepare_sql(self, key: str) -> PreparingQuerySet[MODEL] | _PreparedQueryMixin: + def compile(self, key: str | None = None) -> CompiledQuerySet[MODEL]: """ - Cache generated sql of this query set. - If query set is already in cache, return cached version with already generated sql. + Compiles queryset sql. + :param key: Cache key for saving compiled query to model cache. """ + from tortoise.queryset_compiled import CompiledQuerySet + if key in self.model._meta.query_cache: - prepared_queryset = self.model._meta.query_cache[key]._clone() - # TODO: select db in .prepared, not in here - prepared_queryset._db = None # type: ignore - prepared_queryset._db = prepared_queryset._choose_db(prepared_queryset._db_for_write) - return prepared_queryset + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledQuerySet): + ... # TODO: raise an exception + return cached._clone() - from tortoise.queryset_prepared import PreparingQuerySet + self._choose_db_if_not_chosen(self._select_for_update) + self._make_query() + compiled = CompiledQuerySet( + model=self.model, + query=self.query, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + select_for_update=self._select_for_update, + custom_fields=list(self._annotations.keys()), + ) - preparing_queryset = cast(PreparingQuerySet[MODEL], self._clone(PreparingQuerySet)) - preparing_queryset._cache_key = key + if key is not None: + self.model._meta.query_cache[key] = compiled - return preparing_queryset + return compiled class UpdateQuery(AwaitableQuery): @@ -1389,6 +1406,29 @@ def __await__(self) -> Generator[Any, None, int]: async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] + def compile(self, key: str | None = None) -> CompiledUpdateQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledUpdateQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledUpdateQuery): + ... # TODO: raise an exception + return cached._clone() + + self._choose_db_if_not_chosen(True) + self._make_query() + compiled = CompiledUpdateQuery(model=self.model, query=self.query) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class DeleteQuery(AwaitableQuery): __slots__ = ( @@ -1438,6 +1478,29 @@ def __await__(self) -> Generator[Any, None, int]: async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] + def compile(self, key: str | None = None) -> CompiledDeleteQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledDeleteQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledDeleteQuery): + ... # TODO: raise an exception + return cached._clone() + + self._choose_db_if_not_chosen(True) + self._make_query() + compiled = CompiledDeleteQuery(model=self.model, query=self.query) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class ExistsQuery(AwaitableQuery): __slots__ = ( @@ -1487,6 +1550,29 @@ async def _execute( result, _ = await self._db.execute_query(*self.query.get_parameterized_sql()) return bool(result) + def compile(self, key: str | None = None) -> CompiledExistsQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledExistsQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledExistsQuery): + ... # TODO: raise an exception + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled = CompiledExistsQuery(model=self.model, query=self.query) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class CountQuery(AwaitableQuery): __slots__ = ( @@ -1550,6 +1636,34 @@ async def _execute(self) -> int: return self._limit return count + def compile(self, key: str | None = None) -> CompiledCountQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledCountQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledCountQuery): + ... # TODO: raise an exception + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled = CompiledCountQuery( + model=self.model, + query=self.query, + limit=self._limit, + offset=self._offset, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class FieldSelectQuery(AwaitableQuery): # pylint: disable=W0223 @@ -1801,6 +1915,26 @@ async def _execute(self) -> list[Any] | tuple: _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) return self._process_results(result) + # TODO: add compiled values list query class + # def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL]: + # """ + # Compiles query sql. + # :param key: Cache key for saving compiled query to model cache. + # """ + # + # if key in self.model._meta.query_cache: + # cached = self.model._meta.query_cache[key] + # if not isinstance(cached, CompiledValuesListQuery): + # ... # TODO: raise an exception + # return cached._clone() + # + # compiled = CompiledValuesListQuery(model=self.model, query=self.query) + # + # if key is not None: + # self.model._meta.query_cache[key] = compiled + # + # return compiled + class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( @@ -1936,6 +2070,26 @@ async def _execute(self) -> list[dict] | dict: result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) return self._process_results(result) + # TODO: add compiled values list query class + # def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL]: + # """ + # Compiles query sql. + # :param key: Cache key for saving compiled query to model cache. + # """ + # + # if key in self.model._meta.query_cache: + # cached = self.model._meta.query_cache[key] + # if not isinstance(cached, CompiledValuesQuery): + # ... # TODO: raise an exception + # return cached._clone() + # + # compiled = CompiledValuesQuery(model=self.model, query=self.query) + # + # if key is not None: + # self.model._meta.query_cache[key] = compiled + # + # return compiled + class RawSQLQuery(AwaitableQuery): __slots__ = ("_sql", "_db") diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py new file mode 100644 index 000000000..b2fe7ddc9 --- /dev/null +++ b/tortoise/queryset_compiled.py @@ -0,0 +1,465 @@ +from __future__ import annotations as _ + +import sys +from abc import ABC +from collections import defaultdict +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, TypeVar, cast + +from pypika_tortoise.queries import QueryBuilder, Table + +from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned +from tortoise.parameter import CollectionParameter, Parameter, TortoiseSqlContext +from tortoise.query_utils import Prefetch +from tortoise.queryset import MODEL, AwaitableQuery, QuerySet + +if sys.version_info >= (3, 11): # pragma: nocoverage + from typing import Self +else: + from typing_extensions import Self + +if TYPE_CHECKING: + from tortoise import Model + + +T = TypeVar("T") + + +class CachedSql: + __slots__ = ( + "sql", + "params", + "param_by_name", + "need_params", + "need_collection_params", + ) + + def __init__(self, sql: str, params: list[Parameter | Any]) -> None: + self.sql = sql + self.params = params + self.param_by_name: dict[str, Parameter] = {} + self.need_params: dict[str, int] = {} + self.need_collection_params: dict[str, list[int]] = defaultdict(list) + for idx, param in enumerate(params): + if not isinstance(param, Parameter): + continue + + if param.name not in self.param_by_name: + self.param_by_name[param.name] = param + + if isinstance(param, CollectionParameter): + self.need_collection_params[param.name].append(idx) + else: + self.need_params[param.name] = idx + + def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + # TODO: check for parameters mismatch + + filled_params = self.params.copy() + for name, idx in self.need_params.items(): + param = self.param_by_name[name] + filled_params[idx] = param.encode_value(params[name]) + + for name, indexes in self.need_collection_params.items(): + param = cast(CollectionParameter, self.param_by_name[name]) + collection = param.encode_collection(params[name]) + if len(collection) != len(indexes): + raise ValueError( + f"Provided value length ({len(collection)}) " + f"for parameter {name!r} does not match " + f"parameter indexes length ({len(indexes)})" + ) + for idx, value in zip(indexes, collection): + filled_params[idx] = param.encode_value(value) + + return filled_params + + +class BaseCompiledQuery(AwaitableQuery[MODEL], ABC): + __slots__ = ("_sql_cache", "_dynamic_params", "_dynamic_params_names", "_dynamic_params_init") + + def __init__(self, model: type[MODEL], query: QueryBuilder) -> None: + super().__init__(model) + self.query = query + # TODO: use lru + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] + self._dynamic_params_init: bool = False + + def _clone(self) -> Self: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._capabilities = self._capabilities + query._annotations = self._annotations + + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + + return query + + async def execute(self, **params) -> ...: + ... + + def init_params_table(self) -> None: + _, params = self.query.get_parameterized_sql() + self._sql_cache = {} + self._dynamic_params = { + param.name: param for param in params if isinstance(param, CollectionParameter) + } + self._dynamic_params_names = sorted(self._dynamic_params.keys()) + self._dynamic_params_init = True + + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + if not self._dynamic_params_init: + self.init_params_table() + + reset_params = [] + + cache_key = f"{self._db.capabilities.dialect}-query" + for name in self._dynamic_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + # TODO: raise exception? + continue + + param = self._dynamic_params[name] + cache_key += f"-{name}{len(value)}" + param.collection_size = len(value) + reset_params.append(param) + + # TODO: add ability to limit cache, use lru? + if cache_key not in self._sql_cache: + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy( + self.query.QUERY_CLS.SQL_CONTEXT, + dynamic_params=self._dynamic_params, + ) + sql, params_ = self.query.get_parameterized_sql(ctx) + self._sql_cache[cache_key] = CachedSql(sql, params_) + + for param in reset_params: + param.collection_size = None + + return self._sql_cache[cache_key] + + def sql(self, params_inline=False, **params) -> str: + old_db = self._db + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + self._db = old_db + return cached_query.sql + + +# TODO: type single queries +class CompiledQuerySet(BaseCompiledQuery[MODEL]): + __slots__ = ( + "_prefetch_map", + "_prefetch_queries", + "_select_related_idx", + "_single", + "_raise_does_not_exist", + "_select_for_update", + "_custom_fields", + ) + + def __init__( + self, + model: type[MODEL], + query: QueryBuilder, + prefetch_map: dict[str, set[str | Prefetch]], + prefetch_queries: dict[str, list[tuple[str | None, QuerySet]]], + select_related_idx: list[ + tuple[type[Model], int, Table | str, type[Model], Iterable[str | None]] + ], + single: bool, + raise_does_not_exist: bool, + select_for_update: bool, + custom_fields: list[str] | None, + ) -> None: + super().__init__(model, query) + self._prefetch_map = prefetch_map + self._prefetch_queries = prefetch_queries + self._select_related_idx = select_related_idx + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._select_for_update = select_for_update + self._custom_fields: list[str] | None = custom_fields + + def _clone(self) -> Self: + queryset = super()._clone() + queryset._prefetch_map = self._prefetch_map + queryset._prefetch_queries = self._prefetch_queries + queryset._select_related_idx = self._select_related_idx + queryset._single = self._single + queryset._raise_does_not_exist = self._raise_does_not_exist + queryset._select_for_update = self._select_for_update + queryset._custom_fields = self._custom_fields + return queryset + + async def execute(self, **params) -> list[MODEL]: + self._choose_db_if_not_chosen(self._select_for_update) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + instance_list = await self._db.executor_class( + model=self.model, + db=self._db, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, # type: ignore + ).execute_select( + cached_query.sql, + filled_params, + custom_fields=self._custom_fields, + ) + if self._single: + if len(instance_list) == 1: + return instance_list[0] + if not instance_list: + if self._raise_does_not_exist: + raise DoesNotExist(self.model) + return None # type: ignore + raise MultipleObjectsReturned(self.model) + return instance_list + + +class CompiledUpdateQuery(BaseCompiledQuery[MODEL]): + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(True) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + +class CompiledDeleteQuery(BaseCompiledQuery[MODEL]): + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(True) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + +class CompiledExistsQuery(BaseCompiledQuery[MODEL]): + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + result, _ = await self._db.execute_query(cached_query.sql, filled_params) + return bool(result) + + +class CompiledCountQuery(BaseCompiledQuery[MODEL]): + __slots__ = ( + "_limit", + "_offset", + ) + + def __init__( + self, + model: type[MODEL], + query: QueryBuilder, + limit: int | None, + offset: int | None, + ) -> None: + super().__init__(model, query) + self._limit = limit or 0 + self._offset = offset or 0 + + def _clone(self) -> Self: + query = super()._clone() + query._limit = self._limit + query._offset = self._offset + return query + + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + _, result = await self._db.execute_query(cached_query.sql, filled_params) + if not result: + return 0 + count = list(dict(result[0]).values())[0] - self._offset + if self._limit and count > self._limit: + return self._limit + return count + +""" +class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): + __slots__ = ( + "_cache_key", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + "_db_for_write", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + single: bool, + raise_does_not_exist: bool, + fields_for_select_list: tuple[str, ...] | list[str], + flat: bool, + annotations: dict[str, Any], + query: QueryBuilder, + cache_key: str, + ) -> None: + super().__init__( + model=model, + db=db, + q_objects=[], + single=single, + raise_does_not_exist=raise_does_not_exist, + fields_for_select_list=fields_for_select_list, + limit=None, + offset=None, + distinct=False, + orderings=[], + flat=flat, + annotations=annotations, + custom_filters={}, + group_bys=(), + force_indexes=set(), + use_indexes=set(), + ) + self.query = query + + self._cache_key: str = cache_key + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] + self._db_for_write: bool = False + + def prepared(self) -> PreparedValuesListQuery[SINGLE]: + return self + + def _clone(self) -> PreparedValuesListQuery[SINGLE]: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = self._db + query._capabilities = self._capabilities + + query.fields = self.fields + query._limit = self._limit + query._offset = self._offset + query._distinct = self._distinct + query._orderings = self._orderings + query._custom_filters = self._custom_filters + query._q_objects = self._q_objects + query._single = self._single + query._raise_does_not_exist = self._raise_does_not_exist + query._fields_for_select_list = self._fields_for_select_list + query._flat = self._flat + query._group_bys = self._group_bys + query._force_indexes = self._force_indexes + query._use_indexes = self._use_indexes + query._fields_to_select_sql = self._fields_to_select_sql + query._annotations = self._annotations + + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + + return query + + async def execute(self, **params) -> list[Any] | tuple: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + _, result = await self._db.execute_query(cached_query.sql, filled_params) + return self._process_results(result) + + +class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQueryMixin): + __slots__ = ( + "_cache_key", + "_sql_cache", + "_dynamic_params", + "_dynamic_params_names", + "_db_for_write", + ) + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + single: bool, + raise_does_not_exist: bool, + fields_for_select: dict[str, str], + annotations: dict[str, Any], + query: QueryBuilder, + cache_key: str, + ) -> None: + super().__init__( + model=model, + db=db, + q_objects=[], + single=single, + raise_does_not_exist=raise_does_not_exist, + fields_for_select=fields_for_select, + limit=None, + offset=None, + distinct=False, + orderings=[], + annotations=annotations, + custom_filters={}, + group_bys=(), + force_indexes=set(), + use_indexes=set(), + ) + + self.query = query + self._cache_key: str = cache_key + self._sql_cache: dict[str, CachedSql] = {} + self._dynamic_params: dict[str, CollectionParameter] = {} + self._dynamic_params_names: list[str] = [] + self._db_for_write: bool = False + + def prepared(self) -> PreparedValuesQuery[SINGLE]: + return self + + def _clone(self) -> PreparedValuesQuery[SINGLE]: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = self._db + query._capabilities = self._capabilities + + query._fields_for_select = self._fields_for_select + query._limit = self._limit + query._offset = self._offset + query._distinct = self._distinct + query._orderings = self._orderings + query._custom_filters = self._custom_filters + query._q_objects = self._q_objects + query._single = self._single + query._raise_does_not_exist = self._raise_does_not_exist + query._db = self._db + query._group_bys = self._group_bys + query._force_indexes = self._force_indexes + query._use_indexes = self._use_indexes + query._annotations = self._annotations + + query._cache_key = self._cache_key + query._db_for_write = self._db_for_write + query._sql_cache = self._sql_cache + query._dynamic_params = self._dynamic_params + query._dynamic_params_names = self._dynamic_params_names + + return query + + async def execute(self, **params) -> list[dict] | dict: + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + result = await self._db.execute_query_dict(cached_query.sql, filled_params) + return self._process_results(result) +""" \ No newline at end of file From a8e72050ba8c89e7dcf8b81a80440c3f288e5419 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 09:34:07 +0200 Subject: [PATCH 45/57] implement CompiledValuesListQuery and CompiledValuesQuery --- tortoise/queryset.py | 126 ++++++++++++++++-------- tortoise/queryset_compiled.py | 174 +++++++++------------------------- 2 files changed, 129 insertions(+), 171 deletions(-) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 06a662da1..47cfd29d8 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model from tortoise.queryset_compiled import CompiledQuerySet, CompiledUpdateQuery, CompiledDeleteQuery, \ - CompiledExistsQuery, CompiledCountQuery + CompiledExistsQuery, CompiledCountQuery, CompiledValuesListQuery, CompiledValuesQuery MODEL = TypeVar("MODEL", bound="Model") PRIMARY_KEY = TypeVar("PRIMARY_KEY") @@ -1665,6 +1665,14 @@ def compile(self, key: str | None = None) -> CompiledCountQuery[MODEL]: return compiled +class FieldsSelectProtocol(Protocol[MODEL]): + model: type[MODEL] + _annotations: dict[str, Any] + + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: + ... + + class FieldSelectQuery(AwaitableQuery): # pylint: disable=W0223 @@ -1732,7 +1740,7 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: raise FieldError(f'Unknown field "{field}" for model "{self.model.__name__}"') - def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: + def resolve_to_python_value(self: FieldsSelectProtocol[MODEL], model: type[MODEL], field: str) -> Callable: if field in model._meta.fetch_fields: # return as is to get whole model objects return lambda x: x @@ -1777,6 +1785,13 @@ def _resolve_group_bys(self, *field_names: str) -> list: return group_bys +class ValuesListProtocol(FieldsSelectProtocol[MODEL], Protocol[MODEL]): + fields: dict[str, str] + _flat: bool + _single: bool + _raise_does_not_exist: bool + + class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( "fields", @@ -1888,7 +1903,7 @@ async def __aiter__(self: ValuesListQuery[Any]) -> AsyncIterator[Any]: for val in await self: yield val - def _process_results(self, result: Sequence[dict]) -> list[Any] | tuple: + def _process_results(self: ValuesListProtocol, result: Sequence[dict]) -> list[Any] | tuple: columns = [ (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() @@ -1915,25 +1930,42 @@ async def _execute(self) -> list[Any] | tuple: _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) return self._process_results(result) - # TODO: add compiled values list query class - # def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL]: - # """ - # Compiles query sql. - # :param key: Cache key for saving compiled query to model cache. - # """ - # - # if key in self.model._meta.query_cache: - # cached = self.model._meta.query_cache[key] - # if not isinstance(cached, CompiledValuesListQuery): - # ... # TODO: raise an exception - # return cached._clone() - # - # compiled = CompiledValuesListQuery(model=self.model, query=self.query) - # - # if key is not None: - # self.model._meta.query_cache[key] = compiled - # - # return compiled + def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledValuesListQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledValuesListQuery): + ... # TODO: raise an exception + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled = CompiledValuesListQuery( + model=self.model, + query=self.query, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select_list=self._fields_for_select_list, + flat=self._flat, + annotations=self._annotations, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + + +class ValuesProtocol(FieldsSelectProtocol[MODEL], Protocol[MODEL]): + _fields_for_select: dict[str, str] + _single: bool + _raise_does_not_exist: bool class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): @@ -2041,7 +2073,7 @@ async def __aiter__(self: ValuesQuery[Any]) -> AsyncIterator[dict[str, Any]]: for val in await self: yield val - def _process_results(self, result: list[dict]) -> list[dict] | dict: + def _process_results(self: ValuesProtocol, result: list[dict]) -> list[dict] | dict: columns = [ val for val in [ @@ -2070,25 +2102,35 @@ async def _execute(self) -> list[dict] | dict: result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) return self._process_results(result) - # TODO: add compiled values list query class - # def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL]: - # """ - # Compiles query sql. - # :param key: Cache key for saving compiled query to model cache. - # """ - # - # if key in self.model._meta.query_cache: - # cached = self.model._meta.query_cache[key] - # if not isinstance(cached, CompiledValuesQuery): - # ... # TODO: raise an exception - # return cached._clone() - # - # compiled = CompiledValuesQuery(model=self.model, query=self.query) - # - # if key is not None: - # self.model._meta.query_cache[key] = compiled - # - # return compiled + def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledValuesQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledValuesQuery): + ... # TODO: raise an exception + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled = CompiledValuesQuery( + model=self.model, + query=self.query, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select=self._fields_for_select, + annotations=self._annotations, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled class RawSQLQuery(AwaitableQuery): diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index b2fe7ddc9..73d27f0f7 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -4,14 +4,14 @@ from abc import ABC from collections import defaultdict from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast, Callable from pypika_tortoise.queries import QueryBuilder, Table from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned from tortoise.parameter import CollectionParameter, Parameter, TortoiseSqlContext from tortoise.query_utils import Prefetch -from tortoise.queryset import MODEL, AwaitableQuery, QuerySet +from tortoise.queryset import MODEL, AwaitableQuery, QuerySet, ValuesListQuery, FieldSelectQuery, ValuesQuery if sys.version_info >= (3, 11): # pragma: nocoverage from typing import Self @@ -287,179 +287,95 @@ async def execute(self, **params) -> int: return self._limit return count -""" -class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): + +# TODO: type single query +class CompiledValuesListQuery(BaseCompiledQuery[MODEL]): __slots__ = ( - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", + "fields", + "_single", + "_raise_does_not_exist", + "_flat", + "_annotations", ) def __init__( self, model: type[MODEL], - db: BaseDBAsyncClient, + query: QueryBuilder, single: bool, raise_does_not_exist: bool, fields_for_select_list: tuple[str, ...] | list[str], flat: bool, annotations: dict[str, Any], - query: QueryBuilder, - cache_key: str, ) -> None: - super().__init__( - model=model, - db=db, - q_objects=[], - single=single, - raise_does_not_exist=raise_does_not_exist, - fields_for_select_list=fields_for_select_list, - limit=None, - offset=None, - distinct=False, - orderings=[], - flat=flat, - annotations=annotations, - custom_filters={}, - group_bys=(), - force_indexes=set(), - use_indexes=set(), - ) - self.query = query - - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = False - - def prepared(self) -> PreparedValuesListQuery[SINGLE]: - return self + super().__init__(model, query) - def _clone(self) -> PreparedValuesListQuery[SINGLE]: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._db = self._db - query._capabilities = self._capabilities + fields_for_select = {str(i): field for i, field in enumerate(fields_for_select_list)} + self.fields = fields_for_select + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._flat = flat + self._annotations = annotations + def _clone(self) -> Self: + query = super()._clone() query.fields = self.fields - query._limit = self._limit - query._offset = self._offset - query._distinct = self._distinct - query._orderings = self._orderings - query._custom_filters = self._custom_filters - query._q_objects = self._q_objects query._single = self._single query._raise_does_not_exist = self._raise_does_not_exist - query._fields_for_select_list = self._fields_for_select_list query._flat = self._flat - query._group_bys = self._group_bys - query._force_indexes = self._force_indexes - query._use_indexes = self._use_indexes - query._fields_to_select_sql = self._fields_to_select_sql query._annotations = self._annotations - - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - return query + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: + return FieldSelectQuery.resolve_to_python_value(self, model, field) + async def execute(self, **params) -> list[Any] | tuple: + self._choose_db_if_not_chosen(False) cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) - _, result = await self._db.execute_query(cached_query.sql, filled_params) - return self._process_results(result) + return ValuesListQuery._process_results(self, result) -class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQueryMixin): +# TODO: type single query +class CompiledValuesQuery(BaseCompiledQuery[MODEL]): __slots__ = ( - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", + "_single", + "_raise_does_not_exist", + "_fields_for_select", + "_annotations", ) def __init__( self, model: type[MODEL], - db: BaseDBAsyncClient, + query: QueryBuilder, single: bool, raise_does_not_exist: bool, fields_for_select: dict[str, str], annotations: dict[str, Any], - query: QueryBuilder, - cache_key: str, ) -> None: - super().__init__( - model=model, - db=db, - q_objects=[], - single=single, - raise_does_not_exist=raise_does_not_exist, - fields_for_select=fields_for_select, - limit=None, - offset=None, - distinct=False, - orderings=[], - annotations=annotations, - custom_filters={}, - group_bys=(), - force_indexes=set(), - use_indexes=set(), - ) - - self.query = query - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = False - - def prepared(self) -> PreparedValuesQuery[SINGLE]: - return self + super().__init__(model, query) - def _clone(self) -> PreparedValuesQuery[SINGLE]: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._db = self._db - query._capabilities = self._capabilities + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._fields_for_select = fields_for_select + self._annotations = annotations - query._fields_for_select = self._fields_for_select - query._limit = self._limit - query._offset = self._offset - query._distinct = self._distinct - query._orderings = self._orderings - query._custom_filters = self._custom_filters - query._q_objects = self._q_objects + def _clone(self) -> Self: + query = super()._clone() query._single = self._single query._raise_does_not_exist = self._raise_does_not_exist - query._db = self._db - query._group_bys = self._group_bys - query._force_indexes = self._force_indexes - query._use_indexes = self._use_indexes + query._fields_for_select = self._fields_for_select query._annotations = self._annotations - - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - return query + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: + return FieldSelectQuery.resolve_to_python_value(self, model, field) + async def execute(self, **params) -> list[dict] | dict: + self._choose_db_if_not_chosen(False) cached_query = self._get_or_create_cached_sql(params) filled_params = cached_query.make_filled_params(params) - result = await self._db.execute_query_dict(cached_query.sql, filled_params) - return self._process_results(result) -""" \ No newline at end of file + return ValuesQuery._process_results(self, result) From ff6c2574aaa287f7b99cee2951f3cc7d2eaa122a Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 09:38:41 +0200 Subject: [PATCH 46/57] support Parameters in QuerySet.limit and .offset --- tortoise/queryset.py | 35 +- tortoise/queryset_prepared.py | 1041 --------------------------------- 2 files changed, 25 insertions(+), 1051 deletions(-) delete mode 100644 tortoise/queryset_prepared.py diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 47cfd29d8..09fcfdb14 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -519,32 +519,47 @@ def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: queryset._orderings = self._parse_orderings(orderings) return queryset._as_single() - # TODO: support Parameter arguments - def limit(self, limit: int) -> QuerySet[MODEL]: + @staticmethod + def _validate_limit(value: int) -> int: + if value < 0: + raise ParamsError("Limit should be non-negative number") + return value + + def limit(self, limit: int | Parameter) -> QuerySet[MODEL]: """ Limits QuerySet to given length. :raises ParamsError: Limit should be non-negative number. """ - if limit < 0: - raise ParamsError("Limit should be non-negative number") + if isinstance(limit, int): + self._validate_limit(limit) + elif isinstance(limit, Parameter): + limit.encode = self._validate_limit queryset = self._clone() - queryset._limit = limit + queryset._limit = limit # type: ignore return queryset - # TODO: support Parameter arguments - def offset(self, offset: int) -> QuerySet[MODEL]: + @staticmethod + def _validate_offset(value: int) -> int: + if value < 0: + raise ParamsError("Offset should be non-negative number") + return value + + def offset(self, offset: int | Parameter) -> QuerySet[MODEL]: """ Query offset for QuerySet. :raises ParamsError: Offset should be non-negative number. """ - if offset < 0: - raise ParamsError("Offset should be non-negative number") + + if isinstance(offset, int) and offset < 0: + self._validate_offset(offset) + elif isinstance(offset, Parameter): + offset.encode = self._validate_offset queryset = self._clone() - queryset._offset = offset + queryset._offset = offset # type: ignore if self.capabilities.requires_limit and queryset._limit is None: queryset._limit = 1000000 return queryset diff --git a/tortoise/queryset_prepared.py b/tortoise/queryset_prepared.py deleted file mode 100644 index 957a67b7c..000000000 --- a/tortoise/queryset_prepared.py +++ /dev/null @@ -1,1041 +0,0 @@ -from __future__ import annotations as _ - -import sys -from abc import ABC, abstractmethod -from collections import defaultdict -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Literal, NoReturn, Protocol, TypeVar, cast - -from pypika_tortoise.queries import QueryBuilder, Table -from pypika_tortoise.terms import Term - -from tortoise.backends.base.client import BaseDBAsyncClient -from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned, ParamsError -from tortoise.expressions import Expression, Q -from tortoise.parameter import CollectionParameter, Parameter, TortoiseSqlContext -from tortoise.query_utils import Prefetch -from tortoise.queryset import ( - MODEL, - PRIMARY_KEY, - SINGLE, - AwaitableQuery, - BulkCreateQuery, - BulkUpdateQuery, - QuerySet, - QuerySetSingle, - ValuesListQuery, - ValuesQuery, -) - -if sys.version_info >= (3, 11): # pragma: nocoverage - from typing import Self -else: - from typing_extensions import Self - -if TYPE_CHECKING: - from tortoise import Model - - -T = TypeVar("T") - - -class PreparedQuerySetSingle(QuerySetSingle[T], Protocol[T]): - def prefetch_related( - self, *args: str | Prefetch - ) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage - - def select_related(self, *args: str) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage - - def annotate( - self, **kwargs: Expression | Term - ) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage - - def only(self, *fields_for_select: str) -> PreparedQuerySetSingle[T]: ... # pragma: nocoverage - - def values_list( - self, *fields_: str, flat: bool = False - ) -> PreparedValuesListQuery[Literal[True]]: ... # pragma: nocoverage - - def values( - self, *args: str, **kwargs: str - ) -> PreparedValuesQuery[Literal[True]]: ... # pragma: nocoverage - - def prepared(self) -> PreparedQuerySetSingle[T]: ... - - async def execute(self, **params) -> T: ... - - -class CachedSql: - __slots__ = ( - "sql", - "params", - "param_by_name", - "need_params", - "need_collection_params", - ) - - def __init__(self, sql: str, params: list[Parameter | Any]) -> None: - self.sql = sql - self.params = params - self.param_by_name: dict[str, Parameter] = {} - self.need_params: dict[str, int] = {} - self.need_collection_params: dict[str, list[int]] = defaultdict(list) - for idx, param in enumerate(params): - if not isinstance(param, Parameter): - continue - - if param.name not in self.param_by_name: - self.param_by_name[param.name] = param - - if isinstance(param, CollectionParameter): - self.need_collection_params[param.name].append(idx) - else: - self.need_params[param.name] = idx - - def make_filled_params(self, params: dict[str, Any]) -> list[Any]: - # TODO: check for parameters mismatch - - filled_params = self.params.copy() - for name, idx in self.need_params.items(): - param = self.param_by_name[name] - filled_params[idx] = param.encode_value(params[name]) - - for name, indexes in self.need_collection_params.items(): - param = cast(CollectionParameter, self.param_by_name[name]) - collection = param.encode_collection(params[name]) - if len(collection) != len(indexes): - raise ValueError( - f"Provided value length ({len(collection)}) " - f"for parameter {name!r} does not match " - f"parameter indexes length ({len(indexes)})" - ) - for idx, value in zip(indexes, collection): - filled_params[idx] = param.encode_value(value) - - return filled_params - - -class _PreparedQueryMixin(AwaitableQuery[MODEL], ABC): - _cache_key: str - _sql_cache: dict[str, CachedSql] - _dynamic_params: dict[str, CollectionParameter] - _dynamic_params_names: list[str] - _db_for_write: bool - - @abstractmethod - def _clone(self) -> Self: ... - - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError("QuerySets must be prepared only once") - - def prepared(self) -> Self: - return self - - def _init_prepared(self) -> None: - _, params = self.query.get_parameterized_sql() - self._sql_cache = {} - self._dynamic_params = { - param.name: param for param in params if isinstance(param, CollectionParameter) - } - self._dynamic_params_names = sorted(self._dynamic_params.keys()) - self.model._meta.query_cache[self._cache_key] = self - - def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - reset_params = [] - - cache_key = f"{self._db.capabilities.dialect}-query" - for name in self._dynamic_params_names: - value = params[name] - if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue - - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" - param.collection_size = len(value) - reset_params.append(param) - - # TODO: add ability to limit cache, use lru? - if cache_key not in self._sql_cache: - # TODO: probably could be done in a better way? - ctx = TortoiseSqlContext.copy( - self.query.QUERY_CLS.SQL_CONTEXT, - dynamic_params=self._dynamic_params, - ) - sql, params_ = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params_) - - for param in reset_params: - param.collection_size = None - - return self._sql_cache[cache_key] - - @abstractmethod - async def execute(self, **params) -> Any: ... - - def sql(self, params_inline=False, **params) -> str: - cached_query = self._get_or_create_cached_sql(params) - return cached_query.sql - - def filter(self, *args: Q, **kwargs: Any) -> Self: - return self - - def exclude(self, *args: Q, **kwargs: Any) -> Self: - return self - - def order_by(self, *orderings: str) -> Self: - return self - - def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, self) - - def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, self) - - def limit(self, limit: int | Parameter) -> Self: - return self - - def offset(self, offset: int | Parameter) -> Self: - return self - - def __getitem__(self, key: slice) -> Self: - return self - - def distinct(self) -> Self: - return self - - def select_for_update( - self, - nowait: bool = False, - skip_locked: bool = False, - of: tuple[str, ...] = (), - no_key: bool = False, - ) -> Self: - return self - - def annotate(self, **kwargs: Expression | Term) -> Self: - return self - - def group_by(self, *fields: str) -> Self: - return self - - def values_list( - self, *fields_: str, flat: bool = False - ) -> PreparedValuesListQuery[Literal[False]]: - return cast(PreparedValuesListQuery, self) - - def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: - return cast(PreparedValuesQuery, self) - - def delete(self) -> PreparedDeleteQuery: - return cast(PreparedDeleteQuery, self) - - def update(self, **kwargs: Any) -> PreparedUpdateQuery: - return cast(PreparedUpdateQuery, self) - - def count(self) -> PreparedCountQuery: - return cast(PreparedCountQuery, self) - - def exists(self) -> PreparedExistsQuery: - return cast(PreparedExistsQuery, self) - - def all(self) -> Self: - return self - - def first(self) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, self) - - def last(self) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, self) - - def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: - return cast(PreparedQuerySetSingle, self) - - async def in_bulk( - self, id_list: Iterable[PRIMARY_KEY], field_name: str - ) -> dict[PRIMARY_KEY, MODEL]: - raise NotImplementedError("Prepared queries don't support in_bulk.") - - def bulk_create( - self, - objects: Iterable[MODEL], - batch_size: int | None = None, - ignore_conflicts: bool = False, - update_fields: Iterable[str] | None = None, - on_conflict: Iterable[str] | None = None, - ) -> BulkCreateQuery[MODEL]: - raise NotImplementedError("Prepared queries don't support bulk_create.") - - def bulk_update( - self, - objects: Iterable[MODEL], - fields: Iterable[str], - batch_size: int | None = None, - ) -> BulkUpdateQuery[MODEL]: - raise NotImplementedError("Prepared queries don't support bulk_update.") - - def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, self) - - def only(self, *fields_for_select: str) -> Self: - return self - - def select_related(self, *fields: str) -> Self: - return self - - def force_index(self, *index_names: str) -> Self: - return self - - def use_index(self, *index_names: str) -> Self: - return self - - def prefetch_related(self, *args: str | Prefetch) -> Self: - return self - - -class PreparingQuerySet(QuerySet[MODEL]): - __slots__ = ("_cache_key",) - - def __init__(self, model: type[MODEL], cache_key: str) -> None: - super().__init__(model) - self._cache_key: str = cache_key - - def _clone(self, _new_cls: type[QuerySet[MODEL]] | None = None) -> PreparingQuerySet[MODEL]: - queryset = cast(Self, super()._clone(_new_cls)) - queryset._cache_key = self._cache_key - return cast(PreparingQuerySet, queryset) - - def prepared(self) -> PreparedQuerySet: - if self._cache_key in self.model._meta.query_cache: - return cast(PreparedQuerySet, self.model._meta.query_cache[self._cache_key]) - - self._db = self._choose_db(self._select_for_update) - self._make_query() - - prepared = PreparedQuerySet( - model=self.model, - db=self._db, - query=self.query, - prefetch_map=self._prefetch_map, - prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - select_for_update=self._select_for_update, - custom_fields=list(self._annotations.keys()), - cache_key=self._cache_key, - ) - prepared._init_prepared() - return prepared - - def prepare_sql(self, key: str) -> NoReturn: - raise NotImplementedError("QuerySets must be prepared only once") - - def filter(self, *args: Q, **kwargs: Any) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().filter(*args, **kwargs)) - - def exclude(self, *args: Q, **kwargs: Any) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().exclude(*args, **kwargs)) - - def order_by(self, *orderings: str) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().order_by(*orderings)) - - def latest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().latest(*orderings)) - - def earliest(self, *orderings: str) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().earliest(*orderings)) - - @staticmethod - def _validate_limit(value: int) -> int: - if value < 0: - raise ParamsError("Limit should be non-negative number") - return value - - def limit(self, limit: int | Parameter) -> PreparingQuerySet[MODEL]: - if isinstance(limit, int) and limit < 0: - raise ParamsError("Limit should be non-negative number") - elif isinstance(limit, Parameter): - limit.encode = self._validate_limit - - queryset = self._clone() - queryset._limit = limit # type: ignore - return queryset - - @staticmethod - def _validate_offset(value: int) -> int: - if value < 0: - raise ParamsError("Offset should be non-negative number") - return value - - def offset(self, offset: int | Parameter) -> PreparingQuerySet[MODEL]: - if isinstance(offset, int) and offset < 0: - raise ParamsError("Offset should be non-negative number") - elif isinstance(offset, Parameter): - offset.encode = self._validate_offset - - queryset = self._clone() - queryset._offset = offset # type: ignore - if self.capabilities.requires_limit and queryset._limit is None: - queryset._limit = 1000000 - return queryset - - def __getitem__(self, key: slice) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().__getitem__(key)) - - def distinct(self) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().distinct()) - - def select_for_update( - self, - nowait: bool = False, - skip_locked: bool = False, - of: tuple[str, ...] = (), - no_key: bool = False, - ) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().select_for_update(nowait, skip_locked, of, no_key)) - - def annotate(self, **kwargs: Expression | Term) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().annotate(**kwargs)) - - def group_by(self, *fields: str) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().group_by(*fields)) - - def values_list( - self, *fields_: str, flat: bool = False - ) -> PreparedValuesListQuery[Literal[False]]: - fields_for_select_list = self._get_fields_list_for_select(*fields_) - query = super().values_list(*fields_, flat=flat) - query._db = query._choose_db(True) - query._make_query() - - prepared: PreparedValuesListQuery = PreparedValuesListQuery( - db=query._db, - model=self.model, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - flat=flat, - fields_for_select_list=fields_for_select_list, - annotations=self._annotations, - query=query.query, - cache_key=self._cache_key, - ) - prepared._init_prepared() - return prepared - - def values(self, *args: str, **kwargs: str) -> PreparedValuesQuery[Literal[False]]: - fields_for_select = self._get_fields_for_select(*args, **kwargs) - query = super().values(*args, **kwargs) - query._db = query._choose_db(True) - query._make_query() - - prepared: PreparedValuesQuery = PreparedValuesQuery( - db=query._db, - model=self.model, - single=self._single, - raise_does_not_exist=self._raise_does_not_exist, - fields_for_select=fields_for_select, - annotations=self._annotations, - query=query.query, - cache_key=self._cache_key, - ) - prepared._init_prepared() - return prepared - - def delete(self) -> PreparedDeleteQuery: # type: ignore - query = super().delete() - query._db = query._choose_db(True) - query._make_query() - - prepared = PreparedDeleteQuery( - model=self.model, - db=query._db, - query=query.query, - cache_key=self._cache_key, - ) - prepared._init_prepared() - return prepared - - def update(self, **kwargs: Any) -> PreparedUpdateQuery: # type: ignore - query = super().update(**kwargs) - query._db = query._choose_db(True) - query._make_query() - - prepared = PreparedUpdateQuery( - model=self.model, - db=query._db, - query=query.query, - cache_key=self._cache_key, - ) - prepared._init_prepared() - return prepared - - def count(self) -> PreparedCountQuery: # type: ignore - query = super().count() - query._db = query._choose_db(True) - query._make_query() - - prepared = PreparedCountQuery( - model=self.model, - db=query._db, - query=query.query, - limit=self._limit, - offset=self._offset, - cache_key=self._cache_key, - ) - prepared._init_prepared() - return prepared - - def exists(self) -> PreparedExistsQuery: # type: ignore - query = super().exists() - query._db = query._choose_db(True) - query._make_query() - - prepared = PreparedExistsQuery( - model=self.model, - db=query._db, - query=query.query, - cache_key=self._cache_key, - ) - prepared._init_prepared() - return prepared - - def all(self) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().all()) - - def first(self) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().first()) - - def last(self) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().last()) - - def get(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL]: - return cast(PreparedQuerySetSingle, super().get(*args, **kwargs)) - - async def in_bulk( - self, id_list: Iterable[PRIMARY_KEY], field_name: str - ) -> dict[PRIMARY_KEY, MODEL]: - raise NotImplementedError("Prepared queries don't support in_bulk.") - - def bulk_create( - self, - objects: Iterable[MODEL], - batch_size: int | None = None, - ignore_conflicts: bool = False, - update_fields: Iterable[str] | None = None, - on_conflict: Iterable[str] | None = None, - ) -> BulkCreateQuery[MODEL]: - raise NotImplementedError("Prepared queries don't support bulk_create.") - - def bulk_update( - self, - objects: Iterable[MODEL], - fields: Iterable[str], - batch_size: int | None = None, - ) -> BulkUpdateQuery[MODEL]: - raise NotImplementedError("Prepared queries don't support bulk_update.") - - def get_or_none(self, *args: Q, **kwargs: Any) -> PreparedQuerySetSingle[MODEL | None]: - return cast(PreparedQuerySetSingle, super().get_or_none(*args, **kwargs)) - - def only(self, *fields_for_select: str) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().only(*fields_for_select)) - - def select_related(self, *fields: str) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().select_related(*fields)) - - def force_index(self, *index_names: str) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().force_index(*index_names)) - - def use_index(self, *index_names: str) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().use_index(*index_names)) - - def prefetch_related(self, *args: str | Prefetch) -> PreparingQuerySet[MODEL]: - return cast(PreparingQuerySet, super().prefetch_related(*args)) - - -class PreparedQuerySet(_PreparedQueryMixin[MODEL]): - __slots__ = ( - "_cache_key", - "_custom_fields", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", - ) - - def __init__( - self, - model: type[MODEL], - query: QueryBuilder, - db: BaseDBAsyncClient, - prefetch_map: dict[str, set[str | Prefetch]], - prefetch_queries: dict[str, list[tuple[str | None, QuerySet]]], - select_related_idx: list[ - tuple[type[Model], int, Table | str, type[Model], Iterable[str | None]] - ], - single: bool, - raise_does_not_exist: bool, - select_for_update: bool, - custom_fields: list[str] | None, - cache_key: str, - ) -> None: - super().__init__(model) - self._db = db - self._prefetch_map = prefetch_map - self._prefetch_queries = prefetch_queries - self._select_related_idx = select_related_idx - self._single = single - self._raise_does_not_exist = raise_does_not_exist - self._db_for_write = select_for_update - self._custom_fields: list[str] | None = custom_fields - - self.query = query - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - - def prepared(self) -> PreparedQuerySet[MODEL]: - return self - - def _clone(self) -> PreparedQuerySet[MODEL]: - queryset = self.__class__.__new__(self.__class__) - queryset.model = self.model - queryset.query = self.query - queryset._capabilities = self._capabilities - queryset._annotations = self._annotations - - queryset._db = self._db - queryset._prefetch_map = self._prefetch_map - queryset._prefetch_queries = self._prefetch_queries - queryset._select_related_idx = self._select_related_idx - queryset._single = self._single - queryset._raise_does_not_exist = self._raise_does_not_exist - queryset._db_for_write = self._db_for_write - queryset._custom_fields = self._custom_fields - - queryset._cache_key = self._cache_key - queryset._sql_cache = self._sql_cache - queryset._dynamic_params = self._dynamic_params - queryset._dynamic_params_names = self._dynamic_params_names - return queryset - - async def execute(self, **params) -> list[MODEL]: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - self._choose_db_if_not_chosen(self._db_for_write) - instance_list = await self._db.executor_class( - model=self.model, - db=self._db, - prefetch_map=self._prefetch_map, - prefetch_queries=self._prefetch_queries, - select_related_idx=self._select_related_idx, # type: ignore - ).execute_select( - cached_query.sql, - filled_params, - custom_fields=self._custom_fields, - ) - if self._single: - if len(instance_list) == 1: - return instance_list[0] - if not instance_list: - if self._raise_does_not_exist: - raise DoesNotExist(self.model) - return None # type: ignore - raise MultipleObjectsReturned(self.model) - return instance_list - - -class PreparedUpdateQuery(_PreparedQueryMixin): - __slots__ = ( - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - query: QueryBuilder, - cache_key: str, - ) -> None: - super().__init__(model) - self._db = db - self.query = query - - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = True - - def prepared(self) -> PreparedUpdateQuery: - return self - - def _clone(self) -> PreparedUpdateQuery: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._db = self._db - query._capabilities = self._capabilities - query._annotations = self._annotations - - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - return (await self._db.execute_query(cached_query.sql, filled_params))[0] - - -class PreparedDeleteQuery(_PreparedQueryMixin): - __slots__ = ( - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - query: QueryBuilder, - cache_key: str, - ) -> None: - super().__init__(model) - self._db = db - self.query = query - - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = True - - def prepared(self) -> PreparedDeleteQuery: - return self - - def _clone(self) -> PreparedDeleteQuery: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._capabilities = self._capabilities - query._annotations = self._annotations - query._db = self._db - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - return (await self._db.execute_query(cached_query.sql, filled_params))[0] - - -class PreparedExistsQuery(_PreparedQueryMixin): - __slots__ = ( - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - query: QueryBuilder, - cache_key: str, - ) -> None: - super().__init__(model) - self._db = db - self.query = query - - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = False - - def prepared(self) -> PreparedExistsQuery: - return self - - def _clone(self) -> PreparedExistsQuery: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._capabilities = self._capabilities - query._annotations = self._annotations - query._db = self._db - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - result, _ = await self._db.execute_query(cached_query.sql, filled_params) - return bool(result) - - -class PreparedCountQuery(_PreparedQueryMixin): - __slots__ = ( - "_limit", - "_offset", - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - query: QueryBuilder, - limit: int | None, - offset: int | None, - cache_key: str, - ) -> None: - super().__init__(model) - self._db = db - self.query = query - self._limit = limit or 0 - self._offset = offset or 0 - - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = False - - def prepared(self) -> PreparedCountQuery: - return self - - def _clone(self) -> PreparedCountQuery: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._capabilities = self._capabilities - query._annotations = self._annotations - query._db = self._db - query._limit = self._limit - query._offset = self._offset - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - - return query - - async def execute(self, **params) -> int: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - _, result = await self._db.execute_query(cached_query.sql, filled_params) - if not result: - return 0 - count = list(dict(result[0]).values())[0] - self._offset - if self._limit and count > self._limit: - return self._limit - return count - - -class PreparedValuesListQuery(ValuesListQuery[SINGLE], _PreparedQueryMixin): - __slots__ = ( - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - single: bool, - raise_does_not_exist: bool, - fields_for_select_list: tuple[str, ...] | list[str], - flat: bool, - annotations: dict[str, Any], - query: QueryBuilder, - cache_key: str, - ) -> None: - super().__init__( - model=model, - db=db, - q_objects=[], - single=single, - raise_does_not_exist=raise_does_not_exist, - fields_for_select_list=fields_for_select_list, - limit=None, - offset=None, - distinct=False, - orderings=[], - flat=flat, - annotations=annotations, - custom_filters={}, - group_bys=(), - force_indexes=set(), - use_indexes=set(), - ) - self.query = query - - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = False - - def prepared(self) -> PreparedValuesListQuery[SINGLE]: - return self - - def _clone(self) -> PreparedValuesListQuery[SINGLE]: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._db = self._db - query._capabilities = self._capabilities - - query.fields = self.fields - query._limit = self._limit - query._offset = self._offset - query._distinct = self._distinct - query._orderings = self._orderings - query._custom_filters = self._custom_filters - query._q_objects = self._q_objects - query._single = self._single - query._raise_does_not_exist = self._raise_does_not_exist - query._fields_for_select_list = self._fields_for_select_list - query._flat = self._flat - query._group_bys = self._group_bys - query._force_indexes = self._force_indexes - query._use_indexes = self._use_indexes - query._fields_to_select_sql = self._fields_to_select_sql - query._annotations = self._annotations - - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - - return query - - async def execute(self, **params) -> list[Any] | tuple: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - _, result = await self._db.execute_query(cached_query.sql, filled_params) - return self._process_results(result) - - -class PreparedValuesQuery(ValuesQuery[SINGLE], _PreparedQueryMixin): - __slots__ = ( - "_cache_key", - "_sql_cache", - "_dynamic_params", - "_dynamic_params_names", - "_db_for_write", - ) - - def __init__( - self, - model: type[MODEL], - db: BaseDBAsyncClient, - single: bool, - raise_does_not_exist: bool, - fields_for_select: dict[str, str], - annotations: dict[str, Any], - query: QueryBuilder, - cache_key: str, - ) -> None: - super().__init__( - model=model, - db=db, - q_objects=[], - single=single, - raise_does_not_exist=raise_does_not_exist, - fields_for_select=fields_for_select, - limit=None, - offset=None, - distinct=False, - orderings=[], - annotations=annotations, - custom_filters={}, - group_bys=(), - force_indexes=set(), - use_indexes=set(), - ) - - self.query = query - self._cache_key: str = cache_key - self._sql_cache: dict[str, CachedSql] = {} - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._db_for_write: bool = False - - def prepared(self) -> PreparedValuesQuery[SINGLE]: - return self - - def _clone(self) -> PreparedValuesQuery[SINGLE]: - query = self.__class__.__new__(self.__class__) - query.model = self.model - query.query = self.query - query._db = self._db - query._capabilities = self._capabilities - - query._fields_for_select = self._fields_for_select - query._limit = self._limit - query._offset = self._offset - query._distinct = self._distinct - query._orderings = self._orderings - query._custom_filters = self._custom_filters - query._q_objects = self._q_objects - query._single = self._single - query._raise_does_not_exist = self._raise_does_not_exist - query._db = self._db - query._group_bys = self._group_bys - query._force_indexes = self._force_indexes - query._use_indexes = self._use_indexes - query._annotations = self._annotations - - query._cache_key = self._cache_key - query._db_for_write = self._db_for_write - query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - - return query - - async def execute(self, **params) -> list[dict] | dict: - cached_query = self._get_or_create_cached_sql(params) - filled_params = cached_query.make_filled_params(params) - - result = await self._db.execute_query_dict(cached_query.sql, filled_params) - return self._process_results(result) From 4ae90043f2bc3f4250c5cb12ce2c5708463b4bda Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 09:50:44 +0200 Subject: [PATCH 47/57] fix style and typing issues --- tests/test_queryset_compiled.py | 78 +++++++++------------------------ tortoise/models.py | 2 +- tortoise/queryset.py | 60 ++++++++++++++++++++----- tortoise/queryset_compiled.py | 19 +++++--- 4 files changed, 83 insertions(+), 76 deletions(-) diff --git a/tests/test_queryset_compiled.py b/tests/test_queryset_compiled.py index c13401182..0cacb8d80 100644 --- a/tests/test_queryset_compiled.py +++ b/tests/test_queryset_compiled.py @@ -21,12 +21,7 @@ async def test_gte_filter(db): expected = await Author.filter(id__gte=author2.pk).order_by("id") - prepared = ( - Author - .filter(id__gte=Parameter("idgte")) - .order_by("id") - .compile("test_gte_filter") - ) + prepared = Author.filter(id__gte=Parameter("idgte")).order_by("id").compile("test_gte_filter") print(prepared.sql(idgte=author2.pk)) actual = await prepared.execute(idgte=author2.pk) assert len(actual) == 2 @@ -56,11 +51,7 @@ async def test_startswith_filter(db): author2 = await Author.create(name="testqwe") author3 = await Author.create(name="qwetest") - prepared = ( - Author - .filter(name__startswith=Parameter("name")) - .compile("test_startswith_filter") - ) + prepared = Author.filter(name__startswith=Parameter("name")).compile("test_startswith_filter") for test_name in (author2.pk, author1.name, author3.name, "asd"): expected = await Author.filter(name__startswith=test_name) @@ -88,15 +79,9 @@ async def test_subqueries(db): author2 = await Author.create(name="2") author3 = await Author.create(name="3") - prepared = ( - Author - .filter( - id__in=Subquery( - Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id") - ) - ) - .compile("test_subqueries") - ) + prepared = Author.filter( + id__in=Subquery(Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id")) + ).compile("test_subqueries") for id1, id2 in ( (author2.pk, author1.pk), @@ -115,11 +100,9 @@ async def test_subqueries_in_filter(db): author2 = await Author.create(name="2") author3 = await Author.create(name="3") - prepared = ( - Author - .filter(id__in=Subquery(Author.filter(id__in=Parameter("ids")).values("id"))) - .compile("test_subqueries_in_filter") - ) + prepared = Author.filter( + id__in=Subquery(Author.filter(id__in=Parameter("ids")).values("id")) + ).compile("test_subqueries_in_filter") for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): expected = await Author.filter(id__in=Subquery(Author.filter(id__in=test_ids).values("id"))) @@ -137,8 +120,7 @@ async def test_update(db): new_name1 = f"{author1.name}_test" prepared = ( - Author - .filter(id=Parameter("search_id")) + Author.filter(id=Parameter("search_id")) .update(name=Parameter("replace_name")) .compile("test_update") ) @@ -161,8 +143,7 @@ async def test_delete(db): author3 = await Author.create(name="3") prepared = ( - Author - .filter( + Author.filter( id__in=Parameter("ids"), ) .delete() @@ -181,8 +162,7 @@ async def test_exists(db): author = await Author.create(name="1") prepared = ( - Author - .filter( + Author.filter( id__in=Parameter("ids"), ) .exists() @@ -200,8 +180,7 @@ async def test_count(db): author3 = await Author.create(name="3") prepared = ( - Author - .filter( + Author.filter( id__gte=Parameter("idgte"), ) .count() @@ -225,11 +204,7 @@ async def test_parameter_in_limit(db): ) prepared = ( - Author - .all() - .limit(Parameter("lim")) - .order_by("id") - .compile("test_parameter_in_limit") + Author.all().limit(Parameter("lim")).order_by("id").compile("test_parameter_in_limit") ) assert len(await prepared.execute(lim=1)) == 1 @@ -252,11 +227,7 @@ async def test_parameter_in_offset(db): ) prepared = ( - Author - .all() - .offset(Parameter("off")) - .order_by("id") - .compile("test_parameter_in_offset") + Author.all().offset(Parameter("off")).order_by("id").compile("test_parameter_in_offset") ) assert len(await prepared.execute(off=1)) == 2 @@ -273,8 +244,7 @@ async def test_values(db): author = await Author.create(name="1") prepared = ( - Author - .filter( + Author.filter( id=Parameter("id"), ) .values() @@ -290,8 +260,7 @@ async def test_values_list_all_fields(db): author = await Author.create(name="1") prepared_all = ( - Author - .filter( + Author.filter( id=Parameter("id"), ) .values_list() @@ -306,8 +275,7 @@ async def test_values_list_only_id_field(db): author = await Author.create(name="1") prepared_ids = ( - Author - .filter( + Author.filter( id=Parameter("id"), ) .values_list("id") @@ -322,8 +290,7 @@ async def test_values_list_only_id_field_flat(db): author = await Author.create(name="1") prepared_ids_flat = ( - Author - .filter( + Author.filter( id=Parameter("id"), ) .values_list("id", flat=True) @@ -341,8 +308,7 @@ async def test_update_fk(db): book = await Book.create(name="test", author=author1, rating=5) prepared = ( - Book - .filter(id=Parameter("search_id")) + Book.filter(id=Parameter("search_id")) .update(author=Parameter("replace_author")) .compile("test_update_fk") ) @@ -364,8 +330,7 @@ async def test_update_pk_invalid_obj(db): book = await Book.create(name="test", author=author, rating=5) prepared = ( - Book - .filter(id=Parameter("search_id")) + Book.filter(id=Parameter("search_id")) .update(author=Parameter("replace_author")) .compile("test_update_pk_invalid_obj") ) @@ -396,8 +361,7 @@ def test_remove_prepared_queryset_from_cache(db): def test_prepared_query_get_sql(db, filter_kwargs: dict[str, Any], cache_key_suffix: str): expected_sql = CharPkModel.all().filter(**filter_kwargs).limit(10).offset(0).sql() actual_sql = ( - CharPkModel - .all() + CharPkModel.all() .filter(**{key: Parameter(key) for key in filter_kwargs}) .limit(10) .offset(0) diff --git a/tortoise/models.py b/tortoise/models.py index db485f3c1..5008b0850 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -257,7 +257,7 @@ def __init__(self, meta: Model.Meta) -> None: self.db_native_fields: list[tuple[str, str, Field]] = [] self.db_default_fields: list[tuple[str, str, Field]] = [] self.db_complex_fields: list[tuple[str, str, Field]] = [] - self.query_cache: dict[str, BaseCompiledQuery[Self]] = {} + self.query_cache: dict[str, BaseCompiledQuery] = {} @property def full_name(self) -> str: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 09fcfdb14..979b36be5 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -45,8 +45,15 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model - from tortoise.queryset_compiled import CompiledQuerySet, CompiledUpdateQuery, CompiledDeleteQuery, \ - CompiledExistsQuery, CompiledCountQuery, CompiledValuesListQuery, CompiledValuesQuery + from tortoise.queryset_compiled import ( + CompiledCountQuery, + CompiledDeleteQuery, + CompiledExistsQuery, + CompiledQuerySet, + CompiledUpdateQuery, + CompiledValuesListQuery, + CompiledValuesQuery, + ) MODEL = TypeVar("MODEL", bound="Model") PRIMARY_KEY = TypeVar("PRIMARY_KEY") @@ -1308,7 +1315,11 @@ def compile(self, key: str | None = None) -> CompiledQuerySet[MODEL]: if key in self.model._meta.query_cache: cached = self.model._meta.query_cache[key] if not isinstance(cached, CompiledQuerySet): - ... # TODO: raise an exception + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) return cached._clone() self._choose_db_if_not_chosen(self._select_for_update) @@ -1432,7 +1443,11 @@ def compile(self, key: str | None = None) -> CompiledUpdateQuery[MODEL]: if key in self.model._meta.query_cache: cached = self.model._meta.query_cache[key] if not isinstance(cached, CompiledUpdateQuery): - ... # TODO: raise an exception + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) return cached._clone() self._choose_db_if_not_chosen(True) @@ -1504,7 +1519,11 @@ def compile(self, key: str | None = None) -> CompiledDeleteQuery[MODEL]: if key in self.model._meta.query_cache: cached = self.model._meta.query_cache[key] if not isinstance(cached, CompiledDeleteQuery): - ... # TODO: raise an exception + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) return cached._clone() self._choose_db_if_not_chosen(True) @@ -1576,7 +1595,11 @@ def compile(self, key: str | None = None) -> CompiledExistsQuery[MODEL]: if key in self.model._meta.query_cache: cached = self.model._meta.query_cache[key] if not isinstance(cached, CompiledExistsQuery): - ... # TODO: raise an exception + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) return cached._clone() self._choose_db_if_not_chosen(False) @@ -1662,7 +1685,11 @@ def compile(self, key: str | None = None) -> CompiledCountQuery[MODEL]: if key in self.model._meta.query_cache: cached = self.model._meta.query_cache[key] if not isinstance(cached, CompiledCountQuery): - ... # TODO: raise an exception + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) return cached._clone() self._choose_db_if_not_chosen(False) @@ -1684,8 +1711,7 @@ class FieldsSelectProtocol(Protocol[MODEL]): model: type[MODEL] _annotations: dict[str, Any] - def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: - ... + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: ... class FieldSelectQuery(AwaitableQuery): @@ -1755,7 +1781,9 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: raise FieldError(f'Unknown field "{field}" for model "{self.model.__name__}"') - def resolve_to_python_value(self: FieldsSelectProtocol[MODEL], model: type[MODEL], field: str) -> Callable: + def resolve_to_python_value( + self: FieldsSelectProtocol[MODEL], model: type[MODEL], field: str + ) -> Callable: if field in model._meta.fetch_fields: # return as is to get whole model objects return lambda x: x @@ -1956,7 +1984,11 @@ def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL]: if key in self.model._meta.query_cache: cached = self.model._meta.query_cache[key] if not isinstance(cached, CompiledValuesListQuery): - ... # TODO: raise an exception + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) return cached._clone() self._choose_db_if_not_chosen(False) @@ -2128,7 +2160,11 @@ def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL]: if key in self.model._meta.query_cache: cached = self.model._meta.query_cache[key] if not isinstance(cached, CompiledValuesQuery): - ... # TODO: raise an exception + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) return cached._clone() self._choose_db_if_not_chosen(False) diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index 73d27f0f7..8d364cbd9 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -1,17 +1,24 @@ from __future__ import annotations as _ import sys -from abc import ABC +from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, TypeVar, cast, Callable +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, TypeVar, cast from pypika_tortoise.queries import QueryBuilder, Table from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned from tortoise.parameter import CollectionParameter, Parameter, TortoiseSqlContext from tortoise.query_utils import Prefetch -from tortoise.queryset import MODEL, AwaitableQuery, QuerySet, ValuesListQuery, FieldSelectQuery, ValuesQuery +from tortoise.queryset import ( + MODEL, + AwaitableQuery, + FieldSelectQuery, + QuerySet, + ValuesListQuery, + ValuesQuery, +) if sys.version_info >= (3, 11): # pragma: nocoverage from typing import Self @@ -100,8 +107,8 @@ def _clone(self) -> Self: return query - async def execute(self, **params) -> ...: - ... + @abstractmethod + async def execute(self, **params) -> Any: ... def init_params_table(self) -> None: _, params = self.query.get_parameterized_sql() From dd3bbcb913c6e81ff90bac7d0055bec49fb457e7 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 10:06:52 +0200 Subject: [PATCH 48/57] type single CompiledQuerySet, CompiledValuesListQuery and CompiledValuesQuery --- tortoise/queryset.py | 11 +++++++---- tortoise/queryset_compiled.py | 32 +++++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 979b36be5..b9ff0b4c7 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -50,6 +50,7 @@ CompiledDeleteQuery, CompiledExistsQuery, CompiledQuerySet, + CompiledQuerySetSingle, CompiledUpdateQuery, CompiledValuesListQuery, CompiledValuesQuery, @@ -89,6 +90,8 @@ def values( self, *args: str, **kwargs: str ) -> ValuesQuery[Literal[True]]: ... # pragma: nocoverage + def compile(self, key: str | None = None) -> CompiledQuerySetSingle[T_co]: ... + class AwaitableQuery(Generic[MODEL]): __slots__ = ( @@ -1973,7 +1976,7 @@ async def _execute(self) -> list[Any] | tuple: _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) return self._process_results(result) - def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL]: + def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL, SINGLE]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. @@ -1993,7 +1996,7 @@ def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL]: self._choose_db_if_not_chosen(False) self._make_query() - compiled = CompiledValuesListQuery( + compiled: CompiledValuesListQuery[MODEL, SINGLE] = CompiledValuesListQuery( model=self.model, query=self.query, single=self._single, @@ -2149,7 +2152,7 @@ async def _execute(self) -> list[dict] | dict: result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) return self._process_results(result) - def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL]: + def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL, SINGLE]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. @@ -2169,7 +2172,7 @@ def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL]: self._choose_db_if_not_chosen(False) self._make_query() - compiled = CompiledValuesQuery( + compiled: CompiledValuesQuery[MODEL, SINGLE] = CompiledValuesQuery( model=self.model, query=self.query, single=self._single, diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index 8d364cbd9..ef04e4e70 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable, Iterable -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload from pypika_tortoise.queries import QueryBuilder, Table @@ -13,9 +13,11 @@ from tortoise.query_utils import Prefetch from tortoise.queryset import ( MODEL, + SINGLE, AwaitableQuery, FieldSelectQuery, QuerySet, + T_co, ValuesListQuery, ValuesQuery, ) @@ -29,6 +31,12 @@ from tortoise import Model +class CompiledQuerySetSingle(Protocol[T_co]): + def sql(self, **params) -> str: ... + + async def execute(self, **params) -> MODEL: ... + + T = TypeVar("T") @@ -137,7 +145,6 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: param.collection_size = len(value) reset_params.append(param) - # TODO: add ability to limit cache, use lru? if cache_key not in self._sql_cache: # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy( @@ -160,7 +167,6 @@ def sql(self, params_inline=False, **params) -> str: return cached_query.sql -# TODO: type single queries class CompiledQuerySet(BaseCompiledQuery[MODEL]): __slots__ = ( "_prefetch_map", @@ -295,8 +301,7 @@ async def execute(self, **params) -> int: return count -# TODO: type single query -class CompiledValuesListQuery(BaseCompiledQuery[MODEL]): +class CompiledValuesListQuery(BaseCompiledQuery[MODEL], Generic[MODEL, SINGLE]): __slots__ = ( "fields", "_single", @@ -336,6 +341,14 @@ def _clone(self) -> Self: def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: return FieldSelectQuery.resolve_to_python_value(self, model, field) + @overload + async def execute(self: CompiledValuesListQuery[MODEL, Literal[True]], **params) -> tuple: ... + + @overload + async def execute( + self: CompiledValuesListQuery[MODEL, Literal[False]], **params + ) -> list[Any]: ... + async def execute(self, **params) -> list[Any] | tuple: self._choose_db_if_not_chosen(False) cached_query = self._get_or_create_cached_sql(params) @@ -344,8 +357,7 @@ async def execute(self, **params) -> list[Any] | tuple: return ValuesListQuery._process_results(self, result) -# TODO: type single query -class CompiledValuesQuery(BaseCompiledQuery[MODEL]): +class CompiledValuesQuery(BaseCompiledQuery[MODEL], Generic[MODEL, SINGLE]): __slots__ = ( "_single", "_raise_does_not_exist", @@ -380,6 +392,12 @@ def _clone(self) -> Self: def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: return FieldSelectQuery.resolve_to_python_value(self, model, field) + @overload + async def execute(self: CompiledValuesQuery[MODEL, Literal[True]], **params) -> dict: ... + + @overload + async def execute(self: CompiledValuesQuery[MODEL, Literal[False]], **params) -> list[dict]: ... + async def execute(self, **params) -> list[dict] | dict: self._choose_db_if_not_chosen(False) cached_query = self._get_or_create_cached_sql(params) From e82a4be6b2e162c0dd2ad647d6fde3df9d15f1a0 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 10:33:31 +0200 Subject: [PATCH 49/57] add lru cache for queries that contain collections --- tests/test_queryset_compiled.py | 17 ++++++++ tortoise/queryset.py | 63 ++++++++++++++++++++------- tortoise/queryset_compiled.py | 76 +++++++++++++++++++++++++++------ 3 files changed, 127 insertions(+), 29 deletions(-) diff --git a/tests/test_queryset_compiled.py b/tests/test_queryset_compiled.py index 0cacb8d80..f8159229e 100644 --- a/tests/test_queryset_compiled.py +++ b/tests/test_queryset_compiled.py @@ -370,3 +370,20 @@ def test_prepared_query_get_sql(db, filter_kwargs: dict[str, Any], cache_key_suf ) assert expected_sql == actual_sql + + +def test_compiled_query_auto_cache_size(db): + compiled = Author.filter(id=Parameter("id")).compile() + # Trigger sql generation + compiled.sql(id=1) + assert compiled._sql_cache.maxsize == compiled.DEFAULT_CACHE_SIZE_SIMPLE + + compiled = Author.filter( + id__in=Subquery(Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2")))) + ).compile() + compiled.sql(id1=1, id2=2) + assert compiled._sql_cache.maxsize == compiled.DEFAULT_CACHE_SIZE_SIMPLE + + compiled = Author.filter(id__in=Parameter("ids")).compile() + compiled.sql(ids=[1]) + assert compiled._sql_cache.maxsize == compiled.DEFAULT_CACHE_SIZE_COLLECTIONS diff --git a/tortoise/queryset.py b/tortoise/queryset.py index b9ff0b4c7..340f8b7e3 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -90,7 +90,9 @@ def values( self, *args: str, **kwargs: str ) -> ValuesQuery[Literal[True]]: ... # pragma: nocoverage - def compile(self, key: str | None = None) -> CompiledQuerySetSingle[T_co]: ... + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledQuerySetSingle[T_co]: ... class AwaitableQuery(Generic[MODEL]): @@ -384,11 +386,8 @@ def __init__(self, model: type[MODEL]) -> None: self._force_indexes: set[str] = set() self._use_indexes: set[str] = set() - # TODO: remove _new_cls - def _clone(self, _new_cls: type[QuerySet] | None = None) -> QuerySet[MODEL]: - if _new_cls is None: - _new_cls = self.__class__ - queryset = _new_cls.__new__(_new_cls) + def _clone(self) -> QuerySet[MODEL]: + queryset = self.__class__.__new__(self.__class__) queryset.fields = self.fields queryset.model = self.model queryset.query = self.query @@ -1307,7 +1306,9 @@ async def _execute(self) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list - def compile(self, key: str | None = None) -> CompiledQuerySet[MODEL]: + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledQuerySet[MODEL]: """ Compiles queryset sql. :param key: Cache key for saving compiled query to model cache. @@ -1330,6 +1331,7 @@ def compile(self, key: str | None = None) -> CompiledQuerySet[MODEL]: compiled = CompiledQuerySet( model=self.model, query=self.query, + sql_cache_maxsize=sql_cache_maxsize, prefetch_map=self._prefetch_map, prefetch_queries=self._prefetch_queries, select_related_idx=self._select_related_idx, @@ -1435,7 +1437,9 @@ def __await__(self) -> Generator[Any, None, int]: async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] - def compile(self, key: str | None = None) -> CompiledUpdateQuery[MODEL]: + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledUpdateQuery[MODEL]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. @@ -1455,7 +1459,11 @@ def compile(self, key: str | None = None) -> CompiledUpdateQuery[MODEL]: self._choose_db_if_not_chosen(True) self._make_query() - compiled = CompiledUpdateQuery(model=self.model, query=self.query) + compiled = CompiledUpdateQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + ) if key is not None: self.model._meta.query_cache[key] = compiled @@ -1511,7 +1519,9 @@ def __await__(self) -> Generator[Any, None, int]: async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] - def compile(self, key: str | None = None) -> CompiledDeleteQuery[MODEL]: + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledDeleteQuery[MODEL]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. @@ -1531,7 +1541,11 @@ def compile(self, key: str | None = None) -> CompiledDeleteQuery[MODEL]: self._choose_db_if_not_chosen(True) self._make_query() - compiled = CompiledDeleteQuery(model=self.model, query=self.query) + compiled = CompiledDeleteQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + ) if key is not None: self.model._meta.query_cache[key] = compiled @@ -1587,7 +1601,9 @@ async def _execute( result, _ = await self._db.execute_query(*self.query.get_parameterized_sql()) return bool(result) - def compile(self, key: str | None = None) -> CompiledExistsQuery[MODEL]: + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledExistsQuery[MODEL]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. @@ -1607,7 +1623,11 @@ def compile(self, key: str | None = None) -> CompiledExistsQuery[MODEL]: self._choose_db_if_not_chosen(False) self._make_query() - compiled = CompiledExistsQuery(model=self.model, query=self.query) + compiled = CompiledExistsQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + ) if key is not None: self.model._meta.query_cache[key] = compiled @@ -1677,10 +1697,14 @@ async def _execute(self) -> int: return self._limit return count - def compile(self, key: str | None = None) -> CompiledCountQuery[MODEL]: + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledCountQuery[MODEL]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. + :param sql_cache_maxsize: Maximum cache size for generated sql cache. + Only makes sense for queries that contain collections as a parameters. """ from tortoise.queryset_compiled import CompiledCountQuery @@ -1700,6 +1724,7 @@ def compile(self, key: str | None = None) -> CompiledCountQuery[MODEL]: compiled = CompiledCountQuery( model=self.model, query=self.query, + sql_cache_maxsize=sql_cache_maxsize, limit=self._limit, offset=self._offset, ) @@ -1976,7 +2001,9 @@ async def _execute(self) -> list[Any] | tuple: _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) return self._process_results(result) - def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL, SINGLE]: + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledValuesListQuery[MODEL, SINGLE]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. @@ -1999,6 +2026,7 @@ def compile(self, key: str | None = None) -> CompiledValuesListQuery[MODEL, SING compiled: CompiledValuesListQuery[MODEL, SINGLE] = CompiledValuesListQuery( model=self.model, query=self.query, + sql_cache_maxsize=sql_cache_maxsize, single=self._single, raise_does_not_exist=self._raise_does_not_exist, fields_for_select_list=self._fields_for_select_list, @@ -2152,7 +2180,9 @@ async def _execute(self) -> list[dict] | dict: result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) return self._process_results(result) - def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL, SINGLE]: + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledValuesQuery[MODEL, SINGLE]: """ Compiles query sql. :param key: Cache key for saving compiled query to model cache. @@ -2175,6 +2205,7 @@ def compile(self, key: str | None = None) -> CompiledValuesQuery[MODEL, SINGLE]: compiled: CompiledValuesQuery[MODEL, SINGLE] = CompiledValuesQuery( model=self.model, query=self.query, + sql_cache_maxsize=sql_cache_maxsize, single=self._single, raise_does_not_exist=self._raise_does_not_exist, fields_for_select=self._fields_for_select, diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index ef04e4e70..fbba2ae41 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -2,7 +2,7 @@ import sys from abc import ABC, abstractmethod -from collections import defaultdict +from collections import OrderedDict, defaultdict from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload @@ -90,14 +90,59 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: return filled_params +class _BoundedLRU(Generic[T]): + __slots__ = ("_data", "_maxsize") + + def __init__(self, maxsize: int) -> None: + self._data: OrderedDict[str, T] = OrderedDict() + self._maxsize = maxsize + + @property + def maxsize(self) -> int: + return self._maxsize + + @maxsize.setter + def maxsize(self, value: int) -> None: + self._maxsize = value + + def get(self, key: str) -> T | None: + try: + self._data.move_to_end(key) + return self._data[key] + except KeyError: + return None + + def put(self, key: str, value: T) -> None: + if key in self._data: + self._data.move_to_end(key) + self._data[key] = value + else: + if len(self._data) >= self._maxsize: + self._data.popitem(last=False) + self._data[key] = value + + class BaseCompiledQuery(AwaitableQuery[MODEL], ABC): - __slots__ = ("_sql_cache", "_dynamic_params", "_dynamic_params_names", "_dynamic_params_init") + DEFAULT_CACHE_SIZE_SIMPLE = 1 + DEFAULT_CACHE_SIZE_COLLECTIONS = 128 - def __init__(self, model: type[MODEL], query: QueryBuilder) -> None: + __slots__ = ( + "_sql_cache", + "_sql_cache_maxsize", + "_dynamic_params", + "_dynamic_params_names", + "_dynamic_params_init", + ) + + def __init__( + self, model: type[MODEL], query: QueryBuilder, sql_cache_maxsize: int | None + ) -> None: super().__init__(model) self.query = query - # TODO: use lru - self._sql_cache: dict[str, CachedSql] = {} + self._sql_cache_maxsize = sql_cache_maxsize + self._sql_cache: _BoundedLRU[CachedSql] = _BoundedLRU( + sql_cache_maxsize or self.DEFAULT_CACHE_SIZE_SIMPLE, + ) self._dynamic_params: dict[str, CollectionParameter] = {} self._dynamic_params_names: list[str] = [] self._dynamic_params_init: bool = False @@ -120,12 +165,13 @@ async def execute(self, **params) -> Any: ... def init_params_table(self) -> None: _, params = self.query.get_parameterized_sql() - self._sql_cache = {} self._dynamic_params = { param.name: param for param in params if isinstance(param, CollectionParameter) } self._dynamic_params_names = sorted(self._dynamic_params.keys()) self._dynamic_params_init = True + if self._dynamic_params and self._sql_cache_maxsize is None: + self._sql_cache.maxsize = self.DEFAULT_CACHE_SIZE_COLLECTIONS def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: if not self._dynamic_params_init: @@ -145,19 +191,19 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: param.collection_size = len(value) reset_params.append(param) - if cache_key not in self._sql_cache: + if self._sql_cache.get(cache_key) is None: # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy( self.query.QUERY_CLS.SQL_CONTEXT, dynamic_params=self._dynamic_params, ) sql, params_ = self.query.get_parameterized_sql(ctx) - self._sql_cache[cache_key] = CachedSql(sql, params_) + self._sql_cache.put(cache_key, CachedSql(sql, params_)) for param in reset_params: param.collection_size = None - return self._sql_cache[cache_key] + return cast(CachedSql, self._sql_cache.get(cache_key)) def sql(self, params_inline=False, **params) -> str: old_db = self._db @@ -182,6 +228,7 @@ def __init__( self, model: type[MODEL], query: QueryBuilder, + sql_cache_maxsize: int | None, prefetch_map: dict[str, set[str | Prefetch]], prefetch_queries: dict[str, list[tuple[str | None, QuerySet]]], select_related_idx: list[ @@ -192,7 +239,7 @@ def __init__( select_for_update: bool, custom_fields: list[str] | None, ) -> None: - super().__init__(model, query) + super().__init__(model, query, sql_cache_maxsize) self._prefetch_map = prefetch_map self._prefetch_queries = prefetch_queries self._select_related_idx = select_related_idx @@ -274,10 +321,11 @@ def __init__( self, model: type[MODEL], query: QueryBuilder, + sql_cache_maxsize: int | None, limit: int | None, offset: int | None, ) -> None: - super().__init__(model, query) + super().__init__(model, query, sql_cache_maxsize) self._limit = limit or 0 self._offset = offset or 0 @@ -314,13 +362,14 @@ def __init__( self, model: type[MODEL], query: QueryBuilder, + sql_cache_maxsize: int | None, single: bool, raise_does_not_exist: bool, fields_for_select_list: tuple[str, ...] | list[str], flat: bool, annotations: dict[str, Any], ) -> None: - super().__init__(model, query) + super().__init__(model, query, sql_cache_maxsize) fields_for_select = {str(i): field for i, field in enumerate(fields_for_select_list)} self.fields = fields_for_select @@ -369,12 +418,13 @@ def __init__( self, model: type[MODEL], query: QueryBuilder, + sql_cache_maxsize: int | None, single: bool, raise_does_not_exist: bool, fields_for_select: dict[str, str], annotations: dict[str, Any], ) -> None: - super().__init__(model, query) + super().__init__(model, query, sql_cache_maxsize) self._single = single self._raise_does_not_exist = raise_does_not_exist From 839970af0b7536e45aec98951ab7145e031d37e3 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 10:39:37 +0200 Subject: [PATCH 50/57] set _db to None when cloning compiled queries --- tortoise/queryset_compiled.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index fbba2ae41..d8c3df01a 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -151,6 +151,7 @@ def _clone(self) -> Self: query = self.__class__.__new__(self.__class__) query.model = self.model query.query = self.query + query._db = None # type: ignore query._capabilities = self._capabilities query._annotations = self._annotations From ebd33cbe56335afd6e42cb9cd80f89076afba558 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Thu, 5 Mar 2026 10:41:43 +0200 Subject: [PATCH 51/57] copy _sql_cache_maxsize and _dynamic_params_init when cloning compiled queries --- tortoise/queryset_compiled.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index d8c3df01a..0f4132e35 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -155,9 +155,11 @@ def _clone(self) -> Self: query._capabilities = self._capabilities query._annotations = self._annotations + query._sql_cache_maxsize = self._sql_cache_maxsize query._sql_cache = self._sql_cache query._dynamic_params = self._dynamic_params query._dynamic_params_names = self._dynamic_params_names + query._dynamic_params_init = self._dynamic_params_init return query From 7fc32eb63acef8a2423d612787c130a4d54cf3ad Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 9 Mar 2026 13:44:02 +0200 Subject: [PATCH 52/57] fix compiled queries using model object when filtering by foreign key instead of model pk --- tests/test_queryset_compiled.py | 15 ++++++++++++++- tortoise/expressions.py | 5 ++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/test_queryset_compiled.py b/tests/test_queryset_compiled.py index f8159229e..9b6354077 100644 --- a/tests/test_queryset_compiled.py +++ b/tests/test_queryset_compiled.py @@ -22,7 +22,6 @@ async def test_gte_filter(db): expected = await Author.filter(id__gte=author2.pk).order_by("id") prepared = Author.filter(id__gte=Parameter("idgte")).order_by("id").compile("test_gte_filter") - print(prepared.sql(idgte=author2.pk)) actual = await prepared.execute(idgte=author2.pk) assert len(actual) == 2 assert actual[0].pk == author2.pk @@ -387,3 +386,17 @@ def test_compiled_query_auto_cache_size(db): compiled = Author.filter(id__in=Parameter("ids")).compile() compiled.sql(ids=[1]) assert compiled._sql_cache.maxsize == compiled.DEFAULT_CACHE_SIZE_COLLECTIONS + + +@pytest.mark.asyncio +async def test_filter_by_model(db): + author1 = await Author.create(name="1") + book1 = await Book.create(name="test 1", author=author1, rating=5) + author2 = await Author.create(name="2") + book2 = await Book.create(name="test 2", author=author2, rating=3) + + compiled = Book.filter(author=Parameter("author")).compile() + book = await compiled.execute(author=author1) + assert book == [book1] + book = await compiled.execute(author=author2) + assert book == [book2] diff --git a/tortoise/expressions.py b/tortoise/expressions.py index 8d833229e..c8ace8fdb 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -22,7 +22,7 @@ from tortoise.exceptions import FieldError, OperationalError from tortoise.fields.base import Field from tortoise.fields.data import JSONField -from tortoise.fields.relational import RelationalField +from tortoise.fields.relational import ForeignKeyFieldInstance, RelationalField from tortoise.filters import FilterInfoDict from tortoise.parameter import Parameter from tortoise.query_utils import ( @@ -408,6 +408,9 @@ def _process_filter_kwarg( field_object = model._meta.fields_map[filter_info["field"]] value_encoder = filter_info["value_encoder"] if "value_encoder" in filter_info else None if isinstance(value, Parameter): + if isinstance(field_object.reference, ForeignKeyFieldInstance): + fk_to_field = field_object.reference.to_field + value.value_getter = lambda obj: getattr(obj, fk_to_field) value.field_object = field_object value.value_encoder = value_encoder else: From e6af426880dc09a3f5db3500a5fd599aa78bfa97 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Mon, 9 Mar 2026 15:28:09 +0200 Subject: [PATCH 53/57] raise error if compiled parameters are not provided --- tests/test_queryset_compiled.py | 17 +++++++++++++++++ tortoise/queryset_compiled.py | 25 ++++++++++++++++--------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tests/test_queryset_compiled.py b/tests/test_queryset_compiled.py index 9b6354077..fd575aa7e 100644 --- a/tests/test_queryset_compiled.py +++ b/tests/test_queryset_compiled.py @@ -400,3 +400,20 @@ async def test_filter_by_model(db): assert book == [book1] book = await compiled.execute(author=author2) assert book == [book2] + + +def test_collection_parameter_got_not_collection(db): + compiled = Author.filter(id__in=Parameter("ids")).compile() + with pytest.raises(ValueError): + compiled.sql(ids=123) + + +@pytest.mark.asyncio +async def test_missing_parameters(db): + compiled = Author.filter( + id__in=Parameter("ids"), name__startswith=Parameter("name_prefix") + ).compile() + with pytest.raises(KeyError): + await compiled.execute(ids=[123]) + with pytest.raises(KeyError): + await compiled.execute(name_prefix="test") diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index 0f4132e35..54ede3618 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -68,7 +68,21 @@ def __init__(self, sql: str, params: list[Parameter | Any]) -> None: self.need_params[param.name] = idx def make_filled_params(self, params: dict[str, Any]) -> list[Any]: - # TODO: check for parameters mismatch + for name in self.need_params: + if name not in params: + raise KeyError(f'Expected parameter "{name}" is not provided!') + + for name, indexes in self.need_collection_params.items(): + if name not in params: + raise KeyError(f'Expected parameter "{name}" is not provided!') + collection_length = len(params[name]) + param_length = len(params[name]) + if collection_length != param_length: + raise ValueError( + f"Provided value length ({collection_length}) " + f"for parameter {name!r} does not match " + f"parameter indexes length ({param_length})" + ) filled_params = self.params.copy() for name, idx in self.need_params.items(): @@ -78,12 +92,6 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: for name, indexes in self.need_collection_params.items(): param = cast(CollectionParameter, self.param_by_name[name]) collection = param.encode_collection(params[name]) - if len(collection) != len(indexes): - raise ValueError( - f"Provided value length ({len(collection)}) " - f"for parameter {name!r} does not match " - f"parameter indexes length ({len(indexes)})" - ) for idx, value in zip(indexes, collection): filled_params[idx] = param.encode_value(value) @@ -186,8 +194,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: for name in self._dynamic_params_names: value = params[name] if not isinstance(value, (tuple, list, set)): - # TODO: raise exception? - continue + raise ValueError(f'Expected parameter "{name}" to be a collection, got {value!r}') param = self._dynamic_params[name] cache_key += f"-{name}{len(value)}" From 61d3a3168058ac3d3c9a8a9c492f986ad5967321 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 10 Mar 2026 13:46:06 +0200 Subject: [PATCH 54/57] calculate collection_params in BaseCompiledQuery constructor --- tortoise/queryset_compiled.py | 65 ++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index 54ede3618..449b672f6 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -131,15 +131,13 @@ def put(self, key: str, value: T) -> None: class BaseCompiledQuery(AwaitableQuery[MODEL], ABC): - DEFAULT_CACHE_SIZE_SIMPLE = 1 + DEFAULT_CACHE_SIZE_SIMPLE = 2 DEFAULT_CACHE_SIZE_COLLECTIONS = 128 __slots__ = ( "_sql_cache", - "_sql_cache_maxsize", - "_dynamic_params", - "_dynamic_params_names", - "_dynamic_params_init", + "_collection_params", + "_collection_params_names", ) def __init__( @@ -147,13 +145,19 @@ def __init__( ) -> None: super().__init__(model) self.query = query - self._sql_cache_maxsize = sql_cache_maxsize - self._sql_cache: _BoundedLRU[CachedSql] = _BoundedLRU( - sql_cache_maxsize or self.DEFAULT_CACHE_SIZE_SIMPLE, - ) - self._dynamic_params: dict[str, CollectionParameter] = {} - self._dynamic_params_names: list[str] = [] - self._dynamic_params_init: bool = False + self._sql_cache: _BoundedLRU[CachedSql] = _BoundedLRU(0) + self._collection_params: dict[str, CollectionParameter] = {} + self._collection_params_names: list[str] = [] + + sql, params = self.query.get_parameterized_sql() + self._collection_params = { + param.name: param for param in params if isinstance(param, CollectionParameter) + } + if self._collection_params: + self._collection_params_names = sorted(self._collection_params.keys()) + self._sql_cache.maxsize = sql_cache_maxsize or self.DEFAULT_CACHE_SIZE_COLLECTIONS + else: + self._sql_cache.maxsize = sql_cache_maxsize or self.DEFAULT_CACHE_SIZE_SIMPLE def _clone(self) -> Self: query = self.__class__.__new__(self.__class__) @@ -163,41 +167,38 @@ def _clone(self) -> Self: query._capabilities = self._capabilities query._annotations = self._annotations - query._sql_cache_maxsize = self._sql_cache_maxsize query._sql_cache = self._sql_cache - query._dynamic_params = self._dynamic_params - query._dynamic_params_names = self._dynamic_params_names - query._dynamic_params_init = self._dynamic_params_init + query._collection_params = self._collection_params + query._collection_params_names = self._collection_params_names return query @abstractmethod async def execute(self, **params) -> Any: ... - def init_params_table(self) -> None: - _, params = self.query.get_parameterized_sql() - self._dynamic_params = { - param.name: param for param in params if isinstance(param, CollectionParameter) - } - self._dynamic_params_names = sorted(self._dynamic_params.keys()) - self._dynamic_params_init = True - if self._dynamic_params and self._sql_cache_maxsize is None: - self._sql_cache.maxsize = self.DEFAULT_CACHE_SIZE_COLLECTIONS + def _get_or_create_cached_sql_simple(self) -> CachedSql: + cache_key = self._db.capabilities.dialect + if (cached := self._sql_cache.get(cache_key)) is None: + cached = CachedSql(*self.query.get_parameterized_sql()) + self._sql_cache.put(cache_key, cached) + return cached def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: - if not self._dynamic_params_init: - self.init_params_table() + if not self._collection_params: + return self._get_or_create_cached_sql_simple() + + cache_key = self._db.capabilities.dialect reset_params = [] - cache_key = f"{self._db.capabilities.dialect}-query" - for name in self._dynamic_params_names: + cache_key_parts = [] + for name in self._collection_params_names: value = params[name] if not isinstance(value, (tuple, list, set)): raise ValueError(f'Expected parameter "{name}" to be a collection, got {value!r}') - param = self._dynamic_params[name] - cache_key += f"-{name}{len(value)}" + param = self._collection_params[name] + cache_key_parts.append(f"-{name}:{len(value)}") param.collection_size = len(value) reset_params.append(param) @@ -205,7 +206,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy( self.query.QUERY_CLS.SQL_CONTEXT, - dynamic_params=self._dynamic_params, + dynamic_params=self._collection_params, ) sql, params_ = self.query.get_parameterized_sql(ctx) self._sql_cache.put(cache_key, CachedSql(sql, params_)) From 8931d741b368132e617dfa8765301374fee4b6bd Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 10 Mar 2026 13:49:31 +0200 Subject: [PATCH 55/57] set collection params' sizes when sql was not in cache --- tortoise/queryset_compiled.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index 449b672f6..4fadf5bbb 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -187,22 +187,23 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: if not self._collection_params: return self._get_or_create_cached_sql_simple() - cache_key = self._db.capabilities.dialect - - reset_params = [] - cache_key_parts = [] for name in self._collection_params_names: value = params[name] if not isinstance(value, (tuple, list, set)): raise ValueError(f'Expected parameter "{name}" to be a collection, got {value!r}') - param = self._collection_params[name] - cache_key_parts.append(f"-{name}:{len(value)}") - param.collection_size = len(value) - reset_params.append(param) + cache_key_parts.append(f"{name}:{len(value)}") + + cache_key = f"{self._db.capabilities.dialect}|{'-'.join(cache_key_parts)}" if self._sql_cache.get(cache_key) is None: + reset_params = [] + for name in self._collection_params_names: + param = self._collection_params[name] + param.collection_size = len(params[name]) + reset_params.append(param) + # TODO: probably could be done in a better way? ctx = TortoiseSqlContext.copy( self.query.QUERY_CLS.SQL_CONTEXT, @@ -211,8 +212,8 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: sql, params_ = self.query.get_parameterized_sql(ctx) self._sql_cache.put(cache_key, CachedSql(sql, params_)) - for param in reset_params: - param.collection_size = None + for param in reset_params: + param.collection_size = None return cast(CachedSql, self._sql_cache.get(cache_key)) From 513e3215f06b01d75a9626ca48f490675d14af29 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Tue, 10 Mar 2026 13:51:24 +0200 Subject: [PATCH 56/57] don't store collection param name in sql cache, only length --- tortoise/queryset_compiled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index 4fadf5bbb..8dec047dd 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -193,7 +193,7 @@ def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: if not isinstance(value, (tuple, list, set)): raise ValueError(f'Expected parameter "{name}" to be a collection, got {value!r}') - cache_key_parts.append(f"{name}:{len(value)}") + cache_key_parts.append(str(len(value))) cache_key = f"{self._db.capabilities.dialect}|{'-'.join(cache_key_parts)}" From 472f36ef9ed93f01cfe14e00f9fa167cb5e22ff0 Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Fri, 13 Mar 2026 21:34:21 +0200 Subject: [PATCH 57/57] turn CollectionParameter into Criterion and generate 1=1 or 1=0 if provided parameter is empty (mirror `filters.is_in` and `filters.not_in` behaviour) --- tests/test_queryset_compiled.py | 32 +++++++++++++ tortoise/filters.py | 4 +- tortoise/parameter.py | 81 ++++++++++++++++++++++++--------- tortoise/queryset_compiled.py | 4 +- 4 files changed, 96 insertions(+), 25 deletions(-) diff --git a/tests/test_queryset_compiled.py b/tests/test_queryset_compiled.py index fd575aa7e..c422544a2 100644 --- a/tests/test_queryset_compiled.py +++ b/tests/test_queryset_compiled.py @@ -417,3 +417,35 @@ async def test_missing_parameters(db): await compiled.execute(ids=[123]) with pytest.raises(KeyError): await compiled.execute(name_prefix="test") + + +@pytest.mark.asyncio +async def test_parameter_in_empty(db): + await Author.create(name="1") + await Author.create(name="2") + + compiled = Author.filter(name__in=Parameter("names")).compile() + authors = await compiled.execute(names=[]) + assert not authors + + +@pytest.mark.asyncio +async def test_parameter_not_in_empty(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + + compiled = Author.filter(name__not_in=Parameter("names")).order_by("id").compile() + authors = await compiled.execute(names=[]) + assert authors == [author1, author2] + + +def test_parameter_in_empty_sql(db): + compiled = Author.filter(name__in=Parameter("names")).compile() + compiled_sql = compiled.sql(names=[]) + assert "IN ()" not in compiled_sql + + +def test_parameter_not_in_empty_sql(db): + compiled = Author.filter(name__not_in=Parameter("names")).order_by("id").compile() + compiled_sql = compiled.sql(names=[]) + assert "NOT IN ()" not in compiled_sql diff --git a/tortoise/filters.py b/tortoise/filters.py index 22657fef3..e70e7171b 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -104,7 +104,7 @@ def array_encoder(value: Any | Sequence[Any], instance: Model, field: Field) -> def is_in(field: Term, value: Any) -> Criterion: if value: if isinstance(value, Parameter): - value = CollectionParameter.from_simple_param(value) + return CollectionParameter(field, value, True) return field.isin(value) # SQL has no False, so we return 1=0 return BasicCriterion( @@ -117,7 +117,7 @@ def is_in(field: Term, value: Any) -> Criterion: def not_in(field: Term, value: Any) -> Criterion: if value: if isinstance(value, Parameter): - value = CollectionParameter.from_simple_param(value) + return CollectionParameter(field, value, False) return field.notin(value) | field.isnull() # SQL has no True, so we return 1=1 return BasicCriterion( diff --git a/tortoise/parameter.py b/tortoise/parameter.py index 405460123..dd452e76d 100644 --- a/tortoise/parameter.py +++ b/tortoise/parameter.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast from pypika_tortoise import SqlContext +from pypika_tortoise.enums import Equality +from pypika_tortoise.terms import BasicCriterion, Criterion, NullCriterion, Term, ValueWrapper from tortoise.fields import Field @@ -109,25 +111,36 @@ def encode_value(self, value: Any) -> Any: return encoded -class CollectionParameter(Parameter): +class CollectionParameter(Parameter, Criterion): + IS_IN_EMPTY = BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(0, allow_parametrize=False), + ) + IS_NOT_IN_EMPTY = BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(1, allow_parametrize=False), + ) + __slots__ = ( + "term", "collection_size", "collection_encoder", + "is_in", ) - def __init__(self, name: str) -> None: - super().__init__(name) + def __init__(self, term: Term, param: Parameter, is_in: bool) -> None: + super().__init__(param.name) + self.model = param.model + self.value_encoder = param.value_encoder + self.field_object = param.field_object + self.encode = param.encode + + self.term = term self.collection_size: int | None = None self.collection_encoder: FieldEncoder[Sequence[Any]] | None = None - - @classmethod - def from_simple_param(cls, param: Parameter) -> Self: - new_param = cls(param.name) - new_param.model = param.model - new_param.value_encoder = param.value_encoder - new_param.field_object = param.field_object - new_param.encode = param.encode - return new_param + self.is_in = is_in def encode_collection(self, value: Any) -> Sequence[Any]: if self.collection_encoder is None: @@ -142,18 +155,42 @@ def get_sql(self, ctx: SqlContext) -> str: if ctx.parameterizer is None: raise ValueError("Parametrization must be enabled when using tortoise.Parameter.") + term_sql = self.term.get_sql(ctx) + not_ = "" if self.is_in else "NOT " + fmt = "{term} {not_}IN {container}" + param = self if isinstance(ctx, TortoiseSqlContext) and ctx.dynamic_params is not None: param = ctx.dynamic_params.get(self.name, self) if param.collection_size is None: - return ctx.parameterizer.create_param(param).get_sql(ctx) - else: - placeholders = [] - for idx in range(param.collection_size): - new_param = param.clone() - new_param.collection_encoder = new_param.value_encoder - new_param.value_encoder = None - pypika_param = ctx.parameterizer.create_param(new_param) - placeholders.append(pypika_param.get_sql(ctx)) - return f"({','.join(placeholders)})" + return fmt.format( + term=term_sql, + container=ctx.parameterizer.create_param(param).get_sql(ctx), + not_=not_, + ) + + if not param.collection_size: + if self.is_in: + return self.IS_IN_EMPTY.get_sql(ctx) + return self.IS_NOT_IN_EMPTY.get_sql(ctx) + + placeholders = [] + for idx in range(param.collection_size): + new_param = param.clone() + new_param.collection_encoder = new_param.value_encoder + new_param.value_encoder = None + pypika_param = ctx.parameterizer.create_param(new_param) + placeholders.append(pypika_param.get_sql(ctx)) + + sql = fmt.format( + term=term_sql, + container=f"({','.join(placeholders)})", + not_=not_, + ) + + if not self.is_in: + null_crit = NullCriterion(self.term) + sql = f"({sql} OR {null_crit})" + + return sql diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py index 8dec047dd..7aa833182 100644 --- a/tortoise/queryset_compiled.py +++ b/tortoise/queryset_compiled.py @@ -76,13 +76,15 @@ def make_filled_params(self, params: dict[str, Any]) -> list[Any]: if name not in params: raise KeyError(f'Expected parameter "{name}" is not provided!') collection_length = len(params[name]) - param_length = len(params[name]) + param_length = len(indexes) if collection_length != param_length: raise ValueError( f"Provided value length ({collection_length}) " f"for parameter {name!r} does not match " f"parameter indexes length ({param_length})" ) + # if not collection_length: + # raise ValueError("Parameter must not be empty!") filled_params = self.params.copy() for name, idx in self.need_params.items():