diff --git a/core/pom.xml b/core/pom.xml index eefbcda79..765caf03b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -189,6 +189,25 @@ opentelemetry-sdk-testing test + + com.google.cloud + google-cloud-bigquery + + + com.google.cloud + google-cloud-bigquerystorage + + + org.apache.arrow + arrow-vector + 18.1.0 + + + org.apache.arrow + arrow-memory-netty + 18.1.0 + runtime + diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java new file mode 100644 index 000000000..9ea842225 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BatchProcessor.java @@ -0,0 +1,269 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.api.core.ApiFuture; +import com.google.api.core.ApiFutureCallback; +import com.google.api.core.ApiFutures; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializationError; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Handles asynchronous batching and writing of events to BigQuery. */ +class BatchProcessor implements AutoCloseable { + private static final Logger logger = Logger.getLogger(BatchProcessor.class.getName()); + + private final StreamWriter writer; + private final int batchSize; + private final Duration flushInterval; + private final BlockingQueue> queue; + private final ScheduledExecutorService executor; + @VisibleForTesting final BufferAllocator allocator; + final AtomicBoolean flushLock = new AtomicBoolean(false); + private final Schema arrowSchema; + + public BatchProcessor( + StreamWriter writer, + int batchSize, + Duration flushInterval, + int queueMaxSize, + ScheduledExecutorService executor) { + this.writer = writer; + this.batchSize = batchSize; + this.flushInterval = flushInterval; + this.queue = new LinkedBlockingQueue<>(queueMaxSize); + this.executor = executor; + this.allocator = new RootAllocator(Long.MAX_VALUE); + this.arrowSchema = BigQuerySchema.getArrowSchema(); + } + + public void start() { + @SuppressWarnings("unused") + var unused = + executor.scheduleWithFixedDelay( + () -> { + try { + flush(); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Error in background flush", e); + } + }, + flushInterval.toMillis(), + flushInterval.toMillis(), + MILLISECONDS); + } + + public void append(Map row) { + if (!queue.offer(row)) { + logger.warning("BigQuery event queue is full, dropping event."); + return; + } + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + + public void flush() { + // Acquire the flushLock. If another flush is already in progress, return immediately. + if (!flushLock.compareAndSet(false, true)) { + return; + } + + try { + if (queue.isEmpty()) { + return; + } + + List> batch = new ArrayList<>(); + queue.drainTo(batch, batchSize); + + if (batch.isEmpty()) { + return; + } + try (VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) { + root.allocateNew(); + int rowCount = batch.size(); + + for (int i = 0; i < rowCount; i++) { + Map row = batch.get(i); + for (Field field : arrowSchema.getFields()) { + populateVector(root.getVector(field.getName()), i, row.get(field.getName())); + } + } + root.setRowCount(rowCount); + + try (ArrowRecordBatch recordBatch = new VectorUnloader(root).getRecordBatch()) { + ApiFuture future = writer.append(recordBatch); + ApiFutures.addCallback( + future, + new ApiFutureCallback() { + @Override + public void onFailure(Throwable t) { + logger.log(Level.SEVERE, "Failed to write batch to BigQuery", t); + if (t instanceof AppendSerializationError ase) { + Map rowIndexToErrorMessage = ase.getRowIndexToErrorMessage(); + + if (rowIndexToErrorMessage != null && !rowIndexToErrorMessage.isEmpty()) { + logger.severe("Row-level errors found:"); + for (Map.Entry entry : rowIndexToErrorMessage.entrySet()) { + logger.severe( + String.format( + "Row error at index %d: %s", entry.getKey(), entry.getValue())); + } + } else { + logger.severe( + "AppendSerializationError occurred, but no row-specific errors were" + + " provided."); + } + } + } + + @Override + public void onSuccess(AppendRowsResponse result) { + if (result.hasError()) { + logger.severe("BigQuery append error: " + result.getError().getMessage()); + for (var error : result.getRowErrorsList()) { + logger.severe( + String.format( + "Row error at index %d: %s", error.getIndex(), error.getMessage())); + } + } else { + logger.fine("Successfully wrote " + batch.size() + " rows to BigQuery."); + } + } + }, + directExecutor()); + } + } catch (RuntimeException e) { + logger.log(Level.SEVERE, "Failed to append rows to StreamWriter", e); + } + } finally { + flushLock.set(false); + if (queue.size() >= batchSize && !flushLock.get()) { + executor.execute(this::flush); + } + } + } + + private void populateVector(FieldVector vector, int index, Object value) { + if (value == null || (value instanceof JsonNode jsonNode && jsonNode.isNull())) { + vector.setNull(index); + return; + } + + if (vector instanceof VarCharVector varCharVector) { + String strValue = (value instanceof JsonNode jsonNode) ? jsonNode.asText() : value.toString(); + varCharVector.setSafe(index, strValue.getBytes(UTF_8)); + } else if (vector instanceof BigIntVector bigIntVector) { + long longValue; + if (value instanceof JsonNode jsonNode) { + longValue = jsonNode.asLong(); + } else if (value instanceof Number number) { + longValue = number.longValue(); + } else { + longValue = Long.parseLong(value.toString()); + } + bigIntVector.setSafe(index, longValue); + } else if (vector instanceof BitVector bitVector) { + boolean boolValue = + (value instanceof JsonNode jsonNode) ? jsonNode.asBoolean() : (Boolean) value; + bitVector.setSafe(index, boolValue ? 1 : 0); + } else if (vector instanceof TimeStampVector timeStampVector) { + if (value instanceof Instant instant) { + long micros = + SECONDS.toMicros(instant.getEpochSecond()) + NANOSECONDS.toMicros(instant.getNano()); + timeStampVector.setSafe(index, micros); + } else if (value instanceof JsonNode jsonNode) { + timeStampVector.setSafe(index, jsonNode.asLong()); + } else if (value instanceof Long longValue) { + timeStampVector.setSafe(index, longValue); + } + } else if (vector instanceof ListVector listVector) { + int start = listVector.startNewValue(index); + if (value instanceof ArrayNode arrayNode) { + for (int i = 0; i < arrayNode.size(); i++) { + populateVector(listVector.getDataVector(), start + i, arrayNode.get(i)); + } + listVector.endValue(index, arrayNode.size()); + } else if (value instanceof List) { + List list = (List) value; + for (int i = 0; i < list.size(); i++) { + populateVector(listVector.getDataVector(), start + i, list.get(i)); + } + listVector.endValue(index, list.size()); + } + } else if (vector instanceof StructVector structVector) { + structVector.setIndexDefined(index); + if (value instanceof ObjectNode objectNode) { + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, objectNode.get(child.getName())); + } + } else if (value instanceof Map) { + Map map = (Map) value; + for (FieldVector child : structVector.getChildrenFromFields()) { + populateVector(child, index, map.get(child.getName())); + } + } + } + } + + @Override + public void close() { + flush(); + if (writer != null) { + writer.close(); + } + if (allocator != null) { + allocator.close(); + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java new file mode 100644 index 000000000..5d486f31e --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -0,0 +1,435 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.api.gax.retrying.RetrySettings; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryException; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Clustering; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardTableDefinition; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.TableInfo; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteSettings; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.threeten.bp.Duration; + +/** + * BigQuery Agent Analytics Plugin for Java. + * + *

Logs agent execution events directly to a BigQuery table using the Storage Write API. + */ +public class BigQueryAgentAnalyticsPlugin extends BasePlugin { + private static final Logger logger = + Logger.getLogger(BigQueryAgentAnalyticsPlugin.class.getName()); + private static final ImmutableList DEFAULT_AUTH_SCOPES = + ImmutableList.of("https://www.googleapis.com/auth/cloud-platform"); + private static final AtomicLong threadCounter = new AtomicLong(0); + + private final BigQueryLoggerConfig config; + private final BigQuery bigQuery; + private final BigQueryWriteClient writeClient; + private final ScheduledExecutorService executor; + private final Object tableEnsuredLock = new Object(); + @VisibleForTesting final BatchProcessor batchProcessor; + private volatile boolean tableEnsured = false; + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config) throws IOException { + this(config, createBigQuery(config)); + } + + public BigQueryAgentAnalyticsPlugin(BigQueryLoggerConfig config, BigQuery bigQuery) + throws IOException { + super("bigquery_agent_analytics"); + this.config = config; + this.bigQuery = bigQuery; + ThreadFactory threadFactory = + r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); + this.executor = Executors.newScheduledThreadPool(1, threadFactory); + this.writeClient = createWriteClient(config); + + if (config.enabled()) { + StreamWriter writer = createWriter(config); + this.batchProcessor = + new BatchProcessor( + writer, + config.batchSize(), + config.batchFlushInterval(), + config.queueMaxSize(), + executor); + this.batchProcessor.start(); + } else { + this.batchProcessor = null; + } + } + + private static BigQuery createBigQuery(BigQueryLoggerConfig config) throws IOException { + BigQueryOptions.Builder builder = BigQueryOptions.newBuilder(); + if (config.credentials() != null) { + builder.setCredentials(config.credentials()); + } else { + builder.setCredentials( + GoogleCredentials.getApplicationDefault().createScoped(DEFAULT_AUTH_SCOPES)); + } + return builder.build().getService(); + } + + private void ensureTableExistsOnce() { + if (!tableEnsured) { + synchronized (tableEnsuredLock) { + if (!tableEnsured) { + // Table creation is expensive, so we only do it once per plugin instance. + tableEnsured = true; + ensureTableExists(bigQuery, config); + } + } + } + } + + private void ensureTableExists(BigQuery bigQuery, BigQueryLoggerConfig config) { + TableId tableId = TableId.of(config.projectId(), config.datasetId(), config.tableName()); + Schema schema = BigQuerySchema.getEventsSchema(); + try { + Table table = bigQuery.getTable(tableId); + logger.info("BigQuery table: " + tableId); + if (table == null) { + logger.info("Creating BigQuery table: " + tableId); + StandardTableDefinition.Builder tableDefinitionBuilder = + StandardTableDefinition.newBuilder().setSchema(schema); + if (!config.clusteringFields().isEmpty()) { + tableDefinitionBuilder.setClustering( + Clustering.newBuilder().setFields(config.clusteringFields()).build()); + } + TableInfo tableInfo = TableInfo.newBuilder(tableId, tableDefinitionBuilder.build()).build(); + bigQuery.create(tableInfo); + } else if (config.autoSchemaUpgrade()) { + // TODO(vmaliuta): Implement auto-schema upgrade. + logger.info("BigQuery table already exists and auto-schema upgrade is enabled: " + tableId); + logger.info("Auto-schema upgrade is not implemented yet."); + } + } catch (BigQueryException e) { + if (e.getMessage().contains("invalid_grant")) { + logger.log( + Level.SEVERE, + "Failed to authenticate with BigQuery. Please run 'gcloud auth application-default" + + " login' to refresh your credentials or provide valid credentials in" + + " BigQueryLoggerConfig.", + e); + } else { + logger.log( + Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Failed to check or create/upgrade BigQuery table: " + tableId, e); + } + } + + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) throws IOException { + if (config.credentials() != null) { + return BigQueryWriteClient.create( + BigQueryWriteSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(config.credentials())) + .build()); + } + return BigQueryWriteClient.create(); + } + + protected String getStreamName(BigQueryLoggerConfig config) { + return String.format( + "projects/%s/datasets/%s/tables/%s/streams/_default", + config.projectId(), config.datasetId(), config.tableName()); + } + + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + BigQueryLoggerConfig.RetryConfig rc = config.retryConfig(); + RetrySettings retrySettings = + RetrySettings.newBuilder() + .setMaxAttempts(rc.maxRetries()) + .setInitialRetryDelay(Duration.ofMillis(rc.initialDelay().toMillis())) + .setRetryDelayMultiplier(rc.multiplier()) + .setMaxRetryDelay(Duration.ofMillis(rc.maxDelay().toMillis())) + .build(); + + String streamName = getStreamName(config); + try { + return StreamWriter.newBuilder(streamName, writeClient) + .setRetrySettings(retrySettings) + .setWriterSchema(BigQuerySchema.getArrowSchema()) + .build(); + } catch (Exception e) { + throw new VerifyException("Failed to create StreamWriter for " + streamName, e); + } + } + + private void logEvent( + String eventType, + InvocationContext invocationContext, + Optional callbackContext, + Object content, + Map extraAttributes) { + if (batchProcessor == null) { + return; + } + + ensureTableExistsOnce(); + + Map row = new HashMap<>(); + row.put("timestamp", Instant.now()); + row.put("event_type", eventType); + row.put( + "agent", + callbackContext.map(CallbackContext::agentName).orElse(invocationContext.agent().name())); + row.put("session_id", invocationContext.session().id()); + row.put("invocation_id", invocationContext.invocationId()); + row.put("user_id", invocationContext.userId()); + + if (content instanceof Content contentParts) { + row.put( + "content_parts", + JsonFormatter.formatContentParts(Optional.of(contentParts), config.maxContentLength())); + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } else if (content != null) { + row.put( + "content", JsonFormatter.smartTruncate(content, config.maxContentLength()).toString()); + } + + Map attributes = new HashMap<>(config.customTags()); + if (extraAttributes != null) { + attributes.putAll(extraAttributes); + } + row.put( + "attributes", + JsonFormatter.smartTruncate(attributes, config.maxContentLength()).toString()); + + addTraceDetails(row); + batchProcessor.append(row); + } + + private void addTraceDetails(Map row) { + SpanContext spanContext = Span.current().getSpanContext(); + if (spanContext.isValid()) { + row.put("trace_id", spanContext.getTraceId()); + row.put("span_id", spanContext.getSpanId()); + } + } + + @Override + public Completable close() { + if (batchProcessor != null) { + batchProcessor.close(); + } + if (writeClient != null) { + writeClient.close(); + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + return Completable.complete(); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + return Maybe.fromAction( + () -> logEvent("USER_MESSAGE", invocationContext, Optional.empty(), userMessage, null)); + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { + return Maybe.fromAction( + () -> logEvent("INVOCATION_START", invocationContext, Optional.empty(), null, null)); + } + + @Override + public Maybe onEventCallback(InvocationContext invocationContext, Event event) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("event_author", event.author()); + logEvent( + "EVENT", invocationContext, Optional.empty(), event.content().orElse(null), attrs); + }); + } + + @Override + public Completable afterRunCallback(InvocationContext invocationContext) { + return Completable.fromAction( + () -> { + logEvent("INVOCATION_END", invocationContext, Optional.empty(), null, null); + batchProcessor.flush(); + }); + } + + @Override + public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_START", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.fromAction( + () -> + logEvent( + "AGENT_END", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + null)); + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + LlmRequest req = llmRequest.build(); + attrs.put("model", req.model().orElse("unknown")); + logEvent( + "MODEL_REQUEST", + callbackContext.invocationContext(), + Optional.of(callbackContext), + req, + attrs); + }); + } + + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + llmResponse.usageMetadata().ifPresent(u -> attrs.put("usage_metadata", u)); + logEvent( + "MODEL_RESPONSE", + callbackContext.invocationContext(), + Optional.of(callbackContext), + llmResponse, + attrs); + }); + } + + @Override + public Maybe onModelErrorCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("error_message", error.getMessage()); + logEvent( + "MODEL_ERROR", + callbackContext.invocationContext(), + Optional.of(callbackContext), + null, + attrs); + }); + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_START", + toolContext.invocationContext(), + Optional.of(toolContext), + toolArgs, + attrs); + }); + } + + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + logEvent( + "TOOL_END", toolContext.invocationContext(), Optional.of(toolContext), result, attrs); + }); + } + + @Override + public Maybe> onToolErrorCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + return Maybe.fromAction( + () -> { + Map attrs = new HashMap<>(); + attrs.put("tool_name", tool.name()); + attrs.put("error_message", error.getMessage()); + logEvent( + "TOOL_ERROR", toolContext.invocationContext(), Optional.of(toolContext), null, attrs); + }); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java new file mode 100644 index 000000000..6d1195ea9 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -0,0 +1,183 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.google.auth.Credentials; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; +import javax.annotation.Nullable; + +/** Configuration for the BigQueryAgentAnalyticsPlugin. */ +@AutoValue +public abstract class BigQueryLoggerConfig { + + public abstract boolean enabled(); + + // TODO(vmaliuta): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventAllowlist(); + + // TODO(vmaliuta): Implement allowlist/denylist for event types. + @Nullable + public abstract ImmutableList eventDenylist(); + + public abstract int maxContentLength(); + + public abstract String projectId(); + + public abstract String datasetId(); + + public abstract String tableName(); + + public abstract ImmutableList clusteringFields(); + + // TODO(vmaliuta): Implement logging of multi-modal content. + public abstract boolean logMultiModalContent(); + + public abstract RetryConfig retryConfig(); + + public abstract int batchSize(); + + public abstract Duration batchFlushInterval(); + + public abstract Duration shutdownTimeout(); + + public abstract int queueMaxSize(); + + // TODO(vmaliuta): Implement content formatter. + @Nullable + public abstract BiFunction contentFormatter(); + + // TODO(vmaliuta): Implement connection id. + public abstract Optional connectionId(); + + // TODO(vmaliuta): Implement logging of session metadata. + public abstract boolean logSessionMetadata(); + + public abstract ImmutableMap customTags(); + + public abstract boolean autoSchemaUpgrade(); + + @Nullable + public abstract Credentials credentials(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig.Builder() + .setEnabled(true) + .setMaxContentLength(500 * 1024) + .setProjectId("") + .setDatasetId("agent_analytics") + .setTableName("events") + .setClusteringFields(ImmutableList.of("event_type", "agent", "user_id")) + .setLogMultiModalContent(true) + .setRetryConfig(RetryConfig.builder().build()) + .setBatchSize(1) + .setBatchFlushInterval(Duration.ofSeconds(1)) + .setShutdownTimeout(Duration.ofSeconds(10)) + .setQueueMaxSize(10000) + .setLogSessionMetadata(true) + .setCustomTags(ImmutableMap.of()) + .setAutoSchemaUpgrade(true); + } + + /** Builder for {@link BigQueryLoggerConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setEnabled(boolean enabled); + + public abstract Builder setEventAllowlist(@Nullable List eventAllowlist); + + public abstract Builder setEventDenylist(@Nullable List eventDenylist); + + public abstract Builder setMaxContentLength(int maxContentLength); + + public abstract Builder setProjectId(String projectId); + + public abstract Builder setDatasetId(String datasetId); + + public abstract Builder setTableName(String tableName); + + public abstract Builder setClusteringFields(List clusteringFields); + + public abstract Builder setLogMultiModalContent(boolean logMultiModalContent); + + public abstract Builder setRetryConfig(RetryConfig retryConfig); + + public abstract Builder setBatchSize(int batchSize); + + public abstract Builder setBatchFlushInterval(Duration batchFlushInterval); + + public abstract Builder setShutdownTimeout(Duration shutdownTimeout); + + public abstract Builder setQueueMaxSize(int queueMaxSize); + + public abstract Builder setContentFormatter( + @Nullable BiFunction contentFormatter); + + public abstract Builder setConnectionId(String connectionId); + + public abstract Builder setLogSessionMetadata(boolean logSessionMetadata); + + public abstract Builder setCustomTags(Map customTags); + + public abstract Builder setAutoSchemaUpgrade(boolean autoSchemaUpgrade); + + public abstract Builder setCredentials(Credentials credentials); + + public abstract BigQueryLoggerConfig build(); + } + + /** Retry configuration for BigQuery writes. */ + @AutoValue + public abstract static class RetryConfig { + public abstract int maxRetries(); + + public abstract Duration initialDelay(); + + public abstract double multiplier(); + + public abstract Duration maxDelay(); + + public static Builder builder() { + return new AutoValue_BigQueryLoggerConfig_RetryConfig.Builder() + .setMaxRetries(3) + .setInitialDelay(Duration.ofSeconds(1)) + .setMultiplier(2.0) + .setMaxDelay(Duration.ofSeconds(10)); + } + + /** Builder for {@link RetryConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setMaxRetries(int maxRetries); + + public abstract Builder setInitialDelay(Duration initialDelay); + + public abstract Builder setMultiplier(double multiplier); + + public abstract Builder setMaxDelay(Duration maxDelay); + + public abstract RetryConfig build(); + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java new file mode 100644 index 000000000..a18d90678 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQuerySchema.java @@ -0,0 +1,323 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.cloud.bigquery.Field; +import com.google.cloud.bigquery.FieldList; +import com.google.cloud.bigquery.Schema; +import com.google.cloud.bigquery.StandardSQLTypeName; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema; +import com.google.cloud.bigquery.storage.v1.TableFieldSchema.Mode; +import com.google.cloud.bigquery.storage.v1.TableSchema; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; + +/** Utility for defining the BigQuery events table schema. */ +public final class BigQuerySchema { + + private BigQuerySchema() {} + + /** Returns the BigQuery schema for the events table. */ + public static Schema getEventsSchema() { + return Schema.of( + Field.newBuilder("timestamp", StandardSQLTypeName.TIMESTAMP) + .setMode(Field.Mode.REQUIRED) + .setDescription("The UTC timestamp when the event occurred.") + .build(), + Field.newBuilder("event_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The category of the event.") + .build(), + Field.newBuilder("agent", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The name of the agent that generated this event.") + .build(), + Field.newBuilder("session_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for the entire conversation session.") + .build(), + Field.newBuilder("invocation_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("A unique identifier for a single turn or execution.") + .build(), + Field.newBuilder("user_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The identifier of the end-user.") + .build(), + Field.newBuilder("trace_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry trace ID.") + .build(), + Field.newBuilder("span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry span ID.") + .build(), + Field.newBuilder("parent_span_id", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("OpenTelemetry parent span ID.") + .build(), + Field.newBuilder("content", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("The primary payload of the event.") + .build(), + Field.newBuilder( + "content_parts", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("mime_type", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The MIME type of the content part.") + .build(), + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The URI of the content part if stored externally.") + .build(), + Field.newBuilder( + "object_ref", + StandardSQLTypeName.STRUCT, + FieldList.of( + Field.newBuilder("uri", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("version", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("authorizer", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .build(), + Field.newBuilder("details", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .build())) + .setMode(Field.Mode.NULLABLE) + .setDescription("The ObjectRef of the content part if stored externally.") + .build(), + Field.newBuilder("text", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The raw text content.") + .build(), + Field.newBuilder("part_index", StandardSQLTypeName.INT64) + .setMode(Field.Mode.NULLABLE) + .setDescription("The zero-based index of this part.") + .build(), + Field.newBuilder("part_attributes", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Additional metadata as a JSON object string.") + .build(), + Field.newBuilder("storage_mode", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates how the content part is stored.") + .build())) + .setMode(Field.Mode.REPEATED) + .setDescription("Multi-modal events content parts.") + .build(), + Field.newBuilder("attributes", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing arbitrary key-value pairs.") + .build(), + Field.newBuilder("latency_ms", StandardSQLTypeName.JSON) + .setMode(Field.Mode.NULLABLE) + .setDescription("A JSON object containing latency measurements.") + .build(), + Field.newBuilder("status", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("The outcome of the event.") + .build(), + Field.newBuilder("error_message", StandardSQLTypeName.STRING) + .setMode(Field.Mode.NULLABLE) + .setDescription("Detailed error message if the status is 'ERROR'.") + .build(), + Field.newBuilder("is_truncated", StandardSQLTypeName.BOOL) + .setMode(Field.Mode.NULLABLE) + .setDescription("Indicates if the 'content' field was truncated.") + .build()); + } + + /** Returns the Arrow schema for the events table. */ + public static org.apache.arrow.vector.types.pojo.Schema getArrowSchema() { + return new org.apache.arrow.vector.types.pojo.Schema( + getEventsSchema().getFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList())); + } + + /** Returns the serialized Arrow schema for the events table. */ + public static ByteString getSerializedArrowSchema() { + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), getArrowSchema()); + return ByteString.copyFrom(out.toByteArray()); + } catch (IOException e) { + throw new VerifyException("Failed to serialize arrow schema", e); + } + } + + private static org.apache.arrow.vector.types.pojo.Field convertToArrowField(Field field) { + ArrowType arrowType = convertTypeToArrow(field.getType().getStandardType()); + ImmutableList children = null; + if (field.getSubFields() != null) { + children = + field.getSubFields().stream() + .map(BigQuerySchema::convertToArrowField) + .collect(toImmutableList()); + } + + ImmutableMap metadata = null; + if (field.getType().getStandardType() == StandardSQLTypeName.JSON) { + metadata = ImmutableMap.of("ARROW:extension:name", "google:sqlType:json"); + } + + FieldType fieldType = + new FieldType(field.getMode() != Field.Mode.REQUIRED, arrowType, null, metadata); + org.apache.arrow.vector.types.pojo.Field arrowField = + new org.apache.arrow.vector.types.pojo.Field(field.getName(), fieldType, children); + + if (field.getMode() == Field.Mode.REPEATED) { + return new org.apache.arrow.vector.types.pojo.Field( + field.getName(), + new FieldType(false, new ArrowType.List(), null), + ImmutableList.of( + new org.apache.arrow.vector.types.pojo.Field( + "element", arrowField.getFieldType(), arrowField.getChildren()))); + } + return arrowField; + } + + private static ArrowType convertTypeToArrow(StandardSQLTypeName type) { + return switch (type) { + case BOOL -> new ArrowType.Bool(); + case BYTES -> new ArrowType.Binary(); + case DATE -> new ArrowType.Date(DateUnit.DAY); + case DATETIME -> + // Arrow doesn't have a direct DATETIME, often mapped to Timestamp or Utf8 + new ArrowType.Utf8(); + case FLOAT64 -> new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case INT64 -> new ArrowType.Int(64, true); + case NUMERIC, BIGNUMERIC -> new ArrowType.Decimal(38, 9, 128); + case GEOGRAPHY, STRING, JSON -> new ArrowType.Utf8(); + case STRUCT -> new ArrowType.Struct(); + case TIME -> new ArrowType.Time(TimeUnit.MICROSECOND, 64); + case TIMESTAMP -> new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"); + default -> new ArrowType.Null(); + }; + } + + /** Returns names of fields to cluster by default. */ + public static ImmutableList getDefaultClusteringFields() { + return ImmutableList.of("event_type", "agent", "user_id"); + } + + /** Returns the BigQuery TableSchema for the events table (Storage Write API). */ + public static TableSchema getEventsTableSchema() { + return convertTableSchema(getEventsSchema()); + } + + private static TableSchema convertTableSchema(Schema schema) { + TableSchema.Builder result = TableSchema.newBuilder(); + for (int i = 0; i < schema.getFields().size(); i++) { + result.addFields(i, convertFieldSchema(schema.getFields().get(i))); + } + return result.build(); + } + + private static TableFieldSchema convertFieldSchema(Field field) { + TableFieldSchema.Builder result = TableFieldSchema.newBuilder(); + Field.Mode mode = field.getMode() != null ? field.getMode() : Field.Mode.NULLABLE; + + Mode resultMode = Mode.valueOf(mode.name()); + result.setMode(resultMode).setName(field.getName()); + + StandardSQLTypeName standardType = field.getType().getStandardType(); + TableFieldSchema.Type resultType = convertType(standardType); + result.setType(resultType); + + if (field.getDescription() != null) { + result.setDescription(field.getDescription()); + } + if (field.getSubFields() != null) { + for (int i = 0; i < field.getSubFields().size(); i++) { + result.addFields(i, convertFieldSchema(field.getSubFields().get(i))); + } + } + return result.build(); + } + + private static TableFieldSchema.Type convertType(StandardSQLTypeName type) { + switch (type) { + case BOOL -> { + return TableFieldSchema.Type.BOOL; + } + case BYTES -> { + return TableFieldSchema.Type.BYTES; + } + case DATE -> { + return TableFieldSchema.Type.DATE; + } + case DATETIME -> { + return TableFieldSchema.Type.DATETIME; + } + case FLOAT64 -> { + return TableFieldSchema.Type.DOUBLE; + } + case GEOGRAPHY -> { + return TableFieldSchema.Type.GEOGRAPHY; + } + case INT64 -> { + return TableFieldSchema.Type.INT64; + } + case NUMERIC -> { + return TableFieldSchema.Type.NUMERIC; + } + case STRING -> { + return TableFieldSchema.Type.STRING; + } + case STRUCT -> { + return TableFieldSchema.Type.STRUCT; + } + case TIME -> { + return TableFieldSchema.Type.TIME; + } + case TIMESTAMP -> { + return TableFieldSchema.Type.TIMESTAMP; + } + case BIGNUMERIC -> { + return TableFieldSchema.Type.BIGNUMERIC; + } + case JSON -> { + return TableFieldSchema.Type.JSON; + } + case INTERVAL -> { + return TableFieldSchema.Type.INTERVAL; + } + default -> { + return TableFieldSchema.Type.TYPE_UNSPECIFIED; + } + } + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java new file mode 100644 index 000000000..5e33ea574 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -0,0 +1,111 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; + +/** Utility for formatting and truncating content for BigQuery logging using Jackson. */ +final class JsonFormatter { + private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); + + private JsonFormatter() {} + + /** Formats Content parts into an ArrayNode for BigQuery logging. */ + public static ArrayNode formatContentParts(Optional content, int maxLength) { + ArrayNode partsArray = mapper.createArrayNode(); + if (content.isEmpty() || content.get().parts() == null) { + return partsArray; + } + + List parts = content.get().parts().orElse(ImmutableList.of()); + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ObjectNode partObj = mapper.createObjectNode(); + partObj.put("part_index", i); + partObj.put("storage_mode", "INLINE"); + + if (part.text().isPresent()) { + partObj.put("mime_type", "text/plain"); + partObj.put("text", truncateString(part.text().get(), maxLength)); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + partObj.put("mime_type", blob.mimeType().orElse("")); + partObj.put("text", "[BINARY DATA]"); + } else if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partObj.put("mime_type", fileData.mimeType().orElse("")); + partObj.put("uri", fileData.fileUri().orElse("")); + partObj.put("storage_mode", "EXTERNAL_URI"); + } + partsArray.add(partObj); + } + return partsArray; + } + + /** Recursively truncates long strings inside an object and returns a Jackson JsonNode. */ + public static JsonNode smartTruncate(Object obj, int maxLength) { + if (obj == null) { + return mapper.nullNode(); + } + try { + return recursiveSmartTruncate(mapper.valueToTree(obj), maxLength); + } catch (IllegalArgumentException e) { + // Fallback for types that mapper can't handle directly as a tree + return mapper.valueToTree(String.valueOf(obj)); + } + } + + private static JsonNode recursiveSmartTruncate(JsonNode node, int maxLength) { + if (node.isTextual()) { + return mapper.valueToTree(truncateString(node.asText(), maxLength)); + } else if (node.isObject()) { + ObjectNode newNode = mapper.createObjectNode(); + node.properties() + .iterator() + .forEachRemaining( + entry -> { + newNode.set(entry.getKey(), recursiveSmartTruncate(entry.getValue(), maxLength)); + }); + return newNode; + } else if (node.isArray()) { + ArrayNode newNode = mapper.createArrayNode(); + for (JsonNode element : node) { + newNode.add(recursiveSmartTruncate(element, maxLength)); + } + return newNode; + } + return node; + } + + private static String truncateString(String s, int maxLength) { + if (s == null || s.length() <= maxLength) { + return s; + } + return s.substring(0, maxLength) + "...[truncated]"; + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java new file mode 100644 index 000000000..7580b095d --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BatchProcessorTest.java @@ -0,0 +1,122 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFutures; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BatchProcessorTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private StreamWriter mockWriter; + private ScheduledExecutorService executor; + private BatchProcessor batchProcessor; + private Schema schema; + + @Before + public void setUp() { + executor = Executors.newScheduledThreadPool(1); + batchProcessor = new BatchProcessor(mockWriter, 10, Duration.ofMinutes(1), 100, executor); + schema = BigQuerySchema.getArrowSchema(); + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + } + + @After + public void tearDown() { + batchProcessor.close(); + executor.shutdown(); + } + + @Test + public void flush_populatesTimestampFieldCorrectly() throws Exception { + Instant now = Instant.parse("2026-03-02T19:11:49.631Z"); + Map row = new HashMap<>(); + row.put("timestamp", now); + row.put("event_type", "TEST_EVENT"); + + final boolean[] checksPassed = {false}; + final String[] failureMessage = {null}; + + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenAnswer( + invocation -> { + ArrowRecordBatch recordedBatch = invocation.getArgument(0); + try (VectorSchemaRoot root = + VectorSchemaRoot.create(schema, batchProcessor.allocator)) { + VectorLoader loader = new VectorLoader(root); + loader.load(recordedBatch); + + if (root.getRowCount() != 1) { + failureMessage[0] = "Expected 1 row, got " + root.getRowCount(); + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + + var timestampVector = root.getVector("timestamp"); + if (!(timestampVector instanceof TimeStampMicroTZVector tzVector)) { + failureMessage[0] = "Vector should be an instance of TimeStampMicroTZVector"; + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + } + if (tzVector.isNull(0)) { + failureMessage[0] = "Timestamp should NOT be null"; + } else if (tzVector.get(0) != now.toEpochMilli() * 1000) { + failureMessage[0] = + "Expected " + (now.toEpochMilli() * 1000) + ", got " + tzVector.get(0); + } else { + checksPassed[0] = true; + } + } catch (RuntimeException e) { + failureMessage[0] = "Exception during check: " + e.getMessage(); + } + return ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance()); + }); + + batchProcessor.append(row); + batchProcessor.flush(); + + verify(mockWriter).append(any(ArrowRecordBatch.class)); + assertTrue(failureMessage[0], checksPassed[0]); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java new file mode 100644 index 000000000..cb6ac6a88 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -0,0 +1,204 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.sessions.Session; +import com.google.api.core.ApiFutures; +import com.google.auth.Credentials; +import com.google.cloud.bigquery.BigQuery; +import com.google.cloud.bigquery.BigQueryOptions; +import com.google.cloud.bigquery.Table; +import com.google.cloud.bigquery.TableId; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import java.time.Duration; +import java.util.Optional; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class BigQueryAgentAnalyticsPluginTest { + @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock private BigQuery mockBigQuery; + @Mock private StreamWriter mockWriter; + @Mock private BigQueryWriteClient mockWriteClient; + @Mock private InvocationContext mockInvocationContext; + private BaseAgent fakeAgent; + + private BigQueryLoggerConfig config; + private BigQueryAgentAnalyticsPlugin plugin; + + @Before + public void setUp() throws Exception { + fakeAgent = new FakeAgent("agent_name"); + config = + BigQueryLoggerConfig.builder() + .setEnabled(true) + .setDatasetId("dataset") + .setTableName("table") + .setBatchSize(10) + .setBatchFlushInterval(Duration.ofSeconds(10)) + .setAutoSchemaUpgrade(false) + .setCredentials(mock(Credentials.class)) + .build(); + + when(mockBigQuery.getOptions()) + .thenReturn(BigQueryOptions.newBuilder().setProjectId("test-project").build()); + when(mockBigQuery.getTable(any(TableId.class))).thenReturn(mock(Table.class)); + when(mockWriter.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.getDefaultInstance())); + + plugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter(BigQueryLoggerConfig config) { + return mockWriter; + } + }; + + Session session = Session.builder("session_id").build(); + when(mockInvocationContext.session()).thenReturn(session); + when(mockInvocationContext.invocationId()).thenReturn("invocation_id"); + when(mockInvocationContext.agent()).thenReturn(fakeAgent); + when(mockInvocationContext.userId()).thenReturn("user_id"); + } + + @Test + public void onUserMessageCallback_appendsToWriter() throws Exception { + Content content = Content.builder().build(); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void beforeRunCallback_appendsToWriter() throws Exception { + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void afterRunCallback_flushesAndAppends() throws Exception { + System.out.println("flushLock1: " + plugin.batchProcessor.flushLock.get()); + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + @Test + public void getStreamName_returnsCorrectFormat() { + BigQueryLoggerConfig config = + BigQueryLoggerConfig.builder() + .setProjectId("test-project") + .setDatasetId("test-dataset") + .setTableName("test-table") + .build(); + + String streamName = plugin.getStreamName(config); + + assertEquals( + "projects/test-project/datasets/test-dataset/tables/test-table/streams/_default", + streamName); + } + + @Test + public void formatContentParts_populatesCorrectFields() { + Content content = Content.fromParts(Part.fromText("hello")); + ArrayNode nodes = JsonFormatter.formatContentParts(Optional.of(content), 100); + assertEquals(1, nodes.size()); + ObjectNode node = (ObjectNode) nodes.get(0); + assertEquals(0, node.get("part_index").asInt()); + assertEquals("INLINE", node.get("storage_mode").asText()); + assertEquals("hello", node.get("text").asText()); + assertEquals("text/plain", node.get("mime_type").asText()); + } + + @Test + public void arrowSchema_hasJsonMetadata() { + Schema schema = BigQuerySchema.getArrowSchema(); + Field contentField = schema.findField("content"); + assertNotNull(contentField); + assertEquals("google:sqlType:json", contentField.getMetadata().get("ARROW:extension:name")); + } + + @Test + public void complexType_appendsToWriter() throws Exception { + Part part = Part.fromText("test text"); + Content content = Content.fromParts(part); + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + plugin.batchProcessor.flush(); + + verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); + } + + private static class FakeAgent extends BaseAgent { + FakeAgent(String name) { + super(name, "description", null, null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return Flowable.empty(); + } + } +} diff --git a/pom.xml b/pom.xml index d3f2ba432..e5de92d1b 100644 --- a/pom.xml +++ b/pom.xml @@ -323,6 +323,8 @@ + + @{jacoco.agent.argLine} --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED plain