diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index dc8da639..e6e547f7 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -12,12 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import mysql.connector +from mysql.connector import CMySQLConnection +from mysql.connector.errors import Error from sqlalchemy.dialects.mysql.mysqlconnector import \ MySQLDialect_mysqlconnector +from sqlalchemy.engine import default + +from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.properties import (Properties, + PropertiesUtils) + +if TYPE_CHECKING: + from sqlalchemy import Connection + + from aws_advanced_python_wrapper.hostinfo import HostInfo class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): + supports_statement_cache = True + """ SQLAlchemy dialect for AWS Advanced Python Wrapper with mysqlconnector. Extends the SQLAlchemy MySQL mysqlconnector dialect. This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used @@ -27,3 +46,154 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): name = 'mysql' driver = 'aws_wrapper_mysqlconnector' + + @classmethod + def import_dbapi(cls): + """ + Return the DB-API 2.0 module. + SQLAlchemy calls this to get the driver module. + """ + import aws_advanced_python_wrapper + return aws_advanced_python_wrapper + + def create_connect_args(self, url): + """ + Transform SQLAlchemy URL into connection arguments. + Must include the 'target' parameter for our wrapper driver. + """ + # Extract standard connection parameters + opts = url.translate_connect_args(username='user') + + # Add query string parameters + opts.update(url.query) + + # Add the required 'target' parameter for our wrapper + if 'target' not in opts: + opts['target'] = mysql.connector.Connect + if 'wrapper_plugins' not in opts: + opts['plugins'] = "aurora_connection_tracker,failover" + else: + opts['plugins'] = opts['wrapper_plugins'] + opts.pop('wrapper_plugins', None) + if 'connect_timeout' in opts: + opts['connect_timeout'] = int(opts['connect_timeout']) + + # Return empty args list and kwargs dict + return [], opts + + def _detect_charset(self, connection: Connection) -> str: + if isinstance(connection, CMySQLConnection): + return connection.charset + else: + raise Exception("Could not detect charset because connection was not a CMySQLConnection.") + + def _extract_error_code(self, exception: BaseException) -> int: + if isinstance(exception, AwsWrapperError): + err = exception.driver_error + if err and isinstance(err, Error): + return err.errno + else: + raise Exception("Could not extract error code because driver_error was not a BaseException.") + else: + raise Exception("Could not extract error code because exception was not an AwsWrapperError.") + + def initialize(self, connection): + """ + Override initialization to handle type introspection. + The parent class tries to use TypeInfo.fetch() which requires + a native SQLAlchemy connection, not AwsWrapperConnection. + """ + + # Unwrap SQLAlchemy's connection object + wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection) + + # this is driver-based, does not need server version info + # and is fairly critical for even basic SQL operations + self._connection_charset: Optional[str] = self._detect_charset( + wrapper_conn.target_connection + ) + + # call super().initialize() because we need to have + # server_version_info set up. in 1.4 under python 2 only this does the + # "check unicode returns" thing, which is the one area that some + # SQL gets compiled within initialize() currently + default.DefaultDialect.initialize(self, connection) + + self._detect_sql_mode(connection) + self._detect_ansiquotes(connection) # depends on sql mode + self._detect_casing(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer( + self, server_ansiquotes=self._server_ansiquotes + ) + + self.supports_sequences = ( + self.is_mariadb and self.server_version_info >= (10, 3) + ) + + self.supports_for_update_of = ( + self._is_mysql and self.server_version_info >= (8,) + ) + + self.use_mysql_for_share = ( + self._is_mysql and self.server_version_info >= (8, 0, 1) + ) + + self._needs_correct_for_88718_96365 = ( + not self.is_mariadb and self.server_version_info >= (8,) + ) + + self.delete_returning = ( + self.is_mariadb and self.server_version_info >= (10, 0, 5) + ) + + self.insert_returning = ( + self.is_mariadb and self.server_version_info >= (10, 5) + ) + + self._requires_alias_for_on_duplicate_key = ( + self._is_mysql and self.server_version_info >= (8, 0, 20) + ) + + self._warn_for_known_db_issues() + + def _get_wrapper_connection_and_parent(self, connection): + """ + Traverse the connection chain to find AwsWrapperConnection and its parent connection. + + Args: + connection: SQLAlchemy Connection object + + Returns: + AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None + """ + # Start with the DBAPI connection + parent = connection + child = connection.connection + + # Traverse up to 5 levels deep (reasonable limit) + for _ in range(5): + if isinstance(child, AwsWrapperConnection): + return child, parent + + # Try to go deeper if there's a .connection attribute + if hasattr(child, 'connection'): + parent = child + child = child.connection + else: + break + + return None + + def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties: + prop_copy: Properties = Properties(props.copy()) + + prop_copy["host"] = host_info.host + + if host_info.is_port_specified(): + prop_copy["port"] = str(host_info.port) + + PropertiesUtils.remove_wrapper_props(prop_copy) + return prop_copy diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 82427f95..d85f105c 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -29,7 +29,6 @@ subqueryload) from sqlalchemy.sql import func -from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, enable_on_engines) from ..utils.database_engine import DatabaseEngine @@ -114,41 +113,29 @@ class Book(Base): TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, TestEnvironmentFeatures.PERFORMANCE]) class TestSqlAlchemy: - @pytest.fixture(scope='class') - def rds_utils(self): - region: str = TestEnvironment.get_current().get_info().get_region() - return RdsTestUtility(region) - - - @pytest.fixture(scope="class") + @pytest.fixture(scope="function") def engine(self, conn_utils): conn_str = f'mysql+aws_wrapper_mysqlconnector://{conn_utils.user}:{conn_utils.password}@{conn_utils.writer_cluster_host}:{conn_utils.port}/{conn_utils.dbname}' engine = create_engine(conn_str) Base.metadata.create_all(engine) yield engine Base.metadata.drop_all(engine) + engine.dispose() - @pytest.fixture(scope="class") - def Session(self, engine): - Session = sessionmaker(bind=engine) - yield Session - - @pytest.fixture(scope="class") - def session(self, Session): - session = Session() + @pytest.fixture(scope="function") + def session(self, engine): + session = sessionmaker(bind=engine)() yield session session.rollback() session.close() - def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, engine): + def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, session): """Test SQLAlchemy backend configuration with empty plugins""" # Verify that the connection is using the AWS wrapper - with engine.connect() as connection: - assert connection.connection is not None + assert session.connection().connection is not None # Test basic connection functionality - with Session(engine) as session: - assert session.query(TestModel).count() == 0 + assert session.query(TestModel).count() == 0 def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment): """Test basic SQLAlchemy ORM operations (CRUD)""" diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py new file mode 100644 index 00000000..0ab30d28 --- /dev/null +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -0,0 +1,607 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: N806 + +from __future__ import annotations + +import json +import uuid +from datetime import date, datetime, time, timezone +from decimal import Decimal +from time import perf_counter_ns, sleep +from typing import Any, ClassVar, Dict, List, Optional + +import boto3 +import pytest +from boto3 import client +from botocore.exceptions import ClientError +from sqlalchemy import (JSON, BigInteger, Boolean, Column, Date, DateTime, + Float, ForeignKey, Integer, Numeric, SmallInteger, + String, Text, Time, create_engine, text) +from sqlalchemy.exc import DBAPIError +from sqlalchemy.orm import (DeclarativeBase, Mapped, Session, mapped_column, + relationship, sessionmaker) + +from aws_advanced_python_wrapper.errors import FailoverSuccessError +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from ..utils.conditions import (disable_on_features, enable_on_deployments, + enable_on_engines, enable_on_features, + enable_on_num_instances) +from ..utils.database_engine import DatabaseEngine +from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.test_environment import TestEnvironment +from ..utils.test_environment_features import TestEnvironmentFeatures + + +class Base(DeclarativeBase): + pass + +class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254), unique=True) + age: Mapped[int] = mapped_column() + is_active: Mapped[Optional[bool]] = mapped_column(Boolean, default=True) + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=datetime.now(timezone.utc)) + + +class DataTypeModel(Base): + """Model for testing various data types""" + __tablename__ = 'sqlalchemy_data_type_model' + + id: Mapped[int] = mapped_column(primary_key=True) + + # String fields + string_field: Mapped[Optional[str]] = mapped_column(String(255)) + text_field: Mapped[Optional[str]] = mapped_column(Text) + + # Numeric fields + integer_field: Mapped[Optional[int]] = mapped_column() + small_integer_field: Mapped[Optional[int]] = mapped_column(SmallInteger) + big_integer_field: Mapped[Optional[int]] = mapped_column(BigInteger) + numeric_field: Mapped[Optional[Decimal]] = mapped_column(Numeric(10, 2)) + float_field: Mapped[Optional[float]] = mapped_column(Float) + + # Boolean field + boolean_field: Mapped[Optional[bool]] = mapped_column(Boolean, default=False) + + # Date/Time fields + date_field: Mapped[Optional[date]] = mapped_column(Date) + time_field: Mapped[Optional[time]] = mapped_column(Time) + datetime_field: Mapped[Optional[datetime]] = mapped_column(DateTime) + + # JSON field (MySQL 5.7+) + json_field: Mapped[Optional[Any]] = mapped_column(JSON) + + +class Author(Base): + """Author model for relationship testing""" + __tablename__ = 'sqlalchemy_author' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254)) + birth_date: Mapped[Optional[date]] = mapped_column(Date) + + books: Mapped[List[Book]] = relationship(back_populates='author', cascade='all, delete-orphan') + + +class Book(Base): + """Book model for relationship testing""" + __tablename__ = 'sqlalchemy_book' + + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] = mapped_column(String(200)) + author_id: Mapped[int] = mapped_column(ForeignKey("sqlalchemy_author.id")) + publication_date: Mapped[date] = mapped_column(Date) + pages: Mapped[int] = mapped_column() + price: Mapped[Decimal] = mapped_column(Numeric(8, 2)) + + author: Mapped[Author] = relationship(back_populates='books') + +def _build_url(user, password, host, port, dbname, wrapper_plugins=None, **extra_options): + """Build a SQLAlchemy connection URL using the aws wrapper dialect.""" + query_params = {} + if wrapper_plugins: + query_params['wrapper_plugins'] = wrapper_plugins + query_params['connect_timeout'] = str(extra_options.get('connect_timeout', 10)) + for k, v in extra_options.items(): + if k != 'connect_timeout': + query_params[k] = str(v) + + from sqlalchemy.engine import URL + return URL.create( + drivername="mysql+aws_wrapper_mysqlconnector", + username=user, + password=password, + host=host, + port=port, + database=dbname, + query=query_params, + ) + +@enable_on_engines([DatabaseEngine.MYSQL]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestSqlAlchemyPlugins: + endpoint_id: ClassVar[str] = f"test-sqlalchemy-endpoint-{uuid.uuid4()}" + endpoint_info: ClassVar[Dict[str, Any]] = {} + reuse_existing_endpoint: ClassVar[bool] = False + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def create_secret(self, conn_utils): + """Create a secret in AWS Secrets Manager with database credentials.""" + region = TestEnvironment.get_current().get_info().get_region() + sm_client = boto3.client('secretsmanager', region_name=region) + env = TestEnvironment.get_current() + + secret_name = f"TestSecret-{uuid.uuid4()}" + engine = "postgres" if env.get_engine() == "pg" else "mysql" + secret_value = { + "engine": engine, + "dbname": env.get_info().get_database_info().get_default_db_name(), + "host": env.get_info().get_database_info().get_cluster_endpoint(), + "username": conn_utils.user, + "password": conn_utils.password, + "description": "Test secret generated by integration tests." + } + + try: + response = sm_client.create_secret( + Name=secret_name, + SecretString=json.dumps(secret_value) + ) + secret_arn = response['ARN'] + yield secret_name, secret_arn + finally: + try: + sm_client.delete_secret(SecretId=secret_name, ForceDeleteWithoutRecovery=True) + except Exception: + pass + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + if not self.reuse_existing_endpoint: + instances = env_info.get_database_info().get_instances() + self._create_endpoint(rds_client, instances[0:1]) + + self._wait_until_endpoint_available(rds_client) + yield + if not self.reuse_existing_endpoint: + self._delete_endpoint(rds_client) + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 + available = False + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[{"Name": "db-cluster-endpoint-type", "Values": ["custom"]}] + ) + endpoints = response["DBClusterEndpoints"] + if len(endpoints) != 1: + sleep(3) + continue + TestSqlAlchemyPlugins.endpoint_info = endpoints[0] + if endpoints[0]["Status"] == "available": + available = True + break + sleep(3) + if not available: + pytest.fail(f"Timed out waiting for custom endpoint: '{self.endpoint_id}'.") + + def _create_endpoint(self, rds_client, instances): + instance_ids = [i.get_instance_id() for i in instances] + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + def _delete_endpoint(self, rds_client): + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + self._wait_until_endpoint_deleted(rds_client) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pytest.fail(e) + + def _wait_until_endpoint_deleted(self, rds_client): + end_ns = perf_counter_ns() + 3 * 60 * 1_000_000_000 + deleted = False + while perf_counter_ns() < end_ns: + try: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[{"Name": "db-cluster-endpoint-type", "Values": ["custom"]}] + ) + if len(response["DBClusterEndpoints"]) == 0: + deleted = True + break + sleep(3) + except ClientError as e: + if e.response['Error']['Code'] == 'DBClusterEndpointNotFoundFault': + deleted = True + break + sleep(3) + if deleted: + print(f"Custom endpoint '{self.endpoint_id}' successfully deleted.") + else: + print(f"Warning: Timed out waiting for custom endpoint deletion: '{self.endpoint_id}'.") + + @pytest.fixture(scope='function') + def sa_models(self, sa_setup): + """Create SQLAlchemy tables and provide model classes.""" + engine = sa_setup['engine'] + test_id = str(uuid.uuid4())[:8] + + Base.metadata.create_all(engine, tables=[ + TestModel.__table__, DataTypeModel.__table__, + Author.__table__, Book.__table__ + ]) + + models = { + 'TestModel': TestModel, + 'DataTypeModel': DataTypeModel, + 'Author': Author, + 'Book': Book, + } + + yield models + + Base.metadata.drop_all(engine, tables=[ + Book.__table__, Author.__table__, + DataTypeModel.__table__, TestModel.__table__ + ]) + + + @pytest.fixture(scope='function') + def sa_setup(self, conn_utils, create_secret, request, create_custom_endpoint=None): + """Setup SQLAlchemy engine with configurable plugins.""" + if hasattr(request, 'param') and isinstance(request.param, dict): + config = request.param + plugins_config = config.get('wrapper_plugins', 'aurora_connection_tracker,failover_v2') + extra_options = config.get('options', {}) + use_custom_endpoint = config.get('use_custom_endpoint', False) + custom_endpoint_host = None + if use_custom_endpoint and create_custom_endpoint: + custom_endpoint_host = self.endpoint_info.get('Endpoint') + + if 'iam' in plugins_config: + user = conn_utils.iam_user + extra_options['auth_plugin'] = 'mysql_clear_password' + elif 'aws_secrets_manager' in plugins_config: + user = None + _, secret_arn = create_secret + extra_options['secrets_manager_secret_id'] = secret_arn + else: + user = config.get('user', conn_utils.user) + + if 'iam' in plugins_config or 'aws_secrets_manager' in plugins_config: + password = None + else: + password = config.get('password', conn_utils.password) + + host = custom_endpoint_host or config.get('host', conn_utils.writer_cluster_host) + else: + plugins_config = 'aurora_connection_tracker,failover_v2' + extra_options = {} + user = conn_utils.user + password = conn_utils.password + host = conn_utils.writer_host + + url = _build_url(user, password, host, conn_utils.port, conn_utils.dbname, + wrapper_plugins=plugins_config, **extra_options) + engine = create_engine(url) + SessionLocal = sessionmaker(bind=engine) + + yield {'engine': engine, 'SessionLocal': SessionLocal, + 'plugins': plugins_config, 'options': extra_options} + + engine.dispose() + + def test_sqlalchemy_basic_insert_with_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test basic SQLAlchemy insert operation with plugins enabled""" + TestModel = sa_models['TestModel'] + session: Session = sa_setup['SessionLocal']() + + try: + session.query(TestModel).delete() + obj = TestModel(name="Plugin Test User", email="plugin@example.com", age=25, is_active=True) + session.add(obj) + session.commit() + + assert obj.id is not None + assert obj.name == "Plugin Test User" + + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.name == "Plugin Test User" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': ''}], indirect=True) + def test_sqlalchemy_with_no_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with no plugins enabled""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == '' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="No Plugins User", email="noplugins@example.com", age=30) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "No Plugins User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': 'failover_v2'}], indirect=True) + def test_sqlalchemy_with_failover_only(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with only failover plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'failover_v2' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Failover Only User", email="failover@example.com", age=35) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "Failover Only User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': 'aurora_connection_tracker,failover_v2'}], indirect=True) + def test_sqlalchemy_with_multiple_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with multiple plugins enabled""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'aurora_connection_tracker,failover_v2' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Multi Plugin User", email="multiplugin@example.com", age=40) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "Multi Plugin User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'aws_secrets_manager', + 'use_secrets_manager': True + }], indirect=True) + def test_sqlalchemy_with_secrets_manager_plugin(self, test_environment: TestEnvironment, sa_setup, sa_models): + """Test SQLAlchemy with AWS Secrets Manager plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'aws_secrets_manager' + assert 'secrets_manager_secret_id' in config['options'] + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Secrets Manager User", email="secrets@example.com", age=45) + session.add(obj) + session.commit() + assert obj.id is not None + + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.email == "secrets@example.com" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'iam', + 'password': '', + 'options': {} + }], indirect=True) + def test_sqlalchemy_with_iam_plugin(self, test_environment: TestEnvironment, sa_models, sa_setup, conn_utils): + """Test SQLAlchemy with IAM authentication plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'iam' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="IAM User", email="iam@example.com", age=50) + session.add(obj) + session.commit() + assert obj.id is not None + + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.email == "iam@example.com" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'failover_v2', + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironment, sa_setup, sa_models, rds_utils): + """Test SQLAlchemy failover during query operations""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert 'failover_v2' in config['plugins'] + + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Failover Test User", email="failover@example.com", age=30) + session.add(obj) + session.commit() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + rds_utils.failover_cluster_and_wait_until_writer_changed() + + with pytest.raises(DBAPIError): + session.query(TestModel).filter_by(id=obj.id).first() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + if row: + current_writer_id = row._tuple()[0] + else: + raise Exception("Failed to get current_writer_id from row because row was None.") + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + ''' + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'custom_endpoint,failover_v2', + 'use_custom_endpoint': True, + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_sqlalchemy_custom_endpoint_failover_during_query( + self, test_environment: TestEnvironment, create_custom_endpoint, + sa_setup, sa_models, rds_utils): + """Test SQLAlchemy failover with custom endpoint during query operations""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert 'custom_endpoint' in config['plugins'] + assert 'failover_v2' in config['plugins'] + + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Custom Endpoint Failover Test User", email="custom_failover@example.com", age=35) + session.add(obj) + session.commit() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + rds_utils.failover_cluster_and_wait_until_writer_changed() + + with pytest.raises(DBAPIError): + session.query(TestModel).filter_by(id=obj.id).first() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + current_writer_id = row[0] + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + ''' + + @pytest.fixture(scope='function') + def sa_rw_split_setup(self, conn_utils): + """Setup SQLAlchemy with read/write splitting configuration""" + writer_url = _build_url( + conn_utils.user, conn_utils.password, conn_utils.writer_cluster_host, + conn_utils.port, conn_utils.dbname, plugins='read_write_splitting') + reader_url = _build_url( + conn_utils.user, conn_utils.password, conn_utils.writer_cluster_host, + conn_utils.port, conn_utils.dbname, plugins='read_write_splitting') + + writer_engine = create_engine(writer_url) + reader_engine = create_engine(reader_url) + WriterSession = sessionmaker(bind=writer_engine) + ReaderSession = sessionmaker(bind=reader_engine) + + test_id = str(uuid.uuid4())[:8] + + class RWSplitTestModel(Base): + __tablename__ = f'sa_rw_split_test_{test_id}' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) + value = Column(Integer, nullable=False) + + Base.metadata.create_all(writer_engine, tables=[RWSplitTestModel.__table__]) + + yield { + 'model': RWSplitTestModel, + 'writer_engine': writer_engine, + 'reader_engine': reader_engine, + 'WriterSession': WriterSession, + 'ReaderSession': ReaderSession, + } + + Base.metadata.drop_all(writer_engine, tables=[RWSplitTestModel.__table__]) + writer_engine.dispose() + reader_engine.dispose() + +