diff --git a/agent_assembly/core/runtime_interceptor.py b/agent_assembly/core/runtime_interceptor.py index b9194ba..ce1928d 100644 --- a/agent_assembly/core/runtime_interceptor.py +++ b/agent_assembly/core/runtime_interceptor.py @@ -36,6 +36,8 @@ import os from typing import Any +from agent_assembly.exceptions import OpTerminatedError + ENV_RUNTIME_SOCKET = "AA_RUNTIME_SOCKET" ACTION_TYPE_TOOL_CALL = "tool_call" ENFORCE_MODE = "enforce" @@ -89,6 +91,25 @@ def _extract_tool_name(serialized: Any, kwargs: dict[str, Any]) -> str | None: return None +def _extract_op_id(kwargs: dict[str, Any]) -> str | None: + """Resolve the op id ("{trace_id}:{span_id}") from a check_tool_start call. + + Prefers an explicit ``op_id`` kwarg; otherwise composes it from + ``trace_id`` / ``span_id`` when an adapter supplies them. Returns ``None`` + when no trace identity is present (the call is not part of a tracked op, so + there is nothing for the kill switch to address). + """ + op_id = kwargs.get("op_id") + if isinstance(op_id, str) and op_id: + return op_id + trace_id = kwargs.get("trace_id") + if isinstance(trace_id, str) and trace_id: + span_id = kwargs.get("span_id") + span = span_id if isinstance(span_id, str) else "" + return f"{trace_id}:{span}" + return None + + def _extract_tool_args_json(input_str: Any, kwargs: dict[str, Any]) -> str | None: """Serialize the tool arguments to JSON for the native ``query_policy``. @@ -121,11 +142,25 @@ class RuntimeQueryInterceptor: proceed (fail open), preserving the observe / disabled behavior. """ - def __init__(self, client: Any, runtime_client: Any, agent_id: str, *, enforce: bool = False) -> None: + def __init__( + self, + client: Any, + runtime_client: Any, + agent_id: str, + *, + enforce: bool = False, + op_control: Any | None = None, + ) -> None: self._client = client self._runtime_client = runtime_client self._agent_id = agent_id self._enforce = enforce + # Optional live op-control consumer (AAASM-3491). When wired, the + # gateway's kill switch is honored *in this tool path*: a terminate + # fast-fails the call and a pause blocks it cooperatively before the + # runtime is even queried. Without it, op control only reaches the agent + # via the native runtime's own OpControlStream consumer. + self._op_control = op_control def __getattr__(self, name: str) -> Any: # Delegate anything not defined here (e.g. report_event, on_tool_end, @@ -152,7 +187,17 @@ def check_tool_start( * A raising ``query_policy`` or an error-sentinel ``decision`` (``query_failed`` / ``channel_closed`` / ``shutdown``) → ``deny`` under ``enforce`` (fail closed, AAASM-3106), else ``allow`` (fail open). + + Before any of the above, the live op-control kill switch (AAASM-3491) is + consulted when an ``op_id`` is supplied and a subscriber is wired: a + terminated op is denied immediately and a paused op blocks here until + the gateway resumes it, so an operator terminate/pause reaches this tool + path directly rather than relying solely on the native runtime. """ + op_block = self._check_op_control(_extract_op_id(kwargs)) + if op_block is not None: + return op_block + tool_name = _extract_tool_name(serialized, kwargs) tool_args_json = _extract_tool_args_json(input_str, kwargs) @@ -187,6 +232,24 @@ def _on_query_failure(self, reason: str) -> dict[str, str]: return {"status": "deny", "reason": reason} return {"status": "allow"} + def _check_op_control(self, op_id: str | None) -> dict[str, str] | None: + """Consult the live op-control kill switch for ``op_id`` (AAASM-3491). + + Returns a ``deny`` status dict when the op has been terminated, ``None`` + otherwise — including when no subscriber is wired or no ``op_id`` is + available, so the call proceeds to the normal runtime query. When the op + is *paused*, ``await_op`` blocks here until the gateway resumes (or + terminates) it; this is the cooperative-pause point on the Python tool + path. + """ + if self._op_control is None or not op_id: + return None + try: + self._op_control.await_op(op_id) + except OpTerminatedError as exc: + return {"status": "deny", "reason": str(exc)} + return None + class _FailClosedInterceptor: """Deny-all interceptor used under ``enforce`` when no runtime is reachable. @@ -286,6 +349,7 @@ def build_governance_interceptor( *, runtime_client: Any | None = None, native_available: bool | None = None, + op_control: Any | None = None, ) -> Any: """Return the interceptor adapters should use for pre-execution checks. @@ -328,4 +392,4 @@ def build_governance_interceptor( return _FailClosedInterceptor(client, "runtime unreachable") return client - return RuntimeQueryInterceptor(client, runtime_client, agent_id, enforce=enforce) + return RuntimeQueryInterceptor(client, runtime_client, agent_id, enforce=enforce, op_control=op_control) diff --git a/agent_assembly/op_control.py b/agent_assembly/op_control.py index 4b8e9b3..e96d892 100644 --- a/agent_assembly/op_control.py +++ b/agent_assembly/op_control.py @@ -15,11 +15,19 @@ If a signal arrives for an ``op_id`` no one is currently awaiting, it's buffered into the per-op slot so the next ``await_op`` call sees it. -Out of scope for PR-E (deferred): +Wiring (AAASM-3491): a subscriber can be handed to +:func:`agent_assembly.core.runtime_interceptor.build_governance_interceptor` +(``op_control=...``); the resulting interceptor calls :meth:`await_op` for the +tool call's ``op_id`` in ``check_tool_start``, so an operator terminate/pause +reaches the running tool path. ``init_assembly`` does not construct the +subscriber automatically yet — callers opt in by passing one. + +Out of scope (deferred): - Reconnection / heartbeat on stream close (caller observes via ``stream_alive`` and re-instantiates if desired). - - Auto-wiring into ``init_assembly`` / adapter ``check_action`` hooks - (separate sub-task once the adapter surface is stable). + - Automatic construction inside ``init_assembly`` (the consumer is wired into + the interceptor; auto-instantiation is a follow-up once the gateway-url and + composite-id resolution at init time is settled). """ from __future__ import annotations diff --git a/test/unit/core/test_runtime_interceptor.py b/test/unit/core/test_runtime_interceptor.py index 2df1151..092053d 100644 --- a/test/unit/core/test_runtime_interceptor.py +++ b/test/unit/core/test_runtime_interceptor.py @@ -23,7 +23,7 @@ RuntimeQueryInterceptor, build_governance_interceptor, ) -from agent_assembly.exceptions import ToolExecutionBlockedError +from agent_assembly.exceptions import OpTerminatedError, ToolExecutionBlockedError class _FakeRuntimeClient: @@ -337,3 +337,83 @@ def test_resolve_socket_path_prefers_env(monkeypatch: pytest.MonkeyPatch) -> Non def test_resolve_socket_path_default(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("AA_RUNTIME_SOCKET", raising=False) assert runtime_interceptor._resolve_runtime_socket_path("agent-001") == "/tmp/aa-runtime-agent-001.sock" + + +# ── Live op-control kill switch (AAASM-3491) ────────────────────────────────── + + +class _FakeOpControl: + """Minimal OpControlSubscriber stand-in driving await_op behavior. + + ``terminated`` op_ids raise OpTerminatedError; any other op_id records the + await and returns — modelling the real subscriber once a pause has been + resumed (it blocks on a threading.Event then returns). + """ + + def __init__(self, *, terminated: set[str] | None = None, paused: set[str] | None = None) -> None: + self._terminated = terminated or set() + # `paused` is accepted for call-site readability; the fake models the + # post-resume return for both paused and non-paused ops. + self._paused = paused or set() + self.awaited: list[str] = [] + + def await_op(self, op_id: str, **_kwargs: Any) -> None: + self.awaited.append(op_id) + if op_id in self._terminated: + raise OpTerminatedError(f"op {op_id} was terminated by the gateway", op_id=op_id) + + +def test_terminated_op_denies_before_runtime_query() -> None: + """A terminate for the call's op halts the tool — and the runtime is never + even queried (the kill switch short-circuits).""" + runtime_client = _FakeRuntimeClient("allow") + op_control = _FakeOpControl(terminated={"trace-1:span-1"}) + interceptor = RuntimeQueryInterceptor(_FakeGatewayClient(), runtime_client, "agent-001", op_control=op_control) + + result = interceptor.check_tool_start( + serialized={"name": "web_search"}, + input_str="query", + trace_id="trace-1", + span_id="span-1", + ) + + assert result["status"] == "deny" + assert "terminated" in result["reason"] + assert op_control.awaited == ["trace-1:span-1"] + # Short-circuited: the runtime query must not have run for a terminated op. + assert runtime_client.calls == [] + + +def test_paused_op_consults_await_then_proceeds() -> None: + """A paused op blocks in await_op; once it returns (resume) the tool + proceeds to the normal runtime allow.""" + runtime_client = _FakeRuntimeClient("allow") + op_control = _FakeOpControl(paused={"trace-2:span-2"}) + interceptor = RuntimeQueryInterceptor(_FakeGatewayClient(), runtime_client, "agent-001", op_control=op_control) + + result = interceptor.check_tool_start(serialized={"name": "t"}, input_str="i", op_id="trace-2:span-2") + + assert result == {"status": "allow"} + assert op_control.awaited == ["trace-2:span-2"] + assert len(runtime_client.calls) == 1 + + +def test_no_op_id_skips_op_control() -> None: + """Without a trace identity there is no op to address — the subscriber is + never consulted and the call proceeds normally.""" + op_control = _FakeOpControl(terminated={"trace-x:span-x"}) + interceptor = RuntimeQueryInterceptor( + _FakeGatewayClient(), _FakeRuntimeClient("allow"), "agent-001", op_control=op_control + ) + + result = interceptor.check_tool_start(serialized={"name": "t"}, input_str="i") + + assert result == {"status": "allow"} + assert op_control.awaited == [] + + +def test_op_id_composed_from_trace_and_span() -> None: + assert runtime_interceptor._extract_op_id({"trace_id": "t", "span_id": "s"}) == "t:s" + assert runtime_interceptor._extract_op_id({"op_id": "explicit"}) == "explicit" + assert runtime_interceptor._extract_op_id({"trace_id": "t"}) == "t:" + assert runtime_interceptor._extract_op_id({}) is None