Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
import com.google.adk.events.Event;
import com.google.adk.plugins.Plugin;
import com.google.adk.telemetry.Tracing;
import com.google.adk.telemetry.Instrumentation;
import com.google.adk.telemetry.Instrumentation.AgentInvocation;
import com.google.adk.utils.AgentEnums.AgentOrigin;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
Expand Down Expand Up @@ -322,11 +323,13 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
private Flowable<Event> run(
InvocationContext parentContext,
Function<InvocationContext, Flowable<Event>> runImplementation) {
Context parentSpanContext = Context.current();
return Flowable.defer(
() -> {
InvocationContext invocationContext = createInvocationContext(parentContext);

Context otelContext = Context.current();
return Flowable.using(
() ->
Instrumentation.recordAgentInvocation(
createInvocationContext(parentContext), this, otelContext),
agentInvocation -> {
InvocationContext invocationContext = agentInvocation.getCtx();
Flowable<Event> mainAndAfterEvents =
Flowable.defer(() -> runImplementation.apply(invocationContext))
.concatWith(
Expand All @@ -350,14 +353,10 @@ private Flowable<Event> run(
return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents);
})
.switchIfEmpty(mainAndAfterEvents)
.compose(
Tracing.<Event>trace("invoke_agent " + name())
.setParent(parentSpanContext)
.configure(
span ->
Tracing.traceAgentInvocation(
span, name(), description(), invocationContext)));
});
.doOnNext(agentInvocation::addEvent)
.doOnError(agentInvocation::setError);
},
AgentInvocation::close);
}

/**
Expand Down
20 changes: 8 additions & 12 deletions core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,10 @@ private Flowable<Event> runOneStep(Context spanContext, InvocationContext contex
"Agent not found: " + agentToTransfer)));
}
return postProcessedEvents.concatWith(
Flowable.defer(
() -> {
try (Scope s = spanContext.makeCurrent()) {
return nextAgent.get().runAsync(context);
}
}));
nextAgent
.get()
.runAsync(context)
.compose(Tracing.withContext(spanContext)));
}
return postProcessedEvents;
});
Expand Down Expand Up @@ -666,12 +664,10 @@ public void onError(Throwable e) {
"Agent not found: " + event.actions().transferToAgent().get());
}
Flowable<Event> nextAgentEvents =
Flowable.defer(
() -> {
try (Scope s = spanContext.makeCurrent()) {
return nextAgent.get().runLive(invocationContext);
}
});
nextAgent
.get()
.runLive(invocationContext)
.compose(Tracing.withContext(spanContext));
events = Flowable.concat(events, nextAgentEvents);
}
return events;
Expand Down
36 changes: 22 additions & 14 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.adk.events.ToolConfirmation;
import com.google.adk.telemetry.Instrumentation;
import com.google.adk.telemetry.Instrumentation.ToolExecution;
import com.google.adk.telemetry.Tracing;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.FunctionTool;
Expand Down Expand Up @@ -430,6 +432,25 @@ private static Maybe<Event> postProcessFunctionResult(
ToolContext toolContext,
boolean isLive,
Context parentContext) {
return Maybe.using(
() ->
Instrumentation.recordToolExecution(
tool, invocationContext.agent(), functionArgs, parentContext),
toolExecution ->
processFunctionResult(
maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive)
.doOnSuccess(event -> toolExecution.context().setFunctionResponseEvent(event))
.doOnError(toolExecution::setError),
ToolExecution::close);
}

private static Maybe<Event> processFunctionResult(
Maybe<Map<String, Object>> maybeFunctionResult,
InvocationContext invocationContext,
BaseTool tool,
Map<String, Object> functionArgs,
ToolContext toolContext,
boolean isLive) {
return maybeFunctionResult
.map(Optional::of)
.defaultIfEmpty(Optional.empty())
Expand Down Expand Up @@ -467,20 +488,7 @@ private static Maybe<Event> postProcessFunctionResult(
tool, finalFunctionResult, toolContext, invocationContext);
return Maybe.just(event);
});
})
.compose(
Tracing.<Event>trace("execute_tool [" + tool.name() + "]")
.setParent(parentContext)
.onSuccess(
(span, event) ->
Tracing.traceToolExecution(
span,
tool.name(),
tool.description(),
tool.getClass().getSimpleName(),
functionArgs,
event,
null)));
});
}

private static Optional<Event> mergeParallelFunctionResponseEvents(
Expand Down
31 changes: 25 additions & 6 deletions core/src/main/java/com/google/adk/telemetry/Instrumentation.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ public static final class AgentInvocation extends ClosableTelemetryScope {
private final InvocationContext ctx;
private final List<Event> events = Collections.synchronizedList(new ArrayList<>());

public AgentInvocation(InvocationContext ctx, BaseAgent agent) {
super(Tracing.getTracer().spanBuilder("invoke_agent " + agent.name()).startSpan());
public AgentInvocation(InvocationContext ctx, BaseAgent agent, Context parentContext) {
super(
Tracing.getTracer()
.spanBuilder("invoke_agent " + agent.name())
.setParent(parentContext)
.startSpan());
this.agent = agent;
this.ctx = ctx;
Tracing.traceAgentInvocation(span, agent.name(), agent.description(), ctx);
Expand Down Expand Up @@ -160,8 +164,13 @@ public static final class ToolExecution extends ClosableTelemetryScope {
private final BaseAgent agent;
private final Map<String, Object> functionArgs;

public ToolExecution(BaseTool tool, BaseAgent agent, Map<String, Object> functionArgs) {
super(Tracing.getTracer().spanBuilder("execute_tool " + tool.name()).startSpan());
public ToolExecution(
BaseTool tool, BaseAgent agent, Map<String, Object> functionArgs, Context parentContext) {
super(
Tracing.getTracer()
.spanBuilder("execute_tool " + tool.name())
.setParent(parentContext)
.startSpan());
this.tool = tool;
this.agent = agent;
this.functionArgs = functionArgs;
Expand Down Expand Up @@ -196,12 +205,22 @@ protected void handleMetricsError(RuntimeException e) {

/** Creates an AgentInvocation context to record agent invocation telemetry. */
public static AgentInvocation recordAgentInvocation(InvocationContext ctx, BaseAgent agent) {
return new AgentInvocation(ctx, agent);
return recordAgentInvocation(ctx, agent, Context.current());
}

public static AgentInvocation recordAgentInvocation(
InvocationContext ctx, BaseAgent agent, Context parentContext) {
return new AgentInvocation(ctx, agent, parentContext);
}

/** Creates a ToolExecution context to record tool execution telemetry. */
public static ToolExecution recordToolExecution(
BaseTool tool, BaseAgent agent, Map<String, Object> functionArgs) {
return new ToolExecution(tool, agent, functionArgs);
return recordToolExecution(tool, agent, functionArgs, Context.current());
}

public static ToolExecution recordToolExecution(
BaseTool tool, BaseAgent agent, Map<String, Object> functionArgs, Context parentContext) {
return new ToolExecution(tool, agent, functionArgs, parentContext);
}
}
84 changes: 83 additions & 1 deletion core/src/test/java/com/google/adk/agents/BaseAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,42 @@
import com.google.adk.agents.Callbacks.AfterAgentCallback;
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
import com.google.adk.events.Event;
import com.google.adk.telemetry.Metrics;
import com.google.adk.testing.TestBaseAgent;
import com.google.adk.testing.TestCallback;
import com.google.adk.testing.TestUtils;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import com.google.genai.types.Part;
import io.opentelemetry.api.GlobalOpenTelemetry;
import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.metrics.Meter;
import io.opentelemetry.sdk.OpenTelemetrySdk;
import io.opentelemetry.sdk.metrics.SdkMeterProvider;
import io.opentelemetry.sdk.metrics.data.HistogramPointData;
import io.opentelemetry.sdk.metrics.data.MetricData;
import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader;
import io.opentelemetry.sdk.testing.time.TestClock;
import io.opentelemetry.sdk.trace.SdkTracerProvider;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Maybe;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class BaseAgentTest {

private static final String TEST_AGENT_NAME = "testAgent";
private static final String TEST_AGENT_DESCRIPTION = "A test agent";

private InMemoryMetricReader inMemoryMetricReader;
private TestClock testClock;
private Meter originalMeter;

private static class ClosableTestAgent extends TestBaseAgent {
final AtomicBoolean closed = new AtomicBoolean(false);

Expand All @@ -56,6 +72,35 @@ public Completable close() {
}
}

@Before
public void setUp() {
GlobalOpenTelemetry.resetForTest();
testClock = TestClock.create();
inMemoryMetricReader = InMemoryMetricReader.create();
SdkMeterProvider sdkMeterProvider =
SdkMeterProvider.builder()
.registerMetricReader(inMemoryMetricReader)
.setClock(testClock)
.build();

OpenTelemetrySdk openTelemetrySdk =
OpenTelemetrySdk.builder()
.setTracerProvider(SdkTracerProvider.builder().build())
.setMeterProvider(sdkMeterProvider)
.build();

GlobalOpenTelemetry.set(openTelemetrySdk);
originalMeter = GlobalOpenTelemetry.getMeter("gcp.vertex.agent");
Metrics.setMeterForTesting(openTelemetrySdk.getMeter("gcp.vertex.agent"));
}

@After
public void tearDown() {
if (originalMeter != null) {
Metrics.setMeterForTesting(originalMeter);
}
}

@Test
public void constructor_setsNameAndDescription() {
String name = "testName";
Expand Down Expand Up @@ -173,6 +218,36 @@ public void runAsync_noCallbacks_invokesRunAsyncImpl() {
assertThat(results).hasSize(1);
assertThat(results.get(0).content()).hasValue(runAsyncImplContent);
assertThat(runAsyncImpl.wasCalled()).isTrue();
MetricData durationMetric = findMetricByName("gen_ai.agent.invocation.duration");
assertThat(durationMetric.getUnit()).isEqualTo("ms");
HistogramPointData durationPoint =
durationMetric.getHistogramData().getPoints().iterator().next();
assertThat(durationPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name")))
.isEqualTo("testAgent");

MetricData reqSizeMetric = findMetricByName("gen_ai.agent.request.size");
assertThat(reqSizeMetric.getUnit()).isEqualTo("By");
HistogramPointData reqSizePoint =
reqSizeMetric.getHistogramData().getPoints().iterator().next();
assertThat(reqSizePoint.getSum()).isEqualTo(12.0);
assertThat(reqSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name")))
.isEqualTo("testAgent");

MetricData respSizeMetric = findMetricByName("gen_ai.agent.response.size");
assertThat(respSizeMetric.getUnit()).isEqualTo("By");
HistogramPointData respSizePoint =
respSizeMetric.getHistogramData().getPoints().iterator().next();
assertThat(respSizePoint.getSum()).isEqualTo(11.0);
assertThat(respSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name")))
.isEqualTo("testAgent");

MetricData workflowStepsMetric = findMetricByName("gen_ai.agent.workflow.steps");
assertThat(workflowStepsMetric.getUnit()).isEqualTo("1");
HistogramPointData workflowStepsPoint =
workflowStepsMetric.getHistogramData().getPoints().iterator().next();
assertThat(workflowStepsPoint.getSum()).isEqualTo(1.0);
assertThat(workflowStepsPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name")))
.isEqualTo("testAgent");
}

@Test
Expand Down Expand Up @@ -627,4 +702,11 @@ public void close_twoLevelsSubAgents_closesAllSubAgents() {
assertThat(subAgent.closed.get()).isTrue();
assertThat(subSubAgent.closed.get()).isTrue();
}

private MetricData findMetricByName(String name) {
return inMemoryMetricReader.collectAllMetrics().stream()
.filter(m -> m.getName().equals(name))
.findFirst()
.orElseThrow(() -> new AssertionError("Metric not found: " + name));
}
}
2 changes: 1 addition & 1 deletion core/src/test/java/com/google/adk/agents/LlmAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException {
List<SpanData> spans = openTelemetryRule.getSpans();
SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent");
List<SpanData> llmSpans = findSpansByName(spans, "call_llm");
List<SpanData> toolSpans = findSpansByName(spans, "execute_tool [echo_tool]");
List<SpanData> toolSpans = findSpansByName(spans, "execute_tool echo_tool");

assertThat(llmSpans).hasSize(2);
assertThat(toolSpans).hasSize(1);
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/java/com/google/adk/runner/RunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ public void runAsync_createsToolSpansWithCorrectParent() {
List<SpanData> spans = openTelemetryRule.getSpans();
List<SpanData> llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList();
List<SpanData> toolSpans =
spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList();
spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList();

assertThat(llmSpans).hasSize(2);
assertThat(toolSpans).hasSize(1);
Expand Down Expand Up @@ -1401,7 +1401,7 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception {
List<SpanData> spans = openTelemetryRule.getSpans();
List<SpanData> llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList();
List<SpanData> toolSpans =
spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList();
spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList();

// In runLive, there is one call_llm span for the execution
assertThat(llmSpans).hasSize(1);
Expand Down
Loading
Loading