diff --git a/fastasyncpg/core.py b/fastasyncpg/core.py index 938be07..100fbe1 100644 --- a/fastasyncpg/core.py +++ b/fastasyncpg/core.py @@ -734,12 +734,15 @@ async def claim_one(self:Table, status_col='status', pending='pending', complete p = _ClaimCtx() async with self.db.transaction() as txn: tbl = txn.t[self.name] - p.db = txn p.evt = await tbl.claim(where=f'"{status_col}"=$1', where_args=[pending], order_by=order_by) - try: yield p - except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc() + p.db = txn + async with txn.conn.transaction(): + try: yield p # create save point in case tx is aborted + except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc() if p.evt is None: return pk = self.pks[0] new = failed if p.failed else (None if p.retry else completed) stmt = f'UPDATE {self} SET "{status_col}"=$1 WHERE "{pk}"=$2' - if new: await txn.execute(stmt, new, get_field(p.evt, pk)) + if new: + await txn.execute(stmt, new, get_field(p.evt, pk)) + p.evt = await tbl.selectone(f'"{pk}"=$1', [get_field(p.evt, pk)]) diff --git a/nbs/00_core.ipynb b/nbs/00_core.ipynb index d790f8e..00148fd 100644 --- a/nbs/00_core.ipynb +++ b/nbs/00_core.ipynb @@ -5091,15 +5091,18 @@ " p = _ClaimCtx()\n", " async with self.db.transaction() as txn:\n", " tbl = txn.t[self.name]\n", - " p.db = txn\n", " p.evt = await tbl.claim(where=f'\"{status_col}\"=$1', where_args=[pending], order_by=order_by)\n", - " try: yield p\n", - " except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc()\n", + " p.db = txn\n", + " async with txn.conn.transaction():\n", + " try: yield p # create save point in case tx is aborted\n", + " except Exception as e: p.failed, p.exc, p.tb = True, e, traceback.format_exc()\n", " if p.evt is None: return\n", " pk = self.pks[0]\n", " new = failed if p.failed else (None if p.retry else completed)\n", " stmt = f'UPDATE {self} SET \"{status_col}\"=$1 WHERE \"{pk}\"=$2'\n", - " if new: await txn.execute(stmt, new, get_field(p.evt, pk))" + " if new: \n", + " await txn.execute(stmt, new, get_field(p.evt, pk))\n", + " p.evt = await tbl.selectone(f'\"{pk}\"=$1', [get_field(p.evt, pk)])" ] }, { @@ -5293,10 +5296,79 @@ "await jobs(\"payload=$1\", [pl])" ] }, + { + "cell_type": "markdown", + "id": "a866bdaa", + "metadata": {}, + "source": [ + "Test that when the transaction is aborted the claimed event row is still set to failed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1e45f08", + "metadata": {}, + "outputs": [], + "source": [ + "await jobs.delete_where()\n", + "await jobs.inserts([Job(payload='repro', status='pending')])\n", + "\n", + "async with jobs.claim_one(order_by='id') as p:\n", + " await p.db.execute('select definitely_not_a_real_function()')\n", + " await p.db.t.job.update(p.evt, status='processing')\n", + "\n", + "test_eq(p.evt.status, 'failed')\n", + "test_eq(p.failed, True)" + ] + }, + { + "cell_type": "markdown", + "id": "81bb8295", + "metadata": {}, + "source": [ + "Test that if the `txn.conn.transaction` fails, we allow the exception to surface rather than having the generator swallow it and return a `RuntimeError(\"generator didn't yield\")`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77029d63", + "metadata": {}, + "outputs": [], + "source": [ + "await jobs.delete_where()\n", + "await jobs.inserts([Job(payload='x', status='pending')])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccacb161", + "metadata": {}, + "outputs": [], + "source": [ + "async with db.acquire() as c: ConnType = type(c.conn)\n", + "orig = ConnType.transaction\n", + "calls = []\n", + "from unittest.mock import patch as mock_patch\n", + "\n", + "def boom(self, *a, **k):\n", + " calls.append(1)\n", + " if len(calls) >= 2: raise ConnectionError(\"savepoint boom\")\n", + " return orig(self, *a, **k)\n", + "\n", + "with mock_patch.object(ConnType, 'transaction', boom):\n", + " try:\n", + " async with jobs.claim_one(order_by='id') as p: pass\n", + " assert False, \"expected ConnectionError\"\n", + " except ConnectionError: pass" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "c7e5cb98", + "id": "089f1a43", "metadata": {}, "outputs": [], "source": [