diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index cbceceed2..5b154862e 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -24,7 +24,8 @@ import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; import com.google.adk.plugins.Plugin; -import com.google.adk.telemetry.Tracing; +import com.google.adk.telemetry.Instrumentation; +import com.google.adk.telemetry.Instrumentation.AgentInvocation; import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -322,11 +323,13 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { - Context parentSpanContext = Context.current(); - return Flowable.defer( - () -> { - InvocationContext invocationContext = createInvocationContext(parentContext); - + Context otelContext = Context.current(); + return Flowable.using( + () -> + Instrumentation.recordAgentInvocation( + createInvocationContext(parentContext), this, otelContext), + agentInvocation -> { + InvocationContext invocationContext = agentInvocation.getCtx(); Flowable mainAndAfterEvents = Flowable.defer(() -> runImplementation.apply(invocationContext)) .concatWith( @@ -350,14 +353,10 @@ private Flowable run( return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); }) .switchIfEmpty(mainAndAfterEvents) - .compose( - Tracing.trace("invoke_agent " + name()) - .setParent(parentSpanContext) - .configure( - span -> - Tracing.traceAgentInvocation( - span, name(), description(), invocationContext))); - }); + .doOnNext(agentInvocation::addEvent) + .doOnError(agentInvocation::setError); + }, + AgentInvocation::close); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index ef7dce75a..dffba0e80 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -479,12 +479,10 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex "Agent not found: " + agentToTransfer))); } return postProcessedEvents.concatWith( - Flowable.defer( - () -> { - try (Scope s = spanContext.makeCurrent()) { - return nextAgent.get().runAsync(context); - } - })); + nextAgent + .get() + .runAsync(context) + .compose(Tracing.withContext(spanContext))); } return postProcessedEvents; }); @@ -666,12 +664,10 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - Flowable.defer( - () -> { - try (Scope s = spanContext.makeCurrent()) { - return nextAgent.get().runLive(invocationContext); - } - }); + nextAgent + .get() + .runLive(invocationContext) + .compose(Tracing.withContext(spanContext)); events = Flowable.concat(events, nextAgentEvents); } return events; diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 4aa20798d..8c60ebf76 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -29,6 +29,8 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.events.ToolConfirmation; +import com.google.adk.telemetry.Instrumentation; +import com.google.adk.telemetry.Instrumentation.ToolExecution; import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; @@ -430,6 +432,25 @@ private static Maybe postProcessFunctionResult( ToolContext toolContext, boolean isLive, Context parentContext) { + return Maybe.using( + () -> + Instrumentation.recordToolExecution( + tool, invocationContext.agent(), functionArgs, parentContext), + toolExecution -> + processFunctionResult( + maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive) + .doOnSuccess(event -> toolExecution.context().setFunctionResponseEvent(event)) + .doOnError(toolExecution::setError), + ToolExecution::close); + } + + private static Maybe processFunctionResult( + Maybe> maybeFunctionResult, + InvocationContext invocationContext, + BaseTool tool, + Map functionArgs, + ToolContext toolContext, + boolean isLive) { return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) @@ -467,20 +488,7 @@ private static Maybe postProcessFunctionResult( tool, finalFunctionResult, toolContext, invocationContext); return Maybe.just(event); }); - }) - .compose( - Tracing.trace("execute_tool [" + tool.name() + "]") - .setParent(parentContext) - .onSuccess( - (span, event) -> - Tracing.traceToolExecution( - span, - tool.name(), - tool.description(), - tool.getClass().getSimpleName(), - functionArgs, - event, - null))); + }); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java index a2c62ba12..fd27878c9 100644 --- a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java +++ b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java @@ -125,8 +125,12 @@ public static final class AgentInvocation extends ClosableTelemetryScope { private final InvocationContext ctx; private final List events = Collections.synchronizedList(new ArrayList<>()); - public AgentInvocation(InvocationContext ctx, BaseAgent agent) { - super(Tracing.getTracer().spanBuilder("invoke_agent " + agent.name()).startSpan()); + public AgentInvocation(InvocationContext ctx, BaseAgent agent, Context parentContext) { + super( + Tracing.getTracer() + .spanBuilder("invoke_agent " + agent.name()) + .setParent(parentContext) + .startSpan()); this.agent = agent; this.ctx = ctx; Tracing.traceAgentInvocation(span, agent.name(), agent.description(), ctx); @@ -160,8 +164,13 @@ public static final class ToolExecution extends ClosableTelemetryScope { private final BaseAgent agent; private final Map functionArgs; - public ToolExecution(BaseTool tool, BaseAgent agent, Map functionArgs) { - super(Tracing.getTracer().spanBuilder("execute_tool " + tool.name()).startSpan()); + public ToolExecution( + BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { + super( + Tracing.getTracer() + .spanBuilder("execute_tool " + tool.name()) + .setParent(parentContext) + .startSpan()); this.tool = tool; this.agent = agent; this.functionArgs = functionArgs; @@ -196,12 +205,22 @@ protected void handleMetricsError(RuntimeException e) { /** Creates an AgentInvocation context to record agent invocation telemetry. */ public static AgentInvocation recordAgentInvocation(InvocationContext ctx, BaseAgent agent) { - return new AgentInvocation(ctx, agent); + return recordAgentInvocation(ctx, agent, Context.current()); + } + + public static AgentInvocation recordAgentInvocation( + InvocationContext ctx, BaseAgent agent, Context parentContext) { + return new AgentInvocation(ctx, agent, parentContext); } /** Creates a ToolExecution context to record tool execution telemetry. */ public static ToolExecution recordToolExecution( BaseTool tool, BaseAgent agent, Map functionArgs) { - return new ToolExecution(tool, agent, functionArgs); + return recordToolExecution(tool, agent, functionArgs, Context.current()); + } + + public static ToolExecution recordToolExecution( + BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { + return new ToolExecution(tool, agent, functionArgs, parentContext); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 5e2fa5792..a3436e6cb 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -22,26 +22,42 @@ import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; +import com.google.adk.telemetry.Metrics; import com.google.adk.testing.TestBaseAgent; import com.google.adk.testing.TestCallback; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import io.opentelemetry.sdk.metrics.data.HistogramPointData; +import io.opentelemetry.sdk.metrics.data.MetricData; +import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader; +import io.opentelemetry.sdk.testing.time.TestClock; +import io.opentelemetry.sdk.trace.SdkTracerProvider; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public final class BaseAgentTest { - private static final String TEST_AGENT_NAME = "testAgent"; private static final String TEST_AGENT_DESCRIPTION = "A test agent"; + private InMemoryMetricReader inMemoryMetricReader; + private TestClock testClock; + private Meter originalMeter; + private static class ClosableTestAgent extends TestBaseAgent { final AtomicBoolean closed = new AtomicBoolean(false); @@ -56,6 +72,35 @@ public Completable close() { } } + @Before + public void setUp() { + GlobalOpenTelemetry.resetForTest(); + testClock = TestClock.create(); + inMemoryMetricReader = InMemoryMetricReader.create(); + SdkMeterProvider sdkMeterProvider = + SdkMeterProvider.builder() + .registerMetricReader(inMemoryMetricReader) + .setClock(testClock) + .build(); + + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder() + .setTracerProvider(SdkTracerProvider.builder().build()) + .setMeterProvider(sdkMeterProvider) + .build(); + + GlobalOpenTelemetry.set(openTelemetrySdk); + originalMeter = GlobalOpenTelemetry.getMeter("gcp.vertex.agent"); + Metrics.setMeterForTesting(openTelemetrySdk.getMeter("gcp.vertex.agent")); + } + + @After + public void tearDown() { + if (originalMeter != null) { + Metrics.setMeterForTesting(originalMeter); + } + } + @Test public void constructor_setsNameAndDescription() { String name = "testName"; @@ -173,6 +218,36 @@ public void runAsync_noCallbacks_invokesRunAsyncImpl() { assertThat(results).hasSize(1); assertThat(results.get(0).content()).hasValue(runAsyncImplContent); assertThat(runAsyncImpl.wasCalled()).isTrue(); + MetricData durationMetric = findMetricByName("gen_ai.agent.invocation.duration"); + assertThat(durationMetric.getUnit()).isEqualTo("ms"); + HistogramPointData durationPoint = + durationMetric.getHistogramData().getPoints().iterator().next(); + assertThat(durationPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); + + MetricData reqSizeMetric = findMetricByName("gen_ai.agent.request.size"); + assertThat(reqSizeMetric.getUnit()).isEqualTo("By"); + HistogramPointData reqSizePoint = + reqSizeMetric.getHistogramData().getPoints().iterator().next(); + assertThat(reqSizePoint.getSum()).isEqualTo(12.0); + assertThat(reqSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); + + MetricData respSizeMetric = findMetricByName("gen_ai.agent.response.size"); + assertThat(respSizeMetric.getUnit()).isEqualTo("By"); + HistogramPointData respSizePoint = + respSizeMetric.getHistogramData().getPoints().iterator().next(); + assertThat(respSizePoint.getSum()).isEqualTo(11.0); + assertThat(respSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); + + MetricData workflowStepsMetric = findMetricByName("gen_ai.agent.workflow.steps"); + assertThat(workflowStepsMetric.getUnit()).isEqualTo("1"); + HistogramPointData workflowStepsPoint = + workflowStepsMetric.getHistogramData().getPoints().iterator().next(); + assertThat(workflowStepsPoint.getSum()).isEqualTo(1.0); + assertThat(workflowStepsPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); } @Test @@ -627,4 +702,11 @@ public void close_twoLevelsSubAgents_closesAllSubAgents() { assertThat(subAgent.closed.get()).isTrue(); assertThat(subSubAgent.closed.get()).isTrue(); } + + private MetricData findMetricByName(String name) { + return inMemoryMetricReader.collectAllMetrics().stream() + .filter(m -> m.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new AssertionError("Metric not found: " + name)); + } } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 35cf12f6f..26843bb56 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -494,7 +494,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { List spans = openTelemetryRule.getSpans(); SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); List llmSpans = findSpansByName(spans, "call_llm"); - List toolSpans = findSpansByName(spans, "execute_tool [echo_tool]"); + List toolSpans = findSpansByName(spans, "execute_tool echo_tool"); assertThat(llmSpans).hasSize(2); assertThat(toolSpans).hasSize(1); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 3abfbdc20..00d5d63bf 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -1366,7 +1366,7 @@ public void runAsync_createsToolSpansWithCorrectParent() { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); List toolSpans = - spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); + spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList(); assertThat(llmSpans).hasSize(2); assertThat(toolSpans).hasSize(1); @@ -1401,7 +1401,7 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); List toolSpans = - spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); + spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList(); // In runLive, there is one call_llm span for the execution assertThat(llmSpans).hasSize(1); diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 44877e972..331ae77b2 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -471,7 +471,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent test_agent // ├── call_llm - // │ └── execute_tool [search_flights] + // │ └── execute_tool search_flights // └── call_llm SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); @@ -499,7 +499,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); - SpanData toolResponse = findSpanByName("execute_tool [search_flights]"); + SpanData toolResponse = findSpanByName("execute_tool search_flights"); List callLlmSpans = openTelemetryRule.getSpans().stream() .filter(s -> s.getName().equals("call_llm")) @@ -515,7 +515,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { assertParent(invocation, invokeAgent); // ├── call_llm 1 assertParent(invokeAgent, callLlm1); - // │ └── execute_tool [search_flights] + // │ └── execute_tool search_flights assertParent(callLlm1, toolResponse); // └── call_llm 2 assertParent(invokeAgent, callLlm2); @@ -546,7 +546,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent AgentA // ├── call_llm - // │ └── execute_tool [transfer_to_agent] + // │ └── execute_tool transfer_to_agent // └── invoke_agent AgentB // └── call_llm TestLlm llm = @@ -573,7 +573,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData agentASpan = findSpanByName("invoke_agent AgentA"); - SpanData executeTool = findSpanByName("execute_tool [transfer_to_agent]"); + SpanData executeTool = findSpanByName("execute_tool transfer_to_agent"); SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); List callLlmSpans = @@ -586,10 +586,17 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData agentACallLlm1 = callLlmSpans.get(0); SpanData agentBCallLlm = callLlmSpans.get(1); + // Assert hierarchy: + // invocation + // └── invoke_agent AgentA assertParent(invocation, agentASpan); + // └── call_llm 1 assertParent(agentASpan, agentACallLlm1); + // ├── execute_tool transfer_to_agent assertParent(agentACallLlm1, executeTool); - assertParent(agentASpan, agentBSpan); + // └── invoke_agent AgentB + assertParent(agentACallLlm1, agentBSpan); + // └── call_llm 2 assertParent(agentBSpan, agentBCallLlm); } diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java index 6f35f5a3c..403e3874a 100644 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -91,7 +91,7 @@ public Supplier> asRunAsyncImplSupplier(Content content) { Flowable.defer( () -> { markAsCalled(); - return Flowable.just(Event.builder().content(content).build()); + return Flowable.just(Event.builder().author("testAgent").content(content).build()); }); } @@ -111,7 +111,7 @@ public Supplier> asRunLiveImplSupplier(Content content) { Flowable.defer( () -> { markAsCalled(); - return Flowable.just(Event.builder().content(content).build()); + return Flowable.just(Event.builder().author("testAgent").content(content).build()); }); }