Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,57 @@ def _format_content(
return " | ".join(parts), truncated


def _get_tool_origin(tool: "BaseTool") -> str:
def _find_transfer_target(agent, agent_name: str):
"""Find a transfer target agent by name in the accessible agent tree.

Searches the current agent's sub-agents, parent, and peer agents
to locate the transfer target.

Args:
agent: The current agent executing the transfer.
agent_name: The name of the transfer target to find.

Returns:
The matching agent object, or None if not found.
"""
for sub in getattr(agent, "sub_agents", []):
if sub.name == agent_name:
return sub
parent = getattr(agent, "parent_agent", None)
if parent is not None and parent.name == agent_name:
return parent
if parent is not None:
for peer in getattr(parent, "sub_agents", []):
if peer.name == agent_name and peer.name != agent.name:
return peer
return None


def _get_tool_origin(
tool: "BaseTool",
tool_args: Optional[dict[str, Any]] = None,
tool_context: Optional["ToolContext"] = None,
) -> str:
"""Returns the provenance category of a tool.

Uses lazy imports to avoid circular dependencies.

For ``TransferToAgentTool`` the classification is **call-level**: when
*tool_args* and *tool_context* are supplied the selected
``agent_name`` is resolved against the agent tree so that transfers
to a ``RemoteA2aAgent`` are labelled ``TRANSFER_A2A`` rather than
the generic ``TRANSFER_AGENT``.

Args:
tool: The tool instance.
tool_args: Optional tool arguments, used for call-level
classification of TransferToAgentTool.
tool_context: Optional tool context, used to access the agent
tree for TransferToAgentTool classification.

Returns:
One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN.
One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT,
TRANSFER_A2A, or UNKNOWN.
"""
# Import lazily to avoid circular dependencies.
# pylint: disable=g-import-not-at-top
Expand All @@ -199,6 +240,15 @@ def _get_tool_origin(tool: "BaseTool") -> str:
if McpTool is not None and isinstance(tool, McpTool):
return "MCP"
if isinstance(tool, TransferToAgentTool):
if RemoteA2aAgent is not None and tool_args and tool_context:
agent_name = tool_args.get("agent_name")
if agent_name:
target = _find_transfer_target(
tool_context._invocation_context.agent,
agent_name,
)
if target is not None and isinstance(target, RemoteA2aAgent):
return "TRANSFER_A2A"
return "TRANSFER_AGENT"
if isinstance(tool, AgentTool):
if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent):
Expand Down Expand Up @@ -3228,7 +3278,7 @@ async def before_tool_callback(
args_truncated, is_truncated = _recursive_smart_truncate(
tool_args, self.config.max_content_length
)
tool_origin = _get_tool_origin(tool)
tool_origin = _get_tool_origin(tool, tool_args, tool_context)
content_dict = {
"tool": tool.name,
"args": args_truncated,
Expand Down Expand Up @@ -3262,7 +3312,7 @@ async def after_tool_callback(
resp_truncated, is_truncated = _recursive_smart_truncate(
result, self.config.max_content_length
)
tool_origin = _get_tool_origin(tool)
tool_origin = _get_tool_origin(tool, tool_args, tool_context)
content_dict = {
"tool": tool.name,
"result": resp_truncated,
Expand Down Expand Up @@ -3307,7 +3357,7 @@ async def on_tool_error_callback(
args_truncated, is_truncated = _recursive_smart_truncate(
tool_args, self.config.max_content_length
)
tool_origin = _get_tool_origin(tool)
tool_origin = _get_tool_origin(tool, tool_args, tool_context)
content_dict = {
"tool": tool.name,
"args": args_truncated,
Expand Down
222 changes: 222 additions & 0 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,6 +4486,228 @@ def test_transfer_tool_returns_transfer_agent(self):
result = bigquery_agent_analytics_plugin._get_tool_origin(tool)
assert result == "TRANSFER_AGENT"

def test_transfer_tool_without_args_returns_transfer_agent(self):
"""TransferToAgentTool without tool_args falls back to TRANSFER_AGENT."""
from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool

tool = TransferToAgentTool(agent_names=["remote_a2a"])
result = bigquery_agent_analytics_plugin._get_tool_origin(
tool, tool_args=None, tool_context=None
)
assert result == "TRANSFER_AGENT"

def test_transfer_to_remote_a2a_sub_agent_returns_transfer_a2a(self):
"""Transfer to a RemoteA2aAgent sub-agent is classified TRANSFER_A2A."""
from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool

try:
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
except ImportError:
pytest.skip("A2A agent not available")

remote_agent = mock.MagicMock(spec=RemoteA2aAgent)
remote_agent.name = "remote_a2a"

current_agent = mock.MagicMock()
current_agent.name = "root"
current_agent.sub_agents = [remote_agent]
current_agent.parent_agent = None

inv_ctx = mock.MagicMock()
inv_ctx.agent = current_agent
tool_context = mock.MagicMock()
tool_context._invocation_context = inv_ctx

tool = TransferToAgentTool(agent_names=["remote_a2a"])
result = bigquery_agent_analytics_plugin._get_tool_origin(
tool,
tool_args={"agent_name": "remote_a2a"},
tool_context=tool_context,
)
assert result == "TRANSFER_A2A"

def test_transfer_to_local_sub_agent_returns_transfer_agent(self):
"""Transfer to a local sub-agent is still classified TRANSFER_AGENT."""
from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool

local_agent = mock.MagicMock()
local_agent.name = "local_sub"

current_agent = mock.MagicMock()
current_agent.name = "root"
current_agent.sub_agents = [local_agent]
current_agent.parent_agent = None

inv_ctx = mock.MagicMock()
inv_ctx.agent = current_agent
tool_context = mock.MagicMock()
tool_context._invocation_context = inv_ctx

tool = TransferToAgentTool(agent_names=["local_sub"])
result = bigquery_agent_analytics_plugin._get_tool_origin(
tool,
tool_args={"agent_name": "local_sub"},
tool_context=tool_context,
)
assert result == "TRANSFER_AGENT"

def test_transfer_to_a2a_peer_returns_transfer_a2a(self):
"""Transfer to a RemoteA2aAgent peer is classified TRANSFER_A2A."""
from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool

try:
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
except ImportError:
pytest.skip("A2A agent not available")

remote_peer = mock.MagicMock(spec=RemoteA2aAgent)
remote_peer.name = "remote_peer"

current_agent = mock.MagicMock()
current_agent.name = "child"
current_agent.sub_agents = []

parent_agent = mock.MagicMock()
parent_agent.name = "parent"
parent_agent.sub_agents = [current_agent, remote_peer]
current_agent.parent_agent = parent_agent

inv_ctx = mock.MagicMock()
inv_ctx.agent = current_agent
tool_context = mock.MagicMock()
tool_context._invocation_context = inv_ctx

tool = TransferToAgentTool(
agent_names=["remote_peer"],
)
result = bigquery_agent_analytics_plugin._get_tool_origin(
tool,
tool_args={"agent_name": "remote_peer"},
tool_context=tool_context,
)
assert result == "TRANSFER_A2A"

def test_transfer_mixed_targets_classifies_per_call(self):
"""A single TransferToAgentTool with mixed targets classifies per call."""
from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool

try:
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
except ImportError:
pytest.skip("A2A agent not available")

remote_agent = mock.MagicMock(spec=RemoteA2aAgent)
remote_agent.name = "remote_a2a"
local_agent = mock.MagicMock()
local_agent.name = "local_sub"

current_agent = mock.MagicMock()
current_agent.name = "root"
current_agent.sub_agents = [remote_agent, local_agent]
current_agent.parent_agent = None

inv_ctx = mock.MagicMock()
inv_ctx.agent = current_agent
tool_context = mock.MagicMock()
tool_context._invocation_context = inv_ctx

tool = TransferToAgentTool(
agent_names=["remote_a2a", "local_sub"],
)

# Transfer to remote target → TRANSFER_A2A
result = bigquery_agent_analytics_plugin._get_tool_origin(
tool,
tool_args={"agent_name": "remote_a2a"},
tool_context=tool_context,
)
assert result == "TRANSFER_A2A"

# Transfer to local target → TRANSFER_AGENT
result = bigquery_agent_analytics_plugin._get_tool_origin(
tool,
tool_args={"agent_name": "local_sub"},
tool_context=tool_context,
)
assert result == "TRANSFER_AGENT"

@pytest.mark.asyncio
async def test_tool_error_callback_classifies_a2a_transfer(
self,
mock_auth_default,
mock_bq_client,
mock_write_client,
mock_to_arrow_schema,
dummy_arrow_schema,
mock_asyncio_to_thread,
):
"""on_tool_error_callback produces TRANSFER_A2A for RemoteA2aAgent."""
from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool

try:
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
except ImportError:
pytest.skip("A2A agent not available")

remote_agent = mock.MagicMock(spec=RemoteA2aAgent)
remote_agent.name = "remote_a2a"

mock_agent = mock.MagicMock(spec=base_agent.BaseAgent)
mock_agent.name = "root"
mock_agent.instruction = ""
mock_agent.sub_agents = [remote_agent]
mock_agent.parent_agent = None

mock_s = mock.create_autospec(
session_lib.Session, instance=True, spec_set=True
)
type(mock_s).id = mock.PropertyMock(return_value="sess-1")
type(mock_s).user_id = mock.PropertyMock(return_value="user-1")
type(mock_s).app_name = mock.PropertyMock(return_value="test_app")
type(mock_s).state = mock.PropertyMock(return_value={})

inv_ctx = InvocationContext(
agent=mock_agent,
session=mock_s,
invocation_id="inv-err",
session_service=mock.create_autospec(
base_session_service_lib.BaseSessionService,
instance=True,
spec_set=True,
),
plugin_manager=mock.create_autospec(
plugin_manager_lib.PluginManager,
instance=True,
spec_set=True,
),
)
tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx)
tool = TransferToAgentTool(agent_names=["remote_a2a"])

async with managed_plugin(
PROJECT_ID, DATASET_ID, table_id=TABLE_ID
) as plugin:
await plugin._ensure_started()
mock_write_client.append_rows.reset_mock()

bigquery_agent_analytics_plugin.TraceManager.push_span(tool_ctx, "tool")
await plugin.on_tool_error_callback(
tool=tool,
tool_args={"agent_name": "remote_a2a"},
tool_context=tool_ctx,
error=RuntimeError("connection refused"),
)
await asyncio.sleep(0.01)

rows = await _get_captured_rows_async(
mock_write_client, dummy_arrow_schema
)

assert len(rows) == 1
assert rows[0]["event_type"] == "TOOL_ERROR"
content = json.loads(rows[0]["content"])
assert content["tool_origin"] == "TRANSFER_A2A"

def test_mcp_tool_returns_mcp(self):
try:
from google.adk.tools.mcp_tool.mcp_tool import McpTool
Expand Down