diff --git a/src/provably/__init__.py b/src/provably/__init__.py index 092eb87..d7b490f 100644 --- a/src/provably/__init__.py +++ b/src/provably/__init__.py @@ -22,6 +22,7 @@ init_interceptor, intercept_context, is_enabled, + preprocess_ms, set_intercept_body_hook, set_intercept_url_allowlist, take_last_intercept_row_id, @@ -61,6 +62,7 @@ "intercept_context", "is_enabled", "is_trusted_endpoint", + "preprocess_ms", "list_trusted_endpoints", "normalize_url_for_trust", "outcome_from_trace", diff --git a/src/provably/handoff/_query_records.py b/src/provably/handoff/_query_records.py index 86a3e48..ccf586f 100644 --- a/src/provably/handoff/_query_records.py +++ b/src/provably/handoff/_query_records.py @@ -18,8 +18,10 @@ from __future__ import annotations +import requests + from provably.handoff._http import org_id as env_org_id -from provably.handoff._http import post_json, query_record_page_url +from provably.handoff._http import post_json, post_raw, query_record_page_url from provably.handoff._preprocess import wait_for_proof_completed from provably.handoff._resources import extract_id from provably.log import get_logger @@ -31,6 +33,10 @@ def _sql_escape(s: str) -> str: return s.replace("'", "''") +def _proof_already_exists(resp: requests.Response) -> bool: + return resp.status_code == 400 and "already exists" in (resp.text or "").lower() + + def create_query_record_for_intercept( action_name: str, *, @@ -76,9 +82,7 @@ def create_query_record_for_intercept( oid = (org_id or env_org_id()).strip() if not oid or not middleware_id or not collection_id: - raise ValueError( - "create_query_record_for_intercept requires org_id, middleware_id, collection_id" - ) + raise ValueError("create_query_record_for_intercept requires org_id, middleware_id, collection_id") if row_id is not None: # Single integer equality — accepted by all Provably engine versions. @@ -92,7 +96,9 @@ def create_query_record_for_intercept( ) query_id = extract_id(query_rec if isinstance(query_rec, dict) else {}, ["query_id", "id"]) - post_json(f"/api/v1/organizations/{oid}/queries/{query_id}/generate_proof", {}) + proof_resp = post_raw(f"/api/v1/organizations/{oid}/queries/{query_id}/generate_proof", {}) + if not proof_resp.ok and not _proof_already_exists(proof_resp): + proof_resp.raise_for_status() wait_for_proof_completed(oid, query_id, timeout_s=proof_timeout_s) url = query_record_page_url(oid, query_id) diff --git a/src/provably/intercept/__init__.py b/src/provably/intercept/__init__.py index 1cb474b..ce68614 100644 --- a/src/provably/intercept/__init__.py +++ b/src/provably/intercept/__init__.py @@ -2,6 +2,7 @@ from ._loader import load_latest_intercept_payload as load_latest_intercept_payload from ._self_egress import provably_self_egress as provably_self_egress +from ._storage import preprocess_ms as preprocess_ms from .interceptor import ( clear_intercept_row_ids as clear_intercept_row_ids, ) @@ -25,6 +26,7 @@ "init_interceptor", "intercept_context", "is_enabled", + "preprocess_ms", "provably_self_egress", "set_intercept_body_hook", "set_intercept_url_allowlist", diff --git a/src/provably/intercept/_storage.py b/src/provably/intercept/_storage.py index d8d6faa..dd5c39c 100644 --- a/src/provably/intercept/_storage.py +++ b/src/provably/intercept/_storage.py @@ -5,6 +5,8 @@ import hashlib import json import os +import time +from contextvars import ContextVar, Token from typing import Any import psycopg2 @@ -14,6 +16,23 @@ from provably.log import get_logger from provably.trusted_endpoints import ensure_trusted_endpoints_table, is_trusted_endpoint +# Cumulative ms spent in preprocess during the current intercept_context, so callers can +# attribute that time to the SDK rather than to the wrapped tool/HTTP call. +_ctx_preprocess_ms: ContextVar[float] = ContextVar("provably_preprocess_ms", default=0.0) + + +def reset_preprocess_ms() -> Token[float]: + return _ctx_preprocess_ms.set(0.0) + + +def restore_preprocess_ms(token: Token[float]) -> None: + _ctx_preprocess_ms.reset(token) + + +def preprocess_ms() -> float: + """Cumulative preprocess time (ms) recorded so far in the current intercept_context.""" + return _ctx_preprocess_ms.get() + _log = get_logger(__name__) _DDL_DONE = False @@ -167,7 +186,9 @@ def _write_row( conn.commit() _log.info("intercept_stored", agent_id=agent_id, action_name=action_name, url=url, method=method) if row_id is not None: + t0 = time.perf_counter() preprocess_after_intercept_write() + _ctx_preprocess_ms.set(_ctx_preprocess_ms.get() + (time.perf_counter() - t0) * 1000.0) return row_id finally: conn.close() diff --git a/src/provably/intercept/interceptor.py b/src/provably/intercept/interceptor.py index b3cf4d4..6719ba8 100644 --- a/src/provably/intercept/interceptor.py +++ b/src/provably/intercept/interceptor.py @@ -28,6 +28,8 @@ from provably.intercept._storage import ( insert_intercept_row, request_payload_dict, + reset_preprocess_ms, + restore_preprocess_ms, ) from provably.trusted_endpoints import normalize_url_for_trust @@ -117,9 +119,11 @@ def get_temperature(): t_agent = _ctx_agent_id.set(agent_id) t_action = _ctx_action_name.set(action_name) t_index = _ctx_intercept_index.set(intercept_index) + t_pp = reset_preprocess_ms() try: yield finally: + restore_preprocess_ms(t_pp) _ctx_intercept_index.reset(t_index) _ctx_action_name.reset(t_action) _ctx_agent_id.reset(t_agent) diff --git a/tests/unit/test_query_records.py b/tests/unit/test_query_records.py index 4e8ab0f..ea4429e 100644 --- a/tests/unit/test_query_records.py +++ b/tests/unit/test_query_records.py @@ -3,10 +3,22 @@ from typing import Any import pytest +import requests from provably.handoff._query_records import create_query_record_for_intercept +class _FakeResp: + def __init__(self, status_code: int = 200, text: str = "") -> None: + self.status_code = status_code + self.text = text + self.ok = 200 <= status_code < 300 + + def raise_for_status(self) -> None: + if not self.ok: + raise requests.HTTPError(self.text) + + @pytest.fixture def fake_env(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("PROVABLY_RUST_BE_URL", "https://api.test") @@ -14,6 +26,12 @@ def fake_env(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("PROVABLY_ORG_ID", "org-1") +@pytest.fixture(autouse=True) +def ok_generate_proof(monkeypatch: pytest.MonkeyPatch) -> None: + """Default generate_proof (post_raw) to a 200 so tests don't hit the network.""" + monkeypatch.setattr("provably.handoff._query_records.post_raw", lambda *_a, **_kw: _FakeResp(200)) + + def test_creates_query_record_and_returns_id_url( fake_env: None, monkeypatch: pytest.MonkeyPatch, @@ -31,7 +49,14 @@ def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any def fake_wait(_org: str, _qid: str, timeout_s: float = 180.0) -> None: return None + raw_posted: list[str] = [] + + def fake_raw(path: str, payload: dict[str, Any] | None = None) -> _FakeResp: + raw_posted.append(path) + return _FakeResp(200) + monkeypatch.setattr("provably.handoff._query_records.post_json", fake_post) + monkeypatch.setattr("provably.handoff._query_records.post_raw", fake_raw) monkeypatch.setattr("provably.handoff._query_records.wait_for_proof_completed", fake_wait) qid, qurl = create_query_record_for_intercept( @@ -43,12 +68,10 @@ def fake_wait(_org: str, _qid: str, timeout_s: float = 180.0) -> None: assert qid == "q-uuid" assert qurl == "https://app.test/org/org-1/query-record/q-uuid" - paths = [p for p, _ in posted] - assert paths == [ - "/api/v1/organizations/org-1/middlewares/mw-1/query", - "/api/v1/organizations/org-1/queries/q-uuid/generate_proof", - ] - assert not any(p.endswith("/verify") for p in paths), ( + json_paths = [p for p, _ in posted] + assert json_paths == ["/api/v1/organizations/org-1/middlewares/mw-1/query"] + assert raw_posted == ["/api/v1/organizations/org-1/queries/q-uuid/generate_proof"] + assert not any(p.endswith("/verify") for p in json_paths + raw_posted), ( "create_query_record_for_intercept must not call /verify; that runs in evaluator (cluster B)" ) @@ -212,6 +235,60 @@ def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any assert "WHERE action_name = 'O''Reilly'" in captured["body"]["query"] +def test_proof_already_exists_is_treated_as_success( + fake_env: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A 400 'proof already exists' means the proof is there — proceed to wait, don't raise.""" + waited: dict[str, str] = {} + + def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: + return {"id": "q-uuid"} if path.endswith("/query") else {} + + def fake_raw(path: str, payload: dict[str, Any] | None = None) -> _FakeResp: + return _FakeResp(400, "Proof already exists for this query. Use verify endpoint instead.") + + def fake_wait(_org: str, qid: str, timeout_s: float = 180.0) -> None: + waited["qid"] = qid + + monkeypatch.setattr("provably.handoff._query_records.post_json", fake_post) + monkeypatch.setattr("provably.handoff._query_records.post_raw", fake_raw) + monkeypatch.setattr("provably.handoff._query_records.wait_for_proof_completed", fake_wait) + + qid, _qurl = create_query_record_for_intercept( + "endpoint_0", + agent_id="fetch_and_claim", + middleware_id="mw-1", + collection_id="coll-1", + ) + assert qid == "q-uuid" + assert waited["qid"] == "q-uuid" + + +def test_proof_generation_hard_error_raises( + fake_env: None, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A non-'already exists' failure still raises rather than silently continuing.""" + + def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]: + return {"id": "q-uuid"} if path.endswith("/query") else {} + + monkeypatch.setattr("provably.handoff._query_records.post_json", fake_post) + monkeypatch.setattr( + "provably.handoff._query_records.post_raw", + lambda *_a, **_kw: _FakeResp(500, "boom"), + ) + + with pytest.raises(requests.HTTPError): + create_query_record_for_intercept( + "endpoint_0", + agent_id="fetch_and_claim", + middleware_id="mw-1", + collection_id="coll-1", + ) + + @pytest.mark.parametrize( ("agent_id", "action_name"), [("", "act"), ("ag", "")],