diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index a6f434885b..74dc2b5c10 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -182,7 +182,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s sock.setblocking(False) await asyncio.get_running_loop().sock_connect(sock, host) return sock - except OSError: + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak sock.close() raise @@ -231,6 +232,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s except OSError as e: sock.close() err = e # type: ignore[assignment] + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise if err is not None: raise err @@ -282,19 +287,25 @@ async def _async_configured_socket( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise + try: + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] - except _CertificateError: - ssl_sock.close() - raise - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + except BaseException: + # Protect against cancellation, _CertificateError, or interruption + # where the raw socket would otherwise leak. + ssl_sock.close() + raise async def _configured_protocol_interface( @@ -337,18 +348,19 @@ async def _configured_protocol_interface( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: + try: + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] - except _CertificateError: - transport.abort() - raise - - return AsyncNetworkingInterface((transport, protocol)) + return AsyncNetworkingInterface((transport, protocol)) + except BaseException: + # Protect against cancellation, _CertificateError, or interruption + # where the transport would otherwise leak. + transport.abort() + raise def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index f450ea23cc..c6ddf277a2 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -16,14 +16,20 @@ from __future__ import annotations import asyncio +import functools +import socket as _socket +import ssl as _ssl import sys from test.asynchronous.utils import async_get_pool from test.utils_shared import delay, one +from unittest.mock import patch sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, connected +from pymongo import pool_shared + class TestAsyncCancellation(AsyncIntegrationTest): async def test_async_cancellation_closes_connection(self): @@ -127,3 +133,100 @@ async def task(): await task self.assertTrue(change_stream._closed) + + async def test_cancellation_closes_socket_during_create_connection(self): + address = (await async_client_context.host, await async_client_context.port) + options = (await async_get_pool(self.client)).opts + + created_sockets: list[_socket.socket] = [] + real_socket_cls = _socket.socket + target_task = None + + def tracking_socket(*args, **kwargs): + s = real_socket_cls(*args, **kwargs) + if asyncio.current_task() is target_task: + created_sockets.append(s) + return s + + loop = asyncio.get_running_loop() + real_sock_connect = loop.sock_connect + started = asyncio.Event() + block_forever = asyncio.Event() + + async def slow_sock_connect(sock, addr): + if sock in created_sockets: + started.set() + await block_forever.wait() + return None + return await real_sock_connect(sock, addr) + + with ( + patch.object(_socket, "socket", tracking_socket), + patch.object(loop, "sock_connect", slow_sock_connect), + ): + task = asyncio.create_task(pool_shared._async_create_connection(address, options)) + target_task = task + await asyncio.wait_for(started.wait(), timeout=5) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + self.assertTrue(created_sockets, "expected at least one socket to be created") + for sock in created_sockets: + self.assertEqual( + sock.fileno(), + -1, + f"socket leaked across cancellation: {sock!r}", + ) + + async def test_cancellation_closes_socket_during_ssl_wrap_socket(self): + address = (await async_client_context.host, await async_client_context.port) + options = (await async_get_pool(self.client)).opts + fake_ssl_context = _ssl.create_default_context() + + created_sockets: list[_socket.socket] = [] + real_socket_cls = _socket.socket + target_task = None + + def tracking_socket(*args, **kwargs): + s = real_socket_cls(*args, **kwargs) + if asyncio.current_task() is target_task: + created_sockets.append(s) + return s + + loop = asyncio.get_running_loop() + real_run_in_executor = loop.run_in_executor + started = asyncio.Event() + + def slow_run_in_executor(executor, func, *args): + # Need to unwrap the SNI branch here if present + inner = func.func if isinstance(func, functools.partial) else func + # Each `ctx.wrap_socket` access returns a fresh bound-method + # object, so we check the bound instance (__self__) instead + if ( + getattr(inner, "__self__", None) is fake_ssl_context + and asyncio.current_task() is target_task + ): + started.set() + # Return a future that never completes for cancellation. + return asyncio.get_running_loop().create_future() + return real_run_in_executor(executor, func, *args) + + with ( + patch.object(_socket, "socket", tracking_socket), + patch.object(loop, "run_in_executor", slow_run_in_executor), + patch.object(options, "_PoolOptions__ssl_context", fake_ssl_context), + ): + task = asyncio.create_task(pool_shared._async_configured_socket(address, options)) + target_task = task + await asyncio.wait_for(started.wait(), timeout=5) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(created_sockets, "expected at least one socket to be created") + for sock in created_sockets: + self.assertEqual( + sock.fileno(), + -1, + f"socket leaked across cancellation: {sock!r}", + )