Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/reference/adapters/pymysql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@ Configuration
:members:
:show-inheritance:

Cloud SQL Connector
===================

PyMySQL configs can use the in-process Google Cloud SQL Python Connector by
installing the ``cloud-sql`` extra and enabling the connector in
``driver_features``:

.. code-block:: python

from sqlspec.adapters.pymysql import PyMysqlConfig

config = PyMysqlConfig(
connection_config={
"user": "app-user",
"password": "secret",
"database": "app",
},
driver_features={
"enable_cloud_sql": True,
"cloud_sql_instance": "project:region:instance",
"cloud_sql_ip_type": "PRIVATE",
},
)

When ``enable_cloud_sql`` is true, ``cloud_sql_instance`` is required and must
use ``project:region:instance`` format. Host, port, socket, and direct auth
connection values are passed through the connector rather than opened directly
by PyMySQL.

Driver
======

Expand Down
119 changes: 118 additions & 1 deletion sqlspec/adapters/pymysql/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from sqlspec.adapters.pymysql.pool import PyMysqlConnectionPool
from sqlspec.config import ExtensionConfigs, SyncDatabaseConfig
from sqlspec.driver._sync import SyncPoolConnectionContext, SyncPoolSessionFactory
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.exceptions import ImproperConfigurationError, MissingDependencyError
from sqlspec.extensions.events import EventRuntimeHints
from sqlspec.typing import CLOUD_SQL_CONNECTOR_INSTALLED
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,6 +120,17 @@ class PyMysqlDriverFeatures(TypedDict):
events_backend: Event channel backend selection.
enable_local_infile_bulk_load: Route load_from_arrow through LOAD DATA LOCAL INFILE.
Requires local_infile=True in connection_config.
enable_cloud_sql: Enable Google Cloud SQL connector integration.
Requires cloud-sql-python-connector package.
Defaults to False (explicit opt-in required).
cloud_sql_instance: Cloud SQL instance connection name.
Format: "project:region:instance"
Required when enable_cloud_sql is True.
cloud_sql_enable_iam_auth: Enable IAM database authentication.
Defaults to False for passwordless authentication.
cloud_sql_ip_type: IP address type for connection.
Options: "PUBLIC", "PRIVATE", "PSC"
Defaults to "PRIVATE".
"""

json_serializer: NotRequired["Callable[[Any], str]"]
Expand All @@ -127,6 +139,62 @@ class PyMysqlDriverFeatures(TypedDict):
enable_events: NotRequired[bool]
events_backend: NotRequired[str]
enable_local_infile_bulk_load: NotRequired[bool]
enable_cloud_sql: NotRequired[bool]
cloud_sql_instance: NotRequired[str]
cloud_sql_enable_iam_auth: NotRequired[bool]
cloud_sql_ip_type: NotRequired[str]


_CLOUD_SQL_DIRECT_CONNECTION_KEYS = frozenset((
"bind_address",
"database",
"host",
"password",
"port",
"ssl",
"unix_socket",
"user",
))


class _PyMysqlCloudSqlConnector:
__slots__ = ("_config", "_database", "_driver_kwargs", "_password", "_user")

def __init__(
self,
config: "PyMysqlConfig",
user: str | None,
password: str | None,
database: str | None,
driver_kwargs: "dict[str, Any]",
) -> None:
self._config = config
self._user = user
self._password = password
self._database = database
self._driver_kwargs = driver_kwargs

def __call__(self) -> "PyMysqlConnection":
connector = self._config.get_cloud_sql_connector()
if connector is None:
msg = "Cloud SQL connector is not initialized"
raise ImproperConfigurationError(msg)

conn_kwargs: dict[str, Any] = {
**self._driver_kwargs,
"instance_connection_string": self._config.driver_features["cloud_sql_instance"],
"driver": "pymysql",
"enable_iam_auth": self._config.driver_features.get("cloud_sql_enable_iam_auth", False),
"ip_type": self._config.driver_features.get("cloud_sql_ip_type", "PRIVATE"),
}
if self._user:
conn_kwargs["user"] = self._user
if self._password:
conn_kwargs["password"] = self._password
if self._database:
conn_kwargs["db"] = self._database

return cast("PyMysqlConnection", connector.connect(**conn_kwargs))


class PyMysqlConnectionContext(SyncPoolConnectionContext):
Expand Down Expand Up @@ -198,20 +266,69 @@ def __init__(
**kwargs,
)

self._cloud_sql_connector: Any | None = None
self._validate_connector_config()

def get_cloud_sql_connector(self) -> Any | None:
"""Return the configured Cloud SQL connector instance."""
return self._cloud_sql_connector

def _validate_connector_config(self) -> None:
"""Validate Google Cloud SQL connector configuration."""
if not self.driver_features.get("enable_cloud_sql", False):
return

if not CLOUD_SQL_CONNECTOR_INSTALLED:
raise MissingDependencyError(package="cloud-sql-python-connector", install_package="cloud-sql")

instance = self.driver_features.get("cloud_sql_instance")
if not instance:
msg = "cloud_sql_instance required when enable_cloud_sql is True. Format: 'project:region:instance'"
raise ImproperConfigurationError(msg)

cloud_sql_instance_parts_expected = 2
if instance.count(":") != cloud_sql_instance_parts_expected:
msg = f"Invalid Cloud SQL instance format: {instance}. Expected format: 'project:region:instance'"
raise ImproperConfigurationError(msg)

def _setup_cloud_sql_connector(self, config: "dict[str, Any]") -> "_PyMysqlCloudSqlConnector":
"""Setup Cloud SQL connector and return a pool connection factory."""
from google.cloud.sql.connector import Connector # type: ignore[import-untyped,unused-ignore]

self._cloud_sql_connector = Connector()

user = config.get("user")
password = config.get("password")
database = config.get("database")

for key in _CLOUD_SQL_DIRECT_CONNECTION_KEYS:
config.pop(key, None)

return _PyMysqlCloudSqlConnector(self, user, password, database, dict(config))

def _create_pool(self) -> "PyMysqlConnectionPool":
config = dict(self.connection_config)
pool_recycle = config.pop("pool_recycle_seconds", 86400)
health_check = config.pop("health_check_interval", 30.0)
connection_factory = None
if self.driver_features.get("enable_cloud_sql", False):
connection_factory = self._setup_cloud_sql_connector(config)
return PyMysqlConnectionPool(
config,
recycle_seconds=pool_recycle,
health_check_interval=health_check,
on_connection_create=self._user_connection_hook,
connection_factory=connection_factory,
)

def _close_pool(self) -> None:
if self.connection_instance:
self.connection_instance.close()
self.connection_instance = None

if self._cloud_sql_connector is not None:
self._cloud_sql_connector.close()
self._cloud_sql_connector = None

def create_connection(self) -> PyMysqlConnection:
pool = self.provide_pool()
Expand Down
1 change: 1 addition & 0 deletions sqlspec/adapters/pymysql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def apply_driver_features(
features: dict[str, Any] = dict(driver_features) if driver_features else {}
json_serializer = features.setdefault("json_serializer", to_json)
json_deserializer = features.setdefault("json_deserializer", from_json)
features.setdefault("enable_cloud_sql", False)

if json_serializer is not None:
parameter_config = statement_config.parameter_config.with_json_serializers(
Expand Down
11 changes: 9 additions & 2 deletions sqlspec/adapters/pymysql/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PyMysqlConnectionPool:
"""Thread-local connection manager for PyMySQL."""

__slots__ = (
"_connection_factory",
"_connection_parameters",
"_health_check_interval",
"_on_connection_create",
Expand All @@ -41,6 +42,7 @@ def __init__(
recycle_seconds: int = 86400,
health_check_interval: float = 30.0,
on_connection_create: "Callable[[PyMysqlConnection], None] | None" = None,
connection_factory: "Callable[[], PyMysqlConnection] | None" = None,
) -> None:
"""Initialize the thread-local connection manager.

Expand All @@ -49,8 +51,10 @@ def __init__(
recycle_seconds: Connection recycle time in seconds (default 24h)
health_check_interval: Seconds of idle time before running health check
on_connection_create: Callback executed when connection is created
connection_factory: Optional factory for custom connection creation
"""
self._connection_parameters = connection_parameters
self._connection_factory = connection_factory
self._thread_local = threading.local()
self._recycle_seconds = recycle_seconds
self._health_check_interval = health_check_interval
Expand All @@ -63,13 +67,16 @@ def _database_name(self) -> str:
return str(self._connection_parameters.get("database", "unknown"))

def _create_connection(self) -> PyMysqlConnection:
connection = pymysql.connect(**self._connection_parameters)
if self._connection_factory is not None:
connection = self._connection_factory()
else:
connection = pymysql.connect(**self._connection_parameters)

# Call user-provided callback after connection creation
if self._on_connection_create is not None:
self._on_connection_create(connection)

return cast("PyMysqlConnection", connection)
return connection

def _is_connection_alive(self, connection: PyMysqlConnection) -> bool:
try:
Expand Down
Loading
Loading