Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/fields/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def test_list(db):
assert obj == obj2


@requireCapability(dialect=In("mysql", "postgres"))
@requireCapability(dialect=In("mysql", "postgres", "sqlite"))
@pytest.mark.asyncio
async def test_list_contains(db):
"""Test JSON contains filter on list."""
Expand Down
4 changes: 4 additions & 0 deletions tortoise/backends/sqlite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from tortoise.backends.sqlite.executor import SqliteExecutor
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.connection import get_connections
from tortoise.contrib.sqlite.json_functions import (
install_json_functions as install_json_functions_to_db,
)
from tortoise.contrib.sqlite.regex import (
install_regexp_functions as install_regexp_functions_to_db,
)
Expand Down Expand Up @@ -84,6 +87,7 @@ async def create_connection(self, with_db: bool) -> None:
for pragma, val in self.pragmas.items():
cursor = await self._connection.execute(f"PRAGMA {pragma}={val}")
await cursor.close()
await install_json_functions_to_db(self._connection)
self.log.debug(
"Created connection %s with params: filename=%s %s",
self._connection,
Expand Down
4 changes: 3 additions & 1 deletion tortoise/backends/sqlite/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.sqlite.json_functions import sqlite_json_contains
from tortoise.contrib.sqlite.regex import (
insensitive_posix_sqlite_regexp,
posix_sqlite_regexp,
)
from tortoise.fields import BigIntField, IntField, SmallIntField
from tortoise.filters import insensitive_posix_regex, posix_regex
from tortoise.filters import insensitive_posix_regex, json_contains, posix_regex

# Conversion for the cases where it's hard to know the
# related field, e.g. in raw queries, math or annotations.
Expand All @@ -24,6 +25,7 @@ class SqliteExecutor(BaseExecutor):
FILTER_FUNC_OVERRIDE = {
posix_regex: posix_sqlite_regexp,
insensitive_posix_regex: insensitive_posix_sqlite_regexp,
json_contains: sqlite_json_contains,
}

async def _process_insert_result(self, instance: Model, results: int) -> None:
Expand Down
50 changes: 50 additions & 0 deletions tortoise/contrib/sqlite/json_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import json

import aiosqlite
from pypika_tortoise.terms import Criterion, Term, ValueWrapper
from pypika_tortoise.terms import Function as PypikaFunction


class SQLiteJSONContains(PypikaFunction):
def __init__(self, column_name: Term, target: Term) -> None:
super().__init__("json_contains", column_name, target)


def sqlite_json_contains(field: Term, value: str) -> Criterion:
return SQLiteJSONContains(field, ValueWrapper(value))


def _json_contains_impl(target_str: str | None, candidate_str: str | None) -> bool:
"""Check if target JSON value contains the candidate JSON value.

Semantics match PostgreSQL's @> operator:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should be referencing postgresql implementation in sqlite contrib

- Arrays: every element of candidate appears in target.
- Objects: every key in candidate exists in target with a matching value.
- Scalars: equality.
"""
if target_str is None or candidate_str is None:
return False
try:
target = json.loads(target_str)
candidate = json.loads(candidate_str)
except (json.JSONDecodeError, TypeError):
return False
return _contains(target, candidate)


def _contains(target: object, candidate: object) -> bool:
if isinstance(candidate, dict):
if not isinstance(target, dict):
return False
return all(k in target and _contains(target[k], v) for k, v in candidate.items())
if isinstance(candidate, list):
if not isinstance(target, list):
return False
return all(any(_contains(t, c) for t in target) for c in candidate)
return target == candidate


async def install_json_functions(connection: aiosqlite.Connection) -> None:
await connection.create_function("json_contains", 2, _json_contains_impl)