From 0fb959d53e55f46b166d3607f0f8f36adb19c550 Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Thu, 19 Feb 2026 22:28:08 -0800 Subject: [PATCH] feat: Add schema auto-upgrade, tool provenance, HITL tracing, and span hierarchy fix to BigQuery Agent Analytics plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This CL adds four enhancements to the BigQuery Agent Analytics plugin and fixes a span hierarchy corruption bug. - **Schema Auto-Upgrade:** Additive-only schema migration that automatically adds missing columns to existing BQ tables on startup. A `adk_schema_version` label on the table (starting at `"1"`, bumped with each schema change) makes the check idempotent — the diff runs at most once per version. Enabled by default (`auto_schema_upgrade=True`) because upgrades are additive-only and fail-safe. Pre-versioning tables (no label) are treated as outdated, diffed, and stamped. No previous schema versions need to be stored; the logic diffs actual columns against the current canonical schema. - **Tool Provenance:** Adds `tool_origin` to TOOL_* event content, distinguishing six origin types — `LOCAL` (FunctionTool), `MCP` (McpTool), `A2A` (AgentTool wrapping RemoteA2aAgent), `SUB_AGENT` (AgentTool), `TRANSFER_AGENT` (TransferToAgentTool), and `UNKNOWN` (fallback) — via `isinstance()` checks with lazy imports to avoid circular dependencies. - **HITL Tracing:** Emits dedicated HITL event types (`HITL_CONFIRMATION_REQUEST`, `HITL_CREDENTIAL_REQUEST`, `HITL_INPUT_REQUEST` + `_COMPLETED` variants) for human-in-the-loop interactions. Detection lives in `on_event_callback` (for synthetic `adk_request_*` FunctionCall events emitted by the framework) and `on_user_message_callback` (for `adk_request_*` FunctionResponse completions sent by the user), not in tool callbacks — because `adk_request_*` names are synthetic function calls that bypass `before_tool_callback`/`after_tool_callback` entirely. - **Span Hierarchy Fix (#4561):** Removes `context.attach()`/`context.detach()` calls from `TraceManager.push_span()`, `attach_current_span()`, and `pop_span()`. The plugin was injecting its spans into the shared OTel context, which corrupted the framework's span hierarchy when an external exporter (e.g. `opentelemetry-instrumentation-vertexai`) was active — causing `call_llm` to be re-parented under `llm_request` and parent spans to show shorter durations than children. The plugin now tracks span_id/parent_span_id via its internal contextvar stack without mutating ambient OTel context. Total tests: 139 (up from 105). Closes #4554 Fixes #4561 Co-Authored-By: Claude Opus 4.6 --- .../bigquery_agent_analytics_plugin.py | 289 ++++++- .../test_bigquery_agent_analytics_plugin.py | 712 ++++++++++++++++++ 2 files changed, 970 insertions(+), 31 deletions(-) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 5b0fcf55e9..dd9a5871cc 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -51,7 +51,6 @@ from google.cloud.bigquery_storage_v1 import types as bq_storage_types from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient from google.genai import types -from opentelemetry import context from opentelemetry import trace import pyarrow as pa @@ -71,6 +70,24 @@ "google.adk.plugins.bigquery_agent_analytics", __version__ ) +# Bumped when the schema changes (1 → 2 → 3 …). Used as a table +# label for governance and to decide whether auto-upgrade should run. +_SCHEMA_VERSION = "1" +_SCHEMA_VERSION_LABEL_KEY = "adk_schema_version" + +# Human-in-the-loop (HITL) tool names that receive additional +# dedicated event types alongside the normal TOOL_* events. +_HITL_TOOL_NAMES = frozenset({ + "adk_request_credential", + "adk_request_confirmation", + "adk_request_input", +}) +_HITL_EVENT_MAP = MappingProxyType({ + "adk_request_credential": "HITL_CREDENTIAL_REQUEST", + "adk_request_confirmation": "HITL_CONFIRMATION_REQUEST", + "adk_request_input": "HITL_INPUT_REQUEST", +}) + def _safe_callback(func): """Decorator that catches and logs exceptions in plugin callbacks. @@ -132,6 +149,47 @@ def _format_content( return " | ".join(parts), truncated +def _get_tool_origin(tool: "BaseTool") -> str: + """Returns the provenance category of a tool. + + Uses lazy imports to avoid circular dependencies. + + Args: + tool: The tool instance. + + Returns: + One of LOCAL, MCP, A2A, SUB_AGENT, TRANSFER_AGENT, or UNKNOWN. + """ + # Import lazily to avoid circular dependencies. + # pylint: disable=g-import-not-at-top + from ..tools.agent_tool import AgentTool # pytype: disable=import-error + from ..tools.function_tool import FunctionTool # pytype: disable=import-error + from ..tools.transfer_to_agent_tool import TransferToAgentTool # pytype: disable=import-error + + try: + from ..tools.mcp_tool.mcp_tool import McpTool # pytype: disable=import-error + except ImportError: + McpTool = None + + try: + from ..agents.remote_a2a_agent import RemoteA2aAgent # pytype: disable=import-error + except ImportError: + RemoteA2aAgent = None + + # Order matters: TransferToAgentTool is a subclass of FunctionTool. + if McpTool is not None and isinstance(tool, McpTool): + return "MCP" + if isinstance(tool, TransferToAgentTool): + return "TRANSFER_AGENT" + if isinstance(tool, AgentTool): + if RemoteA2aAgent is not None and isinstance(tool.agent, RemoteA2aAgent): + return "A2A" + return "SUB_AGENT" + if isinstance(tool, FunctionTool): + return "LOCAL" + return "UNKNOWN" + + def _recursive_smart_truncate( obj: Any, max_len: int, seen: Optional[set[int]] = None ) -> tuple[Any, bool]: @@ -435,6 +493,11 @@ class BigQueryLoggerConfig: log_session_metadata: bool = True # Static custom tags (e.g. {"agent_role": "sales"}) custom_tags: dict[str, Any] = field(default_factory=dict) + # Automatically add new columns to existing tables when the plugin + # schema evolves. Only additive changes are made (columns are never + # dropped or altered). Safe to leave enabled; a version label on the + # table ensures the diff runs at most once per schema version. + auto_schema_upgrade: bool = True # ============================================================================== @@ -450,12 +513,17 @@ class BigQueryLoggerConfig: class _SpanRecord: """A single record on the unified span stack. - Consolidates span, token, id, ownership, and timing into one object + Consolidates span, id, ownership, and timing into one object so all stacks stay in sync by construction. + + Note: The plugin intentionally does NOT attach its spans to the + ambient OTel context (no ``context.attach``). This prevents the + plugin from corrupting the framework's span hierarchy when an + external OTel exporter (e.g. ``opentelemetry-instrumentation-vertexai``) + is active. See https://github.com/google/adk-python/issues/4561. """ span: trace.Span - token: Any # opentelemetry context token span_id: str owns_span: bool start_time_ns: int @@ -513,17 +581,26 @@ def get_trace_id(callback_context: CallbackContext) -> Optional[str]: @staticmethod def push_span( - callback_context: CallbackContext, span_name: Optional[str] = "adk-span" + callback_context: CallbackContext, + span_name: Optional[str] = "adk-span", ) -> str: """Starts a new span and pushes it onto the stack. - If OTel is not configured (returning non-recording spans), a UUID fallback - is generated to ensure span_id and parent_span_id are populated in logs. + The span is created but NOT attached to the ambient OTel context, + so it cannot corrupt the framework's own span hierarchy. The + plugin tracks span_id / parent_span_id internally via its own + contextvar stack. + + If OTel is not configured (returning non-recording spans), a UUID + fallback is generated to ensure span_id and parent_span_id are + populated in BigQuery logs. """ TraceManager.init_trace(callback_context) + # Create the span without attaching it to the ambient context. + # This avoids re-parenting framework spans like ``call_llm`` + # or ``execute_tool``. See #4561. span = tracer.start_span(span_name) - token = context.attach(trace.set_span_in_context(span)) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -532,7 +609,6 @@ def push_span( record = _SpanRecord( span=span, - token=token, span_id=span_id_str, owns_span=True, start_time_ns=time.time_ns(), @@ -548,11 +624,14 @@ def push_span( def attach_current_span( callback_context: CallbackContext, ) -> str: - """Attaches the current OTEL span to the stack without owning it.""" + """Records the current OTel span on the stack without owning it. + + The span is NOT re-attached to the ambient context; it is only + tracked internally for span_id / parent_span_id resolution. + """ TraceManager.init_trace(callback_context) span = trace.get_current_span() - token = context.attach(trace.set_span_in_context(span)) if span.get_span_context().is_valid: span_id_str = format(span.get_span_context().span_id, "016x") @@ -561,7 +640,6 @@ def attach_current_span( record = _SpanRecord( span=span, - token=token, span_id=span_id_str, owns_span=False, start_time_ns=time.time_ns(), @@ -575,7 +653,11 @@ def attach_current_span( @staticmethod def pop_span() -> tuple[Optional[str], Optional[int]]: - """Ends the current span and pops it from the stack.""" + """Ends the current span and pops it from the stack. + + No ambient OTel context is detached because we never attached + one in the first place (see ``push_span``). + """ records = _span_records_ctx.get() if not records: return None, None @@ -595,8 +677,6 @@ def pop_span() -> tuple[Optional[str], Optional[int]]: if record.owns_span: record.span.end() - context.detach(record.token) - return record.span_id, duration_ms @staticmethod @@ -1822,16 +1902,25 @@ def _atexit_cleanup(batch_processor: "BatchProcessor") -> None: ) def _ensure_schema_exists(self) -> None: - """Ensures the BigQuery table exists with the correct schema.""" + """Ensures the BigQuery table exists with the correct schema. + + When ``config.auto_schema_upgrade`` is True and the table already + exists, missing columns are added automatically (additive only). + A ``adk_schema_version`` label is written for governance. + """ try: - self.client.get_table(self.full_table_id) + existing_table = self.client.get_table(self.full_table_id) + if self.config.auto_schema_upgrade: + self._maybe_upgrade_schema(existing_table) except cloud_exceptions.NotFound: logger.info("Table %s not found, creating table.", self.full_table_id) tbl = bigquery.Table(self.full_table_id, schema=self._schema) tbl.time_partitioning = bigquery.TimePartitioning( - type_=bigquery.TimePartitioningType.DAY, field="timestamp" + type_=bigquery.TimePartitioningType.DAY, + field="timestamp", ) tbl.clustering_fields = self.config.clustering_fields + tbl.labels = {_SCHEMA_VERSION_LABEL_KEY: _SCHEMA_VERSION} try: self.client.create_table(tbl) except cloud_exceptions.Conflict: @@ -1851,6 +1940,50 @@ def _ensure_schema_exists(self) -> None: exc_info=True, ) + def _maybe_upgrade_schema(self, existing_table: bigquery.Table) -> None: + """Adds missing columns to an existing table (additive only). + + Args: + existing_table: The current BigQuery table object. + """ + stored_version = (existing_table.labels or {}).get( + _SCHEMA_VERSION_LABEL_KEY + ) + if stored_version == _SCHEMA_VERSION: + return + + existing_names = {f.name for f in existing_table.schema} + new_fields = [f for f in self._schema if f.name not in existing_names] + + updated = False + if new_fields: + merged = list(existing_table.schema) + new_fields + existing_table.schema = merged + updated = True + logger.info( + "Auto-upgrading table %s: adding columns %s", + self.full_table_id, + [f.name for f in new_fields], + ) + + # Always stamp the version label so we skip on next run. + labels = dict(existing_table.labels or {}) + labels[_SCHEMA_VERSION_LABEL_KEY] = _SCHEMA_VERSION + existing_table.labels = labels + updated = True + + if updated: + try: + update_fields = ["schema", "labels"] + self.client.update_table(existing_table, update_fields) + except Exception as e: + logger.error( + "Schema auto-upgrade failed for %s: %s", + self.full_table_id, + e, + exc_info=True, + ) + async def shutdown(self, timeout: float | None = None) -> None: """Shuts down the plugin and releases resources. @@ -2123,16 +2256,42 @@ async def on_user_message_callback( ) -> None: """Parity with V1: Logs USER_MESSAGE_RECEIVED event. + Also detects HITL completion responses (user-sent + ``FunctionResponse`` parts with ``adk_request_*`` names) and emits + dedicated ``HITL_*_COMPLETED`` events. + Args: invocation_context: The context of the current invocation. user_message: The message content received from the user. """ + callback_ctx = CallbackContext(invocation_context) await self._log_event( "USER_MESSAGE_RECEIVED", - CallbackContext(invocation_context), + callback_ctx, raw_content=user_message, ) + # Detect HITL completion responses in the user message. + if user_message and user_message.parts: + for part in user_message.parts: + if part.function_response: + hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) + if hitl_event: + resp_truncated, is_truncated = _recursive_smart_truncate( + part.function_response.response or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_response.name, + "result": resp_truncated, + } + await self._log_event( + hitl_event + "_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + @_safe_callback async def on_event_callback( self, @@ -2140,24 +2299,76 @@ async def on_event_callback( invocation_context: InvocationContext, event: "Event", ) -> None: - """Logs state changes from events to BigQuery. + """Logs state changes and HITL events from the event stream. - Checks each event for a non-empty state_delta and logs it as a - STATE_DELTA event. This captures state changes from all sources - (tools, agents, LLM, manual), not just tool callbacks. + - Checks each event for a non-empty state_delta and logs it as a + STATE_DELTA event. + - Detects synthetic ``adk_request_*`` function calls (HITL pause + events) and their corresponding function responses (HITL + completions) and emits dedicated HITL event types. + + The HITL detection must happen here (not in tool callbacks) because + ``adk_request_credential``, ``adk_request_confirmation``, and + ``adk_request_input`` are synthetic function calls injected by the + framework — they never go through ``before_tool_callback`` / + ``after_tool_callback``. Args: invocation_context: The context for the current invocation. event: The event raised by the runner. """ + callback_ctx = CallbackContext(invocation_context) + + # --- State delta logging --- if event.actions and event.actions.state_delta: await self._log_event( "STATE_DELTA", - CallbackContext(invocation_context), + callback_ctx, event_data=EventData( extra_attributes={"state_delta": dict(event.actions.state_delta)} ), ) + + # --- HITL event logging --- + if event.content and event.content.parts: + for part in event.content.parts: + # Detect HITL function calls (request events). + if part.function_call: + hitl_event = _HITL_EVENT_MAP.get(part.function_call.name) + if hitl_event: + args_truncated, is_truncated = _recursive_smart_truncate( + part.function_call.args or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_call.name, + "args": args_truncated, + } + await self._log_event( + hitl_event, + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + # Detect HITL function responses (completion events). + if part.function_response: + hitl_event = _HITL_EVENT_MAP.get(part.function_response.name) + if hitl_event: + resp_truncated, is_truncated = _recursive_smart_truncate( + part.function_response.response or {}, + self.config.max_content_length, + ) + content_dict = { + "tool": part.function_response.name, + "result": resp_truncated, + } + await self._log_event( + hitl_event + "_COMPLETED", + callback_ctx, + raw_content=content_dict, + is_truncated=is_truncated, + ) + return None async def on_state_change_callback( @@ -2460,7 +2671,12 @@ async def before_tool_callback( args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - content_dict = {"tool": tool.name, "args": args_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "args": args_truncated, + "tool_origin": tool_origin, + } TraceManager.push_span(tool_context, "tool") await self._log_event( "TOOL_STARTING", @@ -2489,20 +2705,26 @@ async def after_tool_callback( resp_truncated, is_truncated = _recursive_smart_truncate( result, self.config.max_content_length ) - content_dict = {"tool": tool.name, "result": resp_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "result": resp_truncated, + "tool_origin": tool_origin, + } span_id, duration = TraceManager.pop_span() parent_span_id, _ = TraceManager.get_current_span_and_parent() + event_data = EventData( + latency_ms=duration, + span_id_override=span_id, + parent_span_id_override=parent_span_id, + ) await self._log_event( "TOOL_COMPLETED", tool_context, raw_content=content_dict, is_truncated=is_truncated, - event_data=EventData( - latency_ms=duration, - span_id_override=span_id, - parent_span_id_override=parent_span_id, - ), + event_data=event_data, ) @_safe_callback @@ -2525,7 +2747,12 @@ async def on_tool_error_callback( args_truncated, is_truncated = _recursive_smart_truncate( tool_args, self.config.max_content_length ) - content_dict = {"tool": tool.name, "args": args_truncated} + tool_origin = _get_tool_origin(tool) + content_dict = { + "tool": tool.name, + "args": args_truncated, + "tool_origin": tool_origin, + } _, duration = TraceManager.pop_span() await self._log_event( "TOOL_ERROR", diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index e9f617c400..d3618fb94a 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -3949,3 +3949,715 @@ async def test_multi_turn_multi_subagent_full_sequence( # All rows share the same session for row in rows: assert row["session_id"] == "session-multi" + + +class TestSchemaAutoUpgrade: + """Tests for _ensure_schema_exists with auto_schema_upgrade.""" + + def _make_plugin(self, auto_schema_upgrade=False): + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig( + auto_schema_upgrade=auto_schema_upgrade, + ) + with mock.patch("google.cloud.bigquery.Client"): + plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + config=config, + ) + plugin.client = mock.MagicMock() + plugin.full_table_id = f"{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}" + plugin._schema = bigquery_agent_analytics_plugin._get_events_schema() + return plugin + + def test_create_table_sets_version_label(self): + """New tables get the schema version label.""" + plugin = self._make_plugin() + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin._ensure_schema_exists() + plugin.client.create_table.assert_called_once() + tbl = plugin.client.create_table.call_args[0][0] + assert ( + tbl.labels[bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_no_upgrade_when_disabled(self): + """Auto-upgrade disabled: existing table is not modified.""" + plugin = self._make_plugin(auto_schema_upgrade=False) + existing = mock.MagicMock() + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_upgrade_adds_missing_columns(self): + """Auto-upgrade adds columns missing from existing table.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {"other": "label"} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + updated_names = {f.name for f in updated_table.schema} + assert "event_type" in updated_names + assert "agent" in updated_names + assert "content" in updated_names + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_skip_upgrade_when_version_matches(self): + """No update when stored version matches current.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = plugin._schema + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_upgrade_error_is_logged_not_raised(self): + """Schema upgrade errors are logged, not propagated.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin.client.update_table.side_effect = Exception("boom") + # Should not raise + plugin._ensure_schema_exists() + + def test_upgrade_preserves_existing_columns(self): + """Existing columns are never dropped or altered during upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + # Simulate a table with a subset of canonical columns plus a + # user-added custom column that is NOT in the canonical schema. + custom_field = bigquery.SchemaField("my_custom_col", "STRING") + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("event_type", "STRING"), + custom_field, + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + updated_table = plugin.client.update_table.call_args[0][0] + updated_names = [f.name for f in updated_table.schema] + # Original columns are still present and in original order. + assert updated_names[0] == "timestamp" + assert updated_names[1] == "event_type" + assert updated_names[2] == "my_custom_col" + # New canonical columns were appended after existing ones. + assert "agent" in updated_names + assert "content" in updated_names + + def test_upgrade_from_no_label_treats_as_outdated(self): + """A table with no version label is treated as needing upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = list(plugin._schema) # All columns present + existing.labels = {} # No version label + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + # update_table should be called to stamp the version label even + # though no new columns were needed. + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + + def test_upgrade_from_older_version_label(self): + """A table with an older version label triggers upgrade.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("event_type", "STRING"), + ] + # Simulate a table stamped with an older version. + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: "0", + } + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + plugin.client.update_table.assert_called_once() + updated_table = plugin.client.update_table.call_args[0][0] + # Version label should be updated to current. + assert ( + updated_table.labels[ + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY + ] + == bigquery_agent_analytics_plugin._SCHEMA_VERSION + ) + # Missing columns should have been added. + updated_names = {f.name for f in updated_table.schema} + assert "agent" in updated_names + assert "content" in updated_names + + def test_upgrade_is_idempotent(self): + """Calling _ensure_schema_exists twice doesn't double-update.""" + plugin = self._make_plugin(auto_schema_upgrade=True) + + # First call: table exists with old schema. + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + assert plugin.client.update_table.call_count == 1 + + # Second call: table now has current version label. + existing.labels = { + bigquery_agent_analytics_plugin._SCHEMA_VERSION_LABEL_KEY: ( + bigquery_agent_analytics_plugin._SCHEMA_VERSION + ), + } + plugin.client.update_table.reset_mock() + plugin._ensure_schema_exists() + plugin.client.update_table.assert_not_called() + + def test_update_table_receives_schema_and_labels_fields(self): + """update_table is called with update_fields=['schema', 'labels'].""" + plugin = self._make_plugin(auto_schema_upgrade=True) + existing = mock.MagicMock(spec=bigquery.Table) + existing.schema = [ + bigquery.SchemaField("timestamp", "TIMESTAMP"), + ] + existing.labels = {} + plugin.client.get_table.return_value = existing + plugin._ensure_schema_exists() + + call_args = plugin.client.update_table.call_args + update_fields = call_args[0][1] + assert "schema" in update_fields + assert "labels" in update_fields + + def test_auto_schema_upgrade_defaults_to_true(self): + """Default config has auto_schema_upgrade enabled.""" + config = bigquery_agent_analytics_plugin.BigQueryLoggerConfig() + assert config.auto_schema_upgrade is True + + def test_create_table_conflict_is_ignored(self): + """Race condition (Conflict) during create_table is silently handled.""" + plugin = self._make_plugin() + plugin.client.get_table.side_effect = cloud_exceptions.NotFound("not found") + plugin.client.create_table.side_effect = cloud_exceptions.Conflict( + "already exists" + ) + # Should not raise. + plugin._ensure_schema_exists() + + +class TestToolProvenance: + """Tests for _get_tool_origin helper.""" + + def test_function_tool_returns_local(self): + from google.adk.tools.function_tool import FunctionTool + + def dummy(): + pass + + tool = FunctionTool(dummy) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "LOCAL" + + def test_agent_tool_returns_sub_agent(self): + from google.adk.tools.agent_tool import AgentTool + + agent = mock.MagicMock() + agent.name = "sub" + tool = AgentTool.__new__(AgentTool) + tool.agent = agent + tool._name = "sub" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "SUB_AGENT" + + def test_transfer_tool_returns_transfer_agent(self): + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + tool = TransferToAgentTool(agent_names=["other"]) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "TRANSFER_AGENT" + + def test_mcp_tool_returns_mcp(self): + try: + from google.adk.tools.mcp_tool.mcp_tool import McpTool + except ImportError: + pytest.skip("MCP not installed") + tool = McpTool.__new__(McpTool) + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "MCP" + + def test_a2a_agent_tool_returns_a2a(self): + from google.adk.tools.agent_tool import AgentTool + + try: + from google.adk.agents.remote_a2a_agent import RemoteA2aAgent + except ImportError: + pytest.skip("A2A agent not available") + + remote_agent = mock.MagicMock(spec=RemoteA2aAgent) + remote_agent.name = "remote" + remote_agent.description = "remote a2a agent" + tool = AgentTool.__new__(AgentTool) + tool.agent = remote_agent + tool._name = "remote" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "A2A" + + def test_unknown_tool_returns_unknown(self): + tool = mock.MagicMock(spec=base_tool_lib.BaseTool) + tool.name = "mystery" + result = bigquery_agent_analytics_plugin._get_tool_origin(tool) + assert result == "UNKNOWN" + + +class TestHITLTracing: + """Tests for HITL-specific event emission via on_event_callback. + + HITL events (``adk_request_credential``, ``adk_request_confirmation``, + ``adk_request_input``) are synthetic function calls injected by the + framework — they never pass through ``before_tool_callback`` / + ``after_tool_callback``. Detection therefore lives in + ``on_event_callback``, which inspects the event stream for these + function calls and their corresponding function responses. + """ + + def _make_fc_event(self, fc_name, args=None): + """Build a mock Event containing a function call.""" + event = mock.MagicMock(spec=event_lib.Event) + fc = types.FunctionCall(name=fc_name, args=args or {}) + part = types.Part(function_call=fc) + event.content = types.Content(role="model", parts=[part]) + event.actions = event_actions_lib.EventActions() + return event + + def _make_fr_event(self, fr_name, response=None): + """Build a mock Event containing a function response.""" + event = mock.MagicMock(spec=event_lib.Event) + fr = types.FunctionResponse(name=fr_name, response=response or {}) + part = types.Part(function_response=fr) + event.content = types.Content(role="user", parts=[part]) + event.actions = event_actions_lib.EventActions() + return event + + @pytest.mark.asyncio + async def test_hitl_confirmation_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("adk_request_confirmation", {"confirm": True}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CONFIRMATION_REQUEST" in event_types + + @pytest.mark.asyncio + async def test_hitl_credential_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("adk_request_credential", {"auth": "oauth2"}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CREDENTIAL_REQUEST" in event_types + + @pytest.mark.asyncio + async def test_hitl_completion_emits_additional_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fr_event("adk_request_confirmation", {"confirmed": True}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + assert "HITL_CONFIRMATION_REQUEST_COMPLETED" in event_types + + @pytest.mark.asyncio + async def test_regular_tool_no_hitl_event( + self, + bq_plugin_inst, + mock_write_client, + invocation_context, + dummy_arrow_schema, + ): + event = self._make_fc_event("regular_tool", {"x": 1}) + await bq_plugin_inst.on_event_callback( + invocation_context=invocation_context, event=event + ) + await asyncio.sleep(0.05) + # No HITL events should be emitted for non-HITL function calls. + # on_event_callback only logs STATE_DELTA and HITL events; a regular + # function call produces neither. + assert mock_write_client.append_rows.call_count == 0 + + +# ============================================================================== +# TEST CLASS: Span Hierarchy Isolation (Issue #4561) +# ============================================================================== + + +class TestSpanHierarchyIsolation: + """Regression tests for https://github.com/google/adk-python/issues/4561. + + ``push_span()`` must NOT attach its span to the ambient OTel context. + If it does, any subsequent ``tracer.start_as_current_span()`` in the + framework (e.g. ``call_llm``, ``execute_tool``) will be incorrectly + re-parented under the plugin's span. + """ + + def test_push_span_does_not_change_ambient_context(self, callback_context): + """push_span must not mutate the current OTel span.""" + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + + span_after = trace.get_current_span() + assert span_after is span_before + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + def test_attach_current_span_does_not_change_ambient_context( + self, callback_context + ): + """attach_current_span must not mutate the current OTel span.""" + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.attach_current_span( + callback_context + ) + + span_after = trace.get_current_span() + assert span_after is span_before + + # Cleanup + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + def test_pop_span_does_not_change_ambient_context(self, callback_context): + """pop_span must not mutate the current OTel span.""" + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "test_span" + ) + span_before = trace.get_current_span() + + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + span_after = trace.get_current_span() + assert span_after is span_before + + def test_push_span_with_real_tracer_does_not_reparent(self, callback_context): + """With a real OTel tracer, plugin spans must not become parents + of subsequently created framework spans.""" + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + exporter = InMemorySpanExporter() + provider = TracerProvider() + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + + provider.add_span_processor(SimpleSpanProcessor(exporter)) + framework_tracer = provider.get_tracer("test-framework") + + # Simulate: plugin pushes a span BEFORE the framework span + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "llm_request" + ) + + # Framework creates its own span via start_as_current_span + with framework_tracer.start_as_current_span("call_llm") as fw_span: + fw_context = fw_span.get_span_context() + + # Pop the plugin span + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + provider.shutdown() + + # Verify the framework span was NOT re-parented under the + # plugin's llm_request span + finished = exporter.get_finished_spans() + call_llm_spans = [s for s in finished if s.name == "call_llm"] + assert len(call_llm_spans) == 1 + fw_finished = call_llm_spans[0] + + # The framework span's parent should NOT be the plugin's + # llm_request span. With the fix, the plugin never + # attaches to the ambient context, so ``call_llm`` will + # have whatever parent existed before (None in this test). + assert fw_finished.parent is None + + def test_multiple_push_pop_cycles_leave_context_clean(self, callback_context): + """Multiple push/pop cycles must not leak context changes.""" + original_span = trace.get_current_span() + + for _ in range(5): + bigquery_agent_analytics_plugin.TraceManager.push_span( + callback_context, "cycle_span" + ) + bigquery_agent_analytics_plugin.TraceManager.pop_span() + + assert trace.get_current_span() is original_span + + +# ============================================================================== +# TEST CLASS: End-to-End HITL Tracing via Runner +# ============================================================================== + + +def _hitl_my_action( + tool_context: tool_context_lib.ToolContext, +) -> dict[str, str]: + """Tool function used by HITL end-to-end tests.""" + return {"result": f"confirmed={tool_context.tool_confirmation.confirmed}"} + + +class TestHITLTracingEndToEnd: + """End-to-end tests that run the full Runner + Plugin pipeline with + ``FunctionTool(require_confirmation=True)`` and verify that HITL events + are logged alongside normal TOOL_* events in the BQ analytics plugin. + """ + + @pytest.fixture + def _mock_bq_infra( + self, + mock_auth_default, + mock_bq_client, + mock_write_client, + mock_to_arrow_schema, + mock_asyncio_to_thread, + ): + """Bundle all BQ mocking fixtures.""" + yield mock_write_client + + @pytest.mark.asyncio + async def test_confirmation_flow_emits_hitl_events( + self, + _mock_bq_infra, + dummy_arrow_schema, + ): + """Full Runner pipeline: tool with require_confirmation emits + HITL_CONFIRMATION_REQUEST and HITL_CONFIRMATION_REQUEST_COMPLETED. + """ + from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + from google.adk.tools.function_tool import FunctionTool + from google.genai.types import FunctionCall + from google.genai.types import FunctionResponse + from google.genai.types import Part + + from .. import testing_utils + + mock_write_client = _mock_bq_infra + + tool = FunctionTool(func=_hitl_my_action, require_confirmation=True) + + # -- Mock LLM: first response calls the tool, second is final text -- + llm_responses = [ + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[ + Part(function_call=FunctionCall(name=tool.name, args={})) + ] + ) + ), + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[Part(text="Done, action confirmed.")] + ) + ), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + + # -- Build the plugin -- + bq_plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await bq_plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + # -- Build agent + runner WITH the plugin -- + from google.adk.agents.llm_agent import LlmAgent + + agent = LlmAgent(name="hitl_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent, plugins=[bq_plugin]) + + # -- Turn 1: user query → LLM calls tool → HITL pause -- + events_turn1 = await runner.run_async( + testing_utils.UserContent("run my_action") + ) + + # Find the adk_request_confirmation function call + confirmation_fc_id = None + for ev in events_turn1: + if ev.content and ev.content.parts: + for part in ev.content.parts: + if ( + hasattr(part, "function_call") + and part.function_call + and part.function_call.name + == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + ): + confirmation_fc_id = part.function_call.id + break + if confirmation_fc_id: + break + + assert ( + confirmation_fc_id is not None + ), "Expected adk_request_confirmation function call in turn 1" + + # -- Turn 2: user sends confirmation → tool re-executes -- + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=confirmation_fc_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + events_turn2 = await runner.run_async(user_confirmation) + + # -- Give the async BQ writer a moment to flush -- + await asyncio.sleep(0.2) + + # -- Collect all BQ rows -- + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + + # -- Verify standard events are present -- + assert "TOOL_STARTING" in event_types + assert "TOOL_COMPLETED" in event_types + + # -- Verify HITL-specific events are present -- + assert ( + "HITL_CONFIRMATION_REQUEST" in event_types + ), f"Expected HITL_CONFIRMATION_REQUEST in {event_types}" + assert ( + "HITL_CONFIRMATION_REQUEST_COMPLETED" in event_types + ), f"Expected HITL_CONFIRMATION_REQUEST_COMPLETED in {event_types}" + + # -- Verify HITL events have correct tool name in content -- + hitl_rows = [r for r in rows if r["event_type"].startswith("HITL_")] + for row in hitl_rows: + content = json.loads(row["content"]) if row["content"] else {} + assert content.get("tool") == "adk_request_confirmation", ( + "HITL event should reference 'adk_request_confirmation'," + f" got {content.get('tool')}" + ) + + await bq_plugin.shutdown() + + @pytest.mark.asyncio + async def test_regular_tool_does_not_emit_hitl_events( + self, + _mock_bq_infra, + dummy_arrow_schema, + ): + """A tool WITHOUT require_confirmation should not produce HITL events.""" + from google.adk.tools.function_tool import FunctionTool + from google.genai.types import FunctionCall + from google.genai.types import Part + + from .. import testing_utils + + mock_write_client = _mock_bq_infra + + def regular_tool() -> str: + return "done" + + tool = FunctionTool(func=regular_tool) + + llm_responses = [ + testing_utils.LlmResponse( + content=testing_utils.ModelContent( + parts=[ + Part(function_call=FunctionCall(name=tool.name, args={})) + ] + ) + ), + testing_utils.LlmResponse( + content=testing_utils.ModelContent(parts=[Part(text="All done.")]) + ), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + + bq_plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + table_id=TABLE_ID, + ) + await bq_plugin._ensure_started() + mock_write_client.append_rows.reset_mock() + + from google.adk.agents.llm_agent import LlmAgent + + agent = LlmAgent(name="regular_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent, plugins=[bq_plugin]) + + await runner.run_async(testing_utils.UserContent("run regular_tool")) + await asyncio.sleep(0.2) + + rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) + event_types = [r["event_type"] for r in rows] + + # Standard tool events should be present + assert "TOOL_STARTING" in event_types + assert "TOOL_COMPLETED" in event_types + + # No HITL events + hitl_events = [et for et in event_types if et.startswith("HITL_")] + assert ( + hitl_events == [] + ), f"Expected no HITL events for regular tool, got {hitl_events}" + + await bq_plugin.shutdown()