diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 9bbe6d7b23..906b481d16 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -168,16 +168,57 @@ def _format_content( return " | ".join(parts), truncated -def _get_tool_origin(tool: "BaseTool") -> str: +def _find_transfer_target(agent, agent_name: str): + """Find a transfer target agent by name in the accessible agent tree. + + Searches the current agent's sub-agents, parent, and peer agents + to locate the transfer target. + + Args: + agent: The current agent executing the transfer. + agent_name: The name of the transfer target to find. + + Returns: + The matching agent object, or None if not found. + """ + for sub in getattr(agent, "sub_agents", []): + if sub.name == agent_name: + return sub + parent = getattr(agent, "parent_agent", None) + if parent is not None and parent.name == agent_name: + return parent + if parent is not None: + for peer in getattr(parent, "sub_agents", []): + if peer.name == agent_name and peer.name != agent.name: + return peer + return None + + +def _get_tool_origin( + tool: "BaseTool", + tool_args: Optional[dict[str, Any]] = None, + tool_context: Optional["ToolContext"] = None, +) -> str: """Returns the provenance category of a tool. Uses lazy imports to avoid circular dependencies. + For ``TransferToAgentTool`` the classification is **call-level**: when + *tool_args* and *tool_context* are supplied the selected + ``agent_name`` is resolved against the agent tree so that transfers + to a ``RemoteA2aAgent`` are labelled ``TRANSFER_A2A`` rather than + the generic ``TRANSFER_AGENT``. + Args: tool: The tool instance. + tool_args: Optional tool arguments, used for call-level + classification of TransferToAgentTool. + tool_context: Optional tool context, used to access the agent + tree for TransferToAgentTool classification. Returns: - One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN. + One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, + TRANSFER_A2A, or UNKNOWN. """ # Import lazily to avoid circular dependencies. # pylint: disable=g-import-not-at-top @@ -199,6 +240,15 @@ def _get_tool_origin(tool: "BaseTool") -> str: if McpTool is not None and isinstance(tool, McpTool): return "MCP" if isinstance(tool, TransferToAgentTool): + if RemoteA2aAgent is not None and tool_args and tool_context: + agent_name = tool_args.get("agent_name") + if agent_name: + target = _find_transfer_target( + tool_context._invocation_context.agent, + agent_name, + ) + if target is not None and isinstance(target, RemoteA2aAgent): + return "TRANSFER_A2A" return "TRANSFER_AGENT" if isinstance(tool, AgentTool): if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent): @@ -3228,7 +3278,7 @@ async def before_tool_callback( args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - tool_origin = _get_tool_origin(tool) + tool_origin = _get_tool_origin(tool, tool_args, tool_context) content_dict = { "tool": tool.name, "args": args_truncated, @@ -3262,7 +3312,7 @@ async def after_tool_callback( resp_truncated, is_truncated = _recursive_smart_truncate( result, self.config.max_content_length ) - tool_origin = _get_tool_origin(tool) + tool_origin = _get_tool_origin(tool, tool_args, tool_context) content_dict = { "tool": tool.name, "result": resp_truncated, @@ -3307,7 +3357,7 @@ async def on_tool_error_callback( args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - tool_origin = _get_tool_origin(tool) + tool_origin = _get_tool_origin(tool, tool_args, tool_context) content_dict = { "tool": tool.name, "args": args_truncated, diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 4dd986386f..8cfcfe439e 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -4486,6 +4486,228 @@ def test_transfer_tool_returns_transfer_agent(self): result = bigquery_agent_analytics_plugin._get_tool_origin(tool) assert result == "TRANSFER_AGENT" + def test_transfer_tool_without_args_returns_transfer_agent(self): + """TransferToAgentTool without tool_args falls back to TRANSFER_AGENT.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + tool = TransferToAgentTool(agent_names=["remote_a2a"]) + result = bigquery_agent_analytics_plugin._get_tool_origin( + tool, tool_args=None, tool_context=None + ) + assert result == "TRANSFER_AGENT" + + def test_transfer_to_remote_a2a_sub_agent_returns_transfer_a2a(self): + """Transfer to a RemoteA2aAgent sub-agent is classified TRANSFER_A2A.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + try: + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + except ImportError: + pytest.skip("A2A agent not available") + + remote_agent = mock.MagicMock(spec=RemoteA2aAgent) + remote_agent.name = "remote_a2a" + + current_agent = mock.MagicMock() + current_agent.name = "root" + current_agent.sub_agents = [remote_agent] + current_agent.parent_agent = None + + inv_ctx = mock.MagicMock() + inv_ctx.agent = current_agent + tool_context = mock.MagicMock() + tool_context._invocation_context = inv_ctx + + tool = TransferToAgentTool(agent_names=["remote_a2a"]) + result = bigquery_agent_analytics_plugin._get_tool_origin( + tool, + tool_args={"agent_name": "remote_a2a"}, + tool_context=tool_context, + ) + assert result == "TRANSFER_A2A" + + def test_transfer_to_local_sub_agent_returns_transfer_agent(self): + """Transfer to a local sub-agent is still classified TRANSFER_AGENT.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + local_agent = mock.MagicMock() + local_agent.name = "local_sub" + + current_agent = mock.MagicMock() + current_agent.name = "root" + current_agent.sub_agents = [local_agent] + current_agent.parent_agent = None + + inv_ctx = mock.MagicMock() + inv_ctx.agent = current_agent + tool_context = mock.MagicMock() + tool_context._invocation_context = inv_ctx + + tool = TransferToAgentTool(agent_names=["local_sub"]) + result = bigquery_agent_analytics_plugin._get_tool_origin( + tool, + tool_args={"agent_name": "local_sub"}, + tool_context=tool_context, + ) + assert result == "TRANSFER_AGENT" + + def test_transfer_to_a2a_peer_returns_transfer_a2a(self): + """Transfer to a RemoteA2aAgent peer is classified TRANSFER_A2A.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + try: + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + except ImportError: + pytest.skip("A2A agent not available") + + remote_peer = mock.MagicMock(spec=RemoteA2aAgent) + remote_peer.name = "remote_peer" + + current_agent = mock.MagicMock() + current_agent.name = "child" + current_agent.sub_agents = [] + + parent_agent = mock.MagicMock() + parent_agent.name = "parent" + parent_agent.sub_agents = [current_agent, remote_peer] + current_agent.parent_agent = parent_agent + + inv_ctx = mock.MagicMock() + inv_ctx.agent = current_agent + tool_context = mock.MagicMock() + tool_context._invocation_context = inv_ctx + + tool = TransferToAgentTool( + agent_names=["remote_peer"], + ) + result = bigquery_agent_analytics_plugin._get_tool_origin( + tool, + tool_args={"agent_name": "remote_peer"}, + tool_context=tool_context, + ) + assert result == "TRANSFER_A2A" + + def test_transfer_mixed_targets_classifies_per_call(self): + """A single TransferToAgentTool with mixed targets classifies per call.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + try: + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + except ImportError: + pytest.skip("A2A agent not available") + + remote_agent = mock.MagicMock(spec=RemoteA2aAgent) + remote_agent.name = "remote_a2a" + local_agent = mock.MagicMock() + local_agent.name = "local_sub" + + current_agent = mock.MagicMock() + current_agent.name = "root" + current_agent.sub_agents = [remote_agent, local_agent] + current_agent.parent_agent = None + + inv_ctx = mock.MagicMock() + inv_ctx.agent = current_agent + tool_context = mock.MagicMock() + tool_context._invocation_context = inv_ctx + + tool = TransferToAgentTool( + agent_names=["remote_a2a", "local_sub"], + ) + + # Transfer to remote target → TRANSFER_A2A + result = bigquery_agent_analytics_plugin._get_tool_origin( + tool, + tool_args={"agent_name": "remote_a2a"}, + tool_context=tool_context, + ) + assert result == "TRANSFER_A2A" + + # Transfer to local target → TRANSFER_AGENT + result = bigquery_agent_analytics_plugin._get_tool_origin( + tool, + tool_args={"agent_name": "local_sub"}, + tool_context=tool_context, + ) + assert result == "TRANSFER_AGENT" + + @pytest.mark.asyncio + async def test_tool_error_callback_classifies_a2a_transfer( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + dummy_arrow_schema, + mock_asyncio_to_thread, + ): + """on_tool_error_callback produces TRANSFER_A2A for RemoteA2aAgent.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + try: + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + except ImportError: + pytest.skip("A2A agent not available") + + remote_agent = mock.MagicMock(spec=RemoteA2aAgent) + remote_agent.name = "remote_a2a" + + mock_agent = mock.MagicMock(spec=base_agent.BaseAgent) + mock_agent.name = "root" + mock_agent.instruction = "" + mock_agent.sub_agents = [remote_agent] + mock_agent.parent_agent = None + + mock_s = mock.create_autospec( + session_lib.Session, instance=True, spec_set=True + ) + type(mock_s).id = mock.PropertyMock(return_value="sess-1") + type(mock_s).user_id = mock.PropertyMock(return_value="user-1") + type(mock_s).app_name = mock.PropertyMock(return_value="test_app") + type(mock_s).state = mock.PropertyMock(return_value={}) + + inv_ctx = InvocationContext( + agent=mock_agent, + session=mock_s, + invocation_id="inv-err", + session_service=mock.create_autospec( + base_session_service_lib.BaseSessionService, + instance=True, + spec_set=True, + ), + plugin_manager=mock.create_autospec( + plugin_manager_lib.PluginManager, + instance=True, + spec_set=True, + ), + ) + tool_ctx = tool_context_lib.ToolContext(invocation_context=inv_ctx) + tool = TransferToAgentTool(agent_names=["remote_a2a"]) + + async with managed_plugin( + PROJECT_ID, DATASET_ID, table_id=TABLE_ID + ) as plugin: + await plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + bigquery_agent_analytics_plugin.TraceManager.push_span(tool_ctx, "tool") + await plugin.on_tool_error_callback( + tool=tool, + tool_args={"agent_name": "remote_a2a"}, + tool_context=tool_ctx, + error=RuntimeError("connection refused"), + ) + await asyncio.sleep(0.01) + + rows = await _get_captured_rows_async( + mock_write_client, dummy_arrow_schema + ) + + assert len(rows) == 1 + assert rows[0]["event_type"] == "TOOL_ERROR" + content = json.loads(rows[0]["content"]) + assert content["tool_origin"] == "TRANSFER_A2A" + def test_mcp_tool_returns_mcp(self): try: from google.adk.tools.mcp_tool.mcp_tool import McpTool