From fa0aca2075c250d6b511038963b9320b49028054 Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Sun, 29 Mar 2026 21:25:46 +0000 Subject: [PATCH 1/8] =?UTF-8?q?=E2=9C=A8=20feat(sqlserver):=20Add=20mssql-?= =?UTF-8?q?python=20backend?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .devcontainer/setup_env.sh | 5 + .../workflows/integration-tests-sqlserver.yml | 45 ++- CONTRIBUTING.md | 22 ++ README.md | 67 +++- .../sqlserver/sqlserver_connections.py | 361 ++++++++++++------ .../sqlserver/sqlserver_credentials.py | 19 +- dev_requirements.txt | 1 + setup.py | 13 +- test.env.sample | 1 + tests/conftest.py | 2 + .../test_sqlserver_connection_manager.py | 153 +++++++- 11 files changed, 560 insertions(+), 129 deletions(-) diff --git a/.devcontainer/setup_env.sh b/.devcontainer/setup_env.sh index 5e6fa9352..50d6a5e0c 100644 --- a/.devcontainer/setup_env.sh +++ b/.devcontainer/setup_env.sh @@ -1,5 +1,10 @@ +set -eu + cp test.env.sample test.env +sudo apt-get update +sudo apt-get install -y libltdl7 libkrb5-3 libgssapi-krb5-2 + docker compose build docker compose up -d diff --git a/.github/workflows/integration-tests-sqlserver.yml b/.github/workflows/integration-tests-sqlserver.yml index 3f6a8d976..57ca204a6 100644 --- a/.github/workflows/integration-tests-sqlserver.yml +++ b/.github/workflows/integration-tests-sqlserver.yml @@ -11,7 +11,7 @@ on: # yamllint disable-line rule:truthy - master - v* schedule: - - cron: '0 22 * * 0' + - cron: "0 22 * * 0" jobs: integration-tests-sql-server: @@ -21,7 +21,8 @@ jobs: python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] msodbc_version: ["17", "18"] sqlserver_version: ["2017", "2019", "2022"] - collation: ["SQL_Latin1_General_CP1_CS_AS", "SQL_Latin1_General_CP1_CI_AS"] + collation: + ["SQL_Latin1_General_CP1_CS_AS", "SQL_Latin1_General_CP1_CI_AS"] runs-on: ubuntu-latest container: image: ghcr.io/${{ github.repository }}:CI-${{ matrix.python_version }}-msodbc${{ matrix.msodbc_version }} @@ -29,7 +30,7 @@ jobs: sqlserver: image: ghcr.io/${{ github.repository }}:server-${{ matrix.sqlserver_version }} env: - ACCEPT_EULA: 'Y' + ACCEPT_EULA: "Y" SA_PASSWORD: 5atyaNadella DBT_TEST_USER_1: DBT_TEST_USER_1 DBT_TEST_USER_2: DBT_TEST_USER_2 @@ -50,4 +51,40 @@ jobs: DBT_TEST_USER_1: DBT_TEST_USER_1 DBT_TEST_USER_2: DBT_TEST_USER_2 DBT_TEST_USER_3: DBT_TEST_USER_3 - SQLSERVER_TEST_DRIVER: 'ODBC Driver ${{ matrix.msodbc_version }} for SQL Server' + SQLSERVER_TEST_DRIVER: "ODBC Driver ${{ matrix.msodbc_version }} for SQL Server" + + integration-tests-sql-server-mssql-python: + name: mssql-python + runs-on: ubuntu-latest + permissions: + contents: read + packages: read + container: + image: ghcr.io/${{ github.repository }}:CI-3.13-msodbc18 + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + services: + sqlserver: + image: ghcr.io/${{ github.repository }}:server-2022 + env: + ACCEPT_EULA: "Y" + SA_PASSWORD: 5atyaNadella + DBT_TEST_USER_1: DBT_TEST_USER_1 + DBT_TEST_USER_2: DBT_TEST_USER_2 + DBT_TEST_USER_3: DBT_TEST_USER_3 + COLLATION: SQL_Latin1_General_CP1_CS_AS + steps: + - uses: actions/checkout@v4 + + - name: Install dependencies + run: pip install -r dev_requirements.txt + + - name: Run functional tests with mssql-python + run: pytest -ra -v tests/functional --profile "ci_sql_server" + env: + DBT_TEST_USER_1: DBT_TEST_USER_1 + DBT_TEST_USER_2: DBT_TEST_USER_2 + DBT_TEST_USER_3: DBT_TEST_USER_3 + SQLSERVER_TEST_DRIVER: "ODBC Driver 18 for SQL Server" + SQLSERVER_TEST_USE_MSSQL_PYTHON: "True" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 28bdc43b8..365f23d2b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,15 +32,35 @@ The functional tests require a running SQL Server instance. You can easily spin make server ``` +The default development flow uses the existing ODBC-based path. If you want to develop or test the optional `mssql-python` backend instead, make sure the package is installed in your environment before running tests. + +```shell +pip install mssql-python +``` + +On Debian/Ubuntu-based environments, `mssql-python` may also require these system libraries: + +```shell +sudo apt-get install -y libltdl7 libkrb5-3 libgssapi-krb5-2 +``` + This will use Docker Compose to spin up a local instance of SQL Server. Docker Compose is now bundled with Docker, so make sure to [install the latest version of Docker](https://docs.docker.com/get-docker/). Next, tell our tests how they should connect to the local instance by creating a file called `test.env` in the root of the project. You can use the provided `test.env.sample` as a base and if you started the server with `make server`, then this matches the instance running on your local machine. +If you are testing the optional `mssql-python` backend, also enable its profile flag in `test.env` so the adapter selects that implementation instead of the legacy driver-based one. + ```shell cp test.env.sample test.env ``` +When using the optional `mssql-python` backend, update `test.env` with: + +```shell +SQLSERVER_TEST_USE_MSSQL_PYTHON=True +``` + You can tweak the contents of this file to test against a different database. Note that we need 3 users to be able to run tests related to the grants. @@ -57,6 +77,8 @@ make unit make functional ``` +This remains the documented test procedure for both connection backends. When the `mssql-python` flag is enabled, run the same commands after installing `mssql-python` and setting `SQLSERVER_TEST_USE_MSSQL_PYTHON=True` in `test.env`. + ## CI/CD We use Docker images that have all the things we need to test the adapter in the CI/CD workflows. diff --git a/README.md b/README.md index 80641670e..be88afdcd 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,12 @@ Join us on the [dbt Slack](https://getdbt.slack.com/archives/CMRMDDQ9W) to ask q ## Installation +By default this adapter uses the Microsoft ODBC driver. + This adapter requires the Microsoft ODBC driver to be installed: [Windows](https://docs.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16#download-for-windows) | [macOS](https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/install-microsoft-odbc-driver-sql-server-macos?view=sql-server-ver16) | -[Linux](https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-for-sql-server?view=sql-server-ver16) +[Linux](https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-sql-server?view=sql-server-ver16)
Debian/Ubuntu

@@ -45,6 +47,18 @@ Latest pre-release: ![GitHub tag (latest SemVer pre-release)](https://img.shield pip install -U --pre dbt-sqlserver ``` +### Optional: `mssql-python` backend + +This adapter can also use the `mssql-python` driver behind a feature flag. + +Install it explicitly when you want to use that backend: + +```shell +pip install -U mssql-python +``` + +When this backend is enabled, the adapter does not require the ODBC driver-based connection path for that profile. + ## Changelog See [the changelog](CHANGELOG.md) @@ -55,6 +69,8 @@ See [the changelog](CHANGELOG.md) - `dbt_sqlserver_use_default_schema_concat`: *(default: `false`)* Controls schema name generation when a [custom schema](https://docs.getdbt.com/docs/build/custom-schemas) is set on a model. +- `use_mssql_python`: *(default: `false` in the profile)* Switches the connection backend from the legacy ODBC / `pyodbc` path to the `mssql-python` driver for that target profile. + | Flag value | `custom_schema_name` | Result | |---|---|---| | `false` (default, legacy) | *(none)* | `target.schema` | @@ -74,6 +90,55 @@ See [the changelog](CHANGELOG.md) > **Note:** If you want to permanently customise schema generation and avoid any future deprecation of this flag, override the `sqlserver__generate_schema_name` macro directly in your project. +### `mssql-python` feature flag usage + +Enable the backend per target in your `profiles.yml`: + +```yaml +your_profile: + target: dev + outputs: + dev: + type: sqlserver + host: your-server + port: 1433 + database: your-database + schema: dbo + user: your-user + password: your-password + encrypt: true + trust_cert: false + use_mssql_python: true +``` + +#### Notes + +- `use_mssql_python: true` is a profile-level feature flag. +- When enabled, the adapter uses `mssql-python` instead of the legacy `pyodbc` connection path. +- The legacy ODBC driver setting is only needed for profiles that continue to use the ODBC backend. +- If you enable `use_mssql_python`, make sure the `mssql-python` package is installed in the environment running dbt. +- This path is intended to fail fast when required dependencies or unsupported settings are missing. + +#### Testing + +For local development and validation, use the documented adapter workflow from `CONTRIBUTING.md`: + +```shell +make dev +make server +cp test.env.sample test.env +make unit +make functional +``` + +To exercise the `mssql-python` backend in tests, configure the profile or environment so that the target under test sets: + +```yaml +use_mssql_python: true +``` + +If you are testing in the devcontainer, ensure the `mssql-python` package is installed in that environment before running the unit or functional suite. + ## Contributing diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index 874883de7..de61b81e7 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -10,6 +10,14 @@ import dbt_common.exceptions import pyodbc +try: + import mssql_python as MSSQL_PYTHON +except ModuleNotFoundError as exc: + MSSQL_PYTHON = None + _MSSQL_PYTHON_IMPORT_ERROR: Optional[ModuleNotFoundError] = exc +else: + _MSSQL_PYTHON_IMPORT_ERROR = None + try: from azure.core.credentials import AccessToken except ModuleNotFoundError: @@ -72,6 +80,14 @@ class AccessToken: # type: ignore[no-redef] "decimal.Decimal": "decimal", } +MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS = { + "cli", + "auto", + "environment", + "serviceprincipal", + "activedirectoryaccesstoken", +} + def _require_azure_identity(authentication: str) -> None: if _AZURE_IDENTITY_IMPORT_ERROR is not None: @@ -84,6 +100,49 @@ def _require_azure_identity(authentication: str) -> None: ) from _AZURE_IDENTITY_IMPORT_ERROR +def _require_mssql_python() -> None: + if _MSSQL_PYTHON_IMPORT_ERROR is not None: + raise dbt_common.exceptions.DbtRuntimeError( + "The `mssql-python` backend was requested, but the optional dependency " + "`mssql-python` is not installed. Install it with `pip install mssql-python` " + "or disable `use_mssql_python` in the profile." + ) from _MSSQL_PYTHON_IMPORT_ERROR + + +def _requires_pyodbc_backend(credentials: SQLServerCredentials) -> bool: + authentication = str(credentials.authentication or "sql").lower().strip() + return authentication in AZURE_AUTH_FUNCTIONS or authentication == "activedirectoryaccesstoken" + + +def _use_mssql_python_backend(credentials: SQLServerCredentials) -> bool: + return bool(getattr(credentials, "use_mssql_python", False)) + + +def _validate_pyodbc_requirements(credentials: SQLServerCredentials) -> None: + if not credentials.driver: + raise dbt_common.exceptions.DbtRuntimeError( + "The pyodbc backend requires a SQL Server ODBC driver name " + "in the `driver` profile field." + ) + + +def _validate_mssql_python_requirements(credentials: SQLServerCredentials) -> None: + authentication = str(credentials.authentication or "sql").strip() + authentication_lower = authentication.lower() + + if authentication_lower in MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS: + raise dbt_common.exceptions.DbtRuntimeError( + "Authentication '{}' is currently only supported by the pyodbc backend " + "in this adapter. " + "Disable `use_mssql_python` or use a connection-string-supported " + "authentication mode such as " + "`sql`, `ActiveDirectoryPassword`, `ActiveDirectoryInteractive`, " + "`ActiveDirectoryIntegrated`, " + "`ActiveDirectoryMSI`, `ActiveDirectoryDeviceCode`, " + "or `ActiveDirectoryDefault`.".format(authentication) + ) + + def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: """ Convert bytes to a Microsoft windows byte string. @@ -316,7 +375,7 @@ def bool_to_connection_string_arg(key: str, value: bool) -> str: out : str The connection string argument. """ - return f'{key}={"Yes" if value else "No"}' + return f"{key}={'Yes' if value else 'No'}" def byte_array_to_datetime(value: bytes) -> dt.datetime: @@ -353,6 +412,128 @@ def byte_array_to_datetime(value: bytes) -> dt.datetime: ) +def _build_server_arg(credentials: SQLServerCredentials) -> str: + if "\\" in credentials.host: + # If there is a backslash \ in the host name, the host is a + # SQL Server named instance. In this case then port number has to be omitted. + return credentials.host + return f"{credentials.host},{credentials.port}" + + +def _build_common_connection_string_parts( + credentials: SQLServerCredentials, +) -> list[str]: + con_str = [f"SERVER={_build_server_arg(credentials)}"] + con_str.append(f"Database={credentials.database}") + + assert credentials.authentication is not None + + if ( + "ActiveDirectory" in credentials.authentication + and credentials.authentication != "ActiveDirectoryAccessToken" + ): + con_str.append(f"Authentication={credentials.authentication}") + + if credentials.authentication == "ActiveDirectoryPassword": + con_str.append(f"UID={{{credentials.UID}}}") + con_str.append(f"PWD={{{credentials.PWD}}}") + if credentials.authentication == "ActiveDirectoryServicePrincipal": + con_str.append(f"UID={{{credentials.client_id}}}") + con_str.append(f"PWD={{{credentials.client_secret}}}") + elif credentials.authentication == "ActiveDirectoryInteractive": + con_str.append(f"UID={{{credentials.UID}}}") + + elif credentials.windows_login: + con_str.append("trusted_connection=Yes") + elif credentials.authentication == "sql": + con_str.append(f"UID={{{credentials.UID}}}") + con_str.append(f"PWD={{{credentials.PWD}}}") + + assert credentials.encrypt is not None + assert credentials.trust_cert is not None + + con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt)) + con_str.append(bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert)) + + return con_str + + +def _build_pyodbc_connection_string(credentials: SQLServerCredentials) -> str: + con_str = [f"DRIVER={{{credentials.driver}}}"] + con_str.extend(_build_common_connection_string_parts(credentials)) + con_str.append("Pooling=true") + + if credentials.trace_flag: + con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_ON") + else: + con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_OFF") + + plugin_version = __version__.version + application_name = f"dbt-{credentials.type}/{plugin_version}" + con_str.append(f"APP={application_name}") + + try: + con_str.append("ConnectRetryCount=3") + con_str.append("ConnectRetryInterval=10") + except Exception as e: + logger.debug( + ( + "Retry count should be a integer value. " + "Skipping retries in the connection string." + ), + str(e), + ) + + return ";".join(con_str) + + +def _build_mssql_python_connection_string(credentials: SQLServerCredentials) -> str: + con_str = _build_common_connection_string_parts(credentials) + con_str.append("ConnectRetryCount=3") + con_str.append("ConnectRetryInterval=10") + return ";".join(con_str) + + +def _sanitize_connection_string_for_logging(connection_string: str) -> str: + parts = connection_string.split(";") + sanitized = [] + for part in parts: + if part.lower().startswith("pwd="): + sanitized.append("PWD=***") + else: + sanitized.append(part) + return ";".join(sanitized) + + +def _get_backend_exceptions( + credentials: SQLServerCredentials, +) -> Tuple[Type[Exception], ...]: + if _use_mssql_python_backend(credentials): + _require_mssql_python() + retryable_exceptions = [] + retryable_exceptions.append(getattr(MSSQL_PYTHON, "InternalError", Exception)) + retryable_exceptions.append(getattr(MSSQL_PYTHON, "OperationalError", Exception)) + + if _requires_pyodbc_backend(credentials): + retryable_exceptions.append(getattr(MSSQL_PYTHON, "InterfaceError", Exception)) + + return tuple(retryable_exceptions) + + retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions + pyodbc.InternalError, # not used according to docs, but defined in PEP-249 + pyodbc.OperationalError, + ] + + if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS: + retryable_exceptions.append(pyodbc.InterfaceError) + + return tuple(retryable_exceptions) + + +def _is_pyodbc_handle(handle: Any) -> bool: + return hasattr(handle, "add_output_converter") + + class SQLServerConnectionManager(SQLConnectionManager): TYPE = "sqlserver" @@ -365,7 +546,6 @@ def exception_handler(self, sql): logger.debug("Database error: {}".format(str(e))) try: - # attempt to release the connection self.release() except pyodbc.Error: logger.debug("Failed to release connection!") @@ -373,13 +553,23 @@ def exception_handler(self, sql): raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e except Exception as e: + if _use_mssql_python_backend(self.get_thread_connection().credentials): + if MSSQL_PYTHON is not None and isinstance( + e, getattr(MSSQL_PYTHON, "DatabaseError", tuple()) + ): + logger.debug("Database error: {}".format(str(e))) + + try: + self.release() + except Exception: + logger.debug("Failed to release connection!") + + raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e + logger.debug(f"Error running SQL: {sql}") logger.debug("Rolling back transaction.") self.release() if isinstance(e, dbt_common.exceptions.DbtRuntimeError): - # during a sql query, an internal to dbt exception was raised. - # this sounds a lot like a signal handler and probably has - # useful information, so raise it without modification. raise raise dbt_common.exceptions.DbtRuntimeError(e) @@ -392,101 +582,38 @@ def open(cls, connection: Connection) -> Connection: credentials = cls.get_credentials(connection.credentials) - con_str = [f"DRIVER={{{credentials.driver}}}"] - - if "\\" in credentials.host: - # If there is a backslash \ in the host name, the host is a - # SQL Server named instance. In this case then port number has to be omitted. - con_str.append(f"SERVER={credentials.host}") + if _use_mssql_python_backend(credentials): + _require_mssql_python() + _validate_mssql_python_requirements(credentials) + con_str_concat = _build_mssql_python_connection_string(credentials) else: - con_str.append(f"SERVER={credentials.host},{credentials.port}") + _validate_pyodbc_requirements(credentials) + con_str_concat = _build_pyodbc_connection_string(credentials) - con_str.append(f"Database={credentials.database}") - con_str.append("Pooling=true") - - # Enabling trace flag - if credentials.trace_flag: - con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_ON") - else: - con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_OFF") - - assert credentials.authentication is not None - - # Access token authentication does not additional connection string parameters. - # The access token is passed in the pyodbc attributes. - if ( - "ActiveDirectory" in credentials.authentication - and credentials.authentication != "ActiveDirectoryAccessToken" - ): - con_str.append(f"Authentication={credentials.authentication}") - - if credentials.authentication == "ActiveDirectoryPassword": - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") - if credentials.authentication == "ActiveDirectoryServicePrincipal": - con_str.append(f"UID={{{credentials.client_id}}}") - con_str.append(f"PWD={{{credentials.client_secret}}}") - elif credentials.authentication == "ActiveDirectoryInteractive": - con_str.append(f"UID={{{credentials.UID}}}") - - elif credentials.windows_login: - con_str.append("trusted_connection=Yes") - elif credentials.authentication == "sql": - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") - - # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15 - assert credentials.encrypt is not None - assert credentials.trust_cert is not None - - con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt)) - con_str.append( - bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert) - ) - - plugin_version = __version__.version - application_name = f"dbt-{credentials.type}/{plugin_version}" - con_str.append(f"APP={application_name}") - - try: - con_str.append("ConnectRetryCount=3") - con_str.append("ConnectRetryInterval=10") - - except Exception as e: - logger.debug( - ( - "Retry count should be a integer value. " - "Skipping retries in the connection string." - ), - str(e), - ) - - con_str_concat = ";".join(con_str) - - index = [] - for i, elem in enumerate(con_str): - if "pwd=" in elem.lower(): - index.append(i) - - if len(index) != 0: - con_str[index[0]] = "PWD=***" - - con_str_display = ";".join(con_str) - - retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions - pyodbc.InternalError, # not used according to docs, but defined in PEP-249 - pyodbc.OperationalError, - ] - - if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS: - # Temporary login/token errors fall into this category when using AAD - retryable_exceptions.append(pyodbc.InterfaceError) + con_str_display = _sanitize_connection_string_for_logging(con_str_concat) + retryable_exceptions = _get_backend_exceptions(credentials) def connect(): logger.debug(f"Using connection string: {con_str_display}") - pyodbc.pooling = True - # pyodbc attributes includes the access token provided by the user if required. + if _use_mssql_python_backend(credentials): + MSSQL_PYTHON.pooling(enabled=False) + handle = MSSQL_PYTHON.connect( + con_str_concat, + autocommit=True, + timeout=credentials.login_timeout, + ) + try: + handle.timeout = credentials.query_timeout + except Exception: + logger.debug( + "The mssql-python connection object does not expose a mutable `timeout` " + "attribute; continuing without setting query timeout on the handle." + ) + logger.debug(f"Connected to db: {credentials.database}") + return handle + + pyodbc.pooling = True attrs_before = get_pyodbc_attrs_before_credentials(credentials) handle = pyodbc.connect( @@ -513,11 +640,9 @@ def cancel(self, connection: Connection): logger.debug("Cancel query") def add_begin_query(self): - # return self.add_query('BEGIN TRANSACTION', auto_begin=False) pass def add_commit_query(self): - # return self.add_query('COMMIT TRANSACTION', auto_begin=False) pass def add_query( @@ -548,7 +673,6 @@ def _execute_query_with_retry( retries. Failure begins a sleep and retry routine. """ try: - # pyodbc does not handle a None type binding! if bindings is None: cursor.execute(sql) else: @@ -558,14 +682,13 @@ def _execute_query_with_retry( ] cursor.execute(sql, bindings) except retryable_exceptions as e: - # Cease retries and fail when limit is hit. if attempt >= retry_limit: raise e fire_event( AdapterEventDebug( message=( - f"Got a retryable error {type(e)}. {retry_limit-attempt} " + f"Got a retryable error {type(e)}. {retry_limit - attempt} " "retries left. Retrying in 1 second.\n" f"Error:\n{e}" ) @@ -603,7 +726,9 @@ def _execute_query_with_retry( fire_event( SQLQuery( - conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info() + conn_name=cast_to_str(connection.name), + sql=log_sql, + node_info=get_node_info(), ) ) @@ -621,9 +746,8 @@ def _execute_query_with_retry( attempt=1, ) - # convert DATETIMEOFFSET binary structures to datetime ojbects - # https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 - connection.handle.add_output_converter(-155, byte_array_to_datetime) + if _is_pyodbc_handle(connection.handle): + connection.handle.add_output_converter(-155, byte_array_to_datetime) fire_event( SQLQueryStatus( @@ -656,18 +780,25 @@ def data_type_code_to_name(cls, type_code: Union[str, str]) -> str: return datatypes[data_type] def execute( - self, sql: str, auto_begin: bool = True, fetch: bool = False, limit: Optional[int] = None + self, + sql: str, + auto_begin: bool = True, + fetch: bool = False, + limit: Optional[int] = None, ) -> Tuple[AdapterResponse, agate.Table]: sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) - response = self.get_response(cursor) - if fetch: - while cursor.description is None: - if not cursor.nextset(): - break - table = self.get_result_from_cursor(cursor, limit) - else: - table = empty_table() - while cursor.nextset(): - pass - return response, table + try: + response = self.get_response(cursor) + if fetch: + while cursor.description is None: + if not cursor.nextset(): + break + table = self.get_result_from_cursor(cursor, limit) + else: + table = empty_table() + while cursor.nextset(): + pass + return response, table + finally: + cursor.close() diff --git a/dbt/adapters/sqlserver/sqlserver_credentials.py b/dbt/adapters/sqlserver/sqlserver_credentials.py index 37bba77ea..4c0811d3d 100644 --- a/dbt/adapters/sqlserver/sqlserver_credentials.py +++ b/dbt/adapters/sqlserver/sqlserver_credentials.py @@ -6,10 +6,10 @@ @dataclass class SQLServerCredentials(Credentials): - driver: str - host: str - database: str - schema: str + driver: Optional[str] = None + host: str = "" + database: str = "" + schema: str = "" UID: Optional[str] = None PWD: Optional[str] = None port: Optional[int] = 1433 @@ -27,6 +27,7 @@ class SQLServerCredentials(Credentials): schema_authorization: Optional[str] = None login_timeout: Optional[int] = 0 query_timeout: Optional[int] = 0 + use_mssql_python: bool = False _ALIASES = { "user": "UID", @@ -41,6 +42,8 @@ class SQLServerCredentials(Credentials): "TrustServerCertificate": "trust_cert", "schema_auth": "schema_authorization", "SQL_ATTR_TRACE": "trace_flag", + "mssql_python": "use_mssql_python", + "use_mssql_python_backend": "use_mssql_python", } @property @@ -54,7 +57,7 @@ def _connection_keys(self): if self.authentication.lower().strip() == "serviceprincipal": self.authentication = "ActiveDirectoryServicePrincipal" - return ( + keys = ( "server", "port", "database", @@ -67,8 +70,14 @@ def _connection_keys(self): "trace_flag", "encrypt", "trust_cert", + "use_mssql_python", ) + if not self.use_mssql_python: + keys = ("driver",) + keys + + return keys + @property def unique_field(self): return self.host diff --git a/dev_requirements.txt b/dev_requirements.txt index 239825e67..7b5301d8b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -20,5 +20,6 @@ tox>=3.13 twine wheel pyodbc +mssql-python azure-identity -e . diff --git a/setup.py b/setup.py index e15e6ccfa..fc1306419 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,13 @@ from setuptools.command.install import install package_name = "dbt-sqlserver" -authors_list = ["Mikael Ene", "Anders Swanson", "Sam Debruyn", "Cor Zuurmond", "Cody Scott"] +authors_list = [ + "Mikael Ene", + "Anders Swanson", + "Sam Debruyn", + "Cor Zuurmond", + "Cody Scott", +] dbt_version = "1.9" description = """A Microsoft SQL Server adapter plugin for dbt""" @@ -70,6 +76,11 @@ def run(self): "dbt-common>=1.0,<2.0", "dbt-adapters>=1.11.0,<2.0", ], + extras_require={ + "mssql-python": [ + "mssql-python>=1.0.0", + ], + }, cmdclass={ "verify": VerifyVersionCommand, }, diff --git a/test.env.sample b/test.env.sample index 09982ccc9..b66c97497 100644 --- a/test.env.sample +++ b/test.env.sample @@ -6,6 +6,7 @@ SQLSERVER_TEST_PORT=1433 SQLSERVER_TEST_DBNAME=TestDB SQLSERVER_TEST_ENCRYPT=True SQLSERVER_TEST_TRUST_CERT=True +SQLSERVER_TEST_USE_MSSQL_PYTHON=False DBT_TEST_USER_1=DBT_TEST_USER_1 DBT_TEST_USER_2=DBT_TEST_USER_2 DBT_TEST_USER_3=DBT_TEST_USER_3 diff --git a/tests/conftest.py b/tests/conftest.py index 540ee3025..376bdd93c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,6 +54,8 @@ def _all_profiles_base(): "driver": os.getenv("SQLSERVER_TEST_DRIVER", "ODBC Driver 18 for SQL Server"), "port": int(os.getenv("SQLSERVER_TEST_PORT", "1433")), "retries": 2, + "use_mssql_python": os.getenv("SQLSERVER_TEST_USE_MSSQL_PYTHON", "False").lower() + == "true", } diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index 2170c58bd..5619d6795 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -1,9 +1,14 @@ +from types import SimpleNamespace +from typing import Any, Dict, List + import pytest from azure.identity import AzureCliCredential +from dbt.adapters.contracts.connection import Connection, ConnectionState from dbt_common.exceptions import DbtRuntimeError from dbt.adapters.sqlserver import sqlserver_connections -from dbt.adapters.sqlserver.sqlserver_connections import ( # byte_array_to_datetime, +from dbt.adapters.sqlserver.sqlserver_connections import ( + SQLServerConnectionManager, bool_to_connection_string_arg, get_pyodbc_attrs_before_credentials, ) @@ -16,13 +21,12 @@ @pytest.fixture def credentials() -> SQLServerCredentials: - credentials = SQLServerCredentials( + return SQLServerCredentials( driver="ODBC Driver 17 for SQL Server", host="fake.sql.sqlserver.net", database="dbt", schema="sqlserver", ) - return credentials def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( @@ -65,3 +69,146 @@ def test_get_pyodbc_attrs_before_cli_auth_requires_azure_identity( ) def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> None: assert bool_to_connection_string_arg(key, value) == expected + + +def test_open_with_mssql_python_feature_flag_requires_optional_dependency( + credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch +) -> None: + credentials.driver = None + credentials.use_mssql_python = True + + connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) + + monkeypatch.setattr(sqlserver_connections, "MSSQL_PYTHON", None, raising=False) + monkeypatch.setattr( + sqlserver_connections, + "_MSSQL_PYTHON_IMPORT_ERROR", + ModuleNotFoundError("No module named 'mssql_python'"), + raising=False, + ) + + with pytest.raises(DbtRuntimeError, match="mssql-python"): + SQLServerConnectionManager.open(connection) + + +def test_open_with_mssql_python_feature_flag_builds_connection_without_odbc_driver( + credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch +) -> None: + credentials.driver = None + credentials.UID = "dbt_user" + credentials.PWD = "super-secret" + credentials.encrypt = True + credentials.trust_cert = True + credentials.login_timeout = 17 + credentials.query_timeout = 23 + credentials.retries = 5 + credentials.use_mssql_python = True + + captured: Dict[str, Any] = {} + pooling_calls: List[Dict[str, Any]] = [] + + class FakeHandle: + def __init__(self): + self.timeout = None + + fake_handle = FakeHandle() + + def fake_connect(connection_string, autocommit, timeout): + captured["connection_string"] = connection_string + captured["autocommit"] = autocommit + captured["timeout"] = timeout + return fake_handle + + def fake_pooling(*, enabled): + pooling_calls.append({"enabled": enabled}) + + fake_module = SimpleNamespace( + connect=fake_connect, + pooling=fake_pooling, + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + InternalError=type("InternalError", (Exception,), {}), + ) + + def fake_retry_connection( + cls, + connection, + connect, + logger, + retry_limit, + retryable_exceptions, + ): + captured["retry_limit"] = retry_limit + captured["retryable_exceptions"] = retryable_exceptions + handle = connect() + connection.handle = handle + connection.state = ConnectionState.OPEN + return connection + + monkeypatch.setattr(sqlserver_connections, "MSSQL_PYTHON", fake_module, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + monkeypatch.setattr( + SQLServerConnectionManager, + "retry_connection", + classmethod(fake_retry_connection), + ) + + connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) + opened = SQLServerConnectionManager.open(connection) + + assert opened is connection + assert opened.handle is fake_handle + assert opened.state == ConnectionState.OPEN + + assert captured["autocommit"] is True + assert captured["timeout"] == 17 + assert captured["retry_limit"] == 5 + assert pooling_calls == [{"enabled": False}] + + con_str = captured["connection_string"] + assert "DRIVER=" not in con_str + assert "SERVER=fake.sql.sqlserver.net,1433" in con_str + assert "Database=dbt" in con_str + assert "UID={dbt_user}" in con_str + assert "PWD={super-secret}" in con_str + assert "encrypt=Yes" in con_str + assert "TrustServerCertificate=Yes" in con_str + assert "APP=dbt-sqlserver/" not in con_str + + assert fake_module.OperationalError in captured["retryable_exceptions"] + assert fake_module.InternalError in captured["retryable_exceptions"] + + +def test_open_with_mssql_python_feature_flag_fails_fast_for_pyodbc_token_auth_aliases( + credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch +) -> None: + credentials.driver = None + credentials.use_mssql_python = True + credentials.authentication = "cli" + + fake_module = SimpleNamespace( + connect=lambda *args, **kwargs: None, + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + InternalError=type("InternalError", (Exception,), {}), + ) + + monkeypatch.setattr(sqlserver_connections, "MSSQL_PYTHON", fake_module, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + + connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) + + with pytest.raises(DbtRuntimeError, match="authentication"): + SQLServerConnectionManager.open(connection) + + +def test_open_with_pyodbc_path_still_requires_driver_when_feature_flag_disabled( + credentials: SQLServerCredentials, +) -> None: + credentials.driver = None + credentials.use_mssql_python = False + + connection = Connection(type="sqlserver", name="pyodbc-test", credentials=credentials) + + with pytest.raises(DbtRuntimeError, match="driver"): + SQLServerConnectionManager.open(connection) From 9b26556bcc6cb85520fbaa87608d04c06f339364 Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Thu, 21 May 2026 04:09:41 +0000 Subject: [PATCH 2/8] =?UTF-8?q?=E2=9C=A8=20feat:=20improve=20for=20optiona?= =?UTF-8?q?l=20backends=20lazy=20imported=20and=20dev?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .devcontainer/setup_env.sh | 3 +- README.md | 3 +- .../sqlserver/sqlserver_connections.py | 156 ++++++++++++------ .../adapters/mssql/test_connection_logic.py | 50 +++++- .../test_sqlserver_connection_manager.py | 102 +++++++++++- uv.lock | 82 +++++++-- 6 files changed, 317 insertions(+), 79 deletions(-) diff --git a/.devcontainer/setup_env.sh b/.devcontainer/setup_env.sh index fd1f9cd53..9cdfe06c4 100644 --- a/.devcontainer/setup_env.sh +++ b/.devcontainer/setup_env.sh @@ -15,5 +15,6 @@ pip install uv [ -d .venv ] || uv venv source .venv/bin/activate -uv sync --group dev --extra pyodbc +# Install both backend extras so the devcontainer can exercise either connection path. +uv sync --group dev --extra pyodbc --extra mssql pre-commit install diff --git a/README.md b/README.md index be88afdcd..20829570e 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,7 @@ your_profile: - When enabled, the adapter uses `mssql-python` instead of the legacy `pyodbc` connection path. - The legacy ODBC driver setting is only needed for profiles that continue to use the ODBC backend. - If you enable `use_mssql_python`, make sure the `mssql-python` package is installed in the environment running dbt. +- On Debian/Ubuntu-based environments, `mssql-python` also requires `libltdl7`, `libkrb5-3`, and `libgssapi-krb5-2`. - This path is intended to fail fast when required dependencies or unsupported settings are missing. #### Testing @@ -137,7 +138,7 @@ To exercise the `mssql-python` backend in tests, configure the profile or enviro use_mssql_python: true ``` -If you are testing in the devcontainer, ensure the `mssql-python` package is installed in that environment before running the unit or functional suite. +If you are testing in the devcontainer, the backend prerequisites are installed automatically. Outside the devcontainer, install `mssql-python` and the system libraries above before running the unit or functional suite. diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index 23d6fd582..03bdba5d0 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -8,15 +8,23 @@ import agate import dbt_common.exceptions -import pyodbc +from dbt_common.clients.agate_helper import empty_table +from dbt_common.events.contextvars import get_node_info +from dbt_common.events.functions import fire_event +from dbt_common.utils.casting import cast_to_str -try: - import mssql_python as MSSQL_PYTHON -except ModuleNotFoundError as exc: - MSSQL_PYTHON = None - _MSSQL_PYTHON_IMPORT_ERROR: Optional[ModuleNotFoundError] = exc -else: - _MSSQL_PYTHON_IMPORT_ERROR = None +from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus +from dbt.adapters.sql.connections import SQLConnectionManager +from dbt.adapters.sqlserver import __version__ +from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials + +_PYODBC_MODULE: Optional[Any] = None +_PYODBC_IMPORT_ERROR: Optional[ModuleNotFoundError] = None + +_MSSQL_PYTHON_MODULE: Optional[Any] = None +_MSSQL_PYTHON_IMPORT_ERROR: Optional[ModuleNotFoundError] = None try: from azure.core.credentials import AccessToken @@ -46,17 +54,6 @@ class AccessToken: # type: ignore[no-redef] ManagedIdentityCredential = None _AZURE_IDENTITY_IMPORT_ERROR = exc -from dbt_common.clients.agate_helper import empty_table -from dbt_common.events.contextvars import get_node_info -from dbt_common.events.functions import fire_event -from dbt_common.utils.casting import cast_to_str - -from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState -from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus -from dbt.adapters.sql.connections import SQLConnectionManager -from dbt.adapters.sqlserver import __version__ -from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials _TOKEN: Optional[AccessToken] = None AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default" @@ -89,18 +86,39 @@ class AccessToken: # type: ignore[no-redef] } -def _require_azure_identity(authentication: str) -> None: - if _AZURE_IDENTITY_IMPORT_ERROR is not None: +def _get_pyodbc() -> Any: + global _PYODBC_MODULE, _PYODBC_IMPORT_ERROR + + if _PYODBC_MODULE is not None: + return _PYODBC_MODULE + + if _PYODBC_IMPORT_ERROR is not None: raise dbt_common.exceptions.DbtRuntimeError( - ( - "Azure authentication '{}' requires the optional " - "dependency 'azure-identity'. Install it with `pip install " - "azure-identity` or use a non-Azure authentication mode." - ).format(authentication) - ) from _AZURE_IDENTITY_IMPORT_ERROR + "The legacy `pyodbc` backend was requested, but the optional dependency " + "`pyodbc` is not installed. Install it with `pip install pyodbc` " + "or enable `use_mssql_python` in the profile." + ) from _PYODBC_IMPORT_ERROR + + try: + import pyodbc as imported_pyodbc + except ModuleNotFoundError as exc: + _PYODBC_IMPORT_ERROR = exc + raise dbt_common.exceptions.DbtRuntimeError( + "The legacy `pyodbc` backend was requested, but the optional dependency " + "`pyodbc` is not installed. Install it with `pip install pyodbc` " + "or enable `use_mssql_python` in the profile." + ) from exc + _PYODBC_MODULE = imported_pyodbc + return _PYODBC_MODULE + + +def _get_mssql_python() -> Any: + global _MSSQL_PYTHON_MODULE, _MSSQL_PYTHON_IMPORT_ERROR + + if _MSSQL_PYTHON_MODULE is not None: + return _MSSQL_PYTHON_MODULE -def _require_mssql_python() -> None: if _MSSQL_PYTHON_IMPORT_ERROR is not None: raise dbt_common.exceptions.DbtRuntimeError( "The `mssql-python` backend was requested, but the optional dependency " @@ -108,6 +126,30 @@ def _require_mssql_python() -> None: "or disable `use_mssql_python` in the profile." ) from _MSSQL_PYTHON_IMPORT_ERROR + try: + import mssql_python as imported_mssql_python + except ModuleNotFoundError as exc: + _MSSQL_PYTHON_IMPORT_ERROR = exc + raise dbt_common.exceptions.DbtRuntimeError( + "The `mssql-python` backend was requested, but the optional dependency " + "`mssql-python` is not installed. Install it with `pip install mssql-python` " + "or disable `use_mssql_python` in the profile." + ) from exc + + _MSSQL_PYTHON_MODULE = imported_mssql_python + return _MSSQL_PYTHON_MODULE + + +def _require_azure_identity(authentication: str) -> None: + if _AZURE_IDENTITY_IMPORT_ERROR is not None: + raise dbt_common.exceptions.DbtRuntimeError( + ( + "Azure authentication '{}' requires the optional " + "dependency 'azure-identity'. Install it with `pip install " + "azure-identity` or use a non-Azure authentication mode." + ).format(authentication) + ) from _AZURE_IDENTITY_IMPORT_ERROR + def _requires_pyodbc_backend(credentials: SQLServerCredentials) -> bool: authentication = str(credentials.authentication or "sql").lower().strip() @@ -509,18 +551,22 @@ def _get_backend_exceptions( credentials: SQLServerCredentials, ) -> Tuple[Type[Exception], ...]: if _use_mssql_python_backend(credentials): - _require_mssql_python() - retryable_exceptions = [] - retryable_exceptions.append(getattr(MSSQL_PYTHON, "InternalError", Exception)) - retryable_exceptions.append(getattr(MSSQL_PYTHON, "OperationalError", Exception)) + mssql_python = _get_mssql_python() + + retryable_exceptions = [ + getattr(mssql_python, "InternalError", Exception), + getattr(mssql_python, "OperationalError", Exception), + ] if _requires_pyodbc_backend(credentials): - retryable_exceptions.append(getattr(MSSQL_PYTHON, "InterfaceError", Exception)) + retryable_exceptions.append(getattr(mssql_python, "InterfaceError", Exception)) return tuple(retryable_exceptions) - retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions - pyodbc.InternalError, # not used according to docs, but defined in PEP-249 + pyodbc = _get_pyodbc() + + retryable_exceptions = [ + pyodbc.InternalError, pyodbc.OperationalError, ] @@ -542,20 +588,25 @@ def exception_handler(self, sql): try: yield - except pyodbc.DatabaseError as e: - logger.debug("Database error: {}".format(str(e))) + except Exception as e: + credentials = self.get_thread_connection().credentials - try: - self.release() - except pyodbc.Error: - logger.debug("Failed to release connection!") + if not _use_mssql_python_backend(credentials): + pyodbc = _PYODBC_MODULE + if pyodbc is not None and isinstance(e, getattr(pyodbc, "DatabaseError", tuple())): + logger.debug("Database error: {}".format(str(e))) + + try: + self.release() + except Exception: + logger.debug("Failed to release connection!") - raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e + raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e - except Exception as e: - if _use_mssql_python_backend(self.get_thread_connection().credentials): - if MSSQL_PYTHON is not None and isinstance( - e, getattr(MSSQL_PYTHON, "DatabaseError", tuple()) + if _use_mssql_python_backend(credentials): + mssql_python = _MSSQL_PYTHON_MODULE + if mssql_python is not None and isinstance( + e, getattr(mssql_python, "DatabaseError", tuple()) ): logger.debug("Database error: {}".format(str(e))) @@ -583,12 +634,15 @@ def open(cls, connection: Connection) -> Connection: credentials = cls.get_credentials(connection.credentials) if _use_mssql_python_backend(credentials): - _require_mssql_python() + mssql_python = _get_mssql_python() _validate_mssql_python_requirements(credentials) con_str_concat = _build_mssql_python_connection_string(credentials) + pyodbc = None else: + pyodbc = _get_pyodbc() _validate_pyodbc_requirements(credentials) con_str_concat = _build_pyodbc_connection_string(credentials) + mssql_python = None con_str_display = _sanitize_connection_string_for_logging(con_str_concat) retryable_exceptions = _get_backend_exceptions(credentials) @@ -597,8 +651,10 @@ def connect(): logger.debug(f"Using connection string: {con_str_display}") if _use_mssql_python_backend(credentials): - MSSQL_PYTHON.pooling(enabled=False) - handle = MSSQL_PYTHON.connect( + assert mssql_python is not None + + mssql_python.pooling(enabled=False) + handle = mssql_python.connect( con_str_concat, autocommit=True, timeout=credentials.login_timeout, @@ -613,6 +669,8 @@ def connect(): logger.debug(f"Connected to db: {credentials.database}") return handle + assert pyodbc is not None + pyodbc.pooling = True attrs_before = get_pyodbc_attrs_before_credentials(credentials) diff --git a/tests/unit/adapters/mssql/test_connection_logic.py b/tests/unit/adapters/mssql/test_connection_logic.py index 6d772fc76..f2171591f 100644 --- a/tests/unit/adapters/mssql/test_connection_logic.py +++ b/tests/unit/adapters/mssql/test_connection_logic.py @@ -1,7 +1,9 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from dbt.adapters.sqlserver import sqlserver_connections from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials @@ -27,12 +29,22 @@ def test_connection_string_windows_login_with_port(base_credentials): connection.state = "closed" connection.credentials = base_credentials - with patch("dbt.adapters.sqlserver.sqlserver_connections.pyodbc") as mock_pyodbc: - mock_pyodbc.connect.return_value = MagicMock() + fake_pyodbc = SimpleNamespace( + connect=MagicMock(return_value=MagicMock()), + pooling=False, + InternalError=type("InternalError", (Exception,), {}), + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + ) + + with ( + patch.object(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc), + patch.object(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None), + ): SQLServerConnectionManager.open(connection) - args, _kwargs = mock_pyodbc.connect.call_args + args, _kwargs = fake_pyodbc.connect.call_args connection_string = args[0] assert "SERVER=servers.database.windows.net,1444" in connection_string @@ -50,12 +62,22 @@ def test_connection_string_standard_login_with_port(base_credentials): connection.state = "closed" connection.credentials = base_credentials - with patch("dbt.adapters.sqlserver.sqlserver_connections.pyodbc") as mock_pyodbc: - mock_pyodbc.connect.return_value = MagicMock() + fake_pyodbc = SimpleNamespace( + connect=MagicMock(return_value=MagicMock()), + pooling=False, + InternalError=type("InternalError", (Exception,), {}), + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + ) + + with ( + patch.object(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc), + patch.object(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None), + ): SQLServerConnectionManager.open(connection) - args, _kwargs = mock_pyodbc.connect.call_args + args, _kwargs = fake_pyodbc.connect.call_args connection_string = args[0] assert "SERVER=servers.database.windows.net,1444" in connection_string @@ -71,12 +93,22 @@ def test_connection_string_named_instance_no_port(base_credentials): connection.state = "closed" connection.credentials = base_credentials - with patch("dbt.adapters.sqlserver.sqlserver_connections.pyodbc") as mock_pyodbc: - mock_pyodbc.connect.return_value = MagicMock() + fake_pyodbc = SimpleNamespace( + connect=MagicMock(return_value=MagicMock()), + pooling=False, + InternalError=type("InternalError", (Exception,), {}), + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + ) + + with ( + patch.object(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc), + patch.object(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None), + ): SQLServerConnectionManager.open(connection) - args, _kwargs = mock_pyodbc.connect.call_args + args, _kwargs = fake_pyodbc.connect.call_args connection_string = args[0] assert "SERVER=myhost\\instance" in connection_string diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index 5619d6795..e4e8c8434 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -1,11 +1,13 @@ +import builtins +import importlib from types import SimpleNamespace from typing import Any, Dict, List import pytest from azure.identity import AzureCliCredential -from dbt.adapters.contracts.connection import Connection, ConnectionState from dbt_common.exceptions import DbtRuntimeError +from dbt.adapters.contracts.connection import Connection, ConnectionState from dbt.adapters.sqlserver import sqlserver_connections from dbt.adapters.sqlserver.sqlserver_connections import ( SQLServerConnectionManager, @@ -71,6 +73,85 @@ def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> assert bool_to_connection_string_arg(key, value) == expected +def test_adapter_module_import_does_not_import_optional_backends( + monkeypatch: pytest.MonkeyPatch, +) -> None: + original_import = builtins.__import__ + + def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): + if name in {"pyodbc", "mssql_python"}: + raise AssertionError(f"unexpected import: {name}") + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", guarded_import) + importlib.reload(sqlserver_connections) + + assert sqlserver_connections._PYODBC_MODULE is None + assert sqlserver_connections._MSSQL_PYTHON_MODULE is None + + +def test_get_pyodbc_returns_cached_module(monkeypatch: pytest.MonkeyPatch) -> None: + fake_pyodbc = SimpleNamespace(name="cached-pyodbc") + monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc, raising=False) + monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + + def fail_import(*args, **kwargs): + raise AssertionError("pyodbc import should not run when cached") + + monkeypatch.setattr(builtins, "__import__", fail_import) + + assert sqlserver_connections._get_pyodbc() is fake_pyodbc + assert sqlserver_connections._get_pyodbc() is fake_pyodbc + + +def test_get_mssql_python_returns_cached_module(monkeypatch: pytest.MonkeyPatch) -> None: + fake_mssql_python = SimpleNamespace(name="cached-mssql-python") + monkeypatch.setattr( + sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_mssql_python, raising=False + ) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + + def fail_import(*args, **kwargs): + raise AssertionError("mssql_python import should not run when cached") + + monkeypatch.setattr(builtins, "__import__", fail_import) + + assert sqlserver_connections._get_mssql_python() is fake_mssql_python + assert sqlserver_connections._get_mssql_python() is fake_mssql_python + + +def test_get_pyodbc_raises_only_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", None, raising=False) + monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + original_import = builtins.__import__ + + def missing_pyodbc(name, globals=None, locals=None, fromlist=(), level=0): + if name == "pyodbc": + raise ModuleNotFoundError("No module named 'pyodbc'") + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", missing_pyodbc) + + with pytest.raises(DbtRuntimeError, match="pyodbc"): + sqlserver_connections._get_pyodbc() + + +def test_get_mssql_python_raises_only_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + original_import = builtins.__import__ + + def missing_mssql_python(name, globals=None, locals=None, fromlist=(), level=0): + if name == "mssql_python": + raise ModuleNotFoundError("No module named 'mssql_python'") + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", missing_mssql_python) + + with pytest.raises(DbtRuntimeError, match="mssql-python"): + sqlserver_connections._get_mssql_python() + + def test_open_with_mssql_python_feature_flag_requires_optional_dependency( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -79,7 +160,7 @@ def test_open_with_mssql_python_feature_flag_requires_optional_dependency( connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) - monkeypatch.setattr(sqlserver_connections, "MSSQL_PYTHON", None, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) monkeypatch.setattr( sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", @@ -145,7 +226,7 @@ def fake_retry_connection( connection.state = ConnectionState.OPEN return connection - monkeypatch.setattr(sqlserver_connections, "MSSQL_PYTHON", fake_module, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) monkeypatch.setattr( SQLServerConnectionManager, @@ -193,7 +274,7 @@ def test_open_with_mssql_python_feature_flag_fails_fast_for_pyodbc_token_auth_al InternalError=type("InternalError", (Exception,), {}), ) - monkeypatch.setattr(sqlserver_connections, "MSSQL_PYTHON", fake_module, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) @@ -204,11 +285,24 @@ def test_open_with_mssql_python_feature_flag_fails_fast_for_pyodbc_token_auth_al def test_open_with_pyodbc_path_still_requires_driver_when_feature_flag_disabled( credentials: SQLServerCredentials, + monkeypatch: pytest.MonkeyPatch, ) -> None: credentials.driver = None credentials.use_mssql_python = False + fake_pyodbc = SimpleNamespace( + connect=lambda *args, **kwargs: None, + pooling=False, + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + InternalError=type("InternalError", (Exception,), {}), + ) + + monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc, raising=False) + monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + connection = Connection(type="sqlserver", name="pyodbc-test", credentials=credentials) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) with pytest.raises(DbtRuntimeError, match="driver"): SQLServerConnectionManager.open(connection) diff --git a/uv.lock b/uv.lock index 4a4b988c2..febf48202 100644 --- a/uv.lock +++ b/uv.lock @@ -483,7 +483,7 @@ wheels = [ [[package]] name = "dbt-core" -version = "1.11.9" +version = "1.10.22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "agate" }, @@ -510,9 +510,9 @@ dependencies = [ { name = "sqlparse" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/49/2f36c3a62c4a957ea7672d6566bc7ac1adca81523ca19c0ae5dc74218560/dbt_core-1.11.9.tar.gz", hash = "sha256:8dff914ca4c0d5de93ba8e285b50f007ae4d46f9fe4c845b8ef47ce5ebbc888b", size = 973146, upload-time = "2026-05-06T20:02:37.588Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/f0/6891b867c772416dadc287d852bb58e5a937b1f48008a56eb867d19bedc9/dbt_core-1.10.22.tar.gz", hash = "sha256:78dcda2ec712a356f1b7d9ba82978def56d26bfe5bcaa54855107a9dc55c284f", size = 900881, upload-time = "2026-05-20T10:53:06.034Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/85/e77e8cf3ee9509798e1353c5d1cabbc3fe1c2cc8593054b313407e7508c5/dbt_core-1.11.9-py3-none-any.whl", hash = "sha256:9693d4cf33f99e2ec8cefc6236a7bceed7f212aa35e3b97b83e27b0085ee174c", size = 1061906, upload-time = "2026-05-06T20:02:35.65Z" }, + { url = "https://files.pythonhosted.org/packages/88/97/7d640a719ca22c96e809b521b95f3ddeb78cbd76d64a95eab90632c35ec0/dbt_core-1.10.22-py3-none-any.whl", hash = "sha256:c1bc28223419a205c15143282c104586b07efa6326c83268bee5f2c36d2d1ab4", size = 987258, upload-time = "2026-05-20T10:53:04.169Z" }, ] [[package]] @@ -583,6 +583,9 @@ dependencies = [ azure = [ { name = "azure-identity" }, ] +mssql = [ + { name = "mssql-python" }, +] pyodbc = [ { name = "pyodbc" }, ] @@ -596,8 +599,10 @@ dev = [ { name = "flaky" }, { name = "freezegun" }, { name = "ipdb" }, + { name = "mssql-python" }, { name = "mypy" }, { name = "pre-commit" }, + { name = "pyodbc" }, { name = "pytest" }, { name = "pytest-csv" }, { name = "pytest-dotenv" }, @@ -611,24 +616,27 @@ dev = [ [package.metadata] requires-dist = [ { name = "azure-identity", marker = "extra == 'azure'", specifier = ">=1.12.0" }, - { name = "dbt-adapters", specifier = ">=1.11.0,<2.0" }, - { name = "dbt-common", specifier = ">=1.0,<2.0" }, - { name = "dbt-core", specifier = ">=1.9.0,<2.0" }, + { name = "dbt-adapters", specifier = ">=1.15.2,<2.0" }, + { name = "dbt-common", specifier = ">=1.22.0,<2.0" }, + { name = "dbt-core", specifier = ">=1.10.0,<1.11.0" }, + { name = "mssql-python", marker = "extra == 'mssql'", specifier = ">=1.1.0" }, { name = "pyodbc", marker = "extra == 'pyodbc'", specifier = ">=5.2.0" }, ] -provides-extras = ["azure", "pyodbc"] +provides-extras = ["azure", "pyodbc", "mssql"] [package.metadata.requires-dev] dev = [ { name = "azure-identity", specifier = ">=1.12.0" }, { name = "build" }, { name = "bumpversion" }, - { name = "dbt-tests-adapter", specifier = ">=1.9.0,<2.0" }, + { name = "dbt-tests-adapter", specifier = ">=1.15.0,<2.0" }, { name = "flaky" }, - { name = "freezegun", specifier = "==1.4.0" }, + { name = "freezegun", specifier = ">=1.5.0,<2.0" }, { name = "ipdb" }, + { name = "mssql-python", specifier = ">=1.1.0" }, { name = "mypy", specifier = "==1.11.2" }, { name = "pre-commit" }, + { name = "pyodbc", specifier = ">=5.2.0" }, { name = "pytest" }, { name = "pytest-csv" }, { name = "pytest-dotenv" }, @@ -744,14 +752,14 @@ wheels = [ [[package]] name = "freezegun" -version = "1.4.0" +version = "1.5.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "python-dateutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/73/5decad3abddbe7e1bf4bf98ead1a8345b1cc6fc6ec7e4fa27da81f4e1eee/freezegun-1.4.0.tar.gz", hash = "sha256:10939b0ba0ff5adaecf3b06a5c2f73071d9678e507c5eaedb23c761d56ac774b", size = 31748, upload-time = "2023-12-19T10:46:41.79Z" } +sdist = { url = "https://files.pythonhosted.org/packages/95/dd/23e2f4e357f8fd3bdff613c1fe4466d21bfb00a6177f238079b17f7b1c84/freezegun-1.5.5.tar.gz", hash = "sha256:ac7742a6cc6c25a2c35e9292dfd554b897b517d2dec26891a2e8debf205cb94a", size = 35914, upload-time = "2025-08-09T10:39:08.338Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/ad/72ae71e18011e59b7d129f176ff1a607f4558be4cf5b5d739860a57f9381/freezegun-1.4.0-py3-none-any.whl", hash = "sha256:55e0fc3c84ebf0a96a5aa23ff8b53d70246479e9a68863f1fcac5a3e52f19dd6", size = 17557, upload-time = "2023-12-19T10:46:39.919Z" }, + { url = "https://files.pythonhosted.org/packages/5e/2e/b41d8a1a917d6581fc27a35d05561037b048e47df50f27f8ac9c7e27a710/freezegun-1.5.5-py3-none-any.whl", hash = "sha256:cd557f4a75cf074e84bc374249b9dd491eaeacd61376b9eb3c423282211619d2", size = 19266, upload-time = "2025-08-09T10:39:06.636Z" }, ] [[package]] @@ -1246,6 +1254,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/f2/08ace4142eb281c12701fc3b93a10795e4d4dc7f753911d836675050f886/msgpack-1.1.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d99ef64f349d5ec3293688e91486c5fdb925ed03807f64d98d205d2713c60b46", size = 70868, upload-time = "2025-10-08T09:15:44.959Z" }, ] +[[package]] +name = "mssql-python" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "azure-identity" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/46/10560fd47990859ee3e02970fc8a03c76f88136b81f9dea9b38b3f89749b/mssql_python-1.7.1-cp310-cp310-macosx_15_0_universal2.whl", hash = "sha256:3208a49455bb99ea0dbd15dc18855feb755c4c74257d39a49722eb294220a1be", size = 28109464, upload-time = "2026-05-20T10:46:25.352Z" }, + { url = "https://files.pythonhosted.org/packages/5c/de/9237817b6192f4ffa5cf7f0e6673d22cba38eabc210b444a9ad8cd51adf1/mssql_python-1.7.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ae667e8a716ce5ed21bdb5b1377bc5889addfc4e93b3c19c2924fb35f942f4aa", size = 25039662, upload-time = "2026-05-20T10:46:29.872Z" }, + { url = "https://files.pythonhosted.org/packages/a5/61/d3f15c2adbc6f2efe093981f703ca5b3d826c22a59fb8742d2a647968a67/mssql_python-1.7.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1807269903ebb7c61a16a773ccf2f6faae6dd65e181802b28b9b11ea4b04c541", size = 25286139, upload-time = "2026-05-20T10:46:33.668Z" }, + { url = "https://files.pythonhosted.org/packages/28/25/b22db4241c084bf42c601b87dd94cb23c5da1da1c72ef5c0166dfee68003/mssql_python-1.7.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:92bb73fe9dc0ce24a130ea357f2eec393118ebfa469ecfc8b829e7efccbc7794", size = 24970539, upload-time = "2026-05-20T10:46:37.741Z" }, + { url = "https://files.pythonhosted.org/packages/91/e8/6dd2b34273c7d4fc05007a2a8fdc2da4d79870bfd1f7d6b6df3022c34853/mssql_python-1.7.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0b6699a23cb02134c9b74b1d739bbfed311467587eb36a851d7ef83d9542af71", size = 25211307, upload-time = "2026-05-20T10:46:41.52Z" }, + { url = "https://files.pythonhosted.org/packages/ac/9c/d6e5a91e80db6daa875372a553c33bf6a39b036a5e7e277c747caafe4170/mssql_python-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:1b7c6f0d87289aaa848ac4ffed58a912f20b3e5ac797625b2b3ad6b29da420e7", size = 15466082, upload-time = "2026-05-20T10:46:44.986Z" }, + { url = "https://files.pythonhosted.org/packages/fd/d1/435b395c33c85a207357684546a47e29b8d2a47561a3e50bfbcfa4705e80/mssql_python-1.7.1-cp311-cp311-macosx_15_0_universal2.whl", hash = "sha256:9d2f70c39cf3ba449c4288f459a97ed981cbbb690741ff2c923c4692c8927336", size = 28110812, upload-time = "2026-05-20T10:46:48.799Z" }, + { url = "https://files.pythonhosted.org/packages/af/ea/899798a996456988cbe5c00d432e1a10a33873f5db0e6575a1447e6e1472/mssql_python-1.7.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:521cde6b60b5cf7cb21df78550d9c39a5625f0f9daf083ebd4e49ff40c25a620", size = 25540223, upload-time = "2026-05-20T10:46:52.886Z" }, + { url = "https://files.pythonhosted.org/packages/94/eb/f37ba9771626be8727f39f02c1f313e223a0f8458a82b7daf5ca8c9403df/mssql_python-1.7.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:43d6953be8439331ae7e5ced93b969a36e9daa34ec54fae94e0efab8b5f797a6", size = 25951778, upload-time = "2026-05-20T10:46:56.477Z" }, + { url = "https://files.pythonhosted.org/packages/cb/cf/b278405e118cd3377ca7406934e02a5bf07a82413813fa2430a198cbe1bb/mssql_python-1.7.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:12415b09920397f0c34e070ffddac5cee263409fa3e995e532f7667ba837c78f", size = 25405073, upload-time = "2026-05-20T10:47:00.299Z" }, + { url = "https://files.pythonhosted.org/packages/4b/0e/979aa615a653b1e80eed1b6a90f4c14fb0e12b163df41bbbc7318d9b4417/mssql_python-1.7.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bcf2e249806c4880c0be0eb341d71ea0908fc4ba9b2892bca39fc73d7b80c6d9", size = 25803920, upload-time = "2026-05-20T10:47:03.912Z" }, + { url = "https://files.pythonhosted.org/packages/e4/53/601fecc8ebfa946ad1eee1bcb96b6850ad9ff8b4ffb24bcc7a055311c579/mssql_python-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:c544700c17d0f475499f685fb1774f00891110606158b9276eb79438d60d92a4", size = 15464300, upload-time = "2026-05-20T10:47:06.954Z" }, + { url = "https://files.pythonhosted.org/packages/2d/77/7382591d5d3324425b32377f5b4a5ebef57fae0d6d3f750cac0fa5dfbd70/mssql_python-1.7.1-cp311-cp311-win_arm64.whl", hash = "sha256:95fe46716f50092014dba582d783759491f298245a572b42374ee5dacfea3e37", size = 18528454, upload-time = "2026-05-20T10:47:09.967Z" }, + { url = "https://files.pythonhosted.org/packages/c8/4c/8a4d582f16faa63fd73676048c0dac381a8c629ae049d47ff6e5f821c0af/mssql_python-1.7.1-cp312-cp312-macosx_15_0_universal2.whl", hash = "sha256:3d4281657a3cac35b8fae031fa8d4b94d19d2e39749158a58af7de6891ae81da", size = 28118715, upload-time = "2026-05-20T10:47:13.791Z" }, + { url = "https://files.pythonhosted.org/packages/e9/56/b9c327625fa6ab524da866683eab80957bc7c57ed0200e5aaad2b0add0b2/mssql_python-1.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:32af320f29d32dc34489650e94d16fad5f0e7affb9ecdc229ef2612a45a48645", size = 26036075, upload-time = "2026-05-20T10:47:17.248Z" }, + { url = "https://files.pythonhosted.org/packages/c6/f4/137e7711db057bd5ab6c9d70c1e2b10385427a7450d4754526bf11926667/mssql_python-1.7.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:dc1dcde69bac5b90193e37cd7b3a193dfb41cdfeb4460643dd8adc2d312ffefd", size = 26612898, upload-time = "2026-05-20T10:47:21.565Z" }, + { url = "https://files.pythonhosted.org/packages/55/35/3c8d799a24392ccdb903582f1788968f8b19998e514f666d4891af452044/mssql_python-1.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:12a0da62dc7972980586b4636febac6281040b0a7c8ec07bd2983448bfcf586e", size = 25834286, upload-time = "2026-05-20T10:47:25.569Z" }, + { url = "https://files.pythonhosted.org/packages/33/85/ce43969245bc0cf794bdaa5d5425927f4531b17bfe35cf04f9fb7c44a8ea/mssql_python-1.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fa551d6a907614936c25c50e964580cad492b970b1f4ea55802a2420f3e50fb3", size = 26393056, upload-time = "2026-05-20T10:47:29.363Z" }, + { url = "https://files.pythonhosted.org/packages/05/77/30812559f61dc404090793c38aaec07d508598d0f1ceb594d914fecf89aa/mssql_python-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:26fe105bdb73561e2b19011d75fe0832243d440ab200519c1658924a7485379a", size = 15464577, upload-time = "2026-05-20T10:47:32.722Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ad/8a2d579bb4e6613ea5fe391a16d3fff3c1de169c27949df623de7b314327/mssql_python-1.7.1-cp312-cp312-win_arm64.whl", hash = "sha256:45b02a8247062f5bc84b2233df81b3f56a12e711c65159053dc426102d2b459e", size = 18528548, upload-time = "2026-05-20T10:47:36.623Z" }, + { url = "https://files.pythonhosted.org/packages/a0/cb/3fccc5d85fb50bb9691c041dd47a16125003f4131872f6ab4ff918c4e039/mssql_python-1.7.1-cp313-cp313-macosx_15_0_universal2.whl", hash = "sha256:b823a42226977b26d599f37b265d3fde9a3c26338c1f6849fdd5160428a48886", size = 28118712, upload-time = "2026-05-20T10:47:40.419Z" }, + { url = "https://files.pythonhosted.org/packages/7a/1a/a7a67b21290bc9abc4e44dd1fd3010a9017f0c9076084642772b6ccca2cc/mssql_python-1.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:e27aa83b16564dd2bd93d78531cffa07558b5c721c144b5ce30bd39670a77d3c", size = 26536189, upload-time = "2026-05-20T10:47:44.453Z" }, + { url = "https://files.pythonhosted.org/packages/1b/43/bb574707a4510787a98d1fc4b0b0a959c30677136dfbb9c04bd9235a9e3d/mssql_python-1.7.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:2199e61c410cc14a1d17738b1d0f4187120abad8c9c7d5c93f71c3b0d54f1379", size = 27277981, upload-time = "2026-05-20T10:47:48.626Z" }, + { url = "https://files.pythonhosted.org/packages/14/9e/c7a32a874d0bb06f41bdbe9aa929c7e922c1af0a642d70114b9b28f196a1/mssql_python-1.7.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:77d4c1bc4de3b5b95a16ce581de8e1629cb5709078e622821b8cccafaed0bca4", size = 26266809, upload-time = "2026-05-20T10:47:52.438Z" }, + { url = "https://files.pythonhosted.org/packages/7e/21/732aadf20f2ec168af8a04c44b93d8a0dd208926458da1edca5a62f207a4/mssql_python-1.7.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5d4a2794e692ff598666203e93759182e941f7d7751be370924f94fd97b96bfd", size = 26985887, upload-time = "2026-05-20T10:47:56.9Z" }, + { url = "https://files.pythonhosted.org/packages/5b/a9/cbac3dc17300c17dc2f1c4251fadfed8f337f5dcc8afb270ee847d58a80a/mssql_python-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:31a404e73bd49dd6566ef6fb7deb7a90e5fa2eb60cdb89232c388af9ea83d8fb", size = 15463865, upload-time = "2026-05-20T10:48:00.1Z" }, + { url = "https://files.pythonhosted.org/packages/d5/57/d9f43b228c863df962fe1375ecb9b1a26782214fff77cd78692203aaf473/mssql_python-1.7.1-cp313-cp313-win_arm64.whl", hash = "sha256:89580a585d4309111968c76c5af7bd226c881c786c92f22eaed5586ddacf4aaa", size = 18527852, upload-time = "2026-05-20T10:48:03.29Z" }, + { url = "https://files.pythonhosted.org/packages/89/d9/2c94c1bee748b71cd84d6a2e520890dbd40fb8a25cd0c74f4290dad300cd/mssql_python-1.7.1-cp314-cp314-macosx_15_0_universal2.whl", hash = "sha256:d31a6c815c8a3709fc1f5b59906376b63844c085b5cd0f93ca451be1db82438d", size = 28111978, upload-time = "2026-05-20T10:48:06.789Z" }, + { url = "https://files.pythonhosted.org/packages/d4/df/83be720f21a8e7e3a449012cfbc1b4cc3e306c0c062ad6bb2ce5526e8704/mssql_python-1.7.1-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:a47f76f9927d65e0978ce7683049a2f481fdbbddd6d314ce9253fdc8236ebbae", size = 27038204, upload-time = "2026-05-20T10:48:10.757Z" }, + { url = "https://files.pythonhosted.org/packages/da/24/54a301d1451e61af98f25d60f7f89345b93d7debd8dcfd112998312b844e/mssql_python-1.7.1-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:2d372e9d4bbad8b89d4b0154dbb9a3e6a4b8e473b88449842d75bf426199f008", size = 27945219, upload-time = "2026-05-20T10:48:14.552Z" }, + { url = "https://files.pythonhosted.org/packages/e8/7d/7c02637f3c85536f1694b102a2733277b81e9db9cf96739d62be79ebd7bd/mssql_python-1.7.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b6013180e2b5f252f5f7dc3008c6964482c14e50d263dd55d883f9c93f8cbc22", size = 26703582, upload-time = "2026-05-20T10:48:18.98Z" }, + { url = "https://files.pythonhosted.org/packages/ac/13/f50469590ac8b7ffbbbd9ef062ec75d6926c93df949b05f2dd1d398f778c/mssql_python-1.7.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:77f50558171f01b5d26e83ca55fd6253a241a63f053ea36b76614328be3bafb1", size = 27579144, upload-time = "2026-05-20T10:48:23.133Z" }, + { url = "https://files.pythonhosted.org/packages/56/bf/b582b9f8b6b39b414491a91835e475af839a99933b16e612d25b94364090/mssql_python-1.7.1-cp314-cp314-win_amd64.whl", hash = "sha256:a4fc35cc1c3de6f7741334f76fb81a682daeaefffa33e7cf6a3b69766254d910", size = 15985357, upload-time = "2026-05-20T10:48:26.345Z" }, + { url = "https://files.pythonhosted.org/packages/2d/76/ee2787ffd725e2c58370b73b55c8c8840f688f6cb534e39e14c2e7798511/mssql_python-1.7.1-cp314-cp314-win_arm64.whl", hash = "sha256:6a9d2aed556a4a4c01c0ba02f19671020bbf5b268bd9eb3546bcbe2b9ef07201", size = 19146020, upload-time = "2026-05-20T10:48:30.38Z" }, +] + [[package]] name = "mypy" version = "1.11.2" @@ -2220,11 +2272,11 @@ wheels = [ [[package]] name = "sqlparse" -version = "0.5.5" +version = "0.5.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/90/76/437d71068094df0726366574cf3432a4ed754217b436eb7429415cf2d480/sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e", size = 120815, upload-time = "2025-12-19T07:17:45.073Z" } +sdist = { url = "https://files.pythonhosted.org/packages/18/67/701f86b28d63b2086de47c942eccf8ca2208b3be69715a1119a4e384415a/sqlparse-0.5.4.tar.gz", hash = "sha256:4396a7d3cf1cd679c1be976cf3dc6e0a51d0111e87787e7a8d780e7d5a998f9e", size = 120112, upload-time = "2025-11-28T07:10:18.377Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/49/4b/359f28a903c13438ef59ebeee215fb25da53066db67b305c125f1c6d2a25/sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba", size = 46138, upload-time = "2025-12-19T07:17:46.573Z" }, + { url = "https://files.pythonhosted.org/packages/25/70/001ee337f7aa888fb2e3f5fd7592a6afc5283adb1ed44ce8df5764070f22/sqlparse-0.5.4-py3-none-any.whl", hash = "sha256:99a9f0314977b76d776a0fcb8554de91b9bb8a18560631d6bc48721d07023dcb", size = 45933, upload-time = "2025-11-28T07:10:19.73Z" }, ] [[package]] From d4dc2145938df9f1ff9ceefa580c60d2f11560a8 Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Thu, 21 May 2026 04:27:00 +0000 Subject: [PATCH 3/8] =?UTF-8?q?=E2=9C=A8=20feat:=20refactor=20integration?= =?UTF-8?q?=20tests=20for=20mssql-python=20backend=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../workflows/integration-tests-sqlserver.yml | 72 ++++++++----------- 1 file changed, 30 insertions(+), 42 deletions(-) diff --git a/.github/workflows/integration-tests-sqlserver.yml b/.github/workflows/integration-tests-sqlserver.yml index 1d1e035b6..310e10753 100644 --- a/.github/workflows/integration-tests-sqlserver.yml +++ b/.github/workflows/integration-tests-sqlserver.yml @@ -31,27 +31,48 @@ on: # yamllint disable-line rule:truthy - 'pytest.ini' - '.github/workflows/integration-tests-sqlserver.yml' schedule: - - cron: "0 22 * * 0" + - cron: '0 22 * * 0' jobs: integration-tests-sql-server: - name: Regular + name: Regular ${{ matrix.backend }} if: github.actor != 'dependabot[bot]' strategy: matrix: python_version: ["3.10", "3.11", "3.12", "3.13"] msodbc_version: ["17", "18"] sqlserver_version: ["2017", "2019", "2022"] - collation: - ["SQL_Latin1_General_CP1_CS_AS", "SQL_Latin1_General_CP1_CI_AS"] + collation: ["SQL_Latin1_General_CP1_CS_AS", "SQL_Latin1_General_CP1_CI_AS"] + backend: + - pyodbc + - mssql-python + exclude: + - backend: pyodbc + python_version: "3.12" + - backend: pyodbc + python_version: "3.13" + - backend: mssql-python + python_version: "3.10" + - backend: mssql-python + python_version: "3.11" + include: + - backend: pyodbc + install_extra: pyodbc + use_mssql_python: "False" + - backend: mssql-python + install_extra: mssql + use_mssql_python: "True" runs-on: ubuntu-latest container: image: ghcr.io/${{ github.repository }}:CI-${{ matrix.python_version }}-msodbc${{ matrix.msodbc_version }} + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} services: sqlserver: image: ghcr.io/${{ github.repository }}:server-${{ matrix.sqlserver_version }} env: - ACCEPT_EULA: "Y" + ACCEPT_EULA: 'Y' SA_PASSWORD: 5atyaNadella DBT_TEST_USER_1: DBT_TEST_USER_1 DBT_TEST_USER_2: DBT_TEST_USER_2 @@ -64,7 +85,9 @@ jobs: run: pip install uv - name: Install dependencies - run: uv pip install --system -e ".[pyodbc]" --group dev + env: + INSTALL_EXTRA: ${{ matrix.install_extra }} + run: uv pip install --system -e ".[$INSTALL_EXTRA]" --group dev - name: Run functional tests run: pytest -ra -v tests/functional --profile "ci_sql_server" @@ -73,39 +96,4 @@ jobs: DBT_TEST_USER_2: DBT_TEST_USER_2 DBT_TEST_USER_3: DBT_TEST_USER_3 SQLSERVER_TEST_DRIVER: "ODBC Driver ${{ matrix.msodbc_version }} for SQL Server" - - integration-tests-sql-server-mssql-python: - name: mssql-python - runs-on: ubuntu-latest - permissions: - contents: read - packages: read - container: - image: ghcr.io/${{ github.repository }}:CI-3.13-msodbc18 - credentials: - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - services: - sqlserver: - image: ghcr.io/${{ github.repository }}:server-2022 - env: - ACCEPT_EULA: "Y" - SA_PASSWORD: 5atyaNadella - DBT_TEST_USER_1: DBT_TEST_USER_1 - DBT_TEST_USER_2: DBT_TEST_USER_2 - DBT_TEST_USER_3: DBT_TEST_USER_3 - COLLATION: SQL_Latin1_General_CP1_CS_AS - steps: - - uses: actions/checkout@v4 - - - name: Install dependencies - run: pip install -r dev_requirements.txt - - - name: Run functional tests with mssql-python - run: pytest -ra -v tests/functional --profile "ci_sql_server" - env: - DBT_TEST_USER_1: DBT_TEST_USER_1 - DBT_TEST_USER_2: DBT_TEST_USER_2 - DBT_TEST_USER_3: DBT_TEST_USER_3 - SQLSERVER_TEST_DRIVER: "ODBC Driver 18 for SQL Server" - SQLSERVER_TEST_USE_MSSQL_PYTHON: "True" + SQLSERVER_TEST_USE_MSSQL_PYTHON: ${{ matrix.use_mssql_python }} From b3d8e98a9aff12c0bf9287b6c21e1b9af6f05c83 Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Thu, 21 May 2026 04:36:28 +0000 Subject: [PATCH 4/8] =?UTF-8?q?=E2=9C=A8=20feat:=20update=20installation?= =?UTF-8?q?=20instructions=20for=20optional=20mssql-python=20and=20pyodbc?= =?UTF-8?q?=20backends?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CONTRIBUTING.md | 16 +++++-- README.md | 125 ++++++++++++++++++------------------------------ 2 files changed, 59 insertions(+), 82 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 365f23d2b..4559b6383 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -32,10 +32,20 @@ The functional tests require a running SQL Server instance. You can easily spin make server ``` -The default development flow uses the existing ODBC-based path. If you want to develop or test the optional `mssql-python` backend instead, make sure the package is installed in your environment before running tests. +The default development flow uses the ODBC-based path, but the ODBC driver itself is now an optional dependency. If you want to develop or test that backend, install either the adapter extra or the driver itself before running tests. ```shell -pip install mssql-python +pip install -U "dbt-sqlserver[pyodbc]" +# or +pip install -U pyodbc +``` + +If you want to develop or test the optional `mssql-python` backend instead, install either the adapter extra or the driver itself before running tests. + +```shell +pip install -U "dbt-sqlserver[mssql]" +# or +pip install -U mssql-python ``` On Debian/Ubuntu-based environments, `mssql-python` may also require these system libraries: @@ -77,7 +87,7 @@ make unit make functional ``` -This remains the documented test procedure for both connection backends. When the `mssql-python` flag is enabled, run the same commands after installing `mssql-python` and setting `SQLSERVER_TEST_USE_MSSQL_PYTHON=True` in `test.env`. +This remains the documented test procedure for both connection backends. When the `pyodbc` path is enabled, run the same commands after installing `dbt-sqlserver[pyodbc]` or `pyodbc`. When the `mssql-python` flag is enabled, run the same commands after installing `dbt-sqlserver[mssql]` or `mssql-python` and setting `SQLSERVER_TEST_USE_MSSQL_PYTHON=True` in `test.env`. ## CI/CD diff --git a/README.md b/README.md index 20829570e..a83e44d8e 100644 --- a/README.md +++ b/README.md @@ -9,90 +9,56 @@ E.g. version 1.1.x of the adapter will be compatible with dbt-core 1.1.x. We've bundled all documentation on the dbt docs site: -* [Profile setup & authentication](https://docs.getdbt.com/reference/warehouse-profiles/mssql-profile) -* [Adapter documentation, usage and important notes](https://docs.getdbt.com/reference/resource-configs/mssql-configs) +- [Profile setup & authentication](https://docs.getdbt.com/reference/warehouse-profiles/mssql-profile) +- [Adapter documentation, usage and important notes](https://docs.getdbt.com/reference/resource-configs/mssql-configs) Join us on the [dbt Slack](https://getdbt.slack.com/archives/CMRMDDQ9W) to ask questions, get help, or to discuss the project. ## Installation -By default this adapter uses the Microsoft ODBC driver. +The base package does not bundle any connection driver. Install the adapter together with the backend extra that matches your setup. -This adapter requires the Microsoft ODBC driver to be installed: +Latest version: ![PyPI](https://img.shields.io/pypi/v/dbt-sqlserver?label=latest%20stable&logo=pypi) +Latest pre-release: ![GitHub tag (latest SemVer pre-release)](https://img.shields.io/github/v/tag/dbt-msft/dbt-sqlserver?include_prereleases&label=latest%20pre-release&logo=pypi) + +### `pyodbc` backend + +The legacy and currently default ODBC path uses `pyodbc` and the Microsoft ODBC driver. + +```shell +pip install -U "dbt-sqlserver[pyodbc]" +``` + +You also need the Microsoft ODBC driver for SQL Server installed on your system: [Windows](https://docs.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16#download-for-windows) | [macOS](https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/install-microsoft-odbc-driver-sql-server-macos?view=sql-server-ver16) | [Linux](https://docs.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-sql-server?view=sql-server-ver16)

Debian/Ubuntu -

-Make sure to install the ODBC headers as well as the driver linked above: +Install the ODBC headers as well as the driver linked above: ```shell sudo apt-get install -y unixodbc-dev ``` -

-Latest version: ![PyPI](https://img.shields.io/pypi/v/dbt-sqlserver?label=latest%20stable&logo=pypi) - -```shell -pip install -U dbt-sqlserver -``` +### `mssql-python` backend -Latest pre-release: ![GitHub tag (latest SemVer pre-release)](https://img.shields.io/github/v/tag/dbt-msft/dbt-sqlserver?include_prereleases&label=latest%20pre-release&logo=pypi) +An alternative backend that does not require the ODBC driver. ```shell -pip install -U --pre dbt-sqlserver +pip install -U "dbt-sqlserver[mssql]" ``` -### Optional: `mssql-python` backend - -This adapter can also use the `mssql-python` driver behind a feature flag. - -Install it explicitly when you want to use that backend: +On Debian/Ubuntu-based systems, `mssql-python` requires these system libraries: ```shell -pip install -U mssql-python +sudo apt-get install -y libltdl7 libkrb5-3 libgssapi-krb5-2 ``` -When this backend is enabled, the adapter does not require the ODBC driver-based connection path for that profile. - -## Changelog - -See [the changelog](CHANGELOG.md) - -## Configuration - -### Flags - -- `dbt_sqlserver_use_default_schema_concat`: *(default: `false`)* Controls schema name generation when a [custom schema](https://docs.getdbt.com/docs/build/custom-schemas) is set on a model. - -- `use_mssql_python`: *(default: `false` in the profile)* Switches the connection backend from the legacy ODBC / `pyodbc` path to the `mssql-python` driver for that target profile. - - | Flag value | `custom_schema_name` | Result | - |---|---|---| - | `false` (default, legacy) | *(none)* | `target.schema` | - | `false` (default, legacy) | `"reporting"` | `reporting` | - | `true` (dbt-core standard) | *(none)* | `target.schema` | - | `true` (dbt-core standard) | `"reporting"` | `target.schema_reporting` | - - When `false` (the default), the adapter uses its legacy behaviour: `custom_schema_name` is used **as-is** without being prefixed by `target.schema`. - When `true`, the adapter delegates to dbt-core's `default__generate_schema_name`, which concatenates `target.schema` + `_` + `custom_schema_name`. - - **Example usage in `dbt_project.yml`:** - - ```yaml - vars: - dbt_sqlserver_use_default_schema_concat: true # Enable standard schema concatenation - ``` - - > **Note:** If you want to permanently customise schema generation and avoid any future deprecation of this flag, override the `sqlserver__generate_schema_name` macro directly in your project. - -### `mssql-python` feature flag usage - -Enable the backend per target in your `profiles.yml`: +Enable it per target in your `profiles.yml`: ```yaml your_profile: @@ -108,39 +74,40 @@ your_profile: password: your-password encrypt: true trust_cert: false - use_mssql_python: true + use_mssql_python: true # <-- enables this backend ``` -#### Notes +## Changelog -- `use_mssql_python: true` is a profile-level feature flag. -- When enabled, the adapter uses `mssql-python` instead of the legacy `pyodbc` connection path. -- The legacy ODBC driver setting is only needed for profiles that continue to use the ODBC backend. -- If you enable `use_mssql_python`, make sure the `mssql-python` package is installed in the environment running dbt. -- On Debian/Ubuntu-based environments, `mssql-python` also requires `libltdl7`, `libkrb5-3`, and `libgssapi-krb5-2`. -- This path is intended to fail fast when required dependencies or unsupported settings are missing. +See [the changelog](CHANGELOG.md) -#### Testing +## Configuration -For local development and validation, use the documented adapter workflow from `CONTRIBUTING.md`: +### `use_mssql_python` -```shell -make dev -make server -cp test.env.sample test.env -make unit -make functional -``` +*(default: `false`)* Set to `true` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required driver is not installed. -To exercise the `mssql-python` backend in tests, configure the profile or environment so that the target under test sets: +### `dbt_sqlserver_use_default_schema_concat` -```yaml -use_mssql_python: true -``` +*(default: `false`)* Controls schema name generation when a [custom schema](https://docs.getdbt.com/docs/build/custom-schemas) is set on a model. + +| Value | `custom_schema_name` | Result | +|---|---|---| +| `false` (default) | *(none)* | `target.schema` | +| `false` (default) | `"reporting"` | `reporting` | +| `true` | *(none)* | `target.schema` | +| `true` | `"reporting"` | `target.schema_reporting` | -If you are testing in the devcontainer, the backend prerequisites are installed automatically. Outside the devcontainer, install `mssql-python` and the system libraries above before running the unit or functional suite. +When `false`, `custom_schema_name` is used as-is without being prefixed by `target.schema`. +When `true`, the adapter delegates to dbt-core's `default__generate_schema_name`. +```yaml +# dbt_project.yml +vars: + dbt_sqlserver_use_default_schema_concat: true +``` +> **Note:** To permanently customise schema generation without a flag dependency, override the `sqlserver__generate_schema_name` macro directly in your project. ## Contributing @@ -149,7 +116,7 @@ If you are testing in the devcontainer, the backend prerequisites are installed [![Integration tests on Azure](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/integration-tests-azure.yml/badge.svg)](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/integration-tests-azure.yml) This adapter is community-maintained. -You are welcome to contribute by creating issues, opening or reviewing pull requests or helping other users in Slack channel. +You are welcome to contribute by creating issues, opening or reviewing pull requests, or helping other users in the Slack channel. If you're unsure how to get started, check out our [contributing guide](CONTRIBUTING.md). ## License From 3b2220c573aa5361c8267830e78affd17ee23147 Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Tue, 26 May 2026 02:19:28 +0000 Subject: [PATCH 5/8] feat: Make pyodbc the default SQL Server backend for backward compat and docs matching and improve optional mssql-python support Upgrade to backend enum Update README install docs to explain the default pyodbc backend and optional mssql extra for mssql-python Add azure dep to lazy loading Add pyodbc as a core dependency and update mssql-python, azure-identity, and azure-core extras in pyproject.toml and uv.lock Add authentication normalization and validation for SQL Server profile fields Add mssql-python-specific connection string handling for MSI and Integrated auth flows Expand unit tests for mssql-python auth normalization and connection behavior --- .../workflows/integration-tests-sqlserver.yml | 4 +- CONTRIBUTING.md | 6 +- README.md | 15 +- .../sqlserver/sqlserver_connections.py | 588 ++++++++++++------ .../sqlserver/sqlserver_credentials.py | 39 +- pyproject.toml | 7 +- test.env.sample | 2 +- tests/conftest.py | 5 +- .../adapters/mssql/test_connection_logic.py | 31 +- .../test_sqlserver_connection_manager.py | 480 +++++++++++++- uv.lock | 6 +- 11 files changed, 963 insertions(+), 220 deletions(-) diff --git a/.github/workflows/integration-tests-sqlserver.yml b/.github/workflows/integration-tests-sqlserver.yml index 310e10753..d931dec4a 100644 --- a/.github/workflows/integration-tests-sqlserver.yml +++ b/.github/workflows/integration-tests-sqlserver.yml @@ -58,10 +58,8 @@ jobs: include: - backend: pyodbc install_extra: pyodbc - use_mssql_python: "False" - backend: mssql-python install_extra: mssql - use_mssql_python: "True" runs-on: ubuntu-latest container: image: ghcr.io/${{ github.repository }}:CI-${{ matrix.python_version }}-msodbc${{ matrix.msodbc_version }} @@ -96,4 +94,4 @@ jobs: DBT_TEST_USER_2: DBT_TEST_USER_2 DBT_TEST_USER_3: DBT_TEST_USER_3 SQLSERVER_TEST_DRIVER: "ODBC Driver ${{ matrix.msodbc_version }} for SQL Server" - SQLSERVER_TEST_USE_MSSQL_PYTHON: ${{ matrix.use_mssql_python }} + SQLSERVER_TEST_BACKEND: ${{ matrix.backend }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4559b6383..be99accf4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,7 +59,7 @@ This will use Docker Compose to spin up a local instance of SQL Server. Docker C Next, tell our tests how they should connect to the local instance by creating a file called `test.env` in the root of the project. You can use the provided `test.env.sample` as a base and if you started the server with `make server`, then this matches the instance running on your local machine. -If you are testing the optional `mssql-python` backend, also enable its profile flag in `test.env` so the adapter selects that implementation instead of the legacy driver-based one. +If you are testing the optional `mssql-python` backend, also enable its profile setting in `test.env` so the adapter selects that implementation instead of the legacy driver-based one. ```shell cp test.env.sample test.env @@ -68,7 +68,7 @@ cp test.env.sample test.env When using the optional `mssql-python` backend, update `test.env` with: ```shell -SQLSERVER_TEST_USE_MSSQL_PYTHON=True +SQLSERVER_TEST_BACKEND=mssql-python ``` You can tweak the contents of this file to test against a different database. @@ -87,7 +87,7 @@ make unit make functional ``` -This remains the documented test procedure for both connection backends. When the `pyodbc` path is enabled, run the same commands after installing `dbt-sqlserver[pyodbc]` or `pyodbc`. When the `mssql-python` flag is enabled, run the same commands after installing `dbt-sqlserver[mssql]` or `mssql-python` and setting `SQLSERVER_TEST_USE_MSSQL_PYTHON=True` in `test.env`. +This remains the documented test procedure for both connection backends. When the `pyodbc` path is enabled, run the same commands after installing `dbt-sqlserver[pyodbc]` or `pyodbc`. When the `mssql-python` backend is enabled, run the same commands after installing `dbt-sqlserver[mssql]` or `mssql-python` and setting `SQLSERVER_TEST_BACKEND=mssql-python` in `test.env`. ## CI/CD diff --git a/README.md b/README.md index a83e44d8e..5e6de166f 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Join us on the [dbt Slack](https://getdbt.slack.com/archives/CMRMDDQ9W) to ask q ## Installation -The base package does not bundle any connection driver. Install the adapter together with the backend extra that matches your setup. +The default install uses the `pyodbc` backend and includes the `pyodbc` dependency. If you want the optional `mssql-python` backend instead, install the `mssql` extra. Latest version: ![PyPI](https://img.shields.io/pypi/v/dbt-sqlserver?label=latest%20stable&logo=pypi) Latest pre-release: ![GitHub tag (latest SemVer pre-release)](https://img.shields.io/github/v/tag/dbt-msft/dbt-sqlserver?include_prereleases&label=latest%20pre-release&logo=pypi) @@ -25,6 +25,12 @@ Latest pre-release: ![GitHub tag (latest SemVer pre-release)](https://img.shield The legacy and currently default ODBC path uses `pyodbc` and the Microsoft ODBC driver. +```shell +pip install -U dbt-sqlserver +``` + +You should migrate to use an explicit extra for incoming deprecation, the following is equivalent: + ```shell pip install -U "dbt-sqlserver[pyodbc]" ``` @@ -74,7 +80,7 @@ your_profile: password: your-password encrypt: true trust_cert: false - use_mssql_python: true # <-- enables this backend + backend: mssql-python # <-- enables this backend ``` ## Changelog @@ -83,9 +89,10 @@ See [the changelog](CHANGELOG.md) ## Configuration -### `use_mssql_python` +### `backend` + +*(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required driver is not installed. -*(default: `false`)* Set to `true` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required driver is not installed. ### `dbt_sqlserver_use_default_schema_concat` diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index 03bdba5d0..7d365d103 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -4,9 +4,9 @@ from contextlib import contextmanager from dataclasses import dataclass from itertools import chain, repeat -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Tuple, Type, Union, cast -import agate +import agate # type: ignore[import] import dbt_common.exceptions from dbt_common.clients.agate_helper import empty_table from dbt_common.events.contextvars import get_node_info @@ -18,46 +18,113 @@ from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus from dbt.adapters.sql.connections import SQLConnectionManager from dbt.adapters.sqlserver import __version__ -from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials +from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerBackend, SQLServerCredentials -_PYODBC_MODULE: Optional[Any] = None + +class PyodbcModuleProtocol(Protocol): + InternalError: type[Exception] + OperationalError: type[Exception] + InterfaceError: type[Exception] + DatabaseError: type[Exception] + pooling: bool + + def connect(self, *args: Any, **kwargs: Any) -> Any: ... + + +class MssqlPythonModuleProtocol(Protocol): + InternalError: type[Exception] + OperationalError: type[Exception] + InterfaceError: type[Exception] + DatabaseError: type[Exception] + + def connect(self, *args: Any, **kwargs: Any) -> Any: ... + + +class AccessTokenProtocol(Protocol): + token: str + expires_on: int + + +class TokenCredentialProtocol(Protocol): + def get_token(self, *scopes: Optional[str], **kwargs: Any) -> AccessTokenProtocol: ... + + +class CredentialFactory(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> TokenCredentialProtocol: ... + + +class AzureIdentityModuleProtocol(Protocol): + AzureCliCredential: CredentialFactory + DefaultAzureCredential: CredentialFactory + EnvironmentCredential: CredentialFactory + ManagedIdentityCredential: CredentialFactory + ClientSecretCredential: CredentialFactory + + +class AzureCredentialsModuleProtocol(Protocol): + AccessToken: Type[AccessTokenProtocol] + + +_PYODBC_MODULE: Optional[PyodbcModuleProtocol] = None _PYODBC_IMPORT_ERROR: Optional[ModuleNotFoundError] = None -_MSSQL_PYTHON_MODULE: Optional[Any] = None +_MSSQL_PYTHON_MODULE: Optional[MssqlPythonModuleProtocol] = None _MSSQL_PYTHON_IMPORT_ERROR: Optional[ModuleNotFoundError] = None -try: - from azure.core.credentials import AccessToken -except ModuleNotFoundError: +_AZURE_CREDENTIALS_MODULE: Optional[AzureCredentialsModuleProtocol] = None +_AZURE_CREDENTIALS_IMPORT_ERROR: Optional[ModuleNotFoundError] = None - @dataclass - class AccessToken: # type: ignore[no-redef] - token: str - expires_on: int +_AZURE_IDENTITY_MODULE: Optional[AzureIdentityModuleProtocol] = None +_AZURE_IDENTITY_IMPORT_ERROR: Optional[ModuleNotFoundError] = None -try: - from azure.identity import ( - AzureCliCredential, - ClientSecretCredential, - DefaultAzureCredential, - EnvironmentCredential, - ManagedIdentityCredential, - ) +@dataclass +class AccessToken: # type: ignore[no-redef] + token: str + expires_on: int + + +def _get_azure_access_token_class() -> Type[Any]: + global _AZURE_CREDENTIALS_MODULE, _AZURE_CREDENTIALS_IMPORT_ERROR + + if _AZURE_CREDENTIALS_MODULE is not None: + return _AZURE_CREDENTIALS_MODULE.AccessToken + + if _AZURE_CREDENTIALS_IMPORT_ERROR is not None: + return AccessToken + + try: + import azure.core.credentials as azure_credentials # type: ignore[import] + except ModuleNotFoundError as exc: + _AZURE_CREDENTIALS_IMPORT_ERROR = exc + return AccessToken + + _AZURE_CREDENTIALS_MODULE = cast(AzureCredentialsModuleProtocol, azure_credentials) + return azure_credentials.AccessToken + + +def _get_azure_identity_module() -> AzureIdentityModuleProtocol: + global _AZURE_IDENTITY_MODULE, _AZURE_IDENTITY_IMPORT_ERROR + + if _AZURE_IDENTITY_MODULE is not None: + return _AZURE_IDENTITY_MODULE + + if _AZURE_IDENTITY_IMPORT_ERROR is not None: + raise _missing_azure_identity_error() from _AZURE_IDENTITY_IMPORT_ERROR + + try: + import azure.identity as azure_identity # type: ignore[import] + except ModuleNotFoundError as exc: + _AZURE_IDENTITY_IMPORT_ERROR = exc + raise _missing_azure_identity_error() from exc - _AZURE_IDENTITY_IMPORT_ERROR = None -except ModuleNotFoundError as exc: - AzureCliCredential = None - ClientSecretCredential = None - DefaultAzureCredential = None - EnvironmentCredential = None - ManagedIdentityCredential = None - _AZURE_IDENTITY_IMPORT_ERROR = exc + _AZURE_IDENTITY_MODULE = cast(AzureIdentityModuleProtocol, azure_identity) + return _AZURE_IDENTITY_MODULE -_TOKEN: Optional[AccessToken] = None +_TOKEN: Optional[AccessTokenProtocol] = None AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default" -AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials, Optional[str]], AccessToken] +AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials, Optional[str]], AccessTokenProtocol] logger = AdapterLogger("sqlserver") @@ -79,85 +146,156 @@ class AccessToken: # type: ignore[no-redef] MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS = { "cli", - "auto", "environment", - "serviceprincipal", "activedirectoryaccesstoken", } -def _get_pyodbc() -> Any: +def _auth_key(authentication: Optional[str]) -> str: + if authentication is None: + return "" + return authentication.replace("_", "").replace(" ", "").lower() + + +def _normalize_mssql_python_authentication(authentication: Optional[str]) -> Optional[str]: + authentication = authentication or "" + key = _auth_key(authentication) + if not key: + return None + + if key in {"msi", "activedirectorymsi"}: + return "ActiveDirectoryMSI" + + if key in {"activedirectoryintegrated", "adintegrated"}: + return "ActiveDirectoryIntegrated" + + if key in {"serviceprincipal", "activedirectoryserviceprincipal"}: + return "ActiveDirectoryServicePrincipal" + + if key in {"auto", "default", "activedirectorydefault"}: + return "ActiveDirectoryDefault" + + if key == "activedirectorypassword": + return "ActiveDirectoryPassword" + + if key == "activedirectoryinteractive": + return "ActiveDirectoryInteractive" + + if key == "activedirectorydevicecode": + return "ActiveDirectoryDeviceCode" + + return authentication.strip() + + +def _missing_pyodbc_error() -> dbt_common.exceptions.DbtRuntimeError: + return dbt_common.exceptions.DbtRuntimeError( + "The legacy `pyodbc` backend was requested, but the optional dependency " + "`pyodbc` is not installed. Install it with `pip install pyodbc` " + "or set `backend: mssql-python` in the profile." + ) + + +def _get_pyodbc() -> PyodbcModuleProtocol: global _PYODBC_MODULE, _PYODBC_IMPORT_ERROR if _PYODBC_MODULE is not None: return _PYODBC_MODULE if _PYODBC_IMPORT_ERROR is not None: - raise dbt_common.exceptions.DbtRuntimeError( - "The legacy `pyodbc` backend was requested, but the optional dependency " - "`pyodbc` is not installed. Install it with `pip install pyodbc` " - "or enable `use_mssql_python` in the profile." - ) from _PYODBC_IMPORT_ERROR + raise _missing_pyodbc_error() from _PYODBC_IMPORT_ERROR try: - import pyodbc as imported_pyodbc + import pyodbc as imported_pyodbc # type: ignore[import] except ModuleNotFoundError as exc: _PYODBC_IMPORT_ERROR = exc - raise dbt_common.exceptions.DbtRuntimeError( - "The legacy `pyodbc` backend was requested, but the optional dependency " - "`pyodbc` is not installed. Install it with `pip install pyodbc` " - "or enable `use_mssql_python` in the profile." - ) from exc + raise _missing_pyodbc_error() from exc - _PYODBC_MODULE = imported_pyodbc + _PYODBC_MODULE = cast(PyodbcModuleProtocol, imported_pyodbc) return _PYODBC_MODULE -def _get_mssql_python() -> Any: +def _missing_mssql_python_error() -> dbt_common.exceptions.DbtRuntimeError: + return dbt_common.exceptions.DbtRuntimeError( + "The `mssql-python` backend was requested, but the optional dependency " + "`mssql-python` is not installed. Install it with `pip install mssql-python` " + "or set `backend: pyodbc` in the profile." + ) + + +def _missing_azure_identity_error() -> dbt_common.exceptions.DbtRuntimeError: + return dbt_common.exceptions.DbtRuntimeError( + "Azure authentication requires the optional dependency 'azure-identity'. " + "Install it with `pip install azure-identity` or use a non-Azure " + "authentication mode." + ) + + +def _get_mssql_python() -> MssqlPythonModuleProtocol: global _MSSQL_PYTHON_MODULE, _MSSQL_PYTHON_IMPORT_ERROR if _MSSQL_PYTHON_MODULE is not None: return _MSSQL_PYTHON_MODULE if _MSSQL_PYTHON_IMPORT_ERROR is not None: - raise dbt_common.exceptions.DbtRuntimeError( - "The `mssql-python` backend was requested, but the optional dependency " - "`mssql-python` is not installed. Install it with `pip install mssql-python` " - "or disable `use_mssql_python` in the profile." - ) from _MSSQL_PYTHON_IMPORT_ERROR + raise _missing_mssql_python_error() from _MSSQL_PYTHON_IMPORT_ERROR try: - import mssql_python as imported_mssql_python + import mssql_python as imported_mssql_python # type: ignore[import] except ModuleNotFoundError as exc: _MSSQL_PYTHON_IMPORT_ERROR = exc - raise dbt_common.exceptions.DbtRuntimeError( - "The `mssql-python` backend was requested, but the optional dependency " - "`mssql-python` is not installed. Install it with `pip install mssql-python` " - "or disable `use_mssql_python` in the profile." - ) from exc + raise _missing_mssql_python_error() from exc - _MSSQL_PYTHON_MODULE = imported_mssql_python + _MSSQL_PYTHON_MODULE = cast(MssqlPythonModuleProtocol, imported_mssql_python) return _MSSQL_PYTHON_MODULE -def _require_azure_identity(authentication: str) -> None: - if _AZURE_IDENTITY_IMPORT_ERROR is not None: - raise dbt_common.exceptions.DbtRuntimeError( - ( - "Azure authentication '{}' requires the optional " - "dependency 'azure-identity'. Install it with `pip install " - "azure-identity` or use a non-Azure authentication mode." - ).format(authentication) - ) from _AZURE_IDENTITY_IMPORT_ERROR +def _normalize_authentication(authentication: Optional[str]) -> str: + if authentication is None: + return "sql" + normalized = authentication.strip().lower() + if normalized == "activedirectorymsi": + return "msi" + return normalized -def _requires_pyodbc_backend(credentials: SQLServerCredentials) -> bool: - authentication = str(credentials.authentication or "sql").lower().strip() + +def _uses_pyodbc_token_authentication(credentials: SQLServerCredentials) -> bool: + authentication = _normalize_authentication(credentials.authentication) return authentication in AZURE_AUTH_FUNCTIONS or authentication == "activedirectoryaccesstoken" -def _use_mssql_python_backend(credentials: SQLServerCredentials) -> bool: - return bool(getattr(credentials, "use_mssql_python", False)) +def _is_mssql_python_backend(credentials: SQLServerCredentials) -> bool: + return credentials.backend == SQLServerBackend.mssql_python + + +def _validate_connection_requirements(credentials: SQLServerCredentials) -> None: + for name in ("host", "database", "schema"): + value = getattr(credentials, name) + if value is None or not str(value).strip(): + raise dbt_common.exceptions.DbtRuntimeError( + f"The `{name}` profile field is required for SQL Server connections." + ) + + if credentials.windows_login: + normalized = _normalize_mssql_python_authentication(credentials.authentication) + if normalized is not None and _auth_key(normalized).startswith("activedirectory"): + raise dbt_common.exceptions.DbtRuntimeError( + "windows_login/trusted_connection cannot be combined with ActiveDirectory " + "authentication. Remove `authentication` or disable `windows_login`." + ) + elif credentials.authentication is None or not str(credentials.authentication).strip(): + raise dbt_common.exceptions.DbtRuntimeError( + "The `authentication` profile field is required for SQL Server connections." + ) + + if credentials.encrypt is None: + raise dbt_common.exceptions.DbtRuntimeError( + "The `encrypt` profile field is required for SQL Server connections." + ) + if credentials.trust_cert is None: + raise dbt_common.exceptions.DbtRuntimeError( + "The `trust_cert` profile field is required for SQL Server connections." + ) def _validate_pyodbc_requirements(credentials: SQLServerCredentials) -> None: @@ -169,14 +307,14 @@ def _validate_pyodbc_requirements(credentials: SQLServerCredentials) -> None: def _validate_mssql_python_requirements(credentials: SQLServerCredentials) -> None: - authentication = str(credentials.authentication or "sql").strip() - authentication_lower = authentication.lower() + authentication = _normalize_mssql_python_authentication(credentials.authentication) + authentication_key = _auth_key(authentication) - if authentication_lower in MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS: + if authentication_key in MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS: raise dbt_common.exceptions.DbtRuntimeError( "Authentication '{}' is currently only supported by the pyodbc backend " "in this adapter. " - "Disable `use_mssql_python` or use a connection-string-supported " + "Use `backend: pyodbc` or use a connection-string-supported " "authentication mode such as " "`sql`, `ActiveDirectoryPassword`, `ActiveDirectoryInteractive`, " "`ActiveDirectoryIntegrated`, " @@ -203,13 +341,13 @@ def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: return struct.pack(" bytes: +def convert_access_token_to_mswindows_byte_string(token: AccessTokenProtocol) -> bytes: """ Convert an access token to a Microsoft windows byte string. Parameters ---------- - token : AccessToken + token : AccessTokenProtocol The token. Returns @@ -223,7 +361,7 @@ def convert_access_token_to_mswindows_byte_string(token: AccessToken) -> bytes: def get_cli_access_token( credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessToken: +) -> AccessTokenProtocol: """ Get an Azure access token using the CLI credentials @@ -244,8 +382,8 @@ def get_cli_access_token( Access token. """ _ = credentials - _require_azure_identity("cli") - token = AzureCliCredential().get_token( + azure_identity = _get_azure_identity_module() + token = azure_identity.AzureCliCredential().get_token( scope, timeout=getattr(credentials, "login_timeout", None) ) return token @@ -253,7 +391,7 @@ def get_cli_access_token( def get_auto_access_token( credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessToken: +) -> AccessTokenProtocol: """ Get an Azure access token automatically through azure-identity @@ -267,8 +405,8 @@ def get_auto_access_token( out : AccessToken The access token. """ - _require_azure_identity("auto") - token = DefaultAzureCredential().get_token( + azure_identity = _get_azure_identity_module() + token = azure_identity.DefaultAzureCredential().get_token( scope, timeout=getattr(credentials, "login_timeout", None) ) return token @@ -276,7 +414,7 @@ def get_auto_access_token( def get_environment_access_token( credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessToken: +) -> AccessTokenProtocol: """ Get an Azure access token by reading environment variables @@ -290,8 +428,8 @@ def get_environment_access_token( out : AccessToken The access token. """ - _require_azure_identity("environment") - token = EnvironmentCredential().get_token( + azure_identity = _get_azure_identity_module() + token = azure_identity.EnvironmentCredential().get_token( scope, timeout=getattr(credentials, "login_timeout", None) ) return token @@ -299,7 +437,7 @@ def get_environment_access_token( def get_msi_access_token( credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessToken: +) -> AccessTokenProtocol: """ Get an Azure access token from the system's managed identity @@ -314,14 +452,14 @@ def get_msi_access_token( The access token. """ _ = credentials - _require_azure_identity("msi") - token = ManagedIdentityCredential().get_token(scope) + azure_identity = _get_azure_identity_module() + token = azure_identity.ManagedIdentityCredential().get_token(scope) return token def get_sp_access_token( credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessToken: +) -> AccessTokenProtocol: """ Get an Azure access token using the SP credentials. @@ -336,8 +474,8 @@ def get_sp_access_token( The access token. """ _ = scope - _require_azure_identity("serviceprincipal") - token = ClientSecretCredential( + azure_identity = _get_azure_identity_module() + token = azure_identity.ClientSecretCredential( str(credentials.tenant_id), str(credentials.client_id), str(credentials.client_secret), @@ -372,15 +510,16 @@ def get_pyodbc_attrs_before_credentials(credentials: SQLServerCredentials) -> Di sql_copt_ss_access_token = 1256 # ODBC constant for access token MAX_REMAINING_TIME = 300 - if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS: + authentication = _normalize_authentication(credentials.authentication) + + if authentication in AZURE_AUTH_FUNCTIONS: if not _TOKEN or (_TOKEN.expires_on - time.time() < MAX_REMAINING_TIME): - _TOKEN = AZURE_AUTH_FUNCTIONS[credentials.authentication.lower()]( - credentials, AZURE_CREDENTIAL_SCOPE - ) + _TOKEN = AZURE_AUTH_FUNCTIONS[authentication](credentials, AZURE_CREDENTIAL_SCOPE) + assert _TOKEN is not None token_bytes = convert_access_token_to_mswindows_byte_string(_TOKEN) return {sql_copt_ss_access_token: token_bytes} - if credentials.authentication.lower() == "activedirectoryaccesstoken": + if authentication == "activedirectoryaccesstoken": if credentials.access_token is None or credentials.access_token_expires_on is None: raise ValueError( ( @@ -388,7 +527,7 @@ def get_pyodbc_attrs_before_credentials(credentials: SQLServerCredentials) -> Di "required for ActiveDirectoryAccessToken authentication." ) ) - _TOKEN = AccessToken( + _TOKEN = _get_azure_access_token_class()( token=credentials.access_token, expires_on=int( time.time() + 4500.0 @@ -396,6 +535,7 @@ def get_pyodbc_attrs_before_credentials(credentials: SQLServerCredentials) -> Di else credentials.access_token_expires_on ), ) + assert _TOKEN is not None return {sql_copt_ss_access_token: convert_access_token_to_mswindows_byte_string(_TOKEN)} return {} @@ -420,6 +560,13 @@ def bool_to_connection_string_arg(key: str, value: bool) -> str: return f"{key}={'Yes' if value else 'No'}" +def _escape_connection_string_value(value: Optional[str]) -> str: + text = "" if value is None else str(value) + if text.startswith(" ") or text.endswith(" ") or any(ch in text for ch in ";{}"): + return "{" + text.replace("}", "}}") + "}" + return text + + def byte_array_to_datetime(value: bytes) -> dt.datetime: """ Converts a DATETIMEOFFSET byte array to a timezone-aware datetime object @@ -455,54 +602,123 @@ def byte_array_to_datetime(value: bytes) -> dt.datetime: def _build_server_arg(credentials: SQLServerCredentials) -> str: - if "\\" in credentials.host: + host = credentials.host or "" + if "\\" in host: # If there is a backslash \ in the host name, the host is a # SQL Server named instance. In this case then port number has to be omitted. - return credentials.host - return f"{credentials.host},{credentials.port}" + return host + return f"{host},{credentials.port}" + + +def _format_connection_string_value(value: Optional[str], mssql_python_backend: bool) -> str: + if mssql_python_backend: + return _escape_connection_string_value(value) + return "{" + ("" if value is None else value) + "}" def _build_common_connection_string_parts( credentials: SQLServerCredentials, + mssql_python_backend: bool, ) -> list[str]: con_str = [f"SERVER={_build_server_arg(credentials)}"] con_str.append(f"Database={credentials.database}") - assert credentials.authentication is not None + authentication = credentials.authentication or "" + if mssql_python_backend: + authentication = _normalize_mssql_python_authentication(authentication) or "" - if ( - "ActiveDirectory" in credentials.authentication - and credentials.authentication != "ActiveDirectoryAccessToken" - ): - con_str.append(f"Authentication={credentials.authentication}") + if not authentication.strip() and not credentials.windows_login: + raise dbt_common.exceptions.DbtRuntimeError( + "The `authentication` profile field is required for SQL Server connections." + ) - if credentials.authentication == "ActiveDirectoryPassword": - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") - if credentials.authentication == "ActiveDirectoryServicePrincipal": - con_str.append(f"UID={{{credentials.client_id}}}") - con_str.append(f"PWD={{{credentials.client_secret}}}") - elif credentials.authentication == "ActiveDirectoryInteractive": - con_str.append(f"UID={{{credentials.UID}}}") + if "ActiveDirectory" in authentication and authentication != "ActiveDirectoryAccessToken": + con_str.append(f"Authentication={authentication}") + + if authentication == "ActiveDirectoryPassword": + con_str.append( + f"UID={_format_connection_string_value(credentials.UID, mssql_python_backend)}" + ) + con_str.append( + f"PWD={_format_connection_string_value(credentials.PWD, mssql_python_backend)}" + ) + elif authentication == "ActiveDirectoryServicePrincipal": + con_str.append( + "UID=" + + _format_connection_string_value( + credentials.client_id, + mssql_python_backend, + ) + ) + con_str.append( + "PWD=" + + _format_connection_string_value( + credentials.client_secret, + mssql_python_backend, + ) + ) + elif authentication == "ActiveDirectoryInteractive": + con_str.append( + "UID=%s" + % _format_connection_string_value( + credentials.UID, + mssql_python_backend, + ) + ) + elif authentication == "ActiveDirectoryMSI": + if credentials.PWD: + raise dbt_common.exceptions.DbtRuntimeError( + "password is not valid with ActiveDirectoryMSI for the mssql-python backend." + ) + if credentials.UID: + con_str.append( + f"UID={_format_connection_string_value(credentials.UID, mssql_python_backend)}" + ) + elif authentication == "ActiveDirectoryIntegrated": + if credentials.PWD: + raise dbt_common.exceptions.DbtRuntimeError( + "password is not valid with ActiveDirectoryIntegrated" + " for the mssql-python backend." + ) elif credentials.windows_login: - con_str.append("trusted_connection=Yes") - elif credentials.authentication == "sql": - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") + if mssql_python_backend and (credentials.UID or credentials.PWD): + raise dbt_common.exceptions.DbtRuntimeError( + "user/password are not valid with windows_login/trusted_connection " + "for the mssql-python backend." + ) + con_str.append("Trusted_Connection=yes") + elif authentication == "sql": + con_str.append( + f"UID={_format_connection_string_value(credentials.UID, mssql_python_backend)}" + ) + con_str.append( + f"PWD={_format_connection_string_value(credentials.PWD, mssql_python_backend)}" + ) - assert credentials.encrypt is not None - assert credentials.trust_cert is not None + if credentials.encrypt is None: + raise dbt_common.exceptions.DbtRuntimeError( + "The `encrypt` profile field is required for SQL Server connections." + ) + if credentials.trust_cert is None: + raise dbt_common.exceptions.DbtRuntimeError( + "The `trust_cert` profile field is required for SQL Server connections." + ) con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt)) con_str.append(bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert)) + if not mssql_python_backend: + # Reserved keyword 'app' is controlled by the driver and cannot be specified by the user. + application_name = f"dbt-{credentials.type}/{__version__.version}" + con_str.append(f"APP={application_name}") + return con_str def _build_pyodbc_connection_string(credentials: SQLServerCredentials) -> str: con_str = [f"DRIVER={{{credentials.driver}}}"] - con_str.extend(_build_common_connection_string_parts(credentials)) + con_str.extend(_build_common_connection_string_parts(credentials, mssql_python_backend=False)) con_str.append("Pooling=true") if credentials.trace_flag: @@ -510,27 +726,14 @@ def _build_pyodbc_connection_string(credentials: SQLServerCredentials) -> str: else: con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_OFF") - plugin_version = __version__.version - application_name = f"dbt-{credentials.type}/{plugin_version}" - con_str.append(f"APP={application_name}") - - try: - con_str.append("ConnectRetryCount=3") - con_str.append("ConnectRetryInterval=10") - except Exception as e: - logger.debug( - ( - "Retry count should be a integer value. " - "Skipping retries in the connection string." - ), - str(e), - ) + con_str.append("ConnectRetryCount=3") + con_str.append("ConnectRetryInterval=10") return ";".join(con_str) def _build_mssql_python_connection_string(credentials: SQLServerCredentials) -> str: - con_str = _build_common_connection_string_parts(credentials) + con_str = _build_common_connection_string_parts(credentials, mssql_python_backend=True) con_str.append("ConnectRetryCount=3") con_str.append("ConnectRetryInterval=10") return ";".join(con_str) @@ -547,10 +750,50 @@ def _sanitize_connection_string_for_logging(connection_string: str) -> str: return ";".join(sanitized) +def _connect_mssql_python( + mssql_python: MssqlPythonModuleProtocol, + credentials: SQLServerCredentials, + connection_string: str, +) -> Any: + handle = mssql_python.connect( + connection_string, + autocommit=True, + timeout=credentials.login_timeout, + ) + try: + handle.timeout = credentials.query_timeout + except Exception: + logger.debug( + "The mssql-python connection object does not expose a mutable `timeout` " + "attribute; continuing without setting query timeout on the handle." + ) + logger.debug(f"Connected to db: {credentials.database}") + return handle + + +def _connect_pyodbc( + pyodbc: PyodbcModuleProtocol, + credentials: SQLServerCredentials, + connection_string: str, +) -> Any: + pyodbc.pooling = True + attrs_before = get_pyodbc_attrs_before_credentials(credentials) + + handle = pyodbc.connect( + connection_string, + attrs_before=attrs_before, + autocommit=True, + timeout=credentials.login_timeout, + ) + handle.timeout = credentials.query_timeout + logger.debug(f"Connected to db: {credentials.database}") + return handle + + def _get_backend_exceptions( credentials: SQLServerCredentials, ) -> Tuple[Type[Exception], ...]: - if _use_mssql_python_backend(credentials): + if _is_mssql_python_backend(credentials): mssql_python = _get_mssql_python() retryable_exceptions = [ @@ -558,7 +801,7 @@ def _get_backend_exceptions( getattr(mssql_python, "OperationalError", Exception), ] - if _requires_pyodbc_backend(credentials): + if _uses_pyodbc_token_authentication(credentials): retryable_exceptions.append(getattr(mssql_python, "InterfaceError", Exception)) return tuple(retryable_exceptions) @@ -570,7 +813,7 @@ def _get_backend_exceptions( pyodbc.OperationalError, ] - if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS: + if _uses_pyodbc_token_authentication(credentials): retryable_exceptions.append(pyodbc.InterfaceError) return tuple(retryable_exceptions) @@ -591,7 +834,7 @@ def exception_handler(self, sql): except Exception as e: credentials = self.get_thread_connection().credentials - if not _use_mssql_python_backend(credentials): + if not _is_mssql_python_backend(credentials): pyodbc = _PYODBC_MODULE if pyodbc is not None and isinstance(e, getattr(pyodbc, "DatabaseError", tuple())): logger.debug("Database error: {}".format(str(e))) @@ -603,7 +846,7 @@ def exception_handler(self, sql): raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e - if _use_mssql_python_backend(credentials): + if _is_mssql_python_backend(credentials): mssql_python = _MSSQL_PYTHON_MODULE if mssql_python is not None and isinstance( e, getattr(mssql_python, "DatabaseError", tuple()) @@ -633,56 +876,33 @@ def open(cls, connection: Connection) -> Connection: credentials = cls.get_credentials(connection.credentials) - if _use_mssql_python_backend(credentials): + _validate_connection_requirements(credentials) + + if _is_mssql_python_backend(credentials): mssql_python = _get_mssql_python() _validate_mssql_python_requirements(credentials) con_str_concat = _build_mssql_python_connection_string(credentials) - pyodbc = None + + def connect() -> Any: + logger.debug( + "Using connection string: %s" + % _sanitize_connection_string_for_logging(con_str_concat) + ) + return _connect_mssql_python(mssql_python, credentials, con_str_concat) + else: pyodbc = _get_pyodbc() _validate_pyodbc_requirements(credentials) con_str_concat = _build_pyodbc_connection_string(credentials) - mssql_python = None - - con_str_display = _sanitize_connection_string_for_logging(con_str_concat) - retryable_exceptions = _get_backend_exceptions(credentials) - - def connect(): - logger.debug(f"Using connection string: {con_str_display}") - if _use_mssql_python_backend(credentials): - assert mssql_python is not None - - mssql_python.pooling(enabled=False) - handle = mssql_python.connect( - con_str_concat, - autocommit=True, - timeout=credentials.login_timeout, + def connect() -> Any: + logger.debug( + "Using connection string: %s" + % _sanitize_connection_string_for_logging(con_str_concat) ) - try: - handle.timeout = credentials.query_timeout - except Exception: - logger.debug( - "The mssql-python connection object does not expose a mutable `timeout` " - "attribute; continuing without setting query timeout on the handle." - ) - logger.debug(f"Connected to db: {credentials.database}") - return handle + return _connect_pyodbc(pyodbc, credentials, con_str_concat) - assert pyodbc is not None - - pyodbc.pooling = True - attrs_before = get_pyodbc_attrs_before_credentials(credentials) - - handle = pyodbc.connect( - con_str_concat, - attrs_before=attrs_before, - autocommit=True, - timeout=credentials.login_timeout, - ) - handle.timeout = credentials.query_timeout - logger.debug(f"Connected to db: {credentials.database}") - return handle + retryable_exceptions = _get_backend_exceptions(credentials) conn = cls.retry_connection( connection, @@ -831,7 +1051,7 @@ def get_response(cls, cursor: Any) -> AdapterResponse: ) @classmethod - def data_type_code_to_name(cls, type_code: Union[str, str]) -> str: + def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: data_type = str(type_code)[ str(type_code).index("'") + 1 : str(type_code).rindex("'") # noqa: E203 ] diff --git a/dbt/adapters/sqlserver/sqlserver_credentials.py b/dbt/adapters/sqlserver/sqlserver_credentials.py index 4c0811d3d..81f7f7bb9 100644 --- a/dbt/adapters/sqlserver/sqlserver_credentials.py +++ b/dbt/adapters/sqlserver/sqlserver_credentials.py @@ -1,15 +1,24 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union, cast + +import dbt_common.exceptions +from dbt_common.dataclass_schema import StrEnum from dbt.adapters.contracts.connection import Credentials +class SQLServerBackend(StrEnum): + pyodbc = "pyodbc" + mssql_python = "mssql-python" + + @dataclass class SQLServerCredentials(Credentials): + backend: Union[SQLServerBackend, str] = SQLServerBackend.pyodbc driver: Optional[str] = None - host: str = "" - database: str = "" - schema: str = "" + host: Optional[str] = None + database: Optional[str] = None + schema: Optional[str] = None UID: Optional[str] = None PWD: Optional[str] = None port: Optional[int] = 1433 @@ -27,7 +36,6 @@ class SQLServerCredentials(Credentials): schema_authorization: Optional[str] = None login_timeout: Optional[int] = 0 query_timeout: Optional[int] = 0 - use_mssql_python: bool = False _ALIASES = { "user": "UID", @@ -42,14 +50,27 @@ class SQLServerCredentials(Credentials): "TrustServerCertificate": "trust_cert", "schema_auth": "schema_authorization", "SQL_ATTR_TRACE": "trace_flag", - "mssql_python": "use_mssql_python", - "use_mssql_python_backend": "use_mssql_python", } + def __post_init__(self) -> None: + if isinstance(self.backend, str): + try: + self.backend = SQLServerBackend(self.backend) + except ValueError as exc: + raise dbt_common.exceptions.DbtRuntimeError( + "Unsupported sqlserver backend: '{}'. " + "Supported backends are 'pyodbc' and 'mssql-python'.".format(self.backend) + ) from exc + + self.backend = cast(SQLServerBackend, self.backend) + @property def type(self): return "sqlserver" + def _effective_backend(self) -> SQLServerBackend: + return cast(SQLServerBackend, self.backend) + def _connection_keys(self): if self.windows_login is True: self.authentication = "Windows Login" @@ -70,10 +91,10 @@ def _connection_keys(self): "trace_flag", "encrypt", "trust_cert", - "use_mssql_python", + "backend", ) - if not self.use_mssql_python: + if self._effective_backend() == SQLServerBackend.pyodbc: keys = ("driver",) + keys return keys diff --git a/pyproject.toml b/pyproject.toml index a6eb6e77d..5e013a9e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "dbt-core>=1.10.0,<1.11.0", "dbt-common>=1.22.0,<2.0", "dbt-adapters>=1.15.2,<2.0", + "pyodbc>=5.2.0", ] dynamic = ["version"] @@ -41,7 +42,9 @@ pyodbc = [ "pyodbc>=5.2.0", ] mssql = [ - "mssql-python>=1.1.0", + "mssql-python>=1.7.1", + "azure-identity>=1.12.0", + "azure-core>=1.0.0", ] [dependency-groups] @@ -49,7 +52,7 @@ dev = [ "dbt-tests-adapter>=1.15.0,<2.0", "azure-identity>=1.12.0", "pyodbc>=5.2.0", - "mssql-python>=1.1.0", + "mssql-python>=1.7.1", "build", "bumpversion", "flaky", diff --git a/test.env.sample b/test.env.sample index b66c97497..433818d82 100644 --- a/test.env.sample +++ b/test.env.sample @@ -6,7 +6,7 @@ SQLSERVER_TEST_PORT=1433 SQLSERVER_TEST_DBNAME=TestDB SQLSERVER_TEST_ENCRYPT=True SQLSERVER_TEST_TRUST_CERT=True -SQLSERVER_TEST_USE_MSSQL_PYTHON=False +SQLSERVER_TEST_BACKEND=pyodbc DBT_TEST_USER_1=DBT_TEST_USER_1 DBT_TEST_USER_2=DBT_TEST_USER_2 DBT_TEST_USER_3=DBT_TEST_USER_3 diff --git a/tests/conftest.py b/tests/conftest.py index 376bdd93c..f9c835259 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,13 +49,14 @@ def is_azure(request: FixtureRequest) -> bool: def _all_profiles_base(): + backend = os.getenv("SQLSERVER_TEST_BACKEND", "pyodbc") + return { "type": "sqlserver", "driver": os.getenv("SQLSERVER_TEST_DRIVER", "ODBC Driver 18 for SQL Server"), "port": int(os.getenv("SQLSERVER_TEST_PORT", "1433")), "retries": 2, - "use_mssql_python": os.getenv("SQLSERVER_TEST_USE_MSSQL_PYTHON", "False").lower() - == "true", + "backend": backend, } diff --git a/tests/unit/adapters/mssql/test_connection_logic.py b/tests/unit/adapters/mssql/test_connection_logic.py index f2171591f..0ee823690 100644 --- a/tests/unit/adapters/mssql/test_connection_logic.py +++ b/tests/unit/adapters/mssql/test_connection_logic.py @@ -48,7 +48,10 @@ def test_connection_string_windows_login_with_port(base_credentials): connection_string = args[0] assert "SERVER=servers.database.windows.net,1444" in connection_string - assert "trusted_connection=Yes" in connection_string + assert "Trusted_Connection=yes" in connection_string + assert "UID=" not in connection_string + assert "PWD=" not in connection_string + assert "APP=dbt-sqlserver/" in connection_string def test_connection_string_standard_login_with_port(base_credentials): @@ -57,6 +60,7 @@ def test_connection_string_standard_login_with_port(base_credentials): base_credentials.authentication = "sql" base_credentials.UID = "user" base_credentials.PWD = "password" + base_credentials.trace_flag = True connection = MagicMock() connection.state = "closed" @@ -82,6 +86,31 @@ def test_connection_string_standard_login_with_port(base_credentials): assert "SERVER=servers.database.windows.net,1444" in connection_string assert "UID={user}" in connection_string + assert "PWD={password}" in connection_string + assert "Pooling=true" in connection_string + assert "SQL_ATTR_TRACE=SQL_OPT_TRACE_ON" in connection_string + assert "APP=dbt-sqlserver/" in connection_string + assert "ConnectRetryCount=3" in connection_string + assert "ConnectRetryInterval=10" in connection_string + + +def test_pyodbc_token_authentication_passes_attrs_before(base_credentials, monkeypatch): + base_credentials.authentication = "cli" + base_credentials.windows_login = False + + fake_token = SimpleNamespace(token="fake-token", expires_on=9999999999) + fake_credential = SimpleNamespace(get_token=lambda *args, **kwargs: fake_token) + fake_identity = SimpleNamespace(AzureCliCredential=lambda *args, **kwargs: fake_credential) + + monkeypatch.setattr( + sqlserver_connections, "_AZURE_IDENTITY_MODULE", fake_identity, raising=False + ) + monkeypatch.setattr(sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", None, raising=False) + + attrs_before = sqlserver_connections.get_pyodbc_attrs_before_credentials(base_credentials) + + assert 1256 in attrs_before + assert isinstance(attrs_before[1256], bytes) def test_connection_string_named_instance_no_port(base_credentials): diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index e4e8c8434..7c95a4d01 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -11,10 +11,13 @@ from dbt.adapters.sqlserver import sqlserver_connections from dbt.adapters.sqlserver.sqlserver_connections import ( SQLServerConnectionManager, + _build_mssql_python_connection_string, + _normalize_mssql_python_authentication, + _validate_mssql_python_requirements, bool_to_connection_string_arg, get_pyodbc_attrs_before_credentials, ) -from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials +from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerBackend, SQLServerCredentials # See # https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.5.0/sdk/identity/azure-identity/tests/test_cli_credential.py @@ -73,13 +76,303 @@ def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> assert bool_to_connection_string_arg(key, value) == expected +@pytest.mark.parametrize( + "input_auth, expected", + [ + ("msi", "ActiveDirectoryMSI"), + ("ActiveDirectoryMsi", "ActiveDirectoryMSI"), + ("ActiveDirectoryMSI", "ActiveDirectoryMSI"), + ("active_directory_msi", "ActiveDirectoryMSI"), + ("ActiveDirectoryIntegrated", "ActiveDirectoryIntegrated"), + ("active_directory_integrated", "ActiveDirectoryIntegrated"), + ("adintegrated", "ActiveDirectoryIntegrated"), + ("serviceprincipal", "ActiveDirectoryServicePrincipal"), + ("ActiveDirectoryServicePrincipal", "ActiveDirectoryServicePrincipal"), + ("auto", "ActiveDirectoryDefault"), + ("ActiveDirectoryDefault", "ActiveDirectoryDefault"), + ("default", "ActiveDirectoryDefault"), + ("ActiveDirectoryPassword", "ActiveDirectoryPassword"), + ("ActiveDirectoryInteractive", "ActiveDirectoryInteractive"), + ("ActiveDirectoryDeviceCode", "ActiveDirectoryDeviceCode"), + ], +) +def test_normalize_mssql_python_authentication(input_auth: str, expected: str) -> None: + assert _normalize_mssql_python_authentication(input_auth) == expected + + +def test_escape_connection_string_value_quotes_only_when_needed() -> None: + assert sqlserver_connections._escape_connection_string_value("plain") == "plain" + assert ( + sqlserver_connections._escape_connection_string_value("contains;semicolon") + == "{contains;semicolon}" + ) + assert sqlserver_connections._escape_connection_string_value("brace}") == "{brace}}}" + assert sqlserver_connections._escape_connection_string_value(" leading") == "{ leading}" + assert sqlserver_connections._escape_connection_string_value("trailing ") == "{trailing }" + + +def test_mssql_python_active_directory_default_passes() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="auto", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "Authentication=ActiveDirectoryDefault" in conn_str + + +def test_mssql_python_device_code_authentication() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="ActiveDirectoryDeviceCode", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "Authentication=ActiveDirectoryDeviceCode" in conn_str + + +def test_mssql_python_service_principal_authentication() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="serviceprincipal", + client_id="client-id", + client_secret="client-secret", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "Authentication=ActiveDirectoryServicePrincipal" in conn_str + assert "UID=client-id" in conn_str + assert "PWD=client-secret" in conn_str + + +def test_mssql_python_password_authentication() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="ActiveDirectoryPassword", + UID="user", + PWD="password", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "Authentication=ActiveDirectoryPassword" in conn_str + assert "UID=user" in conn_str + assert "PWD=password" in conn_str + + +def test_mssql_python_default_does_not_append_app_when_installed() -> None: + pytest.importorskip("mssql_python") + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="sql", + UID="user", + PWD="password", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + assert "APP=dbt-sqlserver/" not in conn_str + + +def test_mssql_python_windows_login_rejects_user_password( + credentials: SQLServerCredentials, +) -> None: + credentials.backend = SQLServerBackend.mssql_python + credentials.windows_login = True + credentials.UID = "dbt_user" + credentials.PWD = "super-secret" + credentials.encrypt = True + credentials.trust_cert = True + + with pytest.raises(DbtRuntimeError, match="user/password are not valid"): + _build_mssql_python_connection_string(credentials) + + +def test_mssql_python_system_assigned_msi() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="ActiveDirectoryMsi", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "Authentication=ActiveDirectoryMSI" in conn_str + assert "UID=" not in conn_str + assert "PWD=" not in conn_str + + +def test_mssql_python_user_assigned_msi() -> None: + client_id = "00000000-0000-0000-0000-000000000000" + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="msi", + UID=client_id, + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "Authentication=ActiveDirectoryMSI" in conn_str + assert f"UID={client_id}" in conn_str + assert "PWD=" not in conn_str + + +def test_mssql_python_active_directory_integrated() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="ActiveDirectoryIntegrated", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "Authentication=ActiveDirectoryIntegrated" in conn_str + assert "PWD=" not in conn_str + + +def test_mssql_python_supported_authentication_modes() -> None: + for authentication in [ + "msi", + "ActiveDirectoryMSI", + "active_directory_msi", + "ActiveDirectoryIntegrated", + "active_directory_integrated", + "adintegrated", + "serviceprincipal", + "ActiveDirectoryServicePrincipal", + "auto", + "ActiveDirectoryDefault", + "default", + ]: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication=authentication, + ) + + _validate_mssql_python_requirements(credentials) + + +def test_open_with_mssql_python_system_assigned_msi_passes_connection_string( + credentials: SQLServerCredentials, + monkeypatch: pytest.MonkeyPatch, +) -> None: + credentials.driver = None + credentials.backend = SQLServerBackend.mssql_python + credentials.authentication = "msi" + credentials.encrypt = True + credentials.trust_cert = True + + captured: Dict[str, Any] = {} + + class FakeHandle: + def __init__(self): + self.timeout = None + + def fake_connect(connection_string, autocommit, timeout): + captured["connection_string"] = connection_string + captured["autocommit"] = autocommit + captured["timeout"] = timeout + return FakeHandle() + + fake_module = SimpleNamespace( + connect=fake_connect, + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + InternalError=type("InternalError", (Exception,), {}), + ) + + def fake_retry_connection( + cls, + connection, + connect, + logger, + retry_limit, + retryable_exceptions, + ): + handle = connect() + connection.handle = handle + connection.state = ConnectionState.OPEN + return connection + + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + monkeypatch.setattr( + SQLServerConnectionManager, + "retry_connection", + classmethod(fake_retry_connection), + ) + + connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) + opened = SQLServerConnectionManager.open(connection) + + assert opened is connection + assert opened.state == ConnectionState.OPEN + assert "Authentication=ActiveDirectoryMSI" in captured["connection_string"] + assert "UID=" not in captured["connection_string"] + assert "PWD=" not in captured["connection_string"] + + def test_adapter_module_import_does_not_import_optional_backends( monkeypatch: pytest.MonkeyPatch, ) -> None: original_import = builtins.__import__ def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): - if name in {"pyodbc", "mssql_python"}: + if name in {"pyodbc", "mssql_python", "azure.identity", "azure.core.credentials"}: raise AssertionError(f"unexpected import: {name}") return original_import(name, globals, locals, fromlist, level) @@ -90,6 +383,38 @@ def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): assert sqlserver_connections._MSSQL_PYTHON_MODULE is None +def test_get_pyodbc_imports_only_pyodbc(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", None, raising=False) + monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + original_import = builtins.__import__ + + def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): + if name in {"mssql_python", "azure.identity", "azure.core.credentials"}: + raise AssertionError(f"unexpected import: {name}") + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", guarded_import) + + module = sqlserver_connections._get_pyodbc() + assert module is not None + + +def test_get_mssql_python_imports_only_mssql_python(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + original_import = builtins.__import__ + + def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): + if name in {"pyodbc", "azure.identity", "azure.core.credentials"}: + raise AssertionError(f"unexpected import: {name}") + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", guarded_import) + + module = sqlserver_connections._get_mssql_python() + assert module is not None + + def test_get_pyodbc_returns_cached_module(monkeypatch: pytest.MonkeyPatch) -> None: fake_pyodbc = SimpleNamespace(name="cached-pyodbc") monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc, raising=False) @@ -156,7 +481,7 @@ def test_open_with_mssql_python_feature_flag_requires_optional_dependency( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch ) -> None: credentials.driver = None - credentials.use_mssql_python = True + credentials.backend = SQLServerBackend.mssql_python connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) @@ -183,7 +508,7 @@ def test_open_with_mssql_python_feature_flag_builds_connection_without_odbc_driv credentials.login_timeout = 17 credentials.query_timeout = 23 credentials.retries = 5 - credentials.use_mssql_python = True + credentials.backend = SQLServerBackend.mssql_python captured: Dict[str, Any] = {} pooling_calls: List[Dict[str, Any]] = [] @@ -244,14 +569,14 @@ def fake_retry_connection( assert captured["autocommit"] is True assert captured["timeout"] == 17 assert captured["retry_limit"] == 5 - assert pooling_calls == [{"enabled": False}] + assert pooling_calls == [] con_str = captured["connection_string"] assert "DRIVER=" not in con_str assert "SERVER=fake.sql.sqlserver.net,1433" in con_str assert "Database=dbt" in con_str - assert "UID={dbt_user}" in con_str - assert "PWD={super-secret}" in con_str + assert "UID=dbt_user" in con_str + assert "PWD=super-secret" in con_str assert "encrypt=Yes" in con_str assert "TrustServerCertificate=Yes" in con_str assert "APP=dbt-sqlserver/" not in con_str @@ -259,12 +584,15 @@ def fake_retry_connection( assert fake_module.OperationalError in captured["retryable_exceptions"] assert fake_module.InternalError in captured["retryable_exceptions"] + assert pooling_calls == [] + assert "APP=dbt-sqlserver/" not in con_str + def test_open_with_mssql_python_feature_flag_fails_fast_for_pyodbc_token_auth_aliases( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch ) -> None: credentials.driver = None - credentials.use_mssql_python = True + credentials.backend = SQLServerBackend.mssql_python credentials.authentication = "cli" fake_module = SimpleNamespace( @@ -283,12 +611,146 @@ def test_open_with_mssql_python_feature_flag_fails_fast_for_pyodbc_token_auth_al SQLServerConnectionManager.open(connection) +@pytest.mark.parametrize( + "unsupported_auth", + ["cli", "environment", "ActiveDirectoryAccessToken"], +) +def test_open_with_mssql_python_unsupported_authentications( + credentials: SQLServerCredentials, + monkeypatch: pytest.MonkeyPatch, + unsupported_auth: str, +) -> None: + credentials.driver = None + credentials.backend = SQLServerBackend.mssql_python + credentials.authentication = unsupported_auth + credentials.UID = "dbt_user" + credentials.PWD = "super-secret" + + fake_module = SimpleNamespace( + connect=lambda *args, **kwargs: None, + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + InternalError=type("InternalError", (Exception,), {}), + ) + + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + + connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) + + with pytest.raises(DbtRuntimeError, match="authentication"): + SQLServerConnectionManager.open(connection) + + +@pytest.mark.parametrize( + "authentication", + ["msi", "ActiveDirectoryMSI"], +) +def test_open_with_mssql_python_supported_managed_identity_auth( + credentials: SQLServerCredentials, + monkeypatch: pytest.MonkeyPatch, + authentication: str, +) -> None: + credentials.driver = None + credentials.backend = SQLServerBackend.mssql_python + credentials.authentication = authentication + credentials.UID = None + credentials.PWD = None + credentials.encrypt = True + credentials.trust_cert = True + + captured: Dict[str, Any] = {} + + class FakeHandle: + def __init__(self): + self.timeout = None + + def fake_connect(connection_string, autocommit, timeout): + captured["connection_string"] = connection_string + captured["autocommit"] = autocommit + captured["timeout"] = timeout + return FakeHandle() + + fake_module = SimpleNamespace( + connect=fake_connect, + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + InternalError=type("InternalError", (Exception,), {}), + ) + + def fake_retry_connection( + cls, + connection, + connect, + logger, + retry_limit, + retryable_exceptions, + ): + handle = connect() + connection.handle = handle + connection.state = ConnectionState.OPEN + return connection + + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) + monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + monkeypatch.setattr( + SQLServerConnectionManager, + "retry_connection", + classmethod(fake_retry_connection), + ) + + connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) + opened = SQLServerConnectionManager.open(connection) + + assert opened is connection + assert opened.state == ConnectionState.OPEN + assert "Authentication=ActiveDirectoryMSI" in captured["connection_string"] + assert "UID=" not in captured["connection_string"] + assert "PWD=" not in captured["connection_string"] + + +@pytest.mark.parametrize( + "required_field, value, match_text", + [ + ("host", None, "host"), + ("database", None, "database"), + ("schema", None, "schema"), + ], +) +def test_open_requires_host_database_schema( + credentials: SQLServerCredentials, + monkeypatch: pytest.MonkeyPatch, + required_field: str, + value: object, + match_text: str, +) -> None: + setattr(credentials, required_field, value) + credentials.UID = "dbt_user" + credentials.PWD = "super-secret" + + fake_pyodbc = SimpleNamespace( + connect=lambda *args, **kwargs: None, + pooling=False, + InternalError=type("InternalError", (Exception,), {}), + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + ) + + monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc, raising=False) + monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + + connection = Connection(type="sqlserver", name="pyodbc-test", credentials=credentials) + + with pytest.raises(DbtRuntimeError, match=match_text): + SQLServerConnectionManager.open(connection) + + def test_open_with_pyodbc_path_still_requires_driver_when_feature_flag_disabled( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch, ) -> None: credentials.driver = None - credentials.use_mssql_python = False + credentials.backend = SQLServerBackend.pyodbc fake_pyodbc = SimpleNamespace( connect=lambda *args, **kwargs: None, diff --git a/uv.lock b/uv.lock index febf48202..0162e820e 100644 --- a/uv.lock +++ b/uv.lock @@ -577,6 +577,7 @@ dependencies = [ { name = "dbt-adapters" }, { name = "dbt-common" }, { name = "dbt-core" }, + { name = "pyodbc" }, ] [package.optional-dependencies] @@ -619,7 +620,8 @@ requires-dist = [ { name = "dbt-adapters", specifier = ">=1.15.2,<2.0" }, { name = "dbt-common", specifier = ">=1.22.0,<2.0" }, { name = "dbt-core", specifier = ">=1.10.0,<1.11.0" }, - { name = "mssql-python", marker = "extra == 'mssql'", specifier = ">=1.1.0" }, + { name = "mssql-python", marker = "extra == 'mssql'", specifier = ">=1.4.0" }, + { name = "pyodbc", specifier = ">=5.2.0" }, { name = "pyodbc", marker = "extra == 'pyodbc'", specifier = ">=5.2.0" }, ] provides-extras = ["azure", "pyodbc", "mssql"] @@ -633,7 +635,7 @@ dev = [ { name = "flaky" }, { name = "freezegun", specifier = ">=1.5.0,<2.0" }, { name = "ipdb" }, - { name = "mssql-python", specifier = ">=1.1.0" }, + { name = "mssql-python", specifier = ">=1.4.0" }, { name = "mypy", specifier = "==1.11.2" }, { name = "pre-commit" }, { name = "pyodbc", specifier = ">=5.2.0" }, From 6c8b606f576d6457ad3bd24365106f86e857e85a Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Tue, 26 May 2026 05:17:42 +0000 Subject: [PATCH 6/8] Fix pyodbc driver validation for blank values - Update `_validate_pyodbc_requirements()` to reject `None`, empty, and whitespace-only `driver` values - Add test covering `None`, `""`, and `" "` driver inputs --- .../sqlserver/sqlserver_connections.py | 2 +- .../test_sqlserver_connection_manager.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index 7d365d103..9a628900a 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -299,7 +299,7 @@ def _validate_connection_requirements(credentials: SQLServerCredentials) -> None def _validate_pyodbc_requirements(credentials: SQLServerCredentials) -> None: - if not credentials.driver: + if credentials.driver is None or not credentials.driver.strip(): raise dbt_common.exceptions.DbtRuntimeError( "The pyodbc backend requires a SQL Server ODBC driver name " "in the `driver` profile field." diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index 7c95a4d01..be0c4163a 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -14,6 +14,7 @@ _build_mssql_python_connection_string, _normalize_mssql_python_authentication, _validate_mssql_python_requirements, + _validate_pyodbc_requirements, bool_to_connection_string_arg, get_pyodbc_attrs_before_credentials, ) @@ -68,6 +69,26 @@ def test_get_pyodbc_attrs_before_cli_auth_requires_azure_identity( get_pyodbc_attrs_before_credentials(credentials) +@pytest.mark.parametrize( + "driver", + [None, "", " "], +) +def test_validate_pyodbc_requirements_rejects_blank_driver( + driver: str | None, +) -> None: + credentials = SQLServerCredentials( + driver=driver, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + ) + + with pytest.raises( + DbtRuntimeError, match="The pyodbc backend requires a SQL Server ODBC driver name" + ): + _validate_pyodbc_requirements(credentials) + + @pytest.mark.parametrize( "key, value, expected", [("somekey", False, "somekey=No"), ("somekey", True, "somekey=Yes")], From e35e81299312c55733ea1a83efe89cf59c09519b Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Tue, 26 May 2026 20:19:39 +0000 Subject: [PATCH 7/8] =?UTF-8?q?=E2=9C=A8=20feat(sqlserver):=20enhance=20SQ?= =?UTF-8?q?L=20Server=20connection=20manager=20with=20new=20authentication?= =?UTF-8?q?=20and=20connection=20string=20features?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Added support for Active Directory access token authentication. - Improved handling of connection string sanitization for logging. - Introduced validation for connection requirements and query timeout. - Enhanced tests for various authentication methods and connection string formats. - Refactored connection handling to support new backend options. --- CONTRIBUTING.md | 2 +- README.md | 4 +- dbt/adapters/sqlserver/sqlserver_auth.py | 323 ++++++ dbt/adapters/sqlserver/sqlserver_backend.py | 342 ++++++ .../sqlserver/sqlserver_connections.py | 985 ++-------------- dbt/adapters/sqlserver/sqlserver_constants.py | 114 ++ .../sqlserver/sqlserver_credentials.py | 57 +- dbt/adapters/sqlserver/sqlserver_helpers.py | 291 +++++ dbt/adapters/sqlserver/sqlserver_runtime.py | 450 ++++++++ tests/__init__.py | 11 +- .../adapters/mssql/test_connection_logic.py | 128 +-- .../test_sqlserver_connection_manager.py | 1019 ++++++++++++++--- uv.lock | 8 +- 13 files changed, 2588 insertions(+), 1146 deletions(-) create mode 100644 dbt/adapters/sqlserver/sqlserver_auth.py create mode 100644 dbt/adapters/sqlserver/sqlserver_backend.py create mode 100644 dbt/adapters/sqlserver/sqlserver_constants.py create mode 100644 dbt/adapters/sqlserver/sqlserver_helpers.py create mode 100644 dbt/adapters/sqlserver/sqlserver_runtime.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index be99accf4..e7da4c9eb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,7 +59,7 @@ This will use Docker Compose to spin up a local instance of SQL Server. Docker C Next, tell our tests how they should connect to the local instance by creating a file called `test.env` in the root of the project. You can use the provided `test.env.sample` as a base and if you started the server with `make server`, then this matches the instance running on your local machine. -If you are testing the optional `mssql-python` backend, also enable its profile setting in `test.env` so the adapter selects that implementation instead of the legacy driver-based one. +If you are testing the optional `mssql-python` backend, also enable its backend setting in `test.env` so the adapter selects that implementation instead of the legacy driver-based one. ```shell cp test.env.sample test.env diff --git a/README.md b/README.md index 4c6429595..227082a23 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ The legacy and currently default ODBC path uses `pyodbc` and the Microsoft ODBC pip install -U dbt-sqlserver ``` -You should migrate to use an explicit extra for incoming deprecation, the following is equivalent: +You should migrate to using an explicit extra in preparation for deprecation; the following is equivalent: ```shell pip install -U "dbt-sqlserver[pyodbc]" @@ -137,7 +137,7 @@ vars: ### `backend` -*(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required driver is not installed. +*(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required backend package (Python dependency), such as `pyodbc` or `mssql-python`, is not installed. ## Contributing diff --git a/dbt/adapters/sqlserver/sqlserver_auth.py b/dbt/adapters/sqlserver/sqlserver_auth.py new file mode 100644 index 000000000..3a7993fb9 --- /dev/null +++ b/dbt/adapters/sqlserver/sqlserver_auth.py @@ -0,0 +1,323 @@ +"""Authentication and token helpers for the SQL Server adapter. + +This module owns the shared normalization rules for auth labels, plus the +pyodbc-facing Azure token helpers used by the connection manager. +""" + +from __future__ import annotations + +import struct +import time +from itertools import chain, repeat +from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, cast + +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.sqlserver.sqlserver_constants import ( + AAD_TOKEN_AUTHENTICATIONS, + CONNECTION_AUTH_ALIASES, + CONNECTION_AUTH_PASSTHROUGH_KEYS, + PYODBC_AUTH_ALIASES, + SQLSERVER_BACKEND_MSSQL_PYTHON, +) +from dbt.adapters.sqlserver.sqlserver_runtime import ( + AZURE_CREDENTIAL_SCOPE, + AccessTokenProtocol, + _get_azure_access_token_class, + _get_azure_identity_module, + _get_cached_access_token, +) + +if TYPE_CHECKING: + from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerBackend, SQLServerCredentials + + +logger = AdapterLogger("sqlserver") +AZURE_AUTH_FUNCTION_TYPE = Callable[[Any, Optional[str]], AccessTokenProtocol] + + +def is_mssql_python_backend(backend: "SQLServerBackend") -> bool: + """Return whether the coerced backend enum targets ``mssql-python``.""" + + return backend.value == SQLSERVER_BACKEND_MSSQL_PYTHON + + +def normalize_authentication_key(value: Optional[str]) -> str: + """Normalize a SQL Server auth or lookup key for cross-layer comparisons.""" + + return "" if value is None else value.replace("_", "").replace(" ", "").lower() + + +def is_active_directory_authentication(authentication: Optional[str]) -> bool: + """Return whether an auth label targets one of the ActiveDirectory modes.""" + + return normalize_authentication_key(authentication).startswith("activedirectory") + + +def normalize_mssql_python_authentication( + authentication: Optional[str], +) -> Optional[str]: + """Backend-layer auth normalization used while building connection strings.""" + + authentication = authentication or "" + key = normalize_authentication_key(authentication) + if not key: + return None + + if key in CONNECTION_AUTH_PASSTHROUGH_KEYS: + return authentication.strip() + + if key in CONNECTION_AUTH_ALIASES: + return CONNECTION_AUTH_ALIASES[key] + + return authentication.strip() + + +def normalize_pyodbc_authentication(authentication: Optional[str]) -> str: + """Normalize auth labels for the pyodbc token path. + + Only the token-oriented aliases that participate in cached access-token + retrieval are normalized here. Connection-string auth aliases such as + ``ActiveDirectoryServicePrincipal`` are handled by the backend builders. + """ + + if key := normalize_authentication_key(authentication): + return PYODBC_AUTH_ALIASES.get(key, key) + return "" + + +def normalize_connection_authentication( + authentication: Optional[str], mssql_python_backend: bool +) -> str: + """Normalize auth labels for connection-string generation. + + Call this from connection-string builders and validation, not from profile + parsing. The ``mssql-python`` path canonicalizes long-form connection + strings, while the pyodbc path preserves its raw token-auth labels so + ``get_pyodbc_attrs_before_credentials`` can apply its narrower alias map. + """ + + authentication = authentication or "" + if mssql_python_backend: + return normalize_mssql_python_authentication(authentication) or "" + return authentication.strip() + + +def uses_aad_token_authentication(credentials: "SQLServerCredentials") -> bool: + """Return whether pyodbc should request and cache an Azure access token. + + This is used by retry policy as well as token fetching, so manual + ``ActiveDirectoryAccessToken`` profiles stay in the same retry bucket as + the other AAD token modes. + """ + + authentication = normalize_pyodbc_authentication(credentials.authentication) + return authentication in AAD_TOKEN_AUTHENTICATIONS + + +def get_environment_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessTokenProtocol: + """ + Get an Azure access token by reading environment variables + + Parameters + ----------- + credentials: SQLServerCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + azure_identity = _get_azure_identity_module() + return azure_identity.EnvironmentCredential().get_token( + scope, timeout=credentials.login_timeout + ) + + +def get_msi_access_token( + _credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessTokenProtocol: + """ + Get an Azure access token from the system's managed identity + + Parameters + ----------- + credentials: SQLServerCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + azure_identity = _get_azure_identity_module() + return azure_identity.ManagedIdentityCredential().get_token(scope or AZURE_CREDENTIAL_SCOPE) + + +def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: + """ + Convert bytes to a Microsoft windows byte string. + + Parameters + ---------- + value : bytes + The bytes. + + Returns + ------- + out : bytes + The Microsoft byte string. + """ + encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0)))) + return struct.pack(" bytes: + """ + Convert an access token to a Microsoft windows byte string. + + Parameters + ---------- + token : AccessTokenProtocol + The token. + + Returns + ------- + out : bytes + The Microsoft byte string. + """ + value = bytes(token.token, "UTF-8") + return convert_bytes_to_mswindows_byte_string(value) + + +def get_cli_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessTokenProtocol: + """ + Get an Azure access token using the CLI credentials + + First login with: + + ```bash + az login + ``` + + Parameters + ---------- + credentials: SQLServerCredentials + The credentials. + + Returns + ------- + out : AccessToken + Access token. + """ + azure_identity = _get_azure_identity_module() + return azure_identity.AzureCliCredential().get_token(scope, timeout=credentials.login_timeout) + + +def get_auto_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessTokenProtocol: + """ + Get an Azure access token automatically through azure-identity + + Parameters + ----------- + credentials: SQLServerCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + azure_identity = _get_azure_identity_module() + return azure_identity.DefaultAzureCredential().get_token( + scope, timeout=credentials.login_timeout + ) + + +def get_sp_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessTokenProtocol: + """ + Get an Azure access token using the SP credentials. + + Parameters + ---------- + credentials : SQLServerCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + azure_identity = _get_azure_identity_module() + return azure_identity.ClientSecretCredential( + str(credentials.tenant_id), + str(credentials.client_id), + str(credentials.client_secret), + ).get_token(scope or AZURE_CREDENTIAL_SCOPE) + + +AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = { + "cli": get_cli_access_token, + "auto": get_auto_access_token, + "environment": get_environment_access_token, + "serviceprincipal": get_sp_access_token, + "msi": get_msi_access_token, +} + + +def get_pyodbc_attrs_before_credentials(credentials: SQLServerCredentials) -> Dict: + """Build the pyodbc authentication attrs used by the connection manager.""" + + sql_copt_ss_access_token = 1256 # ODBC constant for access token + + authentication = normalize_pyodbc_authentication(credentials.authentication) + + if authentication in AZURE_AUTH_FUNCTIONS: + token = _get_cached_access_token( + credentials, + authentication, + AZURE_CREDENTIAL_SCOPE, + lambda: AZURE_AUTH_FUNCTIONS[authentication](credentials, AZURE_CREDENTIAL_SCOPE), + ) + token_bytes = convert_access_token_to_mswindows_byte_string(token) + return {sql_copt_ss_access_token: token_bytes} + + if authentication == "activedirectoryaccesstoken": + if credentials.access_token is None or credentials.access_token_expires_on is None: + raise ValueError( + ( + "Access token and a non-zero access token expiry epoch timestamp are " + "required for ActiveDirectoryAccessToken authentication." + ) + ) + + if credentials.access_token_expires_on == 0: + logger.warning( + "ActiveDirectoryAccessToken expiry is 0; defaulting expiry to 75 minutes. " + "Set access_token_expires_on explicitly to remove this message." + ) + + access_token = cast( + AccessTokenProtocol, + _get_azure_access_token_class()( + token=credentials.access_token, + expires_on=int( + time.time() + 4500.0 + if credentials.access_token_expires_on == 0 + else credentials.access_token_expires_on + ), + ), + ) + return { + sql_copt_ss_access_token: convert_access_token_to_mswindows_byte_string(access_token) + } + + return {} diff --git a/dbt/adapters/sqlserver/sqlserver_backend.py b/dbt/adapters/sqlserver/sqlserver_backend.py new file mode 100644 index 000000000..3b3a4486d --- /dev/null +++ b/dbt/adapters/sqlserver/sqlserver_backend.py @@ -0,0 +1,342 @@ +"""Backend-policy helpers for the SQL Server adapter. + +This module owns the backend-specific connection-string assembly and the +shared retry / error handling policy. Mutable lazy-import/runtime cache state +lives in ``sqlserver_runtime.py`` and is orchestrated by +``sqlserver_connections.py``. +""" + +from __future__ import annotations + +from contextlib import suppress +from typing import Any, Callable, Tuple + +import dbt_common.exceptions + +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.sqlserver import __version__ +from dbt.adapters.sqlserver.sqlserver_auth import ( + get_pyodbc_attrs_before_credentials, + is_active_directory_authentication, + normalize_connection_authentication, + uses_aad_token_authentication, +) +from dbt.adapters.sqlserver.sqlserver_constants import ( + MSSQL_AUTH_ACTIVE_DIRECTORY_ACCESS_TOKEN, + MSSQL_AUTH_ACTIVE_DIRECTORY_INTEGRATED, + MSSQL_AUTH_ACTIVE_DIRECTORY_INTERACTIVE, + MSSQL_AUTH_ACTIVE_DIRECTORY_MSI, + MSSQL_AUTH_ACTIVE_DIRECTORY_PASSWORD, + MSSQL_AUTH_ACTIVE_DIRECTORY_SERVICE_PRINCIPAL, +) +from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials +from dbt.adapters.sqlserver.sqlserver_helpers import ( + _set_query_timeout_if_supported, + bool_to_connection_string_arg, + build_server_arg, + format_connection_string_value, + format_pyodbc_driver_value, + sanitize_connection_string_for_logging, +) +from dbt.adapters.sqlserver.sqlserver_runtime import ( + _RUNTIME_STATE, + MssqlPythonModuleProtocol, + PyodbcModuleProtocol, +) + +logger = AdapterLogger("sqlserver") + + +def build_common_connection_string_parts( + credentials: SQLServerCredentials, + mssql_python_backend: bool, +) -> list[str]: + """Build validated connection-string parts shared by both backends. + + Call this only after shared/backend-specific profile validation has run. + `credentials.authentication` is canonicalized here so the backend branches + below can compare one normalized auth label per mode. + """ + + con_str = [f"SERVER={build_server_arg(credentials)}"] + con_str.append(f"Database={credentials.database}") + + authentication = normalize_connection_authentication( + credentials.authentication, + mssql_python_backend, + ) + + if is_active_directory_authentication(authentication) and ( + authentication != MSSQL_AUTH_ACTIVE_DIRECTORY_ACCESS_TOKEN + ): + con_str.append(f"Authentication={authentication}") + + if authentication == MSSQL_AUTH_ACTIVE_DIRECTORY_PASSWORD: + con_str.append( + f"UID={format_connection_string_value(credentials.UID, mssql_python_backend)}" + ) + con_str.append( + f"PWD={format_connection_string_value(credentials.PWD, mssql_python_backend)}" + ) + elif authentication == MSSQL_AUTH_ACTIVE_DIRECTORY_SERVICE_PRINCIPAL: + con_str.append( + "UID=" + + format_connection_string_value( + credentials.client_id, + mssql_python_backend, + ) + ) + con_str.append( + "PWD=" + + format_connection_string_value( + credentials.client_secret, + mssql_python_backend, + ) + ) + elif authentication == MSSQL_AUTH_ACTIVE_DIRECTORY_INTERACTIVE: + con_str.append( + "UID=%s" + % format_connection_string_value( + credentials.UID, + mssql_python_backend, + ) + ) + elif authentication == MSSQL_AUTH_ACTIVE_DIRECTORY_MSI: + if credentials.PWD: + raise dbt_common.exceptions.DbtRuntimeError( + "password is not valid with ActiveDirectoryMSI for the mssql-python backend." + ) + if credentials.UID: + con_str.append( + f"UID={format_connection_string_value(credentials.UID, mssql_python_backend)}" + ) + elif authentication == MSSQL_AUTH_ACTIVE_DIRECTORY_INTEGRATED: + if credentials.PWD: + raise dbt_common.exceptions.DbtRuntimeError( + "password is not valid with ActiveDirectoryIntegrated" + " for the mssql-python backend." + ) + + elif credentials.windows_login: + if mssql_python_backend and (credentials.UID or credentials.PWD): + raise dbt_common.exceptions.DbtRuntimeError( + "user/password are not valid with windows_login/trusted_connection " + "for the mssql-python backend." + ) + con_str.append("Trusted_Connection=yes") + elif authentication == "sql": + con_str.append( + f"UID={format_connection_string_value(credentials.UID, mssql_python_backend)}" + ) + con_str.append( + f"PWD={format_connection_string_value(credentials.PWD, mssql_python_backend)}" + ) + + con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt)) + con_str.append(bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert)) + + if not mssql_python_backend: + application_name = f"dbt-{credentials.type}/{__version__.version}" + con_str.append(f"APP={application_name}") + + return con_str + + +def build_pyodbc_connection_string(credentials: SQLServerCredentials) -> str: + """Build the full pyodbc connection string used by the connection manager. + + Invariants: + - `driver` must be specified and formatted properly (for example, enclosed + in braces if not already). + - Encrypted parameters and other connection attributes default to + standard values suitable for pyodbc. + + Integration: + Called by `SQLServerConnectionManager.open()` when the backend type is + configured as `pyodbc`. + """ + + con_str = [f"DRIVER={format_pyodbc_driver_value(credentials.driver)}"] + con_str.extend(build_common_connection_string_parts(credentials, mssql_python_backend=False)) + con_str.extend( + [ + "Pooling=true", + ( + "SQL_ATTR_TRACE=SQL_OPT_TRACE_ON" + if credentials.trace_flag + else "SQL_ATTR_TRACE=SQL_OPT_TRACE_OFF" + ), + "ConnectRetryCount=3", + "ConnectRetryInterval=10", + ] + ) + + return ";".join(con_str) + + +def build_mssql_python_connection_string(credentials: SQLServerCredentials) -> str: + """Build the full mssql-python connection string used by the connection manager. + + Expected Inputs: + credentials: An instance of SQLServerCredentials containing validated + host, database, and auth details. + + Invariants: + - Must not contain `DRIVER` or ODBC-specific tags. + - Connection parameters are escaped specifically for the + mssql-python parser backend. + + Integration: + Called by `SQLServerConnectionManager.open()` when the backend type is + configured as `mssql-python`. + """ + + con_str = build_common_connection_string_parts(credentials, mssql_python_backend=True) + return ";".join(con_str) + + +def get_pyodbc_retryable_exceptions( + credentials: SQLServerCredentials, + pyodbc: PyodbcModuleProtocol, +) -> Tuple[type[Exception], ...]: + """Return the pyodbc exception types that the connection manager may retry.""" + + retryable_exceptions: list[type[Exception]] = [ + pyodbc.InternalError, + pyodbc.OperationalError, + ] + + if uses_aad_token_authentication(credentials): + retryable_exceptions.append(pyodbc.InterfaceError) + + return tuple(retryable_exceptions) + + +def get_mssql_python_retryable_exceptions( + credentials: SQLServerCredentials, + mssql_python: MssqlPythonModuleProtocol, +) -> Tuple[type[Exception], ...]: + """Return the mssql-python exception types that the connection manager may retry.""" + + retryable_exceptions: list[type[Exception]] = [ + mssql_python.InternalError, + mssql_python.OperationalError, + ] + + if uses_aad_token_authentication(credentials): + retryable_exceptions.append(mssql_python.InterfaceError) + + return tuple(retryable_exceptions) + + +def handle_backend_database_error( + error: Exception, + database_error: type[Exception] | None, + release_connection: Callable[[], None], +) -> None: + """Translate backend database exceptions into dbt runtime errors. + + Call this only after the caller has identified the backend-specific error + type; non-database errors should bypass this helper. + """ + + if database_error is None or not isinstance(error, database_error): + return + + logger.debug(f"Database error: {error}") + + with suppress(Exception): + release_connection() + + raise dbt_common.exceptions.DbtDatabaseError(str(error).strip()) from error + + +def log_connection_string(connection_string: str) -> None: + """Log a sanitized connection string for the current backend.""" + + sanitized_connection_string = sanitize_connection_string_for_logging(connection_string) + logger.debug(f"Using connection string: {sanitized_connection_string}") + + +def is_pyodbc_handle(handle: Any) -> bool: + """Detect a pyodbc handle without importing pyodbc from the caller.""" + + handle_type = type(handle) + module_name = getattr(handle_type, "__module__", "") or "" + class_name = getattr(handle_type, "__name__", "") or "" + + if "pyodbc" in module_name or "pyodbc" in class_name: + return True + + if "unittest.mock" in module_name or "mock" in class_name.lower(): + return hasattr(handle, "add_output_converter") + + return False + + +def _log_connected_database(credentials: SQLServerCredentials) -> None: + logger.debug(f"Connected to db: {credentials.database}") + + +def _finalize_connection_handle( + handle: Any, + credentials: SQLServerCredentials, +) -> Any: + """Apply conservative shared connection-handle configuration.""" + + _set_query_timeout_if_supported(handle, credentials.query_timeout) + _log_connected_database(credentials) + return handle + + +def _finalize_mssql_python_handle( + handle: Any, + credentials: SQLServerCredentials, +) -> Any: + """Apply mssql-python-specific post-connect policy.""" + + timeout_supported = _set_query_timeout_if_supported(handle, credentials.query_timeout) + if ( + not timeout_supported + and credentials.query_timeout not in (None, 0) + and _RUNTIME_STATE.take_timeout_warning() + ): + logger.warning( + "Configured query_timeout=%r, but the mssql-python backend does not " + "support per-connection query timeouts; the setting will be ignored.", + credentials.query_timeout, + ) + + _log_connected_database(credentials) + return handle + + +def _connect_mssql_python( + mssql_python: MssqlPythonModuleProtocol, + credentials: SQLServerCredentials, + connection_string: str, +) -> Any: + mssql_python.pooling(enabled=True) + handle = mssql_python.connect( + connection_string, + autocommit=True, + timeout=credentials.login_timeout, + ) + return _finalize_mssql_python_handle(handle, credentials) + + +def _connect_pyodbc( + pyodbc: PyodbcModuleProtocol, + credentials: SQLServerCredentials, + connection_string: str, +) -> Any: + pyodbc.pooling = True + attrs_before = get_pyodbc_attrs_before_credentials(credentials) + + handle = pyodbc.connect( + connection_string, + attrs_before=attrs_before, + autocommit=True, + timeout=credentials.login_timeout, + ) + return _finalize_connection_handle(handle, credentials) diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index 9a628900a..aa0beede0 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -1,10 +1,13 @@ import datetime as dt -import struct import time from contextlib import contextmanager -from dataclasses import dataclass -from itertools import chain, repeat -from typing import Any, Callable, Dict, Mapping, Optional, Protocol, Tuple, Type, Union, cast +from typing import ( + Any, + Optional, + Tuple, + Type, + Union, +) import agate # type: ignore[import] import dbt_common.exceptions @@ -13,860 +16,87 @@ from dbt_common.events.functions import fire_event from dbt_common.utils.casting import cast_to_str -from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState +from dbt.adapters.contracts.connection import ( + AdapterResponse, + Connection, + ConnectionState, +) from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus +from dbt.adapters.events.types import ( + AdapterEventDebug, + ConnectionUsed, + SQLQuery, + SQLQueryStatus, +) from dbt.adapters.sql.connections import SQLConnectionManager -from dbt.adapters.sqlserver import __version__ -from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerBackend, SQLServerCredentials - - -class PyodbcModuleProtocol(Protocol): - InternalError: type[Exception] - OperationalError: type[Exception] - InterfaceError: type[Exception] - DatabaseError: type[Exception] - pooling: bool - - def connect(self, *args: Any, **kwargs: Any) -> Any: ... - - -class MssqlPythonModuleProtocol(Protocol): - InternalError: type[Exception] - OperationalError: type[Exception] - InterfaceError: type[Exception] - DatabaseError: type[Exception] - - def connect(self, *args: Any, **kwargs: Any) -> Any: ... - - -class AccessTokenProtocol(Protocol): - token: str - expires_on: int - - -class TokenCredentialProtocol(Protocol): - def get_token(self, *scopes: Optional[str], **kwargs: Any) -> AccessTokenProtocol: ... - - -class CredentialFactory(Protocol): - def __call__(self, *args: Any, **kwargs: Any) -> TokenCredentialProtocol: ... - - -class AzureIdentityModuleProtocol(Protocol): - AzureCliCredential: CredentialFactory - DefaultAzureCredential: CredentialFactory - EnvironmentCredential: CredentialFactory - ManagedIdentityCredential: CredentialFactory - ClientSecretCredential: CredentialFactory - - -class AzureCredentialsModuleProtocol(Protocol): - AccessToken: Type[AccessTokenProtocol] - - -_PYODBC_MODULE: Optional[PyodbcModuleProtocol] = None -_PYODBC_IMPORT_ERROR: Optional[ModuleNotFoundError] = None - -_MSSQL_PYTHON_MODULE: Optional[MssqlPythonModuleProtocol] = None -_MSSQL_PYTHON_IMPORT_ERROR: Optional[ModuleNotFoundError] = None - -_AZURE_CREDENTIALS_MODULE: Optional[AzureCredentialsModuleProtocol] = None -_AZURE_CREDENTIALS_IMPORT_ERROR: Optional[ModuleNotFoundError] = None - -_AZURE_IDENTITY_MODULE: Optional[AzureIdentityModuleProtocol] = None -_AZURE_IDENTITY_IMPORT_ERROR: Optional[ModuleNotFoundError] = None - - -@dataclass -class AccessToken: # type: ignore[no-redef] - token: str - expires_on: int - - -def _get_azure_access_token_class() -> Type[Any]: - global _AZURE_CREDENTIALS_MODULE, _AZURE_CREDENTIALS_IMPORT_ERROR - - if _AZURE_CREDENTIALS_MODULE is not None: - return _AZURE_CREDENTIALS_MODULE.AccessToken - - if _AZURE_CREDENTIALS_IMPORT_ERROR is not None: - return AccessToken - - try: - import azure.core.credentials as azure_credentials # type: ignore[import] - except ModuleNotFoundError as exc: - _AZURE_CREDENTIALS_IMPORT_ERROR = exc - return AccessToken - - _AZURE_CREDENTIALS_MODULE = cast(AzureCredentialsModuleProtocol, azure_credentials) - return azure_credentials.AccessToken - - -def _get_azure_identity_module() -> AzureIdentityModuleProtocol: - global _AZURE_IDENTITY_MODULE, _AZURE_IDENTITY_IMPORT_ERROR - - if _AZURE_IDENTITY_MODULE is not None: - return _AZURE_IDENTITY_MODULE - - if _AZURE_IDENTITY_IMPORT_ERROR is not None: - raise _missing_azure_identity_error() from _AZURE_IDENTITY_IMPORT_ERROR - - try: - import azure.identity as azure_identity # type: ignore[import] - except ModuleNotFoundError as exc: - _AZURE_IDENTITY_IMPORT_ERROR = exc - raise _missing_azure_identity_error() from exc - - _AZURE_IDENTITY_MODULE = cast(AzureIdentityModuleProtocol, azure_identity) - return _AZURE_IDENTITY_MODULE - - -_TOKEN: Optional[AccessTokenProtocol] = None -AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default" -AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials, Optional[str]], AccessTokenProtocol] +from dbt.adapters.sqlserver.sqlserver_auth import ( + is_mssql_python_backend, +) +from dbt.adapters.sqlserver.sqlserver_backend import ( + _connect_mssql_python, + _connect_pyodbc, + build_mssql_python_connection_string, + build_pyodbc_connection_string, + get_mssql_python_retryable_exceptions, + get_pyodbc_retryable_exceptions, + handle_backend_database_error, + is_pyodbc_handle, + log_connection_string, +) +from dbt.adapters.sqlserver.sqlserver_constants import datatypes +from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials +from dbt.adapters.sqlserver.sqlserver_helpers import ( + byte_array_to_datetime, + validate_connection_requirements, + validate_mssql_python_requirements, + validate_pyodbc_requirements, +) +from dbt.adapters.sqlserver.sqlserver_runtime import ( + _RUNTIME_STATE, + _get_mssql_python, + _get_pyodbc, +) logger = AdapterLogger("sqlserver") -# https://github.com/mkleehammer/pyodbc/wiki/Data-Types -datatypes = { - "str": "varchar", - "uuid.UUID": "uniqueidentifier", - "uuid": "uniqueidentifier", - "float": "bigint", - "int": "int", - "bytes": "varbinary", - "bytearray": "varbinary", - "bool": "bit", - "datetime.date": "date", - "datetime.datetime": "datetime2(6)", - "datetime.time": "time", - "decimal.Decimal": "decimal", -} - -MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS = { - "cli", - "environment", - "activedirectoryaccesstoken", -} - - -def _auth_key(authentication: Optional[str]) -> str: - if authentication is None: - return "" - return authentication.replace("_", "").replace(" ", "").lower() - - -def _normalize_mssql_python_authentication(authentication: Optional[str]) -> Optional[str]: - authentication = authentication or "" - key = _auth_key(authentication) - if not key: - return None - - if key in {"msi", "activedirectorymsi"}: - return "ActiveDirectoryMSI" - - if key in {"activedirectoryintegrated", "adintegrated"}: - return "ActiveDirectoryIntegrated" - - if key in {"serviceprincipal", "activedirectoryserviceprincipal"}: - return "ActiveDirectoryServicePrincipal" - - if key in {"auto", "default", "activedirectorydefault"}: - return "ActiveDirectoryDefault" - - if key == "activedirectorypassword": - return "ActiveDirectoryPassword" - - if key == "activedirectoryinteractive": - return "ActiveDirectoryInteractive" - - if key == "activedirectorydevicecode": - return "ActiveDirectoryDeviceCode" - - return authentication.strip() - - -def _missing_pyodbc_error() -> dbt_common.exceptions.DbtRuntimeError: - return dbt_common.exceptions.DbtRuntimeError( - "The legacy `pyodbc` backend was requested, but the optional dependency " - "`pyodbc` is not installed. Install it with `pip install pyodbc` " - "or set `backend: mssql-python` in the profile." - ) - - -def _get_pyodbc() -> PyodbcModuleProtocol: - global _PYODBC_MODULE, _PYODBC_IMPORT_ERROR - - if _PYODBC_MODULE is not None: - return _PYODBC_MODULE - - if _PYODBC_IMPORT_ERROR is not None: - raise _missing_pyodbc_error() from _PYODBC_IMPORT_ERROR - - try: - import pyodbc as imported_pyodbc # type: ignore[import] - except ModuleNotFoundError as exc: - _PYODBC_IMPORT_ERROR = exc - raise _missing_pyodbc_error() from exc - - _PYODBC_MODULE = cast(PyodbcModuleProtocol, imported_pyodbc) - return _PYODBC_MODULE - - -def _missing_mssql_python_error() -> dbt_common.exceptions.DbtRuntimeError: - return dbt_common.exceptions.DbtRuntimeError( - "The `mssql-python` backend was requested, but the optional dependency " - "`mssql-python` is not installed. Install it with `pip install mssql-python` " - "or set `backend: pyodbc` in the profile." - ) - - -def _missing_azure_identity_error() -> dbt_common.exceptions.DbtRuntimeError: - return dbt_common.exceptions.DbtRuntimeError( - "Azure authentication requires the optional dependency 'azure-identity'. " - "Install it with `pip install azure-identity` or use a non-Azure " - "authentication mode." - ) - - -def _get_mssql_python() -> MssqlPythonModuleProtocol: - global _MSSQL_PYTHON_MODULE, _MSSQL_PYTHON_IMPORT_ERROR - - if _MSSQL_PYTHON_MODULE is not None: - return _MSSQL_PYTHON_MODULE - - if _MSSQL_PYTHON_IMPORT_ERROR is not None: - raise _missing_mssql_python_error() from _MSSQL_PYTHON_IMPORT_ERROR - - try: - import mssql_python as imported_mssql_python # type: ignore[import] - except ModuleNotFoundError as exc: - _MSSQL_PYTHON_IMPORT_ERROR = exc - raise _missing_mssql_python_error() from exc - - _MSSQL_PYTHON_MODULE = cast(MssqlPythonModuleProtocol, imported_mssql_python) - return _MSSQL_PYTHON_MODULE - - -def _normalize_authentication(authentication: Optional[str]) -> str: - if authentication is None: - return "sql" - - normalized = authentication.strip().lower() - if normalized == "activedirectorymsi": - return "msi" - return normalized - - -def _uses_pyodbc_token_authentication(credentials: SQLServerCredentials) -> bool: - authentication = _normalize_authentication(credentials.authentication) - return authentication in AZURE_AUTH_FUNCTIONS or authentication == "activedirectoryaccesstoken" - - -def _is_mssql_python_backend(credentials: SQLServerCredentials) -> bool: - return credentials.backend == SQLServerBackend.mssql_python - - -def _validate_connection_requirements(credentials: SQLServerCredentials) -> None: - for name in ("host", "database", "schema"): - value = getattr(credentials, name) - if value is None or not str(value).strip(): - raise dbt_common.exceptions.DbtRuntimeError( - f"The `{name}` profile field is required for SQL Server connections." - ) - - if credentials.windows_login: - normalized = _normalize_mssql_python_authentication(credentials.authentication) - if normalized is not None and _auth_key(normalized).startswith("activedirectory"): - raise dbt_common.exceptions.DbtRuntimeError( - "windows_login/trusted_connection cannot be combined with ActiveDirectory " - "authentication. Remove `authentication` or disable `windows_login`." - ) - elif credentials.authentication is None or not str(credentials.authentication).strip(): - raise dbt_common.exceptions.DbtRuntimeError( - "The `authentication` profile field is required for SQL Server connections." - ) - - if credentials.encrypt is None: - raise dbt_common.exceptions.DbtRuntimeError( - "The `encrypt` profile field is required for SQL Server connections." - ) - if credentials.trust_cert is None: - raise dbt_common.exceptions.DbtRuntimeError( - "The `trust_cert` profile field is required for SQL Server connections." - ) - - -def _validate_pyodbc_requirements(credentials: SQLServerCredentials) -> None: - if credentials.driver is None or not credentials.driver.strip(): - raise dbt_common.exceptions.DbtRuntimeError( - "The pyodbc backend requires a SQL Server ODBC driver name " - "in the `driver` profile field." - ) - - -def _validate_mssql_python_requirements(credentials: SQLServerCredentials) -> None: - authentication = _normalize_mssql_python_authentication(credentials.authentication) - authentication_key = _auth_key(authentication) - - if authentication_key in MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS: - raise dbt_common.exceptions.DbtRuntimeError( - "Authentication '{}' is currently only supported by the pyodbc backend " - "in this adapter. " - "Use `backend: pyodbc` or use a connection-string-supported " - "authentication mode such as " - "`sql`, `ActiveDirectoryPassword`, `ActiveDirectoryInteractive`, " - "`ActiveDirectoryIntegrated`, " - "`ActiveDirectoryMSI`, `ActiveDirectoryDeviceCode`, " - "or `ActiveDirectoryDefault`.".format(authentication) - ) - - -def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: - """ - Convert bytes to a Microsoft windows byte string. - - Parameters - ---------- - value : bytes - The bytes. - - Returns - ------- - out : bytes - The Microsoft byte string. - """ - encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0)))) - return struct.pack(" bytes: - """ - Convert an access token to a Microsoft windows byte string. - - Parameters - ---------- - token : AccessTokenProtocol - The token. - - Returns - ------- - out : bytes - The Microsoft byte string. - """ - value = bytes(token.token, "UTF-8") - return convert_bytes_to_mswindows_byte_string(value) - - -def get_cli_access_token( - credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessTokenProtocol: - """ - Get an Azure access token using the CLI credentials - - First login with: - - ```bash - az login - ``` - - Parameters - ---------- - credentials: SQLServerCredentials - The credentials. - - Returns - ------- - out : AccessToken - Access token. - """ - _ = credentials - azure_identity = _get_azure_identity_module() - token = azure_identity.AzureCliCredential().get_token( - scope, timeout=getattr(credentials, "login_timeout", None) - ) - return token - - -def get_auto_access_token( - credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessTokenProtocol: - """ - Get an Azure access token automatically through azure-identity - - Parameters - ----------- - credentials: SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - azure_identity = _get_azure_identity_module() - token = azure_identity.DefaultAzureCredential().get_token( - scope, timeout=getattr(credentials, "login_timeout", None) - ) - return token - - -def get_environment_access_token( - credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessTokenProtocol: - """ - Get an Azure access token by reading environment variables - - Parameters - ----------- - credentials: SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - azure_identity = _get_azure_identity_module() - token = azure_identity.EnvironmentCredential().get_token( - scope, timeout=getattr(credentials, "login_timeout", None) - ) - return token - - -def get_msi_access_token( - credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessTokenProtocol: - """ - Get an Azure access token from the system's managed identity - - Parameters - ----------- - credentials: SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - _ = credentials - azure_identity = _get_azure_identity_module() - token = azure_identity.ManagedIdentityCredential().get_token(scope) - return token - - -def get_sp_access_token( - credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE -) -> AccessTokenProtocol: - """ - Get an Azure access token using the SP credentials. - - Parameters - ---------- - credentials : SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - _ = scope - azure_identity = _get_azure_identity_module() - token = azure_identity.ClientSecretCredential( - str(credentials.tenant_id), - str(credentials.client_id), - str(credentials.client_secret), - ).get_token(AZURE_CREDENTIAL_SCOPE) - return token - - -AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = { - "cli": get_cli_access_token, - "auto": get_auto_access_token, - "environment": get_environment_access_token, - "serviceprincipal": get_sp_access_token, - "msi": get_msi_access_token, -} - - -def get_pyodbc_attrs_before_credentials(credentials: SQLServerCredentials) -> Dict: - """ - Get the pyodbc attributes for authentication. - - Parameters - ---------- - credentials : SQLServerCredentials - Credentials. - - Returns - ------- - Dict - The pyodbc attributes for authentication. - """ - global _TOKEN - sql_copt_ss_access_token = 1256 # ODBC constant for access token - MAX_REMAINING_TIME = 300 - - authentication = _normalize_authentication(credentials.authentication) - - if authentication in AZURE_AUTH_FUNCTIONS: - if not _TOKEN or (_TOKEN.expires_on - time.time() < MAX_REMAINING_TIME): - _TOKEN = AZURE_AUTH_FUNCTIONS[authentication](credentials, AZURE_CREDENTIAL_SCOPE) - assert _TOKEN is not None - token_bytes = convert_access_token_to_mswindows_byte_string(_TOKEN) - return {sql_copt_ss_access_token: token_bytes} - - if authentication == "activedirectoryaccesstoken": - if credentials.access_token is None or credentials.access_token_expires_on is None: - raise ValueError( - ( - "Access token and access token expiry are " - "required for ActiveDirectoryAccessToken authentication." - ) - ) - _TOKEN = _get_azure_access_token_class()( - token=credentials.access_token, - expires_on=int( - time.time() + 4500.0 - if credentials.access_token_expires_on == 0 - else credentials.access_token_expires_on - ), - ) - assert _TOKEN is not None - return {sql_copt_ss_access_token: convert_access_token_to_mswindows_byte_string(_TOKEN)} - - return {} - - -def bool_to_connection_string_arg(key: str, value: bool) -> str: - """ - Convert a boolean to a connection string argument. - - Parameters - ---------- - key : str - The key to use in the connection string. - value : bool - The boolean to convert. - - Returns - ------- - out : str - The connection string argument. - """ - return f"{key}={'Yes' if value else 'No'}" - - -def _escape_connection_string_value(value: Optional[str]) -> str: - text = "" if value is None else str(value) - if text.startswith(" ") or text.endswith(" ") or any(ch in text for ch in ";{}"): - return "{" + text.replace("}", "}}") + "}" - return text - - -def byte_array_to_datetime(value: bytes) -> dt.datetime: - """ - Converts a DATETIMEOFFSET byte array to a timezone-aware datetime object - - Parameters - ---------- - value : buffer - A binary value conforming to SQL_SS_TIMESTAMPOFFSET_STRUCT - - Returns - ------- - out : datetime - - Source - ------ - SQL_SS_TIMESTAMPOFFSET datatype and SQL_SS_TIMESTAMPOFFSET_STRUCT layout: - https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date-time/data-type-support-for-odbc-date-and-time-improvements - """ - # unpack 20 bytes of data into a tuple of 9 values - tup = struct.unpack("<6hI2h", value) - - # construct a datetime object - return dt.datetime( - year=tup[0], - month=tup[1], - day=tup[2], - hour=tup[3], - minute=tup[4], - second=tup[5], - microsecond=tup[6] // 1000, # https://bugs.python.org/issue15443 - tzinfo=dt.timezone(dt.timedelta(hours=tup[7], minutes=tup[8])), - ) - - -def _build_server_arg(credentials: SQLServerCredentials) -> str: - host = credentials.host or "" - if "\\" in host: - # If there is a backslash \ in the host name, the host is a - # SQL Server named instance. In this case then port number has to be omitted. - return host - return f"{host},{credentials.port}" - - -def _format_connection_string_value(value: Optional[str], mssql_python_backend: bool) -> str: - if mssql_python_backend: - return _escape_connection_string_value(value) - return "{" + ("" if value is None else value) + "}" - - -def _build_common_connection_string_parts( - credentials: SQLServerCredentials, - mssql_python_backend: bool, -) -> list[str]: - con_str = [f"SERVER={_build_server_arg(credentials)}"] - con_str.append(f"Database={credentials.database}") - - authentication = credentials.authentication or "" - if mssql_python_backend: - authentication = _normalize_mssql_python_authentication(authentication) or "" - - if not authentication.strip() and not credentials.windows_login: - raise dbt_common.exceptions.DbtRuntimeError( - "The `authentication` profile field is required for SQL Server connections." - ) - - if "ActiveDirectory" in authentication and authentication != "ActiveDirectoryAccessToken": - con_str.append(f"Authentication={authentication}") - - if authentication == "ActiveDirectoryPassword": - con_str.append( - f"UID={_format_connection_string_value(credentials.UID, mssql_python_backend)}" - ) - con_str.append( - f"PWD={_format_connection_string_value(credentials.PWD, mssql_python_backend)}" - ) - elif authentication == "ActiveDirectoryServicePrincipal": - con_str.append( - "UID=" - + _format_connection_string_value( - credentials.client_id, - mssql_python_backend, - ) - ) - con_str.append( - "PWD=" - + _format_connection_string_value( - credentials.client_secret, - mssql_python_backend, - ) - ) - elif authentication == "ActiveDirectoryInteractive": - con_str.append( - "UID=%s" - % _format_connection_string_value( - credentials.UID, - mssql_python_backend, - ) - ) - elif authentication == "ActiveDirectoryMSI": - if credentials.PWD: - raise dbt_common.exceptions.DbtRuntimeError( - "password is not valid with ActiveDirectoryMSI for the mssql-python backend." - ) - if credentials.UID: - con_str.append( - f"UID={_format_connection_string_value(credentials.UID, mssql_python_backend)}" - ) - elif authentication == "ActiveDirectoryIntegrated": - if credentials.PWD: - raise dbt_common.exceptions.DbtRuntimeError( - "password is not valid with ActiveDirectoryIntegrated" - " for the mssql-python backend." - ) - - elif credentials.windows_login: - if mssql_python_backend and (credentials.UID or credentials.PWD): - raise dbt_common.exceptions.DbtRuntimeError( - "user/password are not valid with windows_login/trusted_connection " - "for the mssql-python backend." - ) - con_str.append("Trusted_Connection=yes") - elif authentication == "sql": - con_str.append( - f"UID={_format_connection_string_value(credentials.UID, mssql_python_backend)}" - ) - con_str.append( - f"PWD={_format_connection_string_value(credentials.PWD, mssql_python_backend)}" - ) - - if credentials.encrypt is None: - raise dbt_common.exceptions.DbtRuntimeError( - "The `encrypt` profile field is required for SQL Server connections." - ) - if credentials.trust_cert is None: - raise dbt_common.exceptions.DbtRuntimeError( - "The `trust_cert` profile field is required for SQL Server connections." - ) - - con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt)) - con_str.append(bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert)) - - if not mssql_python_backend: - # Reserved keyword 'app' is controlled by the driver and cannot be specified by the user. - application_name = f"dbt-{credentials.type}/{__version__.version}" - con_str.append(f"APP={application_name}") - - return con_str - - -def _build_pyodbc_connection_string(credentials: SQLServerCredentials) -> str: - con_str = [f"DRIVER={{{credentials.driver}}}"] - con_str.extend(_build_common_connection_string_parts(credentials, mssql_python_backend=False)) - con_str.append("Pooling=true") - - if credentials.trace_flag: - con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_ON") - else: - con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_OFF") - - con_str.append("ConnectRetryCount=3") - con_str.append("ConnectRetryInterval=10") - - return ";".join(con_str) - - -def _build_mssql_python_connection_string(credentials: SQLServerCredentials) -> str: - con_str = _build_common_connection_string_parts(credentials, mssql_python_backend=True) - con_str.append("ConnectRetryCount=3") - con_str.append("ConnectRetryInterval=10") - return ";".join(con_str) - - -def _sanitize_connection_string_for_logging(connection_string: str) -> str: - parts = connection_string.split(";") - sanitized = [] - for part in parts: - if part.lower().startswith("pwd="): - sanitized.append("PWD=***") - else: - sanitized.append(part) - return ";".join(sanitized) - - -def _connect_mssql_python( - mssql_python: MssqlPythonModuleProtocol, - credentials: SQLServerCredentials, - connection_string: str, -) -> Any: - handle = mssql_python.connect( - connection_string, - autocommit=True, - timeout=credentials.login_timeout, - ) - try: - handle.timeout = credentials.query_timeout - except Exception: - logger.debug( - "The mssql-python connection object does not expose a mutable `timeout` " - "attribute; continuing without setting query timeout on the handle." - ) - logger.debug(f"Connected to db: {credentials.database}") - return handle - - -def _connect_pyodbc( - pyodbc: PyodbcModuleProtocol, - credentials: SQLServerCredentials, - connection_string: str, -) -> Any: - pyodbc.pooling = True - attrs_before = get_pyodbc_attrs_before_credentials(credentials) - - handle = pyodbc.connect( - connection_string, - attrs_before=attrs_before, - autocommit=True, - timeout=credentials.login_timeout, - ) - handle.timeout = credentials.query_timeout - logger.debug(f"Connected to db: {credentials.database}") - return handle - - -def _get_backend_exceptions( - credentials: SQLServerCredentials, -) -> Tuple[Type[Exception], ...]: - if _is_mssql_python_backend(credentials): - mssql_python = _get_mssql_python() - - retryable_exceptions = [ - getattr(mssql_python, "InternalError", Exception), - getattr(mssql_python, "OperationalError", Exception), - ] - - if _uses_pyodbc_token_authentication(credentials): - retryable_exceptions.append(getattr(mssql_python, "InterfaceError", Exception)) - - return tuple(retryable_exceptions) - - pyodbc = _get_pyodbc() - - retryable_exceptions = [ - pyodbc.InternalError, - pyodbc.OperationalError, - ] - - if _uses_pyodbc_token_authentication(credentials): - retryable_exceptions.append(pyodbc.InterfaceError) - - return tuple(retryable_exceptions) - - -def _is_pyodbc_handle(handle: Any) -> bool: - return hasattr(handle, "add_output_converter") - class SQLServerConnectionManager(SQLConnectionManager): TYPE = "sqlserver" @contextmanager def exception_handler(self, sql): + """Translate backend database errors and re-raise everything else. + + The backend-specific ``DatabaseError`` type is discovered lazily so the + handler can work with either optional backend. Non-database exceptions + are logged, the connection is released on a best-effort basis, and the + original exception is re-raised unchanged. + """ + try: yield except Exception as e: credentials = self.get_thread_connection().credentials + if is_mssql_python_backend(credentials.backend): + database_error = _RUNTIME_STATE.get_mssql_python_database_error() + else: + database_error = _RUNTIME_STATE.get_pyodbc_database_error() - if not _is_mssql_python_backend(credentials): - pyodbc = _PYODBC_MODULE - if pyodbc is not None and isinstance(e, getattr(pyodbc, "DatabaseError", tuple())): - logger.debug("Database error: {}".format(str(e))) - - try: - self.release() - except Exception: - logger.debug("Failed to release connection!") - - raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e - - if _is_mssql_python_backend(credentials): - mssql_python = _MSSQL_PYTHON_MODULE - if mssql_python is not None and isinstance( - e, getattr(mssql_python, "DatabaseError", tuple()) - ): - logger.debug("Database error: {}".format(str(e))) - - try: - self.release() - except Exception: - logger.debug("Failed to release connection!") - - raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e + if database_error is not None and isinstance(e, database_error): + # The backend-specific handler releases the connection and raises + # DbtDatabaseError, so this branch must not fall through into the + # generic rollback / logging path below. + handle_backend_database_error(e, database_error, self.release) + logger.debug(f"SQL execution raised {type(e).__name__}: {e}") logger.debug(f"Error running SQL: {sql}") logger.debug("Rolling back transaction.") - self.release() - if isinstance(e, dbt_common.exceptions.DbtRuntimeError): - raise - - raise dbt_common.exceptions.DbtRuntimeError(e) + try: + self.release() + except Exception: + logger.debug("Failed to release connection!") + raise @classmethod def open(cls, connection: Connection) -> Connection: @@ -876,34 +106,28 @@ def open(cls, connection: Connection) -> Connection: credentials = cls.get_credentials(connection.credentials) - _validate_connection_requirements(credentials) + validate_connection_requirements(credentials) - if _is_mssql_python_backend(credentials): + if is_mssql_python_backend(credentials.backend): mssql_python = _get_mssql_python() - _validate_mssql_python_requirements(credentials) - con_str_concat = _build_mssql_python_connection_string(credentials) + validate_mssql_python_requirements(credentials) + con_str_concat = build_mssql_python_connection_string(credentials) + retryable_exceptions = get_mssql_python_retryable_exceptions(credentials, mssql_python) def connect() -> Any: - logger.debug( - "Using connection string: %s" - % _sanitize_connection_string_for_logging(con_str_concat) - ) + log_connection_string(con_str_concat) return _connect_mssql_python(mssql_python, credentials, con_str_concat) else: pyodbc = _get_pyodbc() - _validate_pyodbc_requirements(credentials) - con_str_concat = _build_pyodbc_connection_string(credentials) + validate_pyodbc_requirements(credentials) + con_str_concat = build_pyodbc_connection_string(credentials) + retryable_exceptions = get_pyodbc_retryable_exceptions(credentials, pyodbc) def connect() -> Any: - logger.debug( - "Using connection string: %s" - % _sanitize_connection_string_for_logging(con_str_concat) - ) + log_connection_string(con_str_concat) return _connect_pyodbc(pyodbc, credentials, con_str_concat) - retryable_exceptions = _get_backend_exceptions(credentials) - conn = cls.retry_connection( connection, connect=connect, @@ -955,7 +179,7 @@ def _execute_query_with_retry( cursor.execute(sql) else: bindings = [ - binding if not isinstance(binding, dt.datetime) else binding.isoformat() + (binding.isoformat() if isinstance(binding, dt.datetime) else binding) for binding in bindings ] cursor.execute(sql, bindings) @@ -997,10 +221,7 @@ def _execute_query_with_retry( ) with self.exception_handler(sql): - if abridge_sql_log: - log_sql = "{}...".format(sql[:512]) - else: - log_sql = sql + log_sql = f"{sql[:512]}..." if abridge_sql_log else sql fire_event( SQLQuery( @@ -1020,11 +241,11 @@ def _execute_query_with_retry( sql=sql, bindings=bindings, retryable_exceptions=retryable_exceptions, - retry_limit=credentials.retries if credentials.retries > 3 else retry_limit, + retry_limit=(credentials.retries if credentials.retries > 3 else retry_limit), attempt=1, ) - if _is_pyodbc_handle(connection.handle): + if is_pyodbc_handle(connection.handle): connection.handle.add_output_converter(-155, byte_array_to_datetime) fire_event( @@ -1052,10 +273,34 @@ def get_response(cls, cursor: Any) -> AdapterResponse: @classmethod def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: - data_type = str(type_code)[ - str(type_code).index("'") + 1 : str(type_code).rindex("'") # noqa: E203 - ] - return datatypes[data_type] + if isinstance(type_code, int): + raise dbt_common.exceptions.DbtRuntimeError( + "Unsupported SQL Server type code " + f"{type_code!r}: integer type codes are not mapped" + ) + + if isinstance(type_code, str) and type_code in datatypes: + return datatypes[type_code] + + as_str = str(type_code) + if "'" in as_str: + try: + start = as_str.index("'") + 1 + end = as_str.rindex("'") + data_type = as_str[start:end] + except ValueError: + data_type = None + else: + if data_type in datatypes: + return datatypes[data_type] + + if as_str in datatypes: + return datatypes[as_str] + + raise dbt_common.exceptions.DbtRuntimeError( + "Unsupported SQL Server type code " + f"{type_code!r}: no matching entry found in datatypes mapping" + ) def execute( self, @@ -1064,14 +309,18 @@ def execute( fetch: bool = False, limit: Optional[int] = None, ) -> Tuple[AdapterResponse, agate.Table]: + # Connection lifetime policy: the *connection handle* is intentionally + # kept open here. Open / release / cleanup are managed by the parent + # SQLConnectionManager (called by dbt-core's thread-local connection + # pool). pyodbc.pooling=True additionally reuses handles across + # tasks. Only the cursor needs explicit cleanup after each query. sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) try: response = self.get_response(cursor) if fetch: - while cursor.description is None: - if not cursor.nextset(): - break + while cursor.description is None and cursor.nextset(): + pass table = self.get_result_from_cursor(cursor, limit) else: table = empty_table() diff --git a/dbt/adapters/sqlserver/sqlserver_constants.py b/dbt/adapters/sqlserver/sqlserver_constants.py new file mode 100644 index 000000000..ecae4e8c0 --- /dev/null +++ b/dbt/adapters/sqlserver/sqlserver_constants.py @@ -0,0 +1,114 @@ +"""Constants shared by the SQL Server adapter.""" + +from __future__ import annotations + +SQLSERVER_BACKEND_PYODBC = "pyodbc" +SQLSERVER_BACKEND_MSSQL_PYTHON = "mssql-python" +SUPPORTED_SQLSERVER_BACKENDS = ( + SQLSERVER_BACKEND_PYODBC, + SQLSERVER_BACKEND_MSSQL_PYTHON, +) +SUPPORTED_SQLSERVER_BACKENDS_MESSAGE = "Supported backends are 'pyodbc' and 'mssql-python'." + +MSSQL_AUTH_ACTIVE_DIRECTORY_MSI = "ActiveDirectoryMSI" +MSSQL_AUTH_ACTIVE_DIRECTORY_INTEGRATED = "ActiveDirectoryIntegrated" +MSSQL_AUTH_ACTIVE_DIRECTORY_SERVICE_PRINCIPAL = "ActiveDirectoryServicePrincipal" +MSSQL_AUTH_ACTIVE_DIRECTORY_DEFAULT = "ActiveDirectoryDefault" +MSSQL_AUTH_ACTIVE_DIRECTORY_PASSWORD = "ActiveDirectoryPassword" +MSSQL_AUTH_ACTIVE_DIRECTORY_INTERACTIVE = "ActiveDirectoryInteractive" +MSSQL_AUTH_ACTIVE_DIRECTORY_DEVICE_CODE = "ActiveDirectoryDeviceCode" +MSSQL_AUTH_ACTIVE_DIRECTORY_ACCESS_TOKEN = "ActiveDirectoryAccessToken" +MSSQL_AUTH_CLI = "cli" +MSSQL_AUTH_ENVIRONMENT = "environment" + +# pyodbc's token-fetch path uses short, token-oriented aliases. The backend +# builders handle the longer connection-string auth names separately. +PYODBC_AUTH_ALIASES: dict[str, str] = { + "activedirectorymsi": "msi", +} + +# Connection-string auth aliases are canonicalized separately for the +# mssql-python builder. Keep this map distinct from the pyodbc token aliases +# above so the two auth flows do not drift together accidentally. +CONNECTION_AUTH_ALIASES: dict[str, str] = { + "msi": MSSQL_AUTH_ACTIVE_DIRECTORY_MSI, + "activedirectorymsi": MSSQL_AUTH_ACTIVE_DIRECTORY_MSI, + "activedirectoryintegrated": MSSQL_AUTH_ACTIVE_DIRECTORY_INTEGRATED, + "adintegrated": MSSQL_AUTH_ACTIVE_DIRECTORY_INTEGRATED, + "serviceprincipal": MSSQL_AUTH_ACTIVE_DIRECTORY_SERVICE_PRINCIPAL, + "activedirectoryserviceprincipal": MSSQL_AUTH_ACTIVE_DIRECTORY_SERVICE_PRINCIPAL, + "auto": MSSQL_AUTH_ACTIVE_DIRECTORY_DEFAULT, + "default": MSSQL_AUTH_ACTIVE_DIRECTORY_DEFAULT, + "activedirectorydefault": MSSQL_AUTH_ACTIVE_DIRECTORY_DEFAULT, + "activedirectorypassword": MSSQL_AUTH_ACTIVE_DIRECTORY_PASSWORD, + "activedirectoryinteractive": MSSQL_AUTH_ACTIVE_DIRECTORY_INTERACTIVE, + "activedirectorydevicecode": MSSQL_AUTH_ACTIVE_DIRECTORY_DEVICE_CODE, + "access_token": MSSQL_AUTH_ACTIVE_DIRECTORY_ACCESS_TOKEN, + "activedirectoryaccesstoken": MSSQL_AUTH_ACTIVE_DIRECTORY_ACCESS_TOKEN, + MSSQL_AUTH_CLI: MSSQL_AUTH_CLI, + MSSQL_AUTH_ENVIRONMENT: MSSQL_AUTH_ENVIRONMENT, +} + +CONNECTION_AUTH_PASSTHROUGH_KEYS: frozenset[str] = frozenset( + { + MSSQL_AUTH_CLI, + MSSQL_AUTH_ENVIRONMENT, + } +) + +# Canonical pyodbc auth labels that should trigger Azure token caching and +# retryable InterfaceErrors. `ActiveDirectoryAccessToken` is included because +# it still flows through the same token-auth retry policy even though the token +# itself is supplied directly by the caller. +AAD_TOKEN_AUTHENTICATIONS: frozenset[str] = frozenset( + { + MSSQL_AUTH_CLI, + MSSQL_AUTH_ENVIRONMENT, + "auto", + "msi", + "serviceprincipal", + "activedirectoryaccesstoken", + } +) + +MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS = { + MSSQL_AUTH_CLI, + MSSQL_AUTH_ENVIRONMENT, + MSSQL_AUTH_ACTIVE_DIRECTORY_ACCESS_TOKEN, +} + +# Keys whose values must never appear in log output. Keep this scoped to the +# exact connection-string fields that carry secrets so non-secret auth metadata +# does not get redacted. +SENSITIVE_CONNECTION_STRING_KEYS: frozenset[str] = frozenset( + { + "pwd", + "password", + "clientsecret", + "accesstoken", + "accountkey", + "sharedaccesskey", + "sharedaccesssignature", + "uid", + "userid", + "user", + "username", + "clientid", + "secret", + } +) +# https://github.com/mkleehammer/pyodbc/wiki/Data-Types +datatypes = { + "str": "varchar", + "uuid.UUID": "uniqueidentifier", + "uuid": "uniqueidentifier", + "float": "bigint", + "int": "int", + "bytes": "varbinary", + "bytearray": "varbinary", + "bool": "bit", + "datetime.date": "date", + "datetime.datetime": "datetime2(6)", + "datetime.time": "time", + "decimal.Decimal": "decimal", +} diff --git a/dbt/adapters/sqlserver/sqlserver_credentials.py b/dbt/adapters/sqlserver/sqlserver_credentials.py index 81f7f7bb9..05be0f75d 100644 --- a/dbt/adapters/sqlserver/sqlserver_credentials.py +++ b/dbt/adapters/sqlserver/sqlserver_credentials.py @@ -5,16 +5,36 @@ from dbt_common.dataclass_schema import StrEnum from dbt.adapters.contracts.connection import Credentials +from dbt.adapters.sqlserver.sqlserver_auth import normalize_authentication_key +from dbt.adapters.sqlserver.sqlserver_constants import ( + MSSQL_AUTH_ACTIVE_DIRECTORY_SERVICE_PRINCIPAL, + SQLSERVER_BACKEND_MSSQL_PYTHON, + SQLSERVER_BACKEND_PYODBC, + SUPPORTED_SQLSERVER_BACKENDS_MESSAGE, +) +from dbt.adapters.sqlserver.sqlserver_helpers import normalize_query_timeout class SQLServerBackend(StrEnum): - pyodbc = "pyodbc" - mssql_python = "mssql-python" + pyodbc = SQLSERVER_BACKEND_PYODBC + mssql_python = SQLSERVER_BACKEND_MSSQL_PYTHON + + +DEFAULT_SQLSERVER_BACKEND = cast(SQLServerBackend, SQLServerBackend.pyodbc) + + +def coerce_backend(backend: Union[SQLServerBackend, str]) -> SQLServerBackend: + try: + return SQLServerBackend(backend) + except ValueError as exc: + raise dbt_common.exceptions.DbtRuntimeError( + f"Unsupported sqlserver backend: '{backend}'. {SUPPORTED_SQLSERVER_BACKENDS_MESSAGE}" + ) from exc @dataclass class SQLServerCredentials(Credentials): - backend: Union[SQLServerBackend, str] = SQLServerBackend.pyodbc + backend: SQLServerBackend = DEFAULT_SQLSERVER_BACKEND driver: Optional[str] = None host: Optional[str] = None database: Optional[str] = None @@ -53,30 +73,21 @@ class SQLServerCredentials(Credentials): } def __post_init__(self) -> None: - if isinstance(self.backend, str): - try: - self.backend = SQLServerBackend(self.backend) - except ValueError as exc: - raise dbt_common.exceptions.DbtRuntimeError( - "Unsupported sqlserver backend: '{}'. " - "Supported backends are 'pyodbc' and 'mssql-python'.".format(self.backend) - ) from exc - - self.backend = cast(SQLServerBackend, self.backend) + self.backend = coerce_backend(self.backend) + self.query_timeout = normalize_query_timeout(self.query_timeout) @property def type(self): return "sqlserver" - def _effective_backend(self) -> SQLServerBackend: - return cast(SQLServerBackend, self.backend) - def _connection_keys(self): - if self.windows_login is True: - self.authentication = "Windows Login" + """Return the credential fields that distinguish reusable connections.""" - if self.authentication.lower().strip() == "serviceprincipal": - self.authentication = "ActiveDirectoryServicePrincipal" + authentication = self.authentication + if self.windows_login is True: + authentication = "Windows Login" + elif normalize_authentication_key(authentication) == "serviceprincipal": + authentication = MSSQL_AUTH_ACTIVE_DIRECTORY_SERVICE_PRINCIPAL keys = ( "server", @@ -85,6 +96,7 @@ def _connection_keys(self): "schema", "UID", "authentication", + "windows_login", "retries", "login_timeout", "query_timeout", @@ -94,7 +106,10 @@ def _connection_keys(self): "backend", ) - if self._effective_backend() == SQLServerBackend.pyodbc: + if self.backend == SQLServerBackend.pyodbc: + # Only the pyodbc path uses an ODBC driver name. The mssql-python + # backend ignores `driver`, so excluding it keeps connection reuse + # aligned with the actual connection string that backend produces. keys = ("driver",) + keys return keys diff --git a/dbt/adapters/sqlserver/sqlserver_helpers.py b/dbt/adapters/sqlserver/sqlserver_helpers.py new file mode 100644 index 000000000..83d940a59 --- /dev/null +++ b/dbt/adapters/sqlserver/sqlserver_helpers.py @@ -0,0 +1,291 @@ +"""Shared backend and connection-string helpers for the SQL Server adapter. + +Authentication constants and normalization live in ``sqlserver_constants`` +and ``sqlserver_auth`` so this module can stay focused on connection-string +validation, formatting, and logging helpers. +""" + +from __future__ import annotations + +import datetime as dt +import numbers +import struct +from typing import TYPE_CHECKING, Any, Optional + +import dbt_common.exceptions + +from dbt.adapters.sqlserver.sqlserver_auth import ( + is_active_directory_authentication, + is_mssql_python_backend, + normalize_authentication_key, + normalize_connection_authentication, +) +from dbt.adapters.sqlserver.sqlserver_constants import ( + MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS, + SENSITIVE_CONNECTION_STRING_KEYS, +) + +if TYPE_CHECKING: + from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials + + +def validate_connection_requirements(credentials: SQLServerCredentials) -> None: + """Connection-manager preflight for shared profile fields. + + Invariants: + - `host`, `database`, and `schema` fields must not be empty or blank. + - `encrypt` and `trust_cert` fields must not be None. + - `authentication` is required unless `windows_login` is True. + - `windows_login` and ActiveDirectory-based authentication are mutually exclusive. + - `query_timeout` is normalized into a non-negative integer. + + Integration: + This preflight validator runs immediately after credentials coercion and before + backend-specific builders (`build_mssql_python_connection_string` or + `build_pyodbc_connection_string`) or backend-specific requirement checks. + It ensures consistent base states. + """ + + for name, value in ( + ("host", credentials.host), + ("database", credentials.database), + ("schema", credentials.schema), + ): + if value is None or not str(value).strip(): + raise dbt_common.exceptions.DbtRuntimeError( + f"The `{name}` profile field is required for SQL Server connections." + ) + + normalized = normalize_connection_authentication( + credentials.authentication, + is_mssql_python_backend(credentials.backend), + ) + credentials.query_timeout = normalize_query_timeout(credentials.query_timeout) + if credentials.windows_login: + if normalized and is_active_directory_authentication(normalized): + raise dbt_common.exceptions.DbtRuntimeError( + "windows_login/trusted_connection cannot be combined with ActiveDirectory " + "authentication. Remove `authentication` or disable `windows_login`." + ) + elif not normalized: + raise dbt_common.exceptions.DbtRuntimeError( + "The `authentication` profile field is required for SQL Server connections." + ) + + if credentials.encrypt is None: + raise dbt_common.exceptions.DbtRuntimeError( + "The `encrypt` profile field is required for SQL Server connections." + ) + if credentials.trust_cert is None: + raise dbt_common.exceptions.DbtRuntimeError( + "The `trust_cert` profile field is required for SQL Server connections." + ) + + +def validate_pyodbc_requirements(credentials: SQLServerCredentials) -> None: + """Backend-specific validation for the legacy pyodbc connection path.""" + + driver = credentials.driver + if driver is None or not driver.strip(): + raise dbt_common.exceptions.DbtRuntimeError( + "The pyodbc backend requires a SQL Server ODBC driver name " + "in the `driver` profile field." + ) + + +def validate_mssql_python_requirements(credentials: SQLServerCredentials) -> None: + """Backend-specific validation for the mssql-python connection path.""" + + authentication = normalize_connection_authentication(credentials.authentication, True) + + if authentication in MSSQL_PYTHON_UNSUPPORTED_AUTHENTICATIONS: + raise dbt_common.exceptions.DbtRuntimeError( + f"Authentication '{authentication}' is currently only supported by the pyodbc backend " + "in this adapter. " + "Use `backend: pyodbc` or use a connection-string-supported " + "authentication mode such as " + "`sql`, `ActiveDirectoryPassword`, `ActiveDirectoryInteractive`, " + "`ActiveDirectoryIntegrated`, `ActiveDirectoryMSI`, " + "`ActiveDirectoryDeviceCode`, or `ActiveDirectoryDefault`." + ) + + +def normalize_connection_string_key(key: str) -> str: + """Normalize a connection-string key for secret-field lookups.""" + + return normalize_authentication_key(key) + + +def split_connection_string_parts(connection_string: str) -> list[str]: + """Split a SQL Server connection string into normalized segments.""" + parts: list[str] = [] + current: list[str] = [] + in_braces = False + index = 0 + + while index < len(connection_string): + char = connection_string[index] + + if char == ";" and not in_braces: + if segment := "".join(current).strip(): + parts.append(segment) + current = [] + else: + current.append(char) + start = index + 1 + + if char == "{" and not in_braces and "}" in connection_string[start:]: + in_braces = True + elif char == "}" and in_braces: + if index + 1 < len(connection_string) and connection_string[index + 1] == "}": + current.append("}") + index += 1 + else: + in_braces = False + + index += 1 + + if segment := "".join(current).strip(): + parts.append(segment) + return parts + + +def escape_connection_string_value(value: Optional[str]) -> str: + text = "" if value is None else str(value) + if text.startswith(" ") or text.endswith(" ") or any(ch in text for ch in ";{}"): + return "{" + text.replace("}", "}}") + "}" + return text + + +def bool_to_connection_string_arg(key: str, value: Optional[bool]) -> str: + return f"{key}={'Yes' if value else 'No'}" + + +def normalize_query_timeout(query_timeout: Any) -> int: + """Normalize query timeouts and fail fast on invalid negative values. + + Accepts integers and integer-like strings so config parsing can hand this + helper raw values without leaking type quirks into the connection layer. + """ + + if query_timeout is None: + return 0 + if isinstance(query_timeout, bool): + raise dbt_common.exceptions.DbtRuntimeError( + "The `query_timeout` profile field must be a non-negative integer." + ) + + if isinstance(query_timeout, numbers.Integral): + normalized = int(query_timeout) + elif isinstance(query_timeout, str): + try: + normalized = int(query_timeout) + except ValueError as exc: + raise dbt_common.exceptions.DbtRuntimeError( + "The `query_timeout` profile field must be a non-negative integer." + ) from exc + else: + raise dbt_common.exceptions.DbtRuntimeError( + "The `query_timeout` profile field must be a non-negative integer." + ) + + if normalized < 0: + raise dbt_common.exceptions.DbtRuntimeError( + "The `query_timeout` profile field must be a non-negative integer." + ) + + return normalized + + +def build_server_arg(credentials: SQLServerCredentials) -> str: + """Build the `SERVER` token, preserving named instances without a port.""" + + host = (credentials.host or "").strip() + port = credentials.port + + if "\\" in host: + return host + + return f"{host},{port}" if port else host + + +def format_connection_string_value(value: Optional[str], mssql_python_backend: bool) -> str: + """Format a connection-string value for the requested backend.""" + + if mssql_python_backend: + return escape_connection_string_value(value) + return "{" + ("" if value is None else value) + "}" + + +def format_pyodbc_driver_value(value: Optional[str]) -> str: + """Format a pyodbc driver value without double-wrapping explicit braces.""" + + text = "" if value is None else str(value) + if len(text) >= 2 and text.startswith("{") and text.endswith("}"): + return text + return "{" + text + "}" + + +def sanitize_connection_string_for_logging(connection_string: str) -> str: + """Redact sensitive connection-string fields while preserving structure.""" + + sanitized = [] + for part in split_connection_string_parts(connection_string): + if "=" in part: + key, _value = part.split("=", 1) + normalized_key = normalize_connection_string_key(key.strip()) + if normalized_key in SENSITIVE_CONNECTION_STRING_KEYS: + sanitized.append(f"{key.strip()}=***") + continue + sanitized.append(part) + return ";".join(sanitized) + + +def byte_array_to_datetime(value: bytes) -> dt.datetime: + """ + Converts a DATETIMEOFFSET byte array to a timezone-aware datetime object + + Parameters + ---------- + value : buffer + A binary value conforming to SQL_SS_TIMESTAMPOFFSET_STRUCT + + Returns + ------- + out : datetime + + Source + ------ + SQL_SS_TIMESTAMPOFFSET datatype and SQL_SS_TIMESTAMPOFFSET_STRUCT layout: + https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date- + time/data-type-support-for-odbc-date-and-time-improvements + """ + # unpack 20 bytes of data into a tuple of 9 values + tup = struct.unpack("<6hI2h", value) + + # construct a datetime object + return dt.datetime( + year=tup[0], + month=tup[1], + day=tup[2], + hour=tup[3], + minute=tup[4], + second=tup[5], + microsecond=tup[6] // 1000, # https://bugs.python.org/issue15443 + tzinfo=dt.timezone(dt.timedelta(hours=tup[7], minutes=tup[8])), + ) + + +def _set_query_timeout_if_supported(handle: Any, query_timeout: Any) -> bool: + """Normalize and apply `query_timeout`; return False when the handle lacks support.""" + + query_timeout = normalize_query_timeout(query_timeout) + if query_timeout == 0: + return True + + try: + handle.timeout = query_timeout + except AttributeError: + return False + + return True diff --git a/dbt/adapters/sqlserver/sqlserver_runtime.py b/dbt/adapters/sqlserver/sqlserver_runtime.py new file mode 100644 index 000000000..414a97786 --- /dev/null +++ b/dbt/adapters/sqlserver/sqlserver_runtime.py @@ -0,0 +1,450 @@ +"""Internal runtime state for optional backend imports and token caches.""" + +from __future__ import annotations + +import threading +import time +from dataclasses import dataclass +from typing import Any, Callable, NamedTuple, Optional, Protocol, Type, cast + +import dbt_common.exceptions + +_UNSET = object() + +AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default" + + +class AccessTokenProtocol(Protocol): + token: str + expires_on: int + + +class TokenCredentialProtocol(Protocol): + def get_token(self, *scopes: Optional[str], **kwargs: Any) -> AccessTokenProtocol: ... + + +class CredentialFactory(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> TokenCredentialProtocol: ... + + +class AzureIdentityModuleProtocol(Protocol): + AzureCliCredential: CredentialFactory + DefaultAzureCredential: CredentialFactory + EnvironmentCredential: CredentialFactory + ManagedIdentityCredential: CredentialFactory + ClientSecretCredential: CredentialFactory + + +class AzureCredentialsModuleProtocol(Protocol): + AccessToken: Type[AccessTokenProtocol] + + +class PyodbcModuleProtocol(Protocol): + InternalError: type[Exception] + OperationalError: type[Exception] + InterfaceError: type[Exception] + DatabaseError: type[Exception] + pooling: bool + + def connect(self, *args: Any, **kwargs: Any) -> Any: ... + + +class MssqlPythonModuleProtocol(Protocol): + InternalError: type[Exception] + OperationalError: type[Exception] + InterfaceError: type[Exception] + DatabaseError: type[Exception] + + def pooling( + self, + max_size: int = 100, + idle_timeout: int = 600, + enabled: bool = True, + ) -> None: ... + + def connect(self, *args: Any, **kwargs: Any) -> Any: ... + + +@dataclass +class AccessToken: + token: str + expires_on: int + + +@dataclass(frozen=True) +class SQLServerRuntimeSnapshot: + """Shallow copy of mutable runtime state used by focused tests.""" + + pyodbc_module: Any + pyodbc_import_error: Optional[ModuleNotFoundError] + mssql_python_module: Any + mssql_python_import_error: Optional[ModuleNotFoundError] + azure_credentials_module: Any + azure_credentials_import_error: Optional[ModuleNotFoundError] + azure_identity_module: Any + azure_identity_import_error: Optional[ModuleNotFoundError] + access_token_cache: dict[Any, Any] + timeout_warning_logged: bool + + +class SQLServerRuntimeState: + """Own the mutable state behind lazy imports and shared caches. + + Lifecycle and ownership: + - This singleton is the only supported home for optional backend module + imports, cached DatabaseError classes, Azure access tokens, and the + one-shot timeout warning flag. + - Public helpers in ``sqlserver_runtime.py`` are the intended access + points; callers should avoid reading or mutating the fields directly. + + Thread-safety: + - ``module_load_lock`` protects lazy imports and cached exception types. + - ``access_token_cache_lock`` protects token reads/writes. + - ``timeout_warning_lock`` ensures the warning is emitted at most once. + """ + + def __init__(self) -> None: + self.pyodbc_module: Any = None + self.pyodbc_import_error: Optional[ModuleNotFoundError] = None + self.mssql_python_module: Any = None + self.mssql_python_import_error: Optional[ModuleNotFoundError] = None + self.azure_credentials_module: Any = None + self.azure_credentials_import_error: Optional[ModuleNotFoundError] = None + self.azure_identity_module: Any = None + self.azure_identity_import_error: Optional[ModuleNotFoundError] = None + self.access_token_cache: dict[Any, Any] = {} + self.timeout_warning_logged = False + self._pyodbc_db_error: Optional[type[Exception]] = None + self._mssql_python_db_error: Optional[type[Exception]] = None + + self.module_load_lock = threading.Lock() + self.access_token_cache_lock = threading.Lock() + self.timeout_warning_lock = threading.Lock() + + def reset_modules(self) -> None: + with self.module_load_lock: + self.pyodbc_module = None + self.pyodbc_import_error = None + self.mssql_python_module = None + self.mssql_python_import_error = None + self.azure_credentials_module = None + self.azure_credentials_import_error = None + self.azure_identity_module = None + self.azure_identity_import_error = None + self._pyodbc_db_error = None + self._mssql_python_db_error = None + + def reset_access_token_cache(self) -> None: + with self.access_token_cache_lock: + self.access_token_cache.clear() + + def reset_timeout_warning(self) -> None: + with self.timeout_warning_lock: + self.timeout_warning_logged = False + + def reset(self) -> None: + self.reset_modules() + self.reset_access_token_cache() + self.reset_timeout_warning() + + def get_pyodbc_database_error(self) -> Optional[type[Exception]]: + with self.module_load_lock: + if self._pyodbc_db_error is not None: + return self._pyodbc_db_error + if self.pyodbc_module is not None: + self._pyodbc_db_error = self.pyodbc_module.DatabaseError + return self._pyodbc_db_error + return None + + def get_mssql_python_database_error(self) -> Optional[type[Exception]]: + with self.module_load_lock: + if self._mssql_python_db_error is not None: + return self._mssql_python_db_error + if self.mssql_python_module is not None: + self._mssql_python_db_error = self.mssql_python_module.DatabaseError + return self._mssql_python_db_error + return None + + def get_cached_access_token( + self, + cache_key: Any, + loader: Callable[[], Any], + *, + refresh_buffer_seconds: int = 300, + ) -> Any: + """Return a cached token without holding the lock during refresh.""" + + with self.access_token_cache_lock: + token = self.access_token_cache.get(cache_key) + if token and (token.expires_on - time.time() >= refresh_buffer_seconds): + return token + + token = loader() + + with self.access_token_cache_lock: + cached_token = self.access_token_cache.get(cache_key) + if cached_token and (cached_token.expires_on - time.time() >= refresh_buffer_seconds): + return cached_token + self.access_token_cache[cache_key] = token + return token + + def take_timeout_warning(self) -> bool: + with self.timeout_warning_lock: + if self.timeout_warning_logged: + return False + self.timeout_warning_logged = True + return True + + def snapshot(self) -> SQLServerRuntimeSnapshot: + with self.module_load_lock: + pyodbc_module = self.pyodbc_module + pyodbc_import_error = self.pyodbc_import_error + mssql_python_module = self.mssql_python_module + mssql_python_import_error = self.mssql_python_import_error + azure_credentials_module = self.azure_credentials_module + azure_credentials_import_error = self.azure_credentials_import_error + azure_identity_module = self.azure_identity_module + azure_identity_import_error = self.azure_identity_import_error + with self.access_token_cache_lock: + access_token_cache = dict(self.access_token_cache) + with self.timeout_warning_lock: + timeout_warning_logged = self.timeout_warning_logged + + return SQLServerRuntimeSnapshot( + pyodbc_module=pyodbc_module, + pyodbc_import_error=pyodbc_import_error, + mssql_python_module=mssql_python_module, + mssql_python_import_error=mssql_python_import_error, + azure_credentials_module=azure_credentials_module, + azure_credentials_import_error=azure_credentials_import_error, + azure_identity_module=azure_identity_module, + azure_identity_import_error=azure_identity_import_error, + access_token_cache=access_token_cache, + timeout_warning_logged=timeout_warning_logged, + ) + + def configure_for_test( + self, + *, + pyodbc_module: Any = _UNSET, + pyodbc_import_error: Any = _UNSET, + mssql_python_module: Any = _UNSET, + mssql_python_import_error: Any = _UNSET, + azure_credentials_module: Any = _UNSET, + azure_credentials_import_error: Any = _UNSET, + azure_identity_module: Any = _UNSET, + azure_identity_import_error: Any = _UNSET, + access_token_cache: Any = _UNSET, + timeout_warning_logged: Any = _UNSET, + ) -> None: + """Targeted mutation helper used by tests instead of poking globals.""" + + with self.module_load_lock: + if pyodbc_module is not _UNSET: + self.pyodbc_module = pyodbc_module + if pyodbc_import_error is not _UNSET: + self.pyodbc_import_error = pyodbc_import_error + if mssql_python_module is not _UNSET: + self.mssql_python_module = mssql_python_module + if mssql_python_import_error is not _UNSET: + self.mssql_python_import_error = mssql_python_import_error + if azure_credentials_module is not _UNSET: + self.azure_credentials_module = azure_credentials_module + if azure_credentials_import_error is not _UNSET: + self.azure_credentials_import_error = azure_credentials_import_error + if azure_identity_module is not _UNSET: + self.azure_identity_module = azure_identity_module + if azure_identity_import_error is not _UNSET: + self.azure_identity_import_error = azure_identity_import_error + + if access_token_cache is not _UNSET: + with self.access_token_cache_lock: + self.access_token_cache = dict(access_token_cache) + + if timeout_warning_logged is not _UNSET: + with self.timeout_warning_lock: + self.timeout_warning_logged = bool(timeout_warning_logged) + + +_RUNTIME_STATE = SQLServerRuntimeState() + + +class _AccessTokenCacheKey(NamedTuple): + """Dimensions that uniquely identify a cached Azure access token. + + Keeping these fields in one named type means future changes to caching + strategy (e.g. adding a subscription dimension) only require edits here + rather than hunting through the cache dict type hint and the builder. + """ + + authentication: str + scope: str + backend: Any + tenant_id: Optional[str] + client_id: Optional[str] + + +def _access_token_cache_key( + credentials: Any, + authentication: str, + scope: str, +) -> _AccessTokenCacheKey: + """Build the cache key used to memoize Azure access tokens.""" + + return _AccessTokenCacheKey( + authentication=authentication, + scope=scope, + backend=credentials.backend, + tenant_id=credentials.tenant_id, + client_id=credentials.client_id, + ) + + +def _get_azure_access_token_class() -> Type[Any]: + """Return the Azure ``AccessToken`` class or the local fallback.""" + + with _RUNTIME_STATE.module_load_lock: + if _RUNTIME_STATE.azure_credentials_module is not None: + return _RUNTIME_STATE.azure_credentials_module.AccessToken + + if _RUNTIME_STATE.azure_credentials_import_error is not None: + return AccessToken + + try: + # type: ignore[import] + import azure.core.credentials as azure_credentials + except ModuleNotFoundError as exc: + _RUNTIME_STATE.azure_credentials_import_error = exc + return AccessToken + + _RUNTIME_STATE.azure_credentials_module = cast( + AzureCredentialsModuleProtocol, azure_credentials + ) + return _RUNTIME_STATE.azure_credentials_module.AccessToken + + +def _missing_azure_identity_error() -> dbt_common.exceptions.DbtRuntimeError: + return dbt_common.exceptions.DbtRuntimeError( + "Azure authentication requires the optional dependency 'azure-identity'. " + "Install it with `pip install azure-identity` or use a non-Azure " + "authentication mode." + ) + + +def _get_azure_identity_module() -> AzureIdentityModuleProtocol: + """Import and cache ``azure.identity`` when Azure auth is requested.""" + + with _RUNTIME_STATE.module_load_lock: + if _RUNTIME_STATE.azure_identity_module is not None: + return _RUNTIME_STATE.azure_identity_module + + if _RUNTIME_STATE.azure_identity_import_error is not None: + raise _missing_azure_identity_error() from _RUNTIME_STATE.azure_identity_import_error + + try: + import azure.identity as azure_identity # type: ignore[import] + except ModuleNotFoundError as exc: + _RUNTIME_STATE.azure_identity_import_error = exc + raise _missing_azure_identity_error() from exc + + _RUNTIME_STATE.azure_identity_module = cast(AzureIdentityModuleProtocol, azure_identity) + return _RUNTIME_STATE.azure_identity_module + + +def reset_runtime_state_for_test() -> None: + """Clear optional-backend runtime state in focused tests.""" + + _RUNTIME_STATE.reset() + + +def get_runtime_state_for_test() -> SQLServerRuntimeSnapshot: + """Return a shallow snapshot of optional-backend runtime state for tests.""" + + return _RUNTIME_STATE.snapshot() + + +def configure_runtime_state_for_test(**kwargs: Any) -> None: + """Update selected runtime-state fields in focused tests.""" + + _RUNTIME_STATE.configure_for_test(**kwargs) + + +def _missing_pyodbc_error() -> dbt_common.exceptions.DbtRuntimeError: + return dbt_common.exceptions.DbtRuntimeError( + "The legacy `pyodbc` backend was requested, but the optional dependency " + "`pyodbc` is not installed. Install it with `pip install pyodbc` " + "or set `backend: mssql-python` in the profile." + ) + + +def _get_pyodbc() -> PyodbcModuleProtocol: + """Import and cache ``pyodbc`` on first use. + + Expected Inputs: None. + Invariants: Thread-safe lazy import protected by module_load_lock. Raises + DbtRuntimeError if pyodbc is missing. + Integration: Provides the pyodbc module to the connection manager and auth handlers. + """ + + with _RUNTIME_STATE.module_load_lock: + if _RUNTIME_STATE.pyodbc_module is not None: + return _RUNTIME_STATE.pyodbc_module + + if _RUNTIME_STATE.pyodbc_import_error is not None: + raise _missing_pyodbc_error() from _RUNTIME_STATE.pyodbc_import_error + + try: + import pyodbc as imported_pyodbc # type: ignore[import] + except ModuleNotFoundError as exc: + _RUNTIME_STATE.pyodbc_import_error = exc + raise _missing_pyodbc_error() from exc + + _RUNTIME_STATE.pyodbc_module = cast(PyodbcModuleProtocol, imported_pyodbc) + return _RUNTIME_STATE.pyodbc_module + + +def _missing_mssql_python_error() -> dbt_common.exceptions.DbtRuntimeError: + return dbt_common.exceptions.DbtRuntimeError( + "The `mssql-python` backend was requested, but the optional dependency " + "`mssql-python` is not installed. Install it with `pip install mssql-python` " + "or set `backend: pyodbc` in the profile." + ) + + +def _get_mssql_python() -> MssqlPythonModuleProtocol: + """Import and cache the optional ``mssql_python`` backend on demand. + + Expected Inputs: None. + Invariants: Thread-safe lazy import protected by module_load_lock. Raises + DbtRuntimeError if mssql_python is missing. + Integration: Provides the mssql_python module to the connection manager. + """ + + with _RUNTIME_STATE.module_load_lock: + if _RUNTIME_STATE.mssql_python_module is not None: + return _RUNTIME_STATE.mssql_python_module + + if _RUNTIME_STATE.mssql_python_import_error is not None: + raise _missing_mssql_python_error() from _RUNTIME_STATE.mssql_python_import_error + + try: + # type: ignore[import] + import mssql_python as imported_mssql_python + except ModuleNotFoundError as exc: + _RUNTIME_STATE.mssql_python_import_error = exc + raise _missing_mssql_python_error() from exc + + _RUNTIME_STATE.mssql_python_module = cast(MssqlPythonModuleProtocol, imported_mssql_python) + return _RUNTIME_STATE.mssql_python_module + + +def _get_cached_access_token( + credentials: Any, + authentication: str, + scope: str, + loader: Callable[[], Any], +) -> AccessTokenProtocol: + """Return a cached Azure token using the shared runtime state.""" + + cache_key = _access_token_cache_key(credentials, authentication, scope) + return cast(AccessTokenProtocol, _RUNTIME_STATE.get_cached_access_token(cache_key, loader)) diff --git a/tests/__init__.py b/tests/__init__.py index c6609dfc3..29d839b9c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,26 +1,23 @@ import pytest from azure.identity import AzureCliCredential -from dbt.adapters.sqlserver.sqlserver_connections import ( # byte_array_to_datetime, - bool_to_connection_string_arg, - get_pyodbc_attrs_before_credentials, -) +from dbt.adapters.sqlserver.sqlserver_auth import get_pyodbc_attrs_before_credentials from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials +from dbt.adapters.sqlserver.sqlserver_helpers import bool_to_connection_string_arg # See # https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.5.0/sdk/identity/azure-identity/tests/test_cli_credential.py -CHECK_OUTPUT = AzureCliCredential.__module__ + ".subprocess.check_output" +CHECK_OUTPUT = f"{AzureCliCredential.__module__}.subprocess.check_output" @pytest.fixture def credentials() -> SQLServerCredentials: - credentials = SQLServerCredentials( + return SQLServerCredentials( driver="ODBC Driver 18 for SQL Server", host="fake.sql.sqlserver.net", database="dbt", schema="sqlserver", ) - return credentials def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( diff --git a/tests/unit/adapters/mssql/test_connection_logic.py b/tests/unit/adapters/mssql/test_connection_logic.py index 0ee823690..6bef0e3ff 100644 --- a/tests/unit/adapters/mssql/test_connection_logic.py +++ b/tests/unit/adapters/mssql/test_connection_logic.py @@ -1,11 +1,27 @@ from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from dbt.adapters.sqlserver import sqlserver_connections -from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager +from dbt.adapters.sqlserver import sqlserver_auth +from dbt.adapters.sqlserver.sqlserver_connections import ( + SQLServerConnectionManager, +) from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials +from dbt.adapters.sqlserver.sqlserver_runtime import ( + configure_runtime_state_for_test, + reset_runtime_state_for_test, +) + + +def _fake_pyodbc_module(connect): + return SimpleNamespace( + connect=connect, + pooling=False, + InternalError=type("InternalError", (Exception,), {}), + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + ) @pytest.fixture @@ -22,40 +38,29 @@ def base_credentials(): def test_connection_string_windows_login_with_port(base_credentials): - """Port is included in the SERVER token when windows_login is True.""" base_credentials.windows_login = True connection = MagicMock() connection.state = "closed" connection.credentials = base_credentials - fake_pyodbc = SimpleNamespace( - connect=MagicMock(return_value=MagicMock()), - pooling=False, - InternalError=type("InternalError", (Exception,), {}), - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - ) - - with ( - patch.object(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc), - patch.object(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None), - ): + reset_runtime_state_for_test() + fake_pyodbc = _fake_pyodbc_module(MagicMock(return_value=MagicMock())) + configure_runtime_state_for_test(pyodbc_module=fake_pyodbc, pyodbc_import_error=None) - SQLServerConnectionManager.open(connection) + SQLServerConnectionManager.open(connection) - args, _kwargs = fake_pyodbc.connect.call_args - connection_string = args[0] + args, _kwargs = fake_pyodbc.connect.call_args + connection_string = args[0] - assert "SERVER=servers.database.windows.net,1444" in connection_string - assert "Trusted_Connection=yes" in connection_string - assert "UID=" not in connection_string - assert "PWD=" not in connection_string - assert "APP=dbt-sqlserver/" in connection_string + assert "SERVER=servers.database.windows.net,1444" in connection_string + assert "Trusted_Connection=yes" in connection_string + assert "UID=" not in connection_string + assert "PWD=" not in connection_string + assert "APP=dbt-sqlserver/" in connection_string def test_connection_string_standard_login_with_port(base_credentials): - """Port is included in the SERVER token for sql authentication.""" base_credentials.windows_login = False base_credentials.authentication = "sql" base_credentials.UID = "user" @@ -66,35 +71,26 @@ def test_connection_string_standard_login_with_port(base_credentials): connection.state = "closed" connection.credentials = base_credentials - fake_pyodbc = SimpleNamespace( - connect=MagicMock(return_value=MagicMock()), - pooling=False, - InternalError=type("InternalError", (Exception,), {}), - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - ) - - with ( - patch.object(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc), - patch.object(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None), - ): + reset_runtime_state_for_test() + fake_pyodbc = _fake_pyodbc_module(MagicMock(return_value=MagicMock())) + configure_runtime_state_for_test(pyodbc_module=fake_pyodbc, pyodbc_import_error=None) - SQLServerConnectionManager.open(connection) + SQLServerConnectionManager.open(connection) - args, _kwargs = fake_pyodbc.connect.call_args - connection_string = args[0] + args, _kwargs = fake_pyodbc.connect.call_args + connection_string = args[0] - assert "SERVER=servers.database.windows.net,1444" in connection_string - assert "UID={user}" in connection_string - assert "PWD={password}" in connection_string - assert "Pooling=true" in connection_string - assert "SQL_ATTR_TRACE=SQL_OPT_TRACE_ON" in connection_string - assert "APP=dbt-sqlserver/" in connection_string - assert "ConnectRetryCount=3" in connection_string - assert "ConnectRetryInterval=10" in connection_string + assert "SERVER=servers.database.windows.net,1444" in connection_string + assert "UID={user}" in connection_string + assert "PWD={password}" in connection_string + assert "Pooling=true" in connection_string + assert "SQL_ATTR_TRACE=SQL_OPT_TRACE_ON" in connection_string + assert "APP=dbt-sqlserver/" in connection_string + assert "ConnectRetryCount=3" in connection_string + assert "ConnectRetryInterval=10" in connection_string -def test_pyodbc_token_authentication_passes_attrs_before(base_credentials, monkeypatch): +def test_pyodbc_token_authentication_passes_attrs_before(base_credentials): base_credentials.authentication = "cli" base_credentials.windows_login = False @@ -102,19 +98,18 @@ def test_pyodbc_token_authentication_passes_attrs_before(base_credentials, monke fake_credential = SimpleNamespace(get_token=lambda *args, **kwargs: fake_token) fake_identity = SimpleNamespace(AzureCliCredential=lambda *args, **kwargs: fake_credential) - monkeypatch.setattr( - sqlserver_connections, "_AZURE_IDENTITY_MODULE", fake_identity, raising=False + reset_runtime_state_for_test() + configure_runtime_state_for_test( + azure_identity_module=fake_identity, azure_identity_import_error=None ) - monkeypatch.setattr(sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", None, raising=False) - attrs_before = sqlserver_connections.get_pyodbc_attrs_before_credentials(base_credentials) + attrs_before = sqlserver_auth.get_pyodbc_attrs_before_credentials(base_credentials) assert 1256 in attrs_before assert isinstance(attrs_before[1256], bytes) def test_connection_string_named_instance_no_port(base_credentials): - """A named-instance host (containing `\\`) must not append a port to SERVER.""" base_credentials.host = "myhost\\instance" base_credentials.windows_login = True @@ -122,23 +117,14 @@ def test_connection_string_named_instance_no_port(base_credentials): connection.state = "closed" connection.credentials = base_credentials - fake_pyodbc = SimpleNamespace( - connect=MagicMock(return_value=MagicMock()), - pooling=False, - InternalError=type("InternalError", (Exception,), {}), - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - ) - - with ( - patch.object(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc), - patch.object(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None), - ): + reset_runtime_state_for_test() + fake_pyodbc = _fake_pyodbc_module(MagicMock(return_value=MagicMock())) + configure_runtime_state_for_test(pyodbc_module=fake_pyodbc, pyodbc_import_error=None) - SQLServerConnectionManager.open(connection) + SQLServerConnectionManager.open(connection) - args, _kwargs = fake_pyodbc.connect.call_args - connection_string = args[0] + args, _kwargs = fake_pyodbc.connect.call_args + connection_string = args[0] - assert "SERVER=myhost\\instance" in connection_string - assert ",1444" not in connection_string + assert "SERVER=myhost\\instance" in connection_string + assert ",1444" not in connection_string diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index be0c4163a..e1e315dd4 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -2,27 +2,55 @@ import importlib from types import SimpleNamespace from typing import Any, Dict, List +from unittest.mock import MagicMock import pytest from azure.identity import AzureCliCredential -from dbt_common.exceptions import DbtRuntimeError +from dbt_common.exceptions import DbtDatabaseError, DbtRuntimeError +import dbt.adapters.sqlserver.sqlserver_auth +import dbt.adapters.sqlserver.sqlserver_backend as sqlserver_backend from dbt.adapters.contracts.connection import Connection, ConnectionState from dbt.adapters.sqlserver import sqlserver_connections +from dbt.adapters.sqlserver.sqlserver_auth import ( + get_pyodbc_attrs_before_credentials, + normalize_mssql_python_authentication, + uses_aad_token_authentication, +) +from dbt.adapters.sqlserver.sqlserver_backend import ( + _finalize_connection_handle, + _finalize_mssql_python_handle, +) +from dbt.adapters.sqlserver.sqlserver_backend import ( + build_mssql_python_connection_string as _build_mssql_python_connection_string, +) +from dbt.adapters.sqlserver.sqlserver_backend import ( + build_pyodbc_connection_string as _build_pyodbc_connection_string, +) +from dbt.adapters.sqlserver.sqlserver_backend import is_pyodbc_handle as _is_pyodbc_handle from dbt.adapters.sqlserver.sqlserver_connections import ( SQLServerConnectionManager, - _build_mssql_python_connection_string, - _normalize_mssql_python_authentication, - _validate_mssql_python_requirements, - _validate_pyodbc_requirements, +) +from dbt.adapters.sqlserver.sqlserver_credentials import ( + SQLServerBackend, + SQLServerCredentials, +) +from dbt.adapters.sqlserver.sqlserver_helpers import ( bool_to_connection_string_arg, - get_pyodbc_attrs_before_credentials, + escape_connection_string_value, + is_mssql_python_backend, + sanitize_connection_string_for_logging, + validate_connection_requirements, + validate_mssql_python_requirements, + validate_pyodbc_requirements, +) +from dbt.adapters.sqlserver.sqlserver_runtime import ( + configure_runtime_state_for_test, + get_runtime_state_for_test, + reset_runtime_state_for_test, ) -from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerBackend, SQLServerCredentials -# See -# https://github.com/Azure/azure-sdk-for-python/blob/azure-identity_1.5.0/sdk/identity/azure-identity/tests/test_cli_credential.py -CHECK_OUTPUT = AzureCliCredential.__module__ + ".subprocess.check_output" +CHECK_OUTPUT = f"{AzureCliCredential.__module__}.subprocess.check_output" @pytest.fixture @@ -35,22 +63,18 @@ def credentials() -> SQLServerCredentials: ) -def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( +def test_get_pyodbc_attrs_before_sql_auth_returns_empty_dict( credentials: SQLServerCredentials, ) -> None: - """ - When the authentication is set to sql we expect an empty attrs before. - """ attrs_before = get_pyodbc_attrs_before_credentials(credentials) assert attrs_before == {} def test_get_pyodbc_attrs_before_sql_auth_without_azure_identity( - credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch + credentials: SQLServerCredentials, ) -> None: - monkeypatch.setattr( - sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", ModuleNotFoundError() - ) + reset_runtime_state_for_test() + configure_runtime_state_for_test(azure_identity_import_error=ModuleNotFoundError()) attrs_before = get_pyodbc_attrs_before_credentials(credentials) @@ -58,17 +82,57 @@ def test_get_pyodbc_attrs_before_sql_auth_without_azure_identity( def test_get_pyodbc_attrs_before_cli_auth_requires_azure_identity( - credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch + credentials: SQLServerCredentials, ) -> None: credentials.authentication = "cli" + reset_runtime_state_for_test() + configure_runtime_state_for_test(azure_identity_import_error=ModuleNotFoundError()) + + with pytest.raises(DbtRuntimeError, match="requires the optional dependency 'azure-identity'"): + get_pyodbc_attrs_before_credentials(credentials) + + +def test_get_pyodbc_attrs_before_active_directory_access_token_defaults_zero_expiry( + credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch +) -> None: + credentials.authentication = "ActiveDirectoryAccessToken" + credentials.access_token = "some-token" + + warnings: list[str] = [] monkeypatch.setattr( - sqlserver_connections, "_AZURE_IDENTITY_IMPORT_ERROR", ModuleNotFoundError() + dbt.adapters.sqlserver.sqlserver_auth.logger, + "warning", + lambda message, *args: warnings.append(message % args if args else message), ) - with pytest.raises(DbtRuntimeError, match="requires the optional dependency 'azure-identity'"): + credentials.access_token_expires_on = 0 + attrs = get_pyodbc_attrs_before_credentials(credentials) + assert 1256 in attrs + assert any("defaulting expiry" in message for message in warnings) + + +def test_get_pyodbc_attrs_before_active_directory_access_token_requires_expiry( + credentials: SQLServerCredentials, +) -> None: + credentials.authentication = "ActiveDirectoryAccessToken" + credentials.access_token = "some-token" + + credentials.access_token_expires_on = None + with pytest.raises(ValueError, match="access token expiry"): get_pyodbc_attrs_before_credentials(credentials) +def test_get_pyodbc_attrs_before_active_directory_access_token_honors_explicit_expiry( + credentials: SQLServerCredentials, +) -> None: + credentials.authentication = "ActiveDirectoryAccessToken" + credentials.access_token = "some-token" + credentials.access_token_expires_on = 123456789 + + attrs = get_pyodbc_attrs_before_credentials(credentials) + assert 1256 in attrs + + @pytest.mark.parametrize( "driver", [None, "", " "], @@ -84,9 +148,87 @@ def test_validate_pyodbc_requirements_rejects_blank_driver( ) with pytest.raises( - DbtRuntimeError, match="The pyodbc backend requires a SQL Server ODBC driver name" + DbtRuntimeError, + match="The pyodbc backend requires a SQL Server ODBC driver name", ): - _validate_pyodbc_requirements(credentials) + validate_pyodbc_requirements(credentials) + + +def test_validate_pyodbc_requirements_accepts_valid_driver() -> None: + credentials = SQLServerCredentials( + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + ) + + validate_pyodbc_requirements(credentials) + + +def test_validate_connection_requirements_allows_windows_login_without_auth() -> None: + credentials = SQLServerCredentials( + backend=SQLServerBackend.mssql_python, + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + windows_login=True, + authentication="", + encrypt=True, + trust_cert=True, + ) + + validate_connection_requirements(credentials) + + +def test_sqlserver_credentials_reject_negative_query_timeout() -> None: + with pytest.raises(DbtRuntimeError, match="query_timeout"): + SQLServerCredentials( + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + query_timeout=-1, + ) + + +def test_build_pyodbc_connection_string_formats_driver_name() -> None: + credentials = SQLServerCredentials( + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + authentication="sql", + UID="user", + PWD="password", + ) + + conn_str = _build_pyodbc_connection_string(credentials) + + assert conn_str.startswith("DRIVER={ODBC Driver 18 for SQL Server};") + assert "encrypt=Yes" in conn_str + assert "TrustServerCertificate=Yes" in conn_str + + +def test_build_pyodbc_connection_string_preserves_prebraced_driver_name() -> None: + credentials = SQLServerCredentials( + driver="{ODBC Driver 18 for SQL Server}", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + authentication="sql", + UID="user", + PWD="password", + ) + + conn_str = _build_pyodbc_connection_string(credentials) + + assert conn_str.startswith("DRIVER={ODBC Driver 18 for SQL Server};") + assert "DRIVER={{ODBC Driver 18 for SQL Server}}" not in conn_str @pytest.mark.parametrize( @@ -97,6 +239,84 @@ def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> assert bool_to_connection_string_arg(key, value) == expected +def test_is_mssql_python_backend() -> None: + assert is_mssql_python_backend(SQLServerBackend.mssql_python) is True + assert is_mssql_python_backend(SQLServerBackend.pyodbc) is False + + +def test_connection_keys_do_not_mutate_authentication() -> None: + credentials = SQLServerCredentials( + backend=SQLServerBackend.pyodbc, + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + authentication="serviceprincipal", + ) + + original_authentication = credentials.authentication + + credentials._connection_keys() + + assert credentials.authentication == original_authentication + + +def test_connection_keys_include_driver_only_for_pyodbc() -> None: + pyodbc_credentials = SQLServerCredentials( + backend=SQLServerBackend.pyodbc, + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + ) + mssql_python_credentials = SQLServerCredentials( + backend=SQLServerBackend.mssql_python, + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + ) + + assert "driver" in pyodbc_credentials._connection_keys() + assert "driver" not in mssql_python_credentials._connection_keys() + assert "windows_login" in pyodbc_credentials._connection_keys() + assert "windows_login" in mssql_python_credentials._connection_keys() + + +def test_is_pyodbc_handle_false_for_mssql_python_handle() -> None: + handle = type("Handle", (), {"driver_type": "mssql-python"})() + assert _is_pyodbc_handle(handle) is False + + +@pytest.mark.parametrize( + "authentication, expected", + [ + ("cli", True), + ("environment", True), + ("auto", True), + ("serviceprincipal", True), + ("msi", True), + ("ActiveDirectoryAccessToken", True), + ("ActiveDirectoryServicePrincipal", False), + ("ActiveDirectoryDefault", False), + ], +) +def test_uses_aad_token_authentication_matches_pyodbc_token_aliases( + authentication: str, expected: bool +) -> None: + credentials = SQLServerCredentials( + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + authentication=authentication, + ) + + assert uses_aad_token_authentication(credentials) is expected + + @pytest.mark.parametrize( "input_auth, expected", [ @@ -118,18 +338,281 @@ def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> ], ) def test_normalize_mssql_python_authentication(input_auth: str, expected: str) -> None: - assert _normalize_mssql_python_authentication(input_auth) == expected + assert normalize_mssql_python_authentication(input_auth) == expected def test_escape_connection_string_value_quotes_only_when_needed() -> None: - assert sqlserver_connections._escape_connection_string_value("plain") == "plain" - assert ( - sqlserver_connections._escape_connection_string_value("contains;semicolon") - == "{contains;semicolon}" + assert escape_connection_string_value("plain") == "plain" + assert escape_connection_string_value("contains;semicolon") == "{contains;semicolon}" + assert escape_connection_string_value("brace}") == "{brace}}}" + assert escape_connection_string_value(" leading") == "{ leading}" + assert escape_connection_string_value("trailing ") == "{trailing }" + + +def test_sanitize_connection_string_for_logging_redacts_common_secret_fields() -> None: + sanitized = sanitize_connection_string_for_logging( + "SERVER=fake;UID=user@example.com;User Id=another@example.com;" + "PWD=password;Password=hello;ClientSecret=mysecret;ACCESS_TOKEN=token123" + ) + + assert "PWD=***" in sanitized + assert "Password=***" in sanitized + assert "ClientSecret=***" in sanitized + assert "ACCESS_TOKEN=***" in sanitized + assert "UID=***" in sanitized + assert "User Id=***" in sanitized + + +def test_sanitize_connection_string_for_logging_handles_braced_values() -> None: + sanitized = sanitize_connection_string_for_logging( + "SERVER=fake;PWD={token;with=separators};ClientSecret={secret;value};UID=user" + ) + + assert "PWD=***" in sanitized + assert "ClientSecret=***" in sanitized + assert "UID=***" in sanitized + + +def test_sanitize_connection_string_for_logging_trims_whitespace_around_segments() -> None: + sanitized = sanitize_connection_string_for_logging( + " SERVER=fake.sql.sqlserver.net ; UID = user@example.com ; PWD = password ; " + ) + + assert sanitized == "SERVER=fake.sql.sqlserver.net;UID=***;PWD=***" + + +def test_sanitize_connection_string_for_logging_treats_unterminated_brace_as_literal() -> None: + sanitized = sanitize_connection_string_for_logging("SERVER=fake;PWD={token;APP=foo") + + assert sanitized == "SERVER=fake;PWD=***;APP=foo" + + +def test_sanitize_connection_string_for_logging_preserves_non_secret_auth_metadata() -> None: + sanitized = sanitize_connection_string_for_logging( + "SERVER=fake;Authentication=sql;Auth=sql;NonToken=literal;PWD=password" + ) + + assert "Authentication=sql" in sanitized + assert "Auth=sql" in sanitized + assert "NonToken=literal" in sanitized + assert "PWD=***" in sanitized + + +def test_finalize_connection_handle_warns_when_timeout_is_unsupported( + monkeypatch: pytest.MonkeyPatch, +) -> None: + handle = object() + reset_runtime_state_for_test() + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + query_timeout=30, + ) + warnings: list[str] = [] + + monkeypatch.setattr( + sqlserver_backend.logger, + "warning", + lambda message, *args: warnings.append(message % args if args else message), + ) + + result = _finalize_mssql_python_handle(handle, credentials) + second_result = _finalize_mssql_python_handle(handle, credentials) + + assert result is handle + assert second_result is handle + assert len(warnings) == 1 + assert any("query_timeout=30" in message for message in warnings) + + +def test_finalize_connection_handle_ignores_missing_timeout_attribute() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + query_timeout=30, + ) + + handle = object() + + assert _finalize_connection_handle(handle, credentials) is handle + + +def test_finalize_connection_handle_coerces_string_query_timeout() -> None: + class Handle: + def __init__(self) -> None: + self.timeout = None + + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + query_timeout=0, + ) + credentials.query_timeout = "23" + + handle = Handle() + + assert _finalize_connection_handle(handle, credentials) is handle + assert handle.timeout == 23 + + +def test_validate_connection_requirements_rejects_negative_query_timeout() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + ) + credentials.query_timeout = -1 + + with pytest.raises(DbtRuntimeError, match="query_timeout"): + validate_connection_requirements(credentials) + + +def test_finalize_connection_handle_propagates_non_attribute_errors() -> None: + class BrokenHandle: + @property + def timeout(self) -> None: + return None + + @timeout.setter + def timeout(self, value: object) -> None: + raise TypeError("boom") + + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + query_timeout=30, + ) + + with pytest.raises(TypeError, match="boom"): + _finalize_connection_handle(BrokenHandle(), credentials) + + +def test_exception_handler_preserves_unknown_exceptions( + monkeypatch: pytest.MonkeyPatch, +) -> None: + reset_runtime_state_for_test() + + manager = object.__new__(SQLServerConnectionManager) + credentials = SQLServerCredentials( + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + ) + release_calls: list[int] = [] + + monkeypatch.setattr( + manager, + "get_thread_connection", + lambda: SimpleNamespace(credentials=credentials), ) - assert sqlserver_connections._escape_connection_string_value("brace}") == "{brace}}}" - assert sqlserver_connections._escape_connection_string_value(" leading") == "{ leading}" - assert sqlserver_connections._escape_connection_string_value("trailing ") == "{trailing }" + monkeypatch.setattr(manager, "release", lambda: release_calls.append(1)) + debug_messages: list[str] = [] + monkeypatch.setattr( + sqlserver_connections.logger, + "debug", + lambda message, *args: debug_messages.append(message % args if args else message), + ) + + with pytest.raises(TypeError, match="boom"): + with manager.exception_handler("select 1"): + raise TypeError("boom") + + assert release_calls == [1] + assert any("TypeError" in message for message in debug_messages) + + +def test_exception_handler_routes_backend_database_errors_without_falling_through( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class BackendDatabaseError(Exception): + pass + + reset_runtime_state_for_test() + + manager = object.__new__(SQLServerConnectionManager) + credentials = SQLServerCredentials( + backend=SQLServerBackend.pyodbc, + driver="ODBC Driver 18 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + ) + release_calls: list[int] = [] + handler_calls: list[tuple[str, str]] = [] + debug_messages: list[str] = [] + + reset_runtime_state_for_test() + configure_runtime_state_for_test( + pyodbc_module=SimpleNamespace(DatabaseError=BackendDatabaseError) + ) + + try: + monkeypatch.setattr( + manager, + "get_thread_connection", + lambda: SimpleNamespace(credentials=credentials), + ) + monkeypatch.setattr(manager, "release", lambda: release_calls.append(1)) + + def fake_handle_backend_database_error( + error: Exception, + database_error: type[Exception] | None, + release_connection: Any, + ) -> None: + handler_calls.append( + ( + type(error).__name__, + database_error.__name__ if database_error else "", + ) + ) + release_connection() + raise DbtDatabaseError(str(error).strip()) from error + + monkeypatch.setattr( + sqlserver_connections, + "handle_backend_database_error", + fake_handle_backend_database_error, + ) + monkeypatch.setattr( + sqlserver_connections.logger, + "debug", + lambda message, *args: debug_messages.append(message % args if args else message), + ) + + with pytest.raises(DbtDatabaseError, match="boom"): + with manager.exception_handler("select 1"): + raise BackendDatabaseError("boom") + + assert handler_calls == [("BackendDatabaseError", "BackendDatabaseError")] + assert release_calls == [1] + assert all("Rolling back transaction." not in message for message in debug_messages) + assert all("Error running SQL:" not in message for message in debug_messages) + finally: + reset_runtime_state_for_test() + + +def test_data_type_code_to_name_handles_repr_and_rejects_integer_codes() -> None: + assert SQLServerConnectionManager.data_type_code_to_name("") == "varchar" + assert SQLServerConnectionManager.data_type_code_to_name("int") == "int" + + with pytest.raises(DbtRuntimeError, match="integer type codes are not mapped"): + SQLServerConnectionManager.data_type_code_to_name(7) def test_mssql_python_active_directory_default_passes() -> None: @@ -149,6 +632,26 @@ def test_mssql_python_active_directory_default_passes() -> None: assert "Authentication=ActiveDirectoryDefault" in conn_str +def test_mssql_python_connection_string_does_not_append_pyodbc_retry_hints() -> None: + credentials = SQLServerCredentials( + driver=None, + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + encrypt=True, + trust_cert=True, + backend=SQLServerBackend.mssql_python, + authentication="sql", + UID="user", + PWD="password", + ) + + conn_str = _build_mssql_python_connection_string(credentials) + + assert "ConnectRetryCount=3" not in conn_str + assert "ConnectRetryInterval=10" not in conn_str + + def test_mssql_python_device_code_authentication() -> None: credentials = SQLServerCredentials( driver=None, @@ -324,10 +827,10 @@ def test_mssql_python_supported_authentication_modes() -> None: authentication=authentication, ) - _validate_mssql_python_requirements(credentials) + validate_mssql_python_requirements(credentials) -def test_open_with_mssql_python_system_assigned_msi_passes_connection_string( +def test_open_with_mssql_python_backend_system_assigned_msi_passes_connection_string( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -349,32 +852,16 @@ def fake_connect(connection_string, autocommit, timeout): captured["timeout"] = timeout return FakeHandle() - fake_module = SimpleNamespace( - connect=fake_connect, - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - InternalError=type("InternalError", (Exception,), {}), - ) - - def fake_retry_connection( - cls, - connection, - connect, - logger, - retry_limit, - retryable_exceptions, - ): - handle = connect() - connection.handle = handle - connection.state = ConnectionState.OPEN - return connection + fake_module = _fake_mssql_python_module(fake_connect) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test( + mssql_python_module=fake_module, mssql_python_import_error=None + ) monkeypatch.setattr( SQLServerConnectionManager, "retry_connection", - classmethod(fake_retry_connection), + classmethod(_fake_retry_connection_stub()), ) connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) @@ -393,36 +880,45 @@ def test_adapter_module_import_does_not_import_optional_backends( original_import = builtins.__import__ def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): - if name in {"pyodbc", "mssql_python", "azure.identity", "azure.core.credentials"}: + if name in { + "pyodbc", + "mssql_python", + "azure.identity", + "azure.core.credentials", + }: raise AssertionError(f"unexpected import: {name}") return original_import(name, globals, locals, fromlist, level) + reset_runtime_state_for_test() monkeypatch.setattr(builtins, "__import__", guarded_import) importlib.reload(sqlserver_connections) - assert sqlserver_connections._PYODBC_MODULE is None - assert sqlserver_connections._MSSQL_PYTHON_MODULE is None + runtime_state = get_runtime_state_for_test() + assert runtime_state.pyodbc_module is None + assert runtime_state.mssql_python_module is None def test_get_pyodbc_imports_only_pyodbc(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", None, raising=False) - monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() original_import = builtins.__import__ def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): if name in {"mssql_python", "azure.identity", "azure.core.credentials"}: raise AssertionError(f"unexpected import: {name}") + if name == "pyodbc": + return MagicMock() return original_import(name, globals, locals, fromlist, level) monkeypatch.setattr(builtins, "__import__", guarded_import) - module = sqlserver_connections._get_pyodbc() + module = dbt.adapters.sqlserver.sqlserver_runtime._get_pyodbc() assert module is not None -def test_get_mssql_python_imports_only_mssql_python(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) +def test_get_mssql_python_imports_only_mssql_python( + monkeypatch: pytest.MonkeyPatch, +) -> None: + reset_runtime_state_for_test() original_import = builtins.__import__ def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): @@ -432,43 +928,160 @@ def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): monkeypatch.setattr(builtins, "__import__", guarded_import) - module = sqlserver_connections._get_mssql_python() + module = dbt.adapters.sqlserver.sqlserver_runtime._get_mssql_python() assert module is not None def test_get_pyodbc_returns_cached_module(monkeypatch: pytest.MonkeyPatch) -> None: fake_pyodbc = SimpleNamespace(name="cached-pyodbc") - monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc, raising=False) - monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test(pyodbc_module=fake_pyodbc, pyodbc_import_error=None) def fail_import(*args, **kwargs): raise AssertionError("pyodbc import should not run when cached") monkeypatch.setattr(builtins, "__import__", fail_import) - assert sqlserver_connections._get_pyodbc() is fake_pyodbc - assert sqlserver_connections._get_pyodbc() is fake_pyodbc + assert dbt.adapters.sqlserver.sqlserver_runtime._get_pyodbc() is fake_pyodbc + assert dbt.adapters.sqlserver.sqlserver_runtime._get_pyodbc() is fake_pyodbc + + +def test_reset_runtime_state_for_test_clears_cached_modules() -> None: + configure_runtime_state_for_test( + pyodbc_module=SimpleNamespace(name="cached-pyodbc"), + pyodbc_import_error=ModuleNotFoundError("No module named 'pyodbc'"), + mssql_python_module=SimpleNamespace(name="cached-mssql-python"), + mssql_python_import_error=ModuleNotFoundError("No module named 'mssql_python'"), + azure_identity_module=SimpleNamespace(name="cached-azure-identity"), + azure_identity_import_error=ModuleNotFoundError("No module named 'azure.identity'"), + azure_credentials_module=SimpleNamespace(name="cached-azure-creds"), + azure_credentials_import_error=ModuleNotFoundError( + "No module named 'azure.core.credentials'" + ), + access_token_cache={ + ("cli", "scope", "profile"): SimpleNamespace(token="token", expires_on=0) + }, + timeout_warning_logged=True, + ) + + reset_runtime_state_for_test() + + runtime_state = get_runtime_state_for_test() + assert runtime_state.pyodbc_module is None + assert runtime_state.pyodbc_import_error is None + assert runtime_state.mssql_python_module is None + assert runtime_state.mssql_python_import_error is None + assert runtime_state.azure_identity_module is None + assert runtime_state.azure_identity_import_error is None + assert runtime_state.azure_credentials_module is None + assert runtime_state.azure_credentials_import_error is None + assert runtime_state.access_token_cache == {} + assert runtime_state.timeout_warning_logged is False + + +def test_get_pyodbc_attrs_before_credentials_caches_tokens_per_profile( + monkeypatch: pytest.MonkeyPatch, +) -> None: + reset_runtime_state_for_test() + calls: list[str | None] = [] + + def fake_access_token(credentials: SQLServerCredentials, scope: str) -> SimpleNamespace: + calls.append(credentials.client_id) + return SimpleNamespace(token=f"token-{credentials.client_id}", expires_on=9999999999) + + monkeypatch.setitem( + dbt.adapters.sqlserver.sqlserver_auth.AZURE_AUTH_FUNCTIONS, + "cli", + fake_access_token, + ) + + first = SQLServerCredentials( + driver="ODBC Driver 17 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + authentication="cli", + client_id="one", + ) + second = SQLServerCredentials( + driver="ODBC Driver 17 for SQL Server", + host="fake.sql.sqlserver.net", + database="dbt", + schema="sqlserver", + authentication="cli", + client_id="two", + ) + + first_attrs = get_pyodbc_attrs_before_credentials(first) + second_attrs = get_pyodbc_attrs_before_credentials(second) + + assert len(calls) == 2 + assert first_attrs != second_attrs + + +def test_get_pyodbc_attrs_before_credentials_ignores_high_cardinality_fields( + monkeypatch: pytest.MonkeyPatch, +) -> None: + reset_runtime_state_for_test() + calls: list[str | None] = [] + + def fake_access_token(credentials: SQLServerCredentials, scope: str) -> SimpleNamespace: + calls.append(credentials.client_id) + return SimpleNamespace(token=f"token-{credentials.client_id}", expires_on=9999999999) + + monkeypatch.setitem( + dbt.adapters.sqlserver.sqlserver_auth.AZURE_AUTH_FUNCTIONS, + "cli", + fake_access_token, + ) + + first = SQLServerCredentials( + driver="ODBC Driver 17 for SQL Server", + host="first.sql.sqlserver.net", + database="dbt_a", + schema="schema_a", + authentication="cli", + client_id="shared-client-id", + client_secret="secret-one", + ) + second = SQLServerCredentials( + driver="ODBC Driver 17 for SQL Server", + host="second.sql.sqlserver.net", + database="dbt_b", + schema="schema_b", + authentication="cli", + client_id="shared-client-id", + client_secret="secret-two", + ) + + first_attrs = get_pyodbc_attrs_before_credentials(first) + second_attrs = get_pyodbc_attrs_before_credentials(second) + + assert len(calls) == 1 + assert first_attrs == second_attrs -def test_get_mssql_python_returns_cached_module(monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_mssql_python_returns_cached_module( + monkeypatch: pytest.MonkeyPatch, +) -> None: fake_mssql_python = SimpleNamespace(name="cached-mssql-python") - monkeypatch.setattr( - sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_mssql_python, raising=False + reset_runtime_state_for_test() + configure_runtime_state_for_test( + mssql_python_module=fake_mssql_python, + mssql_python_import_error=None, ) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) def fail_import(*args, **kwargs): raise AssertionError("mssql_python import should not run when cached") monkeypatch.setattr(builtins, "__import__", fail_import) - assert sqlserver_connections._get_mssql_python() is fake_mssql_python - assert sqlserver_connections._get_mssql_python() is fake_mssql_python + assert dbt.adapters.sqlserver.sqlserver_runtime._get_mssql_python() is fake_mssql_python + assert dbt.adapters.sqlserver.sqlserver_runtime._get_mssql_python() is fake_mssql_python def test_get_pyodbc_raises_only_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", None, raising=False) - monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() original_import = builtins.__import__ def missing_pyodbc(name, globals=None, locals=None, fromlist=(), level=0): @@ -479,12 +1092,13 @@ def missing_pyodbc(name, globals=None, locals=None, fromlist=(), level=0): monkeypatch.setattr(builtins, "__import__", missing_pyodbc) with pytest.raises(DbtRuntimeError, match="pyodbc"): - sqlserver_connections._get_pyodbc() + dbt.adapters.sqlserver.sqlserver_runtime._get_pyodbc() -def test_get_mssql_python_raises_only_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) +def test_get_mssql_python_raises_only_when_requested( + monkeypatch: pytest.MonkeyPatch, +) -> None: + reset_runtime_state_for_test() original_import = builtins.__import__ def missing_mssql_python(name, globals=None, locals=None, fromlist=(), level=0): @@ -495,10 +1109,10 @@ def missing_mssql_python(name, globals=None, locals=None, fromlist=(), level=0): monkeypatch.setattr(builtins, "__import__", missing_mssql_python) with pytest.raises(DbtRuntimeError, match="mssql-python"): - sqlserver_connections._get_mssql_python() + dbt.adapters.sqlserver.sqlserver_runtime._get_mssql_python() -def test_open_with_mssql_python_feature_flag_requires_optional_dependency( +def test_open_with_mssql_python_backend_requires_optional_dependency( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch ) -> None: credentials.driver = None @@ -506,19 +1120,69 @@ def test_open_with_mssql_python_feature_flag_requires_optional_dependency( connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) - monkeypatch.setattr( - sqlserver_connections, - "_MSSQL_PYTHON_IMPORT_ERROR", - ModuleNotFoundError("No module named 'mssql_python'"), - raising=False, + reset_runtime_state_for_test() + configure_runtime_state_for_test( + mssql_python_module=None, + mssql_python_import_error=ModuleNotFoundError("No module named 'mssql_python'"), ) with pytest.raises(DbtRuntimeError, match="mssql-python"): SQLServerConnectionManager.open(connection) -def test_open_with_mssql_python_feature_flag_builds_connection_without_odbc_driver( +def _fake_retry_connection_stub( + captured: Dict[str, Any] | None = None, +): + def fake_retry_connection( + cls, + connection, + connect, + logger, + retry_limit, + retryable_exceptions, + ): + if captured is not None: + captured["retry_limit"] = retry_limit + captured["retryable_exceptions"] = retryable_exceptions + handle = connect() + connection.handle = handle + connection.state = ConnectionState.OPEN + return connection + + return fake_retry_connection + + +def _fake_mssql_python_module( + connect, + pooling=None, +): + if pooling is None: + + def pooling(*args, **kwargs): + return None + + module = { + "connect": connect, + "OperationalError": type("OperationalError", (Exception,), {}), + "InterfaceError": type("InterfaceError", (Exception,), {}), + "InternalError": type("InternalError", (Exception,), {}), + } + if pooling is not None: + module["pooling"] = pooling + return SimpleNamespace(**module) + + +def _fake_pyodbc_module(connect): + return SimpleNamespace( + connect=connect, + pooling=False, + OperationalError=type("OperationalError", (Exception,), {}), + InterfaceError=type("InterfaceError", (Exception,), {}), + InternalError=type("InternalError", (Exception,), {}), + ) + + +def test_open_with_mssql_python_backend_enables_pooling( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch ) -> None: credentials.driver = None @@ -546,38 +1210,25 @@ def fake_connect(connection_string, autocommit, timeout): captured["timeout"] = timeout return fake_handle - def fake_pooling(*, enabled): - pooling_calls.append({"enabled": enabled}) - - fake_module = SimpleNamespace( - connect=fake_connect, - pooling=fake_pooling, - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - InternalError=type("InternalError", (Exception,), {}), - ) + def fake_pooling(max_size=100, idle_timeout=600, enabled=True): + pooling_calls.append( + { + "max_size": max_size, + "idle_timeout": idle_timeout, + "enabled": enabled, + } + ) - def fake_retry_connection( - cls, - connection, - connect, - logger, - retry_limit, - retryable_exceptions, - ): - captured["retry_limit"] = retry_limit - captured["retryable_exceptions"] = retryable_exceptions - handle = connect() - connection.handle = handle - connection.state = ConnectionState.OPEN - return connection + fake_module = _fake_mssql_python_module(fake_connect, fake_pooling) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test( + mssql_python_module=fake_module, mssql_python_import_error=None + ) monkeypatch.setattr( SQLServerConnectionManager, "retry_connection", - classmethod(fake_retry_connection), + classmethod(_fake_retry_connection_stub(captured)), ) connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) @@ -590,7 +1241,14 @@ def fake_retry_connection( assert captured["autocommit"] is True assert captured["timeout"] == 17 assert captured["retry_limit"] == 5 - assert pooling_calls == [] + assert pooling_calls == [ + { + "max_size": 100, + "idle_timeout": 600, + "enabled": True, + } + ] + assert fake_handle.timeout == 23 con_str = captured["connection_string"] assert "DRIVER=" not in con_str @@ -605,26 +1263,20 @@ def fake_retry_connection( assert fake_module.OperationalError in captured["retryable_exceptions"] assert fake_module.InternalError in captured["retryable_exceptions"] - assert pooling_calls == [] - assert "APP=dbt-sqlserver/" not in con_str - -def test_open_with_mssql_python_feature_flag_fails_fast_for_pyodbc_token_auth_aliases( +def test_open_with_mssql_python_backend_fails_fast_for_pyodbc_token_auth_aliases( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch ) -> None: credentials.driver = None credentials.backend = SQLServerBackend.mssql_python credentials.authentication = "cli" - fake_module = SimpleNamespace( - connect=lambda *args, **kwargs: None, - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - InternalError=type("InternalError", (Exception,), {}), - ) + fake_module = _fake_mssql_python_module(lambda *args, **kwargs: None) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test( + mssql_python_module=fake_module, mssql_python_import_error=None + ) connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) @@ -647,15 +1299,12 @@ def test_open_with_mssql_python_unsupported_authentications( credentials.UID = "dbt_user" credentials.PWD = "super-secret" - fake_module = SimpleNamespace( - connect=lambda *args, **kwargs: None, - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - InternalError=type("InternalError", (Exception,), {}), - ) + fake_module = _fake_mssql_python_module(lambda *args, **kwargs: None) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test( + mssql_python_module=fake_module, mssql_python_import_error=None + ) connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) @@ -667,7 +1316,7 @@ def test_open_with_mssql_python_unsupported_authentications( "authentication", ["msi", "ActiveDirectoryMSI"], ) -def test_open_with_mssql_python_supported_managed_identity_auth( +def test_open_with_mssql_python_backend_supported_managed_identity_auth( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch, authentication: str, @@ -692,32 +1341,16 @@ def fake_connect(connection_string, autocommit, timeout): captured["timeout"] = timeout return FakeHandle() - fake_module = SimpleNamespace( - connect=fake_connect, - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - InternalError=type("InternalError", (Exception,), {}), - ) - - def fake_retry_connection( - cls, - connection, - connect, - logger, - retry_limit, - retryable_exceptions, - ): - handle = connect() - connection.handle = handle - connection.state = ConnectionState.OPEN - return connection + fake_module = _fake_mssql_python_module(fake_connect) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", fake_module, raising=False) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test( + mssql_python_module=fake_module, mssql_python_import_error=None + ) monkeypatch.setattr( SQLServerConnectionManager, "retry_connection", - classmethod(fake_retry_connection), + classmethod(_fake_retry_connection_stub()), ) connection = Connection(type="sqlserver", name="feature-flag-test", credentials=credentials) @@ -757,8 +1390,8 @@ def test_open_requires_host_database_schema( InterfaceError=type("InterfaceError", (Exception,), {}), ) - monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc, raising=False) - monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test(pyodbc_module=fake_pyodbc, pyodbc_import_error=None) connection = Connection(type="sqlserver", name="pyodbc-test", credentials=credentials) @@ -766,26 +1399,64 @@ def test_open_requires_host_database_schema( SQLServerConnectionManager.open(connection) -def test_open_with_pyodbc_path_still_requires_driver_when_feature_flag_disabled( +def test_open_with_pyodbc_backend_still_requires_driver( credentials: SQLServerCredentials, monkeypatch: pytest.MonkeyPatch, ) -> None: credentials.driver = None credentials.backend = SQLServerBackend.pyodbc - fake_pyodbc = SimpleNamespace( - connect=lambda *args, **kwargs: None, - pooling=False, - OperationalError=type("OperationalError", (Exception,), {}), - InterfaceError=type("InterfaceError", (Exception,), {}), - InternalError=type("InternalError", (Exception,), {}), - ) + fake_pyodbc = _fake_pyodbc_module(lambda *args, **kwargs: None) - monkeypatch.setattr(sqlserver_connections, "_PYODBC_MODULE", fake_pyodbc, raising=False) - monkeypatch.setattr(sqlserver_connections, "_PYODBC_IMPORT_ERROR", None, raising=False) + reset_runtime_state_for_test() + configure_runtime_state_for_test(pyodbc_module=fake_pyodbc, pyodbc_import_error=None) connection = Connection(type="sqlserver", name="pyodbc-test", credentials=credentials) - monkeypatch.setattr(sqlserver_connections, "_MSSQL_PYTHON_MODULE", None, raising=False) + configure_runtime_state_for_test(mssql_python_module=None) with pytest.raises(DbtRuntimeError, match="driver"): SQLServerConnectionManager.open(connection) + + +def test_open_with_pyodbc_backend_enables_driver_pooling( + credentials: SQLServerCredentials, + monkeypatch: pytest.MonkeyPatch, +) -> None: + credentials.backend = SQLServerBackend.pyodbc + credentials.encrypt = True + credentials.trust_cert = True + credentials.UID = "dbt_user" + credentials.PWD = "super-secret" + + captured: Dict[str, Any] = {} + + class FakeHandle: + def __init__(self): + self.timeout = None + + def fake_connect(connection_string, attrs_before, autocommit, timeout): + captured["connection_string"] = connection_string + captured["attrs_before"] = attrs_before + captured["autocommit"] = autocommit + captured["timeout"] = timeout + return FakeHandle() + + fake_pyodbc = _fake_pyodbc_module(fake_connect) + + reset_runtime_state_for_test() + configure_runtime_state_for_test(pyodbc_module=fake_pyodbc, pyodbc_import_error=None) + monkeypatch.setattr( + SQLServerConnectionManager, + "retry_connection", + classmethod(_fake_retry_connection_stub()), + ) + + connection = Connection(type="sqlserver", name="pyodbc-test", credentials=credentials) + opened = SQLServerConnectionManager.open(connection) + + assert opened is connection + assert opened.state == ConnectionState.OPEN + assert fake_pyodbc.pooling is True + assert captured["autocommit"] is True + assert captured["timeout"] == credentials.login_timeout + assert "Pooling=true" in captured["connection_string"] diff --git a/uv.lock b/uv.lock index 0162e820e..0a2e4614e 100644 --- a/uv.lock +++ b/uv.lock @@ -585,6 +585,8 @@ azure = [ { name = "azure-identity" }, ] mssql = [ + { name = "azure-core" }, + { name = "azure-identity" }, { name = "mssql-python" }, ] pyodbc = [ @@ -616,11 +618,13 @@ dev = [ [package.metadata] requires-dist = [ + { name = "azure-core", marker = "extra == 'mssql'", specifier = ">=1.0.0" }, { name = "azure-identity", marker = "extra == 'azure'", specifier = ">=1.12.0" }, + { name = "azure-identity", marker = "extra == 'mssql'", specifier = ">=1.12.0" }, { name = "dbt-adapters", specifier = ">=1.15.2,<2.0" }, { name = "dbt-common", specifier = ">=1.22.0,<2.0" }, { name = "dbt-core", specifier = ">=1.10.0,<1.11.0" }, - { name = "mssql-python", marker = "extra == 'mssql'", specifier = ">=1.4.0" }, + { name = "mssql-python", marker = "extra == 'mssql'", specifier = ">=1.7.1" }, { name = "pyodbc", specifier = ">=5.2.0" }, { name = "pyodbc", marker = "extra == 'pyodbc'", specifier = ">=5.2.0" }, ] @@ -635,7 +639,7 @@ dev = [ { name = "flaky" }, { name = "freezegun", specifier = ">=1.5.0,<2.0" }, { name = "ipdb" }, - { name = "mssql-python", specifier = ">=1.4.0" }, + { name = "mssql-python", specifier = ">=1.7.1" }, { name = "mypy", specifier = "==1.11.2" }, { name = "pre-commit" }, { name = "pyodbc", specifier = ">=5.2.0" }, From d4cebee4c61f15f0dcb7dbe04f61c6fcd0086265 Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Tue, 26 May 2026 21:21:49 +0000 Subject: [PATCH 8/8] =?UTF-8?q?=E2=9C=A8=20feat(sqlserver):=20update=20int?= =?UTF-8?q?egration=20tests=20matrix=20for=20faster=20review?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../workflows/integration-tests-sqlserver.yml | 63 ++++++++++++++----- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/.github/workflows/integration-tests-sqlserver.yml b/.github/workflows/integration-tests-sqlserver.yml index d931dec4a..77c5f14a7 100644 --- a/.github/workflows/integration-tests-sqlserver.yml +++ b/.github/workflows/integration-tests-sqlserver.yml @@ -35,31 +35,66 @@ on: # yamllint disable-line rule:truthy jobs: integration-tests-sql-server: - name: Regular ${{ matrix.backend }} + name: ${{ matrix.backend }} / Py${{ matrix.python_version }} / SQL${{ matrix.sqlserver_version }} / ODBC${{ matrix.msodbc_version }} / ${{ matrix.collation }} if: github.actor != 'dependabot[bot]' strategy: + # Smaller matrix allows this to check if a failure is specific. + fail-fast: false matrix: python_version: ["3.10", "3.11", "3.12", "3.13"] - msodbc_version: ["17", "18"] - sqlserver_version: ["2017", "2019", "2022"] - collation: ["SQL_Latin1_General_CP1_CS_AS", "SQL_Latin1_General_CP1_CI_AS"] - backend: - - pyodbc - - mssql-python + backend: [pyodbc, mssql-python] + # Baseline on 2022 + sqlserver_version: ["2022"] + msodbc_version: ["18"] + collation: [SQL_Latin1_General_CP1_CI_AS] exclude: - - backend: pyodbc - python_version: "3.12" - - backend: pyodbc - python_version: "3.13" - backend: mssql-python python_version: "3.10" + sqlserver_version: "2022" - backend: mssql-python python_version: "3.11" + sqlserver_version: "2022" + - backend: mssql-python + python_version: "3.12" + sqlserver_version: "2022" include: + # Keep pyodbc on every supported Python version, but retain + # SQL Server ODBC 17 coverage for the oldest and newest Python. + - backend: pyodbc + python_version: "3.10" + sqlserver_version: "2022" + msodbc_version: "17" + collation: SQL_Latin1_General_CP1_CI_AS - backend: pyodbc - install_extra: pyodbc + python_version: "3.13" + sqlserver_version: "2022" + msodbc_version: "17" + collation: SQL_Latin1_General_CP1_CI_AS + # Older SQL Server versions stay on pyodbc only, with a single + # collation so the matrix stays small and readable. + - backend: pyodbc + python_version: "3.10" + sqlserver_version: "2017" + msodbc_version: "17" + collation: SQL_Latin1_General_CP1_CI_AS + - backend: pyodbc + python_version: "3.13" + sqlserver_version: "2019" + msodbc_version: "17" + collation: SQL_Latin1_General_CP1_CI_AS + # Add the case-sensitive collation only on the latest SQL Server + # and latest Python/backend rows. + - backend: pyodbc + python_version: "3.13" + sqlserver_version: "2022" + msodbc_version: "17" + collation: SQL_Latin1_General_CP1_CS_AS + # mssql-python stays on the latest Python only. - backend: mssql-python - install_extra: mssql + python_version: "3.13" + sqlserver_version: "2022" + msodbc_version: "18" + collation: SQL_Latin1_General_CP1_CS_AS runs-on: ubuntu-latest container: image: ghcr.io/${{ github.repository }}:CI-${{ matrix.python_version }}-msodbc${{ matrix.msodbc_version }} @@ -84,7 +119,7 @@ jobs: - name: Install dependencies env: - INSTALL_EXTRA: ${{ matrix.install_extra }} + INSTALL_EXTRA: ${{ matrix.backend == 'pyodbc' && 'pyodbc' || 'mssql' }} run: uv pip install --system -e ".[$INSTALL_EXTRA]" --group dev - name: Run functional tests