From 753e5c88622cb56be4d6f70bcb5bf0f352fd7ce1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 19 May 2026 10:26:54 -0700 Subject: [PATCH] feat: Add telemetry and metrics recording capabilities The key changes include: 1. **New `Instrumentation.java` class:** This class provides a unified context manager for instrumenting agent invocations and tool executions. It uses OpenTelemetry to create trace spans, record exceptions, and manage the scope of telemetry contexts. It includes inner classes `AgentInvocation` and `ToolExecution`, both implementing `AutoCloseable` to automatically handle the lifecycle of spans and metric recording. 2. **New `Metrics.java` class:** This utility class is responsible for defining and recording various OpenTelemetry metrics (histograms) related to ADK components. These metrics cover: * Agent invocation duration, request size, response size, and workflow steps. * Tool execution duration, request size, and response size. The class uses `GlobalOpenTelemetry` to get a `Meter` and defines static histogram instances. It includes methods to record values for these metrics, often including attributes like agent name, tool name, and error type. 3. **New Unit Tests:** * `InstrumentationTest.java`: Contains unit tests for the `Instrumentation` class, verifying that spans are created correctly, contexts are managed, and metrics are recorded for both successful and error scenarios during agent invocations and tool executions. It uses `OpenTelemetryRule` for testing. * `MetricsTest.java`: Contains unit tests for the `Metrics` class, ensuring that the various static methods correctly record histogram data with the expected values and attributes. It also uses `OpenTelemetryRule`. PiperOrigin-RevId: 917904663 --- .../google/adk/telemetry/Instrumentation.java | 207 ++++++++++++++ .../com/google/adk/telemetry/Metrics.java | 258 ++++++++++++++++++ .../adk/telemetry/InstrumentationTest.java | 201 ++++++++++++++ .../com/google/adk/telemetry/MetricsTest.java | 211 ++++++++++++++ 4 files changed, 877 insertions(+) create mode 100644 core/src/main/java/com/google/adk/telemetry/Instrumentation.java create mode 100644 core/src/main/java/com/google/adk/telemetry/Metrics.java create mode 100644 core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java create mode 100644 core/src/test/java/com/google/adk/telemetry/MetricsTest.java diff --git a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java new file mode 100644 index 000000000..a2c62ba12 --- /dev/null +++ b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java @@ -0,0 +1,207 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.telemetry; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.tools.BaseTool; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Unified context manager utility class for agent and tool execution telemetry in ADK. */ +public final class Instrumentation { + + private static final Logger logger = LoggerFactory.getLogger(Instrumentation.class); + + private Instrumentation() {} + + /** Stores all telemetry related state. */ + public static final class TelemetryContext { + private final Context otelContext; + private @Nullable Event functionResponseEvent; + + public TelemetryContext(Context otelContext) { + this.otelContext = otelContext; + } + + public Context otelContext() { + return otelContext; + } + + public @Nullable Event functionResponseEvent() { + return functionResponseEvent; + } + + public void setFunctionResponseEvent(@Nullable Event functionResponseEvent) { + this.functionResponseEvent = functionResponseEvent; + } + } + + /** Base class for AutoCloseable telemetry tracking scopes. */ + public abstract static class ClosableTelemetryScope implements AutoCloseable { + protected final long startTimeNanos; + protected final Span span; + protected final Scope scope; + protected final TelemetryContext telemetryContext; + protected @Nullable Throwable caughtError; + protected final AtomicBoolean closed = new AtomicBoolean(false); + + @SuppressWarnings("MustBeClosedChecker") + ClosableTelemetryScope(Span span) { + this.startTimeNanos = System.nanoTime(); + this.span = span; + this.scope = span.makeCurrent(); + this.telemetryContext = new TelemetryContext(Context.current()); + } + + public TelemetryContext context() { + return telemetryContext; + } + + public void setError(Throwable caughtError) { + this.caughtError = caughtError; + span.recordException(caughtError); + span.setStatus(StatusCode.ERROR, caughtError.getMessage()); + } + + @Override + public final void close() { + if (closed.getAndSet(true)) { + return; + } + try { + beforeSpanEnd(); + span.end(); + Duration elapsed = Duration.ofNanos(System.nanoTime() - startTimeNanos); + try { + recordMetrics(elapsed, caughtError); + } catch (RuntimeException e) { + handleMetricsError(e); + } + } finally { + scope.close(); + } + } + + /** Hook for subclasses to run code before span ends. */ + protected void beforeSpanEnd() {} + + /** Hook for subclasses to record metrics. */ + protected abstract void recordMetrics(Duration elapsed, @Nullable Throwable error); + + /** Hook for subclasses to handle metrics recording errors. */ + protected abstract void handleMetricsError(RuntimeException e); + } + + /** AutoCloseable telemetry tracking scope for agent invocations. */ + public static final class AgentInvocation extends ClosableTelemetryScope { + private final BaseAgent agent; + private final InvocationContext ctx; + private final List events = Collections.synchronizedList(new ArrayList<>()); + + public AgentInvocation(InvocationContext ctx, BaseAgent agent) { + super(Tracing.getTracer().spanBuilder("invoke_agent " + agent.name()).startSpan()); + this.agent = agent; + this.ctx = ctx; + Tracing.traceAgentInvocation(span, agent.name(), agent.description(), ctx); + } + + public InvocationContext getCtx() { + return ctx; + } + + public void addEvent(Event event) { + events.add(event); + } + + @Override + protected void recordMetrics(Duration elapsed, @Nullable Throwable error) { + Metrics.recordAgentInvocationDuration(agent.name(), elapsed, error); + Metrics.recordAgentRequestSize(agent.name(), ctx.userContent().orElse(null)); + Metrics.recordAgentResponseSize(agent.name(), events); + Metrics.recordAgentWorkflowSteps(agent.name(), events); + } + + @Override + protected void handleMetricsError(RuntimeException e) { + logger.error("Failed to record agent metrics for agent {}", agent.name(), e); + } + } + + /** AutoCloseable telemetry tracking scope for tool executions. */ + public static final class ToolExecution extends ClosableTelemetryScope { + private final BaseTool tool; + private final BaseAgent agent; + private final Map functionArgs; + + public ToolExecution(BaseTool tool, BaseAgent agent, Map functionArgs) { + super(Tracing.getTracer().spanBuilder("execute_tool " + tool.name()).startSpan()); + this.tool = tool; + this.agent = agent; + this.functionArgs = functionArgs; + } + + @Override + protected void beforeSpanEnd() { + Event responseEvent = caughtError == null ? context().functionResponseEvent() : null; + Tracing.traceToolExecution( + span, + tool.name(), + tool.description(), + tool.getClass().getSimpleName(), + functionArgs, + responseEvent, + caughtError); + } + + @Override + protected void recordMetrics(Duration elapsed, @Nullable Throwable error) { + Metrics.recordToolExecutionDuration(tool.name(), agent.name(), elapsed, error); + Metrics.recordToolRequestSize(tool.name(), agent.name(), functionArgs); + Event responseEvent = error == null ? context().functionResponseEvent() : null; + Metrics.recordToolResponseSize(tool.name(), agent.name(), responseEvent); + } + + @Override + protected void handleMetricsError(RuntimeException e) { + logger.error("Failed to record tool execution duration for tool {}", tool.name(), e); + } + } + + /** Creates an AgentInvocation context to record agent invocation telemetry. */ + public static AgentInvocation recordAgentInvocation(InvocationContext ctx, BaseAgent agent) { + return new AgentInvocation(ctx, agent); + } + + /** Creates a ToolExecution context to record tool execution telemetry. */ + public static ToolExecution recordToolExecution( + BaseTool tool, BaseAgent agent, Map functionArgs) { + return new ToolExecution(tool, agent, functionArgs); + } +} diff --git a/core/src/main/java/com/google/adk/telemetry/Metrics.java b/core/src/main/java/com/google/adk/telemetry/Metrics.java new file mode 100644 index 000000000..79cf80e57 --- /dev/null +++ b/core/src/main/java/com/google/adk/telemetry/Metrics.java @@ -0,0 +1,258 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.telemetry; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.adk.events.Event; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.api.metrics.DoubleHistogram; +import io.opentelemetry.api.metrics.LongHistogram; +import io.opentelemetry.api.metrics.Meter; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.jspecify.annotations.Nullable; + +/** Utility class for recording OpenTelemetry metrics within the ADK. */ +public final class Metrics { + + private static final AttributeKey GEN_AI_AGENT_NAME = + AttributeKey.stringKey("gen_ai.agent.name"); + private static final AttributeKey GEN_AI_TOOL_NAME = + AttributeKey.stringKey("gen_ai.tool.name"); + private static final AttributeKey ERROR_TYPE = AttributeKey.stringKey("error.type"); + + @SuppressWarnings("NonFinalStaticField") + private static Meter meter = GlobalOpenTelemetry.getMeter("gcp.vertex.agent"); + + @SuppressWarnings("NonFinalStaticField") + private static DoubleHistogram agentInvocationDuration = + meter + .histogramBuilder("gen_ai.agent.invocation.duration") + .setUnit("ms") + .setDescription("Duration of agent invocations.") + .build(); + + @SuppressWarnings("NonFinalStaticField") + private static DoubleHistogram toolExecutionDuration = + meter + .histogramBuilder("gen_ai.tool.execution.duration") + .setUnit("ms") + .setDescription("Duration of tool executions.") + .build(); + + @SuppressWarnings("NonFinalStaticField") + private static LongHistogram agentRequestSize = + meter + .histogramBuilder("gen_ai.agent.request.size") + .setUnit("By") + .setDescription("Size of agent requests.") + .ofLongs() + .build(); + + @SuppressWarnings("NonFinalStaticField") + private static LongHistogram agentResponseSize = + meter + .histogramBuilder("gen_ai.agent.response.size") + .setUnit("By") + .setDescription("Size of agent responses.") + .ofLongs() + .build(); + + @SuppressWarnings("NonFinalStaticField") + private static LongHistogram agentWorkflowSteps = + meter + .histogramBuilder("gen_ai.agent.workflow.steps") + .setUnit("1") + .setDescription("Length of agentic workflow (# of events).") + .ofLongs() + .build(); + + @SuppressWarnings("NonFinalStaticField") + private static LongHistogram toolRequestSize = + meter + .histogramBuilder("gen_ai.tool.request.size") + .setUnit("By") + .setDescription("Size of tool requests.") + .ofLongs() + .build(); + + @SuppressWarnings("NonFinalStaticField") + private static LongHistogram toolResponseSize = + meter + .histogramBuilder("gen_ai.tool.response.size") + .setUnit("By") + .setDescription("Size of tool responses.") + .ofLongs() + .build(); + + private Metrics() {} + + /** Sets the OpenTelemetry Meter to be used for metrics. This is for testing purposes only. */ + public static void setMeterForTesting(Meter meter) { + Metrics.meter = meter; + Metrics.agentInvocationDuration = + meter + .histogramBuilder("gen_ai.agent.invocation.duration") + .setUnit("ms") + .setDescription("Duration of agent invocations.") + .build(); + Metrics.toolExecutionDuration = + meter + .histogramBuilder("gen_ai.tool.execution.duration") + .setUnit("ms") + .setDescription("Duration of tool executions.") + .build(); + Metrics.agentRequestSize = + meter + .histogramBuilder("gen_ai.agent.request.size") + .setUnit("By") + .setDescription("Size of agent requests.") + .ofLongs() + .build(); + Metrics.agentResponseSize = + meter + .histogramBuilder("gen_ai.agent.response.size") + .setUnit("By") + .setDescription("Size of agent responses.") + .ofLongs() + .build(); + Metrics.agentWorkflowSteps = + meter + .histogramBuilder("gen_ai.agent.workflow.steps") + .setUnit("1") + .setDescription("Length of agentic workflow (# of events).") + .ofLongs() + .build(); + Metrics.toolRequestSize = + meter + .histogramBuilder("gen_ai.tool.request.size") + .setUnit("By") + .setDescription("Size of tool requests.") + .ofLongs() + .build(); + Metrics.toolResponseSize = + meter + .histogramBuilder("gen_ai.tool.response.size") + .setUnit("By") + .setDescription("Size of tool responses.") + .ofLongs() + .build(); + } + + /** Records the duration of the agent invocation. */ + public static void recordAgentInvocationDuration( + String agentName, Duration duration, @Nullable Throwable error) { + AttributesBuilder attrs = Attributes.builder().put(GEN_AI_AGENT_NAME, agentName); + if (error != null) { + attrs.put(ERROR_TYPE, error.getClass().getSimpleName()); + } + agentInvocationDuration.record((double) duration.toMillis(), attrs.build()); + } + + /** Records the size of the agent request. */ + public static void recordAgentRequestSize(String agentName, @Nullable Content userContent) { + long size = getContentSize(userContent); + Attributes attrs = Attributes.of(GEN_AI_AGENT_NAME, agentName); + agentRequestSize.record(size, attrs); + } + + /** Records the size of the agent response by extracting content from events. */ + public static void recordAgentResponseSize(String agentName, @Nullable List events) { + Content responseContent = null; + if (events != null) { + for (int i = events.size() - 1; i >= 0; i--) { + Event event = events.get(i); + if (agentName.equals(event.author()) && event.content().isPresent()) { + responseContent = event.content().get(); + break; + } + } + } + long size = getContentSize(responseContent); + Attributes attrs = Attributes.of(GEN_AI_AGENT_NAME, agentName); + agentResponseSize.record(size, attrs); + } + + /** Records the number of steps in the agent workflow by counting the number of events. */ + public static void recordAgentWorkflowSteps(String agentName, List events) { + Attributes attrs = Attributes.of(GEN_AI_AGENT_NAME, agentName); + long count = + events.stream() + .map(e -> e.stream().filter(event -> agentName.equals(event.author())).count()) + .orElse(0L); + agentWorkflowSteps.record(count, attrs); + } + + /** Records the duration of the tool execution. */ + public static void recordToolExecutionDuration( + String toolName, String agentName, Duration duration, @Nullable Throwable error) { + AttributesBuilder attrs = + Attributes.builder().put(GEN_AI_AGENT_NAME, agentName).put(GEN_AI_TOOL_NAME, toolName); + if (error != null) { + attrs.put(ERROR_TYPE, error.getClass().getSimpleName()); + } + toolExecutionDuration.record((double) duration.toMillis(), attrs.build()); + } + + /** Records the size of the tool request. */ + public static void recordToolRequestSize( + String toolName, String agentName, Map functionArgs) { + long size = 0; + for (Object value : functionArgs.values()) { + if (value instanceof String s) { + size += s.getBytes(UTF_8).length; + } + } + Attributes attrs = Attributes.of(GEN_AI_TOOL_NAME, toolName, GEN_AI_AGENT_NAME, agentName); + toolRequestSize.record(size, attrs); + } + + /** Records the size of the tool response. */ + public static void recordToolResponseSize( + String toolName, String agentName, @Nullable Event responseEvent) { + long size = 0; + if (responseEvent != null) { + size = getContentSize(responseEvent.content().orElse(null)); + } + Attributes attrs = Attributes.of(GEN_AI_TOOL_NAME, toolName, GEN_AI_AGENT_NAME, agentName); + toolResponseSize.record(size, attrs); + } + + private static long getContentSize(@Nullable Content content) { + if (content == null) { + return 0; + } + long size = 0; + for (Part part : content.parts().orElse(ImmutableList.of())) { + size += part.text().map(s -> s.getBytes(UTF_8).length).orElse(0); + size += + part.inlineData() + .flatMap(inlineData -> inlineData.data()) + .map(data -> data.length) + .orElse(0); + } + return size; + } +} diff --git a/core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java b/core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java new file mode 100644 index 000000000..d99a6878e --- /dev/null +++ b/core/src/test/java/com/google/adk/telemetry/InstrumentationTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.telemetry; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.adk.sessions.SessionKey; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.sdk.metrics.data.HistogramPointData; +import io.opentelemetry.sdk.metrics.data.MetricData; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import io.opentelemetry.sdk.trace.data.SpanData; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.util.List; +import java.util.Map; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class InstrumentationTest { + + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + + private Tracer originalTracer; + private Meter originalMeter; + private TestAgent testAgent; + private InvocationContext invocationContext; + + private static class TestAgent extends BaseAgent { + TestAgent() { + super("my-agent", "my-agent-description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext context) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext context) { + return Flowable.empty(); + } + } + + private static class TestTool extends BaseTool { + TestTool() { + super("my-tool", "my-tool-description"); + } + + @Override + public Single> runAsync(Map args, ToolContext context) { + return Single.just(args); + } + } + + @Before + public void setup() { + this.originalTracer = Tracing.getTracer(); + this.originalMeter = GlobalOpenTelemetry.getMeter("gcp.vertex.agent"); + Tracing.setTracerForTesting( + openTelemetryRule.getOpenTelemetry().getTracer("InstrumentationTest")); + Metrics.setMeterForTesting( + openTelemetryRule.getOpenTelemetry().getMeter("InstrumentationTest")); + + testAgent = new TestAgent(); + + SessionKey sessionKey = new SessionKey("test-app", "test-user", "test-session"); + Session session = Session.builder(sessionKey).events(ImmutableList.of()).build(); + + invocationContext = + InvocationContext.builder() + .sessionService(new InMemorySessionService()) + .session(session) + .agent(testAgent) + .invocationId("test-invocation-id") + .build(); + } + + @After + public void tearDown() { + Tracing.setTracerForTesting(originalTracer); + Metrics.setMeterForTesting(originalMeter); + } + + @Test + public void recordAgentInvocation_success() { + try (Instrumentation.AgentInvocation invocation = + Instrumentation.recordAgentInvocation(invocationContext, testAgent)) { + assertThat(invocation.context()).isNotNull(); + assertThat(invocation.context().otelContext()).isNotNull(); + } + + // Verify trace span + List spans = openTelemetryRule.getSpans(); + assertThat(spans).hasSize(1); + SpanData span = spans.get(0); + assertThat(span.getName()).isEqualTo("invoke_agent my-agent"); + assertThat(span.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("my-agent"); + + // Verify metrics + MetricData metric = findMetricByName("gen_ai.agent.invocation.duration"); + List points = + (List) metric.getHistogramData().getPoints(); + assertThat(points).hasSize(1); + HistogramPointData point = points.get(0); + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("my-agent"); + } + + @Test + public void recordAgentInvocation_withError() { + RuntimeException testException = new RuntimeException("test error"); + try (Instrumentation.AgentInvocation invocation = + Instrumentation.recordAgentInvocation(invocationContext, testAgent)) { + invocation.setError(testException); + } + + List spans = openTelemetryRule.getSpans(); + assertThat(spans).hasSize(1); + SpanData span = spans.get(0); + assertThat(span.getName()).isEqualTo("invoke_agent my-agent"); + + MetricData metric = findMetricByName("gen_ai.agent.invocation.duration"); + HistogramPointData point = metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getAttributes().get(AttributeKey.stringKey("error.type"))) + .isEqualTo("RuntimeException"); + } + + @Test + public void recordToolExecution_success() { + TestTool testTool = new TestTool(); + + try (Instrumentation.ToolExecution execution = + Instrumentation.recordToolExecution( + testTool, testAgent, ImmutableMap.of("arg1", "value1"))) { + assertThat(execution.context()).isNotNull(); + } + + List spans = openTelemetryRule.getSpans(); + assertThat(spans).hasSize(1); + SpanData span = spans.get(0); + assertThat(span.getName()).isEqualTo("execute_tool my-tool"); + Attributes attrs = span.getAttributes(); + assertThat(attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))).isEqualTo("my-tool"); + + MetricData metric = findMetricByName("gen_ai.tool.execution.duration"); + HistogramPointData point = metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))) + .isEqualTo("my-tool"); + + metric = findMetricByName("gen_ai.tool.request.size"); + point = (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))) + .isEqualTo("my-tool"); + + metric = findMetricByName("gen_ai.tool.response.size"); + point = metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))) + .isEqualTo("my-tool"); + } + + private MetricData findMetricByName(String name) { + return openTelemetryRule.getMetrics().stream() + .filter(m -> m.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new AssertionError("Metric not found: " + name)); + } +} diff --git a/core/src/test/java/com/google/adk/telemetry/MetricsTest.java b/core/src/test/java/com/google/adk/telemetry/MetricsTest.java new file mode 100644 index 000000000..26a999008 --- /dev/null +++ b/core/src/test/java/com/google/adk/telemetry/MetricsTest.java @@ -0,0 +1,211 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.telemetry; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.adk.events.Event; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.sdk.metrics.data.HistogramPointData; +import io.opentelemetry.sdk.metrics.data.MetricData; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import java.time.Duration; +import java.util.List; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class MetricsTest { + + @Rule public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + + private Meter originalMeter; + + @Before + public void setup() { + this.originalMeter = GlobalOpenTelemetry.getMeter("gcp.vertex.agent"); + Metrics.setMeterForTesting(openTelemetryRule.getOpenTelemetry().getMeter("MetricsTest")); + } + + @After + public void tearDown() { + Metrics.setMeterForTesting(originalMeter); + } + + @Test + public void recordAgentInvocationDuration_success() { + Metrics.recordAgentInvocationDuration("my-agent", Duration.ofMillis(123), null); + + MetricData metric = findMetricByName("gen_ai.agent.invocation.duration"); + assertThat(metric.getUnit()).isEqualTo("ms"); + assertThat(metric.getDescription()).isEqualTo("Duration of agent invocations."); + + List points = + (List) metric.getHistogramData().getPoints(); + assertThat(points).hasSize(1); + HistogramPointData point = points.get(0); + assertThat(point.getSum()).isEqualTo(123.0); + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("my-agent"); + } + + @Test + public void recordAgentInvocationDuration_withError() { + Metrics.recordAgentInvocationDuration( + "my-agent", Duration.ofMillis(500), new IllegalArgumentException("bad arg")); + + MetricData metric = findMetricByName("gen_ai.agent.invocation.duration"); + HistogramPointData point = metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(500.0); + Attributes attrs = point.getAttributes(); + assertThat(attrs.get(AttributeKey.stringKey("gen_ai.agent.name"))).isEqualTo("my-agent"); + assertThat(attrs.get(AttributeKey.stringKey("error.type"))) + .isEqualTo("IllegalArgumentException"); + } + + @Test + public void recordToolExecutionDuration_success() { + Metrics.recordToolExecutionDuration("my-tool", "my-agent", Duration.ofMillis(12), null); + + MetricData metric = findMetricByName("gen_ai.tool.execution.duration"); + HistogramPointData point = + (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(12.0); + Attributes attrs = point.getAttributes(); + assertThat(attrs.get(AttributeKey.stringKey("gen_ai.agent.name"))).isEqualTo("my-agent"); + assertThat(attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))).isEqualTo("my-tool"); + } + + @Test + public void recordToolExecutionDuration_withError() { + Metrics.recordToolExecutionDuration( + "my-tool", "my-agent", Duration.ofMillis(45), new NullPointerException()); + + MetricData metric = findMetricByName("gen_ai.tool.execution.duration"); + HistogramPointData point = + (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(45.0); + Attributes attrs = point.getAttributes(); + assertThat(attrs.get(AttributeKey.stringKey("gen_ai.agent.name"))).isEqualTo("my-agent"); + assertThat(attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))).isEqualTo("my-tool"); + assertThat(attrs.get(AttributeKey.stringKey("error.type"))).isEqualTo("NullPointerException"); + } + + @Test + public void recordAgentRequestSize_success() { + Content userContent = + Content.builder() + .parts( + Part.fromText("hello"), + Part.builder() + .inlineData( + Blob.builder().data("world".getBytes(UTF_8)).mimeType("text/plain").build()) + .build()) + .build(); + + Metrics.recordAgentRequestSize("my-agent", userContent); + + MetricData metric = findMetricByName("gen_ai.agent.request.size"); + assertThat(metric.getUnit()).isEqualTo("By"); + HistogramPointData point = + (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(10); // "hello" is 5, "world" is 5. Total 10. + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("my-agent"); + } + + @Test + public void recordAgentResponseSize_success() { + Content responseContent = Content.fromParts(Part.fromText("response")); + Event mockEvent1 = + Event.builder().author("user").content(Content.fromParts(Part.fromText("hi"))).build(); + Event mockEvent2 = Event.builder().author("my-agent").content(responseContent).build(); + + Metrics.recordAgentResponseSize("my-agent", ImmutableList.of(mockEvent1, mockEvent2)); + + MetricData metric = findMetricByName("gen_ai.agent.response.size"); + HistogramPointData point = + (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(8); // "response" is 8. + } + + @Test + public void recordAgentWorkflowSteps_success() { + Event event1 = Event.builder().author("my-agent").build(); + Event event2 = Event.builder().author("user").build(); + Event event3 = Event.builder().author("my-agent").build(); + + Metrics.recordAgentWorkflowSteps("my-agent", ImmutableList.of(event1, event2, event3)); + + MetricData metric = findMetricByName("gen_ai.agent.workflow.steps"); + HistogramPointData point = + (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(2); // 2 events by "my-agent". + } + + @Test + public void recordToolRequestSize_success() { + Metrics.recordToolRequestSize("my-tool", "my-agent", ImmutableMap.of("arg1", "value1")); + + MetricData metric = findMetricByName("gen_ai.tool.request.size"); + assertThat(metric.getUnit()).isEqualTo("By"); + HistogramPointData point = + (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(6); // "value1" is 6. + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("my-agent"); + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))) + .isEqualTo("my-tool"); + } + + @Test + public void recordToolResponseSize_success() { + Content responseContent = Content.fromParts(Part.fromText("response")); + Event responseEvent = Event.builder().author("my-tool").content(responseContent).build(); + + Metrics.recordToolResponseSize("my-tool", "my-agent", responseEvent); + + MetricData metric = findMetricByName("gen_ai.tool.response.size"); + HistogramPointData point = + (HistogramPointData) metric.getHistogramData().getPoints().iterator().next(); + assertThat(point.getSum()).isEqualTo(8); // "response" is 8. + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("my-agent"); + assertThat(point.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))) + .isEqualTo("my-tool"); + } + + private MetricData findMetricByName(String name) { + return openTelemetryRule.getMetrics().stream() + .filter(m -> m.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new AssertionError("Metric not found: " + name)); + } +}