Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
return NetworkTimeout(str(error))
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error


_T = TypeVar("_T")
Expand Down Expand Up @@ -804,15 +809,17 @@ async def callback(session, custom_arg, custom_kwarg=None):
await self.commit_transaction()
except PyMongoError as exc:
last_error = exc
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue

if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise
Expand Down
15 changes: 11 additions & 4 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool:
def _make_timeout_error(error: BaseException) -> PyMongoError:
"""Convert error to a NetworkTimeout or ExecutionTimeout as appropriate."""
if _csot.remaining() is not None:
return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50})
timeout_error: PyMongoError = ExecutionTimeout(
str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}
)
else:
return NetworkTimeout(str(error))
timeout_error = NetworkTimeout(str(error))
if isinstance(error, PyMongoError):
timeout_error._error_labels = error._error_labels.copy()
return timeout_error


_T = TypeVar("_T")
Expand Down Expand Up @@ -800,15 +805,17 @@ def callback(session, custom_arg, custom_kwarg=None):
self.commit_transaction()
except PyMongoError as exc:
last_error = exc
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
if exc.has_error_label(
"UnknownTransactionCommitResult"
) and not _max_time_expired_error(exc):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the commit.
continue

if exc.has_error_label("TransientTransactionError"):
if not _within_time_limit(start_time):
raise _make_timeout_error(last_error) from exc
# Retry the entire transaction.
break
raise
Expand Down
12 changes: 9 additions & 3 deletions test/asynchronous/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,12 @@ async def callback(session):
listener.reset()
async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))

@async_client_context.require_test_commands
@async_client_context.require_transactions
Expand Down Expand Up @@ -534,10 +536,12 @@ async def callback(session):

async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))

@async_client_context.require_test_commands
@async_client_context.require_transactions
Expand Down Expand Up @@ -565,14 +569,16 @@ async def callback(session):

async with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
await s.with_transaction(callback)

# One insert for the callback and two commits (includes the automatic
# retry).
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))

@async_client_context.require_transactions
async def test_callback_not_retried_after_csot_timeout(self):
Expand Down
12 changes: 9 additions & 3 deletions test/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,12 @@ def callback(session):
listener.reset()
with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))

@client_context.require_test_commands
@client_context.require_transactions
Expand Down Expand Up @@ -524,10 +526,12 @@ def callback(session):

with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)

self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"])
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))

@client_context.require_test_commands
@client_context.require_transactions
Expand All @@ -553,14 +557,16 @@ def callback(session):

with client.start_session() as s:
with PatchSessionTimeout(0):
with self.assertRaises(NetworkTimeout):
with self.assertRaises(NetworkTimeout) as context:
s.with_transaction(callback)

# One insert for the callback and two commits (includes the automatic
# retry).
self.assertEqual(
listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"]
)
# Assert that the timeout error has the same labels as the error it wraps.
self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult"))

@client_context.require_transactions
def test_callback_not_retried_after_csot_timeout(self):
Expand Down
Loading