Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/provably/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
init_interceptor,
intercept_context,
is_enabled,
preprocess_ms,
set_intercept_body_hook,
set_intercept_url_allowlist,
take_last_intercept_row_id,
Expand Down Expand Up @@ -61,6 +62,7 @@
"intercept_context",
"is_enabled",
"is_trusted_endpoint",
"preprocess_ms",
"list_trusted_endpoints",
"normalize_url_for_trust",
"outcome_from_trace",
Expand Down
16 changes: 11 additions & 5 deletions src/provably/handoff/_query_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from __future__ import annotations

import requests

from provably.handoff._http import org_id as env_org_id
from provably.handoff._http import post_json, query_record_page_url
from provably.handoff._http import post_json, post_raw, query_record_page_url
from provably.handoff._preprocess import wait_for_proof_completed
from provably.handoff._resources import extract_id
from provably.log import get_logger
Expand All @@ -31,6 +33,10 @@ def _sql_escape(s: str) -> str:
return s.replace("'", "''")


def _proof_already_exists(resp: requests.Response) -> bool:
return resp.status_code == 400 and "already exists" in (resp.text or "").lower()


def create_query_record_for_intercept(
action_name: str,
*,
Expand Down Expand Up @@ -76,9 +82,7 @@ def create_query_record_for_intercept(

oid = (org_id or env_org_id()).strip()
if not oid or not middleware_id or not collection_id:
raise ValueError(
"create_query_record_for_intercept requires org_id, middleware_id, collection_id"
)
raise ValueError("create_query_record_for_intercept requires org_id, middleware_id, collection_id")

if row_id is not None:
# Single integer equality — accepted by all Provably engine versions.
Expand All @@ -92,7 +96,9 @@ def create_query_record_for_intercept(
)
query_id = extract_id(query_rec if isinstance(query_rec, dict) else {}, ["query_id", "id"])

post_json(f"/api/v1/organizations/{oid}/queries/{query_id}/generate_proof", {})
proof_resp = post_raw(f"/api/v1/organizations/{oid}/queries/{query_id}/generate_proof", {})
if not proof_resp.ok and not _proof_already_exists(proof_resp):
proof_resp.raise_for_status()
wait_for_proof_completed(oid, query_id, timeout_s=proof_timeout_s)

url = query_record_page_url(oid, query_id)
Expand Down
2 changes: 2 additions & 0 deletions src/provably/intercept/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ._loader import load_latest_intercept_payload as load_latest_intercept_payload
from ._self_egress import provably_self_egress as provably_self_egress
from ._storage import preprocess_ms as preprocess_ms
from .interceptor import (
clear_intercept_row_ids as clear_intercept_row_ids,
)
Expand All @@ -25,6 +26,7 @@
"init_interceptor",
"intercept_context",
"is_enabled",
"preprocess_ms",
"provably_self_egress",
"set_intercept_body_hook",
"set_intercept_url_allowlist",
Expand Down
21 changes: 21 additions & 0 deletions src/provably/intercept/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import hashlib
import json
import os
import time
from contextvars import ContextVar, Token
from typing import Any

import psycopg2
Expand All @@ -14,6 +16,23 @@
from provably.log import get_logger
from provably.trusted_endpoints import ensure_trusted_endpoints_table, is_trusted_endpoint

# Cumulative ms spent in preprocess during the current intercept_context, so callers can
# attribute that time to the SDK rather than to the wrapped tool/HTTP call.
_ctx_preprocess_ms: ContextVar[float] = ContextVar("provably_preprocess_ms", default=0.0)


def reset_preprocess_ms() -> Token[float]:
return _ctx_preprocess_ms.set(0.0)


def restore_preprocess_ms(token: Token[float]) -> None:
_ctx_preprocess_ms.reset(token)


def preprocess_ms() -> float:
"""Cumulative preprocess time (ms) recorded so far in the current intercept_context."""
return _ctx_preprocess_ms.get()

_log = get_logger(__name__)
_DDL_DONE = False

Expand Down Expand Up @@ -167,7 +186,9 @@ def _write_row(
conn.commit()
_log.info("intercept_stored", agent_id=agent_id, action_name=action_name, url=url, method=method)
if row_id is not None:
t0 = time.perf_counter()
preprocess_after_intercept_write()
_ctx_preprocess_ms.set(_ctx_preprocess_ms.get() + (time.perf_counter() - t0) * 1000.0)
return row_id
finally:
conn.close()
4 changes: 4 additions & 0 deletions src/provably/intercept/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from provably.intercept._storage import (
insert_intercept_row,
request_payload_dict,
reset_preprocess_ms,
restore_preprocess_ms,
)
from provably.trusted_endpoints import normalize_url_for_trust

Expand Down Expand Up @@ -117,9 +119,11 @@ def get_temperature():
t_agent = _ctx_agent_id.set(agent_id)
t_action = _ctx_action_name.set(action_name)
t_index = _ctx_intercept_index.set(intercept_index)
t_pp = reset_preprocess_ms()
try:
yield
finally:
restore_preprocess_ms(t_pp)
_ctx_intercept_index.reset(t_index)
_ctx_action_name.reset(t_action)
_ctx_agent_id.reset(t_agent)
Expand Down
89 changes: 83 additions & 6 deletions tests/unit/test_query_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,35 @@
from typing import Any

import pytest
import requests

from provably.handoff._query_records import create_query_record_for_intercept


class _FakeResp:
def __init__(self, status_code: int = 200, text: str = "") -> None:
self.status_code = status_code
self.text = text
self.ok = 200 <= status_code < 300

def raise_for_status(self) -> None:
if not self.ok:
raise requests.HTTPError(self.text)


@pytest.fixture
def fake_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("PROVABLY_RUST_BE_URL", "https://api.test")
monkeypatch.setenv("PROVABLY_API_KEY", "k")
monkeypatch.setenv("PROVABLY_ORG_ID", "org-1")


@pytest.fixture(autouse=True)
def ok_generate_proof(monkeypatch: pytest.MonkeyPatch) -> None:
"""Default generate_proof (post_raw) to a 200 so tests don't hit the network."""
monkeypatch.setattr("provably.handoff._query_records.post_raw", lambda *_a, **_kw: _FakeResp(200))


def test_creates_query_record_and_returns_id_url(
fake_env: None,
monkeypatch: pytest.MonkeyPatch,
Expand All @@ -31,7 +49,14 @@ def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any
def fake_wait(_org: str, _qid: str, timeout_s: float = 180.0) -> None:
return None

raw_posted: list[str] = []

def fake_raw(path: str, payload: dict[str, Any] | None = None) -> _FakeResp:
raw_posted.append(path)
return _FakeResp(200)

monkeypatch.setattr("provably.handoff._query_records.post_json", fake_post)
monkeypatch.setattr("provably.handoff._query_records.post_raw", fake_raw)
monkeypatch.setattr("provably.handoff._query_records.wait_for_proof_completed", fake_wait)

qid, qurl = create_query_record_for_intercept(
Expand All @@ -43,12 +68,10 @@ def fake_wait(_org: str, _qid: str, timeout_s: float = 180.0) -> None:
assert qid == "q-uuid"
assert qurl == "https://app.test/org/org-1/query-record/q-uuid"

paths = [p for p, _ in posted]
assert paths == [
"/api/v1/organizations/org-1/middlewares/mw-1/query",
"/api/v1/organizations/org-1/queries/q-uuid/generate_proof",
]
assert not any(p.endswith("/verify") for p in paths), (
json_paths = [p for p, _ in posted]
assert json_paths == ["/api/v1/organizations/org-1/middlewares/mw-1/query"]
assert raw_posted == ["/api/v1/organizations/org-1/queries/q-uuid/generate_proof"]
assert not any(p.endswith("/verify") for p in json_paths + raw_posted), (
"create_query_record_for_intercept must not call /verify; that runs in evaluator (cluster B)"
)

Expand Down Expand Up @@ -212,6 +235,60 @@ def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any
assert "WHERE action_name = 'O''Reilly'" in captured["body"]["query"]


def test_proof_already_exists_is_treated_as_success(
fake_env: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A 400 'proof already exists' means the proof is there — proceed to wait, don't raise."""
waited: dict[str, str] = {}

def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
return {"id": "q-uuid"} if path.endswith("/query") else {}

def fake_raw(path: str, payload: dict[str, Any] | None = None) -> _FakeResp:
return _FakeResp(400, "Proof already exists for this query. Use verify endpoint instead.")

def fake_wait(_org: str, qid: str, timeout_s: float = 180.0) -> None:
waited["qid"] = qid

monkeypatch.setattr("provably.handoff._query_records.post_json", fake_post)
monkeypatch.setattr("provably.handoff._query_records.post_raw", fake_raw)
monkeypatch.setattr("provably.handoff._query_records.wait_for_proof_completed", fake_wait)

qid, _qurl = create_query_record_for_intercept(
"endpoint_0",
agent_id="fetch_and_claim",
middleware_id="mw-1",
collection_id="coll-1",
)
assert qid == "q-uuid"
assert waited["qid"] == "q-uuid"


def test_proof_generation_hard_error_raises(
fake_env: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A non-'already exists' failure still raises rather than silently continuing."""

def fake_post(path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
return {"id": "q-uuid"} if path.endswith("/query") else {}

monkeypatch.setattr("provably.handoff._query_records.post_json", fake_post)
monkeypatch.setattr(
"provably.handoff._query_records.post_raw",
lambda *_a, **_kw: _FakeResp(500, "boom"),
)

with pytest.raises(requests.HTTPError):
create_query_record_for_intercept(
"endpoint_0",
agent_id="fetch_and_claim",
middleware_id="mw-1",
collection_id="coll-1",
)


@pytest.mark.parametrize(
("agent_id", "action_name"),
[("", "act"), ("ag", "")],
Expand Down
Loading