Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions agent_assembly/core/runtime_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import os
from typing import Any

from agent_assembly.exceptions import OpTerminatedError

ENV_RUNTIME_SOCKET = "AA_RUNTIME_SOCKET"
ACTION_TYPE_TOOL_CALL = "tool_call"
ENFORCE_MODE = "enforce"
Expand Down Expand Up @@ -89,6 +91,25 @@ def _extract_tool_name(serialized: Any, kwargs: dict[str, Any]) -> str | None:
return None


def _extract_op_id(kwargs: dict[str, Any]) -> str | None:
"""Resolve the op id ("{trace_id}:{span_id}") from a check_tool_start call.

Prefers an explicit ``op_id`` kwarg; otherwise composes it from
``trace_id`` / ``span_id`` when an adapter supplies them. Returns ``None``
when no trace identity is present (the call is not part of a tracked op, so
there is nothing for the kill switch to address).
"""
op_id = kwargs.get("op_id")
if isinstance(op_id, str) and op_id:
return op_id
trace_id = kwargs.get("trace_id")
if isinstance(trace_id, str) and trace_id:
span_id = kwargs.get("span_id")
span = span_id if isinstance(span_id, str) else ""
return f"{trace_id}:{span}"
return None


def _extract_tool_args_json(input_str: Any, kwargs: dict[str, Any]) -> str | None:
"""Serialize the tool arguments to JSON for the native ``query_policy``.

Expand Down Expand Up @@ -121,11 +142,25 @@ class RuntimeQueryInterceptor:
proceed (fail open), preserving the observe / disabled behavior.
"""

def __init__(self, client: Any, runtime_client: Any, agent_id: str, *, enforce: bool = False) -> None:
def __init__(
self,
client: Any,
runtime_client: Any,
agent_id: str,
*,
enforce: bool = False,
op_control: Any | None = None,
) -> None:
self._client = client
self._runtime_client = runtime_client
self._agent_id = agent_id
self._enforce = enforce
# Optional live op-control consumer (AAASM-3491). When wired, the
# gateway's kill switch is honored *in this tool path*: a terminate
# fast-fails the call and a pause blocks it cooperatively before the
# runtime is even queried. Without it, op control only reaches the agent
# via the native runtime's own OpControlStream consumer.
self._op_control = op_control

def __getattr__(self, name: str) -> Any:
# Delegate anything not defined here (e.g. report_event, on_tool_end,
Expand All @@ -152,7 +187,17 @@ def check_tool_start(
* A raising ``query_policy`` or an error-sentinel ``decision``
(``query_failed`` / ``channel_closed`` / ``shutdown``) β†’ ``deny`` under
``enforce`` (fail closed, AAASM-3106), else ``allow`` (fail open).

Before any of the above, the live op-control kill switch (AAASM-3491) is
consulted when an ``op_id`` is supplied and a subscriber is wired: a
terminated op is denied immediately and a paused op blocks here until
the gateway resumes it, so an operator terminate/pause reaches this tool
path directly rather than relying solely on the native runtime.
"""
op_block = self._check_op_control(_extract_op_id(kwargs))
if op_block is not None:
return op_block

tool_name = _extract_tool_name(serialized, kwargs)
tool_args_json = _extract_tool_args_json(input_str, kwargs)

Expand Down Expand Up @@ -187,6 +232,24 @@ def _on_query_failure(self, reason: str) -> dict[str, str]:
return {"status": "deny", "reason": reason}
return {"status": "allow"}

def _check_op_control(self, op_id: str | None) -> dict[str, str] | None:
"""Consult the live op-control kill switch for ``op_id`` (AAASM-3491).

Returns a ``deny`` status dict when the op has been terminated, ``None``
otherwise β€” including when no subscriber is wired or no ``op_id`` is
available, so the call proceeds to the normal runtime query. When the op
is *paused*, ``await_op`` blocks here until the gateway resumes (or
terminates) it; this is the cooperative-pause point on the Python tool
path.
"""
if self._op_control is None or not op_id:
return None
try:
self._op_control.await_op(op_id)
except OpTerminatedError as exc:
return {"status": "deny", "reason": str(exc)}
return None


class _FailClosedInterceptor:
"""Deny-all interceptor used under ``enforce`` when no runtime is reachable.
Expand Down Expand Up @@ -286,6 +349,7 @@ def build_governance_interceptor(
*,
runtime_client: Any | None = None,
native_available: bool | None = None,
op_control: Any | None = None,
) -> Any:
"""Return the interceptor adapters should use for pre-execution checks.

Expand Down Expand Up @@ -328,4 +392,4 @@ def build_governance_interceptor(
return _FailClosedInterceptor(client, "runtime unreachable")
return client

return RuntimeQueryInterceptor(client, runtime_client, agent_id, enforce=enforce)
return RuntimeQueryInterceptor(client, runtime_client, agent_id, enforce=enforce, op_control=op_control)
14 changes: 11 additions & 3 deletions agent_assembly/op_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@
If a signal arrives for an ``op_id`` no one is currently awaiting, it's
buffered into the per-op slot so the next ``await_op`` call sees it.

Out of scope for PR-E (deferred):
Wiring (AAASM-3491): a subscriber can be handed to
:func:`agent_assembly.core.runtime_interceptor.build_governance_interceptor`
(``op_control=...``); the resulting interceptor calls :meth:`await_op` for the
tool call's ``op_id`` in ``check_tool_start``, so an operator terminate/pause
reaches the running tool path. ``init_assembly`` does not construct the
subscriber automatically yet β€” callers opt in by passing one.

Out of scope (deferred):
- Reconnection / heartbeat on stream close (caller observes via
``stream_alive`` and re-instantiates if desired).
- Auto-wiring into ``init_assembly`` / adapter ``check_action`` hooks
(separate sub-task once the adapter surface is stable).
- Automatic construction inside ``init_assembly`` (the consumer is wired into
the interceptor; auto-instantiation is a follow-up once the gateway-url and
composite-id resolution at init time is settled).
"""

from __future__ import annotations
Expand Down
82 changes: 81 additions & 1 deletion test/unit/core/test_runtime_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
RuntimeQueryInterceptor,
build_governance_interceptor,
)
from agent_assembly.exceptions import ToolExecutionBlockedError
from agent_assembly.exceptions import OpTerminatedError, ToolExecutionBlockedError


class _FakeRuntimeClient:
Expand Down Expand Up @@ -337,3 +337,83 @@ def test_resolve_socket_path_prefers_env(monkeypatch: pytest.MonkeyPatch) -> Non
def test_resolve_socket_path_default(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("AA_RUNTIME_SOCKET", raising=False)
assert runtime_interceptor._resolve_runtime_socket_path("agent-001") == "/tmp/aa-runtime-agent-001.sock"


# ── Live op-control kill switch (AAASM-3491) ──────────────────────────────────


class _FakeOpControl:
"""Minimal OpControlSubscriber stand-in driving await_op behavior.

``terminated`` op_ids raise OpTerminatedError; any other op_id records the
await and returns β€” modelling the real subscriber once a pause has been
resumed (it blocks on a threading.Event then returns).
"""

def __init__(self, *, terminated: set[str] | None = None, paused: set[str] | None = None) -> None:
self._terminated = terminated or set()
# `paused` is accepted for call-site readability; the fake models the
# post-resume return for both paused and non-paused ops.
self._paused = paused or set()
self.awaited: list[str] = []

def await_op(self, op_id: str, **_kwargs: Any) -> None:
self.awaited.append(op_id)
if op_id in self._terminated:
raise OpTerminatedError(f"op {op_id} was terminated by the gateway", op_id=op_id)


def test_terminated_op_denies_before_runtime_query() -> None:
"""A terminate for the call's op halts the tool β€” and the runtime is never
even queried (the kill switch short-circuits)."""
runtime_client = _FakeRuntimeClient("allow")
op_control = _FakeOpControl(terminated={"trace-1:span-1"})
interceptor = RuntimeQueryInterceptor(_FakeGatewayClient(), runtime_client, "agent-001", op_control=op_control)

result = interceptor.check_tool_start(
serialized={"name": "web_search"},
input_str="query",
trace_id="trace-1",
span_id="span-1",
)

assert result["status"] == "deny"
assert "terminated" in result["reason"]
assert op_control.awaited == ["trace-1:span-1"]
# Short-circuited: the runtime query must not have run for a terminated op.
assert runtime_client.calls == []


def test_paused_op_consults_await_then_proceeds() -> None:
"""A paused op blocks in await_op; once it returns (resume) the tool
proceeds to the normal runtime allow."""
runtime_client = _FakeRuntimeClient("allow")
op_control = _FakeOpControl(paused={"trace-2:span-2"})
interceptor = RuntimeQueryInterceptor(_FakeGatewayClient(), runtime_client, "agent-001", op_control=op_control)

result = interceptor.check_tool_start(serialized={"name": "t"}, input_str="i", op_id="trace-2:span-2")

assert result == {"status": "allow"}
assert op_control.awaited == ["trace-2:span-2"]
assert len(runtime_client.calls) == 1


def test_no_op_id_skips_op_control() -> None:
"""Without a trace identity there is no op to address β€” the subscriber is
never consulted and the call proceeds normally."""
op_control = _FakeOpControl(terminated={"trace-x:span-x"})
interceptor = RuntimeQueryInterceptor(
_FakeGatewayClient(), _FakeRuntimeClient("allow"), "agent-001", op_control=op_control
)

result = interceptor.check_tool_start(serialized={"name": "t"}, input_str="i")

assert result == {"status": "allow"}
assert op_control.awaited == []


def test_op_id_composed_from_trace_and_span() -> None:
assert runtime_interceptor._extract_op_id({"trace_id": "t", "span_id": "s"}) == "t:s"
assert runtime_interceptor._extract_op_id({"op_id": "explicit"}) == "explicit"
assert runtime_interceptor._extract_op_id({"trace_id": "t"}) == "t:"
assert runtime_interceptor._extract_op_id({}) is None