diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fbd0f55ed..52ef43da4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -235,7 +235,7 @@ jobs: test-integration: needs: [ci-scope, test-unit] if: needs.ci-scope.outputs.ci_required == 'true' - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 strategy: fail-fast: true matrix: @@ -288,12 +288,22 @@ jobs: sudo apt-get clean df -h / + - name: Install SQL Server ODBC driver + run: | + set -eu + curl -sSL -O https://packages.microsoft.com/config/ubuntu/24.04/packages-microsoft-prod.deb + sudo dpkg -i packages-microsoft-prod.deb + rm packages-microsoft-prod.deb + sudo apt-get update + sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 unixodbc-dev + odbcinst -q -d -n "ODBC Driver 18 for SQL Server" + - name: Cache docker images id: docker-image-cache uses: actions/cache@v6 with: path: /tmp/docker-images - key: docker-images-v1-${{ hashFiles('.github/workflows/ci.yml') }} + key: docker-images-v2-${{ hashFiles('.github/workflows/ci.yml') }} - name: Load cached docker images if: steps.docker-image-cache.outputs.cache-hit == 'true' @@ -312,6 +322,7 @@ jobs: paradedb/paradedb:0.21.5-pg16 cockroachdb/cockroach:latest mysql:8.4 + mcr.microsoft.com/mssql/server:2022-latest gvenzl/oracle-free:23-slim-faststart ghcr.io/goccy/bigquery-emulator:latest gcr.io/cloud-spanner-emulator/emulator:latest diff --git a/docs/extensions/adk/adapters.rst b/docs/extensions/adk/adapters.rst index f2da9d3a1..01c3f7b60 100644 --- a/docs/extensions/adk/adapters.rst +++ b/docs/extensions/adk/adapters.rst @@ -18,9 +18,11 @@ Use async adapters for best performance with ADK runners: - **DuckDB**: ``duckdb`` (analytics; reduced-scope for ADK) - **ADBC**: ``adbc`` (Arrow-native portability; reduced-scope for ADK) - **Spanner**: ``spanner`` (Google Cloud, globally distributed) +- **SQL Server over ODBC**: ``arrow_odbc`` with Microsoft ODBC Driver 18 -Sync adapters (``psycopg`` sync mode, ``sqlite``, ``mysqlconnector``, ``pymysql``) -work but require wrapping with ``anyio`` for async ADK runners. +Sync adapters (``psycopg`` sync mode, ``sqlite``, ``mysqlconnector``, +``pymysql``, ``arrow_odbc``) work but require wrapping with ``anyio`` for async +ADK runners. Each Adapter Provides ===================== diff --git a/docs/extensions/adk/backends.rst b/docs/extensions/adk/backends.rst index 5827b65d9..df5831a31 100644 --- a/docs/extensions/adk/backends.rst +++ b/docs/extensions/adk/backends.rst @@ -93,6 +93,11 @@ The table below classifies every backend by its ADK support level. - Full - Basic - Portability layer; native adapters provide optimized search. + * - arrow_odbc + - Supported + - Full + - Basic + - SQL Server through Microsoft ODBC Driver 18. * - spanner - Supported - Full @@ -270,6 +275,20 @@ ADBC (Arrow Database Connectivity) provides a driver-agnostic interface: - Memory search uses the portable baseline path; choose a native adapter for backend-specific FTS, retention, and storage tuning. +arrow-odbc +---------- + +``arrow_odbc`` provides SQL Server-backed ADK storage through Microsoft ODBC +Driver 18: + +- Session and event storage use SQL Server tables with ``DATETIME2(6)`` and + ``NVARCHAR(MAX)`` JSON payload columns. +- ``append_event_and_update_state()`` writes the session update, event row, and + scoped state in one committed session. +- Memory storage uses the portable text-search path with ``LIKE`` matching. +- Row-oriented ``execute_many()`` is not available; use native Arrow ingest for + bulk database writes outside the ADK store. + Spanner ------- diff --git a/docs/reference/adapters/arrow_odbc.rst b/docs/reference/adapters/arrow_odbc.rst index 8df0b0324..b32cbf304 100644 --- a/docs/reference/adapters/arrow_odbc.rst +++ b/docs/reference/adapters/arrow_odbc.rst @@ -7,6 +7,16 @@ Streams ``pyarrow.RecordBatchReader`` results from any ODBC-compliant driver, making it a good fit for read-heavy analytical transfer between SQL Server, PostgreSQL, MySQL, or other ODBC sources and the Arrow ecosystem. +SQL Server coverage is exercised in CI against SQL Server 2022 through +``pytest-databases`` and Microsoft ODBC Driver 18. The shared contract matrix +verifies native Arrow reads, Arrow reader/batch output, and Arrow bulk ingest +for this adapter. Row-oriented ``execute_many()`` is intentionally unsupported; +use ``load_from_arrow()`` for bulk writes. + +Extension support is SQL Server-backed. The adapter exports a table-backed +events queue store, a Litestar session store, and Google ADK session/event and +memory stores for SQL Server connections through Microsoft ODBC Driver 18. + Configuration ============= @@ -36,6 +46,25 @@ Data Dictionary :members: :show-inheritance: +Extensions +========== + +.. autoclass:: sqlspec.adapters.arrow_odbc.events.ArrowOdbcEventQueueStore + :members: + :show-inheritance: + +.. autoclass:: sqlspec.adapters.arrow_odbc.litestar.ArrowOdbcStore + :members: + :show-inheritance: + +.. autoclass:: sqlspec.adapters.arrow_odbc.adk.ArrowOdbcADKStore + :members: + :show-inheritance: + +.. autoclass:: sqlspec.adapters.arrow_odbc.adk.ArrowOdbcADKMemoryStore + :members: + :show-inheritance: + Schema Discovery ================ diff --git a/docs/reference/extensions/events.rst b/docs/reference/extensions/events.rst index c2071f7aa..c6ce77b01 100644 --- a/docs/reference/extensions/events.rst +++ b/docs/reference/extensions/events.rst @@ -107,6 +107,10 @@ Listeners Event Queue =========== +The durable table queue is available for SQL Server through ``arrow_odbc`` when +configured with Microsoft ODBC Driver 18. It uses SQL Server ``DATETIME2(6)`` +timestamps and ``NVARCHAR`` payload columns. + .. autoclass:: sqlspec.extensions.events.AsyncTableEventQueue :members: :show-inheritance: diff --git a/docs/usage/drivers_and_querying.rst b/docs/usage/drivers_and_querying.rst index 01378265f..65153af7d 100644 --- a/docs/usage/drivers_and_querying.rst +++ b/docs/usage/drivers_and_querying.rst @@ -97,7 +97,10 @@ asyncpg and CockroachDB-asyncpg (cursors inside a stream-owned transaction), pymysql/aiomysql/asyncmy (``SSCursor``), mysql-connector (unbuffered cursors), sqlite/aiosqlite and oracledb (chunked ``fetchmany``), psqlpy (server-side cursor with ``array_size``), and BigQuery (page-wise result iteration). -ADBC, DuckDB, mssql-python, arrow-odbc, and Spanner are eager-fallback only. +ADBC, DuckDB, mssql-python, and Spanner are eager-fallback only for row +streaming. ``arrow-odbc`` row streaming also materializes dict rows eagerly, but +``select_to_arrow(..., return_format="reader" | "batches")`` uses the native +``arrow_odbc`` ``RecordBatchReader`` path. Lifetime and transaction rules: diff --git a/docs/usage/frameworks/litestar/session_stores.rst b/docs/usage/frameworks/litestar/session_stores.rst index 455f2c11c..b00e00a8d 100644 --- a/docs/usage/frameworks/litestar/session_stores.rst +++ b/docs/usage/frameworks/litestar/session_stores.rst @@ -32,6 +32,7 @@ SQLSpec provides stores for async adapters: - ``AsyncpgStore`` - PostgreSQL via asyncpg - ``AiosqliteStore`` - SQLite via aiosqlite +- ``ArrowOdbcStore`` - SQL Server via arrow-odbc and Microsoft ODBC Driver 18 Each store automatically creates its session table on first use if it doesn't exist. diff --git a/sqlspec/adapters/arrow_odbc/adk/__init__.py b/sqlspec/adapters/arrow_odbc/adk/__init__.py new file mode 100644 index 000000000..4f1f2d34c --- /dev/null +++ b/sqlspec/adapters/arrow_odbc/adk/__init__.py @@ -0,0 +1,5 @@ +"""arrow-odbc ADK extension exports.""" + +from sqlspec.adapters.arrow_odbc.adk.store import ArrowOdbcADKConfig, ArrowOdbcADKMemoryStore, ArrowOdbcADKStore + +__all__ = ("ArrowOdbcADKConfig", "ArrowOdbcADKMemoryStore", "ArrowOdbcADKStore") diff --git a/sqlspec/adapters/arrow_odbc/adk/store.py b/sqlspec/adapters/arrow_odbc/adk/store.py new file mode 100644 index 000000000..6611058a5 --- /dev/null +++ b/sqlspec/adapters/arrow_odbc/adk/store.py @@ -0,0 +1,871 @@ +"""arrow-odbc ADK stores for Google Agent Development Kit session storage.""" + +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, ClassVar, Final, cast + +from typing_extensions import NotRequired + +from sqlspec.config import ADKConfig +from sqlspec.exceptions import SQLSpecError +from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory import BaseSyncADKMemoryStore, MemoryRecord +from sqlspec.utils.serializers import from_json, to_json + +if TYPE_CHECKING: + from datetime import timedelta + + from sqlspec.adapters.arrow_odbc.config import ArrowOdbcConfig +else: + ArrowOdbcConfig = Any + + +__all__ = ("ArrowOdbcADKConfig", "ArrowOdbcADKMemoryStore", "ArrowOdbcADKStore") + +MSSQL_SCHEMA: Final[str] = "dbo" +JSON_COLUMN_TYPE: Final[str] = "NVARCHAR(MAX)" + + +class ArrowOdbcADKConfig(ADKConfig): + """arrow-odbc ADK extension settings.""" + + native_json: NotRequired[bool] + """Accepted for parity with SQL Server adapters; arrow-odbc uses NVARCHAR(MAX).""" + + +class ArrowOdbcADKStore(BaseSyncADKStore["ArrowOdbcConfig"]): + """Synchronous SQL Server ADK session/event store using arrow-odbc.""" + + connector_name: ClassVar[str] = "arrow_odbc" + __slots__ = () + + def create_tables(self) -> None: + """Create all ADK session tables if they do not exist.""" + with self._config.provide_session() as driver: + driver.execute(self._get_create_sessions_table_sql()) + driver.execute(self._get_create_events_table_sql()) + driver.execute(self._get_create_app_states_table_sql()) + driver.execute(self._get_create_user_states_table_sql()) + driver.execute(self._get_create_metadata_table_sql()) + driver.execute(self._get_seed_metadata_sql()) + driver.commit() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new ADK session.""" + owner_column = f", {_quote_identifier(self._owner_id_column_name)}" if self._owner_id_column_name else "" + owner_param = ", ?" if self._owner_id_column_name else "" + params: tuple[Any, ...] + if self._owner_id_column_name: + params = (session_id, app_name, user_id, owner_id, to_json(state)) + else: + params = (session_id, app_name, user_id, to_json(state)) + with self._config.provide_session() as driver: + driver.execute( + f""" + INSERT INTO {_table_ref(self._session_table)} ( + id, app_name, user_id{owner_column}, state, create_time, update_time + ) + VALUES (?, ?, ?{owner_param}, ?, SYSUTCDATETIME(), SYSUTCDATETIME()) + """, + params, + ) + row = driver.select_one_or_none( + _get_session_select_sql(self._session_table), (app_name, user_id, session_id) + ) + driver.commit() + if row is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return _session_record_from_row(row) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Return a scoped session or ``None`` if absent.""" + try: + with self._config.provide_session() as driver: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + driver.execute( + f""" + UPDATE {_table_ref(self._session_table)} + SET update_time = SYSUTCDATETIME() + WHERE app_name = ? AND user_id = ? AND id = ? + """, + (app_name, user_id, session_id), + ) + row = driver.select_one_or_none( + _get_session_select_sql(self._session_table), (app_name, user_id, session_id) + ) + if renew_for is not None: + driver.commit() + except SQLSpecError as exc: + if _is_table_missing(exc): + return None + raise + return _session_record_from_row(row) if row is not None else None + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Replace a session's durable state.""" + self._execute( + f""" + UPDATE {_table_ref(self._session_table)} + SET state = ?, update_time = SYSUTCDATETIME() + WHERE app_name = ? AND user_id = ? AND id = ? + """, + (to_json(state), app_name, user_id, session_id), + commit=True, + ) + + def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + """List ADK sessions for an application, optionally scoped to a user.""" + if user_id is None: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {_table_ref(self._session_table)} + WHERE app_name = ? + ORDER BY update_time DESC + """ + params: tuple[Any, ...] = (app_name,) + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {_table_ref(self._session_table)} + WHERE app_name = ? AND user_id = ? + ORDER BY update_time DESC + """ + params = (app_name, user_id) + try: + rows = self._execute_fetchall(sql, params) + except SQLSpecError as exc: + if _is_table_missing(exc): + return [] + raise + return [_session_record_from_row(row) for row in rows] + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete a session. Event rows cascade through the FK.""" + self._execute( + f"DELETE FROM {_table_ref(self._session_table)} WHERE app_name = ? AND user_id = ? AND id = ?", + (app_name, user_id, session_id), + commit=True, + ) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._execute(_get_insert_event_sql(self._events_table), _event_insert_params(event_record), commit=True) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update durable session/scoped state.""" + with self._config.provide_session() as driver: + driver.execute( + f""" + UPDATE {_table_ref(self._session_table)} + SET state = ?, update_time = SYSUTCDATETIME() + WHERE app_name = ? AND user_id = ? AND id = ? + """, + (to_json(state), app_name, user_id, session_id), + ) + row = driver.select_one_or_none( + _get_session_select_sql(self._session_table), (app_name, user_id, session_id) + ) + if row is None: + _raise_session_not_found(session_id) + driver.execute(_get_insert_event_sql(self._events_table), _event_insert_params(event_record)) + if app_state is not None: + driver.execute(self._get_upsert_app_state_sql(), (app_name, to_json(app_state))) + if user_state is not None: + driver.execute(self._get_upsert_user_state_sql(), (app_name, user_id, to_json(user_state))) + driver.commit() + return _session_record_from_row(row) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Return events for a scoped session ordered by event timestamp.""" + if limit is not None and limit <= 0: + return [] + sql, params = self._get_events_query(app_name, user_id, session_id, after_timestamp, limit) + try: + rows = self._execute_fetchall(sql, params) + except SQLSpecError as exc: + if _is_table_missing(exc): + return [] + raise + return [_event_record_from_row(row) for row in rows] + + def delete_expired_events(self, before: datetime) -> int: + """Delete events older than ``before``.""" + try: + count = self._select_count( + f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._events_table)} WHERE timestamp < ?", + (_format_datetime(before),), + ) + self._execute( + f"DELETE FROM {_table_ref(self._events_table)} WHERE timestamp < ?", + (_format_datetime(before),), + commit=True, + ) + except SQLSpecError as exc: + if _is_table_missing(exc): + return 0 + raise + else: + return count + + def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time is older than ``updated_before``.""" + try: + count = self._select_count( + f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._session_table)} WHERE update_time < ?", + (_format_datetime(updated_before),), + ) + self._execute( + f"DELETE FROM {_table_ref(self._session_table)} WHERE update_time < ?", + (_format_datetime(updated_before),), + commit=True, + ) + except SQLSpecError as exc: + if _is_table_missing(exc): + return 0 + raise + else: + return count + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state.""" + try: + row = self._execute_fetchone( + f"SELECT TOP 1 state FROM {_table_ref(self._app_state_table)} WHERE app_name = ?", (app_name,) + ) + except SQLSpecError as exc: + if _is_table_missing(exc): + return None + raise + return _json_dict(_row_value(row, "state", 0)) if row is not None else None + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state.""" + try: + row = self._execute_fetchone( + f""" + SELECT TOP 1 state + FROM {_table_ref(self._user_state_table)} + WHERE app_name = ? AND user_id = ? + """, + (app_name, user_id), + ) + except SQLSpecError as exc: + if _is_table_missing(exc): + return None + raise + return _json_dict(_row_value(row, "state", 0)) if row is not None else None + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state.""" + self._execute(self._get_upsert_app_state_sql(), (app_name, to_json(state)), commit=True) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state.""" + self._execute(self._get_upsert_user_state_sql(), (app_name, user_id, to_json(state)), commit=True) + + def get_metadata(self, key: str) -> "str | None": + """Return an ADK metadata value.""" + try: + row = self._execute_fetchone( + f"SELECT TOP 1 value FROM {_table_ref(self._metadata_table)} WHERE [key] = ?", (key,) + ) + except SQLSpecError as exc: + if _is_table_missing(exc): + return None + raise + value = _row_value(row, "value", 0) if row is not None else None + return str(value) if value is not None else None + + def set_metadata(self, key: str, value: str) -> None: + """Set an ADK metadata value.""" + self._execute(_get_upsert_metadata_sql(self._metadata_table), (key, value), commit=True) + + def _get_create_sessions_table_sql(self) -> str: + """Return T-SQL DDL for the ADK session table.""" + return _get_create_sessions_table_sql(self._session_table, self._owner_id_column_ddl) + + def _get_create_events_table_sql(self) -> str: + """Return T-SQL DDL for the ADK event table.""" + return _get_create_events_table_sql(self._events_table, self._session_table) + + def _get_create_app_states_table_sql(self) -> str: + """Return T-SQL DDL for the app-scoped state table.""" + return _get_create_app_states_table_sql(self._app_state_table) + + def _get_create_user_states_table_sql(self) -> str: + """Return T-SQL DDL for the user-scoped state table.""" + return _get_create_user_states_table_sql(self._user_state_table) + + def _get_create_metadata_table_sql(self) -> str: + """Return T-SQL DDL for the ADK metadata table.""" + return _get_create_metadata_table_sql(self._metadata_table) + + def _get_seed_metadata_sql(self) -> str: + """Return T-SQL to seed schema-version metadata.""" + return _get_seed_metadata_sql(self._metadata_table) + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {_table_ref(self._app_state_table)}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {_table_ref(self._user_state_table)}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {_table_ref(self._metadata_table)}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {_table_ref(self._events_table)}", + f"DROP TABLE IF EXISTS {_table_ref(self._session_table)}", + ] + + def _get_upsert_app_state_sql(self) -> str: + return _get_upsert_state_sql(self._app_state_table, ("app_name",), ("?",)) + + def _get_upsert_user_state_sql(self) -> str: + return _get_upsert_state_sql(self._user_state_table, ("app_name", "user_id"), ("?", "?")) + + def _get_events_query( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "tuple[str, tuple[Any, ...]]": + return _get_events_query(self._events_table, app_name, user_id, session_id, after_timestamp, limit) + + def _execute_fetchone(self, sql: str, params: "tuple[Any, ...]" = ()) -> "dict[str, Any] | None": + with self._config.provide_session() as driver: + return driver.select_one_or_none(sql, params) + + def _execute_fetchall(self, sql: str, params: "tuple[Any, ...]" = ()) -> "list[dict[str, Any]]": + with self._config.provide_session() as driver: + return driver.select(sql, params) + + def _execute(self, sql: str, params: "tuple[Any, ...]" = (), *, commit: bool = False) -> int: + with self._config.provide_session() as driver: + result = driver.execute(sql, params) + if commit: + driver.commit() + return int(result.rows_affected) + + def _select_count(self, sql: str, params: "tuple[Any, ...]" = ()) -> int: + with self._config.provide_session() as driver: + value = driver.select_value(sql, params) + return int(value or 0) + + +class ArrowOdbcADKMemoryStore(BaseSyncADKMemoryStore["ArrowOdbcConfig"]): + """SQL Server ADK memory store using arrow-odbc.""" + + __slots__ = () + + def create_tables(self) -> None: + """Create the memory table if memory storage is enabled.""" + if not self._enabled: + return + with self._config.provide_session() as driver: + for statement in self._get_create_memory_table_sql(): + driver.execute(statement) + driver.commit() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Insert memory entries, skipping duplicates by event_id.""" + if not self._enabled: + msg = "Memory store is disabled" + raise RuntimeError(msg) + if not entries: + return 0 + + inserted_count = 0 + with self._config.provide_session() as driver: + for entry in entries: + exists = driver.select_one_or_none( + f"SELECT TOP 1 id FROM {_table_ref(self._memory_table)} WHERE event_id = ?", (entry["event_id"],) + ) + if exists is not None: + continue + owner_column = ( + f", {_quote_identifier(self._owner_id_column_name)}" if self._owner_id_column_name else "" + ) + owner_param = ", ?" if self._owner_id_column_name else "" + params: tuple[Any, ...] + if self._owner_id_column_name: + params = (*_memory_insert_params(entry), owner_id) + else: + params = _memory_insert_params(entry) + driver.execute( + f""" + INSERT INTO {_table_ref(self._memory_table)} ( + id, session_id, app_name, user_id, event_id, author, + timestamp, content_json, content_text, metadata_json, inserted_at{owner_column} + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?{owner_param}) + """, + params, + ) + inserted_count += 1 + driver.commit() + return inserted_count + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries with SQL Server LIKE matching.""" + if not self._enabled: + msg = "Memory store is disabled" + raise RuntimeError(msg) + effective_limit = max(0, int(limit if limit is not None else self._max_results)) + if effective_limit == 0: + return [] + rows = self._execute_fetchall( + f""" + SELECT id, session_id, app_name, user_id, event_id, author, + timestamp, content_json, content_text, metadata_json, inserted_at + FROM {_table_ref(self._memory_table)} + WHERE app_name = ? + AND user_id = ? + AND content_text LIKE ? + ORDER BY timestamp DESC + OFFSET 0 ROWS FETCH NEXT {effective_limit} ROWS ONLY + """, + (app_name, user_id, f"%{query}%"), + ) + return [_memory_record_from_row(row) for row in rows] + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + count = self._select_count( + f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._memory_table)} WHERE session_id = ?", (session_id,) + ) + self._execute(f"DELETE FROM {_table_ref(self._memory_table)} WHERE session_id = ?", (session_id,), commit=True) + return count + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than ``days`` days.""" + cutoff = datetime.now(timezone.utc).timestamp() - (days * 86_400) + cutoff_dt = datetime.fromtimestamp(cutoff, tz=timezone.utc) + count = self._select_count( + f"SELECT COUNT(*) AS row_count FROM {_table_ref(self._memory_table)} WHERE inserted_at < ?", + (_format_datetime(cutoff_dt),), + ) + self._execute( + f"DELETE FROM {_table_ref(self._memory_table)} WHERE inserted_at < ?", + (_format_datetime(cutoff_dt),), + commit=True, + ) + return count + + def _get_create_memory_table_sql(self) -> "list[str]": + owner_line = f",\n {self._owner_id_column_ddl}" if self._owner_id_column_ddl else "" + return [ + f""" +IF NOT EXISTS ( + SELECT 1 FROM sys.tables + WHERE name = N'{_escape_sql_literal(self._memory_table)}' + AND schema_id = SCHEMA_ID(N'dbo') +) +BEGIN + CREATE TABLE {_table_ref(self._memory_table)} ( + id NVARCHAR(128) NOT NULL, + session_id NVARCHAR(128) NOT NULL, + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL, + event_id NVARCHAR(128) NOT NULL, + author NVARCHAR(256) NULL, + timestamp DATETIME2(6) NOT NULL, + content_json NVARCHAR(MAX) NOT NULL, + content_text NVARCHAR(MAX) NOT NULL, + metadata_json NVARCHAR(MAX) NULL, + inserted_at DATETIME2(6) NOT NULL{owner_line}, + CONSTRAINT {_constraint_ref("pk", self._memory_table, "id")} PRIMARY KEY (id), + CONSTRAINT {_constraint_ref("uq", self._memory_table, "event_id")} UNIQUE (event_id) + ); +END; +""", + _get_create_index_sql( + self._memory_table, f"idx_{self._memory_table}_app_user_time", "app_name, user_id, timestamp DESC" + ), + _get_create_index_sql(self._memory_table, f"idx_{self._memory_table}_session", "session_id"), + ] + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {_table_ref(self._memory_table)}"] + + def _execute_fetchall(self, sql: str, params: "tuple[Any, ...]" = ()) -> "list[dict[str, Any]]": + with self._config.provide_session() as driver: + return driver.select(sql, params) + + def _execute(self, sql: str, params: "tuple[Any, ...]" = (), *, commit: bool = False) -> int: + with self._config.provide_session() as driver: + result = driver.execute(sql, params) + if commit: + driver.commit() + return int(result.rows_affected) + + def _select_count(self, sql: str, params: "tuple[Any, ...]" = ()) -> int: + with self._config.provide_session() as driver: + value = driver.select_value(sql, params) + return int(value or 0) + + +def _get_session_select_sql(table: str) -> str: + return f""" + SELECT TOP 1 id, app_name, user_id, state, create_time, update_time + FROM {_table_ref(table)} + WHERE app_name = ? AND user_id = ? AND id = ? + """ + + +def _get_create_sessions_table_sql(table: str, owner_id_column_ddl: "str | None") -> str: + owner_line = f",\n {owner_id_column_ddl}" if owner_id_column_ddl else "" + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + row_id UNIQUEIDENTIFIER NOT NULL CONSTRAINT {_constraint_ref("df", table, "row_id")} DEFAULT NEWSEQUENTIALID(), + id NVARCHAR(128) NOT NULL, + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL{owner_line}, + state {JSON_COLUMN_TYPE} NOT NULL, + create_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "create_time")} DEFAULT SYSUTCDATETIME(), + update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), + CONSTRAINT {_constraint_ref("pk", table, "row_id")} PRIMARY KEY (row_id), + CONSTRAINT {_constraint_ref("uq", table, "id")} UNIQUE (id) + ); +END; +{_get_create_index_sql(table, f"idx_{table}_app_user", "app_name, user_id")} +{_get_create_index_sql(table, f"idx_{table}_update_time", "update_time DESC")} +""" + + +def _get_create_events_table_sql(table: str, session_table: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + row_id UNIQUEIDENTIFIER NOT NULL CONSTRAINT {_constraint_ref("df", table, "row_id")} DEFAULT NEWSEQUENTIALID(), + id NVARCHAR(128) NOT NULL, + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL, + session_id NVARCHAR(128) NOT NULL, + invocation_id NVARCHAR(256) NOT NULL, + timestamp DATETIME2(6) NOT NULL, + event_data {JSON_COLUMN_TYPE} NOT NULL, + CONSTRAINT {_constraint_ref("pk", table, "row_id")} PRIMARY KEY (row_id), + CONSTRAINT {_constraint_ref("uq", table, "id")} UNIQUE (id), + CONSTRAINT {_constraint_ref("fk", table, "session")} FOREIGN KEY (session_id) + REFERENCES {_table_ref(session_table)}(id) ON DELETE CASCADE + ); +END; +{_get_create_index_sql(table, f"idx_{table}_scope", "app_name, user_id, session_id, timestamp ASC")} +{_get_create_index_sql(table, f"idx_{table}_session", "session_id, timestamp ASC")} +{_get_create_index_sql(table, f"idx_{table}_invocation", "invocation_id")} +{_get_create_index_sql(table, f"idx_{table}_timestamp", "timestamp ASC")} +""" + + +def _get_create_app_states_table_sql(table: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + app_name NVARCHAR(128) NOT NULL, + state {JSON_COLUMN_TYPE} NOT NULL, + update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), + CONSTRAINT {_constraint_ref("pk", table, "app_name")} PRIMARY KEY (app_name) + ); +END; +""" + + +def _get_create_user_states_table_sql(table: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL, + state {JSON_COLUMN_TYPE} NOT NULL, + update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), + CONSTRAINT {_constraint_ref("pk", table, "app_user")} PRIMARY KEY (app_name, user_id) + ); +END; +""" + + +def _get_create_metadata_table_sql(table: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + [key] NVARCHAR(128) NOT NULL, + value NVARCHAR(512) NOT NULL, + CONSTRAINT {_constraint_ref("pk", table, "key")} PRIMARY KEY ([key]) + ); +END; +""" + + +def _get_create_index_sql(table: str, index_name: str, columns: str) -> str: + return f""" +IF NOT EXISTS ( + SELECT 1 FROM sys.indexes + WHERE name = N'{_escape_sql_literal(index_name)}' + AND object_id = OBJECT_ID(N'{_escape_sql_literal(MSSQL_SCHEMA)}.{_escape_sql_literal(table)}') +) +BEGIN + CREATE INDEX {_quote_identifier(index_name)} ON {_table_ref(table)} ({columns}); +END; +""" + + +def _get_insert_event_sql(table: str) -> str: + return f""" + INSERT INTO {_table_ref(table)} ( + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) + VALUES (?, ?, ?, ?, ?, ?, ?) + """ + + +def _get_upsert_state_sql(table: str, key_columns: "tuple[str, ...]", key_params: "tuple[str, ...]") -> str: + source_columns = ", ".join( + f"{param} AS {_quote_identifier(column)}" for column, param in zip(key_columns, key_params, strict=False) + ) + source_columns = f"{source_columns}, ? AS state" + insert_columns = ", ".join(_quote_identifier(column) for column in (*key_columns, "state", "update_time")) + insert_values = ", ".join(f"source.{_quote_identifier(column)}" for column in (*key_columns, "state")) + match_clause = " AND ".join( + f"target.{_quote_identifier(column)} = source.{_quote_identifier(column)}" for column in key_columns + ) + return f""" + MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target + USING (SELECT {source_columns}) AS source + ON ({match_clause}) + WHEN MATCHED THEN + UPDATE SET state = source.state, update_time = SYSUTCDATETIME() + WHEN NOT MATCHED THEN + INSERT ({insert_columns}) + VALUES ({insert_values}, SYSUTCDATETIME()); + """ + + +def _get_upsert_metadata_sql(table: str) -> str: + return f""" + MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target + USING (SELECT ? AS [key], ? AS value) AS source + ON (target.[key] = source.[key]) + WHEN MATCHED THEN + UPDATE SET value = source.value + WHEN NOT MATCHED THEN + INSERT ([key], value) + VALUES (source.[key], source.value); + """ + + +def _get_seed_metadata_sql(table: str) -> str: + return f""" + MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target + USING (SELECT N'schema_version' AS [key], N'1' AS value) AS source + ON (target.[key] = source.[key]) + WHEN MATCHED THEN + UPDATE SET value = source.value + WHEN NOT MATCHED THEN + INSERT ([key], value) + VALUES (source.[key], source.value); + """ + + +def _get_events_query( + table: str, app_name: str, user_id: str, session_id: str, after_timestamp: "datetime | None", limit: "int | None" +) -> "tuple[str, tuple[Any, ...]]": + top_clause = f"TOP {int(limit)} " if limit is not None else "" + params: list[Any] = [app_name, user_id, session_id] + after_clause = "" + if after_timestamp is not None: + after_clause = " AND timestamp > ?" + params.append(_format_datetime(after_timestamp)) + sql = f""" + SELECT {top_clause}id, app_name, user_id, session_id, invocation_id, timestamp, event_data + FROM {_table_ref(table)} + WHERE app_name = ? AND user_id = ? AND session_id = ?{after_clause} + ORDER BY timestamp ASC + """ + return sql, tuple(params) + + +def _event_insert_params(event_record: EventRecord) -> "tuple[Any, ...]": + return ( + event_record["id"], + event_record["app_name"], + event_record["user_id"], + event_record["session_id"], + event_record["invocation_id"], + _format_datetime(event_record["timestamp"]), + to_json(event_record["event_data"]), + ) + + +def _session_record_from_row(row: Any) -> SessionRecord: + return SessionRecord( + id=str(_row_value(row, "id", 0)), + app_name=str(_row_value(row, "app_name", 1)), + user_id=str(_row_value(row, "user_id", 2)), + state=_json_dict(_row_value(row, "state", 3)), + create_time=_datetime_value(_row_value(row, "create_time", 4)), + update_time=_datetime_value(_row_value(row, "update_time", 5)), + ) + + +def _event_record_from_row(row: Any) -> EventRecord: + return EventRecord( + id=str(_row_value(row, "id", 0)), + app_name=str(_row_value(row, "app_name", 1)), + user_id=str(_row_value(row, "user_id", 2)), + session_id=str(_row_value(row, "session_id", 3)), + invocation_id=str(_row_value(row, "invocation_id", 4)), + timestamp=_datetime_value(_row_value(row, "timestamp", 5)), + event_data=_json_dict(_row_value(row, "event_data", 6)), + ) + + +def _memory_insert_params(entry: MemoryRecord) -> "tuple[Any, ...]": + return ( + entry["id"], + entry["session_id"], + entry["app_name"], + entry["user_id"], + entry["event_id"], + entry["author"], + _format_datetime(entry["timestamp"]), + to_json(entry["content_json"]), + entry["content_text"], + to_json(entry["metadata_json"]) if entry["metadata_json"] is not None else None, + _format_datetime(entry["inserted_at"]), + ) + + +def _memory_record_from_row(row: Any) -> MemoryRecord: + return MemoryRecord( + id=str(_row_value(row, "id", 0)), + session_id=str(_row_value(row, "session_id", 1)), + app_name=str(_row_value(row, "app_name", 2)), + user_id=str(_row_value(row, "user_id", 3)), + event_id=str(_row_value(row, "event_id", 4)), + author=cast("str | None", _row_value(row, "author", 5)), + timestamp=_datetime_value(_row_value(row, "timestamp", 6)), + content_json=_json_dict(_row_value(row, "content_json", 7)), + content_text=str(_row_value(row, "content_text", 8) or ""), + metadata_json=_optional_json_dict(_row_value(row, "metadata_json", 9)), + inserted_at=_datetime_value(_row_value(row, "inserted_at", 10)), + ) + + +def _row_value(row: Any, key: str, index: int) -> Any: + if isinstance(row, dict): + if key in row: + return row[key] + upper_key = key.upper() + if upper_key in row: + return row[upper_key] + return None + if isinstance(row, (list, tuple)) and len(row) > index: + return row[index] + return getattr(row, key, None) + + +def _json_dict(value: Any) -> "dict[str, Any]": + if value is None: + return {} + if isinstance(value, dict): + return cast("dict[str, Any]", value) + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, bytes): + value = value.decode("utf-8") + if isinstance(value, str): + return cast("dict[str, Any]", from_json(value)) + return cast("dict[str, Any]", from_json(str(value))) + + +def _optional_json_dict(value: Any) -> "dict[str, Any] | None": + if value is None: + return None + return _json_dict(value) + + +def _datetime_value(value: Any) -> datetime: + if isinstance(value, datetime): + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, bytes): + value = value.decode("utf-8") + if isinstance(value, str): + normalized = value.replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=timezone.utc) + return parsed.astimezone(timezone.utc) + return datetime.now(timezone.utc) + + +def _format_datetime(value: "datetime | None") -> "str | None": + if value is None: + return None + normalized = value.replace(tzinfo=timezone.utc) if value.tzinfo is None else value.astimezone(timezone.utc) + return normalized.replace(tzinfo=None).isoformat(timespec="microseconds") + + +def _is_table_missing(exc: BaseException) -> bool: + text = str(exc).lower() + return "invalid object name" in text or "42s02" in text or "(208)" in text + + +def _quote_identifier(identifier: str) -> str: + return f"[{identifier.replace(']', ']]')}]" + + +def _table_ref(table: str) -> str: + return f"{_quote_identifier(MSSQL_SCHEMA)}.{_quote_identifier(table)}" + + +def _constraint_ref(prefix: str, table: str, suffix: str) -> str: + return _quote_identifier(f"{prefix}_{table}_{suffix}") + + +def _escape_sql_literal(value: str) -> str: + return value.replace("'", "''") + + +def _raise_session_not_found(session_id: str) -> None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/arrow_odbc/config.py b/sqlspec/adapters/arrow_odbc/config.py index 036286322..aa151dd82 100644 --- a/sqlspec/adapters/arrow_odbc/config.py +++ b/sqlspec/adapters/arrow_odbc/config.py @@ -5,7 +5,13 @@ from typing_extensions import NotRequired from sqlspec.adapters.arrow_odbc._typing import ArrowOdbcConnection, ArrowOdbcSessionContext, arrow_odbc_connect -from sqlspec.adapters.arrow_odbc.core import apply_driver_features, build_connection_config, default_statement_config +from sqlspec.adapters.arrow_odbc.core import ( + apply_driver_features, + build_connection_config, + build_statement_config, + default_statement_config, + resolve_dialect_from_dbms_name, +) from sqlspec.adapters.arrow_odbc.driver import ArrowOdbcDriver from sqlspec.config import ExtensionConfigs, NoPoolSyncConfig from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory @@ -147,7 +153,7 @@ def __init__( connection_config=normalized, connection_instance=connection_instance, migration_config=migration_config, - statement_config=statement_config or default_statement_config, + statement_config=statement_config or _resolve_statement_config(features), driver_features=features, bind_key=bind_key, extension_config=extension_config, @@ -207,3 +213,12 @@ def _close_arrow_odbc_connection(connection: "ArrowOdbcConnection") -> None: """Close connection objects from compatible wrappers when they expose close().""" if isinstance(connection, SupportsCloseProtocol): connection.close() + + +def _resolve_statement_config(features: dict[str, Any]) -> "StatementConfig": + dialect = resolve_dialect_from_dbms_name(str(features.get("dbms_name") or features.get("connection_string") or "")) + if dialect == "sqlite": + return default_statement_config + if dialect == "mssql": + return build_statement_config(dialect="tsql") + return build_statement_config(dialect=dialect) diff --git a/sqlspec/adapters/arrow_odbc/core.py b/sqlspec/adapters/arrow_odbc/core.py index b0ce3b47e..509cecba0 100644 --- a/sqlspec/adapters/arrow_odbc/core.py +++ b/sqlspec/adapters/arrow_odbc/core.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Final from sqlspec.core import DriverParameterProfile, ParameterStyle, build_statement_config_from_profile -from sqlspec.exceptions import ImproperConfigurationError, SQLSpecError +from sqlspec.exceptions import ImproperConfigurationError, SQLParsingError, SQLSpecError from sqlspec.utils.serializers import from_json, to_json from sqlspec.utils.type_converters import build_uuid_coercions @@ -61,6 +61,9 @@ def resolve_dialect_from_dbms_name(dbms_name: str | None) -> str: def create_mapped_exception(exc: Exception) -> Exception: """Map an arrow-odbc exception to SQLSpec's exception hierarchy.""" + message = str(exc) + if "Native error: 102" in message or "Incorrect syntax near" in message: + return SQLParsingError(f"ODBC SQL parsing error. Original error: {exc}") return SQLSpecError(f"ODBC database error. Original error: {exc}") diff --git a/sqlspec/adapters/arrow_odbc/driver.py b/sqlspec/adapters/arrow_odbc/driver.py index 707cd0d15..20d6cbc80 100644 --- a/sqlspec/adapters/arrow_odbc/driver.py +++ b/sqlspec/adapters/arrow_odbc/driver.py @@ -1,8 +1,9 @@ """arrow-odbc sync driver.""" +import re from collections.abc import Iterable, Mapping from itertools import chain -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Final, cast from sqlspec.adapters.arrow_odbc._typing import ArrowOdbcConnection, ArrowOdbcCursor, ArrowOdbcError, ArrowOdbcRawCursor from sqlspec.adapters.arrow_odbc.core import ( @@ -103,6 +104,8 @@ def data_dictionary(self) -> "ArrowOdbcDataDictionary": def dispatch_execute(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + if self._dialect == "mssql": + sql, prepared_parameters = _inline_mssql_pagination_parameters(sql, prepared_parameters) parameters = _odbc_parameters(prepared_parameters) if statement.returns_rows(): @@ -120,7 +123,7 @@ def dispatch_execute(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") -> "E ) cursor.execute(query=sql, parameters=parameters) - return self.create_execution_result(cursor, rowcount_override=0) + return self.create_execution_result(cursor, rowcount_override=-1) def dispatch_execute_many(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") -> "ExecutionResult": msg = "arrow-odbc does not expose a row-oriented executemany API; use bulk_insert_arrow() for Arrow ingestion." @@ -128,6 +131,11 @@ def dispatch_execute_many(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") def dispatch_execute_script(self, cursor: "ArrowOdbcRawCursor", statement: "SQL") -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + if self._dialect == "mssql": + cursor.execute(query=sql, parameters=_odbc_parameters(prepared_parameters)) + return self.create_execution_result( + cursor, rowcount_override=-1, statement_count=1, successful_statements=1, is_script_result=True + ) statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) parameters = _odbc_parameters(prepared_parameters) successful_count = 0 @@ -144,6 +152,14 @@ def collect_rows(self, cursor: "ArrowOdbcRawCursor", fetched: "list[Any]") -> "t def resolve_rowcount(self, cursor: "ArrowOdbcRawCursor") -> int: return 0 + def _connection_in_transaction(self) -> bool: + """Return whether the generic ODBC connection is in an active transaction. + + arrow-odbc does not expose a portable transaction-state API, so stack + execution relies on SQLSpec's explicit transaction management. + """ + return False + def begin(self) -> None: statement = "BEGIN TRANSACTION" if self._dialect == "mssql" else "BEGIN" self.connection.execute(statement) @@ -325,6 +341,33 @@ def _statement_dialect_for(dialect: str) -> str: return dialect +_MSSQL_OFFSET_FETCH_PATTERN = re.compile( + r"OFFSET\s+\?\s+ROWS\s+FETCH\s+(?PNEXT|FIRST)\s+\?\s+ROWS\s+ONLY", re.IGNORECASE +) +_MSSQL_PAGINATION_PARAMETER_COUNT: Final = 2 + + +def _inline_mssql_pagination_parameters(sql: str, parameters: object) -> tuple[str, object]: + match = _MSSQL_OFFSET_FETCH_PATTERN.search(sql) + if ( + match is None + or not isinstance(parameters, (list, tuple)) + or len(parameters) < _MSSQL_PAGINATION_PARAMETER_COUNT + ): + return sql, parameters + + offset_value = _pagination_int(parameters[-2]) + limit_value = _pagination_int(parameters[-1]) + replacement = f"OFFSET {offset_value} ROWS FETCH {match.group('fetch_keyword')} {limit_value} ROWS ONLY" + remaining_parameters = parameters[:-2] + return _MSSQL_OFFSET_FETCH_PATTERN.sub(replacement, sql, count=1), remaining_parameters + + +def _pagination_int(value: object) -> int: + unwrapped = getattr(value, "value", value) + return int(cast("Any", unwrapped)) + + def _unwrap_parameter(value: Any) -> Any: wrapped = getattr(value, "value", value) return None if wrapped is None else str(wrapped) diff --git a/sqlspec/adapters/arrow_odbc/events/__init__.py b/sqlspec/adapters/arrow_odbc/events/__init__.py new file mode 100644 index 000000000..c01bdcc1a --- /dev/null +++ b/sqlspec/adapters/arrow_odbc/events/__init__.py @@ -0,0 +1,5 @@ +"""Event queue store for the arrow-odbc adapter.""" + +from sqlspec.adapters.arrow_odbc.events.store import ArrowOdbcEventQueueStore + +__all__ = ("ArrowOdbcEventQueueStore",) diff --git a/sqlspec/adapters/arrow_odbc/events/store.py b/sqlspec/adapters/arrow_odbc/events/store.py new file mode 100644 index 000000000..c692d0209 --- /dev/null +++ b/sqlspec/adapters/arrow_odbc/events/store.py @@ -0,0 +1,74 @@ +"""arrow-odbc event queue store with T-SQL-specific DDL.""" + +import re + +from sqlspec.adapters.arrow_odbc.config import ArrowOdbcConfig +from sqlspec.extensions.events import BaseEventQueueStore +from sqlspec.utils.text import split_qualified_identifier + +__all__ = ("ArrowOdbcEventQueueStore",) + +_NVARCHAR_MAX_THRESHOLD = 4000 +_QUALIFIED_IDENTIFIER_MIN_PARTS = 2 + + +class ArrowOdbcEventQueueStore(BaseEventQueueStore[ArrowOdbcConfig]): + """Event queue DDL for arrow-odbc SQL Server configs.""" + + __slots__ = () + + def _column_types(self) -> tuple[str, str, str]: + return "NVARCHAR(MAX)", "NVARCHAR(MAX)", "DATETIME2(6)" + + def _string_type(self, length: int) -> str: + if length >= _NVARCHAR_MAX_THRESHOLD: + return "NVARCHAR(MAX)" + return f"NVARCHAR({length})" + + def _integer_type(self) -> str: + return "INT" + + def _timestamp_default(self) -> str: + return "SYSUTCDATETIME()" + + def _wrap_create_statement(self, statement: str, object_type: str) -> str: + if object_type == "table": + match = re.search(r"CREATE TABLE\s+(\S+)", statement, re.IGNORECASE) + if match: + table_name = match.group(1) + return f"IF OBJECT_ID(N'{_object_name(table_name)}', N'U') IS NULL BEGIN {statement}; END" + if object_type == "index": + match = re.search(r"CREATE INDEX\s+(\S+)\s+ON\s+(\S+)", statement, re.IGNORECASE) + if match: + index_name = match.group(1).strip("[]") + table_name = match.group(2) + return ( + "IF NOT EXISTS (SELECT 1 FROM sys.indexes " + f"WHERE name = N'{index_name}' AND object_id = OBJECT_ID(N'{_object_name(table_name)}')) " + f"BEGIN {statement}; END" + ) + return statement + + def _wrap_drop_statement(self, statement: str) -> str: + match = re.search(r"DROP TABLE\s+(\S+)", statement, re.IGNORECASE) + if match: + table_name = match.group(1) + return f"IF OBJECT_ID(N'{_object_name(table_name)}', N'U') IS NOT NULL DROP TABLE {table_name};" + return statement + + +def _split_table_name(table_name: str) -> tuple[str, str]: + parts = split_qualified_identifier(table_name, quote_chars='"') + if len(parts) < _QUALIFIED_IDENTIFIER_MIN_PARTS: + return "dbo", parts[0] if parts else table_name + schema_name = ".".join(parts[:-1]) + return schema_name or "dbo", parts[-1] + + +def _object_name(table_name: str) -> str: + schema_name, bare_table_name = _split_table_name(table_name) + return f"{_quote_bracket_identifier(schema_name)}.{_quote_bracket_identifier(bare_table_name)}" + + +def _quote_bracket_identifier(identifier: str) -> str: + return f"[{identifier.replace(']', ']]')}]" diff --git a/sqlspec/adapters/arrow_odbc/litestar/__init__.py b/sqlspec/adapters/arrow_odbc/litestar/__init__.py new file mode 100644 index 000000000..2d9e042ac --- /dev/null +++ b/sqlspec/adapters/arrow_odbc/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar Store integration for arrow-odbc.""" + +from sqlspec.adapters.arrow_odbc.litestar.store import ArrowOdbcStore + +__all__ = ("ArrowOdbcStore",) diff --git a/sqlspec/adapters/arrow_odbc/litestar/store.py b/sqlspec/adapters/arrow_odbc/litestar/store.py new file mode 100644 index 000000000..c3f889e9e --- /dev/null +++ b/sqlspec/adapters/arrow_odbc/litestar/store.py @@ -0,0 +1,307 @@ +"""arrow-odbc Litestar Store implementation.""" + +import base64 +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.arrow_odbc.config import ArrowOdbcConfig + +__all__ = ("ArrowOdbcStore",) + +_MAX_INLINE_DATA_LENGTH = 3500 + + +class ArrowOdbcStore(BaseSQLSpecStore["ArrowOdbcConfig"]): + """SQL Server-backed session store using arrow-odbc sessions.""" + + __slots__ = () + + def __init__(self, config: "ArrowOdbcConfig") -> None: + super().__init__(config) + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key.""" + return await async_(self._get)(key, renew_for) + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value.""" + await async_(self._set)(key, value, expires_in) + + async def delete(self, key: str) -> None: + """Delete a session by key.""" + await async_(self._delete)(key) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired.""" + return await async_(self._exists)(key) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires.""" + return await async_(self._expires_in)(key) + + async def delete_expired(self) -> int: + """Delete all expired sessions.""" + return await async_(self._delete_expired)() + + def _create_table(self) -> None: + with self._config.provide_session() as driver: + driver.execute_script(self._get_create_table_sql()) + driver.commit() + self._log_table_created() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + with self._config.provide_session() as driver: + row = driver.select_one_or_none( + f"SELECT data, expires_at FROM {self._table_name} WHERE session_id = ?", (key,) + ) + if row is None: + return None + + expires_at = _normalize_utc(_row_value(row, "expires_at", 1)) + if expires_at is not None and expires_at < datetime.now(timezone.utc): + self._delete(key) + return None + + if renew_for is not None and expires_at is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + driver.execute( + f""" + UPDATE {self._table_name} + SET expires_at = ?, updated_at = SYSUTCDATETIME() + WHERE session_id = ? + """, + (_format_datetime(new_expires_at), key), + ) + driver.commit() + + data = _row_value(row, "data", 0) + if data is None: + data = self._get_chunked_data(driver, key) + return _decode_bytes(data) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + data = _encode_bytes(self._value_to_bytes(value)) + expires_at = self._calculate_expires_at(expires_in) + with self._config.provide_session() as driver: + inline_data = data if len(data) <= _MAX_INLINE_DATA_LENGTH else None + existing = driver.select_one_or_none( + f"SELECT session_id FROM {self._table_name} WHERE session_id = ?", (key,) + ) + if existing is None: + driver.execute( + f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES (?, ?, ?) + """, + (key, inline_data, _format_datetime(expires_at)), + ) + else: + driver.execute( + f""" + UPDATE {self._table_name} + SET data = ?, expires_at = ?, updated_at = SYSUTCDATETIME() + WHERE session_id = ? + """, + (inline_data, _format_datetime(expires_at), key), + ) + driver.execute(f"DELETE FROM {self._chunk_table_name} WHERE session_id = ?", (key,)) + if inline_data is None: + for index, chunk in enumerate(_chunk_text(data, _MAX_INLINE_DATA_LENGTH)): + driver.execute( + f""" + INSERT INTO {self._chunk_table_name} (session_id, chunk_index, data) + VALUES (?, ?, ?) + """, + (key, index, chunk), + ) + driver.commit() + + def _delete(self, key: str) -> None: + with self._config.provide_session() as driver: + driver.execute(f"DELETE FROM {self._table_name} WHERE session_id = ?", (key,)) + driver.commit() + + def _delete_all(self) -> None: + with self._config.provide_session() as driver: + driver.execute(f"DELETE FROM {self._table_name}") + driver.commit() + self._log_delete_all() + + def _exists(self, key: str) -> bool: + with self._config.provide_session() as driver: + row = driver.select_one_or_none( + f""" + SELECT 1 AS exists_flag + FROM {self._table_name} + WHERE session_id = ? + AND (expires_at IS NULL OR expires_at > SYSUTCDATETIME()) + """, + (key,), + ) + return row is not None + + def _expires_in(self, key: str) -> "int | None": + with self._config.provide_session() as driver: + row = driver.select_one_or_none(f"SELECT expires_at FROM {self._table_name} WHERE session_id = ?", (key,)) + if row is None: + return None + expires_at = _normalize_utc(_row_value(row, "expires_at", 0)) + if expires_at is None: + return None + remaining = expires_at - datetime.now(timezone.utc) + return max(0, int(remaining.total_seconds())) + + def _delete_expired(self) -> int: + with self._config.provide_session() as driver: + count = int( + driver.select_value( + f""" + SELECT COUNT(*) AS expired_count + FROM {self._table_name} + WHERE expires_at IS NOT NULL + AND expires_at < SYSUTCDATETIME() + """ + ) + or 0 + ) + driver.execute( + f""" + DELETE FROM {self._table_name} + WHERE expires_at IS NOT NULL + AND expires_at < SYSUTCDATETIME() + """ + ) + driver.commit() + if count > 0: + self._log_delete_expired(count) + return count + + @property + def _chunk_table_name(self) -> str: + return f"{self._table_name}_chunks" + + def _get_chunked_data(self, driver: Any, key: str) -> str: + rows = driver.select( + f""" + SELECT data + FROM {self._chunk_table_name} + WHERE session_id = ? + ORDER BY chunk_index + """, + (key,), + ) + return "".join(str(_row_value(row, "data", 0) or "") for row in rows) + + def _get_create_table_sql(self) -> str: + """Get SQL Server CREATE TABLE SQL with idempotent guards.""" + return f""" + IF NOT EXISTS ( + SELECT 1 + FROM sys.tables + WHERE name = N'{self._table_name}' + AND schema_id = SCHEMA_ID(N'dbo') + ) + BEGIN + CREATE TABLE {self._table_name} ( + session_id NVARCHAR(255) PRIMARY KEY, + data NVARCHAR(MAX) NULL, + expires_at DATETIME2(6) NULL, + created_at DATETIME2(6) NOT NULL DEFAULT SYSUTCDATETIME(), + updated_at DATETIME2(6) NOT NULL DEFAULT SYSUTCDATETIME() + ); + + CREATE INDEX IX_{self._table_name}_expires_at + ON {self._table_name}(expires_at) + WHERE expires_at IS NOT NULL; + END; + + IF NOT EXISTS ( + SELECT 1 + FROM sys.tables + WHERE name = N'{self._chunk_table_name}' + AND schema_id = SCHEMA_ID(N'dbo') + ) + BEGIN + CREATE TABLE {self._chunk_table_name} ( + session_id NVARCHAR(255) NOT NULL, + chunk_index INT NOT NULL, + data NVARCHAR(3500) NOT NULL, + CONSTRAINT PK_{self._chunk_table_name} PRIMARY KEY (session_id, chunk_index), + CONSTRAINT FK_{self._chunk_table_name}_session FOREIGN KEY (session_id) + REFERENCES {self._table_name}(session_id) ON DELETE CASCADE + ); + END; + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get SQL Server DROP TABLE statements.""" + return [ + f"IF OBJECT_ID(N'dbo.{self._chunk_table_name}', N'U') IS NOT NULL DROP TABLE dbo.{self._chunk_table_name};", + f"IF OBJECT_ID(N'dbo.{self._table_name}', N'U') IS NOT NULL DROP TABLE dbo.{self._table_name};", + ] + + +def _row_value(row: object, key: str, index: int) -> Any: + """Return a value from dict-like or sequence-like driver rows.""" + if isinstance(row, dict): + if key in row: + return row[key] + upper_key = key.upper() + if upper_key in row: + return row[upper_key] + return None + if isinstance(row, (list, tuple)) and len(row) > index: + return row[index] + return getattr(row, key, None) + + +def _normalize_utc(value: Any) -> "datetime | None": + if value is None: + return None + if not isinstance(value, datetime): + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +def _format_datetime(value: "datetime | None") -> "str | None": + if value is None: + return None + normalized = _normalize_utc(value) + if normalized is None: + return None + return normalized.replace(tzinfo=None).isoformat(timespec="microseconds") + + +def _encode_bytes(value: bytes) -> str: + return base64.b64encode(value).decode("ascii") + + +def _decode_bytes(value: Any) -> bytes: + if value is None: + return b"" + if isinstance(value, bytes): + value = value.decode("ascii") + if isinstance(value, bytearray): + value = bytes(value).decode("ascii") + if not isinstance(value, str): + value = str(value) + return base64.b64decode(value.encode("ascii")) + + +def _chunk_text(value: str, size: int) -> "list[str]": + return [value[index : index + size] for index in range(0, len(value), size)] diff --git a/sqlspec/extensions/events/_queue.py b/sqlspec/extensions/events/_queue.py index 863340b87..bd9a7cb23 100644 --- a/sqlspec/extensions/events/_queue.py +++ b/sqlspec/extensions/events/_queue.py @@ -90,9 +90,10 @@ def _build_insert_sql(self) -> str: return f"INSERT INTO {self._table_name} ({columns}) VALUES ({values})" def _build_select_sql(self, select_for_update: bool, skip_locked: bool) -> str: - limit_clause = " FETCH FIRST 1 ROWS ONLY" if "oracle" in self._dialect else " LIMIT 1" + top_clause = "TOP 1 " if self._uses_tsql_limit() else "" + limit_clause = self._row_limit_clause() base = ( - f"SELECT event_id, channel, payload_json, metadata_json, attempts, available_at, lease_expires_at, created_at " + f"SELECT {top_clause}event_id, channel, payload_json, metadata_json, attempts, available_at, lease_expires_at, created_at " f"FROM {self._table_name} " "WHERE channel = :channel AND available_at <= :available_cutoff AND (" "status = :pending_status OR (status = :leased_status AND (lease_expires_at IS NULL OR lease_expires_at <= :lease_cutoff))" @@ -106,12 +107,23 @@ def _build_select_sql(self, select_for_update: bool, skip_locked: bool) -> str: return base + limit_clause + locking_clause def _build_select_by_id_sql(self) -> str: - limit_clause = " FETCH FIRST 1 ROWS ONLY" if "oracle" in self._dialect else " LIMIT 1" + top_clause = "TOP 1 " if self._uses_tsql_limit() else "" + limit_clause = self._row_limit_clause() return ( - f"SELECT event_id, channel, payload_json, metadata_json, attempts, available_at, lease_expires_at, created_at " + f"SELECT {top_clause}event_id, channel, payload_json, metadata_json, attempts, available_at, lease_expires_at, created_at " f"FROM {self._table_name} WHERE event_id = :event_id" + limit_clause ) + def _uses_tsql_limit(self) -> bool: + return self._dialect in {"mssql", "tsql"} or "sql server" in self._dialect + + def _row_limit_clause(self) -> str: + if self._uses_tsql_limit(): + return "" + if "oracle" in self._dialect: + return " FETCH FIRST 1 ROWS ONLY" + return " LIMIT 1" + def _build_claim_sql(self) -> str: return ( f"UPDATE {self._table_name} SET status = :claimed_status, lease_expires_at = :lease_expires_at, attempts = attempts + 1 " diff --git a/tests/conftest.py b/tests/conftest.py index 71a7bb032..90b78b09a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ "pytest_databases.docker.postgres", "pytest_databases.docker.oracle", "pytest_databases.docker.mysql", + "pytest_databases.docker.mssql", "pytest_databases.docker.bigquery", "pytest_databases.docker.spanner", "pytest_databases.docker.gizmosql", diff --git a/tests/integration/adapters/contracts/README.md b/tests/integration/adapters/contracts/README.md index 19f1c99f5..95864384c 100644 --- a/tests/integration/adapters/contracts/README.md +++ b/tests/integration/adapters/contracts/README.md @@ -63,7 +63,7 @@ Every contract is the cross product of three things: | `test_statement_inputs_contract.py` | Statement input variants (raw SQL, `SQL(...)` objects, filtered statements, loader input). | | `test_result_contract.py` | `SQLResult` API (`get_first`/`get_count`/`is_empty`/`one_or_none`). | | `test_query_contract.py` | Filters (`InCollection`/`LimitOffset`/`OrderBy`/`Search`) and complex queries (joins/subqueries/aggregates/CTEs). | -| `test_execute_many_contract.py` | `execute_many` mutation/input variants and per-adapter specifics. | +| `test_execute_many_contract.py` | `execute_many` mutation/input variants and per-adapter specifics; skipped for active bulk-only cases. | | `test_explain_contract.py` | `EXPLAIN` plans (gated by `supports_explain`). | | `test_arrow_contract.py` | Arrow result export (gated by `supports_arrow`). | | `test_script_error_contract.py` | `execute_script` and script error handling. | @@ -84,7 +84,8 @@ exactly as it did before the flag existed. Complete flag set: -- **Result / IO**: `supports_arrow`, `supports_explain`, `supports_storage_bridge`, `supports_copy` +- **Result / IO**: `supports_arrow`, `supports_arrow_streaming`, `supports_native_arrow`, + `supports_storage_bridge`, `supports_native_bulk_ingest`, `supports_copy` - **Statements**: `supports_execute_many`, `supports_execute_script`, `supports_filtered_statement`, `supports_loader_input`, `supports_merge`, `supports_returning`, `supports_for_update` - **Types / codecs**: `supports_json`, `supports_json_native`, `supports_arrays`, @@ -190,7 +191,7 @@ deliberately leaves **irreducible** per-adapter tests in place. Do not fold thes pool sizing, read-only/PRAGMA settings) — the shared "data persists across pooled sessions" guarantee is contracted, but adapter-internal assertions stay local. - Extension/vector detection deferred to its own chapter (e.g. pgvector/paradedb "not enabled"). -- Deferred adapters (spanner, arrow_odbc, mssql_python) until their cases move from +- Deferred adapters (spanner, mssql_python) until their cases move from `DEFERRED_DRIVER_CASES` to active rows. ## Adding A Case diff --git a/tests/integration/adapters/contracts/_adk_cases.py b/tests/integration/adapters/contracts/_adk_cases.py index 6528f97bf..b2b114bb0 100644 --- a/tests/integration/adapters/contracts/_adk_cases.py +++ b/tests/integration/adapters/contracts/_adk_cases.py @@ -7,8 +7,11 @@ from tests.integration.adapters.contracts._cases import ( ADBC_MARK, + ARROW_ODBC_MARK, COCKROACH_XDIST_MARK, DUCKDB_XDIST_MARK, + MSSQL_MARK, + MSSQL_XDIST_MARK, MYSQL_XDIST_MARK, ORACLE_XDIST_MARK, POSTGRES_XDIST_MARK, @@ -95,6 +98,12 @@ class AdkStoreCaseContext: AdkStoreCase( "adbc-postgres", "adk_store_adbc_postgres", "adbc", marks=(ADBC_MARK, POSTGRES_XDIST_MARK, pytest.mark.anyio) ), + AdkStoreCase( + "arrow-odbc", + "adk_store_arrow_odbc_mssql", + "arrow_odbc", + marks=(MSSQL_MARK, MSSQL_XDIST_MARK, ARROW_ODBC_MARK, pytest.mark.anyio), + ), ) ADK_STORE_PARAMS = tuple(pytest.param(case, id=case.id, marks=case.marks) for case in ADK_STORE_CASES) diff --git a/tests/integration/adapters/contracts/_cases.py b/tests/integration/adapters/contracts/_cases.py index 97d553842..ea3d420cc 100644 --- a/tests/integration/adapters/contracts/_cases.py +++ b/tests/integration/adapters/contracts/_cases.py @@ -11,6 +11,7 @@ from tests.integration.adapters.contracts._schema import ( DEFAULT_CONTRACT_TABLE, DUCKDB_CONTRACT_TABLE, + MSSQL_CONTRACT_TABLE, MYSQL_CONTRACT_TABLE, ORACLE_CONTRACT_TABLE, POSTGRES_CONTRACT_TABLE, @@ -88,9 +89,12 @@ class DriverCaseContext: SQLITE_XDIST_MARK = pytest.mark.xdist_group("sqlite") DUCKDB_XDIST_MARK = pytest.mark.xdist_group("duckdb") MYSQL_XDIST_MARK = pytest.mark.xdist_group("mysql") +MSSQL_XDIST_MARK = pytest.mark.xdist_group("mssql") POSTGRES_XDIST_MARK = pytest.mark.xdist_group("postgres") COCKROACH_XDIST_MARK = pytest.mark.xdist_group("cockroachdb") ADBC_MARK = pytest.mark.adbc +ARROW_ODBC_MARK = pytest.mark.arrow_odbc +MSSQL_MARK = pytest.mark.mssql ORACLE_XDIST_MARK = pytest.mark.xdist_group("oracle") BIGQUERY_MARK = pytest.mark.bigquery BIGQUERY_XDIST_MARK = pytest.mark.xdist_group("bigquery") @@ -338,6 +342,23 @@ class DriverCaseContext: "streaming_native:oracledb", ), ), + DriverCase( + id="arrow-odbc-sync", + fixture_name="contract_arrow_odbc_mssql_driver", + adapter="arrow_odbc", + dialect="mssql", + mode="sync", + marks=(MSSQL_MARK, MSSQL_XDIST_MARK, ARROW_ODBC_MARK), + table=MSSQL_CONTRACT_TABLE, + supports_arrow=True, + supports_arrow_streaming=True, + supports_native_arrow=True, + supports_native_bulk_ingest=True, + supports_execute_many=False, + supports_load_from_records=False, + supports_exception_translation=False, + deviations=("execute-rows-affected-unavailable",), + ), DriverCase( id="bigquery-sync", fixture_name="contract_bigquery_driver", @@ -644,15 +665,6 @@ class DriverCaseContext: ) DEFERRED_DRIVER_CASES = ( - DriverCase( - "arrow-odbc-sync", - "", - "arrow_odbc", - "odbc", - "sync", - integration_status="deferred", - reason="No active integration fixture exists for arrow_odbc.", - ), DriverCase( "mssql-python-sync", "", diff --git a/tests/integration/adapters/contracts/_events_cases.py b/tests/integration/adapters/contracts/_events_cases.py index 732430416..1b0247066 100644 --- a/tests/integration/adapters/contracts/_events_cases.py +++ b/tests/integration/adapters/contracts/_events_cases.py @@ -7,7 +7,10 @@ from _pytest.mark.structures import Mark, MarkDecorator from tests.integration.adapters.contracts._cases import ( + ARROW_ODBC_MARK, DUCKDB_XDIST_MARK, + MSSQL_MARK, + MSSQL_XDIST_MARK, MYSQL_XDIST_MARK, POSTGRES_XDIST_MARK, SQLITE_XDIST_MARK, @@ -37,6 +40,13 @@ class EventsCaseContext: SYNC_EVENTS_CASES = ( EventsCase("sqlite-sync", "events_config_sqlite", "sqlite", "sync", marks=(SQLITE_XDIST_MARK,)), EventsCase("duckdb-sync", "events_config_duckdb", "duckdb", "sync", marks=(DUCKDB_XDIST_MARK,)), + EventsCase( + "arrow-odbc-sync", + "events_config_arrow_odbc_mssql", + "arrow_odbc", + "sync", + marks=(MSSQL_MARK, MSSQL_XDIST_MARK, ARROW_ODBC_MARK), + ), EventsCase("pymysql-sync", "events_config_pymysql", "pymysql", "sync", marks=(MYSQL_XDIST_MARK,)), EventsCase( "psycopg-sync", diff --git a/tests/integration/adapters/contracts/_schema.py b/tests/integration/adapters/contracts/_schema.py index f7fa1e17c..9bee4385e 100644 --- a/tests/integration/adapters/contracts/_schema.py +++ b/tests/integration/adapters/contracts/_schema.py @@ -87,6 +87,27 @@ class ContractTable: ) +MSSQL_CONTRACT_TABLE = ContractTable( + name="contract_items", + create_sql=""" + CREATE TABLE contract_items ( + id INT IDENTITY(1,1) PRIMARY KEY, + name NVARCHAR(255) NOT NULL, + value INT NOT NULL, + note NVARCHAR(4000) + ) + """, + pooling_create_sql="CREATE TABLE {table} (id INT PRIMARY KEY, value NVARCHAR(50))", + delete_sql="DELETE FROM contract_items", + insert_named_sql="INSERT INTO contract_items (name, value, note) VALUES (:name, :value, :note)", + insert_qmark_sql="INSERT INTO contract_items (name, value, note) VALUES (?, ?, ?)", + select_by_name_named_sql="SELECT name, value, note FROM contract_items WHERE name = :name", + select_by_name_qmark_sql="SELECT name, value, note FROM contract_items WHERE name = ?", + select_count_sql="SELECT COUNT(*) AS count FROM contract_items", + select_ordered_sql="SELECT name, value, note FROM contract_items ORDER BY value", +) + + def build_bigquery_contract_table(table_name: str) -> ContractTable: """Build a BigQuery ContractTable for a fully-qualified table identifier. diff --git a/tests/integration/adapters/contracts/_store_cases.py b/tests/integration/adapters/contracts/_store_cases.py index 3483100e2..92e9ad593 100644 --- a/tests/integration/adapters/contracts/_store_cases.py +++ b/tests/integration/adapters/contracts/_store_cases.py @@ -6,7 +6,10 @@ from _pytest.mark.structures import Mark, MarkDecorator from tests.integration.adapters.contracts._cases import ( + ARROW_ODBC_MARK, DUCKDB_XDIST_MARK, + MSSQL_MARK, + MSSQL_XDIST_MARK, MYSQL_XDIST_MARK, POSTGRES_XDIST_MARK, SQLITE_XDIST_MARK, @@ -35,6 +38,12 @@ class StoreCaseContext: StoreCase("sqlite", "contract_sqlite_store", "sqlite", marks=(SQLITE_XDIST_MARK, pytest.mark.anyio)), StoreCase("aiosqlite", "contract_aiosqlite_store", "aiosqlite", marks=(SQLITE_XDIST_MARK, pytest.mark.anyio)), StoreCase("duckdb", "contract_duckdb_store", "duckdb", marks=(DUCKDB_XDIST_MARK, pytest.mark.anyio)), + StoreCase( + "arrow-odbc", + "contract_arrow_odbc_store", + "arrow_odbc", + marks=(MSSQL_MARK, MSSQL_XDIST_MARK, ARROW_ODBC_MARK, pytest.mark.anyio), + ), StoreCase("asyncpg", "contract_asyncpg_store", "asyncpg", marks=(POSTGRES_XDIST_MARK, pytest.mark.anyio)), StoreCase("psqlpy", "contract_psqlpy_store", "psqlpy", marks=(POSTGRES_XDIST_MARK, pytest.mark.anyio)), StoreCase( diff --git a/tests/integration/adapters/contracts/behaviors.py b/tests/integration/adapters/contracts/behaviors.py index 72b274b74..5d68500f6 100644 --- a/tests/integration/adapters/contracts/behaviors.py +++ b/tests/integration/adapters/contracts/behaviors.py @@ -230,18 +230,50 @@ def _should_assert_execute_rows_affected(case: DriverCase) -> bool: return "execute-rows-affected-unavailable" not in case.deviations +def _sqlglot_dialect(case: DriverCase) -> str: + return "tsql" if case.dialect == "mssql" else case.dialect + + +def _contract_rows_to_arrow(rows: tuple[ContractRow, ...]) -> Any: + import pyarrow as pa + + return pa.table({ + "name": pa.array([row.name for row in rows], type=pa.string()), + "value": pa.array([row.value for row in rows], type=pa.int64()), + "note": pa.array([None if row.note is None else str(row.note) for row in rows], type=pa.string()), + }) + + def _seed_sync( - driver: SyncContractDriver, rows: tuple[ContractRow, ...], table: ContractTable = DEFAULT_CONTRACT_TABLE + driver: SyncContractDriver, + rows: tuple[ContractRow, ...], + table: ContractTable = DEFAULT_CONTRACT_TABLE, + case: DriverCase | None = None, ) -> None: if rows: + if case is not None and not case.supports_execute_many: + if not case.supports_native_bulk_ingest: + pytest.skip(f"{case.adapter} cannot seed contract rows without execute_many or native bulk ingest") + driver.load_from_arrow(table.name, _contract_rows_to_arrow(rows), overwrite=True) + driver.commit() + return driver.execute_many(table.insert_qmark_sql, _row_parameters(rows)) driver.commit() async def _seed_async( - driver: AsyncContractDriver, rows: tuple[ContractRow, ...], table: ContractTable = DEFAULT_CONTRACT_TABLE + driver: AsyncContractDriver, + rows: tuple[ContractRow, ...], + table: ContractTable = DEFAULT_CONTRACT_TABLE, + case: DriverCase | None = None, ) -> None: if rows: + if case is not None and not case.supports_execute_many: + if not case.supports_native_bulk_ingest: + pytest.skip(f"{case.adapter} cannot seed contract rows without execute_many or native bulk ingest") + await driver.load_from_arrow(table.name, _contract_rows_to_arrow(rows), overwrite=True) + await driver.commit() + return await driver.execute_many(table.insert_qmark_sql, _row_parameters(rows)) await driver.commit() @@ -347,7 +379,7 @@ def assert_sync_streaming_unsupported_contract(driver: object, case: DriverCase) table = case.table sync_driver.execute(table.delete_sql) sync_driver.commit() - _seed_sync(sync_driver, (ContractRow("a", 1), ContractRow("b", 2)), table) + _seed_sync(sync_driver, (ContractRow("a", 1), ContractRow("b", 2)), table, case) with pytest.raises(ImproperConfigurationError): sync_driver.select_stream(table.select_ordered_sql, native_only=True) @@ -365,7 +397,7 @@ async def assert_async_streaming_unsupported_contract(driver: object, case: Driv table = case.table await async_driver.execute(table.delete_sql) await async_driver.commit() - await _seed_async(async_driver, (ContractRow("a", 1), ContractRow("b", 2)), table) + await _seed_async(async_driver, (ContractRow("a", 1), ContractRow("b", 2)), table, case) with pytest.raises(ImproperConfigurationError): async_driver.select_stream(table.select_ordered_sql, native_only=True) @@ -688,7 +720,7 @@ def assert_sync_filter_contract(driver: object, case: DriverCase) -> None: """Assert sync drivers apply OrderBy/LimitOffset, InCollection, and Search filters.""" sync_driver = cast("SyncContractDriver", driver) table = case.table - _seed_sync(sync_driver, _FILTER_SEED_ROWS, table) + _seed_sync(sync_driver, _FILTER_SEED_ROWS, table, case) base = f"SELECT name, value FROM {table.name}" paged = sync_driver.execute(base, OrderByFilter("value", "desc"), LimitOffsetFilter(limit=2, offset=1)) @@ -706,7 +738,7 @@ async def assert_async_filter_contract(driver: object, case: DriverCase) -> None """Assert async drivers apply OrderBy/LimitOffset, InCollection, and Search filters.""" async_driver = cast("AsyncContractDriver", driver) table = case.table - await _seed_async(async_driver, _FILTER_SEED_ROWS, table) + await _seed_async(async_driver, _FILTER_SEED_ROWS, table, case) base = f"SELECT name, value FROM {table.name}" paged = await async_driver.execute(base, OrderByFilter("value", "desc"), LimitOffsetFilter(limit=2, offset=1)) @@ -728,7 +760,7 @@ def assert_sync_complex_query_contract(driver: object, case: DriverCase) -> None pytest.skip(f"{case.adapter} emulator does not support grouped aggregation / correlated subqueries") sync_driver = cast("SyncContractDriver", driver) table = case.table - _seed_sync(sync_driver, _GROUPED_SEED_ROWS, table) + _seed_sync(sync_driver, _GROUPED_SEED_ROWS, table, case) grouped = sync_driver.execute( f"SELECT value, COUNT(*) AS count FROM {table.name} GROUP BY value HAVING COUNT(*) >= 2 ORDER BY value" @@ -747,7 +779,7 @@ async def assert_async_complex_query_contract(driver: object, case: DriverCase) pytest.skip(f"{case.adapter} emulator does not support grouped aggregation / correlated subqueries") async_driver = cast("AsyncContractDriver", driver) table = case.table - await _seed_async(async_driver, _GROUPED_SEED_ROWS, table) + await _seed_async(async_driver, _GROUPED_SEED_ROWS, table, case) grouped = await async_driver.execute( f"SELECT value, COUNT(*) AS count FROM {table.name} GROUP BY value HAVING COUNT(*) >= 2 ORDER BY value" @@ -820,6 +852,8 @@ async def assert_async_statement_stack_contract(driver: object, case: DriverCase def assert_sync_execute_many_contract(driver: object, case: DriverCase) -> None: """Assert sync execute-many behavior for a driver case.""" + if not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") sync_driver = cast("SyncContractDriver", driver) table = case.table @@ -840,6 +874,8 @@ def assert_sync_execute_many_contract(driver: object, case: DriverCase) -> None: async def assert_async_execute_many_contract(driver: object, case: DriverCase) -> None: """Assert async execute-many behavior for a driver case.""" + if not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") async_driver = cast("AsyncContractDriver", driver) table = case.table @@ -860,6 +896,8 @@ async def assert_async_execute_many_contract(driver: object, case: DriverCase) - def assert_sync_execute_many_mutation_contract(driver: object, case: DriverCase) -> None: """Assert sync drivers batch insert, update, and delete with accurate row counts.""" + if not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") sync_driver = cast("SyncContractDriver", driver) table = case.table @@ -880,6 +918,8 @@ def assert_sync_execute_many_mutation_contract(driver: object, case: DriverCase) async def assert_async_execute_many_mutation_contract(driver: object, case: DriverCase) -> None: """Assert async drivers batch insert, update, and delete with accurate row counts.""" + if not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") async_driver = cast("AsyncContractDriver", driver) table = case.table @@ -900,6 +940,8 @@ async def assert_async_execute_many_mutation_contract(driver: object, case: Driv def assert_sync_execute_many_input_contract(driver: object, case: DriverCase) -> None: """Assert sync drivers batch a large sequence and an is_many SQL object.""" + if not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") sync_driver = cast("SyncContractDriver", driver) table = case.table @@ -918,6 +960,8 @@ def assert_sync_execute_many_input_contract(driver: object, case: DriverCase) -> async def assert_async_execute_many_input_contract(driver: object, case: DriverCase) -> None: """Assert async drivers batch a large sequence and an is_many SQL object.""" + if not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") async_driver = cast("AsyncContractDriver", driver) table = case.table @@ -3092,9 +3136,9 @@ def assert_sync_custom_type_adapters_contract(make_config: SyncConfigFactory, ca def assert_sync_statement_input_contract(driver: object, case: DriverCase, input_case: StatementInputCase) -> None: """Assert sync drivers return equivalent rows for one statement input shape.""" sync_driver = cast("SyncContractDriver", driver) - _seed_sync(sync_driver, input_case.setup_rows, case.table) + _seed_sync(sync_driver, input_case.setup_rows, case.table, case) - statement = input_case.statement_factory(case.table.name, case.dialect) + statement = input_case.statement_factory(case.table.name, _sqlglot_dialect(case)) result = _execute_sync(sync_driver, statement, input_case.parameters) assert_result_data(result, input_case.expected_data) @@ -3107,9 +3151,9 @@ async def assert_async_statement_input_contract( ) -> None: """Assert async drivers return equivalent rows for one statement input shape.""" async_driver = cast("AsyncContractDriver", driver) - await _seed_async(async_driver, input_case.setup_rows, case.table) + await _seed_async(async_driver, input_case.setup_rows, case.table, case) - statement = input_case.statement_factory(case.table.name, case.dialect) + statement = input_case.statement_factory(case.table.name, _sqlglot_dialect(case)) result = await _execute_async(async_driver, statement, input_case.parameters) assert_result_data(result, input_case.expected_data) @@ -3120,7 +3164,7 @@ async def assert_async_statement_input_contract( def assert_sync_parameter_contract(driver: object, case: DriverCase, parameter_case: ParameterProfileCase) -> None: """Assert sync drivers bind one parameter profile case correctly.""" sync_driver = cast("SyncContractDriver", driver) - _seed_sync(sync_driver, parameter_case.setup_rows, case.table) + _seed_sync(sync_driver, parameter_case.setup_rows, case.table, case) result = _execute_sync(sync_driver, _with_table(parameter_case.statement, case.table), parameter_case.parameters) if parameter_case.expected_rows_affected is not None and _should_assert_execute_rows_affected(case): @@ -3141,7 +3185,7 @@ async def assert_async_parameter_contract( ) -> None: """Assert async drivers bind one parameter profile case correctly.""" async_driver = cast("AsyncContractDriver", driver) - await _seed_async(async_driver, parameter_case.setup_rows, case.table) + await _seed_async(async_driver, parameter_case.setup_rows, case.table, case) result = await _execute_async( async_driver, _with_table(parameter_case.statement, case.table), parameter_case.parameters @@ -3163,8 +3207,10 @@ def assert_sync_parameter_style_contract( driver: object, case: DriverCase, parameter_style_case: ParameterStyleCase ) -> None: """Assert sync drivers bind one parameter style case correctly.""" + if parameter_style_case.method == "execute_many" and not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") sync_driver = cast("SyncContractDriver", driver) - _seed_sync(sync_driver, parameter_style_case.setup_rows, case.table) + _seed_sync(sync_driver, parameter_style_case.setup_rows, case.table, case) style_statement = _with_table(parameter_style_case.statement, case.table) if parameter_style_case.method == "execute_many": @@ -3194,8 +3240,10 @@ async def assert_async_parameter_style_contract( driver: object, case: DriverCase, parameter_style_case: ParameterStyleCase ) -> None: """Assert async drivers bind one parameter style case correctly.""" + if parameter_style_case.method == "execute_many" and not case.supports_execute_many: + pytest.skip(f"{case.adapter} has no verified execute_many support") async_driver = cast("AsyncContractDriver", driver) - await _seed_async(async_driver, parameter_style_case.setup_rows, case.table) + await _seed_async(async_driver, parameter_style_case.setup_rows, case.table, case) style_statement = _with_table(parameter_style_case.statement, case.table) if parameter_style_case.method == "execute_many": @@ -3225,7 +3273,9 @@ def assert_sync_result_contract(driver: object, case: DriverCase) -> None: """Assert sync drivers expose common result helper behavior.""" sync_driver = cast("SyncContractDriver", driver) table = case.table - _seed_sync(sync_driver, (ContractRow("result1", 10), ContractRow("result2", 20), ContractRow("result3", 30)), table) + _seed_sync( + sync_driver, (ContractRow("result1", 10), ContractRow("result2", 20), ContractRow("result3", 30)), table, case + ) result = assert_sql_result(sync_driver.execute(table.select_ordered_sql)) assert result.get_first() == {"name": "result1", "value": 10, "note": None} @@ -3249,7 +3299,7 @@ async def assert_async_result_contract(driver: object, case: DriverCase) -> None async_driver = cast("AsyncContractDriver", driver) table = case.table await _seed_async( - async_driver, (ContractRow("result1", 10), ContractRow("result2", 20), ContractRow("result3", 30)), table + async_driver, (ContractRow("result1", 10), ContractRow("result2", 20), ContractRow("result3", 30)), table, case ) result = assert_sql_result(await async_driver.execute(table.select_ordered_sql)) @@ -3328,7 +3378,7 @@ def assert_sync_explain_contract(driver: object, case: DriverCase, explain_case: if not case.supports_explain: pytest.skip(_explain_skip_reason(case)) sync_driver = cast("SyncContractDriver", driver) - result = assert_sql_result(sync_driver.execute(explain_case.build(case.table, case.dialect))) + result = assert_sql_result(sync_driver.execute(explain_case.build(case.table, _sqlglot_dialect(case)))) assert result.data is not None @@ -3337,7 +3387,7 @@ async def assert_async_explain_contract(driver: object, case: DriverCase, explai if not case.supports_explain: pytest.skip(_explain_skip_reason(case)) async_driver = cast("AsyncContractDriver", driver) - result = assert_sql_result(await async_driver.execute(explain_case.build(case.table, case.dialect))) + result = assert_sql_result(await async_driver.execute(explain_case.build(case.table, _sqlglot_dialect(case)))) assert result.data is not None @@ -3460,7 +3510,7 @@ def assert_sync_arrow_contract(driver: object, case: DriverCase) -> None: sync_driver = cast("SyncContractDriver", driver) table = case.table - _seed_sync(sync_driver, (ContractRow("a", 1), ContractRow("b", 2), ContractRow("c", 3)), table) + _seed_sync(sync_driver, (ContractRow("a", 1), ContractRow("b", 2), ContractRow("c", 3)), table, case) table_result = sync_driver.select_to_arrow(table.select_ordered_sql) assert isinstance(table_result.data, pa.Table) @@ -3490,7 +3540,7 @@ async def assert_async_arrow_contract(driver: object, case: DriverCase) -> None: async_driver = cast("AsyncContractDriver", driver) table = case.table - await _seed_async(async_driver, (ContractRow("a", 1), ContractRow("b", 2), ContractRow("c", 3)), table) + await _seed_async(async_driver, (ContractRow("a", 1), ContractRow("b", 2), ContractRow("c", 3)), table, case) table_result = await async_driver.select_to_arrow(table.select_ordered_sql) assert isinstance(table_result.data, pa.Table) @@ -3606,14 +3656,14 @@ def assert_sync_arrow_extras_contract(driver: object, case: DriverCase) -> None: sync_driver = cast("SyncContractDriver", driver) table = case.table - _seed_sync(sync_driver, (ContractRow("a", 1, None), ContractRow("b", 2, "noted")), table) + _seed_sync(sync_driver, (ContractRow("a", 1, None), ContractRow("b", 2, "noted")), table, case) null_result = sync_driver.select_to_arrow(table.select_ordered_sql) assert isinstance(null_result.data, pa.Table) assert null_result.data.column("note").to_pylist() == [None, "noted"] sync_driver.execute(table.delete_sql) sync_driver.commit() - _seed_sync(sync_driver, tuple(ContractRow(f"n{i}", i) for i in range(1, _ARROW_LARGE_ROW_COUNT + 1)), table) + _seed_sync(sync_driver, tuple(ContractRow(f"n{i}", i) for i in range(1, _ARROW_LARGE_ROW_COUNT + 1)), table, case) large_result = sync_driver.select_to_arrow(table.select_ordered_sql) assert large_result.rows_affected == _ARROW_LARGE_ROW_COUNT assert sum(large_result.data.column("value").to_pylist()) == sum(range(1, _ARROW_LARGE_ROW_COUNT + 1)) @@ -3628,14 +3678,16 @@ async def assert_async_arrow_extras_contract(driver: object, case: DriverCase) - async_driver = cast("AsyncContractDriver", driver) table = case.table - await _seed_async(async_driver, (ContractRow("a", 1, None), ContractRow("b", 2, "noted")), table) + await _seed_async(async_driver, (ContractRow("a", 1, None), ContractRow("b", 2, "noted")), table, case) null_result = await async_driver.select_to_arrow(table.select_ordered_sql) assert isinstance(null_result.data, pa.Table) assert null_result.data.column("note").to_pylist() == [None, "noted"] await async_driver.execute(table.delete_sql) await async_driver.commit() - await _seed_async(async_driver, tuple(ContractRow(f"n{i}", i) for i in range(1, _ARROW_LARGE_ROW_COUNT + 1)), table) + await _seed_async( + async_driver, tuple(ContractRow(f"n{i}", i) for i in range(1, _ARROW_LARGE_ROW_COUNT + 1)), table, case + ) large_result = await async_driver.select_to_arrow(table.select_ordered_sql) assert large_result.rows_affected == _ARROW_LARGE_ROW_COUNT assert sum(large_result.data.column("value").to_pylist()) == sum(range(1, _ARROW_LARGE_ROW_COUNT + 1)) @@ -3649,7 +3701,7 @@ def assert_sync_arrow_polars_contract(driver: object, case: DriverCase) -> None: sync_driver = cast("SyncContractDriver", driver) table = case.table - _seed_sync(sync_driver, (ContractRow("a", 1), ContractRow("b", 2)), table) + _seed_sync(sync_driver, (ContractRow("a", 1), ContractRow("b", 2)), table, case) frame = sync_driver.select_to_arrow(table.select_ordered_sql).to_polars() assert len(frame) == 2 @@ -3664,7 +3716,7 @@ async def assert_async_arrow_polars_contract(driver: object, case: DriverCase) - async_driver = cast("AsyncContractDriver", driver) table = case.table - await _seed_async(async_driver, (ContractRow("a", 1), ContractRow("b", 2)), table) + await _seed_async(async_driver, (ContractRow("a", 1), ContractRow("b", 2)), table, case) frame = (await async_driver.select_to_arrow(table.select_ordered_sql)).to_polars() assert len(frame) == 2 @@ -4184,7 +4236,7 @@ def assert_sync_storage_bridge_rustfs_contract( storage_registry.clear() try: _register_rustfs_alias(alias, rustfs_service, rustfs_bucket_name, prefix=prefix) - _seed_sync(sync_driver, (ContractRow("alpha", 1, "first"), ContractRow("beta", 2, "second")), table) + _seed_sync(sync_driver, (ContractRow("alpha", 1, "first"), ContractRow("beta", 2, "second")), table, case) export_job = sync_driver.select_to_storage( _storage_bridge_export_sql(table), destination, 1, format_hint="parquet" @@ -4216,7 +4268,9 @@ async def assert_async_storage_bridge_rustfs_contract( storage_registry.clear() try: _register_rustfs_alias(alias, rustfs_service, rustfs_bucket_name, prefix=prefix) - await _seed_async(async_driver, (ContractRow("alpha", 1, "first"), ContractRow("beta", 2, "second")), table) + await _seed_async( + async_driver, (ContractRow("alpha", 1, "first"), ContractRow("beta", 2, "second")), table, case + ) export_job = await async_driver.select_to_storage( _storage_bridge_export_sql(table), destination, 1, format_hint="parquet" diff --git a/tests/integration/adapters/contracts/conftest.py b/tests/integration/adapters/contracts/conftest.py index a5df51927..79d7866b9 100644 --- a/tests/integration/adapters/contracts/conftest.py +++ b/tests/integration/adapters/contracts/conftest.py @@ -12,6 +12,7 @@ from google.auth.credentials import AnonymousCredentials from pytest_databases.docker.bigquery import BigQueryService from pytest_databases.docker.cockroachdb import CockroachDBService +from pytest_databases.docker.mssql import MSSQLService from pytest_databases.docker.mysql import MySQLService from pytest_databases.docker.oracle import OracleService from pytest_databases.docker.postgres import PostgresService @@ -24,6 +25,9 @@ from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver, AiosqliteDriverFeatures from sqlspec.adapters.aiosqlite.adk import AiosqliteADKStore from sqlspec.adapters.aiosqlite.litestar import AiosqliteStore +from sqlspec.adapters.arrow_odbc import ArrowOdbcConfig, ArrowOdbcDriver +from sqlspec.adapters.arrow_odbc.adk import ArrowOdbcADKStore +from sqlspec.adapters.arrow_odbc.litestar import ArrowOdbcStore from sqlspec.adapters.asyncmy import AsyncmyConfig, AsyncmyDriver, AsyncmyDriverFeatures from sqlspec.adapters.asyncmy.adk import AsyncmyADKStore from sqlspec.adapters.asyncmy.litestar import AsyncmyStore @@ -108,6 +112,7 @@ from tests.integration.adapters.contracts._schema import ( DEFAULT_CONTRACT_TABLE, DUCKDB_CONTRACT_TABLE, + MSSQL_CONTRACT_TABLE, MYSQL_CONTRACT_TABLE, ORACLE_CONTRACT_TABLE, POSTGRES_CONTRACT_TABLE, @@ -322,6 +327,27 @@ def contract_oracle_sync_driver(oracle_23ai_service: OracleService) -> Generator config.close_pool() +@pytest.fixture +def contract_arrow_odbc_mssql_driver(mssql_service: MSSQLService) -> Generator[ArrowOdbcDriver, None, None]: + """Provide a fresh arrow-odbc driver backed by SQL Server.""" + config = ArrowOdbcConfig( + connection_config={"connection_string": mssql_service.connection_string}, + driver_features={"dbms_name": "Microsoft SQL Server"}, + ) + try: + with config.provide_session() as driver: + driver.execute_script("DROP TABLE IF EXISTS contract_items") + driver.execute_script(MSSQL_CONTRACT_TABLE.create_sql) + driver.commit() + yield driver + with contextlib.suppress(Exception): + driver.rollback() + driver.execute_script("DROP TABLE IF EXISTS contract_items") + driver.commit() + finally: + config.close_pool() + + @pytest.fixture async def contract_oracle_async_driver(oracle_23ai_service: OracleService) -> "AsyncGenerator[OracleAsyncDriver, None]": """Provide a fresh Oracle async driver for contract tests.""" @@ -832,6 +858,21 @@ def make(*, extension_config: dict[str, Any], suffix: str) -> PsqlpyConfig: return make +@pytest.fixture +def events_config_arrow_odbc_mssql(mssql_service: MSSQLService, tmp_path: Path) -> Callable[..., Any]: + """Build arrow-odbc SQL Server event-channel configs for contract tests.""" + + def make(*, extension_config: dict[str, Any], suffix: str) -> ArrowOdbcConfig: + return ArrowOdbcConfig( + connection_config={"connection_string": mssql_service.connection_string}, + migration_config=_events_migration_config(tmp_path, suffix), + extension_config=extension_config, + driver_features={"dbms_name": "Microsoft SQL Server"}, + ) + + return make + + def _resolve_events_case(request: pytest.FixtureRequest, case: EventsCase) -> EventsCaseContext: return EventsCaseContext(case=case, make_config=request.getfixturevalue(case.factory_fixture)) @@ -1504,6 +1545,22 @@ async def contract_pymysql_store(mysql_service: MySQLService) -> "AsyncGenerator config.close_pool() +@pytest.fixture +async def contract_arrow_odbc_store(mssql_service: MSSQLService) -> "AsyncGenerator[ArrowOdbcStore, None]": + """Provide a ready arrow-odbc SQL Server Litestar store for contract tests.""" + config = ArrowOdbcConfig( + connection_config={"connection_string": mssql_service.connection_string}, + extension_config=_STORE_EXTENSION_CONFIG, + driver_features={"dbms_name": "Microsoft SQL Server"}, + ) + store = ArrowOdbcStore(config) + await store.create_table() + yield store + with contextlib.suppress(Exception): + await store.delete_all() + config.close_pool() + + def _adk_extension_config(suffix: str) -> dict[str, Any]: return { "adk": { @@ -1841,6 +1898,22 @@ def make() -> "tuple[Any, Any]": return make +@pytest.fixture +def adk_store_arrow_odbc_mssql(mssql_service: MSSQLService) -> Callable[..., Any]: + """Build a fresh arrow-odbc SQL Server ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = ArrowOdbcConfig( + connection_config={"connection_string": mssql_service.connection_string}, + extension_config=_adk_extension_config(suffix), + driver_features={"dbms_name": "Microsoft SQL Server"}, + ) + return config, ArrowOdbcADKStore(config) + + return make + + def _resolve_adk_store_case(request: pytest.FixtureRequest, case: AdkStoreCase) -> AdkStoreCaseContext: return AdkStoreCaseContext(case=case, make_store=request.getfixturevalue(case.factory_fixture)) diff --git a/tests/integration/adapters/contracts/test_driver_contract.py b/tests/integration/adapters/contracts/test_driver_contract.py index 82594ef3f..8841ef45d 100644 --- a/tests/integration/adapters/contracts/test_driver_contract.py +++ b/tests/integration/adapters/contracts/test_driver_contract.py @@ -59,4 +59,4 @@ def test_driver_case_metadata_resolves_fixture(driver_case: DriverCaseContext) - assert driver_case.case.adapter assert driver_case.case.dialect assert driver_case.case.fixture_name - assert driver_case.case.supports_execute_many + assert driver_case.case.supports_execute_many or driver_case.case.supports_native_bulk_ingest diff --git a/tests/unit/adapters/test_arrow_odbc/test_driver.py b/tests/unit/adapters/test_arrow_odbc/test_driver.py index 0c08285a7..6bccf1643 100644 --- a/tests/unit/adapters/test_arrow_odbc/test_driver.py +++ b/tests/unit/adapters/test_arrow_odbc/test_driver.py @@ -14,11 +14,13 @@ ArrowOdbcDriver, ArrowOdbcDriverFeatures, build_connection_config, + create_mapped_exception, odbc_type_to_arrow, resolve_dialect_from_dbms_name, ) from sqlspec.adapters.arrow_odbc.data_dictionary import ArrowOdbcDataDictionary -from sqlspec.exceptions import SQLFileNotFoundError, SQLSpecError +from sqlspec.core import LimitOffsetFilter, OrderByFilter +from sqlspec.exceptions import SQLFileNotFoundError, SQLParsingError, SQLSpecError if TYPE_CHECKING: from sqlspec.adapters.arrow_odbc._typing import ArrowOdbcConnection @@ -368,6 +370,7 @@ def test_arrow_odbc_config_init_no_pre_super_assign_connection_string() -> None: assert config.connection_config == {"connection_string": connection_string} assert config.driver_features.get("connection_string") == connection_string + assert config.statement_config.dialect == "tsql" def test_arrow_odbc_config_init_no_pre_super_assign_driver_key() -> None: @@ -378,6 +381,7 @@ def test_arrow_odbc_config_init_no_pre_super_assign_driver_key() -> None: assert config.connection_config["driver"] == "ODBC Driver 17 for SQL Server" assert config.driver_features.get("dbms_name") == "ODBC Driver 17 for SQL Server" + assert config.statement_config.dialect == "tsql" def test_arrow_odbc_config_init_no_pre_super_assign_none_input() -> None: @@ -403,6 +407,32 @@ def test_arrow_odbc_mssql_driver_uses_tsql_statement_dialect() -> None: assert parameters == (1,) +def test_arrow_odbc_mssql_pagination_inlines_offset_fetch_integers() -> None: + """SQL Server ODBC requires literal integer OFFSET/FETCH control values.""" + connection = FakeConnection() + driver = ArrowOdbcDriver( + cast("ArrowOdbcConnection", connection), driver_features={"dbms_name": "Microsoft SQL Server"} + ) + + driver.execute("SELECT name, value FROM dbo.items", OrderByFilter("value", "desc"), LimitOffsetFilter(2, 1)) + + call = connection.read_calls[-1] + assert "OFFSET 1 ROWS FETCH FIRST 2 ROWS ONLY" in call["query"] + assert call["parameters"] is None + + +def test_arrow_odbc_mssql_syntax_error_maps_to_sql_parsing_error() -> None: + """SQL Server syntax errors should satisfy the shared parsing-error contract.""" + mapped = create_mapped_exception( + FakeOdbcError( + "ODBC emitted an error calling 'SQLExecDirect': State: 42000, Native error: 102, " + "Message: [Microsoft][ODBC Driver 18 for SQL Server][SQL Server]Incorrect syntax near '*'." + ) + ) + + assert isinstance(mapped, SQLParsingError) + + def test_arrow_odbc_driver_dialect_set_from_dbms_name() -> None: """The public dialect slot stores the resolved statement dialect without a private mirror slot.""" connection = FakeConnection() @@ -415,6 +445,39 @@ def test_arrow_odbc_driver_dialect_set_from_dbms_name() -> None: assert not hasattr(driver, "_statement_dialect") +def test_arrow_odbc_connection_in_transaction_uses_explicit_management_fallback() -> None: + """Generic ODBC connections do not expose a portable transaction-state API.""" + connection = FakeConnection() + driver = ArrowOdbcDriver(cast("ArrowOdbcConnection", connection)) + + assert driver._connection_in_transaction() is False # pyright: ignore[reportPrivateUsage] + + +def test_arrow_odbc_execute_marks_dml_rowcount_unknown() -> None: + """arrow-odbc does not expose portable rows-affected metadata for DML.""" + connection = FakeConnection() + driver = ArrowOdbcDriver(cast("ArrowOdbcConnection", connection)) + + result = driver.execute("DELETE FROM items WHERE id = ?", (1,)) + + assert result.rows_affected == -1 + assert connection.executed == [("DELETE FROM items WHERE id = ?", ["1"])] + + +def test_arrow_odbc_mssql_execute_script_sends_batch_without_splitting() -> None: + """SQL Server IF/BEGIN batches should stay intact for ODBC execution.""" + connection = FakeConnection() + driver = ArrowOdbcDriver( + cast("ArrowOdbcConnection", connection), driver_features={"dbms_name": "Microsoft SQL Server"} + ) + script = "IF OBJECT_ID(N'dbo.items', N'U') IS NULL BEGIN CREATE TABLE dbo.items (id INT); END;" + + result = driver.execute_script(script) + + assert result.rows_affected == -1 + assert connection.executed == [(script, None)] + + def test_arrow_odbc_driver_slots_populated_from_features() -> None: """All driver_features values are cached as typed slots at initialization.""" connection = FakeConnection() diff --git a/tests/unit/adapters/test_arrow_odbc/test_extensions.py b/tests/unit/adapters/test_arrow_odbc/test_extensions.py new file mode 100644 index 000000000..3f14edc9f --- /dev/null +++ b/tests/unit/adapters/test_arrow_odbc/test_extensions.py @@ -0,0 +1,30 @@ +"""Unit tests for arrow-odbc extension package boundaries.""" + +import ast +from pathlib import Path + + +def test_arrow_odbc_extensions_do_not_import_other_adapters() -> None: + """Extension modules should stay inside the arrow-odbc adapter boundary.""" + adapter_root = Path(__file__).parents[4] / "sqlspec" / "adapters" / "arrow_odbc" + extension_files = [ + *adapter_root.joinpath("adk").glob("*.py"), + *adapter_root.joinpath("events").glob("*.py"), + *adapter_root.joinpath("litestar").glob("*.py"), + ] + + violations: list[tuple[str, str]] = [] + for path in extension_files: + tree = ast.parse(path.read_text(), filename=str(path)) + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + module = node.module + if module.startswith("sqlspec.adapters.") and not module.startswith("sqlspec.adapters.arrow_odbc"): + violations.append((str(path.relative_to(adapter_root.parents[2])), module)) + if isinstance(node, ast.Import): + for alias in node.names: + module = alias.name + if module.startswith("sqlspec.adapters.") and not module.startswith("sqlspec.adapters.arrow_odbc"): + violations.append((str(path.relative_to(adapter_root.parents[2])), module)) + + assert violations == [] diff --git a/tests/unit/adapters/test_contract_arrow_odbc_harness.py b/tests/unit/adapters/test_contract_arrow_odbc_harness.py new file mode 100644 index 000000000..ad5deb231 --- /dev/null +++ b/tests/unit/adapters/test_contract_arrow_odbc_harness.py @@ -0,0 +1,87 @@ +"""Unit coverage for arrow-odbc contract harness capability gates.""" + +from typing import Any, cast + +import pytest + +from tests.integration.adapters.contracts import behaviors +from tests.integration.adapters.contracts._cases import DriverCase, get_driver_case +from tests.integration.adapters.contracts._inputs import PARAMETER_STYLE_CASES +from tests.integration.adapters.contracts._schema import DEFAULT_CONTRACT_TABLE, ContractRow +from tests.integration.adapters.contracts.behaviors import SyncContractDriver + + +class _BulkOnlyDriver: + def __init__(self) -> None: + self.commits = 0 + self.execute_many_calls: list[tuple[object, object]] = [] + self.load_from_arrow_calls: list[tuple[str, dict[str, Any], dict[str, Any]]] = [] + + def execute_many(self, statement: object, parameters: object, /, **kwargs: Any) -> object: + self.execute_many_calls.append((statement, parameters)) + msg = "execute_many should not be used for this case" + raise AssertionError(msg) + + def load_from_arrow(self, table: str, source: Any, /, **kwargs: Any) -> object: + self.load_from_arrow_calls.append((table, source.to_pydict(), kwargs)) + return object() + + def commit(self) -> None: + self.commits += 1 + + +def _bulk_only_case() -> DriverCase: + return DriverCase( + id="bulk-only-sync", + fixture_name="contract_bulk_only_driver", + adapter="arrow_odbc", + dialect="mssql", + mode="sync", + supports_execute_many=False, + supports_native_bulk_ingest=True, + ) + + +def test_arrow_odbc_sync_case_is_active_sql_server_bulk_only() -> None: + """The arrow-odbc contract case is active against SQL Server without row execute_many.""" + case = get_driver_case("arrow-odbc-sync") + + assert case.integration_status == "active" + assert case.fixture_name == "contract_arrow_odbc_mssql_driver" + assert case.dialect == "mssql" + assert case.supports_arrow + assert case.supports_arrow_streaming + assert case.supports_native_arrow + assert case.supports_native_bulk_ingest + assert not case.supports_execute_many + assert not case.supports_load_from_records + assert "execute-rows-affected-unavailable" in case.deviations + + +def test_seed_sync_uses_arrow_bulk_ingest_when_execute_many_is_disabled() -> None: + """Contract seeding uses native Arrow ingest for sync cases without execute_many.""" + driver = _BulkOnlyDriver() + + behaviors._seed_sync( # pyright: ignore[reportPrivateUsage] + cast("SyncContractDriver", driver), + (ContractRow("alpha", 10, None), ContractRow("beta", 20, "note")), + DEFAULT_CONTRACT_TABLE, + _bulk_only_case(), + ) + + assert driver.execute_many_calls == [] + assert driver.commits == 1 + assert driver.load_from_arrow_calls == [ + ("contract_items", {"name": ["alpha", "beta"], "value": [10, 20], "note": [None, "note"]}, {"overwrite": True}) + ] + + +def test_parameter_style_execute_many_case_skips_when_execute_many_is_disabled() -> None: + """Parameter-style cases that call execute_many skip for bulk-only sync drivers.""" + driver = _BulkOnlyDriver() + parameter_style_case = next(case for case in PARAMETER_STYLE_CASES if case.method == "execute_many") + + with pytest.raises(pytest.skip.Exception): + behaviors.assert_sync_parameter_style_contract( + cast("SyncContractDriver", driver), _bulk_only_case(), parameter_style_case + )