From cb2307a84bc96b4fc6efa0cb9d12555efce93b7b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 8 Jun 2026 14:08:32 -0400 Subject: [PATCH 1/6] PYTHON-5867 - Close sockets on interruption or cancellation during async connection creation --- pymongo/network_layer.py | 10 ++++++ pymongo/pool_shared.py | 75 +++++++++++++++++++++++++--------------- 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 7c62a251f8..ce8d4b6392 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -437,6 +437,16 @@ def get_conn(self) -> PyMongoProtocol: def sock(self) -> socket.socket: return self.conn[0].get_extra_info("socket") + def __del__(self) -> None: + # Synchronously release the raw socket in case the event loop is already closed + # or this connection was orphaned. + # Safe even if asyncio has already closed the socket. + try: + if self.sock is not None: + self.sock.close() + except Exception: # noqa: S110 + pass + class NetworkingInterface(NetworkingInterfaceBase): def __init__(self, conn: Union[socket.socket, _sslConn]): diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index a6f434885b..8a7eb33a6d 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -182,7 +182,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s sock.setblocking(False) await asyncio.get_running_loop().sock_connect(sock, host) return sock - except OSError: + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak sock.close() raise @@ -231,6 +232,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s except OSError as e: sock.close() err = e # type: ignore[assignment] + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise if err is not None: raise err @@ -282,19 +287,25 @@ async def _async_configured_socket( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise + try: + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore] - except _CertificateError: - ssl_sock.close() - raise - ssl_sock.settimeout(options.socket_timeout) - return ssl_sock + ssl_sock.settimeout(options.socket_timeout) + return ssl_sock + except BaseException: + # Protect against cancellation, _CertificateError, or interruption + # where the raw socket would otherwise leak. + ssl_sock.close() + raise async def _configured_protocol_interface( @@ -311,11 +322,16 @@ async def _configured_protocol_interface( timeout = options.socket_timeout if ssl_context is None: - return AsyncNetworkingInterface( - await asyncio.get_running_loop().create_connection( - lambda: PyMongoProtocol(timeout=timeout), sock=sock + try: + return AsyncNetworkingInterface( + await asyncio.get_running_loop().create_connection( + lambda: PyMongoProtocol(timeout=timeout), sock=sock + ) ) - ) + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise host = address[0] try: @@ -337,18 +353,23 @@ async def _configured_protocol_interface( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - if ( - ssl_context.verify_mode - and not ssl_context.check_hostname - and not options.tls_allow_invalid_hostnames - ): - try: + except BaseException: + # Protect against cancellation or interruption where the raw socket would otherwise leak + sock.close() + raise + try: + if ( + ssl_context.verify_mode + and not ssl_context.check_hostname + and not options.tls_allow_invalid_hostnames + ): ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore] - except _CertificateError: - transport.abort() - raise - - return AsyncNetworkingInterface((transport, protocol)) + return AsyncNetworkingInterface((transport, protocol)) + except BaseException: + # Protect against cancellation, _CertificateError, or interruption + # where the transport would otherwise leak. + transport.abort() + raise def _create_connection(address: _Address, options: PoolOptions) -> socket.socket: From 04996429b2df18755e5afeb2014ecbeb66c4dd09 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 8 Jun 2026 16:35:50 -0400 Subject: [PATCH 2/6] Fix in-use socket re-use --- pymongo/network_layer.py | 10 ----- pymongo/pool_shared.py | 17 ++------ test/asynchronous/test_async_cancellation.py | 42 ++++++++++++++++++++ 3 files changed, 46 insertions(+), 23 deletions(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index ce8d4b6392..7c62a251f8 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -437,16 +437,6 @@ def get_conn(self) -> PyMongoProtocol: def sock(self) -> socket.socket: return self.conn[0].get_extra_info("socket") - def __del__(self) -> None: - # Synchronously release the raw socket in case the event loop is already closed - # or this connection was orphaned. - # Safe even if asyncio has already closed the socket. - try: - if self.sock is not None: - self.sock.close() - except Exception: # noqa: S110 - pass - class NetworkingInterface(NetworkingInterfaceBase): def __init__(self, conn: Union[socket.socket, _sslConn]): diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index 8a7eb33a6d..74dc2b5c10 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -322,16 +322,11 @@ async def _configured_protocol_interface( timeout = options.socket_timeout if ssl_context is None: - try: - return AsyncNetworkingInterface( - await asyncio.get_running_loop().create_connection( - lambda: PyMongoProtocol(timeout=timeout), sock=sock - ) + return AsyncNetworkingInterface( + await asyncio.get_running_loop().create_connection( + lambda: PyMongoProtocol(timeout=timeout), sock=sock ) - except BaseException: - # Protect against cancellation or interruption where the raw socket would otherwise leak - sock.close() - raise + ) host = address[0] try: @@ -353,10 +348,6 @@ async def _configured_protocol_interface( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - except BaseException: - # Protect against cancellation or interruption where the raw socket would otherwise leak - sock.close() - raise try: if ( ssl_context.verify_mode diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index f450ea23cc..2c62e4cde3 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -16,14 +16,18 @@ from __future__ import annotations import asyncio +import socket as _socket import sys from test.asynchronous.utils import async_get_pool from test.utils_shared import delay, one +from unittest.mock import patch sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, connected +from pymongo import pool_shared + class TestAsyncCancellation(AsyncIntegrationTest): async def test_async_cancellation_closes_connection(self): @@ -127,3 +131,41 @@ async def task(): await task self.assertTrue(change_stream._closed) + + async def test_cancellation_closes_socket_during_create_connection(self): + address = (await async_client_context.host, await async_client_context.port) + options = (await async_get_pool(self.client)).opts + + created_sockets: list[_socket.socket] = [] + real_socket_cls = _socket.socket + + def tracking_socket(*args, **kwargs): + s = real_socket_cls(*args, **kwargs) + created_sockets.append(s) + return s + + loop = asyncio.get_running_loop() + started = asyncio.Event() + block_forever = asyncio.Event() + + async def slow_sock_connect(sock, addr): + started.set() + await block_forever.wait() + + with ( + patch.object(_socket, "socket", tracking_socket), + patch.object(loop, "sock_connect", slow_sock_connect), + ): + task = asyncio.create_task(pool_shared._async_create_connection(address, options)) + await asyncio.wait_for(started.wait(), timeout=5) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(created_sockets, "expected at least one socket to be created") + for sock in created_sockets: + self.assertEqual( + sock.fileno(), + -1, + f"socket leaked across cancellation: {sock!r}", + ) From 7822321c5df60175c5633edec0a9f1a7cdca24e4 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 9 Jun 2026 09:12:09 -0400 Subject: [PATCH 3/6] Isolate test patching Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- test/asynchronous/test_async_cancellation.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 2c62e4cde3..2425978e1c 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -138,30 +138,36 @@ async def test_cancellation_closes_socket_during_create_connection(self): created_sockets: list[_socket.socket] = [] real_socket_cls = _socket.socket + target_task = None def tracking_socket(*args, **kwargs): s = real_socket_cls(*args, **kwargs) - created_sockets.append(s) + if asyncio.current_task() is target_task: + created_sockets.append(s) return s loop = asyncio.get_running_loop() + real_sock_connect = loop.sock_connect started = asyncio.Event() block_forever = asyncio.Event() async def slow_sock_connect(sock, addr): - started.set() - await block_forever.wait() + if asyncio.current_task() is target_task: + started.set() + await block_forever.wait() + return None + return await real_sock_connect(sock, addr) with ( patch.object(_socket, "socket", tracking_socket), patch.object(loop, "sock_connect", slow_sock_connect), ): task = asyncio.create_task(pool_shared._async_create_connection(address, options)) + target_task = task await asyncio.wait_for(started.wait(), timeout=5) task.cancel() with self.assertRaises(asyncio.CancelledError): await task - self.assertTrue(created_sockets, "expected at least one socket to be created") for sock in created_sockets: self.assertEqual( From e5930f0c23598996c34e2bf689dd6be602e5b9e6 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 9 Jun 2026 10:11:53 -0400 Subject: [PATCH 4/6] Fix test --- test/asynchronous/test_async_cancellation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 2425978e1c..0cac2f8210 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -152,7 +152,7 @@ def tracking_socket(*args, **kwargs): block_forever = asyncio.Event() async def slow_sock_connect(sock, addr): - if asyncio.current_task() is target_task: + if sock in created_sockets: started.set() await block_forever.wait() return None From 784c6cd10c5e7ba1f59a0ae101ef2ac0d6649240 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 9 Jun 2026 15:40:12 -0400 Subject: [PATCH 5/6] More tests --- pymongo/pool_shared.py | 19 +++- test/asynchronous/test_async_cancellation.py | 111 +++++++++++++++++++ 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index 74dc2b5c10..e6ee439573 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -322,11 +322,17 @@ async def _configured_protocol_interface( timeout = options.socket_timeout if ssl_context is None: - return AsyncNetworkingInterface( - await asyncio.get_running_loop().create_connection( - lambda: PyMongoProtocol(timeout=timeout), sock=sock + try: + return AsyncNetworkingInterface( + await asyncio.get_running_loop().create_connection( + lambda: PyMongoProtocol(timeout=timeout), sock=sock + ) ) - ) + except BaseException: + # Protect against cancellation or interruption before the transport + # takes ownership of the raw socket. + sock.close() + raise host = address[0] try: @@ -348,6 +354,11 @@ async def _configured_protocol_interface( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) + except BaseException: + # Protect against cancellation or interruption before the transport + # takes ownership of the raw socket. + sock.close() + raise try: if ( ssl_context.verify_mode diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 0cac2f8210..9c870d4aab 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -16,7 +16,10 @@ from __future__ import annotations import asyncio +import contextlib +import functools import socket as _socket +import ssl as _ssl import sys from test.asynchronous.utils import async_get_pool from test.utils_shared import delay, one @@ -175,3 +178,111 @@ async def slow_sock_connect(sock, addr): -1, f"socket leaked across cancellation: {sock!r}", ) + + async def _assert_cancellation_closes_socket( + self, + *, + connection_creator, + loop_method_name, + make_slow, + ssl_context=None, + ): + address = (await async_client_context.host, await async_client_context.port) + options = (await async_get_pool(self.client)).opts + + created_sockets = [] + real_socket_cls = _socket.socket + target_task = None + + def is_target(): + return asyncio.current_task() is target_task + + def tracked_socket(*args, **kwargs): + s = real_socket_cls(*args, **kwargs) + if is_target(): + created_sockets.append(s) + return s + + loop = asyncio.get_running_loop() + started = asyncio.Event() + block_forever = asyncio.Event() + slow_method = make_slow(getattr(loop, loop_method_name), started, block_forever, is_target) + + with contextlib.ExitStack() as stack: + stack.enter_context(patch.object(_socket, "socket", tracked_socket)) + stack.enter_context(patch.object(loop, loop_method_name, slow_method)) + if ssl_context is not None: + stack.enter_context(patch.object(options, "_PoolOptions__ssl_context", ssl_context)) + task = asyncio.create_task(connection_creator(address, options)) + target_task = task + await asyncio.wait_for(started.wait(), timeout=5) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(created_sockets, "expected at least one socket to be created") + for sock in created_sockets: + self.assertEqual( + sock.fileno(), + -1, + f"socket leaked across cancellation: {sock!r}", + ) + + async def test_cancellation_closes_socket_during_protocol_create_connection(self): + def make_slow(real, started, block_forever, is_target): + async def slow_create_connection(*args, **kwargs): + if is_target(): + started.set() + await block_forever.wait() + return await real(*args, **kwargs) + + return slow_create_connection + + await self._assert_cancellation_closes_socket( + connection_creator=pool_shared._configured_protocol_interface, + loop_method_name="create_connection", + make_slow=make_slow, + ) + + async def test_cancellation_closes_socket_during_ssl_wrap_socket(self): + fake_ssl_context = _ssl.create_default_context() + + def make_slow(real, started, _, is_target): + def slow_run_in_executor(executor, func, *args): + # Need to unwrap the SNI branch here if present + inner = func.func if isinstance(func, functools.partial) else func + # Each `ctx.wrap_socket` access returns a fresh bound-method + # object, so we check the bound instance (__self__) instead + if getattr(inner, "__self__", None) is fake_ssl_context and is_target(): + started.set() + # Return a future that never completes for cancellation. + return asyncio.get_running_loop().create_future() + return real(executor, func, *args) + + return slow_run_in_executor + + await self._assert_cancellation_closes_socket( + connection_creator=pool_shared._async_configured_socket, + loop_method_name="run_in_executor", + make_slow=make_slow, + ssl_context=fake_ssl_context, + ) + + async def test_cancellation_closes_socket_during_ssl_create_connection(self): + fake_ssl_context = _ssl.create_default_context() + + def make_slow(real, started, block_forever, is_target): + async def slow_create_connection(*args, **kwargs): + if kwargs.get("ssl") is fake_ssl_context and is_target(): + started.set() + await block_forever.wait() + return await real(*args, **kwargs) + + return slow_create_connection + + await self._assert_cancellation_closes_socket( + connection_creator=pool_shared._configured_protocol_interface, + loop_method_name="create_connection", + make_slow=make_slow, + ssl_context=fake_ssl_context, + ) From 5b92f742bd2303b8093bd70865adba925fe8b889 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 10 Jun 2026 10:36:59 -0400 Subject: [PATCH 6/6] Remove racy closes --- pymongo/pool_shared.py | 19 +--- test/asynchronous/test_async_cancellation.py | 110 +++++-------------- 2 files changed, 31 insertions(+), 98 deletions(-) diff --git a/pymongo/pool_shared.py b/pymongo/pool_shared.py index e6ee439573..74dc2b5c10 100644 --- a/pymongo/pool_shared.py +++ b/pymongo/pool_shared.py @@ -322,17 +322,11 @@ async def _configured_protocol_interface( timeout = options.socket_timeout if ssl_context is None: - try: - return AsyncNetworkingInterface( - await asyncio.get_running_loop().create_connection( - lambda: PyMongoProtocol(timeout=timeout), sock=sock - ) + return AsyncNetworkingInterface( + await asyncio.get_running_loop().create_connection( + lambda: PyMongoProtocol(timeout=timeout), sock=sock ) - except BaseException: - # Protect against cancellation or interruption before the transport - # takes ownership of the raw socket. - sock.close() - raise + ) host = address[0] try: @@ -354,11 +348,6 @@ async def _configured_protocol_interface( # mismatch, will be turned into ServerSelectionTimeoutErrors later. details = _get_timeout_details(options) _raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details) - except BaseException: - # Protect against cancellation or interruption before the transport - # takes ownership of the raw socket. - sock.close() - raise try: if ( ssl_context.verify_mode diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 9c870d4aab..c6ddf277a2 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -16,7 +16,6 @@ from __future__ import annotations import asyncio -import contextlib import functools import socket as _socket import ssl as _ssl @@ -179,41 +178,45 @@ async def slow_sock_connect(sock, addr): f"socket leaked across cancellation: {sock!r}", ) - async def _assert_cancellation_closes_socket( - self, - *, - connection_creator, - loop_method_name, - make_slow, - ssl_context=None, - ): + async def test_cancellation_closes_socket_during_ssl_wrap_socket(self): address = (await async_client_context.host, await async_client_context.port) options = (await async_get_pool(self.client)).opts + fake_ssl_context = _ssl.create_default_context() - created_sockets = [] + created_sockets: list[_socket.socket] = [] real_socket_cls = _socket.socket target_task = None - def is_target(): - return asyncio.current_task() is target_task - - def tracked_socket(*args, **kwargs): + def tracking_socket(*args, **kwargs): s = real_socket_cls(*args, **kwargs) - if is_target(): + if asyncio.current_task() is target_task: created_sockets.append(s) return s loop = asyncio.get_running_loop() + real_run_in_executor = loop.run_in_executor started = asyncio.Event() - block_forever = asyncio.Event() - slow_method = make_slow(getattr(loop, loop_method_name), started, block_forever, is_target) - - with contextlib.ExitStack() as stack: - stack.enter_context(patch.object(_socket, "socket", tracked_socket)) - stack.enter_context(patch.object(loop, loop_method_name, slow_method)) - if ssl_context is not None: - stack.enter_context(patch.object(options, "_PoolOptions__ssl_context", ssl_context)) - task = asyncio.create_task(connection_creator(address, options)) + + def slow_run_in_executor(executor, func, *args): + # Need to unwrap the SNI branch here if present + inner = func.func if isinstance(func, functools.partial) else func + # Each `ctx.wrap_socket` access returns a fresh bound-method + # object, so we check the bound instance (__self__) instead + if ( + getattr(inner, "__self__", None) is fake_ssl_context + and asyncio.current_task() is target_task + ): + started.set() + # Return a future that never completes for cancellation. + return asyncio.get_running_loop().create_future() + return real_run_in_executor(executor, func, *args) + + with ( + patch.object(_socket, "socket", tracking_socket), + patch.object(loop, "run_in_executor", slow_run_in_executor), + patch.object(options, "_PoolOptions__ssl_context", fake_ssl_context), + ): + task = asyncio.create_task(pool_shared._async_configured_socket(address, options)) target_task = task await asyncio.wait_for(started.wait(), timeout=5) task.cancel() @@ -227,62 +230,3 @@ def tracked_socket(*args, **kwargs): -1, f"socket leaked across cancellation: {sock!r}", ) - - async def test_cancellation_closes_socket_during_protocol_create_connection(self): - def make_slow(real, started, block_forever, is_target): - async def slow_create_connection(*args, **kwargs): - if is_target(): - started.set() - await block_forever.wait() - return await real(*args, **kwargs) - - return slow_create_connection - - await self._assert_cancellation_closes_socket( - connection_creator=pool_shared._configured_protocol_interface, - loop_method_name="create_connection", - make_slow=make_slow, - ) - - async def test_cancellation_closes_socket_during_ssl_wrap_socket(self): - fake_ssl_context = _ssl.create_default_context() - - def make_slow(real, started, _, is_target): - def slow_run_in_executor(executor, func, *args): - # Need to unwrap the SNI branch here if present - inner = func.func if isinstance(func, functools.partial) else func - # Each `ctx.wrap_socket` access returns a fresh bound-method - # object, so we check the bound instance (__self__) instead - if getattr(inner, "__self__", None) is fake_ssl_context and is_target(): - started.set() - # Return a future that never completes for cancellation. - return asyncio.get_running_loop().create_future() - return real(executor, func, *args) - - return slow_run_in_executor - - await self._assert_cancellation_closes_socket( - connection_creator=pool_shared._async_configured_socket, - loop_method_name="run_in_executor", - make_slow=make_slow, - ssl_context=fake_ssl_context, - ) - - async def test_cancellation_closes_socket_during_ssl_create_connection(self): - fake_ssl_context = _ssl.create_default_context() - - def make_slow(real, started, block_forever, is_target): - async def slow_create_connection(*args, **kwargs): - if kwargs.get("ssl") is fake_ssl_context and is_target(): - started.set() - await block_forever.wait() - return await real(*args, **kwargs) - - return slow_create_connection - - await self._assert_cancellation_closes_socket( - connection_creator=pool_shared._configured_protocol_interface, - loop_method_name="create_connection", - make_slow=make_slow, - ssl_context=fake_ssl_context, - )