diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 10de7c14d3..167328847c 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -126,6 +126,7 @@ def __init__( a2a_request_meta_provider: Optional[ Callable[[InvocationContext, A2AMessage], dict[str, Any]] ] = None, + full_history_when_stateless: bool = False, **kwargs: Any, ) -> None: """Initialize RemoteA2aAgent. @@ -142,6 +143,10 @@ def __init__( a2a_request_meta_provider: Optional callable that takes InvocationContext and A2AMessage and returns a metadata object to attach to the A2A request. + full_history_when_stateless: If True, stateless agents (those that do not + return Tasks or context IDs) will receive all session events on every + request. If False, the default behavior of sending only events since the + last reply from the agent will be used. **kwargs: Additional arguments passed to BaseAgent Raises: @@ -168,6 +173,7 @@ def __init__( self._a2a_part_converter = a2a_part_converter self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory self._a2a_request_meta_provider = a2a_request_meta_provider + self._full_history_when_stateless = full_history_when_stateless # Validate and store agent card reference if isinstance(agent_card, AgentCard): @@ -365,7 +371,14 @@ def _construct_message_parts_from_session( if event.custom_metadata: metadata = event.custom_metadata context_id = metadata.get(A2A_METADATA_PREFIX + "context_id") - break + # Historical note: this behavior originally always applied, regardless + # of whether the agent was stateful or stateless. However, only stateful + # agents can be expected to have previous events in the remote session. + # For backwards compatibility, we maintain this behavior when + # _full_history_when_stateless is false (the default) or if the agent + # is stateful (i.e. returned a context ID). + if not self._full_history_when_stateless or context_id: + break events_to_process.append(event) for event in reversed(events_to_process): diff --git a/src/google/adk/features/_feature_registry.py b/src/google/adk/features/_feature_registry.py index 2ab0130639..c6584e85df 100644 --- a/src/google/adk/features/_feature_registry.py +++ b/src/google/adk/features/_feature_registry.py @@ -41,6 +41,7 @@ class FeatureName(str, Enum): PUBSUB_TOOLSET = "PUBSUB_TOOLSET" SPANNER_TOOLSET = "SPANNER_TOOLSET" SPANNER_TOOL_SETTINGS = "SPANNER_TOOL_SETTINGS" + SPANNER_VECTOR_STORE = "SPANNER_VECTOR_STORE" TOOL_CONFIG = "TOOL_CONFIG" TOOL_CONFIRMATION = "TOOL_CONFIRMATION" @@ -120,6 +121,9 @@ class FeatureConfig: FeatureName.SPANNER_TOOL_SETTINGS: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), + FeatureName.SPANNER_VECTOR_STORE: FeatureConfig( + FeatureStage.EXPERIMENTAL, default_on=True + ), FeatureName.TOOL_CONFIG: FeatureConfig( FeatureStage.EXPERIMENTAL, default_on=True ), diff --git a/src/google/adk/sessions/migration/_schema_check_utils.py b/src/google/adk/sessions/migration/_schema_check_utils.py index 284f82afe7..249161c84c 100644 --- a/src/google/adk/sessions/migration/_schema_check_utils.py +++ b/src/google/adk/sessions/migration/_schema_check_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Database schema version check utility.""" + from __future__ import annotations import logging @@ -32,8 +33,11 @@ def _get_schema_version_impl(inspector, connection) -> str: """Gets DB schema version using inspector and connection.""" if inspector.has_table("adk_internal_metadata"): try: + key_col = inspector.dialect.identifier_preparer.quote("key") result = connection.execute( - text("SELECT value FROM adk_internal_metadata WHERE key = :key"), + text( + f"SELECT value FROM adk_internal_metadata WHERE {key_col} = :key" + ), {"key": SCHEMA_VERSION_KEY}, ).fetchone() if result: @@ -49,6 +53,7 @@ def _get_schema_version_impl(inspector, connection) -> str: e, ) raise + # Metadata table doesn't exist, check for v0 schema. # V0 schema has an 'events' table with an 'actions' column. if inspector.has_table("events"): diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py index ae7f6371aa..6ca693b235 100644 --- a/src/google/adk/tools/spanner/settings.py +++ b/src/google/adk/tools/spanner/settings.py @@ -55,6 +55,74 @@ class QueryResultMode(Enum): """ +class TableColumn(BaseModel): + """Represents column configuration, to be used as part of create DDL statement for a new vector store table set up.""" + + name: str + """Required. The name of the column.""" + + type: str + """Required. The type of the column. + + For example, + + - GoogleSQL: 'STRING(MAX)', 'INT64', 'FLOAT64', 'BOOL', etc. + - PostgreSQL: 'text', 'int8', 'float8', 'boolean', etc. + """ + + is_nullable: bool = True + """Optional. Whether the column is nullable. By default, the column is nullable.""" + + +class VectorSearchIndexSettings(BaseModel): + """Settings for the index for use with Approximate Nearest Neighbor (ANN) vector similarity search.""" + + index_name: str + """Required. The name of the vector similarity search index.""" + + additional_key_columns: Optional[list[str]] = None + """Optional. The list of the additional key column names in the vector similarity search index. + + To further speed up filtering for highly selective filtering columns, organize + them as additional keys in the vector index after the embedding column. + For example: `category` as additional key column. + `CREATE VECTOR INDEX ON documents(embedding, category);` + """ + + additional_storing_columns: Optional[list[str]] = None + """Optional. The list of the storing column names in the vector similarity search index. + + This enables filtering while walking the vector index, removing unqualified + rows early. + For example: `category` as storing column. + `CREATE VECTOR INDEX ON documents(embedding) STORING (category);` + """ + + tree_depth: int = 2 + """Required. The tree depth (level). This value can be either 2 or 3. + + A tree with 2 levels only has leaves (num_leaves) as nodes. + If the dataset has more than 100 million rows, + then you can use a tree with 3 levels and add branches (num_branches) to + further partition the dataset. + """ + + num_leaves: int = 1000 + """Required. The number of leaves (i.e. potential partitions) for the vector data. + + You can designate num_leaves for trees with 2 or 3 levels. + We recommend that the number of leaves is number_of_rows_in_dataset/1000. + """ + + num_branches: Optional[int] = None + """Optional. The number of branches to further parititon the vector data. + + You can only designate num_branches for trees with 3 levels. + The number of branches must be fewer than the number of leaves + We recommend that the number of leaves is between 1000 and sqrt(number_of_rows_in_dataset). + """ + + class SpannerVectorStoreSettings(BaseModel): """Settings for Spanner Vector Store. @@ -86,18 +154,19 @@ class SpannerVectorStoreSettings(BaseModel): vertex_ai_embedding_model_name: str """Required. The Vertex AI embedding model name, which is used to generate embeddings for vector store and vector similarity search. - For example, 'text-embedding-005'. - Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field. - Otherwise, a runtime error might be raised during a query. + For example, 'text-embedding-005'. + + Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field. + Otherwise, a runtime error might be raised during a query. """ - selected_columns: List[str] = [] + selected_columns: list[str] = [] """Required. The vector store table columns to return in the vector similarity search result. - By default, only the `content_column` value and the distance value are returned. - If sepecified, the list of selected columns and the distance value are returned. - For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. + By default, only the `content_column` value and the distance value are returned. + If sepecified, the list of selected columns and the distance value are returned. + For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. """ nearest_neighbors_algorithm: NearestNeighborsAlgorithm = ( @@ -105,8 +174,8 @@ class SpannerVectorStoreSettings(BaseModel): ) """The algorithm used to perform vector similarity search. This value can be EXACT_NEAREST_NEIGHBORS or APPROXIMATE_NEAREST_NEIGHBORS. - For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors - For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors + For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors + For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors """ top_k: int = 4 @@ -118,16 +187,41 @@ class SpannerVectorStoreSettings(BaseModel): num_leaves_to_search: Optional[int] = None """Optional. This option specifies how many leaf nodes of the index are searched. - Note: this option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. - For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices + Note: This option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. + For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices """ additional_filter: Optional[str] = None """Optional. An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query.""" + vector_search_index_settings: Optional[VectorSearchIndexSettings] = None + """Optional. Settings for the index for use with Approximate Nearest Neighbor (ANN) in the vector store. + + Note: This option is only required when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. + For more details, see https://docs.cloud.google.com/spanner/docs/vector-indexes + """ + + additional_columns_to_setup: Optional[list[TableColumn]] = None + """Optional. A list of supplemental columns to be created when initializing a new vector store table or inserting content rows. + + Note: This configuration is only utilized during the initial table setup + or when inserting content rows. + """ + + primary_key_columns: Optional[list[str]] = None + """Optional. Specifies the column names to be used as the primary key for a new vector store table. + + If provided, every column name listed here must be defined within + `additional_columns_to_setup`. If this field is omitted (set to `None`), + defaults to a single primary key column named `id` which automatically + generates UUIDs for each entry. + + Note: This field is only used during the creation phase of a new vector store. + """ + @model_validator(mode="after") def __post_init__(self): - """Validate the embedding settings.""" + """Validate the vector store settings.""" if not self.vector_length or self.vector_length <= 0: raise ValueError( "Invalid vector length in the Spanner vector store settings." @@ -136,6 +230,17 @@ def __post_init__(self): if not self.selected_columns: self.selected_columns = [self.content_column] + if self.primary_key_columns: + cols = {self.content_column, self.embedding_column} + if self.additional_columns_to_setup: + cols.update({c.name for c in self.additional_columns_to_setup}) + + for pk in self.primary_key_columns: + if pk not in cols: + raise ValueError( + f"Primary key column '{pk}' not found in column definitions." + ) + return self diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/tools/spanner/utils.py index adde521954..ff2531aaf5 100644 --- a/src/google/adk/tools/spanner/utils.py +++ b/src/google/adk/tools/spanner/utils.py @@ -14,16 +14,31 @@ from __future__ import annotations +import asyncio +import itertools import json +import logging +from typing import Generator +from typing import Iterable from typing import Optional +from typing import TYPE_CHECKING from google.auth.credentials import Credentials from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from . import client +from ...features import experimental +from ...features import FeatureName from ..tool_context import ToolContext from .settings import QueryResultMode from .settings import SpannerToolSettings +from .settings import SpannerVectorStoreSettings + +if TYPE_CHECKING: + from google.cloud import spanner + from google.genai import Client + +logger = logging.getLogger("google_adk." + __name__) DEFAULT_MAX_EXECUTED_QUERY_RESULT_ROWS = 50 @@ -115,17 +130,18 @@ def embed_contents( vertex_ai_embedding_model_name: str, contents: list[str], output_dimensionality: Optional[int] = None, + genai_client: Client | None = None, ) -> list[list[float]]: """Embed the given contents into list of vectors using the Vertex AI embedding model endpoint.""" try: from google.genai import Client from google.genai.types import EmbedContentConfig - client = Client() + genai_client = genai_client or Client() config = EmbedContentConfig() if output_dimensionality: config.output_dimensionality = output_dimensionality - response = client.models.embed_content( + response = genai_client.models.embed_content( model=vertex_ai_embedding_model_name, contents=contents, config=config, @@ -133,3 +149,601 @@ def embed_contents( return [list(e.values) for e in response.embeddings] except Exception as ex: raise RuntimeError(f"Failed to embed content: {ex!r}") from ex + + +async def embed_contents_async( + vertex_ai_embedding_model_name: str, + contents: list[str], + output_dimensionality: Optional[int] = None, + genai_client: Client | None = None, +) -> list[list[float]]: + """Embed the given contents into list of vectors using the Vertex AI embedding model endpoint.""" + try: + from google.genai import Client + from google.genai.types import EmbedContentConfig + + genai_client = genai_client or Client() + config = EmbedContentConfig() + if output_dimensionality: + config.output_dimensionality = output_dimensionality + response = await genai_client.aio.models.embed_content( + model=vertex_ai_embedding_model_name, + contents=contents, + config=config, + ) + return [list(e.values) for e in response.embeddings] + except Exception as ex: + raise RuntimeError(f"Failed to embed content: {ex!r}") from ex + + +@experimental(FeatureName.SPANNER_VECTOR_STORE) +class SpannerVectorStore: + """A class for orchestrating and providing utility functions for a Spanner vector store. + + This class provides utility functions for setting up and adding contents to a + vector store table in a Google Cloud Spanner database, based on the given + Spanner tool settings. + """ + + DEFAULT_VECTOR_STORE_ID_COLUMN_NAME = "id" + SPANNER_VECTOR_STORE_USER_AGENT = "adk-spanner-vector-store" + + def __init__( + self, + settings: SpannerToolSettings, + credentials: Credentials | None = None, + spanner_client: spanner.Client | None = None, + genai_client: Client | None = None, + ): + """Initializes the SpannerVectorStore with validated settings and clients. + + This constructor sets up the connection to a specific Spanner database and + configures the necessary clients for vector operations. + + Args: + settings (SpannerToolSettings): The settings for the tool. + credentials (Credentials | None): Credentials for Spanner operations. This + is used to initialize a new Spanner client only if `spanner_client` + is not explicitly provided. + spanner_client (spanner.Client | None): An pre-configured `spanner.Client` + instance. If not provided, a new client will be created. + genai_client (Client | None): Google GenAI client used for + generating vector embeddings. + """ + + if not settings.vector_store_settings: + raise ValueError("Spanner vector store settings are not set.") + + self._settings = settings + + if not spanner_client: + self._spanner_client = client.get_spanner_client( + project=self._vector_store_settings.project_id, + credentials=credentials, + ) + else: + self._spanner_client = spanner_client + client_user_agent = self._spanner_client._client_info.user_agent + if not client_user_agent: + self._spanner_client._client_info.user_agent = client.USER_AGENT + elif client.USER_AGENT not in client_user_agent: + self._spanner_client._client_info.user_agent = " ".join( + [client_user_agent, client.USER_AGENT] + ) + self._spanner_client._client_info.user_agent = " ".join([ + self._spanner_client._client_info.user_agent, + self.SPANNER_VECTOR_STORE_USER_AGENT, + ]) + + instance = self._spanner_client.instance( + self._vector_store_settings.instance_id + ) + if not instance.exists(): + raise ValueError( + "Instance id {} doesn't exist.".format( + self._vector_store_settings.instance_id + ) + ) + self._database = instance.database(self._vector_store_settings.database_id) + if not self._database.exists(): + raise ValueError( + "Database id {} doesn't exist.".format( + self._vector_store_settings.database_id + ) + ) + + self._genai_client = genai_client + + @property + def _vector_store_settings(self) -> SpannerVectorStoreSettings: + """Returns the Spanner vector store settings.""" + + if self._settings.vector_store_settings is None: + raise ValueError("Spanner vector store settings are not set.") + return self._settings.vector_store_settings + + def _create_vector_store_table_ddl( + self, + dialect: DatabaseDialect, + ) -> str: + """Creates the DDL statements necessary to define a vector store table in Spanner. + + The vector store table is created based on the given settings. + - **id_column** (STRING or text): The default primary key, typically a UUID. + Note: This column is only included in the DDL when `primary_key_columns` + is not specified in the settings. + - **content_column** (STRING or text): The source text content used to + generate the embedding. + - **embedding_column** (ARRAY or float4[]): The vector embedding + column corresponding to the content. + - **additional_columns_to_setup** (provided in the settings): Additional + columns to be added to the vector store table. + + Args: + dialect: The database dialect (e.g., GOOGLE_STANDARD_SQL or POSTGRESQL) + governing the DDL syntax. + + Returns: + A DDL statement string defining the vector store table. + """ + + primary_key_columns = self._vector_store_settings.primary_key_columns or [ + self.DEFAULT_VECTOR_STORE_ID_COLUMN_NAME + ] + + column_definitions = [] + + if self._vector_store_settings.primary_key_columns is None: + if dialect == DatabaseDialect.POSTGRESQL: + column_definitions.append( + f"{self.DEFAULT_VECTOR_STORE_ID_COLUMN_NAME} varchar(36) DEFAULT" + " spanner.generate_uuid()" + ) + else: + column_definitions.append( + f"{self.DEFAULT_VECTOR_STORE_ID_COLUMN_NAME} STRING(36) DEFAULT" + " (GENERATE_UUID())" + ) + + # Additional Columns + if self._vector_store_settings.additional_columns_to_setup: + for column in self._vector_store_settings.additional_columns_to_setup: + null_stmt = "" if column.is_nullable else " NOT NULL" + column_definitions.append(f"{column.name} {column.type}{null_stmt}") + + # Content and Embedding Columns + if dialect == DatabaseDialect.POSTGRESQL: + column_definitions.append( + f"{self._vector_store_settings.content_column} text" + ) + column_definitions.append( + f"{self._vector_store_settings.embedding_column} float4[] " + f"VECTOR LENGTH {self._vector_store_settings.vector_length}" + ) + else: + column_definitions.append( + f"{self._vector_store_settings.content_column} STRING(MAX)" + ) + column_definitions.append( + f"{self._vector_store_settings.embedding_column} " + f"ARRAY(vector_length=>{self._vector_store_settings.vector_length})" + ) + + inner_ddl = ",\n ".join(column_definitions) + pk_stmt = ", ".join(primary_key_columns) + + if dialect == DatabaseDialect.POSTGRESQL: + return ( + f"CREATE TABLE IF NOT EXISTS {self._vector_store_settings.table_name}" + f" (\n {inner_ddl},\n PRIMARY KEY({pk_stmt})\n)" + ) + else: + return ( + f"CREATE TABLE IF NOT EXISTS {self._vector_store_settings.table_name}" + f" (\n {inner_ddl}\n) PRIMARY KEY({pk_stmt})" + ) + + def _create_ann_vector_search_index_ddl( + self, + dialect: DatabaseDialect, + ) -> str: + """Create a DDL statement to create a vector search index for ANN. + + Args: + dialect: The database dialect (e.g., GOOGLE_STANDARD_SQL or POSTGRESQL) + governing the DDL syntax. + + Returns: + A DDL statement string to create the vector search index. + """ + + # This is only required when the nearest neighbors search algorithm is + # APPROXIMATE_NEAREST_NEIGHBORS. + if not self._vector_store_settings.vector_search_index_settings: + raise ValueError("Vector search index settings are not set.") + + if dialect != DatabaseDialect.GOOGLE_STANDARD_SQL: + raise ValueError( + "ANN is only supported for the Google Standard SQL dialect." + ) + + index_columns = [self._vector_store_settings.embedding_column] + if ( + self._vector_store_settings.vector_search_index_settings.additional_key_columns + ): + index_columns.extend( + self._vector_store_settings.vector_search_index_settings.additional_key_columns + ) + + statement = ( + "CREATE VECTOR INDEX IF NOT EXISTS" + f" {self._vector_store_settings.vector_search_index_settings.index_name}\n\tON" + f" {self._vector_store_settings.table_name}({', '.join(index_columns)})" + ) + + if ( + self._vector_store_settings.vector_search_index_settings.additional_storing_columns + ): + statement += ( + "\n\tSTORING" + f" ({', '.join(self._vector_store_settings.vector_search_index_settings.additional_storing_columns)})" + ) + + statement += ( + f"\n\tWHERE {self._vector_store_settings.embedding_column} IS NOT NULL" + ) + + options_segments = [ + f"distance_type='{self._vector_store_settings.distance_type}'" + ] + + if ( + getattr( + self._vector_store_settings.vector_search_index_settings, + "tree_depth", + 0, + ) + > 0 + ): + tree_depth = ( + self._vector_store_settings.vector_search_index_settings.tree_depth + ) + if tree_depth not in (2, 3): + raise ValueError( + f"Vector search index settings: tree_depth: {tree_depth} must be" + " either 2 or 3" + ) + options_segments.append( + f"tree_depth={self._vector_store_settings.vector_search_index_settings.tree_depth}" + ) + + if ( + self._vector_store_settings.vector_search_index_settings.num_branches + is not None + and self._vector_store_settings.vector_search_index_settings.num_branches + > 0 + ): + options_segments.append( + f"num_branches={self._vector_store_settings.vector_search_index_settings.num_branches}" + ) + + if self._vector_store_settings.vector_search_index_settings.num_leaves > 0: + options_segments.append( + f"num_leaves={self._vector_store_settings.vector_search_index_settings.num_leaves}" + ) + + statement += "\n\tOPTIONS(" + ", ".join(options_segments) + ")" + + return statement.strip() + + def create_vector_store(self): + """Creates a new vector store within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + try: + ddl = self._create_vector_store_table_ddl(self._database.database_dialect) + logger.debug( + "Executing DDL statement to create vector store table: %s", ddl + ) + operation = self._database.update_ddl([ddl]) + + # Wait for completion + logger.info("Waiting for update database operation to complete...") + operation.result() + + logger.debug( + "Successfully created the vector store table: %s in Spanner" + " database: projects/%s/instances/%s/databases/%s", + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + except Exception as e: + logger.error("Failed to create the vector store. Error: %s", e) + raise + + def create_vector_search_index(self): + """Creates a vector search index within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + try: + if not self._vector_store_settings.vector_search_index_settings: + logger.warning("No vector search index settings found.") + return + + ddl = self._create_ann_vector_search_index_ddl( + self._database.database_dialect + ) + logger.debug( + "Executing DDL statement to create vector search index: %s", ddl + ) + operation = self._database.update_ddl([ddl]) + + # Wait for completion + logger.info("Waiting for update database operation to complete...") + operation.result() + + logger.debug( + "Successfully created the vector search index: %s in Spanner" + " database: projects/%s/instances/%s/databases/%s", + self._vector_store_settings.vector_search_index_settings.index_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + + except Exception as e: + logger.error("Failed to create the vector search index. Error: %s", e) + raise + + async def create_vector_store_async(self): + """Asynchronously creates a new vector store within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + await asyncio.to_thread(self.create_vector_store) + + async def create_vector_search_index_async(self): + """Asynchronously creates a vector search index within the Google Cloud Spanner database. + + Raises: + RuntimeError: If the DDL statement execution against Spanner fails. + """ + await asyncio.to_thread(self.create_vector_search_index) + + def _prepare_and_validate_batches( + self, + contents: Iterable[str], + additional_columns_values: Iterable[dict] | None, + batch_size: int, + ) -> Generator[tuple[list[str], list[dict], int], None, None]: + """Prepares and validates batches of contents and additional columns for insertion into the vector store.""" + content_iter = iter(contents) + + value_iter = ( + iter(additional_columns_values) + if additional_columns_values is not None + else itertools.repeat({}) + ) + + batches = iter(lambda: list(itertools.islice(content_iter, batch_size)), []) + + for index, content_batch in enumerate(batches): + actual_index = index * batch_size + value_batch = list(itertools.islice(value_iter, len(content_batch))) + + if len(value_batch) < len(content_batch): + raise ValueError( + f"Data mismatch: ended at index {actual_index}. Expected" + f" {len(content_batch)} values for this batch, but got" + f" {len(value_batch)}." + ) + + yield (content_batch, value_batch, actual_index) + + if additional_columns_values is not None: + if next(value_iter, None) is not None: + raise ValueError( + "additional_columns_values contains more items than contents." + ) + + def add_contents( + self, + contents: Iterable[str], + *, + additional_columns_values: Iterable[dict] | None = None, + batch_size: int = 200, + ): + """Adds text contents to the vector store. + + Performs batch embedding generation and subsequent insertion of the contents + into the vector store table in the Google Cloud Spanner database. + + Args: + contents (Iterable[str]): An iterable collection of string contents to + be added to the vector store. + additional_columns_values (Iterable[dict] | None): An optional iterable + of dictionary containing values for additional columns to be stored + with the content row. Keys must match column names. + batch_size (int): The maximum number of items to process and insert in a + single batch. Defaults to 200. + """ + total_rows = 0 + try: + self._database.reload() + + cols = [ + c.name + for c in self._vector_store_settings.additional_columns_to_setup or [] + ] + + batch_gen = self._prepare_and_validate_batches( + contents, additional_columns_values, batch_size + ) + + for content_b, extra_b, batch_index in batch_gen: + logger.debug( + "Embedding content batch %d to %d (size: %d)...", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + embeddings = embed_contents( + self._vector_store_settings.vertex_ai_embedding_model_name, + content_b, + self._vector_store_settings.vector_length, + self._genai_client, + ) + + logger.debug( + "Committing batch mutation %d to %d (size: %d).", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + mutation_rows = [ + # [content, embedding, ...additional_columns] + [c, e, *map(extra.get, cols)] + for c, e, extra in zip(content_b, embeddings, extra_b) + ] + with self._database.batch() as batch: + batch.insert_or_update( + table=self._vector_store_settings.table_name, + columns=[ + self._vector_store_settings.content_column, + self._vector_store_settings.embedding_column, + ] + + cols, + values=mutation_rows, + ) + + total_rows += len(mutation_rows) + + logger.debug( + "Successfully added %d contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s", + total_rows, + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + + except Exception as e: + logger.error( + "Failed to finish adding contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s. Total" + " rows added: %d. Error: %s", + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + total_rows, + e, + ) + raise + + async def add_contents_async( + self, + contents: Iterable[str], + *, + additional_columns_values: Iterable[dict] | None = None, + batch_size: int = 200, + ): + """Asynchronously adds text contents to the vector store. + + Performs batch embedding generation and subsequent insertion of the contents + into the vector store table in the Google Cloud Spanner database. + + Args: + contents (Iterable[str]): An iterable collection of string contents to + be added to the vector store. + additional_columns_values (Iterable[dict] | None): An optional iterable + of dictionary containing values for additional columns to be stored + with the content row. Keys must match column names. + batch_size (int): The maximum number of items to process and insert in a + single batch. Defaults to 200. + """ + total_rows = 0 + try: + await asyncio.to_thread(self._database.reload) + + cols = [ + c.name + for c in self._vector_store_settings.additional_columns_to_setup or [] + ] + + batch_gen = self._prepare_and_validate_batches( + contents, additional_columns_values, batch_size + ) + + for content_b, extra_b, batch_index in batch_gen: + logger.debug( + "Embedding content batch %d to %d (size: %d)...", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + embeddings = await embed_contents_async( + self._vector_store_settings.vertex_ai_embedding_model_name, + content_b, + self._vector_store_settings.vector_length, + self._genai_client, + ) + + logger.debug( + "Committing batch mutation %d to %d (size: %d).", + batch_index, + batch_index + len(content_b), + len(content_b), + ) + mutation_rows = [ + # [content, embedding, ...additional_columns] + [c, e, *map(extra.get, cols)] + for c, e, extra in zip(content_b, embeddings, extra_b) + ] + + def _commit_batch(columns, rows_to_commit): + with self._database.batch() as batch: + batch.insert_or_update( + table=self._vector_store_settings.table_name, + columns=[ + self._vector_store_settings.content_column, + self._vector_store_settings.embedding_column, + ] + + columns, + values=rows_to_commit, + ) + + await asyncio.to_thread(_commit_batch, cols, mutation_rows) + total_rows += len(mutation_rows) + + logger.debug( + "Successfully added %d contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s", + total_rows, + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + ) + + except Exception as e: + logger.error( + "Failed to finish adding contents to the vector store table: %s in" + " Spanner database: projects/%s/instances/%s/databases/%s. Total" + " rows added: %d. Error: %s", + self._vector_store_settings.table_name, + self._vector_store_settings.project_id, + self._vector_store_settings.instance_id, + self._vector_store_settings.database_id, + total_rows, + e, + ) + raise diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 8bd4a22f20..d395a5516f 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -665,6 +665,157 @@ def test_construct_message_parts_from_session_empty_events(self): assert parts == [] assert context_id is None + def test_construct_message_parts_from_session_stops_on_agent_reply(self): + """Test message parts construction stops on agent reply by default.""" + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = None + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 1 + assert parts[0].text == "User 2" + assert context_id is None + + def test_construct_message_parts_from_session_stateless_full_history(self): + """Test full history for stateless agent when enabled.""" + self.agent._full_history_when_stateless = True + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = None + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 3 + assert parts[0].text == "User 1" + assert parts[1].text == "Agent 1" + assert parts[2].text == "User 2" + assert context_id is None + + def test_construct_message_parts_from_session_stateful_partial_history(self): + """Test partial history for stateful agent when full history is enabled.""" + self.agent._full_history_when_stateless = True + part1 = Mock() + part1.text = "User 1" + content1 = Mock() + content1.parts = [part1] + user1 = Mock() + user1.content = content1 + user1.author = "user" + user1.custom_metadata = None + + part2 = Mock() + part2.text = "Agent 1" + content2 = Mock() + content2.parts = [part2] + agent1 = Mock() + agent1.content = content2 + agent1.author = self.agent.name + agent1.custom_metadata = {A2A_METADATA_PREFIX + "context_id": "ctx-1"} + + part3 = Mock() + part3.text = "User 2" + content3 = Mock() + content3.parts = [part3] + user2 = Mock() + user2.content = content3 + user2.author = "user" + user2.custom_metadata = None + + self.mock_session.events = [user1, agent1, user2] + + def mock_converter(part): + mock_a2a_part = Mock() + mock_a2a_part.text = part.text + return mock_a2a_part + + self.mock_genai_part_converter.side_effect = mock_converter + + with patch( + "google.adk.agents.remote_a2a_agent._present_other_agent_message" + ) as mock_present: + mock_present.side_effect = lambda event: event + parts, context_id = self.agent._construct_message_parts_from_session( + self.mock_context + ) + assert len(parts) == 1 + assert parts[0].text == "User 2" + assert context_id == "ctx-1" + @pytest.mark.asyncio async def test_handle_a2a_response_success_with_message(self): """Test successful A2A response handling with message.""" diff --git a/tests/unittests/sessions/migration/test_database_schema.py b/tests/unittests/sessions/migration/test_database_schema.py index 239da2f1e2..d08bb97ba0 100644 --- a/tests/unittests/sessions/migration/test_database_schema.py +++ b/tests/unittests/sessions/migration/test_database_schema.py @@ -51,12 +51,18 @@ async def test_new_db_uses_latest_schema(tmp_path): lambda sync_conn: inspect(sync_conn).has_table('adk_internal_metadata') ) assert has_metadata_table - schema_version = await conn.run_sync( - lambda sync_conn: sync_conn.execute( - text('SELECT value FROM adk_internal_metadata WHERE key = :key'), - {'key': _schema_check_utils.SCHEMA_VERSION_KEY}, - ).scalar_one_or_none() - ) + + def get_schema_version(sync_conn): + inspector = inspect(sync_conn) + key_col = inspector.dialect.identifier_preparer.quote('key') + return sync_conn.execute( + text( + f'SELECT value FROM adk_internal_metadata WHERE {key_col} = :key' + ), + {'key': _schema_check_utils.SCHEMA_VERSION_KEY}, + ).scalar_one_or_none() + + schema_version = await conn.run_sync(get_schema_version) assert schema_version == _schema_check_utils.LATEST_SCHEMA_VERSION # Verify events table columns for v1 diff --git a/tests/unittests/tools/spanner/test_utils.py b/tests/unittests/tools/spanner/test_utils.py new file mode 100644 index 0000000000..56c04a2dce --- /dev/null +++ b/tests/unittests/tools/spanner/test_utils.py @@ -0,0 +1,384 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.tools.spanner import utils as spanner_utils +from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings +from google.adk.tools.spanner.settings import TableColumn +from google.adk.tools.spanner.settings import VectorSearchIndexSettings +from google.cloud.spanner_admin_database_v1.types import DatabaseDialect +from google.cloud.spanner_v1 import batch as spanner_batch +from google.cloud.spanner_v1 import client as spanner_client_v1 +from google.cloud.spanner_v1 import database as spanner_database +from google.cloud.spanner_v1 import instance as spanner_instance +import pytest + + +@pytest.fixture +def vector_store_settings(): + """Fixture for SpannerVectorStoreSettings.""" + return SpannerVectorStoreSettings( + project_id="test-project", + instance_id="test-instance", + database_id="test-database", + table_name="test_vector_store", + content_column="content", + embedding_column="embedding", + vector_length=768, + vertex_ai_embedding_model_name="textembedding", + ) + + +@pytest.fixture +def spanner_tool_settings(vector_store_settings): + """Fixture for SpannerToolSettings.""" + return SpannerToolSettings(vector_store_settings=vector_store_settings) + + +@pytest.fixture +def mock_spanner_database(): + """Fixture for a mocked spanner database.""" + mock_database = mock.create_autospec(spanner_database.Database, instance=True) + mock_database.exists.return_value = True + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + return mock_database + + +@pytest.fixture +def mock_spanner_instance(mock_spanner_database): + """Fixture for a mocked spanner instance.""" + mock_instance = mock.create_autospec(spanner_instance.Instance, instance=True) + mock_instance.exists.return_value = True + mock_instance.database.return_value = mock_spanner_database + return mock_instance + + +@pytest.fixture +def mock_spanner_client(mock_spanner_instance): + """Fixture for a mocked spanner client.""" + mock_client = mock.create_autospec(spanner_client_v1.Client, instance=True) + mock_client.instance.return_value = mock_spanner_instance + mock_client._client_info = mock.Mock(user_agent="test-agent") + return mock_client + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_successful( + mock_embed_contents, + spanner_tool_settings, + mock_spanner_client, + mock_spanner_database, + mocker, +): + """Test that add_contents successfully adds content.""" + mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]] + mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True) + mock_batch.__enter__.return_value = mock_batch + mock_spanner_database.batch.return_value = mock_batch + + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store._database = mock_spanner_database + contents = ["content1", "content2"] + vector_store.add_contents(contents=contents) + + mock_spanner_database.reload.assert_called_once() + mock_spanner_database.batch.assert_called_once() + mock_batch.insert_or_update.assert_called_once_with( + table="test_vector_store", + columns=["content", "embedding"], + values=[ + ["content1", [1.0, 2.0]], + ["content2", [3.0, 4.0]], + ], + ) + mock_embed_contents.assert_called_once_with( + "textembedding", contents, 768, mock.ANY + ) + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_with_metadata( + mock_embed_contents, + spanner_tool_settings, + mock_spanner_client, + mock_spanner_database, + mocker, +): + """Test that add_contents successfully adds content with metadata.""" + mock_embed_contents.return_value = [[1.0, 2.0], [3.0, 4.0]] + mock_batch = mocker.create_autospec(spanner_batch.Batch, instance=True) + mock_batch.__enter__.return_value = mock_batch + mock_spanner_database.batch.return_value = mock_batch + spanner_tool_settings.vector_store_settings.additional_columns_to_setup = [ + TableColumn(name="metadata", type="JSON") + ] + + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store._database = mock_spanner_database + contents = ["content1", "content2"] + additional_columns_values = [ + {"metadata": {"meta1": "val1"}}, + {"metadata": {"meta2": "val2"}}, + ] + vector_store.add_contents( + contents=contents, + additional_columns_values=additional_columns_values, + ) + + mock_spanner_database.batch.assert_called_once() + mock_batch.insert_or_update.assert_called_once_with( + table="test_vector_store", + columns=["content", "embedding", "metadata"], + values=[ + ["content1", [1.0, 2.0], {"meta1": "val1"}], + ["content2", [3.0, 4.0], {"meta2": "val2"}], + ], + ) + + +def test_add_contents_empty_contents( + spanner_tool_settings, mock_spanner_client, mock_spanner_database +): + """Test that add_contents does nothing when contents is empty.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.add_contents(contents=[]) + mock_spanner_database.batch.assert_not_called() + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_additional_columns_list_mismatch( + mock_embed_contents, spanner_tool_settings, mock_spanner_client +): + """Test that add_contents raises an error if additional_columns_values and contents lengths differ.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises( + ValueError, + match="additional_columns_values contains more items than contents.", + ): + vector_store.add_contents( + contents=["content1"], + additional_columns_values=[ + {"col1": "val1"}, + {"col1": "val2"}, + ], + ) + + +@mock.patch.object(spanner_utils, "embed_contents", autospec=True) +def test_add_contents_embedding_fails( + mock_embed_contents, spanner_tool_settings, mock_spanner_client +): + """Test that add_contents fails if embedding fails.""" + mock_embed_contents.side_effect = RuntimeError("Embedding failed") + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises(RuntimeError, match="Embedding failed"): + vector_store.add_contents(contents=["content1", "content2"]) + + +def test_init_raises_error_if_vector_store_settings_not_set(): + """Test that SpannerVectorStore raises an error if vector_store_settings is not set.""" + settings = SpannerToolSettings() + with pytest.raises( + ValueError, match="Spanner vector store settings are not set." + ): + spanner_utils.SpannerVectorStore(settings) + + +@pytest.mark.parametrize( + "dialect, expected_ddl", + [ + ( + DatabaseDialect.GOOGLE_STANDARD_SQL, + ( + "CREATE TABLE IF NOT EXISTS test_vector_store (\n" + " id STRING(36) DEFAULT (GENERATE_UUID()),\n" + " content STRING(MAX),\n" + " embedding ARRAY(vector_length=>768)\n" + ") PRIMARY KEY(id)" + ), + ), + ( + DatabaseDialect.POSTGRESQL, + ( + "CREATE TABLE IF NOT EXISTS test_vector_store (\n" + " id varchar(36) DEFAULT spanner.generate_uuid(),\n" + " content text,\n" + " embedding float4[] VECTOR LENGTH 768,\n" + " PRIMARY KEY(id)\n" + ")" + ), + ), + ], +) +def test_create_vector_store_table_ddl( + spanner_tool_settings, mock_spanner_client, dialect, expected_ddl +): + """Test DDL creation for different SQL dialects.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + ddl = vector_store._create_vector_store_table_ddl(dialect) + assert ddl == expected_ddl + + +def test_create_ann_vector_search_index_ddl_raises_error_for_postgresql( + spanner_tool_settings, vector_store_settings, mock_spanner_client +): + """Test that creating an ANN index raises an error for PostgreSQL.""" + vector_store_settings.vector_search_index_settings = mock.Mock() + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises( + ValueError, + match="ANN is only supported for the Google Standard SQL dialect.", + ): + vector_store._create_ann_vector_search_index_ddl( + DatabaseDialect.POSTGRESQL + ) + + +def test_create_vector_store( + spanner_tool_settings, mock_spanner_client, mock_spanner_database +): + """Test the vector store creation process.""" + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.create_vector_store() + mock_spanner_database.update_ddl.assert_called_once() + ddl_statement = mock_spanner_database.update_ddl.call_args[0][0] + assert "CREATE TABLE IF NOT EXISTS test_vector_store" in ddl_statement[0] + + +def test_create_vector_search_index_no_settings( + spanner_tool_settings, mock_spanner_client, mock_spanner_database +): + """Test that create_vector_search_index does nothing if settings are not present.""" + spanner_tool_settings.vector_store_settings.vector_search_index_settings = ( + None + ) + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.create_vector_search_index() + mock_spanner_database.update_ddl.assert_not_called() + + +def test_create_vector_search_index_successful_google_sql( + spanner_tool_settings, + vector_store_settings, + mock_spanner_client, + mock_spanner_database, +): + """Test that create_vector_search_index successfully creates index for Google SQL.""" + mock_spanner_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + vector_store_settings.vector_search_index_settings = ( + VectorSearchIndexSettings( + index_name="test_vector_index", + tree_depth=3, + num_branches=10, + num_leaves=20, + ) + ) + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + vector_store.create_vector_search_index() + mock_spanner_database.update_ddl.assert_called_once() + ddl_statement = mock_spanner_database.update_ddl.call_args[0][0] + expected_ddl = ( + "CREATE VECTOR INDEX IF NOT EXISTS test_vector_index\n" + "\tON test_vector_store(embedding)\n" + "\tWHERE embedding IS NOT NULL\n" + "\tOPTIONS(distance_type='COSINE', tree_depth=3, num_branches=10, " + "num_leaves=20)" + ) + assert ddl_statement[0] == expected_ddl + + +def test_create_vector_search_index_fails( + spanner_tool_settings, + vector_store_settings, + mock_spanner_client, + mock_spanner_database, +): + """Test that create_vector_search_index raises an error if DDL execution fails.""" + mock_spanner_database.update_ddl.side_effect = RuntimeError("DDL failed") + vector_store_settings.vector_search_index_settings = ( + VectorSearchIndexSettings(index_name="test_vector_index") + ) + with mock.patch.object( + spanner_utils.client, + "get_spanner_client", + autospec=True, + return_value=mock_spanner_client, + ): + vector_store = spanner_utils.SpannerVectorStore(spanner_tool_settings) + with pytest.raises(RuntimeError, match="DDL failed"): + vector_store.create_vector_search_index()