Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions agent_assembly/adapters/crewai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from functools import wraps
from threading import local
from typing import Any, Literal
from typing import Any, Literal, cast

from agent_assembly.core.spawn import _SPAWN_CTX, SpawnContext, spawn_context_scope

Expand Down Expand Up @@ -227,31 +227,33 @@ def _unknown_decision(enforce: bool) -> tuple[Literal["allow", "deny", "pending"
return "allow", None


_KNOWN_STATUSES: frozenset[str] = frozenset({"allow", "deny", "pending"})


def _coerce_known_status(value: str) -> Literal["allow", "deny", "pending"] | None:
"""Return the verdict literal for a recognized status string, else ``None``."""
if value in _KNOWN_STATUSES:
return cast("Literal['allow', 'deny', 'pending']", value)
return None


def _normalize_decision(
decision: object,
*,
enforce: bool = False,
) -> tuple[Literal["allow", "deny", "pending"], str | None]:
if isinstance(decision, str):
normalized = decision.strip().lower()
if normalized == "allow":
return "allow", None
if normalized == "deny":
return "deny", None
if normalized == "pending":
return "pending", None
status = _coerce_known_status(decision.strip().lower())
if status is not None:
return status, None
return _unknown_decision(enforce)

if isinstance(decision, Mapping):
raw_status = str(decision.get("status", "")).strip().lower()
reason_value = decision.get("reason")
reason = str(reason_value) if reason_value is not None else None
if raw_status == "allow":
return "allow", reason
if raw_status == "deny":
return "deny", reason
if raw_status == "pending":
return "pending", reason
status = _coerce_known_status(str(decision.get("status", "")).strip().lower())
if status is not None:
return status, reason
unknown_status, unknown_reason = _unknown_decision(enforce)
return unknown_status, reason if reason is not None else unknown_reason

Expand Down
9 changes: 6 additions & 3 deletions agent_assembly/adapters/google_adk/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from collections.abc import Mapping
from dataclasses import dataclass
from functools import wraps
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
from agent_assembly.exceptions import PolicyViolationError

from agent_assembly.adapters.crewai.patch import (
_get_pending_tool_approval_timeout_seconds as _resolve_pending_timeout_seconds,
Expand Down Expand Up @@ -381,14 +384,14 @@ async def _record_async_tool_result(
await recorded


def _build_denied_error(tool_name: str, reason: str | None) -> Exception:
def _build_denied_error(tool_name: str, reason: str | None) -> PolicyViolationError:
from agent_assembly.exceptions import PolicyViolationError

reason_text = reason or "No reason provided."
return PolicyViolationError(f"Tool '{tool_name}' blocked by governance policy: {reason_text}")


def _build_pending_rejected_error(tool_name: str, reason: str | None) -> Exception:
def _build_pending_rejected_error(tool_name: str, reason: str | None) -> PolicyViolationError:
from agent_assembly.exceptions import PolicyViolationError

reason_text = reason or "No reason provided."
Expand Down
29 changes: 15 additions & 14 deletions agent_assembly/adapters/langchain/callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from agent_assembly.exceptions import ToolExecutionBlockedError

_KNOWN_STATUSES: frozenset[str] = frozenset({"allow", "deny", "pending"})


class _FallbackBaseCallbackHandler:
"""Fallback base type when langchain-core is not installed."""
Expand Down Expand Up @@ -60,30 +62,29 @@ def _unknown_decision(self) -> tuple[Literal["allow", "deny", "pending"], str |
return "deny", self._UNKNOWN_DECISION_REASON
return "allow", None

@staticmethod
def _coerce_known_status(value: str) -> Literal["allow", "deny", "pending"] | None:
"""Return the verdict literal for a recognized status string, else ``None``."""
if value in _KNOWN_STATUSES:
return cast("Literal['allow', 'deny', 'pending']", value)
return None

def _normalize_decision(
self,
decision: object,
) -> tuple[Literal["allow", "deny", "pending"], str | None]:
if isinstance(decision, str):
normalized = decision.strip().lower()
if normalized == "allow":
return "allow", None
if normalized == "deny":
return "deny", None
if normalized == "pending":
return "pending", None
status = self._coerce_known_status(decision.strip().lower())
if status is not None:
return status, None
return self._unknown_decision()

if isinstance(decision, Mapping):
raw_status = str(decision.get("status", "")).strip().lower()
reason_value = decision.get("reason")
reason = str(reason_value) if reason_value is not None else None
if raw_status == "allow":
return "allow", reason
if raw_status == "deny":
return "deny", reason
if raw_status == "pending":
return "pending", reason
status = self._coerce_known_status(str(decision.get("status", "")).strip().lower())
if status is not None:
return status, reason
unknown_status, unknown_reason = self._unknown_decision()
return unknown_status, reason if reason is not None else unknown_reason

Expand Down
9 changes: 2 additions & 7 deletions agent_assembly/adapters/langgraph/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ def _is_tool_node(node_executor: Any) -> bool:


def _wrap_tool_node_subgraphs(
node_name: str,
tool_node: Any,
process_agent_id: str | None,
) -> bool:
Expand All @@ -253,11 +252,10 @@ def _wrap_tool_node_subgraphs(
if not isinstance(tools_by_name, dict):
return False
wrapped_any = False
for tool_name, tool in list(tools_by_name.items()):
for tool_name, tool in tools_by_name.items():
tool_func = getattr(tool, "func", None)
if tool_func is not None and _is_compiled_subgraph(tool_func):
wrapper = _make_subgraph_spawn_wrapper(
str(tool_name),
tool_func,
process_agent_id,
spawned_by_tool=str(tool_name),
Expand All @@ -272,7 +270,6 @@ def _wrap_tool_node_subgraphs(


def _make_subgraph_spawn_wrapper(
node_name: str,
subgraph: Any,
process_agent_id: str | None,
*,
Expand Down Expand Up @@ -325,7 +322,6 @@ def _wrap_subgraph_spawn_node(node_map: Any, node_name: Any, node_executor: Any,
"""Wrap a compiled-subgraph node (spawn point) for lineage. Return True when wrapped."""
node_delegation_reason = f"langgraph_node:{node_name}"
sync_wrapper = _make_subgraph_spawn_wrapper(
str(node_name),
node_executor,
process_agent_id,
spawned_by_tool=None,
Expand All @@ -335,7 +331,6 @@ def _wrap_subgraph_spawn_node(node_map: Any, node_name: Any, node_executor: Any,
node_map[node_name] = sync_wrapper
if hasattr(node_executor, "ainvoke") and not getattr(node_executor, "_agent_assembly_ainvoke_spawned", False):
async_wrapper = _make_subgraph_spawn_wrapper(
str(node_name),
node_executor,
process_agent_id,
async_=True,
Expand Down Expand Up @@ -369,7 +364,7 @@ def _wrap_node_entry(
# ToolNode: intercept any compiled-subgraph tools it holds.
# Must come before the callable() check since ToolNode is also callable.
if _is_tool_node(node_executor):
return _wrap_tool_node_subgraphs(str(node_name), node_executor, process_agent_id)
return _wrap_tool_node_subgraphs(node_executor, process_agent_id)

if callable(node_executor):
return _wrap_callable_node_executor(node_map, node_name, node_executor, callback_handler)
Expand Down
13 changes: 7 additions & 6 deletions agent_assembly/adapters/mcp/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import importlib.util
import inspect
from dataclasses import dataclass
from typing import Any, Literal, Mapping
from typing import TYPE_CHECKING, Any, Literal, Mapping

if TYPE_CHECKING:
from agent_assembly.exceptions import MCPToolBlockedError

from agent_assembly.adapters.crewai.patch import (
_get_pending_tool_approval_timeout_seconds as _resolve_pending_timeout_seconds,
Expand Down Expand Up @@ -224,16 +227,14 @@ def _build_blocked_error(
server_identifier: str,
reason: str | None,
is_pending_rejection: bool,
) -> Exception:
) -> MCPToolBlockedError:
from agent_assembly.exceptions import MCPToolBlockedError

reason_text = reason or "No reason provided."
if is_pending_rejection:
message = f"MCP tool '{tool_name}' on server '{server_identifier}' " f"rejected during approval: {reason_text}"
message = f"MCP tool '{tool_name}' on server '{server_identifier}' rejected during approval: {reason_text}"
else:
message = (
f"MCP tool '{tool_name}' on server '{server_identifier}' " f"blocked by governance policy: {reason_text}"
)
message = f"MCP tool '{tool_name}' on server '{server_identifier}' blocked by governance policy: {reason_text}"

return MCPToolBlockedError(
message,
Expand Down
9 changes: 6 additions & 3 deletions agent_assembly/adapters/pydantic_ai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from collections.abc import Mapping
from dataclasses import dataclass
from functools import wraps
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal

if TYPE_CHECKING:
from agent_assembly.exceptions import PolicyViolationError

from agent_assembly.adapters.crewai.patch import (
_get_pending_tool_approval_timeout_seconds as _resolve_pending_timeout_seconds,
Expand Down Expand Up @@ -593,14 +596,14 @@ async def _record_async_tool_result(
await recorded


def _build_denied_error(tool_name: str, reason: str | None) -> Exception:
def _build_denied_error(tool_name: str, reason: str | None) -> PolicyViolationError:
from agent_assembly.exceptions import PolicyViolationError

reason_text = reason or "No reason provided."
return PolicyViolationError(f"Tool '{tool_name}' blocked by governance policy: {reason_text}")


def _build_pending_rejected_error(tool_name: str, reason: str | None) -> Exception:
def _build_pending_rejected_error(tool_name: str, reason: str | None) -> PolicyViolationError:
from agent_assembly.exceptions import PolicyViolationError

reason_text = reason or "No reason provided."
Expand Down
10 changes: 3 additions & 7 deletions agent_assembly/cli/adapter_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

if TYPE_CHECKING:
pass

from agent_assembly.adapters.base import FrameworkAdapter, GovernanceInterceptor

Expand Down Expand Up @@ -147,7 +143,7 @@ def _check_register_hooks_signature(cls: type) -> AdapterValidationResult:
return AdapterValidationResult(
check_name="register_hooks_signature",
passed=False,
message=(f"register_hooks() first parameter annotated as {annotation}, " f"expected GovernanceInterceptor."),
message=f"register_hooks() first parameter annotated as {annotation}, expected GovernanceInterceptor.",
)


Expand All @@ -162,7 +158,7 @@ def _check_unregister_hooks_idempotent(
return AdapterValidationResult(
check_name="unregister_hooks_idempotent",
passed=False,
message=(f"unregister_hooks() is not idempotent: " f"second call raised {type(exc).__name__}: {exc}"),
message=f"unregister_hooks() is not idempotent: second call raised {type(exc).__name__}: {exc}",
)
return AdapterValidationResult(
check_name="unregister_hooks_idempotent",
Expand Down Expand Up @@ -231,7 +227,7 @@ def _check_entry_point_metadata(cls: type, path_or_module: str) -> AdapterValida
return AdapterValidationResult(
check_name="entry_point_metadata",
passed=False,
message=(f"No entry point references {class_qualname}. " f"Found: {entry_points}."),
message=f"No entry point references {class_qualname}. Found: {entry_points}.",
)


Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Global options:

[mypy]
# The pydantic plugin teaches mypy that BaseModel fields with Field(default=...) /
# Field(None) defaults are optional at construction; without it every model with
# defaulted fields reports spurious "Missing named argument" call-arg errors.
plugins = pydantic.mypy
packages = agent_assembly,test
exclude = (?x)(
test/unit_test.{1,64}.py # Ignore the code of unit test because of the usage of mock
Expand Down
16 changes: 7 additions & 9 deletions test/unit/adapters/langchain/test_langgraph_spawn_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_sync_wrapper_sets_spawn_ctx_then_resets(self) -> None:
subgraph = MagicMock()
subgraph.invoke = original_invoke

wrapper = _make_subgraph_spawn_wrapper("subnode", subgraph, "parent-agent")
wrapper = _make_subgraph_spawn_wrapper(subgraph, "parent-agent")
assert _SPAWN_CTX.get() is None

result = wrapper({"input": "x"})
Expand All @@ -60,7 +60,6 @@ def test_delegation_reason_passed_through(self) -> None:
subgraph.invoke = MagicMock(side_effect=lambda *_args, **_kwargs: captured.append(_SPAWN_CTX.get()) or "r") # type: ignore[func-returns-value]

wrapper = _make_subgraph_spawn_wrapper(
"mynode",
subgraph,
"parent",
delegation_reason="langgraph_node:mynode",
Expand All @@ -78,7 +77,6 @@ def test_spawned_by_tool_passed_through_for_tool_node(self) -> None:
subgraph.invoke = MagicMock(side_effect=lambda *_args, **_kwargs: captured.append(_SPAWN_CTX.get()) or "r") # type: ignore[func-returns-value]

wrapper = _make_subgraph_spawn_wrapper(
"tool_x",
subgraph,
"parent",
spawned_by_tool="tool_x",
Expand All @@ -103,7 +101,7 @@ async def fake_ainvoke(*args: object, **kwargs: object) -> str:
subgraph.ainvoke = fake_ainvoke
subgraph.invoke = MagicMock()

wrapper = _make_subgraph_spawn_wrapper("asyncnode", subgraph, "parent-async", async_=True)
wrapper = _make_subgraph_spawn_wrapper(subgraph, "parent-async", async_=True)
result = await wrapper({"input": "y"})

assert result == "async-result"
Expand All @@ -122,7 +120,7 @@ def test_spawn_ctx_depth_increments_when_already_in_ctx(self) -> None:
outer = SpawnContext(parent_agent_id="grandparent", depth=1)
token = _SPAWN_CTX.set(outer)
try:
wrapper = _make_subgraph_spawn_wrapper("child", subgraph, "parent-agent")
wrapper = _make_subgraph_spawn_wrapper(subgraph, "parent-agent")
wrapper({})
finally:
_SPAWN_CTX.reset(token)
Expand All @@ -135,7 +133,7 @@ def test_wrapper_is_pass_through_on_exception(self) -> None:
subgraph = MagicMock()
subgraph.invoke = MagicMock(side_effect=RuntimeError("graph error"))

wrapper = _make_subgraph_spawn_wrapper("err_node", subgraph, "parent")
wrapper = _make_subgraph_spawn_wrapper(subgraph, "parent")
with pytest.raises(RuntimeError, match="graph error"):
wrapper({})
# Token must still be reset
Expand Down Expand Up @@ -231,7 +229,7 @@ def test_wraps_compiled_subgraph_tool_inside_tool_node(self) -> None:
tool_node = MagicMock(spec=["tools_by_name"])
tool_node.tools_by_name = {"search": tool}

result = _wrap_tool_node_subgraphs("tool_node", tool_node, "parent-agent")
result = _wrap_tool_node_subgraphs(tool_node, "parent-agent")

assert result is True
assert tool.func is not subgraph
Expand All @@ -245,7 +243,7 @@ def test_returns_false_when_no_compiled_subgraph_tools(self) -> None:
tool_node = MagicMock(spec=["tools_by_name"])
tool_node.tools_by_name = {"search": tool}

result = _wrap_tool_node_subgraphs("tool_node", tool_node, "parent-agent")
result = _wrap_tool_node_subgraphs(tool_node, "parent-agent")

assert result is False

Expand All @@ -260,7 +258,7 @@ def test_spawned_by_tool_and_delegation_reason_set_for_tool_node_path(self) -> N
tool_node = MagicMock(spec=["tools_by_name"])
tool_node.tools_by_name = {"retriever": tool}

_wrap_tool_node_subgraphs("tool_node", tool_node, "parent-agent")
_wrap_tool_node_subgraphs(tool_node, "parent-agent")

# Invoke the wrapped function to verify spawn context values
tool.func({})
Expand Down
Loading