From 285a718f4b4de18b30a468eef06f2243b7830897 Mon Sep 17 00:00:00 2001 From: Pol Alvarez Date: Fri, 29 May 2026 15:04:11 +0200 Subject: [PATCH 1/2] creating a nested tx during claim_one so that sql aborts do not prevent use from setting status to failed --- fastasyncpg/core.py | 10 +++++++--- nbs/00_core.ipynb | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/fastasyncpg/core.py b/fastasyncpg/core.py index 938be07..dbc0b4b 100644 --- a/fastasyncpg/core.py +++ b/fastasyncpg/core.py @@ -734,12 +734,16 @@ async def claim_one(self:Table, status_col='status', pending='pending', complete p = _ClaimCtx() async with self.db.transaction() as txn: tbl = txn.t[self.name] - p.db = txn p.evt = await tbl.claim(where=f'"{status_col}"=$1', where_args=[pending], order_by=order_by) - try: yield p + p.db = txn + try: + async with txn.conn.transaction(): yield p # create save point in case tx is aborted except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc() if p.evt is None: return pk = self.pks[0] new = failed if p.failed else (None if p.retry else completed) stmt = f'UPDATE {self} SET "{status_col}"=$1 WHERE "{pk}"=$2' - if new: await txn.execute(stmt, new, get_field(p.evt, pk)) + if new: + await txn.execute(stmt, new, get_field(p.evt, pk)) + p.evt = await tbl.selectone(f'"{pk}"=$1', [get_field(p.evt, pk)]) + diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index d790f8e..5ae4485 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -5091,15 +5091,18 @@ " p = _ClaimCtx()\n", " async with self.db.transaction() as txn:\n", " tbl = txn.t[self.name]\n", - " p.db = txn\n", " p.evt = await tbl.claim(where=f'\"{status_col}\"=$1', where_args=[pending], order_by=order_by)\n", - " try: yield p\n", + " p.db = txn\n", + " try:\n", + " async with txn.conn.transaction(): yield p # create save point in case tx is aborted\n", " except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc()\n", " if p.evt is None: return\n", " pk = self.pks[0]\n", " new = failed if p.failed else (None if p.retry else completed)\n", " stmt = f'UPDATE {self} SET \"{status_col}\"=$1 WHERE \"{pk}\"=$2'\n", - " if new: await txn.execute(stmt, new, get_field(p.evt, pk))" + " if new: \n", + " await txn.execute(stmt, new, get_field(p.evt, pk))\n", + " p.evt = await tbl.selectone(f'\"{pk}\"=$1', [get_field(p.evt, pk)])\n" ] }, { @@ -5293,6 +5296,31 @@ "await jobs(\"payload=$1\", [pl])" ] }, + { + "cell_type": "markdown", + "id": "a866bdaa", + "metadata": {}, + "source": [ + "Test that when the transaction is aborted the claimed event row is still set to failed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1e45f08", + "metadata": {}, + "outputs": [], + "source": [ + "await jobs.delete_where()\n", + "await jobs.inserts([Job(payload='repro', status='pending')])\n", + "\n", + "async with jobs.claim_one(order_by='id') as p:\n", + " await p.db.execute('select definitely_not_a_real_function()')\n", + "\n", + "test_eq(p.evt.status, 'failed')\n", + "test_eq(p.failed, True)" + ] + }, { "cell_type": "code", "execution_count": null, From dcd84d7dcbfe693dcff2cd9c8fbdfd6a678b02a1 Mon Sep 17 00:00:00 2001 From: Pol Alvarez Date: Wed, 3 Jun 2026 17:09:49 +0200 Subject: [PATCH 2/2] allowing claim_one tx.conn errors to surface --- fastasyncpg/core.py | 7 +++--- nbs/00_core.ipynb | 54 ++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 52 insertions(+), 9 deletions(-) diff --git a/fastasyncpg/core.py b/fastasyncpg/core.py index dbc0b4b..100fbe1 100644 --- a/fastasyncpg/core.py +++ b/fastasyncpg/core.py @@ -736,9 +736,9 @@ async def claim_one(self:Table, status_col='status', pending='pending', complete tbl = txn.t[self.name] p.evt = await tbl.claim(where=f'"{status_col}"=$1', where_args=[pending], order_by=order_by) p.db = txn - try: - async with txn.conn.transaction(): yield p # create save point in case tx is aborted - except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc() + async with txn.conn.transaction(): + try: yield p # create save point in case tx is aborted + except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc() if p.evt is None: return pk = self.pks[0] new = failed if p.failed else (None if p.retry else completed) @@ -746,4 +746,3 @@ async def claim_one(self:Table, status_col='status', pending='pending', complete if new: await txn.execute(stmt, new, get_field(p.evt, pk)) p.evt = await tbl.selectone(f'"{pk}"=$1', [get_field(p.evt, pk)]) - diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index 5ae4485..00148fd 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -5093,16 +5093,16 @@ " tbl = txn.t[self.name]\n", " p.evt = await tbl.claim(where=f'\"{status_col}\"=$1', where_args=[pending], order_by=order_by)\n", " p.db = txn\n", - " try:\n", - " async with txn.conn.transaction(): yield p # create save point in case tx is aborted\n", - " except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc()\n", + " async with txn.conn.transaction():\n", + " try: yield p # create save point in case tx is aborted\n", + " except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc()\n", " if p.evt is None: return\n", " pk = self.pks[0]\n", " new = failed if p.failed else (None if p.retry else completed)\n", " stmt = f'UPDATE {self} SET \"{status_col}\"=$1 WHERE \"{pk}\"=$2'\n", " if new: \n", " await txn.execute(stmt, new, get_field(p.evt, pk))\n", - " p.evt = await tbl.selectone(f'\"{pk}\"=$1', [get_field(p.evt, pk)])\n" + " p.evt = await tbl.selectone(f'\"{pk}\"=$1', [get_field(p.evt, pk)])" ] }, { @@ -5316,15 +5316,59 @@ "\n", "async with jobs.claim_one(order_by='id') as p:\n", " await p.db.execute('select definitely_not_a_real_function()')\n", + " await p.db.t.job.update(p.evt, status='processing')\n", "\n", "test_eq(p.evt.status, 'failed')\n", "test_eq(p.failed, True)" ] }, + { + "cell_type": "markdown", + "id": "81bb8295", + "metadata": {}, + "source": [ + "Test that if the `txn.conn.transaction` fails, we allow the exception to surface rather than having the generator swallow it and return a `RuntimeError(\"generator didn't yield\")`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77029d63", + "metadata": {}, + "outputs": [], + "source": [ + "await jobs.delete_where()\n", + "await jobs.inserts([Job(payload='x', status='pending')])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccacb161", + "metadata": {}, + "outputs": [], + "source": [ + "async with db.acquire() as c: ConnType = type(c.conn)\n", + "orig = ConnType.transaction\n", + "calls = []\n", + "from unittest.mock import patch as mock_patch\n", + "\n", + "def boom(self, *a, **k):\n", + " calls.append(1)\n", + " if len(calls) >= 2: raise ConnectionError(\"savepoint boom\")\n", + " return orig(self, *a, **k)\n", + "\n", + "with mock_patch.object(ConnType, 'transaction', boom):\n", + " try:\n", + " async with jobs.claim_one(order_by='id') as p: pass\n", + " assert False, \"expected ConnectionError\"\n", + " except ConnectionError: pass" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "c7e5cb98", + "id": "089f1a43", "metadata": {}, "outputs": [], "source": [