diff --git a/tests/test_queryset_compiled.py b/tests/test_queryset_compiled.py new file mode 100644 index 000000000..c422544a2 --- /dev/null +++ b/tests/test_queryset_compiled.py @@ -0,0 +1,451 @@ +from typing import Any + +import pytest + +from tests.testmodels import Author, Book, CharPkModel +from tortoise.exceptions import ParamsError, ValidationError +from tortoise.expressions import Q, Subquery +from tortoise.parameter import Parameter + + +def test_prepared_queryset_query_always_same(db): + cache_key = "test_prepared_queryset_always_same" + prepared = Author.filter(id=Parameter("some_param")).compile(cache_key) + assert Author.all().compile(cache_key).query is prepared.query + + +@pytest.mark.asyncio +async def test_gte_filter(db): + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + expected = await Author.filter(id__gte=author2.pk).order_by("id") + + prepared = Author.filter(id__gte=Parameter("idgte")).order_by("id").compile("test_gte_filter") + actual = await prepared.execute(idgte=author2.pk) + assert len(actual) == 2 + assert actual[0].pk == author2.pk + assert actual[1].pk == author3.pk + assert expected == actual + + +@pytest.mark.asyncio +async def test_string_param(db): + await Author.create(name="1") + author2 = await Author.create(name="2") + await Author.create(name="3") + + expected = await Author.filter(name=author2.name) + + prepared = Author.filter(name=Parameter("name")).compile("test_string_param") + actual = await prepared.execute(name=author2.name) + assert len(actual) == 1 + assert actual[0].pk == author2.pk + assert expected == actual + + +@pytest.mark.asyncio +async def test_startswith_filter(db): + author1 = await Author.create(name="test") + author2 = await Author.create(name="testqwe") + author3 = await Author.create(name="qwetest") + + prepared = Author.filter(name__startswith=Parameter("name")).compile("test_startswith_filter") + + for test_name in (author2.pk, author1.name, author3.name, "asd"): + expected = await Author.filter(name__startswith=test_name) + actual = await prepared.execute(name=test_name) + assert expected == actual + + +@pytest.mark.asyncio +async def test_in_filter(db): + author1 = await Author.create(name="test") + author2 = await Author.create(name="testqwe") + author3 = await Author.create(name="qwetest") + + prepared = Author.filter(id__in=Parameter("ids")).compile("test_in_filter") + + for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): + expected = await Author.filter(id__in=test_ids) + actual = await prepared.execute(ids=test_ids) + assert expected == actual + + +@pytest.mark.asyncio +async def test_subqueries(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.filter( + id__in=Subquery(Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2"))).values("id")) + ).compile("test_subqueries") + + for id1, id2 in ( + (author2.pk, author1.pk), + (author3.pk, author3.pk * 2), + ): + expected = await Author.filter( + id__in=Subquery(Author.filter(Q(id=id1) | Q(id=id2)).values("id")) + ) + actual = await prepared.execute(id1=id1, id2=id2) + assert expected == actual + + +@pytest.mark.asyncio +async def test_subqueries_in_filter(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = Author.filter( + id__in=Subquery(Author.filter(id__in=Parameter("ids")).values("id")) + ).compile("test_subqueries_in_filter") + + for test_ids in ([author2.pk, author1.pk], [author3.pk, author3.pk * 2, author3.pk * 10]): + expected = await Author.filter(id__in=Subquery(Author.filter(id__in=test_ids).values("id"))) + actual = await prepared.execute(ids=test_ids) + assert expected == actual + + +@pytest.mark.asyncio +async def test_update(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + + original_name1 = author1.name + original_name2 = author2.name + new_name1 = f"{author1.name}_test" + + prepared = ( + Author.filter(id=Parameter("search_id")) + .update(name=Parameter("replace_name")) + .compile("test_update") + ) + + await prepared.execute(search_id=author1.pk, replace_name=new_name1) + await author1.refresh_from_db(["name"]) + await author2.refresh_from_db(["name"]) + assert author1.name == new_name1 + assert author2.name == original_name2 + + await prepared.execute(search_id=author1.pk, replace_name=original_name1) + await author1.refresh_from_db(["name"]) + assert author1.name == original_name1 + + +@pytest.mark.asyncio +async def test_delete(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = ( + Author.filter( + id__in=Parameter("ids"), + ) + .delete() + .compile("test_delete") + ) + + affected = await prepared.execute(ids=[author1.pk]) + assert affected == 1 + assert await Author.all().count() == 2 + existing = await Author.all().values_list("id", flat=True) + assert set(existing) == {author2.pk, author3.pk} + + +@pytest.mark.asyncio +async def test_exists(db): + author = await Author.create(name="1") + + prepared = ( + Author.filter( + id__in=Parameter("ids"), + ) + .exists() + .compile("test_exists") + ) + + assert await prepared.execute(ids=[author.pk]) + assert not await prepared.execute(ids=[author.pk * 2]) + + +@pytest.mark.asyncio +async def test_count(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + author3 = await Author.create(name="3") + + prepared = ( + Author.filter( + id__gte=Parameter("idgte"), + ) + .count() + .compile("test_count") + ) + + assert await prepared.execute(idgte=author1.pk) == 3 + assert await prepared.execute(idgte=author2.pk) == 2 + assert await prepared.execute(idgte=author3.pk) == 1 + assert await prepared.execute(idgte=author3.pk * 2) == 0 + + +@pytest.mark.asyncio +async def test_parameter_in_limit(db): + await Author.bulk_create( + [ + Author(name="1"), + Author(name="2"), + Author(name="3"), + ] + ) + + prepared = ( + Author.all().limit(Parameter("lim")).order_by("id").compile("test_parameter_in_limit") + ) + + assert len(await prepared.execute(lim=1)) == 1 + assert len(await prepared.execute(lim=2)) == 2 + assert len(await prepared.execute(lim=3)) == 3 + assert len(await prepared.execute(lim=4)) == 3 + + with pytest.raises(ParamsError): + await prepared.execute(lim=-1) + + +@pytest.mark.asyncio +async def test_parameter_in_offset(db): + await Author.bulk_create( + [ + Author(name="1"), + Author(name="2"), + Author(name="3"), + ] + ) + + prepared = ( + Author.all().offset(Parameter("off")).order_by("id").compile("test_parameter_in_offset") + ) + + assert len(await prepared.execute(off=1)) == 2 + assert len(await prepared.execute(off=2)) == 1 + assert len(await prepared.execute(off=3)) == 0 + assert len(await prepared.execute(off=4)) == 0 + + with pytest.raises(ParamsError): + await prepared.execute(off=-1) + + +@pytest.mark.asyncio +async def test_values(db): + author = await Author.create(name="1") + + prepared = ( + Author.filter( + id=Parameter("id"), + ) + .values() + .compile("test_values") + ) + + assert await prepared.execute(id=author.pk) == [{"id": author.pk, "name": author.name}] + assert await prepared.execute(id=author.pk * 2) == [] + + +@pytest.mark.asyncio +async def test_values_list_all_fields(db): + author = await Author.create(name="1") + + prepared_all = ( + Author.filter( + id=Parameter("id"), + ) + .values_list() + .compile("test_values_list_all_fields") + ) + assert await prepared_all.execute(id=author.pk) == [(author.pk, author.name)] + assert await prepared_all.execute(id=author.pk * 2) == [] + + +@pytest.mark.asyncio +async def test_values_list_only_id_field(db): + author = await Author.create(name="1") + + prepared_ids = ( + Author.filter( + id=Parameter("id"), + ) + .values_list("id") + .compile("test_values_list_only_id_field") + ) + assert await prepared_ids.execute(id=author.pk) == [(author.pk,)] + assert await prepared_ids.execute(id=author.pk * 2) == [] + + +@pytest.mark.asyncio +async def test_values_list_only_id_field_flat(db): + author = await Author.create(name="1") + + prepared_ids_flat = ( + Author.filter( + id=Parameter("id"), + ) + .values_list("id", flat=True) + .compile("test_values_list_only_id_field_flat") + ) + assert await prepared_ids_flat.execute(id=author.pk) == [author.pk] + assert await prepared_ids_flat.execute(id=author.pk * 2) == [] + + +@pytest.mark.asyncio +async def test_update_fk(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + + book = await Book.create(name="test", author=author1, rating=5) + + prepared = ( + Book.filter(id=Parameter("search_id")) + .update(author=Parameter("replace_author")) + .compile("test_update_fk") + ) + + await prepared.execute(search_id=book.pk, replace_author=author2) + book = await Book.get(id=book.pk).select_related("author") + # await book.refresh_from_db(["author_id"]) + assert book.author == author2 + + await prepared.execute(search_id=book.pk, replace_author=author1) + book = await Book.get(id=book.pk).select_related("author") + # await book.refresh_from_db(["author_id"]) + assert book.author == author1 + + +@pytest.mark.asyncio +async def test_update_pk_invalid_obj(db): + author = await Author.create(name="1") + book = await Book.create(name="test", author=author, rating=5) + + prepared = ( + Book.filter(id=Parameter("search_id")) + .update(author=Parameter("replace_author")) + .compile("test_update_pk_invalid_obj") + ) + + with pytest.raises(ValidationError): + await prepared.execute(search_id=book.pk, replace_author="not an Author object") + + +def test_remove_prepared_queryset_from_cache(db): + cache_key = "test_remove_query_from_cache" + prepared = Author.filter(id=Parameter("some_param")).compile(cache_key) + assert Author.all().compile(cache_key).query is prepared.query + Author.remove_compiled_query(cache_key) + assert Author.all().compile(cache_key).query is not prepared.query + + +@pytest.mark.parametrize( + ( + "filter_kwargs", + "cache_key_suffix", + ), + [ + ({"id": "123"}, "1"), + ({"id__gte": "321"}, "2"), + ({"id__in": ["321", 123, 987]}, "3"), + ], +) +def test_prepared_query_get_sql(db, filter_kwargs: dict[str, Any], cache_key_suffix: str): + expected_sql = CharPkModel.all().filter(**filter_kwargs).limit(10).offset(0).sql() + actual_sql = ( + CharPkModel.all() + .filter(**{key: Parameter(key) for key in filter_kwargs}) + .limit(10) + .offset(0) + .compile(f"test_prepared_query_get_sql-{cache_key_suffix}") + .sql(**filter_kwargs) + ) + + assert expected_sql == actual_sql + + +def test_compiled_query_auto_cache_size(db): + compiled = Author.filter(id=Parameter("id")).compile() + # Trigger sql generation + compiled.sql(id=1) + assert compiled._sql_cache.maxsize == compiled.DEFAULT_CACHE_SIZE_SIMPLE + + compiled = Author.filter( + id__in=Subquery(Author.filter(Q(id=Parameter("id1")) | Q(id=Parameter("id2")))) + ).compile() + compiled.sql(id1=1, id2=2) + assert compiled._sql_cache.maxsize == compiled.DEFAULT_CACHE_SIZE_SIMPLE + + compiled = Author.filter(id__in=Parameter("ids")).compile() + compiled.sql(ids=[1]) + assert compiled._sql_cache.maxsize == compiled.DEFAULT_CACHE_SIZE_COLLECTIONS + + +@pytest.mark.asyncio +async def test_filter_by_model(db): + author1 = await Author.create(name="1") + book1 = await Book.create(name="test 1", author=author1, rating=5) + author2 = await Author.create(name="2") + book2 = await Book.create(name="test 2", author=author2, rating=3) + + compiled = Book.filter(author=Parameter("author")).compile() + book = await compiled.execute(author=author1) + assert book == [book1] + book = await compiled.execute(author=author2) + assert book == [book2] + + +def test_collection_parameter_got_not_collection(db): + compiled = Author.filter(id__in=Parameter("ids")).compile() + with pytest.raises(ValueError): + compiled.sql(ids=123) + + +@pytest.mark.asyncio +async def test_missing_parameters(db): + compiled = Author.filter( + id__in=Parameter("ids"), name__startswith=Parameter("name_prefix") + ).compile() + with pytest.raises(KeyError): + await compiled.execute(ids=[123]) + with pytest.raises(KeyError): + await compiled.execute(name_prefix="test") + + +@pytest.mark.asyncio +async def test_parameter_in_empty(db): + await Author.create(name="1") + await Author.create(name="2") + + compiled = Author.filter(name__in=Parameter("names")).compile() + authors = await compiled.execute(names=[]) + assert not authors + + +@pytest.mark.asyncio +async def test_parameter_not_in_empty(db): + author1 = await Author.create(name="1") + author2 = await Author.create(name="2") + + compiled = Author.filter(name__not_in=Parameter("names")).order_by("id").compile() + authors = await compiled.execute(names=[]) + assert authors == [author1, author2] + + +def test_parameter_in_empty_sql(db): + compiled = Author.filter(name__in=Parameter("names")).compile() + compiled_sql = compiled.sql(names=[]) + assert "IN ()" not in compiled_sql + + +def test_parameter_not_in_empty_sql(db): + compiled = Author.filter(name__not_in=Parameter("names")).order_by("id").compile() + compiled_sql = compiled.sql(names=[]) + assert "NOT IN ()" not in compiled_sql diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 6a19be073..373f0f97e 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -1,4 +1,6 @@ import enum +from collections.abc import Callable +from typing import Any from pypika_tortoise import SqlContext, functions from pypika_tortoise.enums import SqlTypes @@ -32,6 +34,7 @@ search, starts_with, ) +from tortoise.parameter import Parameter class MySQLRegexpComparators(enum.Enum): @@ -49,10 +52,47 @@ def get_value_sql(self, ctx: SqlContext) -> str: return format_quotes(value, quote_char) +# TODO: maybe there is better way to do this? +class StrParamWrapper(ValueWrapper): + value: Parameter + + def get_value_sql(self, ctx: SqlContext) -> str: + quote_char = ctx.secondary_quote_char or "" + real_encoder = self.value.encode + + def _encoder(val: str) -> str: + if real_encoder is not None: + val = real_encoder(val) + val = val.replace(quote_char, quote_char * 2) + return format_quotes(val, quote_char) + + new_param = self.value.clone() + new_param.encode = _encoder + return new_param # type: ignore[return-value] + + def escape_like(val: str) -> str: return val.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") +def _format_str_or_parameter( + value: str | Parameter, + like_start: bool = False, + like_end: bool = False, + escape_func: Callable[[Any], str] = escape_like, +) -> Term: + like_at_start = "%" if like_start else "" + like_at_end = "%" if like_end else "" + + if isinstance(value, Parameter): + value.encode = escape_func + if like_start or like_end: + value.encode = lambda val: f"{like_at_start}{escape_func(val)}{like_at_end}" + return StrParamWrapper(value) + else: + return StrWrapper(f"{like_at_start}{escape_func(value)}{like_at_end}") + + def mysql_contains(field: Term, value: str) -> Criterion: return Like( functions.Cast(field, SqlTypes.CHAR), StrWrapper(f"%{escape_like(value)}%"), escape="" @@ -61,24 +101,30 @@ def mysql_contains(field: Term, value: str) -> Criterion: def mysql_starts_with(field: Term, value: str) -> Criterion: return Like( - functions.Cast(field, SqlTypes.CHAR), StrWrapper(f"{escape_like(value)}%"), escape="" + functions.Cast(field, SqlTypes.CHAR), + _format_str_or_parameter(value, False, True, escape_like), + escape="", ) def mysql_ends_with(field: Term, value: str) -> Criterion: return Like( - functions.Cast(field, SqlTypes.CHAR), StrWrapper(f"%{escape_like(value)}"), escape="" + functions.Cast(field, SqlTypes.CHAR), + _format_str_or_parameter(value, True, False, escape_like), + escape="", ) def mysql_insensitive_exact(field: Term, value: str) -> Criterion: - return functions.Upper(functions.Cast(field, SqlTypes.CHAR)).eq(functions.Upper(str(value))) + return functions.Upper(functions.Cast(field, SqlTypes.CHAR)).eq( + functions.Upper(_format_str_or_parameter(value, escape_func=str)) + ) def mysql_insensitive_contains(field: Term, value: str) -> Criterion: return Like( functions.Upper(functions.Cast(field, SqlTypes.CHAR)), - functions.Upper(StrWrapper(f"%{escape_like(value)}%")), + functions.Upper(_format_str_or_parameter(value, True, True, escape_like)), escape="", ) @@ -86,7 +132,7 @@ def mysql_insensitive_contains(field: Term, value: str) -> Criterion: def mysql_insensitive_starts_with(field: Term, value: str) -> Criterion: return Like( functions.Upper(functions.Cast(field, SqlTypes.CHAR)), - functions.Upper(StrWrapper(f"{escape_like(value)}%")), + functions.Upper(_format_str_or_parameter(value, False, True, escape_like)), escape="", ) @@ -94,18 +140,20 @@ def mysql_insensitive_starts_with(field: Term, value: str) -> Criterion: def mysql_insensitive_ends_with(field: Term, value: str) -> Criterion: return Like( functions.Upper(functions.Cast(field, SqlTypes.CHAR)), - functions.Upper(StrWrapper(f"%{escape_like(value)}")), + functions.Upper(_format_str_or_parameter(value, True, False, escape_like)), escape="", ) def mysql_search(field: Term, value: str) -> SearchCriterion: - return SearchCriterion(field, expr=StrWrapper(value)) + return SearchCriterion(field, expr=_format_str_or_parameter(value, escape_func=lambda x: x)) def mysql_posix_regex(field: Term, value: str) -> BasicCriterion: return BasicCriterion( - MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.CHAR)), StrWrapper(value) + MySQLRegexpComparators.REGEXP, + Coalesce(Cast(field, SqlTypes.CHAR)), + _format_str_or_parameter(value, escape_func=lambda x: x), ) diff --git a/tortoise/expressions.py b/tortoise/expressions.py index ff51e2cb9..c8ace8fdb 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -22,8 +22,9 @@ from tortoise.exceptions import FieldError, OperationalError from tortoise.fields.base import Field from tortoise.fields.data import JSONField -from tortoise.fields.relational import RelationalField +from tortoise.fields.relational import ForeignKeyFieldInstance, RelationalField from tortoise.filters import FilterInfoDict +from tortoise.parameter import Parameter from tortoise.query_utils import ( QueryModifier, TableCriterionTuple, @@ -386,6 +387,9 @@ def _process_filter_kwarg( else: filter_info = model._meta.get_filter(key) + if isinstance(value, Parameter): + value.model = model + field_object = None if "table" in filter_info: # join the table @@ -395,15 +399,26 @@ def _process_filter_kwarg( == filter_info["table"][filter_info["backward_key"]], ) if "value_encoder" in filter_info: - value = filter_info["value_encoder"](value, model) + if isinstance(value, Parameter): + value.value_encoder = filter_info["value_encoder"] + else: + value = filter_info["value_encoder"](value, model) table = filter_info["table"] elif not isinstance(value, Term): field_object = model._meta.fields_map[filter_info["field"]] - value = ( - filter_info["value_encoder"](value, model, field_object) - if "value_encoder" in filter_info - else field_object.to_db_value(value, model) - ) + value_encoder = filter_info["value_encoder"] if "value_encoder" in filter_info else None + if isinstance(value, Parameter): + if isinstance(field_object.reference, ForeignKeyFieldInstance): + fk_to_field = field_object.reference.to_field + value.value_getter = lambda obj: getattr(obj, fk_to_field) + value.field_object = field_object + value.value_encoder = value_encoder + else: + value = ( + value_encoder(value, model, field_object) + if value_encoder is not None + else field_object.to_db_value(value, model) + ) op = filter_info["operator"] term: Term = table[filter_info.get("source_field", filter_info["field"])] if field_object is not None: diff --git a/tortoise/filters.py b/tortoise/filters.py index 511941c95..e70e7171b 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -21,6 +21,7 @@ from tortoise.contrib.postgres.fields import ArrayField, TSVectorField from tortoise.fields import Field, JSONField from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance +from tortoise.parameter import CollectionParameter, Parameter if sys.version_info >= (3, 11): # pragma:nocoverage from typing import NotRequired @@ -102,6 +103,8 @@ def array_encoder(value: Any | Sequence[Any], instance: Model, field: Field) -> def is_in(field: Term, value: Any) -> Criterion: if value: + if isinstance(value, Parameter): + return CollectionParameter(field, value, True) return field.isin(value) # SQL has no False, so we return 1=0 return BasicCriterion( @@ -113,6 +116,8 @@ def is_in(field: Term, value: Any) -> Criterion: def not_in(field: Term, value: Any) -> Criterion: if value: + if isinstance(value, Parameter): + return CollectionParameter(field, value, False) return field.notin(value) | field.isnull() # SQL has no True, so we return 1=1 return BasicCriterion( @@ -142,8 +147,11 @@ def not_null(field: Term, value: Any) -> Criterion: return field.isnull() -def contains(field: Term, value: str) -> Criterion: - return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"%{escape_like(value)}%")) +def contains(field: Term, value: str | Parameter) -> Criterion: + return Like( + Cast(field, SqlTypes.VARCHAR), + field.wrap_constant(_format_str_or_parameter(field, value, True, True)), + ) def search(field: Term, value: str) -> Any: @@ -165,33 +173,58 @@ def insensitive_posix_regex(field: Term, value: str): ) -def starts_with(field: Term, value: str) -> Criterion: - return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"{escape_like(value)}%")) +def _format_str_or_parameter( + field: Term, + value: str | Parameter, + like_start: bool = False, + like_end: bool = False, + escape_func: Callable[[Any], str] = escape_like, +) -> Term: + like_at_start = "%" if like_start else "" + like_at_end = "%" if like_end else "" + if isinstance(value, Parameter): + value.encode = escape_func + wrapped = ValueWrapper(value) + if like_start or like_end: + value.encode = lambda val: f"{like_at_start}{escape_func(val)}{like_at_end}" + return wrapped + else: + return field.wrap_constant(f"{like_at_start}{escape_func(value)}{like_at_end}") -def ends_with(field: Term, value: str) -> Criterion: - return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"%{escape_like(value)}")) +def starts_with(field: Term, value: str | Parameter) -> Criterion: + return Like(Cast(field, SqlTypes.VARCHAR), _format_str_or_parameter(field, value, False, True)) -def insensitive_exact(field: Term, value: str) -> Criterion: - return Upper(Cast(field, SqlTypes.VARCHAR)).eq(Upper(str(value))) +def ends_with(field: Term, value: str | Parameter) -> Criterion: + return Like(Cast(field, SqlTypes.VARCHAR), _format_str_or_parameter(field, value, True, False)) -def insensitive_contains(field: Term, value: str) -> Criterion: + +def insensitive_exact(field: Term, value: str | Parameter) -> Criterion: + return Upper(Cast(field, SqlTypes.VARCHAR)).eq( + Upper(_format_str_or_parameter(field, value, escape_func=str)) + ) + + +def insensitive_contains(field: Term, value: str | Parameter) -> Criterion: return Like( - Upper(Cast(field, SqlTypes.VARCHAR)), field.wrap_constant(Upper(f"%{escape_like(value)}%")) + Upper(Cast(field, SqlTypes.VARCHAR)), + field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, True))), ) -def insensitive_starts_with(field: Term, value: str) -> Criterion: +def insensitive_starts_with(field: Term, value: str | Parameter) -> Criterion: return Like( - Upper(Cast(field, SqlTypes.VARCHAR)), field.wrap_constant(Upper(f"{escape_like(value)}%")) + Upper(Cast(field, SqlTypes.VARCHAR)), + field.wrap_constant(Upper(_format_str_or_parameter(field, value, False, True))), ) -def insensitive_ends_with(field: Term, value: str) -> Criterion: +def insensitive_ends_with(field: Term, value: str | Parameter) -> Criterion: return Like( - Upper(Cast(field, SqlTypes.VARCHAR)), field.wrap_constant(Upper(f"%{escape_like(value)}")) + Upper(Cast(field, SqlTypes.VARCHAR)), + field.wrap_constant(Upper(_format_str_or_parameter(field, value, True, False))), ) @@ -240,7 +273,7 @@ def json_contained_by(field: Term, value: str) -> Criterion: def json_filter(field: Term, value: dict) -> Criterion: - raise NotImplementedError("must be overridden in each xecutor") + raise NotImplementedError("must be overridden in each executor") def array_contains(field: Term, value: Any | Sequence[Any]) -> Criterion: diff --git a/tortoise/models.py b/tortoise/models.py index 71fa3c43d..5008b0850 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -50,6 +50,7 @@ QuerySetSingle, RawSQLQuery, ) +from tortoise.queryset_compiled import BaseCompiledQuery from tortoise.router import router from tortoise.signals import Signals from tortoise.transactions import in_transaction @@ -215,6 +216,7 @@ class MetaInfo: "db_complex_fields", "_default_ordering", "_ordering_validated", + "query_cache", ) def __init__(self, meta: Model.Meta) -> None: @@ -255,6 +257,7 @@ def __init__(self, meta: Model.Meta) -> None: self.db_native_fields: list[tuple[str, str, Field]] = [] self.db_default_fields: list[tuple[str, str, Field]] = [] self.db_complex_fields: list[tuple[str, str, Field]] = [] + self.query_cache: dict[str, BaseCompiledQuery] = {} @property def full_name(self) -> str: @@ -1610,6 +1613,10 @@ async def fetch_for_list( db = using_db or cls._choose_db() await db.executor_class(model=cls, db=db).fetch_for_list(instance_list, *args) + @classmethod + def remove_compiled_query(cls, key: str) -> None: + cls._meta.query_cache.pop(key, None) + @classmethod def _check(cls) -> None: """ diff --git a/tortoise/parameter.py b/tortoise/parameter.py new file mode 100644 index 000000000..dd452e76d --- /dev/null +++ b/tortoise/parameter.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import sys +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast + +from pypika_tortoise import SqlContext +from pypika_tortoise.enums import Equality +from pypika_tortoise.terms import BasicCriterion, Criterion, NullCriterion, Term, ValueWrapper + +from tortoise.fields import Field + +if sys.version_info >= (3, 11): # pragma: nocoverage + from typing import Self +else: + from typing_extensions import Self + +if TYPE_CHECKING: + from tortoise import Model + +T_out = TypeVar("T_out", covariant=True) + + +class FieldEncoder(Protocol[T_out]): + def __call__( + self, + value: Any, + model: type[Model] | None, + field: Field | None = None, + ) -> T_out: ... + + +@dataclass(frozen=True) +class TortoiseSqlContext(SqlContext): + dynamic_params: dict[str, CollectionParameter] | None = None + + def copy(self: SqlContext, **kwargs) -> SqlContext: + existing_dynamic_params = ( + self.dynamic_params if isinstance(self, TortoiseSqlContext) else None + ) + return TortoiseSqlContext( + quote_char=kwargs.get("quote_char", self.quote_char), + secondary_quote_char=kwargs.get("secondary_quote_char", self.secondary_quote_char), + alias_quote_char=kwargs.get("alias_quote_char", self.alias_quote_char), + dialect=kwargs.get("dialect", self.dialect), + as_keyword=kwargs.get("as_keyword", self.as_keyword), + subquery=kwargs.get("subquery", self.subquery), + with_alias=kwargs.get("with_alias", self.with_alias), + with_namespace=kwargs.get("with_namespace", self.with_namespace), + subcriterion=kwargs.get("subcriterion", self.subcriterion), + parameterizer=kwargs.get("parameterizer", self.parameterizer), + groupby_alias=kwargs.get("groupby_alias", self.groupby_alias), + orderby_alias=kwargs.get("orderby_alias", self.orderby_alias), + dynamic_params=kwargs.get("dynamic_params", existing_dynamic_params), + ) + + +class Parameter: + __slots__ = ( + "name", + "model", + "value_encoder", + "field_object", + "encode", + "value_getter", + "value_validator", + ) + + def __init__(self, name: str) -> None: + self.name = name + self.model: type[Model] | None = None + self.value_encoder: FieldEncoder[Any] | None = None + self.field_object: Field | None = None + self.encode: Callable[[Any], Any] | None = None + self.value_getter: Callable[[Any], Any] | None = None + self.value_validator: Callable[[Any], Any] | None = None + + def clone(self) -> Self: + new = self.__new__(self.__class__) + new.name = self.name + new.model = self.model + new.value_encoder = self.value_encoder + new.field_object = self.field_object + new.encode = self.encode + new.value_getter = self.value_getter + new.value_validator = self.value_validator + + return new + + def encode_value(self, value: Any) -> Any: + if self.value_validator is not None: + self.value_validator(value) + + if self.value_getter is not None: + value = self.value_getter(value) + + encoded = value + + if self.value_encoder: + if self.field_object is not None: + encoded = self.value_encoder(value, self.model, self.field_object) + else: + encoded = self.value_encoder(value, self.model) + elif self.field_object is not None: + encoded = self.field_object.to_db_value(value, cast(type["Model"], self.model)) + + if self.encode: + encoded = self.encode(encoded) + + return encoded + + +class CollectionParameter(Parameter, Criterion): + IS_IN_EMPTY = BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(0, allow_parametrize=False), + ) + IS_NOT_IN_EMPTY = BasicCriterion( + Equality.eq, + ValueWrapper(1, allow_parametrize=False), + ValueWrapper(1, allow_parametrize=False), + ) + + __slots__ = ( + "term", + "collection_size", + "collection_encoder", + "is_in", + ) + + def __init__(self, term: Term, param: Parameter, is_in: bool) -> None: + super().__init__(param.name) + self.model = param.model + self.value_encoder = param.value_encoder + self.field_object = param.field_object + self.encode = param.encode + + self.term = term + self.collection_size: int | None = None + self.collection_encoder: FieldEncoder[Sequence[Any]] | None = None + self.is_in = is_in + + def encode_collection(self, value: Any) -> Sequence[Any]: + if self.collection_encoder is None: + return value # TODO: probably raise exception + + if self.field_object is not None: + return self.collection_encoder(value, self.model, self.field_object) + else: + return self.collection_encoder(value, self.model) + + def get_sql(self, ctx: SqlContext) -> str: + if ctx.parameterizer is None: + raise ValueError("Parametrization must be enabled when using tortoise.Parameter.") + + term_sql = self.term.get_sql(ctx) + not_ = "" if self.is_in else "NOT " + fmt = "{term} {not_}IN {container}" + + param = self + if isinstance(ctx, TortoiseSqlContext) and ctx.dynamic_params is not None: + param = ctx.dynamic_params.get(self.name, self) + + if param.collection_size is None: + return fmt.format( + term=term_sql, + container=ctx.parameterizer.create_param(param).get_sql(ctx), + not_=not_, + ) + + if not param.collection_size: + if self.is_in: + return self.IS_IN_EMPTY.get_sql(ctx) + return self.IS_NOT_IN_EMPTY.get_sql(ctx) + + placeholders = [] + for idx in range(param.collection_size): + new_param = param.clone() + new_param.collection_encoder = new_param.value_encoder + new_param.value_encoder = None + pypika_param = ctx.parameterizer.create_param(new_param) + placeholders.append(pypika_param.get_sql(ctx)) + + sql = fmt.format( + term=term_sql, + container=f"({','.join(placeholders)})", + not_=not_, + ) + + if not self.is_in: + null_crit = NullCriterion(self.term) + sql = f"({sql} OR {null_crit})" + + return sql diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 0cfd2fc52..340f8b7e3 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -2,8 +2,9 @@ import types from collections import defaultdict -from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable +from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable, Sequence from copy import copy +from operator import attrgetter from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload from pypika_tortoise import JoinType, Order, Table @@ -27,6 +28,7 @@ RelationalField, ) from tortoise.filters import FilterInfoDict +from tortoise.parameter import Parameter from tortoise.query_utils import ( Prefetch, QueryModifier, @@ -43,6 +45,16 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model + from tortoise.queryset_compiled import ( + CompiledCountQuery, + CompiledDeleteQuery, + CompiledExistsQuery, + CompiledQuerySet, + CompiledQuerySetSingle, + CompiledUpdateQuery, + CompiledValuesListQuery, + CompiledValuesQuery, + ) MODEL = TypeVar("MODEL", bound="Model") PRIMARY_KEY = TypeVar("PRIMARY_KEY") @@ -78,6 +90,10 @@ def values( self, *args: str, **kwargs: str ) -> ValuesQuery[Literal[True]]: ... # pragma: nocoverage + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledQuerySetSingle[T_co]: ... + class AwaitableQuery(Generic[MODEL]): __slots__ = ( @@ -512,30 +528,47 @@ def earliest(self, *orderings: str) -> QuerySetSingle[MODEL | None]: queryset._orderings = self._parse_orderings(orderings) return queryset._as_single() - def limit(self, limit: int) -> QuerySet[MODEL]: + @staticmethod + def _validate_limit(value: int) -> int: + if value < 0: + raise ParamsError("Limit should be non-negative number") + return value + + def limit(self, limit: int | Parameter) -> QuerySet[MODEL]: """ Limits QuerySet to given length. :raises ParamsError: Limit should be non-negative number. """ - if limit < 0: - raise ParamsError("Limit should be non-negative number") + if isinstance(limit, int): + self._validate_limit(limit) + elif isinstance(limit, Parameter): + limit.encode = self._validate_limit queryset = self._clone() - queryset._limit = limit + queryset._limit = limit # type: ignore return queryset - def offset(self, offset: int) -> QuerySet[MODEL]: + @staticmethod + def _validate_offset(value: int) -> int: + if value < 0: + raise ParamsError("Offset should be non-negative number") + return value + + def offset(self, offset: int | Parameter) -> QuerySet[MODEL]: """ Query offset for QuerySet. :raises ParamsError: Offset should be non-negative number. """ - if offset < 0: - raise ParamsError("Offset should be non-negative number") + + if isinstance(offset, int) and offset < 0: + self._validate_offset(offset) + elif isinstance(offset, Parameter): + offset.encode = self._validate_offset queryset = self._clone() - queryset._offset = offset + queryset._offset = offset # type: ignore if self.capabilities.requires_limit and queryset._limit is None: queryset._limit = 1000000 return queryset @@ -648,6 +681,14 @@ def group_by(self, *fields: str) -> QuerySet[MODEL]: queryset._group_bys = fields return queryset + def _get_fields_list_for_select(self, *fields_: str) -> tuple[str, ...] | list[str]: + if self._fields_for_select: + raise ValueError(".values_list() cannot be used with .only()") + + return fields_ or [ + field for field in self.model._meta.fields_map if field in self.model._meta.db_fields + ] + list(self._annotations.keys()) + def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Literal[False]]: """ Make QuerySet returns list of tuples for given args instead of objects. @@ -659,12 +700,8 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite If no arguments are passed it will default to a tuple containing all fields in order of declaration. """ - if self._fields_for_select: - raise ValueError(".values_list() cannot be used with .only()") + fields_for_select_list = self._get_fields_list_for_select(*fields_) - fields_for_select_list = fields_ or [ - field for field in self.model._meta.fields_map if field in self.model._meta.db_fields - ] + list(self._annotations.keys()) return ValuesListQuery( db=self._db, model=self.model, @@ -684,20 +721,7 @@ def values_list(self, *fields_: str, flat: bool = False) -> ValuesListQuery[Lite use_indexes=self._use_indexes, ) - def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: - """ - Make QuerySet return dicts instead of objects. - - If called after `.get()`, `.get_or_none()` or `.first()`, returns a dict instead of an object. - - You can specify which fields to include by: - - Passing field names as positional arguments - - Using kwargs in the format `field_name='name_in_dict'` to customize the keys in the resulting dict - - If no arguments are passed, it will default to a dict containing all fields. - - :raises FieldError: If duplicate key has been provided. - """ + def _get_fields_for_select(self, *args: str, **kwargs: str) -> dict[str, str]: if self._fields_for_select: raise ValueError(".values() cannot be used with .only()") @@ -721,6 +745,24 @@ def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: fields_for_select = {field: field for field in _fields} + return fields_for_select + + def values(self, *args: str, **kwargs: str) -> ValuesQuery[Literal[False]]: + """ + Make QuerySet return dicts instead of objects. + + If called after `.get()`, `.get_or_none()` or `.first()`, returns a dict instead of an object. + + You can specify which fields to include by: + - Passing field names as positional arguments + - Using kwargs in the format `field_name='name_in_dict'` to customize the keys in the resulting dict + + If no arguments are passed, it will default to a dict containing all fields. + + :raises FieldError: If duplicate key has been provided. + """ + fields_for_select = self._get_fields_for_select(*args, **kwargs) + return ValuesQuery( db=self._db, model=self.model, @@ -1264,6 +1306,46 @@ async def _execute(self) -> list[MODEL]: raise MultipleObjectsReturned(self.model) return instance_list + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledQuerySet[MODEL]: + """ + Compiles queryset sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledQuerySet + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledQuerySet): + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) + return cached._clone() + + self._choose_db_if_not_chosen(self._select_for_update) + self._make_query() + compiled = CompiledQuerySet( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + select_for_update=self._select_for_update, + custom_fields=list(self._annotations.keys()), + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class UpdateQuery(AwaitableQuery): __slots__ = ( @@ -1310,13 +1392,19 @@ def _make_query(self) -> None: if field_object.generated: raise IntegrityError(f"Field {key} is generated and can not be updated") if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)): - self.model._validate_relation_type(key, value) fk_field: str = field_object.source_field # type: ignore db_field = self.model._meta.fields_map[fk_field].source_field - value = self.model._meta.fields_map[fk_field].to_db_value( - getattr(value, field_object.to_field_instance.model_field_name), - None, - ) + + if isinstance(value, Parameter): + value.field_object = self.model._meta.fields_map[fk_field] + value.value_getter = attrgetter(field_object.to_field_instance.model_field_name) + value.value_validator = lambda val: self.model._validate_relation_type(key, val) + else: + self.model._validate_relation_type(key, value) + value = self.model._meta.fields_map[fk_field].to_db_value( + getattr(value, field_object.to_field_instance.model_field_name), + None, + ) else: try: db_field = self.model._meta.fields_db_projection[key] @@ -1333,7 +1421,11 @@ def _make_query(self) -> None: ) ).term else: - value = self.model._meta.fields_map[key].to_db_value(value, None) + field_object = self.model._meta.fields_map[key] + if isinstance(value, Parameter): + value.field_object = field_object + else: + value = field_object.to_db_value(value, None) self.query = self.query.set(db_field, value) @@ -1345,6 +1437,39 @@ def __await__(self) -> Generator[Any, None, int]: async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledUpdateQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledUpdateQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledUpdateQuery): + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) + return cached._clone() + + self._choose_db_if_not_chosen(True) + self._make_query() + compiled = CompiledUpdateQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class DeleteQuery(AwaitableQuery): __slots__ = ( @@ -1394,6 +1519,39 @@ def __await__(self) -> Generator[Any, None, int]: async def _execute(self) -> int: return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0] + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledDeleteQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledDeleteQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledDeleteQuery): + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) + return cached._clone() + + self._choose_db_if_not_chosen(True) + self._make_query() + compiled = CompiledDeleteQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class ExistsQuery(AwaitableQuery): __slots__ = ( @@ -1443,6 +1601,39 @@ async def _execute( result, _ = await self._db.execute_query(*self.query.get_parameterized_sql()) return bool(result) + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledExistsQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledExistsQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledExistsQuery): + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled = CompiledExistsQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class CountQuery(AwaitableQuery): __slots__ = ( @@ -1506,6 +1697,50 @@ async def _execute(self) -> int: return self._limit return count + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledCountQuery[MODEL]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + :param sql_cache_maxsize: Maximum cache size for generated sql cache. + Only makes sense for queries that contain collections as a parameters. + """ + + from tortoise.queryset_compiled import CompiledCountQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledCountQuery): + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled = CompiledCountQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + limit=self._limit, + offset=self._offset, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + + +class FieldsSelectProtocol(Protocol[MODEL]): + model: type[MODEL] + _annotations: dict[str, Any] + + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: ... + class FieldSelectQuery(AwaitableQuery): # pylint: disable=W0223 @@ -1574,7 +1809,9 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: raise FieldError(f'Unknown field "{field}" for model "{self.model.__name__}"') - def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: + def resolve_to_python_value( + self: FieldsSelectProtocol[MODEL], model: type[MODEL], field: str + ) -> Callable: if field in model._meta.fetch_fields: # return as is to get whole model objects return lambda x: x @@ -1619,6 +1856,13 @@ def _resolve_group_bys(self, *field_names: str) -> list: return group_bys +class ValuesListProtocol(FieldsSelectProtocol[MODEL], Protocol[MODEL]): + fields: dict[str, str] + _flat: bool + _single: bool + _raise_does_not_exist: bool + + class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( "fields", @@ -1730,8 +1974,7 @@ async def __aiter__(self: ValuesListQuery[Any]) -> AsyncIterator[Any]: for val in await self: yield val - async def _execute(self) -> list[Any] | tuple: - _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) + def _process_results(self: ValuesListProtocol, result: Sequence[dict]) -> list[Any] | tuple: columns = [ (key, self.resolve_to_python_value(self.model, name)) for key, name in self.fields.items() @@ -1754,6 +1997,54 @@ async def _execute(self) -> list[Any] | tuple: raise MultipleObjectsReturned(self.model) return lst_values + async def _execute(self) -> list[Any] | tuple: + _, result = await self._db.execute_query(*self.query.get_parameterized_sql()) + return self._process_results(result) + + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledValuesListQuery[MODEL, SINGLE]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledValuesListQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledValuesListQuery): + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled: CompiledValuesListQuery[MODEL, SINGLE] = CompiledValuesListQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select_list=self._fields_for_select_list, + flat=self._flat, + annotations=self._annotations, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + + +class ValuesProtocol(FieldsSelectProtocol[MODEL], Protocol[MODEL]): + _fields_for_select: dict[str, str] + _single: bool + _raise_does_not_exist: bool + class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): __slots__ = ( @@ -1860,8 +2151,7 @@ async def __aiter__(self: ValuesQuery[Any]) -> AsyncIterator[dict[str, Any]]: for val in await self: yield val - async def _execute(self) -> list[dict] | dict: - result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) + def _process_results(self: ValuesProtocol, result: list[dict]) -> list[dict] | dict: columns = [ val for val in [ @@ -1886,6 +2176,47 @@ async def _execute(self) -> list[dict] | dict: raise MultipleObjectsReturned(self.model) return result + async def _execute(self) -> list[dict] | dict: + result = await self._db.execute_query_dict(*self.query.get_parameterized_sql()) + return self._process_results(result) + + def compile( + self, key: str | None = None, sql_cache_maxsize: int | None = None + ) -> CompiledValuesQuery[MODEL, SINGLE]: + """ + Compiles query sql. + :param key: Cache key for saving compiled query to model cache. + """ + + from tortoise.queryset_compiled import CompiledValuesQuery + + if key in self.model._meta.query_cache: + cached = self.model._meta.query_cache[key] + if not isinstance(cached, CompiledValuesQuery): + raise ValueError( + f"Cached query type mismatch: " + f"expected {self.__class__.__name__}, " + f"got {cached.__class__.__name__}" + ) + return cached._clone() + + self._choose_db_if_not_chosen(False) + self._make_query() + compiled: CompiledValuesQuery[MODEL, SINGLE] = CompiledValuesQuery( + model=self.model, + query=self.query, + sql_cache_maxsize=sql_cache_maxsize, + single=self._single, + raise_does_not_exist=self._raise_does_not_exist, + fields_for_select=self._fields_for_select, + annotations=self._annotations, + ) + + if key is not None: + self.model._meta.query_cache[key] = compiled + + return compiled + class RawSQLQuery(AwaitableQuery): __slots__ = ("_sql", "_db") diff --git a/tortoise/queryset_compiled.py b/tortoise/queryset_compiled.py new file mode 100644 index 000000000..7aa833182 --- /dev/null +++ b/tortoise/queryset_compiled.py @@ -0,0 +1,470 @@ +from __future__ import annotations as _ + +import sys +from abc import ABC, abstractmethod +from collections import OrderedDict, defaultdict +from collections.abc import Callable, Iterable +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload + +from pypika_tortoise.queries import QueryBuilder, Table + +from tortoise.exceptions import DoesNotExist, MultipleObjectsReturned +from tortoise.parameter import CollectionParameter, Parameter, TortoiseSqlContext +from tortoise.query_utils import Prefetch +from tortoise.queryset import ( + MODEL, + SINGLE, + AwaitableQuery, + FieldSelectQuery, + QuerySet, + T_co, + ValuesListQuery, + ValuesQuery, +) + +if sys.version_info >= (3, 11): # pragma: nocoverage + from typing import Self +else: + from typing_extensions import Self + +if TYPE_CHECKING: + from tortoise import Model + + +class CompiledQuerySetSingle(Protocol[T_co]): + def sql(self, **params) -> str: ... + + async def execute(self, **params) -> MODEL: ... + + +T = TypeVar("T") + + +class CachedSql: + __slots__ = ( + "sql", + "params", + "param_by_name", + "need_params", + "need_collection_params", + ) + + def __init__(self, sql: str, params: list[Parameter | Any]) -> None: + self.sql = sql + self.params = params + self.param_by_name: dict[str, Parameter] = {} + self.need_params: dict[str, int] = {} + self.need_collection_params: dict[str, list[int]] = defaultdict(list) + for idx, param in enumerate(params): + if not isinstance(param, Parameter): + continue + + if param.name not in self.param_by_name: + self.param_by_name[param.name] = param + + if isinstance(param, CollectionParameter): + self.need_collection_params[param.name].append(idx) + else: + self.need_params[param.name] = idx + + def make_filled_params(self, params: dict[str, Any]) -> list[Any]: + for name in self.need_params: + if name not in params: + raise KeyError(f'Expected parameter "{name}" is not provided!') + + for name, indexes in self.need_collection_params.items(): + if name not in params: + raise KeyError(f'Expected parameter "{name}" is not provided!') + collection_length = len(params[name]) + param_length = len(indexes) + if collection_length != param_length: + raise ValueError( + f"Provided value length ({collection_length}) " + f"for parameter {name!r} does not match " + f"parameter indexes length ({param_length})" + ) + # if not collection_length: + # raise ValueError("Parameter must not be empty!") + + filled_params = self.params.copy() + for name, idx in self.need_params.items(): + param = self.param_by_name[name] + filled_params[idx] = param.encode_value(params[name]) + + for name, indexes in self.need_collection_params.items(): + param = cast(CollectionParameter, self.param_by_name[name]) + collection = param.encode_collection(params[name]) + for idx, value in zip(indexes, collection): + filled_params[idx] = param.encode_value(value) + + return filled_params + + +class _BoundedLRU(Generic[T]): + __slots__ = ("_data", "_maxsize") + + def __init__(self, maxsize: int) -> None: + self._data: OrderedDict[str, T] = OrderedDict() + self._maxsize = maxsize + + @property + def maxsize(self) -> int: + return self._maxsize + + @maxsize.setter + def maxsize(self, value: int) -> None: + self._maxsize = value + + def get(self, key: str) -> T | None: + try: + self._data.move_to_end(key) + return self._data[key] + except KeyError: + return None + + def put(self, key: str, value: T) -> None: + if key in self._data: + self._data.move_to_end(key) + self._data[key] = value + else: + if len(self._data) >= self._maxsize: + self._data.popitem(last=False) + self._data[key] = value + + +class BaseCompiledQuery(AwaitableQuery[MODEL], ABC): + DEFAULT_CACHE_SIZE_SIMPLE = 2 + DEFAULT_CACHE_SIZE_COLLECTIONS = 128 + + __slots__ = ( + "_sql_cache", + "_collection_params", + "_collection_params_names", + ) + + def __init__( + self, model: type[MODEL], query: QueryBuilder, sql_cache_maxsize: int | None + ) -> None: + super().__init__(model) + self.query = query + self._sql_cache: _BoundedLRU[CachedSql] = _BoundedLRU(0) + self._collection_params: dict[str, CollectionParameter] = {} + self._collection_params_names: list[str] = [] + + sql, params = self.query.get_parameterized_sql() + self._collection_params = { + param.name: param for param in params if isinstance(param, CollectionParameter) + } + if self._collection_params: + self._collection_params_names = sorted(self._collection_params.keys()) + self._sql_cache.maxsize = sql_cache_maxsize or self.DEFAULT_CACHE_SIZE_COLLECTIONS + else: + self._sql_cache.maxsize = sql_cache_maxsize or self.DEFAULT_CACHE_SIZE_SIMPLE + + def _clone(self) -> Self: + query = self.__class__.__new__(self.__class__) + query.model = self.model + query.query = self.query + query._db = None # type: ignore + query._capabilities = self._capabilities + query._annotations = self._annotations + + query._sql_cache = self._sql_cache + query._collection_params = self._collection_params + query._collection_params_names = self._collection_params_names + + return query + + @abstractmethod + async def execute(self, **params) -> Any: ... + + def _get_or_create_cached_sql_simple(self) -> CachedSql: + cache_key = self._db.capabilities.dialect + if (cached := self._sql_cache.get(cache_key)) is None: + cached = CachedSql(*self.query.get_parameterized_sql()) + self._sql_cache.put(cache_key, cached) + return cached + + def _get_or_create_cached_sql(self, params: dict[str, Any]) -> CachedSql: + if not self._collection_params: + return self._get_or_create_cached_sql_simple() + + cache_key_parts = [] + for name in self._collection_params_names: + value = params[name] + if not isinstance(value, (tuple, list, set)): + raise ValueError(f'Expected parameter "{name}" to be a collection, got {value!r}') + + cache_key_parts.append(str(len(value))) + + cache_key = f"{self._db.capabilities.dialect}|{'-'.join(cache_key_parts)}" + + if self._sql_cache.get(cache_key) is None: + reset_params = [] + for name in self._collection_params_names: + param = self._collection_params[name] + param.collection_size = len(params[name]) + reset_params.append(param) + + # TODO: probably could be done in a better way? + ctx = TortoiseSqlContext.copy( + self.query.QUERY_CLS.SQL_CONTEXT, + dynamic_params=self._collection_params, + ) + sql, params_ = self.query.get_parameterized_sql(ctx) + self._sql_cache.put(cache_key, CachedSql(sql, params_)) + + for param in reset_params: + param.collection_size = None + + return cast(CachedSql, self._sql_cache.get(cache_key)) + + def sql(self, params_inline=False, **params) -> str: + old_db = self._db + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + self._db = old_db + return cached_query.sql + + +class CompiledQuerySet(BaseCompiledQuery[MODEL]): + __slots__ = ( + "_prefetch_map", + "_prefetch_queries", + "_select_related_idx", + "_single", + "_raise_does_not_exist", + "_select_for_update", + "_custom_fields", + ) + + def __init__( + self, + model: type[MODEL], + query: QueryBuilder, + sql_cache_maxsize: int | None, + prefetch_map: dict[str, set[str | Prefetch]], + prefetch_queries: dict[str, list[tuple[str | None, QuerySet]]], + select_related_idx: list[ + tuple[type[Model], int, Table | str, type[Model], Iterable[str | None]] + ], + single: bool, + raise_does_not_exist: bool, + select_for_update: bool, + custom_fields: list[str] | None, + ) -> None: + super().__init__(model, query, sql_cache_maxsize) + self._prefetch_map = prefetch_map + self._prefetch_queries = prefetch_queries + self._select_related_idx = select_related_idx + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._select_for_update = select_for_update + self._custom_fields: list[str] | None = custom_fields + + def _clone(self) -> Self: + queryset = super()._clone() + queryset._prefetch_map = self._prefetch_map + queryset._prefetch_queries = self._prefetch_queries + queryset._select_related_idx = self._select_related_idx + queryset._single = self._single + queryset._raise_does_not_exist = self._raise_does_not_exist + queryset._select_for_update = self._select_for_update + queryset._custom_fields = self._custom_fields + return queryset + + async def execute(self, **params) -> list[MODEL]: + self._choose_db_if_not_chosen(self._select_for_update) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + instance_list = await self._db.executor_class( + model=self.model, + db=self._db, + prefetch_map=self._prefetch_map, + prefetch_queries=self._prefetch_queries, + select_related_idx=self._select_related_idx, # type: ignore + ).execute_select( + cached_query.sql, + filled_params, + custom_fields=self._custom_fields, + ) + if self._single: + if len(instance_list) == 1: + return instance_list[0] + if not instance_list: + if self._raise_does_not_exist: + raise DoesNotExist(self.model) + return None # type: ignore + raise MultipleObjectsReturned(self.model) + return instance_list + + +class CompiledUpdateQuery(BaseCompiledQuery[MODEL]): + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(True) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + +class CompiledDeleteQuery(BaseCompiledQuery[MODEL]): + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(True) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + return (await self._db.execute_query(cached_query.sql, filled_params))[0] + + +class CompiledExistsQuery(BaseCompiledQuery[MODEL]): + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + result, _ = await self._db.execute_query(cached_query.sql, filled_params) + return bool(result) + + +class CompiledCountQuery(BaseCompiledQuery[MODEL]): + __slots__ = ( + "_limit", + "_offset", + ) + + def __init__( + self, + model: type[MODEL], + query: QueryBuilder, + sql_cache_maxsize: int | None, + limit: int | None, + offset: int | None, + ) -> None: + super().__init__(model, query, sql_cache_maxsize) + self._limit = limit or 0 + self._offset = offset or 0 + + def _clone(self) -> Self: + query = super()._clone() + query._limit = self._limit + query._offset = self._offset + return query + + async def execute(self, **params) -> int: + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + + _, result = await self._db.execute_query(cached_query.sql, filled_params) + if not result: + return 0 + count = list(dict(result[0]).values())[0] - self._offset + if self._limit and count > self._limit: + return self._limit + return count + + +class CompiledValuesListQuery(BaseCompiledQuery[MODEL], Generic[MODEL, SINGLE]): + __slots__ = ( + "fields", + "_single", + "_raise_does_not_exist", + "_flat", + "_annotations", + ) + + def __init__( + self, + model: type[MODEL], + query: QueryBuilder, + sql_cache_maxsize: int | None, + single: bool, + raise_does_not_exist: bool, + fields_for_select_list: tuple[str, ...] | list[str], + flat: bool, + annotations: dict[str, Any], + ) -> None: + super().__init__(model, query, sql_cache_maxsize) + + fields_for_select = {str(i): field for i, field in enumerate(fields_for_select_list)} + self.fields = fields_for_select + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._flat = flat + self._annotations = annotations + + def _clone(self) -> Self: + query = super()._clone() + query.fields = self.fields + query._single = self._single + query._raise_does_not_exist = self._raise_does_not_exist + query._flat = self._flat + query._annotations = self._annotations + return query + + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: + return FieldSelectQuery.resolve_to_python_value(self, model, field) + + @overload + async def execute(self: CompiledValuesListQuery[MODEL, Literal[True]], **params) -> tuple: ... + + @overload + async def execute( + self: CompiledValuesListQuery[MODEL, Literal[False]], **params + ) -> list[Any]: ... + + async def execute(self, **params) -> list[Any] | tuple: + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + _, result = await self._db.execute_query(cached_query.sql, filled_params) + return ValuesListQuery._process_results(self, result) + + +class CompiledValuesQuery(BaseCompiledQuery[MODEL], Generic[MODEL, SINGLE]): + __slots__ = ( + "_single", + "_raise_does_not_exist", + "_fields_for_select", + "_annotations", + ) + + def __init__( + self, + model: type[MODEL], + query: QueryBuilder, + sql_cache_maxsize: int | None, + single: bool, + raise_does_not_exist: bool, + fields_for_select: dict[str, str], + annotations: dict[str, Any], + ) -> None: + super().__init__(model, query, sql_cache_maxsize) + + self._single = single + self._raise_does_not_exist = raise_does_not_exist + self._fields_for_select = fields_for_select + self._annotations = annotations + + def _clone(self) -> Self: + query = super()._clone() + query._single = self._single + query._raise_does_not_exist = self._raise_does_not_exist + query._fields_for_select = self._fields_for_select + query._annotations = self._annotations + return query + + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: + return FieldSelectQuery.resolve_to_python_value(self, model, field) + + @overload + async def execute(self: CompiledValuesQuery[MODEL, Literal[True]], **params) -> dict: ... + + @overload + async def execute(self: CompiledValuesQuery[MODEL, Literal[False]], **params) -> list[dict]: ... + + async def execute(self, **params) -> list[dict] | dict: + self._choose_db_if_not_chosen(False) + cached_query = self._get_or_create_cached_sql(params) + filled_params = cached_query.make_filled_params(params) + result = await self._db.execute_query_dict(cached_query.sql, filled_params) + return ValuesQuery._process_results(self, result)