From 55191ade617f812f7c02ccb14ae3d19f0c2e29e9 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 1 Jul 2026 17:38:09 +0000 Subject: [PATCH 1/2] feat(pymysql): add Cloud SQL connector support --- docs/reference/adapters/pymysql.rst | 29 +++ sqlspec/adapters/pymysql/config.py | 119 ++++++++++- sqlspec/adapters/pymysql/core.py | 1 + sqlspec/adapters/pymysql/pool.py | 9 +- .../test_pymysql/test_cloud_sql_connector.py | 194 ++++++++++++++++++ 5 files changed, 350 insertions(+), 2 deletions(-) create mode 100644 tests/unit/adapters/test_pymysql/test_cloud_sql_connector.py diff --git a/docs/reference/adapters/pymysql.rst b/docs/reference/adapters/pymysql.rst index b5b83d08b..6aa5a8815 100644 --- a/docs/reference/adapters/pymysql.rst +++ b/docs/reference/adapters/pymysql.rst @@ -11,6 +11,35 @@ Configuration :members: :show-inheritance: +Cloud SQL Connector +=================== + +PyMySQL configs can use the in-process Google Cloud SQL Python Connector by +installing the ``cloud-sql`` extra and enabling the connector in +``driver_features``: + +.. code-block:: python + + from sqlspec.adapters.pymysql import PyMysqlConfig + + config = PyMysqlConfig( + connection_config={ + "user": "app-user", + "password": "secret", + "database": "app", + }, + driver_features={ + "enable_cloud_sql": True, + "cloud_sql_instance": "project:region:instance", + "cloud_sql_ip_type": "PRIVATE", + }, + ) + +When ``enable_cloud_sql`` is true, ``cloud_sql_instance`` is required and must +use ``project:region:instance`` format. Host, port, socket, and direct auth +connection values are passed through the connector rather than opened directly +by PyMySQL. + Driver ====== diff --git a/sqlspec/adapters/pymysql/config.py b/sqlspec/adapters/pymysql/config.py index 39359bb0c..42621af1c 100644 --- a/sqlspec/adapters/pymysql/config.py +++ b/sqlspec/adapters/pymysql/config.py @@ -12,8 +12,9 @@ from sqlspec.adapters.pymysql.pool import PyMysqlConnectionPool from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory -from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError from sqlspec.extensions.events import EventRuntimeHints +from sqlspec.typing import CLOUD_SQL_CONNECTOR_INSTALLED from sqlspec.utils.config_tools import normalize_connection_config if TYPE_CHECKING: @@ -119,6 +120,17 @@ class PyMysqlDriverFeatures(TypedDict): events_backend: Event channel backend selection. enable_local_infile_bulk_load: Route load_from_arrow through LOAD DATA LOCAL INFILE. Requires local_infile=True in connection_config. + enable_cloud_sql: Enable Google Cloud SQL connector integration. + Requires cloud-sql-python-connector package. + Defaults to False (explicit opt-in required). + cloud_sql_instance: Cloud SQL instance connection name. + Format: "project:region:instance" + Required when enable_cloud_sql is True. + cloud_sql_enable_iam_auth: Enable IAM database authentication. + Defaults to False for passwordless authentication. + cloud_sql_ip_type: IP address type for connection. + Options: "PUBLIC", "PRIVATE", "PSC" + Defaults to "PRIVATE". """ json_serializer: NotRequired["Callable[[Any], str]"] @@ -127,6 +139,62 @@ class PyMysqlDriverFeatures(TypedDict): enable_events: NotRequired[bool] events_backend: NotRequired[str] enable_local_infile_bulk_load: NotRequired[bool] + enable_cloud_sql: NotRequired[bool] + cloud_sql_instance: NotRequired[str] + cloud_sql_enable_iam_auth: NotRequired[bool] + cloud_sql_ip_type: NotRequired[str] + + +_CLOUD_SQL_DIRECT_CONNECTION_KEYS = frozenset(( + "bind_address", + "database", + "host", + "password", + "port", + "ssl", + "unix_socket", + "user", +)) + + +class _PyMysqlCloudSqlConnector: + __slots__ = ("_config", "_database", "_driver_kwargs", "_password", "_user") + + def __init__( + self, + config: "PyMysqlConfig", + user: str | None, + password: str | None, + database: str | None, + driver_kwargs: "dict[str, Any]", + ) -> None: + self._config = config + self._user = user + self._password = password + self._database = database + self._driver_kwargs = driver_kwargs + + def __call__(self) -> "PyMysqlConnection": + connector = self._config.get_cloud_sql_connector() + if connector is None: + msg = "Cloud SQL connector is not initialized" + raise ImproperConfigurationError(msg) + + conn_kwargs: dict[str, Any] = { + **self._driver_kwargs, + "instance_connection_string": self._config.driver_features["cloud_sql_instance"], + "driver": "pymysql", + "enable_iam_auth": self._config.driver_features.get("cloud_sql_enable_iam_auth", False), + "ip_type": self._config.driver_features.get("cloud_sql_ip_type", "PRIVATE"), + } + if self._user: + conn_kwargs["user"] = self._user + if self._password: + conn_kwargs["password"] = self._password + if self._database: + conn_kwargs["db"] = self._database + + return cast("PyMysqlConnection", connector.connect(**conn_kwargs)) class PyMysqlConnectionContext(SyncPoolConnectionContext): @@ -198,20 +266,69 @@ def __init__( **kwargs, ) + self._cloud_sql_connector: Any | None = None + self._validate_connector_config() + + def get_cloud_sql_connector(self) -> Any | None: + """Return the configured Cloud SQL connector instance.""" + return self._cloud_sql_connector + + def _validate_connector_config(self) -> None: + """Validate Google Cloud SQL connector configuration.""" + if not self.driver_features.get("enable_cloud_sql", False): + return + + if not CLOUD_SQL_CONNECTOR_INSTALLED: + raise MissingDependencyError(package="cloud-sql-python-connector", install_package="cloud-sql") + + instance = self.driver_features.get("cloud_sql_instance") + if not instance: + msg = "cloud_sql_instance required when enable_cloud_sql is True. Format: 'project:region:instance'" + raise ImproperConfigurationError(msg) + + cloud_sql_instance_parts_expected = 2 + if instance.count(":") != cloud_sql_instance_parts_expected: + msg = f"Invalid Cloud SQL instance format: {instance}. Expected format: 'project:region:instance'" + raise ImproperConfigurationError(msg) + + def _setup_cloud_sql_connector(self, config: "dict[str, Any]") -> "_PyMysqlCloudSqlConnector": + """Setup Cloud SQL connector and return a pool connection factory.""" + from google.cloud.sql.connector import Connector # type: ignore[import-untyped,unused-ignore] + + self._cloud_sql_connector = Connector() + + user = config.get("user") + password = config.get("password") + database = config.get("database") + + for key in _CLOUD_SQL_DIRECT_CONNECTION_KEYS: + config.pop(key, None) + + return _PyMysqlCloudSqlConnector(self, user, password, database, dict(config)) + def _create_pool(self) -> "PyMysqlConnectionPool": config = dict(self.connection_config) pool_recycle = config.pop("pool_recycle_seconds", 86400) health_check = config.pop("health_check_interval", 30.0) + connection_factory = None + if self.driver_features.get("enable_cloud_sql", False): + connection_factory = self._setup_cloud_sql_connector(config) return PyMysqlConnectionPool( config, recycle_seconds=pool_recycle, health_check_interval=health_check, on_connection_create=self._user_connection_hook, + connection_factory=connection_factory, ) def _close_pool(self) -> None: if self.connection_instance: self.connection_instance.close() + self.connection_instance = None + + if self._cloud_sql_connector is not None: + self._cloud_sql_connector.close() + self._cloud_sql_connector = None def create_connection(self) -> PyMysqlConnection: pool = self.provide_pool() diff --git a/sqlspec/adapters/pymysql/core.py b/sqlspec/adapters/pymysql/core.py index 4edb948c8..5824ee034 100644 --- a/sqlspec/adapters/pymysql/core.py +++ b/sqlspec/adapters/pymysql/core.py @@ -271,6 +271,7 @@ def apply_driver_features( features: dict[str, Any] = dict(driver_features) if driver_features else {} json_serializer = features.setdefault("json_serializer", to_json) json_deserializer = features.setdefault("json_deserializer", from_json) + features.setdefault("enable_cloud_sql", False) if json_serializer is not None: parameter_config = statement_config.parameter_config.with_json_serializers( diff --git a/sqlspec/adapters/pymysql/pool.py b/sqlspec/adapters/pymysql/pool.py index 2486772a9..351a25fd0 100644 --- a/sqlspec/adapters/pymysql/pool.py +++ b/sqlspec/adapters/pymysql/pool.py @@ -27,6 +27,7 @@ class PyMysqlConnectionPool: """Thread-local connection manager for PyMySQL.""" __slots__ = ( + "_connection_factory", "_connection_parameters", "_health_check_interval", "_on_connection_create", @@ -41,6 +42,7 @@ def __init__( recycle_seconds: int = 86400, health_check_interval: float = 30.0, on_connection_create: "Callable[[PyMysqlConnection], None] | None" = None, + connection_factory: "Callable[[], PyMysqlConnection] | None" = None, ) -> None: """Initialize the thread-local connection manager. @@ -49,8 +51,10 @@ def __init__( recycle_seconds: Connection recycle time in seconds (default 24h) health_check_interval: Seconds of idle time before running health check on_connection_create: Callback executed when connection is created + connection_factory: Optional factory for custom connection creation """ self._connection_parameters = connection_parameters + self._connection_factory = connection_factory self._thread_local = threading.local() self._recycle_seconds = recycle_seconds self._health_check_interval = health_check_interval @@ -63,7 +67,10 @@ def _database_name(self) -> str: return str(self._connection_parameters.get("database", "unknown")) def _create_connection(self) -> PyMysqlConnection: - connection = pymysql.connect(**self._connection_parameters) + if self._connection_factory is not None: + connection = self._connection_factory() + else: + connection = pymysql.connect(**self._connection_parameters) # Call user-provided callback after connection creation if self._on_connection_create is not None: diff --git a/tests/unit/adapters/test_pymysql/test_cloud_sql_connector.py b/tests/unit/adapters/test_pymysql/test_cloud_sql_connector.py new file mode 100644 index 000000000..ce7f6efb0 --- /dev/null +++ b/tests/unit/adapters/test_pymysql/test_cloud_sql_connector.py @@ -0,0 +1,194 @@ +"""Unit tests for PyMySQL Cloud SQL connector integration.""" + +import sys +from collections.abc import Generator +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from sqlspec.adapters.pymysql.config import PyMysqlConfig +from sqlspec.adapters.pymysql.pool import PyMysqlConnectionPool +from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError + +# pyright: reportPrivateUsage=false + + +@pytest.fixture(autouse=True) +def disable_cloud_sql_by_default() -> Generator[None, None, None]: + """Disable Cloud SQL by default for clean test state.""" + with patch("sqlspec.adapters.pymysql.config.CLOUD_SQL_CONNECTOR_INSTALLED", False, create=True): + yield + + +@pytest.fixture +def mock_cloud_sql_module() -> Generator[MagicMock, None, None]: + """Create and register mock google.cloud.sql module.""" + mock_connector_class = MagicMock() + mock_module = MagicMock() + mock_module.connector.Connector = mock_connector_class + + sys.modules["google.cloud.sql"] = mock_module + sys.modules["google.cloud.sql.connector"] = mock_module.connector + + yield mock_connector_class + + sys.modules.pop("google.cloud.sql", None) + sys.modules.pop("google.cloud.sql.connector", None) + + +def test_cloud_sql_defaults_to_false() -> None: + """Cloud SQL connector should require explicit opt-in.""" + config = PyMysqlConfig(connection_config={}) + + assert config.driver_features["enable_cloud_sql"] is False + + +def test_cloud_sql_explicit_disable_uses_direct_pymysql_path() -> None: + """Disabling Cloud SQL should leave the pool on the normal PyMySQL path.""" + config = PyMysqlConfig(connection_config={}, driver_features={"enable_cloud_sql": False}) + pool = config._create_pool() + + assert config.driver_features["enable_cloud_sql"] is False + assert pool._connection_factory is None + assert pool._connection_parameters["host"] == "localhost" + assert pool._connection_parameters["port"] == 3306 + + +def test_cloud_sql_missing_package_raises_error() -> None: + """Enabling Cloud SQL without the connector package should raise.""" + with pytest.raises(MissingDependencyError, match="cloud-sql-python-connector"): + PyMysqlConfig( + connection_config={}, + driver_features={"enable_cloud_sql": True, "cloud_sql_instance": "project:region:instance"}, + ) + + +def test_cloud_sql_missing_instance_raises_error() -> None: + """Cloud SQL requires an instance connection name.""" + with patch("sqlspec.adapters.pymysql.config.CLOUD_SQL_CONNECTOR_INSTALLED", True, create=True): + with pytest.raises( + ImproperConfigurationError, match="cloud_sql_instance required when enable_cloud_sql is True" + ): + PyMysqlConfig(connection_config={}, driver_features={"enable_cloud_sql": True}) + + +@pytest.mark.parametrize("instance", ["invalid-format", "project:region:instance:extra"]) +def test_cloud_sql_invalid_instance_format_raises_error(instance: str) -> None: + """Cloud SQL instance names must be project:region:instance.""" + with patch("sqlspec.adapters.pymysql.config.CLOUD_SQL_CONNECTOR_INSTALLED", True, create=True): + with pytest.raises(ImproperConfigurationError, match="Invalid Cloud SQL instance format"): + PyMysqlConfig( + connection_config={}, driver_features={"enable_cloud_sql": True, "cloud_sql_instance": instance} + ) + + +def test_cloud_sql_setup_strips_direct_connection_parameters(mock_cloud_sql_module: MagicMock) -> None: + """Cloud SQL setup should remove direct network/auth params from pool kwargs.""" + mock_connector = MagicMock() + mock_cloud_sql_module.return_value = mock_connector + + with patch("sqlspec.adapters.pymysql.config.CLOUD_SQL_CONNECTOR_INSTALLED", True, create=True): + config = PyMysqlConfig( + connection_config={ + "host": "127.0.0.1", + "port": 3307, + "unix_socket": "/cloudsql/project:region:instance", + "bind_address": "127.0.0.1", + "user": "testuser", + "password": "testpass", + "database": "testdb", + "autocommit": True, + "charset": "utf8mb4", + }, + driver_features={"enable_cloud_sql": True, "cloud_sql_instance": "project:region:instance"}, + ) + pool = config._create_pool() + + mock_cloud_sql_module.assert_called_once() + assert config.get_cloud_sql_connector() is mock_connector + assert pool._connection_factory is not None + assert "host" not in pool._connection_parameters + assert "port" not in pool._connection_parameters + assert "unix_socket" not in pool._connection_parameters + assert "bind_address" not in pool._connection_parameters + assert "user" not in pool._connection_parameters + assert "password" not in pool._connection_parameters + assert "database" not in pool._connection_parameters + assert pool._connection_parameters["autocommit"] is True + assert pool._connection_parameters["charset"] == "utf8mb4" + + +def test_cloud_sql_connection_factory_calls_connector(mock_cloud_sql_module: MagicMock) -> None: + """The pool factory should connect through google.cloud.sql.connector.""" + cloud_connection = MagicMock() + mock_connector = MagicMock() + mock_connector.connect.return_value = cloud_connection + mock_cloud_sql_module.return_value = mock_connector + + with patch("sqlspec.adapters.pymysql.config.CLOUD_SQL_CONNECTOR_INSTALLED", True, create=True): + config = PyMysqlConfig( + connection_config={"user": "testuser", "password": "testpass", "database": "testdb", "autocommit": True}, + driver_features={ + "enable_cloud_sql": True, + "cloud_sql_instance": "project:region:instance", + "cloud_sql_enable_iam_auth": True, + "cloud_sql_ip_type": "PUBLIC", + }, + ) + pool = config._create_pool() + + with patch("sqlspec.adapters.pymysql.pool.pymysql.connect") as mock_pymysql_connect: + connection = pool._create_connection() + + assert connection is cloud_connection + mock_pymysql_connect.assert_not_called() + mock_connector.connect.assert_called_once_with( + instance_connection_string="project:region:instance", + driver="pymysql", + enable_iam_auth=True, + ip_type="PUBLIC", + autocommit=True, + local_infile=False, + user="testuser", + password="testpass", + db="testdb", + ) + + +def test_pool_runs_connection_create_callback_after_direct_or_factory_paths() -> None: + """Connection creation callbacks should run after direct and factory creation.""" + direct_connection = MagicMock() + cloud_connection = MagicMock() + seen_connections: list[Any] = [] + + def on_connection_create(connection: Any) -> None: + seen_connections.append(connection) + + direct_pool = PyMysqlConnectionPool({"host": "localhost"}, on_connection_create=on_connection_create) + with patch("sqlspec.adapters.pymysql.pool.pymysql.connect", return_value=direct_connection): + assert direct_pool._create_connection() is direct_connection + + factory_pool = PyMysqlConnectionPool( + {}, connection_factory=lambda: cloud_connection, on_connection_create=on_connection_create + ) + assert factory_pool._create_connection() is cloud_connection + + assert seen_connections == [direct_connection, cloud_connection] + + +def test_cloud_sql_connector_cleanup(mock_cloud_sql_module: MagicMock) -> None: + """Cloud SQL connector should be closed when the config closes.""" + mock_connector = MagicMock() + mock_cloud_sql_module.return_value = mock_connector + + with patch("sqlspec.adapters.pymysql.config.CLOUD_SQL_CONNECTOR_INSTALLED", True, create=True): + config = PyMysqlConfig( + connection_config={}, + driver_features={"enable_cloud_sql": True, "cloud_sql_instance": "project:region:instance"}, + ) + config.connection_instance = config._create_pool() + config._close_pool() + + mock_connector.close.assert_called_once() + assert config.get_cloud_sql_connector() is None From 3c512830a30d9baac83b3078ed6ff7220b52d6a8 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 2 Jul 2026 02:03:35 +0000 Subject: [PATCH 2/2] fix(pymysql): remove redundant connection cast --- sqlspec/adapters/pymysql/pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlspec/adapters/pymysql/pool.py b/sqlspec/adapters/pymysql/pool.py index 351a25fd0..742295a56 100644 --- a/sqlspec/adapters/pymysql/pool.py +++ b/sqlspec/adapters/pymysql/pool.py @@ -76,7 +76,7 @@ def _create_connection(self) -> PyMysqlConnection: if self._on_connection_create is not None: self._on_connection_create(connection) - return cast("PyMysqlConnection", connection) + return connection def _is_connection_alive(self, connection: PyMysqlConnection) -> bool: try: