From 59eda98eaecf69ee092bf954da3fae75eb090f2b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 9 Jan 2026 16:57:36 -0800 Subject: [PATCH 1/3] feat: add SpannerVectorStore for orchestrating and providing utility functions for a Spanner vector store PiperOrigin-RevId: 854392465 --- src/google/adk/features/_feature_registry.py | 4 + src/google/adk/tools/spanner/settings.py | 129 +++- src/google/adk/tools/spanner/utils.py | 618 ++++++++++++++++++- tests/unittests/tools/spanner/test_utils.py | 384 ++++++++++++ 4 files changed, 1121 insertions(+), 14 deletions(-) create mode 100644 tests/unittests/tools/spanner/test_utils.py diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 2ab0130639..c6584e85df 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -41,6 +41,7 @@ class FeatureName(str, Enum): PUBSUB_TOOLSET = "PUBSUB_TOOLSET" SPANNER_TOOLSET = "SPANNER_TOOLSET" SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS" + SPANNER_VECTOR_STORE = "SPANNER_VECTOR_STORE" TOOL_CONFIG = "TOOL_CONFIG" TOOL_CONFIRMATION = "TOOL_CONFIRMATION" @@ -120,6 +121,9 @@ class FeatureConfig: FeatureName.SPANNER_TOOL_SETTINGS: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), + FeatureName.SPANNER_VECTOR_STORE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.TOOL_CONFIG: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py index ae7f6371aa..6ca693b235 100644 --- a/src/google/adk/tools/spanner/settings.py +++ b/src/google/adk/tools/spanner/settings.py @@ -55,6 +55,74 @@ class QueryResultMode(Enum): """ +class TableColumn(BaseModel): + """Represents column configuration, to be used as part of create DDL statement for a new vector store table set up.""" + + name: str + """Required. The name of the column.""" + + type: str + """Required. The type of the column. + + For example, + + - GoogleSQL: 'STRING(MAX)', 'INT64', 'FLOAT64', 'BOOL', etc. + - PostgreSQL: 'text', 'int8', 'float8', 'boolean', etc. + """ + + is_nullable: bool = True + """Optional. Whether the column is nullable. By default, the column is nullable.""" + + +class VectorSearchIndexSettings(BaseModel): + """Settings for the index for use with Approximate Nearest Neighbor (ANN) vector similarity search.""" + + index_name: str + """Required. The name of the vector similarity search index.""" + + additional_key_columns: Optional[list[str]] = None + """Optional. The list of the additional key column names in the vector similarity search index. + + To further speed up filtering for highly selective filtering columns, organize + them as additional keys in the vector index after the embedding column. + For example: `category` as additional key column. + `CREATE VECTOR INDEX ON documents(embedding, category);` + """ + + additional_storing_columns: Optional[list[str]] = None + """Optional. The list of the storing column names in the vector similarity search index. + + This enables filtering while walking the vector index, removing unqualified + rows early. + For example: `category` as storing column. + `CREATE VECTOR INDEX ON documents(embedding) STORING (category);` + """ + + tree_depth: int = 2 + """Required. The tree depth (level). This value can be either 2 or 3. + + A tree with 2 levels only has leaves (num_leaves) as nodes. + If the dataset has more than 100 million rows, + then you can use a tree with 3 levels and add branches (num_branches) to + further partition the dataset. + """ + + num_leaves: int = 1000 + """Required. The number of leaves (i.e. potential partitions) for the vector data. + + You can designate num_leaves for trees with 2 or 3 levels. + We recommend that the number of leaves is number_of_rows_in_dataset/1000. + """ + + num_branches: Optional[int] = None + """Optional. The number of branches to further parititon the vector data. + + You can only designate num_branches for trees with 3 levels. + The number of branches must be fewer than the number of leaves + We recommend that the number of leaves is between 1000 and sqrt(number_of_rows_in_dataset). + """ + + class SpannerVectorStoreSettings(BaseModel): """Settings for Spanner Vector Store. @@ -86,18 +154,19 @@ class SpannerVectorStoreSettings(BaseModel): vertex_ai_embedding_model_name: str """Required. The Vertex AI embedding model name, which is used to generate embeddings for vector store and vector similarity search. - For example, 'text-embedding-005'. - Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field. - Otherwise, a runtime error might be raised during a query. + For example, 'text-embedding-005'. + + Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field. + Otherwise, a runtime error might be raised during a query. """ - selected_columns: List[str] = [] + selected_columns: list[str] = [] """Required. The vector store table columns to return in the vector similarity search result. - By default, only the `content_column` value and the distance value are returned. - If sepecified, the list of selected columns and the distance value are returned. - For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. + By default, only the `content_column` value and the distance value are returned. + If sepecified, the list of selected columns and the distance value are returned. + For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. """ nearest_neighbors_algorithm: NearestNeighborsAlgorithm = ( @@ -105,8 +174,8 @@ class SpannerVectorStoreSettings(BaseModel): ) """The algorithm used to perform vector similarity search. This value can be EXACT_NEAREST_NEIGHBORS or APPROXIMATE_NEAREST_NEIGHBORS. - For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors - For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors + For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors + For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors """ top_k: int = 4 @@ -118,16 +187,41 @@ class SpannerVectorStoreSettings(BaseModel): num_leaves_to_search: Optional[int] = None """Optional. This option specifies how many leaf nodes of the index are searched. - Note: this option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. - For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices + Note: This option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. + For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices """ additional_filter: Optional[str] = None """Optional. An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query.""" + vector_search_index_settings: Optional[VectorSearchIndexSettings] = None + """Optional. Settings for the index for use with Approximate Nearest Neighbor (ANN) in the vector store. + + Note: This option is only required when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. + For more details, see https://docs.cloud.google.com/spanner/docs/vector-indexes + """ + + additional_columns_to_setup: Optional[list[TableColumn]] = None + """Optional. A list of supplemental columns to be created when initializing a new vector store table or inserting content rows. + + Note: This configuration is only utilized during the initial table setup + or when inserting content rows. + """ + + primary_key_columns: Optional[list[str]] = None + """Optional. Specifies the column names to be used as the primary key for a new vector store table. + + If provided, every column name listed here must be defined within + `additional_columns_to_setup`. If this field is omitted (set to `None`), + defaults to a single primary key column named `id` which automatically + generates UUIDs for each entry. + + Note: This field is only used during the creation phase of a new vector store. + """ + @model_validator(mode="after") def __post_init__(self): - """Validate the embedding settings.""" + """Validate the vector store settings.""" if not self.vector_length or self.vector_length <= 0: raise ValueError( "Invalid vector length in the Spanner vector store settings." @@ -136,6 +230,17 @@ def __post_init__(self): if not self.selected_columns: self.selected_columns = [self.content_column] + if self.primary_key_columns: + cols = {self.content_column, self.embedding_column} + if self.additional_columns_to_setup: + cols.update({c.name for c in self.additional_columns_to_setup}) + + for pk in self.primary_key_columns: + if pk not in cols: + raise ValueError( + f"Primary key column '{pk}' not found in column definitions." + ) + return self diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/tools/spanner/utils.py index adde521954..ff2531aaf5 100644 --- a/src/google/adk/tools/spanner/utils.py +++ b/src/google/adk/tools/spanner/utils.py @@ -14,16 +14,31 @@ from __future__ import annotations +import asyncio +import itertools import json +import logging +from typing import Generator +from typing import Iterable from typing import Optional +from typing import TYPE_CHECKING from google.auth.credentials import Credentials from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from . import client +from ...features import experimental +from ...features import FeatureName from ..tool_context import ToolContext from .settings import QueryResultMode from .settings import SpannerToolSettings +from .settings import SpannerVectorStoreSettings + +if TYPE_CHECKING: + from google.cloud import spanner + from google.genai import Client + +logger = logging.getLogger("google_adk." + __name__) DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50 @@ -115,17 +130,18 @@ def embed_contents( vertex_ai_embedding_model_name: str, contents: list[str], output_dimensionality: Optional[int] = None, + genai_client: Client | None = None, ) -> list[list[float]]: """Embed the given contents into list of vectors using the Vertex AI embedding model endpoint.""" try: from google.genai import Client from google.genai.types import EmbedContentConfig - client = Client() + genai_client = genai_client or Client() config = EmbedContentConfig() if output_dimensionality: config.output_dimensionality = output_dimensionality - response = client.models.embed_content( + response = genai_client.models.embed_content( model=vertex_ai_embedding_model_name, contents=contents, config=config, @@ -133,3 +149,601 @@ def embed_contents( return [list(e.values) for e in response.embeddings] except Exception as ex: raise RuntimeError(f"Failed to embed content: {ex!r}") from ex + + +async def embed_contents_async( + vertex_ai_embedding_model_name: str, + contents: list[str], + output_dimensionality: Optional[int] = None, + genai_client: Client | None = None, +) -> list[list[float]]: + """Embed the given contents into list of vectors using the Vertex AI embedding model endpoint.""" + try: + from google.genai import Client + from google.genai.types import EmbedContentConfig + + genai_client = genai_client or Client() + config = EmbedContentConfig() + if output_dimensionality: + config.output_dimensionality = output_dimensionality + response = await genai_client.aio.models.embed_content( + model=vertex_ai_embedding_model_name, + contents=contents, + config=config, + ) + return [list(e.values) for e in response.embeddings] + except Exception as ex: + raise RuntimeError(f"Failed to embed content: {ex!r}") from ex + + +@experimental(FeatureName.SPANNER_VECTOR_STORE) +class SpannerVectorStore: + """A class for orchestrating and providing utility functions for a Spanner vector store. + + This class provides utility functions for setting up and adding contents to a + vector store table in a Google Cloud Spanner database, based on the given + Spanner tool settings. + """ + + DEFAULT_VECTOR_STORE_ID_COLUMN_NAME = "id" + SPANNER_VECTOR_STORE_USER_AGENT = "adk-spanner-vector-store" + + def __init__( + self, + settings: SpannerToolSettings, + credentials: Credentials | None = None, + spanner_client: spanner.Client | None = None, + genai_client: Client | None = None, + ): + """Initializes the SpannerVectorStore with validated settings and clients. + + This constructor sets up the connection to a specific Spanner database and + configures the necessary clients for vector operations. + + Args: + settings (SpannerToolSettings): The settings for the tool. + credentials (Credentials | None): Credentials for Spanner operations. This + is used to initialize a new Spanner client only if `spanner_client` + is not explicitly provided. + spanner_client (spanner.Client | None): An pre-configured `spanner.Client` + instance. If not provided, a new client will be created. + genai_client (Client | None): Google GenAI client used for + generating vector embeddings. + """ + + if not settings.vector_store_settings: + raise ValueError("Spanner vector store settings are not set.") + + self._settings = settings + + if not spanner_client: + self._spanner_client = client.get_spanner_client( + project=self._vector_store_settings.project_id, + credentials=credentials, + ) + else: + self._spanner_client = spanner_client + client_user_agent = self._spanner_client._client_info.user_agent + if not client_user_agent: + self._spanner_client._client_info.user_agent = client.USER_AGENT + elif client.USER_AGENT not in client_user_agent: + self._spanner_client._client_info.user_agent = " ".join( + [client_user_agent, client.USER_AGENT] + ) + self._spanner_client._client_info.user_agent = " ".join([ + self._spanner_client._client_info.user_agent, + self.SPANNER_VECTOR_STORE_USER_AGENT, + ]) + + instance = self._spanner_client.instance( + self._vector_store_settings.instance_id + ) + if not instance.exists(): + raise ValueError( + "Instance id {} doesn't exist.".format( + self._vector_store_settings.instance_id + ) + ) + self._database = instance.database(self._vector_store_settings.database_id) + if not self._database.exists(): + raise ValueError( + "Database id {} doesn't exist.".format( + self._vector_store_settings.database_id + ) + ) + + self._genai_client = genai_client + + @property + def _vector_store_settings(self) -> SpannerVectorStoreSettings: + """Returns the Spanner vector store settings.""" + + if self._settings.vector_store_settings is None: + raise ValueError("Spanner vector store settings are not set.") + return self._settings.vector_store_settings + + def _create_vector_store_table_ddl( + self, + dialect: DatabaseDialect, + ) -> str: + """Creates the DDL statements necessary to define a vector store table in Spanner. + + The vector store table is created based on the given settings. + - **id_column** (STRING or text): The default primary key, typically a UUID. + Note: This column is only included in the DDL when `primary_key_columns` + is not specified in the settings. + - **content_column** (STRING or text): The source text content used to + generate the embedding. + - **embedding_column** (ARRAY or float4[]): The vector embedding + column corresponding to the content. + - **additional_columns_to_setup** (provided in the settings): Additional + columns to be added to the vector store table. + + Args: + dialect: The database dialect (e.g., GOOGLE_STANDARD_SQL or POSTGRESQL) + governing the DDL syntax. + + Returns: + A DDL statement string defining the vector store table. + """ + + primary_key_columns = self._vector_store_settings.primary_key_columns or [ + self.DEFAULT_VECTOR_STORE_ID_COLUMN_NAME + ] + + column_definitions = [] + + if self._vector_store_settings.primary_key_columns is None: + if dialect == DatabaseDialect.POSTGRESQL: + column_definitions.append( + f"{self.DEFAULT_VECTOR_STORE_ID_COLUMN_NAME} varchar(36) DEFAULT" + " spanner.generate_uuid()" + ) + else: + column_definitions.append( + f"{self.DEFAULT_VECTOR_STORE_ID_COLUMN_NAME} STRING(36) DEFAULT" + " (GENERATE_UUID())" + ) + + # Additional Columns + if self._vector_store_settings.additional_columns_to_setup: + for column in self._vector_store_settings.additional_columns_to_setup: + null_stmt = "" if column.is_nullable else " NOT NULL" + column_definitions.append(f"{column.name} {column.type}{null_stmt}") + + # Content and Embedding Columns + if dialect == DatabaseDialect.POSTGRESQL: + column_definitions.append( + f"{self._vector_store_settings.content_column} text" + ) + column_definitions.append( + f"{self._vector_store_settings.embedding_column} float4[] " + f"VECTOR LENGTH {self._vector_store_settings.vector_length}" + ) + else: + column_definitions.append( + f"{self._vector_store_settings.content_column} STRING(MAX)" + ) + column_definitions.append( + f"{self._vector_store_settings.embedding_column} " + f"ARRAY(vector_length=>{self._vector_store_settings.vector_length})" + ) + + inner_ddl = ",\n ".join(column_definitions) + pk_stmt = ", ".join(primary_key_columns) + + if dialect == DatabaseDialect.POSTGRESQL: + return ( + f"CREATE TABLE IF NOT EXISTS {self._vector_store_settings.table_name}" + f" (\n {inner_ddl},\n PRIMARY KEY({pk_stmt})\n)" + ) + else: + return ( + f"CREATE TABLE IF NOT EXISTS {self._vector_store_settings.table_name}" + f" (\n {inner_ddl}\n) PRIMARY KEY({pk_stmt})" + ) + + def _create_ann_vector_search_index_ddl( + self, + dialect: DatabaseDialect, + ) -> str: + """Create a DDL statement to create a vector search index for ANN. + + Args: + dialect: The database dialect (e.g., GOOGLE_STANDARD_SQL or POSTGRESQL) + governing the DDL syntax. + + Returns: + A DDL statement string to create the vector search index. + """ + + # This is only required when the nearest neighbors search algorithm is + # APPROXIMATE_NEAREST_NEIGHBORS. + if not self._vector_store_settings.vector_search_index_settings: + raise ValueError("Vector search index settings are not set.") + + if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: + raise ValueError( + "ANN is only supported for the Google Standard SQL dialect." + ) + + index_columns = [self._vector_store_settings.embedding_column] + if ( + self._vector_store_settings.vector_search_index_settings.additional_key_columns + ): + index_columns.extend( + self._vector_store_settings.vector_search_index_settings.additional_key_columns + ) + + statement = ( + "CREATE VECTOR INDEX IF NOT EXISTS" + f" {self._vector_store_settings.vector_search_index_settings.index_name}\n\tON" + f" {self._vector_store_settings.table_name}({', '.join(index_columns)})" + ) + + if ( + self._vector_store_settings.vector_search_index_settings.additional_storing_columns + ): + statement += ( + "\n\tSTORING" + f" ({', '.join(self._vector_store_settings.vector_search_index_settings.additional_storing_columns)})" + ) + + statement += ( + f"\n\tWHERE {self._vector_store_settings.embedding_column} IS NOT NULL" + ) + + options_segments = [ + f"distance_type='{self._vector_store_settings.distance_type}'" + ] + + if ( + getattr( + self._vector_store_settings.vector_search_index_settings, + "tree_depth", + 0, + ) + > 0 + ): + tree_depth = ( + self._vector_store_settings.vector_search_index_settings.tree_depth + ) + if tree_depth not in (2, 3): + raise ValueError( + f"Vector search index settings: tree_depth: {tree_depth} must be" + " either 2 or 3" + ) + options_segments.append( + f"tree_depth={self._vector_store_settings.vector_search_index_settings.tree_depth}" + ) + + if ( + self._vector_store_settings.vector_search_index_settings.num_branches + is not None + and self._vector_store_settings.vector_search_index_settings.num_branches + > 0 + ): + options_segments.append( + f"num_branches={self._vector_store_settings.vector_search_index_settings.num_branches}" + ) + + if self._vector_store_settings.vector_search_index_settings.num_leaves > 0: + options_segments.append( + f"num_leaves={self._vector_store_settings.vector_search_index_settings.num_leaves}" + ) + + statement += "\n\tOPTIONS(" + ", ".join(options_segments) + ")" + + return statement.strip() + + def create_vector_store(self): + """Creates a new vector store within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + try: + ddl = self._create_vector_store_table_ddl(self._database.database_dialect) + logger.debug( + "Executing DDL statement to create vector store table: %s", ddl + ) + operation = self._database.update_ddl([ddl]) + + # Wait for completion + logger.info("Waiting for update database operation to complete...") + operation.result() + + logger.debug( + "Successfully created the vector store table: %s in Spanner" + " database: projects/%s/instances/%s/databases/%s", + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + except Exception as e: + logger.error("Failed to create the vector store. Error: %s", e) + raise + + def create_vector_search_index(self): + """Creates a vector search index within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + try: + if not self._vector_store_settings.vector_search_index_settings: + logger.warning("No vector search index settings found.") + return + + ddl = self._create_ann_vector_search_index_ddl( + self._database.database_dialect + ) + logger.debug( + "Executing DDL statement to create vector search index: %s", ddl + ) + operation = self._database.update_ddl([ddl]) + + # Wait for completion + logger.info("Waiting for update database operation to complete...") + operation.result() + + logger.debug( + "Successfully created the vector search index: %s in Spanner" + " database: projects/%s/instances/%s/databases/%s", + self._vector_store_settings.vector_search_index_settings.index_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + + except Exception as e: + logger.error("Failed to create the vector search index. Error: %s", e) + raise + + async def create_vector_store_async(self): + """Asynchronously creates a new vector store within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + await asyncio.to_thread(self.create_vector_store) + + async def create_vector_search_index_async(self): + """Asynchronously creates a vector search index within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + await asyncio.to_thread(self.create_vector_search_index) + + def _prepare_and_validate_batches( + self, + contents: Iterable[str], + additional_columns_values: Iterable[dict] | None, + batch_size: int, + ) -> Generator[tuple[list[str], list[dict], int], None, None]: + """Prepares and validates batches of contents and additional columns for insertion into the vector store.""" + content_iter = iter(contents) + + value_iter = ( + iter(additional_columns_values) + if additional_columns_values is not None + else itertools.repeat({}) + ) + + batches = iter(lambda: list(itertools.islice(content_iter, batch_size)), []) + + for index, content_batch in enumerate(batches): + actual_index = index * batch_size + value_batch = list(itertools.islice(value_iter, len(content_batch))) + + if len(value_batch) < len(content_batch): + raise ValueError( + f"Data mismatch: ended at index {actual_index}. Expected" + f" {len(content_batch)} values for this batch, but got" + f" {len(value_batch)}." + ) + + yield (content_batch, value_batch, actual_index) + + if additional_columns_values is not None: + if next(value_iter, None) is not None: + raise ValueError( + "additional_columns_values contains more items than contents." + ) + + def add_contents( + self, + contents: Iterable[str], + *, + additional_columns_values: Iterable[dict] | None = None, + batch_size: int = 200, + ): + """Adds text contents to the vector store. + + Performs batch embedding generation and subsequent insertion of the contents + into the vector store table in the Google Cloud Spanner database. + + Args: + contents (Iterable[str]): An iterable collection of string contents to + be added to the vector store. + additional_columns_values (Iterable[dict] | None): An optional iterable + of dictionary containing values for additional columns to be stored + with the content row. Keys must match column names. + batch_size (int): The maximum number of items to process and insert in a + single batch. Defaults to 200. + """ + total_rows = 0 + try: + self._database.reload() + + cols = [ + c.name + for c in self._vector_store_settings.additional_columns_to_setup or [] + ] + + batch_gen = self._prepare_and_validate_batches( + contents, additional_columns_values, batch_size + ) + + for content_b, extra_b, batch_index in batch_gen: + logger.debug( + "Embedding content batch %d to %d (size: %d)...", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + embeddings = embed_contents( + self._vector_store_settings.vertex_ai_embedding_model_name, + content_b, + self._vector_store_settings.vector_length, + self._genai_client, + ) + + logger.debug( + "Committing batch mutation %d to %d (size: %d).", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + mutation_rows = [ + # [content, embedding, ...additional_columns] + [c, e, *map(extra.get, cols)] + for c, e, extra in zip(content_b, embeddings, extra_b) + ] + with self._database.batch() as batch: + batch.insert_or_update( + table=self._vector_store_settings.table_name, + columns=[ + self._vector_store_settings.content_column, + self._vector_store_settings.embedding_column, + ] + + cols, + values=mutation_rows, + ) + + total_rows += len(mutation_rows) + + logger.debug( + "Successfully added %d contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s", + total_rows, + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + + except Exception as e: + logger.error( + "Failed to finish adding contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s. Total" + " rows added: %d. Error: %s", + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + total_rows, + e, + ) + raise + + async def add_contents_async( + self, + contents: Iterable[str], + *, + additional_columns_values: Iterable[dict] | None = None, + batch_size: int = 200, + ): + """Asynchronously adds text contents to the vector store. + + Performs batch embedding generation and subsequent insertion of the contents + into the vector store table in the Google Cloud Spanner database. + + Args: + contents (Iterable[str]): An iterable collection of string contents to + be added to the vector store. + additional_columns_values (Iterable[dict] | None): An optional iterable + of dictionary containing values for additional columns to be stored + with the content row. Keys must match column names. + batch_size (int): The maximum number of items to process and insert in a + single batch. Defaults to 200. + """ + total_rows = 0 + try: + await asyncio.to_thread(self._database.reload) + + cols = [ + c.name + for c in self._vector_store_settings.additional_columns_to_setup or [] + ] + + batch_gen = self._prepare_and_validate_batches( + contents, additional_columns_values, batch_size + ) + + for content_b, extra_b, batch_index in batch_gen: + logger.debug( + "Embedding content batch %d to %d (size: %d)...", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + embeddings = await embed_contents_async( + self._vector_store_settings.vertex_ai_embedding_model_name, + content_b, + self._vector_store_settings.vector_length, + self._genai_client, + ) + + logger.debug( + "Committing batch mutation %d to %d (size: %d).", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + mutation_rows = [ + # [content, embedding, ...additional_columns] + [c, e, *map(extra.get, cols)] + for c, e, extra in zip(content_b, embeddings, extra_b) + ] + + def _commit_batch(columns, rows_to_commit): + with self._database.batch() as batch: + batch.insert_or_update( + table=self._vector_store_settings.table_name, + columns=[ + self._vector_store_settings.content_column, + self._vector_store_settings.embedding_column, + ] + + columns, + values=rows_to_commit, + ) + + await asyncio.to_thread(_commit_batch, cols, mutation_rows) + total_rows += len(mutation_rows) + + logger.debug( + "Successfully added %d contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s", + total_rows, + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + + except Exception as e: + logger.error( + "Failed to finish adding contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s. Total" + " rows added: %d. Error: %s", + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + total_rows, + e, + ) + raise diff --git a/tests/unittests/tools/spanner/test_utils.py b/tests/unittests/tools/spanner/test_utils.py new file mode 100644 index 0000000000..56c04a2dce --- /dev/null +++ b/tests/unittests/tools/spanner/test_utils.py @@ -0,0 +1,384 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.tools.spanner import utils as spanner_utils +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings +from google.adk.tools.spanner.settings import TableColumn +from google.adk.tools.spanner.settings import VectorSearchIndexSettings +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_v1 import batch as spanner_batch +from google.cloud.spanner_v1 import client as spanner_client_v1 +from google.cloud.spanner_v1 import database as spanner_database +from google.cloud.spanner_v1 import instance as spanner_instance +import pytest + + +@pytest.fixture +def vector_store_settings(): + """Fixture for SpannerVectorStoreSettings.""" + return SpannerVectorStoreSettings( + project_id="test-project", + instance_id="test-instance", + database_id="test-database", + table_name="test_vector_store", + content_column="content", + embedding_column="embedding", + vector_length=768, + vertex_ai_embedding_model_name="textembedding", + ) + + +@pytest.fixture +def spanner_tool_settings(vector_store_settings): + """Fixture for SpannerToolSettings.""" + return SpannerToolSettings(vector_store_settings=vector_store_settings) + + +@pytest.fixture +def mock_spanner_database(): + """Fixture for a mocked spanner database.""" + mock_database = mock.create_autospec(spanner_database.Database, instance=True) + mock_database.exists.return_value = True + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + return mock_database + + +@pytest.fixture +def mock_spanner_instance(mock_spanner_database): + """Fixture for a mocked spanner instance.""" + mock_instance = mock.create_autospec(spanner_instance.Instance, instance=True) + mock_instance.exists.return_value = True + mock_instance.database.return_value = mock_spanner_database + return mock_instance + + +@pytest.fixture +def mock_spanner_client(mock_spanner_instance): + """Fixture for a mocked spanner client.""" + mock_client = mock.create_autospec(spanner_client_v1.Client, instance=True) + mock_client.instance.return_value = mock_spanner_instance + mock_client._client_info = mock.Mock(user_agent="test-agent") + return mock_client + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_successful( + mock_embed_contents, + spanner_tool_settings, + mock_spanner_client, + mock_spanner_database, + mocker, +): + """Test that add_contents successfully adds content.""" + mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]] + mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True) + mock_batch.__enter__.return_value = mock_batch + mock_spanner_database.batch.return_value = mock_batch + + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store._database = mock_spanner_database + contents = ["content1", "content2"] + vector_store.add_contents(contents=contents) + + mock_spanner_database.reload.assert_called_once() + mock_spanner_database.batch.assert_called_once() + mock_batch.insert_or_update.assert_called_once_with( + table="test_vector_store", + columns=["content", "embedding"], + values=[ + ["content1", [1.0, 2.0]], + ["content2", [3.0, 4.0]], + ], + ) + mock_embed_contents.assert_called_once_with( + "textembedding", contents, 768, mock.ANY + ) + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_with_metadata( + mock_embed_contents, + spanner_tool_settings, + mock_spanner_client, + mock_spanner_database, + mocker, +): + """Test that add_contents successfully adds content with metadata.""" + mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]] + mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True) + mock_batch.__enter__.return_value = mock_batch + mock_spanner_database.batch.return_value = mock_batch + spanner_tool_settings.vector_store_settings.additional_columns_to_setup = [ + TableColumn(name="metadata", type="JSON") + ] + + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store._database = mock_spanner_database + contents = ["content1", "content2"] + additional_columns_values = [ + {"metadata": {"meta1": "val1"}}, + {"metadata": {"meta2": "val2"}}, + ] + vector_store.add_contents( + contents=contents, + additional_columns_values=additional_columns_values, + ) + + mock_spanner_database.batch.assert_called_once() + mock_batch.insert_or_update.assert_called_once_with( + table="test_vector_store", + columns=["content", "embedding", "metadata"], + values=[ + ["content1", [1.0, 2.0], {"meta1": "val1"}], + ["content2", [3.0, 4.0], {"meta2": "val2"}], + ], + ) + + +def test_add_contents_empty_contents( + spanner_tool_settings, mock_spanner_client, mock_spanner_database +): + """Test that add_contents does nothing when contents is empty.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.add_contents(contents=[]) + mock_spanner_database.batch.assert_not_called() + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_additional_columns_list_mismatch( + mock_embed_contents, spanner_tool_settings, mock_spanner_client +): + """Test that add_contents raises an error if additional_columns_values and contents lengths differ.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises( + ValueError, + match="additional_columns_values contains more items than contents.", + ): + vector_store.add_contents( + contents=["content1"], + additional_columns_values=[ + {"col1": "val1"}, + {"col1": "val2"}, + ], + ) + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_embedding_fails( + mock_embed_contents, spanner_tool_settings, mock_spanner_client +): + """Test that add_contents fails if embedding fails.""" + mock_embed_contents.side_effect = RuntimeError("Embedding failed") + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises(RuntimeError, match="Embedding failed"): + vector_store.add_contents(contents=["content1", "content2"]) + + +def test_init_raises_error_if_vector_store_settings_not_set(): + """Test that SpannerVectorStore raises an error if vector_store_settings is not set.""" + settings = SpannerToolSettings() + with pytest.raises( + ValueError, match="Spanner vector store settings are not set." + ): + spanner_utils.SpannerVectorStore(settings) + + +@pytest.mark.parametrize( + "dialect, expected_ddl", + [ + ( + DatabaseDialect.GOOGLE_STANDARD_SQL, + ( + "CREATE TABLE IF NOT EXISTS test_vector_store (\n" + " id STRING(36) DEFAULT (GENERATE_UUID()),\n" + " content STRING(MAX),\n" + " embedding ARRAY(vector_length=>768)\n" + ") PRIMARY KEY(id)" + ), + ), + ( + DatabaseDialect.POSTGRESQL, + ( + "CREATE TABLE IF NOT EXISTS test_vector_store (\n" + " id varchar(36) DEFAULT spanner.generate_uuid(),\n" + " content text,\n" + " embedding float4[] VECTOR LENGTH 768,\n" + " PRIMARY KEY(id)\n" + ")" + ), + ), + ], +) +def test_create_vector_store_table_ddl( + spanner_tool_settings, mock_spanner_client, dialect, expected_ddl +): + """Test DDL creation for different SQL dialects.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + ddl = vector_store._create_vector_store_table_ddl(dialect) + assert ddl == expected_ddl + + +def test_create_ann_vector_search_index_ddl_raises_error_for_postgresql( + spanner_tool_settings, vector_store_settings, mock_spanner_client +): + """Test that creating an ANN index raises an error for PostgreSQL.""" + vector_store_settings.vector_search_index_settings = mock.Mock() + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises( + ValueError, + match="ANN is only supported for the Google Standard SQL dialect.", + ): + vector_store._create_ann_vector_search_index_ddl( + DatabaseDialect.POSTGRESQL + ) + + +def test_create_vector_store( + spanner_tool_settings, mock_spanner_client, mock_spanner_database +): + """Test the vector store creation process.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.create_vector_store() + mock_spanner_database.update_ddl.assert_called_once() + ddl_statement = mock_spanner_database.update_ddl.call_args[0][0] + assert "CREATE TABLE IF NOT EXISTS test_vector_store" in ddl_statement[0] + + +def test_create_vector_search_index_no_settings( + spanner_tool_settings, mock_spanner_client, mock_spanner_database +): + """Test that create_vector_search_index does nothing if settings are not present.""" + spanner_tool_settings.vector_store_settings.vector_search_index_settings = ( + None + ) + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.create_vector_search_index() + mock_spanner_database.update_ddl.assert_not_called() + + +def test_create_vector_search_index_successful_google_sql( + spanner_tool_settings, + vector_store_settings, + mock_spanner_client, + mock_spanner_database, +): + """Test that create_vector_search_index successfully creates index for Google SQL.""" + mock_spanner_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + vector_store_settings.vector_search_index_settings = ( + VectorSearchIndexSettings( + index_name="test_vector_index", + tree_depth=3, + num_branches=10, + num_leaves=20, + ) + ) + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.create_vector_search_index() + mock_spanner_database.update_ddl.assert_called_once() + ddl_statement = mock_spanner_database.update_ddl.call_args[0][0] + expected_ddl = ( + "CREATE VECTOR INDEX IF NOT EXISTS test_vector_index\n" + "\tON test_vector_store(embedding)\n" + "\tWHERE embedding IS NOT NULL\n" + "\tOPTIONS(distance_type='COSINE', tree_depth=3, num_branches=10, " + "num_leaves=20)" + ) + assert ddl_statement[0] == expected_ddl + + +def test_create_vector_search_index_fails( + spanner_tool_settings, + vector_store_settings, + mock_spanner_client, + mock_spanner_database, +): + """Test that create_vector_search_index raises an error if DDL execution fails.""" + mock_spanner_database.update_ddl.side_effect = RuntimeError("DDL failed") + vector_store_settings.vector_search_index_settings = ( + VectorSearchIndexSettings(index_name="test_vector_index") + ) + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises(RuntimeError, match="DDL failed"): + vector_store.create_vector_search_index() From 2bd984adb3c553438fa1fe9ac05b3f96bd508d5f Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 9 Jan 2026 17:31:42 -0800 Subject: [PATCH 2/3] feat: Add option to send full history to stateless RemoteA2aAgents Introduces `full_history_when_stateless` to RemoteA2aAgent. When True, stateless agents will receive all session events on each request, instead of only events since their last reply. This allows stateless agents to have access to the complete conversation history. PiperOrigin-RevId: 854400798 --- src/google/adk/agents/remote_a2a_agent.py | 15 +- .../unittests/agents/test_remote_a2a_agent.py | 151 ++++++++++++++++++ 2 files changed, 165 insertions(+), 1 deletion(-) diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 10de7c14d3..167328847c 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -126,6 +126,7 @@ def __init__( a2a_request_meta_provider: Optional[ Callable[[InvocationContext, A2AMessage], dict[str, Any]] ] = None, + full_history_when_stateless: bool = False, **kwargs: Any, ) -> None: """Initialize RemoteA2aAgent. @@ -142,6 +143,10 @@ def __init__( a2a_request_meta_provider: Optional callable that takes InvocationContext and A2AMessage and returns a metadata object to attach to the A2A request. + full_history_when_stateless: If True, stateless agents (those that do not + return Tasks or context IDs) will receive all session events on every + request. If False, the default behavior of sending only events since the + last reply from the agent will be used. **kwargs: Additional arguments passed to BaseAgent Raises: @@ -168,6 +173,7 @@ def __init__( self._a2a_part_converter = a2a_part_converter self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory self._a2a_request_meta_provider = a2a_request_meta_provider + self._full_history_when_stateless = full_history_when_stateless # Validate and store agent card reference if isinstance(agent_card, AgentCard): @@ -365,7 +371,14 @@ def _construct_message_parts_from_session( if event.custom_metadata: metadata = event.custom_metadata context_id = metadata.get(A2A_METADATA_PREFIX + "context_id") - break + # Historical note: this behavior originally always applied, regardless + # of whether the agent was stateful or stateless. However, only stateful + # agents can be expected to have previous events in the remote session. + # For backwards compatibility, we maintain this behavior when + # _full_history_when_stateless is false (the default) or if the agent + # is stateful (i.e. returned a context ID). + if not self._full_history_when_stateless or context_id: + break events_to_process.append(event) for event in reversed(events_to_process): diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 8bd4a22f20..d395a5516f 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -665,6 +665,157 @@ def test_construct_message_parts_from_session_empty_events(self): assert parts == [] assert context_id is None + def test_construct_message_parts_from_session_stops_on_agent_reply(self): + """Test message parts construction stops on agent reply by default.""" + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = None + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 1 + assert parts[0].text == "User 2" + assert context_id is None + + def test_construct_message_parts_from_session_stateless_full_history(self): + """Test full history for stateless agent when enabled.""" + self.agent._full_history_when_stateless = True + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = None + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 3 + assert parts[0].text == "User 1" + assert parts[1].text == "Agent 1" + assert parts[2].text == "User 2" + assert context_id is None + + def test_construct_message_parts_from_session_stateful_partial_history(self): + """Test partial history for stateful agent when full history is enabled.""" + self.agent._full_history_when_stateless = True + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = {A2A_METADATA_PREFIX + "context_id": "ctx-1"} + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 1 + assert parts[0].text == "User 2" + assert context_id == "ctx-1" + @pytest.mark.asyncio async def test_handle_a2a_response_success_with_message(self): """Test successful A2A response handling with message.""" From 94d48fce32a1f07cef967d50e82f2b1975b4abd9 Mon Sep 17 00:00:00 2001 From: Dinesh Thumma <160909147+DineshThumma9@users.noreply.github.com> Date: Fri, 9 Jan 2026 18:20:20 -0800 Subject: [PATCH 3/3] fix: Database reserved keyword issue The fix will use quotes to escape "key", which is column name in the metadata table. Should work for different database types. Merge https://github.com/google/adk-python/pull/4106 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/4106 from DineshThumma9:fix/mysql-reserved-keyword-issue e39d0d02f3695d6890bc3267417b5dad58f7e8ee PiperOrigin-RevId: 854411915 --- .../sessions/migration/_schema_check_utils.py | 7 ++++++- .../sessions/migration/test_database_schema.py | 18 ++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/google/adk/sessions/migration/_schema_check_utils.py b/src/google/adk/sessions/migration/_schema_check_utils.py index 284f82afe7..249161c84c 100644 --- a/src/google/adk/sessions/migration/_schema_check_utils.py +++ b/src/google/adk/sessions/migration/_schema_check_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Database schema version check utility.""" + from __future__ import annotations import logging @@ -32,8 +33,11 @@ def _get_schema_version_impl(inspector, connection) -> str: """Gets DB schema version using inspector and connection.""" if inspector.has_table("adk_internal_metadata"): try: + key_col = inspector.dialect.identifier_preparer.quote("key") result = connection.execute( - text("SELECT value FROM adk_internal_metadata WHERE key = :key"), + text( + f"SELECT value FROM adk_internal_metadata WHERE {key_col} = :key" + ), {"key": SCHEMA_VERSION_KEY}, ).fetchone() if result: @@ -49,6 +53,7 @@ def _get_schema_version_impl(inspector, connection) -> str: e, ) raise + # Metadata table doesn't exist, check for v0 schema. # V0 schema has an 'events' table with an 'actions' column. if inspector.has_table("events"): diff --git a/tests/unittests/sessions/migration/test_database_schema.py b/tests/unittests/sessions/migration/test_database_schema.py index 239da2f1e2..d08bb97ba0 100644 --- a/tests/unittests/sessions/migration/test_database_schema.py +++ b/tests/unittests/sessions/migration/test_database_schema.py @@ -51,12 +51,18 @@ async def test_new_db_uses_latest_schema(tmp_path): lambda sync_conn: inspect(sync_conn).has_table('adk_internal_metadata') ) assert has_metadata_table - schema_version = await conn.run_sync( - lambda sync_conn: sync_conn.execute( - text('SELECT value FROM adk_internal_metadata WHERE key = :key'), - {'key': _schema_check_utils.SCHEMA_VERSION_KEY}, - ).scalar_one_or_none() - ) + + def get_schema_version(sync_conn): + inspector = inspect(sync_conn) + key_col = inspector.dialect.identifier_preparer.quote('key') + return sync_conn.execute( + text( + f'SELECT value FROM adk_internal_metadata WHERE {key_col} = :key' + ), + {'key': _schema_check_utils.SCHEMA_VERSION_KEY}, + ).scalar_one_or_none() + + schema_version = await conn.run_sync(get_schema_version) assert schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION # Verify events table columns for v1