diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9eace8810d..dacf0f4186 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -52,7 +52,8 @@ from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported, EndPoint, DefaultEndPoint, DefaultEndPointFactory, - SniEndPointFactory, ConnectionBusy, locally_supported_compressions) + SniEndPointFactory, ConnectionBusy, locally_supported_compressions, + SSLSessionCache) from cassandra.cqltypes import UserType import cassandra.cqltypes as types from cassandra.encoder import Encoder @@ -876,6 +877,39 @@ def default_retry_policy(self, policy): .. versionadded:: 3.17.0 """ + ssl_session_cache = None + """ + An optional :class:`~cassandra.connection.SSLSessionCache` instance used to + enable TLS session resumption (via session tickets or PSK) for all + connections managed by this cluster. + + When :attr:`~Cluster.ssl_context` is set, a cache is created automatically + so that reconnections to the same host can skip the full TLS handshake. + Set this to :const:`None` explicitly to disable session caching. + + Note: automatic caching is **not** enabled for the legacy + :attr:`~Cluster.ssl_options` path because each connection builds a fresh + ``SSLContext``, making session reuse impossible. If you migrate to + ``ssl_context``, the cache will be created automatically. + + You may also pass a custom :class:`~cassandra.connection.SSLSessionCache` + instance with specific ``max_size`` and ``ttl`` parameters:: + + from cassandra.connection import SSLSessionCache + + cluster = Cluster( + ssl_context=ssl_context, + ssl_session_cache=SSLSessionCache(max_size=200, ttl=7200), + ) + + Note: TLS 1.2 sessions are cached immediately after connect. TLS 1.3 + sessions are cached after the CQL handshake completes (Ready / AuthSuccess), + because session tickets are sent asynchronously by the server. + + Works with all connection classes: stdlib ``ssl`` (asyncore, libev, gevent, + asyncio) and PyOpenSSL (Twisted, Eventlet). + """ + sockopts = None """ An optional list of tuples which will be used as arguments to @@ -1217,7 +1251,8 @@ def __init__(self, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, application_info:Optional[ApplicationInfoBase]=None, - client_routes_config:Optional[ClientRoutesConfig]=None + client_routes_config:Optional[ClientRoutesConfig]=None, + ssl_session_cache=_NOT_SET ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1461,6 +1496,17 @@ def __init__(self, self.ssl_options = ssl_options self.ssl_context = ssl_context + + # Auto-create a session cache when TLS is enabled, unless the caller + # explicitly passed ssl_session_cache (including None to opt out). + if ssl_session_cache is _NOT_SET: + if ssl_context is not None: + self.ssl_session_cache = SSLSessionCache() + else: + self.ssl_session_cache = None + else: + self.ssl_session_cache = ssl_session_cache + self.sockopts = sockopts self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait @@ -1706,6 +1752,7 @@ def _make_connection_kwargs(self, endpoint, kwargs_dict): kwargs_dict.setdefault('sockopts', self.sockopts) kwargs_dict.setdefault('ssl_options', self.ssl_options) kwargs_dict.setdefault('ssl_context', self.ssl_context) + kwargs_dict.setdefault('ssl_session_cache', self.ssl_session_cache) kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) diff --git a/cassandra/connection.py b/cassandra/connection.py index 72b273ec37..8e2d7e3b6e 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import absolute_import # to enable import io from stdlib -from collections import defaultdict, deque +from collections import defaultdict, deque, namedtuple, OrderedDict import errno from functools import wraps, partial, total_ordering from heapq import heappush, heappop @@ -22,7 +22,7 @@ import socket import struct import sys -from threading import Thread, Event, RLock, Condition +from threading import Thread, Event, Lock, RLock, Condition import time import ssl import uuid @@ -163,6 +163,15 @@ def socket_family(self): """ return socket.AF_UNSPEC + @property + def tls_session_cache_key(self): + """ + Returns the cache key components for TLS session caching. + This is a tuple that uniquely identifies this endpoint for TLS session purposes. + Subclasses may override this to include additional components (e.g., SNI server name). + """ + return (self.address, self.port) + def resolve(self): """ Resolve the endpoint to an address/port. This is called @@ -210,6 +219,10 @@ def port(self): def resolve(self): return self._address, self._port + @property + def tls_session_cache_key(self): + return (self.address, self.port) + def __eq__(self, other): return isinstance(other, DefaultEndPoint) and \ self.address == other.address and self.port == other.port @@ -277,6 +290,14 @@ def port(self): def ssl_options(self): return self._ssl_options + @property + def tls_session_cache_key(self): + """ + Returns the cache key including server_name for SNI endpoints. + This prevents cache collisions when multiple SNI endpoints use the same proxy. + """ + return (self.address, self.port, self._server_name) + def resolve(self): try: resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port, @@ -395,6 +416,14 @@ def port(self): def socket_family(self): return socket.AF_UNIX + @property + def tls_session_cache_key(self): + """ + Returns the cache key for Unix socket endpoints. + Since Unix sockets don't have a port, only the path is used. + """ + return (self._unix_socket_path,) + def resolve(self): return self.address, None @@ -455,6 +484,14 @@ def port(self) -> Optional[int]: def host_id(self) -> uuid.UUID: return self._host_id + @property + def tls_session_cache_key(self): + """ + Returns the cache key for Client Routes endpoints. + Uses host_id and original address for uniqueness. + """ + return (str(self._host_id), self._original_address, self._original_port) + def resolve(self) -> Tuple[str, int]: """ Resolve endpoint by delegating to the handler. @@ -783,6 +820,132 @@ def generate(self, shard_id: int, total_shards: int): DefaultShardAwarePortGenerator = ShardAwarePortGenerator(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH) +_SessionCacheEntry = namedtuple('_SessionCacheEntry', ['session', 'timestamp']) + + +class SSLSessionCache(object): + """ + A thread-safe cache of TLS session objects, keyed by connection TLS + identity, with LRU eviction and TTL expiration. + + When TLS is enabled, the driver stores the negotiated session after each + successful handshake and reuses it for subsequent connections to the same + host, enabling TLS session resumption (tickets / PSK) without any extra + configuration. + + This cache is created automatically by :class:`.Cluster` when + ``ssl_context`` or ``ssl_options`` are set. Pass ``ssl_session_cache=None`` + to :class:`.Cluster` to opt out. + + Works with both the stdlib ``ssl`` module (asyncore, libev, gevent, asyncio + reactors) and PyOpenSSL (Twisted and Eventlet reactors). + + TLS session resumption works with both TLS 1.2 and TLS 1.3: + + - TLS 1.2: Session IDs (RFC 5246) and optionally Session Tickets (RFC 5077) + - TLS 1.3: Session Tickets (RFC 8446) + """ + + # Cleanup expired sessions every N set() calls + _EXPIRY_CLEANUP_INTERVAL = 100 + + def __init__(self, max_size=100, ttl=3600): + """ + Initialize the TLS session cache. + + :param max_size: Maximum number of sessions to cache. Must be at + least ``1``. When full, the least recently used entry is evicted. + Default: ``100``. + :param ttl: Time-to-live for cached sessions in seconds. Must be + greater than ``0``. Expired entries are lazily removed on access + and periodically during :meth:`set`. Default: ``3600`` (one hour). + """ + if max_size < 1: + raise ValueError("max_size must be >= 1, got %r" % (max_size,)) + if ttl <= 0: + raise ValueError("ttl must be > 0, got %r" % (ttl,)) + self._sessions = OrderedDict() + self._lock = Lock() + self._max_size = max_size + self._ttl = ttl + self._operation_count = 0 + + @property + def max_size(self): + return self._max_size + + @property + def ttl(self): + return self._ttl + + def get(self, key): + """ + Return the cached TLS session for *key*, or ``None`` if none + is stored or if the entry has expired. Accessing an entry + marks it as recently used. + """ + with self._lock: + entry = self._sessions.get(key) + if entry is None: + return None + if time.time() - entry.timestamp > self._ttl: + del self._sessions[key] + return None + # Mark as recently used + self._sessions.move_to_end(key) + return entry.session + + def set(self, key, session): + """ + Store *session* for *key*. ``None`` sessions are silently ignored. + """ + if session is None: + return + + current_time = time.time() + with self._lock: + self._operation_count += 1 + if self._operation_count >= self._EXPIRY_CLEANUP_INTERVAL: + self._operation_count = 0 + self._clear_expired_unlocked(current_time) + + if key in self._sessions: + self._sessions[key] = _SessionCacheEntry(session, current_time) + self._sessions.move_to_end(key) + return + + if len(self._sessions) >= self._max_size: + self._sessions.popitem(last=False) + + self._sessions[key] = _SessionCacheEntry(session, current_time) + + def clear(self): + """Clear all sessions from the cache.""" + with self._lock: + self._sessions.clear() + + def clear_expired(self): + """Remove all expired sessions from the cache.""" + with self._lock: + self._clear_expired_unlocked() + + def size(self): + """Return the current number of cached sessions.""" + with self._lock: + return len(self._sessions) + + def _clear_expired_unlocked(self, current_time=None): + """Remove all expired sessions (must be called with lock held).""" + if current_time is None: + current_time = time.time() + expired_keys = [ + key for key, entry in self._sessions.items() + if current_time - entry.timestamp > self._ttl + ] + for key in expired_keys: + del self._sessions[key] + + class Connection(object): CALLBACK_ERR_THREAD_THRESHOLD = 100 @@ -803,6 +966,7 @@ class Connection(object): endpoint = None ssl_options = None ssl_context = None + _ssl_session_cache = None last_error = None # The current number of operations that are in flight. More precisely, @@ -880,13 +1044,15 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None, cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, ssl_context=None, owning_pool=None, shard_id=None, total_shards=None, - on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None): + on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None, + ssl_session_cache=None): # TODO next major rename host to endpoint and remove port kwarg. self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) self.authenticator = authenticator self.ssl_options = ssl_options.copy() if ssl_options else {} self.ssl_context = ssl_context + self._ssl_session_cache = ssl_session_cache self.sockopts = sockopts self.compression = compression self.cql_version = cql_version @@ -1029,7 +1195,25 @@ def _wrap_socket_from_context(self): server_hostname = self.endpoint.address opts['server_hostname'] = server_hostname - return self.ssl_context.wrap_socket(self._socket, **opts) + ssl_sock = self.ssl_context.wrap_socket(self._socket, **opts) + + # Restore a previously cached session to enable TLS session resumption + # (session tickets / PSK). The session must be set *after* + # wrap_socket() (which only creates the SSLSocket) but *before* + # connect(), because connect() triggers the actual TLS handshake + # (via do_handshake_on_connect, which defaults to True). + # _initiate_connection, called after this method returns, performs + # the connect(). + if self._ssl_session_cache is not None: + cached_session = self._ssl_session_cache.get( + self._ssl_session_cache_key()) + if cached_session is not None: + try: + ssl_sock.session = cached_session + except (AttributeError, ssl.SSLError, ValueError): + log.debug("Could not restore TLS session for %s", self.endpoint) + + return ssl_sock def _initiate_connection(self, sockaddr): if self.features.shard_id is not None: @@ -1043,6 +1227,26 @@ def _initiate_connection(self, sockaddr): self._socket.connect(sockaddr) + def _cache_tls_session_if_needed(self): + """ + Store the current TLS session in the cache (if any) so that future + connections to the same endpoint can resume it. + """ + if self._ssl_session_cache is not None and self.ssl_context is not None: + session = getattr(self._socket, 'session', None) + if session is not None: + self._ssl_session_cache.set(self._ssl_session_cache_key(), session) + + def _ssl_session_cache_key(self): + """ + Return a cache key that matches the TLS peer identity. + + Delegates to the endpoint's ``tls_session_cache_key`` property, which + returns appropriate components for each endpoint type (e.g., includes + ``server_name`` for SNI endpoints to prevent cache collisions). + """ + return self.endpoint.tls_session_cache_key + # PYTHON-1331 # # Allow implementations specific to an event loop to add additional behaviours @@ -1074,6 +1278,16 @@ def _connect_socket(self): self._initiate_connection(sockaddr) self._socket.settimeout(None) + # Cache the negotiated TLS session for future resumption. + # For TLS 1.2 the session is available right after connect(). + # For TLS 1.3 the server sends the session ticket + # asynchronously after the first application-data exchange, + # so socket.session may still be None here; a second + # attempt is made in _cache_tls_session_if_needed() after + # the CQL handshake completes (see _handle_startup_response + # and _handle_auth_response). + self._cache_tls_session_if_needed() + local_addr = self._socket.getsockname() log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr) @@ -1578,6 +1792,9 @@ def _handle_startup_response(self, startup_response, did_authenticate=False): if ProtocolVersion.has_checksumming_support(self.protocol_version): self._enable_checksumming() + # TLS 1.3: the session ticket is sent after the first + # application-data exchange, so try caching it now. + self._cache_tls_session_if_needed() self.connected_event.set() elif isinstance(startup_response, AuthenticateMessage): log.debug("Got AuthenticateMessage on new connection (%s) from %s: %s", @@ -1634,6 +1851,9 @@ def _handle_auth_response(self, auth_response): self.authenticator.on_authentication_success(auth_response.token) if self._compressor: self.compressor = self._compressor + # TLS 1.3: the session ticket is sent after the first + # application-data exchange, so try caching it now. + self._cache_tls_session_if_needed() self.connected_event.set() elif isinstance(auth_response, AuthChallengeMessage): response = self.authenticator.evaluate_challenge(auth_response.challenge) diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index 234a4a574c..7cec610eed 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -108,6 +108,16 @@ def _wrap_socket_from_context(self): if self.ssl_options and 'server_hostname' in self.ssl_options: # This is necessary for SNI self._socket.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + # Apply cached TLS session for resumption (PyOpenSSL) + if self._ssl_session_cache: + cached_session = self._ssl_session_cache.get( + self._ssl_session_cache_key()) + if cached_session: + try: + self._socket.set_session(cached_session) + log.debug("Using cached TLS session for %s", self.endpoint) + except Exception: + log.debug("Could not restore TLS session for %s", self.endpoint) def _initiate_connection(self, sockaddr): if self.uses_legacy_ssl_options: @@ -116,6 +126,8 @@ def _initiate_connection(self, sockaddr): self._socket.connect(sockaddr) if self.ssl_context or self.ssl_options: self._socket.do_handshake() + # Store TLS session after successful handshake (PyOpenSSL) + self._cache_pyopenssl_session() def _match_hostname(self): if self.uses_legacy_ssl_options: @@ -126,6 +138,19 @@ def _match_hostname(self): raise Exception("Hostname verification failed! Certificate name '{}' " "doesn't endpoint '{}'".format(cert_name, self.endpoint.address)) + def _cache_pyopenssl_session(self): + """Store the PyOpenSSL TLS session in the cache after a successful handshake.""" + if self._ssl_session_cache is not None: + try: + session = self._socket.get_session() + if session: + self._ssl_session_cache.set( + self._ssl_session_cache_key(), session) + if self._socket.session_reused(): + log.debug("TLS session was reused for %s", self.endpoint) + except Exception: + log.debug("Could not cache TLS session for %s", self.endpoint) + def close(self): with self.lock: if self.is_closed: diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index 446200bf63..b73b23ba8e 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -139,11 +139,12 @@ def _on_loop_timer(self): @implementer(IOpenSSLClientConnectionCreator) class _SSLCreator(object): - def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout): + def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout, ssl_session_cache=None): self.endpoint = endpoint self.ssl_options = ssl_options self.check_hostname = check_hostname self.timeout = timeout + self.ssl_session_cache = ssl_session_cache if ssl_context: self.context = ssl_context @@ -170,12 +171,36 @@ def info_callback(self, connection, where, ret): if self.check_hostname and self.endpoint.address != connection.get_peer_certificate().get_subject().commonName: transport = connection.get_app_data() transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint))) + return + # Store TLS session after successful handshake (PyOpenSSL) + if self.ssl_session_cache is not None: + try: + session = connection.get_session() + if session: + self.ssl_session_cache.set( + self.endpoint.tls_session_cache_key, session) + if connection.session_reused(): + log.debug("TLS session was reused for %s", self.endpoint) + except Exception: + log.debug("Could not cache TLS session for %s", self.endpoint) def clientConnectionForTLS(self, tlsProtocol): connection = SSL.Connection(self.context, None) connection.set_app_data(tlsProtocol) if self.ssl_options and "server_hostname" in self.ssl_options: connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + + # Apply cached TLS session for resumption (PyOpenSSL) + if self.ssl_session_cache is not None: + cached_session = self.ssl_session_cache.get( + self.endpoint.tls_session_cache_key) + if cached_session: + try: + connection.set_session(cached_session) + log.debug("Using cached TLS session for %s", self.endpoint) + except Exception: + log.debug("Could not restore TLS session for %s", self.endpoint) + return connection @@ -241,6 +266,7 @@ def add_connection(self): self.ssl_options, self._check_hostname, self.connect_timeout, + ssl_session_cache=self._ssl_session_cache, ) endpoint = SSL4ClientEndpoint( diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py index 8ba9ca5b1d..41a72aea33 100644 --- a/tests/unit/io/test_twistedreactor.py +++ b/tests/unit/io/test_twistedreactor.py @@ -13,9 +13,9 @@ # limitations under the License. import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock -from cassandra.connection import DefaultEndPoint +from cassandra.connection import DefaultEndPoint, SSLSessionCache try: from twisted.test import proto_helpers @@ -197,3 +197,91 @@ def test_push(self, mock_connectTCP): self.obj_ut.push('123 pickup') self.mock_reactor_cft.assert_called_with( transport_mock.write, '123 pickup') + + +class TestSSLCreatorInfoCallback(unittest.TestCase): + """Verify that _SSLCreator.info_callback does not cache TLS sessions + when hostname verification fails.""" + + def setUp(self): + if twistedreactor is None: + raise unittest.SkipTest("Twisted libraries not available") + from OpenSSL import SSL + self.SSL = SSL + + def _make_creator(self, check_hostname, endpoint_address, cert_cn, + ssl_session_cache=None): + from cassandra.io.twistedreactor import _SSLCreator + endpoint = Mock() + endpoint.address = endpoint_address + endpoint.tls_session_cache_key = (endpoint_address, 9042) + ssl_ctx = Mock() + creator = _SSLCreator( + endpoint=endpoint, + ssl_context=ssl_ctx, + ssl_options=None, + check_hostname=check_hostname, + timeout=5, + ssl_session_cache=ssl_session_cache, + ) + + # Build a mock OpenSSL connection + connection = Mock() + subject = Mock() + subject.commonName = cert_cn + cert = Mock() + cert.get_subject.return_value = subject + connection.get_peer_certificate.return_value = cert + session = Mock() + connection.get_session.return_value = session + connection.session_reused.return_value = False + transport = Mock() + connection.get_app_data.return_value = transport + + return creator, connection, transport, session + + def test_hostname_mismatch_does_not_cache(self): + """When hostname verification fails, the session must NOT be cached.""" + cache = SSLSessionCache() + creator, connection, transport, session = self._make_creator( + check_hostname=True, + endpoint_address='good.example.com', + cert_cn='evil.example.com', + ssl_session_cache=cache, + ) + + creator.info_callback(connection, self.SSL.SSL_CB_HANDSHAKE_DONE, 0) + + transport.failVerification.assert_called_once() + assert cache.size() == 0, "Session was cached despite hostname mismatch" + + def test_hostname_match_caches_session(self): + """When hostname matches, the session should be cached.""" + cache = SSLSessionCache() + creator, connection, transport, session = self._make_creator( + check_hostname=True, + endpoint_address='good.example.com', + cert_cn='good.example.com', + ssl_session_cache=cache, + ) + + creator.info_callback(connection, self.SSL.SSL_CB_HANDSHAKE_DONE, 0) + + transport.failVerification.assert_not_called() + assert cache.size() == 1 + assert cache.get(('good.example.com', 9042)) is session + + def test_no_check_hostname_caches_session(self): + """When check_hostname is False, always cache regardless of CN.""" + cache = SSLSessionCache() + creator, connection, transport, session = self._make_creator( + check_hostname=False, + endpoint_address='good.example.com', + cert_cn='evil.example.com', + ssl_session_cache=cache, + ) + + creator.info_callback(connection, self.SSL.SSL_CB_HANDSHAKE_DONE, 0) + + transport.failVerification.assert_not_called() + assert cache.size() == 1 diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 4942fd4d69..0ffc742117 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -23,6 +23,7 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import SSLSessionCache from cassandra.pool import Host from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory @@ -634,3 +635,53 @@ def test_no_warning_adding_lbp_ep_to_cluster_with_contact_points(self): ) patched_logger.warning.assert_not_called() + + +class TestSSLSessionCacheAutoCreation(unittest.TestCase): + + def test_cache_created_when_ssl_context_set(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx) + assert isinstance(cluster.ssl_session_cache, SSLSessionCache) + + def test_no_cache_when_only_ssl_options_set(self): + cluster = Cluster(contact_points=['127.0.0.1'], ssl_options={'ca_certs': '/dev/null'}) + assert cluster.ssl_session_cache is None + + def test_no_cache_when_tls_not_enabled(self): + cluster = Cluster(contact_points=['127.0.0.1']) + assert cluster.ssl_session_cache is None + + def test_explicit_none_disables_cache(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx, + ssl_session_cache=None) + assert cluster.ssl_session_cache is None + + def test_explicit_custom_cache_used(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + custom = SSLSessionCache() + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx, + ssl_session_cache=custom) + assert cluster.ssl_session_cache is custom + + def test_cache_passed_to_connection_factory(self): + import ssl + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + endpoint = Mock(address='127.0.0.1') + with patch.object(Cluster.connection_class, 'factory', autospec=True, return_value='connection') as factory: + cluster = Cluster(contact_points=['127.0.0.1'], ssl_context=ctx) + cluster.connection_factory(endpoint) + + assert factory.call_args.kwargs['ssl_session_cache'] is cluster.ssl_session_cache diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 6ac63ff761..820ebb1c78 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,7 +22,9 @@ from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator, + SniEndPoint, + SSLSessionCache) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) @@ -571,3 +573,410 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange): second_run = list(itertools.islice(gen.generate(0, 2), 5)) assert first_run == second_run + + +class TestSSLSessionCache(unittest.TestCase): + + @staticmethod + def _key(address, port, server_hostname=None): + return (address, port, server_hostname) + + def test_get_returns_none_when_empty(self): + cache = SSLSessionCache() + assert cache.get(self._key('127.0.0.1', 9042)) is None + + def test_set_and_get(self): + cache = SSLSessionCache() + session = object() # stand-in for ssl.SSLSession + cache.set(self._key('127.0.0.1', 9042), session) + assert cache.get(self._key('127.0.0.1', 9042)) is session + + def test_different_keys_are_independent(self): + cache = SSLSessionCache() + s1 = object() + s2 = object() + cache.set(self._key('127.0.0.1', 9042), s1) + cache.set(self._key('127.0.0.2', 9042), s2) + assert cache.get(self._key('127.0.0.1', 9042)) is s1 + assert cache.get(self._key('127.0.0.2', 9042)) is s2 + assert cache.get(self._key('127.0.0.1', 9043)) is None + + def test_sni_keys_are_independent_for_same_proxy(self): + cache = SSLSessionCache() + s1 = object() + s2 = object() + + cache.set(self._key('proxy.example.com', 9042, 'node-a'), s1) + cache.set(self._key('proxy.example.com', 9042, 'node-b'), s2) + + assert cache.get(self._key('proxy.example.com', 9042, 'node-a')) is s1 + assert cache.get(self._key('proxy.example.com', 9042, 'node-b')) is s2 + + def test_overwrite_existing_entry(self): + cache = SSLSessionCache() + old = object() + new = object() + cache.set(self._key('127.0.0.1', 9042), old) + cache.set(self._key('127.0.0.1', 9042), new) + assert cache.get(self._key('127.0.0.1', 9042)) is new + + def test_thread_safety(self): + """Concurrent set/get operations must not raise.""" + import threading + cache = SSLSessionCache() + errors = [] + + def writer(addr_suffix): + try: + for i in range(200): + cache.set(self._key('127.0.0.%d' % addr_suffix, 9042), object()) + except Exception as e: + errors.append(e) + + def reader(addr_suffix): + try: + for i in range(200): + cache.get(self._key('127.0.0.%d' % addr_suffix, 9042)) + except Exception as e: + errors.append(e) + + threads = [] + for n in range(5): + threads.append(threading.Thread(target=writer, args=(n,))) + threads.append(threading.Thread(target=reader, args=(n,))) + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [] + + def test_ttl_expiration(self): + """Sessions expire after TTL.""" + cache = SSLSessionCache(max_size=10, ttl=1) + session = object() + cache.set(self._key('127.0.0.1', 9042), session) + assert cache.get(self._key('127.0.0.1', 9042)) is session + + time.sleep(1.1) + assert cache.get(self._key('127.0.0.1', 9042)) is None + + def test_lru_eviction(self): + """LRU eviction when cache reaches max_size.""" + cache = SSLSessionCache(max_size=3, ttl=60) + + s1, s2, s3, s4 = object(), object(), object(), object() + cache.set(self._key('host1', 9042), s1) + cache.set(self._key('host2', 9042), s2) + cache.set(self._key('host3', 9042), s3) + assert cache.size() == 3 + + # Access host2 to make it recently used + cache.get(self._key('host2', 9042)) + + # Adding host4 should evict host1 (LRU) + cache.set(self._key('host4', 9042), s4) + assert cache.size() == 3 + assert cache.get(self._key('host1', 9042)) is None + assert cache.get(self._key('host2', 9042)) is s2 + assert cache.get(self._key('host3', 9042)) is s3 + assert cache.get(self._key('host4', 9042)) is s4 + + def test_none_session_not_cached(self): + """None sessions should be silently ignored.""" + cache = SSLSessionCache() + cache.set(self._key('127.0.0.1', 9042), None) + assert cache.size() == 0 + + def test_clear(self): + """clear() removes all entries.""" + cache = SSLSessionCache() + cache.set(self._key('host1', 9042), object()) + cache.set(self._key('host2', 9042), object()) + assert cache.size() == 2 + cache.clear() + assert cache.size() == 0 + + def test_clear_expired(self): + """clear_expired() removes only expired entries.""" + cache = SSLSessionCache(max_size=10, ttl=1) + cache.set(self._key('host1', 9042), object()) + time.sleep(1.1) + cache.set(self._key('host2', 9042), object()) + assert cache.size() == 2 + cache.clear_expired() + assert cache.size() == 1 + assert cache.get(self._key('host1', 9042)) is None + assert cache.get(self._key('host2', 9042)) is not None + + def test_automatic_expired_cleanup(self): + """Expired sessions are cleaned during set() periodically.""" + cache = SSLSessionCache(max_size=10, ttl=1) + cache._EXPIRY_CLEANUP_INTERVAL = 5 + + for i in range(3): + cache.set(self._key('host%d' % i, 9042), object()) + assert cache.size() == 3 + + time.sleep(1.1) + + # Add sessions until cleanup triggers (at 5 operations) + for i in range(5): + cache.set(self._key('new%d' % i, 9042), object()) + + # Expired sessions should have been cleaned + assert cache.size() == 5 + + def test_custom_max_size_and_ttl(self): + """Cache respects custom max_size and ttl parameters.""" + cache = SSLSessionCache(max_size=50, ttl=7200) + assert cache.max_size == 50 + assert cache.ttl == 7200 + + def test_max_size_zero_raises(self): + """max_size=0 must raise ValueError.""" + with self.assertRaises(ValueError): + SSLSessionCache(max_size=0) + + def test_max_size_negative_raises(self): + """Negative max_size must raise ValueError.""" + with self.assertRaises(ValueError): + SSLSessionCache(max_size=-1) + + def test_ttl_zero_raises(self): + """ttl=0 must raise ValueError.""" + with self.assertRaises(ValueError): + SSLSessionCache(ttl=0) + + def test_ttl_negative_raises(self): + """Negative ttl must raise ValueError.""" + with self.assertRaises(ValueError): + SSLSessionCache(ttl=-10) + + def test_max_size_one_works(self): + """max_size=1 is the smallest valid cache — ensure it works.""" + cache = SSLSessionCache(max_size=1, ttl=60) + s1, s2 = object(), object() + cache.set(self._key('host1', 9042), s1) + assert cache.get(self._key('host1', 9042)) is s1 + # Adding a second entry should evict the first + cache.set(self._key('host2', 9042), s2) + assert cache.size() == 1 + assert cache.get(self._key('host1', 9042)) is None + assert cache.get(self._key('host2', 9042)) is s2 + + +class TestEndPointTLSSessionCacheKey(unittest.TestCase): + """Tests for tls_session_cache_key on endpoint classes.""" + + def test_default_endpoint_key(self): + endpoint = DefaultEndPoint('10.0.0.1', 9042) + assert endpoint.tls_session_cache_key == ('10.0.0.1', 9042) + + def test_default_endpoint_different_ports(self): + ep1 = DefaultEndPoint('10.0.0.1', 9042) + ep2 = DefaultEndPoint('10.0.0.1', 9043) + assert ep1.tls_session_cache_key != ep2.tls_session_cache_key + + def test_sni_endpoint_includes_server_name(self): + ep1 = SniEndPoint('proxy.example.com', 'server1', 9042) + ep2 = SniEndPoint('proxy.example.com', 'server2', 9042) + assert ep1.tls_session_cache_key == ('proxy.example.com', 9042, 'server1') + assert ep2.tls_session_cache_key == ('proxy.example.com', 9042, 'server2') + assert ep1.tls_session_cache_key != ep2.tls_session_cache_key + + def test_unix_socket_endpoint_key(self): + from cassandra.connection import UnixSocketEndPoint + ep = UnixSocketEndPoint('/var/run/scylla.sock') + assert ep.tls_session_cache_key == ('/var/run/scylla.sock',) + + def test_unix_socket_different_paths(self): + from cassandra.connection import UnixSocketEndPoint + ep1 = UnixSocketEndPoint('/var/run/scylla.sock') + ep2 = UnixSocketEndPoint('/tmp/scylla.sock') + assert ep1.tls_session_cache_key != ep2.tls_session_cache_key + + +class TestConnectionSSLSessionRestore(unittest.TestCase): + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_restores_cached_session(self, _send, _connect): + """_wrap_socket_from_context sets ssl_sock.session from cache.""" + import ssl as _ssl + + mock_ssl_sock = Mock() + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + cached = Mock(name='cached_session') + cache = SSLSessionCache() + cache.set(('10.0.0.1', 9042), cached) + + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {} + conn._ssl_session_cache = cache + + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + assert mock_ssl_sock.session == cached + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_tolerates_missing_cache(self, _send, _connect): + """No error when _ssl_session_cache is None.""" + import ssl as _ssl + + mock_ssl_sock = Mock() + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {} + conn._ssl_session_cache = None + + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_handles_set_session_failure(self, _send, _connect): + """If setting session raises ssl.SSLError, it is silently ignored.""" + import ssl as _ssl + + mock_ssl_sock = Mock() + type(mock_ssl_sock).session = property( + fget=lambda self: None, + fset=Mock(side_effect=_ssl.SSLError("bad session")), + ) + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + cache = SSLSessionCache() + cache.set(('10.0.0.1', 9042), Mock(name='bad_cached')) + + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {} + conn._ssl_session_cache = cache + + # Should NOT raise + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_handles_value_error_on_different_context(self, _send, _connect): + """If setting session raises ValueError (different SSLContext), it is silently ignored.""" + import ssl as _ssl + + mock_ssl_sock = Mock() + type(mock_ssl_sock).session = property( + fget=lambda self: None, + fset=Mock(side_effect=ValueError("Session refers to a different SSLContext")), + ) + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + cache = SSLSessionCache() + cache.set(('10.0.0.1', 9042), Mock(name='stale_cached')) + + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {} + conn._ssl_session_cache = cache + + # Should NOT raise — falls back to full handshake + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + + @patch.object(Connection, '_connect_socket') + @patch.object(Connection, '_send_options_message') + def test_wrap_socket_uses_sni_specific_cached_session(self, _send, _connect): + import ssl as _ssl + + mock_ssl_sock = Mock() + mock_ctx = Mock(spec=_ssl.SSLContext) + mock_ctx.check_hostname = False + mock_ctx.wrap_socket.return_value = mock_ssl_sock + + expected = Mock(name='node_b_session') + cache = SSLSessionCache() + cache.set(('proxy.example.com', 9042, 'node-a'), Mock(name='node_a_session')) + cache.set(('proxy.example.com', 9042, 'node-b'), expected) + + conn = Connection.__new__(Connection) + conn.endpoint = SniEndPoint('proxy.example.com', 'node-b', 9042) + conn.ssl_context = mock_ctx + conn.ssl_options = {'server_hostname': 'node-b'} + conn._ssl_session_cache = cache + + result = conn._wrap_socket_from_context() + assert result is mock_ssl_sock + assert mock_ssl_sock.session == expected + + +class TestConnectionCacheTLSSession(unittest.TestCase): + + def _make_conn(self): + conn = Connection.__new__(Connection) + conn.endpoint = DefaultEndPoint('10.0.0.1', 9042) + conn.ssl_context = Mock() + conn._ssl_session_cache = SSLSessionCache() + conn._socket = Mock() + return conn + + def test_cache_tls_session_stores_session(self): + conn = self._make_conn() + fake_session = Mock(name='ssl_session') + conn._socket.session = fake_session + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('10.0.0.1', 9042)) is fake_session + + def test_cache_tls_session_no_op_when_session_none(self): + conn = self._make_conn() + conn._socket.session = None + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('10.0.0.1', 9042)) is None + + def test_cache_tls_session_no_op_when_cache_none(self): + conn = self._make_conn() + conn._ssl_session_cache = None + conn._socket.session = Mock() + + # Should not raise + conn._cache_tls_session_if_needed() + + def test_cache_tls_session_no_op_when_no_ssl_context(self): + conn = self._make_conn() + conn.ssl_context = None + conn._socket.session = Mock() + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('10.0.0.1', 9042)) is None + + def test_cache_tls_session_uses_sni_specific_key(self): + conn = Connection.__new__(Connection) + conn.endpoint = SniEndPoint('proxy.example.com', 'node-b', 9042) + conn.ssl_context = Mock() + conn.ssl_options = {'server_hostname': 'node-b'} + conn._ssl_session_cache = SSLSessionCache() + conn._socket = Mock() + fake_session = Mock(name='ssl_session') + conn._socket.session = fake_session + + conn._cache_tls_session_if_needed() + assert conn._ssl_session_cache.get(('proxy.example.com', 9042, 'node-b')) is fake_session + assert conn._ssl_session_cache.get(('proxy.example.com', 9042, 'node-a')) is None