From f5fa457b18da71bbf3bdaae0bd0f480367f48206 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 1 Jul 2026 17:59:31 +0000 Subject: [PATCH 1/4] feat(pymssql): add sync SQL Server adapter --- docs/reference/adapters/index.rst | 13 + docs/reference/adapters/pymssql.rst | 57 ++ docs/usage/drivers_and_querying.rst | 1 + pyproject.toml | 2 + sqlspec/adapters/pymssql/__init__.py | 26 + sqlspec/adapters/pymssql/_typing.py | 95 ++ sqlspec/adapters/pymssql/adk/__init__.py | 10 + sqlspec/adapters/pymssql/adk/store.py | 850 ++++++++++++++++++ sqlspec/adapters/pymssql/config.py | 184 ++++ sqlspec/adapters/pymssql/core.py | 229 +++++ sqlspec/adapters/pymssql/data_dictionary.py | 256 ++++++ sqlspec/adapters/pymssql/driver.py | 178 ++++ sqlspec/adapters/pymssql/events/__init__.py | 5 + sqlspec/adapters/pymssql/events/store.py | 83 ++ sqlspec/adapters/pymssql/litestar/__init__.py | 5 + sqlspec/adapters/pymssql/litestar/store.py | 255 ++++++ sqlspec/adapters/pymssql/migrations.py | 166 ++++ sqlspec/adapters/pymssql/pool.py | 175 ++++ sqlspec/adapters/pymssql/type_converter.py | 92 ++ .../integration/adapters/contracts/_cases.py | 14 + tests/unit/adapters/test_pymssql/__init__.py | 1 + tests/unit/adapters/test_pymssql/conftest.py | 92 ++ .../unit/adapters/test_pymssql/test_config.py | 96 ++ tests/unit/adapters/test_pymssql/test_core.py | 88 ++ .../unit/adapters/test_pymssql/test_driver.py | 117 +++ .../adapters/test_pymssql/test_extensions.py | 70 ++ tests/unit/adapters/test_pymssql/test_pool.py | 72 ++ .../unit/adapters/test_pymssql/test_wiring.py | 34 + 28 files changed, 3266 insertions(+) create mode 100644 docs/reference/adapters/pymssql.rst create mode 100644 sqlspec/adapters/pymssql/__init__.py create mode 100644 sqlspec/adapters/pymssql/_typing.py create mode 100644 sqlspec/adapters/pymssql/adk/__init__.py create mode 100644 sqlspec/adapters/pymssql/adk/store.py create mode 100644 sqlspec/adapters/pymssql/config.py create mode 100644 sqlspec/adapters/pymssql/core.py create mode 100644 sqlspec/adapters/pymssql/data_dictionary.py create mode 100644 sqlspec/adapters/pymssql/driver.py create mode 100644 sqlspec/adapters/pymssql/events/__init__.py create mode 100644 sqlspec/adapters/pymssql/events/store.py create mode 100644 sqlspec/adapters/pymssql/litestar/__init__.py create mode 100644 sqlspec/adapters/pymssql/litestar/store.py create mode 100644 sqlspec/adapters/pymssql/migrations.py create mode 100644 sqlspec/adapters/pymssql/pool.py create mode 100644 sqlspec/adapters/pymssql/type_converter.py create mode 100644 tests/unit/adapters/test_pymssql/__init__.py create mode 100644 tests/unit/adapters/test_pymssql/conftest.py create mode 100644 tests/unit/adapters/test_pymssql/test_config.py create mode 100644 tests/unit/adapters/test_pymssql/test_core.py create mode 100644 tests/unit/adapters/test_pymssql/test_driver.py create mode 100644 tests/unit/adapters/test_pymssql/test_extensions.py create mode 100644 tests/unit/adapters/test_pymssql/test_pool.py create mode 100644 tests/unit/adapters/test_pymssql/test_wiring.py diff --git a/docs/reference/adapters/index.rst b/docs/reference/adapters/index.rst index 389687a4a..f8556f4ac 100644 --- a/docs/reference/adapters/index.rst +++ b/docs/reference/adapters/index.rst @@ -115,6 +115,12 @@ exports a typed config class and a driver implementation. Sync + Async SQL Server via Microsoft's official mssql-python driver. + .. grid-item-card:: pymssql + :link: pymssql + :link-type: doc + + Sync SQL Server via pymssql / FreeTDS. + Feature Comparison ================== @@ -235,6 +241,12 @@ Feature Comparison - Yes - - + * - pymssql + - Yes + - + - Yes + - + - .. toctree:: :hidden: @@ -257,3 +269,4 @@ Feature Comparison adbc arrow_odbc mssql_python + pymssql diff --git a/docs/reference/adapters/pymssql.rst b/docs/reference/adapters/pymssql.rst new file mode 100644 index 000000000..2aca16f94 --- /dev/null +++ b/docs/reference/adapters/pymssql.rst @@ -0,0 +1,57 @@ +======== +pymssql +======== + +Sync SQL Server adapter using `pymssql `_ +and FreeTDS. It uses pyformat parameters (``%s`` and ``%(name)s``) and exposes +sync SQLSpec config, driver, pooling, data dictionary, migration, and extension +store integrations. + +Configuration +============= + +.. autoclass:: sqlspec.adapters.pymssql.PymssqlConfig + :members: + :show-inheritance: + +Connection Parameters +===================== + +.. autoclass:: sqlspec.adapters.pymssql.config.PymssqlConnectionParams + :members: + :show-inheritance: + +Driver Features +=============== + +.. autoclass:: sqlspec.adapters.pymssql.config.PymssqlDriverFeatures + :members: + :show-inheritance: + +Driver +====== + +.. autoclass:: sqlspec.adapters.pymssql.PymssqlDriver + :members: + :show-inheritance: + +Connection Pool +=============== + +.. autoclass:: sqlspec.adapters.pymssql.PymssqlConnectionPool + :members: + :show-inheritance: + +Data Dictionary +=============== + +.. autoclass:: sqlspec.adapters.pymssql.data_dictionary.PymssqlSyncDataDictionary + :members: + :show-inheritance: + +Migrations +========== + +.. autoclass:: sqlspec.adapters.pymssql.migrations.PymssqlSyncMigrationTracker + :members: + :show-inheritance: diff --git a/docs/usage/drivers_and_querying.rst b/docs/usage/drivers_and_querying.rst index 01378265f..696fbe8ed 100644 --- a/docs/usage/drivers_and_querying.rst +++ b/docs/usage/drivers_and_querying.rst @@ -10,6 +10,7 @@ Supported Drivers (High Level) - **PostgreSQL**: asyncpg, psycopg (sync/async), psqlpy, ADBC - **SQLite**: sqlite3, aiosqlite, ADBC - **MySQL**: asyncmy, mysql-connector, pymysql +- **SQL Server**: mssql-python, pymssql, arrow-odbc - **Analytics / Cloud**: DuckDB, BigQuery, Spanner, Oracle, ADBC Core Execution Pattern diff --git a/pyproject.toml b/pyproject.toml index d1e8989d0..7b5a6c38e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -497,6 +497,8 @@ module = [ "asyncmy.*", "mssql_python", "mssql_python.*", + "pymssql", + "pymssql.*", "pyarrow", "pyarrow.*", "opentelemetry.*", diff --git a/sqlspec/adapters/pymssql/__init__.py b/sqlspec/adapters/pymssql/__init__.py new file mode 100644 index 000000000..f109b9063 --- /dev/null +++ b/sqlspec/adapters/pymssql/__init__.py @@ -0,0 +1,26 @@ +"""pymssql adapter for SQLSpec.""" + +from sqlspec.adapters.pymssql._typing import PymssqlConnection, PymssqlCursor +from sqlspec.adapters.pymssql.config import ( + PymssqlConfig, + PymssqlConnectionParams, + PymssqlDriverFeatures, + PymssqlPoolParams, +) +from sqlspec.adapters.pymssql.core import default_statement_config, driver_profile +from sqlspec.adapters.pymssql.driver import PymssqlDriver, PymssqlExceptionHandler +from sqlspec.adapters.pymssql.pool import PymssqlConnectionPool + +__all__ = ( + "PymssqlConfig", + "PymssqlConnection", + "PymssqlConnectionParams", + "PymssqlConnectionPool", + "PymssqlCursor", + "PymssqlDriver", + "PymssqlDriverFeatures", + "PymssqlExceptionHandler", + "PymssqlPoolParams", + "default_statement_config", + "driver_profile", +) diff --git a/sqlspec/adapters/pymssql/_typing.py b/sqlspec/adapters/pymssql/_typing.py new file mode 100644 index 000000000..a061428df --- /dev/null +++ b/sqlspec/adapters/pymssql/_typing.py @@ -0,0 +1,95 @@ +"""pymssql adapter type definitions. + +This module contains type aliases and classes that are excluded from mypyc +compilation to avoid ABI boundary issues. +""" + +import contextlib +from typing import TYPE_CHECKING, Any + +from sqlspec.typing import import_optional, import_optional_attr + +PYMSSQL_MODULE: Any = import_optional("pymssql") + +if TYPE_CHECKING: + from collections.abc import Callable + from types import TracebackType + from typing import TypeAlias + + from sqlspec.adapters.pymssql.driver import PymssqlDriver + from sqlspec.core import StatementConfig + + PymssqlConnection: TypeAlias = Any + PymssqlRawCursor: TypeAlias = Any + +if not TYPE_CHECKING: + PymssqlConnection = import_optional_attr("pymssql", "Connection") or Any + PymssqlRawCursor = import_optional_attr("pymssql", "Cursor") or Any + +__all__ = ("PYMSSQL_MODULE", "PymssqlConnection", "PymssqlCursor", "PymssqlRawCursor", "PymssqlSessionContext") + + +class PymssqlCursor: + """Context manager for pymssql cursor operations.""" + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "PymssqlConnection") -> None: + self.connection = connection + self.cursor: PymssqlRawCursor | None = None + + def __enter__(self) -> "PymssqlRawCursor": + self.cursor = self.connection.cursor() + return self.cursor + + def __exit__(self, *_: Any) -> None: + if self.cursor is not None: + with contextlib.suppress(Exception): + self.cursor.close() + + +class PymssqlSessionContext: + """Sync context manager for pymssql sessions.""" + + __slots__ = ( + "_acquire_connection", + "_connection", + "_driver", + "_driver_features", + "_prepare_driver", + "_release_connection", + "_statement_config", + ) + + def __init__( + self, + acquire_connection: "Callable[[], Any]", + release_connection: "Callable[[Any], Any]", + statement_config: "StatementConfig", + driver_features: "dict[str, Any]", + prepare_driver: "Callable[[PymssqlDriver], PymssqlDriver]", + ) -> None: + self._acquire_connection = acquire_connection + self._release_connection = release_connection + self._statement_config = statement_config + self._driver_features = driver_features + self._prepare_driver = prepare_driver + self._connection: Any = None + self._driver: PymssqlDriver | None = None + + def __enter__(self) -> "PymssqlDriver": + from sqlspec.adapters.pymssql.driver import PymssqlDriver + + self._connection = self._acquire_connection() + self._driver = PymssqlDriver( + connection=self._connection, statement_config=self._statement_config, driver_features=self._driver_features + ) + return self._prepare_driver(self._driver) + + def __exit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> "bool | None": + if self._connection is not None: + self._release_connection(self._connection) + self._connection = None + return None diff --git a/sqlspec/adapters/pymssql/adk/__init__.py b/sqlspec/adapters/pymssql/adk/__init__.py new file mode 100644 index 000000000..143bde066 --- /dev/null +++ b/sqlspec/adapters/pymssql/adk/__init__.py @@ -0,0 +1,10 @@ +"""pymssql ADK extension.""" + +from sqlspec.adapters.pymssql.adk.store import ( + PymssqlADKConfig, + PymssqlADKMemoryStore, + PymssqlADKStore, + PymssqlSyncADKStore, +) + +__all__ = ("PymssqlADKConfig", "PymssqlADKMemoryStore", "PymssqlADKStore", "PymssqlSyncADKStore") diff --git a/sqlspec/adapters/pymssql/adk/store.py b/sqlspec/adapters/pymssql/adk/store.py new file mode 100644 index 000000000..c091592d1 --- /dev/null +++ b/sqlspec/adapters/pymssql/adk/store.py @@ -0,0 +1,850 @@ +"""pymssql ADK stores for Google Agent Development Kit session storage.""" + +import re +from datetime import datetime +from typing import TYPE_CHECKING, Any, ClassVar, Final, cast + +from typing_extensions import NotRequired + +from sqlspec.adapters.pymssql._typing import PYMSSQL_MODULE, PymssqlCursor +from sqlspec.adapters.pymssql.data_dictionary import MssqlVersionInfo +from sqlspec.config import ADKConfig +from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore +from sqlspec.utils.serializers import from_json, to_json + +if TYPE_CHECKING: + from datetime import timedelta + + from sqlspec.adapters.pymssql.config import PymssqlConfig + from sqlspec.adapters.pymssql.driver import PymssqlDriver + from sqlspec.extensions.adk.memory._types import MemoryRecord + +__all__ = ("PymssqlADKConfig", "PymssqlADKMemoryStore", "PymssqlADKStore", "PymssqlSyncADKStore") + +MSSQL_TABLE_NOT_FOUND_ERROR: Final[int] = 208 +MSSQL_DUPLICATE_OBJECT_ERROR: Final[int] = 2714 +MSSQL_DUPLICATE_INDEX_ERROR: Final[int] = 1913 +MSSQL_SCHEMA: Final[str] = "dbo" +MSSQL_ERROR_NUMBER_PATTERN: Final[re.Pattern[str]] = re.compile(r"\(([-]?\d+)\)") +JSON_FALLBACK_COLUMN_TYPE: Final[str] = "NVARCHAR(MAX)" +JSON_NATIVE_COLUMN_TYPE: Final[str] = "JSON" + + +class _UnavailablePymssqlError(Exception): + """Fallback pymssql exception base when pymssql is unavailable.""" + + +MSSQL_ERROR: Final[type[BaseException]] = cast( + "type[BaseException]", + getattr(PYMSSQL_MODULE, "Error", _UnavailablePymssqlError) + if PYMSSQL_MODULE is not None + else _UnavailablePymssqlError, +) + + +class PymssqlADKConfig(ADKConfig): + """pymssql ADK extension settings.""" + + native_json: NotRequired[bool] + """Force native SQL Server JSON columns when True, or NVARCHAR(MAX) when False.""" + + +class PymssqlSyncADKStore(BaseSyncADKStore["PymssqlConfig"]): + """Synchronous pymssql ADK session/event store.""" + + connector_name: ClassVar[str] = "pymssql" + __slots__ = ("_json_column_type", "_native_json") + + def __init__(self, config: "PymssqlConfig") -> None: + super().__init__(config) + adk_config = _get_mssql_adk_config(config) + native_json = adk_config.get("native_json") + self._native_json: bool | None = native_json if isinstance(native_json, bool) else None + self._json_column_type: str | None = None + + def create_tables(self) -> None: + """Create all ADK session tables if they do not exist.""" + with self._config.provide_session() as driver: + driver.execute_script(self._get_create_sessions_table_sql()) + driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(self._get_create_app_states_table_sql()) + driver.execute_script(self._get_create_user_states_table_sql()) + driver.execute_script(self._get_create_metadata_table_sql()) + driver.execute_script(self._get_seed_metadata_sql()) + driver.commit() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new ADK session.""" + owner_column = f", {_quote_identifier(self._owner_id_column_name)}" if self._owner_id_column_name else "" + owner_param = ", %s" if self._owner_id_column_name else "" + sql = f""" + INSERT INTO {_table_ref(self._session_table)} ( + id, app_name, user_id{owner_column}, state, create_time, update_time + ) + OUTPUT inserted.id, inserted.app_name, inserted.user_id, inserted.state, inserted.create_time, inserted.update_time + VALUES (%s, %s, %s{owner_param}, %s, SYSUTCDATETIME(), SYSUTCDATETIME()) + """ + params: tuple[Any, ...] + if self._owner_id_column_name: + params = (session_id, app_name, user_id, owner_id, to_json(state)) + else: + params = (session_id, app_name, user_id, to_json(state)) + row = self._execute_fetchone(sql, params, commit=True) + if row is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return _session_record_from_row(row) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Return a scoped session or ``None`` if absent.""" + try: + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + self._execute( + f""" + UPDATE {_table_ref(self._session_table)} + SET update_time = SYSUTCDATETIME() + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + commit=True, + ) + row = self._execute_fetchone( + f""" + SELECT TOP (1) id, app_name, user_id, state, create_time, update_time + FROM {_table_ref(self._session_table)} + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return None + raise + return _session_record_from_row(row) if row is not None else None + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Replace a session's durable state.""" + self._execute( + f""" + UPDATE {_table_ref(self._session_table)} + SET state = %s, update_time = SYSUTCDATETIME() + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (to_json(state), app_name, user_id, session_id), + commit=True, + ) + + def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + """List ADK sessions for an application, optionally scoped to a user.""" + if user_id is None: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {_table_ref(self._session_table)} + WHERE app_name = %s + ORDER BY update_time DESC + """ + params: tuple[Any, ...] = (app_name,) + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {_table_ref(self._session_table)} + WHERE app_name = %s AND user_id = %s + ORDER BY update_time DESC + """ + params = (app_name, user_id) + try: + rows = self._execute_fetchall(sql, params) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return [] + raise + return [_session_record_from_row(row) for row in rows] + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete a session. Event rows cascade through the FK.""" + self._execute( + f"DELETE FROM {_table_ref(self._session_table)} WHERE app_name = %s AND user_id = %s AND id = %s", + (app_name, user_id, session_id), + commit=True, + ) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._execute(_get_insert_event_sql(self._events_table), _event_insert_params(event_record), commit=True) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update durable session/scoped state.""" + update_sql = f""" + UPDATE {_table_ref(self._session_table)} + SET state = %s, update_time = SYSUTCDATETIME() + OUTPUT inserted.id, inserted.app_name, inserted.user_id, inserted.state, inserted.create_time, inserted.update_time + WHERE app_name = %s AND user_id = %s AND id = %s + """ + with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: + try: + cursor.execute(update_sql, (to_json(state), app_name, user_id, session_id)) + row = cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) + cursor.execute(_get_insert_event_sql(self._events_table), _event_insert_params(event_record)) + if app_state is not None: + cursor.execute(self._get_upsert_app_state_sql(), (app_name, to_json(app_state))) + if user_state is not None: + cursor.execute(self._get_upsert_user_state_sql(), (app_name, user_id, to_json(user_state))) + except Exception: + conn.rollback() + raise + conn.commit() + return _session_record_from_row(row) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Return events for a scoped session ordered by event timestamp.""" + if limit == 0: + return [] + sql, params = self._get_events_query(app_name, user_id, session_id, after_timestamp, limit) + try: + rows = self._execute_fetchall(sql, params) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return [] + raise + return [_event_record_from_row(row) for row in rows] + + def delete_expired_events(self, before: datetime) -> int: + """Delete events older than ``before``.""" + try: + return self._execute( + f"DELETE FROM {_table_ref(self._events_table)} WHERE timestamp < %s", (before,), commit=True + ) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return 0 + raise + + def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time is older than ``updated_before``.""" + try: + return self._execute( + f"DELETE FROM {_table_ref(self._session_table)} WHERE update_time < %s", (updated_before,), commit=True + ) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return 0 + raise + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state.""" + try: + row = self._execute_fetchone( + f"SELECT TOP (1) state FROM {_table_ref(self._app_state_table)} WHERE app_name = %s", (app_name,) + ) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return None + raise + return _json_dict(row[0]) if row is not None else None + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state.""" + try: + row = self._execute_fetchone( + f""" + SELECT TOP (1) state + FROM {_table_ref(self._user_state_table)} + WHERE app_name = %s AND user_id = %s + """, + (app_name, user_id), + ) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return None + raise + return _json_dict(row[0]) if row is not None else None + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state.""" + self._execute(self._get_upsert_app_state_sql(), (app_name, to_json(state)), commit=True) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state.""" + self._execute(self._get_upsert_user_state_sql(), (app_name, user_id, to_json(state)), commit=True) + + def get_metadata(self, key: str) -> "str | None": + """Return an ADK metadata value.""" + try: + row = self._execute_fetchone( + f"SELECT TOP (1) value FROM {_table_ref(self._metadata_table)} WHERE [key] = %s", (key,) + ) + except MSSQL_ERROR as exc: + if _is_mssql_table_missing(exc): + return None + raise + return str(row[0]) if row is not None else None + + def set_metadata(self, key: str, value: str) -> None: + """Set an ADK metadata value.""" + self._execute(_get_upsert_metadata_sql(self._metadata_table), (key, value), commit=True) + + def _get_create_sessions_table_sql(self) -> str: + """Return T-SQL DDL for the ADK session table.""" + return _get_create_sessions_table_sql( + self._session_table, self._json_column_type_sync(), self._owner_id_column_ddl + ) + + def _get_create_events_table_sql(self) -> str: + """Return T-SQL DDL for the ADK event table.""" + return _get_create_events_table_sql(self._events_table, self._session_table, self._json_column_type_sync()) + + def _get_create_app_states_table_sql(self) -> str: + """Return T-SQL DDL for the app-scoped state table.""" + return _get_create_app_states_table_sql(self._app_state_table, self._json_column_type_sync()) + + def _get_create_user_states_table_sql(self) -> str: + """Return T-SQL DDL for the user-scoped state table.""" + return _get_create_user_states_table_sql(self._user_state_table, self._json_column_type_sync()) + + def _get_create_metadata_table_sql(self) -> str: + """Return T-SQL DDL for the ADK metadata table.""" + return _get_create_metadata_table_sql(self._metadata_table) + + def _get_seed_metadata_sql(self) -> str: + """Return T-SQL to seed schema-version metadata.""" + return _get_seed_metadata_sql(self._metadata_table) + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {_table_ref(self._app_state_table)}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {_table_ref(self._user_state_table)}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {_table_ref(self._metadata_table)}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {_table_ref(self._events_table)}", + f"DROP TABLE IF EXISTS {_table_ref(self._session_table)}", + ] + + def _get_upsert_app_state_sql(self) -> str: + return _get_upsert_state_sql(self._app_state_table, ("app_name",), ("%s",)) + + def _get_upsert_user_state_sql(self) -> str: + return _get_upsert_state_sql(self._user_state_table, ("app_name", "user_id"), ("%s", "%s")) + + def _get_events_query( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "tuple[str, tuple[Any, ...]]": + return _get_events_query(self._events_table, app_name, user_id, session_id, after_timestamp, limit) + + def _json_column_type_sync(self) -> str: + if self._json_column_type is not None: + return self._json_column_type + configured = _configured_json_column_type(self._native_json) + if configured is not None: + self._json_column_type = configured + return configured + with self._config.provide_session() as driver: + self._json_column_type = _json_column_type_from_sync_driver(driver) + return self._json_column_type + + def _execute_fetchone(self, sql: str, params: "tuple[Any, ...]" = (), *, commit: bool = False) -> "Any | None": + with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: + cursor.execute(sql, params) + row = cursor.fetchone() + if commit: + conn.commit() + return row + + def _execute_fetchall(self, sql: str, params: "tuple[Any, ...]" = ()) -> "list[Any]": + with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: + cursor.execute(sql, params) + return list(cursor.fetchall()) + + def _execute(self, sql: str, params: "tuple[Any, ...]" = (), *, commit: bool = False) -> int: + with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: + cursor.execute(sql, params) + rowcount = _cursor_rowcount(cursor) + if commit: + conn.commit() + return rowcount + + +class PymssqlADKMemoryStore(BaseSyncADKMemoryStore["PymssqlConfig"]): + """SQL Server ADK memory store using pymssql.""" + + __slots__ = () + + def __init__(self, config: "PymssqlConfig") -> None: + super().__init__(config) + + def create_tables(self) -> None: + """Create memory tables if they don't exist.""" + if not self._enabled: + return + statements = self._get_create_memory_table_sql() + if isinstance(statements, str): + statements = [statements] + with self._config.provide_session() as driver: + for statement in statements: + driver.execute_script(statement) + driver.commit() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with event-id deduplication.""" + if not self._enabled: + msg = "ADK memory store is disabled" + raise RuntimeError(msg) + if not entries: + return 0 + + owner_column = f", {_quote_identifier(self._owner_id_column_name)}" if self._owner_id_column_name else "" + owner_value = ", %s" if self._owner_id_column_name else "" + sql = f""" + IF NOT EXISTS (SELECT 1 FROM {_table_ref(self._memory_table)} WHERE event_id = %s) + BEGIN + INSERT INTO {_table_ref(self._memory_table)} ( + id, session_id, app_name, user_id, event_id, author, timestamp, + content_json, content_text, metadata_json{owner_column} + ) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s{owner_value}); + END; + """ + inserted = 0 + with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: + for entry in entries: + params = ( + entry["event_id"], + entry["id"], + entry["session_id"], + entry["app_name"], + entry["user_id"], + entry["event_id"], + entry.get("author"), + entry["timestamp"], + to_json(entry["content_json"]), + entry["content_text"], + to_json(entry.get("metadata_json")), + ) + if self._owner_id_column_name: + params = (*params, owner_id) + cursor.execute(sql, params) + inserted += _cursor_rowcount(cursor) + conn.commit() + return inserted + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + if not self._enabled: + msg = "ADK memory store is disabled" + raise RuntimeError(msg) + limit_value = limit or self._max_results + sql = f""" + SELECT TOP (%s) + id, session_id, app_name, user_id, event_id, author, timestamp, + content_json, content_text, metadata_json, inserted_at + FROM {_table_ref(self._memory_table)} + WHERE app_name = %s AND user_id = %s AND content_text LIKE %s + ORDER BY timestamp DESC + """ + rows = self._execute_fetchall(sql, (limit_value, app_name, user_id, f"%{query}%")) + return [_memory_record_from_row(row) for row in rows] + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._execute( + f"DELETE FROM {_table_ref(self._memory_table)} WHERE session_id = %s", (session_id,), commit=True + ) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than the retention window.""" + return self._execute( + f"DELETE FROM {_table_ref(self._memory_table)} WHERE inserted_at < DATEADD(day, -%s, SYSUTCDATETIME())", + (days,), + commit=True, + ) + + def _get_create_memory_table_sql(self) -> "str | list[str]": + owner_line = f",\n {self._owner_id_column_ddl}" if self._owner_id_column_ddl else "" + return [ + f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(self._memory_table)}' + AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(self._memory_table)} ( + id NVARCHAR(128) NOT NULL, + session_id NVARCHAR(128) NOT NULL, + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL, + event_id NVARCHAR(128) NOT NULL, + author NVARCHAR(256) NULL, + timestamp DATETIME2(6) NOT NULL, + content_json NVARCHAR(MAX) NOT NULL, + content_text NVARCHAR(MAX) NOT NULL, + metadata_json NVARCHAR(MAX) NULL, + inserted_at DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", self._memory_table, "inserted_at")} + DEFAULT SYSUTCDATETIME(){owner_line}, + CONSTRAINT {_constraint_ref("pk", self._memory_table, "id")} PRIMARY KEY (id), + CONSTRAINT {_constraint_ref("uq", self._memory_table, "event_id")} UNIQUE (event_id) + ); +END; +""", + _get_create_index_sql(self._memory_table, f"idx_{self._memory_table}_scope", "app_name, user_id"), + _get_create_index_sql(self._memory_table, f"idx_{self._memory_table}_session", "session_id"), + _get_create_index_sql(self._memory_table, f"idx_{self._memory_table}_timestamp", "timestamp DESC"), + ] + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {_table_ref(self._memory_table)}"] + + def _execute_fetchall(self, sql: str, params: "tuple[Any, ...]" = ()) -> "list[Any]": + with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: + cursor.execute(sql, params) + return list(cursor.fetchall()) + + def _execute(self, sql: str, params: "tuple[Any, ...]" = (), *, commit: bool = False) -> int: + with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: + cursor.execute(sql, params) + rowcount = _cursor_rowcount(cursor) + if commit: + conn.commit() + return rowcount + + +def _get_mssql_adk_config(config: Any) -> PymssqlADKConfig: + extension_config = getattr(config, "extension_config", {}) + if not isinstance(extension_config, dict): + return {} + adk_config = extension_config.get("adk", {}) + if not isinstance(adk_config, dict): + return {} + return cast("PymssqlADKConfig", adk_config) + + +def _configured_json_column_type(native_json: "bool | None") -> "str | None": + if native_json is True: + return JSON_NATIVE_COLUMN_TYPE + return JSON_FALLBACK_COLUMN_TYPE + + +def _json_column_type_from_sync_driver(driver: "PymssqlDriver") -> str: + version_info = driver.data_dictionary.get_version(driver) + if isinstance(version_info, MssqlVersionInfo) and version_info.supports_native_json(): + return JSON_NATIVE_COLUMN_TYPE + return JSON_FALLBACK_COLUMN_TYPE + + +def _get_create_sessions_table_sql(table: str, json_column_type: str, owner_id_column_ddl: "str | None") -> str: + owner_line = f",\n {owner_id_column_ddl}" if owner_id_column_ddl else "" + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + row_id UNIQUEIDENTIFIER NOT NULL CONSTRAINT {_constraint_ref("df", table, "row_id")} DEFAULT NEWSEQUENTIALID(), + id NVARCHAR(128) NOT NULL, + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL{owner_line}, + state {json_column_type} NOT NULL, + create_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "create_time")} DEFAULT SYSUTCDATETIME(), + update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), + CONSTRAINT {_constraint_ref("pk", table, "row_id")} PRIMARY KEY (row_id), + CONSTRAINT {_constraint_ref("uq", table, "id")} UNIQUE (id) + ); +END; +{_get_create_index_sql(table, f"idx_{table}_app_user", "app_name, user_id")} +{_get_create_index_sql(table, f"idx_{table}_update_time", "update_time DESC")} +""" + + +def _get_create_events_table_sql(table: str, session_table: str, json_column_type: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + row_id UNIQUEIDENTIFIER NOT NULL CONSTRAINT {_constraint_ref("df", table, "row_id")} DEFAULT NEWSEQUENTIALID(), + id NVARCHAR(128) NOT NULL, + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL, + session_id NVARCHAR(128) NOT NULL, + invocation_id NVARCHAR(256) NOT NULL, + timestamp DATETIME2(6) NOT NULL, + event_data {json_column_type} NOT NULL, + CONSTRAINT {_constraint_ref("pk", table, "row_id")} PRIMARY KEY (row_id), + CONSTRAINT {_constraint_ref("uq", table, "id")} UNIQUE (id), + CONSTRAINT {_constraint_ref("fk", table, "session")} FOREIGN KEY (session_id) + REFERENCES {_table_ref(session_table)}(id) ON DELETE CASCADE + ); +END; +{_get_create_index_sql(table, f"idx_{table}_scope", "app_name, user_id, session_id, timestamp ASC")} +{_get_create_index_sql(table, f"idx_{table}_session", "session_id, timestamp ASC")} +{_get_create_index_sql(table, f"idx_{table}_invocation", "invocation_id")} +{_get_create_index_sql(table, f"idx_{table}_timestamp", "timestamp ASC")} +""" + + +def _get_create_app_states_table_sql(table: str, json_column_type: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + app_name NVARCHAR(128) NOT NULL, + state {json_column_type} NOT NULL, + update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), + CONSTRAINT {_constraint_ref("pk", table, "app_name")} PRIMARY KEY (app_name) + ); +END; +""" + + +def _get_create_user_states_table_sql(table: str, json_column_type: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + app_name NVARCHAR(128) NOT NULL, + user_id NVARCHAR(128) NOT NULL, + state {json_column_type} NOT NULL, + update_time DATETIME2(6) NOT NULL CONSTRAINT {_constraint_ref("df", table, "update_time")} DEFAULT SYSUTCDATETIME(), + CONSTRAINT {_constraint_ref("pk", table, "app_user")} PRIMARY KEY (app_name, user_id) + ); +END; +""" + + +def _get_create_metadata_table_sql(table: str) -> str: + return f""" +IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = N'{_escape_sql_literal(table)}' AND schema_id = SCHEMA_ID(N'dbo')) +BEGIN + CREATE TABLE {_table_ref(table)} ( + [key] NVARCHAR(128) NOT NULL, + value NVARCHAR(512) NOT NULL, + CONSTRAINT {_constraint_ref("pk", table, "key")} PRIMARY KEY ([key]) + ); +END; +""" + + +def _get_create_index_sql(table: str, index_name: str, columns: str) -> str: + return f""" +IF NOT EXISTS ( + SELECT 1 FROM sys.indexes + WHERE name = N'{_escape_sql_literal(index_name)}' + AND object_id = OBJECT_ID(N'{_escape_sql_literal(MSSQL_SCHEMA)}.{_escape_sql_literal(table)}') +) +BEGIN + CREATE INDEX {_quote_identifier(index_name)} ON {_table_ref(table)} ({columns}); +END; +""" + + +def _get_insert_event_sql(table: str) -> str: + return f""" + INSERT INTO {_table_ref(table)} ( + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) + VALUES (%s, %s, %s, %s, %s, %s, %s) + """ + + +def _get_upsert_state_sql(table: str, key_columns: "tuple[str, ...]", key_params: "tuple[str, ...]") -> str: + source_columns = ", ".join( + f"{param} AS {_quote_identifier(column)}" for column, param in zip(key_columns, key_params, strict=False) + ) + source_columns = f"{source_columns}, %s AS state" + insert_columns = ", ".join(_quote_identifier(column) for column in (*key_columns, "state", "update_time")) + insert_values = ", ".join(f"source.{_quote_identifier(column)}" for column in (*key_columns, "state")) + match_clause = " AND ".join( + f"target.{_quote_identifier(column)} = source.{_quote_identifier(column)}" for column in key_columns + ) + return f""" + MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target + USING (SELECT {source_columns}) AS source + ON ({match_clause}) + WHEN MATCHED THEN + UPDATE SET state = source.state, update_time = SYSUTCDATETIME() + WHEN NOT MATCHED THEN + INSERT ({insert_columns}) + VALUES ({insert_values}, SYSUTCDATETIME()); + """ + + +def _get_upsert_metadata_sql(table: str) -> str: + return f""" + MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target + USING (SELECT %s AS [key], %s AS value) AS source + ON (target.[key] = source.[key]) + WHEN MATCHED THEN + UPDATE SET value = source.value + WHEN NOT MATCHED THEN + INSERT ([key], value) + VALUES (source.[key], source.value); + """ + + +def _get_seed_metadata_sql(table: str) -> str: + return f""" + MERGE INTO {_table_ref(table)} WITH (HOLDLOCK) AS target + USING (SELECT N'schema_version' AS [key], N'1' AS value) AS source + ON (target.[key] = source.[key]) + WHEN MATCHED THEN + UPDATE SET value = source.value + WHEN NOT MATCHED THEN + INSERT ([key], value) + VALUES (source.[key], source.value); + """ + + +def _get_events_query( + table: str, app_name: str, user_id: str, session_id: str, after_timestamp: "datetime | None", limit: "int | None" +) -> "tuple[str, tuple[Any, ...]]": + top_clause = "TOP (%s) " if limit is not None else "" + params: list[Any] = [limit] if limit is not None else [] + params.extend([app_name, user_id, session_id]) + after_clause = "" + if after_timestamp is not None: + after_clause = " AND timestamp > %s" + params.append(after_timestamp) + sql = f""" + SELECT {top_clause}id, app_name, user_id, session_id, invocation_id, timestamp, event_data + FROM {_table_ref(table)} + WHERE app_name = %s AND user_id = %s AND session_id = %s{after_clause} + ORDER BY timestamp ASC + """ + return sql, tuple(params) + + +def _event_insert_params(event_record: EventRecord) -> "tuple[Any, ...]": + return ( + event_record["id"], + event_record["app_name"], + event_record["user_id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + to_json(event_record["event_data"]), + ) + + +def _session_record_from_row(row: Any) -> SessionRecord: + return SessionRecord( + id=row[0], app_name=row[1], user_id=row[2], state=_json_dict(row[3]), create_time=row[4], update_time=row[5] + ) + + +def _event_record_from_row(row: Any) -> EventRecord: + return EventRecord( + id=row[0], + app_name=row[1], + user_id=row[2], + session_id=row[3], + invocation_id=row[4], + timestamp=row[5], + event_data=_json_dict(row[6]), + ) + + +def _memory_record_from_row(row: Any) -> "MemoryRecord": + return cast( + "MemoryRecord", + { + "id": row[0], + "session_id": row[1], + "app_name": row[2], + "user_id": row[3], + "event_id": row[4], + "author": row[5], + "timestamp": row[6], + "content_json": _json_dict(row[7]), + "content_text": row[8], + "metadata_json": _json_dict(row[9]) if row[9] is not None else None, + "inserted_at": row[10], + }, + ) + + +def _json_dict(value: Any) -> "dict[str, Any]": + if value is None: + return {} + if isinstance(value, dict): + return cast("dict[str, Any]", value) + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, (bytes, str)): + return cast("dict[str, Any]", from_json(value)) + return cast("dict[str, Any]", from_json(str(value))) + + +def _cursor_rowcount(cursor: Any) -> int: + rowcount = getattr(cursor, "rowcount", 0) + return rowcount if isinstance(rowcount, int) and rowcount > 0 else 0 + + +def _is_mssql_table_missing(exc: BaseException) -> bool: + text = str(exc).lower() + return "invalid object name" in text or _mssql_error_number(exc) == MSSQL_TABLE_NOT_FOUND_ERROR + + +def _mssql_error_number(exc: BaseException) -> "int | None": + matches = MSSQL_ERROR_NUMBER_PATTERN.findall(str(exc)) + if not matches: + return None + try: + return int(matches[-1]) + except ValueError: + return None + + +def _quote_identifier(identifier: str) -> str: + return f"[{identifier.replace(']', ']]')}]" + + +def _table_ref(table: str) -> str: + return f"{_quote_identifier(MSSQL_SCHEMA)}.{_quote_identifier(table)}" + + +def _constraint_ref(prefix: str, table: str, suffix: str) -> str: + return _quote_identifier(f"{prefix}_{table}_{suffix}") + + +def _escape_sql_literal(value: str) -> str: + return value.replace("'", "''") + + +def _raise_session_not_found(session_id: str) -> None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) + + +PymssqlADKStore = PymssqlSyncADKStore diff --git a/sqlspec/adapters/pymssql/config.py b/sqlspec/adapters/pymssql/config.py new file mode 100644 index 000000000..9c01e80c5 --- /dev/null +++ b/sqlspec/adapters/pymssql/config.py @@ -0,0 +1,184 @@ +"""pymssql database configuration.""" + +from collections.abc import Callable, Mapping +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast + +from typing_extensions import NotRequired + +from sqlspec.adapters.pymssql._typing import PymssqlConnection, PymssqlCursor, PymssqlRawCursor, PymssqlSessionContext +from sqlspec.adapters.pymssql.core import apply_driver_features, default_statement_config +from sqlspec.adapters.pymssql.driver import PymssqlDriver, PymssqlExceptionHandler +from sqlspec.adapters.pymssql.migrations import PymssqlSyncMigrationTracker +from sqlspec.adapters.pymssql.pool import PymssqlConnectionPool +from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig +from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory +from sqlspec.extensions.events import EventRuntimeHints +from sqlspec.utils.config_tools import normalize_connection_config + +if TYPE_CHECKING: + from sqlspec.core import StatementConfig + from sqlspec.observability import ObservabilityConfig + +__all__ = ("PymssqlConfig", "PymssqlConnectionParams", "PymssqlDriverFeatures", "PymssqlPoolParams", "PymssqlTimeout") + +PymssqlTimeout = int | float + + +class PymssqlConnectionParams(TypedDict): + """pymssql connection parameters.""" + + server: NotRequired[str] + host: NotRequired[str] + user: NotRequired[str] + password: NotRequired[str] + database: NotRequired[str] + port: NotRequired[int | str] + timeout: NotRequired[PymssqlTimeout] + login_timeout: NotRequired[PymssqlTimeout] + charset: NotRequired[str] + as_dict: NotRequired[bool] + appname: NotRequired[str] + conn_properties: NotRequired[str] + autocommit: NotRequired[bool] + tds_version: NotRequired[str] + use_datetime2: NotRequired[bool] + arraysize: NotRequired[int] + conv: NotRequired[Mapping[int | type[Any], Callable[..., Any]]] + read_only: NotRequired[bool] + pool_recycle_seconds: NotRequired[int] + health_check_interval: NotRequired[float] + extra: NotRequired["dict[str, Any]"] + + +class PymssqlPoolParams(PymssqlConnectionParams): + """pymssql pool parameters.""" + + pool_recycle_seconds: NotRequired[int] + health_check_interval: NotRequired[float] + + +class PymssqlDriverFeatures(TypedDict): + """pymssql driver feature flags. + + json_serializer: Custom JSON serializer function. + Defaults to sqlspec.utils.serializers.to_json. + json_deserializer: Custom JSON deserializer function. + Defaults to sqlspec.utils.serializers.from_json. + on_connection_create: Callback executed when a connection is created. + Receives the raw pymssql connection for low-level driver configuration. + Runs after connection creation. + enable_events: Enable database event channel support. + events_backend: Event channel backend selection. + """ + + json_serializer: NotRequired["Callable[[Any], str]"] + json_deserializer: NotRequired["Callable[[str], Any]"] + on_connection_create: "NotRequired[Callable[[PymssqlConnection], None]]" + enable_events: NotRequired[bool] + events_backend: NotRequired[str] + + +class PymssqlConnectionContext(SyncPoolConnectionContext): + """Context manager for pymssql connections.""" + + __slots__ = () + + +class _PymssqlSessionConnectionHandler(SyncPoolSessionFactory): + __slots__ = () + + +class PymssqlConfig(SyncDatabaseConfig[PymssqlConnection, PymssqlConnectionPool, PymssqlDriver]): + """Configuration for pymssql synchronous connections.""" + + driver_type: "ClassVar[type[PymssqlDriver]]" = PymssqlDriver + connection_type: "ClassVar[type[PymssqlConnection]]" = cast("type[PymssqlConnection]", PymssqlConnection) + migration_tracker_type: "ClassVar[type[PymssqlSyncMigrationTracker]]" = PymssqlSyncMigrationTracker + supports_transactional_ddl: "ClassVar[bool]" = True + supports_native_arrow_export: "ClassVar[bool]" = False + supports_native_arrow_import: "ClassVar[bool]" = False + supports_native_parquet_export: "ClassVar[bool]" = False + supports_native_parquet_import: "ClassVar[bool]" = False + supports_native_row_streaming: "ClassVar[bool]" = False + _connection_context_class: "ClassVar[type[PymssqlConnectionContext]]" = PymssqlConnectionContext + _session_factory_class: "ClassVar[type[_PymssqlSessionConnectionHandler]]" = _PymssqlSessionConnectionHandler + _session_context_class: "ClassVar[type[PymssqlSessionContext]]" = PymssqlSessionContext + _default_statement_config = default_statement_config + + def __init__( + self, + *, + connection_config: "PymssqlPoolParams | dict[str, Any] | None" = None, + connection_instance: "PymssqlConnectionPool | None" = None, + migration_config: "dict[str, Any] | None" = None, + statement_config: "StatementConfig | None" = None, + driver_features: "PymssqlDriverFeatures | dict[str, Any] | None" = None, + bind_key: "str | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, + ) -> None: + connection_config = normalize_connection_config(connection_config) + connection_config.setdefault("server", connection_config.pop("host", "localhost")) + connection_config.setdefault("port", 1433) + + statement_config = statement_config or default_statement_config + statement_config, driver_features = apply_driver_features(statement_config, driver_features) + + features_dict = dict(driver_features) if driver_features else {} + self._user_connection_hook: Callable[[PymssqlConnection], None] | None = features_dict.pop( + "on_connection_create", None + ) + + super().__init__( + connection_config=connection_config, + connection_instance=connection_instance, + migration_config=migration_config, + statement_config=statement_config, + driver_features=features_dict, + bind_key=bind_key, + extension_config=extension_config, + observability_config=observability_config, + **kwargs, + ) + + def _create_pool(self) -> "PymssqlConnectionPool": + config = dict(self.connection_config) + pool_recycle = config.pop("pool_recycle_seconds", 86400) + health_check = config.pop("health_check_interval", 30.0) + return PymssqlConnectionPool( + config, + recycle_seconds=pool_recycle, + health_check_interval=health_check, + on_connection_create=self._user_connection_hook, + ) + + def _close_pool(self) -> None: + if self.connection_instance: + self.connection_instance.close() + self.connection_instance = None + + def create_connection(self) -> "PymssqlConnection": + pool = self.provide_pool() + return pool.acquire() + + def get_signature_namespace(self) -> "dict[str, Any]": + namespace = super().get_signature_namespace() + namespace.update({ + "PymssqlConfig": PymssqlConfig, + "PymssqlConnection": PymssqlConnection, + "PymssqlConnectionContext": PymssqlConnectionContext, + "PymssqlConnectionParams": PymssqlConnectionParams, + "PymssqlConnectionPool": PymssqlConnectionPool, + "PymssqlCursor": PymssqlCursor, + "PymssqlDriver": PymssqlDriver, + "PymssqlDriverFeatures": PymssqlDriverFeatures, + "PymssqlExceptionHandler": PymssqlExceptionHandler, + "PymssqlPoolParams": PymssqlPoolParams, + "PymssqlRawCursor": PymssqlRawCursor, + "PymssqlSessionContext": PymssqlSessionContext, + }) + return namespace + + def get_event_runtime_hints(self) -> "EventRuntimeHints": + return EventRuntimeHints(poll_interval=0.25, lease_seconds=5, select_for_update=True, skip_locked=True) diff --git a/sqlspec/adapters/pymssql/core.py b/sqlspec/adapters/pymssql/core.py new file mode 100644 index 000000000..2ecf47add --- /dev/null +++ b/sqlspec/adapters/pymssql/core.py @@ -0,0 +1,229 @@ +"""pymssql adapter compiled helpers.""" + +import re +from collections.abc import Sized +from typing import TYPE_CHECKING, Any, Final + +from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile +from sqlspec.exceptions import ( + DatabaseConnectionError, + DataError, + DeadlockError, + ForeignKeyViolationError, + IntegrityError, + NotNullViolationError, + OperationalError, + PermissionDeniedError, + QueryTimeoutError, + SQLSpecError, + UniqueViolationError, +) +from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.text import split_qualified_identifier +from sqlspec.utils.type_converters import build_uuid_coercions +from sqlspec.utils.type_guards import has_rowcount + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping, Sequence + from logging import Logger + +__all__ = ( + "apply_driver_features", + "build_insert_statement", + "build_profile", + "build_statement_config", + "collect_rows", + "create_mapped_exception", + "default_statement_config", + "driver_profile", + "format_identifier", + "normalize_execute_many_parameters", + "normalize_execute_parameters", + "resolve_column_names", + "resolve_many_rowcount", + "resolve_rowcount", +) + +_ERROR_NUMBER_PATTERN: Final[re.Pattern[str]] = re.compile(r"\(([-]?\d+)\)") +_ERROR_CODE_MAPPING: Final[dict[int, tuple[type[SQLSpecError], str]]] = { + 2601: (UniqueViolationError, "unique constraint violation"), + 2627: (UniqueViolationError, "unique constraint violation"), + 547: (ForeignKeyViolationError, "foreign key or check constraint violation"), + 515: (NotNullViolationError, "not-null constraint violation"), + 18456: (PermissionDeniedError, "permission denied"), + 4060: (DatabaseConnectionError, "database connection error"), + 53: (DatabaseConnectionError, "database connection error"), + 1205: (DeadlockError, "deadlock detected"), + -2: (QueryTimeoutError, "query timeout"), + 8114: (DataError, "data conversion error"), + 1105: (OperationalError, "operational error"), +} + + +def format_identifier(identifier: str) -> str: + """Format a T-SQL identifier with bracket quoting.""" + cleaned = identifier.strip() + if not cleaned: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + parts = split_qualified_identifier(cleaned, quote_chars='"', allow_bracket_quotes=True) + return ".".join(_quote_bracket_identifier(part) for part in parts) + + +def build_insert_statement(table: str, columns: "list[str]") -> str: + """Build a pymssql-compatible INSERT statement.""" + column_clause = ", ".join(_quote_bracket_identifier(column) for column in columns) + placeholders = ", ".join("%s" for _ in columns) + return f"INSERT INTO {format_identifier(table)} ({column_clause}) VALUES ({placeholders})" + + +def normalize_execute_parameters(parameters: Any) -> Any: + """Normalize parameters for pymssql execute calls.""" + if parameters is None: + return None + if isinstance(parameters, list): + return tuple(parameters) + return parameters + + +def normalize_execute_many_parameters(parameters: Any) -> Any: + """Normalize parameters for pymssql executemany calls.""" + if not parameters: + msg = "execute_many requires parameters" + raise ValueError(msg) + return parameters + + +def _bool_to_int(value: bool) -> int: + return int(value) + + +def build_profile() -> "DriverParameterProfile": + """Create the pymssql driver parameter profile.""" + return DriverParameterProfile( + name="pymssql", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK, ParameterStyle.NAMED_PYFORMAT}, + default_execution_style=ParameterStyle.POSITIONAL_PYFORMAT, + supported_execution_styles={ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=False, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, + default_dialect="tsql", + ) + + +driver_profile = build_profile() + + +def build_statement_config( + *, json_serializer: "Callable[[Any], str] | None" = None, json_deserializer: "Callable[[str], Any] | None" = None +) -> "StatementConfig": + """Construct the pymssql statement configuration.""" + return build_statement_config_from_profile( + driver_profile, + statement_overrides={"dialect": "tsql"}, + json_serializer=json_serializer or to_json, + json_deserializer=json_deserializer or from_json, + ) + + +default_statement_config = build_statement_config() + + +def apply_driver_features( + statement_config: "StatementConfig", driver_features: "Mapping[str, Any] | None" +) -> "tuple[StatementConfig, dict[str, Any]]": + """Apply pymssql driver feature defaults to statement config.""" + features: dict[str, Any] = dict(driver_features) if driver_features else {} + json_serializer = features.setdefault("json_serializer", to_json) + json_deserializer = features.setdefault("json_deserializer", from_json) + + if json_serializer is not None: + parameter_config = statement_config.parameter_config.with_json_serializers( + json_serializer, deserializer=json_deserializer + ) + statement_config = statement_config.replace(parameter_config=parameter_config) + + return statement_config, features + + +def create_mapped_exception(exc: Exception, logger: "Logger | None" = None) -> Exception: + """Map a pymssql exception to SQLSpec's exception hierarchy.""" + error_number = _extract_error_number(exc) + if error_number is not None: + mapping = _ERROR_CODE_MAPPING.get(error_number) + if mapping is not None: + error_class, description = mapping + return error_class(f"SQL Server error {error_number}: {description}. Original error: {exc}") + if logger is not None: + logger.debug("Unmapped SQL Server error number: %s", error_number) + + exc_name = type(exc).__name__ + if exc_name == "IntegrityError": + return IntegrityError(f"SQL Server integrity error. Original error: {exc}") + if exc_name == "OperationalError": + return OperationalError(f"SQL Server operational error. Original error: {exc}") + if exc_name == "DataError": + return DataError(f"SQL Server data error. Original error: {exc}") + return SQLSpecError(f"SQL Server database error. Original error: {exc}") + + +def resolve_column_names(description: "Sequence[Any] | None") -> "list[str]": + """Resolve ordered column names from cursor metadata.""" + if not description: + return [] + return [desc[0] for desc in description] + + +def collect_rows( + fetched_data: "Sequence[Any] | None", description: "Sequence[Any] | None" +) -> "tuple[list[Any], list[str], str]": + """Collect pymssql rows, preserving tuple row shape.""" + column_names = resolve_column_names(description) + if not fetched_data: + return [], column_names, "tuple" + return list(fetched_data), column_names, "tuple" + + +def resolve_rowcount(cursor: Any) -> int: + """Resolve rowcount from a pymssql cursor.""" + if not has_rowcount(cursor): + return 0 + rowcount = cursor.rowcount + if isinstance(rowcount, int) and rowcount >= 0: + return rowcount + return 0 + + +def resolve_many_rowcount(cursor: Any, parameters: Any, *, fallback_count: "int | None" = None) -> int: + """Resolve executemany rowcount using cursor metadata with payload fallback.""" + rowcount = resolve_rowcount(cursor) + if rowcount > 0: + return rowcount + if fallback_count is not None: + return fallback_count + if isinstance(parameters, Sized): + return len(parameters) + return 0 + + +def _quote_bracket_identifier(identifier: str) -> str: + cleaned = identifier.strip() + if cleaned.startswith("[") and cleaned.endswith("]"): + cleaned = cleaned[1:-1].replace("]]", "]") + return f"[{cleaned.replace(']', ']]')}]" + + +def _extract_error_number(exc: Exception) -> "int | None": + matches = _ERROR_NUMBER_PATTERN.findall(str(exc)) + if not matches: + return None + try: + return int(matches[-1]) + except ValueError: + return None diff --git a/sqlspec/adapters/pymssql/data_dictionary.py b/sqlspec/adapters/pymssql/data_dictionary.py new file mode 100644 index 000000000..c66875fef --- /dev/null +++ b/sqlspec/adapters/pymssql/data_dictionary.py @@ -0,0 +1,256 @@ +"""pymssql data dictionary.""" + +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from mypy_extensions import mypyc_attr + +from sqlspec.data_dictionary import ( + ColumnMetadata, + ForeignKeyMetadata, + IndexMetadata, + TableMetadata, + VersionInfo, + get_dialect_config, +) +from sqlspec.data_dictionary.dialects.mssql import ( + extract_mssql_version_value, + is_mssql_azure_sql, + list_mssql_available_features, + merge_mssql_table_lists, + mssql_supports_native_json, + parse_mssql_engine_edition, + parse_mssql_version_components, + resolve_mssql_feature_flag, +) +from sqlspec.driver import SyncDataDictionaryBase +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from sqlspec.adapters.pymssql.driver import PymssqlDriver + from sqlspec.data_dictionary._types import DialectConfig + +__all__ = ("MssqlVersionInfo", "PymssqlSyncDataDictionary") + +logger = get_logger("sqlspec.adapters.pymssql.data_dictionary") + + +class MssqlVersionInfo(VersionInfo): + """MSSQL database version info with build, revision, and Azure SQL detection.""" + + def __init__( + self, + major: int, + minor: int = 0, + build: int = 0, + revision: int = 0, + edition: str | None = None, + engine_edition: int | None = None, + ) -> None: + super().__init__(major, minor, 0) + self.build = build + self.revision = revision + self.edition = edition + self.engine_edition = engine_edition + self.is_azure_sql = is_mssql_azure_sql(engine_edition) + + def supports_native_json(self) -> bool: + """Return whether this server supports the native JSON type.""" + return mssql_supports_native_json(self.major, is_azure_sql=self.is_azure_sql) + + @property + def version_tuple(self) -> "tuple[int, int, int]": + """Get version tuple using the MSSQL build number as the third component.""" + return (self.major, self.minor, self.build) + + def __str__(self) -> str: + """String representation of version info.""" + version_str = f"{self.major}.{self.minor}.{self.build}.{self.revision}" + if self.edition: + version_str += f" ({self.edition})" + if self.is_azure_sql: + version_str += " [Azure]" + return version_str + + +class _MssqlDataDictionaryMixin: + """Shared helpers for MSSQL data dictionaries.""" + + dialect: ClassVar[str] = "mssql" + + def get_dialect_config(self) -> "DialectConfig": + """Return the dialect configuration for this data dictionary.""" + return get_dialect_config(type(self).dialect) + + def resolve_schema(self, schema: str | None) -> str | None: + """Return a schema name using dialect defaults when missing.""" + if schema is not None: + return schema + return self.get_dialect_config().default_schema + + def list_available_features(self) -> list[str]: + """List available feature flags for this dialect.""" + return list_mssql_available_features(self.get_dialect_config()) + + def _build_version_info( + self, version_value: str | None, edition: str | None, engine_edition_value: Any + ) -> MssqlVersionInfo | None: + if not version_value: + return None + major, minor, build, revision = parse_mssql_version_components(version_value) + return MssqlVersionInfo( + major, + minor, + build, + revision, + edition=edition, + engine_edition=parse_mssql_engine_edition(engine_edition_value), + ) + + def _get_optimal_type_from_version(self, version_info: MssqlVersionInfo | None, type_category: str) -> str: + if type_category in {"json", "jsonb"} and version_info is not None and version_info.supports_native_json(): + return "JSON" + return self.get_dialect_config().get_optimal_type(type_category) + + +@mypyc_attr(allow_interpreted_subclasses=True, native_class=False) +class PymssqlSyncDataDictionary(_MssqlDataDictionaryMixin, SyncDataDictionaryBase): + """MSSQL sync data dictionary.""" + + dialect: ClassVar[str] = "mssql" + + def __init__(self) -> None: + super().__init__() + + def get_version(self, driver: "PymssqlDriver") -> MssqlVersionInfo | None: + """Get SQL Server version information.""" + driver_id = id(driver) + if driver_id in self._version_fetch_attempted: + return cast("MssqlVersionInfo | None", self._version_cache.get(driver_id)) + + row = driver.select_one_or_none(self.get_query_text("version")) + if not row: + self._log_version_unavailable(type(self).dialect, "missing") + self.cache_version(driver_id, None) + return None + + version_value = extract_mssql_version_value( + _row_value(row, "product_version") or _row_value(row, "version_string", "version") + ) + edition_value = _row_value(row, "edition") + edition = str(edition_value) if edition_value is not None else None + version_info = self._build_version_info(version_value, edition, _row_value(row, "engine_edition")) + if version_info is None: + self._log_version_unavailable(type(self).dialect, "parse_failed") + self.cache_version(driver_id, None) + return None + + self._log_version_detected(type(self).dialect, version_info) + self.cache_version(driver_id, version_info) + return version_info + + def get_feature_flag(self, driver: "PymssqlDriver", feature: str) -> bool: + """Check whether SQL Server supports a feature.""" + version_info = self.get_version(driver) + return resolve_mssql_feature_flag( + feature, + major=version_info.major if version_info is not None else 0, + is_azure_sql=bool(version_info and version_info.is_azure_sql), + config=self.get_dialect_config(), + version_info=version_info, + ) + + def get_optimal_type(self, driver: "PymssqlDriver", type_category: str) -> str: + """Get optimal SQL Server type for a category.""" + return self._get_optimal_type_from_version(self.get_version(driver), type_category) + + def get_tables(self, driver: "PymssqlDriver", schema: str | None = None) -> list[TableMetadata]: + """Get tables sorted by dependency order with catalog fallback.""" + schema_name = self.resolve_schema(schema) + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="tables") + ordered = cast( + "list[TableMetadata]", + driver.select(self.get_query("tables_by_schema"), schema_name=schema_name, schema_type=TableMetadata), + ) + all_rows = cast( + "list[TableMetadata]", + driver.select(self.get_query("all_tables_by_schema"), schema_name=schema_name, schema_type=TableMetadata), + ) + return merge_mssql_table_lists(ordered, all_rows) + + def get_columns( + self, driver: "PymssqlDriver", table: str | None = None, schema: str | None = None + ) -> list[ColumnMetadata]: + """Get column information for a table or schema.""" + schema_name = self.resolve_schema(schema) + if table is None: + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="columns") + return cast( + "list[ColumnMetadata]", + driver.select(self.get_query("columns_by_schema"), schema_name=schema_name, schema_type=ColumnMetadata), + ) + self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="columns") + return cast( + "list[ColumnMetadata]", + driver.select( + self.get_query("columns_by_table"), + schema_name=schema_name, + table_name=table, + schema_type=ColumnMetadata, + ), + ) + + def get_indexes( + self, driver: "PymssqlDriver", table: str | None = None, schema: str | None = None + ) -> list[IndexMetadata]: + """Get index metadata for a table or schema.""" + schema_name = self.resolve_schema(schema) + if table is None: + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="indexes") + return cast( + "list[IndexMetadata]", + driver.select(self.get_query("indexes_by_schema"), schema_name=schema_name, schema_type=IndexMetadata), + ) + self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="indexes") + return cast( + "list[IndexMetadata]", + driver.select( + self.get_query("indexes_by_table"), schema_name=schema_name, table_name=table, schema_type=IndexMetadata + ), + ) + + def get_foreign_keys( + self, driver: "PymssqlDriver", table: str | None = None, schema: str | None = None + ) -> list[ForeignKeyMetadata]: + """Get foreign key metadata.""" + schema_name = self.resolve_schema(schema) + if table is None: + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="foreign_keys") + return cast( + "list[ForeignKeyMetadata]", + driver.select( + self.get_query("foreign_keys_by_schema"), schema_name=schema_name, schema_type=ForeignKeyMetadata + ), + ) + self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="foreign_keys") + return cast( + "list[ForeignKeyMetadata]", + driver.select( + self.get_query("foreign_keys_by_table"), + schema_name=schema_name, + table_name=table, + schema_type=ForeignKeyMetadata, + ), + ) + + +def _row_value(row: object, *names: str) -> Any: + """Return the first named value from a row-like object.""" + if isinstance(row, dict): + for name in names: + if name in row: + return row[name] + upper_name = name.upper() + if upper_name in row: + return row[upper_name] + return None + return getattr(row, names[0], None) if names else None diff --git a/sqlspec/adapters/pymssql/driver.py b/sqlspec/adapters/pymssql/driver.py new file mode 100644 index 000000000..dd0695ca2 --- /dev/null +++ b/sqlspec/adapters/pymssql/driver.py @@ -0,0 +1,178 @@ +"""pymssql SQL Server driver implementation.""" + +from collections.abc import Sized +from typing import TYPE_CHECKING, Any, cast + +from sqlspec.adapters.pymssql._typing import PYMSSQL_MODULE, PymssqlCursor, PymssqlSessionContext +from sqlspec.adapters.pymssql.core import ( + collect_rows, + create_mapped_exception, + default_statement_config, + driver_profile, + normalize_execute_many_parameters, + normalize_execute_parameters, + resolve_column_names, + resolve_many_rowcount, + resolve_rowcount, +) +from sqlspec.adapters.pymssql.data_dictionary import PymssqlSyncDataDictionary +from sqlspec.core import get_cache_config, register_driver_profile +from sqlspec.driver import BaseSyncExceptionHandler, SyncDriverAdapterBase +from sqlspec.exceptions import SQLSpecError +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from sqlspec.adapters.pymssql._typing import PymssqlConnection, PymssqlRawCursor + from sqlspec.core import SQL, StatementConfig + from sqlspec.driver import ExecutionResult + +__all__ = ("PymssqlCursor", "PymssqlDriver", "PymssqlExceptionHandler", "PymssqlSessionContext") + +logger = get_logger("sqlspec.adapters.pymssql") +pymssql: Any = PYMSSQL_MODULE + + +class _UnavailablePymssqlError(Exception): + """Fallback pymssql exception base when pymssql is unavailable.""" + + +class PymssqlExceptionHandler(BaseSyncExceptionHandler): + """Context manager for handling pymssql exceptions.""" + + __slots__ = () + + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: + if exc_type is None: + return False + error_type = _pymssql_error_type() + if isinstance(exc_val, error_type): + self.pending_exception = create_mapped_exception(cast("Exception", exc_val), logger=logger) + return True + return False + + +class PymssqlDriver(SyncDriverAdapterBase): + """SQL Server database driver using pymssql.""" + + __slots__ = ("_data_dictionary",) + dialect = "tsql" + + def __init__( + self, + connection: "PymssqlConnection", + statement_config: "StatementConfig | None" = None, + driver_features: "dict[str, Any] | None" = None, + ) -> None: + if statement_config is None: + statement_config = default_statement_config.replace( + enable_caching=get_cache_config().compiled_cache_enabled + ) + + super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) + self._data_dictionary: PymssqlSyncDataDictionary | None = None + + def dispatch_execute(self, cursor: "PymssqlRawCursor", statement: "SQL") -> "ExecutionResult": + sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) + + if statement.returns_rows(): + fetched_data = cursor.fetchall() + description = cursor.description or None + rows, column_names, row_format = collect_rows(fetched_data, description) + return self.create_execution_result( + cursor, + selected_data=rows, + column_names=column_names, + data_row_count=len(rows), + is_select_result=True, + row_format=row_format, + ) + + return self.create_execution_result(cursor, rowcount_override=resolve_rowcount(cursor)) + + def dispatch_execute_many(self, cursor: "PymssqlRawCursor", statement: "SQL") -> "ExecutionResult": + sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + + prepared_parameters = normalize_execute_many_parameters(prepared_parameters) + parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None + cursor.executemany(sql, prepared_parameters) + + affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) + return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) + + def dispatch_execute_script(self, cursor: "PymssqlRawCursor", statement: "SQL") -> "ExecutionResult": + sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + if prepared_parameters and len(statements) > 1: + msg = "execute_script with parameters is not supported for multi-statement scripts; use execute or execute_many for parameterized statements" + raise SQLSpecError(msg) + + successful_count = 0 + for stmt in statements: + cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) + successful_count += 1 + return self.create_execution_result( + cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True + ) + + def begin(self) -> None: + try: + with PymssqlCursor(self.connection) as cursor: + cursor.execute("BEGIN TRANSACTION") + except _pymssql_error_type() as exc: + msg = f"Failed to begin SQL Server transaction: {exc}" + raise SQLSpecError(msg) from exc + + def commit(self) -> None: + try: + self.connection.commit() + except _pymssql_error_type() as exc: + msg = f"Failed to commit SQL Server transaction: {exc}" + raise SQLSpecError(msg) from exc + + def rollback(self) -> None: + try: + self.connection.rollback() + except _pymssql_error_type() as exc: + msg = f"Failed to rollback SQL Server transaction: {exc}" + raise SQLSpecError(msg) from exc + + def with_cursor(self, connection: "PymssqlConnection") -> "PymssqlCursor": + return PymssqlCursor(connection) + + def handle_database_exceptions(self) -> "PymssqlExceptionHandler": + return PymssqlExceptionHandler() + + def create_savepoint(self, name: str) -> None: + self.execute_script(f"SAVE TRANSACTION {name}") + + def release_savepoint(self, name: str) -> None: + return None + + def rollback_to_savepoint(self, name: str) -> None: + self.execute_script(f"ROLLBACK TRANSACTION {name}") + + @property + def data_dictionary(self) -> "PymssqlSyncDataDictionary": + if self._data_dictionary is None: + self._data_dictionary = PymssqlSyncDataDictionary() + return self._data_dictionary + + def collect_rows(self, cursor: "PymssqlRawCursor", fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + column_names = resolve_column_names(cursor.description or None) + return fetched, column_names, len(fetched) + + def resolve_rowcount(self, cursor: "PymssqlRawCursor") -> int: + return resolve_rowcount(cursor) + + def _connection_in_transaction(self) -> bool: + return False + + +def _pymssql_error_type() -> "type[BaseException]": + if pymssql is None: + return _UnavailablePymssqlError + return cast("type[BaseException]", getattr(pymssql, "Error", _UnavailablePymssqlError)) + + +register_driver_profile("pymssql", driver_profile) diff --git a/sqlspec/adapters/pymssql/events/__init__.py b/sqlspec/adapters/pymssql/events/__init__.py new file mode 100644 index 000000000..d49368030 --- /dev/null +++ b/sqlspec/adapters/pymssql/events/__init__.py @@ -0,0 +1,5 @@ +"""pymssql event extension.""" + +from sqlspec.adapters.pymssql.events.store import PymssqlEventQueueStore, PymssqlSyncEventQueueStore + +__all__ = ("PymssqlEventQueueStore", "PymssqlSyncEventQueueStore") diff --git a/sqlspec/adapters/pymssql/events/store.py b/sqlspec/adapters/pymssql/events/store.py new file mode 100644 index 000000000..966ae8cb0 --- /dev/null +++ b/sqlspec/adapters/pymssql/events/store.py @@ -0,0 +1,83 @@ +"""pymssql event queue store with T-SQL-specific DDL.""" + +import re + +from sqlspec.adapters.pymssql.config import PymssqlConfig +from sqlspec.extensions.events import BaseEventQueueStore +from sqlspec.utils.text import split_qualified_identifier + +__all__ = ("PymssqlEventQueueStore", "PymssqlSyncEventQueueStore") + +_NVARCHAR_MAX_THRESHOLD = 4000 +_QUALIFIED_IDENTIFIER_MIN_PARTS = 2 + + +class _PymssqlEventStoreMixin: + """Shared T-SQL DDL hooks for sync and async event queue stores.""" + + __slots__ = () + + def _column_types(self) -> tuple[str, str, str]: + return "NVARCHAR(MAX)", "NVARCHAR(MAX)", "DATETIME2(6)" + + def _string_type(self, length: int) -> str: + if length >= _NVARCHAR_MAX_THRESHOLD: + return "NVARCHAR(MAX)" + return f"NVARCHAR({length})" + + def _integer_type(self) -> str: + return "INT" + + def _timestamp_default(self) -> str: + return "SYSUTCDATETIME()" + + def _wrap_create_statement(self, statement: str, object_type: str) -> str: + if object_type == "table": + match = re.search(r"CREATE TABLE\s+(\S+)", statement, re.IGNORECASE) + if match: + table_name = match.group(1) + return f"IF OBJECT_ID(N'{_object_name(table_name)}', N'U') IS NULL BEGIN {statement}; END" + if object_type == "index": + match = re.search(r"CREATE INDEX\s+(\S+)\s+ON\s+(\S+)", statement, re.IGNORECASE) + if match: + index_name = match.group(1).strip("[]") + table_name = match.group(2) + return ( + "IF NOT EXISTS (SELECT 1 FROM sys.indexes " + f"WHERE name = N'{index_name}' AND object_id = OBJECT_ID(N'{_object_name(table_name)}')) " + f"BEGIN {statement}; END" + ) + return statement + + def _wrap_drop_statement(self, statement: str) -> str: + match = re.search(r"DROP TABLE\s+(\S+)", statement, re.IGNORECASE) + if match: + table_name = match.group(1) + return f"IF OBJECT_ID(N'{_object_name(table_name)}', N'U') IS NOT NULL DROP TABLE {table_name};" + return statement + + +class PymssqlEventQueueStore(_PymssqlEventStoreMixin, BaseEventQueueStore[PymssqlConfig]): + """Event queue DDL for pymssql sync configs.""" + + __slots__ = () + + +PymssqlSyncEventQueueStore = PymssqlEventQueueStore + + +def _split_table_name(table_name: str) -> tuple[str, str]: + parts = split_qualified_identifier(table_name, quote_chars='"') + if len(parts) < _QUALIFIED_IDENTIFIER_MIN_PARTS: + return "dbo", parts[0] if parts else table_name + schema_name = ".".join(parts[:-1]) + return schema_name or "dbo", parts[-1] + + +def _object_name(table_name: str) -> str: + schema_name, bare_table_name = _split_table_name(table_name) + return f"{_quote_bracket_identifier(schema_name)}.{_quote_bracket_identifier(bare_table_name)}" + + +def _quote_bracket_identifier(identifier: str) -> str: + return f"[{identifier.replace(']', ']]')}]" diff --git a/sqlspec/adapters/pymssql/litestar/__init__.py b/sqlspec/adapters/pymssql/litestar/__init__.py new file mode 100644 index 000000000..92a19a0f4 --- /dev/null +++ b/sqlspec/adapters/pymssql/litestar/__init__.py @@ -0,0 +1,5 @@ +"""pymssql Litestar extension.""" + +from sqlspec.adapters.pymssql.litestar.store import PymssqlStore + +__all__ = ("PymssqlStore",) diff --git a/sqlspec/adapters/pymssql/litestar/store.py b/sqlspec/adapters/pymssql/litestar/store.py new file mode 100644 index 000000000..425d8c6b0 --- /dev/null +++ b/sqlspec/adapters/pymssql/litestar/store.py @@ -0,0 +1,255 @@ +"""pymssql Litestar Store implementation.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.pymssql.config import PymssqlConfig + +__all__ = ("PymssqlStore",) + + +class PymssqlStore(BaseSQLSpecStore["PymssqlConfig"]): + """SQL Server-backed session store using pymssql sync sessions.""" + + __slots__ = () + + def __init__(self, config: "PymssqlConfig") -> None: + super().__init__(config) + + def _get_create_table_sql(self) -> str: + """Get SQL Server CREATE TABLE SQL with idempotent guards.""" + return f""" + IF NOT EXISTS ( + SELECT 1 + FROM sys.tables + WHERE name = N'{self._table_name}' + AND schema_id = SCHEMA_ID(N'dbo') + ) + BEGIN + CREATE TABLE {self._table_name} ( + session_id NVARCHAR(255) PRIMARY KEY, + data VARBINARY(MAX) NOT NULL, + expires_at DATETIME2(6) NULL, + created_at DATETIME2(6) NOT NULL DEFAULT SYSUTCDATETIME(), + updated_at DATETIME2(6) NOT NULL DEFAULT SYSUTCDATETIME() + ); + + CREATE INDEX IX_{self._table_name}_expires_at + ON {self._table_name}(expires_at) + WHERE expires_at IS NOT NULL; + END; + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get SQL Server DROP TABLE statements.""" + return [f"IF OBJECT_ID(N'dbo.{self._table_name}', N'U') IS NOT NULL DROP TABLE dbo.{self._table_name};"] + + def _create_table(self) -> None: + with self._config.provide_session() as driver: + driver.execute_script(self._get_create_table_sql()) + driver.commit() + self._log_table_created() + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > SYSUTCDATETIME()) + """ + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key,)) + row = cursor.fetchone() + finally: + cursor.close() + + if row is None: + return None + + expires_at = _normalize_utc(_row_value(row, "expires_at", 1)) + if renew_for is not None and expires_at is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + update_cursor = conn.cursor() + try: + update_cursor.execute( + f""" + UPDATE {self._table_name} + SET expires_at = %s, updated_at = SYSUTCDATETIME() + WHERE session_id = %s + """, + (new_expires_at, key), + ) + finally: + update_cursor.close() + conn.commit() + + return _coerce_bytes(_row_value(row, "data", 0)) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key.""" + return await async_(self._get)(key, renew_for) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + sql = f""" + MERGE INTO {self._table_name} AS target + USING (SELECT %s AS session_id, %s AS data, %s AS expires_at) AS src + ON target.session_id = src.session_id + WHEN MATCHED THEN + UPDATE SET + data = src.data, + expires_at = src.expires_at, + updated_at = SYSUTCDATETIME() + WHEN NOT MATCHED THEN + INSERT (session_id, data, expires_at) + VALUES (src.session_id, src.data, src.expires_at); + """ + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key, data, expires_at)) + finally: + cursor.close() + conn.commit() + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value.""" + await async_(self._set)(key, value, expires_in) + + def _delete(self, key: str) -> None: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(f"DELETE FROM {self._table_name} WHERE session_id = %s", (key,)) + finally: + cursor.close() + conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key.""" + await async_(self._delete)(key) + + def _delete_all(self) -> None: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(f"TRUNCATE TABLE {self._table_name}") + finally: + cursor.close() + conn.commit() + self._log_delete_all() + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + def _exists(self, key: str) -> bool: + sql = f""" + SELECT 1 + FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > SYSUTCDATETIME()) + """ + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key,)) + return cursor.fetchone() is not None + finally: + cursor.close() + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired.""" + return await async_(self._exists)(key) + + def _expires_in(self, key: str) -> "int | None": + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(f"SELECT expires_at FROM {self._table_name} WHERE session_id = %s", (key,)) + row = cursor.fetchone() + finally: + cursor.close() + + if row is None: + return None + expires_at = _normalize_utc(_row_value(row, "expires_at", 0)) + if expires_at is None: + return None + remaining = expires_at - datetime.now(timezone.utc) + return max(0, int(remaining.total_seconds())) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires.""" + return await async_(self._expires_in)(key) + + def _delete_expired(self) -> int: + sql = f""" + DELETE FROM {self._table_name} + WHERE expires_at IS NOT NULL + AND expires_at < SYSUTCDATETIME() + """ + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql) + count = int(getattr(cursor, "rowcount", 0) or 0) + finally: + cursor.close() + conn.commit() + if count > 0: + self._log_delete_expired(count) + return count + + async def delete_expired(self) -> int: + """Delete all expired sessions.""" + return await async_(self._delete_expired)() + + +def _row_value(row: object, key: str, index: int) -> Any: + """Return a value from dict-like or sequence-like driver rows.""" + if isinstance(row, dict): + if key in row: + return row[key] + upper_key = key.upper() + if upper_key in row: + return row[upper_key] + return None + if isinstance(row, (list, tuple)) and len(row) > index: + return row[index] + return getattr(row, key, None) + + +def _normalize_utc(value: Any) -> "datetime | None": + if value is None: + return None + if not isinstance(value, datetime): + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +def _coerce_bytes(value: Any) -> bytes: + if value is None: + return b"" + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + if isinstance(value, str): + return value.encode("utf-8") + return bytes(value) diff --git a/sqlspec/adapters/pymssql/migrations.py b/sqlspec/adapters/pymssql/migrations.py new file mode 100644 index 000000000..28e5f6417 --- /dev/null +++ b/sqlspec/adapters/pymssql/migrations.py @@ -0,0 +1,166 @@ +"""pymssql-specific migration tracker.""" + +import logging +import os +from contextlib import suppress +from typing import TYPE_CHECKING + +from sqlspec.builder import CreateTable, sql +from sqlspec.migrations.tracker import SyncMigrationTracker +from sqlspec.migrations.version import parse_version +from sqlspec.observability import resolve_db_system +from sqlspec.utils.logging import get_logger, log_with_context +from sqlspec.utils.text import split_qualified_identifier + +if TYPE_CHECKING: + from sqlspec.driver import SyncDriverAdapterBase + +__all__ = ("PymssqlSyncMigrationTracker",) + +logger = get_logger("sqlspec.migrations.pymssql") +_QUALIFIED_IDENTIFIER_MIN_PARTS = 2 + + +class PymssqlMigrationTrackerMixin: + """T-SQL-specific migration table DDL and schema maintenance.""" + + __slots__ = () + + version_table: str + + def _get_create_table_sql(self) -> CreateTable: + """Return T-SQL-compatible migration tracking table DDL.""" + return ( + sql + .create_table(self.version_table) + .column("version_num", "NVARCHAR(32)", primary_key=True) + .column("version_type", "NVARCHAR(16)") + .column("execution_sequence", "INT") + .column("description", "NVARCHAR(MAX)") + .column("applied_at", "DATETIME2(6)", default="SYSUTCDATETIME()", not_null=True) + .column("execution_time_ms", "INT") + .column("checksum", "NVARCHAR(64)") + .column("applied_by", "NVARCHAR(255)") + .column("replaces", "NVARCHAR(MAX)") + ) + + def _get_idempotent_create_table_sql_text(self) -> str: + """Wrap CREATE TABLE in a T-SQL sys.tables existence probe.""" + schema_name, table_name = _split_schema_table(self.version_table) + create_sql = self._get_create_table_sql_text().rstrip().rstrip(";") + return ( + "IF NOT EXISTS (SELECT 1 FROM sys.tables " + f"WHERE name = '{_escape_sql_literal(table_name)}' " + f"AND schema_id = SCHEMA_ID('{_escape_sql_literal(schema_name)}')) " + f"BEGIN {create_sql}; END;" + ) + + def _get_create_table_sql_text(self) -> str: + """Render CREATE TABLE text without routing SQL Server types through sqlglot.""" + column_lines: list[str] = [] + for column_def in self._get_create_table_sql().columns: + default_clause = f" DEFAULT {column_def.default}" if column_def.default else "" + not_null_clause = " NOT NULL" if column_def.not_null else "" + primary_key_clause = " PRIMARY KEY" if column_def.primary_key else "" + column_lines.append( + f" {column_def.name} {column_def.dtype}{primary_key_clause}{default_clause}{not_null_clause}" + ) + return f"CREATE TABLE {self.version_table} (\n" + ",\n".join(column_lines) + "\n)" + + def _get_existing_columns_sql(self) -> str: + """Return T-SQL query text for migration tracking table columns.""" + schema_name, table_name = _split_schema_table(self.version_table) + return f""" + SELECT c.name AS column_name + FROM sys.columns c + INNER JOIN sys.tables t ON c.object_id = t.object_id + WHERE t.name = '{_escape_sql_literal(table_name)}' + AND t.schema_id = SCHEMA_ID('{_escape_sql_literal(schema_name)}') + """ + + def _get_add_column_sql_text(self, column_name: str) -> str | None: + """Return T-SQL ALTER TABLE text for a missing migration column.""" + target_create = self._get_create_table_sql() + column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None) + if column_def is None: + return None + default_clause = f" DEFAULT {column_def.default}" if column_def.default else "" + nullable_clause = " NOT NULL" if column_def.not_null else " NULL" + return f"ALTER TABLE {self.version_table} ADD {column_def.name} {column_def.dtype}{default_clause}{nullable_clause};" + + +class PymssqlSyncMigrationTracker(PymssqlMigrationTrackerMixin, SyncMigrationTracker): + """T-SQL sync migration tracker.""" + + def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None: + """Create the migration tracking table if it does not exist.""" + driver.execute_script(self._get_idempotent_create_table_sql_text()) + driver.commit() + self._migrate_schema_if_needed(driver) + + def record_migration( + self, driver: "SyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str + ) -> None: + """Record a successfully applied migration with T-SQL-compatible metadata.""" + parsed_version = parse_version(version) + version_type = parsed_version.type.value + result = driver.execute(self._get_next_execution_sequence_sql()) + next_sequence = result.get_data()[0]["next_seq"] if result.data else 1 + driver.execute( + self._get_record_migration_sql( + version, + version_type, + next_sequence, + description, + execution_time_ms, + checksum, + os.environ.get("USER", "unknown"), + ) + ) + driver.commit() + + def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None: + """Check and add missing tracking table columns through SQL Server catalog views.""" + try: + rows = driver.select(self._get_existing_columns_sql()) + existing_columns = {str(row["column_name"]).lower() for row in rows if row.get("column_name") is not None} + missing_columns = self._detect_missing_columns(existing_columns) + if not missing_columns: + return + for column_name in sorted(missing_columns): + self._add_column(driver, column_name) + driver.commit() + except Exception as exc: + with suppress(Exception): + driver.rollback() + log_with_context( + logger, + logging.ERROR, + "migration.track", + db_system=resolve_db_system(type(driver).__name__), + table=self.version_table, + operation="schema_check", + status="failed", + error_type=type(exc).__name__, + ) + + def _add_column(self, driver: "SyncDriverAdapterBase", column_name: str) -> None: + """Add a single missing migration tracking column.""" + add_column_sql = self._get_add_column_sql_text(column_name) + if add_column_sql is None: + return + driver.execute_script(add_column_sql) + + +def _escape_sql_literal(value: str) -> str: + """Escape a string for inclusion in a T-SQL string literal.""" + return value.replace("'", "''") + + +def _split_schema_table(table_name: str) -> tuple[str, str]: + """Split a schema-qualified table name into schema and table parts.""" + parts = split_qualified_identifier(table_name, quote_chars='"') + if len(parts) < _QUALIFIED_IDENTIFIER_MIN_PARTS: + return "dbo", parts[0] if parts else table_name + schema_name = ".".join(parts[:-1]) + return schema_name or "dbo", parts[-1] diff --git a/sqlspec/adapters/pymssql/pool.py b/sqlspec/adapters/pymssql/pool.py new file mode 100644 index 000000000..3e8396071 --- /dev/null +++ b/sqlspec/adapters/pymssql/pool.py @@ -0,0 +1,175 @@ +"""pymssql database configuration with thread-local connections.""" + +import contextlib +import logging +import threading +import time +import uuid +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, cast + +from sqlspec.adapters.pymssql._typing import PYMSSQL_MODULE, PymssqlConnection +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.utils.logging import POOL_LOGGER_NAME, get_logger, log_with_context + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + +__all__ = ("PymssqlConnectionPool",) + + +logger = get_logger(POOL_LOGGER_NAME) +_ADAPTER_NAME = "pymssql" +pymssql: Any = PYMSSQL_MODULE + + +class PymssqlConnectionPool: + """Thread-local connection manager for pymssql.""" + + __slots__ = ( + "_connection_parameters", + "_health_check_interval", + "_on_connection_create", + "_pool_id", + "_recycle_seconds", + "_thread_local", + ) + + def __init__( + self, + connection_parameters: "dict[str, Any]", + recycle_seconds: int = 86400, + health_check_interval: float = 30.0, + on_connection_create: "Callable[[PymssqlConnection], None] | None" = None, + ) -> None: + """Initialize the thread-local connection manager. + + Args: + connection_parameters: pymssql connection parameters + recycle_seconds: Connection recycle time in seconds (default 24h) + health_check_interval: Seconds of idle time before running health check + on_connection_create: Callback executed when connection is created + """ + self._connection_parameters = connection_parameters + self._thread_local = threading.local() + self._recycle_seconds = recycle_seconds + self._health_check_interval = health_check_interval + self._on_connection_create = on_connection_create + self._pool_id = str(uuid.uuid4())[:8] + + @property + def _database_name(self) -> str: + """Get sanitized database name for logging.""" + return str(self._connection_parameters.get("database", "unknown")) + + def _create_connection(self) -> PymssqlConnection: + if pymssql is None: + msg = "pymssql is not installed. Install SQLSpec with the 'pymssql' extra to use this adapter." + raise ImproperConfigurationError(msg) + connection = pymssql.connect(**self._connection_parameters) + + # Call user-provided callback after connection creation + if self._on_connection_create is not None: + self._on_connection_create(connection) + + return cast("PymssqlConnection", connection) + + def _is_connection_alive(self, connection: PymssqlConnection) -> bool: + try: + cursor = connection.cursor() + try: + cursor.execute("SELECT 1") + cursor.fetchone() + finally: + cursor.close() + except Exception: + return False + return True + + def _get_thread_connection(self) -> PymssqlConnection: + thread_state = self._thread_local.__dict__ + if "connection" not in thread_state: + self._thread_local.connection = self._create_connection() + self._thread_local.created_at = time.time() + self._thread_local.last_used = time.time() + return cast("PymssqlConnection", self._thread_local.connection) + + if self._recycle_seconds > 0 and time.time() - self._thread_local.created_at > self._recycle_seconds: + log_with_context( + logger, + logging.DEBUG, + "pool.connection.recycle", + adapter=_ADAPTER_NAME, + pool_id=self._pool_id, + database=self._database_name, + recycle_seconds=self._recycle_seconds, + reason="exceeded_recycle_time", + ) + with contextlib.suppress(Exception): + self._thread_local.connection.close() + self._thread_local.connection = self._create_connection() + self._thread_local.created_at = time.time() + self._thread_local.last_used = time.time() + return cast("PymssqlConnection", self._thread_local.connection) + + idle_time = time.time() - thread_state.get("last_used", 0) + if idle_time > self._health_check_interval and not self._is_connection_alive(self._thread_local.connection): + log_with_context( + logger, + logging.DEBUG, + "pool.connection.recycle", + adapter=_ADAPTER_NAME, + pool_id=self._pool_id, + database=self._database_name, + idle_seconds=round(idle_time, 1), + reason="failed_health_check", + ) + with contextlib.suppress(Exception): + self._thread_local.connection.close() + self._thread_local.connection = self._create_connection() + self._thread_local.created_at = time.time() + + self._thread_local.last_used = time.time() + return cast("PymssqlConnection", self._thread_local.connection) + + def _close_thread_connection(self) -> None: + thread_state = self._thread_local.__dict__ + if "connection" in thread_state: + with contextlib.suppress(Exception): + self._thread_local.connection.close() + del self._thread_local.connection + if "created_at" in thread_state: + del self._thread_local.created_at + if "last_used" in thread_state: + del self._thread_local.last_used + + @contextmanager + def get_connection(self) -> "Generator[PymssqlConnection, None, None]": + """Get a thread-local connection.""" + connection = self._get_thread_connection() + try: + yield connection + except Exception: + with contextlib.suppress(Exception): + self._close_thread_connection() + raise + + def close(self) -> None: + self._close_thread_connection() + + def acquire(self) -> PymssqlConnection: + return self._get_thread_connection() + + def release(self, connection: PymssqlConnection) -> None: + _ = connection + + def size(self) -> int: + try: + _ = self._thread_local.connection + except AttributeError: + return 0 + else: + return 1 + + def checked_out(self) -> int: + return 0 diff --git a/sqlspec/adapters/pymssql/type_converter.py b/sqlspec/adapters/pymssql/type_converter.py new file mode 100644 index 000000000..3e15b0545 --- /dev/null +++ b/sqlspec/adapters/pymssql/type_converter.py @@ -0,0 +1,92 @@ +"""Type converters for pymssql parameter binding.""" + +from typing import TYPE_CHECKING, Any, Final, cast +from uuid import UUID + +from sqlspec.utils.module_loader import ensure_pyarrow +from sqlspec.utils.serializers import from_json, to_json + +if TYPE_CHECKING: + from collections.abc import Callable + + import pyarrow as pa + +__all__ = ("PymssqlTypeConverter", "mssql_type_to_arrow") + +_MSSQL_ARROW_TYPE_SPECS: Final[dict[str, tuple[str, tuple[Any, ...], dict[str, Any]]]] = { + "bit": ("bool_", (), {}), + "tinyint": ("uint8", (), {}), + "smallint": ("int16", (), {}), + "int": ("int32", (), {}), + "bigint": ("int64", (), {}), + "float": ("float64", (), {}), + "real": ("float32", (), {}), + "smallmoney": ("decimal128", (10, 4), {}), + "money": ("decimal128", (19, 4), {}), + "date": ("date32", (), {}), + "datetime": ("timestamp", ("ms",), {}), + "datetime2": ("timestamp", ("us",), {}), + "smalldatetime": ("timestamp", ("s",), {}), + "datetimeoffset": ("timestamp", ("us",), {"tz": "UTC"}), + "uniqueidentifier": ("string", (), {}), + "xml": ("string", (), {}), + "image": ("binary", (), {}), + "binary": ("binary", (), {}), + "varbinary": ("binary", (), {}), + "timestamp": ("binary", (), {}), + "rowversion": ("binary", (), {}), + "char": ("string", (), {}), + "varchar": ("string", (), {}), + "nchar": ("string", (), {}), + "nvarchar": ("string", (), {}), + "text": ("string", (), {}), + "ntext": ("string", (), {}), +} + + +class PymssqlTypeConverter: + """Utility converter for explicit pymssql value coercion. + + The driver pipeline builds its internal coercions directly from feature + flags. This class is a public utility for callers that need per-value bind + or result coercion outside the driver execution path. + """ + + __slots__ = ("_json_deserializer", "_json_serializer") + + def __init__( + self, json_serializer: "Callable[[Any], str]" = to_json, json_deserializer: "Callable[[str], Any]" = from_json + ) -> None: + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + + def coerce_bind_value(self, value: "Any") -> "Any": + """Coerce Python values before pymssql parameter binding.""" + if isinstance(value, (dict, list)): + return self._json_serializer(value) + if isinstance(value, UUID): + return value + return value + + def coerce_read_value(self, value: "Any") -> "Any": + """Coerce pymssql result values after fetching.""" + return value + + +def mssql_type_to_arrow(sql_type: str, *, precision: int | None = None, scale: int | None = None) -> "pa.DataType": + """Resolve a T-SQL type name to an Arrow data type.""" + normalized_type = sql_type.lower().split("(", 1)[0].strip() + if normalized_type in {"decimal", "numeric"} and precision is not None and scale is not None: + return _arrow_type("decimal128", (precision, scale)) + spec = _MSSQL_ARROW_TYPE_SPECS.get(normalized_type) + if spec is None: + return _arrow_type("string") + name, args, kwargs = spec + return _arrow_type(name, args, kwargs) + + +def _arrow_type(name: str, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None) -> "pa.DataType": + ensure_pyarrow() + import pyarrow as pa + + return cast("pa.DataType", getattr(pa, name)(*args, **(kwargs or {}))) diff --git a/tests/integration/adapters/contracts/_cases.py b/tests/integration/adapters/contracts/_cases.py index 97d553842..a8abc780c 100644 --- a/tests/integration/adapters/contracts/_cases.py +++ b/tests/integration/adapters/contracts/_cases.py @@ -662,6 +662,20 @@ class DriverCaseContext: integration_status="deferred", reason="No active integration fixture exists for mssql_python.", ), + DriverCase( + "pymssql-sync", + "", + "pymssql", + "tsql", + "sync", + integration_status="deferred", + reason="No active SQL Server fixture exists for pymssql.", + supports_execute_many=True, + supports_migrations=True, + supports_pooling=True, + supports_connection_hook=True, + supports_data_dictionary=True, + ), DriverCase( "spanner-sync", "", diff --git a/tests/unit/adapters/test_pymssql/__init__.py b/tests/unit/adapters/test_pymssql/__init__.py new file mode 100644 index 000000000..4d6a28a19 --- /dev/null +++ b/tests/unit/adapters/test_pymssql/__init__.py @@ -0,0 +1 @@ +"""Unit tests for the pymssql adapter.""" diff --git a/tests/unit/adapters/test_pymssql/conftest.py b/tests/unit/adapters/test_pymssql/conftest.py new file mode 100644 index 000000000..e0fbaff81 --- /dev/null +++ b/tests/unit/adapters/test_pymssql/conftest.py @@ -0,0 +1,92 @@ +"""Shared pymssql adapter test doubles.""" + +from typing import Any + + +class FakePymssqlError(Exception): + """Base fake pymssql exception.""" + + +class FakePymssqlOperationalError(FakePymssqlError): + """Fake operational error.""" + + +class FakePymssqlIntegrityError(FakePymssqlError): + """Fake integrity error.""" + + +class FakeCursor: + """Minimal DB-API cursor for pymssql unit tests.""" + + def __init__( + self, rows: "list[Any] | None" = None, description: "list[tuple[str, ...]] | None" = None, rowcount: int = -1 + ) -> None: + self.rows = rows or [] + self.description = description + self.rowcount = rowcount + self.closed = False + self.calls: list[tuple[str, Any]] = [] + self.many_calls: list[tuple[str, Any]] = [] + + def execute(self, sql: str, parameters: Any = None) -> None: + self.calls.append((sql, parameters)) + + def executemany(self, sql: str, parameters: Any = None) -> None: + self.many_calls.append((sql, parameters)) + + def fetchall(self) -> "list[Any]": + return self.rows + + def fetchone(self) -> Any: + return self.rows[0] if self.rows else None + + def close(self) -> None: + self.closed = True + + +class FakeConnection: + """Minimal pymssql connection for config, pool, and driver tests.""" + + def __init__(self, cursor: "FakeCursor | None" = None) -> None: + self.cursor_obj = cursor or FakeCursor() + self.closed = False + self.commits = 0 + self.rollbacks = 0 + self.autocommit_values: list[bool] = [] + + def cursor(self, *args: Any, **kwargs: Any) -> FakeCursor: + self.cursor_args = args + self.cursor_kwargs = kwargs + return self.cursor_obj + + def commit(self) -> None: + self.commits += 1 + + def rollback(self) -> None: + self.rollbacks += 1 + + def close(self) -> None: + self.closed = True + + def autocommit(self, value: bool) -> None: + self.autocommit_values.append(value) + + +class FakePymssqlModule: + """Patch target that behaves like the pymssql module surface used by SQLSpec.""" + + Error = FakePymssqlError + OperationalError = FakePymssqlOperationalError + IntegrityError = FakePymssqlIntegrityError + DatabaseError = FakePymssqlError + DataError = FakePymssqlError + InterfaceError = FakePymssqlError + ProgrammingError = FakePymssqlError + + def __init__(self, connection: "FakeConnection | None" = None) -> None: + self.connection = connection or FakeConnection() + self.connect_calls: list[dict[str, Any]] = [] + + def connect(self, **kwargs: Any) -> FakeConnection: + self.connect_calls.append(kwargs) + return self.connection diff --git a/tests/unit/adapters/test_pymssql/test_config.py b/tests/unit/adapters/test_pymssql/test_config.py new file mode 100644 index 000000000..d365aa50a --- /dev/null +++ b/tests/unit/adapters/test_pymssql/test_config.py @@ -0,0 +1,96 @@ +"""pymssql configuration tests.""" + +from typing import get_type_hints + +from sqlspec.adapters.pymssql.config import PymssqlConfig, PymssqlConnectionParams +from sqlspec.adapters.pymssql.driver import PymssqlDriver +from sqlspec.adapters.pymssql.pool import PymssqlConnectionPool + + +def test_connection_params_cover_common_pymssql_keywords() -> None: + """Typed connection params should cover the common pymssql connect kwargs.""" + annotations = get_type_hints(PymssqlConnectionParams, include_extras=True) + + expected_keys = { + "server", + "host", + "user", + "password", + "database", + "port", + "timeout", + "login_timeout", + "charset", + "as_dict", + "appname", + "conn_properties", + "autocommit", + "tds_version", + "pool_recycle_seconds", + "health_check_interval", + } + + assert expected_keys <= set(annotations) + + +def test_config_defaults_server_port_and_features() -> None: + """PymssqlConfig should normalize connection defaults and driver features.""" + config = PymssqlConfig(connection_config={}, driver_features={"enable_events": True}) + + assert config.connection_config["server"] == "localhost" + assert config.connection_config["port"] == 1433 + assert config.driver_type is PymssqlDriver + assert config.supports_transactional_ddl is True + assert config.supports_native_arrow_export is False + assert config.driver_features["enable_events"] is True + + +def test_config_create_pool_splits_pool_options_and_hook() -> None: + """Pool lifecycle options should not be forwarded to pymssql.connect.""" + seen: list[object] = [] + config = PymssqlConfig( + connection_config={ + "server": "sql.example.test", + "user": "sa", + "password": "secret", + "database": "app", + "pool_recycle_seconds": 5, + "health_check_interval": 0.25, + }, + driver_features={"on_connection_create": seen.append}, + ) + + pool = config.create_pool() + + assert isinstance(pool, PymssqlConnectionPool) + assert pool._connection_parameters == { + "server": "sql.example.test", + "user": "sa", + "password": "secret", + "database": "app", + "port": 1433, + } + assert pool._recycle_seconds == 5 + assert pool._health_check_interval == 0.25 + assert "on_connection_create" not in config.driver_features + + +def test_config_close_pool_clears_connection_instance() -> None: + """Closing the config pool should clear the stored pool reference.""" + config = PymssqlConfig(connection_instance=PymssqlConnectionPool({})) + + config._close_pool() + + assert config.connection_instance is None + + +def test_signature_namespace_exposes_public_adapter_types() -> None: + """Config signature namespaces should include public pymssql types.""" + config = PymssqlConfig(connection_config={"server": "localhost"}) + + namespace = config.get_signature_namespace() + + assert namespace["PymssqlConfig"] is PymssqlConfig + assert namespace["PymssqlConnectionParams"] is PymssqlConnectionParams + assert namespace["PymssqlConnectionPool"] is PymssqlConnectionPool + assert namespace["PymssqlDriver"] is PymssqlDriver diff --git a/tests/unit/adapters/test_pymssql/test_core.py b/tests/unit/adapters/test_pymssql/test_core.py new file mode 100644 index 000000000..81435a2d2 --- /dev/null +++ b/tests/unit/adapters/test_pymssql/test_core.py @@ -0,0 +1,88 @@ +"""pymssql core helper tests.""" + +from typing import Any + +import pytest + +from sqlspec.core import SQL, ParameterStyle +from sqlspec.exceptions import DatabaseConnectionError, UniqueViolationError + + +def test_profile_uses_tsql_and_pyformat_execution() -> None: + """The pymssql profile should compile T-SQL to pyformat placeholders.""" + from sqlspec.adapters.pymssql.core import default_statement_config, driver_profile + + parameter_config = default_statement_config.parameter_config + + assert default_statement_config.dialect == "tsql" + assert driver_profile.default_execution_style is ParameterStyle.POSITIONAL_PYFORMAT + assert ParameterStyle.POSITIONAL_PYFORMAT in driver_profile.supported_execution_styles + assert ParameterStyle.NAMED_PYFORMAT in driver_profile.supported_execution_styles + assert parameter_config.default_execution_parameter_style is ParameterStyle.POSITIONAL_PYFORMAT + assert ParameterStyle.POSITIONAL_PYFORMAT in parameter_config.supported_execution_parameter_styles + assert ParameterStyle.NAMED_PYFORMAT in parameter_config.supported_execution_parameter_styles + + +def test_statement_config_compiles_qmark_input_to_percent_s() -> None: + """Qmark input should execute as positional pyformat for pymssql.""" + from sqlspec.adapters.pymssql.core import default_statement_config + + statement = SQL("SELECT * FROM dbo.users WHERE id = ?", 3, statement_config=default_statement_config) + + compiled_sql, parameters = statement.compile() + + assert "WHERE id = %s" in compiled_sql + assert parameters in ([3], (3,)) + + +def test_statement_config_preserves_named_pyformat_input() -> None: + """Named pyformat input should remain a mapping for pymssql.""" + from sqlspec.adapters.pymssql.core import default_statement_config + + statement = SQL( + "SELECT * FROM dbo.users WHERE id = %(user_id)s", {"user_id": 3}, statement_config=default_statement_config + ) + + compiled_sql, parameters = statement.compile() + + assert "WHERE id = %(user_id)s" in compiled_sql + assert parameters == {"user_id": 3} + + +def test_format_identifier_and_insert_statement_use_tsql_identifiers() -> None: + """Generated DML helpers should quote T-SQL identifiers and use %s placeholders.""" + from sqlspec.adapters.pymssql.core import build_insert_statement, format_identifier + + assert format_identifier("dbo.users") == "[dbo].[users]" + assert format_identifier("[sales].[order]]items]") == "[sales].[order]]items]" + assert build_insert_statement("dbo.users", ["id", "display_name"]) == ( + "INSERT INTO [dbo].[users] ([id], [display_name]) VALUES (%s, %s)" + ) + + +@pytest.mark.parametrize( + ("message", "expected_type"), + [ + ("Violation of UNIQUE KEY constraint (2627)", UniqueViolationError), + ("Cannot open database requested by the login (4060)", DatabaseConnectionError), + ], +) +def test_create_mapped_exception_maps_tsql_error_numbers(message: str, expected_type: type[Exception]) -> None: + """SQL Server error numbers should map to SQLSpec exceptions.""" + from sqlspec.adapters.pymssql.core import create_mapped_exception + + exc = create_mapped_exception(Exception(message)) + + assert isinstance(exc, expected_type) + assert "SQL Server error" in str(exc) + + +def test_normalize_execute_many_parameters_requires_payload() -> None: + """execute_many should reject missing batch parameters before hitting pymssql.""" + from sqlspec.adapters.pymssql.core import normalize_execute_many_parameters + + with pytest.raises(ValueError, match="execute_many requires parameters"): + normalize_execute_many_parameters([]) + + rows: list[tuple[Any, ...]] = [(1,), (2,)] + assert normalize_execute_many_parameters(rows) is rows diff --git a/tests/unit/adapters/test_pymssql/test_driver.py b/tests/unit/adapters/test_pymssql/test_driver.py new file mode 100644 index 000000000..179881975 --- /dev/null +++ b/tests/unit/adapters/test_pymssql/test_driver.py @@ -0,0 +1,117 @@ +"""pymssql driver tests.""" + +import pytest + +from sqlspec.core import SQL +from sqlspec.exceptions import SQLSpecError, UniqueViolationError +from tests.unit.adapters.test_pymssql.conftest import ( + FakeConnection, + FakeCursor, + FakePymssqlIntegrityError, + FakePymssqlModule, +) + + +def test_dispatch_execute_select_compiles_to_pyformat_and_collects_rows() -> None: + """SELECT dispatch should execute pyformat SQL and return fetched rows.""" + from sqlspec.adapters.pymssql.core import default_statement_config + from sqlspec.adapters.pymssql.driver import PymssqlDriver + + cursor = FakeCursor(rows=[(1, "Ada")], description=[("id",), ("name",)]) + driver = PymssqlDriver(FakeConnection(cursor), statement_config=default_statement_config) + statement = SQL("SELECT id, name FROM dbo.users WHERE id = ?", 1, statement_config=default_statement_config) + + result = driver.dispatch_execute(cursor, statement) + + assert cursor.calls == [("SELECT id, name FROM dbo.users WHERE id = %s", (1,))] + assert result.selected_data == [(1, "Ada")] + assert result.column_names == ["id", "name"] + assert result.data_row_count == 1 + + +def test_dispatch_execute_many_uses_executemany_and_rowcount() -> None: + """execute_many dispatch should forward batch parameters to pymssql.""" + from sqlspec.adapters.pymssql.core import default_statement_config + from sqlspec.adapters.pymssql.driver import PymssqlDriver + + cursor = FakeCursor(rowcount=2) + driver = PymssqlDriver(FakeConnection(cursor), statement_config=default_statement_config) + statement = SQL( + "INSERT INTO dbo.users (id) VALUES (?)", [(1,), (2,)], statement_config=default_statement_config, is_many=True + ) + + result = driver.dispatch_execute_many(cursor, statement) + + assert cursor.many_calls == [("INSERT INTO dbo.users (id) VALUES (%s)", [(1,), (2,)])] + assert result.rowcount_override == 2 + assert result.is_many_result is True + + +def test_transaction_methods_use_tsql_begin_and_connection_commit_rollback() -> None: + """Transaction operations should use pymssql-compatible calls.""" + from sqlspec.adapters.pymssql.driver import PymssqlDriver + + cursor = FakeCursor() + connection = FakeConnection(cursor) + driver = PymssqlDriver(connection) + + driver.begin() + driver.commit() + driver.rollback() + + assert cursor.calls == [("BEGIN TRANSACTION", None)] + assert connection.commits == 1 + assert connection.rollbacks == 1 + + +def test_exception_handler_maps_pymssql_errors(monkeypatch: pytest.MonkeyPatch) -> None: + """pymssql exception handlers should surface mapped SQLSpec exceptions.""" + import sqlspec.adapters.pymssql.driver as driver_module + from sqlspec.adapters.pymssql.driver import PymssqlExceptionHandler + + monkeypatch.setattr(driver_module, "pymssql", FakePymssqlModule()) + handler = PymssqlExceptionHandler() + + handled = handler._handle_exception( + FakePymssqlIntegrityError, FakePymssqlIntegrityError("Violation of UNIQUE KEY constraint (2627)") + ) + + assert handled is True + assert isinstance(handler.pending_exception, UniqueViolationError) + + +def test_commit_wraps_driver_errors(monkeypatch: pytest.MonkeyPatch) -> None: + """Commit failures should be wrapped in SQLSpecError.""" + import sqlspec.adapters.pymssql.driver as driver_module + from sqlspec.adapters.pymssql.driver import PymssqlDriver + + class FailingConnection(FakeConnection): + def commit(self) -> None: + raise FakePymssqlIntegrityError("commit failed") + + monkeypatch.setattr(driver_module, "pymssql", FakePymssqlModule()) + driver = PymssqlDriver(FailingConnection()) + + with pytest.raises(SQLSpecError, match="Failed to commit SQL Server transaction"): + driver.commit() + + +def test_collect_rows_returns_column_names() -> None: + """The direct row collection hook should match SyncDriverAdapterBase expectations.""" + from sqlspec.adapters.pymssql.driver import PymssqlDriver + + cursor = FakeCursor(description=[("id",), ("name",)]) + driver = PymssqlDriver(FakeConnection(cursor)) + + rows, column_names, row_count = driver.collect_rows(cursor, [(1, "Ada")]) + + assert rows == [(1, "Ada")] + assert column_names == ["id", "name"] + assert row_count == 1 + + +def test_connection_in_transaction_is_false_without_supported_state() -> None: + """pymssql does not expose a reliable transaction-state flag.""" + from sqlspec.adapters.pymssql.driver import PymssqlDriver + + assert PymssqlDriver(FakeConnection())._connection_in_transaction() is False diff --git a/tests/unit/adapters/test_pymssql/test_extensions.py b/tests/unit/adapters/test_pymssql/test_extensions.py new file mode 100644 index 000000000..db3e49640 --- /dev/null +++ b/tests/unit/adapters/test_pymssql/test_extensions.py @@ -0,0 +1,70 @@ +"""pymssql extension package tests.""" + +import pytest + +from sqlspec.adapters.pymssql.config import PymssqlConfig + + +def test_event_store_uses_tsql_column_types_and_idempotent_wrappers() -> None: + """Event queue DDL should use SQL Server types and object-existence guards.""" + from sqlspec.adapters.pymssql.events.store import PymssqlEventQueueStore + + store = PymssqlEventQueueStore(PymssqlConfig(extension_config={"events": {"queue_table": "event_queue"}})) + + assert store._column_types() == ("NVARCHAR(MAX)", "NVARCHAR(MAX)", "DATETIME2(6)") + assert store._timestamp_default() == "SYSUTCDATETIME()" + assert "OBJECT_ID" in store._wrap_create_statement("CREATE TABLE event_queue (id INT)", "table") + assert "sys.indexes" in store._wrap_create_statement("CREATE INDEX idx_events ON event_queue (channel)", "index") + + +def test_litestar_store_ddl_is_tsql_idempotent() -> None: + """Litestar store DDL should be SQL Server-specific and idempotent.""" + from sqlspec.adapters.pymssql.litestar.store import PymssqlStore + + store = PymssqlStore(PymssqlConfig(extension_config={"litestar": {"session_table": "litestar_session"}})) + ddl = store._get_create_table_sql() + + assert "IF NOT EXISTS" in ddl + assert "CREATE TABLE litestar_session" in ddl + assert "VARBINARY(MAX)" in ddl + assert "SYSUTCDATETIME()" in ddl + assert "%s" not in ddl + + +@pytest.mark.anyio +async def test_litestar_store_async_methods_bridge_sync_operations(monkeypatch: pytest.MonkeyPatch) -> None: + """The async Litestar interface should bridge to sync methods through async_.""" + from sqlspec.adapters.pymssql.litestar.store import PymssqlStore + + calls: list[str] = [] + monkeypatch.setattr(PymssqlStore, "_create_table", lambda self: calls.append("create")) + store = PymssqlStore(PymssqlConfig(extension_config={"litestar": {"session_table": "litestar_session"}})) + + await store.create_table() + + assert calls == ["create"] + + +def test_adk_store_ddl_uses_tsql_tables_and_json_fallback() -> None: + """ADK DDL should use T-SQL table shape and NVARCHAR JSON fallback by default.""" + from sqlspec.adapters.pymssql.adk.store import PymssqlADKStore + + store = PymssqlADKStore(PymssqlConfig(extension_config={"adk": {}})) + + sessions_ddl = store._get_create_sessions_table_sql() + events_ddl = store._get_create_events_table_sql() + + assert "CREATE TABLE" in sessions_ddl + assert "NVARCHAR(MAX)" in sessions_ddl + assert "SYSUTCDATETIME()" in sessions_ddl + assert "event_data" in events_ddl + assert "DATETIME2(6)" in events_ddl + + +def test_adk_store_can_force_native_json_column_type() -> None: + """ADK config should allow native SQL Server JSON columns when requested.""" + from sqlspec.adapters.pymssql.adk.store import PymssqlADKStore + + store = PymssqlADKStore(PymssqlConfig(extension_config={"adk": {"native_json": True}})) + + assert "state JSON NOT NULL" in store._get_create_sessions_table_sql() diff --git a/tests/unit/adapters/test_pymssql/test_pool.py b/tests/unit/adapters/test_pymssql/test_pool.py new file mode 100644 index 000000000..cb1253651 --- /dev/null +++ b/tests/unit/adapters/test_pymssql/test_pool.py @@ -0,0 +1,72 @@ +"""pymssql pool tests.""" + +from typing import Any + +from tests.unit.adapters.test_pymssql.conftest import FakeConnection, FakePymssqlModule + + +def test_pool_connects_with_config_and_runs_hook(monkeypatch) -> None: + """The pool should create pymssql connections lazily and call the hook.""" + import sqlspec.adapters.pymssql.pool as pool_module + from sqlspec.adapters.pymssql.pool import PymssqlConnectionPool + + connection = FakeConnection() + fake_module = FakePymssqlModule(connection) + seen: list[FakeConnection] = [] + monkeypatch.setattr(pool_module, "pymssql", fake_module) + + pool = PymssqlConnectionPool( + {"server": "sql.example.test", "user": "sa"}, + recycle_seconds=0, + health_check_interval=999.0, + on_connection_create=seen.append, + ) + + acquired = pool.acquire() + + assert acquired is connection + assert fake_module.connect_calls == [{"server": "sql.example.test", "user": "sa"}] + assert seen == [connection] + assert pool.size() == 1 + + +def test_pool_recycles_failed_health_check(monkeypatch) -> None: + """A failed idle health check should close and replace the thread-local connection.""" + import sqlspec.adapters.pymssql.pool as pool_module + from sqlspec.adapters.pymssql.pool import PymssqlConnectionPool + + first = FakeConnection() + second = FakeConnection() + connections = [first, second] + + class SequencedModule(FakePymssqlModule): + def connect(self, **kwargs: Any) -> FakeConnection: + self.connect_calls.append(kwargs) + return connections.pop(0) + + fake_module = SequencedModule() + monkeypatch.setattr(pool_module, "pymssql", fake_module) + monkeypatch.setattr(PymssqlConnectionPool, "_is_connection_alive", lambda *_: False) + + pool = PymssqlConnectionPool({"server": "sql.example.test"}, health_check_interval=-1.0) + + assert pool.acquire() is first + assert pool.acquire() is second + assert first.closed is True + assert len(fake_module.connect_calls) == 2 + + +def test_pool_close_removes_thread_local_connection(monkeypatch) -> None: + """close() should close and forget the current thread's connection.""" + import sqlspec.adapters.pymssql.pool as pool_module + from sqlspec.adapters.pymssql.pool import PymssqlConnectionPool + + connection = FakeConnection() + monkeypatch.setattr(pool_module, "pymssql", FakePymssqlModule(connection)) + pool = PymssqlConnectionPool({"server": "sql.example.test"}) + + assert pool.acquire() is connection + pool.close() + + assert connection.closed is True + assert pool.size() == 0 diff --git a/tests/unit/adapters/test_pymssql/test_wiring.py b/tests/unit/adapters/test_pymssql/test_wiring.py new file mode 100644 index 000000000..2fb34cc85 --- /dev/null +++ b/tests/unit/adapters/test_pymssql/test_wiring.py @@ -0,0 +1,34 @@ +"""pymssql contract and docs wiring tests.""" + +from pathlib import Path + +from tests.integration.adapters.contracts._cases import get_driver_case + + +def test_deferred_driver_case_is_registered_with_sync_metadata() -> None: + """pymssql should be visible to the contract registry even before a live fixture exists.""" + case = get_driver_case("pymssql-sync") + + assert case.adapter == "pymssql" + assert case.dialect == "tsql" + assert case.mode == "sync" + assert case.integration_status == "deferred" + assert case.supports_migrations is True + assert case.supports_pooling is True + assert case.supports_connection_hook is True + assert case.supports_execute_many is True + assert case.supports_arrow is False + + +def test_reference_docs_include_pymssql_page_and_index_entry() -> None: + """The adapter reference should include pymssql.""" + docs_root = Path("docs/reference/adapters") + + assert (docs_root / "pymssql.rst").is_file() + index = (docs_root / "index.rst").read_text() + page = (docs_root / "pymssql.rst").read_text() + + assert ":link: pymssql" in index + assert "pymssql" in index.split(".. toctree::", 1)[1] + assert "PymssqlConfig" in page + assert "PymssqlDriver" in page From 9781ee9d70bd7ea2ef2bed22d3b59abab6d97591 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 2 Jul 2026 02:11:07 +0000 Subject: [PATCH 2/4] fix(pymssql): satisfy adapter lint checks --- sqlspec/adapters/pymssql/adk/store.py | 2 +- sqlspec/adapters/pymssql/config.py | 3 --- tests/unit/adapters/test_pymssql/test_core.py | 12 ++++++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sqlspec/adapters/pymssql/adk/store.py b/sqlspec/adapters/pymssql/adk/store.py index c091592d1..d137fdb9b 100644 --- a/sqlspec/adapters/pymssql/adk/store.py +++ b/sqlspec/adapters/pymssql/adk/store.py @@ -443,7 +443,7 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object inserted = 0 with self._config.provide_connection() as conn, PymssqlCursor(conn) as cursor: for entry in entries: - params = ( + params: tuple[Any, ...] = ( entry["event_id"], entry["id"], entry["session_id"], diff --git a/sqlspec/adapters/pymssql/config.py b/sqlspec/adapters/pymssql/config.py index 9c01e80c5..f99050a06 100644 --- a/sqlspec/adapters/pymssql/config.py +++ b/sqlspec/adapters/pymssql/config.py @@ -53,9 +53,6 @@ class PymssqlConnectionParams(TypedDict): class PymssqlPoolParams(PymssqlConnectionParams): """pymssql pool parameters.""" - pool_recycle_seconds: NotRequired[int] - health_check_interval: NotRequired[float] - class PymssqlDriverFeatures(TypedDict): """pymssql driver feature flags. diff --git a/tests/unit/adapters/test_pymssql/test_core.py b/tests/unit/adapters/test_pymssql/test_core.py index 81435a2d2..5f53a2acf 100644 --- a/tests/unit/adapters/test_pymssql/test_core.py +++ b/tests/unit/adapters/test_pymssql/test_core.py @@ -16,11 +16,15 @@ def test_profile_uses_tsql_and_pyformat_execution() -> None: assert default_statement_config.dialect == "tsql" assert driver_profile.default_execution_style is ParameterStyle.POSITIONAL_PYFORMAT - assert ParameterStyle.POSITIONAL_PYFORMAT in driver_profile.supported_execution_styles - assert ParameterStyle.NAMED_PYFORMAT in driver_profile.supported_execution_styles + supported_execution_styles = driver_profile.supported_execution_styles + assert supported_execution_styles is not None + assert ParameterStyle.POSITIONAL_PYFORMAT in supported_execution_styles + assert ParameterStyle.NAMED_PYFORMAT in supported_execution_styles assert parameter_config.default_execution_parameter_style is ParameterStyle.POSITIONAL_PYFORMAT - assert ParameterStyle.POSITIONAL_PYFORMAT in parameter_config.supported_execution_parameter_styles - assert ParameterStyle.NAMED_PYFORMAT in parameter_config.supported_execution_parameter_styles + supported_execution_parameter_styles = parameter_config.supported_execution_parameter_styles + assert supported_execution_parameter_styles is not None + assert ParameterStyle.POSITIONAL_PYFORMAT in supported_execution_parameter_styles + assert ParameterStyle.NAMED_PYFORMAT in supported_execution_parameter_styles def test_statement_config_compiles_qmark_input_to_percent_s() -> None: From d31163cef62dcb59b9698cde5b0edb2539d933f4 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 2 Jul 2026 02:20:16 +0000 Subject: [PATCH 3/4] fix(pymssql): align deferred lifecycle metadata --- tests/integration/adapters/contracts/_cases.py | 2 -- tests/unit/adapters/test_pymssql/test_wiring.py | 5 +++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integration/adapters/contracts/_cases.py b/tests/integration/adapters/contracts/_cases.py index a8abc780c..c1ecd723f 100644 --- a/tests/integration/adapters/contracts/_cases.py +++ b/tests/integration/adapters/contracts/_cases.py @@ -672,8 +672,6 @@ class DriverCaseContext: reason="No active SQL Server fixture exists for pymssql.", supports_execute_many=True, supports_migrations=True, - supports_pooling=True, - supports_connection_hook=True, supports_data_dictionary=True, ), DriverCase( diff --git a/tests/unit/adapters/test_pymssql/test_wiring.py b/tests/unit/adapters/test_pymssql/test_wiring.py index 2fb34cc85..b12332938 100644 --- a/tests/unit/adapters/test_pymssql/test_wiring.py +++ b/tests/unit/adapters/test_pymssql/test_wiring.py @@ -14,9 +14,10 @@ def test_deferred_driver_case_is_registered_with_sync_metadata() -> None: assert case.mode == "sync" assert case.integration_status == "deferred" assert case.supports_migrations is True - assert case.supports_pooling is True - assert case.supports_connection_hook is True + assert case.supports_pooling is False + assert case.supports_connection_hook is False assert case.supports_execute_many is True + assert case.supports_data_dictionary is True assert case.supports_arrow is False From dc4f5e1ccf60a9f496c24b1b28adc7fa454bbb1e Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 2 Jul 2026 03:16:17 +0000 Subject: [PATCH 4/4] fix(pymssql): type custom coercion map for mypyc --- sqlspec/adapters/pymssql/core.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sqlspec/adapters/pymssql/core.py b/sqlspec/adapters/pymssql/core.py index 2ecf47add..a66a3ac88 100644 --- a/sqlspec/adapters/pymssql/core.py +++ b/sqlspec/adapters/pymssql/core.py @@ -1,7 +1,7 @@ """pymssql adapter compiled helpers.""" import re -from collections.abc import Sized +from collections.abc import Callable, Sized from typing import TYPE_CHECKING, Any, Final from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile @@ -24,7 +24,7 @@ from sqlspec.utils.type_guards import has_rowcount if TYPE_CHECKING: - from collections.abc import Callable, Mapping, Sequence + from collections.abc import Mapping, Sequence from logging import Logger __all__ = ( @@ -98,6 +98,13 @@ def _bool_to_int(value: bool) -> int: return int(value) +def _build_pymssql_custom_type_coercions() -> dict[type, Callable[[Any], Any]]: + """Return custom type coercions for pymssql.""" + coercions: dict[type, Callable[[Any], Any]] = {bool: _bool_to_int} + coercions.update(build_uuid_coercions()) + return coercions + + def build_profile() -> "DriverParameterProfile": """Create the pymssql driver parameter profile.""" return DriverParameterProfile( @@ -112,7 +119,7 @@ def build_profile() -> "DriverParameterProfile": allow_mixed_parameter_styles=False, preserve_original_params_for_many=False, json_serializer_strategy="helper", - custom_type_coercions={bool: _bool_to_int, **build_uuid_coercions()}, + custom_type_coercions=_build_pymssql_custom_type_coercions(), default_dialect="tsql", )