Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .evergreen/remove-unimplemented-tests.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/bin/bash
PYMONGO=$(dirname "$(cd "$(dirname "$0")" || exit; pwd)")

rm $PYMONGO/test/transactions/legacy/errors-client.json # PYTHON-1894
rm $PYMONGO/test/connection_monitoring/wait-queue-fairness.json # PYTHON-1873
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-application-error.json # PYTHON-4918
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-checkout-error.json # PYTHON-4918
Expand Down
7 changes: 6 additions & 1 deletion pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.client_session import (
AsyncClientSession,
_validate_session_write_concern,
)
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.bulk_shared import (
_COMMANDS,
Expand Down Expand Up @@ -271,6 +274,8 @@ async def write_command(
if bwc.publish:
bwc._start(cmd, request_id, docs)
try:
if bwc.session is not None and bwc.session._starting_transaction:
bwc.session._transaction.set_in_progress()
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
Expand Down
7 changes: 6 additions & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern
from pymongo.asynchronous.client_session import (
AsyncClientSession,
_validate_session_write_concern,
)
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.database import AsyncDatabase
Expand Down Expand Up @@ -258,6 +261,8 @@ async def write_command(
if bwc.publish:
bwc._start(cmd, request_id, op_docs, ns_docs)
try:
if bwc.session is not None and bwc.session._starting_transaction:
bwc.session._transaction.set_in_progress()
reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
Expand Down
8 changes: 7 additions & 1 deletion pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient[
self.attempt = 0
self.client = client
self.has_completed_command = False
self.has_sent_command = False

def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
Expand All @@ -443,6 +444,11 @@ def starting(self) -> bool:
def set_starting(self) -> None:
self.state = _TxnState.STARTING

def set_in_progress(self) -> None:
if self.state == _TxnState.STARTING:
self.has_sent_command = True
self.state = _TxnState.IN_PROGRESS

@property
def pinned_conn(self) -> Optional[AsyncConnection]:
if self.active() and self.conn_mgr:
Expand All @@ -469,6 +475,7 @@ async def reset(self) -> None:
self.recovery_token = None
self.attempt = 0
self.has_completed_command = False
self.has_sent_command = False

def __del__(self) -> None:
if self.conn_mgr:
Expand Down Expand Up @@ -1135,7 +1142,6 @@ def _apply_to(

if self._transaction.state == _TxnState.STARTING:
# First command begins a new transaction.
self._transaction.state = _TxnState.IN_PROGRESS
command["startTransaction"] = True

assert self._transaction.opts
Expand Down
8 changes: 4 additions & 4 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2870,8 +2870,8 @@ async def run(self) -> T:
self._last_error = exc
self._attempt_number += 1

# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
# Revert back to starting state only if the first
# transactional command was never completed.
if (
overloaded
and self._session is not None
Expand Down Expand Up @@ -2921,8 +2921,8 @@ async def run(self) -> T:
self._last_error = exc
if self._last_error is None:
self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
# Revert back to starting state only if the first
# transactional command was never completed.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
Expand Down
2 changes: 2 additions & 0 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ async def command(
unacknowledged = bool(write_concern and not write_concern.acknowledged)
self._raise_if_not_writable(unacknowledged)
try:
if session is not None and session._starting_transaction:
session._transaction.set_in_progress()
return await command(
self,
dbname,
Expand Down
2 changes: 2 additions & 0 deletions pymongo/asynchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ async def run_operation(
if more_to_come:
reply = await conn.receive_message(None)
else:
if operation.session is not None and operation.session._starting_transaction:
operation.session._transaction.set_in_progress()
await conn.send_message(data, max_doc_size)
reply = await conn.receive_message(request_id)

Expand Down
7 changes: 6 additions & 1 deletion pymongo/synchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@
_randint,
)
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
from pymongo.synchronous.client_session import (
ClientSession,
_validate_session_write_concern,
)
from pymongo.synchronous.helpers import _handle_reauth
from pymongo.write_concern import WriteConcern

Expand Down Expand Up @@ -271,6 +274,8 @@ def write_command(
if bwc.publish:
bwc._start(cmd, request_id, docs)
try:
if bwc.session is not None and bwc.session._starting_transaction:
bwc.session._transaction.set_in_progress()
reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
Expand Down
7 changes: 6 additions & 1 deletion pymongo/synchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
from bson.objectid import ObjectId
from bson.raw_bson import RawBSONDocument
from pymongo import _csot, common
from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern
from pymongo.synchronous.client_session import (
ClientSession,
_validate_session_write_concern,
)
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.database import Database
Expand Down Expand Up @@ -258,6 +261,8 @@ def write_command(
if bwc.publish:
bwc._start(cmd, request_id, op_docs, ns_docs)
try:
if bwc.session is not None and bwc.session._starting_transaction:
bwc.session._transaction.set_in_progress()
reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type]
duration = datetime.datetime.now() - bwc.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
Expand Down
8 changes: 7 additions & 1 deletion pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: MongoClient[Any])
self.attempt = 0
self.client = client
self.has_completed_command = False
self.has_sent_command = False

def active(self) -> bool:
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
Expand All @@ -441,6 +442,11 @@ def starting(self) -> bool:
def set_starting(self) -> None:
self.state = _TxnState.STARTING

def set_in_progress(self) -> None:
if self.state == _TxnState.STARTING:
self.has_sent_command = True
self.state = _TxnState.IN_PROGRESS

@property
def pinned_conn(self) -> Optional[Connection]:
if self.active() and self.conn_mgr:
Expand All @@ -467,6 +473,7 @@ def reset(self) -> None:
self.recovery_token = None
self.attempt = 0
self.has_completed_command = False
self.has_sent_command = False

def __del__(self) -> None:
if self.conn_mgr:
Expand Down Expand Up @@ -1131,7 +1138,6 @@ def _apply_to(

if self._transaction.state == _TxnState.STARTING:
# First command begins a new transaction.
self._transaction.state = _TxnState.IN_PROGRESS
command["startTransaction"] = True

assert self._transaction.opts
Expand Down
8 changes: 4 additions & 4 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2860,8 +2860,8 @@ def run(self) -> T:
self._last_error = exc
self._attempt_number += 1

# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
# Revert back to starting state only if the first
# transactional command was never completed.
if (
overloaded
and self._session is not None
Expand Down Expand Up @@ -2911,8 +2911,8 @@ def run(self) -> T:
self._last_error = exc
if self._last_error is None:
self._last_error = exc
# Revert back to starting state if we're in a transaction but haven't completed the first
# command.
# Revert back to starting state only if the first
# transactional command was never completed.
if overloaded and self._session is not None and self._session.in_transaction:
transaction = self._session._transaction
if not transaction.has_completed_command:
Expand Down
2 changes: 2 additions & 0 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ def command(
unacknowledged = bool(write_concern and not write_concern.acknowledged)
self._raise_if_not_writable(unacknowledged)
try:
if session is not None and session._starting_transaction:
session._transaction.set_in_progress()
return command(
self,
dbname,
Expand Down
2 changes: 2 additions & 0 deletions pymongo/synchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def run_operation(
if more_to_come:
reply = conn.receive_message(None)
else:
if operation.session is not None and operation.session._starting_transaction:
operation.session._transaction.set_in_progress()
conn.send_message(data, max_doc_size)
reply = conn.receive_message(request_id)

Expand Down
3 changes: 0 additions & 3 deletions test/asynchronous/test_unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
os.path.join(TEST_PATH, "valid-pass"),
module=__name__,
class_name_prefix="UnifiedTestFormat",
expected_failures=[
"Client side error in command starting transaction", # PYTHON-1894
],
)
)

Expand Down
2 changes: 0 additions & 2 deletions test/asynchronous/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,6 @@ def maybe_skip_test(self, spec):
class_name = self.__class__.__name__.lower()
description = spec["description"].lower()

if "client side error in command starting transaction" in description:
self.skipTest("Implement PYTHON-1894")
if "type=symbol" in description:
self.skipTest("PyMongo does not support the symbol type")
if "timeoutms applied to entire download" in description:
Expand Down
3 changes: 0 additions & 3 deletions test/test_unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
os.path.join(TEST_PATH, "valid-pass"),
module=__name__,
class_name_prefix="UnifiedTestFormat",
expected_failures=[
"Client side error in command starting transaction", # PYTHON-1894
],
)
)

Expand Down
2 changes: 0 additions & 2 deletions test/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,6 @@ def maybe_skip_test(self, spec):
class_name = self.__class__.__name__.lower()
description = spec["description"].lower()

if "client side error in command starting transaction" in description:
self.skipTest("Implement PYTHON-1894")
if "type=symbol" in description:
self.skipTest("PyMongo does not support the symbol type")
if "timeoutms applied to entire download" in description:
Expand Down
Loading