Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock.setblocking(False)
await asyncio.get_running_loop().sock_connect(sock, host)
return sock
except OSError:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
Comment thread
NoahStapp marked this conversation as resolved.

Expand Down Expand Up @@ -231,6 +232,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise

if err is not None:
raise err
Expand Down Expand Up @@ -282,19 +287,25 @@ async def _async_configured_socket(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise

ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the raw socket would otherwise leak.
ssl_sock.close()
raise


async def _configured_protocol_interface(
Expand All @@ -311,11 +322,17 @@ async def _configured_protocol_interface(
timeout = options.socket_timeout

if ssl_context is None:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
try:
return AsyncNetworkingInterface(
await asyncio.get_running_loop().create_connection(
lambda: PyMongoProtocol(timeout=timeout), sock=sock
)
)
)
except BaseException:
# Protect against cancellation or interruption before the transport
# takes ownership of the raw socket.
sock.close()
raise

host = address[0]
try:
Expand All @@ -337,18 +354,24 @@ async def _configured_protocol_interface(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
except BaseException:
# Protect against cancellation or interruption before the transport
# takes ownership of the raw socket.
sock.close()
raise
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise

return AsyncNetworkingInterface((transport, protocol))
return AsyncNetworkingInterface((transport, protocol))
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the transport would otherwise leak.
transport.abort()
raise


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
Expand Down
159 changes: 159 additions & 0 deletions test/asynchronous/test_async_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@
from __future__ import annotations

import asyncio
import contextlib
import functools
import socket as _socket
import ssl as _ssl
import sys
from test.asynchronous.utils import async_get_pool
from test.utils_shared import delay, one
from unittest.mock import patch

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, connected

from pymongo import pool_shared


class TestAsyncCancellation(AsyncIntegrationTest):
async def test_async_cancellation_closes_connection(self):
Expand Down Expand Up @@ -127,3 +134,155 @@ async def task():
await task

self.assertTrue(change_stream._closed)

async def test_cancellation_closes_socket_during_create_connection(self):
address = (await async_client_context.host, await async_client_context.port)
options = (await async_get_pool(self.client)).opts

created_sockets: list[_socket.socket] = []
real_socket_cls = _socket.socket
target_task = None

def tracking_socket(*args, **kwargs):
s = real_socket_cls(*args, **kwargs)
if asyncio.current_task() is target_task:
created_sockets.append(s)
return s
Comment on lines +146 to +150

loop = asyncio.get_running_loop()
real_sock_connect = loop.sock_connect
started = asyncio.Event()
block_forever = asyncio.Event()

async def slow_sock_connect(sock, addr):
if sock in created_sockets:
started.set()
await block_forever.wait()
return None
return await real_sock_connect(sock, addr)

with (
patch.object(_socket, "socket", tracking_socket),
patch.object(loop, "sock_connect", slow_sock_connect),
):
task = asyncio.create_task(pool_shared._async_create_connection(address, options))
target_task = task
await asyncio.wait_for(started.wait(), timeout=5)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(created_sockets, "expected at least one socket to be created")
for sock in created_sockets:
self.assertEqual(
sock.fileno(),
-1,
f"socket leaked across cancellation: {sock!r}",
)

async def _assert_cancellation_closes_socket(
self,
*,
connection_creator,
loop_method_name,
make_slow,
ssl_context=None,
):
address = (await async_client_context.host, await async_client_context.port)
options = (await async_get_pool(self.client)).opts

created_sockets = []
real_socket_cls = _socket.socket
target_task = None

def is_target():
return asyncio.current_task() is target_task

Comment on lines +197 to +199
def tracked_socket(*args, **kwargs):
s = real_socket_cls(*args, **kwargs)
if is_target():
created_sockets.append(s)
return s

loop = asyncio.get_running_loop()
started = asyncio.Event()
block_forever = asyncio.Event()
slow_method = make_slow(getattr(loop, loop_method_name), started, block_forever, is_target)

with contextlib.ExitStack() as stack:
stack.enter_context(patch.object(_socket, "socket", tracked_socket))
stack.enter_context(patch.object(loop, loop_method_name, slow_method))
if ssl_context is not None:
stack.enter_context(patch.object(options, "_PoolOptions__ssl_context", ssl_context))
task = asyncio.create_task(connection_creator(address, options))
target_task = task
await asyncio.wait_for(started.wait(), timeout=5)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(created_sockets, "expected at least one socket to be created")
for sock in created_sockets:
self.assertEqual(
sock.fileno(),
-1,
f"socket leaked across cancellation: {sock!r}",
)

async def test_cancellation_closes_socket_during_protocol_create_connection(self):
def make_slow(real, started, block_forever, is_target):
async def slow_create_connection(*args, **kwargs):
if is_target():
started.set()
await block_forever.wait()
return await real(*args, **kwargs)

return slow_create_connection

await self._assert_cancellation_closes_socket(
connection_creator=pool_shared._configured_protocol_interface,
loop_method_name="create_connection",
make_slow=make_slow,
)

async def test_cancellation_closes_socket_during_ssl_wrap_socket(self):
fake_ssl_context = _ssl.create_default_context()

def make_slow(real, started, _, is_target):
def slow_run_in_executor(executor, func, *args):
# Need to unwrap the SNI branch here if present
inner = func.func if isinstance(func, functools.partial) else func
# Each `ctx.wrap_socket` access returns a fresh bound-method
# object, so we check the bound instance (__self__) instead
if getattr(inner, "__self__", None) is fake_ssl_context and is_target():
started.set()
# Return a future that never completes for cancellation.
return asyncio.get_running_loop().create_future()
return real(executor, func, *args)

return slow_run_in_executor

await self._assert_cancellation_closes_socket(
connection_creator=pool_shared._async_configured_socket,
loop_method_name="run_in_executor",
make_slow=make_slow,
ssl_context=fake_ssl_context,
)

async def test_cancellation_closes_socket_during_ssl_create_connection(self):
fake_ssl_context = _ssl.create_default_context()

def make_slow(real, started, block_forever, is_target):
async def slow_create_connection(*args, **kwargs):
if kwargs.get("ssl") is fake_ssl_context and is_target():
started.set()
await block_forever.wait()
return await real(*args, **kwargs)

return slow_create_connection

await self._assert_cancellation_closes_socket(
connection_creator=pool_shared._configured_protocol_interface,
loop_method_name="create_connection",
make_slow=make_slow,
ssl_context=fake_ssl_context,
)
Loading